feat(testgen): add AI unit test generation command
Some checks failed
CI / Test (push) Failing after 27s
CI / Lint (push) Has been skipped
CI / Build (push) Has been skipped

- Implement `grokkit testgen` for Go/Python/C/C++ files
- Add language-specific prompts and test file conventions
- Include backups, previews, auto-apply flag
- Update README with docs and examples
- Add unit tests for helper functions
- Mark todo as completed
This commit is contained in:
Greg Gauthier 2026-03-02 21:57:33 +00:00
parent 599b478a17
commit c54bc511c9
5 changed files with 466 additions and 1 deletions

View File

@ -45,6 +45,7 @@ grokkit version
- [history](#-grokkit-history)
- [lint](#-grokkit-lint-file)
- [docs](#-grokkit-docs-file)
- [testgen](#-grokkit-testgen)
- [agent](#-grokkit-agent)
- [Configuration](#configuration)
- [Workflows](#workflows)
@ -191,7 +192,31 @@ grokkit docs app.py -m grok-4
- Creates `.bak` backup before any changes
- Shows first 50 lines of documented code as preview
- Requires confirmation (unless `--auto-apply`)
### 🧪 `grokkit testgen PATHS...`
**Description**: Generate comprehensive unit tests for Go/Python/C/C++ files using AI.
**Benefits**:
- Go: Table-driven `t.Parallel()` matching codebase.
- Python: Pytest with `@parametrize`.
- C: Check framework suites.
- C++: Google Test `EXPECT_*`.
- Boosts coverage; safe preview/backup.
**CLI examples**:
```bash
grokkit testgen internal/grok/client.go
grokkit testgen app.py --yes
grokkit testgen foo.c bar.cpp
```
**Safety features**:
- Lang detection via `internal/linter`.
- Creates `test_*.bak` backups.
- Unified diff preview.
- Y/N (--yes auto).
### 🤖 `grokkit agent`
Multi-file agent for complex refactoring (experimental).

View File

@ -57,6 +57,7 @@ func init() {
rootCmd.AddCommand(completionCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(docsCmd)
rootCmd.AddCommand(testgenCmd)
// Add model flag to all commands
rootCmd.PersistentFlags().StringP("model", "m", "", "Grok model to use (overrides config)")

278
cmd/testgen.go Normal file
View File

@ -0,0 +1,278 @@
package cmd
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/fatih/color"
"github.com/spf13/cobra"
"gmgauthier.com/grokkit/config"
"gmgauthier.com/grokkit/internal/grok"
"gmgauthier.com/grokkit/internal/linter"
"gmgauthier.com/grokkit/internal/logger"
)
var testgenCmd = &cobra.Command{
Use: "testgen PATHS...",
Short: "Generate AI unit tests for files (Go/Python/C/C++, preview/apply)",
Long: `Generates comprehensive unit tests matching language conventions.
Supported: Go (table-driven), Python (pytest), C (Check), C++ (Google Test).
Examples:
grokkit testgen internal/grok/client.go
grokkit testgen app.py
grokkit testgen foo.c --yes`,
Args: cobra.MinimumNArgs(1),
SilenceUsage: true,
Run: func(cmd *cobra.Command, args []string) {
yesFlag, _ := cmd.Flags().GetBool("yes")
modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel("testgen", modelFlag)
logger.Info("testgen started", "num_files", len(args), "model", model, "auto_apply", yesFlag)
client := grok.NewClient()
supportedLangs := map[string]bool{"Go": true, "Python": true, "C": true, "C++": true}
allSuccess := true
for _, filePath := range args {
langObj, err := linter.DetectLanguage(filePath)
if err != nil {
color.Red("Failed to detect language for %s: %v", filePath, err)
allSuccess = false
continue
}
lang := langObj.Name
if !supportedLangs[lang] {
color.Yellow("Unsupported lang '%s' for %s (supported: Go/Python/C/C++)", lang, filePath)
allSuccess = false
continue
}
prompt := getTestPrompt(lang)
if err := processTestgenFile(client, filePath, lang, prompt, model, yesFlag); err != nil {
allSuccess = false
color.Red("Failed %s: %v", filePath, err)
logger.Error("processTestgenFile failed", "file", filePath, "error", err)
}
}
if allSuccess {
color.Green("\n✅ All test generations complete!")
color.Yellow("Next steps:\n make test\n make test-cover")
} else {
color.Red("\n❌ Some files failed.")
os.Exit(1)
}
},
}
func processTestgenFile(client *grok.Client, filePath, lang, systemPrompt, model string, yesFlag bool) error {
// Validate source file
srcInfo, err := os.Stat(filePath)
if err != nil {
return fmt.Errorf("source file not found: %s", filePath)
}
if srcInfo.IsDir() {
return fmt.Errorf("directories not supported: %s (use files only)", filePath)
}
origSrc, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("read source: %w", err)
}
cleanSrc := removeSourceComments(string(origSrc), lang)
testPath := getTestFilePath(filePath, lang)
// Handle existing test file
var origTest []byte
testExists := true
testInfo, err := os.Stat(testPath)
if os.IsNotExist(err) {
testExists = false
} else if err != nil {
return fmt.Errorf("stat test file: %w", err)
} else if testInfo.IsDir() {
return fmt.Errorf("test path is dir: %s", testPath)
} else {
origTest, err = os.ReadFile(testPath)
if err != nil {
return fmt.Errorf("read existing test: %w", err)
}
}
// Backup existing test
backupPath := testPath + ".bak"
if testExists {
if err := os.WriteFile(backupPath, origTest, 0644); err != nil {
return fmt.Errorf("backup test file: %w", err)
}
color.Yellow("💾 Backup: %s", backupPath)
}
// Generate tests
codeLang := getCodeLang(lang)
messages := []map[string]string{
{
"role": "system",
"content": systemPrompt,
},
{
"role": "user",
"content": fmt.Sprintf("Source file %s (%s):\n```%s\n%s\n```\n\nGenerate the language-appropriate test file.", filepath.Base(filePath), lang, codeLang, cleanSrc),
},
}
color.Yellow("🤖 Generating tests for %s → %s...", filepath.Base(filePath), filepath.Base(testPath))
rawResponse := client.StreamSilent(messages, model)
newTestCode := grok.CleanCodeResponse(rawResponse)
if len(newTestCode) == 0 || strings.TrimSpace(newTestCode) == "" {
return fmt.Errorf("empty generation response")
}
// Preview
color.Cyan("\n┌─ Preview: %s ────────────────────────────────────────────────", filepath.Base(testPath))
if testExists {
fmt.Println("--- a/" + filepath.Base(testPath))
} else {
fmt.Println("--- /dev/null")
}
fmt.Println("+++ b/" + filepath.Base(testPath))
fmt.Print(newTestCode)
color.Cyan("\n└────────────────────────────────────────────────────────────────")
if !yesFlag {
fmt.Print("\nApply? [y/N]: ")
var confirm string
_, err := fmt.Scanln(&confirm)
if err != nil {
return fmt.Errorf("input error: %w", err)
}
confirm = strings.TrimSpace(strings.ToLower(confirm))
if confirm != "y" && confirm != "yes" {
color.Yellow("⏭️ Skipped %s (backup: %s)", testPath, backupPath)
return nil
}
}
// Apply
if err := os.WriteFile(testPath, []byte(newTestCode), 0644); err != nil {
return fmt.Errorf("write test file: %w", err)
}
color.Green("✅ Wrote %s (%d bytes)", testPath, len(newTestCode))
logger.Debug("testgen applied", "test_file", testPath, "bytes", len(newTestCode), "source_bytes", len(origSrc))
return nil
}
func removeSourceComments(content, lang string) string {
lines := strings.Split(content, "\n")
var cleanedLines []string
for _, line := range lines {
if strings.Contains(line, "Last modified") || strings.Contains(line, "Generated by") ||
strings.Contains(line, "Generated by testgen") {
continue
}
if lang == "Python" && strings.HasPrefix(strings.TrimSpace(line), "# testgen:") {
continue
}
if (lang == "C" || lang == "C++") && strings.Contains(line, "/* testgen */") {
continue
}
cleanedLines = append(cleanedLines, line)
}
return strings.Join(cleanedLines, "\n")
}
func getTestPrompt(lang string) string {
switch lang {
case "Go":
return `You are an expert Go testing specialist. Generate COMPLETE, production-ready unit tests for the provided Go source:
- Table-driven with t.Run(subtest, func(t *testing.T)) and t.Parallel()
- Match exact style of gmgauthier/grokkit/internal/version/version_test.go
- Use t.Context() for ctx-aware tests
- Modern Go 1.24+: slices.Contains/IndexFunc, maps.Keys, errors.Is/Join, any, etc.
- Cover ALL public funcs/methods/fields: happy path, edges, errors, panics, zero values
- Realistic inputs, no external deps/mocks unless code requires
- func TestXxx(t *testing.T) { ... } only
Respond with ONLY the full *_test.go file:
- Correct package name (infer from code)
- Necessary imports
- No benchmarks unless obvious perf func
- NO prose, explanations, markdown, code blocks, or extra text. Pure Go test file.`
case "Python":
return `You are a pytest expert. Generate COMPLETE pytest unit tests for the Python source.
- Use pytest fixtures where appropriate
- @pytest.mark.parametrize for tables
- Cover ALL functions/classes/methods: happy/edge/error cases
- pytest.raises for exceptions
- Modern Python 3.12+: type hints, match/case if applicable
- NO external deps unless source requires
Respond ONLY with full test_*.py file: imports, fixtures, tests. Pure Python test code.`
case "C":
return `You are a C unit testing expert using Check framework. Generate COMPLETE unit tests for C source.
- Use Check suite: Suite, tcase_begin/end, ck_assert_* macros
- Cover ALL functions: happy/edge/error cases
- Include #include <check.h> <minunit.h>? Use Check std.
- main() runner if needed.
Respond ONLY full test_*.c: headers, suite funcs, main. Pure C.`
case "C++":
return `You are a Google Test expert. Generate COMPLETE gtest unit tests for C++ source.
- Use TEST/TEST_F, EXPECT_*/ASSERT_*
- TYPED_TEST_SUITE if templates
- Cover ALL classes/fns: happy/edge/error
- #include <gtest/gtest.h>
Respond ONLY full test_*.cpp: includes, tests. Pure C++.`
default:
return ""
}
}
func getTestFilePath(filePath, lang string) string {
ext := filepath.Ext(filePath)
base := strings.TrimSuffix(filepath.Base(filePath), ext)
switch lang {
case "Go":
dir := filepath.Dir(filePath)
return filepath.Join(dir, base+"_test.go")
case "Python":
return filepath.Join(filepath.Dir(filePath), "test_"+base+".py")
case "C":
return filepath.Join(filepath.Dir(filePath), "test_"+base+".c")
case "C++":
return filepath.Join(filepath.Dir(filePath), "test_"+base+".cpp")
default:
return ""
}
}
func getCodeLang(lang string) string {
switch lang {
case "Go":
return "go"
case "Python":
return "python"
case "C", "C++":
return "c"
default:
return "text"
}
}
func init() {
testgenCmd.Flags().BoolP("yes", "y", false, "Auto-apply without confirmation")
}

161
cmd/testgen_test.go Normal file
View File

@ -0,0 +1,161 @@
package cmd
import (
"strings"
"testing"
)
func TestRemoveSourceComments(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
lang string
}{
{
name: "no comments",
input: `package cmd
import "testing"
func Foo() {}`,
want: `package cmd
import "testing"
func Foo() {}`,
lang: "Go",
},
{
name: "last modified",
input: `// Last modified: 2026-03-02
package cmd`,
want: `package cmd`,
lang: "Go",
},
{
name: "generated by",
input: `// Generated by grokkit testgen
package cmd`,
want: `package cmd`,
lang: "Go",
},
{
name: "multiple removable lines",
input: `line1
// Last modified: foo
line3
// Generated by: bar
line5`,
want: `line1
line3
line5`,
lang: "Go",
},
{
name: "partial match no remove",
input: `// Modified something else
package cmd`,
want: `// Modified something else
package cmd`,
lang: "Go",
},
{
name: "python testgen",
input: `# testgen: generated
def foo(): pass`,
want: `def foo(): pass`,
lang: "Python",
},
{
name: "c testgen",
input: `/* testgen */
int foo() {}`,
want: `int foo() {}`,
lang: "C",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := removeSourceComments(tt.input, tt.lang)
if got != tt.want {
t.Errorf("removeSourceComments() =\n%q\nwant\n%q", got, tt.want)
}
})
}
}
func TestGetTestPrompt(t *testing.T) {
t.Parallel()
tests := []struct {
lang string
wantPrefix string
}{
{"Go", "You are an expert Go testing specialist."},
{"Python", "You are a pytest expert."},
{"C", "You are a C unit testing expert using Check framework."},
{"C++", "You are a Google Test expert."},
{"Invalid", ""},
}
for _, tt := range tests {
t.Run(tt.lang, func(t *testing.T) {
got := getTestPrompt(tt.lang)
if tt.wantPrefix != "" && !strings.HasPrefix(got, tt.wantPrefix) {
t.Errorf("getTestPrompt(%q) prefix =\n%q\nwant %q", tt.lang, got[:100], tt.wantPrefix)
}
if tt.wantPrefix == "" && got != "" {
t.Errorf("getTestPrompt(%q) = %q, want empty", tt.lang, got)
}
})
}
}
func TestGetTestFilePath(t *testing.T) {
t.Parallel()
tests := []struct {
filePath string
lang string
want string
}{
{"foo.go", "Go", "foo_test.go"},
{"dir/foo.py", "Python", "dir/test_foo.py"},
{"bar.c", "C", "test_bar.c"},
{"baz.cpp", "C++", "test_baz.cpp"},
}
for _, tt := range tests {
t.Run(tt.filePath+"_"+tt.lang, func(t *testing.T) {
got := getTestFilePath(tt.filePath, tt.lang)
if got != tt.want {
t.Errorf("getTestFilePath(%q, %q) = %q, want %q", tt.filePath, tt.lang, got, tt.want)
}
})
}
}
func TestGetCodeLang(t *testing.T) {
t.Parallel()
tests := []struct {
lang string
want string
}{
{"Go", "go"},
{"Python", "python"},
{"C", "c"},
{"C++", "c"},
}
for _, tt := range tests {
t.Run(tt.lang, func(t *testing.T) {
got := getCodeLang(tt.lang)
if got != tt.want {
t.Errorf("getCodeLang(%q) = %q, want %q", tt.lang, got, tt.want)
}
})
}
}