diff --git a/sourcecode-parser/graph/construct.go b/sourcecode-parser/graph/construct.go deleted file mode 100644 index cd7c5956..00000000 --- a/sourcecode-parser/graph/construct.go +++ /dev/null @@ -1,1203 +0,0 @@ -package graph - -import ( - "context" - "fmt" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - - javalang "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/java" - - "github.com/shivasurya/code-pathfinder/sourcecode-parser/model" - "github.com/smacker/go-tree-sitter/java" - - sitter "github.com/smacker/go-tree-sitter" - //nolint:all -) - -type Node struct { - ID string - Type string - Name string - CodeSnippet string - LineNumber uint32 - OutgoingEdges []*Edge - IsExternal bool - Modifier string - ReturnType string - MethodArgumentsType []string - MethodArgumentsValue []string - PackageName string - ImportPackage []string - SuperClass string - Interface []string - DataType string - Scope string - VariableValue string - hasAccess bool - File string - isJavaSourceFile bool - ThrowsExceptions []string - Annotation []string - JavaDoc *model.Javadoc - BinaryExpr *model.BinaryExpr - ClassInstanceExpr *model.ClassInstanceExpr - IfStmt *model.IfStmt - WhileStmt *model.WhileStmt - DoStmt *model.DoStmt - ForStmt *model.ForStmt - BreakStmt *model.BreakStmt - ContinueStmt *model.ContinueStmt - YieldStmt *model.YieldStmt - AssertStmt *model.AssertStmt - ReturnStmt *model.ReturnStmt - BlockStmt *model.BlockStmt -} - -type Edge struct { - From *Node - To *Node -} - -type CodeGraph struct { - Nodes map[string]*Node - Edges []*Edge -} - -func NewCodeGraph() *CodeGraph { - return &CodeGraph{ - Nodes: make(map[string]*Node), - Edges: make([]*Edge, 0), - } -} - -func (g *CodeGraph) AddNode(node *Node) { - g.Nodes[node.ID] = node -} - -func (g *CodeGraph) AddEdge(from, to *Node) { - edge := &Edge{From: from, To: to} - g.Edges = append(g.Edges, edge) - from.OutgoingEdges = append(from.OutgoingEdges, edge) -} - -// Add to graph.go - -// FindNodesByType finds all nodes of a given type. -func (g *CodeGraph) FindNodesByType(nodeType string) []*Node { - var nodes []*Node - for _, node := range g.Nodes { - if node.Type == nodeType { - nodes = append(nodes, node) - } - } - return nodes -} - -func extractVisibilityModifier(modifiers string) string { - words := strings.Fields(modifiers) - for _, word := range words { - switch word { - case "public", "private", "protected": - return word - } - } - return "" // return an empty string if no visibility modifier is found -} - -func isJavaSourceFile(filename string) bool { - return filepath.Ext(filename) == ".java" -} - -//nolint:all -func hasAccess(node *sitter.Node, variableName string, sourceCode []byte) bool { - if node == nil { - return false - } - if node.Type() == "identifier" && node.Content(sourceCode) == variableName { - return true - } - - // Recursively check all children of the current node - for i := 0; i < int(node.ChildCount()); i++ { - childNode := node.Child(i) - if hasAccess(childNode, variableName, sourceCode) { - return true - } - } - - // Continue checking in the next sibling - return hasAccess(node.NextSibling(), variableName, sourceCode) -} - -func parseJavadocTags(commentContent string) *model.Javadoc { - javaDoc := &model.Javadoc{} - var javadocTags []*model.JavadocTag - - commentLines := strings.Split(commentContent, "\n") - for _, line := range commentLines { - line = strings.TrimSpace(line) - // line may start with /** or * - line = strings.TrimPrefix(line, "*") - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "@") { - parts := strings.SplitN(line, " ", 2) - if len(parts) == 2 { - tagName := strings.TrimPrefix(parts[0], "@") - tagText := strings.TrimSpace(parts[1]) - - var javadocTag *model.JavadocTag - switch tagName { - case "author": - javadocTag = model.NewJavadocTag(tagName, tagText, "author") - javaDoc.Author = tagText - case "param": - javadocTag = model.NewJavadocTag(tagName, tagText, "param") - case "see": - javadocTag = model.NewJavadocTag(tagName, tagText, "see") - case "throws": - javadocTag = model.NewJavadocTag(tagName, tagText, "throws") - case "version": - javadocTag = model.NewJavadocTag(tagName, tagText, "version") - javaDoc.Version = tagText - case "since": - javadocTag = model.NewJavadocTag(tagName, tagText, "since") - default: - javadocTag = model.NewJavadocTag(tagName, tagText, "unknown") - } - javadocTags = append(javadocTags, javadocTag) - } - } - } - - javaDoc.Tags = javadocTags - javaDoc.NumberOfCommentLines = len(commentLines) - javaDoc.CommentedCodeElements = commentContent - - return javaDoc -} - -func buildGraphFromAST(node *sitter.Node, sourceCode []byte, graph *CodeGraph, currentContext *Node, file string) { - isJavaSourceFile := isJavaSourceFile(file) - switch node.Type() { - case "block": - blockNode := javalang.ParseBlockStatement(node, sourceCode) - uniqueBlockID := fmt.Sprintf("block_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) - blockStmtNode := &Node{ - ID: GenerateSha256(uniqueBlockID), - Type: "BlockStmt", - LineNumber: node.StartPoint().Row + 1, - Name: "BlockStmt", - IsExternal: true, - CodeSnippet: node.Content(sourceCode), - File: file, - isJavaSourceFile: isJavaSourceFile, - BlockStmt: blockNode, - } - graph.AddNode(blockStmtNode) - case "return_statement": - returnNode := javalang.ParseReturnStatement(node, sourceCode) - uniqueReturnID := fmt.Sprintf("return_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) - returnStmtNode := &Node{ - ID: GenerateSha256(uniqueReturnID), - Type: "ReturnStmt", - LineNumber: node.StartPoint().Row + 1, - Name: "ReturnStmt", - IsExternal: true, - CodeSnippet: node.Content(sourceCode), - File: file, - isJavaSourceFile: isJavaSourceFile, - ReturnStmt: returnNode, - } - graph.AddNode(returnStmtNode) - case "assert_statement": - assertNode := javalang.ParseAssertStatement(node, sourceCode) - uniqueAssertID := fmt.Sprintf("assert_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) - assertStmtNode := &Node{ - ID: GenerateSha256(uniqueAssertID), - Type: "AssertStmt", - LineNumber: node.StartPoint().Row + 1, - Name: "AssertStmt", - IsExternal: true, - CodeSnippet: node.Content(sourceCode), - File: file, - isJavaSourceFile: isJavaSourceFile, - AssertStmt: assertNode, - } - graph.AddNode(assertStmtNode) - case "yield_statement": - yieldNode := javalang.ParseYieldStatement(node, sourceCode) - uniqueyieldID := fmt.Sprintf("yield_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) - yieldStmtNode := &Node{ - ID: GenerateSha256(uniqueyieldID), - Type: "YieldStmt", - LineNumber: node.StartPoint().Row + 1, - Name: "YieldStmt", - IsExternal: true, - CodeSnippet: node.Content(sourceCode), - File: file, - isJavaSourceFile: isJavaSourceFile, - YieldStmt: yieldNode, - } - graph.AddNode(yieldStmtNode) - case "break_statement": - breakNode := javalang.ParseBreakStatement(node, sourceCode) - uniquebreakstmtID := fmt.Sprintf("breakstmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) - breakStmtNode := &Node{ - ID: GenerateSha256(uniquebreakstmtID), - Type: "BreakStmt", - LineNumber: node.StartPoint().Row + 1, - Name: "BreakStmt", - IsExternal: true, - CodeSnippet: node.Content(sourceCode), - File: file, - isJavaSourceFile: isJavaSourceFile, - BreakStmt: breakNode, - } - graph.AddNode(breakStmtNode) - case "continue_statement": - continueNode := javalang.ParseContinueStatement(node, sourceCode) - uniquecontinueID := fmt.Sprintf("continuestmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) - continueStmtNode := &Node{ - ID: GenerateSha256(uniquecontinueID), - Type: "ContinueStmt", - LineNumber: node.StartPoint().Row + 1, - Name: "ContinueStmt", - IsExternal: true, - CodeSnippet: node.Content(sourceCode), - File: file, - isJavaSourceFile: isJavaSourceFile, - ContinueStmt: continueNode, - } - graph.AddNode(continueStmtNode) - case "if_statement": - ifNode := model.IfStmt{} - // get the condition of the if statement - conditionNode := node.Child(1) - if conditionNode != nil { - ifNode.Condition = &model.Expr{Node: *conditionNode, NodeString: conditionNode.Content(sourceCode)} - } - // get the then block of the if statement - thenNode := node.Child(2) - if thenNode != nil { - ifNode.Then = model.Stmt{NodeString: thenNode.Content(sourceCode)} - } - // get the else block of the if statement - elseNode := node.Child(4) - if elseNode != nil { - ifNode.Else = model.Stmt{NodeString: elseNode.Content(sourceCode)} - } - - methodID := fmt.Sprintf("ifstmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) - // add node to graph - ifStmtNode := &Node{ - ID: GenerateSha256(methodID), - Type: "IfStmt", - Name: "IfStmt", - IsExternal: true, - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, - File: file, - isJavaSourceFile: isJavaSourceFile, - IfStmt: &ifNode, - } - graph.AddNode(ifStmtNode) - case "while_statement": - whileNode := model.WhileStmt{} - // get the condition of the while statement - conditionNode := node.Child(1) - if conditionNode != nil { - whileNode.Condition = &model.Expr{Node: *conditionNode, NodeString: conditionNode.Content(sourceCode)} - } - methodID := fmt.Sprintf("while_stmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) - // add node to graph - whileStmtNode := &Node{ - ID: GenerateSha256(methodID), - Type: "WhileStmt", - Name: "WhileStmt", - IsExternal: true, - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, - File: file, - isJavaSourceFile: isJavaSourceFile, - WhileStmt: &whileNode, - } - graph.AddNode(whileStmtNode) - case "do_statement": - doWhileNode := model.DoStmt{} - // get the condition of the while statement - conditionNode := node.Child(2) - if conditionNode != nil { - doWhileNode.Condition = &model.Expr{Node: *conditionNode, NodeString: conditionNode.Content(sourceCode)} - } - methodID := fmt.Sprintf("dowhile_stmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) - // add node to graph - doWhileStmtNode := &Node{ - ID: GenerateSha256(methodID), - Type: "DoStmt", - Name: "DoStmt", - IsExternal: true, - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, - File: file, - isJavaSourceFile: isJavaSourceFile, - DoStmt: &doWhileNode, - } - graph.AddNode(doWhileStmtNode) - case "for_statement": - forNode := model.ForStmt{} - // get the condition of the while statement - initNode := node.ChildByFieldName("init") - if initNode != nil { - forNode.Init = &model.Expr{Node: *initNode, NodeString: initNode.Content(sourceCode)} - } - conditionNode := node.ChildByFieldName("condition") - if conditionNode != nil { - forNode.Condition = &model.Expr{Node: *conditionNode, NodeString: conditionNode.Content(sourceCode)} - } - incrementNode := node.ChildByFieldName("increment") - if incrementNode != nil { - forNode.Increment = &model.Expr{Node: *incrementNode, NodeString: incrementNode.Content(sourceCode)} - } - - methodID := fmt.Sprintf("for_stmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) - // add node to graph - forStmtNode := &Node{ - ID: GenerateSha256(methodID), - Type: "ForStmt", - Name: "ForStmt", - IsExternal: true, - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, - File: file, - isJavaSourceFile: isJavaSourceFile, - ForStmt: &forNode, - } - graph.AddNode(forStmtNode) - case "binary_expression": - leftNode := node.ChildByFieldName("left") - rightNode := node.ChildByFieldName("right") - operator := node.ChildByFieldName("operator") - operatorType := operator.Type() - expressionNode := model.BinaryExpr{} - expressionNode.LeftOperand = &model.Expr{Node: *leftNode, NodeString: leftNode.Content(sourceCode)} - expressionNode.RightOperand = &model.Expr{Node: *rightNode, NodeString: rightNode.Content(sourceCode)} - expressionNode.Op = operatorType - switch operatorType { - case "+": - var addExpr model.AddExpr - addExpr.LeftOperand = expressionNode.LeftOperand - addExpr.RightOperand = expressionNode.RightOperand - addExpr.Op = expressionNode.Op - addExpr.BinaryExpr = expressionNode - addExpressionNode := &Node{ - ID: GenerateSha256("add_expression" + node.Content(sourceCode)), - Type: "add_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(addExpressionNode) - case "-": - var subExpr model.SubExpr - subExpr.LeftOperand = expressionNode.LeftOperand - subExpr.RightOperand = expressionNode.RightOperand - subExpr.Op = expressionNode.Op - subExpr.BinaryExpr = expressionNode - subExpressionNode := &Node{ - ID: GenerateSha256("sub_expression" + node.Content(sourceCode)), - Type: "sub_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(subExpressionNode) - case "*": - var mulExpr model.MulExpr - mulExpr.LeftOperand = expressionNode.LeftOperand - mulExpr.RightOperand = expressionNode.RightOperand - mulExpr.Op = expressionNode.Op - mulExpr.BinaryExpr = expressionNode - mulExpressionNode := &Node{ - ID: GenerateSha256("mul_expression" + node.Content(sourceCode)), - Type: "mul_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(mulExpressionNode) - case "/": - var divExpr model.DivExpr - divExpr.LeftOperand = expressionNode.LeftOperand - divExpr.RightOperand = expressionNode.RightOperand - divExpr.Op = expressionNode.Op - divExpr.BinaryExpr = expressionNode - divExpressionNode := &Node{ - ID: GenerateSha256("div_expression" + node.Content(sourceCode)), - Type: "div_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(divExpressionNode) - case ">", "<", ">=", "<=": - var compExpr model.ComparisonExpr - compExpr.LeftOperand = expressionNode.LeftOperand - compExpr.RightOperand = expressionNode.RightOperand - compExpr.Op = expressionNode.Op - compExpr.BinaryExpr = expressionNode - compExpressionNode := &Node{ - ID: GenerateSha256("comp_expression" + node.Content(sourceCode)), - Type: "comp_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(compExpressionNode) - case "%": - var RemExpr model.RemExpr - RemExpr.LeftOperand = expressionNode.LeftOperand - RemExpr.RightOperand = expressionNode.RightOperand - RemExpr.Op = expressionNode.Op - RemExpr.BinaryExpr = expressionNode - RemExpressionNode := &Node{ - ID: GenerateSha256("rem_expression" + node.Content(sourceCode)), - Type: "rem_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(RemExpressionNode) - case ">>": - var RightShiftExpr model.RightShiftExpr - RightShiftExpr.LeftOperand = expressionNode.LeftOperand - RightShiftExpr.RightOperand = expressionNode.RightOperand - RightShiftExpr.Op = expressionNode.Op - RightShiftExpr.BinaryExpr = expressionNode - RightShiftExpressionNode := &Node{ - ID: GenerateSha256("right_shift_expression" + node.Content(sourceCode)), - Type: "right_shift_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(RightShiftExpressionNode) - case "<<": - var LeftShiftExpr model.LeftShiftExpr - LeftShiftExpr.LeftOperand = expressionNode.LeftOperand - LeftShiftExpr.RightOperand = expressionNode.RightOperand - LeftShiftExpr.Op = expressionNode.Op - LeftShiftExpr.BinaryExpr = expressionNode - LeftShiftExpressionNode := &Node{ - ID: GenerateSha256("left_shift_expression" + node.Content(sourceCode)), - Type: "left_shift_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(LeftShiftExpressionNode) - case "!=": - var NEExpr model.NEExpr - NEExpr.LeftOperand = expressionNode.LeftOperand - NEExpr.RightOperand = expressionNode.RightOperand - NEExpr.Op = expressionNode.Op - NEExpr.BinaryExpr = expressionNode - NEExpressionNode := &Node{ - ID: GenerateSha256("ne_expression" + node.Content(sourceCode)), - Type: "ne_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(NEExpressionNode) - case "==": - var EQExpr model.EqExpr - EQExpr.LeftOperand = expressionNode.LeftOperand - EQExpr.RightOperand = expressionNode.RightOperand - EQExpr.Op = expressionNode.Op - EQExpr.BinaryExpr = expressionNode - EQExpressionNode := &Node{ - ID: GenerateSha256("eq_expression" + node.Content(sourceCode)), - Type: "eq_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(EQExpressionNode) - case "&": - var BitwiseAndExpr model.AndBitwiseExpr - BitwiseAndExpr.LeftOperand = expressionNode.LeftOperand - BitwiseAndExpr.RightOperand = expressionNode.RightOperand - BitwiseAndExpr.Op = expressionNode.Op - BitwiseAndExpr.BinaryExpr = expressionNode - BitwiseAndExpressionNode := &Node{ - ID: GenerateSha256("bitwise_and_expression" + node.Content(sourceCode)), - Type: "bitwise_and_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(BitwiseAndExpressionNode) - case "&&": - var AndExpr model.AndLogicalExpr - AndExpr.LeftOperand = expressionNode.LeftOperand - AndExpr.RightOperand = expressionNode.RightOperand - AndExpr.Op = expressionNode.Op - AndExpr.BinaryExpr = expressionNode - AndExpressionNode := &Node{ - ID: GenerateSha256("and_expression" + node.Content(sourceCode)), - Type: "and_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(AndExpressionNode) - case "||": - var OrExpr model.OrLogicalExpr - OrExpr.LeftOperand = expressionNode.LeftOperand - OrExpr.RightOperand = expressionNode.RightOperand - OrExpr.Op = expressionNode.Op - OrExpr.BinaryExpr = expressionNode - OrExpressionNode := &Node{ - ID: GenerateSha256("or_expression" + node.Content(sourceCode)), - Type: "or_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(OrExpressionNode) - case "|": - var BitwiseOrExpr model.OrBitwiseExpr - BitwiseOrExpr.LeftOperand = expressionNode.LeftOperand - BitwiseOrExpr.RightOperand = expressionNode.RightOperand - BitwiseOrExpr.Op = expressionNode.Op - BitwiseOrExpr.BinaryExpr = expressionNode - BitwiseOrExpressionNode := &Node{ - ID: GenerateSha256("bitwise_or_expression" + node.Content(sourceCode)), - Type: "bitwise_or_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(BitwiseOrExpressionNode) - case ">>>": - var BitwiseRightShiftExpr model.UnsignedRightShiftExpr - BitwiseRightShiftExpr.LeftOperand = expressionNode.LeftOperand - BitwiseRightShiftExpr.RightOperand = expressionNode.RightOperand - BitwiseRightShiftExpr.Op = expressionNode.Op - BitwiseRightShiftExpr.BinaryExpr = expressionNode - BitwiseRightShiftExpressionNode := &Node{ - ID: GenerateSha256("bitwise_right_shift_expression" + node.Content(sourceCode)), - Type: "bitwise_right_shift_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(BitwiseRightShiftExpressionNode) - case "^": - var BitwiseXorExpr model.XorBitwiseExpr - BitwiseXorExpr.LeftOperand = expressionNode.LeftOperand - BitwiseXorExpr.RightOperand = expressionNode.RightOperand - BitwiseXorExpr.Op = expressionNode.Op - BitwiseXorExpr.BinaryExpr = expressionNode - BitwiseXorExpressionNode := &Node{ - ID: GenerateSha256("bitwise_xor_expression" + node.Content(sourceCode)), - Type: "bitwise_xor_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(BitwiseXorExpressionNode) - } - - invokedNode := &Node{ - ID: GenerateSha256("binary_expression" + node.Content(sourceCode)), - Type: "binary_expression", - Name: node.Content(sourceCode), - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - File: file, - isJavaSourceFile: isJavaSourceFile, - BinaryExpr: &expressionNode, - } - graph.AddNode(invokedNode) - currentContext = invokedNode - case "method_declaration": - var javadoc *model.Javadoc - if node.PrevSibling() != nil && node.PrevSibling().Type() == "block_comment" { - commentContent := node.PrevSibling().Content(sourceCode) - if strings.HasPrefix(commentContent, "/*") { - javadoc = parseJavadocTags(commentContent) - } - } - methodName, methodID := extractMethodName(node, sourceCode, file) - modifiers := "" - returnType := "" - throws := []string{} - methodArgumentType := []string{} - methodArgumentValue := []string{} - annotationMarkers := []string{} - - for i := 0; i < int(node.ChildCount()); i++ { - childNode := node.Child(i) - childType := childNode.Type() - - switch childType { - case "throws": - // namedChild - for j := 0; j < int(childNode.NamedChildCount()); j++ { - namedChild := childNode.NamedChild(j) - if namedChild.Type() == "type_identifier" { - throws = append(throws, namedChild.Content(sourceCode)) - } - } - case "modifiers": - modifiers = childNode.Content(sourceCode) - for j := 0; j < int(childNode.ChildCount()); j++ { - if childNode.Child(j).Type() == "marker_annotation" { - annotationMarkers = append(annotationMarkers, childNode.Child(j).Content(sourceCode)) - } - } - case "void_type", "type_identifier": - // get return type of method - returnType = childNode.Content(sourceCode) - case "formal_parameters": - // get method arguments - for j := 0; j < int(childNode.NamedChildCount()); j++ { - param := childNode.NamedChild(j) - if param.Type() == "formal_parameter" { - // get type of argument and add to method arguments - paramType := param.Child(0).Content(sourceCode) - paramValue := param.Child(1).Content(sourceCode) - methodArgumentType = append(methodArgumentType, paramType) - methodArgumentValue = append(methodArgumentValue, paramValue) - } - } - } - } - - invokedNode := &Node{ - ID: methodID, // In a real scenario, you would construct a unique ID, possibly using the method signature - Type: "method_declaration", - Name: methodName, - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - Modifier: extractVisibilityModifier(modifiers), - ReturnType: returnType, - MethodArgumentsType: methodArgumentType, - MethodArgumentsValue: methodArgumentValue, - File: file, - isJavaSourceFile: isJavaSourceFile, - ThrowsExceptions: throws, - Annotation: annotationMarkers, - JavaDoc: javadoc, - } - graph.AddNode(invokedNode) - currentContext = invokedNode // Update context to the new method - - case "method_invocation": - methodName, methodID := extractMethodName(node, sourceCode, file) - arguments := []string{} - // get argument list from arguments node iterate for child node - for i := 0; i < int(node.ChildCount()); i++ { - if node.Child(i).Type() == "argument_list" { - argumentsNode := node.Child(i) - for j := 0; j < int(argumentsNode.ChildCount()); j++ { - argument := argumentsNode.Child(j) - switch argument.Type() { - case "identifier": - arguments = append(arguments, argument.Content(sourceCode)) - case "string_literal": - stringliteral := argument.Content(sourceCode) - stringliteral = strings.TrimPrefix(stringliteral, "\"") - stringliteral = strings.TrimSuffix(stringliteral, "\"") - arguments = append(arguments, stringliteral) - default: - arguments = append(arguments, argument.Content(sourceCode)) - } - } - } - } - - invokedNode := &Node{ - ID: methodID, - Type: "method_invocation", - Name: methodName, - IsExternal: true, - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, // Lines start from 0 in the AST - MethodArgumentsValue: arguments, - File: file, - isJavaSourceFile: isJavaSourceFile, - } - graph.AddNode(invokedNode) - - if currentContext != nil { - graph.AddEdge(currentContext, invokedNode) - } - case "class_declaration": - var javadoc *model.Javadoc - if node.PrevSibling() != nil && node.PrevSibling().Type() == "block_comment" { - commentContent := node.PrevSibling().Content(sourceCode) - if strings.HasPrefix(commentContent, "/*") { - javadoc = parseJavadocTags(commentContent) - } - } - className := node.ChildByFieldName("name").Content(sourceCode) - packageName := "" - accessModifier := "" - superClass := "" - annotationMarkers := []string{} - implementedInterface := []string{} - for i := 0; i < int(node.ChildCount()); i++ { - child := node.Child(i) - if child.Type() == "modifiers" { - accessModifier = child.Content(sourceCode) - for j := 0; j < int(child.ChildCount()); j++ { - if child.Child(j).Type() == "marker_annotation" { - annotationMarkers = append(annotationMarkers, child.Child(j).Content(sourceCode)) - } - } - } - if child.Type() == "superclass" { - for j := 0; j < int(child.ChildCount()); j++ { - if child.Child(j).Type() == "type_identifier" { - superClass = child.Child(j).Content(sourceCode) - } - } - } - if child.Type() == "super_interfaces" { - for j := 0; j < int(child.ChildCount()); j++ { - // typelist node and then iterate through type_identifier node - typeList := child.Child(j) - for k := 0; k < int(typeList.ChildCount()); k++ { - implementedInterface = append(implementedInterface, typeList.Child(k).Content(sourceCode)) - } - } - } - } - - classNode := &Node{ - ID: GenerateMethodID(className, []string{}, file), - Type: "class_declaration", - Name: className, - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, - PackageName: packageName, - Modifier: extractVisibilityModifier(accessModifier), - SuperClass: superClass, - Interface: implementedInterface, - File: file, - isJavaSourceFile: isJavaSourceFile, - JavaDoc: javadoc, - Annotation: annotationMarkers, - } - graph.AddNode(classNode) - case "block_comment": - // Parse block comments - if strings.HasPrefix(node.Content(sourceCode), "/*") { - commentContent := node.Content(sourceCode) - javadocTags := parseJavadocTags(commentContent) - - commentNode := &Node{ - ID: GenerateMethodID(node.Content(sourceCode), []string{}, file), - Type: "block_comment", - CodeSnippet: commentContent, - LineNumber: node.StartPoint().Row + 1, - File: file, - isJavaSourceFile: isJavaSourceFile, - JavaDoc: javadocTags, - } - graph.AddNode(commentNode) - } - case "local_variable_declaration", "field_declaration": - // Extract variable name, type, and modifiers - variableName := "" - variableType := "" - variableModifier := "" - variableValue := "" - hasAccessValue := false - var scope string - for i := 0; i < int(node.ChildCount()); i++ { - child := node.Child(i) - switch child.Type() { - case "variable_declarator": - variableName = child.Content(sourceCode) - for j := 0; j < int(child.ChildCount()); j++ { - if child.Child(j).Type() == "identifier" { - variableName = child.Child(j).Content(sourceCode) - } - // if child type contains =, iterate through and get remaining content - if child.Child(j).Type() == "=" { - for k := j + 1; k < int(child.ChildCount()); k++ { - variableValue += child.Child(k).Content(sourceCode) - } - } - - } - // remove spaces from variable value - variableValue = strings.ReplaceAll(variableValue, " ", "") - // remove new line from variable value - variableValue = strings.ReplaceAll(variableValue, "\n", "") - case "modifiers": - variableModifier = child.Content(sourceCode) - } - // if child type contains type, get the type of variable - if strings.Contains(child.Type(), "type") { - variableType = child.Content(sourceCode) - } - } - if node.Type() == "local_variable_declaration" { - scope = "local" - //nolint:all - // hasAccessValue = hasAccess(node.NextSibling(), variableName, sourceCode) - } else { - scope = "field" - } - // Create a new node for the variable - variableNode := &Node{ - ID: GenerateMethodID(variableName, []string{}, file), - Type: "variable_declaration", - Name: variableName, - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, - Modifier: extractVisibilityModifier(variableModifier), - DataType: variableType, - Scope: scope, - VariableValue: variableValue, - hasAccess: hasAccessValue, - File: file, - isJavaSourceFile: isJavaSourceFile, - } - graph.AddNode(variableNode) - case "object_creation_expression": - className := "" - classInstanceExpression := model.ClassInstanceExpr{ - ClassName: "", - Args: []*model.Expr{}, - } - for i := 0; i < int(node.ChildCount()); i++ { - child := node.Child(i) - if child.Type() == "type_identifier" || child.Type() == "scoped_type_identifier" { - className = child.Content(sourceCode) - classInstanceExpression.ClassName = className - } - if child.Type() == "argument_list" { - classInstanceExpression.Args = []*model.Expr{} - for j := 0; j < int(child.ChildCount()); j++ { - argType := child.Child(j).Type() - argumentStopWords := map[string]bool{ - "(": true, - ")": true, - "{": true, - "}": true, - "[": true, - "]": true, - ",": true, - } - if !argumentStopWords[argType] { - argument := &model.Expr{} - argument.Type = child.Child(j).Type() - argument.NodeString = child.Child(j).Content(sourceCode) - classInstanceExpression.Args = append(classInstanceExpression.Args, argument) - } - } - } - } - - objectNode := &Node{ - ID: GenerateMethodID(className, []string{strconv.Itoa(int(node.StartPoint().Row + 1))}, file), - Type: "ClassInstanceExpr", - Name: className, - CodeSnippet: node.Content(sourceCode), - LineNumber: node.StartPoint().Row + 1, - File: file, - isJavaSourceFile: isJavaSourceFile, - ClassInstanceExpr: &classInstanceExpression, - } - graph.AddNode(objectNode) - } - - // Recursively process child nodes - for i := 0; i < int(node.ChildCount()); i++ { - child := node.Child(i) - buildGraphFromAST(child, sourceCode, graph, currentContext, file) - } - - // iterate through method declaration from graph node - for _, node := range graph.Nodes { - if node.Type == "method_declaration" { - // iterate through method method_invocation from graph node - for _, invokedNode := range graph.Nodes { - if invokedNode.Type == "method_invocation" { - if invokedNode.Name == node.Name { - // check argument list count is same - if len(invokedNode.MethodArgumentsValue) == len(node.MethodArgumentsType) { - node.hasAccess = true - } - } - } - } - } - } -} - -//nolint:all -func extractMethodName(node *sitter.Node, sourceCode []byte, filepath string) (string, string) { - var methodID string - - // if the child node is method_declaration, extract method name, modifiers, parameters, and return type - var methodName string - var modifiers, parameters []string - - if node.Type() == "method_declaration" { - // Iterate over all children of the method_declaration node - for i := 0; i < int(node.ChildCount()); i++ { - child := node.Child(i) - switch child.Type() { - case "modifiers", "marker_annotation", "annotation": - // This child is a modifier or annotation, add its content to modifiers - modifiers = append(modifiers, child.Content(sourceCode)) //nolint:all - case "identifier": - // This child is the method name - methodName = child.Content(sourceCode) - case "formal_parameters": - // This child represents formal parameters; iterate through its children - for j := 0; j < int(child.NamedChildCount()); j++ { - param := child.NamedChild(j) - parameters = append(parameters, param.Content(sourceCode)) - } - } - } - } - - // check if type is method_invocation - // if the child node is method_invocation, extract method name - if node.Type() == "method_invocation" { - for j := 0; j < int(node.ChildCount()); j++ { - child := node.Child(j) - if child.Type() == "identifier" { - if methodName == "" { - methodName = child.Content(sourceCode) - } else { - methodName = methodName + "." + child.Content(sourceCode) - } - } - - argumentsNode := node.ChildByFieldName("argument_list") - // add data type of arguments list - if argumentsNode != nil { - for k := 0; k < int(argumentsNode.ChildCount()); k++ { - argument := argumentsNode.Child(k) - parameters = append(parameters, argument.Child(0).Content(sourceCode)) - } - } - - } - } - content := node.Content(sourceCode) - lineNumber := int(node.StartPoint().Row) + 1 - columnNumber := int(node.StartPoint().Column) + 1 - // convert to string and merge - content += " " + strconv.Itoa(lineNumber) + ":" + strconv.Itoa(columnNumber) - methodID = GenerateMethodID(methodName, parameters, filepath+"/"+content) - return methodName, methodID -} -func getFiles(directory string) ([]string, error) { - var files []string - err := filepath.Walk(directory, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - if !info.IsDir() { - // append only java files - if filepath.Ext(path) == ".java" { - files = append(files, path) - } - } - return nil - }) - return files, err -} - -func readFile(path string) ([]byte, error) { - content, err := os.ReadFile(path) - if err != nil { - return nil, err - } - return content, nil -} - -func Initialize(directory string) *CodeGraph { - codeGraph := NewCodeGraph() - // record start time - start := time.Now() - - files, err := getFiles(directory) - if err != nil { - //nolint:all - Log("Directory not found:", err) - return codeGraph - } - - totalFiles := len(files) - numWorkers := 5 // Number of concurrent workers - fileChan := make(chan string, totalFiles) - resultChan := make(chan *CodeGraph, totalFiles) - statusChan := make(chan string, numWorkers) - progressChan := make(chan int, totalFiles) - var wg sync.WaitGroup - - // Worker function - worker := func(workerID int) { - // Initialize the parser for each worker - parser := sitter.NewParser() - defer parser.Close() - - // Set the language (Java in this case) - parser.SetLanguage(java.GetLanguage()) - - for file := range fileChan { - fileName := filepath.Base(file) - statusChan <- fmt.Sprintf("\033[32mWorker %d ....... Reading and parsing code %s\033[0m", workerID, fileName) - sourceCode, err := readFile(file) - if err != nil { - Log("File not found:", err) - continue - } - // Parse the source code - tree, err := parser.ParseCtx(context.TODO(), nil, sourceCode) - if err != nil { - Log("Error parsing file:", err) - continue - } - //nolint:all - defer tree.Close() - - rootNode := tree.RootNode() - localGraph := NewCodeGraph() - statusChan <- fmt.Sprintf("\033[32mWorker %d ....... Building graph and traversing code %s\033[0m", workerID, fileName) - buildGraphFromAST(rootNode, sourceCode, localGraph, nil, file) - statusChan <- fmt.Sprintf("\033[32mWorker %d ....... Done processing file %s\033[0m", workerID, fileName) - - resultChan <- localGraph - progressChan <- 1 - } - wg.Done() - } - - // Start workers - wg.Add(numWorkers) - for i := 0; i < numWorkers; i++ { - go worker(i + 1) - } - - // Send files to workers - for _, file := range files { - fileChan <- file - } - close(fileChan) - - // Status updater - go func() { - statusLines := make([]string, numWorkers) - progress := 0 - for { - select { - case status, ok := <-statusChan: - if !ok { - return - } - workerID := int(status[12] - '0') - statusLines[workerID-1] = status - case _, ok := <-progressChan: - if !ok { - return - } - progress++ - } - fmt.Print("\033[H\033[J") // Clear the screen - for _, line := range statusLines { - Log(line) - } - Fmt("Progress: %d%%\n", (progress*100)/totalFiles) - } - }() - - // Wait for all workers to finish - go func() { - wg.Wait() - close(resultChan) - close(statusChan) - close(progressChan) - }() - - // Collect results - for localGraph := range resultChan { - for _, node := range localGraph.Nodes { - codeGraph.AddNode(node) - } - for _, edge := range localGraph.Edges { - codeGraph.AddEdge(edge.From, edge.To) - } - } - - end := time.Now() - elapsed := end.Sub(start) - Log("Elapsed time: ", elapsed) - Log("Graph built successfully") - - return codeGraph -} diff --git a/sourcecode-parser/graph/graph.go b/sourcecode-parser/graph/graph.go new file mode 100644 index 00000000..cdb1e4a8 --- /dev/null +++ b/sourcecode-parser/graph/graph.go @@ -0,0 +1,32 @@ +package graph + +// NewCodeGraph creates and initializes a new CodeGraph instance. +func NewCodeGraph() *CodeGraph { + return &CodeGraph{ + Nodes: make(map[string]*Node), + Edges: make([]*Edge, 0), + } +} + +// AddNode adds a node to the code graph. +func (g *CodeGraph) AddNode(node *Node) { + g.Nodes[node.ID] = node +} + +// AddEdge adds an edge between two nodes in the code graph. +func (g *CodeGraph) AddEdge(from, to *Node) { + edge := &Edge{From: from, To: to} + g.Edges = append(g.Edges, edge) + from.OutgoingEdges = append(from.OutgoingEdges, edge) +} + +// FindNodesByType finds all nodes of a given type. +func (g *CodeGraph) FindNodesByType(nodeType string) []*Node { + var nodes []*Node + for _, node := range g.Nodes { + if node.Type == nodeType { + nodes = append(nodes, node) + } + } + return nodes +} diff --git a/sourcecode-parser/graph/construct_test.go b/sourcecode-parser/graph/graph_test.go similarity index 66% rename from sourcecode-parser/graph/construct_test.go rename to sourcecode-parser/graph/graph_test.go index 68c1bd26..b2834e8e 100644 --- a/sourcecode-parser/graph/construct_test.go +++ b/sourcecode-parser/graph/graph_test.go @@ -6,10 +6,10 @@ import ( "github.com/shivasurya/code-pathfinder/sourcecode-parser/model" sitter "github.com/smacker/go-tree-sitter" "github.com/smacker/go-tree-sitter/java" + "github.com/smacker/go-tree-sitter/python" "os" "path/filepath" "reflect" - "strings" "testing" ) @@ -877,28 +877,24 @@ func TestBuildGraphFromAST(t *testing.T) { func TestExtractMethodName(t *testing.T) { tests := []struct { - name string - sourceCode string - expectedName string - expectedIDPart string + name string + sourceCode string + expectedName string }{ { - name: "Simple method", - sourceCode: "public void simpleMethod() {}", - expectedName: "simpleMethod", - expectedIDPart: "e4bf121a07daa7b5fc0821f04fe31f22689361aaa7604264034bf231640c0b94", + name: "Simple method", + sourceCode: "public void simpleMethod() {}", + expectedName: "simpleMethod", }, { - name: "Method with parameters", - sourceCode: "private int complexMethod(String a, int b) {}", - expectedName: "complexMethod", - expectedIDPart: "8fa7666614f2db09a92d83f0ec126328a0c0fc93ac0919ffce2be2ce65e5fed5", + name: "Method with parameters", + sourceCode: "private int complexMethod(String a, int b) {}", + expectedName: "complexMethod", }, { - name: "Generic method", - sourceCode: "public List genericMethod(T item) {}", - expectedName: "genericMethod", - expectedIDPart: "4072dc9bf8d115f9c73a0ff3ff2205ef2866845921ac3dd218530ffe85966d96", + name: "Generic method", + sourceCode: "public List genericMethod(T item) {}", + expectedName: "genericMethod", }, } @@ -919,9 +915,439 @@ func TestExtractMethodName(t *testing.T) { t.Errorf("Expected method name %s, but got %s", tt.expectedName, name) } - if !strings.Contains(id, tt.expectedIDPart) { - t.Errorf("Expected method ID to contain %s, but got %s", tt.expectedIDPart, id) + // Verify ID is non-empty and contains the method name (with prefix) + if id == "" { + t.Error("Expected non-empty method ID") + } + + // Method declarations should have IDs prefixed with method: + if methodNode.Type() == "method_declaration" { + // The ID is a hash, but we can verify it was generated (non-empty) + if len(id) != 64 { + t.Errorf("Expected method ID to be SHA256 hash (64 chars), got length %d", len(id)) + } + } + }) + } +} + +// Python-specific tests + +func TestIsPythonSourceFile(t *testing.T) { + tests := []struct { + name string + filename string + want bool + }{ + {"Valid Python file", "example.py", true}, + {"Python file with path", "/path/to/script.py", true}, + {"Python file with Windows path", "C:\\path\\to\\script.py", true}, + {"File with multiple dots", "my.test.script.py", true}, + {"Hidden Python file", ".hidden.py", true}, + {"Non-Python file", "example.txt", false}, + {"Java file", "Example.java", false}, + {"No extension", "pythonfile", false}, + {"Empty string", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isPythonSourceFile(tt.filename); got != tt.want { + t.Errorf("isPythonSourceFile(%q) = %v, want %v", tt.filename, got, tt.want) } }) } } + +func TestGetFilesMixedLanguages(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_get_files_mixed") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create test files + testFiles := []struct { + name string + content string + shouldFind bool + }{ + {"file1.py", "print('Hello')", true}, + {"file2.txt", "Text content", false}, + {"file3.py", "def func(): pass", true}, + {"subdir/file4.py", "class Test: pass", true}, + {"file5", "No extension file", false}, + {"test.java", "public class Test {}", true}, + } + + for _, tf := range testFiles { + path := filepath.Join(tempDir, tf.name) + err := os.MkdirAll(filepath.Dir(path), 0755) + if err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + err = os.WriteFile(path, []byte(tf.content), 0644) + if err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + } + + // Run getFiles + files, err := getFiles(tempDir) + if err != nil { + t.Fatalf("getFiles returned an error: %v", err) + } + + // Check that both .py and .java files are found + expectedFiles := 4 // 3 Python + 1 Java + if len(files) != expectedFiles { + t.Errorf("Expected %d files, but got %d", expectedFiles, len(files)) + } + + // Verify file extensions + pythonCount := 0 + javaCount := 0 + for _, file := range files { + ext := filepath.Ext(file) + switch ext { + case ".py": + pythonCount++ + case ".java": + javaCount++ + default: + t.Errorf("Unexpected file extension: %s", ext) + } + } + + if pythonCount != 3 { + t.Errorf("Expected 3 Python files, got %d", pythonCount) + } + if javaCount != 1 { + t.Errorf("Expected 1 Java file, got %d", javaCount) + } +} + +func TestBuildGraphFromASTPythonFunctionDefinition(t *testing.T) { + tests := []struct { + name string + sourceCode string + expectedNodeCount int + expectedName string + expectedParams []string + }{ + { + name: "Simple function without parameters", + sourceCode: `def simple_func(): + pass`, + expectedNodeCount: 1, + expectedName: "simple_func", + expectedParams: []string{}, + }, + { + name: "Function with parameters", + sourceCode: `def add(x, y): + return x + y`, + expectedNodeCount: 2, // function + return + expectedName: "add", + expectedParams: []string{"x", "y"}, + }, + { + name: "Method with self parameter", + sourceCode: `def method(self, arg1, arg2): + self.value = arg1`, + expectedNodeCount: 2, // function + assignment + expectedName: "method", + expectedParams: []string{"self", "arg1", "arg2"}, + }, + { + name: "Function with default parameters", + sourceCode: `def func_with_defaults(x, y=10, z=20): + return x + y + z`, + expectedNodeCount: 2, // function + return + expectedName: "func_with_defaults", + expectedParams: []string{"x", "y=10", "z=20"}, // Parser captures default values + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(tt.sourceCode)) + if err != nil { + t.Fatalf("Failed to parse Python source code: %v", err) + } + defer tree.Close() + + root := tree.RootNode() + graph := NewCodeGraph() + buildGraphFromAST(root, []byte(tt.sourceCode), graph, nil, "test.py") + + if len(graph.Nodes) < tt.expectedNodeCount { + t.Errorf("Expected at least %d nodes, but got %d", tt.expectedNodeCount, len(graph.Nodes)) + } + + // Find the function_definition node + var funcNode *Node + for _, node := range graph.Nodes { + if node.Type == "function_definition" { + funcNode = node + break + } + } + + if funcNode == nil { + t.Fatal("No function_definition node found") + } + + if funcNode.Name != tt.expectedName { + t.Errorf("Expected function name %s, got %s", tt.expectedName, funcNode.Name) + } + + if !funcNode.isPythonSourceFile { + t.Error("Expected isPythonSourceFile to be true") + } + + if len(tt.expectedParams) > 0 { + if len(funcNode.MethodArgumentsValue) != len(tt.expectedParams) { + t.Errorf("Expected %d parameters, got %d", len(tt.expectedParams), len(funcNode.MethodArgumentsValue)) + } + for i, param := range tt.expectedParams { + if i < len(funcNode.MethodArgumentsValue) && funcNode.MethodArgumentsValue[i] != param { + t.Errorf("Expected parameter %d to be %s, got %s", i, param, funcNode.MethodArgumentsValue[i]) + } + } + } + }) + } +} + +func TestBuildGraphFromASTPythonClassDefinition(t *testing.T) { + tests := []struct { + name string + sourceCode string + expectedNodeCount int + expectedClassName string + expectedBases []string + }{ + { + name: "Simple class without base", + sourceCode: `class SimpleClass: + pass`, + expectedNodeCount: 1, + expectedClassName: "SimpleClass", + expectedBases: []string{}, + }, + { + name: "Class with single base", + sourceCode: `class Derived(Base): + pass`, + expectedNodeCount: 1, + expectedClassName: "Derived", + expectedBases: []string{"Base"}, + }, + { + name: "Class with multiple bases", + sourceCode: `class MultiDerived(Base1, Base2, Base3): + pass`, + expectedNodeCount: 1, + expectedClassName: "MultiDerived", + expectedBases: []string{"Base1", "Base2", "Base3"}, + }, + { + name: "Class with method", + sourceCode: `class MyClass: + def my_method(self): + return 42`, + expectedNodeCount: 3, // class + method + return + expectedClassName: "MyClass", + expectedBases: []string{}, + }, + { + name: "Class with __init__ method", + sourceCode: `class Person: + def __init__(self, name): + self.name = name`, + expectedNodeCount: 3, // class + __init__ + assignment + expectedClassName: "Person", + expectedBases: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(tt.sourceCode)) + if err != nil { + t.Fatalf("Failed to parse Python source code: %v", err) + } + defer tree.Close() + + root := tree.RootNode() + graph := NewCodeGraph() + buildGraphFromAST(root, []byte(tt.sourceCode), graph, nil, "test.py") + + if len(graph.Nodes) < tt.expectedNodeCount { + t.Errorf("Expected at least %d nodes, but got %d", tt.expectedNodeCount, len(graph.Nodes)) + } + + // Find the class_definition node + var classNode *Node + for _, node := range graph.Nodes { + if node.Type == "class_definition" { + classNode = node + break + } + } + + if classNode == nil { + t.Fatal("No class_definition node found") + } + + if classNode.Name != tt.expectedClassName { + t.Errorf("Expected class name %s, got %s", tt.expectedClassName, classNode.Name) + } + + if !classNode.isPythonSourceFile { + t.Error("Expected isPythonSourceFile to be true") + } + + if len(tt.expectedBases) > 0 { + if len(classNode.Interface) != len(tt.expectedBases) { + t.Errorf("Expected %d base classes, got %d", len(tt.expectedBases), len(classNode.Interface)) + } + for i, base := range tt.expectedBases { + if i < len(classNode.Interface) && classNode.Interface[i] != base { + t.Errorf("Expected base class %d to be %s, got %s", i, base, classNode.Interface[i]) + } + } + } else if len(classNode.Interface) != 0 { + t.Errorf("Expected no base classes, got %d", len(classNode.Interface)) + } + }) + } +} + +func TestBuildGraphFromASTPythonStatements(t *testing.T) { + tests := []struct { + name string + sourceCode string + expectedTypes []string + minNodeCount int + }{ + { + name: "Function with return statement", + sourceCode: `def get_value(): + return 42`, + expectedTypes: []string{"function_definition", "ReturnStmt"}, + minNodeCount: 2, + }, + { + name: "Function with assert statement", + sourceCode: `def validate(x): + assert x > 0, "must be positive"`, + expectedTypes: []string{"function_definition", "AssertStmt"}, + minNodeCount: 2, + }, + { + name: "Function with break and continue", + sourceCode: `def loop(): + for i in range(10): + if i == 5: + break + if i == 3: + continue`, + expectedTypes: []string{"function_definition", "BreakStmt", "ContinueStmt"}, + minNodeCount: 3, + }, + { + name: "Generator function with yield", + sourceCode: `def gen(): + yield 1 + yield 2`, + expectedTypes: []string{"function_definition", "YieldStmt"}, + minNodeCount: 2, + }, + { + name: "Function with variable assignment", + sourceCode: `def compute(): + result = 10 + 20 + return result`, + expectedTypes: []string{"function_definition", "variable_assignment", "ReturnStmt"}, + minNodeCount: 3, + }, + { + name: "Function with function calls", + sourceCode: `def caller(): + print("hello") + other_func()`, + expectedTypes: []string{"function_definition", "call"}, + minNodeCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(tt.sourceCode)) + if err != nil { + t.Fatalf("Failed to parse Python source code: %v", err) + } + defer tree.Close() + + root := tree.RootNode() + graph := NewCodeGraph() + buildGraphFromAST(root, []byte(tt.sourceCode), graph, nil, "test.py") + + if len(graph.Nodes) < tt.minNodeCount { + t.Errorf("Expected at least %d nodes, but got %d", tt.minNodeCount, len(graph.Nodes)) + } + + // Verify expected node types exist + nodeTypes := make(map[string]bool) + pythonSpecificTypes := map[string]bool{ + "function_definition": true, + "class_definition": true, + "call": true, + "variable_assignment": true, + "ReturnStmt": true, + "AssertStmt": true, + "BreakStmt": true, + "ContinueStmt": true, + "YieldStmt": true, + } + + for _, node := range graph.Nodes { + nodeTypes[node.Type] = true + + // Python-specific nodes should be marked as Python + if pythonSpecificTypes[node.Type] && !node.isPythonSourceFile { + t.Errorf("Node %s (type: %s) should have isPythonSourceFile=true", node.ID, node.Type) + } + } + + for _, expectedType := range tt.expectedTypes { + if !nodeTypes[expectedType] { + t.Errorf("Expected node type %s not found. Found types: %v", expectedType, getKeys(nodeTypes)) + } + } + }) + } +} + +// Helper function to get map keys. +func getKeys(m map[string]bool) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} diff --git a/sourcecode-parser/graph/initialize.go b/sourcecode-parser/graph/initialize.go new file mode 100644 index 00000000..d2fd6c65 --- /dev/null +++ b/sourcecode-parser/graph/initialize.go @@ -0,0 +1,144 @@ +package graph + +import ( + "context" + "fmt" + "path/filepath" + "sync" + "time" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/java" + "github.com/smacker/go-tree-sitter/python" +) + +// Initialize initializes the code graph by parsing all source files in a directory. +func Initialize(directory string) *CodeGraph { + codeGraph := NewCodeGraph() + start := time.Now() + + files, err := getFiles(directory) + if err != nil { + //nolint:all + Log("Directory not found:", err) + return codeGraph + } + + totalFiles := len(files) + numWorkers := 5 + fileChan := make(chan string, totalFiles) + resultChan := make(chan *CodeGraph, totalFiles) + statusChan := make(chan string, numWorkers) + progressChan := make(chan int, totalFiles) + var wg sync.WaitGroup + + // Worker function + worker := func(workerID int) { + parser := sitter.NewParser() + defer parser.Close() + + for file := range fileChan { + fileName := filepath.Base(file) + fileExt := filepath.Ext(file) + + // Set the language based on file extension + switch fileExt { + case ".java": + parser.SetLanguage(java.GetLanguage()) + case ".py": + parser.SetLanguage(python.GetLanguage()) + default: + Log("Unsupported file type:", file) + continue + } + + statusChan <- fmt.Sprintf("\033[32mWorker %d ....... Reading and parsing code %s\033[0m", workerID, fileName) + sourceCode, err := readFile(file) + if err != nil { + Log("File not found:", err) + continue + } + + tree, err := parser.ParseCtx(context.TODO(), nil, sourceCode) + if err != nil { + Log("Error parsing file:", err) + continue + } + //nolint:all + defer tree.Close() + + rootNode := tree.RootNode() + localGraph := NewCodeGraph() + statusChan <- fmt.Sprintf("\033[32mWorker %d ....... Building graph and traversing code %s\033[0m", workerID, fileName) + buildGraphFromAST(rootNode, sourceCode, localGraph, nil, file) + statusChan <- fmt.Sprintf("\033[32mWorker %d ....... Done processing file %s\033[0m", workerID, fileName) + + resultChan <- localGraph + progressChan <- 1 + } + wg.Done() + } + + // Start workers + wg.Add(numWorkers) + for i := 0; i < numWorkers; i++ { + go worker(i + 1) + } + + // Send files to workers + for _, file := range files { + fileChan <- file + } + close(fileChan) + + // Status updater + go func() { + statusLines := make([]string, numWorkers) + progress := 0 + for { + select { + case status, ok := <-statusChan: + if !ok { + return + } + workerID := int(status[12] - '0') + statusLines[workerID-1] = status + case _, ok := <-progressChan: + if !ok { + return + } + progress++ + } + fmt.Print("\033[H\033[J") // Clear the screen + for _, line := range statusLines { + Log(line) + } + Fmt("Progress: %d%%\n", (progress*100)/totalFiles) + } + }() + + // Wait for all workers to finish + go func() { + wg.Wait() + close(resultChan) + close(statusChan) + close(progressChan) + }() + + // Collect results + for localGraph := range resultChan { + for _, node := range localGraph.Nodes { + codeGraph.AddNode(node) + } + for _, edge := range localGraph.Edges { + codeGraph.AddEdge(edge.From, edge.To) + } + } + + end := time.Now() + elapsed := end.Sub(start) + Log("Elapsed time: ", elapsed) + Log("Graph built successfully") + + return codeGraph +} diff --git a/sourcecode-parser/graph/initialize_test.go b/sourcecode-parser/graph/initialize_test.go new file mode 100644 index 00000000..9d78fd1c --- /dev/null +++ b/sourcecode-parser/graph/initialize_test.go @@ -0,0 +1,224 @@ +package graph + +import ( + "os" + "path/filepath" + "testing" +) + +func TestInitializeWithEmptyDirectory(t *testing.T) { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "test_empty_dir") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + graph := Initialize(tmpDir) + + if graph == nil { + t.Fatal("Initialize should return a non-nil graph") + } + if len(graph.Nodes) != 0 { + t.Errorf("Expected 0 nodes for empty directory, got %d", len(graph.Nodes)) + } +} + +func TestInitializeWithJavaFile(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "test_java_dir") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create a simple Java file + javaCode := ` +public class HelloWorld { + public static void main(String[] args) { + System.out.println("Hello, World!"); + } +} +` + javaFile := filepath.Join(tmpDir, "HelloWorld.java") + if err := os.WriteFile(javaFile, []byte(javaCode), 0644); err != nil { + t.Fatalf("Failed to write Java file: %v", err) + } + + graph := Initialize(tmpDir) + + if graph == nil { + t.Fatal("Initialize should return a non-nil graph") + } + if len(graph.Nodes) == 0 { + t.Error("Expected nodes to be created from Java file") + } + + // Check for class node + hasClassNode := false + for _, node := range graph.Nodes { + if node.Type == "class_declaration" && node.Name == "HelloWorld" { + hasClassNode = true + break + } + } + if !hasClassNode { + t.Error("Expected to find HelloWorld class node") + } +} + +func TestInitializeWithPythonFile(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "test_python_dir") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create a simple Python file + pythonCode := ` +def greet(name): + return f"Hello, {name}!" + +class Greeter: + def __init__(self, greeting): + self.greeting = greeting +` + pythonFile := filepath.Join(tmpDir, "greet.py") + if err := os.WriteFile(pythonFile, []byte(pythonCode), 0644); err != nil { + t.Fatalf("Failed to write Python file: %v", err) + } + + graph := Initialize(tmpDir) + + if graph == nil { + t.Fatal("Initialize should return a non-nil graph") + } + if len(graph.Nodes) == 0 { + t.Error("Expected nodes to be created from Python file") + } + + // Check for function and class nodes + hasFunctionNode := false + hasClassNode := false + for _, node := range graph.Nodes { + if node.Type == "function_definition" && node.Name == "greet" { + hasFunctionNode = true + } + if node.Type == "class_definition" && node.Name == "Greeter" { + hasClassNode = true + } + } + if !hasFunctionNode { + t.Error("Expected to find greet function node") + } + if !hasClassNode { + t.Error("Expected to find Greeter class node") + } +} + +func TestInitializeWithMixedFiles(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "test_mixed_dir") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create a Java file + javaCode := `public class Test { }` + javaFile := filepath.Join(tmpDir, "Test.java") + if err := os.WriteFile(javaFile, []byte(javaCode), 0644); err != nil { + t.Fatalf("Failed to write Java file: %v", err) + } + + // Create a Python file + pythonCode := `def test(): pass` + pythonFile := filepath.Join(tmpDir, "test.py") + if err := os.WriteFile(pythonFile, []byte(pythonCode), 0644); err != nil { + t.Fatalf("Failed to write Python file: %v", err) + } + + // Create a non-source file (should be ignored) + txtFile := filepath.Join(tmpDir, "readme.txt") + if err := os.WriteFile(txtFile, []byte("This is a readme"), 0644); err != nil { + t.Fatalf("Failed to write txt file: %v", err) + } + + graph := Initialize(tmpDir) + + if graph == nil { + t.Fatal("Initialize should return a non-nil graph") + } + if len(graph.Nodes) == 0 { + t.Error("Expected nodes to be created from source files") + } + + // Check that both Java and Python nodes exist + hasJavaNode := false + hasPythonNode := false + for _, node := range graph.Nodes { + if node.isJavaSourceFile { + hasJavaNode = true + } + if node.isPythonSourceFile { + hasPythonNode = true + } + } + if !hasJavaNode { + t.Error("Expected to find Java nodes") + } + if !hasPythonNode { + t.Error("Expected to find Python nodes") + } +} + +func TestInitializeWithNonExistentDirectory(t *testing.T) { + graph := Initialize("/path/that/does/not/exist") + + if graph == nil { + t.Fatal("Initialize should return a non-nil graph even for non-existent directory") + } + if len(graph.Nodes) != 0 { + t.Errorf("Expected 0 nodes for non-existent directory, got %d", len(graph.Nodes)) + } +} + +func TestInitializeWithNestedDirectories(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "test_nested_dir") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create nested structure + subDir := filepath.Join(tmpDir, "src", "main") + if err := os.MkdirAll(subDir, 0755); err != nil { + t.Fatalf("Failed to create subdirectory: %v", err) + } + + // Create Java file in subdirectory + javaCode := `public class Nested { }` + javaFile := filepath.Join(subDir, "Nested.java") + if err := os.WriteFile(javaFile, []byte(javaCode), 0644); err != nil { + t.Fatalf("Failed to write Java file: %v", err) + } + + graph := Initialize(tmpDir) + + if graph == nil { + t.Fatal("Initialize should return a non-nil graph") + } + if len(graph.Nodes) == 0 { + t.Error("Expected nodes to be created from nested file") + } + + // Check that nested file was processed + hasNestedClass := false + for _, node := range graph.Nodes { + if node.Type == "class_declaration" && node.Name == "Nested" { + hasNestedClass = true + break + } + } + if !hasNestedClass { + t.Error("Expected to find Nested class from subdirectory") + } +} diff --git a/sourcecode-parser/graph/parser.go b/sourcecode-parser/graph/parser.go new file mode 100644 index 00000000..bfdd2e6c --- /dev/null +++ b/sourcecode-parser/graph/parser.go @@ -0,0 +1,110 @@ +package graph + +import sitter "github.com/smacker/go-tree-sitter" + +// buildGraphFromAST builds a code graph from an Abstract Syntax Tree. +func buildGraphFromAST(node *sitter.Node, sourceCode []byte, graph *CodeGraph, currentContext *Node, file string) { + isJavaSourceFile := isJavaSourceFile(file) + isPythonSourceFile := isPythonSourceFile(file) + + switch node.Type() { + // Python-specific node types + case "function_definition": + if isPythonSourceFile { + currentContext = parsePythonFunctionDefinition(node, sourceCode, graph, file) + } + + case "class_definition": + if isPythonSourceFile { + parsePythonClassDefinition(node, sourceCode, graph, file) + } + + case "call": + if isPythonSourceFile { + parsePythonCall(node, sourceCode, graph, currentContext, file) + } + + case "return_statement": + parseReturnStatement(node, sourceCode, graph, file, isJavaSourceFile, isPythonSourceFile) + + case "break_statement": + parseBreakStatement(node, sourceCode, graph, file, isJavaSourceFile, isPythonSourceFile) + + case "continue_statement": + parseContinueStatement(node, sourceCode, graph, file, isJavaSourceFile, isPythonSourceFile) + + case "assert_statement": + parseAssertStatement(node, sourceCode, graph, file, isJavaSourceFile, isPythonSourceFile) + + case "expression_statement": + if isPythonSourceFile { + parsePythonYieldExpression(node, sourceCode, graph, file) + } + + case "assignment": + if isPythonSourceFile { + parsePythonAssignment(node, sourceCode, graph, file) + } + + // Java-specific node types + case "block": + parseBlockStatement(node, sourceCode, graph, file, isJavaSourceFile) + + case "yield_statement": + parseYieldStatement(node, sourceCode, graph, file, isJavaSourceFile) + + case "if_statement": + parseIfStatement(node, sourceCode, graph, file, isJavaSourceFile) + + case "while_statement": + parseWhileStatement(node, sourceCode, graph, file, isJavaSourceFile) + + case "do_statement": + parseDoStatement(node, sourceCode, graph, file, isJavaSourceFile) + + case "for_statement": + parseForStatement(node, sourceCode, graph, file, isJavaSourceFile) + + case "binary_expression": + currentContext = parseJavaBinaryExpression(node, sourceCode, graph, file, isJavaSourceFile) + + case "method_declaration": + currentContext = parseJavaMethodDeclaration(node, sourceCode, graph, file) + + case "method_invocation": + parseJavaMethodInvocation(node, sourceCode, graph, currentContext, file) + + case "class_declaration": + parseJavaClassDeclaration(node, sourceCode, graph, file) + + case "block_comment": + parseJavaBlockComment(node, sourceCode, graph, file) + + case "local_variable_declaration", "field_declaration": + parseJavaVariableDeclaration(node, sourceCode, graph, file) + + case "object_creation_expression": + parseJavaObjectCreation(node, sourceCode, graph, file) + } + + // Recursively process child nodes + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + buildGraphFromAST(child, sourceCode, graph, currentContext, file) + } + + // Post-processing: Link method invocations to declarations + for _, node := range graph.Nodes { + if node.Type == "method_declaration" { + for _, invokedNode := range graph.Nodes { + if invokedNode.Type == "method_invocation" { + if invokedNode.Name == node.Name { + if len(invokedNode.MethodArgumentsValue) == len(node.MethodArgumentsType) { + node.hasAccess = true + } + } + } + } + } + } +} diff --git a/sourcecode-parser/graph/parser_java.go b/sourcecode-parser/graph/parser_java.go new file mode 100644 index 00000000..b8994be5 --- /dev/null +++ b/sourcecode-parser/graph/parser_java.go @@ -0,0 +1,374 @@ +package graph + +import ( + "strconv" + "strings" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/model" + sitter "github.com/smacker/go-tree-sitter" +) + +// parseJavaBinaryExpression parses Java binary expressions. +func parseJavaBinaryExpression(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string, isJavaSourceFile bool) *Node { + leftNode := node.ChildByFieldName("left") + rightNode := node.ChildByFieldName("right") + operator := node.ChildByFieldName("operator") + operatorType := operator.Type() + expressionNode := model.BinaryExpr{} + expressionNode.LeftOperand = &model.Expr{Node: *leftNode, NodeString: leftNode.Content(sourceCode)} + expressionNode.RightOperand = &model.Expr{Node: *rightNode, NodeString: rightNode.Content(sourceCode)} + expressionNode.Op = operatorType + + var exprType string + switch operatorType { + case "+": + exprType = "add_expression" + case "-": + exprType = "sub_expression" + case "*": + exprType = "mul_expression" + case "/": + exprType = "div_expression" + case ">", "<", ">=", "<=": + exprType = "comp_expression" + case "%": + exprType = "rem_expression" + case ">>": + exprType = "right_shift_expression" + case "<<": + exprType = "left_shift_expression" + case "!=": + exprType = "ne_expression" + case "==": + exprType = "eq_expression" + case "&": + exprType = "bitwise_and_expression" + case "&&": + exprType = "and_expression" + case "||": + exprType = "or_expression" + case "|": + exprType = "bitwise_or_expression" + case ">>>": + exprType = "bitwise_right_shift_expression" + case "^": + exprType = "bitwise_xor_expression" + default: + exprType = "binary_expression" + } + + exprNode := &Node{ + ID: GenerateSha256(exprType + node.Content(sourceCode)), + Type: exprType, + Name: node.Content(sourceCode), + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + File: file, + isJavaSourceFile: isJavaSourceFile, + BinaryExpr: &expressionNode, + } + graph.AddNode(exprNode) + + invokedNode := &Node{ + ID: GenerateSha256("binary_expression" + node.Content(sourceCode)), + Type: "binary_expression", + Name: node.Content(sourceCode), + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + File: file, + isJavaSourceFile: isJavaSourceFile, + BinaryExpr: &expressionNode, + } + graph.AddNode(invokedNode) + return invokedNode +} + +// parseJavaMethodDeclaration parses Java method declarations. +func parseJavaMethodDeclaration(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) *Node { + var javadoc *model.Javadoc + if node.PrevSibling() != nil && node.PrevSibling().Type() == "block_comment" { + commentContent := node.PrevSibling().Content(sourceCode) + if strings.HasPrefix(commentContent, "/*") { + javadoc = parseJavadocTags(commentContent) + } + } + methodName, methodID := extractMethodName(node, sourceCode, file) + modifiers := "" + returnType := "" + throws := []string{} + methodArgumentType := []string{} + methodArgumentValue := []string{} + annotationMarkers := []string{} + + for i := 0; i < int(node.ChildCount()); i++ { + childNode := node.Child(i) + childType := childNode.Type() + + switch childType { + case "throws": + for j := 0; j < int(childNode.NamedChildCount()); j++ { + namedChild := childNode.NamedChild(j) + if namedChild.Type() == "type_identifier" { + throws = append(throws, namedChild.Content(sourceCode)) + } + } + case "modifiers": + modifiers = childNode.Content(sourceCode) + for j := 0; j < int(childNode.ChildCount()); j++ { + if childNode.Child(j).Type() == "marker_annotation" { + annotationMarkers = append(annotationMarkers, childNode.Child(j).Content(sourceCode)) + } + } + case "void_type", "type_identifier": + returnType = childNode.Content(sourceCode) + case "formal_parameters": + for j := 0; j < int(childNode.NamedChildCount()); j++ { + param := childNode.NamedChild(j) + if param.Type() == "formal_parameter" { + paramType := param.Child(0).Content(sourceCode) + paramValue := param.Child(1).Content(sourceCode) + methodArgumentType = append(methodArgumentType, paramType) + methodArgumentValue = append(methodArgumentValue, paramValue) + } + } + } + } + + invokedNode := &Node{ + ID: methodID, + Type: "method_declaration", + Name: methodName, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + Modifier: extractVisibilityModifier(modifiers), + ReturnType: returnType, + MethodArgumentsType: methodArgumentType, + MethodArgumentsValue: methodArgumentValue, + File: file, + isJavaSourceFile: true, + ThrowsExceptions: throws, + Annotation: annotationMarkers, + JavaDoc: javadoc, + } + graph.AddNode(invokedNode) + return invokedNode +} + +// parseJavaMethodInvocation parses Java method invocations. +func parseJavaMethodInvocation(node *sitter.Node, sourceCode []byte, graph *CodeGraph, currentContext *Node, file string) { + methodName, methodID := extractMethodName(node, sourceCode, file) + arguments := []string{} + for i := 0; i < int(node.ChildCount()); i++ { + if node.Child(i).Type() == "argument_list" { + argumentsNode := node.Child(i) + for j := 0; j < int(argumentsNode.ChildCount()); j++ { + argument := argumentsNode.Child(j) + switch argument.Type() { + case "identifier": + arguments = append(arguments, argument.Content(sourceCode)) + case "string_literal": + stringliteral := argument.Content(sourceCode) + stringliteral = strings.TrimPrefix(stringliteral, "\"") + stringliteral = strings.TrimSuffix(stringliteral, "\"") + arguments = append(arguments, stringliteral) + default: + arguments = append(arguments, argument.Content(sourceCode)) + } + } + } + } + + invokedNode := &Node{ + ID: methodID, + Type: "method_invocation", + Name: methodName, + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + MethodArgumentsValue: arguments, + File: file, + isJavaSourceFile: true, + } + graph.AddNode(invokedNode) + + if currentContext != nil { + graph.AddEdge(currentContext, invokedNode) + } +} + +// parseJavaClassDeclaration parses Java class declarations. +func parseJavaClassDeclaration(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) { + var javadoc *model.Javadoc + if node.PrevSibling() != nil && node.PrevSibling().Type() == "block_comment" { + commentContent := node.PrevSibling().Content(sourceCode) + if strings.HasPrefix(commentContent, "/*") { + javadoc = parseJavadocTags(commentContent) + } + } + className := node.ChildByFieldName("name").Content(sourceCode) + packageName := "" + accessModifier := "" + superClass := "" + annotationMarkers := []string{} + implementedInterface := []string{} + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + switch child.Type() { + case "modifiers": + accessModifier = child.Content(sourceCode) + for j := 0; j < int(child.ChildCount()); j++ { + if child.Child(j).Type() == "marker_annotation" { + annotationMarkers = append(annotationMarkers, child.Child(j).Content(sourceCode)) + } + } + case "superclass": + for j := 0; j < int(child.ChildCount()); j++ { + if child.Child(j).Type() == "type_identifier" { + superClass = child.Child(j).Content(sourceCode) + } + } + case "super_interfaces": + for j := 0; j < int(child.ChildCount()); j++ { + typeList := child.Child(j) + for k := 0; k < int(typeList.ChildCount()); k++ { + implementedInterface = append(implementedInterface, typeList.Child(k).Content(sourceCode)) + } + } + } + } + + classNode := &Node{ + ID: GenerateMethodID("class:"+className, []string{}, file), + Type: "class_declaration", + Name: className, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + PackageName: packageName, + Modifier: extractVisibilityModifier(accessModifier), + SuperClass: superClass, + Interface: implementedInterface, + File: file, + isJavaSourceFile: true, + JavaDoc: javadoc, + Annotation: annotationMarkers, + } + graph.AddNode(classNode) +} + +// parseJavaBlockComment parses Java block comments. +func parseJavaBlockComment(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) { + if strings.HasPrefix(node.Content(sourceCode), "/*") { + commentContent := node.Content(sourceCode) + javadocTags := parseJavadocTags(commentContent) + + commentNode := &Node{ + ID: GenerateMethodID(node.Content(sourceCode), []string{}, file), + Type: "block_comment", + CodeSnippet: commentContent, + LineNumber: node.StartPoint().Row + 1, + File: file, + isJavaSourceFile: true, + JavaDoc: javadocTags, + } + graph.AddNode(commentNode) + } +} + +// parseJavaVariableDeclaration parses Java variable declarations (local and field). +func parseJavaVariableDeclaration(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) { + variableName := "" + variableType := "" + variableModifier := "" + variableValue := "" + hasAccessValue := false + var scope string + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + switch child.Type() { + case "variable_declarator": + variableName = child.Content(sourceCode) + for j := 0; j < int(child.ChildCount()); j++ { + if child.Child(j).Type() == "identifier" { + variableName = child.Child(j).Content(sourceCode) + } + if child.Child(j).Type() == "=" { + for k := j + 1; k < int(child.ChildCount()); k++ { + variableValue += child.Child(k).Content(sourceCode) + } + } + + } + variableValue = strings.ReplaceAll(variableValue, " ", "") + variableValue = strings.ReplaceAll(variableValue, "\n", "") + case "modifiers": + variableModifier = child.Content(sourceCode) + } + if strings.Contains(child.Type(), "type") { + variableType = child.Content(sourceCode) + } + } + if node.Type() == "local_variable_declaration" { + scope = "local" + } else { + scope = "field" + } + variableNode := &Node{ + ID: GenerateMethodID(variableName, []string{}, file), + Type: "variable_declaration", + Name: variableName, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + Modifier: extractVisibilityModifier(variableModifier), + DataType: variableType, + Scope: scope, + VariableValue: variableValue, + hasAccess: hasAccessValue, + File: file, + isJavaSourceFile: true, + } + graph.AddNode(variableNode) +} + +// parseJavaObjectCreation parses Java object creation expressions. +func parseJavaObjectCreation(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) { + className := "" + classInstanceExpression := model.ClassInstanceExpr{ + ClassName: "", + Args: []*model.Expr{}, + } + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "type_identifier" || child.Type() == "scoped_type_identifier" { + className = child.Content(sourceCode) + classInstanceExpression.ClassName = className + } + if child.Type() == "argument_list" { + classInstanceExpression.Args = []*model.Expr{} + for j := 0; j < int(child.ChildCount()); j++ { + argType := child.Child(j).Type() + argumentStopWords := map[string]bool{ + "(": true, ")": true, "{": true, "}": true, + "[": true, "]": true, ",": true, + } + if !argumentStopWords[argType] { + argument := &model.Expr{} + argument.Type = child.Child(j).Type() + argument.NodeString = child.Child(j).Content(sourceCode) + classInstanceExpression.Args = append(classInstanceExpression.Args, argument) + } + } + } + } + + objectNode := &Node{ + ID: GenerateMethodID(className, []string{strconv.Itoa(int(node.StartPoint().Row + 1))}, file), + Type: "ClassInstanceExpr", + Name: className, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + File: file, + isJavaSourceFile: true, + ClassInstanceExpr: &classInstanceExpression, + } + graph.AddNode(objectNode) +} diff --git a/sourcecode-parser/graph/parser_python.go b/sourcecode-parser/graph/parser_python.go new file mode 100644 index 00000000..6f87906a --- /dev/null +++ b/sourcecode-parser/graph/parser_python.go @@ -0,0 +1,242 @@ +package graph + +import ( + "fmt" + + pythonlang "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/python" + sitter "github.com/smacker/go-tree-sitter" +) + +// parsePythonFunctionDefinition parses Python function definitions. +func parsePythonFunctionDefinition(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) *Node { + // Extract function name and parameters + functionName := "" + parameters := []string{} + + nameNode := node.ChildByFieldName("name") + if nameNode != nil { + functionName = nameNode.Content(sourceCode) + } + + parametersNode := node.ChildByFieldName("parameters") + if parametersNode != nil { + for i := 0; i < int(parametersNode.NamedChildCount()); i++ { + param := parametersNode.NamedChild(i) + if param.Type() == "identifier" || param.Type() == "typed_parameter" || param.Type() == "default_parameter" { + parameters = append(parameters, param.Content(sourceCode)) + } + } + } + + methodID := GenerateMethodID("function:"+functionName, parameters, file) + functionNode := &Node{ + ID: methodID, + Type: "function_definition", + Name: functionName, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + MethodArgumentsValue: parameters, + File: file, + isPythonSourceFile: true, + } + graph.AddNode(functionNode) + return functionNode +} + +// parsePythonClassDefinition parses Python class definitions. +func parsePythonClassDefinition(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) { + // Extract class name and bases + className := "" + superClasses := []string{} + + nameNode := node.ChildByFieldName("name") + if nameNode != nil { + className = nameNode.Content(sourceCode) + } + + superclassNode := node.ChildByFieldName("superclasses") + if superclassNode != nil { + for i := 0; i < int(superclassNode.NamedChildCount()); i++ { + superClass := superclassNode.NamedChild(i) + if superClass.Type() == "identifier" || superClass.Type() == "attribute" { + superClasses = append(superClasses, superClass.Content(sourceCode)) + } + } + } + + classNode := &Node{ + ID: GenerateMethodID("class:"+className, []string{}, file), + Type: "class_definition", + Name: className, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + Interface: superClasses, + File: file, + isPythonSourceFile: true, + } + graph.AddNode(classNode) +} + +// parsePythonCall parses Python function calls. +func parsePythonCall(node *sitter.Node, sourceCode []byte, graph *CodeGraph, currentContext *Node, file string) { + // Python function calls + callName := "" + arguments := []string{} + + functionNode := node.ChildByFieldName("function") + if functionNode != nil { + callName = functionNode.Content(sourceCode) + } + + argumentsNode := node.ChildByFieldName("arguments") + if argumentsNode != nil { + for i := 0; i < int(argumentsNode.NamedChildCount()); i++ { + arg := argumentsNode.NamedChild(i) + arguments = append(arguments, arg.Content(sourceCode)) + } + } + + callID := GenerateMethodID(callName, arguments, file) + callNode := &Node{ + ID: callID, + Type: "call", + Name: callName, + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + MethodArgumentsValue: arguments, + File: file, + isPythonSourceFile: true, + } + graph.AddNode(callNode) + if currentContext != nil { + graph.AddEdge(currentContext, callNode) + } +} + +// parsePythonReturnStatement parses Python return statements. +func parsePythonReturnStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) { + returnNode := pythonlang.ParseReturnStatement(node, sourceCode) + uniqueReturnID := fmt.Sprintf("return_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + returnStmtNode := &Node{ + ID: GenerateSha256(uniqueReturnID), + Type: "ReturnStmt", + LineNumber: node.StartPoint().Row + 1, + Name: "ReturnStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + File: file, + isPythonSourceFile: true, + ReturnStmt: returnNode, + } + graph.AddNode(returnStmtNode) +} + +// parsePythonBreakStatement parses Python break statements. +func parsePythonBreakStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) { + breakNode := pythonlang.ParseBreakStatement(node, sourceCode) + uniquebreakstmtID := fmt.Sprintf("breakstmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + breakStmtNode := &Node{ + ID: GenerateSha256(uniquebreakstmtID), + Type: "BreakStmt", + LineNumber: node.StartPoint().Row + 1, + Name: "BreakStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + File: file, + isPythonSourceFile: true, + BreakStmt: breakNode, + } + graph.AddNode(breakStmtNode) +} + +// parsePythonContinueStatement parses Python continue statements. +func parsePythonContinueStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) { + continueNode := pythonlang.ParseContinueStatement(node, sourceCode) + uniquecontinueID := fmt.Sprintf("continuestmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + continueStmtNode := &Node{ + ID: GenerateSha256(uniquecontinueID), + Type: "ContinueStmt", + LineNumber: node.StartPoint().Row + 1, + Name: "ContinueStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + File: file, + isPythonSourceFile: true, + ContinueStmt: continueNode, + } + graph.AddNode(continueStmtNode) +} + +// parsePythonAssertStatement parses Python assert statements. +func parsePythonAssertStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) { + assertNode := pythonlang.ParseAssertStatement(node, sourceCode) + uniqueAssertID := fmt.Sprintf("assert_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + assertStmtNode := &Node{ + ID: GenerateSha256(uniqueAssertID), + Type: "AssertStmt", + LineNumber: node.StartPoint().Row + 1, + Name: "AssertStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + File: file, + isPythonSourceFile: true, + AssertStmt: assertNode, + } + graph.AddNode(assertStmtNode) +} + +// parsePythonYieldExpression parses Python yield expressions. +func parsePythonYieldExpression(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) { + // Handle yield expressions in Python + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "yield" { + yieldNode := pythonlang.ParseYieldStatement(child, sourceCode) + uniqueyieldID := fmt.Sprintf("yield_%d_%d_%s", child.StartPoint().Row+1, child.StartPoint().Column+1, file) + yieldStmtNode := &Node{ + ID: GenerateSha256(uniqueyieldID), + Type: "YieldStmt", + LineNumber: child.StartPoint().Row + 1, + Name: "YieldStmt", + IsExternal: true, + CodeSnippet: child.Content(sourceCode), + File: file, + isPythonSourceFile: true, + YieldStmt: yieldNode, + } + graph.AddNode(yieldStmtNode) + break + } + } +} + +// parsePythonAssignment parses Python variable assignments. +func parsePythonAssignment(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string) { + // Python variable assignments + variableName := "" + variableValue := "" + + leftNode := node.ChildByFieldName("left") + if leftNode != nil { + variableName = leftNode.Content(sourceCode) + } + + rightNode := node.ChildByFieldName("right") + if rightNode != nil { + variableValue = rightNode.Content(sourceCode) + } + + variableNode := &Node{ + ID: GenerateMethodID(variableName, []string{}, file), + Type: "variable_assignment", + Name: variableName, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + VariableValue: variableValue, + Scope: "local", + File: file, + isPythonSourceFile: true, + } + graph.AddNode(variableNode) +} diff --git a/sourcecode-parser/graph/parser_python_test.go b/sourcecode-parser/graph/parser_python_test.go new file mode 100644 index 00000000..ae98406b --- /dev/null +++ b/sourcecode-parser/graph/parser_python_test.go @@ -0,0 +1,223 @@ +package graph + +import ( + "context" + "testing" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" +) + +func TestParsePythonFunctionDefinition(t *testing.T) { + tests := []struct { + name string + code string + expectedName string + expectedParams int + }{ + { + name: "Simple function", + code: "def hello():\n pass", + expectedName: "hello", + expectedParams: 0, + }, + { + name: "Function with parameters", + code: "def add(x, y):\n return x + y", + expectedName: "add", + expectedParams: 2, + }, + { + name: "Function with default parameters", + code: "def greet(name, msg='Hello'):\n print(msg, name)", + expectedName: "greet", + expectedParams: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(tt.code)) + if err != nil { + t.Fatalf("Failed to parse: %v", err) + } + defer tree.Close() + + graph := NewCodeGraph() + root := tree.RootNode() + + // Find function_definition node + var funcNode *sitter.Node + for i := 0; i < int(root.NamedChildCount()); i++ { + child := root.NamedChild(i) + if child.Type() == "function_definition" { + funcNode = child + break + } + } + + if funcNode == nil { + t.Fatal("No function_definition node found") + } + + node := parsePythonFunctionDefinition(funcNode, []byte(tt.code), graph, "test.py") + + if node.Name != tt.expectedName { + t.Errorf("Expected name %s, got %s", tt.expectedName, node.Name) + } + if len(node.MethodArgumentsValue) != tt.expectedParams { + t.Errorf("Expected %d params, got %d", tt.expectedParams, len(node.MethodArgumentsValue)) + } + if !node.isPythonSourceFile { + t.Error("Expected isPythonSourceFile to be true") + } + }) + } +} + +func TestParsePythonClassDefinition(t *testing.T) { + tests := []struct { + name string + code string + expectedName string + expectedBases int + }{ + { + name: "Simple class", + code: "class MyClass:\n pass", + expectedName: "MyClass", + expectedBases: 0, + }, + { + name: "Class with inheritance", + code: "class Child(Parent):\n pass", + expectedName: "Child", + expectedBases: 1, + }, + { + name: "Class with multiple inheritance", + code: "class Multi(Base1, Base2):\n pass", + expectedName: "Multi", + expectedBases: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(tt.code)) + if err != nil { + t.Fatalf("Failed to parse: %v", err) + } + defer tree.Close() + + graph := NewCodeGraph() + root := tree.RootNode() + + var classNode *sitter.Node + for i := 0; i < int(root.NamedChildCount()); i++ { + child := root.NamedChild(i) + if child.Type() == "class_definition" { + classNode = child + break + } + } + + if classNode == nil { + t.Fatal("No class_definition node found") + } + + parsePythonClassDefinition(classNode, []byte(tt.code), graph, "test.py") + + if len(graph.Nodes) == 0 { + t.Fatal("No nodes added to graph") + } + + var node *Node + for _, n := range graph.Nodes { + if n.Type == "class_definition" { + node = n + break + } + } + + if node == nil { + t.Fatal("No class node found in graph") + } + + if node.Name != tt.expectedName { + t.Errorf("Expected name %s, got %s", tt.expectedName, node.Name) + } + if len(node.Interface) != tt.expectedBases { + t.Errorf("Expected %d bases, got %d", tt.expectedBases, len(node.Interface)) + } + }) + } +} + +func TestParsePythonAssignment(t *testing.T) { + code := "x = 42" + + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(code)) + if err != nil { + t.Fatalf("Failed to parse: %v", err) + } + defer tree.Close() + + graph := NewCodeGraph() + root := tree.RootNode() + + var assignNode *sitter.Node + for i := 0; i < int(root.NamedChildCount()); i++ { + child := root.NamedChild(i) + if child.Type() == "expression_statement" { + for j := 0; j < int(child.NamedChildCount()); j++ { + subchild := child.NamedChild(j) + if subchild.Type() == "assignment" { + assignNode = subchild + break + } + } + } + } + + if assignNode == nil { + t.Fatal("No assignment node found") + } + + parsePythonAssignment(assignNode, []byte(code), graph, "test.py") + + if len(graph.Nodes) == 0 { + t.Fatal("No nodes added to graph") + } + + var node *Node + for _, n := range graph.Nodes { + if n.Type == "variable_assignment" { + node = n + break + } + } + + if node == nil { + t.Fatal("No variable assignment node found") + } + + if node.Name != "x" { + t.Errorf("Expected variable name 'x', got %s", node.Name) + } + if node.VariableValue != "42" { + t.Errorf("Expected value '42', got %s", node.VariableValue) + } +} diff --git a/sourcecode-parser/graph/parser_statements.go b/sourcecode-parser/graph/parser_statements.go new file mode 100644 index 00000000..467a14d5 --- /dev/null +++ b/sourcecode-parser/graph/parser_statements.go @@ -0,0 +1,239 @@ +package graph + +import ( + "fmt" + + javalang "github.com/shivasurya/code-pathfinder/sourcecode-parser/graph/java" + "github.com/shivasurya/code-pathfinder/sourcecode-parser/model" + sitter "github.com/smacker/go-tree-sitter" +) + +// parseBlockStatement parses block statements. +func parseBlockStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string, isJavaSourceFile bool) { + blockNode := javalang.ParseBlockStatement(node, sourceCode) + uniqueBlockID := fmt.Sprintf("block_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + blockStmtNode := &Node{ + ID: GenerateSha256(uniqueBlockID), + Type: "BlockStmt", + LineNumber: node.StartPoint().Row + 1, + Name: "BlockStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + File: file, + isJavaSourceFile: isJavaSourceFile, + BlockStmt: blockNode, + } + graph.AddNode(blockStmtNode) +} + +// parseReturnStatement parses return statements (Java or Python). +func parseReturnStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string, isJava, isPython bool) { + if isPython { + parsePythonReturnStatement(node, sourceCode, graph, file) + } else if isJava { + returnNode := javalang.ParseReturnStatement(node, sourceCode) + uniqueReturnID := fmt.Sprintf("return_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + returnStmtNode := &Node{ + ID: GenerateSha256(uniqueReturnID), + Type: "ReturnStmt", + LineNumber: node.StartPoint().Row + 1, + Name: "ReturnStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + File: file, + isJavaSourceFile: isJava, + ReturnStmt: returnNode, + } + graph.AddNode(returnStmtNode) + } +} + +// parseBreakStatement parses break statements (Java or Python). +func parseBreakStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string, isJava, isPython bool) { + if isPython { + parsePythonBreakStatement(node, sourceCode, graph, file) + } else if isJava { + breakNode := javalang.ParseBreakStatement(node, sourceCode) + uniquebreakstmtID := fmt.Sprintf("breakstmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + breakStmtNode := &Node{ + ID: GenerateSha256(uniquebreakstmtID), + Type: "BreakStmt", + LineNumber: node.StartPoint().Row + 1, + Name: "BreakStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + File: file, + isJavaSourceFile: isJava, + BreakStmt: breakNode, + } + graph.AddNode(breakStmtNode) + } +} + +// parseContinueStatement parses continue statements (Java or Python). +func parseContinueStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string, isJava, isPython bool) { + if isPython { + parsePythonContinueStatement(node, sourceCode, graph, file) + } else if isJava { + continueNode := javalang.ParseContinueStatement(node, sourceCode) + uniquecontinueID := fmt.Sprintf("continuestmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + continueStmtNode := &Node{ + ID: GenerateSha256(uniquecontinueID), + Type: "ContinueStmt", + LineNumber: node.StartPoint().Row + 1, + Name: "ContinueStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + File: file, + isJavaSourceFile: isJava, + ContinueStmt: continueNode, + } + graph.AddNode(continueStmtNode) + } +} + +// parseAssertStatement parses assert statements (Java or Python). +func parseAssertStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string, isJava, isPython bool) { + if isPython { + parsePythonAssertStatement(node, sourceCode, graph, file) + } else if isJava { + assertNode := javalang.ParseAssertStatement(node, sourceCode) + uniqueAssertID := fmt.Sprintf("assert_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + assertStmtNode := &Node{ + ID: GenerateSha256(uniqueAssertID), + Type: "AssertStmt", + LineNumber: node.StartPoint().Row + 1, + Name: "AssertStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + File: file, + isJavaSourceFile: isJava, + AssertStmt: assertNode, + } + graph.AddNode(assertStmtNode) + } +} + +// parseYieldStatement parses yield statements (Java only). +func parseYieldStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string, isJavaSourceFile bool) { + yieldNode := javalang.ParseYieldStatement(node, sourceCode) + uniqueyieldID := fmt.Sprintf("yield_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + yieldStmtNode := &Node{ + ID: GenerateSha256(uniqueyieldID), + Type: "YieldStmt", + LineNumber: node.StartPoint().Row + 1, + Name: "YieldStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + File: file, + isJavaSourceFile: isJavaSourceFile, + YieldStmt: yieldNode, + } + graph.AddNode(yieldStmtNode) +} + +// parseIfStatement parses if statements. +func parseIfStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string, isJavaSourceFile bool) { + ifNode := model.IfStmt{} + conditionNode := node.Child(1) + if conditionNode != nil { + ifNode.Condition = &model.Expr{Node: *conditionNode, NodeString: conditionNode.Content(sourceCode)} + } + thenNode := node.Child(2) + if thenNode != nil { + ifNode.Then = model.Stmt{NodeString: thenNode.Content(sourceCode)} + } + elseNode := node.Child(4) + if elseNode != nil { + ifNode.Else = model.Stmt{NodeString: elseNode.Content(sourceCode)} + } + + methodID := fmt.Sprintf("ifstmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + ifStmtNode := &Node{ + ID: GenerateSha256(methodID), + Type: "IfStmt", + Name: "IfStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + File: file, + isJavaSourceFile: isJavaSourceFile, + IfStmt: &ifNode, + } + graph.AddNode(ifStmtNode) +} + +// parseWhileStatement parses while statements. +func parseWhileStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string, isJavaSourceFile bool) { + whileNode := model.WhileStmt{} + conditionNode := node.Child(1) + if conditionNode != nil { + whileNode.Condition = &model.Expr{Node: *conditionNode, NodeString: conditionNode.Content(sourceCode)} + } + methodID := fmt.Sprintf("while_stmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + whileStmtNode := &Node{ + ID: GenerateSha256(methodID), + Type: "WhileStmt", + Name: "WhileStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + File: file, + isJavaSourceFile: isJavaSourceFile, + WhileStmt: &whileNode, + } + graph.AddNode(whileStmtNode) +} + +// parseDoStatement parses do-while statements. +func parseDoStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string, isJavaSourceFile bool) { + doWhileNode := model.DoStmt{} + conditionNode := node.Child(2) + if conditionNode != nil { + doWhileNode.Condition = &model.Expr{Node: *conditionNode, NodeString: conditionNode.Content(sourceCode)} + } + methodID := fmt.Sprintf("dowhile_stmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + doWhileStmtNode := &Node{ + ID: GenerateSha256(methodID), + Type: "DoStmt", + Name: "DoStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + File: file, + isJavaSourceFile: isJavaSourceFile, + DoStmt: &doWhileNode, + } + graph.AddNode(doWhileStmtNode) +} + +// parseForStatement parses for statements. +func parseForStatement(node *sitter.Node, sourceCode []byte, graph *CodeGraph, file string, isJavaSourceFile bool) { + forNode := model.ForStmt{} + initNode := node.ChildByFieldName("init") + if initNode != nil { + forNode.Init = &model.Expr{Node: *initNode, NodeString: initNode.Content(sourceCode)} + } + conditionNode := node.ChildByFieldName("condition") + if conditionNode != nil { + forNode.Condition = &model.Expr{Node: *conditionNode, NodeString: conditionNode.Content(sourceCode)} + } + incrementNode := node.ChildByFieldName("increment") + if incrementNode != nil { + forNode.Increment = &model.Expr{Node: *incrementNode, NodeString: incrementNode.Content(sourceCode)} + } + + methodID := fmt.Sprintf("for_stmt_%d_%d_%s", node.StartPoint().Row+1, node.StartPoint().Column+1, file) + forStmtNode := &Node{ + ID: GenerateSha256(methodID), + Type: "ForStmt", + Name: "ForStmt", + IsExternal: true, + CodeSnippet: node.Content(sourceCode), + LineNumber: node.StartPoint().Row + 1, + File: file, + isJavaSourceFile: isJavaSourceFile, + ForStmt: &forNode, + } + graph.AddNode(forStmtNode) +} diff --git a/sourcecode-parser/graph/python/parse_statement.go b/sourcecode-parser/graph/python/parse_statement.go new file mode 100644 index 00000000..e440014e --- /dev/null +++ b/sourcecode-parser/graph/python/parse_statement.go @@ -0,0 +1,71 @@ +package python + +import ( + "github.com/shivasurya/code-pathfinder/sourcecode-parser/model" + sitter "github.com/smacker/go-tree-sitter" +) + +func ParseReturnStatement(node *sitter.Node, sourcecode []byte) *model.ReturnStmt { + returnStmt := &model.ReturnStmt{} + // Python return statements can have 0 or more return values + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() != "return" { + returnStmt.Result = &model.Expr{NodeString: child.Content(sourcecode)} + break + } + } + return returnStmt +} + +func ParseBreakStatement(node *sitter.Node, sourcecode []byte) *model.BreakStmt { + breakStmt := &model.BreakStmt{} + // Python break statements don't have labels + return breakStmt +} + +func ParseContinueStatement(node *sitter.Node, sourcecode []byte) *model.ContinueStmt { + continueStmt := &model.ContinueStmt{} + // Python continue statements don't have labels + return continueStmt +} + +func ParseAssertStatement(node *sitter.Node, sourcecode []byte) *model.AssertStmt { + assertStmt := &model.AssertStmt{} + // Python assert has condition and optional message + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() != "assert" && child.Type() != "," { + if assertStmt.Expr == nil { + assertStmt.Expr = &model.Expr{NodeString: child.Content(sourcecode)} + } else if assertStmt.Message == nil { + assertStmt.Message = &model.Expr{NodeString: child.Content(sourcecode)} + } + } + } + return assertStmt +} + +func ParseBlockStatement(node *sitter.Node, sourcecode []byte) *model.BlockStmt { + blockStmt := &model.BlockStmt{} + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + singleBlockStmt := &model.Stmt{} + singleBlockStmt.NodeString = child.Content(sourcecode) + blockStmt.Stmts = append(blockStmt.Stmts, *singleBlockStmt) + } + return blockStmt +} + +func ParseYieldStatement(node *sitter.Node, sourcecode []byte) *model.YieldStmt { + yieldStmt := &model.YieldStmt{} + // Python yield can be in yield_expression + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() != "yield" && child.Type() != "from" { + yieldStmt.Value = &model.Expr{NodeString: child.Content(sourcecode)} + break + } + } + return yieldStmt +} diff --git a/sourcecode-parser/graph/python/parse_statement_test.go b/sourcecode-parser/graph/python/parse_statement_test.go new file mode 100644 index 00000000..aa1a64ce --- /dev/null +++ b/sourcecode-parser/graph/python/parse_statement_test.go @@ -0,0 +1,277 @@ +package python + +import ( + "context" + "testing" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" +) + +func TestParseReturnStatement(t *testing.T) { + tests := []struct { + name string + code string + expected string + }{ + { + name: "return with value", + code: "return 42", + expected: "42", + }, + { + name: "return without value", + code: "return", + expected: "", + }, + { + name: "return with expression", + code: "return x + y", + expected: "x + y", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := sitter.NewParser() + defer parser.Close() + parser.SetLanguage(python.GetLanguage()) + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(tt.code)) + if err != nil { + t.Fatalf("Failed to parse code: %v", err) + } + defer tree.Close() + + root := tree.RootNode() + returnNode := root.Child(0) + + result := ParseReturnStatement(returnNode, []byte(tt.code)) + + if result == nil { + t.Fatal("ParseReturnStatement returned nil") + } + + switch { + case tt.expected == "" && result.Result != nil: + t.Errorf("Expected nil result, got %v", result.Result) + case tt.expected != "" && result.Result == nil: + t.Errorf("Expected result %s, got nil", tt.expected) + case tt.expected != "" && result.Result != nil && result.Result.NodeString != tt.expected: + t.Errorf("Expected result %s, got %s", tt.expected, result.Result.NodeString) + } + }) + } +} + +func TestParseBreakStatement(t *testing.T) { + code := "break" + parser := sitter.NewParser() + defer parser.Close() + parser.SetLanguage(python.GetLanguage()) + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(code)) + if err != nil { + t.Fatalf("Failed to parse code: %v", err) + } + defer tree.Close() + + root := tree.RootNode() + breakNode := root.Child(0) + + result := ParseBreakStatement(breakNode, []byte(code)) + + if result == nil { + t.Fatal("ParseBreakStatement returned nil") + } + + // Python break statements don't have labels + if result.Label != "" { + t.Errorf("Expected empty label, got %s", result.Label) + } +} + +func TestParseContinueStatement(t *testing.T) { + code := "continue" + parser := sitter.NewParser() + defer parser.Close() + parser.SetLanguage(python.GetLanguage()) + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(code)) + if err != nil { + t.Fatalf("Failed to parse code: %v", err) + } + defer tree.Close() + + root := tree.RootNode() + continueNode := root.Child(0) + + result := ParseContinueStatement(continueNode, []byte(code)) + + if result == nil { + t.Fatal("ParseContinueStatement returned nil") + } + + // Python continue statements don't have labels + if result.Label != "" { + t.Errorf("Expected empty label, got %s", result.Label) + } +} + +func TestParseAssertStatement(t *testing.T) { + tests := []struct { + name string + code string + expectedExpr string + expectedMessage string + }{ + { + name: "assert without message", + code: "assert x > 0", + expectedExpr: "x > 0", + expectedMessage: "", + }, + { + name: "assert with message", + code: "assert x > 0, \"x must be positive\"", + expectedExpr: "x > 0", + expectedMessage: "\"x must be positive\"", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := sitter.NewParser() + defer parser.Close() + parser.SetLanguage(python.GetLanguage()) + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(tt.code)) + if err != nil { + t.Fatalf("Failed to parse code: %v", err) + } + defer tree.Close() + + root := tree.RootNode() + assertNode := root.Child(0) + + result := ParseAssertStatement(assertNode, []byte(tt.code)) + + if result == nil { + t.Fatal("ParseAssertStatement returned nil") + } + + if result.Expr == nil { + t.Fatal("Expected non-nil Expr") + } + + if result.Expr.NodeString != tt.expectedExpr { + t.Errorf("Expected expr %s, got %s", tt.expectedExpr, result.Expr.NodeString) + } + + if tt.expectedMessage == "" && result.Message != nil { + t.Errorf("Expected nil message, got %v", result.Message) + } else if tt.expectedMessage != "" && result.Message == nil { + t.Errorf("Expected message %s, got nil", tt.expectedMessage) + } + }) + } +} + +func TestParseYieldStatement(t *testing.T) { + tests := []struct { + name string + code string + expected string + }{ + { + name: "yield with value", + code: "yield 42", + expected: "42", + }, + { + name: "yield with expression", + code: "yield x + y", + expected: "x + y", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := sitter.NewParser() + defer parser.Close() + parser.SetLanguage(python.GetLanguage()) + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(tt.code)) + if err != nil { + t.Fatalf("Failed to parse code: %v", err) + } + defer tree.Close() + + root := tree.RootNode() + // Navigate to the actual yield node: module -> expression_statement -> yield + exprStmt := root.Child(0) + yieldNode := exprStmt.Child(0) + + result := ParseYieldStatement(yieldNode, []byte(tt.code)) + + if result == nil { + t.Fatal("ParseYieldStatement returned nil") + } + + if result.Value == nil { + t.Fatal("Expected non-nil Value") + } + + if result.Value.NodeString != tt.expected { + t.Errorf("Expected value %s, got %s", tt.expected, result.Value.NodeString) + } + }) + } +} + +func TestParseBlockStatement(t *testing.T) { + code := `if True: + x = 1 + y = 2` + + parser := sitter.NewParser() + defer parser.Close() + parser.SetLanguage(python.GetLanguage()) + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(code)) + if err != nil { + t.Fatalf("Failed to parse code: %v", err) + } + defer tree.Close() + + root := tree.RootNode() + ifNode := root.Child(0) + blockNode := ifNode.ChildByFieldName("consequence") + + result := ParseBlockStatement(blockNode, []byte(code)) + + if result == nil { + t.Fatal("ParseBlockStatement returned nil") + } + + if len(result.Stmts) == 0 { + t.Error("Expected non-empty statements list") + } +} + +func BenchmarkParseReturnStatement(b *testing.B) { + code := "return x + y" + parser := sitter.NewParser() + defer parser.Close() + parser.SetLanguage(python.GetLanguage()) + + tree, _ := parser.ParseCtx(context.Background(), nil, []byte(code)) + defer tree.Close() + root := tree.RootNode() + returnNode := root.Child(0) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ParseReturnStatement(returnNode, []byte(code)) + } +} diff --git a/sourcecode-parser/graph/types.go b/sourcecode-parser/graph/types.go new file mode 100644 index 00000000..39d0e606 --- /dev/null +++ b/sourcecode-parser/graph/types.go @@ -0,0 +1,57 @@ +package graph + +import "github.com/shivasurya/code-pathfinder/sourcecode-parser/model" + +// Node represents a node in the code graph with various properties +// describing code elements like classes, methods, variables, etc. +type Node struct { + ID string + Type string + Name string + CodeSnippet string + LineNumber uint32 + OutgoingEdges []*Edge + IsExternal bool + Modifier string + ReturnType string + MethodArgumentsType []string + MethodArgumentsValue []string + PackageName string + ImportPackage []string + SuperClass string + Interface []string + DataType string + Scope string + VariableValue string + hasAccess bool + File string + isJavaSourceFile bool + isPythonSourceFile bool + ThrowsExceptions []string + Annotation []string + JavaDoc *model.Javadoc + BinaryExpr *model.BinaryExpr + ClassInstanceExpr *model.ClassInstanceExpr + IfStmt *model.IfStmt + WhileStmt *model.WhileStmt + DoStmt *model.DoStmt + ForStmt *model.ForStmt + BreakStmt *model.BreakStmt + ContinueStmt *model.ContinueStmt + YieldStmt *model.YieldStmt + AssertStmt *model.AssertStmt + ReturnStmt *model.ReturnStmt + BlockStmt *model.BlockStmt +} + +// Edge represents a directed edge between two nodes in the code graph. +type Edge struct { + From *Node + To *Node +} + +// CodeGraph represents the entire code graph with nodes and edges. +type CodeGraph struct { + Nodes map[string]*Node + Edges []*Edge +} diff --git a/sourcecode-parser/graph/types_test.go b/sourcecode-parser/graph/types_test.go new file mode 100644 index 00000000..a0b6b495 --- /dev/null +++ b/sourcecode-parser/graph/types_test.go @@ -0,0 +1,186 @@ +package graph + +import ( + "testing" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/model" +) + +func TestNodeCreation(t *testing.T) { + node := &Node{ + ID: "test_id", + Type: "method_declaration", + Name: "testMethod", + LineNumber: 10, + IsExternal: false, + Modifier: "public", + ReturnType: "void", + PackageName: "com.test", + File: "Test.java", + } + + if node.ID != "test_id" { + t.Errorf("Expected ID 'test_id', got %s", node.ID) + } + if node.Type != "method_declaration" { + t.Errorf("Expected Type 'method_declaration', got %s", node.Type) + } + if node.Name != "testMethod" { + t.Errorf("Expected Name 'testMethod', got %s", node.Name) + } + if node.LineNumber != 10 { + t.Errorf("Expected LineNumber 10, got %d", node.LineNumber) + } +} + +func TestNodeWithJavaDoc(t *testing.T) { + javadoc := &model.Javadoc{ + Author: "Test Author", + Version: "1.0", + NumberOfCommentLines: 5, + CommentedCodeElements: "/** Test javadoc */", + } + + node := &Node{ + ID: "test_id", + Name: "testMethod", + JavaDoc: javadoc, + } + + if node.JavaDoc == nil { + t.Fatal("JavaDoc should not be nil") + } + if node.JavaDoc.Author != "Test Author" { + t.Errorf("Expected Author 'Test Author', got %s", node.JavaDoc.Author) + } +} + +func TestNodeWithStatements(t *testing.T) { + ifStmt := &model.IfStmt{ + ConditionalStmt: model.ConditionalStmt{ + Condition: &model.Expr{NodeString: "x > 0"}, + }, + } + + node := &Node{ + ID: "test_id", + Name: "ifNode", + Type: "IfStmt", + IfStmt: ifStmt, + } + + if node.IfStmt == nil { + t.Fatal("IfStmt should not be nil") + } + if node.IfStmt.Condition.NodeString != "x > 0" { + t.Errorf("Expected condition 'x > 0', got %s", node.IfStmt.Condition.NodeString) + } +} + +func TestEdgeCreation(t *testing.T) { + fromNode := &Node{ID: "from", Name: "FromNode"} + toNode := &Node{ID: "to", Name: "ToNode"} + + edge := &Edge{ + From: fromNode, + To: toNode, + } + + if edge.From.ID != "from" { + t.Errorf("Expected From.ID 'from', got %s", edge.From.ID) + } + if edge.To.ID != "to" { + t.Errorf("Expected To.ID 'to', got %s", edge.To.ID) + } +} + +func TestCodeGraphCreation(t *testing.T) { + graph := &CodeGraph{ + Nodes: make(map[string]*Node), + Edges: make([]*Edge, 0), + } + + if graph.Nodes == nil { + t.Error("Nodes map should not be nil") + } + if graph.Edges == nil { + t.Error("Edges slice should not be nil") + } + if len(graph.Nodes) != 0 { + t.Errorf("Expected 0 nodes, got %d", len(graph.Nodes)) + } + if len(graph.Edges) != 0 { + t.Errorf("Expected 0 edges, got %d", len(graph.Edges)) + } +} + +func TestNodeLanguageFlags(t *testing.T) { + javaNode := &Node{ + ID: "java_node", + isJavaSourceFile: true, + } + + pythonNode := &Node{ + ID: "python_node", + isPythonSourceFile: true, + } + + if !javaNode.isJavaSourceFile { + t.Error("Java node should have isJavaSourceFile=true") + } + if pythonNode.isJavaSourceFile { + t.Error("Python node should have isJavaSourceFile=false") + } + if !pythonNode.isPythonSourceFile { + t.Error("Python node should have isPythonSourceFile=true") + } +} + +func TestNodeMethodArguments(t *testing.T) { + node := &Node{ + ID: "method_id", + MethodArgumentsType: []string{"int", "String"}, + MethodArgumentsValue: []string{"count", "name"}, + } + + if len(node.MethodArgumentsType) != 2 { + t.Errorf("Expected 2 argument types, got %d", len(node.MethodArgumentsType)) + } + if len(node.MethodArgumentsValue) != 2 { + t.Errorf("Expected 2 argument values, got %d", len(node.MethodArgumentsValue)) + } + if node.MethodArgumentsType[0] != "int" { + t.Errorf("Expected first type 'int', got %s", node.MethodArgumentsType[0]) + } + if node.MethodArgumentsValue[1] != "name" { + t.Errorf("Expected second value 'name', got %s", node.MethodArgumentsValue[1]) + } +} + +func TestNodeAnnotations(t *testing.T) { + node := &Node{ + ID: "annotated_method", + Annotation: []string{"@Override", "@Deprecated"}, + } + + if len(node.Annotation) != 2 { + t.Errorf("Expected 2 annotations, got %d", len(node.Annotation)) + } + if node.Annotation[0] != "@Override" { + t.Errorf("Expected first annotation '@Override', got %s", node.Annotation[0]) + } +} + +func TestNodeExceptions(t *testing.T) { + node := &Node{ + ID: "throwing_method", + ThrowsExceptions: []string{"IOException", "SQLException"}, + } + + if len(node.ThrowsExceptions) != 2 { + t.Errorf("Expected 2 exceptions, got %d", len(node.ThrowsExceptions)) + } + if node.ThrowsExceptions[0] != "IOException" { + t.Errorf("Expected first exception 'IOException', got %s", node.ThrowsExceptions[0]) + } +} diff --git a/sourcecode-parser/graph/util.go b/sourcecode-parser/graph/util.go deleted file mode 100644 index 1cb41aaf..00000000 --- a/sourcecode-parser/graph/util.go +++ /dev/null @@ -1,70 +0,0 @@ -package graph - -import ( - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" - "log" - "os" -) - -var verboseFlag bool - -func GenerateMethodID(methodName string, parameters []string, sourceFile string) string { - hashInput := fmt.Sprintf("%s-%s-%s", methodName, parameters, sourceFile) - hash := sha256.Sum256([]byte(hashInput)) - return hex.EncodeToString(hash[:]) -} - -func GenerateSha256(input string) string { - hash := sha256.Sum256([]byte(input)) - return hex.EncodeToString(hash[:]) -} - -// Helper function to append a node to a slice only if it's not already present. -func appendUnique(slice []*Node, node *Node) []*Node { - for _, n := range slice { - if n == node { - return slice - } - } - return append(slice, node) -} - -func FormatType(v interface{}) string { - switch val := v.(type) { - case string: - return val - case int, int64: - return fmt.Sprintf("%d", val) - case float32, float64: - return fmt.Sprintf("%.2f", val) - case []interface{}: - //nolint:all - jsonBytes, _ := json.Marshal(val) - return string(jsonBytes) - default: - return fmt.Sprintf("%v", val) - } -} - -func EnableVerboseLogging() { - verboseFlag = true -} - -func Log(message string, args ...interface{}) { - if verboseFlag { - log.Println(message, args) - } -} - -func Fmt(format string, args ...interface{}) { - if verboseFlag { - fmt.Printf(format, args...) - } -} - -func IsGitHubActions() bool { - return os.Getenv("GITHUB_ACTIONS") == "true" -} diff --git a/sourcecode-parser/graph/util_test.go b/sourcecode-parser/graph/util_test.go deleted file mode 100644 index b2993a95..00000000 --- a/sourcecode-parser/graph/util_test.go +++ /dev/null @@ -1,298 +0,0 @@ -package graph - -import ( - "bytes" - "encoding/hex" - "fmt" - "io" - "log" - "os" - "strings" - "testing" -) - -func TestGenerateMethodID(t *testing.T) { - tests := []struct { - name string - methodName string - parameters []string - sourceFile string - want int - }{ - { - name: "Simple method", - methodName: "testMethod", - parameters: []string{"int", "string"}, - sourceFile: "Test.java", - want: 64, - }, - { - name: "Empty parameters", - methodName: "emptyParams", - parameters: []string{}, - sourceFile: "Empty.java", - want: 64, - }, - { - name: "Long method name", - methodName: "thisIsAVeryLongMethodNameThatExceedsTwentyCharacters", - parameters: []string{"long"}, - sourceFile: "LongName.java", - want: 64, - }, - { - name: "Special characters", - methodName: "special$Method#Name", - parameters: []string{"char[]", "int[]"}, - sourceFile: "Special!File@Name.java", - want: 64, - }, - { - name: "Unicode characters", - methodName: "unicodeMethod你好", - parameters: []string{"String"}, - sourceFile: "Unicode文件.java", - want: 64, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := GenerateMethodID(tt.methodName, tt.parameters, tt.sourceFile) - if len(got) != tt.want { - t.Errorf("GenerateMethodID() returned ID with incorrect length, got %d, want %d", len(got), tt.want) - } - if !isValidHexString(got) { - t.Errorf("GenerateMethodID() returned invalid hex string: %s", got) - } - }) - } -} - -func TestGenerateSha256(t *testing.T) { - tests := []struct { - name string - input string - want int - }{ - { - name: "Empty string", - input: "", - want: 64, - }, - { - name: "Simple string", - input: "Hello, World!", - want: 64, - }, - { - name: "Long string", - input: "This is a very long string that exceeds sixty-four characters in length", - want: 64, - }, - { - name: "Special characters", - input: "!@#$%^&*()_+{}[]|\\:;\"'<>,.?/", - want: 64, - }, - { - name: "Unicode characters", - input: "こんにちは世界", - want: 64, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := GenerateSha256(tt.input) - if len(got) != tt.want { - t.Errorf("GenerateSha256() returned hash with incorrect length, got %d, want %d", len(got), tt.want) - } - if !isValidHexString(got) { - t.Errorf("GenerateSha256() returned invalid hex string: %s", got) - } - }) - } -} - -func isValidHexString(s string) bool { - _, err := hex.DecodeString(s) - return err == nil -} - -func TestFormatType(t *testing.T) { - tests := []struct { - name string - input interface{} - want string - }{ - { - name: "String input", - input: "test string", - want: "test string", - }, - { - name: "Integer input", - input: 42, - want: "42", - }, - { - name: "Int64 input", - input: int64(9223372036854775807), - want: "9223372036854775807", - }, - { - name: "Float32 input", - input: float32(3.14), - want: "3.14", - }, - { - name: "Float64 input", - input: 2.71828, - want: "2.72", - }, - { - name: "Slice of integers", - input: []interface{}{1, 2, 3}, - want: "[1,2,3]", - }, - { - name: "Slice of mixed types", - input: []interface{}{"a", 1, true}, - want: `["a",1,true]`, - }, - { - name: "Boolean input", - input: true, - want: "true", - }, - { - name: "Nil input", - input: nil, - want: "", - }, - { - name: "Struct input", - input: struct{ Name string }{"John"}, - want: "{John}", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := FormatType(tt.input) - if got != tt.want { - t.Errorf("FormatType() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestEnableVerboseLogging(t *testing.T) { - // Reset verboseFlag before test - verboseFlag = false - - EnableVerboseLogging() - - if !verboseFlag { - t.Error("EnableVerboseLogging() did not set verboseFlag to true") - } -} - -func TestLog(t *testing.T) { - tests := []struct { - name string - message string - args []interface{} - verbose bool - }{ - { - name: "Verbose logging enabled", - message: "Test message", - args: []interface{}{1, "two", true}, - verbose: true, - }, - { - name: "Verbose logging disabled", - message: "Another test message", - args: []interface{}{3.14, []int{1, 2, 3}}, - verbose: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - verboseFlag = tt.verbose - - // Redirect log output - var buf bytes.Buffer - log.SetOutput(&buf) - defer log.SetOutput(os.Stderr) - - Log(tt.message, tt.args...) - - logOutput := buf.String() - if tt.verbose { - if !strings.Contains(logOutput, tt.message) { - t.Errorf("Log() output does not contain expected message: %s", tt.message) - } - for _, arg := range tt.args { - if !strings.Contains(logOutput, fmt.Sprint(arg)) { - t.Errorf("Log() output does not contain expected argument: %v", arg) - } - } - } else if logOutput != "" { - t.Errorf("Log() produced output when verbose logging was disabled") - } - }) - } -} - -func TestFmt(t *testing.T) { - tests := []struct { - name string - format string - args []interface{} - verbose bool - want string - }{ - { - name: "Verbose formatting enabled", - format: "Number: %d, String: %s, Float: %.2f", - args: []interface{}{42, "test", 3.14159}, - verbose: true, - want: "Number: 42, String: test, Float: 3.14", - }, - { - name: "Verbose formatting disabled", - format: "This should not be printed: %v", - args: []interface{}{"ignored"}, - verbose: false, - want: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - verboseFlag = tt.verbose - - // Redirect stdout - oldStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - Fmt(tt.format, tt.args...) - - w.Close() - os.Stdout = oldStdout - - var buf bytes.Buffer - io.Copy(&buf, r) - got := buf.String() - - if got != tt.want { - t.Errorf("Fmt() output = %q, want %q", got, tt.want) - } - }) - } -} diff --git a/sourcecode-parser/graph/utils.go b/sourcecode-parser/graph/utils.go new file mode 100644 index 00000000..f7a94727 --- /dev/null +++ b/sourcecode-parser/graph/utils.go @@ -0,0 +1,271 @@ +package graph + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/shivasurya/code-pathfinder/sourcecode-parser/model" + sitter "github.com/smacker/go-tree-sitter" +) + +var verboseFlag bool + +// GenerateMethodID generates a unique SHA256 hash ID for a method. +func GenerateMethodID(methodName string, parameters []string, sourceFile string) string { + hashInput := fmt.Sprintf("%s-%s-%s", methodName, parameters, sourceFile) + hash := sha256.Sum256([]byte(hashInput)) + return hex.EncodeToString(hash[:]) +} + +// GenerateSha256 generates a SHA256 hash from an input string. +func GenerateSha256(input string) string { + hash := sha256.Sum256([]byte(input)) + return hex.EncodeToString(hash[:]) +} + +// appendUnique appends a node to a slice only if it's not already present. +func appendUnique(slice []*Node, node *Node) []*Node { + for _, n := range slice { + if n == node { + return slice + } + } + return append(slice, node) +} + +// FormatType formats various types to string representation. +func FormatType(v interface{}) string { + switch val := v.(type) { + case string: + return val + case int, int64: + return fmt.Sprintf("%d", val) + case float32, float64: + return fmt.Sprintf("%.2f", val) + case []interface{}: + //nolint:all + jsonBytes, _ := json.Marshal(val) + return string(jsonBytes) + default: + return fmt.Sprintf("%v", val) + } +} + +// EnableVerboseLogging enables verbose logging mode. +func EnableVerboseLogging() { + verboseFlag = true +} + +// Log logs a message if verbose logging is enabled. +func Log(message string, args ...interface{}) { + if verboseFlag { + log.Println(message, args) + } +} + +// Fmt prints formatted output if verbose logging is enabled. +func Fmt(format string, args ...interface{}) { + if verboseFlag { + fmt.Printf(format, args...) + } +} + +// IsGitHubActions checks if running in GitHub Actions environment. +func IsGitHubActions() bool { + return os.Getenv("GITHUB_ACTIONS") == "true" +} + +// extractVisibilityModifier extracts visibility modifier from a string of modifiers. +func extractVisibilityModifier(modifiers string) string { + words := strings.Fields(modifiers) + for _, word := range words { + switch word { + case "public", "private", "protected": + return word + } + } + return "" // return an empty string if no visibility modifier is found +} + +// isJavaSourceFile checks if a file is a Java source file. +func isJavaSourceFile(filename string) bool { + return filepath.Ext(filename) == ".java" +} + +// isPythonSourceFile checks if a file is a Python source file. +func isPythonSourceFile(filename string) bool { + return filepath.Ext(filename) == ".py" +} + +//nolint:all +func hasAccess(node *sitter.Node, variableName string, sourceCode []byte) bool { + if node == nil { + return false + } + if node.Type() == "identifier" && node.Content(sourceCode) == variableName { + return true + } + + // Recursively check all children of the current node + for i := 0; i < int(node.ChildCount()); i++ { + childNode := node.Child(i) + if hasAccess(childNode, variableName, sourceCode) { + return true + } + } + + // Continue checking in the next sibling + return hasAccess(node.NextSibling(), variableName, sourceCode) +} + +// parseJavadocTags parses Javadoc tags from comment content. +func parseJavadocTags(commentContent string) *model.Javadoc { + javaDoc := &model.Javadoc{} + var javadocTags []*model.JavadocTag + + commentLines := strings.Split(commentContent, "\n") + for _, line := range commentLines { + line = strings.TrimSpace(line) + // line may start with /** or * + line = strings.TrimPrefix(line, "*") + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "@") { + parts := strings.SplitN(line, " ", 2) + if len(parts) == 2 { + tagName := strings.TrimPrefix(parts[0], "@") + tagText := strings.TrimSpace(parts[1]) + + var javadocTag *model.JavadocTag + switch tagName { + case "author": + javadocTag = model.NewJavadocTag(tagName, tagText, "author") + javaDoc.Author = tagText + case "param": + javadocTag = model.NewJavadocTag(tagName, tagText, "param") + case "see": + javadocTag = model.NewJavadocTag(tagName, tagText, "see") + case "throws": + javadocTag = model.NewJavadocTag(tagName, tagText, "throws") + case "version": + javadocTag = model.NewJavadocTag(tagName, tagText, "version") + javaDoc.Version = tagText + case "since": + javadocTag = model.NewJavadocTag(tagName, tagText, "since") + default: + javadocTag = model.NewJavadocTag(tagName, tagText, "unknown") + } + javadocTags = append(javadocTags, javadocTag) + } + } + } + + javaDoc.Tags = javadocTags + javaDoc.NumberOfCommentLines = len(commentLines) + javaDoc.CommentedCodeElements = commentContent + + return javaDoc +} + +//nolint:all +func extractMethodName(node *sitter.Node, sourceCode []byte, filepath string) (string, string) { + var methodID string + + // if the child node is method_declaration, extract method name, modifiers, parameters, and return type + var methodName string + var modifiers, parameters []string + + if node.Type() == "method_declaration" { + // Iterate over all children of the method_declaration node + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + switch child.Type() { + case "modifiers", "marker_annotation", "annotation": + // This child is a modifier or annotation, add its content to modifiers + modifiers = append(modifiers, child.Content(sourceCode)) //nolint:all + case "identifier": + // This child is the method name + methodName = child.Content(sourceCode) + case "formal_parameters": + // This child represents formal parameters; iterate through its children + for j := 0; j < int(child.NamedChildCount()); j++ { + param := child.NamedChild(j) + parameters = append(parameters, param.Content(sourceCode)) + } + } + } + } + + // check if type is method_invocation + // if the child node is method_invocation, extract method name + if node.Type() == "method_invocation" { + for j := 0; j < int(node.ChildCount()); j++ { + child := node.Child(j) + if child.Type() == "identifier" { + if methodName == "" { + methodName = child.Content(sourceCode) + } else { + methodName = methodName + "." + child.Content(sourceCode) + } + } + + argumentsNode := node.ChildByFieldName("argument_list") + // add data type of arguments list + if argumentsNode != nil { + for k := 0; k < int(argumentsNode.ChildCount()); k++ { + argument := argumentsNode.Child(k) + parameters = append(parameters, argument.Child(0).Content(sourceCode)) + } + } + + } + } + content := node.Content(sourceCode) + lineNumber := int(node.StartPoint().Row) + 1 + columnNumber := int(node.StartPoint().Column) + 1 + // convert to string and merge + content += " " + strconv.Itoa(lineNumber) + ":" + strconv.Itoa(columnNumber) + + // Prefix method declarations to avoid ID collisions with invocations + prefix := "" + if node.Type() == "method_declaration" { + prefix = "method:" + } + + methodID = GenerateMethodID(prefix+methodName, parameters, filepath+"/"+content) + return methodName, methodID +} + +// getFiles walks through a directory and returns all Java and Python source files. +func getFiles(directory string) ([]string, error) { + var files []string + err := filepath.Walk(directory, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() { + // append only java and python files + ext := filepath.Ext(path) + if ext == ".java" || ext == ".py" { + files = append(files, path) + } + } + return nil + }) + return files, err +} + +// readFile reads the contents of a file. +func readFile(path string) ([]byte, error) { + content, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return content, nil +} diff --git a/sourcecode-parser/graph/utils_test.go b/sourcecode-parser/graph/utils_test.go new file mode 100644 index 00000000..1fac638f --- /dev/null +++ b/sourcecode-parser/graph/utils_test.go @@ -0,0 +1,683 @@ +package graph + +import ( + "context" + "io/ioutil" + "os" + "path/filepath" + "testing" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/java" +) + +func TestParseJavadocTagsComprehensive(t *testing.T) { + tests := []struct { + name string + comment string + expectedAuthor string + expectedVersion string + expectedTagCount int + expectedCommentLines int + }{ + { + name: "Complete Javadoc", + comment: `/** + * This is a test class + * @author John Doe + * @version 1.0.0 + * @param name the parameter + * @return the result + * @throws IOException if error + * @see OtherClass + * @since 1.0 + */`, + expectedAuthor: "John Doe", + expectedVersion: "1.0.0", + expectedTagCount: 7, + expectedCommentLines: 10, + }, + { + name: "Minimal Javadoc", + comment: `/** + * Simple comment + */`, + expectedAuthor: "", + expectedVersion: "", + expectedTagCount: 0, + expectedCommentLines: 3, + }, + { + name: "Multiple params", + comment: `/** + * @param x first param + * @param y second param + * @param z third param + */`, + expectedAuthor: "", + expectedVersion: "", + expectedTagCount: 3, + expectedCommentLines: 5, + }, + { + name: "Unknown tags", + comment: `/** + * @deprecated use new method + * @custom custom tag + */`, + expectedTagCount: 2, + expectedCommentLines: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseJavadocTags(tt.comment) + + if result == nil { + t.Fatal("parseJavadocTags returned nil") + } + + if result.Author != tt.expectedAuthor { + t.Errorf("Expected author '%s', got '%s'", tt.expectedAuthor, result.Author) + } + + if result.Version != tt.expectedVersion { + t.Errorf("Expected version '%s', got '%s'", tt.expectedVersion, result.Version) + } + + if len(result.Tags) != tt.expectedTagCount { + t.Errorf("Expected %d tags, got %d", tt.expectedTagCount, len(result.Tags)) + } + + if result.NumberOfCommentLines != tt.expectedCommentLines { + t.Errorf("Expected %d comment lines, got %d", tt.expectedCommentLines, result.NumberOfCommentLines) + } + + if result.CommentedCodeElements != tt.comment { + t.Error("CommentedCodeElements should match original comment") + } + }) + } +} + +func TestExtractMethodNameComprehensive(t *testing.T) { + tests := []struct { + name string + code string + expectedMethod string + shouldHaveID bool + }{ + { + name: "Simple method declaration", + code: `public void testMethod() { + System.out.println("test"); +}`, + expectedMethod: "testMethod", + shouldHaveID: true, + }, + { + name: "Method with parameters", + code: `public String calculate(int x, String y) { + return "result"; +}`, + expectedMethod: "calculate", + shouldHaveID: true, + }, + { + name: "Method with annotations", + code: `@Override +public void toString() { + return "test"; +}`, + expectedMethod: "toString", + shouldHaveID: true, + }, + { + name: "Method invocation", + code: `System.out.println("hello")`, + expectedMethod: "println", + shouldHaveID: true, + }, + { + name: "Chained method invocation", + code: `object.getX().getY()`, + expectedMethod: "getY", + shouldHaveID: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := sitter.NewParser() + parser.SetLanguage(java.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(tt.code)) + if err != nil { + t.Fatalf("Failed to parse: %v", err) + } + defer tree.Close() + + root := tree.RootNode() + + // Find the relevant node + var targetNode *sitter.Node + var findNode func(*sitter.Node) + findNode = func(node *sitter.Node) { + if node.Type() == "method_declaration" || node.Type() == "method_invocation" { + targetNode = node + return + } + for i := 0; i < int(node.ChildCount()); i++ { + if targetNode == nil { + findNode(node.Child(i)) + } + } + } + findNode(root) + + if targetNode == nil { + t.Skip("Could not find method node in parsed tree") + } + + methodName, methodID := extractMethodName(targetNode, []byte(tt.code), "Test.java") + + if methodName != tt.expectedMethod { + t.Errorf("Expected method name '%s', got '%s'", tt.expectedMethod, methodName) + } + + if tt.shouldHaveID && methodID == "" { + t.Error("Expected non-empty method ID") + } + + if methodID != "" && len(methodID) != 64 { + t.Errorf("Expected SHA256 hash length 64, got %d", len(methodID)) + } + }) + } +} + +func TestGetFilesComprehensive(t *testing.T) { + // Create temp directory structure + tmpDir, err := ioutil.TempDir("", "test_getfiles") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test files + testFiles := []struct { + path string + shouldMatch bool + }{ + {"Test1.java", true}, + {"Test2.java", true}, + {"script.py", true}, + {"README.md", false}, + {"config.json", false}, + {"app.js", false}, + {"subdir/Test3.java", true}, + {"subdir/script2.py", true}, + {"subdir/other.txt", false}, + } + + for _, tf := range testFiles { + fullPath := filepath.Join(tmpDir, tf.path) + dir := filepath.Dir(fullPath) + if err := os.MkdirAll(dir, 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + if err := ioutil.WriteFile(fullPath, []byte("test content"), 0644); err != nil { + t.Fatalf("Failed to create file: %v", err) + } + } + + // Test getFiles + files, err := getFiles(tmpDir) + if err != nil { + t.Fatalf("getFiles failed: %v", err) + } + + // Count expected files + expectedCount := 0 + for _, tf := range testFiles { + if tf.shouldMatch { + expectedCount++ + } + } + + if len(files) != expectedCount { + t.Errorf("Expected %d files, got %d", expectedCount, len(files)) + } + + // Verify only Java and Python files + for _, file := range files { + ext := filepath.Ext(file) + if ext != ".java" && ext != ".py" { + t.Errorf("Unexpected file extension: %s", ext) + } + } +} + +func TestGetFilesErrors(t *testing.T) { + tests := []struct { + name string + directory string + wantError bool + }{ + { + name: "Non-existent directory", + directory: "/path/that/does/not/exist/xyz123", + wantError: true, + }, + { + name: "Empty path", + directory: "", + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + files, err := getFiles(tt.directory) + + if tt.wantError && err == nil { + t.Error("Expected error but got none") + } + + if err == nil && len(files) > 0 { + t.Errorf("Expected empty files list for invalid directory, got %d files", len(files)) + } + }) + } +} + +func TestReadFileComprehensive(t *testing.T) { + // Create temp file + tmpFile, err := ioutil.TempFile("", "test_readfile_*.java") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + + testContent := []byte("public class Test {\n public static void main(String[] args) {}\n}") + if _, err := tmpFile.Write(testContent); err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + tmpFile.Close() + + t.Run("Read existing file", func(t *testing.T) { + content, err := readFile(tmpFile.Name()) + if err != nil { + t.Fatalf("readFile failed: %v", err) + } + + if string(content) != string(testContent) { + t.Error("Content mismatch") + } + }) + + t.Run("Read non-existent file", func(t *testing.T) { + _, err := readFile("/non/existent/file.java") + if err == nil { + t.Error("Expected error for non-existent file") + } + }) + + t.Run("Read empty file", func(t *testing.T) { + emptyFile, err := ioutil.TempFile("", "test_empty_*.java") + if err != nil { + t.Fatalf("Failed to create empty file: %v", err) + } + defer os.Remove(emptyFile.Name()) + emptyFile.Close() + + content, err := readFile(emptyFile.Name()) + if err != nil { + t.Fatalf("readFile failed: %v", err) + } + + if len(content) != 0 { + t.Errorf("Expected empty content, got %d bytes", len(content)) + } + }) + + t.Run("Read large file", func(t *testing.T) { + largeFile, err := ioutil.TempFile("", "test_large_*.java") + if err != nil { + t.Fatalf("Failed to create large file: %v", err) + } + defer os.Remove(largeFile.Name()) + + // Write 1MB of data + largeContent := make([]byte, 1024*1024) + for i := range largeContent { + largeContent[i] = byte(i % 256) + } + largeFile.Write(largeContent) + largeFile.Close() + + content, err := readFile(largeFile.Name()) + if err != nil { + t.Fatalf("readFile failed: %v", err) + } + + if len(content) != len(largeContent) { + t.Errorf("Expected %d bytes, got %d", len(largeContent), len(content)) + } + }) +} + +func TestHasAccessComprehensive(t *testing.T) { + tests := []struct { + name string + code string + variableName string + expected bool + }{ + { + name: "Variable exists in simple code", + code: `public class Test { + void method() { + int x = 10; + System.out.println(x); + } +}`, + variableName: "x", + expected: true, + }, + { + name: "Variable does not exist", + code: `public class Test { + void method() { + int x = 10; + } +}`, + variableName: "y", + expected: false, + }, + { + name: "Variable in nested scope", + code: `public class Test { + void method() { + if (true) { + int nested = 5; + System.out.println(nested); + } + } +}`, + variableName: "nested", + expected: true, + }, + { + name: "Class name as variable", + code: `public class Test { + void method() { + Test obj = new Test(); + } +}`, + variableName: "Test", + expected: true, + }, + { + name: "Null node", + code: "public class Test {}", + variableName: "anything", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := sitter.NewParser() + parser.SetLanguage(java.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(tt.code)) + if err != nil { + t.Fatalf("Failed to parse: %v", err) + } + defer tree.Close() + + root := tree.RootNode() + result := hasAccess(root, tt.variableName, []byte(tt.code)) + + if result != tt.expected { + t.Errorf("Expected %v for variable '%s', got %v", tt.expected, tt.variableName, result) + } + }) + } + + // Test nil node explicitly + t.Run("Nil node returns false", func(t *testing.T) { + result := hasAccess(nil, "test", []byte("anything")) + if result { + t.Error("hasAccess with nil node should return false") + } + }) +} + +func TestAppendUniqueComprehensive(t *testing.T) { + t.Run("Append to empty slice", func(t *testing.T) { + var slice []*Node + node := &Node{ID: "test1", Name: "Node1"} + + result := appendUnique(slice, node) + + if len(result) != 1 { + t.Errorf("Expected length 1, got %d", len(result)) + } + if result[0] != node { + t.Error("Node not appended correctly") + } + }) + + t.Run("Append unique nodes", func(t *testing.T) { + node1 := &Node{ID: "test1", Name: "Node1"} + node2 := &Node{ID: "test2", Name: "Node2"} + node3 := &Node{ID: "test3", Name: "Node3"} + + slice := []*Node{node1} + slice = appendUnique(slice, node2) + slice = appendUnique(slice, node3) + + if len(slice) != 3 { + t.Errorf("Expected length 3, got %d", len(slice)) + } + }) + + t.Run("Append duplicate node", func(t *testing.T) { + node := &Node{ID: "test1", Name: "Node1"} + slice := []*Node{node} + + result := appendUnique(slice, node) + + if len(result) != 1 { + t.Errorf("Expected length 1 after duplicate, got %d", len(result)) + } + if result[0] != node { + t.Error("Node reference changed") + } + }) + + t.Run("Multiple duplicates", func(t *testing.T) { + node1 := &Node{ID: "test1"} + node2 := &Node{ID: "test2"} + + slice := []*Node{node1, node2} + slice = appendUnique(slice, node1) + slice = appendUnique(slice, node2) + slice = appendUnique(slice, node1) + + if len(slice) != 2 { + t.Errorf("Expected length 2, got %d", len(slice)) + } + }) + + t.Run("Nil node", func(t *testing.T) { + slice := []*Node{&Node{ID: "test1"}} + result := appendUnique(slice, nil) + + if len(result) != 2 { + t.Errorf("Expected length 2, got %d", len(result)) + } + }) +} + +func TestFormatTypeComprehensive(t *testing.T) { + tests := []struct { + name string + input interface{} + expected string + }{ + {"String", "hello world", "hello world"}, + {"Empty string", "", ""}, + {"Int", 42, "42"}, + {"Int64", int64(9223372036854775807), "9223372036854775807"}, + {"Negative int", -100, "-100"}, + {"Float32", float32(3.14), "3.14"}, + {"Float64", 2.71828, "2.72"}, + {"Zero float", 0.0, "0.00"}, + {"Bool true", true, "true"}, + {"Bool false", false, "false"}, + {"Nil", nil, ""}, + {"Empty slice", []interface{}{}, "[]"}, + {"Int slice", []interface{}{1, 2, 3}, "[1,2,3]"}, + {"Mixed slice", []interface{}{1, "two", 3.0}, "[1,\"two\",3]"}, + {"Nested slice", []interface{}{[]interface{}{1, 2}, []interface{}{3, 4}}, "[[1,2],[3,4]]"}, + {"Struct", struct{ Name string }{"test"}, "{test}"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FormatType(tt.input) + if result != tt.expected { + t.Errorf("Expected '%s', got '%s'", tt.expected, result) + } + }) + } +} + +func TestVerboseLoggingComprehensive(t *testing.T) { + // Save original state + originalVerbose := verboseFlag + defer func() { verboseFlag = originalVerbose }() + + t.Run("EnableVerboseLogging sets flag", func(t *testing.T) { + verboseFlag = false + EnableVerboseLogging() + if !verboseFlag { + t.Error("verboseFlag should be true after EnableVerboseLogging") + } + }) + + t.Run("Log when verbose enabled", func(t *testing.T) { + verboseFlag = true + // Should not panic + Log("test message") + Log("test with args: %s %d", "hello", 42) + }) + + t.Run("Log when verbose disabled", func(t *testing.T) { + verboseFlag = false + // Should not panic + Log("this should not print") + }) + + t.Run("Fmt when verbose enabled", func(t *testing.T) { + verboseFlag = true + // Should not panic + Fmt("test: %s\n", "hello") + Fmt("numbers: %d %d\n", 1, 2) + }) + + t.Run("Fmt when verbose disabled", func(t *testing.T) { + verboseFlag = false + // Should not panic + Fmt("this should not print\n") + }) +} + +func TestIsGitHubActionsComprehensive(t *testing.T) { + // Save original environment + original := os.Getenv("GITHUB_ACTIONS") + defer os.Setenv("GITHUB_ACTIONS", original) + + tests := []struct { + name string + envValue string + unset bool + expected bool + }{ + {"Environment is 'true'", "true", false, true}, + {"Environment is 'false'", "false", false, false}, + {"Environment is '1'", "1", false, false}, + {"Environment is empty string", "", false, false}, + {"Environment is unset", "", true, false}, + {"Environment is 'True' (capitalized)", "True", false, false}, + {"Environment is 'TRUE'", "TRUE", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.unset { + os.Unsetenv("GITHUB_ACTIONS") + } else { + os.Setenv("GITHUB_ACTIONS", tt.envValue) + } + + result := IsGitHubActions() + if result != tt.expected { + t.Errorf("Expected %v, got %v (env=%q)", tt.expected, result, tt.envValue) + } + }) + } +} + +func BenchmarkGenerateMethodID(b *testing.B) { + params := []string{"int", "String", "Object"} + for i := 0; i < b.N; i++ { + GenerateMethodID("testMethod", params, "Test.java") + } +} + +func BenchmarkGenerateSha256(b *testing.B) { + input := "test input for sha256 hashing" + for i := 0; i < b.N; i++ { + GenerateSha256(input) + } +} + +func BenchmarkParseJavadocTags(b *testing.B) { + comment := `/** + * Test method + * @param x first parameter + * @param y second parameter + * @return result + */` + for i := 0; i < b.N; i++ { + parseJavadocTags(comment) + } +} + +func BenchmarkHasAccess(b *testing.B) { + code := `public class Test { + void method() { + int x = 10; + System.out.println(x); + } +}` + parser := sitter.NewParser() + parser.SetLanguage(java.GetLanguage()) + defer parser.Close() + + tree, _ := parser.ParseCtx(context.Background(), nil, []byte(code)) + defer tree.Close() + root := tree.RootNode() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + hasAccess(root, "x", []byte(code)) + } +} diff --git a/test-src/python/sample.py b/test-src/python/sample.py new file mode 100644 index 00000000..70f6197f --- /dev/null +++ b/test-src/python/sample.py @@ -0,0 +1,62 @@ +"""Sample Python file for testing parser""" + + +class Calculator: + """A simple calculator class""" + + def __init__(self): + self.result = 0 + + def add(self, x, y): + """Add two numbers""" + result = x + y + return result + + def subtract(self, x, y): + """Subtract two numbers""" + assert y != 0, "Cannot divide by zero" + return x - y + + def multiply(self, x, y): + """Multiply two numbers""" + for i in range(y): + if i == 10: + break + self.result += x + return self.result + + +def fibonacci(n): + """Generate fibonacci sequence""" + a, b = 0, 1 + for _ in range(n): + yield a + a, b = b, a + b + + +def process_data(data): + """Process data with error handling""" + if not data: + return None + + processed = [] + for item in data: + if item < 0: + continue + processed.append(item * 2) + + return processed + + +def main(): + """Main function""" + calc = Calculator() + result = calc.add(10, 20) + print(f"Result: {result}") + + for num in fibonacci(10): + print(num) + + +if __name__ == "__main__": + main()