package cmd import ( "fmt" "os" "path/filepath" "strings" "github.com/fatih/color" "github.com/spf13/cobra" "gmgauthier.com/grokkit/config" "gmgauthier.com/grokkit/internal/grok" "gmgauthier.com/grokkit/internal/linter" "gmgauthier.com/grokkit/internal/logger" ) var testgenCmd = &cobra.Command{ Use: "testgen PATHS...", Short: "Generate AI unit tests for files (Go/Python/C/C++, preview/apply)", Long: `Generates comprehensive unit tests matching language conventions. Supported: Go (table-driven), Python (pytest), C (Check), C++ (Google Test). Examples: grokkit testgen internal/grok/client.go grokkit testgen app.py grokkit testgen foo.c --yes`, Args: cobra.MinimumNArgs(1), SilenceUsage: true, Run: func(cmd *cobra.Command, args []string) { yesFlag, _ := cmd.Flags().GetBool("yes") modelFlag, _ := cmd.Flags().GetString("model") model := config.GetModel("testgen", modelFlag) logger.Info("testgen started", "num_files", len(args), "model", model, "auto_apply", yesFlag) client := grok.NewClient() supportedLangs := map[string]bool{"Go": true, "Python": true, "C": true, "C++": true} allSuccess := true for _, filePath := range args { langObj, err := linter.DetectLanguage(filePath) if err != nil { color.Red("Failed to detect language for %s: %v", filePath, err) allSuccess = false continue } lang := langObj.Name if !supportedLangs[lang] { color.Yellow("Unsupported lang '%s' for %s (supported: Go/Python/C/C++)", lang, filePath) allSuccess = false continue } prompt := getTestPrompt(lang) if err := processTestgenFile(client, filePath, lang, prompt, model, yesFlag); err != nil { allSuccess = false color.Red("Failed %s: %v", filePath, err) logger.Error("processTestgenFile failed", "file", filePath, "error", err) } } if allSuccess { color.Green("\nāœ… All test generations complete!") color.Yellow("Next steps:\n make test\n make test-cover") } else { color.Red("\nāŒ Some files failed.") os.Exit(1) } }, } func processTestgenFile(client *grok.Client, filePath, lang, systemPrompt, model string, yesFlag bool) error { // Validate source file srcInfo, err := os.Stat(filePath) if err != nil { return fmt.Errorf("source file not found: %s", filePath) } if srcInfo.IsDir() { return fmt.Errorf("directories not supported: %s (use files only)", filePath) } origSrc, err := os.ReadFile(filePath) if err != nil { return fmt.Errorf("read source: %w", err) } cleanSrc := removeSourceComments(string(origSrc), lang) testPath := getTestFilePath(filePath, lang) // Handle existing test file var origTest []byte testExists := true testInfo, err := os.Stat(testPath) if os.IsNotExist(err) { testExists = false } else if err != nil { return fmt.Errorf("stat test file: %w", err) } else if testInfo.IsDir() { return fmt.Errorf("test path is dir: %s", testPath) } else { origTest, err = os.ReadFile(testPath) if err != nil { return fmt.Errorf("read existing test: %w", err) } } // Backup existing test backupPath := testPath + ".bak" if testExists { if err := os.WriteFile(backupPath, origTest, 0644); err != nil { return fmt.Errorf("backup test file: %w", err) } color.Yellow("šŸ’¾ Backup: %s", backupPath) } // Generate tests codeLang := getCodeLang(lang) messages := []map[string]string{ { "role": "system", "content": systemPrompt, }, { "role": "user", "content": fmt.Sprintf("Source file %s (%s):\n```%s\n%s\n```\n\nGenerate the language-appropriate test file.", filepath.Base(filePath), lang, codeLang, cleanSrc), }, } color.Yellow("šŸ¤– Generating tests for %s → %s...", filepath.Base(filePath), filepath.Base(testPath)) rawResponse := client.StreamSilent(messages, model) newTestCode := grok.CleanCodeResponse(rawResponse) if len(newTestCode) == 0 || strings.TrimSpace(newTestCode) == "" { return fmt.Errorf("empty generation response") } // Preview color.Cyan("\nā”Œā”€ Preview: %s ────────────────────────────────────────────────", filepath.Base(testPath)) if testExists { fmt.Println("--- a/" + filepath.Base(testPath)) } else { fmt.Println("--- /dev/null") } fmt.Println("+++ b/" + filepath.Base(testPath)) fmt.Print(newTestCode) color.Cyan("\n└────────────────────────────────────────────────────────────────") if !yesFlag { fmt.Print("\nApply? [y/N]: ") var confirm string _, err := fmt.Scanln(&confirm) if err != nil { return fmt.Errorf("input error: %w", err) } confirm = strings.TrimSpace(strings.ToLower(confirm)) if confirm != "y" && confirm != "yes" { color.Yellow("ā­ļø Skipped %s (backup: %s)", testPath, backupPath) return nil } } // Apply if err := os.WriteFile(testPath, []byte(newTestCode), 0644); err != nil { return fmt.Errorf("write test file: %w", err) } color.Green("āœ… Wrote %s (%d bytes)", testPath, len(newTestCode)) logger.Debug("testgen applied", "test_file", testPath, "bytes", len(newTestCode), "source_bytes", len(origSrc)) return nil } func removeSourceComments(content, lang string) string { lines := strings.Split(content, "\n") var cleanedLines []string for _, line := range lines { if strings.Contains(line, "Last modified") || strings.Contains(line, "Generated by") || strings.Contains(line, "Generated by testgen") { continue } if lang == "Python" && strings.HasPrefix(strings.TrimSpace(line), "# testgen:") { continue } if (lang == "C" || lang == "C++") && strings.Contains(line, "/* testgen */") { continue } cleanedLines = append(cleanedLines, line) } return strings.Join(cleanedLines, "\n") } func getTestPrompt(lang string) string { switch lang { case "Go": return `You are an expert Go testing specialist. Generate COMPLETE, production-ready unit tests for the provided Go source: - Table-driven with t.Run(subtest, func(t *testing.T)) and t.Parallel() - Match exact style of gmgauthier/grokkit/internal/version/version_test.go - Use t.Context() for ctx-aware tests - Modern Go 1.24+: slices.Contains/IndexFunc, maps.Keys, errors.Is/Join, any, etc. - Cover ALL public funcs/methods/fields: happy path, edges, errors, panics, zero values - Realistic inputs, no external deps/mocks unless code requires - func TestXxx(t *testing.T) { ... } only Respond with ONLY the full *_test.go file: - Correct package name (infer from code) - Necessary imports - No benchmarks unless obvious perf func - NO prose, explanations, markdown, code blocks, or extra text. Pure Go test file.` case "Python": return `You are a pytest expert. Generate COMPLETE pytest unit tests for the Python source. - Use pytest fixtures where appropriate - @pytest.mark.parametrize for tables - Cover ALL functions/classes/methods: happy/edge/error cases - pytest.raises for exceptions - Modern Python 3.12+: type hints, match/case if applicable - NO external deps unless source requires Respond ONLY with full test_*.py file: imports, fixtures, tests. Pure Python test code.` case "C": return `You are a C unit testing expert using Check framework. Generate COMPLETE unit tests for C source. - Use Check suite: Suite, tcase_begin/end, ck_assert_* macros - Cover ALL functions: happy/edge/error cases - Include #include ? Use Check std. - main() runner if needed. Respond ONLY full test_*.c: headers, suite funcs, main. Pure C.` case "C++": return `You are a Google Test expert. Generate COMPLETE gtest unit tests for C++ source. - Use TEST/TEST_F, EXPECT_*/ASSERT_* - TYPED_TEST_SUITE if templates - Cover ALL classes/fns: happy/edge/error - #include Respond ONLY full test_*.cpp: includes, tests. Pure C++.` default: return "" } } func getTestFilePath(filePath, lang string) string { ext := filepath.Ext(filePath) base := strings.TrimSuffix(filepath.Base(filePath), ext) switch lang { case "Go": dir := filepath.Dir(filePath) return filepath.Join(dir, base+"_test.go") case "Python": return filepath.Join(filepath.Dir(filePath), "test_"+base+".py") case "C": return filepath.Join(filepath.Dir(filePath), "test_"+base+".c") case "C++": return filepath.Join(filepath.Dir(filePath), "test_"+base+".cpp") default: return "" } } func getCodeLang(lang string) string { switch lang { case "Go": return "go" case "Python": return "python" case "C", "C++": return "c" default: return "text" } } func init() { testgenCmd.Flags().BoolP("yes", "y", false, "Auto-apply without confirmation") }