diff --git a/sourcecode-parser/cmd/analyze.go b/sourcecode-parser/cmd/analyze.go new file mode 100644 index 00000000..4828d60e --- /dev/null +++ b/sourcecode-parser/cmd/analyze.go @@ -0,0 +1,128 @@ +package cmd + +import ( + "fmt" + "strings" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph" + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph" + "github.com/spf13/cobra" +) + +var analyzeCmd = &cobra.Command{ + Use: "analyze", + Short: "Analyze source code for security vulnerabilities using call graph", + Run: func(cmd *cobra.Command, _ []string) { + projectInput := cmd.Flag("project").Value.String() + + if projectInput == "" { + fmt.Println("Error: --project flag is required") + return + } + + fmt.Println("Building code graph...") + codeGraph := graph.Initialize(projectInput) + + fmt.Println("Building call graph and analyzing security patterns...") + cg, registry, patternRegistry, err := callgraph.InitializeCallGraph(codeGraph, projectInput) + if err != nil { + fmt.Println("Error building call graph:", err) + return + } + + fmt.Printf("Call graph built successfully: %d functions indexed\n", len(cg.Functions)) + fmt.Printf("Module registry: %d modules\n", len(registry.Modules)) + + // Debug: Print call graph details (commented out for production) + // fmt.Printf("\nDEBUG: Call graph statistics:\n") + // fmt.Printf(" Functions indexed: %d\n", len(cg.Functions)) + // for fqn := range cg.Functions { + // fmt.Printf(" - %s\n", fqn) + // } + // fmt.Printf(" Call sites: %d callers\n", len(cg.CallSites)) + // for caller, sites := range cg.CallSites { + // fmt.Printf(" %s makes %d calls:\n", caller, len(sites)) + // for _, site := range sites { + // fmt.Printf(" - Target: %s, TargetFQN: %s, Resolved: %v\n", site.Target, site.TargetFQN, site.Resolved) + // } + // } + // fmt.Println() + + // Run security analysis + matches := callgraph.AnalyzePatterns(cg, patternRegistry) + + if len(matches) == 0 { + fmt.Println("\n✓ No security issues found!") + return + } + + fmt.Printf("\n⚠ Found %d potential security issues:\n\n", len(matches)) + for i, match := range matches { + fmt.Printf("%d. [%s] %s\n", i+1, match.Severity, match.PatternName) + fmt.Printf(" Description: %s\n", match.Description) + fmt.Printf(" CWE: %s, OWASP: %s\n\n", match.CWE, match.OWASP) + + // Display source information + if match.SourceFQN != "" { + if match.SourceCall != "" { + fmt.Printf(" Source: %s() calls %s()\n", match.SourceFQN, match.SourceCall) + } else { + fmt.Printf(" Source: %s\n", match.SourceFQN) + } + if match.SourceFile != "" { + fmt.Printf(" at %s:%d\n", match.SourceFile, match.SourceLine) + if match.SourceCode != "" { + printCodeSnippet(match.SourceCode, int(match.SourceLine)) + } + } + fmt.Println() + } + + // Display sink information + if match.SinkFQN != "" { + if match.SinkCall != "" { + fmt.Printf(" Sink: %s() calls %s()\n", match.SinkFQN, match.SinkCall) + } else { + fmt.Printf(" Sink: %s\n", match.SinkFQN) + } + if match.SinkFile != "" { + fmt.Printf(" at %s:%d\n", match.SinkFile, match.SinkLine) + if match.SinkCode != "" { + printCodeSnippet(match.SinkCode, int(match.SinkLine)) + } + } + fmt.Println() + } + + // Display data flow path + if len(match.DataFlowPath) > 0 { + fmt.Printf(" Data flow path (%d steps):\n", len(match.DataFlowPath)) + for j, step := range match.DataFlowPath { + if j == 0 { + fmt.Printf(" %s (source)\n", step) + } else if j == len(match.DataFlowPath)-1 { + fmt.Printf(" └─> %s (sink)\n", step) + } else { + fmt.Printf(" └─> %s\n", step) + } + } + fmt.Println() + } + } + }, +} + +func printCodeSnippet(code string, startLine int) { + lines := strings.Split(code, "\n") + for i, line := range lines { + if line != "" { + fmt.Printf(" %4d | %s\n", startLine+i, line) + } + } +} + +func init() { + rootCmd.AddCommand(analyzeCmd) + analyzeCmd.Flags().StringP("project", "p", "", "Project directory to analyze (required)") + analyzeCmd.MarkFlagRequired("project") //nolint:all +} diff --git a/sourcecode-parser/cmd/resolution_report.go b/sourcecode-parser/cmd/resolution_report.go new file mode 100644 index 00000000..9f9c611e --- /dev/null +++ b/sourcecode-parser/cmd/resolution_report.go @@ -0,0 +1,223 @@ +package cmd + +import ( + "fmt" + "sort" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph" + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph" + "github.com/spf13/cobra" +) + +var resolutionReportCmd = &cobra.Command{ + Use: "resolution-report", + Short: "Generate a diagnostic report on call resolution statistics", + Long: `Analyze the call graph and generate a detailed report showing: + - Overall resolution statistics (resolved vs unresolved) + - Breakdown by failure category + - Top unresolved patterns with occurrence counts + +This helps identify why calls are not being resolved and prioritize +improvements to the resolution logic.`, + Run: func(cmd *cobra.Command, _ []string) { + projectInput := cmd.Flag("project").Value.String() + + if projectInput == "" { + fmt.Println("Error: --project flag is required") + return + } + + fmt.Println("Building code graph...") + codeGraph := graph.Initialize(projectInput) + + fmt.Println("Building call graph...") + cg, registry, _, err := callgraph.InitializeCallGraph(codeGraph, projectInput) + if err != nil { + fmt.Printf("Error building call graph: %v\n", err) + return + } + + fmt.Printf("\nResolution Report for %s\n", projectInput) + fmt.Println("===============================================") + + // Collect statistics + stats := aggregateResolutionStatistics(cg) + + // Print overall statistics + printOverallStatistics(stats) + fmt.Println() + + // Print failure breakdown + printFailureBreakdown(stats) + fmt.Println() + + // Print top unresolved patterns + printTopUnresolvedPatterns(stats, 20) + fmt.Println() + + fmt.Printf("Module registry: %d modules\n", len(registry.Modules)) + }, +} + +// resolutionStatistics holds aggregated statistics about call resolution. +type resolutionStatistics struct { + TotalCalls int + ResolvedCalls int + UnresolvedCalls int + FailuresByReason map[string]int // Category -> count + PatternCounts map[string]int // Target pattern -> count + FrameworkCounts map[string]int // Framework prefix -> count (for external_framework category) + UnresolvedByFQN map[string]callgraph.CallSite // For detailed inspection +} + +// aggregateResolutionStatistics analyzes the call graph and collects statistics. +func aggregateResolutionStatistics(cg *callgraph.CallGraph) *resolutionStatistics { + stats := &resolutionStatistics{ + FailuresByReason: make(map[string]int), + PatternCounts: make(map[string]int), + FrameworkCounts: make(map[string]int), + UnresolvedByFQN: make(map[string]callgraph.CallSite), + } + + // Iterate through all call sites + for _, callSites := range cg.CallSites { + for _, site := range callSites { + stats.TotalCalls++ + + if site.Resolved { + stats.ResolvedCalls++ + } else { + stats.UnresolvedCalls++ + + // Count by failure reason + if site.FailureReason != "" { + stats.FailuresByReason[site.FailureReason]++ + } else { + stats.FailuresByReason["uncategorized"]++ + } + + // Count pattern occurrences + stats.PatternCounts[site.Target]++ + + // For external frameworks, track which framework + if site.FailureReason == "external_framework" { + // Extract framework prefix (first component before dot) + for idx := 0; idx < len(site.TargetFQN); idx++ { + if site.TargetFQN[idx] == '.' { + framework := site.TargetFQN[:idx] + stats.FrameworkCounts[framework]++ + break + } + } + } + + // Store for detailed inspection + stats.UnresolvedByFQN[site.TargetFQN] = site + } + } + } + + return stats +} + +// printOverallStatistics prints the overall resolution statistics. +func printOverallStatistics(stats *resolutionStatistics) { + fmt.Println("Overall Statistics:") + fmt.Printf(" Total calls: %d\n", stats.TotalCalls) + fmt.Printf(" Resolved: %d (%.1f%%)\n", + stats.ResolvedCalls, + percentage(stats.ResolvedCalls, stats.TotalCalls)) + fmt.Printf(" Unresolved: %d (%.1f%%)\n", + stats.UnresolvedCalls, + percentage(stats.UnresolvedCalls, stats.TotalCalls)) +} + +// printFailureBreakdown prints the breakdown of failures by category. +func printFailureBreakdown(stats *resolutionStatistics) { + fmt.Println("Failure Breakdown:") + + // Sort categories by count (descending) + type categoryCount struct { + category string + count int + } + categories := make([]categoryCount, 0, len(stats.FailuresByReason)) + for cat, count := range stats.FailuresByReason { + categories = append(categories, categoryCount{cat, count}) + } + sort.Slice(categories, func(i, j int) bool { + return categories[i].count > categories[j].count + }) + + // Print each category + for _, cc := range categories { + fmt.Printf(" %-20s %d (%.1f%%)\n", + cc.category+":", + cc.count, + percentage(cc.count, stats.TotalCalls)) + + // For external frameworks, show framework breakdown + if cc.category == "external_framework" && len(stats.FrameworkCounts) > 0 { + // Sort frameworks by count + type frameworkCount struct { + framework string + count int + } + var frameworks []frameworkCount + for fw, count := range stats.FrameworkCounts { + frameworks = append(frameworks, frameworkCount{fw, count}) + } + sort.Slice(frameworks, func(i, j int) bool { + return frameworks[i].count > frameworks[j].count + }) + + // Print top 5 frameworks + for i, fc := range frameworks { + if i >= 5 { + break + } + fmt.Printf(" %s.*: %d\n", fc.framework, fc.count) + } + } + } +} + +// printTopUnresolvedPatterns prints the most common unresolved patterns. +func printTopUnresolvedPatterns(stats *resolutionStatistics, topN int) { + fmt.Printf("Top %d Unresolved Patterns:\n", topN) + + // Sort patterns by count (descending) + type patternCount struct { + pattern string + count int + } + patterns := make([]patternCount, 0, len(stats.PatternCounts)) + for pattern, count := range stats.PatternCounts { + patterns = append(patterns, patternCount{pattern, count}) + } + sort.Slice(patterns, func(i, j int) bool { + return patterns[i].count > patterns[j].count + }) + + // Print top N patterns + for i, pc := range patterns { + if i >= topN { + break + } + fmt.Printf(" %2d. %-40s %d occurrences\n", i+1, pc.pattern, pc.count) + } +} + +// percentage calculates the percentage of part out of total. +func percentage(part, total int) float64 { + if total == 0 { + return 0.0 + } + return float64(part) * 100.0 / float64(total) +} + +func init() { + rootCmd.AddCommand(resolutionReportCmd) + resolutionReportCmd.Flags().StringP("project", "p", "", "Project root directory") + resolutionReportCmd.MarkFlagRequired("project") +} diff --git a/sourcecode-parser/cmd/resolution_report_test.go b/sourcecode-parser/cmd/resolution_report_test.go new file mode 100644 index 00000000..9585638f --- /dev/null +++ b/sourcecode-parser/cmd/resolution_report_test.go @@ -0,0 +1,88 @@ +package cmd + +import ( + "testing" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph" + "github.com/stretchr/testify/assert" +) + +// Note: categorizeResolutionFailure is in graph/callgraph/builder.go, not cmd package +// This test file validates the resolution report output formatting + +func TestAggregateResolutionStatistics(t *testing.T) { + // Create a mock call graph with various call sites + cg := callgraph.NewCallGraph() + + // Add resolved call sites + cg.AddCallSite("test.func1", callgraph.CallSite{ + Target: "print", + Resolved: true, + TargetFQN: "builtins.print", + }) + + // Add unresolved call sites with different failure reasons + cg.AddCallSite("test.func2", callgraph.CallSite{ + Target: "models.ForeignKey", + Resolved: false, + TargetFQN: "django.db.models.ForeignKey", + FailureReason: "external_framework", + }) + + cg.AddCallSite("test.func3", callgraph.CallSite{ + Target: "Task.objects.filter", + Resolved: false, + TargetFQN: "tasks.models.Task.objects.filter", + FailureReason: "orm_pattern", + }) + + cg.AddCallSite("test.func4", callgraph.CallSite{ + Target: "response.json", + Resolved: false, + TargetFQN: "response.json", + FailureReason: "variable_method", + }) + + // Aggregate statistics + stats := aggregateResolutionStatistics(cg) + + // Validate overall counts + assert.Equal(t, 4, stats.TotalCalls) + assert.Equal(t, 1, stats.ResolvedCalls) + assert.Equal(t, 3, stats.UnresolvedCalls) + + // Validate failure breakdown + assert.Equal(t, 1, stats.FailuresByReason["external_framework"]) + assert.Equal(t, 1, stats.FailuresByReason["orm_pattern"]) + assert.Equal(t, 1, stats.FailuresByReason["variable_method"]) + + // Validate pattern counts + assert.Equal(t, 1, stats.PatternCounts["models.ForeignKey"]) + assert.Equal(t, 1, stats.PatternCounts["Task.objects.filter"]) + assert.Equal(t, 1, stats.PatternCounts["response.json"]) + + // Validate framework counts + assert.Equal(t, 1, stats.FrameworkCounts["django"]) +} + +func TestPercentage(t *testing.T) { + tests := []struct { + name string + part int + total int + expected float64 + }{ + {"50 percent", 50, 100, 50.0}, + {"zero percent", 0, 100, 0.0}, + {"hundred percent", 100, 100, 100.0}, + {"zero total", 10, 0, 0.0}, + {"decimal result", 1, 3, 33.333333333333336}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := percentage(tt.part, tt.total) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/sourcecode-parser/graph/callgraph/benchmark_test.go b/sourcecode-parser/graph/callgraph/benchmark_test.go new file mode 100644 index 00000000..592938da --- /dev/null +++ b/sourcecode-parser/graph/callgraph/benchmark_test.go @@ -0,0 +1,356 @@ +package callgraph + +import ( + "os" + "testing" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph" +) + +// Benchmark project paths +// These paths are used for performance testing against real-world codebases. +const ( + // Small project: ~5 Python files, simple imports. + smallProjectPath = "../../../test-src/python/simple_project" + + // Medium project: label-studio (~1000 Python files, complex imports). + mediumProjectPath = "/Users/shiva/src/label-studio/label_studio" + + // Large project: salt (~10,000 Python files, very complex imports). + largeProjectPath = "/Users/shiva/src/shivasurya/salt/salt" +) + +// BenchmarkBuildModuleRegistry_Small measures module registry performance on a small codebase. +// Target: <10ms +// +// This benchmark tests Pass 1 of the 3-pass algorithm on a minimal project. +// It measures the overhead of directory walking and module path conversion. +func BenchmarkBuildModuleRegistry_Small(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + registry, err := BuildModuleRegistry(smallProjectPath) + if err != nil { + b.Fatalf("Failed to build module registry: %v", err) + } + if len(registry.Modules) == 0 { + b.Fatal("Expected modules to be registered") + } + } +} + +// BenchmarkBuildModuleRegistry_Medium measures module registry performance on a medium codebase. +// Target: <500ms +// +// This benchmark tests Pass 1 against label-studio, a real-world Django application. +// It stresses the directory walking and file filtering logic. +func BenchmarkBuildModuleRegistry_Medium(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + registry, err := BuildModuleRegistry(mediumProjectPath) + if err != nil { + b.Fatalf("Failed to build module registry: %v", err) + } + if len(registry.Modules) == 0 { + b.Fatal("Expected modules to be registered") + } + } +} + +// BenchmarkBuildModuleRegistry_Large measures module registry performance on a large codebase. +// Target: <2s +// +// This benchmark tests Pass 1 against salt, a massive Python project with thousands of modules. +// It validates that the algorithm scales to production-sized codebases. +func BenchmarkBuildModuleRegistry_Large(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + registry, err := BuildModuleRegistry(largeProjectPath) + if err != nil { + b.Fatalf("Failed to build module registry: %v", err) + } + if len(registry.Modules) == 0 { + b.Fatal("Expected modules to be registered") + } + } +} + +// BenchmarkExtractImports_Small measures import extraction performance on a small project. +// Target: <20ms +// +// This benchmark tests Pass 2A (import extraction) using tree-sitter parsing. +// It measures parser initialization and AST traversal overhead. +func BenchmarkExtractImports_Small(b *testing.B) { + // Pre-build registry to isolate import extraction performance + registry, err := BuildModuleRegistry(smallProjectPath) + if err != nil { + b.Fatalf("Failed to build module registry: %v", err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Extract imports from all files in the small project + for modulePath, filePath := range registry.Modules { + sourceCode, readErr := os.ReadFile(filePath) + if readErr != nil { + b.Fatalf("Failed to read file %s: %v", filePath, readErr) + } + + _, extractErr := ExtractImports(filePath, sourceCode, registry) + if extractErr != nil { + b.Fatalf("Failed to extract imports from %s: %v", modulePath, extractErr) + } + } + } +} + +// BenchmarkExtractImports_Medium measures import extraction performance on a medium project. +// Target: <2s +// +// This benchmark tests Pass 2A against label-studio's import patterns. +// It validates that tree-sitter parsing scales to production projects. +func BenchmarkExtractImports_Medium(b *testing.B) { + // Pre-build registry to isolate import extraction performance + registry, err := BuildModuleRegistry(mediumProjectPath) + if err != nil { + b.Fatalf("Failed to build module registry: %v", err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Extract imports from all files in the medium project + for _, filePath := range registry.Modules { + sourceCode, readErr := os.ReadFile(filePath) + if readErr != nil { + // Skip files that can't be read (permissions, etc.) + continue + } + + _, extractErr := ExtractImports(filePath, sourceCode, registry) + if extractErr != nil { + // Skip files with parse errors (syntax errors, etc.) + continue + } + } + } +} + +// BenchmarkExtractCallSites_Small measures call site extraction performance on a small project. +// Target: <30ms +// +// This benchmark tests Pass 2B (call site extraction) using tree-sitter. +// It measures the overhead of finding all function/method calls in the AST. +func BenchmarkExtractCallSites_Small(b *testing.B) { + // Pre-build registry and import maps + registry, err := BuildModuleRegistry(smallProjectPath) + if err != nil { + b.Fatalf("Failed to build module registry: %v", err) + } + + // Build import maps for all files + importMaps := make(map[string]*ImportMap) + for modulePath, filePath := range registry.Modules { + sourceCode, readErr := os.ReadFile(filePath) + if readErr != nil { + b.Fatalf("Failed to read file %s: %v", filePath, readErr) + } + + importMap, extractErr := ExtractImports(filePath, sourceCode, registry) + if extractErr != nil { + b.Fatalf("Failed to extract imports from %s: %v", modulePath, extractErr) + } + importMaps[filePath] = importMap + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Extract call sites from all files + for _, filePath := range registry.Modules { + sourceCode, readErr := os.ReadFile(filePath) + if readErr != nil { + b.Fatalf("Failed to read file %s: %v", filePath, readErr) + } + + importMap := importMaps[filePath] + _, extractErr := ExtractCallSites(filePath, sourceCode, importMap) + if extractErr != nil { + b.Fatalf("Failed to extract call sites from %s: %v", filePath, extractErr) + } + } + } +} + +// BenchmarkBuildCallGraph_Small measures end-to-end call graph construction on a small project. +// Target: <100ms +// +// This benchmark tests the complete 3-pass algorithm: +// - Pass 1: Module registry +// - Pass 2A: Import extraction +// - Pass 2B: Call site extraction +// - Pass 3: Call graph construction +// +// Note: Currently skipped because BuildCallGraph expects codeGraph to have functions pre-indexed +// which requires full AST parsing. Use BenchmarkInitializeCallGraph_Small instead which +// includes the full pipeline. +func BenchmarkBuildCallGraph_Small(b *testing.B) { + b.Skip("Skipping: BuildCallGraph requires codeGraph with pre-indexed functions") +} + +// BenchmarkBuildCallGraph_Medium measures end-to-end call graph construction on a medium project. +// Target: <5s +// +// This benchmark validates that the 3-pass algorithm scales to label-studio. +// +// Note: Currently skipped. Use BenchmarkInitializeCallGraph_Medium instead. +func BenchmarkBuildCallGraph_Medium(b *testing.B) { + b.Skip("Skipping: BuildCallGraph requires codeGraph with pre-indexed functions") +} + +// BenchmarkBuildCallGraph_Large measures end-to-end call graph construction on a large project. +// Target: <30s +// +// This benchmark validates that the 3-pass algorithm can handle salt's complexity. +// +// Note: Currently skipped. Use BenchmarkInitializeCallGraph_Large instead (when enabled). +func BenchmarkBuildCallGraph_Large(b *testing.B) { + b.Skip("Skipping: BuildCallGraph requires codeGraph with pre-indexed functions") +} + +// BenchmarkInitializeCallGraph_Small measures the full initialization pipeline on a small project. +// Target: <150ms +// +// This benchmark tests InitializeCallGraph(), which includes: +// - Code graph initialization (AST parsing) +// - Module registry building +// - Call graph construction +// - Pattern registry loading +func BenchmarkInitializeCallGraph_Small(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Include full pipeline: graph initialization + call graph initialization + codeGraph := graph.Initialize(smallProjectPath) + callGraph, registry, patternRegistry, err := InitializeCallGraph(codeGraph, smallProjectPath) + if err != nil { + b.Fatalf("Failed to initialize call graph: %v", err) + } + + // Validate results (don't fail if Functions is empty since it depends on graph.Initialize) + if len(registry.Modules) == 0 { + b.Fatal("Expected modules to be registered") + } + if len(patternRegistry.Patterns) == 0 { + b.Fatal("Expected patterns to be loaded") + } + _ = callGraph // Use callGraph to avoid unused variable + } +} + +// BenchmarkInitializeCallGraph_Medium measures the full initialization pipeline on a medium project. +// Target: <10s +// +// Note: Disabled by default due to long runtime. Enable manually to test medium project performance. +func BenchmarkInitializeCallGraph_Medium(b *testing.B) { + b.Skip("Skipping: Medium project benchmarks take >10s, enable manually") + + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + codeGraph := graph.Initialize(mediumProjectPath) + callGraph, registry, patternRegistry, err := InitializeCallGraph(codeGraph, mediumProjectPath) + if err != nil { + b.Fatalf("Failed to initialize call graph: %v", err) + } + + if len(registry.Modules) == 0 { + b.Fatal("Expected modules to be registered") + } + if len(patternRegistry.Patterns) == 0 { + b.Fatal("Expected patterns to be loaded") + } + _ = callGraph + } +} + +// BenchmarkPatternMatching_Small measures security pattern analysis performance on a small project. +// Target: <50ms +// +// This benchmark tests the pattern matching engine against a small call graph. +func BenchmarkPatternMatching_Small(b *testing.B) { + // Pre-build call graph and pattern registry + codeGraph := graph.Initialize(smallProjectPath) + callGraph, registry, patternRegistry, err := InitializeCallGraph(codeGraph, smallProjectPath) + if err != nil { + b.Fatalf("Failed to initialize call graph: %v", err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + matches := AnalyzePatterns(callGraph, patternRegistry) + _ = matches // Use matches to avoid compiler optimization + } + + _ = registry // Silence unused variable warning +} + +// BenchmarkPatternMatching_Medium measures security pattern analysis performance on a medium project. +// Target: <2s +// +// This benchmark validates that pattern matching scales to label-studio's call graph. +func BenchmarkPatternMatching_Medium(b *testing.B) { + // Pre-build call graph and pattern registry + codeGraph := graph.Initialize(mediumProjectPath) + callGraph, registry, patternRegistry, err := InitializeCallGraph(codeGraph, mediumProjectPath) + if err != nil { + b.Fatalf("Failed to initialize call graph: %v", err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + matches := AnalyzePatterns(callGraph, patternRegistry) + _ = matches + } + + _ = registry +} + +// BenchmarkResolveCallTarget measures call target resolution performance. +// Target: <1µs per call +// +// This benchmark tests the hot path for resolving function calls to FQNs. +// It's critical for overall performance since it's called for every call site. +func BenchmarkResolveCallTarget(b *testing.B) { + // Setup test data + registry := NewModuleRegistry() + registry.AddModule("myapp.utils", "/project/myapp/utils.py") + registry.AddModule("myapp.helpers", "/project/myapp/helpers.py") + + importMap := NewImportMap("/project/myapp/main.py") + importMap.AddImport("utils", "myapp.utils") + importMap.AddImport("helper", "myapp.helpers") + + currentModule := "myapp.main" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Test simple attribute access (most common case) + _, _ = resolveCallTarget("utils.process_data", importMap, registry, currentModule) + + // Test aliased import + _, _ = resolveCallTarget("helper.format", importMap, registry, currentModule) + + // Test fully qualified name + _, _ = resolveCallTarget("myapp.utils.validate", importMap, registry, currentModule) + } +} diff --git a/sourcecode-parser/graph/callgraph/builder.go b/sourcecode-parser/graph/callgraph/builder.go index 9a3e05d9..f3388760 100644 --- a/sourcecode-parser/graph/callgraph/builder.go +++ b/sourcecode-parser/graph/callgraph/builder.go @@ -4,10 +4,97 @@ import ( "os" "path/filepath" "strings" + "sync" "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph" ) +// ImportMapCache provides thread-safe caching of ImportMap instances. +// This avoids re-parsing imports from the same file multiple times. +// +// The cache uses a read-write mutex to allow concurrent reads while +// ensuring safe writes. This is critical for performance since: +// - Import extraction involves tree-sitter parsing (expensive) +// - Many files may import the same modules +// - Build call graph processes files sequentially (for now) +// +// Example usage: +// +// cache := NewImportMapCache() +// importMap := cache.GetOrExtract(filePath, sourceCode, registry) +type ImportMapCache struct { + cache map[string]*ImportMap // Maps file path to ImportMap + mu sync.RWMutex // Protects cache map +} + +// NewImportMapCache creates a new empty import map cache. +func NewImportMapCache() *ImportMapCache { + return &ImportMapCache{ + cache: make(map[string]*ImportMap), + } +} + +// Get retrieves an ImportMap from the cache if it exists. +// +// Parameters: +// - filePath: absolute path to the Python file +// +// Returns: +// - ImportMap and true if found in cache, nil and false otherwise +func (c *ImportMapCache) Get(filePath string) (*ImportMap, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + importMap, ok := c.cache[filePath] + return importMap, ok +} + +// Put stores an ImportMap in the cache. +// +// Parameters: +// - filePath: absolute path to the Python file +// - importMap: the extracted ImportMap to cache +func (c *ImportMapCache) Put(filePath string, importMap *ImportMap) { + c.mu.Lock() + defer c.mu.Unlock() + + c.cache[filePath] = importMap +} + +// GetOrExtract retrieves an ImportMap from cache or extracts it if not cached. +// This is the main entry point for using the cache. +// +// Parameters: +// - filePath: absolute path to the Python file +// - sourceCode: file contents (only used if extraction needed) +// - registry: module registry for resolving imports +// +// Returns: +// - ImportMap from cache or newly extracted +// - error if extraction fails (cache misses only) +// +// Thread-safety: +// - Multiple goroutines can safely call GetOrExtract concurrently +// - First caller for a file will extract and cache +// - Subsequent callers will get cached result +func (c *ImportMapCache) GetOrExtract(filePath string, sourceCode []byte, registry *ModuleRegistry) (*ImportMap, error) { + // Try to get from cache (fast path with read lock) + if importMap, ok := c.Get(filePath); ok { + return importMap, nil + } + + // Cache miss - extract imports (expensive operation) + importMap, err := ExtractImports(filePath, sourceCode, registry) + if err != nil { + return nil, err + } + + // Store in cache for future use + c.Put(filePath, importMap) + + return importMap, nil +} + // BuildCallGraph constructs the complete call graph for a Python project. // This is Pass 3 of the 3-pass algorithm: // - Pass 1: BuildModuleRegistry - map files to modules @@ -47,6 +134,10 @@ import ( func BuildCallGraph(codeGraph *graph.CodeGraph, registry *ModuleRegistry, projectRoot string) (*CallGraph, error) { callGraph := NewCallGraph() + // Initialize import map cache for performance + // This avoids re-parsing imports from the same file multiple times + importCache := NewImportMapCache() + // First, index all function definitions from the code graph // This builds the Functions map for quick lookup indexFunctions(codeGraph, callGraph, registry) @@ -65,8 +156,8 @@ func BuildCallGraph(codeGraph *graph.CodeGraph, registry *ModuleRegistry, projec continue } - // Extract imports to build ImportMap for this file - importMap, err := ExtractImports(filePath, sourceCode, registry) + // Extract imports using cache (avoids re-parsing if already cached) + importMap, err := importCache.GetOrExtract(filePath, sourceCode, registry) if err != nil { // Skip files with import errors continue @@ -98,6 +189,11 @@ func BuildCallGraph(codeGraph *graph.CodeGraph, registry *ModuleRegistry, projec callSite.TargetFQN = targetFQN callSite.Resolved = resolved + // If resolution failed, categorize the failure reason + if !resolved { + callSite.FailureReason = categorizeResolutionFailure(callSite.Target, targetFQN) + } + // Add call site to graph (dereference pointer) callGraph.AddCallSite(callerFQN, *callSite) @@ -232,12 +328,130 @@ func findContainingFunction(location Location, functions []*graph.Node, modulePa // // target="obj.method", imports={} // → "obj.method", false (needs type inference) + +// Python built-in functions that should not be resolved as module functions. +var pythonBuiltins = map[string]bool{ + "eval": true, + "exec": true, + "input": true, + "raw_input": true, + "compile": true, + "__import__": true, +} + +// categorizeResolutionFailure determines why a call target failed to resolve. +// This enables diagnostic reporting to understand resolution gaps. +// +// Categories: +// - "external_framework" - Known external frameworks (Django, REST, pytest, stdlib) +// - "orm_pattern" - Django ORM patterns (Model.objects.*, queryset.*) +// - "attribute_chain" - Method calls on objects/return values +// - "variable_method" - Method calls that appear to be on variables +// - "super_call" - Calls via super() mechanism +// - "not_in_imports" - Simple name not found in imports +// - "unknown" - Other unresolved patterns +// +// Parameters: +// - target: original call target string (e.g., "models.ForeignKey") +// - targetFQN: resolved fully qualified name (e.g., "django.db.models.ForeignKey") +// +// Returns: +// - category string describing the failure reason +func categorizeResolutionFailure(target, targetFQN string) string { + // Check for external frameworks (common patterns) + if strings.HasPrefix(targetFQN, "django.") || + strings.HasPrefix(targetFQN, "rest_framework.") || + strings.HasPrefix(targetFQN, "pytest.") || + strings.HasPrefix(targetFQN, "unittest.") || + strings.HasPrefix(targetFQN, "json.") || + strings.HasPrefix(targetFQN, "logging.") || + strings.HasPrefix(targetFQN, "os.") || + strings.HasPrefix(targetFQN, "sys.") || + strings.HasPrefix(targetFQN, "re.") || + strings.HasPrefix(targetFQN, "pathlib.") || + strings.HasPrefix(targetFQN, "collections.") || + strings.HasPrefix(targetFQN, "datetime.") { + return "external_framework" + } + + // Check for Django ORM patterns + if strings.Contains(target, ".objects.") || + strings.HasSuffix(target, ".objects") || + (strings.Contains(target, ".") && (strings.HasSuffix(target, ".filter") || + strings.HasSuffix(target, ".get") || + strings.HasSuffix(target, ".create") || + strings.HasSuffix(target, ".update") || + strings.HasSuffix(target, ".delete") || + strings.HasSuffix(target, ".all") || + strings.HasSuffix(target, ".first") || + strings.HasSuffix(target, ".last") || + strings.HasSuffix(target, ".count") || + strings.HasSuffix(target, ".exists"))) { + return "orm_pattern" + } + + // Check for super() calls + if strings.HasPrefix(target, "super(") || strings.HasPrefix(target, "super.") { + return "super_call" + } + + // Check for attribute chains (has dots, looks like obj.method()) + // Heuristic: lowercase first component likely means variable/object + if dotIndex := strings.Index(target, "."); dotIndex != -1 { + firstComponent := target[:dotIndex] + // If starts with lowercase and not a known module pattern, likely attribute chain + if len(firstComponent) > 0 && firstComponent[0] >= 'a' && firstComponent[0] <= 'z' { + // Could be variable method or attribute chain + // Check common variable-like patterns + if firstComponent == "self" || firstComponent == "cls" || + firstComponent == "request" || firstComponent == "response" || + firstComponent == "queryset" || firstComponent == "user" || + firstComponent == "obj" || firstComponent == "value" || + firstComponent == "data" || firstComponent == "result" { + return "variable_method" + } + return "attribute_chain" + } + } + + // Simple name (no dots) - not in imports + if !strings.Contains(target, ".") { + return "not_in_imports" + } + + // Everything else + return "unknown" +} + func resolveCallTarget(target string, importMap *ImportMap, registry *ModuleRegistry, currentModule string) (string, bool) { + // Handle self.method() calls - resolve to current module + if strings.HasPrefix(target, "self.") { + methodName := strings.TrimPrefix(target, "self.") + // Resolve to module.method + moduleFQN := currentModule + "." + methodName + // Validate exists + if validateFQN(moduleFQN, registry) { + return moduleFQN, true + } + // Return unresolved but with module prefix + return moduleFQN, false + } + // Handle simple names (no dots) if !strings.Contains(target, ".") { + // Check if it's a Python built-in + if pythonBuiltins[target] { + // Return as builtins.function for pattern matching + return "builtins." + target, true + } + // Try to resolve through imports if fqn, ok := importMap.Resolve(target); ok { // Found in imports - return the FQN + // Check if it's a known framework + if isKnown, _ := IsKnownFramework(fqn); isKnown { + return fqn, true + } // Validate if it exists in registry resolved := validateFQN(fqn, registry) return fqn, resolved @@ -261,6 +475,10 @@ func resolveCallTarget(target string, importMap *ImportMap, registry *ModuleRegi // Try to resolve base through imports if baseFQN, ok := importMap.Resolve(base); ok { fullFQN := baseFQN + "." + rest + // Check if it's a known framework + if isKnown, _ := IsKnownFramework(fullFQN); isKnown { + return fullFQN, true + } if validateFQN(fullFQN, registry) { return fullFQN, true } diff --git a/sourcecode-parser/graph/callgraph/builder_framework_test.go b/sourcecode-parser/graph/callgraph/builder_framework_test.go new file mode 100644 index 00000000..7988d8be --- /dev/null +++ b/sourcecode-parser/graph/callgraph/builder_framework_test.go @@ -0,0 +1,255 @@ +package callgraph + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestFrameworkResolution validates that known frameworks are resolved correctly. +func TestFrameworkResolution(t *testing.T) { + // Create temporary test directory + tmpDir := t.TempDir() + + // Create test files + testFile := filepath.Join(tmpDir, "test_frameworks.py") + testCode := ` +import django.db.models as models +from rest_framework import serializers +import pytest +import json +import logging + +def test_django(): + # Django ORM call (should be resolved as external framework) + user = models.User.objects.get(id=1) + return user + +def test_rest_framework(): + # REST framework call (should be resolved as external framework) + serializer = serializers.ModelSerializer() + return serializer + +def test_pytest(): + # pytest call (should be resolved as external framework) + fixture = pytest.fixture() + return fixture + +def test_stdlib(): + # stdlib calls (should be resolved as external framework) + data = json.loads('{}') + logger = logging.getLogger(__name__) + return data +` + err := os.WriteFile(testFile, []byte(testCode), 0644) + assert.NoError(t, err) + + // Build module registry + registry, err := BuildModuleRegistry(tmpDir) + assert.NoError(t, err) + + // Build import map cache + cache := NewImportMapCache() + sourceCode, err := os.ReadFile(testFile) + assert.NoError(t, err) + + importMap, err := cache.GetOrExtract(testFile, sourceCode, registry) + assert.NoError(t, err) + + // Get module path + modulePath, ok := registry.FileToModule[testFile] + assert.True(t, ok) + + // Test Django models resolution + targetFQN, resolved := resolveCallTarget("models.User", importMap, registry, modulePath) + assert.True(t, resolved, "Django models.User should be resolved") + assert.Equal(t, "django.db.models.User", targetFQN) + + // Test REST framework resolution + targetFQN, resolved = resolveCallTarget("serializers.ModelSerializer", importMap, registry, modulePath) + assert.True(t, resolved, "REST framework serializers should be resolved") + assert.Equal(t, "rest_framework.serializers.ModelSerializer", targetFQN) + + // Test pytest resolution + targetFQN, resolved = resolveCallTarget("pytest.fixture", importMap, registry, modulePath) + assert.True(t, resolved, "pytest.fixture should be resolved") + assert.Equal(t, "pytest.fixture", targetFQN) + + // Test json (stdlib) resolution + targetFQN, resolved = resolveCallTarget("json.loads", importMap, registry, modulePath) + assert.True(t, resolved, "json.loads should be resolved") + assert.Equal(t, "json.loads", targetFQN) + + // Test logging (stdlib) resolution + targetFQN, resolved = resolveCallTarget("logging.getLogger", importMap, registry, modulePath) + assert.True(t, resolved, "logging.getLogger should be resolved") + assert.Equal(t, "logging.getLogger", targetFQN) +} + +// TestNonFrameworkResolution ensures non-framework calls still work correctly. +func TestNonFrameworkResolution(t *testing.T) { + // Create temporary test directory + tmpDir := t.TempDir() + + // Create utils module + utilsFile := filepath.Join(tmpDir, "utils.py") + utilsCode := ` +def sanitize(value): + return value.strip() + +def validate(data): + return True +` + err := os.WriteFile(utilsFile, []byte(utilsCode), 0644) + assert.NoError(t, err) + + // Create test file that imports utils + testFile := filepath.Join(tmpDir, "test.py") + testCode := ` +from utils import sanitize, validate + +def process(): + result = sanitize(" test ") + valid = validate(result) + return valid +` + err = os.WriteFile(testFile, []byte(testCode), 0644) + assert.NoError(t, err) + + // Build module registry + registry, err := BuildModuleRegistry(tmpDir) + assert.NoError(t, err) + + // Build import map + cache := NewImportMapCache() + sourceCode, err := os.ReadFile(testFile) + assert.NoError(t, err) + + importMap, err := cache.GetOrExtract(testFile, sourceCode, registry) + assert.NoError(t, err) + + // Get module path + modulePath, ok := registry.FileToModule[testFile] + assert.True(t, ok) + + // Test local function resolution (should resolve to local module) + targetFQN, resolved := resolveCallTarget("sanitize", importMap, registry, modulePath) + assert.True(t, resolved, "Local function sanitize should be resolved") + assert.Contains(t, targetFQN, "utils.sanitize") + + targetFQN, resolved = resolveCallTarget("validate", importMap, registry, modulePath) + assert.True(t, resolved, "Local function validate should be resolved") + assert.Contains(t, targetFQN, "utils.validate") +} + +// TestFrameworkVsLocalPrecedence ensures local definitions take precedence over frameworks. +func TestFrameworkVsLocalPrecedence(t *testing.T) { + // Create temporary test directory + tmpDir := t.TempDir() + + // Create a local module named "json" (shadowing stdlib) + jsonFile := filepath.Join(tmpDir, "json.py") + jsonCode := ` +def loads(data): + return "custom loads" +` + err := os.WriteFile(jsonFile, []byte(jsonCode), 0644) + assert.NoError(t, err) + + // Create test file that imports local json + testFile := filepath.Join(tmpDir, "test.py") + testCode := ` +from json import loads + +def process(): + return loads('{}') +` + err = os.WriteFile(testFile, []byte(testCode), 0644) + assert.NoError(t, err) + + // Build module registry + registry, err := BuildModuleRegistry(tmpDir) + assert.NoError(t, err) + + // Build import map + cache := NewImportMapCache() + sourceCode, err := os.ReadFile(testFile) + assert.NoError(t, err) + + importMap, err := cache.GetOrExtract(testFile, sourceCode, registry) + assert.NoError(t, err) + + // Get module path + modulePath, ok := registry.FileToModule[testFile] + assert.True(t, ok) + + // Test that local json takes precedence over stdlib + targetFQN, resolved := resolveCallTarget("loads", importMap, registry, modulePath) + assert.True(t, resolved, "Local json.loads should be resolved") + // When there's a local module that shadows stdlib, it resolves to local + // The FQN will be json.loads but from the local module, not stdlib + assert.Contains(t, targetFQN, "json.loads", "Should resolve to json.loads") + + // Verify it's actually from local module by checking registry + _, localExists := registry.Modules[targetFQN[:strings.LastIndex(targetFQN, ".")]] + assert.True(t, localExists, "Should resolve to local json module in registry") +} + +// TestMixedFrameworkAndLocalCalls validates correct resolution in mixed scenarios. +func TestMixedFrameworkAndLocalCalls(t *testing.T) { + // Create temporary test directory + tmpDir := t.TempDir() + + // Create local utils + utilsFile := filepath.Join(tmpDir, "utils.py") + utilsCode := ` +def helper(): + pass +` + err := os.WriteFile(utilsFile, []byte(utilsCode), 0644) + assert.NoError(t, err) + + // Create test file with mixed calls + testFile := filepath.Join(tmpDir, "test.py") + testCode := ` +import json +from utils import helper + +def process(): + # Local call + helper() + # Framework call + data = json.loads('{}') + return data +` + err = os.WriteFile(testFile, []byte(testCode), 0644) + assert.NoError(t, err) + + // Build module registry + registry, err := BuildModuleRegistry(tmpDir) + assert.NoError(t, err) + + // Build import map + cache := NewImportMapCache() + sourceCode, err := os.ReadFile(testFile) + assert.NoError(t, err) + + importMap, err := cache.GetOrExtract(testFile, sourceCode, registry) + assert.NoError(t, err) + + modulePath, ok := registry.FileToModule[testFile] + assert.True(t, ok) + + // Test local function resolution + targetFQN, resolved := resolveCallTarget("helper", importMap, registry, modulePath) + assert.True(t, resolved, "Local helper should be resolved") + assert.Contains(t, targetFQN, "utils.helper") + + // Test framework resolution + targetFQN, resolved = resolveCallTarget("json.loads", importMap, registry, modulePath) + assert.True(t, resolved, "json.loads should be resolved as framework") + assert.Equal(t, "json.loads", targetFQN) +} diff --git a/sourcecode-parser/graph/callgraph/cache_test.go b/sourcecode-parser/graph/callgraph/cache_test.go new file mode 100644 index 00000000..82b29310 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/cache_test.go @@ -0,0 +1,211 @@ +package callgraph + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewImportMapCache(t *testing.T) { + cache := NewImportMapCache() + assert.NotNil(t, cache) + assert.NotNil(t, cache.cache) + assert.Equal(t, 0, len(cache.cache)) +} + +func TestImportMapCache_GetEmpty(t *testing.T) { + cache := NewImportMapCache() + + importMap, ok := cache.Get("/nonexistent/file.py") + assert.False(t, ok) + assert.Nil(t, importMap) +} + +func TestImportMapCache_PutAndGet(t *testing.T) { + cache := NewImportMapCache() + filePath := "/test/file.py" + + // Create a test ImportMap + testImportMap := NewImportMap(filePath) + testImportMap.AddImport("os", "os") + testImportMap.AddImport("json", "json") + + // Put in cache + cache.Put(filePath, testImportMap) + + // Get from cache + retrieved, ok := cache.Get(filePath) + assert.True(t, ok) + assert.NotNil(t, retrieved) + assert.Equal(t, filePath, retrieved.FilePath) + assert.Equal(t, "os", retrieved.Imports["os"]) + assert.Equal(t, "json", retrieved.Imports["json"]) +} + +func TestImportMapCache_GetOrExtract_CacheHit(t *testing.T) { + cache := NewImportMapCache() + registry := NewModuleRegistry() + filePath := "/test/file.py" + + // Pre-populate cache + cachedImportMap := NewImportMap(filePath) + cachedImportMap.AddImport("cached", "cached.module") + cache.Put(filePath, cachedImportMap) + + // GetOrExtract should return cached version (sourceCode won't be used) + result, err := cache.GetOrExtract(filePath, []byte("# dummy code"), registry) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, "cached.module", result.Imports["cached"]) +} + +func TestImportMapCache_GetOrExtract_CacheMiss(t *testing.T) { + cache := NewImportMapCache() + registry := NewModuleRegistry() + filePath := "../../../test-src/python/imports_test/simple_imports.py" + + // Read test file + sourceCode, err := readFileBytes(filePath) + assert.NoError(t, err) + + // GetOrExtract should extract and cache + result, err := cache.GetOrExtract(filePath, sourceCode, registry) + assert.NoError(t, err) + assert.NotNil(t, result) + + // Verify it's now in cache + cached, ok := cache.Get(filePath) + assert.True(t, ok) + assert.Equal(t, result, cached) +} + +func TestImportMapCache_Concurrent(t *testing.T) { + cache := NewImportMapCache() + registry := NewModuleRegistry() + filePath := "../../../test-src/python/imports_test/simple_imports.py" + + sourceCode, err := readFileBytes(filePath) + assert.NoError(t, err) + + // Launch multiple goroutines to access cache concurrently + const numGoroutines = 10 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + errors := make([]error, numGoroutines) + results := make([]*ImportMap, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(index int) { + defer wg.Done() + result, getErr := cache.GetOrExtract(filePath, sourceCode, registry) + errors[index] = getErr + results[index] = result + }(i) + } + + wg.Wait() + + // All goroutines should succeed + for i := 0; i < numGoroutines; i++ { + assert.NoError(t, errors[i], "Goroutine %d should not error", i) + assert.NotNil(t, results[i], "Goroutine %d should return a result", i) + } + + // All results should be identical (same cached instance or semantically equal) + for i := 1; i < numGoroutines; i++ { + assert.Equal(t, results[0].FilePath, results[i].FilePath) + assert.Equal(t, len(results[0].Imports), len(results[i].Imports)) + } + + // Cache should only contain one entry + assert.Equal(t, 1, len(cache.cache)) +} + +func TestImportMapCache_MultipleFiles(t *testing.T) { + cache := NewImportMapCache() + + file1 := "/test/file1.py" + file2 := "/test/file2.py" + file3 := "/test/file3.py" + + // Add multiple entries + cache.Put(file1, NewImportMap(file1)) + cache.Put(file2, NewImportMap(file2)) + cache.Put(file3, NewImportMap(file3)) + + // Verify all are cached + _, ok1 := cache.Get(file1) + _, ok2 := cache.Get(file2) + _, ok3 := cache.Get(file3) + + assert.True(t, ok1) + assert.True(t, ok2) + assert.True(t, ok3) + assert.Equal(t, 3, len(cache.cache)) +} + +func TestImportMapCache_OverwriteExisting(t *testing.T) { + cache := NewImportMapCache() + filePath := "/test/file.py" + + // Add first version + firstMap := NewImportMap(filePath) + firstMap.AddImport("first", "first.module") + cache.Put(filePath, firstMap) + + // Overwrite with second version + secondMap := NewImportMap(filePath) + secondMap.AddImport("second", "second.module") + cache.Put(filePath, secondMap) + + // Should have second version + result, ok := cache.Get(filePath) + assert.True(t, ok) + assert.Equal(t, "second.module", result.Imports["second"]) + assert.NotContains(t, result.Imports, "first") +} + +func BenchmarkImportMapCache_Get(b *testing.B) { + cache := NewImportMapCache() + filePath := "/test/file.py" + testMap := NewImportMap(filePath) + cache.Put(filePath, testMap) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _ = cache.Get(filePath) + } +} + +func BenchmarkImportMapCache_Put(b *testing.B) { + cache := NewImportMapCache() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + filePath := "/test/file.py" + testMap := NewImportMap(filePath) + cache.Put(filePath, testMap) + } +} + +func BenchmarkImportMapCache_ConcurrentGet(b *testing.B) { + cache := NewImportMapCache() + filePath := "/test/file.py" + testMap := NewImportMap(filePath) + cache.Put(filePath, testMap) + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, _ = cache.Get(filePath) + } + }) +} diff --git a/sourcecode-parser/graph/callgraph/frameworks.go b/sourcecode-parser/graph/callgraph/frameworks.go new file mode 100644 index 00000000..c4cb7007 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/frameworks.go @@ -0,0 +1,413 @@ +package callgraph + +import ( + "strings" +) + +// FrameworkDefinition represents a known external framework or library. +// This is used to mark calls to external code as resolved, even though +// we don't have the source code for these frameworks. +type FrameworkDefinition struct { + Name string // Display name (e.g., "Django") + Prefixes []string // Module prefixes to match (e.g., ["django.", "django"]) + Description string // Human-readable description + Category string // Category: "web", "orm", "testing", "stdlib", etc. +} + +// builtinFrameworks contains the list of known Python frameworks and libraries. +// This list focuses on the most common frameworks found in Python projects. +var builtinFrameworks = []FrameworkDefinition{ + // Web Frameworks + { + Name: "Django", + Prefixes: []string{"django."}, + Description: "Django web framework", + Category: "web", + }, + { + Name: "Django REST Framework", + Prefixes: []string{"rest_framework."}, + Description: "Django REST framework for building Web APIs", + Category: "web", + }, + { + Name: "Flask", + Prefixes: []string{"flask."}, + Description: "Flask web framework", + Category: "web", + }, + { + Name: "FastAPI", + Prefixes: []string{"fastapi."}, + Description: "FastAPI web framework", + Category: "web", + }, + { + Name: "Starlette", + Prefixes: []string{"starlette."}, + Description: "Starlette ASGI framework", + Category: "web", + }, + { + Name: "Tornado", + Prefixes: []string{"tornado."}, + Description: "Tornado web framework", + Category: "web", + }, + { + Name: "Pyramid", + Prefixes: []string{"pyramid."}, + Description: "Pyramid web framework", + Category: "web", + }, + { + Name: "Bottle", + Prefixes: []string{"bottle."}, + Description: "Bottle web framework", + Category: "web", + }, + + // ORM and Database + { + Name: "SQLAlchemy", + Prefixes: []string{"sqlalchemy."}, + Description: "SQLAlchemy ORM", + Category: "orm", + }, + { + Name: "Peewee", + Prefixes: []string{"peewee."}, + Description: "Peewee ORM", + Category: "orm", + }, + { + Name: "Tortoise ORM", + Prefixes: []string{"tortoise."}, + Description: "Tortoise ORM", + Category: "orm", + }, + { + Name: "Pony ORM", + Prefixes: []string{"pony."}, + Description: "Pony ORM", + Category: "orm", + }, + + // Testing Frameworks + { + Name: "pytest", + Prefixes: []string{"pytest.", "_pytest."}, + Description: "pytest testing framework", + Category: "testing", + }, + { + Name: "unittest", + Prefixes: []string{"unittest."}, + Description: "Python unittest framework", + Category: "testing", + }, + { + Name: "nose", + Prefixes: []string{"nose."}, + Description: "nose testing framework", + Category: "testing", + }, + { + Name: "mock", + Prefixes: []string{"mock.", "unittest.mock."}, + Description: "Python mock library", + Category: "testing", + }, + + // HTTP and Requests + { + Name: "requests", + Prefixes: []string{"requests."}, + Description: "HTTP library for Python", + Category: "http", + }, + { + Name: "httpx", + Prefixes: []string{"httpx."}, + Description: "Async HTTP client", + Category: "http", + }, + { + Name: "urllib3", + Prefixes: []string{"urllib3."}, + Description: "HTTP client library", + Category: "http", + }, + { + Name: "aiohttp", + Prefixes: []string{"aiohttp."}, + Description: "Async HTTP client/server", + Category: "http", + }, + + // Data Science and ML + { + Name: "numpy", + Prefixes: []string{"numpy.", "np."}, + Description: "Numerical computing library", + Category: "data_science", + }, + { + Name: "pandas", + Prefixes: []string{"pandas.", "pd."}, + Description: "Data analysis library", + Category: "data_science", + }, + { + Name: "scikit-learn", + Prefixes: []string{"sklearn.", "scikit_learn."}, + Description: "Machine learning library", + Category: "data_science", + }, + { + Name: "tensorflow", + Prefixes: []string{"tensorflow.", "tf."}, + Description: "TensorFlow ML framework", + Category: "data_science", + }, + { + Name: "pytorch", + Prefixes: []string{"torch."}, + Description: "PyTorch ML framework", + Category: "data_science", + }, + + // Async and Concurrency + { + Name: "asyncio", + Prefixes: []string{"asyncio."}, + Description: "Async I/O library", + Category: "async", + }, + { + Name: "celery", + Prefixes: []string{"celery."}, + Description: "Distributed task queue", + Category: "async", + }, + + // Serialization and Data Formats + { + Name: "json", + Prefixes: []string{"json."}, + Description: "JSON encoder/decoder", + Category: "serialization", + }, + { + Name: "pickle", + Prefixes: []string{"pickle.", "_pickle."}, + Description: "Python object serialization", + Category: "serialization", + }, + { + Name: "yaml", + Prefixes: []string{"yaml.", "pyyaml."}, + Description: "YAML parser", + Category: "serialization", + }, + { + Name: "xml", + Prefixes: []string{"xml."}, + Description: "XML processing", + Category: "serialization", + }, + + // Logging and Monitoring + { + Name: "logging", + Prefixes: []string{"logging."}, + Description: "Python logging library", + Category: "logging", + }, + { + Name: "sentry", + Prefixes: []string{"sentry_sdk."}, + Description: "Sentry error tracking", + Category: "logging", + }, + + // Utilities + { + Name: "datetime", + Prefixes: []string{"datetime."}, + Description: "Date and time types", + Category: "stdlib", + }, + { + Name: "collections", + Prefixes: []string{"collections."}, + Description: "Container datatypes", + Category: "stdlib", + }, + { + Name: "itertools", + Prefixes: []string{"itertools."}, + Description: "Iterator functions", + Category: "stdlib", + }, + { + Name: "functools", + Prefixes: []string{"functools."}, + Description: "Higher-order functions", + Category: "stdlib", + }, + { + Name: "os", + Prefixes: []string{"os."}, + Description: "Operating system interfaces", + Category: "stdlib", + }, + { + Name: "sys", + Prefixes: []string{"sys."}, + Description: "System-specific parameters", + Category: "stdlib", + }, + { + Name: "pathlib", + Prefixes: []string{"pathlib."}, + Description: "Object-oriented filesystem paths", + Category: "stdlib", + }, + { + Name: "re", + Prefixes: []string{"re."}, + Description: "Regular expressions", + Category: "stdlib", + }, + { + Name: "subprocess", + Prefixes: []string{"subprocess."}, + Description: "Subprocess management", + Category: "stdlib", + }, + { + Name: "threading", + Prefixes: []string{"threading."}, + Description: "Thread-based parallelism", + Category: "stdlib", + }, + { + Name: "multiprocessing", + Prefixes: []string{"multiprocessing."}, + Description: "Process-based parallelism", + Category: "stdlib", + }, + { + Name: "socket", + Prefixes: []string{"socket."}, + Description: "Low-level networking", + Category: "stdlib", + }, + { + Name: "http", + Prefixes: []string{"http."}, + Description: "HTTP modules", + Category: "stdlib", + }, + { + Name: "urllib", + Prefixes: []string{"urllib."}, + Description: "URL handling modules", + Category: "stdlib", + }, + { + Name: "email", + Prefixes: []string{"email."}, + Description: "Email and MIME handling", + Category: "stdlib", + }, + { + Name: "hashlib", + Prefixes: []string{"hashlib."}, + Description: "Secure hash and message digest", + Category: "stdlib", + }, + { + Name: "hmac", + Prefixes: []string{"hmac."}, + Description: "Keyed-hashing for message authentication", + Category: "stdlib", + }, + { + Name: "secrets", + Prefixes: []string{"secrets."}, + Description: "Generate secure random numbers", + Category: "stdlib", + }, + { + Name: "typing", + Prefixes: []string{"typing."}, + Description: "Type hints support", + Category: "stdlib", + }, + { + Name: "dataclasses", + Prefixes: []string{"dataclasses."}, + Description: "Data classes", + Category: "stdlib", + }, + { + Name: "abc", + Prefixes: []string{"abc."}, + Description: "Abstract base classes", + Category: "stdlib", + }, +} + +// LoadFrameworks returns the list of known frameworks. +// This function provides an extensibility hook for future enhancements +// where frameworks might be loaded from a configuration file. +func LoadFrameworks() []FrameworkDefinition { + return builtinFrameworks +} + +// IsKnownFramework checks if the given fully qualified name (FQN) +// belongs to a known external framework or standard library. +// +// Parameters: +// - fqn: fully qualified name (e.g., "django.db.models.ForeignKey") +// +// Returns: +// - true if the FQN matches any known framework +// - the matching framework definition +func IsKnownFramework(fqn string) (bool, *FrameworkDefinition) { + frameworks := LoadFrameworks() + + for i := range frameworks { + framework := &frameworks[i] + for _, prefix := range framework.Prefixes { + // Check for exact match or prefix match + if fqn == prefix || strings.HasPrefix(fqn, prefix) { + return true, framework + } + } + } + + return false, nil +} + +// GetFrameworkCategory returns the category of a framework given its FQN. +// Returns empty string if not a known framework. +func GetFrameworkCategory(fqn string) string { + isKnown, framework := IsKnownFramework(fqn) + if isKnown { + return framework.Category + } + return "" +} + +// GetFrameworkName returns the name of a framework given its FQN. +// Returns empty string if not a known framework. +func GetFrameworkName(fqn string) string { + isKnown, framework := IsKnownFramework(fqn) + if isKnown { + return framework.Name + } + return "" +} diff --git a/sourcecode-parser/graph/callgraph/frameworks_test.go b/sourcecode-parser/graph/callgraph/frameworks_test.go new file mode 100644 index 00000000..943edecc --- /dev/null +++ b/sourcecode-parser/graph/callgraph/frameworks_test.go @@ -0,0 +1,377 @@ +package callgraph + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsKnownFramework_Django(t *testing.T) { + tests := []struct { + name string + fqn string + expected bool + category string + }{ + { + name: "Django core", + fqn: "django.db.models.Model", + expected: true, + category: "web", + }, + { + name: "Django REST framework", + fqn: "rest_framework.serializers.ModelSerializer", + expected: true, + category: "web", + }, + { + name: "Django forms", + fqn: "django.forms.Form", + expected: true, + category: "web", + }, + { + name: "Django ORM", + fqn: "django.db.models.ForeignKey", + expected: true, + category: "web", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.Equal(t, tt.expected, isKnown) + if isKnown { + assert.NotNil(t, framework) + assert.Equal(t, tt.category, framework.Category) + } + }) + } +} + +func TestIsKnownFramework_Testing(t *testing.T) { + tests := []struct { + name string + fqn string + expected bool + category string + }{ + { + name: "pytest", + fqn: "pytest.fixture", + expected: true, + category: "testing", + }, + { + name: "unittest", + fqn: "unittest.TestCase", + expected: true, + category: "testing", + }, + { + name: "mock", + fqn: "unittest.mock.patch", + expected: true, + category: "testing", + }, + { + name: "_pytest internal", + fqn: "_pytest.fixtures", + expected: true, + category: "testing", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.Equal(t, tt.expected, isKnown) + if isKnown { + assert.NotNil(t, framework) + assert.Equal(t, tt.category, framework.Category) + } + }) + } +} + +func TestIsKnownFramework_HTTP(t *testing.T) { + tests := []struct { + name string + fqn string + expected bool + category string + }{ + { + name: "requests library", + fqn: "requests.get", + expected: true, + category: "http", + }, + { + name: "httpx", + fqn: "httpx.AsyncClient", + expected: true, + category: "http", + }, + { + name: "urllib3", + fqn: "urllib3.PoolManager", + expected: true, + category: "http", + }, + { + name: "aiohttp", + fqn: "aiohttp.ClientSession", + expected: true, + category: "http", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.Equal(t, tt.expected, isKnown) + if isKnown { + assert.NotNil(t, framework) + assert.Equal(t, tt.category, framework.Category) + } + }) + } +} + +func TestIsKnownFramework_DataScience(t *testing.T) { + tests := []struct { + name string + fqn string + expected bool + category string + }{ + { + name: "numpy", + fqn: "numpy.array", + expected: true, + category: "data_science", + }, + { + name: "pandas", + fqn: "pandas.DataFrame", + expected: true, + category: "data_science", + }, + { + name: "sklearn", + fqn: "sklearn.ensemble.RandomForestClassifier", + expected: true, + category: "data_science", + }, + { + name: "tensorflow", + fqn: "tensorflow.keras.Model", + expected: true, + category: "data_science", + }, + { + name: "pytorch", + fqn: "torch.nn.Module", + expected: true, + category: "data_science", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.Equal(t, tt.expected, isKnown) + if isKnown { + assert.NotNil(t, framework) + assert.Equal(t, tt.category, framework.Category) + } + }) + } +} + +func TestIsKnownFramework_Stdlib(t *testing.T) { + tests := []struct { + name string + fqn string + expected bool + category string + }{ + { + name: "json", + fqn: "json.loads", + expected: true, + category: "serialization", + }, + { + name: "pickle", + fqn: "pickle.dumps", + expected: true, + category: "serialization", + }, + { + name: "logging", + fqn: "logging.getLogger", + expected: true, + category: "logging", + }, + { + name: "datetime", + fqn: "datetime.datetime", + expected: true, + category: "stdlib", + }, + { + name: "collections", + fqn: "collections.defaultdict", + expected: true, + category: "stdlib", + }, + { + name: "os", + fqn: "os.path.join", + expected: true, + category: "stdlib", + }, + { + name: "subprocess", + fqn: "subprocess.run", + expected: true, + category: "stdlib", + }, + { + name: "hashlib", + fqn: "hashlib.sha256", + expected: true, + category: "stdlib", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.Equal(t, tt.expected, isKnown) + if isKnown { + assert.NotNil(t, framework) + assert.Equal(t, tt.category, framework.Category) + } + }) + } +} + +func TestIsKnownFramework_NotFound(t *testing.T) { + tests := []struct { + name string + fqn string + }{ + { + name: "Custom application module", + fqn: "myapp.utils.helpers", + }, + { + name: "Custom package", + fqn: "internal.services.auth", + }, + { + name: "Unknown framework", + fqn: "unknownframework.module", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.False(t, isKnown) + assert.Nil(t, framework) + }) + } +} + +func TestGetFrameworkCategory(t *testing.T) { + tests := []struct { + name string + fqn string + expected string + }{ + { + name: "Django web framework", + fqn: "django.http.HttpResponse", + expected: "web", + }, + { + name: "pytest testing", + fqn: "pytest.mark.parametrize", + expected: "testing", + }, + { + name: "requests http", + fqn: "requests.post", + expected: "http", + }, + { + name: "Unknown framework", + fqn: "myapp.custom", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + category := GetFrameworkCategory(tt.fqn) + assert.Equal(t, tt.expected, category) + }) + } +} + +func TestGetFrameworkName(t *testing.T) { + tests := []struct { + name string + fqn string + expected string + }{ + { + name: "Django", + fqn: "django.contrib.auth", + expected: "Django", + }, + { + name: "Flask", + fqn: "flask.Flask", + expected: "Flask", + }, + { + name: "FastAPI", + fqn: "fastapi.FastAPI", + expected: "FastAPI", + }, + { + name: "Unknown", + fqn: "myapp.unknown", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + name := GetFrameworkName(tt.fqn) + assert.Equal(t, tt.expected, name) + }) + } +} + +func TestLoadFrameworks(t *testing.T) { + frameworks := LoadFrameworks() + + // Should have at least 50 frameworks defined + assert.GreaterOrEqual(t, len(frameworks), 50) + + // Check that all frameworks have required fields + for _, fw := range frameworks { + assert.NotEmpty(t, fw.Name, "Framework should have a name") + assert.NotEmpty(t, fw.Prefixes, "Framework should have at least one prefix") + assert.NotEmpty(t, fw.Category, "Framework should have a category") + assert.NotEmpty(t, fw.Description, "Framework should have a description") + } +} diff --git a/sourcecode-parser/graph/callgraph/integration.go b/sourcecode-parser/graph/callgraph/integration.go new file mode 100644 index 00000000..1c776a0f --- /dev/null +++ b/sourcecode-parser/graph/callgraph/integration.go @@ -0,0 +1,124 @@ +package callgraph + +import ( + "time" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph" +) + +// InitializeCallGraph builds the call graph from a code graph. +// This integrates the 3-pass algorithm into the main initialization pipeline. +// +// Algorithm: +// 1. Build module registry from project directory +// 2. Build call graph from code graph using registry +// 3. Load default security patterns +// 4. Return integrated result +// +// Parameters: +// - codeGraph: the parsed code graph from Initialize() +// - projectRoot: absolute path to project root directory +// +// Returns: +// - CallGraph: complete call graph with edges and call sites +// - ModuleRegistry: module path mappings +// - PatternRegistry: loaded security patterns +// - error: if any step fails +func InitializeCallGraph(codeGraph *graph.CodeGraph, projectRoot string) (*CallGraph, *ModuleRegistry, *PatternRegistry, error) { + // Pass 1: Build module registry + startRegistry := time.Now() + registry, err := BuildModuleRegistry(projectRoot) + if err != nil { + return nil, nil, nil, err + } + elapsedRegistry := time.Since(startRegistry) + + // Pass 2-3: Build call graph (includes import extraction and call site extraction) + startCallGraph := time.Now() + callGraph, err := BuildCallGraph(codeGraph, registry, projectRoot) + if err != nil { + return nil, nil, nil, err + } + elapsedCallGraph := time.Since(startCallGraph) + + // Load security patterns + startPatterns := time.Now() + patternRegistry := NewPatternRegistry() + patternRegistry.LoadDefaultPatterns() + elapsedPatterns := time.Since(startPatterns) + + // Log timing information + graph.Log("Module registry built in:", elapsedRegistry) + graph.Log("Call graph built in:", elapsedCallGraph) + graph.Log("Patterns loaded in:", elapsedPatterns) + + return callGraph, registry, patternRegistry, nil +} + +// AnalyzePatterns runs pattern matching against the call graph. +// Returns a list of matched patterns with their details. +func AnalyzePatterns(callGraph *CallGraph, patternRegistry *PatternRegistry) []PatternMatch { + var matches []PatternMatch + + for _, pattern := range patternRegistry.Patterns { + details := patternRegistry.MatchPattern(pattern, callGraph) + if details != nil && details.Matched { + match := PatternMatch{ + PatternID: pattern.ID, + PatternName: pattern.Name, + Description: pattern.Description, + Severity: pattern.Severity, + CWE: pattern.CWE, + OWASP: pattern.OWASP, + SourceFQN: details.SourceFQN, + SourceCall: details.SourceCall, + SinkFQN: details.SinkFQN, + SinkCall: details.SinkCall, + DataFlowPath: details.DataFlowPath, + } + + // Lookup source function details from call graph + if sourceNode, ok := callGraph.Functions[details.SourceFQN]; ok { + match.SourceFile = sourceNode.File + match.SourceLine = sourceNode.LineNumber + match.SourceCode = sourceNode.GetCodeSnippet() + } + + // Lookup sink function details from call graph + if sinkNode, ok := callGraph.Functions[details.SinkFQN]; ok { + match.SinkFile = sinkNode.File + match.SinkLine = sinkNode.LineNumber + match.SinkCode = sinkNode.GetCodeSnippet() + } + + matches = append(matches, match) + } + } + + return matches +} + +// PatternMatch represents a detected security pattern in the code. +type PatternMatch struct { + PatternID string // Pattern identifier + PatternName string // Human-readable name + Description string // What was detected + Severity Severity // Risk level + CWE string // CWE identifier + OWASP string // OWASP category + + // Vulnerability location details + SourceFQN string // Fully qualified name of the source function + SourceCall string // The actual dangerous call (e.g., "input", "request.GET") + SourceFile string // File path where source is located + SourceLine uint32 // Line number of source function + SourceCode string // Code snippet of source function + + SinkFQN string // Fully qualified name of the sink function + SinkCall string // The actual dangerous call (e.g., "eval", "exec") + SinkFile string // File path where sink is located + SinkLine uint32 // Line number of sink function + SinkCode string // Code snippet of sink function + + DataFlowPath []string // Complete path from source to sink (FQNs) +} diff --git a/sourcecode-parser/graph/callgraph/integration_test.go b/sourcecode-parser/graph/callgraph/integration_test.go new file mode 100644 index 00000000..0faa52d2 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/integration_test.go @@ -0,0 +1,260 @@ +package callgraph + +import ( + "os" + "path/filepath" + "testing" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInitializeCallGraph_EmptyCodeGraph(t *testing.T) { + tmpDir := t.TempDir() + + codeGraph := &graph.CodeGraph{ + Nodes: make(map[string]*graph.Node), + Edges: []*graph.Edge{}, + } + + callGraph, registry, patternRegistry, err := InitializeCallGraph(codeGraph, tmpDir) + + require.NoError(t, err) + assert.NotNil(t, callGraph) + assert.NotNil(t, registry) + assert.NotNil(t, patternRegistry) +} + +func TestInitializeCallGraph_WithSimpleProject(t *testing.T) { + // Create a simple test project + tmpDir := t.TempDir() + + // Create a Python file + testFile := filepath.Join(tmpDir, "test.py") + sourceCode := []byte(` +def get_input(): + return input() + +def process(): + data = get_input() + eval(data) +`) + err := os.WriteFile(testFile, sourceCode, 0644) + require.NoError(t, err) + + // Create a minimal code graph + codeGraph := &graph.CodeGraph{ + Nodes: map[string]*graph.Node{ + "node1": { + ID: "node1", + Type: "function_definition", + Name: "get_input", + File: testFile, + LineNumber: 2, + }, + "node2": { + ID: "node2", + Type: "function_definition", + Name: "process", + File: testFile, + LineNumber: 5, + }, + }, + Edges: []*graph.Edge{}, + } + + callGraph, registry, patternRegistry, err := InitializeCallGraph(codeGraph, tmpDir) + + require.NoError(t, err) + assert.NotNil(t, callGraph) + assert.NotNil(t, registry) + assert.NotNil(t, patternRegistry) + + // Verify module registry was built + assert.NotEmpty(t, registry.Modules) + + // Verify functions were indexed + assert.NotEmpty(t, callGraph.Functions) + + // Verify patterns were loaded + assert.NotEmpty(t, patternRegistry.Patterns) +} + +func TestAnalyzePatterns_NoMatches(t *testing.T) { + // Create call graph with safe functions + callGraph := NewCallGraph() + callGraph.AddCallSite("myapp.safe_function", CallSite{ + Target: "print", + TargetFQN: "builtins.print", + }) + + patternRegistry := NewPatternRegistry() + patternRegistry.LoadDefaultPatterns() + + matches := AnalyzePatterns(callGraph, patternRegistry) + + assert.Empty(t, matches) +} + +func TestAnalyzePatterns_WithMatch(t *testing.T) { + // Create call graph that matches code injection pattern + callGraph := NewCallGraph() + + // Source: get_input() calls input() + callGraph.AddCallSite("myapp.get_input", CallSite{ + Target: "input", + TargetFQN: "builtins.input", + }) + + // Sink: process() calls eval() + callGraph.AddCallSite("myapp.process", CallSite{ + Target: "eval", + TargetFQN: "builtins.eval", + }) + + // Create path from source to sink + callGraph.AddEdge("myapp.get_input", "myapp.process") + + patternRegistry := NewPatternRegistry() + patternRegistry.LoadDefaultPatterns() + + matches := AnalyzePatterns(callGraph, patternRegistry) + + require.Len(t, matches, 1) + assert.Equal(t, "CODE-INJECTION-001", matches[0].PatternID) + assert.Equal(t, "Code injection via eval with user input", matches[0].PatternName) + assert.Equal(t, SeverityCritical, matches[0].Severity) + assert.Equal(t, "CWE-94", matches[0].CWE) +} + +func TestAnalyzePatterns_WithSanitizer(t *testing.T) { + // Create call graph with sanitizer in the path + callGraph := NewCallGraph() + + // Source: get_input() calls input() + callGraph.AddCallSite("myapp.get_input", CallSite{ + Target: "input", + TargetFQN: "builtins.input", + }) + + // Sanitizer: sanitize_input() calls sanitize() + callGraph.AddCallSite("myapp.sanitize_input", CallSite{ + Target: "sanitize", + TargetFQN: "myapp.utils.sanitize", + }) + + // Sink: process() calls eval() + callGraph.AddCallSite("myapp.process", CallSite{ + Target: "eval", + TargetFQN: "builtins.eval", + }) + + // Create path with sanitizer: source -> sanitizer -> sink + callGraph.AddEdge("myapp.get_input", "myapp.sanitize_input") + callGraph.AddEdge("myapp.sanitize_input", "myapp.process") + + patternRegistry := NewPatternRegistry() + patternRegistry.LoadDefaultPatterns() + + matches := AnalyzePatterns(callGraph, patternRegistry) + + // Should not match because sanitizer is present + assert.Empty(t, matches) +} + +func TestPatternMatch_Structure(t *testing.T) { + match := PatternMatch{ + PatternID: "TEST-001", + PatternName: "Test Pattern", + Description: "Test description", + Severity: SeverityHigh, + CWE: "CWE-123", + OWASP: "A01:2021-Test", + } + + assert.Equal(t, "TEST-001", match.PatternID) + assert.Equal(t, "Test Pattern", match.PatternName) + assert.Equal(t, "Test description", match.Description) + assert.Equal(t, SeverityHigh, match.Severity) + assert.Equal(t, "CWE-123", match.CWE) + assert.Equal(t, "A01:2021-Test", match.OWASP) +} + +func TestInitializeCallGraph_Integration(t *testing.T) { + // End-to-end integration test + tmpDir := t.TempDir() + + // Create a Python package structure + utilsDir := filepath.Join(tmpDir, "utils") + err := os.MkdirAll(utilsDir, 0755) + require.NoError(t, err) + + // Create utils/helpers.py + helpersFile := filepath.Join(utilsDir, "helpers.py") + helpersCode := []byte(` +def sanitize(data): + return data.strip() +`) + err = os.WriteFile(helpersFile, helpersCode, 0644) + require.NoError(t, err) + + // Create main.py + mainFile := filepath.Join(tmpDir, "main.py") + mainCode := []byte(` +from utils.helpers import sanitize + +def get_input(): + return input() + +def process(): + data = get_input() + clean_data = sanitize(data) + eval(clean_data) +`) + err = os.WriteFile(mainFile, mainCode, 0644) + require.NoError(t, err) + + // Create code graph + codeGraph := &graph.CodeGraph{ + Nodes: map[string]*graph.Node{ + "node1": { + ID: "node1", + Type: "function_definition", + Name: "sanitize", + File: helpersFile, + LineNumber: 2, + }, + "node2": { + ID: "node2", + Type: "function_definition", + Name: "get_input", + File: mainFile, + LineNumber: 4, + }, + "node3": { + ID: "node3", + Type: "function_definition", + Name: "process", + File: mainFile, + LineNumber: 7, + }, + }, + Edges: []*graph.Edge{}, + } + + // Initialize call graph + callGraph, registry, patternRegistry, err := InitializeCallGraph(codeGraph, tmpDir) + + require.NoError(t, err) + assert.NotNil(t, callGraph) + assert.NotNil(t, registry) + assert.NotNil(t, patternRegistry) + + // Verify modules were registered + assert.Contains(t, registry.Modules, "utils.helpers") + assert.Contains(t, registry.Modules, "main") + + // Verify functions were indexed + assert.NotEmpty(t, callGraph.Functions) +} diff --git a/sourcecode-parser/graph/callgraph/patterns.go b/sourcecode-parser/graph/callgraph/patterns.go index d0d49e09..5706a687 100644 --- a/sourcecode-parser/graph/callgraph/patterns.go +++ b/sourcecode-parser/graph/callgraph/patterns.go @@ -54,8 +54,8 @@ type Pattern struct { // PatternRegistry manages security patterns. type PatternRegistry struct { - Patterns map[string]*Pattern // Pattern ID -> Pattern - PatternsByType map[PatternType][]*Pattern // Type -> Patterns + Patterns map[string]*Pattern // Pattern ID -> Pattern + PatternsByType map[PatternType][]*Pattern // Type -> Patterns } // NewPatternRegistry creates a new pattern registry. @@ -93,7 +93,7 @@ func (pr *PatternRegistry) LoadDefaultPatterns() { Description: "Detects code injection when user input flows to eval() without sanitization", Type: PatternTypeMissingSanitizer, Severity: SeverityCritical, - Sources: []string{"request.GET", "request.POST", "input", "raw_input"}, + Sources: []string{"request.GET", "request.POST", "input", "raw_input", "request.query_params.get"}, Sinks: []string{"eval", "exec"}, Sanitizers: []string{"sanitize", "escape", "validate"}, CWE: "CWE-94", @@ -102,7 +102,8 @@ func (pr *PatternRegistry) LoadDefaultPatterns() { } // MatchPattern checks if a call graph matches a pattern. -func (pr *PatternRegistry) MatchPattern(pattern *Pattern, callGraph *CallGraph) bool { +// Returns detailed match information if a vulnerability is found. +func (pr *PatternRegistry) MatchPattern(pattern *Pattern, callGraph *CallGraph) *PatternMatchDetails { switch pattern.Type { case PatternTypeDangerousFunction: return pr.matchDangerousFunction(pattern, callGraph) @@ -111,81 +112,125 @@ func (pr *PatternRegistry) MatchPattern(pattern *Pattern, callGraph *CallGraph) case PatternTypeMissingSanitizer: return pr.matchMissingSanitizer(pattern, callGraph) default: - return false + return nil } } +// PatternMatchDetails contains detailed information about a pattern match. +type PatternMatchDetails struct { + Matched bool + SourceFQN string // Fully qualified name of function containing the source call + SourceCall string // The actual dangerous call (e.g., "input", "request.GET") + SinkFQN string // Fully qualified name of function containing the sink call + SinkCall string // The actual dangerous call (e.g., "eval", "exec") + DataFlowPath []string // Complete path from source to sink +} + // matchDangerousFunction checks if any dangerous function is called. -func (pr *PatternRegistry) matchDangerousFunction(pattern *Pattern, callGraph *CallGraph) bool { - for _, callSites := range callGraph.CallSites { +func (pr *PatternRegistry) matchDangerousFunction(pattern *Pattern, callGraph *CallGraph) *PatternMatchDetails { + for caller, callSites := range callGraph.CallSites { for _, callSite := range callSites { for _, dangerousFunc := range pattern.DangerousFunctions { if matchesFunctionName(callSite.TargetFQN, dangerousFunc) || matchesFunctionName(callSite.Target, dangerousFunc) { - return true + return &PatternMatchDetails{ + Matched: true, + SourceFQN: caller, + SinkFQN: callSite.TargetFQN, + DataFlowPath: []string{caller, callSite.TargetFQN}, + } } } } } - return false + return &PatternMatchDetails{Matched: false} } // matchSourceSink checks if there's a path from source to sink. -func (pr *PatternRegistry) matchSourceSink(pattern *Pattern, callGraph *CallGraph) bool { +func (pr *PatternRegistry) matchSourceSink(pattern *Pattern, callGraph *CallGraph) *PatternMatchDetails { sourceCalls := pr.findCallsByFunctions(pattern.Sources, callGraph) if len(sourceCalls) == 0 { - return false + return &PatternMatchDetails{Matched: false} } sinkCalls := pr.findCallsByFunctions(pattern.Sinks, callGraph) if len(sinkCalls) == 0 { - return false + return &PatternMatchDetails{Matched: false} } for _, source := range sourceCalls { for _, sink := range sinkCalls { - if pr.hasPath(source.caller, sink.caller, callGraph) { - return true + path := pr.findPath(source.caller, sink.caller, callGraph) + if len(path) > 0 { + return &PatternMatchDetails{ + Matched: true, + SourceFQN: source.caller, + SinkFQN: sink.caller, + DataFlowPath: path, + } } } } - return false + return &PatternMatchDetails{Matched: false} } // matchMissingSanitizer checks if there's a path from source to sink without sanitization. -func (pr *PatternRegistry) matchMissingSanitizer(pattern *Pattern, callGraph *CallGraph) bool { +func (pr *PatternRegistry) matchMissingSanitizer(pattern *Pattern, callGraph *CallGraph) *PatternMatchDetails { sourceCalls := pr.findCallsByFunctions(pattern.Sources, callGraph) if len(sourceCalls) == 0 { - return false + return &PatternMatchDetails{Matched: false} } sinkCalls := pr.findCallsByFunctions(pattern.Sinks, callGraph) if len(sinkCalls) == 0 { - return false + return &PatternMatchDetails{Matched: false} } sanitizerCalls := pr.findCallsByFunctions(pattern.Sanitizers, callGraph) + // Sort for deterministic results + sortCallInfo(sourceCalls) + sortCallInfo(sinkCalls) + for _, source := range sourceCalls { for _, sink := range sinkCalls { - if pr.hasPath(source.caller, sink.caller, callGraph) { + // Skip false positives where source and sink are in the same function + if source.caller == sink.caller { + continue + } + + path := pr.findPath(source.caller, sink.caller, callGraph) + if len(path) > 1 { // Require at least 2 functions in path + // Check if any sanitizer is on the path hasSanitizer := false for _, sanitizer := range sanitizerCalls { - if pr.hasPath(source.caller, sanitizer.caller, callGraph) && - pr.hasPath(sanitizer.caller, sink.caller, callGraph) { - hasSanitizer = true + // Check if sanitizer is in the path + for _, pathFunc := range path { + if pathFunc == sanitizer.caller { + hasSanitizer = true + break + } + } + if hasSanitizer { break } } if !hasSanitizer { - return true + return &PatternMatchDetails{ + Matched: true, + SourceFQN: source.caller, + SourceCall: source.target, + SinkFQN: sink.caller, + SinkCall: sink.target, + DataFlowPath: path, + } } } } } - return false + return &PatternMatchDetails{Matched: false} } // callInfo stores information about a function call location. @@ -242,20 +287,95 @@ func (pr *PatternRegistry) dfsPath(current, target string, callGraph *CallGraph, return false } +// findPath finds the complete path from source to sink in the call graph. +// Returns the path as a slice of function FQNs, or empty slice if no path exists. +func (pr *PatternRegistry) findPath(from, to string, callGraph *CallGraph) []string { + if from == to { + return []string{from} + } + + visited := make(map[string]bool) + path := make([]string, 0) + + if pr.dfsPathWithTrace(from, to, callGraph, visited, &path) { + return path + } + + return []string{} +} + +// dfsPathWithTrace performs depth-first search and captures the path. +func (pr *PatternRegistry) dfsPathWithTrace(current, target string, callGraph *CallGraph, visited map[string]bool, path *[]string) bool { + *path = append(*path, current) + + if current == target { + return true + } + + if visited[current] { + *path = (*path)[:len(*path)-1] // backtrack + return false + } + + visited[current] = true + + callees := callGraph.GetCallees(current) + for _, callee := range callees { + if pr.dfsPathWithTrace(callee, target, callGraph, visited, path) { + return true + } + } + + // Backtrack if no path found + *path = (*path)[:len(*path)-1] + return false +} + +// sortCallInfo sorts callInfo slices by caller FQN for deterministic results. +func sortCallInfo(calls []callInfo) { + // Simple bubble sort - good enough for small slices + for i := 0; i < len(calls); i++ { + for j := i + 1; j < len(calls); j++ { + if calls[i].caller > calls[j].caller { + calls[i], calls[j] = calls[j], calls[i] + } + } + } +} + // matchesFunctionName checks if a function name matches a pattern. -// Supports exact matches and suffix matches. +// Supports exact matches, suffix matches, and prefix matches. +// Examples: +// - "builtins.eval" matches pattern "eval" (suffix match) +// - "request.GET.get" matches pattern "request.GET" (prefix match for sources) +// - "vulnerable_app.eval" matches pattern "eval" (last component match) func matchesFunctionName(fqn, pattern string) bool { + // Exact match: "eval" == "eval" if fqn == pattern { return true } + // Suffix match: "builtins.eval" ends with ".eval" if strings.HasSuffix(fqn, "."+pattern) { return true } - if strings.Contains(fqn, pattern) { + // Prefix match: "request.GET.get" starts with "request.GET." + // This handles attribute access chains for sources + if strings.HasPrefix(fqn, pattern+".") { return true } + // Extract last component after last dot and compare + // This handles cases like "vulnerable_app.eval" → "eval" + // but avoids matching "executor" against "exec" + lastDot := strings.LastIndex(fqn, ".") + if lastDot >= 0 && lastDot < len(fqn)-1 { + lastComponent := fqn[lastDot+1:] + if lastComponent == pattern { + return true + } + } + return false } diff --git a/sourcecode-parser/graph/callgraph/patterns_test.go b/sourcecode-parser/graph/callgraph/patterns_test.go index d2694a3e..02708bab 100644 --- a/sourcecode-parser/graph/callgraph/patterns_test.go +++ b/sourcecode-parser/graph/callgraph/patterns_test.go @@ -89,9 +89,9 @@ func TestMatchesFunctionName(t *testing.T) { }{ {"Exact match", "eval", "eval", true}, {"Suffix match", "myapp.utils.eval", "eval", true}, - {"Contains match", "myapp.request.GET", "request.GET", true}, + {"Prefix match", "request.GET.get", "request.GET", true}, {"No match", "myapp.safe_function", "eval", false}, - {"Partial no match", "evaluation", "eval", true}, // Contains matches + {"Partial no match", "evaluation", "eval", false}, // Should NOT match - avoids false positives } for _, tt := range tests { @@ -117,7 +117,8 @@ func TestPatternRegistry_MatchDangerousFunction(t *testing.T) { }) matched := registry.MatchPattern(pattern, callGraph) - assert.True(t, matched) + assert.NotNil(t, matched) + assert.True(t, matched.Matched) } func TestPatternRegistry_MatchDangerousFunction_NoMatch(t *testing.T) { @@ -135,7 +136,12 @@ func TestPatternRegistry_MatchDangerousFunction_NoMatch(t *testing.T) { }) matched := registry.MatchPattern(pattern, callGraph) - assert.False(t, matched) + if matched == nil || !matched.Matched { + // Pattern didn't match (expected) + assert.True(t, true) + } else { + assert.Fail(t, "Expected no match but found one") + } } func TestPatternRegistry_MatchSourceSink(t *testing.T) { @@ -165,7 +171,8 @@ func TestPatternRegistry_MatchSourceSink(t *testing.T) { callGraph.AddEdge("myapp.process", "myapp.execute_code") matched := registry.MatchPattern(pattern, callGraph) - assert.True(t, matched) + assert.NotNil(t, matched) + assert.True(t, matched.Matched) } func TestPatternRegistry_MatchMissingSanitizer_WithSanitizer(t *testing.T) { @@ -200,7 +207,12 @@ func TestPatternRegistry_MatchMissingSanitizer_WithSanitizer(t *testing.T) { callGraph.AddEdge("myapp.sanitize_input", "myapp.execute_code") matched := registry.MatchPattern(pattern, callGraph) - assert.False(t, matched) // Should not match because sanitizer is present + if matched == nil || !matched.Matched { + // Pattern didn't match (expected) + assert.True(t, true) + } else { + assert.Fail(t, "Expected no match but found one") + } // Should not match because sanitizer is present } func TestPatternRegistry_MatchMissingSanitizer_WithoutSanitizer(t *testing.T) { @@ -229,7 +241,8 @@ func TestPatternRegistry_MatchMissingSanitizer_WithoutSanitizer(t *testing.T) { callGraph.AddEdge("myapp.get_input", "myapp.execute_code") matched := registry.MatchPattern(pattern, callGraph) - assert.True(t, matched) // Should match because sanitizer is missing + assert.NotNil(t, matched) + assert.True(t, matched.Matched) // Should match because sanitizer is missing } func TestPatternRegistry_HasPath(t *testing.T) { diff --git a/sourcecode-parser/graph/callgraph/types.go b/sourcecode-parser/graph/callgraph/types.go index 077ea8c1..dfe2b2f8 100644 --- a/sourcecode-parser/graph/callgraph/types.go +++ b/sourcecode-parser/graph/callgraph/types.go @@ -16,13 +16,23 @@ type Location struct { // It captures both the syntactic information (where the call is) and // semantic information (what is being called and with what arguments). type CallSite struct { - Target string // The name of the function being called (e.g., "eval", "utils.sanitize") - Location Location // Where this call occurs in the source code - Arguments []Argument // Arguments passed to the call - Resolved bool // Whether we successfully resolved this call to a definition - TargetFQN string // Fully qualified name after resolution (e.g., "myapp.utils.sanitize") + Target string // The name of the function being called (e.g., "eval", "utils.sanitize") + Location Location // Where this call occurs in the source code + Arguments []Argument // Arguments passed to the call + Resolved bool // Whether we successfully resolved this call to a definition + TargetFQN string // Fully qualified name after resolution (e.g., "myapp.utils.sanitize") + FailureReason string // Why resolution failed (empty if Resolved=true) } +// Resolution failure reason categories for diagnostics: +// - "external_framework" - Call to Django, REST framework, pytest, stdlib, etc. +// - "orm_pattern" - Django ORM patterns like Model.objects.filter() +// - "attribute_chain" - Method calls on return values like response.json() +// - "variable_method" - Method calls on variables like value.split() +// - "super_call" - Calls via super() to parent class methods +// - "not_in_imports" - Simple function call not found in imports +// - "unknown" - Unresolved for other reasons + // Argument represents a single argument passed to a function call. // Tracks both the value/expression and metadata about the argument. type Argument struct { diff --git a/sourcecode-parser/main_test.go b/sourcecode-parser/main_test.go index 5a072f0e..fc27616f 100644 --- a/sourcecode-parser/main_test.go +++ b/sourcecode-parser/main_test.go @@ -23,7 +23,7 @@ func TestExecute(t *testing.T) { { name: "Successful execution", mockExecuteErr: nil, - expectedOutput: "Code Pathfinder is designed for identifying vulnerabilities in source code.\n\nUsage:\n pathfinder [command]\n\nAvailable Commands:\n ci Scan a project for vulnerabilities with ruleset in ci mode\n completion Generate the autocompletion script for the specified shell\n help Help about any command\n query Execute queries on the source code\n scan Scan a project for vulnerabilities with ruleset\n version Print the version and commit information\n\nFlags:\n --disable-metrics Disable metrics collection\n -h, --help help for pathfinder\n --verbose Verbose output\n\nUse \"pathfinder [command] --help\" for more information about a command.\n", + expectedOutput: "Code Pathfinder is designed for identifying vulnerabilities in source code.\n\nUsage:\n pathfinder [command]\n\nAvailable Commands:\n analyze Analyze source code for security vulnerabilities using call graph\n ci Scan a project for vulnerabilities with ruleset in ci mode\n completion Generate the autocompletion script for the specified shell\n help Help about any command\n query Execute queries on the source code\n resolution-report Generate a diagnostic report on call resolution statistics\n scan Scan a project for vulnerabilities with ruleset\n version Print the version and commit information\n\nFlags:\n --disable-metrics Disable metrics collection\n -h, --help help for pathfinder\n --verbose Verbose output\n\nUse \"pathfinder [command] --help\" for more information about a command.\n", expectedExit: 0, }, }