diff --git a/sourcecode-parser/graph/callgraph/attribute_registry.go b/sourcecode-parser/graph/callgraph/attribute_registry.go index 851cb139..f26fbd9d 100644 --- a/sourcecode-parser/graph/callgraph/attribute_registry.go +++ b/sourcecode-parser/graph/callgraph/attribute_registry.go @@ -3,25 +3,16 @@ package callgraph import ( "sync" - "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph" + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph/core" ) -// ClassAttribute represents a single attribute of a class. -type ClassAttribute struct { - Name string // Attribute name (e.g., "value", "user") - Type *TypeInfo // Inferred type of the attribute - AssignedIn string // Method where assigned (e.g., "__init__", "setup") - Location *graph.SourceLocation - Confidence float64 // Confidence in type inference (0.0-1.0) -} +// Deprecated: Use core.ClassAttribute instead. +// This alias will be removed in a future version. +type ClassAttribute = core.ClassAttribute -// ClassAttributes holds all attributes for a single class. -type ClassAttributes struct { - ClassFQN string // Fully qualified class name (e.g., "myapp.models.User") - Attributes map[string]*ClassAttribute // Map from attribute name to attribute info - Methods []string // List of method FQNs in this class - FilePath string // Source file path where class is defined -} +// Deprecated: Use core.ClassAttributes instead. +// This alias will be removed in a future version. +type ClassAttributes = core.ClassAttributes // AttributeRegistry is the global registry of class attributes // It provides thread-safe access to class attribute information. diff --git a/sourcecode-parser/graph/callgraph/core/attribute_types.go b/sourcecode-parser/graph/callgraph/core/attribute_types.go new file mode 100644 index 00000000..02edadaa --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/attribute_types.go @@ -0,0 +1,30 @@ +package core + +import ( + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph" +) + +// TypeInfo represents inferred type information for a variable or expression. +// It tracks the fully qualified type name, confidence level, and how the type was inferred. +type TypeInfo struct { + TypeFQN string // Fully qualified type name (e.g., "builtins.str", "myapp.models.User") + Confidence float32 // Confidence level from 0.0 to 1.0 (1.0 = certain, 0.5 = heuristic, 0.0 = unknown) + Source string // How the type was inferred (e.g., "literal", "assignment", "annotation") +} + +// ClassAttribute represents a single attribute of a class. +type ClassAttribute struct { + Name string // Attribute name (e.g., "value", "user") + Type *TypeInfo // Inferred type of the attribute + AssignedIn string // Method where assigned (e.g., "__init__", "setup") + Location *graph.SourceLocation // Source location of the attribute + Confidence float64 // Confidence in type inference (0.0-1.0) +} + +// ClassAttributes holds all attributes for a single class. +type ClassAttributes struct { + ClassFQN string // Fully qualified class name (e.g., "myapp.models.User") + Attributes map[string]*ClassAttribute // Map from attribute name to attribute info + Methods []string // List of method FQNs in this class + FilePath string // Source file path where class is defined +} diff --git a/sourcecode-parser/graph/callgraph/core/doc.go b/sourcecode-parser/graph/callgraph/core/doc.go new file mode 100644 index 00000000..149243af --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/doc.go @@ -0,0 +1,24 @@ +// Package core provides foundational type definitions for the callgraph analyzer. +// +// This package contains pure data structures with minimal dependencies that form +// the contract for all other callgraph packages. Types in this package should: +// +// - Have zero circular dependencies +// - Contain minimal business logic +// - Be stable and rarely change +// +// # Core Types +// +// CallGraph represents the complete call graph with edges between functions. +// +// Statement represents individual program statements for def-use analysis. +// +// TaintSummary stores results of taint analysis for a function. +// +// # Usage +// +// import "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph/core" +// +// cg := core.NewCallGraph() +// cg.AddEdge("main.foo", "main.bar") +package core diff --git a/sourcecode-parser/graph/callgraph/core/frameworks.go b/sourcecode-parser/graph/callgraph/core/frameworks.go new file mode 100644 index 00000000..a9cf7774 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/frameworks.go @@ -0,0 +1,413 @@ +package core + +import ( + "strings" +) + +// FrameworkDefinition represents a known external framework or library. +// This is used to mark calls to external code as resolved, even though +// we don't have the source code for these frameworks. +type FrameworkDefinition struct { + Name string // Display name (e.g., "Django") + Prefixes []string // Module prefixes to match (e.g., ["django.", "django"]) + Description string // Human-readable description + Category string // Category: "web", "orm", "testing", "stdlib", etc. +} + +// builtinFrameworks contains the list of known Python frameworks and libraries. +// This list focuses on the most common frameworks found in Python projects. +var builtinFrameworks = []FrameworkDefinition{ + // Web Frameworks + { + Name: "Django", + Prefixes: []string{"django."}, + Description: "Django web framework", + Category: "web", + }, + { + Name: "Django REST Framework", + Prefixes: []string{"rest_framework."}, + Description: "Django REST framework for building Web APIs", + Category: "web", + }, + { + Name: "Flask", + Prefixes: []string{"flask."}, + Description: "Flask web framework", + Category: "web", + }, + { + Name: "FastAPI", + Prefixes: []string{"fastapi."}, + Description: "FastAPI web framework", + Category: "web", + }, + { + Name: "Starlette", + Prefixes: []string{"starlette."}, + Description: "Starlette ASGI framework", + Category: "web", + }, + { + Name: "Tornado", + Prefixes: []string{"tornado."}, + Description: "Tornado web framework", + Category: "web", + }, + { + Name: "Pyramid", + Prefixes: []string{"pyramid."}, + Description: "Pyramid web framework", + Category: "web", + }, + { + Name: "Bottle", + Prefixes: []string{"bottle."}, + Description: "Bottle web framework", + Category: "web", + }, + + // ORM and Database + { + Name: "SQLAlchemy", + Prefixes: []string{"sqlalchemy."}, + Description: "SQLAlchemy ORM", + Category: "orm", + }, + { + Name: "Peewee", + Prefixes: []string{"peewee."}, + Description: "Peewee ORM", + Category: "orm", + }, + { + Name: "Tortoise ORM", + Prefixes: []string{"tortoise."}, + Description: "Tortoise ORM", + Category: "orm", + }, + { + Name: "Pony ORM", + Prefixes: []string{"pony."}, + Description: "Pony ORM", + Category: "orm", + }, + + // Testing Frameworks + { + Name: "pytest", + Prefixes: []string{"pytest.", "_pytest."}, + Description: "pytest testing framework", + Category: "testing", + }, + { + Name: "unittest", + Prefixes: []string{"unittest."}, + Description: "Python unittest framework", + Category: "testing", + }, + { + Name: "nose", + Prefixes: []string{"nose."}, + Description: "nose testing framework", + Category: "testing", + }, + { + Name: "mock", + Prefixes: []string{"mock.", "unittest.mock."}, + Description: "Python mock library", + Category: "testing", + }, + + // HTTP and Requests + { + Name: "requests", + Prefixes: []string{"requests."}, + Description: "HTTP library for Python", + Category: "http", + }, + { + Name: "httpx", + Prefixes: []string{"httpx."}, + Description: "Async HTTP client", + Category: "http", + }, + { + Name: "urllib3", + Prefixes: []string{"urllib3."}, + Description: "HTTP client library", + Category: "http", + }, + { + Name: "aiohttp", + Prefixes: []string{"aiohttp."}, + Description: "Async HTTP client/server", + Category: "http", + }, + + // Data Science and ML + { + Name: "numpy", + Prefixes: []string{"numpy.", "np."}, + Description: "Numerical computing library", + Category: "data_science", + }, + { + Name: "pandas", + Prefixes: []string{"pandas.", "pd."}, + Description: "Data analysis library", + Category: "data_science", + }, + { + Name: "scikit-learn", + Prefixes: []string{"sklearn.", "scikit_learn."}, + Description: "Machine learning library", + Category: "data_science", + }, + { + Name: "tensorflow", + Prefixes: []string{"tensorflow.", "tf."}, + Description: "TensorFlow ML framework", + Category: "data_science", + }, + { + Name: "pytorch", + Prefixes: []string{"torch."}, + Description: "PyTorch ML framework", + Category: "data_science", + }, + + // Async and Concurrency + { + Name: "asyncio", + Prefixes: []string{"asyncio."}, + Description: "Async I/O library", + Category: "async", + }, + { + Name: "celery", + Prefixes: []string{"celery."}, + Description: "Distributed task queue", + Category: "async", + }, + + // Serialization and Data Formats + { + Name: "json", + Prefixes: []string{"json."}, + Description: "JSON encoder/decoder", + Category: "serialization", + }, + { + Name: "pickle", + Prefixes: []string{"pickle.", "_pickle."}, + Description: "Python object serialization", + Category: "serialization", + }, + { + Name: "yaml", + Prefixes: []string{"yaml.", "pyyaml."}, + Description: "YAML parser", + Category: "serialization", + }, + { + Name: "xml", + Prefixes: []string{"xml."}, + Description: "XML processing", + Category: "serialization", + }, + + // Logging and Monitoring + { + Name: "logging", + Prefixes: []string{"logging."}, + Description: "Python logging library", + Category: "logging", + }, + { + Name: "sentry", + Prefixes: []string{"sentry_sdk."}, + Description: "Sentry error tracking", + Category: "logging", + }, + + // Utilities + { + Name: "datetime", + Prefixes: []string{"datetime."}, + Description: "Date and time types", + Category: "stdlib", + }, + { + Name: "collections", + Prefixes: []string{"collections."}, + Description: "Container datatypes", + Category: "stdlib", + }, + { + Name: "itertools", + Prefixes: []string{"itertools."}, + Description: "Iterator functions", + Category: "stdlib", + }, + { + Name: "functools", + Prefixes: []string{"functools."}, + Description: "Higher-order functions", + Category: "stdlib", + }, + { + Name: "os", + Prefixes: []string{"os."}, + Description: "Operating system interfaces", + Category: "stdlib", + }, + { + Name: "sys", + Prefixes: []string{"sys."}, + Description: "System-specific parameters", + Category: "stdlib", + }, + { + Name: "pathlib", + Prefixes: []string{"pathlib."}, + Description: "Object-oriented filesystem paths", + Category: "stdlib", + }, + { + Name: "re", + Prefixes: []string{"re."}, + Description: "Regular expressions", + Category: "stdlib", + }, + { + Name: "subprocess", + Prefixes: []string{"subprocess."}, + Description: "Subprocess management", + Category: "stdlib", + }, + { + Name: "threading", + Prefixes: []string{"threading."}, + Description: "Thread-based parallelism", + Category: "stdlib", + }, + { + Name: "multiprocessing", + Prefixes: []string{"multiprocessing."}, + Description: "Process-based parallelism", + Category: "stdlib", + }, + { + Name: "socket", + Prefixes: []string{"socket."}, + Description: "Low-level networking", + Category: "stdlib", + }, + { + Name: "http", + Prefixes: []string{"http."}, + Description: "HTTP modules", + Category: "stdlib", + }, + { + Name: "urllib", + Prefixes: []string{"urllib."}, + Description: "URL handling modules", + Category: "stdlib", + }, + { + Name: "email", + Prefixes: []string{"email."}, + Description: "Email and MIME handling", + Category: "stdlib", + }, + { + Name: "hashlib", + Prefixes: []string{"hashlib."}, + Description: "Secure hash and message digest", + Category: "stdlib", + }, + { + Name: "hmac", + Prefixes: []string{"hmac."}, + Description: "Keyed-hashing for message authentication", + Category: "stdlib", + }, + { + Name: "secrets", + Prefixes: []string{"secrets."}, + Description: "Generate secure random numbers", + Category: "stdlib", + }, + { + Name: "typing", + Prefixes: []string{"typing."}, + Description: "Type hints support", + Category: "stdlib", + }, + { + Name: "dataclasses", + Prefixes: []string{"dataclasses."}, + Description: "Data classes", + Category: "stdlib", + }, + { + Name: "abc", + Prefixes: []string{"abc."}, + Description: "Abstract base classes", + Category: "stdlib", + }, +} + +// LoadFrameworks returns the list of known frameworks. +// This function provides an extensibility hook for future enhancements +// where frameworks might be loaded from a configuration file. +func LoadFrameworks() []FrameworkDefinition { + return builtinFrameworks +} + +// IsKnownFramework checks if the given fully qualified name (FQN) +// belongs to a known external framework or standard library. +// +// Parameters: +// - fqn: fully qualified name (e.g., "django.db.models.ForeignKey") +// +// Returns: +// - true if the FQN matches any known framework +// - the matching framework definition +func IsKnownFramework(fqn string) (bool, *FrameworkDefinition) { + frameworks := LoadFrameworks() + + for i := range frameworks { + framework := &frameworks[i] + for _, prefix := range framework.Prefixes { + // Check for exact match or prefix match + if fqn == prefix || strings.HasPrefix(fqn, prefix) { + return true, framework + } + } + } + + return false, nil +} + +// GetFrameworkCategory returns the category of a framework given its FQN. +// Returns empty string if not a known framework. +func GetFrameworkCategory(fqn string) string { + isKnown, framework := IsKnownFramework(fqn) + if isKnown { + return framework.Category + } + return "" +} + +// GetFrameworkName returns the name of a framework given its FQN. +// Returns empty string if not a known framework. +func GetFrameworkName(fqn string) string { + isKnown, framework := IsKnownFramework(fqn) + if isKnown { + return framework.Name + } + return "" +} diff --git a/sourcecode-parser/graph/callgraph/core/frameworks_test.go b/sourcecode-parser/graph/callgraph/core/frameworks_test.go new file mode 100644 index 00000000..73d3ae6a --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/frameworks_test.go @@ -0,0 +1,377 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsKnownFramework_Django(t *testing.T) { + tests := []struct { + name string + fqn string + expected bool + category string + }{ + { + name: "Django core", + fqn: "django.db.models.Model", + expected: true, + category: "web", + }, + { + name: "Django REST framework", + fqn: "rest_framework.serializers.ModelSerializer", + expected: true, + category: "web", + }, + { + name: "Django forms", + fqn: "django.forms.Form", + expected: true, + category: "web", + }, + { + name: "Django ORM", + fqn: "django.db.models.ForeignKey", + expected: true, + category: "web", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.Equal(t, tt.expected, isKnown) + if isKnown { + assert.NotNil(t, framework) + assert.Equal(t, tt.category, framework.Category) + } + }) + } +} + +func TestIsKnownFramework_Testing(t *testing.T) { + tests := []struct { + name string + fqn string + expected bool + category string + }{ + { + name: "pytest", + fqn: "pytest.fixture", + expected: true, + category: "testing", + }, + { + name: "unittest", + fqn: "unittest.TestCase", + expected: true, + category: "testing", + }, + { + name: "mock", + fqn: "unittest.mock.patch", + expected: true, + category: "testing", + }, + { + name: "_pytest internal", + fqn: "_pytest.fixtures", + expected: true, + category: "testing", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.Equal(t, tt.expected, isKnown) + if isKnown { + assert.NotNil(t, framework) + assert.Equal(t, tt.category, framework.Category) + } + }) + } +} + +func TestIsKnownFramework_HTTP(t *testing.T) { + tests := []struct { + name string + fqn string + expected bool + category string + }{ + { + name: "requests library", + fqn: "requests.get", + expected: true, + category: "http", + }, + { + name: "httpx", + fqn: "httpx.AsyncClient", + expected: true, + category: "http", + }, + { + name: "urllib3", + fqn: "urllib3.PoolManager", + expected: true, + category: "http", + }, + { + name: "aiohttp", + fqn: "aiohttp.ClientSession", + expected: true, + category: "http", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.Equal(t, tt.expected, isKnown) + if isKnown { + assert.NotNil(t, framework) + assert.Equal(t, tt.category, framework.Category) + } + }) + } +} + +func TestIsKnownFramework_DataScience(t *testing.T) { + tests := []struct { + name string + fqn string + expected bool + category string + }{ + { + name: "numpy", + fqn: "numpy.array", + expected: true, + category: "data_science", + }, + { + name: "pandas", + fqn: "pandas.DataFrame", + expected: true, + category: "data_science", + }, + { + name: "sklearn", + fqn: "sklearn.ensemble.RandomForestClassifier", + expected: true, + category: "data_science", + }, + { + name: "tensorflow", + fqn: "tensorflow.keras.Model", + expected: true, + category: "data_science", + }, + { + name: "pytorch", + fqn: "torch.nn.Module", + expected: true, + category: "data_science", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.Equal(t, tt.expected, isKnown) + if isKnown { + assert.NotNil(t, framework) + assert.Equal(t, tt.category, framework.Category) + } + }) + } +} + +func TestIsKnownFramework_Stdlib(t *testing.T) { + tests := []struct { + name string + fqn string + expected bool + category string + }{ + { + name: "json", + fqn: "json.loads", + expected: true, + category: "serialization", + }, + { + name: "pickle", + fqn: "pickle.dumps", + expected: true, + category: "serialization", + }, + { + name: "logging", + fqn: "logging.getLogger", + expected: true, + category: "logging", + }, + { + name: "datetime", + fqn: "datetime.datetime", + expected: true, + category: "stdlib", + }, + { + name: "collections", + fqn: "collections.defaultdict", + expected: true, + category: "stdlib", + }, + { + name: "os", + fqn: "os.path.join", + expected: true, + category: "stdlib", + }, + { + name: "subprocess", + fqn: "subprocess.run", + expected: true, + category: "stdlib", + }, + { + name: "hashlib", + fqn: "hashlib.sha256", + expected: true, + category: "stdlib", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.Equal(t, tt.expected, isKnown) + if isKnown { + assert.NotNil(t, framework) + assert.Equal(t, tt.category, framework.Category) + } + }) + } +} + +func TestIsKnownFramework_NotFound(t *testing.T) { + tests := []struct { + name string + fqn string + }{ + { + name: "Custom application module", + fqn: "myapp.utils.helpers", + }, + { + name: "Custom package", + fqn: "internal.services.auth", + }, + { + name: "Unknown framework", + fqn: "unknownframework.module", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isKnown, framework := IsKnownFramework(tt.fqn) + assert.False(t, isKnown) + assert.Nil(t, framework) + }) + } +} + +func TestGetFrameworkCategory(t *testing.T) { + tests := []struct { + name string + fqn string + expected string + }{ + { + name: "Django web framework", + fqn: "django.http.HttpResponse", + expected: "web", + }, + { + name: "pytest testing", + fqn: "pytest.mark.parametrize", + expected: "testing", + }, + { + name: "requests http", + fqn: "requests.post", + expected: "http", + }, + { + name: "Unknown framework", + fqn: "myapp.custom", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + category := GetFrameworkCategory(tt.fqn) + assert.Equal(t, tt.expected, category) + }) + } +} + +func TestGetFrameworkName(t *testing.T) { + tests := []struct { + name string + fqn string + expected string + }{ + { + name: "Django", + fqn: "django.contrib.auth", + expected: "Django", + }, + { + name: "Flask", + fqn: "flask.Flask", + expected: "Flask", + }, + { + name: "FastAPI", + fqn: "fastapi.FastAPI", + expected: "FastAPI", + }, + { + name: "Unknown", + fqn: "myapp.unknown", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + name := GetFrameworkName(tt.fqn) + assert.Equal(t, tt.expected, name) + }) + } +} + +func TestLoadFrameworks(t *testing.T) { + frameworks := LoadFrameworks() + + // Should have at least 50 frameworks defined + assert.GreaterOrEqual(t, len(frameworks), 50) + + // Check that all frameworks have required fields + for _, fw := range frameworks { + assert.NotEmpty(t, fw.Name, "Framework should have a name") + assert.NotEmpty(t, fw.Prefixes, "Framework should have at least one prefix") + assert.NotEmpty(t, fw.Category, "Framework should have a category") + assert.NotEmpty(t, fw.Description, "Framework should have a description") + } +} diff --git a/sourcecode-parser/graph/callgraph/core/statement.go b/sourcecode-parser/graph/callgraph/core/statement.go new file mode 100644 index 00000000..cc1c2150 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/statement.go @@ -0,0 +1,334 @@ +package core + +// StatementType represents the type of statement in the code. +type StatementType string + +const ( + // Assignment represents variable assignments: x = expr. + StatementTypeAssignment StatementType = "assignment" + + // Call represents function/method calls: foo(), obj.method(). + StatementTypeCall StatementType = "call" + + // Return represents return statements: return expr. + StatementTypeReturn StatementType = "return" + + // If represents conditional statements: if condition: ... + StatementTypeIf StatementType = "if" + + // For represents loop statements: for x in iterable: ... + StatementTypeFor StatementType = "for" + + // While represents while loop statements: while condition: ... + StatementTypeWhile StatementType = "while" + + // With represents context manager statements: with expr as var: ... + StatementTypeWith StatementType = "with" + + // Try represents exception handling: try: ... except: ... + StatementTypeTry StatementType = "try" + + // Raise represents exception raising: raise Exception(). + StatementTypeRaise StatementType = "raise" + + // Import represents import statements: import module, from module import name. + StatementTypeImport StatementType = "import" + + // Expression represents expression statements (calls, attribute access, etc.). + StatementTypeExpression StatementType = "expression" +) + +// Statement represents a single statement in the code with def-use information. +type Statement struct { + // Type is the kind of statement (assignment, call, return, etc.) + Type StatementType + + // LineNumber is the source line number for this statement (1-indexed) + LineNumber uint32 + + // Def is the variable being defined by this statement (if any) + // For assignments: the left-hand side variable + // For for loops: the loop variable + // For with statements: the as variable + // Empty string if no definition + Def string + + // Uses is the list of variables used/read by this statement + // For assignments: variables in the right-hand side expression + // For calls: variables used in arguments + // For conditions: variables in the condition expression + Uses []string + + // CallTarget is the function/method being called (if Type == StatementTypeCall) + // Format: "function_name" for direct calls, "obj.method" for method calls + // Empty string for non-call statements + CallTarget string + + // CallArgs are the argument variables passed to the call (if Type == StatementTypeCall) + // Only includes variable names, not literals + CallArgs []string + + // NestedStatements contains statements inside this statement's body + // Used for if/for/while/with/try blocks + // Empty for simple statements like assignments + NestedStatements []*Statement + + // ElseBranch contains statements in the else branch (if applicable) + // Used for if/try statements + ElseBranch []*Statement +} + +// GetDef returns the variable defined by this statement, or empty string if none. +func (s *Statement) GetDef() string { + return s.Def +} + +// GetUses returns the list of variables used by this statement. +func (s *Statement) GetUses() []string { + return s.Uses +} + +// IsCall returns true if this statement is a function/method call. +func (s *Statement) IsCall() bool { + return s.Type == StatementTypeCall || s.Type == StatementTypeExpression +} + +// IsAssignment returns true if this statement is a variable assignment. +func (s *Statement) IsAssignment() bool { + return s.Type == StatementTypeAssignment +} + +// IsControlFlow returns true if this statement is a control flow construct. +func (s *Statement) IsControlFlow() bool { + switch s.Type { + case StatementTypeIf, StatementTypeFor, StatementTypeWhile, StatementTypeWith, StatementTypeTry: + return true + default: + return false + } +} + +// HasNestedStatements returns true if this statement contains nested statements. +func (s *Statement) HasNestedStatements() bool { + return len(s.NestedStatements) > 0 || len(s.ElseBranch) > 0 +} + +// AllStatements returns a flattened list of this statement and all nested statements. +// Performs depth-first traversal. +func (s *Statement) AllStatements() []*Statement { + result := []*Statement{s} + + for _, nested := range s.NestedStatements { + result = append(result, nested.AllStatements()...) + } + + for _, elseBranch := range s.ElseBranch { + result = append(result, elseBranch.AllStatements()...) + } + + return result +} + +// DefUseChain represents the def-use relationships for all variables in a function. +type DefUseChain struct { + // Defs maps variable names to all statements that define them. + // A variable can have multiple definitions across different code paths. + Defs map[string][]*Statement + + // Uses maps variable names to all statements that use them. + // A variable can be used in multiple places. + Uses map[string][]*Statement +} + +// NewDefUseChain creates an empty def-use chain. +func NewDefUseChain() *DefUseChain { + return &DefUseChain{ + Defs: make(map[string][]*Statement), + Uses: make(map[string][]*Statement), + } +} + +// AddDef registers a statement as defining a variable. +func (chain *DefUseChain) AddDef(varName string, stmt *Statement) { + if varName == "" { + return + } + chain.Defs[varName] = append(chain.Defs[varName], stmt) +} + +// AddUse registers a statement as using a variable. +func (chain *DefUseChain) AddUse(varName string, stmt *Statement) { + if varName == "" { + return + } + chain.Uses[varName] = append(chain.Uses[varName], stmt) +} + +// GetDefs returns all statements that define a given variable. +// Returns empty slice if variable is never defined. +func (chain *DefUseChain) GetDefs(varName string) []*Statement { + if defs, ok := chain.Defs[varName]; ok { + return defs + } + return []*Statement{} +} + +// GetUses returns all statements that use a given variable. +// Returns empty slice if variable is never used. +func (chain *DefUseChain) GetUses(varName string) []*Statement { + if uses, ok := chain.Uses[varName]; ok { + return uses + } + return []*Statement{} +} + +// IsDefined returns true if the variable has at least one definition. +func (chain *DefUseChain) IsDefined(varName string) bool { + return len(chain.Defs[varName]) > 0 +} + +// IsUsed returns true if the variable has at least one use. +func (chain *DefUseChain) IsUsed(varName string) bool { + return len(chain.Uses[varName]) > 0 +} + +// AllVariables returns a list of all variable names in the def-use chain. +func (chain *DefUseChain) AllVariables() []string { + varSet := make(map[string]bool) + + for varName := range chain.Defs { + varSet[varName] = true + } + + for varName := range chain.Uses { + varSet[varName] = true + } + + result := make([]string, 0, len(varSet)) + for varName := range varSet { + result = append(result, varName) + } + + return result +} + +// BuildDefUseChains constructs a def-use chain from a list of statements. +// This is a single-pass algorithm that builds an inverted index. +// +// Algorithm: +// 1. Initialize empty Defs and Uses maps +// 2. For each statement: +// - If stmt.Def is not empty: add stmt to Defs[stmt.Def] +// - For each variable in stmt.Uses: add stmt to Uses[variable] +// 3. Return DefUseChain +// +// Time complexity: O(n × m) +// +// where n = number of statements +// m = average number of uses per statement +// Typical: 50 statements × 3 variables = 150 operations (~1 microsecond) +// +// Space complexity: O(v × k) +// +// where v = number of unique variables +// k = average number of defs + uses per variable +// Typical: 20 variables × 5 references = 100 pointers = 800 bytes +// +// Example: +// +// statements := []*Statement{ +// {LineNumber: 1, Def: "x", Uses: []string{}}, +// {LineNumber: 2, Def: "y", Uses: []string{"x"}}, +// {LineNumber: 3, Def: "", Uses: []string{"y"}}, +// } +// +// chain := BuildDefUseChains(statements) +// +// // Query: where is x defined? +// xDefs := chain.Defs["x"] // [stmt1] +// +// // Query: where is x used? +// xUses := chain.Uses["x"] // [stmt2] +func BuildDefUseChains(statements []*Statement) *DefUseChain { + chain := NewDefUseChain() + + // Single pass: build inverted index + for _, stmt := range statements { + // Track definition (single variable per statement) + if stmt.Def != "" { + chain.AddDef(stmt.Def, stmt) + } + + // Track all uses in this statement + for _, varName := range stmt.Uses { + chain.AddUse(varName, stmt) + } + } + + return chain +} + +// DefUseStats contains statistics about the def-use chain (for debugging/diagnostics). +type DefUseStats struct { + NumVariables int // Total unique variables + NumDefs int // Total definition sites + NumUses int // Total use sites + MaxDefsPerVariable int // Most definitions for a single variable + MaxUsesPerVariable int // Most uses for a single variable + UndefinedVariables int // Variables used but never defined (parameters) + DeadVariables int // Variables defined but never used +} + +// ComputeStats computes statistics about this def-use chain. +// Useful for performance analysis and debugging. +// +// Example: +// +// stats := chain.ComputeStats() +// fmt.Printf("Function has %d variables, %d defs, %d uses\n", +// stats.NumVariables, stats.NumDefs, stats.NumUses) +func (chain *DefUseChain) ComputeStats() DefUseStats { + stats := DefUseStats{} + + // Count unique variables + varSet := make(map[string]bool) + for varName := range chain.Defs { + varSet[varName] = true + } + for varName := range chain.Uses { + varSet[varName] = true + } + stats.NumVariables = len(varSet) + + // Count total defs and max defs per variable + for _, defs := range chain.Defs { + stats.NumDefs += len(defs) + if len(defs) > stats.MaxDefsPerVariable { + stats.MaxDefsPerVariable = len(defs) + } + } + + // Count total uses and max uses per variable + for _, uses := range chain.Uses { + stats.NumUses += len(uses) + if len(uses) > stats.MaxUsesPerVariable { + stats.MaxUsesPerVariable = len(uses) + } + } + + // Count undefined variables (used but not defined) + for varName := range chain.Uses { + if len(chain.Defs[varName]) == 0 { + stats.UndefinedVariables++ + } + } + + // Count dead variables (defined but not used) + for varName := range chain.Defs { + if len(chain.Uses[varName]) == 0 { + stats.DeadVariables++ + } + } + + return stats +} diff --git a/sourcecode-parser/graph/callgraph/core/statement_test.go b/sourcecode-parser/graph/callgraph/core/statement_test.go new file mode 100644 index 00000000..c1945c77 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/statement_test.go @@ -0,0 +1,696 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStatementGetDef(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected string + }{ + { + name: "assignment with def", + stmt: &Statement{ + Type: StatementTypeAssignment, + Def: "x", + }, + expected: "x", + }, + { + name: "call without def", + stmt: &Statement{ + Type: StatementTypeCall, + Def: "", + }, + expected: "", + }, + { + name: "for loop with def", + stmt: &Statement{ + Type: StatementTypeFor, + Def: "item", + }, + expected: "item", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.GetDef() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementGetUses(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected []string + }{ + { + name: "assignment with uses", + stmt: &Statement{ + Type: StatementTypeAssignment, + Uses: []string{"a", "b"}, + }, + expected: []string{"a", "b"}, + }, + { + name: "call with no uses", + stmt: &Statement{ + Type: StatementTypeCall, + Uses: []string{}, + }, + expected: []string{}, + }, + { + name: "if statement with condition uses", + stmt: &Statement{ + Type: StatementTypeIf, + Uses: []string{"flag", "count"}, + }, + expected: []string{"flag", "count"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.GetUses() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementIsCall(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected bool + }{ + { + name: "call statement", + stmt: &Statement{Type: StatementTypeCall}, + expected: true, + }, + { + name: "expression statement", + stmt: &Statement{Type: StatementTypeExpression}, + expected: true, + }, + { + name: "assignment statement", + stmt: &Statement{Type: StatementTypeAssignment}, + expected: false, + }, + { + name: "return statement", + stmt: &Statement{Type: StatementTypeReturn}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.IsCall() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementIsAssignment(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected bool + }{ + { + name: "assignment statement", + stmt: &Statement{Type: StatementTypeAssignment}, + expected: true, + }, + { + name: "call statement", + stmt: &Statement{Type: StatementTypeCall}, + expected: false, + }, + { + name: "for statement", + stmt: &Statement{Type: StatementTypeFor}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.IsAssignment() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementIsControlFlow(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected bool + }{ + { + name: "if statement", + stmt: &Statement{Type: StatementTypeIf}, + expected: true, + }, + { + name: "for statement", + stmt: &Statement{Type: StatementTypeFor}, + expected: true, + }, + { + name: "while statement", + stmt: &Statement{Type: StatementTypeWhile}, + expected: true, + }, + { + name: "with statement", + stmt: &Statement{Type: StatementTypeWith}, + expected: true, + }, + { + name: "try statement", + stmt: &Statement{Type: StatementTypeTry}, + expected: true, + }, + { + name: "assignment statement", + stmt: &Statement{Type: StatementTypeAssignment}, + expected: false, + }, + { + name: "call statement", + stmt: &Statement{Type: StatementTypeCall}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.IsControlFlow() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementHasNestedStatements(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expected bool + }{ + { + name: "if with nested statements", + stmt: &Statement{ + Type: StatementTypeIf, + NestedStatements: []*Statement{ + {Type: StatementTypeAssignment}, + }, + }, + expected: true, + }, + { + name: "if with else branch", + stmt: &Statement{ + Type: StatementTypeIf, + ElseBranch: []*Statement{ + {Type: StatementTypeReturn}, + }, + }, + expected: true, + }, + { + name: "simple assignment", + stmt: &Statement{ + Type: StatementTypeAssignment, + }, + expected: false, + }, + { + name: "if with both nested and else", + stmt: &Statement{ + Type: StatementTypeIf, + NestedStatements: []*Statement{ + {Type: StatementTypeAssignment}, + }, + ElseBranch: []*Statement{ + {Type: StatementTypeReturn}, + }, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.HasNestedStatements() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatementAllStatements(t *testing.T) { + tests := []struct { + name string + stmt *Statement + expectedCount int + }{ + { + name: "simple statement", + stmt: &Statement{ + Type: StatementTypeAssignment, + LineNumber: 1, + }, + expectedCount: 1, + }, + { + name: "if with one nested statement", + stmt: &Statement{ + Type: StatementTypeIf, + LineNumber: 1, + NestedStatements: []*Statement{ + {Type: StatementTypeAssignment, LineNumber: 2}, + }, + }, + expectedCount: 2, + }, + { + name: "if with nested and else", + stmt: &Statement{ + Type: StatementTypeIf, + LineNumber: 1, + NestedStatements: []*Statement{ + {Type: StatementTypeAssignment, LineNumber: 2}, + {Type: StatementTypeCall, LineNumber: 3}, + }, + ElseBranch: []*Statement{ + {Type: StatementTypeReturn, LineNumber: 5}, + }, + }, + expectedCount: 4, + }, + { + name: "deeply nested statements", + stmt: &Statement{ + Type: StatementTypeIf, + LineNumber: 1, + NestedStatements: []*Statement{ + { + Type: StatementTypeFor, + LineNumber: 2, + NestedStatements: []*Statement{ + {Type: StatementTypeAssignment, LineNumber: 3}, + {Type: StatementTypeCall, LineNumber: 4}, + }, + }, + }, + }, + expectedCount: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stmt.AllStatements() + assert.Equal(t, tt.expectedCount, len(result)) + + // Verify first statement is always the root + assert.Equal(t, tt.stmt, result[0]) + }) + } +} + +func TestNewDefUseChain(t *testing.T) { + chain := NewDefUseChain() + + assert.NotNil(t, chain) + assert.NotNil(t, chain.Defs) + assert.NotNil(t, chain.Uses) + assert.Equal(t, 0, len(chain.Defs)) + assert.Equal(t, 0, len(chain.Uses)) +} + +func TestDefUseChainAddDef(t *testing.T) { + chain := NewDefUseChain() + stmt1 := &Statement{Type: StatementTypeAssignment, LineNumber: 1, Def: "x"} + stmt2 := &Statement{Type: StatementTypeAssignment, LineNumber: 2, Def: "x"} + + // Add first definition + chain.AddDef("x", stmt1) + assert.Equal(t, 1, len(chain.Defs["x"])) + assert.Equal(t, stmt1, chain.Defs["x"][0]) + + // Add second definition of same variable + chain.AddDef("x", stmt2) + assert.Equal(t, 2, len(chain.Defs["x"])) + assert.Equal(t, stmt2, chain.Defs["x"][1]) + + // Add definition for different variable + stmt3 := &Statement{Type: StatementTypeAssignment, LineNumber: 3, Def: "y"} + chain.AddDef("y", stmt3) + assert.Equal(t, 1, len(chain.Defs["y"])) + assert.Equal(t, stmt3, chain.Defs["y"][0]) + + // Test empty variable name (should be ignored) + chain.AddDef("", stmt1) + _, exists := chain.Defs[""] + assert.False(t, exists) +} + +func TestDefUseChainAddUse(t *testing.T) { + chain := NewDefUseChain() + stmt1 := &Statement{Type: StatementTypeCall, LineNumber: 1, Uses: []string{"x"}} + stmt2 := &Statement{Type: StatementTypeAssignment, LineNumber: 2, Uses: []string{"x"}} + + // Add first use + chain.AddUse("x", stmt1) + assert.Equal(t, 1, len(chain.Uses["x"])) + assert.Equal(t, stmt1, chain.Uses["x"][0]) + + // Add second use of same variable + chain.AddUse("x", stmt2) + assert.Equal(t, 2, len(chain.Uses["x"])) + assert.Equal(t, stmt2, chain.Uses["x"][1]) + + // Add use for different variable + stmt3 := &Statement{Type: StatementTypeReturn, LineNumber: 3, Uses: []string{"y"}} + chain.AddUse("y", stmt3) + assert.Equal(t, 1, len(chain.Uses["y"])) + assert.Equal(t, stmt3, chain.Uses["y"][0]) + + // Test empty variable name (should be ignored) + chain.AddUse("", stmt1) + _, exists := chain.Uses[""] + assert.False(t, exists) +} + +func TestDefUseChainGetDefs(t *testing.T) { + chain := NewDefUseChain() + stmt1 := &Statement{Type: StatementTypeAssignment, LineNumber: 1, Def: "x"} + stmt2 := &Statement{Type: StatementTypeAssignment, LineNumber: 2, Def: "x"} + + chain.AddDef("x", stmt1) + chain.AddDef("x", stmt2) + + defs := chain.GetDefs("x") + assert.Equal(t, 2, len(defs)) + assert.Equal(t, stmt1, defs[0]) + assert.Equal(t, stmt2, defs[1]) + + // Test non-existent variable (should return empty slice, not nil) + nonExistent := chain.GetDefs("nonexistent") + assert.NotNil(t, nonExistent) + assert.Equal(t, 0, len(nonExistent)) +} + +func TestDefUseChainGetUses(t *testing.T) { + chain := NewDefUseChain() + stmt1 := &Statement{Type: StatementTypeCall, LineNumber: 1, Uses: []string{"x"}} + stmt2 := &Statement{Type: StatementTypeAssignment, LineNumber: 2, Uses: []string{"x"}} + + chain.AddUse("x", stmt1) + chain.AddUse("x", stmt2) + + uses := chain.GetUses("x") + assert.Equal(t, 2, len(uses)) + assert.Equal(t, stmt1, uses[0]) + assert.Equal(t, stmt2, uses[1]) + + // Test non-existent variable (should return empty slice, not nil) + nonExistent := chain.GetUses("nonexistent") + assert.NotNil(t, nonExistent) + assert.Equal(t, 0, len(nonExistent)) +} + +func TestDefUseChainIsDefined(t *testing.T) { + chain := NewDefUseChain() + stmt := &Statement{Type: StatementTypeAssignment, LineNumber: 1, Def: "x"} + + assert.False(t, chain.IsDefined("x")) + + chain.AddDef("x", stmt) + assert.True(t, chain.IsDefined("x")) + assert.False(t, chain.IsDefined("y")) +} + +func TestDefUseChainIsUsed(t *testing.T) { + chain := NewDefUseChain() + stmt := &Statement{Type: StatementTypeCall, LineNumber: 1, Uses: []string{"x"}} + + assert.False(t, chain.IsUsed("x")) + + chain.AddUse("x", stmt) + assert.True(t, chain.IsUsed("x")) + assert.False(t, chain.IsUsed("y")) +} + +func TestDefUseChainAllVariables(t *testing.T) { + chain := NewDefUseChain() + + stmt1 := &Statement{Type: StatementTypeAssignment, LineNumber: 1, Def: "x"} + stmt2 := &Statement{Type: StatementTypeCall, LineNumber: 2, Uses: []string{"y"}} + stmt3 := &Statement{Type: StatementTypeAssignment, LineNumber: 3, Def: "z", Uses: []string{"x"}} + + chain.AddDef("x", stmt1) + chain.AddUse("y", stmt2) + chain.AddDef("z", stmt3) + chain.AddUse("x", stmt3) + + vars := chain.AllVariables() + assert.Equal(t, 3, len(vars)) + + // Create a set to check presence + varSet := make(map[string]bool) + for _, v := range vars { + varSet[v] = true + } + + assert.True(t, varSet["x"]) + assert.True(t, varSet["y"]) + assert.True(t, varSet["z"]) +} + +func TestDefUseChainComplexScenario(t *testing.T) { + // Simulate a real code scenario: + // 1: x = 5 + // 2: y = x + 10 + // 3: if y > 15: + // 4: z = x * 2 + // 5: print(z) + + chain := NewDefUseChain() + + stmt1 := &Statement{Type: StatementTypeAssignment, LineNumber: 1, Def: "x"} + stmt2 := &Statement{Type: StatementTypeAssignment, LineNumber: 2, Def: "y", Uses: []string{"x"}} + stmt3 := &Statement{Type: StatementTypeIf, LineNumber: 3, Uses: []string{"y"}} + stmt4 := &Statement{Type: StatementTypeAssignment, LineNumber: 4, Def: "z", Uses: []string{"x"}} + stmt5 := &Statement{Type: StatementTypeCall, LineNumber: 5, Uses: []string{"z"}} + + chain.AddDef("x", stmt1) + + chain.AddDef("y", stmt2) + chain.AddUse("x", stmt2) + + chain.AddUse("y", stmt3) + + chain.AddDef("z", stmt4) + chain.AddUse("x", stmt4) + + chain.AddUse("z", stmt5) + + // Verify x: 1 def, 2 uses + assert.Equal(t, 1, len(chain.GetDefs("x"))) + assert.Equal(t, 2, len(chain.GetUses("x"))) + + // Verify y: 1 def, 1 use + assert.Equal(t, 1, len(chain.GetDefs("y"))) + assert.Equal(t, 1, len(chain.GetUses("y"))) + + // Verify z: 1 def, 1 use + assert.Equal(t, 1, len(chain.GetDefs("z"))) + assert.Equal(t, 1, len(chain.GetUses("z"))) + + // All variables + vars := chain.AllVariables() + assert.Equal(t, 3, len(vars)) +} + +func TestBuildDefUseChains(t *testing.T) { + tests := []struct { + name string + statements []*Statement + checkFn func(*testing.T, *DefUseChain) + }{ + { + name: "empty statements", + statements: []*Statement{}, + checkFn: func(t *testing.T, chain *DefUseChain) { + t.Helper() + assert.NotNil(t, chain) + assert.Equal(t, 0, len(chain.Defs)) + assert.Equal(t, 0, len(chain.Uses)) + }, + }, + { + name: "single assignment", + statements: []*Statement{ + {LineNumber: 1, Def: "x", Uses: []string{}}, + }, + checkFn: func(t *testing.T, chain *DefUseChain) { + t.Helper() + assert.Equal(t, 1, len(chain.Defs)) + assert.Equal(t, 1, len(chain.Defs["x"])) + assert.Equal(t, 0, len(chain.Uses)) + }, + }, + { + name: "def-use chain", + statements: []*Statement{ + {LineNumber: 1, Def: "x", Uses: []string{}}, + {LineNumber: 2, Def: "y", Uses: []string{"x"}}, + {LineNumber: 3, Def: "", Uses: []string{"y"}}, + }, + checkFn: func(t *testing.T, chain *DefUseChain) { + t.Helper() + // Check defs + assert.Equal(t, 2, len(chain.Defs)) + assert.Equal(t, 1, len(chain.Defs["x"])) + assert.Equal(t, 1, len(chain.Defs["y"])) + + // Check uses + assert.Equal(t, 2, len(chain.Uses)) + assert.Equal(t, 1, len(chain.Uses["x"])) + assert.Equal(t, 1, len(chain.Uses["y"])) + }, + }, + { + name: "multiple defs and uses", + statements: []*Statement{ + {LineNumber: 1, Def: "x", Uses: []string{}}, + {LineNumber: 2, Def: "x", Uses: []string{"x"}}, + {LineNumber: 3, Def: "y", Uses: []string{"x", "x"}}, + }, + checkFn: func(t *testing.T, chain *DefUseChain) { + t.Helper() + // Variable x has 2 definitions + assert.Equal(t, 2, len(chain.Defs["x"])) + + // Variable x is used in 2 statements (line 2 and 3) + assert.Equal(t, 3, len(chain.Uses["x"])) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := BuildDefUseChains(tt.statements) + tt.checkFn(t, chain) + }) + } +} + +func TestDefUseChainComputeStats(t *testing.T) { + tests := []struct { + name string + setupFn func() *DefUseChain + expectedStats DefUseStats + }{ + { + name: "empty chain", + setupFn: NewDefUseChain, + expectedStats: DefUseStats{ + NumVariables: 0, + NumDefs: 0, + NumUses: 0, + MaxDefsPerVariable: 0, + MaxUsesPerVariable: 0, + UndefinedVariables: 0, + DeadVariables: 0, + }, + }, + { + name: "simple def-use", + setupFn: func() *DefUseChain { + statements := []*Statement{ + {LineNumber: 1, Def: "x", Uses: []string{}}, + {LineNumber: 2, Def: "y", Uses: []string{"x"}}, + } + return BuildDefUseChains(statements) + }, + expectedStats: DefUseStats{ + NumVariables: 2, + NumDefs: 2, + NumUses: 1, + MaxDefsPerVariable: 1, + MaxUsesPerVariable: 1, + UndefinedVariables: 0, + DeadVariables: 1, // y is defined but not used + }, + }, + { + name: "undefined variable", + setupFn: func() *DefUseChain { + statements := []*Statement{ + {LineNumber: 1, Def: "x", Uses: []string{"y"}}, // y is used but never defined + } + return BuildDefUseChains(statements) + }, + expectedStats: DefUseStats{ + NumVariables: 2, + NumDefs: 1, + NumUses: 1, + MaxDefsPerVariable: 1, + MaxUsesPerVariable: 1, + UndefinedVariables: 1, // y is undefined + DeadVariables: 1, // x is never used + }, + }, + { + name: "multiple defs per variable", + setupFn: func() *DefUseChain { + statements := []*Statement{ + {LineNumber: 1, Def: "x", Uses: []string{}}, + {LineNumber: 2, Def: "x", Uses: []string{}}, + {LineNumber: 3, Def: "x", Uses: []string{}}, + {LineNumber: 4, Def: "", Uses: []string{"x", "x"}}, + } + return BuildDefUseChains(statements) + }, + expectedStats: DefUseStats{ + NumVariables: 1, + NumDefs: 3, + NumUses: 2, + MaxDefsPerVariable: 3, + MaxUsesPerVariable: 2, + UndefinedVariables: 0, + DeadVariables: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := tt.setupFn() + stats := chain.ComputeStats() + + assert.Equal(t, tt.expectedStats.NumVariables, stats.NumVariables, "NumVariables mismatch") + assert.Equal(t, tt.expectedStats.NumDefs, stats.NumDefs, "NumDefs mismatch") + assert.Equal(t, tt.expectedStats.NumUses, stats.NumUses, "NumUses mismatch") + assert.Equal(t, tt.expectedStats.MaxDefsPerVariable, stats.MaxDefsPerVariable, "MaxDefsPerVariable mismatch") + assert.Equal(t, tt.expectedStats.MaxUsesPerVariable, stats.MaxUsesPerVariable, "MaxUsesPerVariable mismatch") + assert.Equal(t, tt.expectedStats.UndefinedVariables, stats.UndefinedVariables, "UndefinedVariables mismatch") + assert.Equal(t, tt.expectedStats.DeadVariables, stats.DeadVariables, "DeadVariables mismatch") + }) + } +} diff --git a/sourcecode-parser/graph/callgraph/core/stdlib_types.go b/sourcecode-parser/graph/callgraph/core/stdlib_types.go new file mode 100644 index 00000000..91e8058e --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/stdlib_types.go @@ -0,0 +1,168 @@ +package core + +// StdlibRegistry holds all Python stdlib module registries. +type StdlibRegistry struct { + Modules map[string]*StdlibModule + Manifest *Manifest +} + +// Manifest contains metadata about the stdlib registry. +// +//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). +type Manifest struct { + SchemaVersion string `json:"schema_version"` + RegistryVersion string `json:"registry_version"` + PythonVersion PythonVersionInfo `json:"python_version"` + GeneratedAt string `json:"generated_at"` + GeneratorVersion string `json:"generator_version"` + BaseURL string `json:"base_url"` + Modules []*ModuleEntry `json:"modules"` + Statistics *RegistryStats `json:"statistics"` +} + +// PythonVersionInfo contains Python version details. +type PythonVersionInfo struct { + Major int `json:"major"` + Minor int `json:"minor"` + Patch int `json:"patch"` + Full string `json:"full"` +} + +// ModuleEntry represents a single module in the manifest. +// +//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). +type ModuleEntry struct { + Name string `json:"name"` + File string `json:"file"` + URL string `json:"url"` + SizeBytes int64 `json:"size_bytes"` + Checksum string `json:"checksum"` +} + +// RegistryStats contains aggregate statistics. +// +//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). +type RegistryStats struct { + TotalModules int `json:"total_modules"` + TotalFunctions int `json:"total_functions"` + TotalClasses int `json:"total_classes"` + TotalConstants int `json:"total_constants"` + TotalAttributes int `json:"total_attributes"` +} + +// StdlibModule represents a single stdlib module registry. +// +//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). +type StdlibModule struct { + Module string `json:"module"` + PythonVersion string `json:"python_version"` + GeneratedAt string `json:"generated_at"` + Functions map[string]*StdlibFunction `json:"functions"` + Classes map[string]*StdlibClass `json:"classes"` + Constants map[string]*StdlibConstant `json:"constants"` + Attributes map[string]*StdlibAttribute `json:"attributes"` +} + +// StdlibFunction represents a function in a stdlib module. +// +//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). +type StdlibFunction struct { + ReturnType string `json:"return_type"` + Confidence float32 `json:"confidence"` + Params []*FunctionParam `json:"params"` + Source string `json:"source"` + Docstring string `json:"docstring,omitempty"` +} + +// FunctionParam represents a function parameter. +type FunctionParam struct { + Name string `json:"name"` + Type string `json:"type"` + Required bool `json:"required"` +} + +// StdlibClass represents a class in a stdlib module. +type StdlibClass struct { + Type string `json:"type"` + Methods map[string]*StdlibFunction `json:"methods"` + Docstring string `json:"docstring,omitempty"` +} + +// StdlibConstant represents a module-level constant. +// +//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). +type StdlibConstant struct { + Type string `json:"type"` + Value string `json:"value"` + Confidence float32 `json:"confidence"` + PlatformSpecific bool `json:"platform_specific,omitempty"` +} + +// StdlibAttribute represents a module-level attribute (os.environ, sys.modules, etc.). +// +//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). +type StdlibAttribute struct { + Type string `json:"type"` + BehavesLike string `json:"behaves_like,omitempty"` + Confidence float32 `json:"confidence"` + Docstring string `json:"docstring,omitempty"` +} + +// NewStdlibRegistry creates a new stdlib registry. +func NewStdlibRegistry() *StdlibRegistry { + return &StdlibRegistry{ + Modules: make(map[string]*StdlibModule), + } +} + +// GetModule returns the registry for a specific module. +func (r *StdlibRegistry) GetModule(moduleName string) *StdlibModule { + return r.Modules[moduleName] +} + +// HasModule checks if a module exists in the registry. +func (r *StdlibRegistry) HasModule(moduleName string) bool { + _, exists := r.Modules[moduleName] + return exists +} + +// GetFunction returns a function from a module. +func (r *StdlibRegistry) GetFunction(moduleName, functionName string) *StdlibFunction { + module := r.GetModule(moduleName) + if module == nil { + return nil + } + return module.Functions[functionName] +} + +// GetClass returns a class from a module. +func (r *StdlibRegistry) GetClass(moduleName, className string) *StdlibClass { + module := r.GetModule(moduleName) + if module == nil { + return nil + } + return module.Classes[className] +} + +// GetConstant returns a constant from a module. +func (r *StdlibRegistry) GetConstant(moduleName, constantName string) *StdlibConstant { + module := r.GetModule(moduleName) + if module == nil { + return nil + } + return module.Constants[constantName] +} + +// GetAttribute returns an attribute from a module. +func (r *StdlibRegistry) GetAttribute(moduleName, attributeName string) *StdlibAttribute { + module := r.GetModule(moduleName) + if module == nil { + return nil + } + return module.Attributes[attributeName] +} + +// ModuleCount returns the number of loaded modules. +func (r *StdlibRegistry) ModuleCount() int { + return len(r.Modules) +} diff --git a/sourcecode-parser/graph/callgraph/core/stdlib_types_test.go b/sourcecode-parser/graph/callgraph/core/stdlib_types_test.go new file mode 100644 index 00000000..83ce8b26 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/stdlib_types_test.go @@ -0,0 +1,227 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewStdlibRegistry(t *testing.T) { + registry := NewStdlibRegistry() + + assert.NotNil(t, registry) + assert.NotNil(t, registry.Modules) + assert.Equal(t, 0, len(registry.Modules)) +} + +func TestStdlibRegistry_GetModule(t *testing.T) { + registry := NewStdlibRegistry() + + // Add a module + module := &StdlibModule{ + Module: "os", + Functions: make(map[string]*StdlibFunction), + } + registry.Modules["os"] = module + + // Test getting existing module + result := registry.GetModule("os") + assert.NotNil(t, result) + assert.Equal(t, "os", result.Module) + + // Test getting non-existent module + result = registry.GetModule("nonexistent") + assert.Nil(t, result) +} + +func TestStdlibRegistry_HasModule(t *testing.T) { + registry := NewStdlibRegistry() + + // Test non-existent module + assert.False(t, registry.HasModule("os")) + + // Add a module + module := &StdlibModule{Module: "os"} + registry.Modules["os"] = module + + // Test existing module + assert.True(t, registry.HasModule("os")) +} + +func TestStdlibRegistry_GetFunction(t *testing.T) { + registry := NewStdlibRegistry() + + // Add a module with a function + module := &StdlibModule{ + Module: "os", + Functions: map[string]*StdlibFunction{ + "getcwd": { + ReturnType: "builtins.str", + Confidence: 1.0, + }, + }, + } + registry.Modules["os"] = module + + // Test getting existing function + fn := registry.GetFunction("os", "getcwd") + assert.NotNil(t, fn) + assert.Equal(t, "builtins.str", fn.ReturnType) + assert.Equal(t, float32(1.0), fn.Confidence) + + // Test getting non-existent function + fn = registry.GetFunction("os", "nonexistent") + assert.Nil(t, fn) + + // Test getting function from non-existent module + fn = registry.GetFunction("nonexistent", "getcwd") + assert.Nil(t, fn) +} + +func TestStdlibRegistry_GetClass(t *testing.T) { + registry := NewStdlibRegistry() + + // Add a module with a class + module := &StdlibModule{ + Module: "pathlib", + Classes: map[string]*StdlibClass{ + "Path": { + Type: "builtins.type", + Methods: make(map[string]*StdlibFunction), + }, + }, + } + registry.Modules["pathlib"] = module + + // Test getting existing class + cls := registry.GetClass("pathlib", "Path") + assert.NotNil(t, cls) + assert.Equal(t, "builtins.type", cls.Type) + + // Test getting non-existent class + cls = registry.GetClass("pathlib", "NonExistent") + assert.Nil(t, cls) + + // Test getting class from non-existent module + cls = registry.GetClass("nonexistent", "Path") + assert.Nil(t, cls) +} + +func TestStdlibRegistry_GetConstant(t *testing.T) { + registry := NewStdlibRegistry() + + // Add a module with a constant + module := &StdlibModule{ + Module: "os", + Constants: map[string]*StdlibConstant{ + "O_RDONLY": { + Type: "builtins.int", + Value: "0", + Confidence: 1.0, + }, + }, + } + registry.Modules["os"] = module + + // Test getting existing constant + constant := registry.GetConstant("os", "O_RDONLY") + assert.NotNil(t, constant) + assert.Equal(t, "builtins.int", constant.Type) + assert.Equal(t, "0", constant.Value) + + // Test getting non-existent constant + constant = registry.GetConstant("os", "nonexistent") + assert.Nil(t, constant) + + // Test getting constant from non-existent module + constant = registry.GetConstant("nonexistent", "O_RDONLY") + assert.Nil(t, constant) +} + +func TestStdlibRegistry_GetAttribute(t *testing.T) { + registry := NewStdlibRegistry() + + // Add a module with an attribute + module := &StdlibModule{ + Module: "os", + Attributes: map[string]*StdlibAttribute{ + "environ": { + Type: "os._Environ", + BehavesLike: "builtins.dict", + Confidence: 0.9, + }, + }, + } + registry.Modules["os"] = module + + // Test getting existing attribute + attr := registry.GetAttribute("os", "environ") + assert.NotNil(t, attr) + assert.Equal(t, "os._Environ", attr.Type) + assert.Equal(t, "builtins.dict", attr.BehavesLike) + + // Test getting non-existent attribute + attr = registry.GetAttribute("os", "nonexistent") + assert.Nil(t, attr) + + // Test getting attribute from non-existent module + attr = registry.GetAttribute("nonexistent", "environ") + assert.Nil(t, attr) +} + +func TestStdlibRegistry_ModuleCount(t *testing.T) { + registry := NewStdlibRegistry() + + // Initially empty + assert.Equal(t, 0, registry.ModuleCount()) + + // Add modules + registry.Modules["os"] = &StdlibModule{Module: "os"} + assert.Equal(t, 1, registry.ModuleCount()) + + registry.Modules["sys"] = &StdlibModule{Module: "sys"} + assert.Equal(t, 2, registry.ModuleCount()) + + registry.Modules["pathlib"] = &StdlibModule{Module: "pathlib"} + assert.Equal(t, 3, registry.ModuleCount()) +} + +func TestStdlibRegistry_Integration(t *testing.T) { + // Test a complete workflow + registry := NewStdlibRegistry() + + // Add os module with various components + osModule := &StdlibModule{ + Module: "os", + Functions: map[string]*StdlibFunction{ + "getcwd": {ReturnType: "builtins.str", Confidence: 1.0}, + "chdir": {ReturnType: "builtins.NoneType", Confidence: 1.0}, + }, + Constants: map[string]*StdlibConstant{ + "O_RDONLY": {Type: "builtins.int", Value: "0", Confidence: 1.0}, + }, + Attributes: map[string]*StdlibAttribute{ + "environ": {Type: "os._Environ", Confidence: 0.9}, + }, + } + registry.Modules["os"] = osModule + + // Verify module exists + assert.True(t, registry.HasModule("os")) + assert.Equal(t, 1, registry.ModuleCount()) + + // Verify function + getcwd := registry.GetFunction("os", "getcwd") + assert.NotNil(t, getcwd) + assert.Equal(t, "builtins.str", getcwd.ReturnType) + + // Verify constant + oRdonly := registry.GetConstant("os", "O_RDONLY") + assert.NotNil(t, oRdonly) + assert.Equal(t, "0", oRdonly.Value) + + // Verify attribute + environ := registry.GetAttribute("os", "environ") + assert.NotNil(t, environ) + assert.Equal(t, "os._Environ", environ.Type) +} diff --git a/sourcecode-parser/graph/callgraph/core/taint_summary.go b/sourcecode-parser/graph/callgraph/core/taint_summary.go new file mode 100644 index 00000000..f4942bc8 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/taint_summary.go @@ -0,0 +1,238 @@ +package core + +// TaintInfo represents detailed taint tracking information for a single detection. +type TaintInfo struct { + // SourceLine is the line number where taint originated (1-indexed) + SourceLine uint32 + + // SourceVar is the variable name at the taint source + SourceVar string + + // SinkLine is the line number where tainted data reaches a dangerous sink (1-indexed) + SinkLine uint32 + + // SinkVar is the variable name at the sink + SinkVar string + + // SinkCall is the dangerous function/method call at the sink + // Examples: "execute", "eval", "os.system" + SinkCall string + + // PropagationPath is the list of variables through which taint propagated + // Example: ["user_input", "data", "query"] shows user_input -> data -> query + PropagationPath []string + + // Confidence is a score from 0.0 to 1.0 indicating detection confidence + // 1.0 = high confidence (direct flow) + // 0.7 = medium confidence (through stdlib function) + // 0.5 = low confidence (through third-party library) + // 0.0 = no taint detected + Confidence float64 + + // Sanitized indicates if a sanitizer was detected in the propagation path + // If true, the taint was neutralized and should not trigger a finding + Sanitized bool + + // SanitizerLine is the line number where sanitization occurred (if Sanitized == true) + SanitizerLine uint32 + + // SanitizerCall is the sanitizer function that was called + // Examples: "escape_html", "quote_sql", "validate_email" + SanitizerCall string +} + +// IsTainted returns true if this TaintInfo represents actual taint (confidence > 0). +func (ti *TaintInfo) IsTainted() bool { + return ti.Confidence > 0.0 && !ti.Sanitized +} + +// IsHighConfidence returns true if confidence >= 0.8. +func (ti *TaintInfo) IsHighConfidence() bool { + return ti.Confidence >= 0.8 +} + +// IsMediumConfidence returns true if 0.5 <= confidence < 0.8. +func (ti *TaintInfo) IsMediumConfidence() bool { + return ti.Confidence >= 0.5 && ti.Confidence < 0.8 +} + +// IsLowConfidence returns true if 0.0 < confidence < 0.5. +func (ti *TaintInfo) IsLowConfidence() bool { + return ti.Confidence > 0.0 && ti.Confidence < 0.5 +} + +// TaintSummary represents the complete taint analysis results for a function. +type TaintSummary struct { + // FunctionFQN is the fully qualified name of the analyzed function + // Format: "module.Class.method" or "module.function" + FunctionFQN string + + // TaintedVars maps variable names to their taint information + // If a variable is not in this map, it is considered untainted + // Multiple TaintInfo entries indicate multiple taint paths to the same variable + TaintedVars map[string][]*TaintInfo + + // Detections is a list of all taint flows that reached a dangerous sink + // These represent potential security vulnerabilities + Detections []*TaintInfo + + // TaintedParams tracks which function parameters are tainted (by parameter name) + // Used for inter-procedural analysis + TaintedParams []string + + // TaintedReturn indicates if the function's return value is tainted + TaintedReturn bool + + // ReturnTaintInfo provides details if TaintedReturn is true + ReturnTaintInfo *TaintInfo + + // AnalysisError indicates if the analysis failed for this function + // If true, the summary is incomplete + AnalysisError bool + + // ErrorMessage contains the error description if AnalysisError is true + ErrorMessage string +} + +// NewTaintSummary creates an empty taint summary for a function. +func NewTaintSummary(functionFQN string) *TaintSummary { + return &TaintSummary{ + FunctionFQN: functionFQN, + TaintedVars: make(map[string][]*TaintInfo), + Detections: make([]*TaintInfo, 0), + TaintedParams: make([]string, 0), + } +} + +// AddTaintedVar records taint information for a variable. +func (ts *TaintSummary) AddTaintedVar(varName string, taintInfo *TaintInfo) { + if varName == "" || taintInfo == nil { + return + } + ts.TaintedVars[varName] = append(ts.TaintedVars[varName], taintInfo) +} + +// GetTaintInfo retrieves all taint information for a variable. +// Returns nil if variable is not tainted. +func (ts *TaintSummary) GetTaintInfo(varName string) []*TaintInfo { + return ts.TaintedVars[varName] +} + +// IsTainted checks if a variable is tainted (has at least one unsanitized taint path). +func (ts *TaintSummary) IsTainted(varName string) bool { + taintInfos := ts.TaintedVars[varName] + for _, info := range taintInfos { + if info.IsTainted() { + return true + } + } + return false +} + +// AddDetection records a taint flow that reached a dangerous sink. +func (ts *TaintSummary) AddDetection(detection *TaintInfo) { + if detection == nil { + return + } + ts.Detections = append(ts.Detections, detection) +} + +// HasDetections returns true if any taint flows reached dangerous sinks. +func (ts *TaintSummary) HasDetections() bool { + return len(ts.Detections) > 0 +} + +// GetHighConfidenceDetections returns detections with confidence >= 0.8. +func (ts *TaintSummary) GetHighConfidenceDetections() []*TaintInfo { + result := make([]*TaintInfo, 0) + for _, detection := range ts.Detections { + if detection.IsHighConfidence() { + result = append(result, detection) + } + } + return result +} + +// GetMediumConfidenceDetections returns detections with 0.5 <= confidence < 0.8. +func (ts *TaintSummary) GetMediumConfidenceDetections() []*TaintInfo { + result := make([]*TaintInfo, 0) + for _, detection := range ts.Detections { + if detection.IsMediumConfidence() { + result = append(result, detection) + } + } + return result +} + +// GetLowConfidenceDetections returns detections with 0.0 < confidence < 0.5. +func (ts *TaintSummary) GetLowConfidenceDetections() []*TaintInfo { + result := make([]*TaintInfo, 0) + for _, detection := range ts.Detections { + if detection.IsLowConfidence() { + result = append(result, detection) + } + } + return result +} + +// MarkTaintedParam marks a function parameter as tainted. +func (ts *TaintSummary) MarkTaintedParam(paramName string) { + if paramName == "" { + return + } + + // Check if already marked + for _, p := range ts.TaintedParams { + if p == paramName { + return + } + } + + ts.TaintedParams = append(ts.TaintedParams, paramName) +} + +// IsParamTainted checks if a function parameter is tainted. +func (ts *TaintSummary) IsParamTainted(paramName string) bool { + for _, p := range ts.TaintedParams { + if p == paramName { + return true + } + } + return false +} + +// MarkReturnTainted marks the function's return value as tainted. +func (ts *TaintSummary) MarkReturnTainted(taintInfo *TaintInfo) { + ts.TaintedReturn = true + ts.ReturnTaintInfo = taintInfo +} + +// SetError marks the analysis as failed with an error message. +func (ts *TaintSummary) SetError(errorMsg string) { + ts.AnalysisError = true + ts.ErrorMessage = errorMsg +} + +// IsComplete returns true if analysis completed without errors. +func (ts *TaintSummary) IsComplete() bool { + return !ts.AnalysisError +} + +// GetTaintedVarCount returns the number of distinct tainted variables. +func (ts *TaintSummary) GetTaintedVarCount() int { + count := 0 + for _, taintInfos := range ts.TaintedVars { + for _, info := range taintInfos { + if info.IsTainted() { + count++ + break // Count each variable only once + } + } + } + return count +} + +// GetDetectionCount returns the total number of detections. +func (ts *TaintSummary) GetDetectionCount() int { + return len(ts.Detections) +} diff --git a/sourcecode-parser/graph/callgraph/core/taint_summary_test.go b/sourcecode-parser/graph/callgraph/core/taint_summary_test.go new file mode 100644 index 00000000..92e99f0c --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/taint_summary_test.go @@ -0,0 +1,507 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTaintInfoIsTainted(t *testing.T) { + tests := []struct { + name string + info *TaintInfo + expected bool + }{ + { + name: "high confidence taint", + info: &TaintInfo{ + Confidence: 1.0, + Sanitized: false, + }, + expected: true, + }, + { + name: "medium confidence taint", + info: &TaintInfo{ + Confidence: 0.7, + Sanitized: false, + }, + expected: true, + }, + { + name: "sanitized taint", + info: &TaintInfo{ + Confidence: 1.0, + Sanitized: true, + }, + expected: false, + }, + { + name: "zero confidence", + info: &TaintInfo{ + Confidence: 0.0, + Sanitized: false, + }, + expected: false, + }, + { + name: "low confidence but not sanitized", + info: &TaintInfo{ + Confidence: 0.3, + Sanitized: false, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.info.IsTainted() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTaintInfoIsHighConfidence(t *testing.T) { + tests := []struct { + name string + confidence float64 + expected bool + }{ + {name: "perfect confidence", confidence: 1.0, expected: true}, + {name: "high confidence", confidence: 0.9, expected: true}, + {name: "exactly 0.8", confidence: 0.8, expected: true}, + {name: "just below threshold", confidence: 0.79, expected: false}, + {name: "medium confidence", confidence: 0.6, expected: false}, + {name: "low confidence", confidence: 0.3, expected: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := &TaintInfo{Confidence: tt.confidence} + result := info.IsHighConfidence() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTaintInfoIsMediumConfidence(t *testing.T) { + tests := []struct { + name string + confidence float64 + expected bool + }{ + {name: "high confidence", confidence: 1.0, expected: false}, + {name: "just below high", confidence: 0.79, expected: true}, + {name: "mid range", confidence: 0.6, expected: true}, + {name: "exactly 0.5", confidence: 0.5, expected: true}, + {name: "just below medium", confidence: 0.49, expected: false}, + {name: "low confidence", confidence: 0.3, expected: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := &TaintInfo{Confidence: tt.confidence} + result := info.IsMediumConfidence() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTaintInfoIsLowConfidence(t *testing.T) { + tests := []struct { + name string + confidence float64 + expected bool + }{ + {name: "medium confidence", confidence: 0.6, expected: false}, + {name: "just below medium", confidence: 0.49, expected: true}, + {name: "low confidence", confidence: 0.3, expected: true}, + {name: "very low", confidence: 0.1, expected: true}, + {name: "zero confidence", confidence: 0.0, expected: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := &TaintInfo{Confidence: tt.confidence} + result := info.IsLowConfidence() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestNewTaintSummary(t *testing.T) { + summary := NewTaintSummary("module.Class.method") + + assert.Equal(t, "module.Class.method", summary.FunctionFQN) + assert.NotNil(t, summary.TaintedVars) + assert.Equal(t, 0, len(summary.TaintedVars)) + assert.NotNil(t, summary.Detections) + assert.Equal(t, 0, len(summary.Detections)) + assert.NotNil(t, summary.TaintedParams) + assert.Equal(t, 0, len(summary.TaintedParams)) + assert.False(t, summary.TaintedReturn) + assert.Nil(t, summary.ReturnTaintInfo) + assert.False(t, summary.AnalysisError) + assert.Equal(t, "", summary.ErrorMessage) +} + +func TestTaintSummaryAddTaintedVar(t *testing.T) { + summary := NewTaintSummary("test.function") + + taint1 := &TaintInfo{ + SourceLine: 1, + SourceVar: "input", + Confidence: 1.0, + } + + taint2 := &TaintInfo{ + SourceLine: 2, + SourceVar: "input2", + Confidence: 0.7, + } + + // Add first taint + summary.AddTaintedVar("x", taint1) + assert.Equal(t, 1, len(summary.TaintedVars["x"])) + assert.Equal(t, taint1, summary.TaintedVars["x"][0]) + + // Add second taint to same variable + summary.AddTaintedVar("x", taint2) + assert.Equal(t, 2, len(summary.TaintedVars["x"])) + assert.Equal(t, taint2, summary.TaintedVars["x"][1]) + + // Test empty variable name (should be ignored) + summary.AddTaintedVar("", taint1) + _, exists := summary.TaintedVars[""] + assert.False(t, exists) + + // Test nil taint info (should be ignored) + summary.AddTaintedVar("y", nil) + _, exists = summary.TaintedVars["y"] + assert.False(t, exists) +} + +func TestTaintSummaryGetTaintInfo(t *testing.T) { + summary := NewTaintSummary("test.function") + + taint := &TaintInfo{ + SourceLine: 1, + SourceVar: "input", + Confidence: 1.0, + } + + summary.AddTaintedVar("x", taint) + + // Get existing taint + result := summary.GetTaintInfo("x") + assert.NotNil(t, result) + assert.Equal(t, 1, len(result)) + assert.Equal(t, taint, result[0]) + + // Get non-existent variable + nonExistent := summary.GetTaintInfo("nonexistent") + assert.Nil(t, nonExistent) +} + +func TestTaintSummaryIsTainted(t *testing.T) { + summary := NewTaintSummary("test.function") + + // Add tainted variable + taint1 := &TaintInfo{ + Confidence: 1.0, + Sanitized: false, + } + summary.AddTaintedVar("x", taint1) + assert.True(t, summary.IsTainted("x")) + + // Add sanitized taint + taint2 := &TaintInfo{ + Confidence: 1.0, + Sanitized: true, + } + summary.AddTaintedVar("y", taint2) + assert.False(t, summary.IsTainted("y")) + + // Add variable with both tainted and sanitized paths + summary.AddTaintedVar("z", taint1) // tainted + summary.AddTaintedVar("z", taint2) // sanitized + assert.True(t, summary.IsTainted("z")) // Should return true if ANY path is tainted + + // Check non-existent variable + assert.False(t, summary.IsTainted("nonexistent")) +} + +func TestTaintSummaryAddDetection(t *testing.T) { + summary := NewTaintSummary("test.function") + + detection1 := &TaintInfo{ + SourceLine: 1, + SinkLine: 5, + SinkCall: "execute", + Confidence: 1.0, + } + + detection2 := &TaintInfo{ + SourceLine: 2, + SinkLine: 6, + SinkCall: "eval", + Confidence: 0.8, + } + + summary.AddDetection(detection1) + assert.Equal(t, 1, len(summary.Detections)) + assert.Equal(t, detection1, summary.Detections[0]) + + summary.AddDetection(detection2) + assert.Equal(t, 2, len(summary.Detections)) + assert.Equal(t, detection2, summary.Detections[1]) + + // Test nil detection (should be ignored) + summary.AddDetection(nil) + assert.Equal(t, 2, len(summary.Detections)) +} + +func TestTaintSummaryHasDetections(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.False(t, summary.HasDetections()) + + detection := &TaintInfo{ + SourceLine: 1, + SinkLine: 5, + Confidence: 1.0, + } + summary.AddDetection(detection) + + assert.True(t, summary.HasDetections()) +} + +func TestTaintSummaryGetHighConfidenceDetections(t *testing.T) { + summary := NewTaintSummary("test.function") + + high1 := &TaintInfo{Confidence: 1.0} + high2 := &TaintInfo{Confidence: 0.9} + medium := &TaintInfo{Confidence: 0.6} + low := &TaintInfo{Confidence: 0.3} + + summary.AddDetection(high1) + summary.AddDetection(medium) + summary.AddDetection(high2) + summary.AddDetection(low) + + highConf := summary.GetHighConfidenceDetections() + assert.Equal(t, 2, len(highConf)) + assert.Equal(t, high1, highConf[0]) + assert.Equal(t, high2, highConf[1]) +} + +func TestTaintSummaryGetMediumConfidenceDetections(t *testing.T) { + summary := NewTaintSummary("test.function") + + high := &TaintInfo{Confidence: 1.0} + medium1 := &TaintInfo{Confidence: 0.7} + medium2 := &TaintInfo{Confidence: 0.5} + low := &TaintInfo{Confidence: 0.3} + + summary.AddDetection(high) + summary.AddDetection(medium1) + summary.AddDetection(low) + summary.AddDetection(medium2) + + mediumConf := summary.GetMediumConfidenceDetections() + assert.Equal(t, 2, len(mediumConf)) + assert.Equal(t, medium1, mediumConf[0]) + assert.Equal(t, medium2, mediumConf[1]) +} + +func TestTaintSummaryGetLowConfidenceDetections(t *testing.T) { + summary := NewTaintSummary("test.function") + + high := &TaintInfo{Confidence: 1.0} + medium := &TaintInfo{Confidence: 0.6} + low1 := &TaintInfo{Confidence: 0.4} + low2 := &TaintInfo{Confidence: 0.1} + + summary.AddDetection(high) + summary.AddDetection(low1) + summary.AddDetection(medium) + summary.AddDetection(low2) + + lowConf := summary.GetLowConfidenceDetections() + assert.Equal(t, 2, len(lowConf)) + assert.Equal(t, low1, lowConf[0]) + assert.Equal(t, low2, lowConf[1]) +} + +func TestTaintSummaryMarkTaintedParam(t *testing.T) { + summary := NewTaintSummary("test.function") + + // Mark first param + summary.MarkTaintedParam("param1") + assert.Equal(t, 1, len(summary.TaintedParams)) + assert.Equal(t, "param1", summary.TaintedParams[0]) + + // Mark second param + summary.MarkTaintedParam("param2") + assert.Equal(t, 2, len(summary.TaintedParams)) + assert.Equal(t, "param2", summary.TaintedParams[1]) + + // Try to mark same param again (should not duplicate) + summary.MarkTaintedParam("param1") + assert.Equal(t, 2, len(summary.TaintedParams)) + + // Test empty param name (should be ignored) + summary.MarkTaintedParam("") + assert.Equal(t, 2, len(summary.TaintedParams)) +} + +func TestTaintSummaryIsParamTainted(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.False(t, summary.IsParamTainted("param1")) + + summary.MarkTaintedParam("param1") + assert.True(t, summary.IsParamTainted("param1")) + assert.False(t, summary.IsParamTainted("param2")) + + summary.MarkTaintedParam("param2") + assert.True(t, summary.IsParamTainted("param2")) +} + +func TestTaintSummaryMarkReturnTainted(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.False(t, summary.TaintedReturn) + assert.Nil(t, summary.ReturnTaintInfo) + + taint := &TaintInfo{ + SourceLine: 1, + Confidence: 1.0, + } + + summary.MarkReturnTainted(taint) + assert.True(t, summary.TaintedReturn) + assert.Equal(t, taint, summary.ReturnTaintInfo) +} + +func TestTaintSummarySetError(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.False(t, summary.AnalysisError) + assert.Equal(t, "", summary.ErrorMessage) + + summary.SetError("parse error") + assert.True(t, summary.AnalysisError) + assert.Equal(t, "parse error", summary.ErrorMessage) +} + +func TestTaintSummaryIsComplete(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.True(t, summary.IsComplete()) + + summary.SetError("error") + assert.False(t, summary.IsComplete()) +} + +func TestTaintSummaryGetTaintedVarCount(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.Equal(t, 0, summary.GetTaintedVarCount()) + + // Add tainted variable + taint1 := &TaintInfo{Confidence: 1.0, Sanitized: false} + summary.AddTaintedVar("x", taint1) + assert.Equal(t, 1, summary.GetTaintedVarCount()) + + // Add another taint to same variable (should still count as 1) + taint2 := &TaintInfo{Confidence: 0.7, Sanitized: false} + summary.AddTaintedVar("x", taint2) + assert.Equal(t, 1, summary.GetTaintedVarCount()) + + // Add tainted second variable + summary.AddTaintedVar("y", taint1) + assert.Equal(t, 2, summary.GetTaintedVarCount()) + + // Add sanitized variable (should not count) + sanitized := &TaintInfo{Confidence: 1.0, Sanitized: true} + summary.AddTaintedVar("z", sanitized) + assert.Equal(t, 2, summary.GetTaintedVarCount()) +} + +func TestTaintSummaryGetDetectionCount(t *testing.T) { + summary := NewTaintSummary("test.function") + + assert.Equal(t, 0, summary.GetDetectionCount()) + + detection1 := &TaintInfo{Confidence: 1.0} + summary.AddDetection(detection1) + assert.Equal(t, 1, summary.GetDetectionCount()) + + detection2 := &TaintInfo{Confidence: 0.8} + summary.AddDetection(detection2) + assert.Equal(t, 2, summary.GetDetectionCount()) +} + +func TestTaintSummaryComplexScenario(t *testing.T) { + // Simulate a real security finding scenario + summary := NewTaintSummary("app.views.process_payment") + + // Taint flows from user input + userInputTaint := &TaintInfo{ + SourceLine: 10, + SourceVar: "request.GET['amount']", + SinkLine: 25, + SinkVar: "query", + SinkCall: "cursor.execute", + PropagationPath: []string{"user_amount", "amount", "query"}, + Confidence: 1.0, + Sanitized: false, + } + + // Track the variable propagation + summary.AddTaintedVar("user_amount", &TaintInfo{ + SourceLine: 10, + SourceVar: "request.GET['amount']", + Confidence: 1.0, + }) + + summary.AddTaintedVar("amount", &TaintInfo{ + SourceLine: 15, + SourceVar: "user_amount", + Confidence: 1.0, + }) + + summary.AddTaintedVar("query", &TaintInfo{ + SourceLine: 20, + SourceVar: "amount", + Confidence: 1.0, + }) + + // Record the detection + summary.AddDetection(userInputTaint) + + // Mark the request parameter as tainted + summary.MarkTaintedParam("request") + + // Verify the summary + assert.True(t, summary.IsTainted("user_amount")) + assert.True(t, summary.IsTainted("amount")) + assert.True(t, summary.IsTainted("query")) + assert.Equal(t, 3, summary.GetTaintedVarCount()) + assert.True(t, summary.HasDetections()) + assert.Equal(t, 1, summary.GetDetectionCount()) + assert.Equal(t, 1, len(summary.GetHighConfidenceDetections())) + assert.True(t, summary.IsParamTainted("request")) + assert.True(t, summary.IsComplete()) + + // Verify the detection details + detection := summary.Detections[0] + assert.Equal(t, uint32(10), detection.SourceLine) + assert.Equal(t, uint32(25), detection.SinkLine) + assert.Equal(t, "cursor.execute", detection.SinkCall) + assert.Equal(t, 3, len(detection.PropagationPath)) + assert.True(t, detection.IsHighConfidence()) + assert.False(t, detection.Sanitized) +} diff --git a/sourcecode-parser/graph/callgraph/core/types.go b/sourcecode-parser/graph/callgraph/core/types.go new file mode 100644 index 00000000..769c3001 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/types.go @@ -0,0 +1,292 @@ +package core + +import ( + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph" +) + +// Location represents a source code location for tracking call sites. +// This enables precise mapping of where calls occur in the source code. +type Location struct { + File string // Absolute path to the source file + Line int // Line number (1-indexed) + Column int // Column number (1-indexed) +} + +// CallSite represents a function/method call location in the source code. +// It captures both the syntactic information (where the call is) and +// semantic information (what is being called and with what arguments). +type CallSite struct { + Target string // The name of the function being called (e.g., "eval", "utils.sanitize") + Location Location // Where this call occurs in the source code + Arguments []Argument // Arguments passed to the call + Resolved bool // Whether we successfully resolved this call to a definition + TargetFQN string // Fully qualified name after resolution (e.g., "myapp.utils.sanitize") + FailureReason string // Why resolution failed (empty if Resolved=true) + + // Phase 2: Type inference metadata + ResolvedViaTypeInference bool // Was this resolved using type inference? + InferredType string // The inferred type FQN (e.g., "builtins.str", "test.User") + TypeConfidence float32 // Confidence score of the type inference (0.0-1.0) + TypeSource string // How type was inferred (e.g., "literal", "return_type", "class_instantiation") +} + +// Resolution failure reason categories for diagnostics: +// - "external_framework" - Call to Django, REST framework, pytest, stdlib, etc. +// - "orm_pattern" - Django ORM patterns like Model.objects.filter() +// - "attribute_chain" - Method calls on return values like response.json() +// - "variable_method" - Method calls on variables like value.split() +// - "super_call" - Calls via super() to parent class methods +// - "not_in_imports" - Simple function call not found in imports +// - "unknown" - Unresolved for other reasons + +// Argument represents a single argument passed to a function call. +// Tracks both the value/expression and metadata about the argument. +type Argument struct { + Value string // The argument expression as a string + IsVariable bool // Whether this argument is a variable reference + Position int // Position in the argument list (0-indexed) +} + +// CallGraph represents the complete call graph of a program. +// It maps function definitions to their call sites and provides +// both forward (callers → callees) and reverse (callees → callers) edges. +// +// Example: +// +// Function A calls B and C +// edges: {"A": ["B", "C"]} +// reverseEdges: {"B": ["A"], "C": ["A"]} +type CallGraph struct { + // Forward edges: maps fully qualified function name to list of functions it calls + // Key: caller FQN (e.g., "myapp.views.get_user") + // Value: list of callee FQNs (e.g., ["myapp.db.query", "myapp.utils.sanitize"]) + Edges map[string][]string + + // Reverse edges: maps fully qualified function name to list of functions that call it + // Useful for backward slicing and finding all callers of a function + // Key: callee FQN + // Value: list of caller FQNs + ReverseEdges map[string][]string + + // Detailed call site information for each function + // Key: caller FQN + // Value: list of all call sites within that function + CallSites map[string][]CallSite + + // Map from fully qualified name to the actual function node in the graph + // This allows quick lookup of function metadata (line number, file, etc.) + Functions map[string]*graph.Node + + // Taint summaries for each function (intra-procedural analysis results) + // Key: function FQN + // Value: TaintSummary with taint flow information + Summaries map[string]*TaintSummary +} + +// NewCallGraph creates and initializes a new CallGraph instance. +// All maps are pre-allocated to avoid nil pointer issues. +func NewCallGraph() *CallGraph { + return &CallGraph{ + Edges: make(map[string][]string), + ReverseEdges: make(map[string][]string), + CallSites: make(map[string][]CallSite), + Functions: make(map[string]*graph.Node), + Summaries: make(map[string]*TaintSummary), + } +} + +// AddEdge adds a directed edge from caller to callee in the call graph. +// Automatically updates both forward and reverse edges. +// +// Parameters: +// - caller: fully qualified name of the calling function +// - callee: fully qualified name of the called function +func (cg *CallGraph) AddEdge(caller, callee string) { + // Add forward edge + if !contains(cg.Edges[caller], callee) { + cg.Edges[caller] = append(cg.Edges[caller], callee) + } + + // Add reverse edge + if !contains(cg.ReverseEdges[callee], caller) { + cg.ReverseEdges[callee] = append(cg.ReverseEdges[callee], caller) + } +} + +// AddCallSite adds a call site to the call graph. +// This stores detailed information about where and how a function is called. +// +// Parameters: +// - caller: fully qualified name of the calling function +// - callSite: detailed information about the call +func (cg *CallGraph) AddCallSite(caller string, callSite CallSite) { + cg.CallSites[caller] = append(cg.CallSites[caller], callSite) +} + +// GetCallers returns all functions that call the specified function. +// Uses the reverse edges for efficient lookup. +// +// Parameters: +// - callee: fully qualified name of the function +// +// Returns: +// - list of caller FQNs, or empty slice if no callers found +func (cg *CallGraph) GetCallers(callee string) []string { + if callers, ok := cg.ReverseEdges[callee]; ok { + return callers + } + return []string{} +} + +// GetCallees returns all functions called by the specified function. +// Uses the forward edges for efficient lookup. +// +// Parameters: +// - caller: fully qualified name of the function +// +// Returns: +// - list of callee FQNs, or empty slice if no callees found +func (cg *CallGraph) GetCallees(caller string) []string { + if callees, ok := cg.Edges[caller]; ok { + return callees + } + return []string{} +} + +// ModuleRegistry maintains the mapping between Python file paths and module paths. +// This is essential for resolving imports and building fully qualified names. +// +// Example: +// +// File: /project/myapp/utils/helpers.py +// Module: myapp.utils.helpers +type ModuleRegistry struct { + // Maps fully qualified module path to absolute file path + // Key: "myapp.utils.helpers" + // Value: "/absolute/path/to/myapp/utils/helpers.py" + Modules map[string]string + + // Maps absolute file path to fully qualified module path (reverse of Modules) + // Key: "/absolute/path/to/myapp/utils/helpers.py" + // Value: "myapp.utils.helpers" + // Used for resolving relative imports + FileToModule map[string]string + + // Maps short module names to all matching file paths (handles ambiguity) + // Key: "helpers" + // Value: ["/path/to/myapp/utils/helpers.py", "/path/to/lib/helpers.py"] + ShortNames map[string][]string + + // Cache for resolved imports to avoid redundant lookups + // Key: import string (e.g., "utils.helpers") + // Value: fully qualified module path + ResolvedImports map[string]string +} + +// NewModuleRegistry creates and initializes a new ModuleRegistry instance. +func NewModuleRegistry() *ModuleRegistry { + return &ModuleRegistry{ + Modules: make(map[string]string), + FileToModule: make(map[string]string), + ShortNames: make(map[string][]string), + ResolvedImports: make(map[string]string), + } +} + +// AddModule registers a module in the registry. +// Automatically indexes both the full module path and the short name. +// +// Parameters: +// - modulePath: fully qualified module path (e.g., "myapp.utils.helpers") +// - filePath: absolute file path (e.g., "/project/myapp/utils/helpers.py") +func (mr *ModuleRegistry) AddModule(modulePath, filePath string) { + mr.Modules[modulePath] = filePath + mr.FileToModule[filePath] = modulePath + + // Extract short name (last component) + // "myapp.utils.helpers" → "helpers" + shortName := extractShortName(modulePath) + if !containsString(mr.ShortNames[shortName], filePath) { + mr.ShortNames[shortName] = append(mr.ShortNames[shortName], filePath) + } +} + +// GetModulePath returns the file path for a given module, if it exists. +// +// Parameters: +// - modulePath: fully qualified module path +// +// Returns: +// - file path and true if found, empty string and false otherwise +func (mr *ModuleRegistry) GetModulePath(modulePath string) (string, bool) { + filePath, ok := mr.Modules[modulePath] + return filePath, ok +} + +// ImportMap represents the import statements in a single Python file. +// Maps local aliases to fully qualified module paths. +// +// Example: +// +// File contains: from myapp.utils import sanitize as clean +// Imports: {"clean": "myapp.utils.sanitize"} +type ImportMap struct { + FilePath string // Absolute path to the file containing these imports + Imports map[string]string // Maps alias/name to fully qualified module path +} + +// NewImportMap creates and initializes a new ImportMap instance. +func NewImportMap(filePath string) *ImportMap { + return &ImportMap{ + FilePath: filePath, + Imports: make(map[string]string), + } +} + +// AddImport adds an import mapping to the import map. +// +// Parameters: +// - alias: the local name used in the file (e.g., "clean", "sanitize", "utils") +// - fqn: the fully qualified name (e.g., "myapp.utils.sanitize") +func (im *ImportMap) AddImport(alias, fqn string) { + im.Imports[alias] = fqn +} + +// Resolve looks up the fully qualified name for a local alias. +// +// Parameters: +// - alias: the local name to resolve +// +// Returns: +// - fully qualified name and true if found, empty string and false otherwise +func (im *ImportMap) Resolve(alias string) (string, bool) { + fqn, ok := im.Imports[alias] + return fqn, ok +} + +// Helper function to check if a string slice contains a specific string. +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +// Helper function alias for consistency. +func containsString(slice []string, item string) bool { + return contains(slice, item) +} + +// Helper function to extract the last component of a dotted path. +// Example: "myapp.utils.helpers" → "helpers". +func extractShortName(modulePath string) string { + // Find last dot + for i := len(modulePath) - 1; i >= 0; i-- { + if modulePath[i] == '.' { + return modulePath[i+1:] + } + } + return modulePath +} diff --git a/sourcecode-parser/graph/callgraph/core/types_test.go b/sourcecode-parser/graph/callgraph/core/types_test.go new file mode 100644 index 00000000..9e6c8bdc --- /dev/null +++ b/sourcecode-parser/graph/callgraph/core/types_test.go @@ -0,0 +1,576 @@ +package core + +import ( + "testing" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph" + "github.com/stretchr/testify/assert" +) + +func TestNewCallGraph(t *testing.T) { + cg := NewCallGraph() + + assert.NotNil(t, cg) + assert.NotNil(t, cg.Edges) + assert.NotNil(t, cg.ReverseEdges) + assert.NotNil(t, cg.CallSites) + assert.NotNil(t, cg.Functions) + assert.Equal(t, 0, len(cg.Edges)) + assert.Equal(t, 0, len(cg.ReverseEdges)) +} + +func TestCallGraph_AddEdge(t *testing.T) { + tests := []struct { + name string + caller string + callee string + }{ + { + name: "Add single edge", + caller: "myapp.views.get_user", + callee: "myapp.db.query", + }, + { + name: "Add edge with qualified names", + caller: "myapp.utils.helpers.sanitize_input", + callee: "myapp.utils.validators.validate_string", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cg := NewCallGraph() + cg.AddEdge(tt.caller, tt.callee) + + // Check forward edge + assert.Contains(t, cg.Edges[tt.caller], tt.callee) + assert.Equal(t, 1, len(cg.Edges[tt.caller])) + + // Check reverse edge + assert.Contains(t, cg.ReverseEdges[tt.callee], tt.caller) + assert.Equal(t, 1, len(cg.ReverseEdges[tt.callee])) + }) + } +} + +func TestCallGraph_AddEdge_MultipleCalls(t *testing.T) { + cg := NewCallGraph() + caller := "myapp.views.process" + callees := []string{ + "myapp.db.query", + "myapp.utils.sanitize", + "myapp.logging.log", + } + + for _, callee := range callees { + cg.AddEdge(caller, callee) + } + + // Verify all forward edges + assert.Equal(t, 3, len(cg.Edges[caller])) + for _, callee := range callees { + assert.Contains(t, cg.Edges[caller], callee) + } + + // Verify all reverse edges + for _, callee := range callees { + assert.Contains(t, cg.ReverseEdges[callee], caller) + assert.Equal(t, 1, len(cg.ReverseEdges[callee])) + } +} + +func TestCallGraph_AddEdge_Duplicate(t *testing.T) { + cg := NewCallGraph() + caller := "myapp.views.get_user" + callee := "myapp.db.query" + + // Add same edge twice + cg.AddEdge(caller, callee) + cg.AddEdge(caller, callee) + + // Should only appear once + assert.Equal(t, 1, len(cg.Edges[caller])) + assert.Contains(t, cg.Edges[caller], callee) +} + +func TestCallGraph_AddCallSite(t *testing.T) { + cg := NewCallGraph() + caller := "myapp.views.get_user" + callSite := CallSite{ + Target: "query", + Location: Location{ + File: "/path/to/views.py", + Line: 42, + Column: 10, + }, + Arguments: []Argument{ + {Value: "user_id", IsVariable: true, Position: 0}, + }, + Resolved: true, + TargetFQN: "myapp.db.query", + } + + cg.AddCallSite(caller, callSite) + + assert.Equal(t, 1, len(cg.CallSites[caller])) + assert.Equal(t, callSite.Target, cg.CallSites[caller][0].Target) + assert.Equal(t, callSite.Location.Line, cg.CallSites[caller][0].Location.Line) +} + +func TestCallGraph_AddCallSite_Multiple(t *testing.T) { + cg := NewCallGraph() + caller := "myapp.views.process" + + callSites := []CallSite{ + { + Target: "query", + Location: Location{File: "/path/to/views.py", Line: 10, Column: 5}, + Resolved: true, + TargetFQN: "myapp.db.query", + }, + { + Target: "sanitize", + Location: Location{File: "/path/to/views.py", Line: 15, Column: 8}, + Resolved: true, + TargetFQN: "myapp.utils.sanitize", + }, + } + + for _, cs := range callSites { + cg.AddCallSite(caller, cs) + } + + assert.Equal(t, 2, len(cg.CallSites[caller])) +} + +func TestCallGraph_GetCallers(t *testing.T) { + cg := NewCallGraph() + + // Set up call graph: + // main → helper + // main → util + // process → helper + cg.AddEdge("myapp.main", "myapp.helper") + cg.AddEdge("myapp.main", "myapp.util") + cg.AddEdge("myapp.process", "myapp.helper") + + tests := []struct { + name string + callee string + expectedCount int + expectedCallers []string + }{ + { + name: "Function with multiple callers", + callee: "myapp.helper", + expectedCount: 2, + expectedCallers: []string{"myapp.main", "myapp.process"}, + }, + { + name: "Function with single caller", + callee: "myapp.util", + expectedCount: 1, + expectedCallers: []string{"myapp.main"}, + }, + { + name: "Function with no callers", + callee: "myapp.main", + expectedCount: 0, + expectedCallers: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + callers := cg.GetCallers(tt.callee) + assert.Equal(t, tt.expectedCount, len(callers)) + for _, expectedCaller := range tt.expectedCallers { + assert.Contains(t, callers, expectedCaller) + } + }) + } +} + +func TestCallGraph_GetCallees(t *testing.T) { + cg := NewCallGraph() + + // Set up call graph: + // main → helper, util, logger + // process → db + cg.AddEdge("myapp.main", "myapp.helper") + cg.AddEdge("myapp.main", "myapp.util") + cg.AddEdge("myapp.main", "myapp.logger") + cg.AddEdge("myapp.process", "myapp.db") + + tests := []struct { + name string + caller string + expectedCount int + expectedCallees []string + }{ + { + name: "Function with multiple callees", + caller: "myapp.main", + expectedCount: 3, + expectedCallees: []string{"myapp.helper", "myapp.util", "myapp.logger"}, + }, + { + name: "Function with single callee", + caller: "myapp.process", + expectedCount: 1, + expectedCallees: []string{"myapp.db"}, + }, + { + name: "Function with no callees", + caller: "myapp.helper", + expectedCount: 0, + expectedCallees: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + callees := cg.GetCallees(tt.caller) + assert.Equal(t, tt.expectedCount, len(callees)) + for _, expectedCallee := range tt.expectedCallees { + assert.Contains(t, callees, expectedCallee) + } + }) + } +} + +func TestNewModuleRegistry(t *testing.T) { + mr := NewModuleRegistry() + + assert.NotNil(t, mr) + assert.NotNil(t, mr.Modules) + assert.NotNil(t, mr.ShortNames) + assert.NotNil(t, mr.ResolvedImports) + assert.Equal(t, 0, len(mr.Modules)) +} + +func TestModuleRegistry_AddModule(t *testing.T) { + tests := []struct { + name string + modulePath string + filePath string + shortName string + }{ + { + name: "Simple module", + modulePath: "myapp.views", + filePath: "/path/to/myapp/views.py", + shortName: "views", + }, + { + name: "Nested module", + modulePath: "myapp.utils.helpers", + filePath: "/path/to/myapp/utils/helpers.py", + shortName: "helpers", + }, + { + name: "Package init", + modulePath: "myapp.utils", + filePath: "/path/to/myapp/utils/__init__.py", + shortName: "utils", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mr := NewModuleRegistry() + mr.AddModule(tt.modulePath, tt.filePath) + + // Check module is registered + path, ok := mr.GetModulePath(tt.modulePath) + assert.True(t, ok) + assert.Equal(t, tt.filePath, path) + + // Check short name is indexed + assert.Contains(t, mr.ShortNames[tt.shortName], tt.filePath) + }) + } +} + +func TestModuleRegistry_AddModule_AmbiguousShortNames(t *testing.T) { + mr := NewModuleRegistry() + + // Add two modules with same short name + mr.AddModule("myapp.utils.helpers", "/path/to/myapp/utils/helpers.py") + mr.AddModule("lib.helpers", "/path/to/lib/helpers.py") + + // Both should be indexed under short name "helpers" + assert.Equal(t, 2, len(mr.ShortNames["helpers"])) + assert.Contains(t, mr.ShortNames["helpers"], "/path/to/myapp/utils/helpers.py") + assert.Contains(t, mr.ShortNames["helpers"], "/path/to/lib/helpers.py") + + // But each should be accessible by full module path + path1, ok1 := mr.GetModulePath("myapp.utils.helpers") + assert.True(t, ok1) + assert.Equal(t, "/path/to/myapp/utils/helpers.py", path1) + + path2, ok2 := mr.GetModulePath("lib.helpers") + assert.True(t, ok2) + assert.Equal(t, "/path/to/lib/helpers.py", path2) +} + +func TestModuleRegistry_GetModulePath_NotFound(t *testing.T) { + mr := NewModuleRegistry() + + path, ok := mr.GetModulePath("nonexistent.module") + assert.False(t, ok) + assert.Equal(t, "", path) +} + +func TestNewImportMap(t *testing.T) { + filePath := "/path/to/file.py" + im := NewImportMap(filePath) + + assert.NotNil(t, im) + assert.Equal(t, filePath, im.FilePath) + assert.NotNil(t, im.Imports) + assert.Equal(t, 0, len(im.Imports)) +} + +func TestImportMap_AddImport(t *testing.T) { + tests := []struct { + name string + alias string + fqn string + }{ + { + name: "Simple import", + alias: "utils", + fqn: "myapp.utils", + }, + { + name: "Aliased import", + alias: "clean", + fqn: "myapp.utils.sanitize", + }, + { + name: "Full module import", + alias: "myapp.db.models", + fqn: "myapp.db.models", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + im := NewImportMap("/path/to/file.py") + im.AddImport(tt.alias, tt.fqn) + + fqn, ok := im.Resolve(tt.alias) + assert.True(t, ok) + assert.Equal(t, tt.fqn, fqn) + }) + } +} + +func TestImportMap_Resolve_NotFound(t *testing.T) { + im := NewImportMap("/path/to/file.py") + + fqn, ok := im.Resolve("nonexistent") + assert.False(t, ok) + assert.Equal(t, "", fqn) +} + +func TestImportMap_Multiple(t *testing.T) { + im := NewImportMap("/path/to/file.py") + + imports := map[string]string{ + "utils": "myapp.utils", + "sanitize": "myapp.utils.sanitize", + "clean": "myapp.utils.clean", + "db": "myapp.db", + } + + for alias, fqn := range imports { + im.AddImport(alias, fqn) + } + + // Verify all imports are resolvable + for alias, expectedFqn := range imports { + fqn, ok := im.Resolve(alias) + assert.True(t, ok) + assert.Equal(t, expectedFqn, fqn) + } +} + +func TestLocation(t *testing.T) { + loc := Location{ + File: "/path/to/file.py", + Line: 42, + Column: 10, + } + + assert.Equal(t, "/path/to/file.py", loc.File) + assert.Equal(t, 42, loc.Line) + assert.Equal(t, 10, loc.Column) +} + +func TestCallSite(t *testing.T) { + cs := CallSite{ + Target: "sanitize", + Location: Location{ + File: "/path/to/views.py", + Line: 15, + Column: 8, + }, + Arguments: []Argument{ + {Value: "user_input", IsVariable: true, Position: 0}, + {Value: "\"html\"", IsVariable: false, Position: 1}, + }, + Resolved: true, + TargetFQN: "myapp.utils.sanitize", + } + + assert.Equal(t, "sanitize", cs.Target) + assert.Equal(t, 15, cs.Location.Line) + assert.Equal(t, 2, len(cs.Arguments)) + assert.True(t, cs.Resolved) + assert.Equal(t, "myapp.utils.sanitize", cs.TargetFQN) +} + +func TestArgument(t *testing.T) { + tests := []struct { + name string + value string + isVariable bool + position int + }{ + { + name: "Variable argument", + value: "user_input", + isVariable: true, + position: 0, + }, + { + name: "String literal argument", + value: "\"hello\"", + isVariable: false, + position: 1, + }, + { + name: "Number literal argument", + value: "42", + isVariable: false, + position: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + arg := Argument{ + Value: tt.value, + IsVariable: tt.isVariable, + Position: tt.position, + } + + assert.Equal(t, tt.value, arg.Value) + assert.Equal(t, tt.isVariable, arg.IsVariable) + assert.Equal(t, tt.position, arg.Position) + }) + } +} + +func TestCallGraph_WithFunctions(t *testing.T) { + cg := NewCallGraph() + + // Create mock function nodes + funcMain := &graph.Node{ + ID: "main_id", + Type: "function_definition", + Name: "main", + File: "/path/to/main.py", + } + + funcHelper := &graph.Node{ + ID: "helper_id", + Type: "function_definition", + Name: "helper", + File: "/path/to/utils.py", + } + + // Add functions to call graph + cg.Functions["myapp.main"] = funcMain + cg.Functions["myapp.utils.helper"] = funcHelper + + // Add edge + cg.AddEdge("myapp.main", "myapp.utils.helper") + + // Verify we can access function metadata + assert.Equal(t, "main", cg.Functions["myapp.main"].Name) + assert.Equal(t, "helper", cg.Functions["myapp.utils.helper"].Name) +} + +func TestExtractShortName(t *testing.T) { + tests := []struct { + name string + modulePath string + expected string + }{ + { + name: "Simple module", + modulePath: "views", + expected: "views", + }, + { + name: "Two components", + modulePath: "myapp.views", + expected: "views", + }, + { + name: "Three components", + modulePath: "myapp.utils.helpers", + expected: "helpers", + }, + { + name: "Deep nesting", + modulePath: "myapp.api.v1.endpoints.users", + expected: "users", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractShortName(tt.modulePath) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestContains(t *testing.T) { + tests := []struct { + name string + slice []string + item string + expected bool + }{ + { + name: "Item exists", + slice: []string{"a", "b", "c"}, + item: "b", + expected: true, + }, + { + name: "Item does not exist", + slice: []string{"a", "b", "c"}, + item: "d", + expected: false, + }, + { + name: "Empty slice", + slice: []string{}, + item: "a", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := contains(tt.slice, tt.item) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/sourcecode-parser/graph/callgraph/frameworks.go b/sourcecode-parser/graph/callgraph/frameworks.go index c4cb7007..738bd30b 100644 --- a/sourcecode-parser/graph/callgraph/frameworks.go +++ b/sourcecode-parser/graph/callgraph/frameworks.go @@ -1,413 +1,33 @@ package callgraph import ( - "strings" + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph/core" ) -// FrameworkDefinition represents a known external framework or library. -// This is used to mark calls to external code as resolved, even though -// we don't have the source code for these frameworks. -type FrameworkDefinition struct { - Name string // Display name (e.g., "Django") - Prefixes []string // Module prefixes to match (e.g., ["django.", "django"]) - Description string // Human-readable description - Category string // Category: "web", "orm", "testing", "stdlib", etc. -} - -// builtinFrameworks contains the list of known Python frameworks and libraries. -// This list focuses on the most common frameworks found in Python projects. -var builtinFrameworks = []FrameworkDefinition{ - // Web Frameworks - { - Name: "Django", - Prefixes: []string{"django."}, - Description: "Django web framework", - Category: "web", - }, - { - Name: "Django REST Framework", - Prefixes: []string{"rest_framework."}, - Description: "Django REST framework for building Web APIs", - Category: "web", - }, - { - Name: "Flask", - Prefixes: []string{"flask."}, - Description: "Flask web framework", - Category: "web", - }, - { - Name: "FastAPI", - Prefixes: []string{"fastapi."}, - Description: "FastAPI web framework", - Category: "web", - }, - { - Name: "Starlette", - Prefixes: []string{"starlette."}, - Description: "Starlette ASGI framework", - Category: "web", - }, - { - Name: "Tornado", - Prefixes: []string{"tornado."}, - Description: "Tornado web framework", - Category: "web", - }, - { - Name: "Pyramid", - Prefixes: []string{"pyramid."}, - Description: "Pyramid web framework", - Category: "web", - }, - { - Name: "Bottle", - Prefixes: []string{"bottle."}, - Description: "Bottle web framework", - Category: "web", - }, - - // ORM and Database - { - Name: "SQLAlchemy", - Prefixes: []string{"sqlalchemy."}, - Description: "SQLAlchemy ORM", - Category: "orm", - }, - { - Name: "Peewee", - Prefixes: []string{"peewee."}, - Description: "Peewee ORM", - Category: "orm", - }, - { - Name: "Tortoise ORM", - Prefixes: []string{"tortoise."}, - Description: "Tortoise ORM", - Category: "orm", - }, - { - Name: "Pony ORM", - Prefixes: []string{"pony."}, - Description: "Pony ORM", - Category: "orm", - }, - - // Testing Frameworks - { - Name: "pytest", - Prefixes: []string{"pytest.", "_pytest."}, - Description: "pytest testing framework", - Category: "testing", - }, - { - Name: "unittest", - Prefixes: []string{"unittest."}, - Description: "Python unittest framework", - Category: "testing", - }, - { - Name: "nose", - Prefixes: []string{"nose."}, - Description: "nose testing framework", - Category: "testing", - }, - { - Name: "mock", - Prefixes: []string{"mock.", "unittest.mock."}, - Description: "Python mock library", - Category: "testing", - }, - - // HTTP and Requests - { - Name: "requests", - Prefixes: []string{"requests."}, - Description: "HTTP library for Python", - Category: "http", - }, - { - Name: "httpx", - Prefixes: []string{"httpx."}, - Description: "Async HTTP client", - Category: "http", - }, - { - Name: "urllib3", - Prefixes: []string{"urllib3."}, - Description: "HTTP client library", - Category: "http", - }, - { - Name: "aiohttp", - Prefixes: []string{"aiohttp."}, - Description: "Async HTTP client/server", - Category: "http", - }, - - // Data Science and ML - { - Name: "numpy", - Prefixes: []string{"numpy.", "np."}, - Description: "Numerical computing library", - Category: "data_science", - }, - { - Name: "pandas", - Prefixes: []string{"pandas.", "pd."}, - Description: "Data analysis library", - Category: "data_science", - }, - { - Name: "scikit-learn", - Prefixes: []string{"sklearn.", "scikit_learn."}, - Description: "Machine learning library", - Category: "data_science", - }, - { - Name: "tensorflow", - Prefixes: []string{"tensorflow.", "tf."}, - Description: "TensorFlow ML framework", - Category: "data_science", - }, - { - Name: "pytorch", - Prefixes: []string{"torch."}, - Description: "PyTorch ML framework", - Category: "data_science", - }, - - // Async and Concurrency - { - Name: "asyncio", - Prefixes: []string{"asyncio."}, - Description: "Async I/O library", - Category: "async", - }, - { - Name: "celery", - Prefixes: []string{"celery."}, - Description: "Distributed task queue", - Category: "async", - }, +// Deprecated: Use core.FrameworkDefinition instead. +// This alias will be removed in a future version. +type FrameworkDefinition = core.FrameworkDefinition - // Serialization and Data Formats - { - Name: "json", - Prefixes: []string{"json."}, - Description: "JSON encoder/decoder", - Category: "serialization", - }, - { - Name: "pickle", - Prefixes: []string{"pickle.", "_pickle."}, - Description: "Python object serialization", - Category: "serialization", - }, - { - Name: "yaml", - Prefixes: []string{"yaml.", "pyyaml."}, - Description: "YAML parser", - Category: "serialization", - }, - { - Name: "xml", - Prefixes: []string{"xml."}, - Description: "XML processing", - Category: "serialization", - }, - - // Logging and Monitoring - { - Name: "logging", - Prefixes: []string{"logging."}, - Description: "Python logging library", - Category: "logging", - }, - { - Name: "sentry", - Prefixes: []string{"sentry_sdk."}, - Description: "Sentry error tracking", - Category: "logging", - }, - - // Utilities - { - Name: "datetime", - Prefixes: []string{"datetime."}, - Description: "Date and time types", - Category: "stdlib", - }, - { - Name: "collections", - Prefixes: []string{"collections."}, - Description: "Container datatypes", - Category: "stdlib", - }, - { - Name: "itertools", - Prefixes: []string{"itertools."}, - Description: "Iterator functions", - Category: "stdlib", - }, - { - Name: "functools", - Prefixes: []string{"functools."}, - Description: "Higher-order functions", - Category: "stdlib", - }, - { - Name: "os", - Prefixes: []string{"os."}, - Description: "Operating system interfaces", - Category: "stdlib", - }, - { - Name: "sys", - Prefixes: []string{"sys."}, - Description: "System-specific parameters", - Category: "stdlib", - }, - { - Name: "pathlib", - Prefixes: []string{"pathlib."}, - Description: "Object-oriented filesystem paths", - Category: "stdlib", - }, - { - Name: "re", - Prefixes: []string{"re."}, - Description: "Regular expressions", - Category: "stdlib", - }, - { - Name: "subprocess", - Prefixes: []string{"subprocess."}, - Description: "Subprocess management", - Category: "stdlib", - }, - { - Name: "threading", - Prefixes: []string{"threading."}, - Description: "Thread-based parallelism", - Category: "stdlib", - }, - { - Name: "multiprocessing", - Prefixes: []string{"multiprocessing."}, - Description: "Process-based parallelism", - Category: "stdlib", - }, - { - Name: "socket", - Prefixes: []string{"socket."}, - Description: "Low-level networking", - Category: "stdlib", - }, - { - Name: "http", - Prefixes: []string{"http."}, - Description: "HTTP modules", - Category: "stdlib", - }, - { - Name: "urllib", - Prefixes: []string{"urllib."}, - Description: "URL handling modules", - Category: "stdlib", - }, - { - Name: "email", - Prefixes: []string{"email."}, - Description: "Email and MIME handling", - Category: "stdlib", - }, - { - Name: "hashlib", - Prefixes: []string{"hashlib."}, - Description: "Secure hash and message digest", - Category: "stdlib", - }, - { - Name: "hmac", - Prefixes: []string{"hmac."}, - Description: "Keyed-hashing for message authentication", - Category: "stdlib", - }, - { - Name: "secrets", - Prefixes: []string{"secrets."}, - Description: "Generate secure random numbers", - Category: "stdlib", - }, - { - Name: "typing", - Prefixes: []string{"typing."}, - Description: "Type hints support", - Category: "stdlib", - }, - { - Name: "dataclasses", - Prefixes: []string{"dataclasses."}, - Description: "Data classes", - Category: "stdlib", - }, - { - Name: "abc", - Prefixes: []string{"abc."}, - Description: "Abstract base classes", - Category: "stdlib", - }, -} - -// LoadFrameworks returns the list of known frameworks. -// This function provides an extensibility hook for future enhancements -// where frameworks might be loaded from a configuration file. +// LoadFrameworks is a convenience wrapper. +// Deprecated: Use core.LoadFrameworks instead. func LoadFrameworks() []FrameworkDefinition { - return builtinFrameworks + return core.LoadFrameworks() } -// IsKnownFramework checks if the given fully qualified name (FQN) -// belongs to a known external framework or standard library. -// -// Parameters: -// - fqn: fully qualified name (e.g., "django.db.models.ForeignKey") -// -// Returns: -// - true if the FQN matches any known framework -// - the matching framework definition +// IsKnownFramework is a convenience wrapper. +// Deprecated: Use core.IsKnownFramework instead. func IsKnownFramework(fqn string) (bool, *FrameworkDefinition) { - frameworks := LoadFrameworks() - - for i := range frameworks { - framework := &frameworks[i] - for _, prefix := range framework.Prefixes { - // Check for exact match or prefix match - if fqn == prefix || strings.HasPrefix(fqn, prefix) { - return true, framework - } - } - } - - return false, nil + return core.IsKnownFramework(fqn) } -// GetFrameworkCategory returns the category of a framework given its FQN. -// Returns empty string if not a known framework. +// GetFrameworkCategory is a convenience wrapper. +// Deprecated: Use core.GetFrameworkCategory instead. func GetFrameworkCategory(fqn string) string { - isKnown, framework := IsKnownFramework(fqn) - if isKnown { - return framework.Category - } - return "" + return core.GetFrameworkCategory(fqn) } -// GetFrameworkName returns the name of a framework given its FQN. -// Returns empty string if not a known framework. +// GetFrameworkName is a convenience wrapper. +// Deprecated: Use core.GetFrameworkName instead. func GetFrameworkName(fqn string) string { - isKnown, framework := IsKnownFramework(fqn) - if isKnown { - return framework.Name - } - return "" + return core.GetFrameworkName(fqn) } diff --git a/sourcecode-parser/graph/callgraph/statement.go b/sourcecode-parser/graph/callgraph/statement.go index 72c8e61e..cd2f468d 100644 --- a/sourcecode-parser/graph/callgraph/statement.go +++ b/sourcecode-parser/graph/callgraph/statement.go @@ -1,334 +1,68 @@ package callgraph -// StatementType represents the type of statement in the code. -type StatementType string - -const ( - // Assignment represents variable assignments: x = expr. - StatementTypeAssignment StatementType = "assignment" - - // Call represents function/method calls: foo(), obj.method(). - StatementTypeCall StatementType = "call" - - // Return represents return statements: return expr. - StatementTypeReturn StatementType = "return" - - // If represents conditional statements: if condition: ... - StatementTypeIf StatementType = "if" - - // For represents loop statements: for x in iterable: ... - StatementTypeFor StatementType = "for" - - // While represents while loop statements: while condition: ... - StatementTypeWhile StatementType = "while" - - // With represents context manager statements: with expr as var: ... - StatementTypeWith StatementType = "with" - - // Try represents exception handling: try: ... except: ... - StatementTypeTry StatementType = "try" - - // Raise represents exception raising: raise Exception(). - StatementTypeRaise StatementType = "raise" - - // Import represents import statements: import module, from module import name. - StatementTypeImport StatementType = "import" - - // Expression represents expression statements (calls, attribute access, etc.). - StatementTypeExpression StatementType = "expression" +import ( + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph/core" ) -// Statement represents a single statement in the code with def-use information. -type Statement struct { - // Type is the kind of statement (assignment, call, return, etc.) - Type StatementType - - // LineNumber is the source line number for this statement (1-indexed) - LineNumber uint32 - - // Def is the variable being defined by this statement (if any) - // For assignments: the left-hand side variable - // For for loops: the loop variable - // For with statements: the as variable - // Empty string if no definition - Def string - - // Uses is the list of variables used/read by this statement - // For assignments: variables in the right-hand side expression - // For calls: variables used in arguments - // For conditions: variables in the condition expression - Uses []string +// Deprecated: Use core.StatementType instead. +// This alias will be removed in a future version. +type StatementType = core.StatementType - // CallTarget is the function/method being called (if Type == StatementTypeCall) - // Format: "function_name" for direct calls, "obj.method" for method calls - // Empty string for non-call statements - CallTarget string +const ( + // Deprecated: Use core.StatementTypeAssignment instead. + StatementTypeAssignment = core.StatementTypeAssignment - // CallArgs are the argument variables passed to the call (if Type == StatementTypeCall) - // Only includes variable names, not literals - CallArgs []string + // Deprecated: Use core.StatementTypeCall instead. + StatementTypeCall = core.StatementTypeCall - // NestedStatements contains statements inside this statement's body - // Used for if/for/while/with/try blocks - // Empty for simple statements like assignments - NestedStatements []*Statement + // Deprecated: Use core.StatementTypeReturn instead. + StatementTypeReturn = core.StatementTypeReturn - // ElseBranch contains statements in the else branch (if applicable) - // Used for if/try statements - ElseBranch []*Statement -} + // Deprecated: Use core.StatementTypeIf instead. + StatementTypeIf = core.StatementTypeIf -// GetDef returns the variable defined by this statement, or empty string if none. -func (s *Statement) GetDef() string { - return s.Def -} + // Deprecated: Use core.StatementTypeFor instead. + StatementTypeFor = core.StatementTypeFor -// GetUses returns the list of variables used by this statement. -func (s *Statement) GetUses() []string { - return s.Uses -} + // Deprecated: Use core.StatementTypeWhile instead. + StatementTypeWhile = core.StatementTypeWhile -// IsCall returns true if this statement is a function/method call. -func (s *Statement) IsCall() bool { - return s.Type == StatementTypeCall || s.Type == StatementTypeExpression -} + // Deprecated: Use core.StatementTypeWith instead. + StatementTypeWith = core.StatementTypeWith -// IsAssignment returns true if this statement is a variable assignment. -func (s *Statement) IsAssignment() bool { - return s.Type == StatementTypeAssignment -} + // Deprecated: Use core.StatementTypeTry instead. + StatementTypeTry = core.StatementTypeTry -// IsControlFlow returns true if this statement is a control flow construct. -func (s *Statement) IsControlFlow() bool { - switch s.Type { - case StatementTypeIf, StatementTypeFor, StatementTypeWhile, StatementTypeWith, StatementTypeTry: - return true - default: - return false - } -} + // Deprecated: Use core.StatementTypeRaise instead. + StatementTypeRaise = core.StatementTypeRaise -// HasNestedStatements returns true if this statement contains nested statements. -func (s *Statement) HasNestedStatements() bool { - return len(s.NestedStatements) > 0 || len(s.ElseBranch) > 0 -} + // Deprecated: Use core.StatementTypeImport instead. + StatementTypeImport = core.StatementTypeImport -// AllStatements returns a flattened list of this statement and all nested statements. -// Performs depth-first traversal. -func (s *Statement) AllStatements() []*Statement { - result := []*Statement{s} - - for _, nested := range s.NestedStatements { - result = append(result, nested.AllStatements()...) - } + // Deprecated: Use core.StatementTypeExpression instead. + StatementTypeExpression = core.StatementTypeExpression +) - for _, elseBranch := range s.ElseBranch { - result = append(result, elseBranch.AllStatements()...) - } +// Deprecated: Use core.Statement instead. +// This alias will be removed in a future version. +type Statement = core.Statement - return result -} +// Deprecated: Use core.DefUseChain instead. +// This alias will be removed in a future version. +type DefUseChain = core.DefUseChain -// DefUseChain represents the def-use relationships for all variables in a function. -type DefUseChain struct { - // Defs maps variable names to all statements that define them. - // A variable can have multiple definitions across different code paths. - Defs map[string][]*Statement +// Deprecated: Use core.DefUseStats instead. +// This alias will be removed in a future version. +type DefUseStats = core.DefUseStats - // Uses maps variable names to all statements that use them. - // A variable can be used in multiple places. - Uses map[string][]*Statement -} - -// NewDefUseChain creates an empty def-use chain. +// NewDefUseChain is a convenience wrapper. +// Deprecated: Use core.NewDefUseChain instead. func NewDefUseChain() *DefUseChain { - return &DefUseChain{ - Defs: make(map[string][]*Statement), - Uses: make(map[string][]*Statement), - } -} - -// AddDef registers a statement as defining a variable. -func (chain *DefUseChain) AddDef(varName string, stmt *Statement) { - if varName == "" { - return - } - chain.Defs[varName] = append(chain.Defs[varName], stmt) + return core.NewDefUseChain() } -// AddUse registers a statement as using a variable. -func (chain *DefUseChain) AddUse(varName string, stmt *Statement) { - if varName == "" { - return - } - chain.Uses[varName] = append(chain.Uses[varName], stmt) -} - -// GetDefs returns all statements that define a given variable. -// Returns empty slice if variable is never defined. -func (chain *DefUseChain) GetDefs(varName string) []*Statement { - if defs, ok := chain.Defs[varName]; ok { - return defs - } - return []*Statement{} -} - -// GetUses returns all statements that use a given variable. -// Returns empty slice if variable is never used. -func (chain *DefUseChain) GetUses(varName string) []*Statement { - if uses, ok := chain.Uses[varName]; ok { - return uses - } - return []*Statement{} -} - -// IsDefined returns true if the variable has at least one definition. -func (chain *DefUseChain) IsDefined(varName string) bool { - return len(chain.Defs[varName]) > 0 -} - -// IsUsed returns true if the variable has at least one use. -func (chain *DefUseChain) IsUsed(varName string) bool { - return len(chain.Uses[varName]) > 0 -} - -// AllVariables returns a list of all variable names in the def-use chain. -func (chain *DefUseChain) AllVariables() []string { - varSet := make(map[string]bool) - - for varName := range chain.Defs { - varSet[varName] = true - } - - for varName := range chain.Uses { - varSet[varName] = true - } - - result := make([]string, 0, len(varSet)) - for varName := range varSet { - result = append(result, varName) - } - - return result -} - -// BuildDefUseChains constructs a def-use chain from a list of statements. -// This is a single-pass algorithm that builds an inverted index. -// -// Algorithm: -// 1. Initialize empty Defs and Uses maps -// 2. For each statement: -// - If stmt.Def is not empty: add stmt to Defs[stmt.Def] -// - For each variable in stmt.Uses: add stmt to Uses[variable] -// 3. Return DefUseChain -// -// Time complexity: O(n × m) -// -// where n = number of statements -// m = average number of uses per statement -// Typical: 50 statements × 3 variables = 150 operations (~1 microsecond) -// -// Space complexity: O(v × k) -// -// where v = number of unique variables -// k = average number of defs + uses per variable -// Typical: 20 variables × 5 references = 100 pointers = 800 bytes -// -// Example: -// -// statements := []*Statement{ -// {LineNumber: 1, Def: "x", Uses: []string{}}, -// {LineNumber: 2, Def: "y", Uses: []string{"x"}}, -// {LineNumber: 3, Def: "", Uses: []string{"y"}}, -// } -// -// chain := BuildDefUseChains(statements) -// -// // Query: where is x defined? -// xDefs := chain.Defs["x"] // [stmt1] -// -// // Query: where is x used? -// xUses := chain.Uses["x"] // [stmt2] +// BuildDefUseChains is a convenience wrapper. +// Deprecated: Use core.BuildDefUseChains instead. func BuildDefUseChains(statements []*Statement) *DefUseChain { - chain := NewDefUseChain() - - // Single pass: build inverted index - for _, stmt := range statements { - // Track definition (single variable per statement) - if stmt.Def != "" { - chain.AddDef(stmt.Def, stmt) - } - - // Track all uses in this statement - for _, varName := range stmt.Uses { - chain.AddUse(varName, stmt) - } - } - - return chain -} - -// DefUseStats contains statistics about the def-use chain (for debugging/diagnostics). -type DefUseStats struct { - NumVariables int // Total unique variables - NumDefs int // Total definition sites - NumUses int // Total use sites - MaxDefsPerVariable int // Most definitions for a single variable - MaxUsesPerVariable int // Most uses for a single variable - UndefinedVariables int // Variables used but never defined (parameters) - DeadVariables int // Variables defined but never used -} - -// ComputeStats computes statistics about this def-use chain. -// Useful for performance analysis and debugging. -// -// Example: -// -// stats := chain.ComputeStats() -// fmt.Printf("Function has %d variables, %d defs, %d uses\n", -// stats.NumVariables, stats.NumDefs, stats.NumUses) -func (chain *DefUseChain) ComputeStats() DefUseStats { - stats := DefUseStats{} - - // Count unique variables - varSet := make(map[string]bool) - for varName := range chain.Defs { - varSet[varName] = true - } - for varName := range chain.Uses { - varSet[varName] = true - } - stats.NumVariables = len(varSet) - - // Count total defs and max defs per variable - for _, defs := range chain.Defs { - stats.NumDefs += len(defs) - if len(defs) > stats.MaxDefsPerVariable { - stats.MaxDefsPerVariable = len(defs) - } - } - - // Count total uses and max uses per variable - for _, uses := range chain.Uses { - stats.NumUses += len(uses) - if len(uses) > stats.MaxUsesPerVariable { - stats.MaxUsesPerVariable = len(uses) - } - } - - // Count undefined variables (used but not defined) - for varName := range chain.Uses { - if len(chain.Defs[varName]) == 0 { - stats.UndefinedVariables++ - } - } - - // Count dead variables (defined but not used) - for varName := range chain.Defs { - if len(chain.Uses[varName]) == 0 { - stats.DeadVariables++ - } - } - - return stats + return core.BuildDefUseChains(statements) } diff --git a/sourcecode-parser/graph/callgraph/stdlib_registry.go b/sourcecode-parser/graph/callgraph/stdlib_registry.go index ad687221..fb349bf2 100644 --- a/sourcecode-parser/graph/callgraph/stdlib_registry.go +++ b/sourcecode-parser/graph/callgraph/stdlib_registry.go @@ -1,168 +1,55 @@ package callgraph -// StdlibRegistry holds all Python stdlib module registries. -type StdlibRegistry struct { - Modules map[string]*StdlibModule - Manifest *Manifest -} - -// Manifest contains metadata about the stdlib registry. -// -//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). -type Manifest struct { - SchemaVersion string `json:"schema_version"` - RegistryVersion string `json:"registry_version"` - PythonVersion PythonVersionInfo `json:"python_version"` - GeneratedAt string `json:"generated_at"` - GeneratorVersion string `json:"generator_version"` - BaseURL string `json:"base_url"` - Modules []*ModuleEntry `json:"modules"` - Statistics *RegistryStats `json:"statistics"` -} - -// PythonVersionInfo contains Python version details. -type PythonVersionInfo struct { - Major int `json:"major"` - Minor int `json:"minor"` - Patch int `json:"patch"` - Full string `json:"full"` -} - -// ModuleEntry represents a single module in the manifest. -// -//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). -type ModuleEntry struct { - Name string `json:"name"` - File string `json:"file"` - URL string `json:"url"` - SizeBytes int64 `json:"size_bytes"` - Checksum string `json:"checksum"` -} - -// RegistryStats contains aggregate statistics. -// -//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). -type RegistryStats struct { - TotalModules int `json:"total_modules"` - TotalFunctions int `json:"total_functions"` - TotalClasses int `json:"total_classes"` - TotalConstants int `json:"total_constants"` - TotalAttributes int `json:"total_attributes"` -} - -// StdlibModule represents a single stdlib module registry. -// -//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). -type StdlibModule struct { - Module string `json:"module"` - PythonVersion string `json:"python_version"` - GeneratedAt string `json:"generated_at"` - Functions map[string]*StdlibFunction `json:"functions"` - Classes map[string]*StdlibClass `json:"classes"` - Constants map[string]*StdlibConstant `json:"constants"` - Attributes map[string]*StdlibAttribute `json:"attributes"` -} - -// StdlibFunction represents a function in a stdlib module. -// -//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). -type StdlibFunction struct { - ReturnType string `json:"return_type"` - Confidence float32 `json:"confidence"` - Params []*FunctionParam `json:"params"` - Source string `json:"source"` - Docstring string `json:"docstring,omitempty"` -} +import ( + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph/core" +) -// FunctionParam represents a function parameter. -type FunctionParam struct { - Name string `json:"name"` - Type string `json:"type"` - Required bool `json:"required"` -} +// Deprecated: Use core.StdlibRegistry instead. +// This alias will be removed in a future version. +type StdlibRegistry = core.StdlibRegistry -// StdlibClass represents a class in a stdlib module. -type StdlibClass struct { - Type string `json:"type"` - Methods map[string]*StdlibFunction `json:"methods"` - Docstring string `json:"docstring,omitempty"` -} +// Deprecated: Use core.Manifest instead. +// This alias will be removed in a future version. +type Manifest = core.Manifest -// StdlibConstant represents a module-level constant. -// -//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). -type StdlibConstant struct { - Type string `json:"type"` - Value string `json:"value"` - Confidence float32 `json:"confidence"` - PlatformSpecific bool `json:"platform_specific,omitempty"` -} +// Deprecated: Use core.PythonVersionInfo instead. +// This alias will be removed in a future version. +type PythonVersionInfo = core.PythonVersionInfo -// StdlibAttribute represents a module-level attribute (os.environ, sys.modules, etc.). -// -//nolint:tagliatelle // JSON tags match Python-generated registry format (snake_case). -type StdlibAttribute struct { - Type string `json:"type"` - BehavesLike string `json:"behaves_like,omitempty"` - Confidence float32 `json:"confidence"` - Docstring string `json:"docstring,omitempty"` -} +// Deprecated: Use core.ModuleEntry instead. +// This alias will be removed in a future version. +type ModuleEntry = core.ModuleEntry -// NewStdlibRegistry creates a new stdlib registry. -func NewStdlibRegistry() *StdlibRegistry { - return &StdlibRegistry{ - Modules: make(map[string]*StdlibModule), - } -} +// Deprecated: Use core.RegistryStats instead. +// This alias will be removed in a future version. +type RegistryStats = core.RegistryStats -// GetModule returns the registry for a specific module. -func (r *StdlibRegistry) GetModule(moduleName string) *StdlibModule { - return r.Modules[moduleName] -} +// Deprecated: Use core.StdlibModule instead. +// This alias will be removed in a future version. +type StdlibModule = core.StdlibModule -// HasModule checks if a module exists in the registry. -func (r *StdlibRegistry) HasModule(moduleName string) bool { - _, exists := r.Modules[moduleName] - return exists -} +// Deprecated: Use core.StdlibFunction instead. +// This alias will be removed in a future version. +type StdlibFunction = core.StdlibFunction -// GetFunction returns a function from a module. -func (r *StdlibRegistry) GetFunction(moduleName, functionName string) *StdlibFunction { - module := r.GetModule(moduleName) - if module == nil { - return nil - } - return module.Functions[functionName] -} +// Deprecated: Use core.FunctionParam instead. +// This alias will be removed in a future version. +type FunctionParam = core.FunctionParam -// GetClass returns a class from a module. -func (r *StdlibRegistry) GetClass(moduleName, className string) *StdlibClass { - module := r.GetModule(moduleName) - if module == nil { - return nil - } - return module.Classes[className] -} +// Deprecated: Use core.StdlibClass instead. +// This alias will be removed in a future version. +type StdlibClass = core.StdlibClass -// GetConstant returns a constant from a module. -func (r *StdlibRegistry) GetConstant(moduleName, constantName string) *StdlibConstant { - module := r.GetModule(moduleName) - if module == nil { - return nil - } - return module.Constants[constantName] -} +// Deprecated: Use core.StdlibConstant instead. +// This alias will be removed in a future version. +type StdlibConstant = core.StdlibConstant -// GetAttribute returns an attribute from a module. -func (r *StdlibRegistry) GetAttribute(moduleName, attributeName string) *StdlibAttribute { - module := r.GetModule(moduleName) - if module == nil { - return nil - } - return module.Attributes[attributeName] -} +// Deprecated: Use core.StdlibAttribute instead. +// This alias will be removed in a future version. +type StdlibAttribute = core.StdlibAttribute -// ModuleCount returns the number of loaded modules. -func (r *StdlibRegistry) ModuleCount() int { - return len(r.Modules) +// NewStdlibRegistry is a convenience wrapper. +// Deprecated: Use core.NewStdlibRegistry instead. +func NewStdlibRegistry() *StdlibRegistry { + return core.NewStdlibRegistry() } diff --git a/sourcecode-parser/graph/callgraph/taint_summary.go b/sourcecode-parser/graph/callgraph/taint_summary.go index b277891e..9b44ae6a 100644 --- a/sourcecode-parser/graph/callgraph/taint_summary.go +++ b/sourcecode-parser/graph/callgraph/taint_summary.go @@ -1,238 +1,19 @@ package callgraph -// TaintInfo represents detailed taint tracking information for a single detection. -type TaintInfo struct { - // SourceLine is the line number where taint originated (1-indexed) - SourceLine uint32 +import ( + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph/core" +) - // SourceVar is the variable name at the taint source - SourceVar string +// Deprecated: Use core.TaintInfo instead. +// This alias will be removed in a future version. +type TaintInfo = core.TaintInfo - // SinkLine is the line number where tainted data reaches a dangerous sink (1-indexed) - SinkLine uint32 +// Deprecated: Use core.TaintSummary instead. +// This alias will be removed in a future version. +type TaintSummary = core.TaintSummary - // SinkVar is the variable name at the sink - SinkVar string - - // SinkCall is the dangerous function/method call at the sink - // Examples: "execute", "eval", "os.system" - SinkCall string - - // PropagationPath is the list of variables through which taint propagated - // Example: ["user_input", "data", "query"] shows user_input -> data -> query - PropagationPath []string - - // Confidence is a score from 0.0 to 1.0 indicating detection confidence - // 1.0 = high confidence (direct flow) - // 0.7 = medium confidence (through stdlib function) - // 0.5 = low confidence (through third-party library) - // 0.0 = no taint detected - Confidence float64 - - // Sanitized indicates if a sanitizer was detected in the propagation path - // If true, the taint was neutralized and should not trigger a finding - Sanitized bool - - // SanitizerLine is the line number where sanitization occurred (if Sanitized == true) - SanitizerLine uint32 - - // SanitizerCall is the sanitizer function that was called - // Examples: "escape_html", "quote_sql", "validate_email" - SanitizerCall string -} - -// IsTainted returns true if this TaintInfo represents actual taint (confidence > 0). -func (ti *TaintInfo) IsTainted() bool { - return ti.Confidence > 0.0 && !ti.Sanitized -} - -// IsHighConfidence returns true if confidence >= 0.8. -func (ti *TaintInfo) IsHighConfidence() bool { - return ti.Confidence >= 0.8 -} - -// IsMediumConfidence returns true if 0.5 <= confidence < 0.8. -func (ti *TaintInfo) IsMediumConfidence() bool { - return ti.Confidence >= 0.5 && ti.Confidence < 0.8 -} - -// IsLowConfidence returns true if 0.0 < confidence < 0.5. -func (ti *TaintInfo) IsLowConfidence() bool { - return ti.Confidence > 0.0 && ti.Confidence < 0.5 -} - -// TaintSummary represents the complete taint analysis results for a function. -type TaintSummary struct { - // FunctionFQN is the fully qualified name of the analyzed function - // Format: "module.Class.method" or "module.function" - FunctionFQN string - - // TaintedVars maps variable names to their taint information - // If a variable is not in this map, it is considered untainted - // Multiple TaintInfo entries indicate multiple taint paths to the same variable - TaintedVars map[string][]*TaintInfo - - // Detections is a list of all taint flows that reached a dangerous sink - // These represent potential security vulnerabilities - Detections []*TaintInfo - - // TaintedParams tracks which function parameters are tainted (by parameter name) - // Used for inter-procedural analysis - TaintedParams []string - - // TaintedReturn indicates if the function's return value is tainted - TaintedReturn bool - - // ReturnTaintInfo provides details if TaintedReturn is true - ReturnTaintInfo *TaintInfo - - // AnalysisError indicates if the analysis failed for this function - // If true, the summary is incomplete - AnalysisError bool - - // ErrorMessage contains the error description if AnalysisError is true - ErrorMessage string -} - -// NewTaintSummary creates an empty taint summary for a function. +// NewTaintSummary is a convenience wrapper. +// Deprecated: Use core.NewTaintSummary instead. func NewTaintSummary(functionFQN string) *TaintSummary { - return &TaintSummary{ - FunctionFQN: functionFQN, - TaintedVars: make(map[string][]*TaintInfo), - Detections: make([]*TaintInfo, 0), - TaintedParams: make([]string, 0), - } -} - -// AddTaintedVar records taint information for a variable. -func (ts *TaintSummary) AddTaintedVar(varName string, taintInfo *TaintInfo) { - if varName == "" || taintInfo == nil { - return - } - ts.TaintedVars[varName] = append(ts.TaintedVars[varName], taintInfo) -} - -// GetTaintInfo retrieves all taint information for a variable. -// Returns nil if variable is not tainted. -func (ts *TaintSummary) GetTaintInfo(varName string) []*TaintInfo { - return ts.TaintedVars[varName] -} - -// IsTainted checks if a variable is tainted (has at least one unsanitized taint path). -func (ts *TaintSummary) IsTainted(varName string) bool { - taintInfos := ts.TaintedVars[varName] - for _, info := range taintInfos { - if info.IsTainted() { - return true - } - } - return false -} - -// AddDetection records a taint flow that reached a dangerous sink. -func (ts *TaintSummary) AddDetection(detection *TaintInfo) { - if detection == nil { - return - } - ts.Detections = append(ts.Detections, detection) -} - -// HasDetections returns true if any taint flows reached dangerous sinks. -func (ts *TaintSummary) HasDetections() bool { - return len(ts.Detections) > 0 -} - -// GetHighConfidenceDetections returns detections with confidence >= 0.8. -func (ts *TaintSummary) GetHighConfidenceDetections() []*TaintInfo { - result := make([]*TaintInfo, 0) - for _, detection := range ts.Detections { - if detection.IsHighConfidence() { - result = append(result, detection) - } - } - return result -} - -// GetMediumConfidenceDetections returns detections with 0.5 <= confidence < 0.8. -func (ts *TaintSummary) GetMediumConfidenceDetections() []*TaintInfo { - result := make([]*TaintInfo, 0) - for _, detection := range ts.Detections { - if detection.IsMediumConfidence() { - result = append(result, detection) - } - } - return result -} - -// GetLowConfidenceDetections returns detections with 0.0 < confidence < 0.5. -func (ts *TaintSummary) GetLowConfidenceDetections() []*TaintInfo { - result := make([]*TaintInfo, 0) - for _, detection := range ts.Detections { - if detection.IsLowConfidence() { - result = append(result, detection) - } - } - return result -} - -// MarkTaintedParam marks a function parameter as tainted. -func (ts *TaintSummary) MarkTaintedParam(paramName string) { - if paramName == "" { - return - } - - // Check if already marked - for _, p := range ts.TaintedParams { - if p == paramName { - return - } - } - - ts.TaintedParams = append(ts.TaintedParams, paramName) -} - -// IsParamTainted checks if a function parameter is tainted. -func (ts *TaintSummary) IsParamTainted(paramName string) bool { - for _, p := range ts.TaintedParams { - if p == paramName { - return true - } - } - return false -} - -// MarkReturnTainted marks the function's return value as tainted. -func (ts *TaintSummary) MarkReturnTainted(taintInfo *TaintInfo) { - ts.TaintedReturn = true - ts.ReturnTaintInfo = taintInfo -} - -// SetError marks the analysis as failed with an error message. -func (ts *TaintSummary) SetError(errorMsg string) { - ts.AnalysisError = true - ts.ErrorMessage = errorMsg -} - -// IsComplete returns true if analysis completed without errors. -func (ts *TaintSummary) IsComplete() bool { - return !ts.AnalysisError -} - -// GetTaintedVarCount returns the number of distinct tainted variables. -func (ts *TaintSummary) GetTaintedVarCount() int { - count := 0 - for _, taintInfos := range ts.TaintedVars { - for _, info := range taintInfos { - if info.IsTainted() { - count++ - break // Count each variable only once - } - } - } - return count -} - -// GetDetectionCount returns the total number of detections. -func (ts *TaintSummary) GetDetectionCount() int { - return len(ts.Detections) + return core.NewTaintSummary(functionFQN) } diff --git a/sourcecode-parser/graph/callgraph/type_inference.go b/sourcecode-parser/graph/callgraph/type_inference.go index c3c4e96b..5c7e2db5 100644 --- a/sourcecode-parser/graph/callgraph/type_inference.go +++ b/sourcecode-parser/graph/callgraph/type_inference.go @@ -1,14 +1,14 @@ package callgraph -import "strings" - -// TypeInfo represents inferred type information for a variable or expression. -// It tracks the fully qualified type name, confidence level, and how the type was inferred. -type TypeInfo struct { - TypeFQN string // Fully qualified type name (e.g., "builtins.str", "myapp.models.User") - Confidence float32 // Confidence level from 0.0 to 1.0 (1.0 = certain, 0.5 = heuristic, 0.0 = unknown) - Source string // How the type was inferred (e.g., "literal", "assignment", "annotation") -} +import ( + "strings" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph/core" +) + +// Deprecated: Use core.TypeInfo instead. +// This alias will be removed in a future version. +type TypeInfo = core.TypeInfo // VariableBinding tracks a variable's type within a scope. // It captures the variable name, its inferred type, and source location. diff --git a/sourcecode-parser/graph/callgraph/types.go b/sourcecode-parser/graph/callgraph/types.go index 03252121..fb34b2fd 100644 --- a/sourcecode-parser/graph/callgraph/types.go +++ b/sourcecode-parser/graph/callgraph/types.go @@ -1,267 +1,55 @@ package callgraph import ( - "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph" + "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/callgraph/core" ) -// Location represents a source code location for tracking call sites. -// This enables precise mapping of where calls occur in the source code. -type Location struct { - File string // Absolute path to the source file - Line int // Line number (1-indexed) - Column int // Column number (1-indexed) -} - -// CallSite represents a function/method call location in the source code. -// It captures both the syntactic information (where the call is) and -// semantic information (what is being called and with what arguments). -type CallSite struct { - Target string // The name of the function being called (e.g., "eval", "utils.sanitize") - Location Location // Where this call occurs in the source code - Arguments []Argument // Arguments passed to the call - Resolved bool // Whether we successfully resolved this call to a definition - TargetFQN string // Fully qualified name after resolution (e.g., "myapp.utils.sanitize") - FailureReason string // Why resolution failed (empty if Resolved=true) - - // Phase 2: Type inference metadata - ResolvedViaTypeInference bool // Was this resolved using type inference? - InferredType string // The inferred type FQN (e.g., "builtins.str", "test.User") - TypeConfidence float32 // Confidence score of the type inference (0.0-1.0) - TypeSource string // How type was inferred (e.g., "literal", "return_type", "class_instantiation") -} - -// Resolution failure reason categories for diagnostics: -// - "external_framework" - Call to Django, REST framework, pytest, stdlib, etc. -// - "orm_pattern" - Django ORM patterns like Model.objects.filter() -// - "attribute_chain" - Method calls on return values like response.json() -// - "variable_method" - Method calls on variables like value.split() -// - "super_call" - Calls via super() to parent class methods -// - "not_in_imports" - Simple function call not found in imports -// - "unknown" - Unresolved for other reasons +// Deprecated: Use core.Location instead. +// This alias will be removed in a future version. +type Location = core.Location -// Argument represents a single argument passed to a function call. -// Tracks both the value/expression and metadata about the argument. -type Argument struct { - Value string // The argument expression as a string - IsVariable bool // Whether this argument is a variable reference - Position int // Position in the argument list (0-indexed) -} - -// CallGraph represents the complete call graph of a program. -// It maps function definitions to their call sites and provides -// both forward (callers → callees) and reverse (callees → callers) edges. -// -// Example: -// Function A calls B and C -// edges: {"A": ["B", "C"]} -// reverseEdges: {"B": ["A"], "C": ["A"]} -type CallGraph struct { - // Forward edges: maps fully qualified function name to list of functions it calls - // Key: caller FQN (e.g., "myapp.views.get_user") - // Value: list of callee FQNs (e.g., ["myapp.db.query", "myapp.utils.sanitize"]) - Edges map[string][]string +// Deprecated: Use core.CallSite instead. +// This alias will be removed in a future version. +type CallSite = core.CallSite - // Reverse edges: maps fully qualified function name to list of functions that call it - // Useful for backward slicing and finding all callers of a function - // Key: callee FQN - // Value: list of caller FQNs - ReverseEdges map[string][]string +// Deprecated: Use core.Argument instead. +// This alias will be removed in a future version. +type Argument = core.Argument - // Detailed call site information for each function - // Key: caller FQN - // Value: list of all call sites within that function - CallSites map[string][]CallSite +// Deprecated: Use core.CallGraph instead. +// This alias will be removed in a future version. +type CallGraph = core.CallGraph - // Map from fully qualified name to the actual function node in the graph - // This allows quick lookup of function metadata (line number, file, etc.) - Functions map[string]*graph.Node +// Deprecated: Use core.ModuleRegistry instead. +// This alias will be removed in a future version. +type ModuleRegistry = core.ModuleRegistry - // Taint summaries for each function (intra-procedural analysis results) - // Key: function FQN - // Value: TaintSummary with taint flow information - Summaries map[string]*TaintSummary -} +// Deprecated: Use core.ImportMap instead. +// This alias will be removed in a future version. +type ImportMap = core.ImportMap -// NewCallGraph creates and initializes a new CallGraph instance. -// All maps are pre-allocated to avoid nil pointer issues. +// NewCallGraph is a convenience wrapper. +// Deprecated: Use core.NewCallGraph instead. func NewCallGraph() *CallGraph { - return &CallGraph{ - Edges: make(map[string][]string), - ReverseEdges: make(map[string][]string), - CallSites: make(map[string][]CallSite), - Functions: make(map[string]*graph.Node), - Summaries: make(map[string]*TaintSummary), - } -} - -// AddEdge adds a directed edge from caller to callee in the call graph. -// Automatically updates both forward and reverse edges. -// -// Parameters: -// - caller: fully qualified name of the calling function -// - callee: fully qualified name of the called function -func (cg *CallGraph) AddEdge(caller, callee string) { - // Add forward edge - if !contains(cg.Edges[caller], callee) { - cg.Edges[caller] = append(cg.Edges[caller], callee) - } - - // Add reverse edge - if !contains(cg.ReverseEdges[callee], caller) { - cg.ReverseEdges[callee] = append(cg.ReverseEdges[callee], caller) - } -} - -// AddCallSite adds a call site to the call graph. -// This stores detailed information about where and how a function is called. -// -// Parameters: -// - caller: fully qualified name of the calling function -// - callSite: detailed information about the call -func (cg *CallGraph) AddCallSite(caller string, callSite CallSite) { - cg.CallSites[caller] = append(cg.CallSites[caller], callSite) -} - -// GetCallers returns all functions that call the specified function. -// Uses the reverse edges for efficient lookup. -// -// Parameters: -// - callee: fully qualified name of the function -// -// Returns: -// - list of caller FQNs, or empty slice if no callers found -func (cg *CallGraph) GetCallers(callee string) []string { - if callers, ok := cg.ReverseEdges[callee]; ok { - return callers - } - return []string{} -} - -// GetCallees returns all functions called by the specified function. -// Uses the forward edges for efficient lookup. -// -// Parameters: -// - caller: fully qualified name of the function -// -// Returns: -// - list of callee FQNs, or empty slice if no callees found -func (cg *CallGraph) GetCallees(caller string) []string { - if callees, ok := cg.Edges[caller]; ok { - return callees - } - return []string{} -} - -// ModuleRegistry maintains the mapping between Python file paths and module paths. -// This is essential for resolving imports and building fully qualified names. -// -// Example: -// File: /project/myapp/utils/helpers.py -// Module: myapp.utils.helpers -type ModuleRegistry struct { - // Maps fully qualified module path to absolute file path - // Key: "myapp.utils.helpers" - // Value: "/absolute/path/to/myapp/utils/helpers.py" - Modules map[string]string - - // Maps absolute file path to fully qualified module path (reverse of Modules) - // Key: "/absolute/path/to/myapp/utils/helpers.py" - // Value: "myapp.utils.helpers" - // Used for resolving relative imports - FileToModule map[string]string - - // Maps short module names to all matching file paths (handles ambiguity) - // Key: "helpers" - // Value: ["/path/to/myapp/utils/helpers.py", "/path/to/lib/helpers.py"] - ShortNames map[string][]string - - // Cache for resolved imports to avoid redundant lookups - // Key: import string (e.g., "utils.helpers") - // Value: fully qualified module path - ResolvedImports map[string]string + return core.NewCallGraph() } -// NewModuleRegistry creates and initializes a new ModuleRegistry instance. +// NewModuleRegistry is a convenience wrapper. +// Deprecated: Use core.NewModuleRegistry instead. func NewModuleRegistry() *ModuleRegistry { - return &ModuleRegistry{ - Modules: make(map[string]string), - FileToModule: make(map[string]string), - ShortNames: make(map[string][]string), - ResolvedImports: make(map[string]string), - } -} - -// AddModule registers a module in the registry. -// Automatically indexes both the full module path and the short name. -// -// Parameters: -// - modulePath: fully qualified module path (e.g., "myapp.utils.helpers") -// - filePath: absolute file path (e.g., "/project/myapp/utils/helpers.py") -func (mr *ModuleRegistry) AddModule(modulePath, filePath string) { - mr.Modules[modulePath] = filePath - mr.FileToModule[filePath] = modulePath - - // Extract short name (last component) - // "myapp.utils.helpers" → "helpers" - shortName := extractShortName(modulePath) - if !containsString(mr.ShortNames[shortName], filePath) { - mr.ShortNames[shortName] = append(mr.ShortNames[shortName], filePath) - } -} - -// GetModulePath returns the file path for a given module, if it exists. -// -// Parameters: -// - modulePath: fully qualified module path -// -// Returns: -// - file path and true if found, empty string and false otherwise -func (mr *ModuleRegistry) GetModulePath(modulePath string) (string, bool) { - filePath, ok := mr.Modules[modulePath] - return filePath, ok -} - -// ImportMap represents the import statements in a single Python file. -// Maps local aliases to fully qualified module paths. -// -// Example: -// File contains: from myapp.utils import sanitize as clean -// Imports: {"clean": "myapp.utils.sanitize"} -type ImportMap struct { - FilePath string // Absolute path to the file containing these imports - Imports map[string]string // Maps alias/name to fully qualified module path + return core.NewModuleRegistry() } -// NewImportMap creates and initializes a new ImportMap instance. +// NewImportMap is a convenience wrapper. +// Deprecated: Use core.NewImportMap instead. func NewImportMap(filePath string) *ImportMap { - return &ImportMap{ - FilePath: filePath, - Imports: make(map[string]string), - } + return core.NewImportMap(filePath) } -// AddImport adds an import mapping to the import map. -// -// Parameters: -// - alias: the local name used in the file (e.g., "clean", "sanitize", "utils") -// - fqn: the fully qualified name (e.g., "myapp.utils.sanitize") -func (im *ImportMap) AddImport(alias, fqn string) { - im.Imports[alias] = fqn -} - -// Resolve looks up the fully qualified name for a local alias. -// -// Parameters: -// - alias: the local name to resolve -// -// Returns: -// - fully qualified name and true if found, empty string and false otherwise -func (im *ImportMap) Resolve(alias string) (string, bool) { - fqn, ok := im.Imports[alias] - return fqn, ok -} +// Helper functions for internal use within callgraph package +// These are kept here for backward compatibility with other files in the package -// Helper function to check if a string slice contains a specific string. +// contains checks if a string slice contains a specific string. func contains(slice []string, item string) bool { for _, s := range slice { if s == item { @@ -271,12 +59,12 @@ func contains(slice []string, item string) bool { return false } -// Helper function alias for consistency. +// containsString is an alias for contains for consistency. func containsString(slice []string, item string) bool { return contains(slice, item) } -// Helper function to extract the last component of a dotted path. +// extractShortName extracts the last component of a dotted path. // Example: "myapp.utils.helpers" → "helpers". func extractShortName(modulePath string) string { // Find last dot