package config import ( "testing" "github.com/spf13/viper" ) func TestGetModel(t *testing.T) { // Reset viper for testing viper.Reset() viper.SetDefault("default_model", "grok-4") tests := []struct { name string flagModel string expected string }{ { name: "returns flag model when provided", flagModel: "grok-beta", expected: "grok-beta", }, { name: "returns default when flag empty", flagModel: "", expected: "grok-4", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := GetModel("", tt.flagModel) if result != tt.expected { t.Errorf("GetModel(%q) = %q, want %q", tt.flagModel, result, tt.expected) } }) } } func TestGetModelWithAlias(t *testing.T) { viper.Reset() viper.Set("aliases.beta", "grok-beta-2") viper.SetDefault("default_model", "grok-4") result := GetModel("", "beta") expected := "grok-beta-2" if result != expected { t.Errorf("GetModel('beta') = %q, want %q", result, expected) } } func TestGetCommandModel(t *testing.T) { viper.Reset() viper.SetDefault("default_model", "grok-4") viper.Set("commands.lint.model", "grok-4-1-fast-non-reasoning") viper.Set("commands.other.model", "grok-other") tests := []struct { command string flagModel string expected string }{ { command: "lint", flagModel: "", expected: "grok-4-1-fast-non-reasoning", }, { command: "lint", flagModel: "override", expected: "override", }, { command: "other", flagModel: "", expected: "grok-other", }, { command: "unknown", flagModel: "", expected: "grok-4", }, } for _, tt := range tests { t.Run(tt.command+"_"+tt.flagModel, func(t *testing.T) { result := GetModel(tt.command, tt.flagModel) if result != tt.expected { t.Errorf("GetModel(%q, %q) = %q, want %q", tt.command, tt.flagModel, result, tt.expected) } }) } } func TestLoad(t *testing.T) { // Just ensure Load doesn't panic Load() } func TestGetTemperature(t *testing.T) { viper.Reset() viper.SetDefault("temperature", 0.7) got := GetTemperature() if got != 0.7 { t.Errorf("GetTemperature() default = %v, want 0.7", got) } viper.Set("temperature", 0.8) got = GetTemperature() if got != 0.8 { t.Errorf("GetTemperature() custom = %v, want 0.8", got) } } func TestGetTimeout(t *testing.T) { viper.Reset() viper.SetDefault("timeout", 60) got := GetTimeout() if got != 60 { t.Errorf("GetTimeout() default = %d, want 60", got) } viper.Set("timeout", 30) got = GetTimeout() if got != 30 { t.Errorf("GetTimeout() = %d, want 30", got) } viper.Set("timeout", 0) got = GetTimeout() if got != 60 { t.Errorf("GetTimeout() invalid = %d, want 60", got) } } func TestGetLogLevel(t *testing.T) { viper.Reset() viper.SetDefault("log_level", "info") got := GetLogLevel() if got != "info" { t.Errorf("GetLogLevel() default = %q, want info", got) } viper.Set("log_level", "debug") got = GetLogLevel() if got != "debug" { t.Errorf("GetLogLevel() custom = %q, want debug", got) } }