Skip to content

Commit 7cb229d

Browse files
Chenguang-Zhuzsol
andauthored
Implement lazy loading mechanism for QualifiedNameProvider (#720)
* Implement lazy loading mechanism for expensive metadata providers * Add support for lazy values in metadata matchers * Fix type issues and implement lazy value support in base metadata provider too * Add unit tests for BaseMetadataProvider Co-authored-by: Zsolt Dollenstein <[email protected]>
1 parent b3eda50 commit 7cb229d

6 files changed

Lines changed: 135 additions & 21 deletions

File tree

libcst/_metadata_dependent.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77
from abc import ABC
88
from contextlib import contextmanager
99
from typing import (
10+
Callable,
1011
cast,
1112
ClassVar,
1213
Collection,
14+
Generic,
1315
Iterator,
1416
Mapping,
1517
Type,
1618
TYPE_CHECKING,
1719
TypeVar,
20+
Union,
1821
)
1922

2023
if TYPE_CHECKING:
@@ -29,7 +32,28 @@
2932

3033
_T = TypeVar("_T")
3134

32-
_UNDEFINED_DEFAULT = object()
35+
36+
class _UNDEFINED_DEFAULT:
37+
pass
38+
39+
40+
class LazyValue(Generic[_T]):
41+
"""
42+
The class for implementing a lazy metadata loading mechanism that improves the
43+
performance when retriving expensive metadata (e.g., qualified names). Providers
44+
including :class:`~libcst.metadata.QualifiedNameProvider` use this class to load
45+
the metadata of a certain node lazily when calling
46+
:func:`~libcst.MetadataDependent.get_metadata`.
47+
"""
48+
49+
def __init__(self, callable: Callable[[], _T]) -> None:
50+
self.callable = callable
51+
self.return_value: Union[_T, Type[_UNDEFINED_DEFAULT]] = _UNDEFINED_DEFAULT
52+
53+
def __call__(self) -> _T:
54+
if self.return_value is _UNDEFINED_DEFAULT:
55+
self.return_value = self.callable()
56+
return cast(_T, self.return_value)
3357

3458

3559
class MetadataDependent(ABC):
@@ -107,6 +131,9 @@ def get_metadata(
107131
)
108132

109133
if default is not _UNDEFINED_DEFAULT:
110-
return cast(_T, self.metadata[key].get(node, default))
134+
value = self.metadata[key].get(node, default)
111135
else:
112-
return cast(_T, self.metadata[key][node])
136+
value = self.metadata[key][node]
137+
if isinstance(value, LazyValue):
138+
value = value()
139+
return cast(_T, value)

libcst/matchers/_matcher_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import libcst
3232
import libcst.metadata as meta
3333
from libcst import FlattenSentinel, MaybeSentinel, RemovalSentinel
34+
from libcst._metadata_dependent import LazyValue
3435

3536

3637
class DoNotCareSentinel(Enum):
@@ -1544,7 +1545,11 @@ def _fetch(provider: meta.ProviderT, node: libcst.CSTNode) -> object:
15441545
if provider not in metadata:
15451546
metadata[provider] = wrapper.resolve(provider)
15461547

1547-
return metadata.get(provider, {}).get(node, _METADATA_MISSING_SENTINEL)
1548+
node_metadata = metadata.get(provider, {}).get(node, _METADATA_MISSING_SENTINEL)
1549+
if isinstance(node_metadata, LazyValue):
1550+
node_metadata = node_metadata()
1551+
1552+
return node_metadata
15481553

15491554
return _fetch
15501555

libcst/metadata/base_provider.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from types import MappingProxyType
88
from typing import (
99
Callable,
10-
cast,
1110
Generic,
1211
List,
1312
Mapping,
@@ -16,12 +15,14 @@
1615
Type,
1716
TYPE_CHECKING,
1817
TypeVar,
18+
Union,
1919
)
2020

2121
from libcst._batched_visitor import BatchableCSTVisitor
2222
from libcst._metadata_dependent import (
2323
_T as _MetadataT,
2424
_UNDEFINED_DEFAULT,
25+
LazyValue,
2526
MetadataDependent,
2627
)
2728
from libcst._visitors import CSTVisitor
@@ -36,6 +37,7 @@
3637
# BaseMetadataProvider[int] would be a subtype of BaseMetadataProvider[object], so the
3738
# typevar is covariant.
3839
_ProvidedMetadataT = TypeVar("_ProvidedMetadataT", covariant=True)
40+
MaybeLazyMetadataT = Union[LazyValue[_ProvidedMetadataT], _ProvidedMetadataT]
3941

4042

4143
# We can't use an ABCMeta here, because of metaclass conflicts
@@ -52,16 +54,16 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]):
5254
#
5355
# N.B. This has some typing variance problems. See `set_metadata` for an
5456
# explanation.
55-
_computed: MutableMapping["CSTNode", _ProvidedMetadataT]
57+
_computed: MutableMapping["CSTNode", MaybeLazyMetadataT]
5658

57-
#: Implement gen_cache to indicate the matadata provider depends on cache from external
59+
#: Implement gen_cache to indicate the metadata provider depends on cache from external
5860
#: system. This function will be called by :class:`~libcst.metadata.FullRepoManager`
5961
#: to compute required cache object per file path.
6062
gen_cache: Optional[Callable[[Path, List[str], int], Mapping[str, object]]] = None
6163

6264
def __init__(self, cache: object = None) -> None:
6365
super().__init__()
64-
self._computed = {}
66+
self._computed: MutableMapping["CSTNode", MaybeLazyMetadataT] = {}
6567
if self.gen_cache and cache is None:
6668
# The metadata provider implementation is responsible to store and use cache.
6769
raise Exception(
@@ -71,7 +73,7 @@ def __init__(self, cache: object = None) -> None:
7173

7274
def _gen(
7375
self, wrapper: "MetadataWrapper"
74-
) -> Mapping["CSTNode", _ProvidedMetadataT]:
76+
) -> Mapping["CSTNode", MaybeLazyMetadataT]:
7577
"""
7678
Resolves and returns metadata mapping for the module in ``wrapper``.
7779
@@ -93,11 +95,7 @@ def _gen_impl(self, module: "Module") -> None:
9395
"""
9496
...
9597

96-
# pyre-ignore[46]: The covariant `value` isn't type-safe because we write it to
97-
# pyre: `self._computed`, however we assume that only one subclass in the MRO chain
98-
# pyre: will ever call `set_metadata`, so it's okay for our purposes. There's no
99-
# pyre: sane way to redesign this API so that it doesn't have this problem.
100-
def set_metadata(self, node: "CSTNode", value: _ProvidedMetadataT) -> None:
98+
def set_metadata(self, node: "CSTNode", value: MaybeLazyMetadataT) -> None:
10199
"""
102100
Record a metadata value ``value`` for ``node``.
103101
"""
@@ -107,7 +105,9 @@ def get_metadata(
107105
self,
108106
key: Type["BaseMetadataProvider[_MetadataT]"],
109107
node: "CSTNode",
110-
default: _MetadataT = _UNDEFINED_DEFAULT,
108+
default: Union[
109+
MaybeLazyMetadataT, Type[_UNDEFINED_DEFAULT]
110+
] = _UNDEFINED_DEFAULT,
111111
) -> _MetadataT:
112112
"""
113113
The same method as :func:`~libcst.MetadataDependent.get_metadata` except
@@ -116,9 +116,12 @@ def get_metadata(
116116
"""
117117
if key is type(self):
118118
if default is not _UNDEFINED_DEFAULT:
119-
return cast(_MetadataT, self._computed.get(node, default))
119+
ret = self._computed.get(node, default)
120120
else:
121-
return cast(_MetadataT, self._computed[node])
121+
ret = self._computed[node]
122+
if isinstance(ret, LazyValue):
123+
return ret()
124+
return ret
122125

123126
return super().get_metadata(key, node, default)
124127

libcst/metadata/name_provider.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Collection, List, Mapping, Optional, Union
99

1010
import libcst as cst
11-
from libcst._metadata_dependent import MetadataDependent
11+
from libcst._metadata_dependent import LazyValue, MetadataDependent
1212
from libcst.helpers.module import calculate_module_and_package, ModuleNameAndPackage
1313
from libcst.metadata.base_provider import BatchableMetadataProvider
1414
from libcst.metadata.scope_provider import (
@@ -78,7 +78,9 @@ def __init__(self, provider: "QualifiedNameProvider") -> None:
7878
def on_visit(self, node: cst.CSTNode) -> bool:
7979
scope = self.provider.get_metadata(ScopeProvider, node, None)
8080
if scope:
81-
self.provider.set_metadata(node, scope.get_qualified_names_for(node))
81+
self.provider.set_metadata(
82+
node, LazyValue(lambda: scope.get_qualified_names_for(node))
83+
)
8284
else:
8385
self.provider.set_metadata(node, set())
8486
super().on_visit(node)

libcst/metadata/tests/test_base_provider.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import libcst as cst
99
from libcst import parse_module
10+
from libcst._metadata_dependent import LazyValue
1011
from libcst.metadata import (
1112
BatchableMetadataProvider,
1213
MetadataWrapper,
@@ -75,3 +76,63 @@ def visit_Return(self, node: cst.Return) -> None:
7576
self.assertEqual(metadata[SimpleProvider][pass_], 1)
7677
self.assertEqual(metadata[SimpleProvider][return_], 2)
7778
self.assertEqual(metadata[SimpleProvider][pass_2], 1)
79+
80+
def test_lazy_visitor_provider(self) -> None:
81+
class SimpleLazyProvider(VisitorMetadataProvider[int]):
82+
"""
83+
Sets metadata on every node to a callable that returns 1.
84+
"""
85+
86+
def on_visit(self, node: cst.CSTNode) -> bool:
87+
self.set_metadata(node, LazyValue(lambda: 1))
88+
return True
89+
90+
wrapper = MetadataWrapper(parse_module("pass; return"))
91+
module = wrapper.module
92+
pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
93+
return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]
94+
95+
provider = SimpleLazyProvider()
96+
metadata = provider._gen(wrapper)
97+
98+
# Check access on provider
99+
self.assertEqual(provider.get_metadata(SimpleLazyProvider, module), 1)
100+
self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_), 1)
101+
self.assertEqual(provider.get_metadata(SimpleLazyProvider, return_), 1)
102+
103+
# Check returned mapping
104+
self.assertTrue(isinstance(metadata[module], LazyValue))
105+
self.assertTrue(isinstance(metadata[pass_], LazyValue))
106+
self.assertTrue(isinstance(metadata[return_], LazyValue))
107+
108+
def testlazy_batchable_provider(self) -> None:
109+
class SimpleLazyProvider(BatchableMetadataProvider[int]):
110+
"""
111+
Sets metadata on every pass node to a callable that returns 1,
112+
and every return node to a callable that returns 2.
113+
"""
114+
115+
def visit_Pass(self, node: cst.Pass) -> None:
116+
self.set_metadata(node, LazyValue(lambda: 1))
117+
118+
def visit_Return(self, node: cst.Return) -> None:
119+
self.set_metadata(node, LazyValue(lambda: 2))
120+
121+
wrapper = MetadataWrapper(parse_module("pass; return; pass"))
122+
module = wrapper.module
123+
pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
124+
return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]
125+
pass_2 = cast(cst.SimpleStatementLine, module.body[0]).body[2]
126+
127+
provider = SimpleLazyProvider()
128+
metadata = _gen_batchable(wrapper, [provider])
129+
130+
# Check access on provider
131+
self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_), 1)
132+
self.assertEqual(provider.get_metadata(SimpleLazyProvider, return_), 2)
133+
self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_2), 1)
134+
135+
# Check returned mapping
136+
self.assertTrue(isinstance(metadata[SimpleLazyProvider][pass_], LazyValue))
137+
self.assertTrue(isinstance(metadata[SimpleLazyProvider][return_], LazyValue))
138+
self.assertTrue(isinstance(metadata[SimpleLazyProvider][pass_2], LazyValue))

libcst/metadata/tests/test_name_provider.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import libcst as cst
1212
from libcst import ensure_type
13+
from libcst._nodes.base import CSTNode
1314
from libcst.metadata import (
1415
FullyQualifiedNameProvider,
1516
MetadataWrapper,
@@ -22,11 +23,26 @@
2223
from libcst.testing.utils import data_provider, UnitTest
2324

2425

26+
class QNameVisitor(cst.CSTVisitor):
27+
28+
METADATA_DEPENDENCIES = (QualifiedNameProvider,)
29+
30+
def __init__(self) -> None:
31+
self.qnames: Dict["CSTNode", Collection[QualifiedName]] = {}
32+
33+
def on_visit(self, node: cst.CSTNode) -> bool:
34+
qname = self.get_metadata(QualifiedNameProvider, node)
35+
self.qnames[node] = qname
36+
return True
37+
38+
2539
def get_qualified_name_metadata_provider(
2640
module_str: str,
2741
) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]:
2842
wrapper = MetadataWrapper(cst.parse_module(dedent(module_str)))
29-
return wrapper.module, wrapper.resolve(QualifiedNameProvider)
43+
visitor = QNameVisitor()
44+
wrapper.visit(visitor)
45+
return wrapper.module, visitor.qnames
3046

3147

3248
def get_qualified_names(module_str: str) -> Set[QualifiedName]:
@@ -358,7 +374,7 @@ def f(): pass
358374
else:
359375
import f
360376
import a.b as f
361-
377+
362378
f()
363379
"""
364380
)

0 commit comments

Comments
 (0)