@@ -129,12 +129,12 @@ def roundtrip(iree_devices: list[iree.runtime.HalDevice]):
129129
130130class SimpleModel (nn .Module ):
131131 def forward (self , x ):
132- return torch . relu ( x ) + 1.0
132+ return x + 1
133133
134134
135135class MultiOutputModel (nn .Module ):
136136 def forward (self , x ):
137- return torch . relu ( x ), torch . tanh ( x )
137+ return x + 1 , x * 2
138138
139139
140140class TestCompileTorchModule :
@@ -167,10 +167,10 @@ def test_compilation(self, tmp_path):
167167class TestAdaptTorchModuleToIree :
168168 """Tests for adapt_torch_module_to_iree."""
169169
170- def test_basic_loading_and_execution (self ):
170+ def test_basic_loading_and_execution (self , deterministic_random_seed ):
171171 """Test that loaded module executes and produces correct output shape."""
172172 model = SimpleModel ()
173- example_input = torch .randn ( 2 , 32 )
173+ example_input = torch .randint ( 0 , 100 , ( 2 , 32 ), dtype = torch . int64 )
174174
175175 iree_module = adapt_torch_module_to_iree (
176176 model ,
@@ -182,14 +182,12 @@ def test_basic_loading_and_execution(self):
182182 result = iree_module .forward (example_input )
183183 assert isinstance (result , torch .Tensor )
184184 assert result .shape == (2 , 32 )
185- assert not torch .isnan (result ).any ()
186185
187- def test_output_matches_torch (self ):
188- """Test that IREE output matches torch output."""
189- torch .manual_seed (42 )
186+ def test_output_matches_torch (self , deterministic_random_seed ):
187+ """Test that IREE output matches torch output"""
190188 model = SimpleModel ()
191189 model .eval ()
192- example_input = torch .randn ( 2 , 32 )
190+ example_input = torch .randint ( 0 , 100 , ( 2 , 32 ), dtype = torch . int64 )
193191
194192 torch_output = model (example_input )
195193 iree_module = adapt_torch_module_to_iree (
@@ -200,7 +198,7 @@ def test_output_matches_torch(self):
200198 )
201199 iree_output = iree_module .forward (example_input )
202200
203- torch .testing . assert_close (iree_output , torch_output , rtol = 1e-4 , atol = 1e-4 )
201+ assert torch .equal (iree_output , torch_output )
204202
205203 def test_multi_output_model (self ):
206204 """Test model with multiple outputs."""
@@ -224,10 +222,10 @@ def test_multi_output_model(self):
224222class TestOneshotCompileAndRun :
225223 """Tests for oneshot_iree_run."""
226224
227- def test_basic_oneshot (self ):
225+ def test_basic_oneshot (self , deterministic_random_seed ):
228226 """Test basic one-shot execution."""
229227 model = SimpleModel ()
230- example_input = torch .randn ( 2 , 32 )
228+ example_input = torch .randint ( 0 , 100 , ( 2 , 32 ), dtype = torch . int64 )
231229
232230 result = oneshot_iree_run (
233231 model ,
@@ -237,14 +235,12 @@ def test_basic_oneshot(self):
237235 )
238236 assert isinstance (result , torch .Tensor )
239237 assert result .shape == (2 , 32 )
240- assert not torch .isnan (result ).any ()
241238
242- def test_oneshot_matches_torch (self ):
239+ def test_oneshot_matches_torch (self , deterministic_random_seed ):
243240 """Test that one-shot execution matches torch."""
244- torch .manual_seed (42 )
245241 model = SimpleModel ()
246242 model .eval ()
247- example_input = torch .randn ( 2 , 32 )
243+ example_input = torch .randint ( 0 , 100 , ( 2 , 32 ), dtype = torch . int64 )
248244
249245 torch_output = model (example_input )
250246 iree_output = oneshot_iree_run (
@@ -254,4 +250,5 @@ def test_oneshot_matches_torch(self):
254250 compile_args = COMPILE_FLAGS ,
255251 )
256252
257- torch .testing .assert_close (iree_output , torch_output , rtol = 1e-4 , atol = 1e-4 )
253+ # Use exact comparison for integer arithmetic
254+ assert torch .equal (iree_output , torch_output )
0 commit comments