feat(config): add per-command model defaults
Introduce support for per-command model defaults in config.toml, overriding global default if set. Update GetModel to accept command name and prioritize: flag > command default > global default. Add example config file and adjust all commands to pass their name. Update tests accordingly.
This commit is contained in:
parent
06917f93d8
commit
24be047322
@ -20,7 +20,7 @@ var agentCmd = &cobra.Command{
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
instruction := args[0]
|
||||
modelFlag, _ := cmd.Flags().GetString("model")
|
||||
model := config.GetModel(modelFlag)
|
||||
model := config.GetModel("agent", modelFlag)
|
||||
|
||||
client := grok.NewClient()
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ var chatCmd = &cobra.Command{
|
||||
Short: "Simple interactive CLI chat with Grok (full history + streaming)",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
modelFlag, _ := cmd.Flags().GetString("model")
|
||||
model := config.GetModel(modelFlag)
|
||||
model := config.GetModel("chat", modelFlag)
|
||||
|
||||
client := grok.NewClient()
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ var commitCmd = &cobra.Command{
|
||||
return
|
||||
}
|
||||
modelFlag, _ := cmd.Flags().GetString("model")
|
||||
model := config.GetModel(modelFlag)
|
||||
model := config.GetModel("commit", modelFlag)
|
||||
|
||||
client := grok.NewClient()
|
||||
messages := []map[string]string{
|
||||
|
||||
@ -24,7 +24,7 @@ var commitMsgCmd = &cobra.Command{
|
||||
return
|
||||
}
|
||||
modelFlag, _ := cmd.Flags().GetString("model")
|
||||
model := config.GetModel(modelFlag)
|
||||
model := config.GetModel("commitmsg", modelFlag)
|
||||
|
||||
client := grok.NewClient()
|
||||
messages := []map[string]string{
|
||||
|
||||
@ -22,7 +22,7 @@ var editCmd = &cobra.Command{
|
||||
instruction := args[1]
|
||||
|
||||
modelFlag, _ := cmd.Flags().GetString("model")
|
||||
model := config.GetModel(modelFlag)
|
||||
model := config.GetModel("edit", modelFlag)
|
||||
|
||||
logger.Info("edit command started",
|
||||
"file", filePath,
|
||||
|
||||
@ -23,7 +23,7 @@ var historyCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
modelFlag, _ := cmd.Flags().GetString("model")
|
||||
model := config.GetModel(modelFlag)
|
||||
model := config.GetModel("history", modelFlag)
|
||||
|
||||
client := grok.NewClient()
|
||||
messages := []map[string]string{
|
||||
|
||||
@ -116,7 +116,7 @@ func runLint(cmd *cobra.Command, args []string) {
|
||||
logger.Info("requesting AI fixes", "file", absPath, "original_size", len(originalContent))
|
||||
|
||||
modelFlag, _ := cmd.Flags().GetString("model")
|
||||
model := config.GetModel(modelFlag)
|
||||
model := config.GetModel("lint", modelFlag)
|
||||
|
||||
client := grok.NewClient()
|
||||
messages := buildLintFixMessages(result, string(originalContent))
|
||||
|
||||
@ -27,7 +27,7 @@ var prDescribeCmd = &cobra.Command{
|
||||
return
|
||||
}
|
||||
modelFlag, _ := cmd.Flags().GetString("model")
|
||||
model := config.GetModel(modelFlag)
|
||||
model := config.GetModel("prdescribe", modelFlag)
|
||||
|
||||
client := grok.NewClient()
|
||||
messages := []map[string]string{
|
||||
|
||||
@ -15,7 +15,7 @@ var reviewCmd = &cobra.Command{
|
||||
Short: "Review the current repository or directory",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
modelFlag, _ := cmd.Flags().GetString("model")
|
||||
model := config.GetModel(modelFlag)
|
||||
model := config.GetModel("review", modelFlag)
|
||||
|
||||
client := grok.NewClient()
|
||||
diff, err := git.Run([]string{"diff", "--no-color"})
|
||||
|
||||
29
config.toml.example
Normal file
29
config.toml.example
Normal file
@ -0,0 +1,29 @@
|
||||
# Example configuration file for Grokkit
|
||||
# Copy this to ~/.config/grokkit/config.toml and customize as needed.
|
||||
|
||||
default_model = "grok-4"
|
||||
|
||||
temperature = 0.7
|
||||
log_level = "info"
|
||||
timeout = 60
|
||||
|
||||
# Model aliases (shorthand names)
|
||||
[aliases]
|
||||
beta = "grok-beta-2"
|
||||
fast = "grok-4-1-fast-non-reasoning"
|
||||
|
||||
# Per-command model defaults (overrides code defaults if set)
|
||||
[commands]
|
||||
lint.model = "grok-4-1-fast-non-reasoning" # Fast model for code fixes
|
||||
agent.model = "grok-4" # Reasoning model for agent tasks
|
||||
chat.model = "grok-4"
|
||||
commit.model = "grok-4"
|
||||
commitmsg.model = "grok-4"
|
||||
edit.model = "grok-4-1-fast-non-reasoning"
|
||||
history.model = "grok-4"
|
||||
prdescribe.model = "grok-4"
|
||||
review.model = "grok-4"
|
||||
|
||||
# Chat history settings
|
||||
[chat]
|
||||
history_file = "~/.config/grokkit/chat_history.json"
|
||||
@ -26,17 +26,31 @@ func Load() {
|
||||
viper.SetDefault("log_level", "info")
|
||||
viper.SetDefault("timeout", 60)
|
||||
|
||||
viper.SetDefault("commands.agent.model", "grok-4")
|
||||
viper.SetDefault("commands.chat.model", "grok-4")
|
||||
viper.SetDefault("commands.commit.model", "grok-4")
|
||||
viper.SetDefault("commands.commitmsg.model", "grok-4")
|
||||
viper.SetDefault("commands.edit.model", "grok-4")
|
||||
viper.SetDefault("commands.history.model", "grok-4")
|
||||
viper.SetDefault("commands.lint.model", "grok-4-1-fast-non-reasoning")
|
||||
viper.SetDefault("commands.prdescribe.model", "grok-4")
|
||||
viper.SetDefault("commands.review.model", "grok-4")
|
||||
|
||||
// Config file is optional, so we ignore read errors
|
||||
_ = viper.ReadInConfig()
|
||||
}
|
||||
|
||||
func GetModel(flagModel string) string {
|
||||
func GetModel(commandName string, flagModel string) string {
|
||||
if flagModel != "" {
|
||||
if alias := viper.GetString("aliases." + flagModel); alias != "" {
|
||||
return alias
|
||||
}
|
||||
return flagModel
|
||||
}
|
||||
cmdModel := viper.GetString("commands." + commandName + ".model")
|
||||
if cmdModel != "" {
|
||||
return cmdModel
|
||||
}
|
||||
return viper.GetString("default_model")
|
||||
}
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ func TestGetModel(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetModel(tt.flagModel)
|
||||
result := GetModel("", tt.flagModel)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetModel(%q) = %q, want %q", tt.flagModel, result, tt.expected)
|
||||
}
|
||||
@ -43,10 +43,54 @@ func TestGetModelWithAlias(t *testing.T) {
|
||||
viper.Set("aliases.beta", "grok-beta-2")
|
||||
viper.SetDefault("default_model", "grok-4")
|
||||
|
||||
result := GetModel("beta")
|
||||
result := GetModel("", "beta")
|
||||
expected := "grok-beta-2"
|
||||
if result != expected {
|
||||
t.Errorf("GetModel('beta') = %q, want %q", result, expected)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCommandModel(t *testing.T) {
|
||||
viper.Reset()
|
||||
viper.SetDefault("default_model", "grok-4")
|
||||
viper.Set("commands.lint.model", "grok-4-1-fast-non-reasoning")
|
||||
viper.Set("commands.other.model", "grok-other")
|
||||
|
||||
tests := []struct {
|
||||
command string
|
||||
flagModel string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
command: "lint",
|
||||
flagModel: "",
|
||||
expected: "grok-4-1-fast-non-reasoning",
|
||||
},
|
||||
{
|
||||
command: "lint",
|
||||
flagModel: "override",
|
||||
expected: "override",
|
||||
},
|
||||
{
|
||||
command: "other",
|
||||
flagModel: "",
|
||||
expected: "grok-other",
|
||||
},
|
||||
{
|
||||
command: "unknown",
|
||||
flagModel: "",
|
||||
expected: "grok-4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.command+"_"+tt.flagModel, func(t *testing.T) {
|
||||
result := GetModel(tt.command, tt.flagModel)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetModel(%q, %q) = %q, want %q", tt.command, tt.flagModel, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user