From 99ef10b16b13b1338603bddaa5bde1ab2d00cc97 Mon Sep 17 00:00:00 2001 From: Greg Gauthier Date: Mon, 2 Mar 2026 20:47:16 +0000 Subject: [PATCH] refactor(cmd): extract run funcs and add injectable deps for testability - Introduce newGrokClient and gitRun vars to allow mocking in tests. - Refactor commit, commitmsg, history, prdescribe, and review cmds into separate run funcs. - Update docs, lint, and review to use newGrokClient. - Add comprehensive unit tests in run_test.go covering happy paths, errors, and edge cases. - Expand grok client tests with SSE server mocks for Stream* methods. --- cmd/client.go | 16 ++ cmd/commit.go | 70 +++--- cmd/commitmsg.go | 45 ++-- cmd/docs.go | 4 +- cmd/history.go | 38 +-- cmd/lint.go | 2 +- cmd/prdescribe.go | 40 ++-- cmd/review.go | 40 ++-- cmd/run_test.go | 453 +++++++++++++++++++++++++++++++++++ internal/grok/client_test.go | 78 ++++++ 10 files changed, 664 insertions(+), 122 deletions(-) create mode 100644 cmd/client.go create mode 100644 cmd/run_test.go diff --git a/cmd/client.go b/cmd/client.go new file mode 100644 index 0000000..b965025 --- /dev/null +++ b/cmd/client.go @@ -0,0 +1,16 @@ +package cmd + +import ( + "gmgauthier.com/grokkit/internal/git" + "gmgauthier.com/grokkit/internal/grok" +) + +// newGrokClient is the factory for the AI client. +// Tests replace this to inject a mock without making real API calls. +var newGrokClient = func() grok.AIClient { + return grok.NewClient() +} + +// gitRun is the git command runner. +// Tests replace this to inject controlled git output. +var gitRun = git.Run diff --git a/cmd/commit.go b/cmd/commit.go index 3fec2b2..a549592 100644 --- a/cmd/commit.go +++ b/cmd/commit.go @@ -7,49 +7,49 @@ import ( "github.com/fatih/color" "github.com/spf13/cobra" "gmgauthier.com/grokkit/config" - "gmgauthier.com/grokkit/internal/git" - "gmgauthier.com/grokkit/internal/grok" ) var commitCmd = &cobra.Command{ Use: "commit", Short: "Generate message and commit staged changes", - Run: func(cmd *cobra.Command, args []string) { - diff, err := git.Run([]string{"diff", "--cached", "--no-color"}) - if err != nil { - color.Red("Failed to get staged changes: %v", err) - return - } - if diff == "" { - color.Yellow("No staged changes!") - return - } - modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel("commit", modelFlag) + Run: runCommit, +} - client := grok.NewClient() - messages := buildCommitMessages(diff) - color.Yellow("Generating commit message...") - msg := client.Stream(messages, model) +func runCommit(cmd *cobra.Command, args []string) { + diff, err := gitRun([]string{"diff", "--cached", "--no-color"}) + if err != nil { + color.Red("Failed to get staged changes: %v", err) + return + } + if diff == "" { + color.Yellow("No staged changes!") + return + } + modelFlag, _ := cmd.Flags().GetString("model") + model := config.GetModel("commit", modelFlag) - color.Cyan("\nProposed commit message:\n%s", msg) - var confirm string - color.Yellow("Commit with this message? (y/n): ") - if _, err := fmt.Scanln(&confirm); err != nil { - color.Red("Failed to read input: %v", err) - return - } - if confirm != "y" && confirm != "Y" { - color.Yellow("Aborted.") - return - } + client := newGrokClient() + messages := buildCommitMessages(diff) + color.Yellow("Generating commit message...") + msg := client.Stream(messages, model) - if err := exec.Command("git", "commit", "-m", msg).Run(); err != nil { - color.Red("Git commit failed") - } else { - color.Green("✅ Committed successfully!") - } - }, + color.Cyan("\nProposed commit message:\n%s", msg) + var confirm string + color.Yellow("Commit with this message? (y/n): ") + if _, err := fmt.Scanln(&confirm); err != nil { + color.Red("Failed to read input: %v", err) + return + } + if confirm != "y" && confirm != "Y" { + color.Yellow("Aborted.") + return + } + + if err := exec.Command("git", "commit", "-m", msg).Run(); err != nil { + color.Red("Git commit failed") + } else { + color.Green("✅ Committed successfully!") + } } func buildCommitMessages(diff string) []map[string]string { diff --git a/cmd/commitmsg.go b/cmd/commitmsg.go index e3fd05f..e736088 100644 --- a/cmd/commitmsg.go +++ b/cmd/commitmsg.go @@ -1,37 +1,32 @@ package cmd import ( - "fmt" - "github.com/fatih/color" "github.com/spf13/cobra" "gmgauthier.com/grokkit/config" - "gmgauthier.com/grokkit/internal/git" - "gmgauthier.com/grokkit/internal/grok" ) var commitMsgCmd = &cobra.Command{ Use: "commit-msg", Short: "Generate conventional commit message from staged changes", - Run: func(cmd *cobra.Command, args []string) { - diff, err := git.Run([]string{"diff", "--cached", "--no-color"}) - if err != nil { - color.Red("Failed to get staged changes: %v", err) - return - } - if diff == "" { - color.Yellow("No staged changes!") - return - } - modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel("commitmsg", modelFlag) - - client := grok.NewClient() - messages := []map[string]string{ - {"role": "system", "content": "Return ONLY a conventional commit message (type(scope): subject\n\nbody)."}, - {"role": "user", "content": fmt.Sprintf("Staged changes:\n%s", diff)}, - } - color.Yellow("Generating commit message...") - client.Stream(messages, model) - }, + Run: runCommitMsg, +} + +func runCommitMsg(cmd *cobra.Command, args []string) { + diff, err := gitRun([]string{"diff", "--cached", "--no-color"}) + if err != nil { + color.Red("Failed to get staged changes: %v", err) + return + } + if diff == "" { + color.Yellow("No staged changes!") + return + } + modelFlag, _ := cmd.Flags().GetString("model") + model := config.GetModel("commitmsg", modelFlag) + + client := newGrokClient() + messages := buildCommitMessages(diff) + color.Yellow("Generating commit message...") + client.Stream(messages, model) } diff --git a/cmd/docs.go b/cmd/docs.go index f9f92f5..cc253e1 100644 --- a/cmd/docs.go +++ b/cmd/docs.go @@ -47,13 +47,13 @@ func runDocs(cmd *cobra.Command, args []string) { modelFlag, _ := cmd.Flags().GetString("model") model := config.GetModel("docs", modelFlag) - client := grok.NewClient() + client := newGrokClient() for _, filePath := range args { processDocsFile(client, model, filePath) } } -func processDocsFile(client *grok.Client, model, filePath string) { +func processDocsFile(client grok.AIClient, model, filePath string) { logger.Info("starting docs operation", "file", filePath) if _, err := os.Stat(filePath); os.IsNotExist(err) { diff --git a/cmd/history.go b/cmd/history.go index 2ea064b..a11100d 100644 --- a/cmd/history.go +++ b/cmd/history.go @@ -4,32 +4,32 @@ import ( "github.com/fatih/color" "github.com/spf13/cobra" "gmgauthier.com/grokkit/config" - "gmgauthier.com/grokkit/internal/git" - "gmgauthier.com/grokkit/internal/grok" ) var historyCmd = &cobra.Command{ Use: "history", Short: "Summarize recent git history", - Run: func(cmd *cobra.Command, args []string) { - log, err := git.Run([]string{"log", "--oneline", "-10"}) - if err != nil { - color.Red("Failed to get git log: %v", err) - return - } - if log == "" { - color.Yellow("No commits found.") - return - } + Run: runHistory, +} - modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel("history", modelFlag) +func runHistory(cmd *cobra.Command, args []string) { + log, err := gitRun([]string{"log", "--oneline", "-10"}) + if err != nil { + color.Red("Failed to get git log: %v", err) + return + } + if log == "" { + color.Yellow("No commits found.") + return + } - client := grok.NewClient() - messages := buildHistoryMessages(log) - color.Yellow("Summarizing recent commits...") - client.Stream(messages, model) - }, + modelFlag, _ := cmd.Flags().GetString("model") + model := config.GetModel("history", modelFlag) + + client := newGrokClient() + messages := buildHistoryMessages(log) + color.Yellow("Summarizing recent commits...") + client.Stream(messages, model) } func buildHistoryMessages(log string) []map[string]string { diff --git a/cmd/lint.go b/cmd/lint.go index 578dd06..47a4baf 100644 --- a/cmd/lint.go +++ b/cmd/lint.go @@ -118,7 +118,7 @@ func runLint(cmd *cobra.Command, args []string) { modelFlag, _ := cmd.Flags().GetString("model") model := config.GetModel("lint", modelFlag) - client := grok.NewClient() + client := newGrokClient() messages := buildLintFixMessages(result, string(originalContent)) response := client.StreamSilent(messages, model) diff --git a/cmd/prdescribe.go b/cmd/prdescribe.go index 038ff5f..d489801 100644 --- a/cmd/prdescribe.go +++ b/cmd/prdescribe.go @@ -6,34 +6,34 @@ import ( "github.com/fatih/color" "github.com/spf13/cobra" "gmgauthier.com/grokkit/config" - "gmgauthier.com/grokkit/internal/git" - "gmgauthier.com/grokkit/internal/grok" ) var prDescribeCmd = &cobra.Command{ Use: "pr-describe", Short: "Generate full PR description from current branch", - Run: func(cmd *cobra.Command, args []string) { - diff, err := git.Run([]string{"diff", "main..HEAD", "--no-color"}) - if err != nil || diff == "" { - diff, err = git.Run([]string{"diff", "origin/main..HEAD", "--no-color"}) - if err != nil { - color.Red("Failed to get branch diff: %v", err) - return - } - } - if diff == "" { - color.Yellow("No changes on this branch compared to main/origin/main.") + Run: runPRDescribe, +} + +func runPRDescribe(cmd *cobra.Command, args []string) { + diff, err := gitRun([]string{"diff", "main..HEAD", "--no-color"}) + if err != nil || diff == "" { + diff, err = gitRun([]string{"diff", "origin/main..HEAD", "--no-color"}) + if err != nil { + color.Red("Failed to get branch diff: %v", err) return } - modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel("prdescribe", modelFlag) + } + if diff == "" { + color.Yellow("No changes on this branch compared to main/origin/main.") + return + } + modelFlag, _ := cmd.Flags().GetString("model") + model := config.GetModel("prdescribe", modelFlag) - client := grok.NewClient() - messages := buildPRDescribeMessages(diff) - color.Yellow("Writing PR description...") - client.Stream(messages, model) - }, + client := newGrokClient() + messages := buildPRDescribeMessages(diff) + color.Yellow("Writing PR description...") + client.Stream(messages, model) } func buildPRDescribeMessages(diff string) []map[string]string { diff --git a/cmd/review.go b/cmd/review.go index 1194c54..e48ba36 100644 --- a/cmd/review.go +++ b/cmd/review.go @@ -6,33 +6,33 @@ import ( "github.com/fatih/color" "github.com/spf13/cobra" "gmgauthier.com/grokkit/config" - "gmgauthier.com/grokkit/internal/git" - "gmgauthier.com/grokkit/internal/grok" ) var reviewCmd = &cobra.Command{ Use: "review [path]", Short: "Review the current repository or directory", - Run: func(cmd *cobra.Command, args []string) { - modelFlag, _ := cmd.Flags().GetString("model") - model := config.GetModel("review", modelFlag) + Run: runReview, +} - client := grok.NewClient() - diff, err := git.Run([]string{"diff", "--no-color"}) - if err != nil { - color.Red("Failed to get git diff: %v", err) - return - } - status, err := git.Run([]string{"status", "--short"}) - if err != nil { - color.Red("Failed to get git status: %v", err) - return - } +func runReview(cmd *cobra.Command, args []string) { + modelFlag, _ := cmd.Flags().GetString("model") + model := config.GetModel("review", modelFlag) - messages := buildReviewMessages(status, diff) - color.Yellow("Grok is reviewing the repo...") - client.Stream(messages, model) - }, + diff, err := gitRun([]string{"diff", "--no-color"}) + if err != nil { + color.Red("Failed to get git diff: %v", err) + return + } + status, err := gitRun([]string{"status", "--short"}) + if err != nil { + color.Red("Failed to get git status: %v", err) + return + } + + client := newGrokClient() + messages := buildReviewMessages(status, diff) + color.Yellow("Grok is reviewing the repo...") + client.Stream(messages, model) } func buildReviewMessages(status, diff string) []map[string]string { diff --git a/cmd/run_test.go b/cmd/run_test.go new file mode 100644 index 0000000..9308e2a --- /dev/null +++ b/cmd/run_test.go @@ -0,0 +1,453 @@ +package cmd + +import ( + "errors" + "os" + "testing" + + "github.com/spf13/cobra" + "gmgauthier.com/grokkit/internal/grok" +) + +// mockStreamer records calls made to it and returns a canned response. +type mockStreamer struct { + response string + calls int +} + +func (m *mockStreamer) Stream(messages []map[string]string, model string) string { + m.calls++ + return m.response +} + +func (m *mockStreamer) StreamWithTemp(messages []map[string]string, model string, temp float64) string { + m.calls++ + return m.response +} + +func (m *mockStreamer) StreamSilent(messages []map[string]string, model string) string { + m.calls++ + return m.response +} + +// Ensure mockStreamer satisfies the interface at compile time. +var _ grok.AIClient = (*mockStreamer)(nil) + +// withMockClient injects mock and returns a restore func. +func withMockClient(mock grok.AIClient) func() { + orig := newGrokClient + newGrokClient = func() grok.AIClient { return mock } + return func() { newGrokClient = orig } +} + +// withMockGit injects a fake git runner and returns a restore func. +func withMockGit(fn func([]string) (string, error)) func() { + orig := gitRun + gitRun = fn + return func() { gitRun = orig } +} + +// testCmd returns a minimal cobra command with the model flag registered. +func testCmd() *cobra.Command { + c := &cobra.Command{} + c.Flags().String("model", "", "") + return c +} + +// --- runHistory --- + +func TestRunHistory(t *testing.T) { + t.Run("calls AI with log output", func(t *testing.T) { + mock := &mockStreamer{response: "3 commits: feat, fix, chore"} + defer withMockClient(mock)() + defer withMockGit(func(args []string) (string, error) { + return "abc1234 feat: add thing\ndef5678 fix: bug", nil + })() + + runHistory(testCmd(), nil) + + if mock.calls != 1 { + t.Errorf("expected 1 AI call, got %d", mock.calls) + } + }) + + t.Run("no commits — skips AI", func(t *testing.T) { + mock := &mockStreamer{} + defer withMockClient(mock)() + defer withMockGit(func(args []string) (string, error) { + return "", nil + })() + + runHistory(testCmd(), nil) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls, got %d", mock.calls) + } + }) + + t.Run("git error — skips AI", func(t *testing.T) { + mock := &mockStreamer{} + defer withMockClient(mock)() + defer withMockGit(func(args []string) (string, error) { + return "", errors.New("not a git repo") + })() + + runHistory(testCmd(), nil) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls, got %d", mock.calls) + } + }) +} + +// --- runReview --- + +func TestRunReview(t *testing.T) { + t.Run("reviews with diff and status", func(t *testing.T) { + mock := &mockStreamer{response: "looks good"} + defer withMockClient(mock)() + callCount := 0 + defer withMockGit(func(args []string) (string, error) { + callCount++ + switch args[0] { + case "diff": + return "diff content", nil + case "status": + return "M main.go", nil + } + return "", nil + })() + + runReview(testCmd(), nil) + + if mock.calls != 1 { + t.Errorf("expected 1 AI call, got %d", mock.calls) + } + if callCount != 2 { + t.Errorf("expected 2 git calls (diff + status), got %d", callCount) + } + }) + + t.Run("git diff error — skips AI", func(t *testing.T) { + mock := &mockStreamer{} + defer withMockClient(mock)() + defer withMockGit(func(args []string) (string, error) { + if args[0] == "diff" { + return "", errors.New("git error") + } + return "", nil + })() + + runReview(testCmd(), nil) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls, got %d", mock.calls) + } + }) + + t.Run("git status error — skips AI", func(t *testing.T) { + mock := &mockStreamer{} + defer withMockClient(mock)() + defer withMockGit(func(args []string) (string, error) { + if args[0] == "status" { + return "", errors.New("git error") + } + return "diff content", nil + })() + + runReview(testCmd(), nil) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls, got %d", mock.calls) + } + }) +} + +// --- runCommit --- + +func TestRunCommit(t *testing.T) { + t.Run("no staged changes — skips AI", func(t *testing.T) { + mock := &mockStreamer{} + defer withMockClient(mock)() + defer withMockGit(func(args []string) (string, error) { + return "", nil + })() + + runCommit(testCmd(), nil) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls, got %d", mock.calls) + } + }) + + t.Run("git error — skips AI", func(t *testing.T) { + mock := &mockStreamer{} + defer withMockClient(mock)() + defer withMockGit(func(args []string) (string, error) { + return "", errors.New("not a git repo") + })() + + runCommit(testCmd(), nil) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls, got %d", mock.calls) + } + }) + + t.Run("with staged changes — calls AI then cancels via stdin", func(t *testing.T) { + mock := &mockStreamer{response: "feat(cmd): add thing"} + defer withMockClient(mock)() + defer withMockGit(func(args []string) (string, error) { + return "diff --git a/foo.go b/foo.go\n+func bar() {}", nil + })() + + // Pipe "n\n" so the confirmation prompt returns without committing. + r, w, _ := os.Pipe() + origStdin := os.Stdin + os.Stdin = r + w.WriteString("n\n") + w.Close() + defer func() { os.Stdin = origStdin }() + + runCommit(testCmd(), nil) + + if mock.calls != 1 { + t.Errorf("expected 1 AI call, got %d", mock.calls) + } + }) +} + +// --- runCommitMsg --- + +func TestRunCommitMsg(t *testing.T) { + t.Run("no staged changes — skips AI", func(t *testing.T) { + mock := &mockStreamer{} + defer withMockClient(mock)() + defer withMockGit(func(args []string) (string, error) { + return "", nil + })() + + runCommitMsg(testCmd(), nil) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls, got %d", mock.calls) + } + }) + + t.Run("with staged changes — calls AI and prints message", func(t *testing.T) { + mock := &mockStreamer{response: "feat(api): add endpoint"} + defer withMockClient(mock)() + defer withMockGit(func(args []string) (string, error) { + return "diff --git a/api.go b/api.go", nil + })() + + runCommitMsg(testCmd(), nil) + + if mock.calls != 1 { + t.Errorf("expected 1 AI call, got %d", mock.calls) + } + }) +} + +// --- runPRDescribe --- + +func TestRunPRDescribe(t *testing.T) { + t.Run("no changes on branch — skips AI", func(t *testing.T) { + mock := &mockStreamer{} + defer withMockClient(mock)() + defer withMockGit(func(args []string) (string, error) { + return "", nil // both diff calls return empty + })() + + runPRDescribe(testCmd(), nil) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls, got %d", mock.calls) + } + }) + + t.Run("first diff succeeds — calls AI", func(t *testing.T) { + mock := &mockStreamer{response: "## PR Title\n\nDescription"} + defer withMockClient(mock)() + callCount := 0 + defer withMockGit(func(args []string) (string, error) { + callCount++ + if callCount == 1 { + return "diff --git a/foo.go b/foo.go", nil + } + return "", nil + })() + + runPRDescribe(testCmd(), nil) + + if mock.calls != 1 { + t.Errorf("expected 1 AI call, got %d", mock.calls) + } + }) + + t.Run("first diff empty, second succeeds — calls AI", func(t *testing.T) { + mock := &mockStreamer{response: "PR description"} + defer withMockClient(mock)() + callCount := 0 + defer withMockGit(func(args []string) (string, error) { + callCount++ + if callCount == 2 { + return "diff --git a/bar.go b/bar.go", nil + } + return "", nil + })() + + runPRDescribe(testCmd(), nil) + + if mock.calls != 1 { + t.Errorf("expected 1 AI call, got %d", mock.calls) + } + }) + + t.Run("second diff error — skips AI", func(t *testing.T) { + mock := &mockStreamer{} + defer withMockClient(mock)() + callCount := 0 + defer withMockGit(func(args []string) (string, error) { + callCount++ + if callCount == 2 { + return "", errors.New("no remote") + } + return "", nil + })() + + runPRDescribe(testCmd(), nil) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls, got %d", mock.calls) + } + }) +} + +// --- runLint / processDocsFile (error paths only — linter/file I/O) --- + +func TestRunLintFileNotFound(t *testing.T) { + // Reset flags to defaults + dryRun = false + autoFix = false + + // Pass a non-existent file; should return without calling AI. + mock := &mockStreamer{} + defer withMockClient(mock)() + + runLint(testCmd(), []string{"/nonexistent/path/file.go"}) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls for missing file, got %d", mock.calls) + } +} + +func TestProcessDocsFileNotFound(t *testing.T) { + mock := &mockStreamer{} + processDocsFile(mock, "grok-4", "/nonexistent/path/file.go") + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls for missing file, got %d", mock.calls) + } +} + +func TestProcessDocsFileUnsupportedLanguage(t *testing.T) { + // Create a temp file with an unsupported extension. + f, err := os.CreateTemp("", "test*.xyz") + if err != nil { + t.Fatal(err) + } + f.Close() + defer os.Remove(f.Name()) + + mock := &mockStreamer{} + processDocsFile(mock, "grok-4", f.Name()) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls for unsupported language, got %d", mock.calls) + } +} + +func TestProcessDocsFilePreviewAndCancel(t *testing.T) { + f, err := os.CreateTemp("", "test*.go") + if err != nil { + t.Fatal(err) + } + if _, err := f.WriteString("package main\n\nfunc Foo() {}\n"); err != nil { + t.Fatal(err) + } + f.Close() + defer os.Remove(f.Name()) + defer os.Remove(f.Name() + ".bak") + + mock := &mockStreamer{response: "package main\n\n// Foo does nothing.\nfunc Foo() {}\n"} + + origAutoApply := autoApply + autoApply = false + defer func() { autoApply = origAutoApply }() + + // Pipe "n\n" so confirmation returns without writing. + r, w, _ := os.Pipe() + origStdin := os.Stdin + os.Stdin = r + w.WriteString("n\n") + w.Close() + defer func() { os.Stdin = origStdin }() + + processDocsFile(mock, "grok-4", f.Name()) + + if mock.calls != 1 { + t.Errorf("expected 1 AI call, got %d", mock.calls) + } +} + +func TestProcessDocsFileAutoApply(t *testing.T) { + f, err := os.CreateTemp("", "test*.go") + if err != nil { + t.Fatal(err) + } + original := "package main\n\nfunc Bar() {}\n" + if _, err := f.WriteString(original); err != nil { + t.Fatal(err) + } + f.Close() + defer os.Remove(f.Name()) + defer os.Remove(f.Name() + ".bak") + + // CleanCodeResponse will trim the trailing newline from the AI response. + aiResponse := "package main\n\n// Bar does nothing.\nfunc Bar() {}\n" + documented := "package main\n\n// Bar does nothing.\nfunc Bar() {}" + mock := &mockStreamer{response: aiResponse} + + origAutoApply := autoApply + autoApply = true + defer func() { autoApply = origAutoApply }() + + processDocsFile(mock, "grok-4", f.Name()) + + if mock.calls != 1 { + t.Errorf("expected 1 AI call, got %d", mock.calls) + } + // Verify file was rewritten with documented content. + content, _ := os.ReadFile(f.Name()) + if string(content) != documented { + t.Errorf("file content = %q, want %q", string(content), documented) + } + // Verify backup was created with original content. + backup, _ := os.ReadFile(f.Name() + ".bak") + if string(backup) != original { + t.Errorf("backup content = %q, want %q", string(backup), original) + } +} + +func TestRunDocs(t *testing.T) { + // runDocs with a missing file: should call processDocsFile which returns early. + mock := &mockStreamer{} + defer withMockClient(mock)() + + runDocs(testCmd(), []string{"/nonexistent/file.go"}) + + if mock.calls != 0 { + t.Errorf("expected 0 AI calls for missing file, got %d", mock.calls) + } +} diff --git a/internal/grok/client_test.go b/internal/grok/client_test.go index b0a2cb3..feb0567 100644 --- a/internal/grok/client_test.go +++ b/internal/grok/client_test.go @@ -1,6 +1,9 @@ package grok import ( + "fmt" + "net/http" + "net/http/httptest" "os" "testing" ) @@ -48,6 +51,81 @@ func TestCleanCodeResponse(t *testing.T) { } } +func sseServer(chunks []string) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + for _, c := range chunks { + fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", c) + } + fmt.Fprintf(w, "data: [DONE]\n\n") + })) +} + +func TestStreamSilent(t *testing.T) { + srv := sseServer([]string{"Hello", " ", "World"}) + defer srv.Close() + + client := &Client{APIKey: "test-key", BaseURL: srv.URL} + got := client.StreamSilent([]map[string]string{{"role": "user", "content": "hi"}}, "test-model") + if got != "Hello World" { + t.Errorf("StreamSilent() = %q, want %q", got, "Hello World") + } +} + +func TestStream(t *testing.T) { + srv := sseServer([]string{"foo", "bar"}) + defer srv.Close() + + client := &Client{APIKey: "test-key", BaseURL: srv.URL} + got := client.Stream([]map[string]string{{"role": "user", "content": "hi"}}, "test-model") + if got != "foobar" { + t.Errorf("Stream() = %q, want %q", got, "foobar") + } +} + +func TestStreamWithTemp(t *testing.T) { + srv := sseServer([]string{"response"}) + defer srv.Close() + + client := &Client{APIKey: "test-key", BaseURL: srv.URL} + got := client.StreamWithTemp([]map[string]string{{"role": "user", "content": "hi"}}, "test-model", 0.5) + if got != "response" { + t.Errorf("StreamWithTemp() = %q, want %q", got, "response") + } +} + +func TestStreamDoneSignal(t *testing.T) { + // Verifies that [DONE] stops processing and non-content chunks are skipped + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"ok\"}}]}\n\n") + fmt.Fprintf(w, "data: [DONE]\n\n") + // This line should never be processed + fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"extra\"}}]}\n\n") + })) + defer srv.Close() + + client := &Client{APIKey: "test-key", BaseURL: srv.URL} + got := client.StreamSilent(nil, "test-model") + if got != "ok" { + t.Errorf("got %q, want %q", got, "ok") + } +} + +func TestStreamEmptyResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "data: [DONE]\n\n") + })) + defer srv.Close() + + client := &Client{APIKey: "test-key", BaseURL: srv.URL} + got := client.StreamSilent(nil, "test-model") + if got != "" { + t.Errorf("got %q, want empty string", got) + } +} + func TestNewClient(t *testing.T) { // Save and restore env oldKey := os.Getenv("XAI_API_KEY")