diff --git a/cmd/agent.go b/cmd/agent.go index c8a6166..0d91304 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -20,7 +20,7 @@ var agentCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { instruction := args[0] modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel(modelFlag) + model := config.GetModel("agent", modelFlag) client := grok.NewClient() diff --git a/cmd/chat.go b/cmd/chat.go index 4a555cc..e6be2f0 100644 --- a/cmd/chat.go +++ b/cmd/chat.go @@ -63,7 +63,7 @@ var chatCmd = &cobra.Command{ Short: "Simple interactive CLI chat with Grok (full history + streaming)", Run: func(cmd *cobra.Command, args []string) { modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel(modelFlag) + model := config.GetModel("chat", modelFlag) client := grok.NewClient() diff --git a/cmd/commit.go b/cmd/commit.go index 716486a..83aa2c8 100644 --- a/cmd/commit.go +++ b/cmd/commit.go @@ -25,7 +25,7 @@ var commitCmd = &cobra.Command{ return } modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel(modelFlag) + model := config.GetModel("commit", modelFlag) client := grok.NewClient() messages := []map[string]string{ diff --git a/cmd/commitmsg.go b/cmd/commitmsg.go index 3d3f990..e3fd05f 100644 --- a/cmd/commitmsg.go +++ b/cmd/commitmsg.go @@ -24,7 +24,7 @@ var commitMsgCmd = &cobra.Command{ return } modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel(modelFlag) + model := config.GetModel("commitmsg", modelFlag) client := grok.NewClient() messages := []map[string]string{ diff --git a/cmd/edit.go b/cmd/edit.go index 69df99b..2df21d1 100644 --- a/cmd/edit.go +++ b/cmd/edit.go @@ -22,7 +22,7 @@ var editCmd = &cobra.Command{ instruction := args[1] modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel(modelFlag) + model := config.GetModel("edit", modelFlag) logger.Info("edit command started", "file", filePath, diff --git a/cmd/history.go b/cmd/history.go index 2cf3558..0cb498c 100644 --- a/cmd/history.go +++ b/cmd/history.go @@ -23,7 +23,7 @@ var historyCmd = &cobra.Command{ } modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel(modelFlag) + model := config.GetModel("history", modelFlag) client := grok.NewClient() messages := []map[string]string{ diff --git a/cmd/lint.go b/cmd/lint.go index 322aec5..578dd06 100644 --- a/cmd/lint.go +++ b/cmd/lint.go @@ -116,7 +116,7 @@ func runLint(cmd *cobra.Command, args []string) { logger.Info("requesting AI fixes", "file", absPath, "original_size", len(originalContent)) modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel(modelFlag) + model := config.GetModel("lint", modelFlag) client := grok.NewClient() messages := buildLintFixMessages(result, string(originalContent)) diff --git a/cmd/prdescribe.go b/cmd/prdescribe.go index bd83d8a..e237ac7 100644 --- a/cmd/prdescribe.go +++ b/cmd/prdescribe.go @@ -27,7 +27,7 @@ var prDescribeCmd = &cobra.Command{ return } modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel(modelFlag) + model := config.GetModel("prdescribe", modelFlag) client := grok.NewClient() messages := []map[string]string{ diff --git a/cmd/review.go b/cmd/review.go index 98d0ef4..38a2a29 100644 --- a/cmd/review.go +++ b/cmd/review.go @@ -15,7 +15,7 @@ var reviewCmd = &cobra.Command{ Short: "Review the current repository or directory", Run: func(cmd *cobra.Command, args []string) { modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel(modelFlag) + model := config.GetModel("review", modelFlag) client := grok.NewClient() diff, err := git.Run([]string{"diff", "--no-color"}) diff --git a/config.toml.example b/config.toml.example new file mode 100644 index 0000000..1929a0c --- /dev/null +++ b/config.toml.example @@ -0,0 +1,29 @@ +# Example configuration file for Grokkit +# Copy this to ~/.config/grokkit/config.toml and customize as needed. + +default_model = "grok-4" + +temperature = 0.7 +log_level = "info" +timeout = 60 + +# Model aliases (shorthand names) +[aliases] +beta = "grok-beta-2" +fast = "grok-4-1-fast-non-reasoning" + +# Per-command model defaults (overrides code defaults if set) +[commands] + lint.model = "grok-4-1-fast-non-reasoning" # Fast model for code fixes + agent.model = "grok-4" # Reasoning model for agent tasks + chat.model = "grok-4" + commit.model = "grok-4" + commitmsg.model = "grok-4" + edit.model = "grok-4-1-fast-non-reasoning" + history.model = "grok-4" + prdescribe.model = "grok-4" + review.model = "grok-4" + +# Chat history settings +[chat] +history_file = "~/.config/grokkit/chat_history.json" diff --git a/config/config.go b/config/config.go index 8b52dae..84a8f58 100644 --- a/config/config.go +++ b/config/config.go @@ -26,17 +26,31 @@ func Load() { viper.SetDefault("log_level", "info") viper.SetDefault("timeout", 60) + viper.SetDefault("commands.agent.model", "grok-4") + viper.SetDefault("commands.chat.model", "grok-4") + viper.SetDefault("commands.commit.model", "grok-4") + viper.SetDefault("commands.commitmsg.model", "grok-4") + viper.SetDefault("commands.edit.model", "grok-4") + viper.SetDefault("commands.history.model", "grok-4") + viper.SetDefault("commands.lint.model", "grok-4-1-fast-non-reasoning") + viper.SetDefault("commands.prdescribe.model", "grok-4") + viper.SetDefault("commands.review.model", "grok-4") + // Config file is optional, so we ignore read errors _ = viper.ReadInConfig() } -func GetModel(flagModel string) string { +func GetModel(commandName string, flagModel string) string { if flagModel != "" { if alias := viper.GetString("aliases." + flagModel); alias != "" { return alias } return flagModel } + cmdModel := viper.GetString("commands." + commandName + ".model") + if cmdModel != "" { + return cmdModel + } return viper.GetString("default_model") } diff --git a/config/config_test.go b/config/config_test.go index a92dd66..d645079 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -30,7 +30,7 @@ func TestGetModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := GetModel(tt.flagModel) + result := GetModel("", tt.flagModel) if result != tt.expected { t.Errorf("GetModel(%q) = %q, want %q", tt.flagModel, result, tt.expected) } @@ -43,10 +43,54 @@ func TestGetModelWithAlias(t *testing.T) { viper.Set("aliases.beta", "grok-beta-2") viper.SetDefault("default_model", "grok-4") - result := GetModel("beta") + result := GetModel("", "beta") expected := "grok-beta-2" if result != expected { t.Errorf("GetModel('beta') = %q, want %q", result, expected) + + } +} + +func TestGetCommandModel(t *testing.T) { + viper.Reset() + viper.SetDefault("default_model", "grok-4") + viper.Set("commands.lint.model", "grok-4-1-fast-non-reasoning") + viper.Set("commands.other.model", "grok-other") + + tests := []struct { + command string + flagModel string + expected string + }{ + { + command: "lint", + flagModel: "", + expected: "grok-4-1-fast-non-reasoning", + }, + { + command: "lint", + flagModel: "override", + expected: "override", + }, + { + command: "other", + flagModel: "", + expected: "grok-other", + }, + { + command: "unknown", + flagModel: "", + expected: "grok-4", + }, + } + + for _, tt := range tests { + t.Run(tt.command+"_"+tt.flagModel, func(t *testing.T) { + result := GetModel(tt.command, tt.flagModel) + if result != tt.expected { + t.Errorf("GetModel(%q, %q) = %q, want %q", tt.command, tt.flagModel, result, tt.expected) + } + }) } }