diff --git a/libcst/_metadata_dependent.py b/libcst/_metadata_dependent.py index 6a768270c..4faf74727 100644 --- a/libcst/_metadata_dependent.py +++ b/libcst/_metadata_dependent.py @@ -7,14 +7,17 @@ from abc import ABC from contextlib import contextmanager from typing import ( + Callable, cast, ClassVar, Collection, + Generic, Iterator, Mapping, Type, TYPE_CHECKING, TypeVar, + Union, ) if TYPE_CHECKING: @@ -29,7 +32,28 @@ _T = TypeVar("_T") -_UNDEFINED_DEFAULT = object() + +class _UNDEFINED_DEFAULT: + pass + + +class LazyValue(Generic[_T]): + """ + The class for implementing a lazy metadata loading mechanism that improves the + performance when retriving expensive metadata (e.g., qualified names). Providers + including :class:`~libcst.metadata.QualifiedNameProvider` use this class to load + the metadata of a certain node lazily when calling + :func:`~libcst.MetadataDependent.get_metadata`. + """ + + def __init__(self, callable: Callable[[], _T]) -> None: + self.callable = callable + self.return_value: Union[_T, Type[_UNDEFINED_DEFAULT]] = _UNDEFINED_DEFAULT + + def __call__(self) -> _T: + if self.return_value is _UNDEFINED_DEFAULT: + self.return_value = self.callable() + return cast(_T, self.return_value) class MetadataDependent(ABC): @@ -107,6 +131,9 @@ def get_metadata( ) if default is not _UNDEFINED_DEFAULT: - return cast(_T, self.metadata[key].get(node, default)) + value = self.metadata[key].get(node, default) else: - return cast(_T, self.metadata[key][node]) + value = self.metadata[key][node] + if isinstance(value, LazyValue): + value = value() + return cast(_T, value) diff --git a/libcst/matchers/_matcher_base.py b/libcst/matchers/_matcher_base.py index 64670be42..d8f69ec63 100644 --- a/libcst/matchers/_matcher_base.py +++ b/libcst/matchers/_matcher_base.py @@ -31,6 +31,7 @@ import libcst import libcst.metadata as meta from libcst import FlattenSentinel, MaybeSentinel, RemovalSentinel +from libcst._metadata_dependent import LazyValue class DoNotCareSentinel(Enum): @@ -1544,7 +1545,11 @@ def _fetch(provider: meta.ProviderT, node: libcst.CSTNode) -> object: if provider not in metadata: metadata[provider] = wrapper.resolve(provider) - return metadata.get(provider, {}).get(node, _METADATA_MISSING_SENTINEL) + node_metadata = metadata.get(provider, {}).get(node, _METADATA_MISSING_SENTINEL) + if isinstance(node_metadata, LazyValue): + node_metadata = node_metadata() + + return node_metadata return _fetch diff --git a/libcst/metadata/base_provider.py b/libcst/metadata/base_provider.py index 69af2dcea..1c113f57a 100644 --- a/libcst/metadata/base_provider.py +++ b/libcst/metadata/base_provider.py @@ -7,7 +7,6 @@ from types import MappingProxyType from typing import ( Callable, - cast, Generic, List, Mapping, @@ -16,12 +15,14 @@ Type, TYPE_CHECKING, TypeVar, + Union, ) from libcst._batched_visitor import BatchableCSTVisitor from libcst._metadata_dependent import ( _T as _MetadataT, _UNDEFINED_DEFAULT, + LazyValue, MetadataDependent, ) from libcst._visitors import CSTVisitor @@ -36,6 +37,7 @@ # BaseMetadataProvider[int] would be a subtype of BaseMetadataProvider[object], so the # typevar is covariant. _ProvidedMetadataT = TypeVar("_ProvidedMetadataT", covariant=True) +MaybeLazyMetadataT = Union[LazyValue[_ProvidedMetadataT], _ProvidedMetadataT] # We can't use an ABCMeta here, because of metaclass conflicts @@ -52,16 +54,16 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]): # # N.B. This has some typing variance problems. See `set_metadata` for an # explanation. - _computed: MutableMapping["CSTNode", _ProvidedMetadataT] + _computed: MutableMapping["CSTNode", MaybeLazyMetadataT] - #: Implement gen_cache to indicate the matadata provider depends on cache from external + #: Implement gen_cache to indicate the metadata provider depends on cache from external #: system. This function will be called by :class:`~libcst.metadata.FullRepoManager` #: to compute required cache object per file path. gen_cache: Optional[Callable[[Path, List[str], int], Mapping[str, object]]] = None def __init__(self, cache: object = None) -> None: super().__init__() - self._computed = {} + self._computed: MutableMapping["CSTNode", MaybeLazyMetadataT] = {} if self.gen_cache and cache is None: # The metadata provider implementation is responsible to store and use cache. raise Exception( @@ -71,7 +73,7 @@ def __init__(self, cache: object = None) -> None: def _gen( self, wrapper: "MetadataWrapper" - ) -> Mapping["CSTNode", _ProvidedMetadataT]: + ) -> Mapping["CSTNode", MaybeLazyMetadataT]: """ Resolves and returns metadata mapping for the module in ``wrapper``. @@ -93,11 +95,7 @@ def _gen_impl(self, module: "Module") -> None: """ ... - # pyre-ignore[46]: The covariant `value` isn't type-safe because we write it to - # pyre: `self._computed`, however we assume that only one subclass in the MRO chain - # pyre: will ever call `set_metadata`, so it's okay for our purposes. There's no - # pyre: sane way to redesign this API so that it doesn't have this problem. - def set_metadata(self, node: "CSTNode", value: _ProvidedMetadataT) -> None: + def set_metadata(self, node: "CSTNode", value: MaybeLazyMetadataT) -> None: """ Record a metadata value ``value`` for ``node``. """ @@ -107,7 +105,9 @@ def get_metadata( self, key: Type["BaseMetadataProvider[_MetadataT]"], node: "CSTNode", - default: _MetadataT = _UNDEFINED_DEFAULT, + default: Union[ + MaybeLazyMetadataT, Type[_UNDEFINED_DEFAULT] + ] = _UNDEFINED_DEFAULT, ) -> _MetadataT: """ The same method as :func:`~libcst.MetadataDependent.get_metadata` except @@ -116,9 +116,12 @@ def get_metadata( """ if key is type(self): if default is not _UNDEFINED_DEFAULT: - return cast(_MetadataT, self._computed.get(node, default)) + ret = self._computed.get(node, default) else: - return cast(_MetadataT, self._computed[node]) + ret = self._computed[node] + if isinstance(ret, LazyValue): + return ret() + return ret return super().get_metadata(key, node, default) diff --git a/libcst/metadata/name_provider.py b/libcst/metadata/name_provider.py index 007535043..60d8763e7 100644 --- a/libcst/metadata/name_provider.py +++ b/libcst/metadata/name_provider.py @@ -8,7 +8,7 @@ from typing import Collection, List, Mapping, Optional, Union import libcst as cst -from libcst._metadata_dependent import MetadataDependent +from libcst._metadata_dependent import LazyValue, MetadataDependent from libcst.helpers.module import calculate_module_and_package, ModuleNameAndPackage from libcst.metadata.base_provider import BatchableMetadataProvider from libcst.metadata.scope_provider import ( @@ -78,7 +78,9 @@ def __init__(self, provider: "QualifiedNameProvider") -> None: def on_visit(self, node: cst.CSTNode) -> bool: scope = self.provider.get_metadata(ScopeProvider, node, None) if scope: - self.provider.set_metadata(node, scope.get_qualified_names_for(node)) + self.provider.set_metadata( + node, LazyValue(lambda: scope.get_qualified_names_for(node)) + ) else: self.provider.set_metadata(node, set()) super().on_visit(node) diff --git a/libcst/metadata/tests/test_base_provider.py b/libcst/metadata/tests/test_base_provider.py index 0bf4ca512..26ebde701 100644 --- a/libcst/metadata/tests/test_base_provider.py +++ b/libcst/metadata/tests/test_base_provider.py @@ -7,6 +7,7 @@ import libcst as cst from libcst import parse_module +from libcst._metadata_dependent import LazyValue from libcst.metadata import ( BatchableMetadataProvider, MetadataWrapper, @@ -75,3 +76,63 @@ def visit_Return(self, node: cst.Return) -> None: self.assertEqual(metadata[SimpleProvider][pass_], 1) self.assertEqual(metadata[SimpleProvider][return_], 2) self.assertEqual(metadata[SimpleProvider][pass_2], 1) + + def test_lazy_visitor_provider(self) -> None: + class SimpleLazyProvider(VisitorMetadataProvider[int]): + """ + Sets metadata on every node to a callable that returns 1. + """ + + def on_visit(self, node: cst.CSTNode) -> bool: + self.set_metadata(node, LazyValue(lambda: 1)) + return True + + wrapper = MetadataWrapper(parse_module("pass; return")) + module = wrapper.module + pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] + return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1] + + provider = SimpleLazyProvider() + metadata = provider._gen(wrapper) + + # Check access on provider + self.assertEqual(provider.get_metadata(SimpleLazyProvider, module), 1) + self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_), 1) + self.assertEqual(provider.get_metadata(SimpleLazyProvider, return_), 1) + + # Check returned mapping + self.assertTrue(isinstance(metadata[module], LazyValue)) + self.assertTrue(isinstance(metadata[pass_], LazyValue)) + self.assertTrue(isinstance(metadata[return_], LazyValue)) + + def testlazy_batchable_provider(self) -> None: + class SimpleLazyProvider(BatchableMetadataProvider[int]): + """ + Sets metadata on every pass node to a callable that returns 1, + and every return node to a callable that returns 2. + """ + + def visit_Pass(self, node: cst.Pass) -> None: + self.set_metadata(node, LazyValue(lambda: 1)) + + def visit_Return(self, node: cst.Return) -> None: + self.set_metadata(node, LazyValue(lambda: 2)) + + wrapper = MetadataWrapper(parse_module("pass; return; pass")) + module = wrapper.module + pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] + return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1] + pass_2 = cast(cst.SimpleStatementLine, module.body[0]).body[2] + + provider = SimpleLazyProvider() + metadata = _gen_batchable(wrapper, [provider]) + + # Check access on provider + self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_), 1) + self.assertEqual(provider.get_metadata(SimpleLazyProvider, return_), 2) + self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_2), 1) + + # Check returned mapping + self.assertTrue(isinstance(metadata[SimpleLazyProvider][pass_], LazyValue)) + self.assertTrue(isinstance(metadata[SimpleLazyProvider][return_], LazyValue)) + self.assertTrue(isinstance(metadata[SimpleLazyProvider][pass_2], LazyValue)) diff --git a/libcst/metadata/tests/test_name_provider.py b/libcst/metadata/tests/test_name_provider.py index 9b0b409fc..9f3813687 100644 --- a/libcst/metadata/tests/test_name_provider.py +++ b/libcst/metadata/tests/test_name_provider.py @@ -10,6 +10,7 @@ import libcst as cst from libcst import ensure_type +from libcst._nodes.base import CSTNode from libcst.metadata import ( FullyQualifiedNameProvider, MetadataWrapper, @@ -22,11 +23,26 @@ from libcst.testing.utils import data_provider, UnitTest +class QNameVisitor(cst.CSTVisitor): + + METADATA_DEPENDENCIES = (QualifiedNameProvider,) + + def __init__(self) -> None: + self.qnames: Dict["CSTNode", Collection[QualifiedName]] = {} + + def on_visit(self, node: cst.CSTNode) -> bool: + qname = self.get_metadata(QualifiedNameProvider, node) + self.qnames[node] = qname + return True + + def get_qualified_name_metadata_provider( module_str: str, ) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]: wrapper = MetadataWrapper(cst.parse_module(dedent(module_str))) - return wrapper.module, wrapper.resolve(QualifiedNameProvider) + visitor = QNameVisitor() + wrapper.visit(visitor) + return wrapper.module, visitor.qnames def get_qualified_names(module_str: str) -> Set[QualifiedName]: @@ -358,7 +374,7 @@ def f(): pass else: import f import a.b as f - + f() """ )