diff --git a/sourcecode-parser/graph/callgraph/statement.go b/sourcecode-parser/graph/callgraph/statement.go new file mode 100644 index 00000000..1205bf60 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/statement.go @@ -0,0 +1,207 @@ +package callgraph + +// StatementType represents the type of statement in the code. +type StatementType string + +const ( + // Assignment represents variable assignments: x = expr. + StatementTypeAssignment StatementType = "assignment" + + // Call represents function/method calls: foo(), obj.method(). + StatementTypeCall StatementType = "call" + + // Return represents return statements: return expr. + StatementTypeReturn StatementType = "return" + + // If represents conditional statements: if condition: ... + StatementTypeIf StatementType = "if" + + // For represents loop statements: for x in iterable: ... + StatementTypeFor StatementType = "for" + + // While represents while loop statements: while condition: ... + StatementTypeWhile StatementType = "while" + + // With represents context manager statements: with expr as var: ... + StatementTypeWith StatementType = "with" + + // Try represents exception handling: try: ... except: ... + StatementTypeTry StatementType = "try" + + // Raise represents exception raising: raise Exception(). + StatementTypeRaise StatementType = "raise" + + // Import represents import statements: import module, from module import name. + StatementTypeImport StatementType = "import" + + // Expression represents expression statements (calls, attribute access, etc.). + StatementTypeExpression StatementType = "expression" +) + +// Statement represents a single statement in the code with def-use information. +type Statement struct { + // Type is the kind of statement (assignment, call, return, etc.) + Type StatementType + + // LineNumber is the source line number for this statement (1-indexed) + LineNumber uint32 + + // Def is the variable being defined by this statement (if any) + // For assignments: the left-hand side variable + // For for loops: the loop variable + // For with statements: the as variable + // Empty string if no definition + Def string + + // Uses is the list of variables used/read by this statement + // For assignments: variables in the right-hand side expression + // For calls: variables used in arguments + // For conditions: variables in the condition expression + Uses []string + + // CallTarget is the function/method being called (if Type == StatementTypeCall) + // Format: "function_name" for direct calls, "obj.method" for method calls + // Empty string for non-call statements + CallTarget string + + // CallArgs are the argument variables passed to the call (if Type == StatementTypeCall) + // Only includes variable names, not literals + CallArgs []string + + // NestedStatements contains statements inside this statement's body + // Used for if/for/while/with/try blocks + // Empty for simple statements like assignments + NestedStatements []*Statement + + // ElseBranch contains statements in the else branch (if applicable) + // Used for if/try statements + ElseBranch []*Statement +} + +// GetDef returns the variable defined by this statement, or empty string if none. +func (s *Statement) GetDef() string { + return s.Def +} + +// GetUses returns the list of variables used by this statement. +func (s *Statement) GetUses() []string { + return s.Uses +} + +// IsCall returns true if this statement is a function/method call. +func (s *Statement) IsCall() bool { + return s.Type == StatementTypeCall || s.Type == StatementTypeExpression +} + +// IsAssignment returns true if this statement is a variable assignment. +func (s *Statement) IsAssignment() bool { + return s.Type == StatementTypeAssignment +} + +// IsControlFlow returns true if this statement is a control flow construct. +func (s *Statement) IsControlFlow() bool { + switch s.Type { + case StatementTypeIf, StatementTypeFor, StatementTypeWhile, StatementTypeWith, StatementTypeTry: + return true + default: + return false + } +} + +// HasNestedStatements returns true if this statement contains nested statements. +func (s *Statement) HasNestedStatements() bool { + return len(s.NestedStatements) > 0 || len(s.ElseBranch) > 0 +} + +// AllStatements returns a flattened list of this statement and all nested statements. +// Performs depth-first traversal. +func (s *Statement) AllStatements() []*Statement { + result := []*Statement{s} + + for _, nested := range s.NestedStatements { + result = append(result, nested.AllStatements()...) + } + + for _, elseBranch := range s.ElseBranch { + result = append(result, elseBranch.AllStatements()...) + } + + return result +} + +// DefUseChain represents the def-use relationships for all variables in a function. +type DefUseChain struct { + // Defs maps variable names to all statements that define them. + // A variable can have multiple definitions across different code paths. + Defs map[string][]*Statement + + // Uses maps variable names to all statements that use them. + // A variable can be used in multiple places. + Uses map[string][]*Statement +} + +// NewDefUseChain creates an empty def-use chain. +func NewDefUseChain() *DefUseChain { + return &DefUseChain{ + Defs: make(map[string][]*Statement), + Uses: make(map[string][]*Statement), + } +} + +// AddDef registers a statement as defining a variable. +func (chain *DefUseChain) AddDef(varName string, stmt *Statement) { + if varName == "" { + return + } + chain.Defs[varName] = append(chain.Defs[varName], stmt) +} + +// AddUse registers a statement as using a variable. +func (chain *DefUseChain) AddUse(varName string, stmt *Statement) { + if varName == "" { + return + } + chain.Uses[varName] = append(chain.Uses[varName], stmt) +} + +// GetDefs returns all statements that define a given variable. +// Returns empty slice if variable is never defined. +func (chain *DefUseChain) GetDefs(varName string) []*Statement { + return chain.Defs[varName] +} + +// GetUses returns all statements that use a given variable. +// Returns empty slice if variable is never used. +func (chain *DefUseChain) GetUses(varName string) []*Statement { + return chain.Uses[varName] +} + +// IsDefined returns true if the variable has at least one definition. +func (chain *DefUseChain) IsDefined(varName string) bool { + return len(chain.Defs[varName]) > 0 +} + +// IsUsed returns true if the variable has at least one use. +func (chain *DefUseChain) IsUsed(varName string) bool { + return len(chain.Uses[varName]) > 0 +} + +// AllVariables returns a list of all variable names in the def-use chain. +func (chain *DefUseChain) AllVariables() []string { + varSet := make(map[string]bool) + + for varName := range chain.Defs { + varSet[varName] = true + } + + for varName := range chain.Uses { + varSet[varName] = true + } + + result := make([]string, 0, len(varSet)) + for varName := range varSet { + result = append(result, varName) + } + + return result +} diff --git a/sourcecode-parser/graph/callgraph/statement_test.go b/sourcecode-parser/graph/callgraph/statement_test.go new file mode 100644 index 00000000..0c8a64b4 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/statement_test.go @@ -0,0 +1,525 @@ +package callgraph + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStatementGetDef(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected string + }{ + { + name: "assignment with def", + stmt: &Statement{ + Type: StatementTypeAssignment, + Def: "x", + }, + expected: "x", + }, + { + name: "call without def", + stmt: &Statement{ + Type: StatementTypeCall, + Def: "", + }, + expected: "", + }, + { + name: "for loop with def", + stmt: &Statement{ + Type: StatementTypeFor, + Def: "item", + }, + expected: "item", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.GetDef() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementGetUses(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected []string + }{ + { + name: "assignment with uses", + stmt: &Statement{ + Type: StatementTypeAssignment, + Uses: []string{"a", "b"}, + }, + expected: []string{"a", "b"}, + }, + { + name: "call with no uses", + stmt: &Statement{ + Type: StatementTypeCall, + Uses: []string{}, + }, + expected: []string{}, + }, + { + name: "if statement with condition uses", + stmt: &Statement{ + Type: StatementTypeIf, + Uses: []string{"flag", "count"}, + }, + expected: []string{"flag", "count"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.GetUses() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementIsCall(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected bool + }{ + { + name: "call statement", + stmt: &Statement{Type: StatementTypeCall}, + expected: true, + }, + { + name: "expression statement", + stmt: &Statement{Type: StatementTypeExpression}, + expected: true, + }, + { + name: "assignment statement", + stmt: &Statement{Type: StatementTypeAssignment}, + expected: false, + }, + { + name: "return statement", + stmt: &Statement{Type: StatementTypeReturn}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.IsCall() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementIsAssignment(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected bool + }{ + { + name: "assignment statement", + stmt: &Statement{Type: StatementTypeAssignment}, + expected: true, + }, + { + name: "call statement", + stmt: &Statement{Type: StatementTypeCall}, + expected: false, + }, + { + name: "for statement", + stmt: &Statement{Type: StatementTypeFor}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.IsAssignment() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementIsControlFlow(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected bool + }{ + { + name: "if statement", + stmt: &Statement{Type: StatementTypeIf}, + expected: true, + }, + { + name: "for statement", + stmt: &Statement{Type: StatementTypeFor}, + expected: true, + }, + { + name: "while statement", + stmt: &Statement{Type: StatementTypeWhile}, + expected: true, + }, + { + name: "with statement", + stmt: &Statement{Type: StatementTypeWith}, + expected: true, + }, + { + name: "try statement", + stmt: &Statement{Type: StatementTypeTry}, + expected: true, + }, + { + name: "assignment statement", + stmt: &Statement{Type: StatementTypeAssignment}, + expected: false, + }, + { + name: "call statement", + stmt: &Statement{Type: StatementTypeCall}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.IsControlFlow() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementHasNestedStatements(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected bool + }{ + { + name: "if with nested statements", + stmt: &Statement{ + Type: StatementTypeIf, + NestedStatements: []*Statement{ + {Type: StatementTypeAssignment}, + }, + }, + expected: true, + }, + { + name: "if with else branch", + stmt: &Statement{ + Type: StatementTypeIf, + ElseBranch: []*Statement{ + {Type: StatementTypeReturn}, + }, + }, + expected: true, + }, + { + name: "simple assignment", + stmt: &Statement{ + Type: StatementTypeAssignment, + }, + expected: false, + }, + { + name: "if with both nested and else", + stmt: &Statement{ + Type: StatementTypeIf, + NestedStatements: []*Statement{ + {Type: StatementTypeAssignment}, + }, + ElseBranch: []*Statement{ + {Type: StatementTypeReturn}, + }, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.HasNestedStatements() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementAllStatements(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expectedCount int + }{ + { + name: "simple statement", + stmt: &Statement{ + Type: StatementTypeAssignment, + LineNumber: 1, + }, + expectedCount: 1, + }, + { + name: "if with one nested statement", + stmt: &Statement{ + Type: StatementTypeIf, + LineNumber: 1, + NestedStatements: []*Statement{ + {Type: StatementTypeAssignment, LineNumber: 2}, + }, + }, + expectedCount: 2, + }, + { + name: "if with nested and else", + stmt: &Statement{ + Type: StatementTypeIf, + LineNumber: 1, + NestedStatements: []*Statement{ + {Type: StatementTypeAssignment, LineNumber: 2}, + {Type: StatementTypeCall, LineNumber: 3}, + }, + ElseBranch: []*Statement{ + {Type: StatementTypeReturn, LineNumber: 5}, + }, + }, + expectedCount: 4, + }, + { + name: "deeply nested statements", + stmt: &Statement{ + Type: StatementTypeIf, + LineNumber: 1, + NestedStatements: []*Statement{ + { + Type: StatementTypeFor, + LineNumber: 2, + NestedStatements: []*Statement{ + {Type: StatementTypeAssignment, LineNumber: 3}, + {Type: StatementTypeCall, LineNumber: 4}, + }, + }, + }, + }, + expectedCount: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.AllStatements() + assert.Equal(t, tt.expectedCount, len(result)) + + // Verify first statement is always the root + assert.Equal(t, tt.stmt, result[0]) + }) + } +} + +func TestNewDefUseChain(t *testing.T) { + chain := NewDefUseChain() + + assert.NotNil(t, chain) + assert.NotNil(t, chain.Defs) + assert.NotNil(t, chain.Uses) + assert.Equal(t, 0, len(chain.Defs)) + assert.Equal(t, 0, len(chain.Uses)) +} + +func TestDefUseChainAddDef(t *testing.T) { + chain := NewDefUseChain() + stmt1 := &Statement{Type: StatementTypeAssignment, LineNumber: 1, Def: "x"} + stmt2 := &Statement{Type: StatementTypeAssignment, LineNumber: 2, Def: "x"} + + // Add first definition + chain.AddDef("x", stmt1) + assert.Equal(t, 1, len(chain.Defs["x"])) + assert.Equal(t, stmt1, chain.Defs["x"][0]) + + // Add second definition of same variable + chain.AddDef("x", stmt2) + assert.Equal(t, 2, len(chain.Defs["x"])) + assert.Equal(t, stmt2, chain.Defs["x"][1]) + + // Add definition for different variable + stmt3 := &Statement{Type: StatementTypeAssignment, LineNumber: 3, Def: "y"} + chain.AddDef("y", stmt3) + assert.Equal(t, 1, len(chain.Defs["y"])) + assert.Equal(t, stmt3, chain.Defs["y"][0]) + + // Test empty variable name (should be ignored) + chain.AddDef("", stmt1) + _, exists := chain.Defs[""] + assert.False(t, exists) +} + +func TestDefUseChainAddUse(t *testing.T) { + chain := NewDefUseChain() + stmt1 := &Statement{Type: StatementTypeCall, LineNumber: 1, Uses: []string{"x"}} + stmt2 := &Statement{Type: StatementTypeAssignment, LineNumber: 2, Uses: []string{"x"}} + + // Add first use + chain.AddUse("x", stmt1) + assert.Equal(t, 1, len(chain.Uses["x"])) + assert.Equal(t, stmt1, chain.Uses["x"][0]) + + // Add second use of same variable + chain.AddUse("x", stmt2) + assert.Equal(t, 2, len(chain.Uses["x"])) + assert.Equal(t, stmt2, chain.Uses["x"][1]) + + // Add use for different variable + stmt3 := &Statement{Type: StatementTypeReturn, LineNumber: 3, Uses: []string{"y"}} + chain.AddUse("y", stmt3) + assert.Equal(t, 1, len(chain.Uses["y"])) + assert.Equal(t, stmt3, chain.Uses["y"][0]) + + // Test empty variable name (should be ignored) + chain.AddUse("", stmt1) + _, exists := chain.Uses[""] + assert.False(t, exists) +} + +func TestDefUseChainGetDefs(t *testing.T) { + chain := NewDefUseChain() + stmt1 := &Statement{Type: StatementTypeAssignment, LineNumber: 1, Def: "x"} + stmt2 := &Statement{Type: StatementTypeAssignment, LineNumber: 2, Def: "x"} + + chain.AddDef("x", stmt1) + chain.AddDef("x", stmt2) + + defs := chain.GetDefs("x") + assert.Equal(t, 2, len(defs)) + assert.Equal(t, stmt1, defs[0]) + assert.Equal(t, stmt2, defs[1]) + + // Test non-existent variable + nonExistent := chain.GetDefs("nonexistent") + assert.Nil(t, nonExistent) +} + +func TestDefUseChainGetUses(t *testing.T) { + chain := NewDefUseChain() + stmt1 := &Statement{Type: StatementTypeCall, LineNumber: 1, Uses: []string{"x"}} + stmt2 := &Statement{Type: StatementTypeAssignment, LineNumber: 2, Uses: []string{"x"}} + + chain.AddUse("x", stmt1) + chain.AddUse("x", stmt2) + + uses := chain.GetUses("x") + assert.Equal(t, 2, len(uses)) + assert.Equal(t, stmt1, uses[0]) + assert.Equal(t, stmt2, uses[1]) + + // Test non-existent variable + nonExistent := chain.GetUses("nonexistent") + assert.Nil(t, nonExistent) +} + +func TestDefUseChainIsDefined(t *testing.T) { + chain := NewDefUseChain() + stmt := &Statement{Type: StatementTypeAssignment, LineNumber: 1, Def: "x"} + + assert.False(t, chain.IsDefined("x")) + + chain.AddDef("x", stmt) + assert.True(t, chain.IsDefined("x")) + assert.False(t, chain.IsDefined("y")) +} + +func TestDefUseChainIsUsed(t *testing.T) { + chain := NewDefUseChain() + stmt := &Statement{Type: StatementTypeCall, LineNumber: 1, Uses: []string{"x"}} + + assert.False(t, chain.IsUsed("x")) + + chain.AddUse("x", stmt) + assert.True(t, chain.IsUsed("x")) + assert.False(t, chain.IsUsed("y")) +} + +func TestDefUseChainAllVariables(t *testing.T) { + chain := NewDefUseChain() + + stmt1 := &Statement{Type: StatementTypeAssignment, LineNumber: 1, Def: "x"} + stmt2 := &Statement{Type: StatementTypeCall, LineNumber: 2, Uses: []string{"y"}} + stmt3 := &Statement{Type: StatementTypeAssignment, LineNumber: 3, Def: "z", Uses: []string{"x"}} + + chain.AddDef("x", stmt1) + chain.AddUse("y", stmt2) + chain.AddDef("z", stmt3) + chain.AddUse("x", stmt3) + + vars := chain.AllVariables() + assert.Equal(t, 3, len(vars)) + + // Create a set to check presence + varSet := make(map[string]bool) + for _, v := range vars { + varSet[v] = true + } + + assert.True(t, varSet["x"]) + assert.True(t, varSet["y"]) + assert.True(t, varSet["z"]) +} + +func TestDefUseChainComplexScenario(t *testing.T) { + // Simulate a real code scenario: + // 1: x = 5 + // 2: y = x + 10 + // 3: if y > 15: + // 4: z = x * 2 + // 5: print(z) + + chain := NewDefUseChain() + + stmt1 := &Statement{Type: StatementTypeAssignment, LineNumber: 1, Def: "x"} + stmt2 := &Statement{Type: StatementTypeAssignment, LineNumber: 2, Def: "y", Uses: []string{"x"}} + stmt3 := &Statement{Type: StatementTypeIf, LineNumber: 3, Uses: []string{"y"}} + stmt4 := &Statement{Type: StatementTypeAssignment, LineNumber: 4, Def: "z", Uses: []string{"x"}} + stmt5 := &Statement{Type: StatementTypeCall, LineNumber: 5, Uses: []string{"z"}} + + chain.AddDef("x", stmt1) + + chain.AddDef("y", stmt2) + chain.AddUse("x", stmt2) + + chain.AddUse("y", stmt3) + + chain.AddDef("z", stmt4) + chain.AddUse("x", stmt4) + + chain.AddUse("z", stmt5) + + // Verify x: 1 def, 2 uses + assert.Equal(t, 1, len(chain.GetDefs("x"))) + assert.Equal(t, 2, len(chain.GetUses("x"))) + + // Verify y: 1 def, 1 use + assert.Equal(t, 1, len(chain.GetDefs("y"))) + assert.Equal(t, 1, len(chain.GetUses("y"))) + + // Verify z: 1 def, 1 use + assert.Equal(t, 1, len(chain.GetDefs("z"))) + assert.Equal(t, 1, len(chain.GetUses("z"))) + + // All variables + vars := chain.AllVariables() + assert.Equal(t, 3, len(vars)) +} diff --git a/sourcecode-parser/graph/callgraph/taint_summary.go b/sourcecode-parser/graph/callgraph/taint_summary.go new file mode 100644 index 00000000..b277891e --- /dev/null +++ b/sourcecode-parser/graph/callgraph/taint_summary.go @@ -0,0 +1,238 @@ +package callgraph + +// TaintInfo represents detailed taint tracking information for a single detection. +type TaintInfo struct { + // SourceLine is the line number where taint originated (1-indexed) + SourceLine uint32 + + // SourceVar is the variable name at the taint source + SourceVar string + + // SinkLine is the line number where tainted data reaches a dangerous sink (1-indexed) + SinkLine uint32 + + // SinkVar is the variable name at the sink + SinkVar string + + // SinkCall is the dangerous function/method call at the sink + // Examples: "execute", "eval", "os.system" + SinkCall string + + // PropagationPath is the list of variables through which taint propagated + // Example: ["user_input", "data", "query"] shows user_input -> data -> query + PropagationPath []string + + // Confidence is a score from 0.0 to 1.0 indicating detection confidence + // 1.0 = high confidence (direct flow) + // 0.7 = medium confidence (through stdlib function) + // 0.5 = low confidence (through third-party library) + // 0.0 = no taint detected + Confidence float64 + + // Sanitized indicates if a sanitizer was detected in the propagation path + // If true, the taint was neutralized and should not trigger a finding + Sanitized bool + + // SanitizerLine is the line number where sanitization occurred (if Sanitized == true) + SanitizerLine uint32 + + // SanitizerCall is the sanitizer function that was called + // Examples: "escape_html", "quote_sql", "validate_email" + SanitizerCall string +} + +// IsTainted returns true if this TaintInfo represents actual taint (confidence > 0). +func (ti *TaintInfo) IsTainted() bool { + return ti.Confidence > 0.0 && !ti.Sanitized +} + +// IsHighConfidence returns true if confidence >= 0.8. +func (ti *TaintInfo) IsHighConfidence() bool { + return ti.Confidence >= 0.8 +} + +// IsMediumConfidence returns true if 0.5 <= confidence < 0.8. +func (ti *TaintInfo) IsMediumConfidence() bool { + return ti.Confidence >= 0.5 && ti.Confidence < 0.8 +} + +// IsLowConfidence returns true if 0.0 < confidence < 0.5. +func (ti *TaintInfo) IsLowConfidence() bool { + return ti.Confidence > 0.0 && ti.Confidence < 0.5 +} + +// TaintSummary represents the complete taint analysis results for a function. +type TaintSummary struct { + // FunctionFQN is the fully qualified name of the analyzed function + // Format: "module.Class.method" or "module.function" + FunctionFQN string + + // TaintedVars maps variable names to their taint information + // If a variable is not in this map, it is considered untainted + // Multiple TaintInfo entries indicate multiple taint paths to the same variable + TaintedVars map[string][]*TaintInfo + + // Detections is a list of all taint flows that reached a dangerous sink + // These represent potential security vulnerabilities + Detections []*TaintInfo + + // TaintedParams tracks which function parameters are tainted (by parameter name) + // Used for inter-procedural analysis + TaintedParams []string + + // TaintedReturn indicates if the function's return value is tainted + TaintedReturn bool + + // ReturnTaintInfo provides details if TaintedReturn is true + ReturnTaintInfo *TaintInfo + + // AnalysisError indicates if the analysis failed for this function + // If true, the summary is incomplete + AnalysisError bool + + // ErrorMessage contains the error description if AnalysisError is true + ErrorMessage string +} + +// NewTaintSummary creates an empty taint summary for a function. +func NewTaintSummary(functionFQN string) *TaintSummary { + return &TaintSummary{ + FunctionFQN: functionFQN, + TaintedVars: make(map[string][]*TaintInfo), + Detections: make([]*TaintInfo, 0), + TaintedParams: make([]string, 0), + } +} + +// AddTaintedVar records taint information for a variable. +func (ts *TaintSummary) AddTaintedVar(varName string, taintInfo *TaintInfo) { + if varName == "" || taintInfo == nil { + return + } + ts.TaintedVars[varName] = append(ts.TaintedVars[varName], taintInfo) +} + +// GetTaintInfo retrieves all taint information for a variable. +// Returns nil if variable is not tainted. +func (ts *TaintSummary) GetTaintInfo(varName string) []*TaintInfo { + return ts.TaintedVars[varName] +} + +// IsTainted checks if a variable is tainted (has at least one unsanitized taint path). +func (ts *TaintSummary) IsTainted(varName string) bool { + taintInfos := ts.TaintedVars[varName] + for _, info := range taintInfos { + if info.IsTainted() { + return true + } + } + return false +} + +// AddDetection records a taint flow that reached a dangerous sink. +func (ts *TaintSummary) AddDetection(detection *TaintInfo) { + if detection == nil { + return + } + ts.Detections = append(ts.Detections, detection) +} + +// HasDetections returns true if any taint flows reached dangerous sinks. +func (ts *TaintSummary) HasDetections() bool { + return len(ts.Detections) > 0 +} + +// GetHighConfidenceDetections returns detections with confidence >= 0.8. +func (ts *TaintSummary) GetHighConfidenceDetections() []*TaintInfo { + result := make([]*TaintInfo, 0) + for _, detection := range ts.Detections { + if detection.IsHighConfidence() { + result = append(result, detection) + } + } + return result +} + +// GetMediumConfidenceDetections returns detections with 0.5 <= confidence < 0.8. +func (ts *TaintSummary) GetMediumConfidenceDetections() []*TaintInfo { + result := make([]*TaintInfo, 0) + for _, detection := range ts.Detections { + if detection.IsMediumConfidence() { + result = append(result, detection) + } + } + return result +} + +// GetLowConfidenceDetections returns detections with 0.0 < confidence < 0.5. +func (ts *TaintSummary) GetLowConfidenceDetections() []*TaintInfo { + result := make([]*TaintInfo, 0) + for _, detection := range ts.Detections { + if detection.IsLowConfidence() { + result = append(result, detection) + } + } + return result +} + +// MarkTaintedParam marks a function parameter as tainted. +func (ts *TaintSummary) MarkTaintedParam(paramName string) { + if paramName == "" { + return + } + + // Check if already marked + for _, p := range ts.TaintedParams { + if p == paramName { + return + } + } + + ts.TaintedParams = append(ts.TaintedParams, paramName) +} + +// IsParamTainted checks if a function parameter is tainted. +func (ts *TaintSummary) IsParamTainted(paramName string) bool { + for _, p := range ts.TaintedParams { + if p == paramName { + return true + } + } + return false +} + +// MarkReturnTainted marks the function's return value as tainted. +func (ts *TaintSummary) MarkReturnTainted(taintInfo *TaintInfo) { + ts.TaintedReturn = true + ts.ReturnTaintInfo = taintInfo +} + +// SetError marks the analysis as failed with an error message. +func (ts *TaintSummary) SetError(errorMsg string) { + ts.AnalysisError = true + ts.ErrorMessage = errorMsg +} + +// IsComplete returns true if analysis completed without errors. +func (ts *TaintSummary) IsComplete() bool { + return !ts.AnalysisError +} + +// GetTaintedVarCount returns the number of distinct tainted variables. +func (ts *TaintSummary) GetTaintedVarCount() int { + count := 0 + for _, taintInfos := range ts.TaintedVars { + for _, info := range taintInfos { + if info.IsTainted() { + count++ + break // Count each variable only once + } + } + } + return count +} + +// GetDetectionCount returns the total number of detections. +func (ts *TaintSummary) GetDetectionCount() int { + return len(ts.Detections) +} diff --git a/sourcecode-parser/graph/callgraph/taint_summary_test.go b/sourcecode-parser/graph/callgraph/taint_summary_test.go new file mode 100644 index 00000000..8ddb4d90 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/taint_summary_test.go @@ -0,0 +1,507 @@ +package callgraph + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTaintInfoIsTainted(t *testing.T) { + tests := []struct { + name string + info *TaintInfo + expected bool + }{ + { + name: "high confidence taint", + info: &TaintInfo{ + Confidence: 1.0, + Sanitized: false, + }, + expected: true, + }, + { + name: "medium confidence taint", + info: &TaintInfo{ + Confidence: 0.7, + Sanitized: false, + }, + expected: true, + }, + { + name: "sanitized taint", + info: &TaintInfo{ + Confidence: 1.0, + Sanitized: true, + }, + expected: false, + }, + { + name: "zero confidence", + info: &TaintInfo{ + Confidence: 0.0, + Sanitized: false, + }, + expected: false, + }, + { + name: "low confidence but not sanitized", + info: &TaintInfo{ + Confidence: 0.3, + Sanitized: false, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.info.IsTainted() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTaintInfoIsHighConfidence(t *testing.T) { + tests := []struct { + name string + confidence float64 + expected bool + }{ + {name: "perfect confidence", confidence: 1.0, expected: true}, + {name: "high confidence", confidence: 0.9, expected: true}, + {name: "exactly 0.8", confidence: 0.8, expected: true}, + {name: "just below threshold", confidence: 0.79, expected: false}, + {name: "medium confidence", confidence: 0.6, expected: false}, + {name: "low confidence", confidence: 0.3, expected: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := &TaintInfo{Confidence: tt.confidence} + result := info.IsHighConfidence() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTaintInfoIsMediumConfidence(t *testing.T) { + tests := []struct { + name string + confidence float64 + expected bool + }{ + {name: "high confidence", confidence: 1.0, expected: false}, + {name: "just below high", confidence: 0.79, expected: true}, + {name: "mid range", confidence: 0.6, expected: true}, + {name: "exactly 0.5", confidence: 0.5, expected: true}, + {name: "just below medium", confidence: 0.49, expected: false}, + {name: "low confidence", confidence: 0.3, expected: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := &TaintInfo{Confidence: tt.confidence} + result := info.IsMediumConfidence() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTaintInfoIsLowConfidence(t *testing.T) { + tests := []struct { + name string + confidence float64 + expected bool + }{ + {name: "medium confidence", confidence: 0.6, expected: false}, + {name: "just below medium", confidence: 0.49, expected: true}, + {name: "low confidence", confidence: 0.3, expected: true}, + {name: "very low", confidence: 0.1, expected: true}, + {name: "zero confidence", confidence: 0.0, expected: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := &TaintInfo{Confidence: tt.confidence} + result := info.IsLowConfidence() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestNewTaintSummary(t *testing.T) { + summary := NewTaintSummary("module.Class.method") + + assert.Equal(t, "module.Class.method", summary.FunctionFQN) + assert.NotNil(t, summary.TaintedVars) + assert.Equal(t, 0, len(summary.TaintedVars)) + assert.NotNil(t, summary.Detections) + assert.Equal(t, 0, len(summary.Detections)) + assert.NotNil(t, summary.TaintedParams) + assert.Equal(t, 0, len(summary.TaintedParams)) + assert.False(t, summary.TaintedReturn) + assert.Nil(t, summary.ReturnTaintInfo) + assert.False(t, summary.AnalysisError) + assert.Equal(t, "", summary.ErrorMessage) +} + +func TestTaintSummaryAddTaintedVar(t *testing.T) { + summary := NewTaintSummary("test.function") + + taint1 := &TaintInfo{ + SourceLine: 1, + SourceVar: "input", + Confidence: 1.0, + } + + taint2 := &TaintInfo{ + SourceLine: 2, + SourceVar: "input2", + Confidence: 0.7, + } + + // Add first taint + summary.AddTaintedVar("x", taint1) + assert.Equal(t, 1, len(summary.TaintedVars["x"])) + assert.Equal(t, taint1, summary.TaintedVars["x"][0]) + + // Add second taint to same variable + summary.AddTaintedVar("x", taint2) + assert.Equal(t, 2, len(summary.TaintedVars["x"])) + assert.Equal(t, taint2, summary.TaintedVars["x"][1]) + + // Test empty variable name (should be ignored) + summary.AddTaintedVar("", taint1) + _, exists := summary.TaintedVars[""] + assert.False(t, exists) + + // Test nil taint info (should be ignored) + summary.AddTaintedVar("y", nil) + _, exists = summary.TaintedVars["y"] + assert.False(t, exists) +} + +func TestTaintSummaryGetTaintInfo(t *testing.T) { + summary := NewTaintSummary("test.function") + + taint := &TaintInfo{ + SourceLine: 1, + SourceVar: "input", + Confidence: 1.0, + } + + summary.AddTaintedVar("x", taint) + + // Get existing taint + result := summary.GetTaintInfo("x") + assert.NotNil(t, result) + assert.Equal(t, 1, len(result)) + assert.Equal(t, taint, result[0]) + + // Get non-existent variable + nonExistent := summary.GetTaintInfo("nonexistent") + assert.Nil(t, nonExistent) +} + +func TestTaintSummaryIsTainted(t *testing.T) { + summary := NewTaintSummary("test.function") + + // Add tainted variable + taint1 := &TaintInfo{ + Confidence: 1.0, + Sanitized: false, + } + summary.AddTaintedVar("x", taint1) + assert.True(t, summary.IsTainted("x")) + + // Add sanitized taint + taint2 := &TaintInfo{ + Confidence: 1.0, + Sanitized: true, + } + summary.AddTaintedVar("y", taint2) + assert.False(t, summary.IsTainted("y")) + + // Add variable with both tainted and sanitized paths + summary.AddTaintedVar("z", taint1) // tainted + summary.AddTaintedVar("z", taint2) // sanitized + assert.True(t, summary.IsTainted("z")) // Should return true if ANY path is tainted + + // Check non-existent variable + assert.False(t, summary.IsTainted("nonexistent")) +} + +func TestTaintSummaryAddDetection(t *testing.T) { + summary := NewTaintSummary("test.function") + + detection1 := &TaintInfo{ + SourceLine: 1, + SinkLine: 5, + SinkCall: "execute", + Confidence: 1.0, + } + + detection2 := &TaintInfo{ + SourceLine: 2, + SinkLine: 6, + SinkCall: "eval", + Confidence: 0.8, + } + + summary.AddDetection(detection1) + assert.Equal(t, 1, len(summary.Detections)) + assert.Equal(t, detection1, summary.Detections[0]) + + summary.AddDetection(detection2) + assert.Equal(t, 2, len(summary.Detections)) + assert.Equal(t, detection2, summary.Detections[1]) + + // Test nil detection (should be ignored) + summary.AddDetection(nil) + assert.Equal(t, 2, len(summary.Detections)) +} + +func TestTaintSummaryHasDetections(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.False(t, summary.HasDetections()) + + detection := &TaintInfo{ + SourceLine: 1, + SinkLine: 5, + Confidence: 1.0, + } + summary.AddDetection(detection) + + assert.True(t, summary.HasDetections()) +} + +func TestTaintSummaryGetHighConfidenceDetections(t *testing.T) { + summary := NewTaintSummary("test.function") + + high1 := &TaintInfo{Confidence: 1.0} + high2 := &TaintInfo{Confidence: 0.9} + medium := &TaintInfo{Confidence: 0.6} + low := &TaintInfo{Confidence: 0.3} + + summary.AddDetection(high1) + summary.AddDetection(medium) + summary.AddDetection(high2) + summary.AddDetection(low) + + highConf := summary.GetHighConfidenceDetections() + assert.Equal(t, 2, len(highConf)) + assert.Equal(t, high1, highConf[0]) + assert.Equal(t, high2, highConf[1]) +} + +func TestTaintSummaryGetMediumConfidenceDetections(t *testing.T) { + summary := NewTaintSummary("test.function") + + high := &TaintInfo{Confidence: 1.0} + medium1 := &TaintInfo{Confidence: 0.7} + medium2 := &TaintInfo{Confidence: 0.5} + low := &TaintInfo{Confidence: 0.3} + + summary.AddDetection(high) + summary.AddDetection(medium1) + summary.AddDetection(low) + summary.AddDetection(medium2) + + mediumConf := summary.GetMediumConfidenceDetections() + assert.Equal(t, 2, len(mediumConf)) + assert.Equal(t, medium1, mediumConf[0]) + assert.Equal(t, medium2, mediumConf[1]) +} + +func TestTaintSummaryGetLowConfidenceDetections(t *testing.T) { + summary := NewTaintSummary("test.function") + + high := &TaintInfo{Confidence: 1.0} + medium := &TaintInfo{Confidence: 0.6} + low1 := &TaintInfo{Confidence: 0.4} + low2 := &TaintInfo{Confidence: 0.1} + + summary.AddDetection(high) + summary.AddDetection(low1) + summary.AddDetection(medium) + summary.AddDetection(low2) + + lowConf := summary.GetLowConfidenceDetections() + assert.Equal(t, 2, len(lowConf)) + assert.Equal(t, low1, lowConf[0]) + assert.Equal(t, low2, lowConf[1]) +} + +func TestTaintSummaryMarkTaintedParam(t *testing.T) { + summary := NewTaintSummary("test.function") + + // Mark first param + summary.MarkTaintedParam("param1") + assert.Equal(t, 1, len(summary.TaintedParams)) + assert.Equal(t, "param1", summary.TaintedParams[0]) + + // Mark second param + summary.MarkTaintedParam("param2") + assert.Equal(t, 2, len(summary.TaintedParams)) + assert.Equal(t, "param2", summary.TaintedParams[1]) + + // Try to mark same param again (should not duplicate) + summary.MarkTaintedParam("param1") + assert.Equal(t, 2, len(summary.TaintedParams)) + + // Test empty param name (should be ignored) + summary.MarkTaintedParam("") + assert.Equal(t, 2, len(summary.TaintedParams)) +} + +func TestTaintSummaryIsParamTainted(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.False(t, summary.IsParamTainted("param1")) + + summary.MarkTaintedParam("param1") + assert.True(t, summary.IsParamTainted("param1")) + assert.False(t, summary.IsParamTainted("param2")) + + summary.MarkTaintedParam("param2") + assert.True(t, summary.IsParamTainted("param2")) +} + +func TestTaintSummaryMarkReturnTainted(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.False(t, summary.TaintedReturn) + assert.Nil(t, summary.ReturnTaintInfo) + + taint := &TaintInfo{ + SourceLine: 1, + Confidence: 1.0, + } + + summary.MarkReturnTainted(taint) + assert.True(t, summary.TaintedReturn) + assert.Equal(t, taint, summary.ReturnTaintInfo) +} + +func TestTaintSummarySetError(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.False(t, summary.AnalysisError) + assert.Equal(t, "", summary.ErrorMessage) + + summary.SetError("parse error") + assert.True(t, summary.AnalysisError) + assert.Equal(t, "parse error", summary.ErrorMessage) +} + +func TestTaintSummaryIsComplete(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.True(t, summary.IsComplete()) + + summary.SetError("error") + assert.False(t, summary.IsComplete()) +} + +func TestTaintSummaryGetTaintedVarCount(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.Equal(t, 0, summary.GetTaintedVarCount()) + + // Add tainted variable + taint1 := &TaintInfo{Confidence: 1.0, Sanitized: false} + summary.AddTaintedVar("x", taint1) + assert.Equal(t, 1, summary.GetTaintedVarCount()) + + // Add another taint to same variable (should still count as 1) + taint2 := &TaintInfo{Confidence: 0.7, Sanitized: false} + summary.AddTaintedVar("x", taint2) + assert.Equal(t, 1, summary.GetTaintedVarCount()) + + // Add tainted second variable + summary.AddTaintedVar("y", taint1) + assert.Equal(t, 2, summary.GetTaintedVarCount()) + + // Add sanitized variable (should not count) + sanitized := &TaintInfo{Confidence: 1.0, Sanitized: true} + summary.AddTaintedVar("z", sanitized) + assert.Equal(t, 2, summary.GetTaintedVarCount()) +} + +func TestTaintSummaryGetDetectionCount(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.Equal(t, 0, summary.GetDetectionCount()) + + detection1 := &TaintInfo{Confidence: 1.0} + summary.AddDetection(detection1) + assert.Equal(t, 1, summary.GetDetectionCount()) + + detection2 := &TaintInfo{Confidence: 0.8} + summary.AddDetection(detection2) + assert.Equal(t, 2, summary.GetDetectionCount()) +} + +func TestTaintSummaryComplexScenario(t *testing.T) { + // Simulate a real security finding scenario + summary := NewTaintSummary("app.views.process_payment") + + // Taint flows from user input + userInputTaint := &TaintInfo{ + SourceLine: 10, + SourceVar: "request.GET['amount']", + SinkLine: 25, + SinkVar: "query", + SinkCall: "cursor.execute", + PropagationPath: []string{"user_amount", "amount", "query"}, + Confidence: 1.0, + Sanitized: false, + } + + // Track the variable propagation + summary.AddTaintedVar("user_amount", &TaintInfo{ + SourceLine: 10, + SourceVar: "request.GET['amount']", + Confidence: 1.0, + }) + + summary.AddTaintedVar("amount", &TaintInfo{ + SourceLine: 15, + SourceVar: "user_amount", + Confidence: 1.0, + }) + + summary.AddTaintedVar("query", &TaintInfo{ + SourceLine: 20, + SourceVar: "amount", + Confidence: 1.0, + }) + + // Record the detection + summary.AddDetection(userInputTaint) + + // Mark the request parameter as tainted + summary.MarkTaintedParam("request") + + // Verify the summary + assert.True(t, summary.IsTainted("user_amount")) + assert.True(t, summary.IsTainted("amount")) + assert.True(t, summary.IsTainted("query")) + assert.Equal(t, 3, summary.GetTaintedVarCount()) + assert.True(t, summary.HasDetections()) + assert.Equal(t, 1, summary.GetDetectionCount()) + assert.Equal(t, 1, len(summary.GetHighConfidenceDetections())) + assert.True(t, summary.IsParamTainted("request")) + assert.True(t, summary.IsComplete()) + + // Verify the detection details + detection := summary.Detections[0] + assert.Equal(t, uint32(10), detection.SourceLine) + assert.Equal(t, uint32(25), detection.SinkLine) + assert.Equal(t, "cursor.execute", detection.SinkCall) + assert.Equal(t, 3, len(detection.PropagationPath)) + assert.True(t, detection.IsHighConfidence()) + assert.False(t, detection.Sanitized) +}