feat(config): add per-command model defaults
All checks were successful
CI / Test (push) Successful in 30s
CI / Lint (push) Successful in 25s
CI / Build (push) Successful in 19s

Introduce support for per-command model defaults in config.toml, overriding global default if set. Update GetModel to accept command name and prioritize: flag > command default > global default. Add example config file and adjust all commands to pass their name. Update tests accordingly.
This commit is contained in:
Gregory Gauthier 2026-03-02 16:56:56 +00:00
parent 06917f93d8
commit 24be047322
12 changed files with 99 additions and 12 deletions

View File

@ -20,7 +20,7 @@ var agentCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
instruction := args[0] instruction := args[0]
modelFlag, _ := cmd.Flags().GetString("model") modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel(modelFlag) model := config.GetModel("agent", modelFlag)
client := grok.NewClient() client := grok.NewClient()

View File

@ -63,7 +63,7 @@ var chatCmd = &cobra.Command{
Short: "Simple interactive CLI chat with Grok (full history + streaming)", Short: "Simple interactive CLI chat with Grok (full history + streaming)",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
modelFlag, _ := cmd.Flags().GetString("model") modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel(modelFlag) model := config.GetModel("chat", modelFlag)
client := grok.NewClient() client := grok.NewClient()

View File

@ -25,7 +25,7 @@ var commitCmd = &cobra.Command{
return return
} }
modelFlag, _ := cmd.Flags().GetString("model") modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel(modelFlag) model := config.GetModel("commit", modelFlag)
client := grok.NewClient() client := grok.NewClient()
messages := []map[string]string{ messages := []map[string]string{

View File

@ -24,7 +24,7 @@ var commitMsgCmd = &cobra.Command{
return return
} }
modelFlag, _ := cmd.Flags().GetString("model") modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel(modelFlag) model := config.GetModel("commitmsg", modelFlag)
client := grok.NewClient() client := grok.NewClient()
messages := []map[string]string{ messages := []map[string]string{

View File

@ -22,7 +22,7 @@ var editCmd = &cobra.Command{
instruction := args[1] instruction := args[1]
modelFlag, _ := cmd.Flags().GetString("model") modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel(modelFlag) model := config.GetModel("edit", modelFlag)
logger.Info("edit command started", logger.Info("edit command started",
"file", filePath, "file", filePath,

View File

@ -23,7 +23,7 @@ var historyCmd = &cobra.Command{
} }
modelFlag, _ := cmd.Flags().GetString("model") modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel(modelFlag) model := config.GetModel("history", modelFlag)
client := grok.NewClient() client := grok.NewClient()
messages := []map[string]string{ messages := []map[string]string{

View File

@ -116,7 +116,7 @@ func runLint(cmd *cobra.Command, args []string) {
logger.Info("requesting AI fixes", "file", absPath, "original_size", len(originalContent)) logger.Info("requesting AI fixes", "file", absPath, "original_size", len(originalContent))
modelFlag, _ := cmd.Flags().GetString("model") modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel(modelFlag) model := config.GetModel("lint", modelFlag)
client := grok.NewClient() client := grok.NewClient()
messages := buildLintFixMessages(result, string(originalContent)) messages := buildLintFixMessages(result, string(originalContent))

View File

@ -27,7 +27,7 @@ var prDescribeCmd = &cobra.Command{
return return
} }
modelFlag, _ := cmd.Flags().GetString("model") modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel(modelFlag) model := config.GetModel("prdescribe", modelFlag)
client := grok.NewClient() client := grok.NewClient()
messages := []map[string]string{ messages := []map[string]string{

View File

@ -15,7 +15,7 @@ var reviewCmd = &cobra.Command{
Short: "Review the current repository or directory", Short: "Review the current repository or directory",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
modelFlag, _ := cmd.Flags().GetString("model") modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel(modelFlag) model := config.GetModel("review", modelFlag)
client := grok.NewClient() client := grok.NewClient()
diff, err := git.Run([]string{"diff", "--no-color"}) diff, err := git.Run([]string{"diff", "--no-color"})

29
config.toml.example Normal file
View File

@ -0,0 +1,29 @@
# Example configuration file for Grokkit
# Copy this to ~/.config/grokkit/config.toml and customize as needed.
default_model = "grok-4"
temperature = 0.7
log_level = "info"
timeout = 60
# Model aliases (shorthand names)
[aliases]
beta = "grok-beta-2"
fast = "grok-4-1-fast-non-reasoning"
# Per-command model defaults (overrides code defaults if set)
[commands]
lint.model = "grok-4-1-fast-non-reasoning" # Fast model for code fixes
agent.model = "grok-4" # Reasoning model for agent tasks
chat.model = "grok-4"
commit.model = "grok-4"
commitmsg.model = "grok-4"
edit.model = "grok-4-1-fast-non-reasoning"
history.model = "grok-4"
prdescribe.model = "grok-4"
review.model = "grok-4"
# Chat history settings
[chat]
history_file = "~/.config/grokkit/chat_history.json"

View File

@ -26,17 +26,31 @@ func Load() {
viper.SetDefault("log_level", "info") viper.SetDefault("log_level", "info")
viper.SetDefault("timeout", 60) viper.SetDefault("timeout", 60)
viper.SetDefault("commands.agent.model", "grok-4")
viper.SetDefault("commands.chat.model", "grok-4")
viper.SetDefault("commands.commit.model", "grok-4")
viper.SetDefault("commands.commitmsg.model", "grok-4")
viper.SetDefault("commands.edit.model", "grok-4")
viper.SetDefault("commands.history.model", "grok-4")
viper.SetDefault("commands.lint.model", "grok-4-1-fast-non-reasoning")
viper.SetDefault("commands.prdescribe.model", "grok-4")
viper.SetDefault("commands.review.model", "grok-4")
// Config file is optional, so we ignore read errors // Config file is optional, so we ignore read errors
_ = viper.ReadInConfig() _ = viper.ReadInConfig()
} }
func GetModel(flagModel string) string { func GetModel(commandName string, flagModel string) string {
if flagModel != "" { if flagModel != "" {
if alias := viper.GetString("aliases." + flagModel); alias != "" { if alias := viper.GetString("aliases." + flagModel); alias != "" {
return alias return alias
} }
return flagModel return flagModel
} }
cmdModel := viper.GetString("commands." + commandName + ".model")
if cmdModel != "" {
return cmdModel
}
return viper.GetString("default_model") return viper.GetString("default_model")
} }

View File

@ -30,7 +30,7 @@ func TestGetModel(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := GetModel(tt.flagModel) result := GetModel("", tt.flagModel)
if result != tt.expected { if result != tt.expected {
t.Errorf("GetModel(%q) = %q, want %q", tt.flagModel, result, tt.expected) t.Errorf("GetModel(%q) = %q, want %q", tt.flagModel, result, tt.expected)
} }
@ -43,10 +43,54 @@ func TestGetModelWithAlias(t *testing.T) {
viper.Set("aliases.beta", "grok-beta-2") viper.Set("aliases.beta", "grok-beta-2")
viper.SetDefault("default_model", "grok-4") viper.SetDefault("default_model", "grok-4")
result := GetModel("beta") result := GetModel("", "beta")
expected := "grok-beta-2" expected := "grok-beta-2"
if result != expected { if result != expected {
t.Errorf("GetModel('beta') = %q, want %q", result, expected) t.Errorf("GetModel('beta') = %q, want %q", result, expected)
}
}
func TestGetCommandModel(t *testing.T) {
viper.Reset()
viper.SetDefault("default_model", "grok-4")
viper.Set("commands.lint.model", "grok-4-1-fast-non-reasoning")
viper.Set("commands.other.model", "grok-other")
tests := []struct {
command string
flagModel string
expected string
}{
{
command: "lint",
flagModel: "",
expected: "grok-4-1-fast-non-reasoning",
},
{
command: "lint",
flagModel: "override",
expected: "override",
},
{
command: "other",
flagModel: "",
expected: "grok-other",
},
{
command: "unknown",
flagModel: "",
expected: "grok-4",
},
}
for _, tt := range tests {
t.Run(tt.command+"_"+tt.flagModel, func(t *testing.T) {
result := GetModel(tt.command, tt.flagModel)
if result != tt.expected {
t.Errorf("GetModel(%q, %q) = %q, want %q", tt.command, tt.flagModel, result, tt.expected)
}
})
} }
} }