diff --git a/cmd/prdescribe.go b/cmd/prdescribe.go index d489801..4313900 100644 --- a/cmd/prdescribe.go +++ b/cmd/prdescribe.go @@ -14,17 +14,23 @@ var prDescribeCmd = &cobra.Command{ Run: runPRDescribe, } +func init() { + prDescribeCmd.Flags().StringP("base", "b", "master", "Base branch to compare against") +} + func runPRDescribe(cmd *cobra.Command, args []string) { - diff, err := gitRun([]string{"diff", "main..HEAD", "--no-color"}) + base, _ := cmd.Flags().GetString("base") + + diff, err := gitRun([]string{"diff", fmt.Sprintf("%s..HEAD", base), "--no-color"}) if err != nil || diff == "" { - diff, err = gitRun([]string{"diff", "origin/main..HEAD", "--no-color"}) + diff, err = gitRun([]string{"diff", fmt.Sprintf("origin/%s..HEAD", base), "--no-color"}) if err != nil { color.Red("Failed to get branch diff: %v", err) return } } if diff == "" { - color.Yellow("No changes on this branch compared to main/origin/main.") + color.Yellow("No changes on this branch compared to %s/origin/%s.", base, base) return } modelFlag, _ := cmd.Flags().GetString("model") diff --git a/cmd/run_test.go b/cmd/run_test.go index 770d170..859f40d 100644 --- a/cmd/run_test.go +++ b/cmd/run_test.go @@ -47,10 +47,11 @@ func withMockGit(fn func([]string) (string, error)) func() { return func() { gitRun = orig } } -// testCmd returns a minimal cobra command with the model flag registered. +// testCmd returns a minimal cobra command with common flags registered. func testCmd() *cobra.Command { c := &cobra.Command{} c.Flags().String("model", "", "") + c.Flags().String("base", "master", "") return c } @@ -308,22 +309,62 @@ func TestRunPRDescribe(t *testing.T) { } }) - t.Run("second diff error — skips AI", func(t *testing.T) { - mock := &mockStreamer{} + t.Run("uses custom base branch", func(t *testing.T) { + mock := &mockStreamer{response: "PR description"} defer withMockClient(mock)() - callCount := 0 + var capturedArgs []string defer withMockGit(func(args []string) (string, error) { - callCount++ - if callCount == 2 { - return "", errors.New("no remote") + capturedArgs = args + return "diff content", nil + })() + + cmd := testCmd() + if err := cmd.Flags().Set("base", "develop"); err != nil { + t.Fatal(err) + } + runPRDescribe(cmd, nil) + + if mock.calls != 1 { + t.Errorf("expected 1 AI call, got %d", mock.calls) + } + // Expect "diff", "develop..HEAD", "--no-color" + expectedArg := "develop..HEAD" + found := false + for _, arg := range capturedArgs { + if arg == expectedArg { + found = true + break } - return "", nil + } + if !found { + t.Errorf("expected arg %q not found in %v", expectedArg, capturedArgs) + } + }) + + t.Run("defaults to master", func(t *testing.T) { + mock := &mockStreamer{response: "PR description"} + defer withMockClient(mock)() + var capturedArgs []string + defer withMockGit(func(args []string) (string, error) { + capturedArgs = args + return "diff content", nil })() runPRDescribe(testCmd(), nil) - if mock.calls != 0 { - t.Errorf("expected 0 AI calls, got %d", mock.calls) + if mock.calls != 1 { + t.Errorf("expected 1 AI call, got %d", mock.calls) + } + expectedArg := "master..HEAD" + found := false + for _, arg := range capturedArgs { + if arg == expectedArg { + found = true + break + } + } + if !found { + t.Errorf("expected arg %q not found in %v", expectedArg, capturedArgs) } }) }