1414import functools
1515import hashlib
1616import importlib
17+ import inspect
1718import logging
1819import os
1920import pickle
@@ -96,6 +97,14 @@ class _GuardedCodeCacheEntry:
9697_FunctionId = NewType ("_FunctionId" , str ) # __resume_at
9798
9899
100+ @dataclasses .dataclass (frozen = True )
101+ class InlinedSource :
102+ module : str
103+ firstlineno : int
104+ lastlineno : int
105+ checksum : str
106+
107+
99108@dataclasses .dataclass
100109class _DynamoCodeCacheEntry :
101110 """
@@ -124,6 +133,7 @@ class _DynamoCodeCacheEntry:
124133@dataclasses .dataclass
125134class _DynamoCacheEntry :
126135 codes : list [_DynamoCodeCacheEntry ]
136+ inlined_sources : set [InlinedSource ]
127137 python_version : str = platform .python_version ()
128138 torch_version : str = torch .__version__
129139
@@ -142,6 +152,22 @@ def after_deserialization(self) -> _DynamoCacheEntry:
142152 return pickle .loads (self .content )
143153
144154
155+ def _hash_source (source : str ) -> str :
156+ sha256_hash = hashlib .sha256 ()
157+ sha256_hash .update (source .encode ())
158+ return sha256_hash .hexdigest ()
159+
160+
161+ def _get_sourcelines (
162+ m : types .ModuleType , firstlineno : int , lastlineno : int
163+ ) -> list [str ]:
164+ return inspect .getsourcelines (m )[0 ][firstlineno - 1 : lastlineno - 1 ]
165+
166+
167+ def _hash_sourcelines (m : types .ModuleType , firstlineno : int , lastlineno : int ) -> str :
168+ return _hash_source ("" .join (_get_sourcelines (m , firstlineno , lastlineno )))
169+
170+
145171class CompilePackage :
146172 """
147173 CompilePackage is considered a low level component and should not be directly exposed to
@@ -155,7 +181,12 @@ class CompilePackage:
155181 updates with compiled functions and resume functions.
156182 """
157183
158- def __init__ (self , fn : Any , dynamo : Optional [_DynamoCacheEntry ] = None ) -> None :
184+ def __init__ (
185+ self ,
186+ fn : Any ,
187+ dynamo : Optional [_DynamoCacheEntry ] = None ,
188+ ignore_inlined_sources : bool = False ,
189+ ) -> None :
159190 self ._innermost_fn = None
160191 self ._codes : dict [types .CodeType , _DynamoCodeCacheEntry ] = {}
161192
@@ -164,14 +195,22 @@ def __init__(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
164195
165196 # For debugging/testing purpose only.
166197 self ._cached_backends : dict [_BackendId , Any ] = {}
198+ self ._inlined_sources : set [InlinedSource ] = set ()
199+ self ._resume_codes : set [types .CodeType ] = set ()
167200
168- self ._initialize (fn , dynamo )
201+ self ._initialize (fn , dynamo , ignore_inlined_sources )
169202 self .uninstall ()
170203 self .validate ()
171204
172- def _initialize (self , fn : Any , dynamo : Optional [_DynamoCacheEntry ] = None ) -> None :
205+ def _initialize (
206+ self ,
207+ fn : Any ,
208+ dynamo : Optional [_DynamoCacheEntry ] = None ,
209+ ignore_inlined_sources : bool = False ,
210+ ) -> None :
173211 from .eval_frame import innermost_fn
174212
213+ self ._inlined_sources = set ()
175214 self ._innermost_fn = innermost_fn (fn )
176215 assert self ._innermost_fn is not None
177216 if dynamo is not None :
@@ -184,6 +223,16 @@ def _initialize(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> No
184223 raise RuntimeError (
185224 f"Compile package was created with a different PyTorch version: { dynamo .torch_version } "
186225 )
226+ if not ignore_inlined_sources :
227+ for code in dynamo .inlined_sources :
228+ m = importlib .import_module (code .module )
229+ checksum = _hash_sourcelines (m , code .firstlineno , code .lastlineno )
230+ if checksum != code .checksum :
231+ raise RuntimeError (
232+ f"Source code changes detected for { code .module } (line { code .firstlineno } - line { code .lastlineno } )"
233+ )
234+
235+ self ._inlined_sources = dynamo .inlined_sources
187236
188237 main , * codes = dynamo .codes
189238 self ._codes = {self ._innermost_fn .__code__ : main }
@@ -252,6 +301,27 @@ def add_guarded_code(
252301 )
253302 self ._current_entry .guarded_codes .append (guarded_code_entry )
254303
304+ def add_inlined_source (self , sources : list [types .CodeType ]) -> None :
305+ for code in sources :
306+ if code in self ._resume_codes :
307+ continue
308+ module = inspect .getmodule (code )
309+ if module is None :
310+ continue
311+ source = inspect .getsource (code )
312+ lastlineno = code .co_firstlineno + len (inspect .getsourcelines (code )[0 ])
313+ assert source == "" .join (
314+ _get_sourcelines (module , code .co_firstlineno , lastlineno )
315+ )
316+ self ._inlined_sources .add (
317+ InlinedSource (
318+ module = module .__name__ ,
319+ firstlineno = code .co_firstlineno ,
320+ lastlineno = lastlineno ,
321+ checksum = _hash_source (source ),
322+ )
323+ )
324+
255325 def add_resume_function (
256326 self ,
257327 python_code : types .CodeType ,
@@ -261,6 +331,7 @@ def add_resume_function(
261331 self ._add_function (
262332 python_code , python_module , _FunctionId (name ) if name else None
263333 )
334+ self ._resume_codes .add (python_code )
264335
265336 def add_import_source (self , alias : str , module_name : str ) -> None :
266337 assert self ._current_entry is not None
@@ -345,7 +416,9 @@ def install(self, backends: dict[_BackendId, Any]) -> None:
345416
346417 def cache_entry (self ) -> _DynamoCacheEntry :
347418 self .validate ()
348- return _DynamoCacheEntry (codes = list (self ._codes .values ()))
419+ return _DynamoCacheEntry (
420+ codes = list (self ._codes .values ()), inlined_sources = self ._inlined_sources
421+ )
349422
350423
351424@CacheArtifactFactory .register
0 commit comments