|
| 1 | +package diagnostic |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "encoding/json" |
| 7 | + "fmt" |
| 8 | + "io" |
| 9 | + "net/http" |
| 10 | + "time" |
| 11 | +) |
| 12 | + |
| 13 | +// LLMClient handles communication with local LLM (Ollama/vLLM). |
| 14 | +type LLMClient struct { |
| 15 | + BaseURL string |
| 16 | + Model string |
| 17 | + Temperature float64 |
| 18 | + MaxTokens int |
| 19 | + HTTPClient *http.Client |
| 20 | +} |
| 21 | + |
| 22 | +// NewLLMClient creates a new LLM client. |
| 23 | +// Example: |
| 24 | +// |
| 25 | +// client := NewLLMClient("http://localhost:11434", "qwen3-coder:32b") |
| 26 | +func NewLLMClient(baseURL, model string) *LLMClient { |
| 27 | + return &LLMClient{ |
| 28 | + BaseURL: baseURL, |
| 29 | + Model: model, |
| 30 | + Temperature: 0.0, // Deterministic |
| 31 | + MaxTokens: 2000, |
| 32 | + HTTPClient: &http.Client{ |
| 33 | + Timeout: 120 * time.Second, |
| 34 | + }, |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +// AnalyzeFunction sends a function to the LLM for pattern discovery and test generation. |
| 39 | +// Returns structured analysis result or error. |
| 40 | +// |
| 41 | +// Performance: ~2-5 seconds per function (depends on function size) |
| 42 | +// |
| 43 | +// Example: |
| 44 | +// |
| 45 | +// client := NewLLMClient("http://localhost:11434", "qwen3-coder:32b") |
| 46 | +// result, err := client.AnalyzeFunction(functionMetadata) |
| 47 | +// if err != nil { |
| 48 | +// log.Printf("LLM analysis failed: %v", err) |
| 49 | +// return nil, err |
| 50 | +// } |
| 51 | +// fmt.Printf("Found %d sources, %d sinks, %d test cases\n", |
| 52 | +// len(result.DiscoveredPatterns.Sources), |
| 53 | +// len(result.DiscoveredPatterns.Sinks), |
| 54 | +// len(result.DataflowTestCases)) |
| 55 | +func (c *LLMClient) AnalyzeFunction(fn *FunctionMetadata) (*LLMAnalysisResult, error) { |
| 56 | + startTime := time.Now() |
| 57 | + |
| 58 | + // Build prompt |
| 59 | + prompt := BuildAnalysisPrompt(fn.SourceCode) |
| 60 | + |
| 61 | + // Call LLM |
| 62 | + responseText, err := c.callOllama(prompt) |
| 63 | + if err != nil { |
| 64 | + return nil, fmt.Errorf("LLM call failed: %w", err) |
| 65 | + } |
| 66 | + |
| 67 | + // Parse JSON response |
| 68 | + var result LLMAnalysisResult |
| 69 | + err = json.Unmarshal([]byte(responseText), &result) |
| 70 | + if err != nil { |
| 71 | + return nil, fmt.Errorf("failed to parse LLM response: %w\nResponse: %s", err, responseText) |
| 72 | + } |
| 73 | + |
| 74 | + // Add metadata |
| 75 | + result.FunctionFQN = fn.FQN |
| 76 | + result.AnalysisMetadata.ProcessingTime = time.Since(startTime).String() |
| 77 | + result.AnalysisMetadata.ModelUsed = c.Model |
| 78 | + |
| 79 | + // Validate result |
| 80 | + if err := c.validateResult(&result); err != nil { |
| 81 | + return nil, fmt.Errorf("invalid LLM result: %w", err) |
| 82 | + } |
| 83 | + |
| 84 | + return &result, nil |
| 85 | +} |
| 86 | + |
| 87 | +// callOllama makes HTTP request to Ollama API. |
| 88 | +func (c *LLMClient) callOllama(prompt string) (string, error) { |
| 89 | + // Ollama API format |
| 90 | + requestBody := map[string]interface{}{ |
| 91 | + "model": c.Model, |
| 92 | + "prompt": prompt, |
| 93 | + "stream": false, |
| 94 | + "options": map[string]interface{}{ |
| 95 | + "temperature": c.Temperature, |
| 96 | + "num_predict": c.MaxTokens, |
| 97 | + }, |
| 98 | + "format": "json", // Request JSON output |
| 99 | + } |
| 100 | + |
| 101 | + jsonBody, err := json.Marshal(requestBody) |
| 102 | + if err != nil { |
| 103 | + return "", fmt.Errorf("failed to marshal request: %w", err) |
| 104 | + } |
| 105 | + |
| 106 | + // Make request |
| 107 | + url := c.BaseURL + "/api/generate" |
| 108 | + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, bytes.NewBuffer(jsonBody)) |
| 109 | + if err != nil { |
| 110 | + return "", fmt.Errorf("failed to create request: %w", err) |
| 111 | + } |
| 112 | + req.Header.Set("Content-Type", "application/json") |
| 113 | + |
| 114 | + resp, err := c.HTTPClient.Do(req) |
| 115 | + if err != nil { |
| 116 | + return "", fmt.Errorf("HTTP request failed: %w", err) |
| 117 | + } |
| 118 | + defer resp.Body.Close() |
| 119 | + |
| 120 | + if resp.StatusCode != http.StatusOK { |
| 121 | + bodyBytes, _ := io.ReadAll(resp.Body) |
| 122 | + return "", fmt.Errorf("LLM returned status %d: %s", resp.StatusCode, string(bodyBytes)) |
| 123 | + } |
| 124 | + |
| 125 | + // Read response |
| 126 | + bodyBytes, err := io.ReadAll(resp.Body) |
| 127 | + if err != nil { |
| 128 | + return "", fmt.Errorf("failed to read response: %w", err) |
| 129 | + } |
| 130 | + |
| 131 | + // Parse Ollama response format |
| 132 | + var ollamaResp struct { |
| 133 | + Response string `json:"response"` |
| 134 | + Done bool `json:"done"` |
| 135 | + } |
| 136 | + err = json.Unmarshal(bodyBytes, &ollamaResp) |
| 137 | + if err != nil { |
| 138 | + return "", fmt.Errorf("failed to parse Ollama response: %w", err) |
| 139 | + } |
| 140 | + |
| 141 | + return ollamaResp.Response, nil |
| 142 | +} |
| 143 | + |
| 144 | +// validateResult checks that LLM result has required fields. |
| 145 | +func (c *LLMClient) validateResult(result *LLMAnalysisResult) error { |
| 146 | + if result.AnalysisMetadata.Confidence < 0.0 || result.AnalysisMetadata.Confidence > 1.0 { |
| 147 | + return fmt.Errorf("invalid confidence: %f", result.AnalysisMetadata.Confidence) |
| 148 | + } |
| 149 | + |
| 150 | + // Validate test cases |
| 151 | + for i, tc := range result.DataflowTestCases { |
| 152 | + if tc.Source.Line <= 0 { |
| 153 | + return fmt.Errorf("test case %d: invalid source line %d", i, tc.Source.Line) |
| 154 | + } |
| 155 | + if tc.Sink.Line <= 0 { |
| 156 | + return fmt.Errorf("test case %d: invalid sink line %d", i, tc.Sink.Line) |
| 157 | + } |
| 158 | + if tc.Confidence < 0.0 || tc.Confidence > 1.0 { |
| 159 | + return fmt.Errorf("test case %d: invalid confidence %f", i, tc.Confidence) |
| 160 | + } |
| 161 | + } |
| 162 | + |
| 163 | + return nil |
| 164 | +} |
| 165 | + |
| 166 | +// AnalyzeBatch analyzes multiple functions in parallel. |
| 167 | +// Returns results map (FQN -> result) and errors map (FQN -> error). |
| 168 | +// |
| 169 | +// Performance: 4-8 parallel workers, ~30-60 minutes for 10k functions |
| 170 | +// |
| 171 | +// Example: |
| 172 | +// |
| 173 | +// client := NewLLMClient("http://localhost:11434", "qwen3-coder:32b") |
| 174 | +// results, errors := client.AnalyzeBatch(functions, 4) |
| 175 | +// fmt.Printf("Analyzed %d functions, %d errors\n", len(results), len(errors)) |
| 176 | +func (c *LLMClient) AnalyzeBatch(functions []*FunctionMetadata, concurrency int) (map[string]*LLMAnalysisResult, map[string]error) { |
| 177 | + results := make(map[string]*LLMAnalysisResult) |
| 178 | + errors := make(map[string]error) |
| 179 | + |
| 180 | + // Channel for work |
| 181 | + workChan := make(chan *FunctionMetadata, len(functions)) |
| 182 | + resultChan := make(chan struct { |
| 183 | + fqn string |
| 184 | + result *LLMAnalysisResult |
| 185 | + err error |
| 186 | + }, len(functions)) |
| 187 | + |
| 188 | + // Start workers |
| 189 | + for i := 0; i < concurrency; i++ { |
| 190 | + go func() { |
| 191 | + for fn := range workChan { |
| 192 | + result, err := c.AnalyzeFunction(fn) |
| 193 | + resultChan <- struct { |
| 194 | + fqn string |
| 195 | + result *LLMAnalysisResult |
| 196 | + err error |
| 197 | + }{fn.FQN, result, err} |
| 198 | + } |
| 199 | + }() |
| 200 | + } |
| 201 | + |
| 202 | + // Send work |
| 203 | + for _, fn := range functions { |
| 204 | + workChan <- fn |
| 205 | + } |
| 206 | + close(workChan) |
| 207 | + |
| 208 | + // Collect results |
| 209 | + for i := 0; i < len(functions); i++ { |
| 210 | + res := <-resultChan |
| 211 | + if res.err != nil { |
| 212 | + errors[res.fqn] = res.err |
| 213 | + } else { |
| 214 | + results[res.fqn] = res.result |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + return results, errors |
| 219 | +} |
0 commit comments