grokkit/cmd/testgen.go
Gregory Gauthier d1eaa5234b
All checks were successful
CI / Test (push) Successful in 28s
CI / Lint (push) Successful in 18s
CI / Build (push) Successful in 15s
refactor(testgen): update Go prompt to enforce unit + live test pattern
- Modify user message to reference new pattern in system prompt
- Revise Go test prompt for exact unit/live structure matching scaffold_test.go
- Update testgen_test.go to match new prompt prefix
2026-03-03 16:52:50 +00:00

290 lines
8.9 KiB
Go

package cmd
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/fatih/color"
"github.com/spf13/cobra"
"gmgauthier.com/grokkit/config"
"gmgauthier.com/grokkit/internal/grok"
"gmgauthier.com/grokkit/internal/linter"
"gmgauthier.com/grokkit/internal/logger"
)
var testgenCmd = &cobra.Command{
Use: "testgen PATHS...",
Short: "Generate AI unit tests for files (Go/Python/C/C++, preview/apply)",
Long: `Generates comprehensive unit tests matching language conventions.
Supported: Go (table-driven), Python (pytest), C (Check), C++ (Google Test).
Examples:
grokkit testgen internal/grok/client.go
grokkit testgen app.py
grokkit testgen foo.c --yes`,
Args: cobra.MinimumNArgs(1),
SilenceUsage: true,
Run: func(cmd *cobra.Command, args []string) {
yesFlag, _ := cmd.Flags().GetBool("yes")
modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel("testgen", modelFlag)
logger.Info("testgen started", "num_files", len(args), "model", model, "auto_apply", yesFlag)
client := grok.NewClient()
supportedLangs := map[string]bool{"Go": true, "Python": true, "C": true, "C++": true}
allSuccess := true
for _, filePath := range args {
langObj, err := linter.DetectLanguage(filePath)
if err != nil {
color.Red("Failed to detect language for %s: %v", filePath, err)
allSuccess = false
continue
}
lang := langObj.Name
if !supportedLangs[lang] {
color.Yellow("Unsupported lang '%s' for %s (supported: Go/Python/C/C++)", lang, filePath)
allSuccess = false
continue
}
prompt := getTestPrompt(lang)
if err := processTestgenFile(client, filePath, lang, prompt, model, yesFlag); err != nil {
allSuccess = false
color.Red("Failed %s: %v", filePath, err)
logger.Error("processTestgenFile failed", "file", filePath, "error", err)
}
}
if allSuccess {
color.Green("\n✅ All test generations complete!")
color.Yellow("Next steps:\n make test\n make test-cover")
} else {
color.Red("\n❌ Some files failed.")
os.Exit(1)
}
},
}
func processTestgenFile(client *grok.Client, filePath, lang, systemPrompt, model string, yesFlag bool) error {
// Validate source file
srcInfo, err := os.Stat(filePath)
if err != nil {
return fmt.Errorf("source file not found: %s", filePath)
}
if srcInfo.IsDir() {
return fmt.Errorf("directories not supported: %s (use files only)", filePath)
}
origSrc, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("read source: %w", err)
}
cleanSrc := removeSourceComments(string(origSrc), lang)
testPath := getTestFilePath(filePath, lang)
// Handle existing test file
var origTest []byte
testExists := true
testInfo, err := os.Stat(testPath)
if os.IsNotExist(err) {
testExists = false
} else if err != nil {
return fmt.Errorf("stat test file: %w", err)
} else if testInfo.IsDir() {
return fmt.Errorf("test path is dir: %s", testPath)
} else {
origTest, err = os.ReadFile(testPath)
if err != nil {
return fmt.Errorf("read existing test: %w", err)
}
}
// Backup existing test
backupPath := testPath + ".bak"
if testExists {
if err := os.WriteFile(backupPath, origTest, 0644); err != nil {
return fmt.Errorf("backup test file: %w", err)
}
color.Yellow("💾 Backup: %s", backupPath)
}
// Generate tests
codeLang := getCodeLang(lang)
messages := []map[string]string{
{
"role": "system",
"content": systemPrompt,
},
{
"role": "user",
"content": fmt.Sprintf("Source file %s (%s):\n```%s\n%s\n```\n\nGenerate the test file using the new Unit + Live integration pattern described in the system prompt.", filepath.Base(filePath), lang, codeLang, cleanSrc)},
}
color.Yellow("🤖 Generating tests for %s → %s...", filepath.Base(filePath), filepath.Base(testPath))
rawResponse := client.StreamSilent(messages, model)
newTestCode := grok.CleanCodeResponse(rawResponse)
if len(newTestCode) == 0 || strings.TrimSpace(newTestCode) == "" {
return fmt.Errorf("empty generation response")
}
// Preview
color.Cyan("\n┌─ Preview: %s ────────────────────────────────────────────────", filepath.Base(testPath))
if testExists {
fmt.Println("--- a/" + filepath.Base(testPath))
} else {
fmt.Println("--- /dev/null")
}
fmt.Println("+++ b/" + filepath.Base(testPath))
fmt.Print(newTestCode)
color.Cyan("\n└────────────────────────────────────────────────────────────────")
if !yesFlag {
fmt.Print("\nApply? [y/N]: ")
var confirm string
_, err := fmt.Scanln(&confirm)
if err != nil {
return fmt.Errorf("input error: %w", err)
}
confirm = strings.TrimSpace(strings.ToLower(confirm))
if confirm != "y" && confirm != "yes" {
color.Yellow("⏭️ Skipped %s (backup: %s)", testPath, backupPath)
return nil
}
}
// Apply
if err := os.WriteFile(testPath, []byte(newTestCode), 0644); err != nil {
return fmt.Errorf("write test file: %w", err)
}
color.Green("✅ Wrote %s (%d bytes)", testPath, len(newTestCode))
logger.Debug("testgen applied", "test_file", testPath, "bytes", len(newTestCode), "source_bytes", len(origSrc))
return nil
}
func removeSourceComments(content, lang string) string {
lines := strings.Split(content, "\n")
var cleanedLines []string
for _, line := range lines {
if strings.Contains(line, "Last modified") || strings.Contains(line, "Generated by") ||
strings.Contains(line, "Generated by testgen") {
continue
}
if lang == "Python" && strings.HasPrefix(strings.TrimSpace(line), "# testgen:") {
continue
}
if (lang == "C" || lang == "C++") && strings.Contains(line, "/* testgen */") {
continue
}
cleanedLines = append(cleanedLines, line)
}
return strings.Join(cleanedLines, "\n")
}
func getTestPrompt(lang string) string {
switch lang {
case "Go":
return `You are an expert Go test writer for the Grokkit project.
Generate a COMPLETE, production-ready *_test.go file that follows this EXACT pattern:
1. Fast unit test(s) named TestXXX_Unit (or just TestXXX for the main test)
- Zero API calls
- Uses t.TempDir() + os.Chdir with proper defer restore (and error logging)
- Uses testify/assert and require
- Always runs instantly on "make test"
2. Optional live integration test named TestXXX_Live
- Skipped by default with a clear t.Skip message showing the exact command to run it
- Only runs when the user passes -short
- Does real Grok API calls (exactly like the scaffold live test)
- Includes t.Log messages so it never looks hung
Match the style of cmd/scaffold_test.go exactly:
- Clear t.Logf messages
- Table-driven where it makes sense
- Proper temp-dir isolation
- No unused imports
- defer restore of working directory with error logging
Return ONLY the full Go test file. No explanations, no markdown, no backticks, no extra text whatsoever.
Current project style (match imports, helpers, error handling):
{{.Context}}`
case "Python":
return `You are a pytest expert. Generate COMPLETE pytest unit tests for the Python source.
- Use pytest fixtures where appropriate
- @pytest.mark.parametrize for tables
- Cover ALL functions/classes/methods: happy/edge/error cases
- pytest.raises for exceptions
- Modern Python 3.12+: type hints, match/case if applicable
- NO external deps unless source requires
Respond ONLY with full test_*.py file: imports, fixtures, tests. Pure Python test code.`
case "C":
return `You are a C unit testing expert using Check framework. Generate COMPLETE unit tests for C source.
- Use Check suite: Suite, tcase_begin/end, ck_assert_* macros
- Cover ALL functions: happy/edge/error cases
- Include #include <check.h> <minunit.h>? Use Check std.
- main() runner if needed.
Respond ONLY full test_*.c: headers, suite funcs, main. Pure C.`
case "C++":
return `You are a Google Test expert. Generate COMPLETE gtest unit tests for C++ source.
- Use TEST/TEST_F, EXPECT_*/ASSERT_*
- TYPED_TEST_SUITE if templates
- Cover ALL classes/fns: happy/edge/error
- #include <gtest/gtest.h>
Respond ONLY full test_*.cpp: includes, tests. Pure C++.`
default:
return ""
}
}
func getTestFilePath(filePath, lang string) string {
ext := filepath.Ext(filePath)
base := strings.TrimSuffix(filepath.Base(filePath), ext)
switch lang {
case "Go":
dir := filepath.Dir(filePath)
return filepath.Join(dir, base+"_test.go")
case "Python":
return filepath.Join(filepath.Dir(filePath), "test_"+base+".py")
case "C":
return filepath.Join(filepath.Dir(filePath), "test_"+base+".c")
case "C++":
return filepath.Join(filepath.Dir(filePath), "test_"+base+".cpp")
default:
return ""
}
}
func getCodeLang(lang string) string {
switch lang {
case "Go":
return "go"
case "Python":
return "python"
case "C", "C++":
return "c"
default:
return "text"
}
}
func init() {
testgenCmd.Flags().BoolP("yes", "y", false, "Auto-apply without confirmation")
}