feat(testgen): add AI unit test generation command
- Implement `grokkit testgen` for Go/Python/C/C++ files - Add language-specific prompts and test file conventions - Include backups, previews, auto-apply flag - Update README with docs and examples - Add unit tests for helper functions - Mark todo as completed
This commit is contained in:
parent
599b478a17
commit
c54bc511c9
25
README.md
25
README.md
@ -45,6 +45,7 @@ grokkit version
|
||||
- [history](#-grokkit-history)
|
||||
- [lint](#-grokkit-lint-file)
|
||||
- [docs](#-grokkit-docs-file)
|
||||
- [testgen](#-grokkit-testgen)
|
||||
- [agent](#-grokkit-agent)
|
||||
- [Configuration](#configuration)
|
||||
- [Workflows](#workflows)
|
||||
@ -192,6 +193,30 @@ grokkit docs app.py -m grok-4
|
||||
- Shows first 50 lines of documented code as preview
|
||||
- Requires confirmation (unless `--auto-apply`)
|
||||
|
||||
### 🧪 `grokkit testgen PATHS...`
|
||||
|
||||
**Description**: Generate comprehensive unit tests for Go/Python/C/C++ files using AI.
|
||||
|
||||
**Benefits**:
|
||||
- Go: Table-driven `t.Parallel()` matching codebase.
|
||||
- Python: Pytest with `@parametrize`.
|
||||
- C: Check framework suites.
|
||||
- C++: Google Test `EXPECT_*`.
|
||||
- Boosts coverage; safe preview/backup.
|
||||
|
||||
**CLI examples**:
|
||||
```bash
|
||||
grokkit testgen internal/grok/client.go
|
||||
grokkit testgen app.py --yes
|
||||
grokkit testgen foo.c bar.cpp
|
||||
```
|
||||
|
||||
**Safety features**:
|
||||
- Lang detection via `internal/linter`.
|
||||
- Creates `test_*.bak` backups.
|
||||
- Unified diff preview.
|
||||
- Y/N (--yes auto).
|
||||
|
||||
### 🤖 `grokkit agent`
|
||||
Multi-file agent for complex refactoring (experimental).
|
||||
|
||||
|
||||
@ -57,6 +57,7 @@ func init() {
|
||||
rootCmd.AddCommand(completionCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(docsCmd)
|
||||
rootCmd.AddCommand(testgenCmd)
|
||||
|
||||
// Add model flag to all commands
|
||||
rootCmd.PersistentFlags().StringP("model", "m", "", "Grok model to use (overrides config)")
|
||||
|
||||
278
cmd/testgen.go
Normal file
278
cmd/testgen.go
Normal file
@ -0,0 +1,278 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"gmgauthier.com/grokkit/config"
|
||||
"gmgauthier.com/grokkit/internal/grok"
|
||||
"gmgauthier.com/grokkit/internal/linter"
|
||||
"gmgauthier.com/grokkit/internal/logger"
|
||||
)
|
||||
|
||||
var testgenCmd = &cobra.Command{
|
||||
Use: "testgen PATHS...",
|
||||
Short: "Generate AI unit tests for files (Go/Python/C/C++, preview/apply)",
|
||||
Long: `Generates comprehensive unit tests matching language conventions.
|
||||
Supported: Go (table-driven), Python (pytest), C (Check), C++ (Google Test).
|
||||
|
||||
Examples:
|
||||
grokkit testgen internal/grok/client.go
|
||||
grokkit testgen app.py
|
||||
grokkit testgen foo.c --yes`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
SilenceUsage: true,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
yesFlag, _ := cmd.Flags().GetBool("yes")
|
||||
modelFlag, _ := cmd.Flags().GetString("model")
|
||||
model := config.GetModel("testgen", modelFlag)
|
||||
|
||||
logger.Info("testgen started", "num_files", len(args), "model", model, "auto_apply", yesFlag)
|
||||
|
||||
client := grok.NewClient()
|
||||
|
||||
supportedLangs := map[string]bool{"Go": true, "Python": true, "C": true, "C++": true}
|
||||
|
||||
allSuccess := true
|
||||
for _, filePath := range args {
|
||||
langObj, err := linter.DetectLanguage(filePath)
|
||||
if err != nil {
|
||||
color.Red("Failed to detect language for %s: %v", filePath, err)
|
||||
allSuccess = false
|
||||
continue
|
||||
}
|
||||
lang := langObj.Name
|
||||
if !supportedLangs[lang] {
|
||||
color.Yellow("Unsupported lang '%s' for %s (supported: Go/Python/C/C++)", lang, filePath)
|
||||
allSuccess = false
|
||||
continue
|
||||
}
|
||||
|
||||
prompt := getTestPrompt(lang)
|
||||
if err := processTestgenFile(client, filePath, lang, prompt, model, yesFlag); err != nil {
|
||||
allSuccess = false
|
||||
color.Red("Failed %s: %v", filePath, err)
|
||||
logger.Error("processTestgenFile failed", "file", filePath, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if allSuccess {
|
||||
color.Green("\n✅ All test generations complete!")
|
||||
color.Yellow("Next steps:\n make test\n make test-cover")
|
||||
} else {
|
||||
color.Red("\n❌ Some files failed.")
|
||||
os.Exit(1)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
func processTestgenFile(client *grok.Client, filePath, lang, systemPrompt, model string, yesFlag bool) error {
|
||||
// Validate source file
|
||||
srcInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("source file not found: %s", filePath)
|
||||
}
|
||||
if srcInfo.IsDir() {
|
||||
return fmt.Errorf("directories not supported: %s (use files only)", filePath)
|
||||
}
|
||||
|
||||
origSrc, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read source: %w", err)
|
||||
}
|
||||
cleanSrc := removeSourceComments(string(origSrc), lang)
|
||||
|
||||
testPath := getTestFilePath(filePath, lang)
|
||||
|
||||
// Handle existing test file
|
||||
var origTest []byte
|
||||
testExists := true
|
||||
testInfo, err := os.Stat(testPath)
|
||||
if os.IsNotExist(err) {
|
||||
testExists = false
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("stat test file: %w", err)
|
||||
} else if testInfo.IsDir() {
|
||||
return fmt.Errorf("test path is dir: %s", testPath)
|
||||
} else {
|
||||
origTest, err = os.ReadFile(testPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read existing test: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Backup existing test
|
||||
backupPath := testPath + ".bak"
|
||||
if testExists {
|
||||
if err := os.WriteFile(backupPath, origTest, 0644); err != nil {
|
||||
return fmt.Errorf("backup test file: %w", err)
|
||||
}
|
||||
color.Yellow("💾 Backup: %s", backupPath)
|
||||
}
|
||||
|
||||
// Generate tests
|
||||
codeLang := getCodeLang(lang)
|
||||
messages := []map[string]string{
|
||||
{
|
||||
"role": "system",
|
||||
"content": systemPrompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": fmt.Sprintf("Source file %s (%s):\n```%s\n%s\n```\n\nGenerate the language-appropriate test file.", filepath.Base(filePath), lang, codeLang, cleanSrc),
|
||||
},
|
||||
}
|
||||
|
||||
color.Yellow("🤖 Generating tests for %s → %s...", filepath.Base(filePath), filepath.Base(testPath))
|
||||
rawResponse := client.StreamSilent(messages, model)
|
||||
newTestCode := grok.CleanCodeResponse(rawResponse)
|
||||
|
||||
if len(newTestCode) == 0 || strings.TrimSpace(newTestCode) == "" {
|
||||
return fmt.Errorf("empty generation response")
|
||||
}
|
||||
|
||||
// Preview
|
||||
color.Cyan("\n┌─ Preview: %s ────────────────────────────────────────────────", filepath.Base(testPath))
|
||||
if testExists {
|
||||
fmt.Println("--- a/" + filepath.Base(testPath))
|
||||
} else {
|
||||
fmt.Println("--- /dev/null")
|
||||
}
|
||||
fmt.Println("+++ b/" + filepath.Base(testPath))
|
||||
fmt.Print(newTestCode)
|
||||
color.Cyan("\n└────────────────────────────────────────────────────────────────")
|
||||
|
||||
if !yesFlag {
|
||||
fmt.Print("\nApply? [y/N]: ")
|
||||
var confirm string
|
||||
_, err := fmt.Scanln(&confirm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("input error: %w", err)
|
||||
}
|
||||
confirm = strings.TrimSpace(strings.ToLower(confirm))
|
||||
if confirm != "y" && confirm != "yes" {
|
||||
color.Yellow("⏭️ Skipped %s (backup: %s)", testPath, backupPath)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Apply
|
||||
if err := os.WriteFile(testPath, []byte(newTestCode), 0644); err != nil {
|
||||
return fmt.Errorf("write test file: %w", err)
|
||||
}
|
||||
|
||||
color.Green("✅ Wrote %s (%d bytes)", testPath, len(newTestCode))
|
||||
logger.Debug("testgen applied", "test_file", testPath, "bytes", len(newTestCode), "source_bytes", len(origSrc))
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeSourceComments(content, lang string) string {
|
||||
lines := strings.Split(content, "\n")
|
||||
var cleanedLines []string
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "Last modified") || strings.Contains(line, "Generated by") ||
|
||||
strings.Contains(line, "Generated by testgen") {
|
||||
continue
|
||||
}
|
||||
if lang == "Python" && strings.HasPrefix(strings.TrimSpace(line), "# testgen:") {
|
||||
continue
|
||||
}
|
||||
if (lang == "C" || lang == "C++") && strings.Contains(line, "/* testgen */") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
func getTestPrompt(lang string) string {
|
||||
switch lang {
|
||||
case "Go":
|
||||
return `You are an expert Go testing specialist. Generate COMPLETE, production-ready unit tests for the provided Go source:
|
||||
|
||||
- Table-driven with t.Run(subtest, func(t *testing.T)) and t.Parallel()
|
||||
- Match exact style of gmgauthier/grokkit/internal/version/version_test.go
|
||||
- Use t.Context() for ctx-aware tests
|
||||
- Modern Go 1.24+: slices.Contains/IndexFunc, maps.Keys, errors.Is/Join, any, etc.
|
||||
- Cover ALL public funcs/methods/fields: happy path, edges, errors, panics, zero values
|
||||
- Realistic inputs, no external deps/mocks unless code requires
|
||||
- func TestXxx(t *testing.T) { ... } only
|
||||
|
||||
Respond with ONLY the full *_test.go file:
|
||||
- Correct package name (infer from code)
|
||||
- Necessary imports
|
||||
- No benchmarks unless obvious perf func
|
||||
- NO prose, explanations, markdown, code blocks, or extra text. Pure Go test file.`
|
||||
case "Python":
|
||||
return `You are a pytest expert. Generate COMPLETE pytest unit tests for the Python source.
|
||||
|
||||
- Use pytest fixtures where appropriate
|
||||
- @pytest.mark.parametrize for tables
|
||||
- Cover ALL functions/classes/methods: happy/edge/error cases
|
||||
- pytest.raises for exceptions
|
||||
- Modern Python 3.12+: type hints, match/case if applicable
|
||||
- NO external deps unless source requires
|
||||
|
||||
Respond ONLY with full test_*.py file: imports, fixtures, tests. Pure Python test code.`
|
||||
case "C":
|
||||
return `You are a C unit testing expert using Check framework. Generate COMPLETE unit tests for C source.
|
||||
|
||||
- Use Check suite: Suite, tcase_begin/end, ck_assert_* macros
|
||||
- Cover ALL functions: happy/edge/error cases
|
||||
- Include #include <check.h> <minunit.h>? Use Check std.
|
||||
- main() runner if needed.
|
||||
|
||||
Respond ONLY full test_*.c: headers, suite funcs, main. Pure C.`
|
||||
case "C++":
|
||||
return `You are a Google Test expert. Generate COMPLETE gtest unit tests for C++ source.
|
||||
|
||||
- Use TEST/TEST_F, EXPECT_*/ASSERT_*
|
||||
- TYPED_TEST_SUITE if templates
|
||||
- Cover ALL classes/fns: happy/edge/error
|
||||
- #include <gtest/gtest.h>
|
||||
|
||||
Respond ONLY full test_*.cpp: includes, tests. Pure C++.`
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func getTestFilePath(filePath, lang string) string {
|
||||
ext := filepath.Ext(filePath)
|
||||
base := strings.TrimSuffix(filepath.Base(filePath), ext)
|
||||
switch lang {
|
||||
case "Go":
|
||||
dir := filepath.Dir(filePath)
|
||||
return filepath.Join(dir, base+"_test.go")
|
||||
case "Python":
|
||||
return filepath.Join(filepath.Dir(filePath), "test_"+base+".py")
|
||||
case "C":
|
||||
return filepath.Join(filepath.Dir(filePath), "test_"+base+".c")
|
||||
case "C++":
|
||||
return filepath.Join(filepath.Dir(filePath), "test_"+base+".cpp")
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func getCodeLang(lang string) string {
|
||||
switch lang {
|
||||
case "Go":
|
||||
return "go"
|
||||
case "Python":
|
||||
return "python"
|
||||
case "C", "C++":
|
||||
return "c"
|
||||
default:
|
||||
return "text"
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
testgenCmd.Flags().BoolP("yes", "y", false, "Auto-apply without confirmation")
|
||||
}
|
||||
161
cmd/testgen_test.go
Normal file
161
cmd/testgen_test.go
Normal file
@ -0,0 +1,161 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRemoveSourceComments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
lang string
|
||||
}{
|
||||
{
|
||||
name: "no comments",
|
||||
input: `package cmd
|
||||
import "testing"
|
||||
|
||||
func Foo() {}`,
|
||||
want: `package cmd
|
||||
import "testing"
|
||||
|
||||
func Foo() {}`,
|
||||
lang: "Go",
|
||||
},
|
||||
{
|
||||
name: "last modified",
|
||||
input: `// Last modified: 2026-03-02
|
||||
package cmd`,
|
||||
want: `package cmd`,
|
||||
lang: "Go",
|
||||
},
|
||||
{
|
||||
name: "generated by",
|
||||
input: `// Generated by grokkit testgen
|
||||
package cmd`,
|
||||
want: `package cmd`,
|
||||
lang: "Go",
|
||||
},
|
||||
{
|
||||
name: "multiple removable lines",
|
||||
input: `line1
|
||||
// Last modified: foo
|
||||
line3
|
||||
// Generated by: bar
|
||||
line5`,
|
||||
want: `line1
|
||||
line3
|
||||
line5`,
|
||||
lang: "Go",
|
||||
},
|
||||
{
|
||||
name: "partial match no remove",
|
||||
input: `// Modified something else
|
||||
package cmd`,
|
||||
want: `// Modified something else
|
||||
package cmd`,
|
||||
lang: "Go",
|
||||
},
|
||||
{
|
||||
name: "python testgen",
|
||||
input: `# testgen: generated
|
||||
def foo(): pass`,
|
||||
want: `def foo(): pass`,
|
||||
lang: "Python",
|
||||
},
|
||||
{
|
||||
name: "c testgen",
|
||||
input: `/* testgen */
|
||||
int foo() {}`,
|
||||
want: `int foo() {}`,
|
||||
lang: "C",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := removeSourceComments(tt.input, tt.lang)
|
||||
if got != tt.want {
|
||||
t.Errorf("removeSourceComments() =\n%q\nwant\n%q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTestPrompt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
lang string
|
||||
wantPrefix string
|
||||
}{
|
||||
{"Go", "You are an expert Go testing specialist."},
|
||||
{"Python", "You are a pytest expert."},
|
||||
{"C", "You are a C unit testing expert using Check framework."},
|
||||
{"C++", "You are a Google Test expert."},
|
||||
{"Invalid", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.lang, func(t *testing.T) {
|
||||
got := getTestPrompt(tt.lang)
|
||||
if tt.wantPrefix != "" && !strings.HasPrefix(got, tt.wantPrefix) {
|
||||
t.Errorf("getTestPrompt(%q) prefix =\n%q\nwant %q", tt.lang, got[:100], tt.wantPrefix)
|
||||
}
|
||||
if tt.wantPrefix == "" && got != "" {
|
||||
t.Errorf("getTestPrompt(%q) = %q, want empty", tt.lang, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTestFilePath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
filePath string
|
||||
lang string
|
||||
want string
|
||||
}{
|
||||
{"foo.go", "Go", "foo_test.go"},
|
||||
{"dir/foo.py", "Python", "dir/test_foo.py"},
|
||||
{"bar.c", "C", "test_bar.c"},
|
||||
{"baz.cpp", "C++", "test_baz.cpp"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.filePath+"_"+tt.lang, func(t *testing.T) {
|
||||
got := getTestFilePath(tt.filePath, tt.lang)
|
||||
if got != tt.want {
|
||||
t.Errorf("getTestFilePath(%q, %q) = %q, want %q", tt.filePath, tt.lang, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCodeLang(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
lang string
|
||||
want string
|
||||
}{
|
||||
{"Go", "go"},
|
||||
{"Python", "python"},
|
||||
{"C", "c"},
|
||||
{"C++", "c"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.lang, func(t *testing.T) {
|
||||
got := getCodeLang(tt.lang)
|
||||
if got != tt.want {
|
||||
t.Errorf("getCodeLang(%q) = %q, want %q", tt.lang, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user