diff --git a/sourcecode-parser/cmd/diagnose.go b/sourcecode-parser/cmd/diagnose.go new file mode 100644 index 00000000..16f90acf --- /dev/null +++ b/sourcecode-parser/cmd/diagnose.go @@ -0,0 +1,238 @@ +package cmd + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/diagnostic" + "github.com/spf13/cobra" +) + +var diagnoseCmd = &cobra.Command{ + Use: "diagnose", + Short: "Validate intra-procedural taint analysis against LLM ground truth", + Long: `The diagnose command validates the accuracy of intra-procedural taint analysis +by comparing tool results against LLM-based ground truth analysis. + +It extracts functions, runs both tool and LLM analysis, compares results, +and generates diagnostic reports with precision, recall, and failure analysis.`, + Run: func(cmd *cobra.Command, _ []string) { + projectInput := cmd.Flag("project").Value.String() + llmURL := cmd.Flag("llm-url").Value.String() + modelName := cmd.Flag("model").Value.String() + provider := cmd.Flag("provider").Value.String() + apiKey := cmd.Flag("api-key").Value.String() + outputDir := cmd.Flag("output").Value.String() + maxFunctions, _ := cmd.Flags().GetInt("max-functions") + concurrency, _ := cmd.Flags().GetInt("concurrency") + + if projectInput == "" { + fmt.Println("Error: --project flag is required") + return + } + + startTime := time.Now() + + // Create LLM client based on provider + var llmClient *diagnostic.LLMClient + if provider == "openai" { + if apiKey == "" { + fmt.Println("Error: --api-key is required for OpenAI-compatible providers") + return + } + llmClient = diagnostic.NewOpenAIClient(llmURL, modelName, apiKey) + } else { + llmClient = diagnostic.NewLLMClient(llmURL, modelName) + } + + // Step 1: Extract functions + fmt.Println("===============================================================================") + fmt.Println(" DIAGNOSTIC VALIDATION STARTING") + fmt.Println("===============================================================================") + fmt.Println() + fmt.Printf("Project: %s\n", projectInput) + fmt.Printf("LLM Endpoint: %s\n", llmURL) + fmt.Printf("Model: %s\n", modelName) + fmt.Printf("Provider: %s\n", provider) + fmt.Printf("Max Functions: %d\n", maxFunctions) + fmt.Printf("Concurrency: %d\n", concurrency) + fmt.Println() + + fmt.Println("Step 1/4: Extracting functions from codebase...") + + functions, err := diagnostic.ExtractAllFunctions(projectInput) + if err != nil { + fmt.Printf("Error extracting functions: %v\n", err) + return + } + + // Limit to maxFunctions if specified + if maxFunctions > 0 && len(functions) > maxFunctions { + functions = functions[:maxFunctions] + } + + fmt.Printf("✓ Extracted %d functions\n", len(functions)) + fmt.Println() + + // Step 2: LLM Analysis + fmt.Println("Step 2/4: Running LLM analysis (this may take a while)...") + llmResults, llmErrors := llmClient.AnalyzeBatch(functions, concurrency) + fmt.Printf("✓ Analyzed %d functions (%d errors)\n", len(llmResults), len(llmErrors)) + + // Print errors (always show if there are any) + if len(llmErrors) > 0 { + fmt.Println("\n⚠️ LLM Analysis Errors:") + count := 0 + for fqn, err := range llmErrors { + if count >= 5 { + fmt.Printf(" ... and %d more errors\n", len(llmErrors)-5) + break + } + fmt.Printf(" ❌ %s:\n", fqn) + fmt.Printf(" %v\n", err) + count++ + } + fmt.Printf("\n💡 Tip: Failed responses saved to %s/llm_errors.txt\n", outputDir) + } + fmt.Println() + + // Step 3: Tool Analysis + Comparison + fmt.Println("Step 3/4: Running tool analysis and comparison...") + + comparisons := []*diagnostic.DualLevelComparison{} + functionsMap := make(map[string]*diagnostic.FunctionMetadata) + + for _, fn := range functions { + functionsMap[fn.FQN] = fn + + llmResult, hasLLM := llmResults[fn.FQN] + if !hasLLM { + continue // Skip functions with LLM errors + } + + // Extract unique source/sink/sanitizer patterns from LLM-discovered patterns + sourcePatterns := make(map[string]bool) + sinkPatterns := make(map[string]bool) + sanitizerPatterns := make(map[string]bool) + + for _, src := range llmResult.DiscoveredPatterns.Sources { + sourcePatterns[src.Pattern] = true + } + for _, snk := range llmResult.DiscoveredPatterns.Sinks { + sinkPatterns[snk.Pattern] = true + } + for _, san := range llmResult.DiscoveredPatterns.Sanitizers { + sanitizerPatterns[san.Pattern] = true + } + + // Convert to slices and clean patterns + // Strip () from patterns since tool matching doesn't expect them + sources := []string{} + for pattern := range sourcePatterns { + cleanPattern := strings.TrimSuffix(pattern, "()") + sources = append(sources, cleanPattern) + } + sinks := []string{} + for pattern := range sinkPatterns { + cleanPattern := strings.TrimSuffix(pattern, "()") + sinks = append(sinks, cleanPattern) + } + sanitizers := []string{} + for pattern := range sanitizerPatterns { + cleanPattern := strings.TrimSuffix(pattern, "()") + sanitizers = append(sanitizers, cleanPattern) + } + + if verboseFlag { + fmt.Printf(" %s: LLM found %d sources, %d sinks, %d sanitizers\n", + fn.FQN, len(sources), len(sinks), len(sanitizers)) + if len(sources) > 0 { + fmt.Printf(" Sources: %v\n", sources) + } + if len(sinks) > 0 { + fmt.Printf(" Sinks: %v\n", sinks) + } + } + + // If no patterns discovered, use empty lists (tool will find nothing, matching LLM) + if len(sources) == 0 && len(sinks) == 0 { + // No patterns = no flows expected + toolResult := &diagnostic.FunctionTaintResult{ + FunctionFQN: fn.FQN, + HasTaintFlow: false, + TaintFlows: []diagnostic.ToolTaintFlow{}, + } + comparison := diagnostic.CompareFunctionResults(fn, toolResult, llmResult) + comparisons = append(comparisons, comparison) + continue + } + + // Run tool with LLM-discovered patterns + toolResult, err := diagnostic.AnalyzeSingleFunction(fn, sources, sinks, sanitizers) + if err != nil { + if verboseFlag { + fmt.Printf(" Tool error for %s: %v\n", fn.FQN, err) + } + continue + } + + if verboseFlag && toolResult != nil { + fmt.Printf(" Tool found %d flows (HasTaintFlow=%v)\n", + len(toolResult.TaintFlows), toolResult.HasTaintFlow) + } + + comparison := diagnostic.CompareFunctionResults(fn, toolResult, llmResult) + comparisons = append(comparisons, comparison) + } + + fmt.Printf("✓ Compared %d functions\n", len(comparisons)) + fmt.Println() + + // Step 4: Generate Reports + fmt.Println("Step 4/4: Generating reports...") + metrics := diagnostic.CalculateOverallMetrics(comparisons, startTime) + metrics.TopFailures = diagnostic.ExtractTopFailures(comparisons, functionsMap, 5) + + // Console report + err = diagnostic.GenerateConsoleReport(metrics, outputDir) + if err != nil { + fmt.Printf("Error generating console report: %v\n", err) + return + } + + // JSON report + if outputDir != "" { + err = os.MkdirAll(outputDir, 0755) + if err != nil { + fmt.Printf("Error creating output directory: %v\n", err) + return + } + + jsonPath := filepath.Join(outputDir, "diagnostic_report.json") + err = diagnostic.GenerateJSONReport(metrics, comparisons, jsonPath) + if err != nil { + fmt.Printf("Error generating JSON report: %v\n", err) + return + } + + fmt.Printf("✓ JSON report saved to: %s\n", jsonPath) + fmt.Println() + } + }, +} + +func init() { + rootCmd.AddCommand(diagnoseCmd) + diagnoseCmd.Flags().StringP("project", "p", "", "Project directory to analyze (required)") + diagnoseCmd.Flags().String("llm-url", "http://localhost:11434", "LLM endpoint base URL") + diagnoseCmd.Flags().String("model", "qwen2.5-coder:3b", "LLM model name") + diagnoseCmd.Flags().String("provider", "ollama", "LLM provider: ollama, openai (for xAI Grok, vLLM, etc.)") + diagnoseCmd.Flags().String("api-key", "", "API key for OpenAI-compatible providers (e.g., xAI Grok)") + diagnoseCmd.Flags().StringP("output", "o", "./diagnostic_output", "Output directory for reports") + diagnoseCmd.Flags().IntP("max-functions", "m", 50, "Maximum functions to analyze") + diagnoseCmd.Flags().IntP("concurrency", "c", 3, "LLM request concurrency") + diagnoseCmd.MarkFlagRequired("project") //nolint:all +} diff --git a/sourcecode-parser/diagnostic/analyzer.go b/sourcecode-parser/diagnostic/analyzer.go new file mode 100644 index 00000000..19eaff23 --- /dev/null +++ b/sourcecode-parser/diagnostic/analyzer.go @@ -0,0 +1,250 @@ +package diagnostic + +import ( + "fmt" + "strings" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph" +) + +// FunctionTaintResult represents the structured taint analysis result for a single function. +// This is the internal API (not user-facing) used for diagnostic comparison. +type FunctionTaintResult struct { + // FunctionFQN identifies the function + FunctionFQN string + + // HasTaintFlow indicates if ANY taint flow was detected (binary result) + HasTaintFlow bool + + // TaintFlows contains all detected flows (detailed result) + TaintFlows []ToolTaintFlow + + // AnalysisError indicates if analysis failed + AnalysisError bool + + // ErrorMessage if AnalysisError == true + ErrorMessage string +} + +// ToolTaintFlow represents a single taint flow detected by our tool. +type ToolTaintFlow struct { + // Source information + SourceLine int + SourceVariable string + SourceType string // e.g., "request.GET['username']" + SourceCategory string // e.g., "user_input" (semantic) + + // Sink information + SinkLine int + SinkVariable string + SinkType string // e.g., "sqlite3.execute" + SinkCategory string // e.g., "sql_execution" (semantic) + + // Flow details + FlowPath []FlowStep + + // Metadata + VulnerabilityType string // e.g., "SQL_INJECTION" + Confidence float64 // 0.0-1.0 + IsSanitized bool // If sanitizer detected in path +} + +// AnalyzeSingleFunction runs intra-procedural taint analysis on a single function. +// This wraps existing taint analysis logic but: +// 1. Analyzes ONLY the specified function (not whole codebase) +// 2. Returns structured result (not text) +// 3. Filters to intra-procedural flows only +// +// Performance: ~1-5ms per function (depends on function size) +// +// Example: +// +// result, err := AnalyzeSingleFunction(functionMetadata, sources, sinks, sanitizers) +// if err != nil { +// log.Printf("Analysis failed: %v", err) +// return nil, err +// } +// if result.HasTaintFlow { +// fmt.Printf("Found %d flows\n", len(result.TaintFlows)) +// } +func AnalyzeSingleFunction( + fn *FunctionMetadata, + sources []string, + sinks []string, + sanitizers []string, +) (*FunctionTaintResult, error) { + result := &FunctionTaintResult{ + FunctionFQN: fn.FQN, + HasTaintFlow: false, + TaintFlows: []ToolTaintFlow{}, + } + + // Parse function source code + sourceCode := []byte(fn.SourceCode) + tree, err := callgraph.ParsePythonFile(sourceCode) + if err != nil { + result.AnalysisError = true + result.ErrorMessage = fmt.Sprintf("Parse error: %v", err) + return result, nil // Return result with error flag (don't fail) + } + + // Find the function node + functionNode := findFunctionNodeByName(tree.RootNode(), fn.FunctionName, sourceCode) + if functionNode == nil { + result.AnalysisError = true + result.ErrorMessage = "Function node not found in AST" + return result, nil + } + + // Extract statements (using existing logic from statement_extraction.go) + statements, err := callgraph.ExtractStatements(fn.FilePath, sourceCode, functionNode) + if err != nil { + result.AnalysisError = true + result.ErrorMessage = fmt.Sprintf("Statement extraction error: %v", err) + return result, nil + } + + // Build def-use chains (using existing logic from statement.go) + defUseChain := callgraph.BuildDefUseChains(statements) + + // Run taint analysis (using existing logic from taint.go) + taintSummary := callgraph.AnalyzeIntraProceduralTaint( + fn.FQN, + statements, + defUseChain, + sources, + sinks, + sanitizers, + ) + + // Check if any flows detected + if !taintSummary.HasDetections() { + return result, nil // No flows, return empty result + } + + // Convert TaintSummary detections to ToolTaintFlow + result.HasTaintFlow = true + for _, detection := range taintSummary.Detections { + // Only include if both source and sink are within function boundaries + if detection.SourceLine >= uint32(fn.StartLine) && + detection.SourceLine <= uint32(fn.EndLine) && + detection.SinkLine >= uint32(fn.StartLine) && + detection.SinkLine <= uint32(fn.EndLine) { + + flow := ToolTaintFlow{ + SourceLine: int(detection.SourceLine), + SourceVariable: detection.SourceVar, + SinkLine: int(detection.SinkLine), + SinkVariable: detection.SinkVar, + SinkType: detection.SinkCall, + Confidence: detection.Confidence, + IsSanitized: detection.Sanitized, + } + + // Build flow path from propagation path + flow.FlowPath = []FlowStep{} + for _, varName := range detection.PropagationPath { + flow.FlowPath = append(flow.FlowPath, FlowStep{ + Variable: varName, + Operation: "propagate", + }) + } + + // Categorize source and sink (semantic mapping) + flow.SourceCategory = categorizePattern(flow.SourceType, sources) + flow.SinkCategory = categorizePattern(flow.SinkType, sinks) + flow.VulnerabilityType = inferVulnerabilityType(flow.SourceCategory, flow.SinkCategory) + + result.TaintFlows = append(result.TaintFlows, flow) + } + } + + return result, nil +} + +// findFunctionNodeByName finds a function_definition node by name in the AST. +// Helper for AnalyzeSingleFunction. +func findFunctionNodeByName(node *sitter.Node, functionName string, sourceCode []byte) *sitter.Node { + if node == nil { + return nil + } + + if node.Type() == "function_definition" { + nameNode := node.ChildByFieldName("name") + if nameNode != nil && nameNode.Content(sourceCode) == functionName { + return node + } + } + + // Recurse into children + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child != nil { + result := findFunctionNodeByName(child, functionName, sourceCode) + if result != nil { + return result + } + } + } + + return nil +} + +// categorizePattern maps a pattern to a semantic category. +// Example: "request.GET" → "user_input", "os.system" → "command_exec". +func categorizePattern(pattern string, _ []string) string { + patternLower := strings.ToLower(pattern) + + // User input sources + if strings.Contains(patternLower, "request.get") || + strings.Contains(patternLower, "request.post") || + strings.Contains(patternLower, "input(") { + return "user_input" + } + + // File operations + if strings.Contains(patternLower, "open(") || + strings.Contains(patternLower, "file") { + return "file_operation" + } + + // SQL sinks + if strings.Contains(patternLower, "execute") || + strings.Contains(patternLower, "cursor") || + strings.Contains(patternLower, "sql") { + return "sql_execution" + } + + // Command execution sinks + if strings.Contains(patternLower, "system") || + strings.Contains(patternLower, "subprocess") || + strings.Contains(patternLower, "popen") { + return "command_exec" + } + + // Code execution sinks + if strings.Contains(patternLower, "eval") || + strings.Contains(patternLower, "exec") { + return "code_exec" + } + + return "other" +} + +// inferVulnerabilityType maps source+sink categories to vulnerability type. +func inferVulnerabilityType(sourceCategory, sinkCategory string) string { + if sourceCategory == "user_input" && sinkCategory == "sql_execution" { + return "SQL_INJECTION" + } + if sourceCategory == "user_input" && sinkCategory == "command_exec" { + return "COMMAND_INJECTION" + } + if sourceCategory == "user_input" && sinkCategory == "code_exec" { + return "CODE_INJECTION" + } + if sourceCategory == "user_input" && sinkCategory == "file_operation" { + return "PATH_TRAVERSAL" + } + return "TAINT_FLOW" +} diff --git a/sourcecode-parser/diagnostic/analyzer_test.go b/sourcecode-parser/diagnostic/analyzer_test.go new file mode 100644 index 00000000..56917dd0 --- /dev/null +++ b/sourcecode-parser/diagnostic/analyzer_test.go @@ -0,0 +1,149 @@ +package diagnostic + +import ( + "testing" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestAnalyzeSingleFunction_SimpleFlow tests basic taint flow detection. +func TestAnalyzeSingleFunction_SimpleFlow(t *testing.T) { + fn := &FunctionMetadata{ + FQN: "test.simple_flow", + FunctionName: "simple_flow", + FilePath: "test.py", + SourceCode: `def simple_flow(): + user_input = input() + eval(user_input)`, + StartLine: 1, + EndLine: 3, + } + + result, err := AnalyzeSingleFunction(fn, []string{"input"}, []string{"eval"}, []string{}) + require.NoError(t, err) + require.NotNil(t, result) + + assert.Equal(t, "test.simple_flow", result.FunctionFQN) + assert.True(t, result.HasTaintFlow) + assert.False(t, result.AnalysisError) + assert.GreaterOrEqual(t, len(result.TaintFlows), 1) +} + +// TestAnalyzeSingleFunction_NoFlow tests function with no taint flows. +func TestAnalyzeSingleFunction_NoFlow(t *testing.T) { + fn := &FunctionMetadata{ + FQN: "test.no_flow", + FunctionName: "no_flow", + FilePath: "test.py", + SourceCode: `def no_flow(): + x = 1 + return x`, + StartLine: 1, + EndLine: 3, + } + + result, err := AnalyzeSingleFunction(fn, []string{"input"}, []string{"eval"}, []string{}) + require.NoError(t, err) + require.NotNil(t, result) + + assert.Equal(t, "test.no_flow", result.FunctionFQN) + assert.False(t, result.HasTaintFlow) + assert.False(t, result.AnalysisError) + assert.Equal(t, 0, len(result.TaintFlows)) +} + +// TestAnalyzeSingleFunction_ParseError tests error handling for invalid syntax. +func TestAnalyzeSingleFunction_ParseError(t *testing.T) { + fn := &FunctionMetadata{ + FQN: "test.invalid", + FunctionName: "invalid", + FilePath: "test.py", + SourceCode: `def invalid( # missing closing paren`, + StartLine: 1, + EndLine: 1, + } + + result, err := AnalyzeSingleFunction(fn, []string{"input"}, []string{"eval"}, []string{}) + require.NoError(t, err) // Should not error, but set AnalysisError flag + require.NotNil(t, result) + + assert.True(t, result.AnalysisError) + // Can be either parse error or function not found (depends on tree-sitter recovery) + assert.NotEmpty(t, result.ErrorMessage) +} + +// TestCategorizePattern tests semantic pattern categorization. +func TestCategorizePattern(t *testing.T) { + tests := []struct { + pattern string + expected string + }{ + {"request.GET", "user_input"}, + {"request.POST", "user_input"}, + {"input()", "user_input"}, + {"os.system", "command_exec"}, + {"subprocess.call", "command_exec"}, + {"eval", "code_exec"}, + {"execute", "sql_execution"}, + {"open()", "file_operation"}, + {"unknown", "other"}, + } + + for _, tt := range tests { + t.Run(tt.pattern, func(t *testing.T) { + result := categorizePattern(tt.pattern, []string{}) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestInferVulnerabilityType tests vulnerability type inference. +func TestInferVulnerabilityType(t *testing.T) { + tests := []struct { + source string + sink string + expected string + }{ + {"user_input", "sql_execution", "SQL_INJECTION"}, + {"user_input", "command_exec", "COMMAND_INJECTION"}, + {"user_input", "code_exec", "CODE_INJECTION"}, + {"user_input", "file_operation", "PATH_TRAVERSAL"}, + {"other", "other", "TAINT_FLOW"}, + } + + for _, tt := range tests { + t.Run(tt.source+"_to_"+tt.sink, func(t *testing.T) { + result := inferVulnerabilityType(tt.source, tt.sink) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestFindFunctionNodeByName tests AST function finding. +func TestFindFunctionNodeByName(t *testing.T) { + sourceCode := []byte(` +def function_one(): + pass + +def function_two(): + pass +`) + + tree, err := callgraph.ParsePythonFile(sourceCode) + require.NoError(t, err) + require.NotNil(t, tree) + + // Should find function_one + node := findFunctionNodeByName(tree.RootNode(), "function_one", sourceCode) + assert.NotNil(t, node) + + // Should find function_two + node = findFunctionNodeByName(tree.RootNode(), "function_two", sourceCode) + assert.NotNil(t, node) + + // Should not find non-existent function + node = findFunctionNodeByName(tree.RootNode(), "function_three", sourceCode) + assert.Nil(t, node) +} diff --git a/sourcecode-parser/diagnostic/comparator.go b/sourcecode-parser/diagnostic/comparator.go new file mode 100644 index 00000000..62bb0894 --- /dev/null +++ b/sourcecode-parser/diagnostic/comparator.go @@ -0,0 +1,253 @@ +package diagnostic + +import ( + "strings" +) + +// DualLevelComparison represents comparison results at both binary and detailed levels. +type DualLevelComparison struct { + FunctionFQN string + + // Level 1: Binary classification + BinaryToolResult bool // Tool says: has flow + BinaryLLMResult bool // LLM says: has flow + BinaryAgreement bool // Do they agree? + + // Level 2: Detailed flow comparison (only if both found flows) + DetailedComparison *FlowComparisonResult // nil if N/A + + // Metrics + Precision float64 + Recall float64 + F1Score float64 + + // Categorization (if disagreement) + FailureCategory string // "control_flow", "sanitizer", etc. + FailureReason string // From LLM reasoning +} + +// FlowComparisonResult contains detailed flow-by-flow comparison. +type FlowComparisonResult struct { + ToolFlows []NormalizedTaintFlow + LLMFlows []NormalizedTaintFlow + + Matches []FlowMatch // TP: Both found + UnmatchedTool []NormalizedTaintFlow // FP: Tool only + UnmatchedLLM []NormalizedTaintFlow // FN: LLM only + + FlowPrecision float64 // Matches / ToolFlows + FlowRecall float64 // Matches / LLMFlows + FlowF1Score float64 // 2PR/(P+R) +} + +// FlowMatch represents a matched flow between tool and LLM. +type FlowMatch struct { + ToolFlow NormalizedTaintFlow + LLMFlow NormalizedTaintFlow + ToolIndex int + LLMIndex int +} + +// CompareFunctionResults performs dual-level comparison between tool and LLM results. +// +// Performance: ~1ms per function +// +// Example: +// +// comparison := CompareFunctionResults(fn, toolResult, llmResult) +// if comparison.BinaryAgreement { +// fmt.Println("✅ Agreement on binary level") +// } +// if comparison.DetailedComparison != nil { +// fmt.Printf("Flow precision: %.2f%%\n", comparison.Precision*100) +// } +func CompareFunctionResults( + fn *FunctionMetadata, + toolResult *FunctionTaintResult, + llmResult *LLMAnalysisResult, +) *DualLevelComparison { + // Determine if LLM detected any dataflow (not just dangerous flows) + llmHasFlow := llmResult.AnalysisMetadata.TotalFlows > 0 + + comparison := &DualLevelComparison{ + FunctionFQN: fn.FQN, + BinaryToolResult: toolResult.HasTaintFlow, + BinaryLLMResult: llmHasFlow, + BinaryAgreement: toolResult.HasTaintFlow == llmHasFlow, + } + + // Level 2: Detailed comparison (only if both found flows) + if toolResult.HasTaintFlow && llmHasFlow { + toolNorm := NormalizeToolResult(toolResult) + llmNorm := NormalizeLLMResult(llmResult) + + flowComparison := CompareNormalizedFlows(toolNorm, llmNorm, DefaultMatchConfig()) + + comparison.DetailedComparison = flowComparison + comparison.Precision = flowComparison.FlowPrecision + comparison.Recall = flowComparison.FlowRecall + comparison.F1Score = flowComparison.FlowF1Score + } else if !comparison.BinaryAgreement { + // Binary disagreement: Categorize failure + comparison.FailureCategory = categorizeFailureFromLLM(llmResult) + comparison.FailureReason = extractReasoningFromLLM(llmResult) + } + + return comparison +} + +// CompareNormalizedFlows performs detailed flow-by-flow comparison with fuzzy matching. +func CompareNormalizedFlows( + toolFlows, llmFlows []NormalizedTaintFlow, + config MatchConfig, +) *FlowComparisonResult { + result := &FlowComparisonResult{ + ToolFlows: toolFlows, + LLMFlows: llmFlows, + Matches: []FlowMatch{}, + } + + matched := make(map[int]bool) // Track which LLM flows are matched + + // For each tool flow, try to find matching LLM flow + for i, toolFlow := range toolFlows { + foundMatch := false + for j, llmFlow := range llmFlows { + if matched[j] { + continue // Already matched + } + if FlowsMatch(toolFlow, llmFlow, config) { + result.Matches = append(result.Matches, FlowMatch{ + ToolFlow: toolFlow, + LLMFlow: llmFlow, + ToolIndex: i, + LLMIndex: j, + }) + matched[j] = true + foundMatch = true + break + } + } + + if !foundMatch { + result.UnmatchedTool = append(result.UnmatchedTool, toolFlow) + } + } + + // Identify unmatched LLM flows (FN) + for i, llmFlow := range llmFlows { + if !matched[i] { + result.UnmatchedLLM = append(result.UnmatchedLLM, llmFlow) + } + } + + // Calculate metrics + if len(toolFlows) > 0 { + result.FlowPrecision = float64(len(result.Matches)) / float64(len(toolFlows)) + } + if len(llmFlows) > 0 { + result.FlowRecall = float64(len(result.Matches)) / float64(len(llmFlows)) + } + if result.FlowPrecision+result.FlowRecall > 0 { + result.FlowF1Score = 2 * result.FlowPrecision * result.FlowRecall / + (result.FlowPrecision + result.FlowRecall) + } + + return result +} + +// categorizeFailureFromLLM extracts failure category from LLM analysis. +// First tries to use LLM-provided category, falls back to keyword matching. +func categorizeFailureFromLLM(llmResult *LLMAnalysisResult) string { + // Strategy 1: Use LLM-provided category (most reliable) + for _, testCase := range llmResult.DataflowTestCases { + if testCase.FailureCategory != "" && testCase.FailureCategory != "none" { + return testCase.FailureCategory + } + } + + // Strategy 2: Fallback to keyword matching (for older LLM responses) + for _, testCase := range llmResult.DataflowTestCases { + reasoning := strings.ToLower(testCase.Reasoning) + + // Check sanitizers first (high priority issue) + if strings.Contains(reasoning, "sanitiz") || strings.Contains(reasoning, "escape") || + strings.Contains(reasoning, "quote") || strings.Contains(reasoning, "clean") || + strings.Contains(reasoning, "filter") || strings.Contains(reasoning, "validate") { + return "sanitizer_missed" + } + + // Control flow branches (high priority - common limitation) + if strings.Contains(reasoning, "if ") || strings.Contains(reasoning, "branch") || + strings.Contains(reasoning, "conditional") || strings.Contains(reasoning, "else") || + strings.Contains(reasoning, "inside") { + return "control_flow_branch" + } + + // Field sensitivity (object attribute tracking) + if (strings.Contains(reasoning, "field") || strings.Contains(reasoning, "attribute") || + strings.Contains(reasoning, "self.") || strings.Contains(reasoning, "obj.")) && + !strings.Contains(reasoning, "dict") { + return "field_sensitivity" + } + + // Container operations (list/dict/set) + if strings.Contains(reasoning, "list") || strings.Contains(reasoning, "dict") || + strings.Contains(reasoning, "append") || strings.Contains(reasoning, "array") || + strings.Contains(reasoning, "container") || strings.Contains(reasoning, "[") { + return "container_operation" + } + + // String formatting operations + if strings.Contains(reasoning, "f-string") || strings.Contains(reasoning, "format") || + strings.Contains(reasoning, "concatenat") || strings.Contains(reasoning, "join") || + strings.Contains(reasoning, "%s") || strings.Contains(reasoning, ".format(") { + return "string_formatting" + } + + // Method call propagation + if strings.Contains(reasoning, "method") || strings.Contains(reasoning, ".upper()") || + strings.Contains(reasoning, ".lower()") || strings.Contains(reasoning, ".strip()") || + strings.Contains(reasoning, "string method") { + return "method_call_propagation" + } + + // Assignment chain tracking + if strings.Contains(reasoning, "assignment") && (strings.Contains(reasoning, "chain") || + strings.Contains(reasoning, "through") || strings.Contains(reasoning, "via")) { + return "assignment_chain" + } + + // Return flow tracking + if strings.Contains(reasoning, "return") && strings.Contains(reasoning, "flow") { + return "return_flow" + } + + // Function parameter flow + if strings.Contains(reasoning, "parameter") && strings.Contains(reasoning, "flow") { + return "parameter_flow" + } + + // Complex expressions (method chains, nested calls) + if strings.Contains(reasoning, "complex") || strings.Contains(reasoning, "nested") || + strings.Contains(reasoning, "chain") || strings.Contains(reasoning, "multiple") { + return "complex_expression" + } + + // Inter-procedural (out of scope for intra-procedural analysis) + if strings.Contains(reasoning, "function call") || strings.Contains(reasoning, "called function") || + strings.Contains(reasoning, "inter-procedural") || strings.Contains(reasoning, "cross-function") { + return "context_required" + } + } + + return "unknown" +} + +// extractReasoningFromLLM gets the reasoning from first test case. +func extractReasoningFromLLM(llmResult *LLMAnalysisResult) string { + if len(llmResult.DataflowTestCases) > 0 { + return llmResult.DataflowTestCases[0].Reasoning + } + return "" +} diff --git a/sourcecode-parser/diagnostic/comparator_test.go b/sourcecode-parser/diagnostic/comparator_test.go new file mode 100644 index 00000000..bf53745c --- /dev/null +++ b/sourcecode-parser/diagnostic/comparator_test.go @@ -0,0 +1,270 @@ +package diagnostic + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestCompareFunctionResults_BinaryTP tests binary true positive (both detect flow). +func TestCompareFunctionResults_BinaryTP(t *testing.T) { + fn := &FunctionMetadata{FQN: "test.func"} + + toolResult := &FunctionTaintResult{ + HasTaintFlow: true, + TaintFlows: []ToolTaintFlow{ + { + SourceLine: 10, + SourceVariable: "x", + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "y", + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + }, + }, + } + + llmResult := &LLMAnalysisResult{ + AnalysisMetadata: AnalysisMetadata{ + DangerousFlows: 1, + TotalFlows: 1, + }, + DataflowTestCases: []DataflowTestCase{ + { + Source: TestCaseSource{Line: 10, Variable: "x"}, + Sink: TestCaseSink{Line: 20, Variable: "y"}, + ExpectedDetection: true, + }, + }, + } + + comparison := CompareFunctionResults(fn, toolResult, llmResult) + + assert.True(t, comparison.BinaryAgreement) + assert.True(t, comparison.BinaryToolResult) + assert.True(t, comparison.BinaryLLMResult) + assert.NotNil(t, comparison.DetailedComparison) +} + +// TestCompareFunctionResults_BinaryTN tests binary true negative (both say no flow). +func TestCompareFunctionResults_BinaryTN(t *testing.T) { + fn := &FunctionMetadata{FQN: "test.func"} + + toolResult := &FunctionTaintResult{ + HasTaintFlow: false, + } + + llmResult := &LLMAnalysisResult{ + AnalysisMetadata: AnalysisMetadata{ + DangerousFlows: 0, + }, + } + + comparison := CompareFunctionResults(fn, toolResult, llmResult) + + assert.True(t, comparison.BinaryAgreement) + assert.False(t, comparison.BinaryToolResult) + assert.False(t, comparison.BinaryLLMResult) + assert.Nil(t, comparison.DetailedComparison) +} + +// TestCompareFunctionResults_BinaryFP tests false positive (tool detects, LLM doesn't). +func TestCompareFunctionResults_BinaryFP(t *testing.T) { + fn := &FunctionMetadata{FQN: "test.func"} + + toolResult := &FunctionTaintResult{ + HasTaintFlow: true, + } + + llmResult := &LLMAnalysisResult{ + AnalysisMetadata: AnalysisMetadata{ + DangerousFlows: 0, + }, + DataflowTestCases: []DataflowTestCase{ + { + ExpectedDetection: false, + Reasoning: "This flow is sanitized", + }, + }, + } + + comparison := CompareFunctionResults(fn, toolResult, llmResult) + + assert.False(t, comparison.BinaryAgreement) + assert.True(t, comparison.BinaryToolResult) + assert.False(t, comparison.BinaryLLMResult) + assert.NotEmpty(t, comparison.FailureReason) +} + +// TestCompareFunctionResults_BinaryFN tests false negative (LLM detects, tool doesn't). +func TestCompareFunctionResults_BinaryFN(t *testing.T) { + fn := &FunctionMetadata{FQN: "test.func"} + + toolResult := &FunctionTaintResult{ + HasTaintFlow: false, + } + + llmResult := &LLMAnalysisResult{ + AnalysisMetadata: AnalysisMetadata{ + DangerousFlows: 1, + TotalFlows: 1, + }, + DataflowTestCases: []DataflowTestCase{ + { + ExpectedDetection: true, + Reasoning: "Flow should be detected through control flow", + FailureCategory: "control_flow_branch", + }, + }, + } + + comparison := CompareFunctionResults(fn, toolResult, llmResult) + + assert.False(t, comparison.BinaryAgreement) + assert.False(t, comparison.BinaryToolResult) + assert.True(t, comparison.BinaryLLMResult) + // Failure category should be set (might be "control_flow_branch" or "unknown" depending on reasoning) + assert.NotEmpty(t, comparison.FailureCategory) +} + +// TestCompareNormalizedFlows_AllMatch tests when all flows match. +func TestCompareNormalizedFlows_AllMatch(t *testing.T) { + toolFlows := []NormalizedTaintFlow{ + { + SourceLine: 10, + SourceVariable: "x", + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "y", + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + }, + } + + llmFlows := []NormalizedTaintFlow{ + { + SourceLine: 10, + SourceVariable: "x", + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "y", + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + }, + } + + result := CompareNormalizedFlows(toolFlows, llmFlows, DefaultMatchConfig()) + + assert.Equal(t, 1, len(result.Matches)) + assert.Equal(t, 0, len(result.UnmatchedTool)) + assert.Equal(t, 0, len(result.UnmatchedLLM)) + assert.Equal(t, 1.0, result.FlowPrecision) + assert.Equal(t, 1.0, result.FlowRecall) + assert.Equal(t, 1.0, result.FlowF1Score) +} + +// TestCompareNormalizedFlows_PartialMatch tests partial matching. +func TestCompareNormalizedFlows_PartialMatch(t *testing.T) { + toolFlows := []NormalizedTaintFlow{ + { + SourceLine: 10, + SourceVariable: "x", + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "y", + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + }, + { + SourceLine: 30, + SourceVariable: "a", + SourceCategory: "user_input", + SinkLine: 40, + SinkVariable: "b", + SinkCategory: "command_exec", + VulnerabilityType: "COMMAND_INJECTION", + }, + } + + llmFlows := []NormalizedTaintFlow{ + { + SourceLine: 10, + SourceVariable: "x", + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "y", + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + }, + } + + result := CompareNormalizedFlows(toolFlows, llmFlows, DefaultMatchConfig()) + + assert.Equal(t, 1, len(result.Matches)) + assert.Equal(t, 1, len(result.UnmatchedTool)) // Tool found extra flow + assert.Equal(t, 0, len(result.UnmatchedLLM)) + assert.Equal(t, 0.5, result.FlowPrecision) // 1/2 + assert.Equal(t, 1.0, result.FlowRecall) // 1/1 +} + +// TestCategorizeFailureFromLLM tests failure categorization. +func TestCategorizeFailureFromLLM(t *testing.T) { + tests := []struct { + reasoning string + expectedCategory string + }{ + {"Flow depends on if condition", "control_flow_branch"}, + {"Field access through self.field", "field_sensitivity"}, + {"Data is sanitized by escape function", "sanitizer_missed"}, + {"Flow through list append", "container_operation"}, + {"Flow through f-string formatting", "string_formatting"}, + {"Unknown reason", "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.expectedCategory, func(t *testing.T) { + llmResult := &LLMAnalysisResult{ + DataflowTestCases: []DataflowTestCase{ + {Reasoning: tt.reasoning}, + }, + } + + category := categorizeFailureFromLLM(llmResult) + assert.Equal(t, tt.expectedCategory, category) + }) + } +} + +// TestExtractReasoningFromLLM tests reasoning extraction. +func TestExtractReasoningFromLLM(t *testing.T) { + llmResult := &LLMAnalysisResult{ + DataflowTestCases: []DataflowTestCase{ + {Reasoning: "First reasoning"}, + {Reasoning: "Second reasoning"}, + }, + } + + reasoning := extractReasoningFromLLM(llmResult) + assert.Equal(t, "First reasoning", reasoning) + + // Empty case + emptyResult := &LLMAnalysisResult{} + reasoning = extractReasoningFromLLM(emptyResult) + assert.Empty(t, reasoning) +} + +// TestCompareNormalizedFlows_EmptyFlows tests with empty flow lists. +func TestCompareNormalizedFlows_EmptyFlows(t *testing.T) { + toolFlows := []NormalizedTaintFlow{} + llmFlows := []NormalizedTaintFlow{} + + result := CompareNormalizedFlows(toolFlows, llmFlows, DefaultMatchConfig()) + + require.NotNil(t, result) + assert.Equal(t, 0, len(result.Matches)) + assert.Equal(t, 0.0, result.FlowPrecision) + assert.Equal(t, 0.0, result.FlowRecall) + assert.Equal(t, 0.0, result.FlowF1Score) +} diff --git a/sourcecode-parser/diagnostic/llm.go b/sourcecode-parser/diagnostic/llm.go index b617e3e0..4b9d5ff9 100644 --- a/sourcecode-parser/diagnostic/llm.go +++ b/sourcecode-parser/diagnostic/llm.go @@ -7,24 +7,37 @@ import ( "fmt" "io" "net/http" + "os" + "path/filepath" "time" ) -// LLMClient handles communication with local LLM (Ollama/vLLM). +// LLMProvider represents the type of LLM provider. +type LLMProvider string + +const ( + ProviderOllama LLMProvider = "ollama" + ProviderOpenAI LLMProvider = "openai" // Also compatible with xAI Grok, vLLM, etc. +) + +// LLMClient handles communication with LLM providers (Ollama, OpenAI-compatible APIs). type LLMClient struct { + Provider LLMProvider BaseURL string Model string Temperature float64 MaxTokens int + APIKey string // For OpenAI-compatible APIs (xAI Grok, etc.) HTTPClient *http.Client } -// NewLLMClient creates a new LLM client. +// NewLLMClient creates a new LLM client for Ollama. // Example: // // client := NewLLMClient("http://localhost:11434", "qwen3-coder:32b") func NewLLMClient(baseURL, model string) *LLMClient { return &LLMClient{ + Provider: ProviderOllama, BaseURL: baseURL, Model: model, Temperature: 0.0, // Deterministic @@ -35,6 +48,24 @@ func NewLLMClient(baseURL, model string) *LLMClient { } } +// NewOpenAIClient creates a new OpenAI-compatible client (xAI Grok, vLLM, etc.). +// Example: +// +// client := NewOpenAIClient("https://api.x.ai/v1", "grok-beta", "xai-YOUR_API_KEY") +func NewOpenAIClient(baseURL, model, apiKey string) *LLMClient { + return &LLMClient{ + Provider: ProviderOpenAI, + BaseURL: baseURL, + Model: model, + APIKey: apiKey, + Temperature: 0.0, // Deterministic + MaxTokens: 4000, // Increased for complex functions + HTTPClient: &http.Client{ + Timeout: 120 * time.Second, + }, + } +} + // AnalyzeFunction sends a function to the LLM for pattern discovery and test generation. // Returns structured analysis result or error. // @@ -58,8 +89,18 @@ func (c *LLMClient) AnalyzeFunction(fn *FunctionMetadata) (*LLMAnalysisResult, e // Build prompt prompt := BuildAnalysisPrompt(fn.SourceCode) - // Call LLM - responseText, err := c.callOllama(prompt) + // Call LLM based on provider + var responseText string + var err error + switch c.Provider { + case ProviderOllama: + responseText, err = c.callOllama(prompt) + case ProviderOpenAI: + responseText, err = c.callOpenAI(prompt) + default: + return nil, fmt.Errorf("unsupported provider: %s", c.Provider) + } + if err != nil { return nil, fmt.Errorf("LLM call failed: %w", err) } @@ -68,7 +109,20 @@ func (c *LLMClient) AnalyzeFunction(fn *FunctionMetadata) (*LLMAnalysisResult, e var result LLMAnalysisResult err = json.Unmarshal([]byte(responseText), &result) if err != nil { - return nil, fmt.Errorf("failed to parse LLM response: %w\nResponse: %s", err, responseText) + // Try to extract JSON from markdown code blocks if present + responseText = extractJSONFromMarkdown(responseText) + err = json.Unmarshal([]byte(responseText), &result) + if err != nil { + // Save failed response to debug file + c.saveFailedResponse(fn.FQN, responseText, err) + + // Log first 500 chars for debugging + preview := responseText + if len(preview) > 500 { + preview = preview[:500] + } + return nil, fmt.Errorf("failed to parse LLM response: %w\nResponse preview: %s", err, preview) + } } // Add metadata @@ -141,6 +195,73 @@ func (c *LLMClient) callOllama(prompt string) (string, error) { return ollamaResp.Response, nil } +// callOpenAI makes HTTP request to OpenAI-compatible API (xAI Grok, vLLM, etc.). +func (c *LLMClient) callOpenAI(prompt string) (string, error) { + // OpenAI API format + requestBody := map[string]interface{}{ + "model": c.Model, + "messages": []map[string]string{ + { + "role": "user", + "content": prompt, + }, + }, + "temperature": c.Temperature, + "max_tokens": c.MaxTokens, + "response_format": map[string]string{"type": "json_object"}, // Request JSON output + } + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + // Make request + url := c.BaseURL + "/chat/completions" + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, bytes.NewBuffer(jsonBody)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return "", fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("LLM returned status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + // Read response + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + // Parse OpenAI response format + var openaiResp struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + err = json.Unmarshal(bodyBytes, &openaiResp) + if err != nil { + return "", fmt.Errorf("failed to parse OpenAI response: %w", err) + } + + if len(openaiResp.Choices) == 0 { + return "", fmt.Errorf("no choices in OpenAI response") + } + + return openaiResp.Choices[0].Message.Content, nil +} + // validateResult checks that LLM result has required fields. func (c *LLMClient) validateResult(result *LLMAnalysisResult) error { if result.AnalysisMetadata.Confidence < 0.0 || result.AnalysisMetadata.Confidence > 1.0 { @@ -217,3 +338,66 @@ func (c *LLMClient) AnalyzeBatch(functions []*FunctionMetadata, concurrency int) return results, errors } + +// extractJSONFromMarkdown extracts JSON from markdown code blocks. +func extractJSONFromMarkdown(text string) string { + // Try to find JSON between ```json and ``` + start := -1 + end := -1 + + // Look for ```json + jsonMarker := "```json" + idx := len(jsonMarker) + if len(text) > idx && text[:idx] == jsonMarker { + start = idx + } + + // Look for closing ``` + if start != -1 { + closingMarker := "```" + closeIdx := len(text) - len(closingMarker) + if closeIdx > start && text[closeIdx:] == closingMarker { + end = closeIdx + } + } + + if start != -1 && end != -1 { + return text[start:end] + } + + // Try plain ``` markers + markers := []int{} + for i := 0; i < len(text)-2; i++ { + if text[i:i+3] == "```" { + markers = append(markers, i) + } + } + + if len(markers) >= 2 { + // Return content between first and last markers + return text[markers[0]+3 : markers[len(markers)-1]] + } + + return text +} + +// saveFailedResponse saves failed LLM response to debug file. +func (c *LLMClient) saveFailedResponse(fqn, responseText string, parseErr error) { + // Create debug directory + debugDir := "/tmp/diagnostic_llm_errors" + os.MkdirAll(debugDir, 0755) + + // Create timestamped filename + timestamp := time.Now().Format("20060102_150405") + filename := filepath.Join(debugDir, fmt.Sprintf("error_%s.txt", timestamp)) + + // Write error details + content := fmt.Sprintf("=== LLM Response Parse Error ===\n") + content += fmt.Sprintf("Function: %s\n", fqn) + content += fmt.Sprintf("Error: %v\n", parseErr) + content += fmt.Sprintf("Provider: %s\n", c.Provider) + content += fmt.Sprintf("Model: %s\n", c.Model) + content += fmt.Sprintf("\n=== Full Response ===\n%s\n", responseText) + + os.WriteFile(filename, []byte(content), 0644) +} diff --git a/sourcecode-parser/diagnostic/llm_test.go b/sourcecode-parser/diagnostic/llm_test.go index 49bbc07d..b3545692 100644 --- a/sourcecode-parser/diagnostic/llm_test.go +++ b/sourcecode-parser/diagnostic/llm_test.go @@ -10,15 +10,31 @@ import ( "github.com/stretchr/testify/require" ) -// TestNewLLMClient tests client creation. +// TestNewLLMClient tests Ollama client creation. func TestNewLLMClient(t *testing.T) { client := NewLLMClient("http://localhost:11434", "qwen3-coder:32b") assert.NotNil(t, client) + assert.Equal(t, ProviderOllama, client.Provider) assert.Equal(t, "http://localhost:11434", client.BaseURL) assert.Equal(t, "qwen3-coder:32b", client.Model) assert.Equal(t, 0.0, client.Temperature) assert.Equal(t, 2000, client.MaxTokens) + assert.Equal(t, "", client.APIKey) + assert.NotNil(t, client.HTTPClient) +} + +// TestNewOpenAIClient tests OpenAI-compatible client creation. +func TestNewOpenAIClient(t *testing.T) { + client := NewOpenAIClient("https://api.x.ai/v1", "grok-beta", "test-api-key") + + assert.NotNil(t, client) + assert.Equal(t, ProviderOpenAI, client.Provider) + assert.Equal(t, "https://api.x.ai/v1", client.BaseURL) + assert.Equal(t, "grok-beta", client.Model) + assert.Equal(t, "test-api-key", client.APIKey) + assert.Equal(t, 0.0, client.Temperature) + assert.Equal(t, 4000, client.MaxTokens) assert.NotNil(t, client.HTTPClient) } @@ -390,3 +406,146 @@ func findSubstring(s, substr string) bool { } return false } + +// TestAnalyzeFunction_OpenAI tests successful OpenAI API analysis. +func TestAnalyzeFunction_OpenAI(t *testing.T) { + mockResponse := LLMAnalysisResult{ + DiscoveredPatterns: DiscoveredPatterns{ + Sources: []PatternLocation{ + {Pattern: "input", Lines: []int{1}, Variables: []string{"x"}}, + }, + }, + AnalysisMetadata: AnalysisMetadata{ + Confidence: 0.9, + }, + } + + // Create mock OpenAI server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify OpenAI request format + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + assert.Equal(t, "grok-test", reqBody["model"]) + assert.NotNil(t, reqBody["messages"]) + assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + + // Return OpenAI format response + responseBytes, _ := json.Marshal(mockResponse) + openaiResp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]string{ + "content": string(responseBytes), + }, + }, + }, + } + json.NewEncoder(w).Encode(openaiResp) + })) + defer server.Close() + + client := NewOpenAIClient(server.URL, "grok-test", "test-key") + + fn := &FunctionMetadata{ + FQN: "test.func", + SourceCode: "def func(): pass", + } + + result, err := client.AnalyzeFunction(fn) + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, 1, len(result.DiscoveredPatterns.Sources)) +} + +// TestAnalyzeFunction_OpenAI_HTTPError tests OpenAI API error handling. +func TestAnalyzeFunction_OpenAI_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "Invalid API key"}`)) + })) + defer server.Close() + + client := NewOpenAIClient(server.URL, "grok-test", "bad-key") + + fn := &FunctionMetadata{ + FQN: "test.func", + SourceCode: "def func(): pass", + } + + result, err := client.AnalyzeFunction(fn) + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "status 401") +} + +// TestAnalyzeFunction_OpenAI_NoChoices tests OpenAI response with no choices. +func TestAnalyzeFunction_OpenAI_NoChoices(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + openaiResp := map[string]interface{}{ + "choices": []map[string]interface{}{}, + } + json.NewEncoder(w).Encode(openaiResp) + })) + defer server.Close() + + client := NewOpenAIClient(server.URL, "grok-test", "test-key") + + fn := &FunctionMetadata{ + FQN: "test.func", + SourceCode: "def func(): pass", + } + + result, err := client.AnalyzeFunction(fn) + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "no choices") +} + +// TestExtractJSONFromMarkdown tests JSON extraction from markdown code blocks. +func TestExtractJSONFromMarkdown(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "json code block", + input: "```json\n{\"key\": \"value\"}\n```", + expected: "\n{\"key\": \"value\"}\n", + }, + { + name: "plain code block", + input: "```\n{\"key\": \"value\"}\n```", + expected: "\n{\"key\": \"value\"}\n", + }, + { + name: "no code block", + input: "{\"key\": \"value\"}", + expected: "{\"key\": \"value\"}", + }, + { + name: "multiple code blocks", + input: "```\nfirst\n```\nmiddle\n```\nsecond\n```", + expected: "\nfirst\n```\nmiddle\n```\nsecond\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractJSONFromMarkdown(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestSaveFailedResponse tests error logging functionality. +func TestSaveFailedResponse(t *testing.T) { + client := NewLLMClient("http://localhost:11434", "test-model") + + // This test just ensures the function doesn't panic + // Actual file creation is tested in integration tests + client.saveFailedResponse("test.func", `{"incomplete": `, assert.AnError) + + // No assertions needed - just verify no panic +} diff --git a/sourcecode-parser/diagnostic/metrics.go b/sourcecode-parser/diagnostic/metrics.go new file mode 100644 index 00000000..3601b507 --- /dev/null +++ b/sourcecode-parser/diagnostic/metrics.go @@ -0,0 +1,165 @@ +package diagnostic + +import ( + "time" +) + +// OverallMetrics contains aggregated metrics across all functions. +type OverallMetrics struct { + // Total functions analyzed + TotalFunctions int + + // Confusion Matrix + TruePositives int // Tool detected, LLM confirmed ✅ + FalsePositives int // Tool detected, LLM says safe ⚠️ + FalseNegatives int // Tool missed, LLM found vuln ❌ + TrueNegatives int // Tool skipped, LLM confirmed safe ✅ + + // Metrics + Precision float64 // TP / (TP + FP) + Recall float64 // TP / (TP + FN) + F1Score float64 // 2 * (P * R) / (P + R) + Agreement float64 // (TP + TN) / Total + + // Processing stats + LLMProcessingTime string + TotalProcessingTime string + FunctionsPerSecond float64 + + // Failure breakdown + FailuresByCategory map[string]int + TopFailures []FailureExample +} + +// FailureExample represents a specific failure case. +type FailureExample struct { + Type string // "FALSE_POSITIVE", "FALSE_NEGATIVE" + FunctionFQN string + FunctionFile string + FunctionLine int + Category string // "control_flow", "sanitizer", etc. + Reason string // From LLM + Flow *NormalizedTaintFlow // Flow details (if applicable) +} + +// CalculateOverallMetrics aggregates metrics from all function comparisons. +// +// Performance: O(n) where n = number of comparisons +// +// Example: +// +// metrics := CalculateOverallMetrics(comparisons, startTime) +// fmt.Printf("Precision: %.1f%%\n", metrics.Precision*100) +// fmt.Printf("Recall: %.1f%%\n", metrics.Recall*100) +// fmt.Printf("F1 Score: %.1f%%\n", metrics.F1Score*100) +func CalculateOverallMetrics( + comparisons []*DualLevelComparison, + startTime time.Time, +) *OverallMetrics { + metrics := &OverallMetrics{ + TotalFunctions: len(comparisons), + FailuresByCategory: make(map[string]int), + TopFailures: []FailureExample{}, + } + + // Calculate confusion matrix + for _, cmp := range comparisons { + switch { + case cmp.BinaryToolResult && cmp.BinaryLLMResult: + metrics.TruePositives++ + case cmp.BinaryToolResult && !cmp.BinaryLLMResult: + metrics.FalsePositives++ + case !cmp.BinaryToolResult && cmp.BinaryLLMResult: + metrics.FalseNegatives++ + default: + metrics.TrueNegatives++ + } + + // Track failure categories + if !cmp.BinaryAgreement && cmp.FailureCategory != "" { + metrics.FailuresByCategory[cmp.FailureCategory]++ + } + } + + // Calculate metrics + if metrics.TruePositives+metrics.FalsePositives > 0 { + metrics.Precision = float64(metrics.TruePositives) / + float64(metrics.TruePositives+metrics.FalsePositives) + } + + if metrics.TruePositives+metrics.FalseNegatives > 0 { + metrics.Recall = float64(metrics.TruePositives) / + float64(metrics.TruePositives+metrics.FalseNegatives) + } + + if metrics.Precision+metrics.Recall > 0 { + metrics.F1Score = 2 * metrics.Precision * metrics.Recall / + (metrics.Precision + metrics.Recall) + } + + if metrics.TotalFunctions > 0 { + metrics.Agreement = float64(metrics.TruePositives+metrics.TrueNegatives) / + float64(metrics.TotalFunctions) + } + + // Processing stats + totalDuration := time.Since(startTime) + metrics.TotalProcessingTime = totalDuration.String() + if totalDuration.Seconds() > 0 { + metrics.FunctionsPerSecond = float64(metrics.TotalFunctions) / totalDuration.Seconds() + } + + return metrics +} + +// ExtractTopFailures extracts the most important failure examples. +// Returns up to maxPerType failures of each type (FP/FN). +func ExtractTopFailures( + comparisons []*DualLevelComparison, + functionsMap map[string]*FunctionMetadata, + maxPerType int, +) []FailureExample { + failures := []FailureExample{} + + fpCount := 0 + fnCount := 0 + + for _, cmp := range comparisons { + fn := functionsMap[cmp.FunctionFQN] + if fn == nil { + continue + } + + // False Positives + if cmp.BinaryToolResult && !cmp.BinaryLLMResult { + if fpCount < maxPerType { + failures = append(failures, FailureExample{ + Type: "FALSE_POSITIVE", + FunctionFQN: cmp.FunctionFQN, + FunctionFile: fn.FilePath, + FunctionLine: fn.StartLine, + Category: cmp.FailureCategory, + Reason: cmp.FailureReason, + }) + fpCount++ + } + } + + // False Negatives + if !cmp.BinaryToolResult && cmp.BinaryLLMResult { + if fnCount < maxPerType { + failures = append(failures, FailureExample{ + Type: "FALSE_NEGATIVE", + FunctionFQN: cmp.FunctionFQN, + FunctionFile: fn.FilePath, + FunctionLine: fn.StartLine, + Category: cmp.FailureCategory, + Reason: cmp.FailureReason, + }) + fnCount++ + } + } + } + + return failures +} diff --git a/sourcecode-parser/diagnostic/metrics_test.go b/sourcecode-parser/diagnostic/metrics_test.go new file mode 100644 index 00000000..977c88e6 --- /dev/null +++ b/sourcecode-parser/diagnostic/metrics_test.go @@ -0,0 +1,377 @@ +package diagnostic + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestCalculateOverallMetrics_AllTruePositives tests perfect agreement (all TP). +func TestCalculateOverallMetrics_AllTruePositives(t *testing.T) { + comparisons := []*DualLevelComparison{ + { + FunctionFQN: "test.func1", + BinaryToolResult: true, + BinaryLLMResult: true, + BinaryAgreement: true, + }, + { + FunctionFQN: "test.func2", + BinaryToolResult: true, + BinaryLLMResult: true, + BinaryAgreement: true, + }, + } + + startTime := time.Now().Add(-2 * time.Second) + metrics := CalculateOverallMetrics(comparisons, startTime) + + assert.Equal(t, 2, metrics.TotalFunctions) + assert.Equal(t, 2, metrics.TruePositives) + assert.Equal(t, 0, metrics.FalsePositives) + assert.Equal(t, 0, metrics.FalseNegatives) + assert.Equal(t, 0, metrics.TrueNegatives) + assert.Equal(t, 1.0, metrics.Precision) + assert.Equal(t, 1.0, metrics.Recall) + assert.Equal(t, 1.0, metrics.F1Score) + assert.Equal(t, 1.0, metrics.Agreement) + assert.Greater(t, metrics.FunctionsPerSecond, 0.0) +} + +// TestCalculateOverallMetrics_AllTrueNegatives tests perfect agreement (all TN). +func TestCalculateOverallMetrics_AllTrueNegatives(t *testing.T) { + comparisons := []*DualLevelComparison{ + { + FunctionFQN: "test.func1", + BinaryToolResult: false, + BinaryLLMResult: false, + BinaryAgreement: true, + }, + { + FunctionFQN: "test.func2", + BinaryToolResult: false, + BinaryLLMResult: false, + BinaryAgreement: true, + }, + } + + startTime := time.Now().Add(-1 * time.Second) + metrics := CalculateOverallMetrics(comparisons, startTime) + + assert.Equal(t, 2, metrics.TotalFunctions) + assert.Equal(t, 0, metrics.TruePositives) + assert.Equal(t, 0, metrics.FalsePositives) + assert.Equal(t, 0, metrics.FalseNegatives) + assert.Equal(t, 2, metrics.TrueNegatives) + assert.Equal(t, 0.0, metrics.Precision) // No TP + FP + assert.Equal(t, 0.0, metrics.Recall) // No TP + FN + assert.Equal(t, 0.0, metrics.F1Score) + assert.Equal(t, 1.0, metrics.Agreement) +} + +// TestCalculateOverallMetrics_MixedResults tests confusion matrix calculation. +func TestCalculateOverallMetrics_MixedResults(t *testing.T) { + comparisons := []*DualLevelComparison{ + // TP + { + FunctionFQN: "test.tp", + BinaryToolResult: true, + BinaryLLMResult: true, + BinaryAgreement: true, + }, + // FP + { + FunctionFQN: "test.fp", + BinaryToolResult: true, + BinaryLLMResult: false, + BinaryAgreement: false, + FailureCategory: "sanitizer_missed", + }, + // FN + { + FunctionFQN: "test.fn", + BinaryToolResult: false, + BinaryLLMResult: true, + BinaryAgreement: false, + FailureCategory: "control_flow_branch", + }, + // TN + { + FunctionFQN: "test.tn", + BinaryToolResult: false, + BinaryLLMResult: false, + BinaryAgreement: true, + }, + } + + startTime := time.Now().Add(-5 * time.Second) + metrics := CalculateOverallMetrics(comparisons, startTime) + + assert.Equal(t, 4, metrics.TotalFunctions) + assert.Equal(t, 1, metrics.TruePositives) + assert.Equal(t, 1, metrics.FalsePositives) + assert.Equal(t, 1, metrics.FalseNegatives) + assert.Equal(t, 1, metrics.TrueNegatives) + + // Precision = TP / (TP + FP) = 1 / 2 = 0.5 + assert.Equal(t, 0.5, metrics.Precision) + + // Recall = TP / (TP + FN) = 1 / 2 = 0.5 + assert.Equal(t, 0.5, metrics.Recall) + + // F1 = 2 * P * R / (P + R) = 2 * 0.5 * 0.5 / 1.0 = 0.5 + assert.Equal(t, 0.5, metrics.F1Score) + + // Agreement = (TP + TN) / Total = 2 / 4 = 0.5 + assert.Equal(t, 0.5, metrics.Agreement) + + // Failure categories + assert.Equal(t, 2, len(metrics.FailuresByCategory)) + assert.Equal(t, 1, metrics.FailuresByCategory["sanitizer_missed"]) + assert.Equal(t, 1, metrics.FailuresByCategory["control_flow_branch"]) +} + +// TestCalculateOverallMetrics_ZeroDivision tests edge case with no detections. +func TestCalculateOverallMetrics_ZeroDivision(t *testing.T) { + comparisons := []*DualLevelComparison{} + + startTime := time.Now() + metrics := CalculateOverallMetrics(comparisons, startTime) + + assert.Equal(t, 0, metrics.TotalFunctions) + assert.Equal(t, 0.0, metrics.Precision) + assert.Equal(t, 0.0, metrics.Recall) + assert.Equal(t, 0.0, metrics.F1Score) + assert.Equal(t, 0.0, metrics.Agreement) +} + +// TestCalculateOverallMetrics_FailureCategoryCounting tests category aggregation. +func TestCalculateOverallMetrics_FailureCategoryCounting(t *testing.T) { + comparisons := []*DualLevelComparison{ + { + FunctionFQN: "test.fp1", + BinaryToolResult: true, + BinaryLLMResult: false, + BinaryAgreement: false, + FailureCategory: "sanitizer_missed", + }, + { + FunctionFQN: "test.fp2", + BinaryToolResult: true, + BinaryLLMResult: false, + BinaryAgreement: false, + FailureCategory: "sanitizer_missed", + }, + { + FunctionFQN: "test.fn1", + BinaryToolResult: false, + BinaryLLMResult: true, + BinaryAgreement: false, + FailureCategory: "control_flow_branch", + }, + { + FunctionFQN: "test.tp", + BinaryToolResult: true, + BinaryLLMResult: true, + BinaryAgreement: true, + }, + } + + startTime := time.Now() + metrics := CalculateOverallMetrics(comparisons, startTime) + + require.NotNil(t, metrics.FailuresByCategory) + assert.Equal(t, 2, metrics.FailuresByCategory["sanitizer_missed"]) + assert.Equal(t, 1, metrics.FailuresByCategory["control_flow_branch"]) +} + +// TestExtractTopFailures tests failure example extraction. +func TestExtractTopFailures(t *testing.T) { + functionsMap := map[string]*FunctionMetadata{ + "test.fp1": { + FQN: "test.fp1", + FilePath: "test.py", + StartLine: 10, + }, + "test.fp2": { + FQN: "test.fp2", + FilePath: "test.py", + StartLine: 20, + }, + "test.fn1": { + FQN: "test.fn1", + FilePath: "test.py", + StartLine: 30, + }, + } + + comparisons := []*DualLevelComparison{ + // FP + { + FunctionFQN: "test.fp1", + BinaryToolResult: true, + BinaryLLMResult: false, + BinaryAgreement: false, + FailureCategory: "sanitizer_missed", + FailureReason: "Data is sanitized", + }, + // FP + { + FunctionFQN: "test.fp2", + BinaryToolResult: true, + BinaryLLMResult: false, + BinaryAgreement: false, + FailureCategory: "sanitizer_missed", + FailureReason: "Escape function used", + }, + // FN + { + FunctionFQN: "test.fn1", + BinaryToolResult: false, + BinaryLLMResult: true, + BinaryAgreement: false, + FailureCategory: "control_flow_branch", + FailureReason: "Flow through if branch", + }, + } + + failures := ExtractTopFailures(comparisons, functionsMap, 2) + + require.NotNil(t, failures) + // Should return up to 2 of each type (FP, FN) + assert.LessOrEqual(t, len(failures), 4) + + // Count types + fpCount := 0 + fnCount := 0 + for _, f := range failures { + if f.Type == "FALSE_POSITIVE" { + fpCount++ + assert.Contains(t, []string{"test.fp1", "test.fp2"}, f.FunctionFQN) + } else if f.Type == "FALSE_NEGATIVE" { + fnCount++ + assert.Equal(t, "test.fn1", f.FunctionFQN) + } + } + + assert.LessOrEqual(t, fpCount, 2) + assert.LessOrEqual(t, fnCount, 2) +} + +// TestExtractTopFailures_EmptyComparisons tests with no failures. +func TestExtractTopFailures_EmptyComparisons(t *testing.T) { + functionsMap := map[string]*FunctionMetadata{} + comparisons := []*DualLevelComparison{} + + failures := ExtractTopFailures(comparisons, functionsMap, 5) + + require.NotNil(t, failures) + assert.Equal(t, 0, len(failures)) +} + +// TestExtractTopFailures_LimitPerType tests maxPerType limit. +func TestExtractTopFailures_LimitPerType(t *testing.T) { + functionsMap := map[string]*FunctionMetadata{ + "test.fp1": {FQN: "test.fp1", FilePath: "test.py", StartLine: 10}, + "test.fp2": {FQN: "test.fp2", FilePath: "test.py", StartLine: 20}, + "test.fp3": {FQN: "test.fp3", FilePath: "test.py", StartLine: 30}, + } + + comparisons := []*DualLevelComparison{ + { + FunctionFQN: "test.fp1", + BinaryToolResult: true, + BinaryLLMResult: false, + BinaryAgreement: false, + FailureCategory: "sanitizer_missed", + FailureReason: "Sanitized 1", + }, + { + FunctionFQN: "test.fp2", + BinaryToolResult: true, + BinaryLLMResult: false, + BinaryAgreement: false, + FailureCategory: "sanitizer_missed", + FailureReason: "Sanitized 2", + }, + { + FunctionFQN: "test.fp3", + BinaryToolResult: true, + BinaryLLMResult: false, + BinaryAgreement: false, + FailureCategory: "sanitizer_missed", + FailureReason: "Sanitized 3", + }, + } + + // Limit to 2 per type + failures := ExtractTopFailures(comparisons, functionsMap, 2) + + require.NotNil(t, failures) + // Should only return 2 FPs (limited by maxPerType) + assert.Equal(t, 2, len(failures)) + + for _, f := range failures { + assert.Equal(t, "FALSE_POSITIVE", f.Type) + } +} + +// TestExtractTopFailures_MissingMetadata tests handling of missing function metadata. +func TestExtractTopFailures_MissingMetadata(t *testing.T) { + functionsMap := map[string]*FunctionMetadata{ + // Only test.fp1 has metadata + "test.fp1": { + FQN: "test.fp1", + FilePath: "test.py", + StartLine: 10, + }, + } + + comparisons := []*DualLevelComparison{ + { + FunctionFQN: "test.fp1", + BinaryToolResult: true, + BinaryLLMResult: false, + BinaryAgreement: false, + FailureCategory: "sanitizer_missed", + FailureReason: "Sanitized", + }, + { + FunctionFQN: "test.fp2", // No metadata + BinaryToolResult: true, + BinaryLLMResult: false, + BinaryAgreement: false, + FailureCategory: "sanitizer_missed", + FailureReason: "Sanitized", + }, + } + + failures := ExtractTopFailures(comparisons, functionsMap, 5) + + require.NotNil(t, failures) + // Should only include test.fp1 (has metadata) + assert.Equal(t, 1, len(failures)) + assert.Equal(t, "test.fp1", failures[0].FunctionFQN) +} + +// TestCalculateOverallMetrics_ProcessingTime tests timing calculation. +func TestCalculateOverallMetrics_ProcessingTime(t *testing.T) { + comparisons := []*DualLevelComparison{ + { + FunctionFQN: "test.func1", + BinaryToolResult: true, + BinaryLLMResult: true, + BinaryAgreement: true, + }, + } + + startTime := time.Now().Add(-3 * time.Second) + metrics := CalculateOverallMetrics(comparisons, startTime) + + assert.NotEmpty(t, metrics.TotalProcessingTime) + assert.Greater(t, metrics.FunctionsPerSecond, 0.0) + // Should process ~0.33 functions/second (1 function in 3 seconds) + assert.Less(t, metrics.FunctionsPerSecond, 1.0) +} diff --git a/sourcecode-parser/diagnostic/normalizer.go b/sourcecode-parser/diagnostic/normalizer.go new file mode 100644 index 00000000..5c9be2b2 --- /dev/null +++ b/sourcecode-parser/diagnostic/normalizer.go @@ -0,0 +1,232 @@ +package diagnostic + +import ( + "strings" +) + +// NormalizedTaintFlow is the common format for comparison. +// Both tool and LLM results are converted to this format. +type NormalizedTaintFlow struct { + SourceLine int + SourceVariable string + SourceCategory string // Semantic: "user_input", "file_read", etc. + + SinkLine int + SinkVariable string + SinkCategory string // Semantic: "sql_execution", "command_exec", etc. + + VulnerabilityType string // "SQL_INJECTION", "XSS", etc. + Confidence float64 +} + +// NormalizeToolResult converts our tool's result to normalized format. +func NormalizeToolResult(toolResult *FunctionTaintResult) []NormalizedTaintFlow { + normalized := make([]NormalizedTaintFlow, 0, len(toolResult.TaintFlows)) + + for _, flow := range toolResult.TaintFlows { + normalized = append(normalized, NormalizedTaintFlow{ + SourceLine: flow.SourceLine, + SourceVariable: flow.SourceVariable, + SourceCategory: flow.SourceCategory, + SinkLine: flow.SinkLine, + SinkVariable: flow.SinkVariable, + SinkCategory: flow.SinkCategory, + VulnerabilityType: flow.VulnerabilityType, + Confidence: flow.Confidence, + }) + } + + return normalized +} + +// NormalizeLLMResult converts LLM test cases to normalized format. +func NormalizeLLMResult(llmResult *LLMAnalysisResult) []NormalizedTaintFlow { + normalized := make([]NormalizedTaintFlow, 0, len(llmResult.DataflowTestCases)) + + for _, testCase := range llmResult.DataflowTestCases { + // Only include test cases where LLM expects detection + if testCase.ExpectedDetection { + normalized = append(normalized, NormalizedTaintFlow{ + SourceLine: testCase.Source.Line, + SourceVariable: testCase.Source.Variable, + SourceCategory: categorizeLLMPattern(testCase.Source.Pattern), + SinkLine: testCase.Sink.Line, + SinkVariable: testCase.Sink.Variable, + SinkCategory: categorizeLLMPattern(testCase.Sink.Pattern), + VulnerabilityType: normalizeVulnType(testCase.VulnerabilityType), + Confidence: testCase.Confidence, + }) + } + } + + return normalized +} + +// categorizeLLMPattern maps LLM pattern string to semantic category. +func categorizeLLMPattern(pattern string) string { + patternLower := strings.ToLower(pattern) + + // User input + if strings.Contains(patternLower, "request") || + strings.Contains(patternLower, "input") || + strings.Contains(patternLower, "get[") || + strings.Contains(patternLower, "post[") { + return "user_input" + } + + // SQL + if strings.Contains(patternLower, "execute") || + strings.Contains(patternLower, "sql") || + strings.Contains(patternLower, "cursor") { + return "sql_execution" + } + + // Command execution + if strings.Contains(patternLower, "system") || + strings.Contains(patternLower, "subprocess") || + strings.Contains(patternLower, "popen") || + strings.Contains(patternLower, "call") { + return "command_exec" + } + + // Code execution + if strings.Contains(patternLower, "eval") || + strings.Contains(patternLower, "exec") { + return "code_exec" + } + + return "other" +} + +// normalizeVulnType normalizes vulnerability type names. +func normalizeVulnType(vulnType string) string { + normalized := strings.ToUpper(strings.ReplaceAll(vulnType, " ", "_")) + + // Handle variations + equivalenceMap := map[string]string{ + "SQLI": "SQL_INJECTION", + "SQL INJECTION": "SQL_INJECTION", + "CMD_INJECTION": "COMMAND_INJECTION", + "COMMAND INJECTION": "COMMAND_INJECTION", + "OS_COMMAND_INJECTION": "COMMAND_INJECTION", + "CODE INJECTION": "CODE_INJECTION", + "CROSS_SITE_SCRIPTING": "XSS", + "CROSS SITE SCRIPTING": "XSS", + "PATH_TRAVERSAL": "PATH_TRAVERSAL", + "DIRECTORY_TRAVERSAL": "PATH_TRAVERSAL", + } + + if canonical, ok := equivalenceMap[normalized]; ok { + return canonical + } + + return normalized +} + +// MatchConfig specifies how lenient fuzzy matching should be. +type MatchConfig struct { + // LineThreshold: Accept matches within ±N lines (default: 2) + LineThreshold int + + // AllowVariableAliases: Match user_input vs user_input_1 (SSA) (default: true) + AllowVariableAliases bool + + // SemanticVulnTypes: "SQL_INJECTION" == "sqli" (default: true) + SemanticVulnTypes bool +} + +// DefaultMatchConfig returns default fuzzy matching configuration. +func DefaultMatchConfig() MatchConfig { + return MatchConfig{ + LineThreshold: 2, + AllowVariableAliases: true, + SemanticVulnTypes: true, + } +} + +// FlowsMatch checks if two normalized flows match (fuzzy matching). +func FlowsMatch(f1, f2 NormalizedTaintFlow, config MatchConfig) bool { + // 1. Line numbers within threshold + sourceLineMatch := abs(f1.SourceLine-f2.SourceLine) <= config.LineThreshold + sinkLineMatch := abs(f1.SinkLine-f2.SinkLine) <= config.LineThreshold + + if !sourceLineMatch || !sinkLineMatch { + return false + } + + // 2. Variable names match (with optional aliases) + sourceVarMatch := variablesMatch(f1.SourceVariable, f2.SourceVariable, config.AllowVariableAliases) + sinkVarMatch := variablesMatch(f1.SinkVariable, f2.SinkVariable, config.AllowVariableAliases) + + if !sourceVarMatch || !sinkVarMatch { + return false + } + + // 3. Categories match (semantic comparison) + categoryMatch := (f1.SourceCategory == f2.SourceCategory) && + (f1.SinkCategory == f2.SinkCategory) + + if !categoryMatch { + return false + } + + // 4. Vulnerability type match (with semantic equivalence) + vulnMatch := vulnTypesMatch(f1.VulnerabilityType, f2.VulnerabilityType, config.SemanticVulnTypes) + + return vulnMatch +} + +// variablesMatch checks if two variable names match (with optional alias support). +func variablesMatch(v1, v2 string, allowAliases bool) bool { + if v1 == v2 { + return true + } + + if allowAliases { + // Strip SSA suffixes: user_input_1 → user_input + base1 := stripSSASuffix(v1) + base2 := stripSSASuffix(v2) + return base1 == base2 + } + + return false +} + +// stripSSASuffix removes SSA renaming suffix. +// Example: "user_input_1" → "user_input". +func stripSSASuffix(varName string) string { + // Simple heuristic: remove _N suffix where N is digit + parts := strings.Split(varName, "_") + if len(parts) >= 2 { + lastPart := parts[len(parts)-1] + // Check if last part is a number + if len(lastPart) > 0 && lastPart[0] >= '0' && lastPart[0] <= '9' { + return strings.Join(parts[:len(parts)-1], "_") + } + } + return varName +} + +// vulnTypesMatch checks if two vulnerability types match semantically. +func vulnTypesMatch(t1, t2 string, semantic bool) bool { + if t1 == t2 { + return true + } + + if semantic { + // Normalize both + t1Norm := normalizeVulnType(t1) + t2Norm := normalizeVulnType(t2) + return t1Norm == t2Norm + } + + return false +} + +// abs returns absolute value of an int. +func abs(x int) int { + if x < 0 { + return -x + } + return x +} diff --git a/sourcecode-parser/diagnostic/normalizer_test.go b/sourcecode-parser/diagnostic/normalizer_test.go new file mode 100644 index 00000000..855e0444 --- /dev/null +++ b/sourcecode-parser/diagnostic/normalizer_test.go @@ -0,0 +1,245 @@ +package diagnostic + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestNormalizeToolResult tests tool result normalization. +func TestNormalizeToolResult(t *testing.T) { + toolResult := &FunctionTaintResult{ + FunctionFQN: "test.func", + HasTaintFlow: true, + TaintFlows: []ToolTaintFlow{ + { + SourceLine: 10, + SourceVariable: "user_input", + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "query", + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + Confidence: 0.95, + }, + }, + } + + normalized := NormalizeToolResult(toolResult) + + assert.Equal(t, 1, len(normalized)) + assert.Equal(t, 10, normalized[0].SourceLine) + assert.Equal(t, "user_input", normalized[0].SourceVariable) + assert.Equal(t, "user_input", normalized[0].SourceCategory) + assert.Equal(t, 20, normalized[0].SinkLine) + assert.Equal(t, "SQL_INJECTION", normalized[0].VulnerabilityType) +} + +// TestNormalizeLLMResult tests LLM result normalization. +func TestNormalizeLLMResult(t *testing.T) { + llmResult := &LLMAnalysisResult{ + DataflowTestCases: []DataflowTestCase{ + { + Source: TestCaseSource{ + Pattern: "request.GET['cmd']", + Line: 5, + Variable: "cmd", + }, + Sink: TestCaseSink{ + Pattern: "os.system", + Line: 10, + Variable: "cmd", + }, + ExpectedDetection: true, + VulnerabilityType: "COMMAND_INJECTION", + Confidence: 0.9, + }, + { + Source: TestCaseSource{ + Pattern: "input()", + Line: 15, + Variable: "safe", + }, + Sink: TestCaseSink{ + Pattern: "print", + Line: 16, + Variable: "safe", + }, + ExpectedDetection: false, // Should NOT be included + }, + }, + } + + normalized := NormalizeLLMResult(llmResult) + + // Only expected detections are included + assert.Equal(t, 1, len(normalized)) + assert.Equal(t, 5, normalized[0].SourceLine) + assert.Equal(t, "cmd", normalized[0].SourceVariable) + assert.Equal(t, "user_input", normalized[0].SourceCategory) + assert.Equal(t, "command_exec", normalized[0].SinkCategory) +} + +// TestFlowsMatch_LineThreshold tests line number fuzzy matching. +func TestFlowsMatch_LineThreshold(t *testing.T) { + config := DefaultMatchConfig() + + f1 := NormalizedTaintFlow{ + SourceLine: 10, + SourceVariable: "x", + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "y", + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + } + + // Within ±2 lines: SHOULD match + f2 := NormalizedTaintFlow{ + SourceLine: 11, // +1 line + SourceVariable: "x", + SourceCategory: "user_input", + SinkLine: 19, // -1 line + SinkVariable: "y", + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + } + + assert.True(t, FlowsMatch(f1, f2, config)) + + // Outside threshold: should NOT match + f3 := NormalizedTaintFlow{ + SourceLine: 15, // +5 lines + SourceVariable: "x", + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "y", + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + } + + assert.False(t, FlowsMatch(f1, f3, config)) +} + +// TestFlowsMatch_SSAVariables tests SSA variable alias matching. +func TestFlowsMatch_SSAVariables(t *testing.T) { + config := DefaultMatchConfig() + + f1 := NormalizedTaintFlow{ + SourceLine: 10, + SourceVariable: "user_input", + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "query", + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + } + + // With SSA suffix _1: SHOULD match + f2 := NormalizedTaintFlow{ + SourceLine: 10, + SourceVariable: "user_input_1", // SSA renamed + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "query_2", // SSA renamed + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + } + + assert.True(t, FlowsMatch(f1, f2, config)) +} + +// TestFlowsMatch_VulnTypes tests semantic vulnerability type matching. +func TestFlowsMatch_VulnTypes(t *testing.T) { + config := DefaultMatchConfig() + + f1 := NormalizedTaintFlow{ + SourceLine: 10, + SourceVariable: "x", + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "y", + SinkCategory: "sql_execution", + VulnerabilityType: "SQL_INJECTION", + } + + // Different but equivalent vuln type: SHOULD match + f2 := NormalizedTaintFlow{ + SourceLine: 10, + SourceVariable: "x", + SourceCategory: "user_input", + SinkLine: 20, + SinkVariable: "y", + SinkCategory: "sql_execution", + VulnerabilityType: "sqli", // Lowercase variant + } + + assert.True(t, FlowsMatch(f1, f2, config)) +} + +// TestStripSSASuffix tests SSA suffix removal. +func TestStripSSASuffix(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"user_input_1", "user_input"}, + {"user_input_2", "user_input"}, + {"query_10", "query"}, + {"normal_var", "normal_var"}, + {"x", "x"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := stripSSASuffix(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestNormalizeVulnType tests vulnerability type normalization. +func TestNormalizeVulnType(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"SQL_INJECTION", "SQL_INJECTION"}, + {"sqli", "SQL_INJECTION"}, + {"SQL INJECTION", "SQL_INJECTION"}, + {"COMMAND_INJECTION", "COMMAND_INJECTION"}, + {"cmd_injection", "COMMAND_INJECTION"}, + {"XSS", "XSS"}, + {"cross_site_scripting", "XSS"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := normalizeVulnType(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestCategorizeLLMPattern tests LLM pattern categorization. +func TestCategorizeLLMPattern(t *testing.T) { + tests := []struct { + pattern string + expected string + }{ + {"request.GET['cmd']", "user_input"}, + {"request.POST", "user_input"}, + {"input()", "user_input"}, + {"cursor.execute", "sql_execution"}, + {"os.system", "command_exec"}, + {"subprocess.call", "command_exec"}, + {"eval", "code_exec"}, + } + + for _, tt := range tests { + t.Run(tt.pattern, func(t *testing.T) { + result := categorizeLLMPattern(tt.pattern) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/sourcecode-parser/diagnostic/prompt.go b/sourcecode-parser/diagnostic/prompt.go index b1b24205..b8adf269 100644 --- a/sourcecode-parser/diagnostic/prompt.go +++ b/sourcecode-parser/diagnostic/prompt.go @@ -7,29 +7,38 @@ import ( // BuildAnalysisPrompt constructs the prompt for LLM pattern discovery and test generation. // Based on diagnostic-tech-proposal.md Section 3.3 (LLM Prompt Design). func BuildAnalysisPrompt(sourceCode string) string { - return fmt.Sprintf(`You are a dataflow analysis expert. Analyze this Python function to discover all dataflow patterns and generate test cases. + return fmt.Sprintf(`You are an intra-procedural dataflow analysis expert. Your goal is to validate if a static analysis tool correctly tracks data flow within THIS SINGLE FUNCTION ONLY. **FUNCTION TO ANALYZE**: `+"```python\n%s\n```"+` -**YOUR TASK**: - -1. **DISCOVER PATTERNS** - Identify all dataflow patterns in THIS function: - - **Sources**: Any operation that introduces new data (user input, file reads, network, env vars, function params, etc.) - - **Sinks**: Any operation that consumes data (output, storage, exec, system calls, returns, etc.) - - **Sanitizers**: Any operation that transforms/validates data (escape, quote, validate, cast, etc.) - - **Propagators**: Operations that pass data along (assignments, calls, returns) - -2. **TRACE DATAFLOWS** - For each discovered source: - - Track where the data flows (which variables, which lines) - - Identify if it reaches any sinks - - Note if any sanitizers are applied - - Track through: assignments, calls, returns, branches, containers, attributes - -3. **GENERATE TEST CASES** - Create test cases our tool should pass: - - Expected flows (source → sink paths) - - Expected sanitizer detection - - Expected variable tracking +**YOUR TASK**: Test if the tool can track how data flows between variables in this function. + +1. **DISCOVER DATA SOURCES** - Where does data originate in this function: + - Function parameters (any param used in the function body) + - Local variable assignments (x = "value") + - Function calls that return data (result = some_function()) + - Literals/constants that get assigned to variables + - Dictionary/list access (val = dict['key']) + +2. **DISCOVER DATA SINKS** - Where does data get consumed/used: + - Return statements (return x) + - Function calls that use data (print(x), logger.info(x)) + - Assignments to data structures (dict[key] = x, list.append(x)) + - Operations that use variables (y = x + 1) + +3. **TRACE INTRA-PROCEDURAL FLOWS** - For EACH source variable: + - Track ALL assignments: if 'a = source', then 'b = a' means b is tainted from source + - Track through operations: if 'b = a + "suffix"', b is still tainted from a + - Track through containers: if 'list = [a]', then 'b = list[0]' means b is tainted + - Track through branches: if inside 'if True: b = a', then b is tainted + - Track through method calls: if 'b = a.upper()', b is still tainted from a + - ONLY track within this function (do NOT analyze called functions) + +4. **GENERATE TEST CASES** - Create test cases for flows YOU FOUND: + - Only include flows where source → sink connection exists within this function + - Set expected_detection=true if there IS a flow path + - Set expected_detection=false if source and sink are INDEPENDENT variables **OUTPUT FORMAT** (JSON): `+"```json"+` @@ -95,7 +104,8 @@ func BuildAnalysisPrompt(sourceCode string) string { "expected_detection": true, "vulnerability_type": "COMMAND_INJECTION", "confidence": 0.95, - "reasoning": "Direct flow from user input to OS command without sanitization" + "reasoning": "Direct flow from user input to OS command without sanitization", + "failure_category": "none" } ], @@ -128,35 +138,63 @@ func BuildAnalysisPrompt(sourceCode string) string { **IMPORTANT GUIDELINES**: -1. **NO PREDEFINED PATTERNS**: Discover patterns from the code itself, don't assume -2. **BE SPECIFIC**: Include exact line numbers, variable names, code snippets -3. **TRACK EVERYTHING**: Even non-security dataflows (var assignments, returns, etc.) -4. **SANITIZER EFFECTIVENESS**: Note what each sanitizer actually blocks -5. **GENERATE TESTS**: Each test case should be independently verifiable -6. **CONFIDENCE SCORES**: Rate how confident you are (0.0-1.0) -7. **EXPLAIN REASONING**: Why you think a flow exists or doesn't exist - -**EXAMPLE PATTERNS TO DISCOVER**: - -Security: -- request.GET/POST/COOKIES → eval/exec/os.system -- input() → open() -- socket.recv() → subprocess.call() - -Generic Dataflow: -- function_param → return value -- config['key'] → database.save() -- user.name → logger.info() -- x = calculate() → result = process(x) - -**FOCUS**: Validate dataflow tracking algorithm: -- ✅ Track variables through assignments -- ✅ Detect def-use chains correctly -- ✅ Identify taint propagation paths -- ✅ Recognize sanitizers -- ✅ Handle control flow (if/else) -- ✅ Track container operations -- ✅ Track attribute access +1. **RESPOND WITH PURE JSON ONLY**: No markdown, no code blocks, no explanations - just the raw JSON object +2. **INTRA-PROCEDURAL ONLY**: Only analyze flows within this function, ignore inter-procedural flows +3. **BE SPECIFIC**: Include exact line numbers, variable names from the function +4. **TRACK SIMPLE DATAFLOWS**: Focus on variable assignments, not complex analysis +5. **CONFIDENCE SCORES**: Rate how confident you are (0.0-1.0) +6. **EXPLAIN REASONING**: Describe the assignment chain (e.g., "a→b→c→return") +7. **FAILURE CATEGORY**: For each test case, specify why a static analysis tool might miss this flow: + - "none" - Simple flow, tool should detect easily + - "control_flow_branch" - Flow through if/else/while branches + - "field_sensitivity" - Flow through object fields (self.x, obj.attr) + - "sanitizer_missed" - Has sanitizer that tool might not recognize + - "container_operation" - Flow through list/dict/array operations + - "string_formatting" - Flow through f-strings, .format(), concatenation + - "method_call_propagation" - Flow through method calls like .upper(), .strip() + - "assignment_chain" - Long chain of assignments (a=b; c=a; d=c) + - "return_flow" - Flow from variable to return statement + - "parameter_flow" - Flow from function parameter to usage + - "complex_expression" - Nested calls, multiple operations + - "context_required" - Requires analyzing called functions (out of scope) + +**EXAMPLE DATAFLOW PATTERNS WITH CATEGORIES**: + +Simple assignments (failure_category: "none"): +- param → local_var → return +- x = param; y = x; return y +- result = function(param); return result + +Assignment chains (failure_category: "assignment_chain"): +- x = param; y = x; z = y; w = z; return w + +Through method calls (failure_category: "method_call_propagation"): +- x = param.upper(); return x +- y = param.strip().lower(); return y + +Through string operations (failure_category: "string_formatting"): +- x = param + "suffix"; return x +- z = f"{param}"; return z +- w = "%%s" %% param; return w + +Through containers (failure_category: "container_operation"): +- list = [param]; x = list[0]; return x +- dict = {"key": param}; y = dict["key"]; return y + +Through branches (failure_category: "control_flow_branch"): +- if condition: x = param; return x +- if True: y = param; else: y = ""; return y + +No flow (failure_category: "none"): +- x = param; y = "constant"; return y (y NOT tainted by param) + +**FOCUS**: Validate intra-procedural dataflow tracking: +- ✅ Direct assignments (a = b) +- ✅ Assignment chains (a = b; c = a) +- ✅ Operations preserving taint (x = a + "text") +- ✅ Container flows (list[0] = a; b = list[0]) +- ✅ Control flow branches +- ✅ Independent variables (no flow) Output ONLY the JSON, no additional text.`, sourceCode) } diff --git a/sourcecode-parser/diagnostic/prompt_test.go b/sourcecode-parser/diagnostic/prompt_test.go index 3be7c4af..d8c62292 100644 --- a/sourcecode-parser/diagnostic/prompt_test.go +++ b/sourcecode-parser/diagnostic/prompt_test.go @@ -15,16 +15,16 @@ func TestBuildAnalysisPrompt(t *testing.T) { // Verify prompt contains key elements assert.Contains(t, prompt, "dataflow analysis expert") assert.Contains(t, prompt, sourceCode) - assert.Contains(t, prompt, "DISCOVER PATTERNS") - assert.Contains(t, prompt, "TRACE DATAFLOWS") + assert.Contains(t, prompt, "DISCOVER DATA SOURCES") + assert.Contains(t, prompt, "TRACE INTRA-PROCEDURAL FLOWS") assert.Contains(t, prompt, "GENERATE TEST CASES") assert.Contains(t, prompt, "discovered_patterns") assert.Contains(t, prompt, "dataflow_test_cases") assert.Contains(t, prompt, "JSON") - assert.Contains(t, prompt, "Sources") - assert.Contains(t, prompt, "Sinks") - assert.Contains(t, prompt, "Sanitizers") - assert.Contains(t, prompt, "Propagators") + assert.Contains(t, prompt, "sources") + assert.Contains(t, prompt, "sinks") + assert.Contains(t, prompt, "sanitizers") + assert.Contains(t, prompt, "propagators") } // TestBuildAnalysisPrompt_ContainsExamples tests that prompt includes examples. @@ -37,17 +37,17 @@ func TestBuildAnalysisPrompt_ContainsExamples(t *testing.T) { assert.Contains(t, prompt, "COMMAND_INJECTION") // Check for generic dataflow examples - assert.Contains(t, prompt, "function_param") - assert.Contains(t, prompt, "return value") + assert.Contains(t, prompt, "param") + assert.Contains(t, prompt, "return") } // TestBuildAnalysisPrompt_ContainsGuidelines tests that prompt includes important guidelines. func TestBuildAnalysisPrompt_ContainsGuidelines(t *testing.T) { prompt := BuildAnalysisPrompt("def dummy(): pass") - assert.Contains(t, prompt, "NO PREDEFINED PATTERNS") + assert.Contains(t, prompt, "INTRA-PROCEDURAL ONLY") assert.Contains(t, prompt, "BE SPECIFIC") - assert.Contains(t, prompt, "TRACK EVERYTHING") + assert.Contains(t, prompt, "TRACK SIMPLE DATAFLOWS") assert.Contains(t, prompt, "CONFIDENCE SCORES") assert.Contains(t, prompt, "Output ONLY the JSON") } @@ -76,7 +76,7 @@ func TestBuildAnalysisPrompt_EmptySourceCode(t *testing.T) { prompt := BuildAnalysisPrompt("") // Should still generate valid prompt structure - assert.Contains(t, prompt, "DISCOVER PATTERNS") + assert.Contains(t, prompt, "DISCOVER DATA SOURCES") assert.Contains(t, prompt, "GENERATE TEST CASES") assert.NotEmpty(t, prompt) } diff --git a/sourcecode-parser/diagnostic/reporter.go b/sourcecode-parser/diagnostic/reporter.go new file mode 100644 index 00000000..da140700 --- /dev/null +++ b/sourcecode-parser/diagnostic/reporter.go @@ -0,0 +1,191 @@ +package diagnostic + +import ( + "encoding/json" + "fmt" + "os" + "sort" + "strings" + "time" +) + +// GenerateConsoleReport prints human-readable report to stdout. +func GenerateConsoleReport(metrics *OverallMetrics, outputDir string) error { + fmt.Println("===============================================================================") + fmt.Println(" INTRA-PROCEDURAL TAINT ANALYSIS DIAGNOSTIC") + fmt.Println("===============================================================================") + fmt.Println() + + // Overall stats + fmt.Printf("Functions Analyzed: %d\n", metrics.TotalFunctions) + fmt.Printf("Processing Time: %s\n", metrics.TotalProcessingTime) + fmt.Printf("Speed: %.1f functions/second\n", metrics.FunctionsPerSecond) + fmt.Println() + + fmt.Println("-------------------------------------------------------------------------------") + fmt.Println("OVERALL METRICS") + fmt.Println("-------------------------------------------------------------------------------") + fmt.Println() + + fmt.Printf("Agreement with LLM: %.1f%% (%d / %d)\n", + metrics.Agreement*100, + metrics.TruePositives+metrics.TrueNegatives, + metrics.TotalFunctions) + fmt.Printf("Precision: %.1f%% (%d / %d)\n", + metrics.Precision*100, + metrics.TruePositives, + metrics.TruePositives+metrics.FalsePositives) + fmt.Printf("Recall: %.1f%% (%d / %d)\n", + metrics.Recall*100, + metrics.TruePositives, + metrics.TruePositives+metrics.FalseNegatives) + fmt.Printf("F1 Score: %.1f%%\n", metrics.F1Score*100) + fmt.Println() + + fmt.Println("Confusion Matrix:") + fmt.Printf(" True Positives: %-6d (Tool detected, LLM confirmed)\n", metrics.TruePositives) + fmt.Printf(" False Positives: %-6d (Tool detected, LLM says safe)\n", metrics.FalsePositives) + fmt.Printf(" False Negatives: %-6d (Tool missed, LLM found vuln)\n", metrics.FalseNegatives) + fmt.Printf(" True Negatives: %-6d (Tool skipped, LLM confirmed safe)\n", metrics.TrueNegatives) + fmt.Println() + + // Failure breakdown + if len(metrics.FailuresByCategory) > 0 { + fmt.Println("-------------------------------------------------------------------------------") + fmt.Println("FAILURE BREAKDOWN") + fmt.Println("-------------------------------------------------------------------------------") + fmt.Println() + + // Sort categories by count + type categoryCount struct { + category string + count int + } + categories := []categoryCount{} + for cat, count := range metrics.FailuresByCategory { + categories = append(categories, categoryCount{cat, count}) + } + sort.Slice(categories, func(i, j int) bool { + return categories[i].count > categories[j].count + }) + + totalFailures := metrics.FalsePositives + metrics.FalseNegatives + + fmt.Printf("Top Failure Categories: %d total\n", totalFailures) + for i, cc := range categories { + percentage := 0.0 + if totalFailures > 0 { + percentage = float64(cc.count) / float64(totalFailures) * 100 + } + marker := "" + if i == 0 { + marker = " <- FIX THIS FIRST" + } + fmt.Printf(" %d. %-25s %d cases (%.1f%%)%s\n", + i+1, cc.category+":", cc.count, percentage, marker) + } + fmt.Println() + } + + // Top failures + if len(metrics.TopFailures) > 0 { + fmt.Println("-------------------------------------------------------------------------------") + fmt.Println("TOP FAILURE EXAMPLES") + fmt.Println("-------------------------------------------------------------------------------") + fmt.Println() + + count := 0 + for _, failure := range metrics.TopFailures { + if count >= 5 { + break + } + count++ + + typeStr := "FALSE NEGATIVE" + if failure.Type == "FALSE_POSITIVE" { + typeStr = "FALSE POSITIVE" + } + + fmt.Printf("%d. %s (%s)\n", count, failure.FunctionFQN, failure.Category) + fmt.Printf(" File: %s:%d\n", failure.FunctionFile, failure.FunctionLine) + fmt.Printf(" Type: %s\n", typeStr) + fmt.Println() + if failure.Reason != "" { + fmt.Printf(" Reason: %s\n", wrapText(failure.Reason, 70, " ")) + fmt.Println() + } + } + } + + // Recommendations + fmt.Println("-------------------------------------------------------------------------------") + fmt.Println("NEXT STEPS") + fmt.Println("-------------------------------------------------------------------------------") + fmt.Println() + fmt.Println("1. Review top failure examples above") + fmt.Println("2. Focus on the top failure category to maximize impact") + fmt.Println("3. Re-run diagnostic after improvements to measure progress") + fmt.Println() + + // Save location + if outputDir != "" { + fmt.Printf("Full report saved to: %s/\n", outputDir) + fmt.Println() + } + + fmt.Println("===============================================================================") + + return nil +} + +// GenerateJSONReport writes machine-readable JSON report. +func GenerateJSONReport( + metrics *OverallMetrics, + comparisons []*DualLevelComparison, + outputPath string, +) error { + report := map[string]interface{}{ + "metrics": metrics, + "comparisons": comparisons, + "timestamp": time.Now().Format(time.RFC3339), + } + + jsonBytes, err := json.MarshalIndent(report, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal JSON: %w", err) + } + + err = os.WriteFile(outputPath, jsonBytes, 0644) + if err != nil { + return fmt.Errorf("failed to write JSON: %w", err) + } + + return nil +} + +// Helper functions + +func wrapText(text string, width int, prefix string) string { + words := strings.Fields(text) + if len(words) == 0 { + return "" + } + + var result strings.Builder + lineLength := 0 + + for _, word := range words { + if lineLength+len(word)+1 > width { + result.WriteString("\n" + prefix) + lineLength = 0 + } + if lineLength > 0 { + result.WriteString(" ") + lineLength++ + } + result.WriteString(word) + lineLength += len(word) + } + + return result.String() +} diff --git a/sourcecode-parser/diagnostic/reporter_test.go b/sourcecode-parser/diagnostic/reporter_test.go new file mode 100644 index 00000000..270ae864 --- /dev/null +++ b/sourcecode-parser/diagnostic/reporter_test.go @@ -0,0 +1,351 @@ +package diagnostic + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGenerateConsoleReport tests human-readable console output. +func TestGenerateConsoleReport(t *testing.T) { + metrics := &OverallMetrics{ + TotalFunctions: 100, + TruePositives: 60, + FalsePositives: 10, + FalseNegatives: 20, + TrueNegatives: 10, + Precision: 0.857, + Recall: 0.75, + F1Score: 0.8, + Agreement: 0.7, + TotalProcessingTime: "5m30s", + FunctionsPerSecond: 0.303, + FailuresByCategory: map[string]int{ + "control_flow_branch": 15, + "sanitizer_missed": 10, + "field_sensitivity": 5, + }, + TopFailures: []FailureExample{ + { + Type: "FALSE_POSITIVE", + FunctionFQN: "test.example", + FunctionFile: "test.py", + FunctionLine: 42, + Category: "sanitizer_missed", + Reason: "Data is sanitized by escape function", + }, + }, + } + + err := GenerateConsoleReport(metrics, "") + assert.NoError(t, err) +} + +// TestGenerateConsoleReport_WithOutputDir tests console report with output directory. +func TestGenerateConsoleReport_WithOutputDir(t *testing.T) { + metrics := &OverallMetrics{ + TotalFunctions: 50, + TruePositives: 30, + FalsePositives: 5, + FalseNegatives: 10, + TrueNegatives: 5, + Precision: 0.857, + Recall: 0.75, + F1Score: 0.8, + Agreement: 0.7, + TotalProcessingTime: "2m15s", + FunctionsPerSecond: 0.370, + FailuresByCategory: map[string]int{}, + TopFailures: []FailureExample{}, + } + + err := GenerateConsoleReport(metrics, "/tmp/diagnostic_output") + assert.NoError(t, err) +} + +// TestGenerateConsoleReport_NoFailures tests report with no failures. +func TestGenerateConsoleReport_NoFailures(t *testing.T) { + metrics := &OverallMetrics{ + TotalFunctions: 10, + TruePositives: 5, + FalsePositives: 0, + FalseNegatives: 0, + TrueNegatives: 5, + Precision: 1.0, + Recall: 1.0, + F1Score: 1.0, + Agreement: 1.0, + TotalProcessingTime: "30s", + FunctionsPerSecond: 0.333, + FailuresByCategory: map[string]int{}, + TopFailures: []FailureExample{}, + } + + err := GenerateConsoleReport(metrics, "") + assert.NoError(t, err) +} + +// TestGenerateJSONReport tests machine-readable JSON output. +func TestGenerateJSONReport(t *testing.T) { + tempDir := t.TempDir() + jsonPath := filepath.Join(tempDir, "report.json") + + metrics := &OverallMetrics{ + TotalFunctions: 100, + TruePositives: 60, + FalsePositives: 10, + FalseNegatives: 20, + TrueNegatives: 10, + Precision: 0.857, + Recall: 0.75, + F1Score: 0.8, + Agreement: 0.7, + TotalProcessingTime: "5m30s", + FunctionsPerSecond: 0.303, + FailuresByCategory: map[string]int{ + "control_flow_branch": 15, + }, + TopFailures: []FailureExample{ + { + Type: "FALSE_POSITIVE", + FunctionFQN: "test.example", + FunctionFile: "test.py", + FunctionLine: 42, + Category: "sanitizer_missed", + Reason: "Data is sanitized", + }, + }, + } + + comparisons := []*DualLevelComparison{ + { + FunctionFQN: "test.func1", + BinaryToolResult: true, + BinaryLLMResult: true, + BinaryAgreement: true, + }, + { + FunctionFQN: "test.func2", + BinaryToolResult: true, + BinaryLLMResult: false, + BinaryAgreement: false, + FailureCategory: "sanitizer_missed", + FailureReason: "Sanitized", + }, + } + + err := GenerateJSONReport(metrics, comparisons, jsonPath) + require.NoError(t, err) + + // Verify file was created + _, err = os.Stat(jsonPath) + assert.NoError(t, err) + + // Verify JSON structure + jsonBytes, err := os.ReadFile(jsonPath) + require.NoError(t, err) + + var report map[string]interface{} + err = json.Unmarshal(jsonBytes, &report) + require.NoError(t, err) + + assert.Contains(t, report, "metrics") + assert.Contains(t, report, "comparisons") + assert.Contains(t, report, "timestamp") + + // Verify metrics structure + metricsData := report["metrics"].(map[string]interface{}) + assert.Equal(t, float64(100), metricsData["TotalFunctions"]) + assert.Equal(t, float64(60), metricsData["TruePositives"]) + + // Verify comparisons array + comparisonsData := report["comparisons"].([]interface{}) + assert.Equal(t, 2, len(comparisonsData)) +} + +// TestGenerateJSONReport_EmptyData tests JSON report with empty data. +func TestGenerateJSONReport_EmptyData(t *testing.T) { + tempDir := t.TempDir() + jsonPath := filepath.Join(tempDir, "empty_report.json") + + metrics := &OverallMetrics{ + TotalFunctions: 0, + FailuresByCategory: map[string]int{}, + TopFailures: []FailureExample{}, + } + + comparisons := []*DualLevelComparison{} + + err := GenerateJSONReport(metrics, comparisons, jsonPath) + require.NoError(t, err) + + // Verify file was created + _, err = os.Stat(jsonPath) + assert.NoError(t, err) + + // Verify JSON is valid + jsonBytes, err := os.ReadFile(jsonPath) + require.NoError(t, err) + + var report map[string]interface{} + err = json.Unmarshal(jsonBytes, &report) + require.NoError(t, err) +} + +// TestGenerateJSONReport_InvalidPath tests error handling for invalid path. +func TestGenerateJSONReport_InvalidPath(t *testing.T) { + metrics := &OverallMetrics{ + TotalFunctions: 10, + } + comparisons := []*DualLevelComparison{} + + // Use invalid path (directory doesn't exist and is deeply nested) + invalidPath := "/nonexistent/deeply/nested/path/report.json" + + err := GenerateJSONReport(metrics, comparisons, invalidPath) + assert.Error(t, err) +} + +// TestGenerateJSONReport_TimestampFormat tests timestamp is in RFC3339 format. +func TestGenerateJSONReport_TimestampFormat(t *testing.T) { + tempDir := t.TempDir() + jsonPath := filepath.Join(tempDir, "timestamp_report.json") + + metrics := &OverallMetrics{ + TotalFunctions: 10, + } + comparisons := []*DualLevelComparison{} + + err := GenerateJSONReport(metrics, comparisons, jsonPath) + require.NoError(t, err) + + jsonBytes, err := os.ReadFile(jsonPath) + require.NoError(t, err) + + var report map[string]interface{} + err = json.Unmarshal(jsonBytes, &report) + require.NoError(t, err) + + timestampStr := report["timestamp"].(string) + assert.NotEmpty(t, timestampStr) + + // Verify RFC3339 format + _, err = time.Parse(time.RFC3339, timestampStr) + assert.NoError(t, err) +} + +// TestWrapText tests text wrapping helper. +func TestWrapText(t *testing.T) { + tests := []struct { + name string + text string + width int + prefix string + expected string + }{ + { + name: "short text", + text: "Short text", + width: 20, + prefix: "", + expected: "Short text", + }, + { + name: "long text with wrapping", + text: "This is a very long text that should be wrapped at the specified width", + width: 20, + prefix: " ", + expected: "This is a very long\n text that should be\n wrapped at the\n specified width", + }, + { + name: "empty text", + text: "", + width: 20, + prefix: "", + expected: "", + }, + { + name: "single word", + text: "Word", + width: 10, + prefix: "", + expected: "Word", + }, + { + name: "exact width", + text: "Exactly twenty chars", + width: 20, + prefix: "", + expected: "Exactly twenty chars", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := wrapText(tt.text, tt.width, tt.prefix) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestGenerateConsoleReport_MultipleFailureCategories tests report with many categories. +func TestGenerateConsoleReport_MultipleFailureCategories(t *testing.T) { + metrics := &OverallMetrics{ + TotalFunctions: 100, + TruePositives: 50, + FalsePositives: 25, + FalseNegatives: 25, + TrueNegatives: 0, + Precision: 0.667, + Recall: 0.667, + F1Score: 0.667, + Agreement: 0.5, + FailuresByCategory: map[string]int{ + "control_flow_branch": 20, + "sanitizer_missed": 15, + "field_sensitivity": 10, + "container_operation": 3, + "string_formatting": 2, + }, + TopFailures: []FailureExample{}, + TotalProcessingTime: "10m", + FunctionsPerSecond: 0.167, + } + + err := GenerateConsoleReport(metrics, "") + assert.NoError(t, err) +} + +// TestGenerateConsoleReport_LongReasonText tests wrapping of long failure reasons. +func TestGenerateConsoleReport_LongReasonText(t *testing.T) { + metrics := &OverallMetrics{ + TotalFunctions: 10, + TruePositives: 5, + FalsePositives: 5, + Precision: 0.5, + Recall: 1.0, + F1Score: 0.667, + Agreement: 0.5, + TopFailures: []FailureExample{ + { + Type: "FALSE_POSITIVE", + FunctionFQN: "test.long_reason", + FunctionFile: "test.py", + FunctionLine: 100, + Category: "sanitizer_missed", + Reason: "This is a very long reason that explains in great detail why the tool incorrectly flagged this function as vulnerable when in reality it is perfectly safe because the data is sanitized by multiple layers of validation and escaping before being used in the SQL query", + }, + }, + FailuresByCategory: map[string]int{}, + TotalProcessingTime: "1m", + FunctionsPerSecond: 0.167, + } + + err := GenerateConsoleReport(metrics, "") + assert.NoError(t, err) +} diff --git a/sourcecode-parser/diagnostic/types.go b/sourcecode-parser/diagnostic/types.go index 9ba79d28..b16a606b 100644 --- a/sourcecode-parser/diagnostic/types.go +++ b/sourcecode-parser/diagnostic/types.go @@ -43,149 +43,166 @@ type FunctionMetadata struct { } // LLMAnalysisResult contains the LLM's analysis of a function. +// +//nolint:tagliatelle // LLM API uses snake_case JSON tags type LLMAnalysisResult struct { // FunctionFQN identifies which function was analyzed - FunctionFQN string + FunctionFQN string `json:"function_fqn,omitempty"` // DiscoveredPatterns contains sources/sinks/sanitizers found by LLM - DiscoveredPatterns DiscoveredPatterns + DiscoveredPatterns DiscoveredPatterns `json:"discovered_patterns"` // DataflowTestCases are test cases generated by LLM // Each test case specifies expected dataflow behavior - DataflowTestCases []DataflowTestCase + DataflowTestCases []DataflowTestCase `json:"dataflow_test_cases"` // VariableTracking shows how LLM traced variables through the function - VariableTracking []VariableTrack + VariableTracking []VariableTrack `json:"variable_tracking"` // Metadata about the analysis - AnalysisMetadata AnalysisMetadata + AnalysisMetadata AnalysisMetadata `json:"analysis_metadata"` } // DiscoveredPatterns contains all patterns discovered by LLM in the function. type DiscoveredPatterns struct { - Sources []PatternLocation - Sinks []PatternLocation - Sanitizers []PatternLocation - Propagators []PropagatorOperation + Sources []PatternLocation `json:"sources"` + Sinks []PatternLocation `json:"sinks"` + Sanitizers []PatternLocation `json:"sanitizers"` + Propagators []PropagatorOperation `json:"propagators"` } // PatternLocation describes where a pattern (source/sink/sanitizer) was found. type PatternLocation struct { // Pattern is the code pattern (e.g., "request.GET", "os.system") - Pattern string + Pattern string `json:"pattern"` // Lines where this pattern appears - Lines []int + Lines []int `json:"lines"` // Variables involved - Variables []string + Variables []string `json:"variables"` // Category for semantic grouping // Examples: "user_input", "file_read", "sql_execution", "command_exec" - Category string + Category string `json:"category"` // Description of what this pattern does - Description string + Description string `json:"description"` // Severity (for sinks): CRITICAL, HIGH, MEDIUM, LOW - Severity string + Severity string `json:"severity,omitempty"` } // PropagatorOperation describes how data propagates. +// +//nolint:tagliatelle // LLM API uses snake_case JSON tags type PropagatorOperation struct { // Type: "assignment", "function_call", "return" - Type string + Type string `json:"type"` // Line number - Line int + Line int `json:"line"` // Source variable - FromVar string + FromVar string `json:"from_var"` // Destination variable - ToVar string + ToVar string `json:"to_var"` // Function name (if Type == "function_call") - Function string + Function string `json:"function,omitempty"` } // DataflowTestCase is a test case generated by LLM. // This is what we validate our tool against. -type DataflowTestCase struct { +// +//nolint:tagliatelle // LLM API uses snake_case JSON tags +type DataflowTestCase struct{ // TestID for reference - TestID int + TestID int `json:"test_id"` // Description of what this test validates - Description string + Description string `json:"description"` // Source information - Source TestCaseSource + Source TestCaseSource `json:"source"` // Sink information - Sink TestCaseSink + Sink TestCaseSink `json:"sink"` // Flow path (sequence of variables/operations) - FlowPath []FlowStep + FlowPath []FlowStep `json:"flow_path"` // Sanitizers in the path (if any) - SanitizersInPath []string + SanitizersInPath []string `json:"sanitizers_in_path"` // Expected detection result // true: Our tool SHOULD detect this flow // false: Our tool should NOT detect (e.g., sanitized) - ExpectedDetection bool + ExpectedDetection bool `json:"expected_detection"` // Vulnerability type (if ExpectedDetection == true) - VulnerabilityType string + VulnerabilityType string `json:"vulnerability_type"` // Confidence score (0.0-1.0) - Confidence float64 + Confidence float64 `json:"confidence"` // Reasoning for this test case - Reasoning string + Reasoning string `json:"reasoning"` + + // Failure category (if tool might miss this) + // Categories: control_flow_branch, field_sensitivity, sanitizer_missed, + // container_operation, string_formatting, method_call_propagation, + // assignment_chain, return_flow, parameter_flow, complex_expression, + // context_required, none + FailureCategory string `json:"failure_category,omitempty"` } // TestCaseSource describes the source in a test case. type TestCaseSource struct { - Pattern string // e.g., "request.GET['cmd']" - Line int - Variable string + Pattern string `json:"pattern"` // e.g., "request.GET['cmd']" + Line int `json:"line"` + Variable string `json:"variable"` } // TestCaseSink describes the sink in a test case. type TestCaseSink struct { - Pattern string // e.g., "os.system" - Line int - Variable string + Pattern string `json:"pattern"` // e.g., "os.system" + Line int `json:"line"` + Variable string `json:"variable"` } // FlowStep describes one step in a dataflow path. type FlowStep struct { - Line int - Variable string - Operation string // "source", "assignment", "call", "sanitizer", "sink" + Line int `json:"line"` + Variable string `json:"variable"` + Operation string `json:"operation"` // "source", "assignment", "call", "sanitizer", "sink" } // VariableTrack shows how LLM traced a variable. +// +//nolint:tagliatelle // LLM API uses snake_case JSON tags type VariableTrack struct { - Variable string - FirstDefined int - LastUsed int - Aliases []string // Other variable names that hold the same data - FlowsToLines []int - FlowsToVars []string + Variable string `json:"variable"` + FirstDefined int `json:"first_defined"` + LastUsed int `json:"last_used"` + Aliases []string `json:"aliases"` // Other variable names that hold the same data + FlowsToLines []int `json:"flows_to_lines"` + FlowsToVars []string `json:"flows_to_vars"` } // AnalysisMetadata contains metadata about the LLM analysis. +// +//nolint:tagliatelle // LLM API uses snake_case JSON tags type AnalysisMetadata struct { - TotalSources int - TotalSinks int - TotalSanitizers int - TotalFlows int - DangerousFlows int - SafeFlows int - Confidence float64 - Limitations []string - ProcessingTime string - ModelUsed string + TotalSources int `json:"total_sources"` + TotalSinks int `json:"total_sinks"` + TotalSanitizers int `json:"total_sanitizers"` + TotalFlows int `json:"total_flows"` + DangerousFlows int `json:"dangerous_flows"` + SafeFlows int `json:"safe_flows"` + Confidence float64 `json:"confidence"` + Limitations []string `json:"limitations"` + ProcessingTime string `json:"processing_time,omitempty"` + ModelUsed string `json:"model_used,omitempty"` } diff --git a/sourcecode-parser/main_test.go b/sourcecode-parser/main_test.go index fc27616f..33cae261 100644 --- a/sourcecode-parser/main_test.go +++ b/sourcecode-parser/main_test.go @@ -23,7 +23,7 @@ func TestExecute(t *testing.T) { { name: "Successful execution", mockExecuteErr: nil, - expectedOutput: "Code Pathfinder is designed for identifying vulnerabilities in source code.\n\nUsage:\n pathfinder [command]\n\nAvailable Commands:\n analyze Analyze source code for security vulnerabilities using call graph\n ci Scan a project for vulnerabilities with ruleset in ci mode\n completion Generate the autocompletion script for the specified shell\n help Help about any command\n query Execute queries on the source code\n resolution-report Generate a diagnostic report on call resolution statistics\n scan Scan a project for vulnerabilities with ruleset\n version Print the version and commit information\n\nFlags:\n --disable-metrics Disable metrics collection\n -h, --help help for pathfinder\n --verbose Verbose output\n\nUse \"pathfinder [command] --help\" for more information about a command.\n", + expectedOutput: "Code Pathfinder is designed for identifying vulnerabilities in source code.\n\nUsage:\n pathfinder [command]\n\nAvailable Commands:\n analyze Analyze source code for security vulnerabilities using call graph\n ci Scan a project for vulnerabilities with ruleset in ci mode\n completion Generate the autocompletion script for the specified shell\n diagnose Validate intra-procedural taint analysis against LLM ground truth\n help Help about any command\n query Execute queries on the source code\n resolution-report Generate a diagnostic report on call resolution statistics\n scan Scan a project for vulnerabilities with ruleset\n version Print the version and commit information\n\nFlags:\n --disable-metrics Disable metrics collection\n -h, --help help for pathfinder\n --verbose Verbose output\n\nUse \"pathfinder [command] --help\" for more information about a command.\n", expectedExit: 0, }, }