Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/serena/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from serena.util.dataclass import get_dataclass_default
from serena.util.logging import MemoryLogHandler
from solidlsp.ls_config import Language
from solidlsp.ls_types import SymbolKind
from solidlsp.util.subprocess_util import subprocess_kwargs

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -811,9 +812,7 @@ def health_check(project: str) -> None:
return

# Extract suitable symbol (prefer class or function over variables)
# LSP symbol kinds: 5=class, 12=function, 6=method, 9=constructor
preferred_kinds = [5, 12, 6, 9] # class, function, method, constructor

preferred_kinds = {SymbolKind.Class.name, SymbolKind.Function.name, SymbolKind.Method.name, SymbolKind.Constructor.name}
selected_symbol = None
for symbol in overview_data:
if symbol.get("kind") in preferred_kinds:
Expand Down
2 changes: 1 addition & 1 deletion src/serena/code_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, cast

from serena.jetbrains.jetbrains_plugin_client import JetBrainsPluginClient
from serena.symbol import JetBrainsSymbol, LanguageServerSymbol, LanguageServerSymbolRetriever, PositionInFile, Symbol
from solidlsp import SolidLanguageServer, ls_types
from solidlsp.ls import LSPFileBuffer
from solidlsp.ls_utils import PathUtils, TextUtils

from .constants import DEFAULT_SOURCE_FILE_ENCODING
from .project import Project
from .tools.jetbrains_plugin_client import JetBrainsPluginClient

if TYPE_CHECKING:
from .agent import SerenaAgent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from requests import Response
from sensai.util.string import ToStringMixin

import serena.tools.jetbrains_types as jb
import serena.jetbrains.jetbrains_types as jb
from serena.jetbrains.jetbrains_types import PluginStatusDTO
from serena.project import Project
from serena.text_utils import render_html
from serena.tools.jetbrains_types import PluginStatusDTO
from serena.util.version import Version

T = TypeVar("T")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import NotRequired, TypedDict
from typing import Literal, NotRequired, TypedDict


class PluginStatusDTO(TypedDict):
Expand Down Expand Up @@ -30,6 +30,9 @@ class SymbolDTO(TypedDict):
num_usages: NotRequired[int]


SymbolDTOKey = Literal["name_path", "relative_path", "type", "body", "quick_info", "documentation", "text_range", "children", "num_usages"]


class SymbolCollectionResponse(TypedDict):
symbols: list[SymbolDTO]

Expand Down
241 changes: 190 additions & 51 deletions src/serena/symbol.py

Large diffs are not rendered by default.

48 changes: 11 additions & 37 deletions src/serena/tools/jetbrains_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from collections import defaultdict
from typing import Any, Literal

import serena.tools.jetbrains_types as jb
import serena.jetbrains.jetbrains_types as jb
from serena.jetbrains.jetbrains_plugin_client import JetBrainsPluginClient
from serena.symbol import JetBrainsSymbolDictGrouper
from serena.tools import Tool, ToolMarkerOptional, ToolMarkerSymbolicRead
from serena.tools.jetbrains_plugin_client import JetBrainsPluginClient

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -91,6 +91,8 @@ class JetBrainsFindReferencingSymbolsTool(Tool, ToolMarkerSymbolicRead, ToolMark
Finds symbols that reference the given symbol using the JetBrains backend
"""

symbol_dict_grouper = JetBrainsSymbolDictGrouper(["relative_path", "type"], ["type"], collapse_singleton=True)

# TODO: (maybe) - add content snippets showing the references like in LS based version?
def apply(
self,
Expand Down Expand Up @@ -118,8 +120,10 @@ def apply(
relative_path=relative_path,
include_quick_info=False, # TODO: Hotfix for serena-jetbrains-plugin/issues/13; revert once fixed
)
result = self._to_json(response_dict)
return self._limit_length(result, max_answer_chars)
symbol_dicts = response_dict["symbols"]
result = self.symbol_dict_grouper.group(symbol_dicts)
result_json = self._to_json(result)
return self._limit_length(result_json, max_answer_chars)


class JetBrainsGetSymbolsOverviewTool(Tool, ToolMarkerSymbolicRead, ToolMarkerOptional):
Expand All @@ -128,37 +132,7 @@ class JetBrainsGetSymbolsOverviewTool(Tool, ToolMarkerSymbolicRead, ToolMarkerOp
"""

USE_COMPACT_FORMAT = True

@staticmethod
def _transform_symbols_to_compact_format(symbols: list[jb.SymbolDTO]) -> dict[str, list]:
"""
Transform symbol overview from verbose format to compact grouped format.

Groups symbols by kind and uses names instead of full symbol objects.
For symbols with children, creates nested dictionaries.

The name_path can be inferred from the hierarchical structure:
- Top-level symbols: name_path = name
- Nested symbols: name_path = parent_name + "/" + name
For example, "convert" under class "ProjectType" has name_path "ProjectType/convert".
"""
result = defaultdict(list)

for symbol in symbols:
kind = symbol.get("type", "Unknown")
name_path = symbol["name_path"]
name = name_path.split("/")[-1]
children = symbol.get("children", [])

if children:
# Symbol has children: create nested dict {name: children_dict}
children_dict = JetBrainsGetSymbolsOverviewTool._transform_symbols_to_compact_format(children)
result[kind].append({name: children_dict})
else:
# Symbol has no children: just add the name
result[kind].append(name) # type: ignore

return result
symbol_dict_grouper = JetBrainsSymbolDictGrouper(["type"], ["type"], collapse_singleton=True, map_name_path_to_name=True)

def apply(
self,
Expand Down Expand Up @@ -187,7 +161,7 @@ def apply(
)
if self.USE_COMPACT_FORMAT:
symbols = symbol_overview["symbols"]
result: dict[str, Any] = {"symbols": self._transform_symbols_to_compact_format(symbols)}
result: dict[str, Any] = {"symbols": self.symbol_dict_grouper.group(symbols)}
documentation = symbol_overview.pop("documentation", None)
if documentation:
result["docstring"] = documentation
Expand Down
87 changes: 33 additions & 54 deletions src/serena/tools/symbol_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
"""

import os
from collections import defaultdict
from collections.abc import Sequence
from copy import copy
from typing import Any

from serena.symbol import LanguageServerSymbol, LanguageServerSymbolDictGrouper
from serena.tools import (
SUCCESS_RESULT,
Tool,
Expand All @@ -18,22 +16,6 @@
from solidlsp.ls_types import SymbolKind


def _sanitize_symbol_dict(symbol_dict: dict[str, Any]) -> dict[str, Any]:
"""
Sanitize a symbol dictionary inplace by removing unnecessary information.
"""
# We replace the location entry, which repeats line information already included in body_location
# and has unnecessary information on column, by just the relative path.
symbol_dict = copy(symbol_dict)
s_relative_path = symbol_dict.get("location", {}).get("relative_path")
if s_relative_path is not None:
symbol_dict["relative_path"] = s_relative_path
symbol_dict.pop("location", None)
# also remove name, name_path should be enough
symbol_dict.pop("name")
return symbol_dict


class RestartLanguageServerTool(Tool, ToolMarkerOptional):
"""Restarts the language server, may be necessary when edits not through Serena happen."""

Expand All @@ -50,6 +32,8 @@ class GetSymbolsOverviewTool(Tool, ToolMarkerSymbolicRead):
Gets an overview of the top-level symbols defined in a given file.
"""

symbol_dict_grouper = LanguageServerSymbolDictGrouper(["kind"], ["kind"], collapse_singleton=True)

def apply(self, relative_path: str, depth: int = 0, max_answer_chars: int = -1) -> str:
"""
Use this tool to get a high-level understanding of the code symbols in a file.
Expand All @@ -65,11 +49,11 @@ def apply(self, relative_path: str, depth: int = 0, max_answer_chars: int = -1)
:return: a JSON object containing symbols grouped by kind in a compact format.
"""
result = self.get_symbol_overview(relative_path, depth=depth)
compact_result = self._transform_symbols_to_compact_format(result)
compact_result = self.symbol_dict_grouper.group(result)
result_json_str = self._to_json(compact_result)
return self._limit_length(result_json_str, max_answer_chars)

def get_symbol_overview(self, relative_path: str, depth: int = 0) -> list[dict]:
def get_symbol_overview(self, relative_path: str, depth: int = 0) -> list[LanguageServerSymbol.OutputDict]:
"""
:param relative_path: relative path to a source file
:param depth: the depth up to which descendants shall be retrieved
Expand All @@ -85,37 +69,25 @@ def get_symbol_overview(self, relative_path: str, depth: int = 0) -> list[dict]:
if os.path.isdir(file_path):
raise ValueError(f"Expected a file path, but got a directory path: {relative_path}. ")

return symbol_retriever.get_symbol_overview(relative_path, depth=depth)[relative_path]
symbols = symbol_retriever.get_symbol_overview(relative_path)[relative_path]

@staticmethod
def _transform_symbols_to_compact_format(symbols: list[dict[str, Any]]) -> dict[str, list]:
"""
Transform symbol overview from verbose format to compact grouped format.

Groups symbols by kind and uses names instead of full symbol objects.
For symbols with children, creates nested dictionaries.

The name_path can be inferred from the hierarchical structure:
- Top-level symbols: name_path = name
- Nested symbols: name_path = parent_name + "/" + name
For example, "convert" under class "ProjectType" has name_path "ProjectType/convert".
"""
result = defaultdict(list)
def child_inclusion_predicate(s: LanguageServerSymbol) -> bool:
return not s.is_low_level()

symbol_dicts = []
for symbol in symbols:
kind = symbol.get("kind", "Unknown")
name = symbol.get("name", "unknown")
children = symbol.get("children", [])

if children:
# Symbol has children: create nested dict {name: children_dict}
children_dict = GetSymbolsOverviewTool._transform_symbols_to_compact_format(children)
result[kind].append({name: children_dict})
else:
# Symbol has no children: just add the name
result[kind].append(name)

return result
symbol_dicts.append(
symbol.to_dict(
name_path=False,
name=True,
depth=depth,
kind=True,
relative_path=False,
location=False,
child_inclusion_predicate=child_inclusion_predicate,
)
)
return symbol_dicts


class FindSymbolTool(Tool, ToolMarkerSymbolicRead):
Expand Down Expand Up @@ -185,7 +157,7 @@ def apply(
substring_matching=substring_matching,
within_relative_path=relative_path,
)
symbol_dicts = [_sanitize_symbol_dict(s.to_dict(kind=True, location=True, depth=depth, include_body=include_body)) for s in symbols]
symbol_dicts = [dict(s.to_dict(kind=True, relative_path=True, body_location=True, depth=depth, body=include_body)) for s in symbols]
if not include_body and include_info:
# we add an info field to the symbol dicts if requested
for s, s_dict in zip(symbols, symbol_dicts, strict=True):
Expand All @@ -201,6 +173,8 @@ class FindReferencingSymbolsTool(Tool, ToolMarkerSymbolicRead):
Finds symbols that reference the given symbol using the language server backend
"""

symbol_dict_grouper = LanguageServerSymbolDictGrouper(["relative_path", "kind"], ["kind"], collapse_singleton=True)

# noinspection PyDefaultArgument
def apply(
self,
Expand Down Expand Up @@ -229,17 +203,19 @@ def apply(
parsed_include_kinds: Sequence[SymbolKind] | None = [SymbolKind(k) for k in include_kinds] if include_kinds else None
parsed_exclude_kinds: Sequence[SymbolKind] | None = [SymbolKind(k) for k in exclude_kinds] if exclude_kinds else None
symbol_retriever = self.create_language_server_symbol_retriever()

references_in_symbols = symbol_retriever.find_referencing_symbols(
name_path,
relative_file_path=relative_path,
include_body=include_body,
include_kinds=parsed_include_kinds,
exclude_kinds=parsed_exclude_kinds,
)

reference_dicts = []
for ref in references_in_symbols:
ref_dict = ref.symbol.to_dict(kind=True, location=True, depth=0, include_body=include_body)
ref_dict = _sanitize_symbol_dict(ref_dict)
ref_dict_orig = ref.symbol.to_dict(kind=True, relative_path=True, depth=0, body=include_body, body_location=True)
ref_dict = dict(ref_dict_orig)
if not include_body:
ref_relative_path = ref.symbol.location.relative_path
assert ref_relative_path is not None, f"Referencing symbol {ref.symbol.name} has no relative path, this is likely a bug."
Expand All @@ -250,8 +226,11 @@ def apply(
)
ref_dict["content_around_reference"] = content_around_ref.to_display_string()
reference_dicts.append(ref_dict)
result = self._to_json(reference_dicts)
return self._limit_length(result, max_answer_chars)

result = self.symbol_dict_grouper.group(reference_dicts) # type: ignore

result_json = self._to_json(result)
return self._limit_length(result_json, max_answer_chars)


class ReplaceSymbolBodyTool(Tool, ToolMarkerSymbolicEdit):
Expand Down
24 changes: 19 additions & 5 deletions test/serena/test_serena_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def test_find_symbol_references(self, serena_agent: SerenaAgent, symbol_name: st

# Find the symbol location first
find_symbol_tool = agent.get_tool(FindSymbolTool)
result = find_symbol_tool.apply_ex(name_path_pattern=symbol_name, relative_path=def_file)
result = find_symbol_tool.apply(name_path_pattern=symbol_name, relative_path=def_file)

time.sleep(1)
symbols = json.loads(result)
Expand All @@ -218,12 +218,26 @@ def test_find_symbol_references(self, serena_agent: SerenaAgent, symbol_name: st

# Now find references
find_refs_tool = agent.get_tool(FindReferencingSymbolsTool)
result = find_refs_tool.apply_ex(name_path=def_symbol["name_path"], relative_path=def_symbol["relative_path"])
result = find_refs_tool.apply(name_path=def_symbol["name_path"], relative_path=def_symbol["relative_path"])

def contains_ref_with_relative_path(refs, relative_path):
"""
Checks for reference to relative path, regardless of output format (grouped an ungrouped)
"""
if isinstance(refs, list):
for ref in refs:
if contains_ref_with_relative_path(ref, relative_path):
return True
elif isinstance(refs, dict):
if relative_path in refs:
return True
for value in refs.values():
if contains_ref_with_relative_path(value, relative_path):
return True
return False

refs = json.loads(result)
assert any(
ref["relative_path"] == ref_file for ref in refs
), f"Expected to find reference to {symbol_name} in {ref_file}. refs={refs}"
assert contains_ref_with_relative_path(refs, ref_file), f"Expected to find reference to {symbol_name} in {ref_file}. refs={refs}"

@pytest.mark.parametrize(
"serena_agent,name_path,substring_matching,expected_symbol_name,expected_kind,expected_file",
Expand Down
24 changes: 23 additions & 1 deletion test/serena/test_symbol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from serena.symbol import LanguageServerSymbolRetriever, NamePathComponent, NamePathMatcher
from serena.jetbrains.jetbrains_types import SymbolDTO, SymbolDTOKey
from serena.symbol import LanguageServerSymbol, LanguageServerSymbolRetriever, NamePathComponent, NamePathMatcher
from solidlsp import SolidLanguageServer
from solidlsp.ls_config import Language

Expand Down Expand Up @@ -238,3 +239,24 @@ def test_request_info(self, language_server: SolidLanguageServer):
create_user_method_symbol = symbol_retriever.find("UserService/create_user", within_relative_path="test_repo/services.py")[0]
create_user_method_symbol_info = symbol_retriever.request_info_for_symbol(create_user_method_symbol)
assert "Create a new user and store it" in create_user_method_symbol_info


class TestSymbolDictTypes:
@staticmethod
def check_key_type(dict_type: type, key_type: type):
"""
:param dict_type: a TypedDict type
:param key_type: the corresponding key type (Literal[...]) that the dict should have for keys
"""
dict_type_keys = dict_type.__annotations__.keys()
assert len(dict_type_keys) == len(
key_type.__args__ # type: ignore
), f"Expected {len(key_type.__args__)} keys in {dict_type}, but got {len(dict_type_keys)}" # type: ignore
for expected_key in key_type.__args__: # type: ignore
assert expected_key in dict_type_keys, f"Expected key '{expected_key}' not found in {dict_type}"

def test_ls_symbol_dict_type(self):
self.check_key_type(LanguageServerSymbol.OutputDict, LanguageServerSymbol.OutputDictKey)

def test_jb_symbol_dict_type(self):
self.check_key_type(SymbolDTO, SymbolDTOKey)
1 change: 1 addition & 0 deletions test/solidlsp/fsharp/test_fsharp_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def test_go_to_definition(self, language_server: SolidLanguageServer) -> None:
# We should get at least some definitions
assert len(definitions) >= 0, "Should get definitions (even if empty for complex cases)"

@pytest.mark.skipif(is_ci, reason="Test is flaky") # TODO: Re-enable if the LS can be made more reliable #1039
@pytest.mark.parametrize("language_server", [Language.FSHARP], indirect=True)
def test_hover_information(self, language_server: SolidLanguageServer) -> None:
"""Test hover information functionality."""
Expand Down
Loading