Skip to content

Commit cf01e3e

Browse files
committed
Address PR comments
1 parent d326109 commit cf01e3e

File tree

2 files changed

+21
-59
lines changed

2 files changed

+21
-59
lines changed

sharktank/sharktank/utils/iree.py

Lines changed: 7 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
overload,
1515
TYPE_CHECKING,
1616
Sequence,
17-
Protocol,
18-
runtime_checkable,
1917
)
2018
import os
2119
import sys
@@ -251,46 +249,13 @@ def adapt_torch_module_to_iree(
251249

252250
iree_devices = get_iree_devices(device=device, device_count=device_count)
253251

254-
def load_fn(devices: list[iree.runtime.HalDevice]) -> TorchLikeIreeModule:
255-
vm_module, vm_context, vm_instance = load_iree_module(
256-
module_buff=vmfb_bytes,
257-
devices=devices,
258-
parameters_path=parameters_path,
259-
tensor_parallel_size=len(devices),
260-
)
261-
return TorchLikeIreeModule(vm_module, vm_context, devices)
262-
263-
return with_iree_device_context(load_fn, iree_devices)
264-
265-
266-
@runtime_checkable
267-
class InferenceModule(Protocol):
268-
"""Protocol for inference modules (both torch and IREE).
269-
270-
This defines a common interface that both torch.nn.Module and
271-
TorchLikeIreeModule can satisfy, allowing them to be used
272-
interchangeably in inference code.
273-
274-
Example:
275-
>>> def run_inference(model: InferenceModule, inputs):
276-
... return model(inputs)
277-
>>>
278-
>>> # Works with torch modules
279-
>>> torch_model = MyTorchModel()
280-
>>> run_inference(torch_model, x)
281-
>>>
282-
>>> # Also works with IREE modules
283-
>>> iree_model = adapt_torch_module_to_iree(torch_model, ...)
284-
>>> run_inference(iree_model, x)
285-
"""
286-
287-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
288-
"""Execute the module's forward pass."""
289-
...
290-
291-
def forward(self, *args: Any, **kwargs: Any) -> Any:
292-
"""Execute the module's forward pass explicitly."""
293-
...
252+
vm_module, vm_context, vm_instance = load_iree_module(
253+
module_buff=vmfb_bytes,
254+
devices=iree_devices,
255+
parameters_path=parameters_path,
256+
tensor_parallel_size=len(iree_devices),
257+
)
258+
return TorchLikeIreeModule(vm_module, vm_context, iree_devices)
294259

295260

296261
class TorchLikeIreeModule:

sharktank/tests/utils/iree_test.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ def roundtrip(iree_devices: list[iree.runtime.HalDevice]):
129129

130130
class SimpleModel(nn.Module):
131131
def forward(self, x):
132-
return torch.relu(x) + 1.0
132+
return x + 1
133133

134134

135135
class MultiOutputModel(nn.Module):
136136
def forward(self, x):
137-
return torch.relu(x), torch.tanh(x)
137+
return x + 1, x * 2
138138

139139

140140
class TestCompileTorchModule:
@@ -167,10 +167,10 @@ def test_compilation(self, tmp_path):
167167
class TestAdaptTorchModuleToIree:
168168
"""Tests for adapt_torch_module_to_iree."""
169169

170-
def test_basic_loading_and_execution(self):
170+
def test_basic_loading_and_execution(self, deterministic_random_seed):
171171
"""Test that loaded module executes and produces correct output shape."""
172172
model = SimpleModel()
173-
example_input = torch.randn(2, 32)
173+
example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64)
174174

175175
iree_module = adapt_torch_module_to_iree(
176176
model,
@@ -182,14 +182,12 @@ def test_basic_loading_and_execution(self):
182182
result = iree_module.forward(example_input)
183183
assert isinstance(result, torch.Tensor)
184184
assert result.shape == (2, 32)
185-
assert not torch.isnan(result).any()
186185

187-
def test_output_matches_torch(self):
188-
"""Test that IREE output matches torch output."""
189-
torch.manual_seed(42)
186+
def test_output_matches_torch(self, deterministic_random_seed):
187+
"""Test that IREE output matches torch output"""
190188
model = SimpleModel()
191189
model.eval()
192-
example_input = torch.randn(2, 32)
190+
example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64)
193191

194192
torch_output = model(example_input)
195193
iree_module = adapt_torch_module_to_iree(
@@ -200,7 +198,7 @@ def test_output_matches_torch(self):
200198
)
201199
iree_output = iree_module.forward(example_input)
202200

203-
torch.testing.assert_close(iree_output, torch_output, rtol=1e-4, atol=1e-4)
201+
assert torch.equal(iree_output, torch_output)
204202

205203
def test_multi_output_model(self):
206204
"""Test model with multiple outputs."""
@@ -224,10 +222,10 @@ def test_multi_output_model(self):
224222
class TestOneshotCompileAndRun:
225223
"""Tests for oneshot_iree_run."""
226224

227-
def test_basic_oneshot(self):
225+
def test_basic_oneshot(self, deterministic_random_seed):
228226
"""Test basic one-shot execution."""
229227
model = SimpleModel()
230-
example_input = torch.randn(2, 32)
228+
example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64)
231229

232230
result = oneshot_iree_run(
233231
model,
@@ -237,14 +235,12 @@ def test_basic_oneshot(self):
237235
)
238236
assert isinstance(result, torch.Tensor)
239237
assert result.shape == (2, 32)
240-
assert not torch.isnan(result).any()
241238

242-
def test_oneshot_matches_torch(self):
239+
def test_oneshot_matches_torch(self, deterministic_random_seed):
243240
"""Test that one-shot execution matches torch."""
244-
torch.manual_seed(42)
245241
model = SimpleModel()
246242
model.eval()
247-
example_input = torch.randn(2, 32)
243+
example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64)
248244

249245
torch_output = model(example_input)
250246
iree_output = oneshot_iree_run(
@@ -254,4 +250,5 @@ def test_oneshot_matches_torch(self):
254250
compile_args=COMPILE_FLAGS,
255251
)
256252

257-
torch.testing.assert_close(iree_output, torch_output, rtol=1e-4, atol=1e-4)
253+
# Use exact comparison for integer arithmetic
254+
assert torch.equal(iree_output, torch_output)

0 commit comments

Comments
 (0)