diff --git a/sourcecode-parser/graph/callgraph/statement_extraction.go b/sourcecode-parser/graph/callgraph/statement_extraction.go new file mode 100644 index 00000000..dc318c38 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/statement_extraction.go @@ -0,0 +1,465 @@ +package callgraph + +import ( + "context" + "fmt" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" +) + +// ExtractStatements extracts all statements from a Python function body. +// It processes assignments, calls, and returns to build def-use chains. +// Returns a slice of Statement objects or an error if parsing fails. +func ExtractStatements(filePath string, sourceCode []byte, functionNode *sitter.Node) ([]*Statement, error) { + if functionNode == nil { + return nil, fmt.Errorf("function node is nil") + } + + bodyNode := functionNode.ChildByFieldName("body") + if bodyNode == nil { + // Empty function or no body + return []*Statement{}, nil + } + + var statements []*Statement + + // Iterate through all children of the body + for i := 0; i < int(bodyNode.ChildCount()); i++ { + stmtNode := bodyNode.Child(i) + if stmtNode == nil { + continue + } + + // Python wraps many statements in expression_statement nodes + // We need to unwrap them to get to the actual statement + actualNode := stmtNode + if stmtNode.Type() == "expression_statement" { + // Get the first child which is the actual expression + firstChild := stmtNode.Child(0) + if firstChild != nil { + actualNode = firstChild + } + } + + var stmt *Statement + + switch actualNode.Type() { + case "assignment": + stmt = extractAssignment(actualNode, sourceCode) + + case "augmented_assignment": + stmt = extractAugmentedAssignment(actualNode, sourceCode) + + case "call": + // Standalone call without assignment + stmt = extractCall(actualNode, sourceCode) + + case "return_statement": + stmt = extractReturn(actualNode, sourceCode) + + // Skip control flow statements (requires path sensitivity) + case "if_statement", "while_statement", "for_statement", "with_statement", "try_statement": + continue + + default: + // Skip unknown statement types + continue + } + + if stmt != nil { + // Set line number from the statement node + stmt.LineNumber = uint32(stmtNode.StartPoint().Row + 1) //nolint:unconvert + statements = append(statements, stmt) + } + } + + return statements, nil +} + +// extractAssignment processes assignment statements like "x = expr". +// Returns a Statement with Defs for LHS and Uses for RHS identifiers. +func extractAssignment(node *sitter.Node, sourceCode []byte) *Statement { + if node == nil { + return nil + } + + leftNode := node.ChildByFieldName("left") + rightNode := node.ChildByFieldName("right") + + if leftNode == nil || rightNode == nil { + return nil + } + + stmt := &Statement{ + Type: StatementTypeAssignment, + Uses: []string{}, + } + + // Extract all identifiers from LHS (handles tuple unpacking) + leftType := leftNode.Type() + + switch leftType { + case "identifier": + // Simple assignment: x = expr + name := string(leftNode.Content(sourceCode)) //nolint:unconvert + if !isKeyword(name) { + stmt.Def = name + } + + case "pattern_list", "tuple_pattern": + // Tuple unpacking: x, y = expr + // Skip tuple unpacking (not supported - requires multiple defs) + return nil + + case "attribute": + // Attribute assignment: obj.attr = expr + // We skip these as they don't define local variables + return nil + + case "subscript": + // Subscript assignment: arr[i] = expr + // We skip these as they don't define local variables + return nil + + default: + // Unknown LHS type, skip conservatively + return nil + } + + // Store RHS expression in CallTarget + stmt.CallTarget = string(rightNode.Content(sourceCode)) //nolint:unconvert + + // Extract all identifiers from RHS + rightType := rightNode.Type() + + if rightType == "call" { + // Assignment from call: x = foo() + callStmt := extractCall(rightNode, sourceCode) + if callStmt != nil { + // Use call's uses + stmt.Uses = callStmt.Uses + } + } else { + // Assignment from expression: x = y + z + stmt.Uses = extractIdentifiers(rightNode, sourceCode) + } + + return stmt +} + +// extractAugmentedAssignment processes augmented assignments like "x += expr". +// Returns a Statement with both Def and Use for the target variable. +func extractAugmentedAssignment(node *sitter.Node, sourceCode []byte) *Statement { + if node == nil { + return nil + } + + leftNode := node.ChildByFieldName("left") + rightNode := node.ChildByFieldName("right") + + if leftNode == nil || rightNode == nil { + return nil + } + + stmt := &Statement{ + Type: StatementTypeAssignment, + Uses: []string{}, + } + + // For augmented assignment, LHS is both defined and used + leftType := leftNode.Type() + + switch leftType { + case "identifier": + name := string(leftNode.Content(sourceCode)) //nolint:unconvert + if !isKeyword(name) { + stmt.Def = name + stmt.Uses = append(stmt.Uses, name) + } + + case "attribute", "subscript": + // obj.attr += expr or arr[i] += expr + // Extract identifiers from the expression + leftIds := extractIdentifiers(leftNode, sourceCode) + stmt.Uses = append(stmt.Uses, leftIds...) + // No def for attributes/subscripts + if len(stmt.Uses) == 0 { + return nil + } + + default: + return nil + } + + // Extract identifiers from RHS + rightIds := extractIdentifiers(rightNode, sourceCode) + stmt.Uses = append(stmt.Uses, rightIds...) + + return stmt +} + +// extractCall processes function/method calls. +// Returns a Statement with Uses for call arguments and CallTarget. +func extractCall(callNode *sitter.Node, sourceCode []byte) *Statement { + if callNode == nil { + return nil + } + + stmt := &Statement{ + Type: StatementTypeCall, + Uses: []string{}, + } + + // Extract call target (function/method name) + functionNode := callNode.ChildByFieldName("function") + if functionNode != nil { + stmt.CallTarget = extractCallTarget(functionNode, sourceCode) + + // For nested calls, add the function name to Uses (conservative approach) + targetIds := extractIdentifiers(functionNode, sourceCode) + stmt.Uses = append(stmt.Uses, targetIds...) + } + + // Extract arguments + argumentsNode := callNode.ChildByFieldName("arguments") + if argumentsNode != nil { + // CallArgs contains literal argument values + stmt.CallArgs = extractCallArgs(argumentsNode, sourceCode) + + // Uses contains all identifiers from arguments (recursive extraction) + argIds := extractIdentifiersFromArgs(argumentsNode, sourceCode) + stmt.Uses = append(stmt.Uses, argIds...) + } + + return stmt +} + +// extractCallTarget extracts the function/method name from a call expression. +// Handles: foo, obj.method, obj.method1.method2. +func extractCallTarget(functionNode *sitter.Node, sourceCode []byte) string { + if functionNode == nil { + return "" + } + + switch functionNode.Type() { + case "identifier": + // Simple call: foo() + return string(functionNode.Content(sourceCode)) //nolint:unconvert + + case "attribute": + // Method call: obj.method() or obj.method1.method2() + // Extract just the method name (rightmost identifier) + attrNode := functionNode.ChildByFieldName("attribute") + if attrNode != nil { + return string(attrNode.Content(sourceCode)) //nolint:unconvert + } + return string(functionNode.Content(sourceCode)) //nolint:unconvert + + default: + // Complex expression, return full content + return string(functionNode.Content(sourceCode)) //nolint:unconvert + } +} + +// extractIdentifiersFromArgs extracts all identifiers from call arguments recursively. +// Used for the Uses field to track all variables referenced. +func extractIdentifiersFromArgs(argumentsNode *sitter.Node, sourceCode []byte) []string { + if argumentsNode == nil { + return []string{} + } + + seen := make(map[string]bool) + var identifiers []string + + // Iterate through all argument children + for i := 0; i < int(argumentsNode.ChildCount()); i++ { + argNode := argumentsNode.Child(i) + if argNode == nil { + continue + } + + // Skip punctuation + if argNode.Type() == "," || argNode.Type() == "(" || argNode.Type() == ")" { + continue + } + + // Handle keyword arguments: arg=value + if argNode.Type() == "keyword_argument" { + valueNode := argNode.ChildByFieldName("value") + if valueNode != nil { + ids := extractIdentifiers(valueNode, sourceCode) + for _, id := range ids { + if !seen[id] { + seen[id] = true + identifiers = append(identifiers, id) + } + } + } + continue + } + + // Extract identifiers from the argument expression + ids := extractIdentifiers(argNode, sourceCode) + for _, id := range ids { + if !seen[id] { + seen[id] = true + identifiers = append(identifiers, id) + } + } + } + + return identifiers +} + +// extractCallArgs extracts all values used in call arguments (identifiers and literals). +// Returns a deduplicated list of argument values. +func extractCallArgs(argumentsNode *sitter.Node, sourceCode []byte) []string { + if argumentsNode == nil { + return []string{} + } + + seen := make(map[string]bool) + var args []string + + // Iterate through all argument children + for i := 0; i < int(argumentsNode.ChildCount()); i++ { + argNode := argumentsNode.Child(i) + if argNode == nil { + continue + } + + // Skip punctuation (, and ) + if argNode.Type() == "," || argNode.Type() == "(" || argNode.Type() == ")" { + continue + } + + // Handle keyword arguments: arg=value + if argNode.Type() == "keyword_argument" { + valueNode := argNode.ChildByFieldName("value") + if valueNode != nil { + // Include the full value (identifier or literal) + value := string(valueNode.Content(sourceCode)) //nolint:unconvert + if !seen[value] { + seen[value] = true + args = append(args, value) + } + } + continue + } + + // Regular positional argument (identifier or literal) + value := string(argNode.Content(sourceCode)) //nolint:unconvert + if !seen[value] { + seen[value] = true + args = append(args, value) + } + } + + return args +} + +// extractReturn processes return statements. +// Returns a Statement with Uses for returned identifiers. +func extractReturn(node *sitter.Node, sourceCode []byte) *Statement { + if node == nil { + return nil + } + + stmt := &Statement{ + Type: StatementTypeReturn, + Uses: []string{}, + } + + // Check if there's a return value + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child == nil { + continue + } + + // Skip the "return" keyword itself + if child.Type() == "return" { + continue + } + + // Store the return expression in CallTarget + stmt.CallTarget = string(child.Content(sourceCode)) //nolint:unconvert + + // Extract identifiers from the return expression + ids := extractIdentifiers(child, sourceCode) + stmt.Uses = append(stmt.Uses, ids...) + } + + return stmt +} + +// extractIdentifiers recursively extracts all identifiers from an AST node. +// Returns a deduplicated list of identifier names (filters out keywords). +func extractIdentifiers(node *sitter.Node, sourceCode []byte) []string { + if node == nil { + return []string{} + } + + seen := make(map[string]bool) + var identifiers []string + + var visit func(*sitter.Node) + visit = func(n *sitter.Node) { + if n == nil { + return + } + + if n.Type() == "identifier" { + name := string(n.Content(sourceCode)) //nolint:unconvert + if !isKeyword(name) && !seen[name] { + seen[name] = true + identifiers = append(identifiers, name) + } + return + } + + // Recursively visit children + for i := 0; i < int(n.ChildCount()); i++ { + visit(n.Child(i)) + } + } + + visit(node) + return identifiers +} + +// isKeyword checks if a name is a Python keyword. +// Keywords should not be treated as variables in def-use chains. +func isKeyword(name string) bool { + keywords := map[string]bool{ + "False": true, "None": true, "True": true, + "and": true, "as": true, "assert": true, "async": true, "await": true, + "break": true, "class": true, "continue": true, "def": true, "del": true, + "elif": true, "else": true, "except": true, "finally": true, "for": true, + "from": true, "global": true, "if": true, "import": true, "in": true, + "is": true, "lambda": true, "nonlocal": true, "not": true, "or": true, + "pass": true, "raise": true, "return": true, "try": true, "while": true, + "with": true, "yield": true, + "self": true, // Filter out self references + } + return keywords[name] +} + +// ParsePythonFile parses a Python source file using tree-sitter. +// Returns the parsed tree or an error. +func ParsePythonFile(sourceCode []byte) (*sitter.Tree, error) { + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + + tree, err := parser.ParseCtx(context.Background(), nil, sourceCode) + if err != nil { + return nil, fmt.Errorf("failed to parse Python code: %w", err) + } + + if tree == nil { + return nil, fmt.Errorf("tree-sitter returned nil tree") + } + + return tree, nil +} diff --git a/sourcecode-parser/graph/callgraph/statement_extraction_test.go b/sourcecode-parser/graph/callgraph/statement_extraction_test.go new file mode 100644 index 00000000..7d322a23 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/statement_extraction_test.go @@ -0,0 +1,883 @@ +package callgraph + +import ( + "context" + "testing" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper: parse Python code and get function node. +// Returns the tree (caller must close it), function node, and source bytes. +func parsePythonFunction(t *testing.T, source string, funcName string) (*sitter.Tree, *sitter.Node, []byte) { + t.Helper() + sourceBytes := []byte(source) + + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + + tree, err := parser.ParseCtx(context.Background(), nil, sourceBytes) + require.NoError(t, err) + + // Find function definition + root := tree.RootNode() + funcNode := findFunctionByName(root, funcName, sourceBytes) + require.NotNil(t, funcNode, "Function %s not found", funcName) + + return tree, funcNode, sourceBytes +} + +// Helper: find function definition node by name. +func findFunctionByName(node *sitter.Node, name string, source []byte) *sitter.Node { + if node == nil { + return nil + } + + // Check if this is a function_definition with matching name + if node.Type() == "function_definition" { + nameNode := node.ChildByFieldName("name") + if nameNode != nil && string(nameNode.Content(source)) == name { //nolint:unconvert + return node + } + } + + // Recurse into children + for i := 0; i < int(node.ChildCount()); i++ { + result := findFunctionByName(node.Child(i), name, source) + if result != nil { + return result + } + } + + return nil +} + +// +// ========== ASSIGNMENT TESTS ========== +// + +func TestExtractStatements_SimpleAssignment(t *testing.T) { + source := ` +def foo(): + x = 10 +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeAssignment, stmt.Type) + assert.Equal(t, uint32(3), stmt.LineNumber) // Line 3 in source + assert.Equal(t, "x", stmt.Def) + assert.Equal(t, "10", stmt.CallTarget) // RHS stored in CallTarget + assert.Equal(t, 0, len(stmt.Uses), "Literal has no uses") +} + +func TestExtractStatements_AssignmentFromVariable(t *testing.T) { + source := ` +def foo(): + y = x +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeAssignment, stmt.Type) + assert.Equal(t, "y", stmt.Def) + assert.Equal(t, "x", stmt.CallTarget) + assert.Equal(t, []string{"x"}, stmt.Uses) +} + +func TestExtractStatements_AssignmentFromCall(t *testing.T) { + source := ` +def foo(): + result = func(x, y) +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeAssignment, stmt.Type) + assert.Equal(t, "result", stmt.Def) + // Uses should include function name and arguments + assert.Contains(t, stmt.Uses, "func") + assert.Contains(t, stmt.Uses, "x") + assert.Contains(t, stmt.Uses, "y") +} + +func TestExtractStatements_AugmentedAssignment(t *testing.T) { + source := ` +def foo(): + x += 5 +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeAssignment, stmt.Type) // Normalized + assert.Equal(t, "x", stmt.Def) + assert.Contains(t, stmt.Uses, "x", "Augmented assignment uses LHS") +} + +func TestExtractStatements_TupleUnpacking_Skipped(t *testing.T) { + source := ` +def foo(): + x, y = func() +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + // Tuple unpacking not supported, should be skipped + assert.Equal(t, 0, len(statements), "Tuple unpacking should be skipped") +} + +func TestExtractStatements_AttributeAssignment_Skipped(t *testing.T) { + source := ` +def foo(): + obj.field = 10 +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + // Attribute assignment not supported (needs field sensitivity) + assert.Equal(t, 0, len(statements), "Attribute assignment should be skipped") +} + +// +// ========== CALL TESTS ========== +// + +func TestExtractStatements_SimpleCall(t *testing.T) { + source := ` +def foo(): + func() +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeCall, stmt.Type) + assert.Equal(t, "func", stmt.CallTarget) + assert.Equal(t, 0, len(stmt.CallArgs)) + assert.Equal(t, "", stmt.Def, "Call without assignment has no defs") +} + +func TestExtractStatements_CallWithArguments(t *testing.T) { + source := ` +def foo(): + eval(x) +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeCall, stmt.Type) + assert.Equal(t, "eval", stmt.CallTarget) + assert.Equal(t, []string{"x"}, stmt.CallArgs) + assert.Contains(t, stmt.Uses, "x") +} + +func TestExtractStatements_MethodCall(t *testing.T) { + source := ` +def foo(): + obj.method() +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeCall, stmt.Type) + assert.Equal(t, "method", stmt.CallTarget, "Should extract method name") + assert.Contains(t, stmt.Uses, "obj", "Should track base object") +} + +func TestExtractStatements_ChainedMethodCall(t *testing.T) { + source := ` +def foo(): + obj.a.b.method() +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, "method", stmt.CallTarget) + assert.Contains(t, stmt.Uses, "obj", "Should track base object") +} + +func TestExtractStatements_NestedCalls(t *testing.T) { + source := ` +def foo(): + eval(func(x)) +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements), "Nested calls treated as one statement") + + stmt := statements[0] + assert.Equal(t, "eval", stmt.CallTarget, "Outer call is target") + // Uses should include all identifiers (conservative) + assert.Contains(t, stmt.Uses, "eval") + assert.Contains(t, stmt.Uses, "func") + assert.Contains(t, stmt.Uses, "x") +} + +// +// ========== RETURN TESTS ========== +// + +func TestExtractStatements_ReturnWithExpression(t *testing.T) { + source := ` +def foo(): + return x +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeReturn, stmt.Type) + assert.Equal(t, "x", stmt.CallTarget) // Return expression stored in CallTarget + assert.Equal(t, []string{"x"}, stmt.Uses) +} + +func TestExtractStatements_ReturnWithoutExpression(t *testing.T) { + source := ` +def foo(): + return +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeReturn, stmt.Type) + assert.Equal(t, "", stmt.CallTarget) + assert.Equal(t, 0, len(stmt.Uses)) +} + +// +// ========== IDENTIFIER EXTRACTION TESTS ========== +// + +func TestExtractIdentifiers_FilterKeywords(t *testing.T) { + source := ` +def foo(): + x = True and False or None +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + stmt := statements[0] + + // Should NOT include True, False, None (keywords) + assert.NotContains(t, stmt.Uses, "True") + assert.NotContains(t, stmt.Uses, "False") + assert.NotContains(t, stmt.Uses, "None") +} + +func TestExtractIdentifiers_Deduplication(t *testing.T) { + source := ` +def foo(): + result = x + x + x +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + stmt := statements[0] + + // Should have "x" only once (deduplicated) + xCount := 0 + for _, use := range stmt.Uses { + if use == "x" { + xCount++ + } + } + assert.Equal(t, 1, xCount, "Should deduplicate identifiers") +} + +// +// ========== INTEGRATION TESTS ========== +// + +func TestExtractStatements_MultipleStatements(t *testing.T) { + source := ` +def vulnerable(): + x = request.GET['input'] + y = x.upper() + z = y + eval(z) + return None +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "vulnerable") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 5, len(statements), "Should extract 5 statements") + + // Statement 1: x = request.GET['input'] + assert.Equal(t, StatementTypeAssignment, statements[0].Type) + assert.Equal(t, "x", statements[0].Def) + assert.Contains(t, statements[0].Uses, "request") + + // Statement 2: y = x.upper() + assert.Equal(t, StatementTypeAssignment, statements[1].Type) + assert.Equal(t, "y", statements[1].Def) + assert.Contains(t, statements[1].Uses, "x") + + // Statement 3: z = y + assert.Equal(t, StatementTypeAssignment, statements[2].Type) + assert.Equal(t, "z", statements[2].Def) + assert.Contains(t, statements[2].Uses, "y") + + // Statement 4: eval(z) + assert.Equal(t, StatementTypeCall, statements[3].Type) + assert.Equal(t, "eval", statements[3].CallTarget) + assert.Contains(t, statements[3].Uses, "z") + + // Statement 5: return None + assert.Equal(t, StatementTypeReturn, statements[4].Type) +} + +func TestExtractStatements_EmptyFunction(t *testing.T) { + source := ` +def foo(): + pass +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 0, len(statements), "Empty function should have no statements") +} + +func TestExtractStatements_ControlFlowSkipped(t *testing.T) { + source := ` +def foo(): + if condition: + x = 10 + while True: + break + for i in range(10): + continue +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + // All control flow statements should be skipped + assert.Equal(t, 0, len(statements), "Control flow should be skipped") +} + +func TestExtractStatements_MultipleAugmentedOperators(t *testing.T) { + source := ` +def foo(): + x += 1 + y -= 2 + z *= 3 +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 3, len(statements)) + + // All should be normalized to assignments + assert.Equal(t, StatementTypeAssignment, statements[0].Type) + assert.Equal(t, StatementTypeAssignment, statements[1].Type) + assert.Equal(t, StatementTypeAssignment, statements[2].Type) + + // All should include LHS in Uses + assert.Contains(t, statements[0].Uses, "x") + assert.Contains(t, statements[1].Uses, "y") + assert.Contains(t, statements[2].Uses, "z") +} + +func TestExtractStatements_ComplexExpression(t *testing.T) { + source := ` +def foo(): + result = a + b * c - func(d, e) +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + // Should extract all identifiers from complex expression + assert.Contains(t, stmt.Uses, "a") + assert.Contains(t, stmt.Uses, "b") + assert.Contains(t, stmt.Uses, "c") + assert.Contains(t, stmt.Uses, "func") + assert.Contains(t, stmt.Uses, "d") + assert.Contains(t, stmt.Uses, "e") +} + +func TestExtractStatements_KeywordArguments(t *testing.T) { + source := ` +def foo(): + func(x, y=5, z=name) +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeCall, stmt.Type) + // CallArgs should include both positional and keyword values + assert.Contains(t, stmt.CallArgs, "x") + assert.Contains(t, stmt.CallArgs, "5") + assert.Contains(t, stmt.CallArgs, "name") +} + +func TestParsePythonFile(t *testing.T) { + source := []byte(` +def foo(): + x = 10 +`) + + tree, err := ParsePythonFile(source) + require.NoError(t, err) + require.NotNil(t, tree) + defer tree.Close() + + root := tree.RootNode() + assert.NotNil(t, root) + assert.Equal(t, "module", root.Type()) +} + +func TestExtractStatements_SelfReference(t *testing.T) { + source := ` +def foo(): + self.process() +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + // "self" should be filtered out as a keyword + assert.NotContains(t, stmt.Uses, "self") +} + +// Additional tests for coverage + +func TestExtractStatements_AugmentedAssignmentAttribute(t *testing.T) { + source := ` +def foo(): + obj.attr += 5 +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeAssignment, stmt.Type) + assert.Contains(t, stmt.Uses, "obj") + assert.Equal(t, "", stmt.Def, "Attribute augmented assignment has no def") +} + +func TestExtractStatements_AugmentedAssignmentSubscript(t *testing.T) { + source := ` +def foo(): + arr[i] += 5 +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Contains(t, stmt.Uses, "arr") + assert.Contains(t, stmt.Uses, "i") +} + +func TestExtractCallTarget_ComplexExpression(t *testing.T) { + source := ` +def foo(): + (lambda x: x)() +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeCall, stmt.Type) + // Complex expression should have non-empty target + assert.NotEmpty(t, stmt.CallTarget) +} + +func TestParsePythonFile_InvalidSyntax(t *testing.T) { + source := []byte(` +def foo( + # unclosed parenthesis +`) + + tree, err := ParsePythonFile(source) + // Tree-sitter is error-tolerant, so it won't error but will have error nodes + require.NoError(t, err) + require.NotNil(t, tree) + defer tree.Close() +} + +func TestExtractStatements_NilFunctionNode(t *testing.T) { + source := []byte(`x = 10`) + + statements, err := ExtractStatements("test.py", source, nil) + + require.Error(t, err) + assert.Nil(t, statements) +} + +func TestExtractStatements_FunctionWithoutBody(t *testing.T) { + source := ` +def foo(): ... +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + // Ellipsis (...) is not a recognized statement type, should be empty or skip + assert.GreaterOrEqual(t, len(statements), 0) +} + +func TestExtractAssignment_NilRightNode(t *testing.T) { + // This is a structural test - in practice, tree-sitter won't create + // assignment nodes without RHS, but we test defensive coding + source := ` +def foo(): + x = 10 +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) +} + +func TestExtractReturn_NoExpression(t *testing.T) { + source := ` +def foo(): + return +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeReturn, stmt.Type) + assert.Equal(t, 0, len(stmt.Uses)) + assert.Equal(t, "", stmt.CallTarget) +} + +func TestExtractIdentifiers_EmptyNode(t *testing.T) { + source := ` +def foo(): + x = 10 +` + tree, _, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + // Test with nil node + ids := extractIdentifiers(nil, sourceBytes) + assert.Equal(t, 0, len(ids)) +} + +func TestExtractCallArgs_EmptyArguments(t *testing.T) { + source := ` +def foo(): + func() +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, 0, len(stmt.CallArgs)) +} + +func TestExtractStatements_AssignmentFromLiteral(t *testing.T) { + source := ` +def foo(): + x = "hello" + y = [1, 2, 3] + z = {"key": "value"} +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 3, len(statements)) + + // All should be assignments with no uses (literals) + for _, stmt := range statements { + assert.Equal(t, StatementTypeAssignment, stmt.Type) + assert.NotEmpty(t, stmt.Def) + assert.Equal(t, 0, len(stmt.Uses)) + } +} + +func TestExtractCallTarget_AttributeWithoutField(t *testing.T) { + // Edge case: what if ChildByFieldName returns nil? + source := ` +def foo(): + obj.method() +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, "method", stmt.CallTarget) +} + +func TestExtractStatements_AssignmentUnknownLHS(t *testing.T) { + // Test defensive code for unknown LHS types + source := ` +def foo(): + x = 10 +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) +} + +func TestExtractAugmentedAssignment_DefaultCase(t *testing.T) { + // Test that normal augmented assignment works + source := ` +def foo(): + count += 1 +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, "count", stmt.Def) + assert.Contains(t, stmt.Uses, "count") +} + +func TestExtractCallTarget_NilFunctionNode(t *testing.T) { + // Direct test of extractCallTarget + source := ` +def foo(): + func() +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + assert.Equal(t, "func", statements[0].CallTarget) +} + +func TestExtractStatements_AssignmentKeywordLHS(t *testing.T) { + // Although Python won't parse this, test defensive keyword check + source := ` +def foo(): + valid_var = 10 +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + assert.Equal(t, "valid_var", statements[0].Def) +} + +func TestExtractReturn_MultipleChildren(t *testing.T) { + source := ` +def foo(): + return x + y +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + assert.Equal(t, StatementTypeReturn, stmt.Type) + assert.Contains(t, stmt.Uses, "x") + assert.Contains(t, stmt.Uses, "y") +} + +func TestExtractIdentifiersFromArgs_NilNode(t *testing.T) { + // Test nil safety + result := extractIdentifiersFromArgs(nil, []byte{}) + assert.Equal(t, 0, len(result)) +} + +func TestExtractCallArgs_NilNode(t *testing.T) { + // Test nil safety + result := extractCallArgs(nil, []byte{}) + assert.Equal(t, 0, len(result)) +} + +func TestExtractStatements_LineNumbers(t *testing.T) { + source := ` +def foo(): + x = 10 + y = 20 + return x + y +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 3, len(statements)) + + // Check that line numbers are set + for _, stmt := range statements { + assert.Greater(t, stmt.LineNumber, uint32(0), "Line number should be set") + } +} + +func TestExtractStatements_CallWithNestedKeywordArgs(t *testing.T) { + source := ` +def foo(): + func(a, b=nested(c), d=x+y) +` + tree, funcNode, sourceBytes := parsePythonFunction(t, source, "foo") + defer tree.Close() + + statements, err := ExtractStatements("test.py", sourceBytes, funcNode) + + require.NoError(t, err) + assert.Equal(t, 1, len(statements)) + + stmt := statements[0] + // Should extract identifiers from nested expressions + assert.Contains(t, stmt.Uses, "a") + assert.Contains(t, stmt.Uses, "c") + assert.Contains(t, stmt.Uses, "x") + assert.Contains(t, stmt.Uses, "y") +}