refactor(tests): improve error handling and defer usage

- Add error checking for os.Setenv and io operations in test files
- Use anonymous functions in defer to ignore errors from os.Remove, os.Setenv, etc.
- Minor formatting and consistency fixes in tests and client code
This commit is contained in:
Greg Gauthier 2026-03-02 21:33:11 +00:00
parent 032301e041
commit e7407aa991
8 changed files with 76 additions and 54 deletions

View File

@ -9,10 +9,12 @@ import (
func TestGetChatHistoryFile(t *testing.T) { func TestGetChatHistoryFile(t *testing.T) {
// Save original HOME // Save original HOME
oldHome := os.Getenv("HOME") oldHome := os.Getenv("HOME")
defer os.Setenv("HOME", oldHome) defer func() { _ = os.Setenv("HOME", oldHome) }()
tmpDir := t.TempDir() tmpDir := t.TempDir()
os.Setenv("HOME", tmpDir) if err := os.Setenv("HOME", tmpDir); err != nil {
t.Fatal(err)
}
histFile := getChatHistoryFile() histFile := getChatHistoryFile()
expected := filepath.Join(tmpDir, ".config", "grokkit", "chat_history.json") expected := filepath.Join(tmpDir, ".config", "grokkit", "chat_history.json")
@ -22,11 +24,18 @@ func TestGetChatHistoryFile(t *testing.T) {
} }
} }
func setHomeForTest(t *testing.T, dir string) {
t.Helper()
if err := os.Setenv("HOME", dir); err != nil {
t.Fatal(err)
}
}
func TestLoadChatHistory_NoFile(t *testing.T) { func TestLoadChatHistory_NoFile(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
oldHome := os.Getenv("HOME") oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir) setHomeForTest(t, tmpDir)
defer os.Setenv("HOME", oldHome) defer func() { _ = os.Setenv("HOME", oldHome) }()
history := loadChatHistory() history := loadChatHistory()
if history != nil { if history != nil {
@ -37,8 +46,8 @@ func TestLoadChatHistory_NoFile(t *testing.T) {
func TestSaveAndLoadChatHistory(t *testing.T) { func TestSaveAndLoadChatHistory(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
oldHome := os.Getenv("HOME") oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir) setHomeForTest(t, tmpDir)
defer os.Setenv("HOME", oldHome) defer func() { _ = os.Setenv("HOME", oldHome) }()
// Create test messages // Create test messages
messages := []map[string]string{ messages := []map[string]string{
@ -77,8 +86,8 @@ func TestSaveAndLoadChatHistory(t *testing.T) {
func TestLoadChatHistory_InvalidJSON(t *testing.T) { func TestLoadChatHistory_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
oldHome := os.Getenv("HOME") oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir) setHomeForTest(t, tmpDir)
defer os.Setenv("HOME", oldHome) defer func() { _ = os.Setenv("HOME", oldHome) }()
// Create invalid JSON file // Create invalid JSON file
histDir := filepath.Join(tmpDir, ".config", "grokkit") histDir := filepath.Join(tmpDir, ".config", "grokkit")

View File

@ -7,9 +7,9 @@ import (
func TestBuildDocsMessages(t *testing.T) { func TestBuildDocsMessages(t *testing.T) {
tests := []struct { tests := []struct {
language string language string
code string code string
styleCheck string styleCheck string
}{ }{
{language: "Go", code: "package main\nfunc Foo() {}", styleCheck: "godoc"}, {language: "Go", code: "package main\nfunc Foo() {}", styleCheck: "godoc"},
{language: "Python", code: "def foo():\n pass", styleCheck: "PEP 257"}, {language: "Python", code: "def foo():\n pass", styleCheck: "PEP 257"},

View File

@ -14,7 +14,7 @@ func TestEditCommand(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer os.Remove(tmpfile.Name()) defer func() { _ = os.Remove(tmpfile.Name()) }()
original := []byte("package main\n\nfunc hello() {}\n") original := []byte("package main\n\nfunc hello() {}\n")
if err := os.WriteFile(tmpfile.Name(), original, 0644); err != nil { if err := os.WriteFile(tmpfile.Name(), original, 0644); err != nil {

View File

@ -205,8 +205,12 @@ func TestRunCommit(t *testing.T) {
r, w, _ := os.Pipe() r, w, _ := os.Pipe()
origStdin := os.Stdin origStdin := os.Stdin
os.Stdin = r os.Stdin = r
w.WriteString("n\n") if _, err := w.WriteString("n\n"); err != nil {
w.Close() t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
defer func() { os.Stdin = origStdin }() defer func() { os.Stdin = origStdin }()
runCommit(testCmd(), nil) runCommit(testCmd(), nil)
@ -357,8 +361,10 @@ func TestProcessDocsFileUnsupportedLanguage(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
f.Close() if err := f.Close(); err != nil {
defer os.Remove(f.Name()) t.Fatal(err)
}
defer func() { _ = os.Remove(f.Name()) }()
mock := &mockStreamer{} mock := &mockStreamer{}
processDocsFile(mock, "grok-4", f.Name()) processDocsFile(mock, "grok-4", f.Name())
@ -376,9 +382,11 @@ func TestProcessDocsFilePreviewAndCancel(t *testing.T) {
if _, err := f.WriteString("package main\n\nfunc Foo() {}\n"); err != nil { if _, err := f.WriteString("package main\n\nfunc Foo() {}\n"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
f.Close() if err := f.Close(); err != nil {
defer os.Remove(f.Name()) t.Fatal(err)
defer os.Remove(f.Name() + ".bak") }
defer func() { _ = os.Remove(f.Name()) }()
defer func() { _ = os.Remove(f.Name() + ".bak") }()
mock := &mockStreamer{response: "package main\n\n// Foo does nothing.\nfunc Foo() {}\n"} mock := &mockStreamer{response: "package main\n\n// Foo does nothing.\nfunc Foo() {}\n"}
@ -390,8 +398,12 @@ func TestProcessDocsFilePreviewAndCancel(t *testing.T) {
r, w, _ := os.Pipe() r, w, _ := os.Pipe()
origStdin := os.Stdin origStdin := os.Stdin
os.Stdin = r os.Stdin = r
w.WriteString("n\n") if _, err := w.WriteString("n\n"); err != nil {
w.Close() t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
defer func() { os.Stdin = origStdin }() defer func() { os.Stdin = origStdin }()
processDocsFile(mock, "grok-4", f.Name()) processDocsFile(mock, "grok-4", f.Name())
@ -410,9 +422,11 @@ func TestProcessDocsFileAutoApply(t *testing.T) {
if _, err := f.WriteString(original); err != nil { if _, err := f.WriteString(original); err != nil {
t.Fatal(err) t.Fatal(err)
} }
f.Close() if err := f.Close(); err != nil {
defer os.Remove(f.Name()) t.Fatal(err)
defer os.Remove(f.Name() + ".bak") }
defer func() { _ = os.Remove(f.Name()) }()
defer func() { _ = os.Remove(f.Name() + ".bak") }()
// CleanCodeResponse will trim the trailing newline from the AI response. // CleanCodeResponse will trim the trailing newline from the AI response.
aiResponse := "package main\n\n// Bar does nothing.\nfunc Bar() {}\n" aiResponse := "package main\n\n// Bar does nothing.\nfunc Bar() {}\n"

View File

@ -106,7 +106,7 @@ func (c *Client) streamInternal(messages []map[string]string, model string, temp
color.Red("Request failed: %v", err) color.Red("Request failed: %v", err)
os.Exit(1) os.Exit(1)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
logger.Debug("API response received", logger.Debug("API response received",
"status", resp.Status, "status", resp.Status,
@ -173,4 +173,4 @@ func CleanCodeResponse(text string) string {
text = strings.TrimSpace(text) text = strings.TrimSpace(text)
return text return text
} }

View File

@ -55,9 +55,9 @@ func sseServer(chunks []string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
for _, c := range chunks { for _, c := range chunks {
fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", c) _, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", c)
} }
fmt.Fprintf(w, "data: [DONE]\n\n") _, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
})) }))
} }
@ -98,10 +98,10 @@ func TestStreamDoneSignal(t *testing.T) {
// Verifies that [DONE] stops processing and non-content chunks are skipped // Verifies that [DONE] stops processing and non-content chunks are skipped
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"ok\"}}]}\n\n") _, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"ok\"}}]}\n\n")
fmt.Fprintf(w, "data: [DONE]\n\n") _, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
// This line should never be processed // This line should never be processed
fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"extra\"}}]}\n\n") _, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"extra\"}}]}\n\n")
})) }))
defer srv.Close() defer srv.Close()
@ -115,7 +115,7 @@ func TestStreamDoneSignal(t *testing.T) {
func TestStreamEmptyResponse(t *testing.T) { func TestStreamEmptyResponse(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprintf(w, "data: [DONE]\n\n") _, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
})) }))
defer srv.Close() defer srv.Close()
@ -131,13 +131,15 @@ func TestNewClient(t *testing.T) {
oldKey := os.Getenv("XAI_API_KEY") oldKey := os.Getenv("XAI_API_KEY")
defer func() { defer func() {
if oldKey != "" { if oldKey != "" {
os.Setenv("XAI_API_KEY", oldKey) _ = os.Setenv("XAI_API_KEY", oldKey)
} else { } else {
os.Unsetenv("XAI_API_KEY") _ = os.Unsetenv("XAI_API_KEY")
} }
}() }()
os.Setenv("XAI_API_KEY", "test-key") if err := os.Setenv("XAI_API_KEY", "test-key"); err != nil {
t.Fatal(err)
}
client := NewClient() client := NewClient()
if client.APIKey != "test-key" { if client.APIKey != "test-key" {

View File

@ -321,7 +321,7 @@ func TestLanguageStructure(t *testing.T) {
func contains(s, substr string) bool { func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && return len(s) >= len(substr) && (s == substr || len(s) > len(substr) &&
(s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr ||
containsMiddle(s, substr))) containsMiddle(s, substr)))
} }
func containsMiddle(s, substr string) bool { func containsMiddle(s, substr string) bool {

View File

@ -7,12 +7,18 @@ import (
"testing" "testing"
) )
func setHome(t *testing.T, dir string) {
t.Helper()
if err := os.Setenv("HOME", dir); err != nil {
t.Fatal(err)
}
}
func TestInit(t *testing.T) { func TestInit(t *testing.T) {
// Create temp directory
tmpDir := t.TempDir() tmpDir := t.TempDir()
oldHome := os.Getenv("HOME") oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir) setHome(t, tmpDir)
defer os.Setenv("HOME", oldHome) defer func() { _ = os.Setenv("HOME", oldHome) }()
tests := []struct { tests := []struct {
name string name string
@ -32,7 +38,6 @@ func TestInit(t *testing.T) {
t.Errorf("Init(%q) unexpected error: %v", tt.logLevel, err) t.Errorf("Init(%q) unexpected error: %v", tt.logLevel, err)
} }
// Check log file was created
logFile := filepath.Join(tmpDir, ".config", "grokkit", "grokkit.log") logFile := filepath.Join(tmpDir, ".config", "grokkit", "grokkit.log")
if _, err := os.Stat(logFile); os.IsNotExist(err) { if _, err := os.Stat(logFile); os.IsNotExist(err) {
t.Errorf("Log file not created at %s", logFile) t.Errorf("Log file not created at %s", logFile)
@ -44,20 +49,18 @@ func TestInit(t *testing.T) {
func TestLogging(t *testing.T) { func TestLogging(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
oldHome := os.Getenv("HOME") oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir) setHome(t, tmpDir)
defer os.Setenv("HOME", oldHome) defer func() { _ = os.Setenv("HOME", oldHome) }()
if err := Init("debug"); err != nil { if err := Init("debug"); err != nil {
t.Fatalf("Init() failed: %v", err) t.Fatalf("Init() failed: %v", err)
} }
// Test all log levels with structured fields
Debug("test debug message", "key", "value") Debug("test debug message", "key", "value")
Info("test info message", "count", 42) Info("test info message", "count", 42)
Warn("test warn message", "enabled", true) Warn("test warn message", "enabled", true)
Error("test error message", "error", "something went wrong") Error("test error message", "error", "something went wrong")
// Verify log file has content
logFile := filepath.Join(tmpDir, ".config", "grokkit", "grokkit.log") logFile := filepath.Join(tmpDir, ".config", "grokkit", "grokkit.log")
content, err := os.ReadFile(logFile) content, err := os.ReadFile(logFile)
if err != nil { if err != nil {
@ -67,7 +70,6 @@ func TestLogging(t *testing.T) {
t.Errorf("Log file is empty") t.Errorf("Log file is empty")
} }
// Check for JSON structure (slog uses JSON handler)
contentStr := string(content) contentStr := string(content)
if !strings.Contains(contentStr, `"level"`) { if !strings.Contains(contentStr, `"level"`) {
t.Errorf("Log content doesn't contain JSON level field") t.Errorf("Log content doesn't contain JSON level field")
@ -80,20 +82,16 @@ func TestLogging(t *testing.T) {
func TestSetLevel(t *testing.T) { func TestSetLevel(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
oldHome := os.Getenv("HOME") oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir) setHome(t, tmpDir)
defer os.Setenv("HOME", oldHome) defer func() { _ = os.Setenv("HOME", oldHome) }()
if err := Init("info"); err != nil { if err := Init("info"); err != nil {
t.Fatalf("Init() failed: %v", err) t.Fatalf("Init() failed: %v", err)
} }
// Change level to debug
SetLevel("debug") SetLevel("debug")
// Log at debug level
Debug("debug after level change", "test", true) Debug("debug after level change", "test", true)
// Verify log file has the debug message
logFile := filepath.Join(tmpDir, ".config", "grokkit", "grokkit.log") logFile := filepath.Join(tmpDir, ".config", "grokkit", "grokkit.log")
content, err := os.ReadFile(logFile) content, err := os.ReadFile(logFile)
if err != nil { if err != nil {
@ -108,14 +106,13 @@ func TestSetLevel(t *testing.T) {
func TestWith(t *testing.T) { func TestWith(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
oldHome := os.Getenv("HOME") oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir) setHome(t, tmpDir)
defer os.Setenv("HOME", oldHome) defer func() { _ = os.Setenv("HOME", oldHome) }()
if err := Init("info"); err != nil { if err := Init("info"); err != nil {
t.Fatalf("Init() failed: %v", err) t.Fatalf("Init() failed: %v", err)
} }
// Create logger with context
contextLogger := With("request_id", "123", "user", "testuser") contextLogger := With("request_id", "123", "user", "testuser")
if contextLogger == nil { if contextLogger == nil {
t.Errorf("With() returned nil logger") t.Errorf("With() returned nil logger")