grokkit/internal/grok/client.go

139 lines
3.4 KiB
Go

package grok
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"github.com/fatih/color"
)
type Client struct {
APIKey string
BaseURL string
}
func NewClient() *Client {
key := os.Getenv("XAI_API_KEY")
if key == "" {
color.Red("Error: XAI_API_KEY environment variable not set")
os.Exit(1)
}
return &Client{
APIKey: key,
BaseURL: "https://api.x.ai/v1",
}
}
func (c *Client) Stream(messages []map[string]string, model string) string {
url := c.BaseURL + "/chat/completions"
payload := map[string]interface{}{
"model": model,
"messages": messages,
"temperature": 0.7,
"stream": true,
}
body, _ := json.Marshal(payload)
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Authorization", "Bearer "+c.APIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
color.Red("Request failed: %v", err)
os.Exit(1)
}
if resp.StatusCode != http.StatusOK {
bodyBytes, readErr := io.ReadAll(resp.Body)
if readErr != nil {
color.Red("Failed to read response body: %v", readErr)
}
color.Red("API failed with status %d: %s", resp.StatusCode, string(bodyBytes))
resp.Body.Close()
os.Exit(1)
}
defer resp.Body.Close()
var fullReply strings.Builder
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
data := line[6:]
if data == "[DONE]" {
break
}
var chunk map[string]interface{}
if json.Unmarshal([]byte(data), &chunk) == nil {
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
if delta, ok := choices[0].(map[string]interface{})["delta"].(map[string]interface{}); ok {
if content, ok := delta["content"].(string); ok {
fmt.Print(content)
fullReply.WriteString(content)
}
}
}
}
}
}
fmt.Println()
return fullReply.String()
}
func (c *Client) StreamChan(messages []map[string]string, model string) <-chan string {
ch := make(chan string, 100)
go func() {
defer close(ch)
url := c.BaseURL + "/chat/completions"
payload := map[string]interface{}{
"model": model,
"messages": messages,
"temperature": 0.7,
"stream": true,
}
body, _ := json.Marshal(payload)
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Authorization", "Bearer "+c.APIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
color.Red("Request failed: %v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
color.Red("API failed with status %d: %s", resp.StatusCode, string(bodyBytes))
return
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
data := line[6:]
if data == "[DONE]" {
break
}
var chunk map[string]interface{}
if json.Unmarshal([]byte(data), &chunk) != nil {
continue
}
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
if delta, ok := choices[0].(map[string]interface{})["delta"].(map[string]interface{}); ok {
if content, ok := delta["content"].(string); ok && content != "" {
ch <- content
}
}
}
}
}
}()
return ch
}