refactor(cmd): extract run funcs and add injectable deps for testability
- Introduce newGrokClient and gitRun vars to allow mocking in tests. - Refactor commit, commitmsg, history, prdescribe, and review cmds into separate run funcs. - Update docs, lint, and review to use newGrokClient. - Add comprehensive unit tests in run_test.go covering happy paths, errors, and edge cases. - Expand grok client tests with SSE server mocks for Stream* methods.
This commit is contained in:
parent
f763976a27
commit
99ef10b16b
16
cmd/client.go
Normal file
16
cmd/client.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gmgauthier.com/grokkit/internal/git"
|
||||||
|
"gmgauthier.com/grokkit/internal/grok"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newGrokClient is the factory for the AI client.
|
||||||
|
// Tests replace this to inject a mock without making real API calls.
|
||||||
|
var newGrokClient = func() grok.AIClient {
|
||||||
|
return grok.NewClient()
|
||||||
|
}
|
||||||
|
|
||||||
|
// gitRun is the git command runner.
|
||||||
|
// Tests replace this to inject controlled git output.
|
||||||
|
var gitRun = git.Run
|
||||||
@ -7,49 +7,49 @@ import (
|
|||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"gmgauthier.com/grokkit/config"
|
"gmgauthier.com/grokkit/config"
|
||||||
"gmgauthier.com/grokkit/internal/git"
|
|
||||||
"gmgauthier.com/grokkit/internal/grok"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var commitCmd = &cobra.Command{
|
var commitCmd = &cobra.Command{
|
||||||
Use: "commit",
|
Use: "commit",
|
||||||
Short: "Generate message and commit staged changes",
|
Short: "Generate message and commit staged changes",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: runCommit,
|
||||||
diff, err := git.Run([]string{"diff", "--cached", "--no-color"})
|
}
|
||||||
if err != nil {
|
|
||||||
color.Red("Failed to get staged changes: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if diff == "" {
|
|
||||||
color.Yellow("No staged changes!")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
|
||||||
model := config.GetModel("commit", modelFlag)
|
|
||||||
|
|
||||||
client := grok.NewClient()
|
func runCommit(cmd *cobra.Command, args []string) {
|
||||||
messages := buildCommitMessages(diff)
|
diff, err := gitRun([]string{"diff", "--cached", "--no-color"})
|
||||||
color.Yellow("Generating commit message...")
|
if err != nil {
|
||||||
msg := client.Stream(messages, model)
|
color.Red("Failed to get staged changes: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if diff == "" {
|
||||||
|
color.Yellow("No staged changes!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
|
model := config.GetModel("commit", modelFlag)
|
||||||
|
|
||||||
color.Cyan("\nProposed commit message:\n%s", msg)
|
client := newGrokClient()
|
||||||
var confirm string
|
messages := buildCommitMessages(diff)
|
||||||
color.Yellow("Commit with this message? (y/n): ")
|
color.Yellow("Generating commit message...")
|
||||||
if _, err := fmt.Scanln(&confirm); err != nil {
|
msg := client.Stream(messages, model)
|
||||||
color.Red("Failed to read input: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if confirm != "y" && confirm != "Y" {
|
|
||||||
color.Yellow("Aborted.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := exec.Command("git", "commit", "-m", msg).Run(); err != nil {
|
color.Cyan("\nProposed commit message:\n%s", msg)
|
||||||
color.Red("Git commit failed")
|
var confirm string
|
||||||
} else {
|
color.Yellow("Commit with this message? (y/n): ")
|
||||||
color.Green("✅ Committed successfully!")
|
if _, err := fmt.Scanln(&confirm); err != nil {
|
||||||
}
|
color.Red("Failed to read input: %v", err)
|
||||||
},
|
return
|
||||||
|
}
|
||||||
|
if confirm != "y" && confirm != "Y" {
|
||||||
|
color.Yellow("Aborted.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := exec.Command("git", "commit", "-m", msg).Run(); err != nil {
|
||||||
|
color.Red("Git commit failed")
|
||||||
|
} else {
|
||||||
|
color.Green("✅ Committed successfully!")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildCommitMessages(diff string) []map[string]string {
|
func buildCommitMessages(diff string) []map[string]string {
|
||||||
|
|||||||
@ -1,37 +1,32 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"gmgauthier.com/grokkit/config"
|
"gmgauthier.com/grokkit/config"
|
||||||
"gmgauthier.com/grokkit/internal/git"
|
|
||||||
"gmgauthier.com/grokkit/internal/grok"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var commitMsgCmd = &cobra.Command{
|
var commitMsgCmd = &cobra.Command{
|
||||||
Use: "commit-msg",
|
Use: "commit-msg",
|
||||||
Short: "Generate conventional commit message from staged changes",
|
Short: "Generate conventional commit message from staged changes",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: runCommitMsg,
|
||||||
diff, err := git.Run([]string{"diff", "--cached", "--no-color"})
|
}
|
||||||
if err != nil {
|
|
||||||
color.Red("Failed to get staged changes: %v", err)
|
func runCommitMsg(cmd *cobra.Command, args []string) {
|
||||||
return
|
diff, err := gitRun([]string{"diff", "--cached", "--no-color"})
|
||||||
}
|
if err != nil {
|
||||||
if diff == "" {
|
color.Red("Failed to get staged changes: %v", err)
|
||||||
color.Yellow("No staged changes!")
|
return
|
||||||
return
|
}
|
||||||
}
|
if diff == "" {
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
color.Yellow("No staged changes!")
|
||||||
model := config.GetModel("commitmsg", modelFlag)
|
return
|
||||||
|
}
|
||||||
client := grok.NewClient()
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
messages := []map[string]string{
|
model := config.GetModel("commitmsg", modelFlag)
|
||||||
{"role": "system", "content": "Return ONLY a conventional commit message (type(scope): subject\n\nbody)."},
|
|
||||||
{"role": "user", "content": fmt.Sprintf("Staged changes:\n%s", diff)},
|
client := newGrokClient()
|
||||||
}
|
messages := buildCommitMessages(diff)
|
||||||
color.Yellow("Generating commit message...")
|
color.Yellow("Generating commit message...")
|
||||||
client.Stream(messages, model)
|
client.Stream(messages, model)
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -47,13 +47,13 @@ func runDocs(cmd *cobra.Command, args []string) {
|
|||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
model := config.GetModel("docs", modelFlag)
|
model := config.GetModel("docs", modelFlag)
|
||||||
|
|
||||||
client := grok.NewClient()
|
client := newGrokClient()
|
||||||
for _, filePath := range args {
|
for _, filePath := range args {
|
||||||
processDocsFile(client, model, filePath)
|
processDocsFile(client, model, filePath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func processDocsFile(client *grok.Client, model, filePath string) {
|
func processDocsFile(client grok.AIClient, model, filePath string) {
|
||||||
logger.Info("starting docs operation", "file", filePath)
|
logger.Info("starting docs operation", "file", filePath)
|
||||||
|
|
||||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||||
|
|||||||
@ -4,32 +4,32 @@ import (
|
|||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"gmgauthier.com/grokkit/config"
|
"gmgauthier.com/grokkit/config"
|
||||||
"gmgauthier.com/grokkit/internal/git"
|
|
||||||
"gmgauthier.com/grokkit/internal/grok"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var historyCmd = &cobra.Command{
|
var historyCmd = &cobra.Command{
|
||||||
Use: "history",
|
Use: "history",
|
||||||
Short: "Summarize recent git history",
|
Short: "Summarize recent git history",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: runHistory,
|
||||||
log, err := git.Run([]string{"log", "--oneline", "-10"})
|
}
|
||||||
if err != nil {
|
|
||||||
color.Red("Failed to get git log: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if log == "" {
|
|
||||||
color.Yellow("No commits found.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
func runHistory(cmd *cobra.Command, args []string) {
|
||||||
model := config.GetModel("history", modelFlag)
|
log, err := gitRun([]string{"log", "--oneline", "-10"})
|
||||||
|
if err != nil {
|
||||||
|
color.Red("Failed to get git log: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if log == "" {
|
||||||
|
color.Yellow("No commits found.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
client := grok.NewClient()
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
messages := buildHistoryMessages(log)
|
model := config.GetModel("history", modelFlag)
|
||||||
color.Yellow("Summarizing recent commits...")
|
|
||||||
client.Stream(messages, model)
|
client := newGrokClient()
|
||||||
},
|
messages := buildHistoryMessages(log)
|
||||||
|
color.Yellow("Summarizing recent commits...")
|
||||||
|
client.Stream(messages, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildHistoryMessages(log string) []map[string]string {
|
func buildHistoryMessages(log string) []map[string]string {
|
||||||
|
|||||||
@ -118,7 +118,7 @@ func runLint(cmd *cobra.Command, args []string) {
|
|||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
model := config.GetModel("lint", modelFlag)
|
model := config.GetModel("lint", modelFlag)
|
||||||
|
|
||||||
client := grok.NewClient()
|
client := newGrokClient()
|
||||||
messages := buildLintFixMessages(result, string(originalContent))
|
messages := buildLintFixMessages(result, string(originalContent))
|
||||||
response := client.StreamSilent(messages, model)
|
response := client.StreamSilent(messages, model)
|
||||||
|
|
||||||
|
|||||||
@ -6,34 +6,34 @@ import (
|
|||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"gmgauthier.com/grokkit/config"
|
"gmgauthier.com/grokkit/config"
|
||||||
"gmgauthier.com/grokkit/internal/git"
|
|
||||||
"gmgauthier.com/grokkit/internal/grok"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var prDescribeCmd = &cobra.Command{
|
var prDescribeCmd = &cobra.Command{
|
||||||
Use: "pr-describe",
|
Use: "pr-describe",
|
||||||
Short: "Generate full PR description from current branch",
|
Short: "Generate full PR description from current branch",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: runPRDescribe,
|
||||||
diff, err := git.Run([]string{"diff", "main..HEAD", "--no-color"})
|
}
|
||||||
if err != nil || diff == "" {
|
|
||||||
diff, err = git.Run([]string{"diff", "origin/main..HEAD", "--no-color"})
|
func runPRDescribe(cmd *cobra.Command, args []string) {
|
||||||
if err != nil {
|
diff, err := gitRun([]string{"diff", "main..HEAD", "--no-color"})
|
||||||
color.Red("Failed to get branch diff: %v", err)
|
if err != nil || diff == "" {
|
||||||
return
|
diff, err = gitRun([]string{"diff", "origin/main..HEAD", "--no-color"})
|
||||||
}
|
if err != nil {
|
||||||
}
|
color.Red("Failed to get branch diff: %v", err)
|
||||||
if diff == "" {
|
|
||||||
color.Yellow("No changes on this branch compared to main/origin/main.")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
}
|
||||||
model := config.GetModel("prdescribe", modelFlag)
|
if diff == "" {
|
||||||
|
color.Yellow("No changes on this branch compared to main/origin/main.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
|
model := config.GetModel("prdescribe", modelFlag)
|
||||||
|
|
||||||
client := grok.NewClient()
|
client := newGrokClient()
|
||||||
messages := buildPRDescribeMessages(diff)
|
messages := buildPRDescribeMessages(diff)
|
||||||
color.Yellow("Writing PR description...")
|
color.Yellow("Writing PR description...")
|
||||||
client.Stream(messages, model)
|
client.Stream(messages, model)
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildPRDescribeMessages(diff string) []map[string]string {
|
func buildPRDescribeMessages(diff string) []map[string]string {
|
||||||
|
|||||||
@ -6,33 +6,33 @@ import (
|
|||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"gmgauthier.com/grokkit/config"
|
"gmgauthier.com/grokkit/config"
|
||||||
"gmgauthier.com/grokkit/internal/git"
|
|
||||||
"gmgauthier.com/grokkit/internal/grok"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var reviewCmd = &cobra.Command{
|
var reviewCmd = &cobra.Command{
|
||||||
Use: "review [path]",
|
Use: "review [path]",
|
||||||
Short: "Review the current repository or directory",
|
Short: "Review the current repository or directory",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: runReview,
|
||||||
modelFlag, _ := cmd.Flags().GetString("model")
|
}
|
||||||
model := config.GetModel("review", modelFlag)
|
|
||||||
|
|
||||||
client := grok.NewClient()
|
func runReview(cmd *cobra.Command, args []string) {
|
||||||
diff, err := git.Run([]string{"diff", "--no-color"})
|
modelFlag, _ := cmd.Flags().GetString("model")
|
||||||
if err != nil {
|
model := config.GetModel("review", modelFlag)
|
||||||
color.Red("Failed to get git diff: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
status, err := git.Run([]string{"status", "--short"})
|
|
||||||
if err != nil {
|
|
||||||
color.Red("Failed to get git status: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
messages := buildReviewMessages(status, diff)
|
diff, err := gitRun([]string{"diff", "--no-color"})
|
||||||
color.Yellow("Grok is reviewing the repo...")
|
if err != nil {
|
||||||
client.Stream(messages, model)
|
color.Red("Failed to get git diff: %v", err)
|
||||||
},
|
return
|
||||||
|
}
|
||||||
|
status, err := gitRun([]string{"status", "--short"})
|
||||||
|
if err != nil {
|
||||||
|
color.Red("Failed to get git status: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client := newGrokClient()
|
||||||
|
messages := buildReviewMessages(status, diff)
|
||||||
|
color.Yellow("Grok is reviewing the repo...")
|
||||||
|
client.Stream(messages, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildReviewMessages(status, diff string) []map[string]string {
|
func buildReviewMessages(status, diff string) []map[string]string {
|
||||||
|
|||||||
453
cmd/run_test.go
Normal file
453
cmd/run_test.go
Normal file
@ -0,0 +1,453 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"gmgauthier.com/grokkit/internal/grok"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockStreamer records calls made to it and returns a canned response.
|
||||||
|
type mockStreamer struct {
|
||||||
|
response string
|
||||||
|
calls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamer) Stream(messages []map[string]string, model string) string {
|
||||||
|
m.calls++
|
||||||
|
return m.response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamer) StreamWithTemp(messages []map[string]string, model string, temp float64) string {
|
||||||
|
m.calls++
|
||||||
|
return m.response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamer) StreamSilent(messages []map[string]string, model string) string {
|
||||||
|
m.calls++
|
||||||
|
return m.response
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure mockStreamer satisfies the interface at compile time.
|
||||||
|
var _ grok.AIClient = (*mockStreamer)(nil)
|
||||||
|
|
||||||
|
// withMockClient injects mock and returns a restore func.
|
||||||
|
func withMockClient(mock grok.AIClient) func() {
|
||||||
|
orig := newGrokClient
|
||||||
|
newGrokClient = func() grok.AIClient { return mock }
|
||||||
|
return func() { newGrokClient = orig }
|
||||||
|
}
|
||||||
|
|
||||||
|
// withMockGit injects a fake git runner and returns a restore func.
|
||||||
|
func withMockGit(fn func([]string) (string, error)) func() {
|
||||||
|
orig := gitRun
|
||||||
|
gitRun = fn
|
||||||
|
return func() { gitRun = orig }
|
||||||
|
}
|
||||||
|
|
||||||
|
// testCmd returns a minimal cobra command with the model flag registered.
|
||||||
|
func testCmd() *cobra.Command {
|
||||||
|
c := &cobra.Command{}
|
||||||
|
c.Flags().String("model", "", "")
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- runHistory ---
|
||||||
|
|
||||||
|
func TestRunHistory(t *testing.T) {
|
||||||
|
t.Run("calls AI with log output", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{response: "3 commits: feat, fix, chore"}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
return "abc1234 feat: add thing\ndef5678 fix: bug", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
runHistory(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 1 {
|
||||||
|
t.Errorf("expected 1 AI call, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no commits — skips AI", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
runHistory(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("git error — skips AI", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
return "", errors.New("not a git repo")
|
||||||
|
})()
|
||||||
|
|
||||||
|
runHistory(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- runReview ---
|
||||||
|
|
||||||
|
func TestRunReview(t *testing.T) {
|
||||||
|
t.Run("reviews with diff and status", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{response: "looks good"}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
callCount := 0
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
callCount++
|
||||||
|
switch args[0] {
|
||||||
|
case "diff":
|
||||||
|
return "diff content", nil
|
||||||
|
case "status":
|
||||||
|
return "M main.go", nil
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
runReview(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 1 {
|
||||||
|
t.Errorf("expected 1 AI call, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
if callCount != 2 {
|
||||||
|
t.Errorf("expected 2 git calls (diff + status), got %d", callCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("git diff error — skips AI", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
if args[0] == "diff" {
|
||||||
|
return "", errors.New("git error")
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
runReview(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("git status error — skips AI", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
if args[0] == "status" {
|
||||||
|
return "", errors.New("git error")
|
||||||
|
}
|
||||||
|
return "diff content", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
runReview(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- runCommit ---
|
||||||
|
|
||||||
|
func TestRunCommit(t *testing.T) {
|
||||||
|
t.Run("no staged changes — skips AI", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
runCommit(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("git error — skips AI", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
return "", errors.New("not a git repo")
|
||||||
|
})()
|
||||||
|
|
||||||
|
runCommit(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with staged changes — calls AI then cancels via stdin", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{response: "feat(cmd): add thing"}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
return "diff --git a/foo.go b/foo.go\n+func bar() {}", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
// Pipe "n\n" so the confirmation prompt returns without committing.
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
origStdin := os.Stdin
|
||||||
|
os.Stdin = r
|
||||||
|
w.WriteString("n\n")
|
||||||
|
w.Close()
|
||||||
|
defer func() { os.Stdin = origStdin }()
|
||||||
|
|
||||||
|
runCommit(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 1 {
|
||||||
|
t.Errorf("expected 1 AI call, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- runCommitMsg ---
|
||||||
|
|
||||||
|
func TestRunCommitMsg(t *testing.T) {
|
||||||
|
t.Run("no staged changes — skips AI", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
runCommitMsg(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with staged changes — calls AI and prints message", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{response: "feat(api): add endpoint"}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
return "diff --git a/api.go b/api.go", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
runCommitMsg(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 1 {
|
||||||
|
t.Errorf("expected 1 AI call, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- runPRDescribe ---
|
||||||
|
|
||||||
|
func TestRunPRDescribe(t *testing.T) {
|
||||||
|
t.Run("no changes on branch — skips AI", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
return "", nil // both diff calls return empty
|
||||||
|
})()
|
||||||
|
|
||||||
|
runPRDescribe(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("first diff succeeds — calls AI", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{response: "## PR Title\n\nDescription"}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
callCount := 0
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
callCount++
|
||||||
|
if callCount == 1 {
|
||||||
|
return "diff --git a/foo.go b/foo.go", nil
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
runPRDescribe(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 1 {
|
||||||
|
t.Errorf("expected 1 AI call, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("first diff empty, second succeeds — calls AI", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{response: "PR description"}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
callCount := 0
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
callCount++
|
||||||
|
if callCount == 2 {
|
||||||
|
return "diff --git a/bar.go b/bar.go", nil
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
runPRDescribe(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 1 {
|
||||||
|
t.Errorf("expected 1 AI call, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("second diff error — skips AI", func(t *testing.T) {
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
callCount := 0
|
||||||
|
defer withMockGit(func(args []string) (string, error) {
|
||||||
|
callCount++
|
||||||
|
if callCount == 2 {
|
||||||
|
return "", errors.New("no remote")
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
})()
|
||||||
|
|
||||||
|
runPRDescribe(testCmd(), nil)
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- runLint / processDocsFile (error paths only — linter/file I/O) ---
|
||||||
|
|
||||||
|
func TestRunLintFileNotFound(t *testing.T) {
|
||||||
|
// Reset flags to defaults
|
||||||
|
dryRun = false
|
||||||
|
autoFix = false
|
||||||
|
|
||||||
|
// Pass a non-existent file; should return without calling AI.
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
|
||||||
|
runLint(testCmd(), []string{"/nonexistent/path/file.go"})
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls for missing file, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessDocsFileNotFound(t *testing.T) {
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
processDocsFile(mock, "grok-4", "/nonexistent/path/file.go")
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls for missing file, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessDocsFileUnsupportedLanguage(t *testing.T) {
|
||||||
|
// Create a temp file with an unsupported extension.
|
||||||
|
f, err := os.CreateTemp("", "test*.xyz")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
defer os.Remove(f.Name())
|
||||||
|
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
processDocsFile(mock, "grok-4", f.Name())
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls for unsupported language, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessDocsFilePreviewAndCancel(t *testing.T) {
|
||||||
|
f, err := os.CreateTemp("", "test*.go")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := f.WriteString("package main\n\nfunc Foo() {}\n"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
defer os.Remove(f.Name())
|
||||||
|
defer os.Remove(f.Name() + ".bak")
|
||||||
|
|
||||||
|
mock := &mockStreamer{response: "package main\n\n// Foo does nothing.\nfunc Foo() {}\n"}
|
||||||
|
|
||||||
|
origAutoApply := autoApply
|
||||||
|
autoApply = false
|
||||||
|
defer func() { autoApply = origAutoApply }()
|
||||||
|
|
||||||
|
// Pipe "n\n" so confirmation returns without writing.
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
origStdin := os.Stdin
|
||||||
|
os.Stdin = r
|
||||||
|
w.WriteString("n\n")
|
||||||
|
w.Close()
|
||||||
|
defer func() { os.Stdin = origStdin }()
|
||||||
|
|
||||||
|
processDocsFile(mock, "grok-4", f.Name())
|
||||||
|
|
||||||
|
if mock.calls != 1 {
|
||||||
|
t.Errorf("expected 1 AI call, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessDocsFileAutoApply(t *testing.T) {
|
||||||
|
f, err := os.CreateTemp("", "test*.go")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
original := "package main\n\nfunc Bar() {}\n"
|
||||||
|
if _, err := f.WriteString(original); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
defer os.Remove(f.Name())
|
||||||
|
defer os.Remove(f.Name() + ".bak")
|
||||||
|
|
||||||
|
// CleanCodeResponse will trim the trailing newline from the AI response.
|
||||||
|
aiResponse := "package main\n\n// Bar does nothing.\nfunc Bar() {}\n"
|
||||||
|
documented := "package main\n\n// Bar does nothing.\nfunc Bar() {}"
|
||||||
|
mock := &mockStreamer{response: aiResponse}
|
||||||
|
|
||||||
|
origAutoApply := autoApply
|
||||||
|
autoApply = true
|
||||||
|
defer func() { autoApply = origAutoApply }()
|
||||||
|
|
||||||
|
processDocsFile(mock, "grok-4", f.Name())
|
||||||
|
|
||||||
|
if mock.calls != 1 {
|
||||||
|
t.Errorf("expected 1 AI call, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
// Verify file was rewritten with documented content.
|
||||||
|
content, _ := os.ReadFile(f.Name())
|
||||||
|
if string(content) != documented {
|
||||||
|
t.Errorf("file content = %q, want %q", string(content), documented)
|
||||||
|
}
|
||||||
|
// Verify backup was created with original content.
|
||||||
|
backup, _ := os.ReadFile(f.Name() + ".bak")
|
||||||
|
if string(backup) != original {
|
||||||
|
t.Errorf("backup content = %q, want %q", string(backup), original)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunDocs(t *testing.T) {
|
||||||
|
// runDocs with a missing file: should call processDocsFile which returns early.
|
||||||
|
mock := &mockStreamer{}
|
||||||
|
defer withMockClient(mock)()
|
||||||
|
|
||||||
|
runDocs(testCmd(), []string{"/nonexistent/file.go"})
|
||||||
|
|
||||||
|
if mock.calls != 0 {
|
||||||
|
t.Errorf("expected 0 AI calls for missing file, got %d", mock.calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,6 +1,9 @@
|
|||||||
package grok
|
package grok
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -48,6 +51,81 @@ func TestCleanCodeResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sseServer(chunks []string) *httptest.Server {
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
for _, c := range chunks {
|
||||||
|
fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", c)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamSilent(t *testing.T) {
|
||||||
|
srv := sseServer([]string{"Hello", " ", "World"})
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &Client{APIKey: "test-key", BaseURL: srv.URL}
|
||||||
|
got := client.StreamSilent([]map[string]string{{"role": "user", "content": "hi"}}, "test-model")
|
||||||
|
if got != "Hello World" {
|
||||||
|
t.Errorf("StreamSilent() = %q, want %q", got, "Hello World")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStream(t *testing.T) {
|
||||||
|
srv := sseServer([]string{"foo", "bar"})
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &Client{APIKey: "test-key", BaseURL: srv.URL}
|
||||||
|
got := client.Stream([]map[string]string{{"role": "user", "content": "hi"}}, "test-model")
|
||||||
|
if got != "foobar" {
|
||||||
|
t.Errorf("Stream() = %q, want %q", got, "foobar")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamWithTemp(t *testing.T) {
|
||||||
|
srv := sseServer([]string{"response"})
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &Client{APIKey: "test-key", BaseURL: srv.URL}
|
||||||
|
got := client.StreamWithTemp([]map[string]string{{"role": "user", "content": "hi"}}, "test-model", 0.5)
|
||||||
|
if got != "response" {
|
||||||
|
t.Errorf("StreamWithTemp() = %q, want %q", got, "response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamDoneSignal(t *testing.T) {
|
||||||
|
// Verifies that [DONE] stops processing and non-content chunks are skipped
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"ok\"}}]}\n\n")
|
||||||
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||||
|
// This line should never be processed
|
||||||
|
fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":\"extra\"}}]}\n\n")
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &Client{APIKey: "test-key", BaseURL: srv.URL}
|
||||||
|
got := client.StreamSilent(nil, "test-model")
|
||||||
|
if got != "ok" {
|
||||||
|
t.Errorf("got %q, want %q", got, "ok")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamEmptyResponse(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &Client{APIKey: "test-key", BaseURL: srv.URL}
|
||||||
|
got := client.StreamSilent(nil, "test-model")
|
||||||
|
if got != "" {
|
||||||
|
t.Errorf("got %q, want empty string", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewClient(t *testing.T) {
|
func TestNewClient(t *testing.T) {
|
||||||
// Save and restore env
|
// Save and restore env
|
||||||
oldKey := os.Getenv("XAI_API_KEY")
|
oldKey := os.Getenv("XAI_API_KEY")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user