From c54bc511c936a486712e1078c0222c9218957250 Mon Sep 17 00:00:00 2001 From: Greg Gauthier Date: Mon, 2 Mar 2026 21:57:33 +0000 Subject: [PATCH] feat(testgen): add AI unit test generation command - Implement `grokkit testgen` for Go/Python/C/C++ files - Add language-specific prompts and test file conventions - Include backups, previews, auto-apply flag - Update README with docs and examples - Add unit tests for helper functions - Mark todo as completed --- README.md | 27 ++- cmd/root.go | 1 + cmd/testgen.go | 278 ++++++++++++++++++++++++++ cmd/testgen_test.go | 161 +++++++++++++++ todo/{queued => completed}/testgen.md | 0 5 files changed, 466 insertions(+), 1 deletion(-) create mode 100644 cmd/testgen.go create mode 100644 cmd/testgen_test.go rename todo/{queued => completed}/testgen.md (100%) diff --git a/README.md b/README.md index f946bc6..e9e0a10 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ grokkit version - [history](#-grokkit-history) - [lint](#-grokkit-lint-file) - [docs](#-grokkit-docs-file) + - [testgen](#-grokkit-testgen) - [agent](#-grokkit-agent) - [Configuration](#configuration) - [Workflows](#workflows) @@ -191,7 +192,31 @@ grokkit docs app.py -m grok-4 - Creates `.bak` backup before any changes - Shows first 50 lines of documented code as preview - Requires confirmation (unless `--auto-apply`) - + +### ๐Ÿงช `grokkit testgen PATHS...` + +**Description**: Generate comprehensive unit tests for Go/Python/C/C++ files using AI. + +**Benefits**: +- Go: Table-driven `t.Parallel()` matching codebase. +- Python: Pytest with `@parametrize`. +- C: Check framework suites. +- C++: Google Test `EXPECT_*`. +- Boosts coverage; safe preview/backup. + +**CLI examples**: +```bash +grokkit testgen internal/grok/client.go +grokkit testgen app.py --yes +grokkit testgen foo.c bar.cpp +``` + +**Safety features**: +- Lang detection via `internal/linter`. +- Creates `test_*.bak` backups. +- Unified diff preview. +- Y/N (--yes auto). + ### ๐Ÿค– `grokkit agent` Multi-file agent for complex refactoring (experimental). diff --git a/cmd/root.go b/cmd/root.go index 0e0fac9..bb27044 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -57,6 +57,7 @@ func init() { rootCmd.AddCommand(completionCmd) rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(docsCmd) + rootCmd.AddCommand(testgenCmd) // Add model flag to all commands rootCmd.PersistentFlags().StringP("model", "m", "", "Grok model to use (overrides config)") diff --git a/cmd/testgen.go b/cmd/testgen.go new file mode 100644 index 0000000..5654798 --- /dev/null +++ b/cmd/testgen.go @@ -0,0 +1,278 @@ +package cmd + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/fatih/color" + "github.com/spf13/cobra" + + "gmgauthier.com/grokkit/config" + "gmgauthier.com/grokkit/internal/grok" + "gmgauthier.com/grokkit/internal/linter" + "gmgauthier.com/grokkit/internal/logger" +) + +var testgenCmd = &cobra.Command{ + Use: "testgen PATHS...", + Short: "Generate AI unit tests for files (Go/Python/C/C++, preview/apply)", + Long: `Generates comprehensive unit tests matching language conventions. +Supported: Go (table-driven), Python (pytest), C (Check), C++ (Google Test). + +Examples: + grokkit testgen internal/grok/client.go + grokkit testgen app.py + grokkit testgen foo.c --yes`, + Args: cobra.MinimumNArgs(1), + SilenceUsage: true, + Run: func(cmd *cobra.Command, args []string) { + yesFlag, _ := cmd.Flags().GetBool("yes") + modelFlag, _ := cmd.Flags().GetString("model") + model := config.GetModel("testgen", modelFlag) + + logger.Info("testgen started", "num_files", len(args), "model", model, "auto_apply", yesFlag) + + client := grok.NewClient() + + supportedLangs := map[string]bool{"Go": true, "Python": true, "C": true, "C++": true} + + allSuccess := true + for _, filePath := range args { + langObj, err := linter.DetectLanguage(filePath) + if err != nil { + color.Red("Failed to detect language for %s: %v", filePath, err) + allSuccess = false + continue + } + lang := langObj.Name + if !supportedLangs[lang] { + color.Yellow("Unsupported lang '%s' for %s (supported: Go/Python/C/C++)", lang, filePath) + allSuccess = false + continue + } + + prompt := getTestPrompt(lang) + if err := processTestgenFile(client, filePath, lang, prompt, model, yesFlag); err != nil { + allSuccess = false + color.Red("Failed %s: %v", filePath, err) + logger.Error("processTestgenFile failed", "file", filePath, "error", err) + } + } + + if allSuccess { + color.Green("\nโœ… All test generations complete!") + color.Yellow("Next steps:\n make test\n make test-cover") + } else { + color.Red("\nโŒ Some files failed.") + os.Exit(1) + } + }, +} + +func processTestgenFile(client *grok.Client, filePath, lang, systemPrompt, model string, yesFlag bool) error { + // Validate source file + srcInfo, err := os.Stat(filePath) + if err != nil { + return fmt.Errorf("source file not found: %s", filePath) + } + if srcInfo.IsDir() { + return fmt.Errorf("directories not supported: %s (use files only)", filePath) + } + + origSrc, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("read source: %w", err) + } + cleanSrc := removeSourceComments(string(origSrc), lang) + + testPath := getTestFilePath(filePath, lang) + + // Handle existing test file + var origTest []byte + testExists := true + testInfo, err := os.Stat(testPath) + if os.IsNotExist(err) { + testExists = false + } else if err != nil { + return fmt.Errorf("stat test file: %w", err) + } else if testInfo.IsDir() { + return fmt.Errorf("test path is dir: %s", testPath) + } else { + origTest, err = os.ReadFile(testPath) + if err != nil { + return fmt.Errorf("read existing test: %w", err) + } + } + + // Backup existing test + backupPath := testPath + ".bak" + if testExists { + if err := os.WriteFile(backupPath, origTest, 0644); err != nil { + return fmt.Errorf("backup test file: %w", err) + } + color.Yellow("๐Ÿ’พ Backup: %s", backupPath) + } + + // Generate tests + codeLang := getCodeLang(lang) + messages := []map[string]string{ + { + "role": "system", + "content": systemPrompt, + }, + { + "role": "user", + "content": fmt.Sprintf("Source file %s (%s):\n```%s\n%s\n```\n\nGenerate the language-appropriate test file.", filepath.Base(filePath), lang, codeLang, cleanSrc), + }, + } + + color.Yellow("๐Ÿค– Generating tests for %s โ†’ %s...", filepath.Base(filePath), filepath.Base(testPath)) + rawResponse := client.StreamSilent(messages, model) + newTestCode := grok.CleanCodeResponse(rawResponse) + + if len(newTestCode) == 0 || strings.TrimSpace(newTestCode) == "" { + return fmt.Errorf("empty generation response") + } + + // Preview + color.Cyan("\nโ”Œโ”€ Preview: %s โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€", filepath.Base(testPath)) + if testExists { + fmt.Println("--- a/" + filepath.Base(testPath)) + } else { + fmt.Println("--- /dev/null") + } + fmt.Println("+++ b/" + filepath.Base(testPath)) + fmt.Print(newTestCode) + color.Cyan("\nโ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€") + + if !yesFlag { + fmt.Print("\nApply? [y/N]: ") + var confirm string + _, err := fmt.Scanln(&confirm) + if err != nil { + return fmt.Errorf("input error: %w", err) + } + confirm = strings.TrimSpace(strings.ToLower(confirm)) + if confirm != "y" && confirm != "yes" { + color.Yellow("โญ๏ธ Skipped %s (backup: %s)", testPath, backupPath) + return nil + } + } + + // Apply + if err := os.WriteFile(testPath, []byte(newTestCode), 0644); err != nil { + return fmt.Errorf("write test file: %w", err) + } + + color.Green("โœ… Wrote %s (%d bytes)", testPath, len(newTestCode)) + logger.Debug("testgen applied", "test_file", testPath, "bytes", len(newTestCode), "source_bytes", len(origSrc)) + return nil +} + +func removeSourceComments(content, lang string) string { + lines := strings.Split(content, "\n") + var cleanedLines []string + for _, line := range lines { + if strings.Contains(line, "Last modified") || strings.Contains(line, "Generated by") || + strings.Contains(line, "Generated by testgen") { + continue + } + if lang == "Python" && strings.HasPrefix(strings.TrimSpace(line), "# testgen:") { + continue + } + if (lang == "C" || lang == "C++") && strings.Contains(line, "/* testgen */") { + continue + } + cleanedLines = append(cleanedLines, line) + } + return strings.Join(cleanedLines, "\n") +} + +func getTestPrompt(lang string) string { + switch lang { + case "Go": + return `You are an expert Go testing specialist. Generate COMPLETE, production-ready unit tests for the provided Go source: + +- Table-driven with t.Run(subtest, func(t *testing.T)) and t.Parallel() +- Match exact style of gmgauthier/grokkit/internal/version/version_test.go +- Use t.Context() for ctx-aware tests +- Modern Go 1.24+: slices.Contains/IndexFunc, maps.Keys, errors.Is/Join, any, etc. +- Cover ALL public funcs/methods/fields: happy path, edges, errors, panics, zero values +- Realistic inputs, no external deps/mocks unless code requires +- func TestXxx(t *testing.T) { ... } only + +Respond with ONLY the full *_test.go file: +- Correct package name (infer from code) +- Necessary imports +- No benchmarks unless obvious perf func +- NO prose, explanations, markdown, code blocks, or extra text. Pure Go test file.` + case "Python": + return `You are a pytest expert. Generate COMPLETE pytest unit tests for the Python source. + +- Use pytest fixtures where appropriate +- @pytest.mark.parametrize for tables +- Cover ALL functions/classes/methods: happy/edge/error cases +- pytest.raises for exceptions +- Modern Python 3.12+: type hints, match/case if applicable +- NO external deps unless source requires + +Respond ONLY with full test_*.py file: imports, fixtures, tests. Pure Python test code.` + case "C": + return `You are a C unit testing expert using Check framework. Generate COMPLETE unit tests for C source. + +- Use Check suite: Suite, tcase_begin/end, ck_assert_* macros +- Cover ALL functions: happy/edge/error cases +- Include #include ? Use Check std. +- main() runner if needed. + +Respond ONLY full test_*.c: headers, suite funcs, main. Pure C.` + case "C++": + return `You are a Google Test expert. Generate COMPLETE gtest unit tests for C++ source. + +- Use TEST/TEST_F, EXPECT_*/ASSERT_* +- TYPED_TEST_SUITE if templates +- Cover ALL classes/fns: happy/edge/error +- #include + +Respond ONLY full test_*.cpp: includes, tests. Pure C++.` + default: + return "" + } +} + +func getTestFilePath(filePath, lang string) string { + ext := filepath.Ext(filePath) + base := strings.TrimSuffix(filepath.Base(filePath), ext) + switch lang { + case "Go": + dir := filepath.Dir(filePath) + return filepath.Join(dir, base+"_test.go") + case "Python": + return filepath.Join(filepath.Dir(filePath), "test_"+base+".py") + case "C": + return filepath.Join(filepath.Dir(filePath), "test_"+base+".c") + case "C++": + return filepath.Join(filepath.Dir(filePath), "test_"+base+".cpp") + default: + return "" + } +} + +func getCodeLang(lang string) string { + switch lang { + case "Go": + return "go" + case "Python": + return "python" + case "C", "C++": + return "c" + default: + return "text" + } +} + +func init() { + testgenCmd.Flags().BoolP("yes", "y", false, "Auto-apply without confirmation") +} diff --git a/cmd/testgen_test.go b/cmd/testgen_test.go new file mode 100644 index 0000000..5751045 --- /dev/null +++ b/cmd/testgen_test.go @@ -0,0 +1,161 @@ +package cmd + +import ( + "strings" + "testing" +) + +func TestRemoveSourceComments(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + lang string + }{ + { + name: "no comments", + input: `package cmd +import "testing" + +func Foo() {}`, + want: `package cmd +import "testing" + +func Foo() {}`, + lang: "Go", + }, + { + name: "last modified", + input: `// Last modified: 2026-03-02 +package cmd`, + want: `package cmd`, + lang: "Go", + }, + { + name: "generated by", + input: `// Generated by grokkit testgen +package cmd`, + want: `package cmd`, + lang: "Go", + }, + { + name: "multiple removable lines", + input: `line1 +// Last modified: foo +line3 +// Generated by: bar +line5`, + want: `line1 +line3 +line5`, + lang: "Go", + }, + { + name: "partial match no remove", + input: `// Modified something else +package cmd`, + want: `// Modified something else +package cmd`, + lang: "Go", + }, + { + name: "python testgen", + input: `# testgen: generated +def foo(): pass`, + want: `def foo(): pass`, + lang: "Python", + }, + { + name: "c testgen", + input: `/* testgen */ +int foo() {}`, + want: `int foo() {}`, + lang: "C", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := removeSourceComments(tt.input, tt.lang) + if got != tt.want { + t.Errorf("removeSourceComments() =\n%q\nwant\n%q", got, tt.want) + } + }) + } +} + +func TestGetTestPrompt(t *testing.T) { + t.Parallel() + + tests := []struct { + lang string + wantPrefix string + }{ + {"Go", "You are an expert Go testing specialist."}, + {"Python", "You are a pytest expert."}, + {"C", "You are a C unit testing expert using Check framework."}, + {"C++", "You are a Google Test expert."}, + {"Invalid", ""}, + } + + for _, tt := range tests { + t.Run(tt.lang, func(t *testing.T) { + got := getTestPrompt(tt.lang) + if tt.wantPrefix != "" && !strings.HasPrefix(got, tt.wantPrefix) { + t.Errorf("getTestPrompt(%q) prefix =\n%q\nwant %q", tt.lang, got[:100], tt.wantPrefix) + } + if tt.wantPrefix == "" && got != "" { + t.Errorf("getTestPrompt(%q) = %q, want empty", tt.lang, got) + } + }) + } +} + +func TestGetTestFilePath(t *testing.T) { + t.Parallel() + + tests := []struct { + filePath string + lang string + want string + }{ + {"foo.go", "Go", "foo_test.go"}, + {"dir/foo.py", "Python", "dir/test_foo.py"}, + {"bar.c", "C", "test_bar.c"}, + {"baz.cpp", "C++", "test_baz.cpp"}, + } + + for _, tt := range tests { + t.Run(tt.filePath+"_"+tt.lang, func(t *testing.T) { + got := getTestFilePath(tt.filePath, tt.lang) + if got != tt.want { + t.Errorf("getTestFilePath(%q, %q) = %q, want %q", tt.filePath, tt.lang, got, tt.want) + } + }) + } +} + +func TestGetCodeLang(t *testing.T) { + t.Parallel() + + tests := []struct { + lang string + want string + }{ + {"Go", "go"}, + {"Python", "python"}, + {"C", "c"}, + {"C++", "c"}, + } + + for _, tt := range tests { + t.Run(tt.lang, func(t *testing.T) { + got := getCodeLang(tt.lang) + if got != tt.want { + t.Errorf("getCodeLang(%q) = %q, want %q", tt.lang, got, tt.want) + } + }) + } +} diff --git a/todo/queued/testgen.md b/todo/completed/testgen.md similarity index 100% rename from todo/queued/testgen.md rename to todo/completed/testgen.md