feat(config): add per-command model defaults
Introduce support for per-command model defaults in config.toml, overriding global default if set. Update GetModel to accept command name and prioritize: flag > command default > global default. Add example config file and adjust all commands to pass their name. Update tests accordingly.
This commit is contained in:
parent
06917f93d8
commit
24be047322
@ -20,7 +20,7 @@ var agentCmd = &cobra.Command{
|
|||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
instruction := args[0]
|
instruction := args[0]
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
model := config.GetModel(modelFlag)
|
model := config.GetModel("agent", modelFlag)
|
||||||
|
|
||||||
client := grok.NewClient()
|
client := grok.NewClient()
|
||||||
|
|
||||||
|
|||||||
@ -63,7 +63,7 @@ var chatCmd = &cobra.Command{
|
|||||||
Short: "Simple interactive CLI chat with Grok (full history + streaming)",
|
Short: "Simple interactive CLI chat with Grok (full history + streaming)",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
model := config.GetModel(modelFlag)
|
model := config.GetModel("chat", modelFlag)
|
||||||
|
|
||||||
client := grok.NewClient()
|
client := grok.NewClient()
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ var commitCmd = &cobra.Command{
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
model := config.GetModel(modelFlag)
|
model := config.GetModel("commit", modelFlag)
|
||||||
|
|
||||||
client := grok.NewClient()
|
client := grok.NewClient()
|
||||||
messages := []map[string]string{
|
messages := []map[string]string{
|
||||||
|
|||||||
@ -24,7 +24,7 @@ var commitMsgCmd = &cobra.Command{
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
model := config.GetModel(modelFlag)
|
model := config.GetModel("commitmsg", modelFlag)
|
||||||
|
|
||||||
client := grok.NewClient()
|
client := grok.NewClient()
|
||||||
messages := []map[string]string{
|
messages := []map[string]string{
|
||||||
|
|||||||
@ -22,7 +22,7 @@ var editCmd = &cobra.Command{
|
|||||||
instruction := args[1]
|
instruction := args[1]
|
||||||
|
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
model := config.GetModel(modelFlag)
|
model := config.GetModel("edit", modelFlag)
|
||||||
|
|
||||||
logger.Info("edit command started",
|
logger.Info("edit command started",
|
||||||
"file", filePath,
|
"file", filePath,
|
||||||
|
|||||||
@ -23,7 +23,7 @@ var historyCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
model := config.GetModel(modelFlag)
|
model := config.GetModel("history", modelFlag)
|
||||||
|
|
||||||
client := grok.NewClient()
|
client := grok.NewClient()
|
||||||
messages := []map[string]string{
|
messages := []map[string]string{
|
||||||
|
|||||||
@ -116,7 +116,7 @@ func runLint(cmd *cobra.Command, args []string) {
|
|||||||
logger.Info("requesting AI fixes", "file", absPath, "original_size", len(originalContent))
|
logger.Info("requesting AI fixes", "file", absPath, "original_size", len(originalContent))
|
||||||
|
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
model := config.GetModel(modelFlag)
|
model := config.GetModel("lint", modelFlag)
|
||||||
|
|
||||||
client := grok.NewClient()
|
client := grok.NewClient()
|
||||||
messages := buildLintFixMessages(result, string(originalContent))
|
messages := buildLintFixMessages(result, string(originalContent))
|
||||||
|
|||||||
@ -27,7 +27,7 @@ var prDescribeCmd = &cobra.Command{
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
model := config.GetModel(modelFlag)
|
model := config.GetModel("prdescribe", modelFlag)
|
||||||
|
|
||||||
client := grok.NewClient()
|
client := grok.NewClient()
|
||||||
messages := []map[string]string{
|
messages := []map[string]string{
|
||||||
|
|||||||
@ -15,7 +15,7 @@ var reviewCmd = &cobra.Command{
|
|||||||
Short: "Review the current repository or directory",
|
Short: "Review the current repository or directory",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
model := config.GetModel(modelFlag)
|
model := config.GetModel("review", modelFlag)
|
||||||
|
|
||||||
client := grok.NewClient()
|
client := grok.NewClient()
|
||||||
diff, err := git.Run([]string{"diff", "--no-color"})
|
diff, err := git.Run([]string{"diff", "--no-color"})
|
||||||
|
|||||||
29
config.toml.example
Normal file
29
config.toml.example
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Example configuration file for Grokkit
|
||||||
|
# Copy this to ~/.config/grokkit/config.toml and customize as needed.
|
||||||
|
|
||||||
|
default_model = "grok-4"
|
||||||
|
|
||||||
|
temperature = 0.7
|
||||||
|
log_level = "info"
|
||||||
|
timeout = 60
|
||||||
|
|
||||||
|
# Model aliases (shorthand names)
|
||||||
|
[aliases]
|
||||||
|
beta = "grok-beta-2"
|
||||||
|
fast = "grok-4-1-fast-non-reasoning"
|
||||||
|
|
||||||
|
# Per-command model defaults (overrides code defaults if set)
|
||||||
|
[commands]
|
||||||
|
lint.model = "grok-4-1-fast-non-reasoning" # Fast model for code fixes
|
||||||
|
agent.model = "grok-4" # Reasoning model for agent tasks
|
||||||
|
chat.model = "grok-4"
|
||||||
|
commit.model = "grok-4"
|
||||||
|
commitmsg.model = "grok-4"
|
||||||
|
edit.model = "grok-4-1-fast-non-reasoning"
|
||||||
|
history.model = "grok-4"
|
||||||
|
prdescribe.model = "grok-4"
|
||||||
|
review.model = "grok-4"
|
||||||
|
|
||||||
|
# Chat history settings
|
||||||
|
[chat]
|
||||||
|
history_file = "~/.config/grokkit/chat_history.json"
|
||||||
@ -26,17 +26,31 @@ func Load() {
|
|||||||
viper.SetDefault("log_level", "info")
|
viper.SetDefault("log_level", "info")
|
||||||
viper.SetDefault("timeout", 60)
|
viper.SetDefault("timeout", 60)
|
||||||
|
|
||||||
|
viper.SetDefault("commands.agent.model", "grok-4")
|
||||||
|
viper.SetDefault("commands.chat.model", "grok-4")
|
||||||
|
viper.SetDefault("commands.commit.model", "grok-4")
|
||||||
|
viper.SetDefault("commands.commitmsg.model", "grok-4")
|
||||||
|
viper.SetDefault("commands.edit.model", "grok-4")
|
||||||
|
viper.SetDefault("commands.history.model", "grok-4")
|
||||||
|
viper.SetDefault("commands.lint.model", "grok-4-1-fast-non-reasoning")
|
||||||
|
viper.SetDefault("commands.prdescribe.model", "grok-4")
|
||||||
|
viper.SetDefault("commands.review.model", "grok-4")
|
||||||
|
|
||||||
// Config file is optional, so we ignore read errors
|
// Config file is optional, so we ignore read errors
|
||||||
_ = viper.ReadInConfig()
|
_ = viper.ReadInConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetModel(flagModel string) string {
|
func GetModel(commandName string, flagModel string) string {
|
||||||
if flagModel != "" {
|
if flagModel != "" {
|
||||||
if alias := viper.GetString("aliases." + flagModel); alias != "" {
|
if alias := viper.GetString("aliases." + flagModel); alias != "" {
|
||||||
return alias
|
return alias
|
||||||
}
|
}
|
||||||
return flagModel
|
return flagModel
|
||||||
}
|
}
|
||||||
|
cmdModel := viper.GetString("commands." + commandName + ".model")
|
||||||
|
if cmdModel != "" {
|
||||||
|
return cmdModel
|
||||||
|
}
|
||||||
return viper.GetString("default_model")
|
return viper.GetString("default_model")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -30,7 +30,7 @@ func TestGetModel(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := GetModel(tt.flagModel)
|
result := GetModel("", tt.flagModel)
|
||||||
if result != tt.expected {
|
if result != tt.expected {
|
||||||
t.Errorf("GetModel(%q) = %q, want %q", tt.flagModel, result, tt.expected)
|
t.Errorf("GetModel(%q) = %q, want %q", tt.flagModel, result, tt.expected)
|
||||||
}
|
}
|
||||||
@ -43,10 +43,54 @@ func TestGetModelWithAlias(t *testing.T) {
|
|||||||
viper.Set("aliases.beta", "grok-beta-2")
|
viper.Set("aliases.beta", "grok-beta-2")
|
||||||
viper.SetDefault("default_model", "grok-4")
|
viper.SetDefault("default_model", "grok-4")
|
||||||
|
|
||||||
result := GetModel("beta")
|
result := GetModel("", "beta")
|
||||||
expected := "grok-beta-2"
|
expected := "grok-beta-2"
|
||||||
if result != expected {
|
if result != expected {
|
||||||
t.Errorf("GetModel('beta') = %q, want %q", result, expected)
|
t.Errorf("GetModel('beta') = %q, want %q", result, expected)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCommandModel(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
viper.SetDefault("default_model", "grok-4")
|
||||||
|
viper.Set("commands.lint.model", "grok-4-1-fast-non-reasoning")
|
||||||
|
viper.Set("commands.other.model", "grok-other")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
flagModel string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
command: "lint",
|
||||||
|
flagModel: "",
|
||||||
|
expected: "grok-4-1-fast-non-reasoning",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
command: "lint",
|
||||||
|
flagModel: "override",
|
||||||
|
expected: "override",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
command: "other",
|
||||||
|
flagModel: "",
|
||||||
|
expected: "grok-other",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
command: "unknown",
|
||||||
|
flagModel: "",
|
||||||
|
expected: "grok-4",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.command+"_"+tt.flagModel, func(t *testing.T) {
|
||||||
|
result := GetModel(tt.command, tt.flagModel)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetModel(%q, %q) = %q, want %q", tt.command, tt.flagModel, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user