- Remove Grokkit-specific references to make the prompt more versatile. - Update unit and live test patterns for broader applicability, including standard testing.T usage. - Adjust test name derivation and skip messages for consistency. - Sync test assertions in testgen_test.go with the updated prompt.
295 lines
9.2 KiB
Go
295 lines
9.2 KiB
Go
package cmd
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/fatih/color"
|
|
"github.com/spf13/cobra"
|
|
|
|
"gmgauthier.com/grokkit/config"
|
|
"gmgauthier.com/grokkit/internal/grok"
|
|
"gmgauthier.com/grokkit/internal/linter"
|
|
"gmgauthier.com/grokkit/internal/logger"
|
|
)
|
|
|
|
var testgenCmd = &cobra.Command{
|
|
Use: "testgen PATHS...",
|
|
Short: "Generate AI unit tests for files (Go/Python/C/C++, preview/apply)",
|
|
Long: `Generates comprehensive unit tests matching language conventions.
|
|
Supported: Go (table-driven), Python (pytest), C (Check), C++ (Google Test).
|
|
|
|
Examples:
|
|
grokkit testgen internal/grok/client.go
|
|
grokkit testgen app.py
|
|
grokkit testgen foo.c --yes`,
|
|
Args: cobra.MinimumNArgs(1),
|
|
SilenceUsage: true,
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
yesFlag, _ := cmd.Flags().GetBool("yes")
|
|
modelFlag, _ := cmd.Flags().GetString("model")
|
|
model := config.GetModel("testgen", modelFlag)
|
|
|
|
logger.Info("testgen started", "num_files", len(args), "model", model, "auto_apply", yesFlag)
|
|
|
|
client := grok.NewClient()
|
|
|
|
supportedLangs := map[string]bool{"Go": true, "Python": true, "C": true, "C++": true}
|
|
|
|
allSuccess := true
|
|
for _, filePath := range args {
|
|
langObj, err := linter.DetectLanguage(filePath)
|
|
if err != nil {
|
|
color.Red("Failed to detect language for %s: %v", filePath, err)
|
|
allSuccess = false
|
|
continue
|
|
}
|
|
lang := langObj.Name
|
|
if !supportedLangs[lang] {
|
|
color.Yellow("Unsupported lang '%s' for %s (supported: Go/Python/C/C++)", lang, filePath)
|
|
allSuccess = false
|
|
continue
|
|
}
|
|
|
|
prompt := getTestPrompt(lang)
|
|
if err := processTestgenFile(client, filePath, lang, prompt, model, yesFlag); err != nil {
|
|
allSuccess = false
|
|
color.Red("Failed %s: %v", filePath, err)
|
|
logger.Error("processTestgenFile failed", "file", filePath, "error", err)
|
|
}
|
|
}
|
|
|
|
if allSuccess {
|
|
color.Green("\n✅ All test generations complete!")
|
|
color.Yellow("Next steps:\n make test\n make test-cover")
|
|
} else {
|
|
color.Red("\n❌ Some files failed.")
|
|
os.Exit(1)
|
|
}
|
|
},
|
|
}
|
|
|
|
func processTestgenFile(client *grok.Client, filePath, lang, systemPrompt, model string, yesFlag bool) error {
|
|
// Validate source file
|
|
srcInfo, err := os.Stat(filePath)
|
|
if err != nil {
|
|
return fmt.Errorf("source file not found: %s", filePath)
|
|
}
|
|
if srcInfo.IsDir() {
|
|
return fmt.Errorf("directories not supported: %s (use files only)", filePath)
|
|
}
|
|
|
|
origSrc, err := os.ReadFile(filePath)
|
|
if err != nil {
|
|
return fmt.Errorf("read source: %w", err)
|
|
}
|
|
cleanSrc := removeSourceComments(string(origSrc), lang)
|
|
|
|
testPath := getTestFilePath(filePath, lang)
|
|
|
|
// Handle existing test file
|
|
var origTest []byte
|
|
testExists := true
|
|
testInfo, err := os.Stat(testPath)
|
|
if os.IsNotExist(err) {
|
|
testExists = false
|
|
} else if err != nil {
|
|
return fmt.Errorf("stat test file: %w", err)
|
|
} else if testInfo.IsDir() {
|
|
return fmt.Errorf("test path is dir: %s", testPath)
|
|
} else {
|
|
origTest, err = os.ReadFile(testPath)
|
|
if err != nil {
|
|
return fmt.Errorf("read existing test: %w", err)
|
|
}
|
|
}
|
|
|
|
// Backup existing test
|
|
backupPath := testPath + ".bak"
|
|
if testExists {
|
|
if err := os.WriteFile(backupPath, origTest, 0644); err != nil {
|
|
return fmt.Errorf("backup test file: %w", err)
|
|
}
|
|
color.Yellow("💾 Backup: %s", backupPath)
|
|
}
|
|
|
|
// Generate tests
|
|
codeLang := getCodeLang(lang)
|
|
messages := []map[string]string{
|
|
{
|
|
"role": "system",
|
|
"content": systemPrompt,
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": fmt.Sprintf("Source file %s (%s):\n```%s\n%s\n```\n\nGenerate the test file using the new Unit + Live integration pattern described in the system prompt.", filepath.Base(filePath), lang, codeLang, cleanSrc)},
|
|
}
|
|
|
|
color.Yellow("🤖 Generating tests for %s → %s...", filepath.Base(filePath), filepath.Base(testPath))
|
|
rawResponse := client.StreamSilent(messages, model)
|
|
newTestCode := grok.CleanCodeResponse(rawResponse)
|
|
|
|
if len(newTestCode) == 0 || strings.TrimSpace(newTestCode) == "" {
|
|
return fmt.Errorf("empty generation response")
|
|
}
|
|
|
|
// Preview
|
|
color.Cyan("\n┌─ Preview: %s ────────────────────────────────────────────────", filepath.Base(testPath))
|
|
if testExists {
|
|
fmt.Println("--- a/" + filepath.Base(testPath))
|
|
} else {
|
|
fmt.Println("--- /dev/null")
|
|
}
|
|
fmt.Println("+++ b/" + filepath.Base(testPath))
|
|
fmt.Print(newTestCode)
|
|
color.Cyan("\n└────────────────────────────────────────────────────────────────")
|
|
|
|
if !yesFlag {
|
|
fmt.Print("\nApply? [y/N]: ")
|
|
var confirm string
|
|
_, err := fmt.Scanln(&confirm)
|
|
if err != nil {
|
|
return fmt.Errorf("input error: %w", err)
|
|
}
|
|
confirm = strings.TrimSpace(strings.ToLower(confirm))
|
|
if confirm != "y" && confirm != "yes" {
|
|
color.Yellow("⏭️ Skipped %s (backup: %s)", testPath, backupPath)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Apply
|
|
if err := os.WriteFile(testPath, []byte(newTestCode), 0644); err != nil {
|
|
return fmt.Errorf("write test file: %w", err)
|
|
}
|
|
|
|
color.Green("✅ Wrote %s (%d bytes)", testPath, len(newTestCode))
|
|
logger.Debug("testgen applied", "test_file", testPath, "bytes", len(newTestCode), "source_bytes", len(origSrc))
|
|
return nil
|
|
}
|
|
|
|
func removeSourceComments(content, lang string) string {
|
|
lines := strings.Split(content, "\n")
|
|
var cleanedLines []string
|
|
for _, line := range lines {
|
|
if strings.Contains(line, "Last modified") || strings.Contains(line, "Generated by") ||
|
|
strings.Contains(line, "Generated by testgen") {
|
|
continue
|
|
}
|
|
if lang == "Python" && strings.HasPrefix(strings.TrimSpace(line), "# testgen:") {
|
|
continue
|
|
}
|
|
if (lang == "C" || lang == "C++") && strings.Contains(line, "/* testgen */") {
|
|
continue
|
|
}
|
|
cleanedLines = append(cleanedLines, line)
|
|
}
|
|
return strings.Join(cleanedLines, "\n")
|
|
}
|
|
|
|
func getTestPrompt(lang string) string {
|
|
switch lang {
|
|
case "Go":
|
|
return `You are an expert Go test writer.
|
|
|
|
Generate a COMPLETE, production-ready *_test.go file that follows this modern Unit + Live test pattern:
|
|
|
|
1. Fast unit test (always runs on "go test"):
|
|
func TestXXX_Unit(t *testing.T) { // or TestCreateIniFile if it fits better
|
|
t.Parallel()
|
|
t.Log("✓ Fast XXX unit test")
|
|
|
|
// Zero external calls
|
|
// Use t.TempDir() + os.Chdir with proper defer restore + error logging
|
|
// Use standard testing.T (add testify/assert only if the project already uses it)
|
|
// Cover happy path, errors, edge cases
|
|
// Table-driven where it makes sense
|
|
}
|
|
|
|
2. Optional live integration test (skipped by default):
|
|
func TestXXX_Live(t *testing.T) {
|
|
if !testing.Short() {
|
|
t.Skip("skipping live integration test. Run with:\n go test -run TestXXX_Live -short -v")
|
|
}
|
|
t.Log("🧪 Running live integration test...")
|
|
// Real behavior (file I/O, exec, API calls, etc.)
|
|
// Use t.Logf messages so it never looks hung
|
|
}
|
|
|
|
Exact rules:
|
|
- Derive sensible test names from the source filename and functions (e.g. filer.go → TestCreateIniFile_Unit / TestCreateIniFile_Live)
|
|
- The XXX in the t.Skip command MUST exactly match the live test function name you created
|
|
- t.Parallel() on unit tests only
|
|
- NO unused imports
|
|
- Return ONLY the full test file. No explanations, no markdown, no backticks, no extra text whatsoever.`
|
|
case "Python":
|
|
return `You are a pytest expert. Generate COMPLETE pytest unit tests for the Python source.
|
|
|
|
- Use pytest fixtures where appropriate
|
|
- @pytest.mark.parametrize for tables
|
|
- Cover ALL functions/classes/methods: happy/edge/error cases
|
|
- pytest.raises for exceptions
|
|
- Modern Python 3.12+: type hints, match/case if applicable
|
|
- NO external deps unless source requires
|
|
|
|
Respond ONLY with full test_*.py file: imports, fixtures, tests. Pure Python test code.`
|
|
case "C":
|
|
return `You are a C unit testing expert using Check framework. Generate COMPLETE unit tests for C source.
|
|
|
|
- Use Check suite: Suite, tcase_begin/end, ck_assert_* macros
|
|
- Cover ALL functions: happy/edge/error cases
|
|
- Include #include <check.h>
|
|
- main() runner if needed.
|
|
|
|
Respond ONLY full test_*.c: headers, suite funcs, main. Pure C.`
|
|
case "C++":
|
|
return `You are a Google Test expert. Generate COMPLETE gtest unit tests for C++ source.
|
|
|
|
- Use TEST/TEST_F, EXPECT_*/ASSERT_*
|
|
- TYPED_TEST_SUITE if templates
|
|
- Cover ALL classes/fns: happy/edge/error
|
|
- #include <gtest/gtest.h>
|
|
|
|
Respond ONLY full test_*.cpp: includes, tests. Pure C++.`
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func getTestFilePath(filePath, lang string) string {
|
|
ext := filepath.Ext(filePath)
|
|
base := strings.TrimSuffix(filepath.Base(filePath), ext)
|
|
switch lang {
|
|
case "Go":
|
|
dir := filepath.Dir(filePath)
|
|
return filepath.Join(dir, base+"_test.go")
|
|
case "Python":
|
|
return filepath.Join(filepath.Dir(filePath), "test_"+base+".py")
|
|
case "C":
|
|
return filepath.Join(filepath.Dir(filePath), "test_"+base+".c")
|
|
case "C++":
|
|
return filepath.Join(filepath.Dir(filePath), "test_"+base+".cpp")
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func getCodeLang(lang string) string {
|
|
switch lang {
|
|
case "Go":
|
|
return "go"
|
|
case "Python":
|
|
return "python"
|
|
case "C", "C++":
|
|
return "c"
|
|
default:
|
|
return "text"
|
|
}
|
|
}
|
|
|
|
func init() {
|
|
testgenCmd.Flags().BoolP("yes", "y", false, "Auto-apply without confirmation")
|
|
}
|