grokkit/cmd/testgen.go

279 lines
8.5 KiB
Go
Raw Permalink Normal View History

package cmd
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/fatih/color"
"github.com/spf13/cobra"
"gmgauthier.com/grokkit/config"
"gmgauthier.com/grokkit/internal/grok"
"gmgauthier.com/grokkit/internal/linter"
"gmgauthier.com/grokkit/internal/logger"
)
var testgenCmd = &cobra.Command{
Use: "testgen PATHS...",
Short: "Generate AI unit tests for files (Go/Python/C/C++, preview/apply)",
Long: `Generates comprehensive unit tests matching language conventions.
Supported: Go (table-driven), Python (pytest), C (Check), C++ (Google Test).
Examples:
grokkit testgen internal/grok/client.go
grokkit testgen app.py
grokkit testgen foo.c --yes`,
Args: cobra.MinimumNArgs(1),
SilenceUsage: true,
Run: func(cmd *cobra.Command, args []string) {
yesFlag, _ := cmd.Flags().GetBool("yes")
modelFlag, _ := cmd.Flags().GetString("model")
model := config.GetModel("testgen", modelFlag)
logger.Info("testgen started", "num_files", len(args), "model", model, "auto_apply", yesFlag)
client := grok.NewClient()
supportedLangs := map[string]bool{"Go": true, "Python": true, "C": true, "C++": true}
allSuccess := true
for _, filePath := range args {
langObj, err := linter.DetectLanguage(filePath)
if err != nil {
color.Red("Failed to detect language for %s: %v", filePath, err)
allSuccess = false
continue
}
lang := langObj.Name
if !supportedLangs[lang] {
color.Yellow("Unsupported lang '%s' for %s (supported: Go/Python/C/C++)", lang, filePath)
allSuccess = false
continue
}
prompt := getTestPrompt(lang)
if err := processTestgenFile(client, filePath, lang, prompt, model, yesFlag); err != nil {
allSuccess = false
color.Red("Failed %s: %v", filePath, err)
logger.Error("processTestgenFile failed", "file", filePath, "error", err)
}
}
if allSuccess {
color.Green("\n✅ All test generations complete!")
color.Yellow("Next steps:\n make test\n make test-cover")
} else {
color.Red("\n❌ Some files failed.")
os.Exit(1)
}
},
}
func processTestgenFile(client *grok.Client, filePath, lang, systemPrompt, model string, yesFlag bool) error {
// Validate source file
srcInfo, err := os.Stat(filePath)
if err != nil {
return fmt.Errorf("source file not found: %s", filePath)
}
if srcInfo.IsDir() {
return fmt.Errorf("directories not supported: %s (use files only)", filePath)
}
origSrc, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("read source: %w", err)
}
cleanSrc := removeSourceComments(string(origSrc), lang)
testPath := getTestFilePath(filePath, lang)
// Handle existing test file
var origTest []byte
testExists := true
testInfo, err := os.Stat(testPath)
if os.IsNotExist(err) {
testExists = false
} else if err != nil {
return fmt.Errorf("stat test file: %w", err)
} else if testInfo.IsDir() {
return fmt.Errorf("test path is dir: %s", testPath)
} else {
origTest, err = os.ReadFile(testPath)
if err != nil {
return fmt.Errorf("read existing test: %w", err)
}
}
// Backup existing test
backupPath := testPath + ".bak"
if testExists {
if err := os.WriteFile(backupPath, origTest, 0644); err != nil {
return fmt.Errorf("backup test file: %w", err)
}
color.Yellow("💾 Backup: %s", backupPath)
}
// Generate tests
codeLang := getCodeLang(lang)
messages := []map[string]string{
{
"role": "system",
"content": systemPrompt,
},
{
"role": "user",
"content": fmt.Sprintf("Source file %s (%s):\n```%s\n%s\n```\n\nGenerate the language-appropriate test file.", filepath.Base(filePath), lang, codeLang, cleanSrc),
},
}
color.Yellow("🤖 Generating tests for %s → %s...", filepath.Base(filePath), filepath.Base(testPath))
rawResponse := client.StreamSilent(messages, model)
newTestCode := grok.CleanCodeResponse(rawResponse)
if len(newTestCode) == 0 || strings.TrimSpace(newTestCode) == "" {
return fmt.Errorf("empty generation response")
}
// Preview
color.Cyan("\n┌─ Preview: %s ────────────────────────────────────────────────", filepath.Base(testPath))
if testExists {
fmt.Println("--- a/" + filepath.Base(testPath))
} else {
fmt.Println("--- /dev/null")
}
fmt.Println("+++ b/" + filepath.Base(testPath))
fmt.Print(newTestCode)
color.Cyan("\n└────────────────────────────────────────────────────────────────")
if !yesFlag {
fmt.Print("\nApply? [y/N]: ")
var confirm string
_, err := fmt.Scanln(&confirm)
if err != nil {
return fmt.Errorf("input error: %w", err)
}
confirm = strings.TrimSpace(strings.ToLower(confirm))
if confirm != "y" && confirm != "yes" {
color.Yellow("⏭️ Skipped %s (backup: %s)", testPath, backupPath)
return nil
}
}
// Apply
if err := os.WriteFile(testPath, []byte(newTestCode), 0644); err != nil {
return fmt.Errorf("write test file: %w", err)
}
color.Green("✅ Wrote %s (%d bytes)", testPath, len(newTestCode))
logger.Debug("testgen applied", "test_file", testPath, "bytes", len(newTestCode), "source_bytes", len(origSrc))
return nil
}
func removeSourceComments(content, lang string) string {
lines := strings.Split(content, "\n")
var cleanedLines []string
for _, line := range lines {
if strings.Contains(line, "Last modified") || strings.Contains(line, "Generated by") ||
strings.Contains(line, "Generated by testgen") {
continue
}
if lang == "Python" && strings.HasPrefix(strings.TrimSpace(line), "# testgen:") {
continue
}
if (lang == "C" || lang == "C++") && strings.Contains(line, "/* testgen */") {
continue
}
cleanedLines = append(cleanedLines, line)
}
return strings.Join(cleanedLines, "\n")
}
func getTestPrompt(lang string) string {
switch lang {
case "Go":
return `You are an expert Go testing specialist. Generate COMPLETE, production-ready unit tests for the provided Go source:
- Table-driven with t.Run(subtest, func(t *testing.T)) and t.Parallel()
- Match exact style of gmgauthier/grokkit/internal/version/version_test.go
- Use t.Context() for ctx-aware tests
- Modern Go 1.24+: slices.Contains/IndexFunc, maps.Keys, errors.Is/Join, any, etc.
- Cover ALL public funcs/methods/fields: happy path, edges, errors, panics, zero values
- Realistic inputs, no external deps/mocks unless code requires
- func TestXxx(t *testing.T) { ... } only
Respond with ONLY the full *_test.go file:
- Correct package name (infer from code)
- Necessary imports
- No benchmarks unless obvious perf func
- NO prose, explanations, markdown, code blocks, or extra text. Pure Go test file.`
case "Python":
return `You are a pytest expert. Generate COMPLETE pytest unit tests for the Python source.
- Use pytest fixtures where appropriate
- @pytest.mark.parametrize for tables
- Cover ALL functions/classes/methods: happy/edge/error cases
- pytest.raises for exceptions
- Modern Python 3.12+: type hints, match/case if applicable
- NO external deps unless source requires
Respond ONLY with full test_*.py file: imports, fixtures, tests. Pure Python test code.`
case "C":
return `You are a C unit testing expert using Check framework. Generate COMPLETE unit tests for C source.
- Use Check suite: Suite, tcase_begin/end, ck_assert_* macros
- Cover ALL functions: happy/edge/error cases
- Include #include <check.h> <minunit.h>? Use Check std.
- main() runner if needed.
Respond ONLY full test_*.c: headers, suite funcs, main. Pure C.`
case "C++":
return `You are a Google Test expert. Generate COMPLETE gtest unit tests for C++ source.
- Use TEST/TEST_F, EXPECT_*/ASSERT_*
- TYPED_TEST_SUITE if templates
- Cover ALL classes/fns: happy/edge/error
- #include <gtest/gtest.h>
Respond ONLY full test_*.cpp: includes, tests. Pure C++.`
default:
return ""
}
}
func getTestFilePath(filePath, lang string) string {
ext := filepath.Ext(filePath)
base := strings.TrimSuffix(filepath.Base(filePath), ext)
switch lang {
case "Go":
dir := filepath.Dir(filePath)
return filepath.Join(dir, base+"_test.go")
case "Python":
return filepath.Join(filepath.Dir(filePath), "test_"+base+".py")
case "C":
return filepath.Join(filepath.Dir(filePath), "test_"+base+".c")
case "C++":
return filepath.Join(filepath.Dir(filePath), "test_"+base+".cpp")
default:
return ""
}
}
func getCodeLang(lang string) string {
switch lang {
case "Go":
return "go"
case "Python":
return "python"
case "C", "C++":
return "c"
default:
return "text"
}
}
func init() {
testgenCmd.Flags().BoolP("yes", "y", false, "Auto-apply without confirmation")
}