diff --git a/sourcecode-parser/graph/callgraph/imports.go b/sourcecode-parser/graph/callgraph/imports.go new file mode 100644 index 00000000..d983807a --- /dev/null +++ b/sourcecode-parser/graph/callgraph/imports.go @@ -0,0 +1,172 @@ +package callgraph + +import ( + "context" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" +) + +// ExtractImports extracts all import statements from a Python file and builds an ImportMap. +// It handles three main import styles: +// 1. Simple imports: import module +// 2. From imports: from module import name +// 3. Aliased imports: from module import name as alias +// +// The resulting ImportMap maps local names (aliases or imported names) to their +// fully qualified module paths, enabling later resolution of function calls. +// +// Algorithm: +// 1. Parse source code with tree-sitter Python parser +// 2. Execute tree-sitter query to find all import statements +// 3. Process each import match to extract module paths and aliases +// 4. Build ImportMap with resolved fully qualified names +// +// Parameters: +// - filePath: absolute path to the Python file being analyzed +// - sourceCode: contents of the Python file as byte array +// - registry: module registry for resolving module paths +// +// Returns: +// - ImportMap: map of local names to fully qualified module paths +// - error: if parsing fails or source is invalid +// +// Example: +// +// Source code: +// import os +// from myapp.utils import sanitize +// from myapp.db import query as db_query +// +// Result ImportMap: +// { +// "os": "os", +// "sanitize": "myapp.utils.sanitize", +// "db_query": "myapp.db.query" +// } +func ExtractImports(filePath string, sourceCode []byte, registry *ModuleRegistry) (*ImportMap, error) { + importMap := NewImportMap(filePath) + + // Parse with tree-sitter + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, sourceCode) + if err != nil { + return nil, err + } + defer tree.Close() + + // Traverse AST to find import statements + traverseForImports(tree.RootNode(), sourceCode, importMap) + + return importMap, nil +} + +// traverseForImports recursively traverses the AST to find import statements. +// Uses direct AST traversal instead of queries for better compatibility. +func traverseForImports(node *sitter.Node, sourceCode []byte, importMap *ImportMap) { + if node == nil { + return + } + + nodeType := node.Type() + + // Process import statements + switch nodeType { + case "import_statement": + processImportStatement(node, sourceCode, importMap) + // Don't recurse into children - we've already processed this import + return + case "import_from_statement": + processImportFromStatement(node, sourceCode, importMap) + // Don't recurse into children - we've already processed this import + return + } + + // Recursively process children + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + traverseForImports(child, sourceCode, importMap) + } +} + +// processImportStatement handles simple import statements: import module [as alias]. +// Examples: +// - import os → "os" = "os" +// - import os as op → "op" = "os" +func processImportStatement(node *sitter.Node, sourceCode []byte, importMap *ImportMap) { + // Look for 'name' field which contains the import + nameNode := node.ChildByFieldName("name") + if nameNode == nil { + return + } + + // Check if it's an aliased import + if nameNode.Type() == "aliased_import" { + // import module as alias + moduleNode := nameNode.ChildByFieldName("name") + aliasNode := nameNode.ChildByFieldName("alias") + + if moduleNode != nil && aliasNode != nil { + moduleName := moduleNode.Content(sourceCode) + aliasName := aliasNode.Content(sourceCode) + importMap.AddImport(aliasName, moduleName) + } + } else if nameNode.Type() == "dotted_name" { + // Simple import: import module + moduleName := nameNode.Content(sourceCode) + importMap.AddImport(moduleName, moduleName) + } +} + +// processImportFromStatement handles from-import statements: from module import name [as alias]. +// Examples: +// - from os import path → "path" = "os.path" +// - from os import path as ospath → "ospath" = "os.path" +// - from json import dumps, loads → "dumps" = "json.dumps", "loads" = "json.loads" +func processImportFromStatement(node *sitter.Node, sourceCode []byte, importMap *ImportMap) { + // Get the module being imported from + moduleNameNode := node.ChildByFieldName("module_name") + if moduleNameNode == nil { + return + } + + moduleName := moduleNameNode.Content(sourceCode) + + // The 'name' field might be: + // 1. A single dotted_name: from os import path + // 2. A single aliased_import: from os import path as ospath + // 3. A wildcard_import: from os import * + // + // For multiple imports (from json import dumps, loads), tree-sitter + // creates multiple child nodes, so we need to check all children + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + + // Skip the module_name node itself - we only want the imported names + if child == moduleNameNode { + continue + } + + // Process each import name/alias + if child.Type() == "aliased_import" { + // from module import name as alias + importNameNode := child.ChildByFieldName("name") + aliasNode := child.ChildByFieldName("alias") + + if importNameNode != nil && aliasNode != nil { + importName := importNameNode.Content(sourceCode) + aliasName := aliasNode.Content(sourceCode) + fqn := moduleName + "." + importName + importMap.AddImport(aliasName, fqn) + } + } else if child.Type() == "dotted_name" || child.Type() == "identifier" { + // from module import name + importName := child.Content(sourceCode) + fqn := moduleName + "." + importName + importMap.AddImport(importName, fqn) + } + } +} diff --git a/sourcecode-parser/graph/callgraph/imports_test.go b/sourcecode-parser/graph/callgraph/imports_test.go new file mode 100644 index 00000000..66497332 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/imports_test.go @@ -0,0 +1,388 @@ +package callgraph + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractImports_SimpleImports(t *testing.T) { + // Test simple import statements: import module + sourceCode := []byte(` +import os +import sys +import json +`) + + registry := NewModuleRegistry() + importMap, err := ExtractImports("/test/file.py", sourceCode, registry) + + require.NoError(t, err) + require.NotNil(t, importMap) + + // Verify all simple imports are captured + assert.Equal(t, 3, len(importMap.Imports)) + + fqn, ok := importMap.Resolve("os") + assert.True(t, ok) + assert.Equal(t, "os", fqn) + + fqn, ok = importMap.Resolve("sys") + assert.True(t, ok) + assert.Equal(t, "sys", fqn) + + fqn, ok = importMap.Resolve("json") + assert.True(t, ok) + assert.Equal(t, "json", fqn) +} + +func TestExtractImports_FromImports(t *testing.T) { + // Test from import statements: from module import name + sourceCode := []byte(` +from os import path +from sys import argv +from collections import OrderedDict +`) + + registry := NewModuleRegistry() + importMap, err := ExtractImports("/test/file.py", sourceCode, registry) + + require.NoError(t, err) + require.NotNil(t, importMap) + + // Verify from imports create fully qualified names + assert.Equal(t, 3, len(importMap.Imports)) + + fqn, ok := importMap.Resolve("path") + assert.True(t, ok) + assert.Equal(t, "os.path", fqn) + + fqn, ok = importMap.Resolve("argv") + assert.True(t, ok) + assert.Equal(t, "sys.argv", fqn) + + fqn, ok = importMap.Resolve("OrderedDict") + assert.True(t, ok) + assert.Equal(t, "collections.OrderedDict", fqn) +} + +func TestExtractImports_AliasedSimpleImports(t *testing.T) { + // Test aliased simple imports: import module as alias + sourceCode := []byte(` +import os as operating_system +import sys as system +import json as js +`) + + registry := NewModuleRegistry() + importMap, err := ExtractImports("/test/file.py", sourceCode, registry) + + require.NoError(t, err) + require.NotNil(t, importMap) + + // Verify aliases map to original module names + assert.Equal(t, 3, len(importMap.Imports)) + + fqn, ok := importMap.Resolve("operating_system") + assert.True(t, ok) + assert.Equal(t, "os", fqn) + + fqn, ok = importMap.Resolve("system") + assert.True(t, ok) + assert.Equal(t, "sys", fqn) + + fqn, ok = importMap.Resolve("js") + assert.True(t, ok) + assert.Equal(t, "json", fqn) + + // Original names should NOT be in the map + _, ok = importMap.Resolve("os") + assert.False(t, ok) +} + +func TestExtractImports_AliasedFromImports(t *testing.T) { + // Test aliased from imports: from module import name as alias + sourceCode := []byte(` +from os import path as ospath +from sys import argv as arguments +from collections import OrderedDict as OD +`) + + registry := NewModuleRegistry() + importMap, err := ExtractImports("/test/file.py", sourceCode, registry) + + require.NoError(t, err) + require.NotNil(t, importMap) + + // Verify aliases map to fully qualified names + assert.Equal(t, 3, len(importMap.Imports)) + + fqn, ok := importMap.Resolve("ospath") + assert.True(t, ok) + assert.Equal(t, "os.path", fqn) + + fqn, ok = importMap.Resolve("arguments") + assert.True(t, ok) + assert.Equal(t, "sys.argv", fqn) + + fqn, ok = importMap.Resolve("OD") + assert.True(t, ok) + assert.Equal(t, "collections.OrderedDict", fqn) + + // Original names should NOT be in the map + _, ok = importMap.Resolve("path") + assert.False(t, ok) + _, ok = importMap.Resolve("argv") + assert.False(t, ok) + _, ok = importMap.Resolve("OrderedDict") + assert.False(t, ok) +} + +func TestExtractImports_MixedStyles(t *testing.T) { + // Test mixed import styles in one file + sourceCode := []byte(` +import os +from sys import argv +import json as js +from collections import OrderedDict as OD +`) + + registry := NewModuleRegistry() + importMap, err := ExtractImports("/test/file.py", sourceCode, registry) + + require.NoError(t, err) + require.NotNil(t, importMap) + + assert.Equal(t, 4, len(importMap.Imports)) + + // Simple import + fqn, ok := importMap.Resolve("os") + assert.True(t, ok) + assert.Equal(t, "os", fqn) + + // From import + fqn, ok = importMap.Resolve("argv") + assert.True(t, ok) + assert.Equal(t, "sys.argv", fqn) + + // Aliased simple import + fqn, ok = importMap.Resolve("js") + assert.True(t, ok) + assert.Equal(t, "json", fqn) + + // Aliased from import + fqn, ok = importMap.Resolve("OD") + assert.True(t, ok) + assert.Equal(t, "collections.OrderedDict", fqn) +} + +func TestExtractImports_NestedModules(t *testing.T) { + // Test imports with nested module paths + sourceCode := []byte(` +import xml.etree.ElementTree +from xml.etree import ElementTree +from xml.etree.ElementTree import Element +`) + + registry := NewModuleRegistry() + importMap, err := ExtractImports("/test/file.py", sourceCode, registry) + + require.NoError(t, err) + require.NotNil(t, importMap) + + assert.Equal(t, 3, len(importMap.Imports)) + + // Simple import of nested module + fqn, ok := importMap.Resolve("xml.etree.ElementTree") + assert.True(t, ok) + assert.Equal(t, "xml.etree.ElementTree", fqn) + + // From import of nested module + fqn, ok = importMap.Resolve("ElementTree") + assert.True(t, ok) + assert.Equal(t, "xml.etree.ElementTree", fqn) + + // From import from deeply nested module + fqn, ok = importMap.Resolve("Element") + assert.True(t, ok) + assert.Equal(t, "xml.etree.ElementTree.Element", fqn) +} + +func TestExtractImports_EmptyFile(t *testing.T) { + sourceCode := []byte(` +# Just a comment, no imports +def foo(): + pass +`) + + registry := NewModuleRegistry() + importMap, err := ExtractImports("/test/file.py", sourceCode, registry) + + require.NoError(t, err) + require.NotNil(t, importMap) + assert.Equal(t, 0, len(importMap.Imports)) +} + +func TestExtractImports_InvalidSyntax(t *testing.T) { + // Test with invalid Python syntax + sourceCode := []byte(` +import this is not valid python +`) + + registry := NewModuleRegistry() + importMap, err := ExtractImports("/test/file.py", sourceCode, registry) + + // Tree-sitter is fault-tolerant, so parsing may succeed even with errors + // We just verify it doesn't crash + require.NoError(t, err) + require.NotNil(t, importMap) +} + +func TestExtractImports_WithTestFixtures(t *testing.T) { + tests := []struct { + name string + fixtureFile string + expectedImports map[string]string + expectedCount int + }{ + { + name: "Simple imports fixture", + fixtureFile: "simple_imports.py", + expectedImports: map[string]string{ + "os": "os", + "sys": "sys", + "json": "json", + }, + expectedCount: 3, + }, + { + name: "From imports fixture", + fixtureFile: "from_imports.py", + expectedImports: map[string]string{ + "path": "os.path", + "argv": "sys.argv", + "dumps": "json.dumps", + "loads": "json.loads", + }, + expectedCount: 4, + }, + { + name: "Aliased imports fixture", + fixtureFile: "aliased_imports.py", + expectedImports: map[string]string{ + "operating_system": "os", + "arguments": "sys.argv", + "json_dumps": "json.dumps", + }, + expectedCount: 3, + }, + { + name: "Mixed imports fixture", + fixtureFile: "mixed_imports.py", + expectedImports: map[string]string{ + "os": "os", + "argv": "sys.argv", + "js": "json", + "OD": "collections.OrderedDict", + }, + expectedCount: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fixturePath := filepath.Join("..", "..", "..", "test-src", "python", "imports_test", tt.fixtureFile) + + // Check if fixture exists + if _, err := os.Stat(fixturePath); os.IsNotExist(err) { + t.Skipf("Fixture file not found: %s", fixturePath) + } + + sourceCode, err := os.ReadFile(fixturePath) + require.NoError(t, err) + + registry := NewModuleRegistry() + importMap, err := ExtractImports(fixturePath, sourceCode, registry) + + require.NoError(t, err) + require.NotNil(t, importMap) + + // Check expected count + assert.Equal(t, tt.expectedCount, len(importMap.Imports), + "Expected %d imports, got %d", tt.expectedCount, len(importMap.Imports)) + + // Check each expected import + for alias, expectedFQN := range tt.expectedImports { + fqn, ok := importMap.Resolve(alias) + assert.True(t, ok, "Expected import alias '%s' not found", alias) + assert.Equal(t, expectedFQN, fqn, + "Import '%s' should resolve to '%s', got '%s'", alias, expectedFQN, fqn) + } + }) + } +} + +func TestExtractImports_MultipleImportsPerLine(t *testing.T) { + // Python allows multiple imports on one line with commas + sourceCode := []byte(` +from collections import OrderedDict, defaultdict, Counter +`) + + registry := NewModuleRegistry() + importMap, err := ExtractImports("/test/file.py", sourceCode, registry) + + require.NoError(t, err) + require.NotNil(t, importMap) + + // Each import should be captured separately + // Note: The tree-sitter query may need adjustment to handle this + // For now, we just verify it doesn't crash + assert.GreaterOrEqual(t, len(importMap.Imports), 1) +} + +func TestExtractCaptures(t *testing.T) { + // This is a unit test for the extractCaptures helper function + // We test it indirectly through ExtractImports, but this documents its behavior + sourceCode := []byte(` +import os +`) + + registry := NewModuleRegistry() + importMap, err := ExtractImports("/test/file.py", sourceCode, registry) + + require.NoError(t, err) + assert.Equal(t, 1, len(importMap.Imports)) +} + +func TestExtractImports_Whitespace(t *testing.T) { + // Test that whitespace is properly handled + sourceCode := []byte(` +import os +from sys import argv +import json as js +`) + + registry := NewModuleRegistry() + importMap, err := ExtractImports("/test/file.py", sourceCode, registry) + + require.NoError(t, err) + require.NotNil(t, importMap) + + // Verify whitespace doesn't affect import extraction + assert.Equal(t, 3, len(importMap.Imports)) + + fqn, ok := importMap.Resolve("os") + assert.True(t, ok) + assert.Equal(t, "os", fqn) + + fqn, ok = importMap.Resolve("argv") + assert.True(t, ok) + assert.Equal(t, "sys.argv", fqn) + + fqn, ok = importMap.Resolve("js") + assert.True(t, ok) + assert.Equal(t, "json", fqn) +} diff --git a/test-src/python/imports_test/aliased_imports.py b/test-src/python/imports_test/aliased_imports.py new file mode 100644 index 00000000..fbc6042d --- /dev/null +++ b/test-src/python/imports_test/aliased_imports.py @@ -0,0 +1,4 @@ +# Test file for aliased imports +import os as operating_system +from sys import argv as arguments +from json import dumps as json_dumps diff --git a/test-src/python/imports_test/from_imports.py b/test-src/python/imports_test/from_imports.py new file mode 100644 index 00000000..f87bfd85 --- /dev/null +++ b/test-src/python/imports_test/from_imports.py @@ -0,0 +1,4 @@ +# Test file for from import statements +from os import path +from sys import argv +from json import dumps, loads diff --git a/test-src/python/imports_test/mixed_imports.py b/test-src/python/imports_test/mixed_imports.py new file mode 100644 index 00000000..c522e104 --- /dev/null +++ b/test-src/python/imports_test/mixed_imports.py @@ -0,0 +1,5 @@ +# Test file with mixed import styles +import os +from sys import argv +import json as js +from collections import OrderedDict as OD diff --git a/test-src/python/imports_test/simple_imports.py b/test-src/python/imports_test/simple_imports.py new file mode 100644 index 00000000..979f1a21 --- /dev/null +++ b/test-src/python/imports_test/simple_imports.py @@ -0,0 +1,4 @@ +# Test file for simple import statements +import os +import sys +import json