refactor(tests): improve error handling and defer usage
All checks were successful
CI / Test (push) Successful in 31s
CI / Lint (push) Successful in 25s
CI / Build (push) Successful in 22s
Release / Create Release (push) Successful in 36s

- Add error checking for os.Setenv and io operations in test files
- Use anonymous functions in defer to ignore errors from os.Remove, os.Setenv, etc.
- Minor formatting and consistency fixes in tests and client code
This commit is contained in:
Greg Gauthier 2026-03-02 21:33:11 +00:00
parent 032301e041
commit 918ccc01c8
8 changed files with 76 additions and 54 deletions

View File

@ -9,10 +9,12 @@ import (
func TestGetChatHistoryFile(t *testing.T) {
// Save original HOME
oldHome := os.Getenv("HOME")
defer os.Setenv("HOME", oldHome)
defer func() { _ = os.Setenv("HOME", oldHome) }()
tmpDir := t.TempDir()
os.Setenv("HOME", tmpDir)
if err := os.Setenv("HOME", tmpDir); err != nil {
t.Fatal(err)
}
histFile := getChatHistoryFile()
expected := filepath.Join(tmpDir, ".config", "grokkit", "chat_history.json")
@ -22,11 +24,18 @@ func TestGetChatHistoryFile(t *testing.T) {
}
}
func setHomeForTest(t *testing.T, dir string) {
t.Helper()
if err := os.Setenv("HOME", dir); err != nil {
t.Fatal(err)
}
}
func TestLoadChatHistory_NoFile(t *testing.T) {
tmpDir := t.TempDir()
oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", oldHome)
setHomeForTest(t, tmpDir)
defer func() { _ = os.Setenv("HOME", oldHome) }()
history := loadChatHistory()
if history != nil {
@ -37,8 +46,8 @@ func TestLoadChatHistory_NoFile(t *testing.T) {
func TestSaveAndLoadChatHistory(t *testing.T) {
tmpDir := t.TempDir()
oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", oldHome)
setHomeForTest(t, tmpDir)
defer func() { _ = os.Setenv("HOME", oldHome) }()
// Create test messages
messages := []map[string]string{
@ -77,8 +86,8 @@ func TestSaveAndLoadChatHistory(t *testing.T) {
func TestLoadChatHistory_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", oldHome)
setHomeForTest(t, tmpDir)
defer func() { _ = os.Setenv("HOME", oldHome) }()
// Create invalid JSON file
histDir := filepath.Join(tmpDir, ".config", "grokkit")

View File

@ -7,9 +7,9 @@ import (
func TestBuildDocsMessages(t *testing.T) {
tests := []struct {
language string
code string
styleCheck string
language string
code string
styleCheck string
}{
{language: "Go", code: "package main\nfunc Foo() {}", styleCheck: "godoc"},
{language: "Python", code: "def foo():\n pass", styleCheck: "PEP 257"},

View File

@ -14,7 +14,7 @@ func TestEditCommand(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())
defer func() { _ = os.Remove(tmpfile.Name()) }()
original := []byte("package main\n\nfunc hello() {}\n")
if err := os.WriteFile(tmpfile.Name(), original, 0644); err != nil {

View File

@ -205,8 +205,12 @@ func TestRunCommit(t *testing.T) {
r, w, _ := os.Pipe()
origStdin := os.Stdin
os.Stdin = r
w.WriteString("n\n")
w.Close()
if _, err := w.WriteString("n\n"); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
defer func() { os.Stdin = origStdin }()
runCommit(testCmd(), nil)
@ -357,8 +361,10 @@ func TestProcessDocsFileUnsupportedLanguage(t *testing.T) {
if err != nil {
t.Fatal(err)
}
f.Close()
defer os.Remove(f.Name())
if err := f.Close(); err != nil {
t.Fatal(err)
}
defer func() { _ = os.Remove(f.Name()) }()
mock := &mockStreamer{}
processDocsFile(mock, "grok-4", f.Name())
@ -376,9 +382,11 @@ func TestProcessDocsFilePreviewAndCancel(t *testing.T) {
if _, err := f.WriteString("package main\n\nfunc Foo() {}\n"); err != nil {
t.Fatal(err)
}
f.Close()
defer os.Remove(f.Name())
defer os.Remove(f.Name() + ".bak")
if err := f.Close(); err != nil {
t.Fatal(err)
}
defer func() { _ = os.Remove(f.Name()) }()
defer func() { _ = os.Remove(f.Name() + ".bak") }()
mock := &mockStreamer{response: "package main\n\n// Foo does nothing.\nfunc Foo() {}\n"}
@ -390,8 +398,12 @@ func TestProcessDocsFilePreviewAndCancel(t *testing.T) {
r, w, _ := os.Pipe()
origStdin := os.Stdin
os.Stdin = r
w.WriteString("n\n")
w.Close()
if _, err := w.WriteString("n\n"); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
defer func() { os.Stdin = origStdin }()
processDocsFile(mock, "grok-4", f.Name())
@ -410,9 +422,11 @@ func TestProcessDocsFileAutoApply(t *testing.T) {
if _, err := f.WriteString(original); err != nil {
t.Fatal(err)
}
f.Close()
defer os.Remove(f.Name())
defer os.Remove(f.Name() + ".bak")
if err := f.Close(); err != nil {
t.Fatal(err)
}
defer func() { _ = os.Remove(f.Name()) }()
defer func() { _ = os.Remove(f.Name() + ".bak") }()
// CleanCodeResponse will trim the trailing newline from the AI response.
aiResponse := "package main\n\n// Bar does nothing.\nfunc Bar() {}\n"

View File

@ -106,7 +106,7 @@ func (c *Client) streamInternal(messages []map[string]string, model string, temp
color.Red("Request failed: %v", err)
os.Exit(1)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
logger.Debug("API response received",
"status", resp.Status,
@ -173,4 +173,4 @@ func CleanCodeResponse(text string) string {
text = strings.TrimSpace(text)
return text
}
}

View File

@ -55,9 +55,9 @@ func sseServer(chunks []string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
for _, c := range chunks {
fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", c)
_, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", c)
}
fmt.Fprintf(w, "data: [DONE]\n\n")
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
}))
}
@ -98,10 +98,10 @@ func TestStreamDoneSignal(t *testing.T) {
// Verifies that [DONE] stops processing and non-content chunks are skipped
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"ok\"}}]}\n\n")
fmt.Fprintf(w, "data: [DONE]\n\n")
_, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"ok\"}}]}\n\n")
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
// This line should never be processed
fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"extra\"}}]}\n\n")
_, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"extra\"}}]}\n\n")
}))
defer srv.Close()
@ -115,7 +115,7 @@ func TestStreamDoneSignal(t *testing.T) {
func TestStreamEmptyResponse(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprintf(w, "data: [DONE]\n\n")
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
}))
defer srv.Close()
@ -131,13 +131,15 @@ func TestNewClient(t *testing.T) {
oldKey := os.Getenv("XAI_API_KEY")
defer func() {
if oldKey != "" {
os.Setenv("XAI_API_KEY", oldKey)
_ = os.Setenv("XAI_API_KEY", oldKey)
} else {
os.Unsetenv("XAI_API_KEY")
_ = os.Unsetenv("XAI_API_KEY")
}
}()
os.Setenv("XAI_API_KEY", "test-key")
if err := os.Setenv("XAI_API_KEY", "test-key"); err != nil {
t.Fatal(err)
}
client := NewClient()
if client.APIKey != "test-key" {

View File

@ -321,7 +321,7 @@ func TestLanguageStructure(t *testing.T) {
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) &&
(s[:len(substr)] == substr || s[len(s)-len(substr):] == substr ||
containsMiddle(s, substr)))
containsMiddle(s, substr)))
}
func containsMiddle(s, substr string) bool {

View File

@ -7,12 +7,18 @@ import (
"testing"
)
func setHome(t *testing.T, dir string) {
t.Helper()
if err := os.Setenv("HOME", dir); err != nil {
t.Fatal(err)
}
}
func TestInit(t *testing.T) {
// Create temp directory
tmpDir := t.TempDir()
oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", oldHome)
setHome(t, tmpDir)
defer func() { _ = os.Setenv("HOME", oldHome) }()
tests := []struct {
name string
@ -32,7 +38,6 @@ func TestInit(t *testing.T) {
t.Errorf("Init(%q) unexpected error: %v", tt.logLevel, err)
}
// Check log file was created
logFile := filepath.Join(tmpDir, ".config", "grokkit", "grokkit.log")
if _, err := os.Stat(logFile); os.IsNotExist(err) {
t.Errorf("Log file not created at %s", logFile)
@ -44,20 +49,18 @@ func TestInit(t *testing.T) {
func TestLogging(t *testing.T) {
tmpDir := t.TempDir()
oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", oldHome)
setHome(t, tmpDir)
defer func() { _ = os.Setenv("HOME", oldHome) }()
if err := Init("debug"); err != nil {
t.Fatalf("Init() failed: %v", err)
}
// Test all log levels with structured fields
Debug("test debug message", "key", "value")
Info("test info message", "count", 42)
Warn("test warn message", "enabled", true)
Error("test error message", "error", "something went wrong")
// Verify log file has content
logFile := filepath.Join(tmpDir, ".config", "grokkit", "grokkit.log")
content, err := os.ReadFile(logFile)
if err != nil {
@ -67,7 +70,6 @@ func TestLogging(t *testing.T) {
t.Errorf("Log file is empty")
}
// Check for JSON structure (slog uses JSON handler)
contentStr := string(content)
if !strings.Contains(contentStr, `"level"`) {
t.Errorf("Log content doesn't contain JSON level field")
@ -80,20 +82,16 @@ func TestLogging(t *testing.T) {
func TestSetLevel(t *testing.T) {
tmpDir := t.TempDir()
oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", oldHome)
setHome(t, tmpDir)
defer func() { _ = os.Setenv("HOME", oldHome) }()
if err := Init("info"); err != nil {
t.Fatalf("Init() failed: %v", err)
}
// Change level to debug
SetLevel("debug")
// Log at debug level
Debug("debug after level change", "test", true)
// Verify log file has the debug message
logFile := filepath.Join(tmpDir, ".config", "grokkit", "grokkit.log")
content, err := os.ReadFile(logFile)
if err != nil {
@ -108,14 +106,13 @@ func TestSetLevel(t *testing.T) {
func TestWith(t *testing.T) {
tmpDir := t.TempDir()
oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", oldHome)
setHome(t, tmpDir)
defer func() { _ = os.Setenv("HOME", oldHome) }()
if err := Init("info"); err != nil {
t.Fatalf("Init() failed: %v", err)
}
// Create logger with context
contextLogger := With("request_id", "123", "user", "testuser")
if contextLogger == nil {
t.Errorf("With() returned nil logger")