Skip to content

Commit f096820

Browse files
zhxchen17pytorchmergebot
authored andcommitted
[precompile] Detect source code changes for save/load. (pytorch#156432)
Go through all dynamo traced functions and compute checksum for them. While loading a precompilation back to memory, we will always check the checksum and refuse to load when source code changes are detected. Differential Revision: [D76987123](https://our.internmc.facebook.com/intern/diff/D76987123/) Pull Request resolved: pytorch#156432 Approved by: https://github.com/jansel, https://github.com/jamesjwu
1 parent d3efd73 commit f096820

File tree

3 files changed

+151
-4
lines changed

3 files changed

+151
-4
lines changed

test/dynamo/test_package.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Owner(s): ["module: dynamo"]
22

3+
import importlib
34
import os
5+
import sys
6+
import tempfile
47
import unittest
58

69
import torch
@@ -185,6 +188,76 @@ def fn(x):
185188
):
186189
compiled_fn(*args2)
187190

191+
def test_file_change(self):
192+
ctx = DiskDynamoStore()
193+
194+
def import_from_path(module_name, file_path):
195+
spec = importlib.util.spec_from_file_location(module_name, file_path)
196+
module = importlib.util.module_from_spec(spec)
197+
sys.modules[module_name] = module
198+
spec.loader.exec_module(module)
199+
return module
200+
201+
mock_module_add_original = """
202+
def add(x, y):
203+
return x + y
204+
"""
205+
206+
mock_module_add_modified = """
207+
def add(x, y):
208+
return x - y
209+
"""
210+
with tempfile.TemporaryDirectory() as tmp_dir:
211+
mock_module_add_original_path = os.path.join(
212+
tmp_dir, "mock_module_add_original.py"
213+
)
214+
mock_module_add_modified_path = os.path.join(
215+
tmp_dir, "mock_module_add_modified.py"
216+
)
217+
with open(mock_module_add_original_path, "w") as f:
218+
f.write(mock_module_add_original)
219+
with open(mock_module_add_modified_path, "w") as f:
220+
f.write(mock_module_add_modified)
221+
222+
module = import_from_path(
223+
"torch.test_package_helper",
224+
mock_module_add_original_path,
225+
)
226+
227+
def fn(x):
228+
return module.add(x, 1)
229+
230+
args = (torch.randn(3, 2),)
231+
232+
def guard_filter_fn(guards):
233+
return [
234+
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
235+
for guard in guards
236+
]
237+
238+
# Saving
239+
package = CompilePackage(fn)
240+
compiled_fn = torch._dynamo.optimize(
241+
backend="eager", package=package, guard_filter_fn=guard_filter_fn
242+
)(fn)
243+
compiled_fn(*args)
244+
for backend_id, backend in package.cached_backends.items():
245+
ctx.record_eager_backend(backend_id, backend)
246+
ctx.save_package(package, self.path())
247+
248+
module = import_from_path(
249+
"torch.test_package_helper",
250+
mock_module_add_modified_path,
251+
)
252+
with self.assertRaisesRegex(RuntimeError, "Source code changes detected"):
253+
ctx.load_package(fn, self.path())
254+
255+
module = import_from_path(
256+
"torch.test_package_helper",
257+
mock_module_add_original_path,
258+
)
259+
ctx.load_package(fn, self.path())
260+
188261

189262
if __name__ == "__main__":
190263
from torch._dynamo.test_case import run_tests

torch/_dynamo/convert_frame.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,7 @@ def count_args(code: CodeType) -> int:
976976
if package is not None:
977977
assert check_fn.guards_state is not None
978978
package.add_guarded_code(check_fn.guards_state, out_code)
979+
package.add_inlined_source(output.tracing_context.traced_code)
979980

980981
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
981982
annotation_str = "Torch-Compiled Region: " + compile_id_str

torch/_dynamo/package.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import functools
1515
import hashlib
1616
import importlib
17+
import inspect
1718
import logging
1819
import os
1920
import pickle
@@ -96,6 +97,14 @@ class _GuardedCodeCacheEntry:
9697
_FunctionId = NewType("_FunctionId", str) # __resume_at
9798

9899

100+
@dataclasses.dataclass(frozen=True)
101+
class InlinedSource:
102+
module: str
103+
firstlineno: int
104+
lastlineno: int
105+
checksum: str
106+
107+
99108
@dataclasses.dataclass
100109
class _DynamoCodeCacheEntry:
101110
"""
@@ -124,6 +133,7 @@ class _DynamoCodeCacheEntry:
124133
@dataclasses.dataclass
125134
class _DynamoCacheEntry:
126135
codes: list[_DynamoCodeCacheEntry]
136+
inlined_sources: set[InlinedSource]
127137
python_version: str = platform.python_version()
128138
torch_version: str = torch.__version__
129139

@@ -142,6 +152,22 @@ def after_deserialization(self) -> _DynamoCacheEntry:
142152
return pickle.loads(self.content)
143153

144154

155+
def _hash_source(source: str) -> str:
156+
sha256_hash = hashlib.sha256()
157+
sha256_hash.update(source.encode())
158+
return sha256_hash.hexdigest()
159+
160+
161+
def _get_sourcelines(
162+
m: types.ModuleType, firstlineno: int, lastlineno: int
163+
) -> list[str]:
164+
return inspect.getsourcelines(m)[0][firstlineno - 1 : lastlineno - 1]
165+
166+
167+
def _hash_sourcelines(m: types.ModuleType, firstlineno: int, lastlineno: int) -> str:
168+
return _hash_source("".join(_get_sourcelines(m, firstlineno, lastlineno)))
169+
170+
145171
class CompilePackage:
146172
"""
147173
CompilePackage is considered a low level component and should not be directly exposed to
@@ -155,7 +181,12 @@ class CompilePackage:
155181
updates with compiled functions and resume functions.
156182
"""
157183

158-
def __init__(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
184+
def __init__(
185+
self,
186+
fn: Any,
187+
dynamo: Optional[_DynamoCacheEntry] = None,
188+
ignore_inlined_sources: bool = False,
189+
) -> None:
159190
self._innermost_fn = None
160191
self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {}
161192

@@ -164,14 +195,22 @@ def __init__(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
164195

165196
# For debugging/testing purpose only.
166197
self._cached_backends: dict[_BackendId, Any] = {}
198+
self._inlined_sources: set[InlinedSource] = set()
199+
self._resume_codes: set[types.CodeType] = set()
167200

168-
self._initialize(fn, dynamo)
201+
self._initialize(fn, dynamo, ignore_inlined_sources)
169202
self.uninstall()
170203
self.validate()
171204

172-
def _initialize(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
205+
def _initialize(
206+
self,
207+
fn: Any,
208+
dynamo: Optional[_DynamoCacheEntry] = None,
209+
ignore_inlined_sources: bool = False,
210+
) -> None:
173211
from .eval_frame import innermost_fn
174212

213+
self._inlined_sources = set()
175214
self._innermost_fn = innermost_fn(fn)
176215
assert self._innermost_fn is not None
177216
if dynamo is not None:
@@ -184,6 +223,16 @@ def _initialize(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> No
184223
raise RuntimeError(
185224
f"Compile package was created with a different PyTorch version: {dynamo.torch_version}"
186225
)
226+
if not ignore_inlined_sources:
227+
for code in dynamo.inlined_sources:
228+
m = importlib.import_module(code.module)
229+
checksum = _hash_sourcelines(m, code.firstlineno, code.lastlineno)
230+
if checksum != code.checksum:
231+
raise RuntimeError(
232+
f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})"
233+
)
234+
235+
self._inlined_sources = dynamo.inlined_sources
187236

188237
main, *codes = dynamo.codes
189238
self._codes = {self._innermost_fn.__code__: main}
@@ -252,6 +301,27 @@ def add_guarded_code(
252301
)
253302
self._current_entry.guarded_codes.append(guarded_code_entry)
254303

304+
def add_inlined_source(self, sources: list[types.CodeType]) -> None:
305+
for code in sources:
306+
if code in self._resume_codes:
307+
continue
308+
module = inspect.getmodule(code)
309+
if module is None:
310+
continue
311+
source = inspect.getsource(code)
312+
lastlineno = code.co_firstlineno + len(inspect.getsourcelines(code)[0])
313+
assert source == "".join(
314+
_get_sourcelines(module, code.co_firstlineno, lastlineno)
315+
)
316+
self._inlined_sources.add(
317+
InlinedSource(
318+
module=module.__name__,
319+
firstlineno=code.co_firstlineno,
320+
lastlineno=lastlineno,
321+
checksum=_hash_source(source),
322+
)
323+
)
324+
255325
def add_resume_function(
256326
self,
257327
python_code: types.CodeType,
@@ -261,6 +331,7 @@ def add_resume_function(
261331
self._add_function(
262332
python_code, python_module, _FunctionId(name) if name else None
263333
)
334+
self._resume_codes.add(python_code)
264335

265336
def add_import_source(self, alias: str, module_name: str) -> None:
266337
assert self._current_entry is not None
@@ -345,7 +416,9 @@ def install(self, backends: dict[_BackendId, Any]) -> None:
345416

346417
def cache_entry(self) -> _DynamoCacheEntry:
347418
self.validate()
348-
return _DynamoCacheEntry(codes=list(self._codes.values()))
419+
return _DynamoCacheEntry(
420+
codes=list(self._codes.values()), inlined_sources=self._inlined_sources
421+
)
349422

350423

351424
@CacheArtifactFactory.register

0 commit comments

Comments
 (0)