Introduce support for per-command model defaults in config.toml, overriding global default if set. Update GetModel to accept command name and prioritize: flag > command default > global default. Add example config file and adjust all commands to pass their name. Update tests accordingly.
72 lines
1.7 KiB
Go
72 lines
1.7 KiB
Go
package config
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
func Load() {
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
// Fall back to current directory if home not found
|
|
home = "."
|
|
}
|
|
configPath := filepath.Join(home, ".config", "grokkit")
|
|
|
|
viper.SetConfigName("config")
|
|
viper.SetConfigType("toml")
|
|
viper.AddConfigPath(configPath)
|
|
viper.AddConfigPath(".")
|
|
viper.AutomaticEnv()
|
|
|
|
viper.SetDefault("default_model", "grok-4")
|
|
viper.SetDefault("temperature", 0.7)
|
|
viper.SetDefault("log_level", "info")
|
|
viper.SetDefault("timeout", 60)
|
|
|
|
viper.SetDefault("commands.agent.model", "grok-4")
|
|
viper.SetDefault("commands.chat.model", "grok-4")
|
|
viper.SetDefault("commands.commit.model", "grok-4")
|
|
viper.SetDefault("commands.commitmsg.model", "grok-4")
|
|
viper.SetDefault("commands.edit.model", "grok-4")
|
|
viper.SetDefault("commands.history.model", "grok-4")
|
|
viper.SetDefault("commands.lint.model", "grok-4-1-fast-non-reasoning")
|
|
viper.SetDefault("commands.prdescribe.model", "grok-4")
|
|
viper.SetDefault("commands.review.model", "grok-4")
|
|
|
|
// Config file is optional, so we ignore read errors
|
|
_ = viper.ReadInConfig()
|
|
}
|
|
|
|
func GetModel(commandName string, flagModel string) string {
|
|
if flagModel != "" {
|
|
if alias := viper.GetString("aliases." + flagModel); alias != "" {
|
|
return alias
|
|
}
|
|
return flagModel
|
|
}
|
|
cmdModel := viper.GetString("commands." + commandName + ".model")
|
|
if cmdModel != "" {
|
|
return cmdModel
|
|
}
|
|
return viper.GetString("default_model")
|
|
}
|
|
|
|
func GetTemperature() float64 {
|
|
return viper.GetFloat64("temperature")
|
|
}
|
|
|
|
func GetTimeout() int {
|
|
timeout := viper.GetInt("timeout")
|
|
if timeout <= 0 {
|
|
return 60 // Default 60 seconds
|
|
}
|
|
return timeout
|
|
}
|
|
|
|
func GetLogLevel() string {
|
|
return viper.GetString("log_level")
|
|
}
|