From 2cbb21ff895b662624a5006cf080301d9ffa7924 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Thu, 16 Oct 2025 00:48:29 +0000 Subject: [PATCH 1/6] Modulify v0 --- sharktank/sharktank/utils/inference_module.py | 117 ++++++ sharktank/sharktank/utils/iree.py | 68 ++++ .../sharktank/utils/iree_module_builder.py | 226 +++++++++++ .../tests/utils/iree_module_builder_test.py | 375 ++++++++++++++++++ 4 files changed, 786 insertions(+) create mode 100644 sharktank/sharktank/utils/inference_module.py create mode 100644 sharktank/sharktank/utils/iree_module_builder.py create mode 100644 sharktank/tests/utils/iree_module_builder_test.py diff --git a/sharktank/sharktank/utils/inference_module.py b/sharktank/sharktank/utils/inference_module.py new file mode 100644 index 00000000000..4545756077f --- /dev/null +++ b/sharktank/sharktank/utils/inference_module.py @@ -0,0 +1,117 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Protocol-based interface for unified inference modules (torch and IREE).""" + +from typing import Any, Protocol, runtime_checkable, Union, List +import torch +import iree.runtime + + +@runtime_checkable +class InferenceModule(Protocol): + """Protocol for inference modules (both torch and IREE). + + This defines a common interface that both torch.nn.Module and + TorchLikeIreeModule can satisfy, allowing them to be used + interchangeably in inference code. + + Example: + >>> def run_inference(model: InferenceModule, inputs): + ... return model(inputs) + >>> + >>> # Works with torch modules + >>> torch_model = MyTorchModel() + >>> run_inference(torch_model, x) + >>> + >>> # Also works with IREE modules + >>> iree_model = load_torch_module_as_iree(torch_model, ...) + >>> run_inference(iree_model, x) + """ + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Execute the module's forward pass.""" + ... + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Execute the module's forward pass explicitly.""" + ... + + +def is_inference_module(obj: Any) -> bool: + """Check if an object satisfies the InferenceModule protocol. + + Args: + obj: Object to check + + Returns: + True if the object has both __call__ and forward methods + """ + return isinstance(obj, InferenceModule) + + +def get_inference_device( + module: InferenceModule, +) -> Union[torch.device, iree.runtime.HalDevice]: + """Get the device a module is running on. + + Args: + module: An inference module (torch or IREE) + + Returns: + The device the module is on (torch.device for torch modules, + iree.runtime.HalDevice for IREE modules) + + Example: + >>> torch_model = MyTorchModel().to("cuda") + >>> get_inference_device(torch_model) # torch.device('cuda:0') + >>> + >>> iree_model = load_torch_module_as_iree(torch_model, device="local-task") + >>> get_inference_device(iree_model) # + """ + if isinstance(module, torch.nn.Module): + # For torch modules, get device from first parameter + try: + return next(module.parameters()).device + except StopIteration: + # No parameters, assume CPU + return torch.device("cpu") + elif hasattr(module, "devices"): + # For IREE modules (TorchLikeIreeModule) + return module.devices[0] + else: + raise TypeError(f"Unknown module type: {type(module)}") + + +def inference_module_call( + module: InferenceModule, + *args: Any, + method: str = "forward", + **kwargs: Any, +) -> Any: + """Call a method on an inference module in a unified way. + + This handles the slight differences between torch and IREE modules + when calling methods other than forward(). + + Args: + module: The inference module to call + *args: Positional arguments + method: Name of the method to call (default: "forward") + **kwargs: Keyword arguments + + Returns: + The module's output + + Example: + >>> # Call forward + >>> output = inference_module_call(model, input_tensor) + >>> + >>> # Call a custom method + >>> output = inference_module_call(model, x, method="generate") + """ + method_fn = getattr(module, method) + return method_fn(*args, **kwargs) diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index daaa11124c6..9c22e9d0a1d 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -753,3 +753,71 @@ def run_model_with_iree_run_module( input_args = [f"--input={arg.strip()}" for arg in input_args] cmd += input_args subprocess.check_call(cmd, **subprocess_run_kwargs) + + +class TypePreservingIreeModule(TorchLikeIreeModule): + """Extension of TorchLikeIreeModule that preserves output types. + + This addresses the limitation where output type information (ShardedTensor, + InferenceTensor, single vs tuple, etc.) is lost during IREE execution. + + You provide an output_type_mapper function that transforms the flat tuple + of tensors returned by IREE back into the original structure. + + Example: + >>> # Simple case: torch module returns single tensor, but IREE returns tuple + >>> def unwrap_single(outputs): + ... return outputs[0] + >>> + >>> iree_module = TypePreservingIreeModule( + ... vm_module, vm_context, devices, + ... output_type_mapper=unwrap_single + ... ) + >>> result = iree_module.forward(x) # Returns single tensor, not tuple + + >>> # Complex case: reconstruct ShardedTensor + >>> def reconstruct_sharded(outputs): + ... return SplitPrimitiveTensor(ts=outputs, shard_dim=1) + >>> + >>> iree_module = TypePreservingIreeModule( + ... vm_module, vm_context, devices, + ... output_type_mapper=reconstruct_sharded + ... ) + >>> result = iree_module.forward(x) # Returns ShardedTensor + """ + + def __init__( + self, + module: iree.runtime.VmModule, + vm_context: iree.runtime.VmContext, + devices: List[iree.runtime.HalDevice], + output_type_mapper: Callable[[tuple[torch.Tensor, ...]], Any], + ): + """Initialize with an output type mapper. + + Args: + module: IREE VmModule + vm_context: IREE VmContext + devices: List of IREE HalDevices + output_type_mapper: Function that transforms flat tensor tuple + back to the original output structure + """ + super().__init__(module, vm_context, devices) + self.output_type_mapper = output_type_mapper + + def __getattr__(self, name: str) -> Any: + """Override to apply output type mapping.""" + # Avoid recursion on our own attributes + if name == "output_type_mapper": + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + # Get the base method from parent + base_method = super().__getattr__(name) + + def wrapped_method(*args, **kwargs): + result = base_method(*args, **kwargs) + return self.output_type_mapper(result) + + return wrapped_method diff --git a/sharktank/sharktank/utils/iree_module_builder.py b/sharktank/sharktank/utils/iree_module_builder.py new file mode 100644 index 00000000000..9e9bd820678 --- /dev/null +++ b/sharktank/sharktank/utils/iree_module_builder.py @@ -0,0 +1,226 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""High-level utilities for converting torch modules to IREE modules.""" + +from typing import Any, Optional, Sequence +from pathlib import Path +import tempfile + +import torch +from iree.turbine import aot +from iree.turbine.aot import FxProgramsBuilder +import iree.runtime + +from .iree import ( + TorchLikeIreeModule, + TypePreservingIreeModule, + get_iree_devices, + load_iree_module, + with_iree_device_context, +) + + +def compile_torch_module_to_iree( + module: torch.nn.Module, + example_args: tuple[Any, ...] = tuple(), + example_kwargs: dict[str, Any] = None, + function_name: str = "forward", + compile_args: Sequence[str] = None, + save_mlir_to: Optional[Path] = None, + save_vmfb_to: Optional[Path] = None, +) -> memoryview: + """Compile a torch module to IREE VMFB bytecode. + + Args: + module: The torch.nn.Module to compile + example_args: Example positional arguments for tracing + example_kwargs: Example keyword arguments for tracing + function_name: Name of the function to export (default: "forward") + compile_args: Additional IREE compiler flags + save_mlir_to: Optional path to save the exported MLIR + save_vmfb_to: Optional path to save the compiled VMFB + + Returns: + A memoryview of the compiled VMFB bytecode + + Example: + >>> model = MyTorchModel() + >>> example_input = torch.randn(1, 3, 224, 224) + >>> vmfb_bytes = compile_torch_module_to_iree( + ... model, + ... example_args=(example_input,), + ... compile_args=["--iree-hal-target-device=local-task"] + ... ) + """ + if example_kwargs is None: + example_kwargs = {} + + # Export to MLIR using turbine + fxb = FxProgramsBuilder(module) + + @fxb.export_program( + name=function_name, args=example_args, kwargs=example_kwargs, strict=False + ) + def _(module, *args, **kwargs): + return getattr(module, function_name)(*args, **kwargs) + + export_output = aot.export(fxb) + + # Save MLIR if requested + if save_mlir_to is not None: + export_output.save_mlir(save_mlir_to) + + # Set compiler flags + if compile_args is not None: + export_output.session.set_flags(*compile_args) + + # Compile to VMFB + if save_vmfb_to is not None: + export_output.compile(save_to=str(save_vmfb_to), target_backends=None) + with open(save_vmfb_to, "rb") as f: + return memoryview(f.read()) + else: + return export_output.compile(save_to=None, target_backends=None).map_memory() + + +def load_torch_module_as_iree( + module: torch.nn.Module, + example_args: tuple[Any, ...] = tuple(), + example_kwargs: dict[str, Any] = None, + function_name: str = "forward", + device: str | list[str] = "local-task", + device_count: int | None = None, + compile_args: Sequence[str] = None, + parameters_path: Optional[str] = None, + save_mlir_to: Optional[Path] = None, + save_vmfb_to: Optional[Path] = None, + output_type_mapper: Optional[Any] = None, +) -> TorchLikeIreeModule: + """Compile a torch module to IREE and load it as a TorchLikeIreeModule. + + This is a high-level convenience function that combines export, compilation, + and loading into a single call. + + Args: + module: The torch.nn.Module to compile + example_args: Example positional arguments for tracing + example_kwargs: Example keyword arguments for tracing + function_name: Name of the function to export (default: "forward") + device: IREE device(s) to load on (e.g., "local-task", "hip://0") + device_count: Number of devices to use (for multi-device scenarios) + compile_args: Additional IREE compiler flags + parameters_path: Optional path to external parameters (IRPA file) + save_mlir_to: Optional path to save the exported MLIR + save_vmfb_to: Optional path to save the compiled VMFB + output_type_mapper: Optional function to transform IREE's flat tensor tuple + back to the original output structure. If provided, returns a + TypePreservingIreeModule (subclass of TorchLikeIreeModule). + + Returns: + A TorchLikeIreeModule that can be called like the original torch module. + If output_type_mapper is provided, returns TypePreservingIreeModule. + + Example: + >>> model = MyTorchModel() + >>> example_input = torch.randn(1, 3, 224, 224) + >>> iree_model = load_torch_module_as_iree( + ... model, + ... example_args=(example_input,), + ... device="local-task" + ... ) + >>> output = iree_model.forward(example_input) + >>> + >>> # With type preservation (returns single tensor instead of tuple) + >>> def unwrap(outputs): return outputs[0] + >>> iree_model = load_torch_module_as_iree( + ... model, + ... example_args=(example_input,), + ... output_type_mapper=unwrap + ... ) + >>> output = iree_model.forward(example_input) # Single tensor, not tuple + """ + # Compile the module + vmfb_bytes = compile_torch_module_to_iree( + module=module, + example_args=example_args, + example_kwargs=example_kwargs, + function_name=function_name, + compile_args=compile_args, + save_mlir_to=save_mlir_to, + save_vmfb_to=save_vmfb_to, + ) + + # Get devices + iree_devices = get_iree_devices(device=device, device_count=device_count) + + # Load the module + def load_fn(devices: list[iree.runtime.HalDevice]) -> TorchLikeIreeModule: + vm_module, vm_context, vm_instance = load_iree_module( + module_buff=vmfb_bytes, + devices=devices, + parameters_path=parameters_path, + tensor_parallel_size=len(devices), + ) + if output_type_mapper is not None: + return TypePreservingIreeModule( + vm_module, vm_context, devices, output_type_mapper + ) + else: + return TorchLikeIreeModule(vm_module, vm_context, devices) + + return with_iree_device_context(load_fn, iree_devices) + + +def oneshot_compile_and_run( + module: torch.nn.Module, + args: tuple[Any, ...] = tuple(), + kwargs: dict[str, Any] = None, + function: str = "forward", + device: str | list[str] = "local-task", + device_count: int | None = None, + compile_args: Sequence[str] = None, +) -> tuple[torch.Tensor, ...]: + """One-shot function: export, compile, load, and run in one call. + + This is useful for quick testing and benchmarking. For production use, + prefer load_torch_module_as_iree() to reuse the compiled module. + + Args: + module: The torch.nn.Module to run + args: Positional arguments to pass to the function + kwargs: Keyword arguments to pass to the function + function: Name of the function to call (default: "forward") + device: IREE device(s) to run on + device_count: Number of devices to use + compile_args: Additional IREE compiler flags + + Returns: + Tuple of output tensors + + Example: + >>> model = MyTorchModel() + >>> input_tensor = torch.randn(1, 3, 224, 224) + >>> outputs = oneshot_compile_and_run( + ... model, + ... args=(input_tensor,), + ... device="local-task" + ... ) + """ + if kwargs is None: + kwargs = {} + + iree_module = load_torch_module_as_iree( + module=module, + example_args=args, + example_kwargs=kwargs, + function_name=function, + device=device, + device_count=device_count, + compile_args=compile_args, + ) + + return getattr(iree_module, function)(*args, **kwargs) diff --git a/sharktank/tests/utils/iree_module_builder_test.py b/sharktank/tests/utils/iree_module_builder_test.py new file mode 100644 index 00000000000..0d02e695467 --- /dev/null +++ b/sharktank/tests/utils/iree_module_builder_test.py @@ -0,0 +1,375 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Tests for high-level IREE module builder utilities.""" + +import pytest +import torch +import torch.nn as nn + +from sharktank.utils.iree_module_builder import ( + compile_torch_module_to_iree, + load_torch_module_as_iree, + oneshot_compile_and_run, +) +from sharktank.utils.iree import TypePreservingIreeModule, TorchLikeIreeModule +from sharktank.types import SplitPrimitiveTensor + + +class SimpleModel(nn.Module): + """Simple test model.""" + + def __init__(self, hidden_size=64): + super().__init__() + self.fc1 = nn.Linear(32, hidden_size) + self.fc2 = nn.Linear(hidden_size, 10) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + return self.fc2(x) + + +class MultiOutputModel(nn.Module): + """Model that returns multiple outputs.""" + + def __init__(self): + super().__init__() + self.fc = nn.Linear(32, 64) + + def forward(self, x): + h = self.fc(x) + return torch.relu(h), torch.tanh(h) + + +class TestCompileTorchModule: + """Tests for compile_torch_module_to_iree.""" + + def test_basic_compilation(self): + """Test basic compilation without saving artifacts.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + + vmfb_bytes = compile_torch_module_to_iree( + model, + example_args=(example_input,), + compile_args=["--iree-hal-target-device=local-task"], + ) + + assert isinstance(vmfb_bytes, memoryview) + assert len(vmfb_bytes) > 0 + + def test_compilation_with_save_mlir(self, tmp_path): + """Test compilation with MLIR saving.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + mlir_path = tmp_path / "model.mlir" + + vmfb_bytes = compile_torch_module_to_iree( + model, + example_args=(example_input,), + compile_args=["--iree-hal-target-device=local-task"], + save_mlir_to=mlir_path, + ) + + assert mlir_path.exists() + assert mlir_path.stat().st_size > 0 + assert isinstance(vmfb_bytes, memoryview) + + def test_compilation_with_save_vmfb(self, tmp_path): + """Test compilation with VMFB saving.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + vmfb_path = tmp_path / "model.vmfb" + + vmfb_bytes = compile_torch_module_to_iree( + model, + example_args=(example_input,), + compile_args=["--iree-hal-target-device=local-task"], + save_vmfb_to=vmfb_path, + ) + + assert vmfb_path.exists() + assert vmfb_path.stat().st_size > 0 + assert isinstance(vmfb_bytes, memoryview) + + def test_compilation_with_kwargs(self): + """Test compilation with keyword arguments.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + + vmfb_bytes = compile_torch_module_to_iree( + model, + example_args=tuple(), + example_kwargs={"x": example_input}, + compile_args=["--iree-hal-target-device=local-task"], + ) + + assert isinstance(vmfb_bytes, memoryview) + assert len(vmfb_bytes) > 0 + + +class TestLoadTorchModuleAsIree: + """Tests for load_torch_module_as_iree.""" + + def test_basic_loading_and_execution(self): + """Test that loaded module executes and produces correct output shape.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + + iree_module = load_torch_module_as_iree( + model, + example_args=(example_input,), + device="local-task", + compile_args=["--iree-hal-target-device=local-task"], + ) + + result = iree_module.forward(example_input) + # Verify output structure + assert isinstance(result, tuple) + assert len(result) == 1 + assert result[0].shape == (2, 10) + # Verify it's actually a tensor with values + assert not torch.isnan(result[0]).any() + + def test_output_matches_torch(self): + """Test that IREE output matches torch output.""" + torch.manual_seed(42) + model = SimpleModel() + model.eval() + example_input = torch.randn(2, 32) + + # Get torch output + with torch.no_grad(): + torch_output = model(example_input) + + # Get IREE output + iree_module = load_torch_module_as_iree( + model, + example_args=(example_input,), + device="local-task", + compile_args=["--iree-hal-target-device=local-task"], + ) + iree_output = iree_module.forward(example_input) + + # Compare + torch.testing.assert_close(iree_output[0], torch_output, rtol=1e-4, atol=1e-4) + + def test_multi_output_model(self): + """Test model with multiple outputs.""" + model = MultiOutputModel() + example_input = torch.randn(2, 32) + + iree_module = load_torch_module_as_iree( + model, + example_args=(example_input,), + device="local-task", + compile_args=["--iree-hal-target-device=local-task"], + ) + + result = iree_module.forward(example_input) + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0].shape == (2, 64) + assert result[1].shape == (2, 64) + + +class TestTypePreservingIreeModule: + """Tests for TypePreservingIreeModule and output_type_mapper.""" + + def test_output_type_mapper_changes_return_structure(self): + """Test that output_type_mapper successfully transforms return type.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + + # Without type mapper - returns tuple + iree_module = load_torch_module_as_iree( + model, + example_args=(example_input,), + device="local-task", + compile_args=["--iree-hal-target-device=local-task"], + ) + result_tuple = iree_module.forward(example_input) + + # With type mapper - returns unwrapped single tensor + def unwrap(outputs): + return outputs[0] + + iree_module_unwrapped = load_torch_module_as_iree( + model, + example_args=(example_input,), + device="local-task", + compile_args=["--iree-hal-target-device=local-task"], + output_type_mapper=unwrap, + ) + result_single = iree_module_unwrapped.forward(example_input) + + # Verify the transformation worked + assert isinstance(result_tuple, tuple) + assert isinstance(result_single, torch.Tensor) + assert not isinstance(result_single, tuple) + torch.testing.assert_close(result_single, result_tuple[0]) + + def test_reconstruct_sharded_tensor(self): + """Test reconstructing a ShardedTensor-like output.""" + model = MultiOutputModel() + example_input = torch.randn(2, 32) + + # Simulate sharded output reconstruction + def reconstruct_sharded(outputs): + # Treat the two outputs as shards + return SplitPrimitiveTensor(ts=outputs, shard_dim=1) + + iree_module = load_torch_module_as_iree( + model, + example_args=(example_input,), + device="local-task", + compile_args=["--iree-hal-target-device=local-task"], + output_type_mapper=reconstruct_sharded, + ) + + result = iree_module.forward(example_input) + assert isinstance(result, SplitPrimitiveTensor) + assert result.shard_dim == 1 + assert len(result.shards) == 2 + + def test_custom_transformation(self): + """Test custom output transformation.""" + model = MultiOutputModel() + example_input = torch.randn(2, 32) + + # Custom transformer: return dict + def to_dict(outputs): + return {"relu": outputs[0], "tanh": outputs[1]} + + iree_module = load_torch_module_as_iree( + model, + example_args=(example_input,), + device="local-task", + compile_args=["--iree-hal-target-device=local-task"], + output_type_mapper=to_dict, + ) + + result = iree_module.forward(example_input) + assert isinstance(result, dict) + assert "relu" in result + assert "tanh" in result + assert result["relu"].shape == (2, 64) + assert result["tanh"].shape == (2, 64) + + +class TestOneshotCompileAndRun: + """Tests for oneshot_compile_and_run.""" + + def test_basic_oneshot(self): + """Test basic one-shot execution.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + + result = oneshot_compile_and_run( + model, + args=(example_input,), + device="local-task", + compile_args=("--iree-hal-target-device=local-task",), + ) + + assert isinstance(result, tuple) + assert len(result) == 1 + assert result[0].shape == (2, 10) + + def test_oneshot_matches_torch(self): + """Test that one-shot execution matches torch.""" + torch.manual_seed(42) + model = SimpleModel() + model.eval() + example_input = torch.randn(2, 32) + + # Torch output + with torch.no_grad(): + torch_output = model(example_input) + + # IREE output + iree_output = oneshot_compile_and_run( + model, + args=(example_input,), + device="local-task", + compile_args=("--iree-hal-target-device=local-task",), + ) + + torch.testing.assert_close(iree_output[0], torch_output, rtol=1e-4, atol=1e-4) + + def test_oneshot_with_kwargs(self): + """Test one-shot with keyword arguments.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + + result = oneshot_compile_and_run( + model, + args=tuple(), + kwargs={"x": example_input}, + device="local-task", + compile_args=("--iree-hal-target-device=local-task",), + ) + + assert isinstance(result, tuple) + assert len(result) == 1 + assert result[0].shape == (2, 10) + + +class TestInferenceModuleProtocol: + """Tests for Protocol-based usage.""" + + def test_protocol_compatibility(self): + """Test that both torch and IREE modules work with generic inference code.""" + from sharktank.utils.inference_module import InferenceModule + + def run_model(model: InferenceModule, input_data): + """Generic function that works with any InferenceModule.""" + return model.forward(input_data) + + torch_model = SimpleModel() + example_input = torch.randn(2, 32) + + # Works with torch module + torch_model.eval() + with torch.no_grad(): + torch_result = run_model(torch_model, example_input) + assert torch_result.shape == (2, 10) + + # Works with IREE module + iree_model = load_torch_module_as_iree( + torch_model, + example_args=(example_input,), + device="local-task", + compile_args=["--iree-hal-target-device=local-task"], + output_type_mapper=lambda x: x[0], # Unwrap for compatibility + ) + iree_result = run_model(iree_model, example_input) + assert iree_result.shape == (2, 10) + + # Results should be close + torch.testing.assert_close(iree_result, torch_result, rtol=1e-4, atol=1e-4) + + def test_call_vs_forward_equivalence(self): + """Test that calling via __call__ produces same result as forward().""" + model = SimpleModel() + example_input = torch.randn(2, 32) + + iree_model = load_torch_module_as_iree( + model, + example_args=(example_input,), + device="local-task", + compile_args=["--iree-hal-target-device=local-task"], + output_type_mapper=lambda x: x[0], + ) + + result1 = iree_model(example_input) + result2 = iree_model.forward(example_input) + + torch.testing.assert_close(result1, result2) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From b654349cc725a9b538d1a678b697a8f3700dac9b Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Thu, 16 Oct 2025 01:36:18 +0000 Subject: [PATCH 2/6] Fix tests --- sharktank/sharktank/utils/iree.py | 4 + .../tests/utils/iree_module_builder_test.py | 132 +++++++++++++----- 2 files changed, 98 insertions(+), 38 deletions(-) diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index 9c22e9d0a1d..ba9c01ab255 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -127,6 +127,10 @@ def __init__( self.vm_context = vm_context self.devices = devices + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Execute forward method (torch.nn.Module compatibility).""" + return self.forward(*args, **kwargs) + def __getattr__(self, name: str) -> Any: def f( *args: tuple[Any, ...], **kwargs: dict[str, Any] diff --git a/sharktank/tests/utils/iree_module_builder_test.py b/sharktank/tests/utils/iree_module_builder_test.py index 0d02e695467..165761fbdd8 100644 --- a/sharktank/tests/utils/iree_module_builder_test.py +++ b/sharktank/tests/utils/iree_module_builder_test.py @@ -4,7 +4,14 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Tests for high-level IREE module builder utilities.""" +"""Tests for high-level IREE module builder utilities. + +These are integration tests that compile and run IREE modules. +They can be slow, so they're marked with @pytest.mark.iree_integration. + +Run with: pytest -m iree_integration +Skip with: pytest -m "not iree_integration" +""" import pytest import torch @@ -18,6 +25,8 @@ from sharktank.utils.iree import TypePreservingIreeModule, TorchLikeIreeModule from sharktank.types import SplitPrimitiveTensor +pytestmark = pytest.mark.iree_integration + class SimpleModel(nn.Module): """Simple test model.""" @@ -55,10 +64,13 @@ def test_basic_compilation(self): vmfb_bytes = compile_torch_module_to_iree( model, example_args=(example_input,), - compile_args=["--iree-hal-target-device=local-task"], + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], ) - assert isinstance(vmfb_bytes, memoryview) + # Verify we got bytecode assert len(vmfb_bytes) > 0 def test_compilation_with_save_mlir(self, tmp_path): @@ -70,13 +82,16 @@ def test_compilation_with_save_mlir(self, tmp_path): vmfb_bytes = compile_torch_module_to_iree( model, example_args=(example_input,), - compile_args=["--iree-hal-target-device=local-task"], + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], save_mlir_to=mlir_path, ) assert mlir_path.exists() assert mlir_path.stat().st_size > 0 - assert isinstance(vmfb_bytes, memoryview) + assert len(vmfb_bytes) > 0 def test_compilation_with_save_vmfb(self, tmp_path): """Test compilation with VMFB saving.""" @@ -87,13 +102,16 @@ def test_compilation_with_save_vmfb(self, tmp_path): vmfb_bytes = compile_torch_module_to_iree( model, example_args=(example_input,), - compile_args=["--iree-hal-target-device=local-task"], + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], save_vmfb_to=vmfb_path, ) assert vmfb_path.exists() assert vmfb_path.stat().st_size > 0 - assert isinstance(vmfb_bytes, memoryview) + assert len(vmfb_bytes) > 0 def test_compilation_with_kwargs(self): """Test compilation with keyword arguments.""" @@ -104,10 +122,12 @@ def test_compilation_with_kwargs(self): model, example_args=tuple(), example_kwargs={"x": example_input}, - compile_args=["--iree-hal-target-device=local-task"], + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], ) - assert isinstance(vmfb_bytes, memoryview) assert len(vmfb_bytes) > 0 @@ -122,13 +142,16 @@ def test_basic_loading_and_execution(self): iree_module = load_torch_module_as_iree( model, example_args=(example_input,), - device="local-task", - compile_args=["--iree-hal-target-device=local-task"], + device="local-sync", + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], ) result = iree_module.forward(example_input) # Verify output structure - assert isinstance(result, tuple) + assert isinstance(result, list) assert len(result) == 1 assert result[0].shape == (2, 10) # Verify it's actually a tensor with values @@ -149,8 +172,11 @@ def test_output_matches_torch(self): iree_module = load_torch_module_as_iree( model, example_args=(example_input,), - device="local-task", - compile_args=["--iree-hal-target-device=local-task"], + device="local-sync", + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], ) iree_output = iree_module.forward(example_input) @@ -165,12 +191,15 @@ def test_multi_output_model(self): iree_module = load_torch_module_as_iree( model, example_args=(example_input,), - device="local-task", - compile_args=["--iree-hal-target-device=local-task"], + device="local-sync", + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], ) result = iree_module.forward(example_input) - assert isinstance(result, tuple) + assert isinstance(result, list) assert len(result) == 2 assert result[0].shape == (2, 64) assert result[1].shape == (2, 64) @@ -188,8 +217,11 @@ def test_output_type_mapper_changes_return_structure(self): iree_module = load_torch_module_as_iree( model, example_args=(example_input,), - device="local-task", - compile_args=["--iree-hal-target-device=local-task"], + device="local-sync", + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], ) result_tuple = iree_module.forward(example_input) @@ -200,14 +232,17 @@ def unwrap(outputs): iree_module_unwrapped = load_torch_module_as_iree( model, example_args=(example_input,), - device="local-task", - compile_args=["--iree-hal-target-device=local-task"], + device="local-sync", + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], output_type_mapper=unwrap, ) result_single = iree_module_unwrapped.forward(example_input) # Verify the transformation worked - assert isinstance(result_tuple, tuple) + assert isinstance(result_tuple, list) assert isinstance(result_single, torch.Tensor) assert not isinstance(result_single, tuple) torch.testing.assert_close(result_single, result_tuple[0]) @@ -225,8 +260,11 @@ def reconstruct_sharded(outputs): iree_module = load_torch_module_as_iree( model, example_args=(example_input,), - device="local-task", - compile_args=["--iree-hal-target-device=local-task"], + device="local-sync", + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], output_type_mapper=reconstruct_sharded, ) @@ -247,8 +285,11 @@ def to_dict(outputs): iree_module = load_torch_module_as_iree( model, example_args=(example_input,), - device="local-task", - compile_args=["--iree-hal-target-device=local-task"], + device="local-sync", + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], output_type_mapper=to_dict, ) @@ -271,11 +312,14 @@ def test_basic_oneshot(self): result = oneshot_compile_and_run( model, args=(example_input,), - device="local-task", - compile_args=("--iree-hal-target-device=local-task",), + device="local-sync", + compile_args=( + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ), ) - assert isinstance(result, tuple) + assert isinstance(result, list) assert len(result) == 1 assert result[0].shape == (2, 10) @@ -294,8 +338,11 @@ def test_oneshot_matches_torch(self): iree_output = oneshot_compile_and_run( model, args=(example_input,), - device="local-task", - compile_args=("--iree-hal-target-device=local-task",), + device="local-sync", + compile_args=( + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ), ) torch.testing.assert_close(iree_output[0], torch_output, rtol=1e-4, atol=1e-4) @@ -309,11 +356,14 @@ def test_oneshot_with_kwargs(self): model, args=tuple(), kwargs={"x": example_input}, - device="local-task", - compile_args=("--iree-hal-target-device=local-task",), + device="local-sync", + compile_args=( + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ), ) - assert isinstance(result, tuple) + assert isinstance(result, list) assert len(result) == 1 assert result[0].shape == (2, 10) @@ -342,8 +392,11 @@ def run_model(model: InferenceModule, input_data): iree_model = load_torch_module_as_iree( torch_model, example_args=(example_input,), - device="local-task", - compile_args=["--iree-hal-target-device=local-task"], + device="local-sync", + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], output_type_mapper=lambda x: x[0], # Unwrap for compatibility ) iree_result = run_model(iree_model, example_input) @@ -360,8 +413,11 @@ def test_call_vs_forward_equivalence(self): iree_model = load_torch_module_as_iree( model, example_args=(example_input,), - device="local-task", - compile_args=["--iree-hal-target-device=local-task"], + device="local-sync", + compile_args=[ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + ], output_type_mapper=lambda x: x[0], ) From 5b6fccc59638484d6606a4ce0e38f774833d3d79 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Thu, 16 Oct 2025 06:43:07 +0000 Subject: [PATCH 3/6] Cleanup --- sharktank/sharktank/utils/inference_module.py | 117 ----- sharktank/sharktank/utils/iree.py | 317 +++++++++---- .../sharktank/utils/iree_module_builder.py | 226 --------- .../tests/utils/iree_module_builder_test.py | 431 ------------------ sharktank/tests/utils/iree_test.py | 140 ++++++ 5 files changed, 365 insertions(+), 866 deletions(-) delete mode 100644 sharktank/sharktank/utils/inference_module.py delete mode 100644 sharktank/sharktank/utils/iree_module_builder.py delete mode 100644 sharktank/tests/utils/iree_module_builder_test.py diff --git a/sharktank/sharktank/utils/inference_module.py b/sharktank/sharktank/utils/inference_module.py deleted file mode 100644 index 4545756077f..00000000000 --- a/sharktank/sharktank/utils/inference_module.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Protocol-based interface for unified inference modules (torch and IREE).""" - -from typing import Any, Protocol, runtime_checkable, Union, List -import torch -import iree.runtime - - -@runtime_checkable -class InferenceModule(Protocol): - """Protocol for inference modules (both torch and IREE). - - This defines a common interface that both torch.nn.Module and - TorchLikeIreeModule can satisfy, allowing them to be used - interchangeably in inference code. - - Example: - >>> def run_inference(model: InferenceModule, inputs): - ... return model(inputs) - >>> - >>> # Works with torch modules - >>> torch_model = MyTorchModel() - >>> run_inference(torch_model, x) - >>> - >>> # Also works with IREE modules - >>> iree_model = load_torch_module_as_iree(torch_model, ...) - >>> run_inference(iree_model, x) - """ - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - """Execute the module's forward pass.""" - ... - - def forward(self, *args: Any, **kwargs: Any) -> Any: - """Execute the module's forward pass explicitly.""" - ... - - -def is_inference_module(obj: Any) -> bool: - """Check if an object satisfies the InferenceModule protocol. - - Args: - obj: Object to check - - Returns: - True if the object has both __call__ and forward methods - """ - return isinstance(obj, InferenceModule) - - -def get_inference_device( - module: InferenceModule, -) -> Union[torch.device, iree.runtime.HalDevice]: - """Get the device a module is running on. - - Args: - module: An inference module (torch or IREE) - - Returns: - The device the module is on (torch.device for torch modules, - iree.runtime.HalDevice for IREE modules) - - Example: - >>> torch_model = MyTorchModel().to("cuda") - >>> get_inference_device(torch_model) # torch.device('cuda:0') - >>> - >>> iree_model = load_torch_module_as_iree(torch_model, device="local-task") - >>> get_inference_device(iree_model) # - """ - if isinstance(module, torch.nn.Module): - # For torch modules, get device from first parameter - try: - return next(module.parameters()).device - except StopIteration: - # No parameters, assume CPU - return torch.device("cpu") - elif hasattr(module, "devices"): - # For IREE modules (TorchLikeIreeModule) - return module.devices[0] - else: - raise TypeError(f"Unknown module type: {type(module)}") - - -def inference_module_call( - module: InferenceModule, - *args: Any, - method: str = "forward", - **kwargs: Any, -) -> Any: - """Call a method on an inference module in a unified way. - - This handles the slight differences between torch and IREE modules - when calling methods other than forward(). - - Args: - module: The inference module to call - *args: Positional arguments - method: Name of the method to call (default: "forward") - **kwargs: Keyword arguments - - Returns: - The module's output - - Example: - >>> # Call forward - >>> output = inference_module_call(model, input_tensor) - >>> - >>> # Call a custom method - >>> output = inference_module_call(model, x, method="generate") - """ - method_fn = getattr(module, method) - return method_fn(*args, **kwargs) diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index ba9c01ab255..f98a8117577 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -4,7 +4,19 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any, Callable, List, Tuple, Optional, Union, overload, TYPE_CHECKING +from typing import ( + Any, + Callable, + List, + Tuple, + Optional, + Union, + overload, + TYPE_CHECKING, + Sequence, + Protocol, + runtime_checkable, +) import os import sys import json @@ -72,39 +84,221 @@ def oneshot_iree_run( device_count: int | None = None, compile_args: tuple[str, ...] = None, ) -> tuple[torch.Tensor, ...]: - """All in one: export, compile and run.""" + """One-shot function: export, compile, load, and run in one call. + This is useful for quick testing and benchmarking. For repeated use, + prefer load_torch_module_as_iree() to reuse the compiled module. + Args: + module: The torch.nn.Module to run + args: Positional arguments to pass to the function + kwargs: Keyword arguments to pass to the function + function: Name of the function to call (default: "forward") + device: IREE device(s) to run on + device_count: Number of devices to use + compile_args: Additional IREE compiler flags + Returns: + Tensor, or tuple of output tensors + Example: + >>> model = MyTorchModel() + >>> input_tensor = torch.randn(1, 3, 224, 224) + >>> outputs = oneshot_compile_and_run( + ... model, + ... args=(input_tensor,), + ... device="local-task" + ... ) + """ + vmfb_bytes = compile_torch_module_to_iree( + module=module, + example_args=args, + example_kwargs=kwargs, + function_name=function, + compile_args=compile_args, + ) + + # Get devices + iree_devices = get_iree_devices(device=device, device_count=device_count) + + def run(iree_devices: list[iree.runtime.HalDevice]): + vm_module, vm_context, vm_instance = load_iree_module( + module_buff=vmfb_bytes, devices=iree_devices + ) + torch_like_iree_module = TorchLikeIreeModule( + vm_module, vm_context, iree_devices + ) + results = getattr(torch_like_iree_module, function)(*args, **kwargs) + # Clone to avoid leaking IREE-backed torch tensors. + # Results can be a single tensor or tuple of tensors + if isinstance(results, torch.Tensor): + return results.clone() + else: + return tuple(t.clone() for t in results) + + return with_iree_device_context(run, iree_devices) + + +def compile_torch_module_to_iree( + module: torch.nn.Module, + example_args: tuple[Any, ...] = tuple(), + example_kwargs: dict[str, Any] = None, + function_name: str = "forward", + compile_args: Sequence[str] = None, + save_mlir_to: Optional[Path] = None, + save_vmfb_to: Optional[Path] = None, +) -> memoryview: + """Compile a torch module to IREE VMFB bytecode. + + Args: + module: The torch.nn.Module to compile + example_args: Example positional arguments for tracing + example_kwargs: Example keyword arguments for tracing + function_name: Name of the function to export (default: "forward") + compile_args: Additional IREE compiler flags + save_mlir_to: Optional path to save the exported MLIR + save_vmfb_to: Optional path to save the compiled VMFB + + Returns: + A memoryview of the compiled VMFB bytecode + + Example: + >>> model = MyTorchModel() + >>> example_input = torch.randn(1, 3, 224, 224) + >>> vmfb_bytes = compile_torch_module_to_iree( + ... model, + ... example_args=(example_input,), + ... compile_args=["--iree-hal-target-device=local-task"] + ... ) + """ from iree.turbine import aot from iree.turbine.aot import FxProgramsBuilder + if example_kwargs is None: + example_kwargs = {} + + # Export to MLIR using turbine fxb = FxProgramsBuilder(module) - @fxb.export_program(name=function, args=args, kwargs=kwargs, strict=False) + @fxb.export_program( + name=function_name, args=example_args, kwargs=example_kwargs, strict=False + ) def _(module, *args, **kwargs): - return getattr(module, function)(*args, **kwargs) + return getattr(module, function_name)(*args, **kwargs) - export_output = aot.export( - fxb, - ) + export_output = aot.export(fxb) + + # Save MLIR if requested + if save_mlir_to is not None: + export_output.save_mlir(save_mlir_to) + + # Set compiler flags if compile_args is not None: export_output.session.set_flags(*compile_args) - memory_view: memoryview = export_output.compile( - save_to=None, target_backends=None - ).map_memory() + + # Compile to VMFB + if save_vmfb_to is not None: + export_output.compile(save_to=str(save_vmfb_to), target_backends=None) + with open(save_vmfb_to, "rb") as f: + return memoryview(f.read()) + else: + return export_output.compile(save_to=None, target_backends=None).map_memory() + + +def load_torch_module_as_iree( + module: torch.nn.Module, + example_args: tuple[Any, ...] = tuple(), + example_kwargs: dict[str, Any] = None, + function_name: str = "forward", + device: str | list[str] = "local-task", + device_count: int | None = None, + compile_args: Sequence[str] = None, + parameters_path: Optional[str] = None, + save_mlir_to: Optional[Path] = None, + save_vmfb_to: Optional[Path] = None, +) -> "TorchLikeIreeModule": + """Compile a torch module to IREE and load it as a TorchLikeIreeModule. + + This is a high-level convenience function that combines export, compilation, + and loading into a single call. + + Args: + module: The torch.nn.Module to compile + example_args: Example positional arguments for tracing + example_kwargs: Example keyword arguments for tracing + function_name: Name of the function to export (default: "forward") + device: IREE device(s) to load on (e.g., "local-task", "hip://0") + device_count: Number of devices to use (for multi-device scenarios) + compile_args: Additional IREE compiler flags + parameters_path: Optional path to external parameters (IRPA file) + save_mlir_to: Optional path to save the exported MLIR + save_vmfb_to: Optional path to save the compiled VMFB + + Returns: + A TorchLikeIreeModule that can be called like the original torch module. + Single outputs are unwrapped, multiple outputs are returned as tuples. + + Example: + >>> model = MyTorchModel() + >>> example_input = torch.randn(1, 3, 224, 224) + >>> iree_model = load_torch_module_as_iree( + ... model, + ... example_args=(example_input,), + ... device="local-task" + ... ) + >>> output = iree_model.forward(example_input) # Single tensor, not list/tuple + """ + # Compile the module + vmfb_bytes = compile_torch_module_to_iree( + module=module, + example_args=example_args, + example_kwargs=example_kwargs, + function_name=function_name, + compile_args=compile_args, + save_mlir_to=save_mlir_to, + save_vmfb_to=save_vmfb_to, + ) + + # Get devices iree_devices = get_iree_devices(device=device, device_count=device_count) - def run(iree_devices: list[iree.runtime.HalDevice]): + # Load the module + def load_fn(devices: list[iree.runtime.HalDevice]) -> TorchLikeIreeModule: vm_module, vm_context, vm_instance = load_iree_module( - module_buff=memory_view, devices=iree_devices - ) - torch_like_iree_module = TorchLikeIreeModule( - vm_module, vm_context, iree_devices + module_buff=vmfb_bytes, + devices=devices, + parameters_path=parameters_path, + tensor_parallel_size=len(devices), ) - results = getattr(torch_like_iree_module, function)(*args, **kwargs) - # Clone to avoid leaking IREE-backed torch tensors. - results = tuple(t.clone() for t in results) - return results + return TorchLikeIreeModule(vm_module, vm_context, devices) - return with_iree_device_context(run, iree_devices) + return with_iree_device_context(load_fn, iree_devices) + + +@runtime_checkable +class InferenceModule(Protocol): + """Protocol for inference modules (both torch and IREE). + + This defines a common interface that both torch.nn.Module and + TorchLikeIreeModule can satisfy, allowing them to be used + interchangeably in inference code. + + Example: + >>> def run_inference(model: InferenceModule, inputs): + ... return model(inputs) + >>> + >>> # Works with torch modules + >>> torch_model = MyTorchModel() + >>> run_inference(torch_model, x) + >>> + >>> # Also works with IREE modules + >>> iree_model = load_torch_module_as_iree(torch_model, ...) + >>> run_inference(iree_model, x) + """ + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Execute the module's forward pass.""" + ... + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Execute the module's forward pass explicitly.""" + ... class TorchLikeIreeModule: @@ -113,8 +307,11 @@ class TorchLikeIreeModule: This handles marshaling of torch tensor and sharktank.type.InferenceTensor arguments. Unfortunately, we can't marshall the output back to the correct tensor types as - some of the information is lost. E.g. the sharded tensor types. We return a flat - list of torch tensors. + some of the information is lost. E.g. the sharded tensor types. + + Returns: + - Single output: Returns the tensor directly + - Multiple outputs: Returns a tuple of tensors """ def __init__( @@ -134,7 +331,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: def __getattr__(self, name: str) -> Any: def f( *args: tuple[Any, ...], **kwargs: dict[str, Any] - ) -> tuple[torch.Tensor, ...]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: flat_args = flatten_for_iree_signature( ( args, @@ -156,7 +353,11 @@ def f( for arg, iree_arg in zip(flat_args, iree_args_post_call): arg[...] = iree_arg - return res + # Match torch.nn.Module behavior: single output unwrapped, multiple as tuple + if len(res) == 1: + return res[0] + else: + return tuple(res) return f @@ -757,71 +958,3 @@ def run_model_with_iree_run_module( input_args = [f"--input={arg.strip()}" for arg in input_args] cmd += input_args subprocess.check_call(cmd, **subprocess_run_kwargs) - - -class TypePreservingIreeModule(TorchLikeIreeModule): - """Extension of TorchLikeIreeModule that preserves output types. - - This addresses the limitation where output type information (ShardedTensor, - InferenceTensor, single vs tuple, etc.) is lost during IREE execution. - - You provide an output_type_mapper function that transforms the flat tuple - of tensors returned by IREE back into the original structure. - - Example: - >>> # Simple case: torch module returns single tensor, but IREE returns tuple - >>> def unwrap_single(outputs): - ... return outputs[0] - >>> - >>> iree_module = TypePreservingIreeModule( - ... vm_module, vm_context, devices, - ... output_type_mapper=unwrap_single - ... ) - >>> result = iree_module.forward(x) # Returns single tensor, not tuple - - >>> # Complex case: reconstruct ShardedTensor - >>> def reconstruct_sharded(outputs): - ... return SplitPrimitiveTensor(ts=outputs, shard_dim=1) - >>> - >>> iree_module = TypePreservingIreeModule( - ... vm_module, vm_context, devices, - ... output_type_mapper=reconstruct_sharded - ... ) - >>> result = iree_module.forward(x) # Returns ShardedTensor - """ - - def __init__( - self, - module: iree.runtime.VmModule, - vm_context: iree.runtime.VmContext, - devices: List[iree.runtime.HalDevice], - output_type_mapper: Callable[[tuple[torch.Tensor, ...]], Any], - ): - """Initialize with an output type mapper. - - Args: - module: IREE VmModule - vm_context: IREE VmContext - devices: List of IREE HalDevices - output_type_mapper: Function that transforms flat tensor tuple - back to the original output structure - """ - super().__init__(module, vm_context, devices) - self.output_type_mapper = output_type_mapper - - def __getattr__(self, name: str) -> Any: - """Override to apply output type mapping.""" - # Avoid recursion on our own attributes - if name == "output_type_mapper": - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) - - # Get the base method from parent - base_method = super().__getattr__(name) - - def wrapped_method(*args, **kwargs): - result = base_method(*args, **kwargs) - return self.output_type_mapper(result) - - return wrapped_method diff --git a/sharktank/sharktank/utils/iree_module_builder.py b/sharktank/sharktank/utils/iree_module_builder.py deleted file mode 100644 index 9e9bd820678..00000000000 --- a/sharktank/sharktank/utils/iree_module_builder.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""High-level utilities for converting torch modules to IREE modules.""" - -from typing import Any, Optional, Sequence -from pathlib import Path -import tempfile - -import torch -from iree.turbine import aot -from iree.turbine.aot import FxProgramsBuilder -import iree.runtime - -from .iree import ( - TorchLikeIreeModule, - TypePreservingIreeModule, - get_iree_devices, - load_iree_module, - with_iree_device_context, -) - - -def compile_torch_module_to_iree( - module: torch.nn.Module, - example_args: tuple[Any, ...] = tuple(), - example_kwargs: dict[str, Any] = None, - function_name: str = "forward", - compile_args: Sequence[str] = None, - save_mlir_to: Optional[Path] = None, - save_vmfb_to: Optional[Path] = None, -) -> memoryview: - """Compile a torch module to IREE VMFB bytecode. - - Args: - module: The torch.nn.Module to compile - example_args: Example positional arguments for tracing - example_kwargs: Example keyword arguments for tracing - function_name: Name of the function to export (default: "forward") - compile_args: Additional IREE compiler flags - save_mlir_to: Optional path to save the exported MLIR - save_vmfb_to: Optional path to save the compiled VMFB - - Returns: - A memoryview of the compiled VMFB bytecode - - Example: - >>> model = MyTorchModel() - >>> example_input = torch.randn(1, 3, 224, 224) - >>> vmfb_bytes = compile_torch_module_to_iree( - ... model, - ... example_args=(example_input,), - ... compile_args=["--iree-hal-target-device=local-task"] - ... ) - """ - if example_kwargs is None: - example_kwargs = {} - - # Export to MLIR using turbine - fxb = FxProgramsBuilder(module) - - @fxb.export_program( - name=function_name, args=example_args, kwargs=example_kwargs, strict=False - ) - def _(module, *args, **kwargs): - return getattr(module, function_name)(*args, **kwargs) - - export_output = aot.export(fxb) - - # Save MLIR if requested - if save_mlir_to is not None: - export_output.save_mlir(save_mlir_to) - - # Set compiler flags - if compile_args is not None: - export_output.session.set_flags(*compile_args) - - # Compile to VMFB - if save_vmfb_to is not None: - export_output.compile(save_to=str(save_vmfb_to), target_backends=None) - with open(save_vmfb_to, "rb") as f: - return memoryview(f.read()) - else: - return export_output.compile(save_to=None, target_backends=None).map_memory() - - -def load_torch_module_as_iree( - module: torch.nn.Module, - example_args: tuple[Any, ...] = tuple(), - example_kwargs: dict[str, Any] = None, - function_name: str = "forward", - device: str | list[str] = "local-task", - device_count: int | None = None, - compile_args: Sequence[str] = None, - parameters_path: Optional[str] = None, - save_mlir_to: Optional[Path] = None, - save_vmfb_to: Optional[Path] = None, - output_type_mapper: Optional[Any] = None, -) -> TorchLikeIreeModule: - """Compile a torch module to IREE and load it as a TorchLikeIreeModule. - - This is a high-level convenience function that combines export, compilation, - and loading into a single call. - - Args: - module: The torch.nn.Module to compile - example_args: Example positional arguments for tracing - example_kwargs: Example keyword arguments for tracing - function_name: Name of the function to export (default: "forward") - device: IREE device(s) to load on (e.g., "local-task", "hip://0") - device_count: Number of devices to use (for multi-device scenarios) - compile_args: Additional IREE compiler flags - parameters_path: Optional path to external parameters (IRPA file) - save_mlir_to: Optional path to save the exported MLIR - save_vmfb_to: Optional path to save the compiled VMFB - output_type_mapper: Optional function to transform IREE's flat tensor tuple - back to the original output structure. If provided, returns a - TypePreservingIreeModule (subclass of TorchLikeIreeModule). - - Returns: - A TorchLikeIreeModule that can be called like the original torch module. - If output_type_mapper is provided, returns TypePreservingIreeModule. - - Example: - >>> model = MyTorchModel() - >>> example_input = torch.randn(1, 3, 224, 224) - >>> iree_model = load_torch_module_as_iree( - ... model, - ... example_args=(example_input,), - ... device="local-task" - ... ) - >>> output = iree_model.forward(example_input) - >>> - >>> # With type preservation (returns single tensor instead of tuple) - >>> def unwrap(outputs): return outputs[0] - >>> iree_model = load_torch_module_as_iree( - ... model, - ... example_args=(example_input,), - ... output_type_mapper=unwrap - ... ) - >>> output = iree_model.forward(example_input) # Single tensor, not tuple - """ - # Compile the module - vmfb_bytes = compile_torch_module_to_iree( - module=module, - example_args=example_args, - example_kwargs=example_kwargs, - function_name=function_name, - compile_args=compile_args, - save_mlir_to=save_mlir_to, - save_vmfb_to=save_vmfb_to, - ) - - # Get devices - iree_devices = get_iree_devices(device=device, device_count=device_count) - - # Load the module - def load_fn(devices: list[iree.runtime.HalDevice]) -> TorchLikeIreeModule: - vm_module, vm_context, vm_instance = load_iree_module( - module_buff=vmfb_bytes, - devices=devices, - parameters_path=parameters_path, - tensor_parallel_size=len(devices), - ) - if output_type_mapper is not None: - return TypePreservingIreeModule( - vm_module, vm_context, devices, output_type_mapper - ) - else: - return TorchLikeIreeModule(vm_module, vm_context, devices) - - return with_iree_device_context(load_fn, iree_devices) - - -def oneshot_compile_and_run( - module: torch.nn.Module, - args: tuple[Any, ...] = tuple(), - kwargs: dict[str, Any] = None, - function: str = "forward", - device: str | list[str] = "local-task", - device_count: int | None = None, - compile_args: Sequence[str] = None, -) -> tuple[torch.Tensor, ...]: - """One-shot function: export, compile, load, and run in one call. - - This is useful for quick testing and benchmarking. For production use, - prefer load_torch_module_as_iree() to reuse the compiled module. - - Args: - module: The torch.nn.Module to run - args: Positional arguments to pass to the function - kwargs: Keyword arguments to pass to the function - function: Name of the function to call (default: "forward") - device: IREE device(s) to run on - device_count: Number of devices to use - compile_args: Additional IREE compiler flags - - Returns: - Tuple of output tensors - - Example: - >>> model = MyTorchModel() - >>> input_tensor = torch.randn(1, 3, 224, 224) - >>> outputs = oneshot_compile_and_run( - ... model, - ... args=(input_tensor,), - ... device="local-task" - ... ) - """ - if kwargs is None: - kwargs = {} - - iree_module = load_torch_module_as_iree( - module=module, - example_args=args, - example_kwargs=kwargs, - function_name=function, - device=device, - device_count=device_count, - compile_args=compile_args, - ) - - return getattr(iree_module, function)(*args, **kwargs) diff --git a/sharktank/tests/utils/iree_module_builder_test.py b/sharktank/tests/utils/iree_module_builder_test.py deleted file mode 100644 index 165761fbdd8..00000000000 --- a/sharktank/tests/utils/iree_module_builder_test.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Tests for high-level IREE module builder utilities. - -These are integration tests that compile and run IREE modules. -They can be slow, so they're marked with @pytest.mark.iree_integration. - -Run with: pytest -m iree_integration -Skip with: pytest -m "not iree_integration" -""" - -import pytest -import torch -import torch.nn as nn - -from sharktank.utils.iree_module_builder import ( - compile_torch_module_to_iree, - load_torch_module_as_iree, - oneshot_compile_and_run, -) -from sharktank.utils.iree import TypePreservingIreeModule, TorchLikeIreeModule -from sharktank.types import SplitPrimitiveTensor - -pytestmark = pytest.mark.iree_integration - - -class SimpleModel(nn.Module): - """Simple test model.""" - - def __init__(self, hidden_size=64): - super().__init__() - self.fc1 = nn.Linear(32, hidden_size) - self.fc2 = nn.Linear(hidden_size, 10) - - def forward(self, x): - x = torch.relu(self.fc1(x)) - return self.fc2(x) - - -class MultiOutputModel(nn.Module): - """Model that returns multiple outputs.""" - - def __init__(self): - super().__init__() - self.fc = nn.Linear(32, 64) - - def forward(self, x): - h = self.fc(x) - return torch.relu(h), torch.tanh(h) - - -class TestCompileTorchModule: - """Tests for compile_torch_module_to_iree.""" - - def test_basic_compilation(self): - """Test basic compilation without saving artifacts.""" - model = SimpleModel() - example_input = torch.randn(2, 32) - - vmfb_bytes = compile_torch_module_to_iree( - model, - example_args=(example_input,), - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - ) - - # Verify we got bytecode - assert len(vmfb_bytes) > 0 - - def test_compilation_with_save_mlir(self, tmp_path): - """Test compilation with MLIR saving.""" - model = SimpleModel() - example_input = torch.randn(2, 32) - mlir_path = tmp_path / "model.mlir" - - vmfb_bytes = compile_torch_module_to_iree( - model, - example_args=(example_input,), - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - save_mlir_to=mlir_path, - ) - - assert mlir_path.exists() - assert mlir_path.stat().st_size > 0 - assert len(vmfb_bytes) > 0 - - def test_compilation_with_save_vmfb(self, tmp_path): - """Test compilation with VMFB saving.""" - model = SimpleModel() - example_input = torch.randn(2, 32) - vmfb_path = tmp_path / "model.vmfb" - - vmfb_bytes = compile_torch_module_to_iree( - model, - example_args=(example_input,), - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - save_vmfb_to=vmfb_path, - ) - - assert vmfb_path.exists() - assert vmfb_path.stat().st_size > 0 - assert len(vmfb_bytes) > 0 - - def test_compilation_with_kwargs(self): - """Test compilation with keyword arguments.""" - model = SimpleModel() - example_input = torch.randn(2, 32) - - vmfb_bytes = compile_torch_module_to_iree( - model, - example_args=tuple(), - example_kwargs={"x": example_input}, - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - ) - - assert len(vmfb_bytes) > 0 - - -class TestLoadTorchModuleAsIree: - """Tests for load_torch_module_as_iree.""" - - def test_basic_loading_and_execution(self): - """Test that loaded module executes and produces correct output shape.""" - model = SimpleModel() - example_input = torch.randn(2, 32) - - iree_module = load_torch_module_as_iree( - model, - example_args=(example_input,), - device="local-sync", - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - ) - - result = iree_module.forward(example_input) - # Verify output structure - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].shape == (2, 10) - # Verify it's actually a tensor with values - assert not torch.isnan(result[0]).any() - - def test_output_matches_torch(self): - """Test that IREE output matches torch output.""" - torch.manual_seed(42) - model = SimpleModel() - model.eval() - example_input = torch.randn(2, 32) - - # Get torch output - with torch.no_grad(): - torch_output = model(example_input) - - # Get IREE output - iree_module = load_torch_module_as_iree( - model, - example_args=(example_input,), - device="local-sync", - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - ) - iree_output = iree_module.forward(example_input) - - # Compare - torch.testing.assert_close(iree_output[0], torch_output, rtol=1e-4, atol=1e-4) - - def test_multi_output_model(self): - """Test model with multiple outputs.""" - model = MultiOutputModel() - example_input = torch.randn(2, 32) - - iree_module = load_torch_module_as_iree( - model, - example_args=(example_input,), - device="local-sync", - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - ) - - result = iree_module.forward(example_input) - assert isinstance(result, list) - assert len(result) == 2 - assert result[0].shape == (2, 64) - assert result[1].shape == (2, 64) - - -class TestTypePreservingIreeModule: - """Tests for TypePreservingIreeModule and output_type_mapper.""" - - def test_output_type_mapper_changes_return_structure(self): - """Test that output_type_mapper successfully transforms return type.""" - model = SimpleModel() - example_input = torch.randn(2, 32) - - # Without type mapper - returns tuple - iree_module = load_torch_module_as_iree( - model, - example_args=(example_input,), - device="local-sync", - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - ) - result_tuple = iree_module.forward(example_input) - - # With type mapper - returns unwrapped single tensor - def unwrap(outputs): - return outputs[0] - - iree_module_unwrapped = load_torch_module_as_iree( - model, - example_args=(example_input,), - device="local-sync", - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - output_type_mapper=unwrap, - ) - result_single = iree_module_unwrapped.forward(example_input) - - # Verify the transformation worked - assert isinstance(result_tuple, list) - assert isinstance(result_single, torch.Tensor) - assert not isinstance(result_single, tuple) - torch.testing.assert_close(result_single, result_tuple[0]) - - def test_reconstruct_sharded_tensor(self): - """Test reconstructing a ShardedTensor-like output.""" - model = MultiOutputModel() - example_input = torch.randn(2, 32) - - # Simulate sharded output reconstruction - def reconstruct_sharded(outputs): - # Treat the two outputs as shards - return SplitPrimitiveTensor(ts=outputs, shard_dim=1) - - iree_module = load_torch_module_as_iree( - model, - example_args=(example_input,), - device="local-sync", - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - output_type_mapper=reconstruct_sharded, - ) - - result = iree_module.forward(example_input) - assert isinstance(result, SplitPrimitiveTensor) - assert result.shard_dim == 1 - assert len(result.shards) == 2 - - def test_custom_transformation(self): - """Test custom output transformation.""" - model = MultiOutputModel() - example_input = torch.randn(2, 32) - - # Custom transformer: return dict - def to_dict(outputs): - return {"relu": outputs[0], "tanh": outputs[1]} - - iree_module = load_torch_module_as_iree( - model, - example_args=(example_input,), - device="local-sync", - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - output_type_mapper=to_dict, - ) - - result = iree_module.forward(example_input) - assert isinstance(result, dict) - assert "relu" in result - assert "tanh" in result - assert result["relu"].shape == (2, 64) - assert result["tanh"].shape == (2, 64) - - -class TestOneshotCompileAndRun: - """Tests for oneshot_compile_and_run.""" - - def test_basic_oneshot(self): - """Test basic one-shot execution.""" - model = SimpleModel() - example_input = torch.randn(2, 32) - - result = oneshot_compile_and_run( - model, - args=(example_input,), - device="local-sync", - compile_args=( - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ), - ) - - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].shape == (2, 10) - - def test_oneshot_matches_torch(self): - """Test that one-shot execution matches torch.""" - torch.manual_seed(42) - model = SimpleModel() - model.eval() - example_input = torch.randn(2, 32) - - # Torch output - with torch.no_grad(): - torch_output = model(example_input) - - # IREE output - iree_output = oneshot_compile_and_run( - model, - args=(example_input,), - device="local-sync", - compile_args=( - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ), - ) - - torch.testing.assert_close(iree_output[0], torch_output, rtol=1e-4, atol=1e-4) - - def test_oneshot_with_kwargs(self): - """Test one-shot with keyword arguments.""" - model = SimpleModel() - example_input = torch.randn(2, 32) - - result = oneshot_compile_and_run( - model, - args=tuple(), - kwargs={"x": example_input}, - device="local-sync", - compile_args=( - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ), - ) - - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].shape == (2, 10) - - -class TestInferenceModuleProtocol: - """Tests for Protocol-based usage.""" - - def test_protocol_compatibility(self): - """Test that both torch and IREE modules work with generic inference code.""" - from sharktank.utils.inference_module import InferenceModule - - def run_model(model: InferenceModule, input_data): - """Generic function that works with any InferenceModule.""" - return model.forward(input_data) - - torch_model = SimpleModel() - example_input = torch.randn(2, 32) - - # Works with torch module - torch_model.eval() - with torch.no_grad(): - torch_result = run_model(torch_model, example_input) - assert torch_result.shape == (2, 10) - - # Works with IREE module - iree_model = load_torch_module_as_iree( - torch_model, - example_args=(example_input,), - device="local-sync", - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - output_type_mapper=lambda x: x[0], # Unwrap for compatibility - ) - iree_result = run_model(iree_model, example_input) - assert iree_result.shape == (2, 10) - - # Results should be close - torch.testing.assert_close(iree_result, torch_result, rtol=1e-4, atol=1e-4) - - def test_call_vs_forward_equivalence(self): - """Test that calling via __call__ produces same result as forward().""" - model = SimpleModel() - example_input = torch.randn(2, 32) - - iree_model = load_torch_module_as_iree( - model, - example_args=(example_input,), - device="local-sync", - compile_args=[ - "--iree-hal-target-device=local", - "--iree-hal-local-target-device-backends=llvm-cpu", - ], - output_type_mapper=lambda x: x[0], - ) - - result1 = iree_model(example_input) - result2 = iree_model.forward(example_input) - - torch.testing.assert_close(result1, result2) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/sharktank/tests/utils/iree_test.py b/sharktank/tests/utils/iree_test.py index 77d4fe98f26..07e9b6b8d2a 100644 --- a/sharktank/tests/utils/iree_test.py +++ b/sharktank/tests/utils/iree_test.py @@ -9,6 +9,7 @@ import pytest import platform import torch +import torch.nn as nn from parameterized import parameterized from pathlib import Path @@ -16,8 +17,11 @@ from sharktank.types import DefaultPrimitiveTensor from sharktank.utils import chdir from sharktank.utils.iree import ( + compile_torch_module_to_iree, device_array_to_host, get_iree_devices, + load_torch_module_as_iree, + oneshot_iree_run, run_model_with_iree_run_module, tensor_to_device_array, trace_model_with_tracy, @@ -115,3 +119,139 @@ def roundtrip(iree_devices: list[iree.runtime.HalDevice]): assert ops.equal(tensor, tensor_roundtrip) with_iree_device_context(roundtrip, iree_devices) + + +COMPILE_FLAGS = [ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", +] + + +class SimpleModel(nn.Module): + def forward(self, x): + return torch.relu(x) + 1.0 + + +class MultiOutputModel(nn.Module): + def forward(self, x): + return torch.relu(x), torch.tanh(x) + + +class TestCompileTorchModule: + """Tests for compile_torch_module_to_iree.""" + + def test_compilation(self, tmp_path): + """Test compilation with optional artifact saving.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + mlir_path = tmp_path / "model.mlir" + vmfb_path = tmp_path / "model.vmfb" + + vmfb_bytes = compile_torch_module_to_iree( + model, + example_args=(example_input,), + compile_args=COMPILE_FLAGS, + save_mlir_to=mlir_path, + save_vmfb_to=vmfb_path, + ) + + assert len(vmfb_bytes) > 100 + + assert mlir_path.exists() + assert mlir_path.stat().st_size > 0 + + assert vmfb_path.exists() + assert vmfb_path.stat().st_size > 0 + + +class TestLoadTorchModuleAsIree: + """Tests for load_torch_module_as_iree.""" + + def test_basic_loading_and_execution(self): + """Test that loaded module executes and produces correct output shape.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + + iree_module = load_torch_module_as_iree( + model, + example_args=(example_input,), + device="local-sync", + compile_args=COMPILE_FLAGS, + ) + + result = iree_module.forward(example_input) + assert isinstance(result, torch.Tensor) + assert result.shape == (2, 32) + assert not torch.isnan(result).any() + + def test_output_matches_torch(self): + """Test that IREE output matches torch output.""" + torch.manual_seed(42) + model = SimpleModel() + model.eval() + example_input = torch.randn(2, 32) + + torch_output = model(example_input) + iree_module = load_torch_module_as_iree( + model, + example_args=(example_input,), + device="local-sync", + compile_args=COMPILE_FLAGS, + ) + iree_output = iree_module.forward(example_input) + + torch.testing.assert_close(iree_output, torch_output, rtol=1e-4, atol=1e-4) + + def test_multi_output_model(self): + """Test model with multiple outputs.""" + model = MultiOutputModel() + example_input = torch.randn(2, 32) + + iree_module = load_torch_module_as_iree( + model, + example_args=(example_input,), + device="local-sync", + compile_args=COMPILE_FLAGS, + ) + + result = iree_module.forward(example_input) + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0].shape == (2, 32) + assert result[1].shape == (2, 32) + + +class TestOneshotCompileAndRun: + """Tests for oneshot_iree_run.""" + + def test_basic_oneshot(self): + """Test basic one-shot execution.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + + result = oneshot_iree_run( + model, + args=(example_input,), + device="local-sync", + compile_args=COMPILE_FLAGS, + ) + assert isinstance(result, torch.Tensor) + assert result.shape == (2, 32) + assert not torch.isnan(result).any() + + def test_oneshot_matches_torch(self): + """Test that one-shot execution matches torch.""" + torch.manual_seed(42) + model = SimpleModel() + model.eval() + example_input = torch.randn(2, 32) + + torch_output = model(example_input) + iree_output = oneshot_iree_run( + model, + args=(example_input,), + device="local-sync", + compile_args=COMPILE_FLAGS, + ) + + torch.testing.assert_close(iree_output, torch_output, rtol=1e-4, atol=1e-4) From fccc0b2a7cf15c250f0336990865a2b302b57e81 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Thu, 16 Oct 2025 06:58:20 +0000 Subject: [PATCH 4/6] More clean up --- sharktank/sharktank/utils/iree.py | 22 +++++++--------------- sharktank/tests/utils/iree_test.py | 12 ++++++------ 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index f98a8117577..707ec166d4a 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -86,7 +86,7 @@ def oneshot_iree_run( ) -> tuple[torch.Tensor, ...]: """One-shot function: export, compile, load, and run in one call. This is useful for quick testing and benchmarking. For repeated use, - prefer load_torch_module_as_iree() to reuse the compiled module. + prefer adapt_torch_module_to_iree() to reuse the compiled module. Args: module: The torch.nn.Module to run args: Positional arguments to pass to the function @@ -114,7 +114,6 @@ def oneshot_iree_run( compile_args=compile_args, ) - # Get devices iree_devices = get_iree_devices(device=device, device_count=device_count) def run(iree_devices: list[iree.runtime.HalDevice]): @@ -144,7 +143,7 @@ def compile_torch_module_to_iree( save_mlir_to: Optional[Path] = None, save_vmfb_to: Optional[Path] = None, ) -> memoryview: - """Compile a torch module to IREE VMFB bytecode. + """Compile a torch module using IREE to VMFB bytecode. Args: module: The torch.nn.Module to compile @@ -173,7 +172,6 @@ def compile_torch_module_to_iree( if example_kwargs is None: example_kwargs = {} - # Export to MLIR using turbine fxb = FxProgramsBuilder(module) @fxb.export_program( @@ -184,15 +182,12 @@ def _(module, *args, **kwargs): export_output = aot.export(fxb) - # Save MLIR if requested if save_mlir_to is not None: export_output.save_mlir(save_mlir_to) - # Set compiler flags if compile_args is not None: export_output.session.set_flags(*compile_args) - # Compile to VMFB if save_vmfb_to is not None: export_output.compile(save_to=str(save_vmfb_to), target_backends=None) with open(save_vmfb_to, "rb") as f: @@ -201,7 +196,7 @@ def _(module, *args, **kwargs): return export_output.compile(save_to=None, target_backends=None).map_memory() -def load_torch_module_as_iree( +def adapt_torch_module_to_iree( module: torch.nn.Module, example_args: tuple[Any, ...] = tuple(), example_kwargs: dict[str, Any] = None, @@ -213,10 +208,10 @@ def load_torch_module_as_iree( save_mlir_to: Optional[Path] = None, save_vmfb_to: Optional[Path] = None, ) -> "TorchLikeIreeModule": - """Compile a torch module to IREE and load it as a TorchLikeIreeModule. + """Compile a torch module to IREE and adapt it as a TorchLikeIreeModule. This is a high-level convenience function that combines export, compilation, - and loading into a single call. + and adaption into a single call. Args: module: The torch.nn.Module to compile @@ -237,14 +232,13 @@ def load_torch_module_as_iree( Example: >>> model = MyTorchModel() >>> example_input = torch.randn(1, 3, 224, 224) - >>> iree_model = load_torch_module_as_iree( + >>> iree_model = adapt_torch_module_to_iree( ... model, ... example_args=(example_input,), ... device="local-task" ... ) >>> output = iree_model.forward(example_input) # Single tensor, not list/tuple """ - # Compile the module vmfb_bytes = compile_torch_module_to_iree( module=module, example_args=example_args, @@ -255,10 +249,8 @@ def load_torch_module_as_iree( save_vmfb_to=save_vmfb_to, ) - # Get devices iree_devices = get_iree_devices(device=device, device_count=device_count) - # Load the module def load_fn(devices: list[iree.runtime.HalDevice]) -> TorchLikeIreeModule: vm_module, vm_context, vm_instance = load_iree_module( module_buff=vmfb_bytes, @@ -288,7 +280,7 @@ class InferenceModule(Protocol): >>> run_inference(torch_model, x) >>> >>> # Also works with IREE modules - >>> iree_model = load_torch_module_as_iree(torch_model, ...) + >>> iree_model = adapt_torch_module_to_iree(torch_model, ...) >>> run_inference(iree_model, x) """ diff --git a/sharktank/tests/utils/iree_test.py b/sharktank/tests/utils/iree_test.py index 07e9b6b8d2a..1e6b688e9a2 100644 --- a/sharktank/tests/utils/iree_test.py +++ b/sharktank/tests/utils/iree_test.py @@ -17,10 +17,10 @@ from sharktank.types import DefaultPrimitiveTensor from sharktank.utils import chdir from sharktank.utils.iree import ( + adapt_torch_module_to_iree, compile_torch_module_to_iree, device_array_to_host, get_iree_devices, - load_torch_module_as_iree, oneshot_iree_run, run_model_with_iree_run_module, tensor_to_device_array, @@ -164,15 +164,15 @@ def test_compilation(self, tmp_path): assert vmfb_path.stat().st_size > 0 -class TestLoadTorchModuleAsIree: - """Tests for load_torch_module_as_iree.""" +class TestAdaptTorchModuleToIree: + """Tests for adapt_torch_module_to_iree.""" def test_basic_loading_and_execution(self): """Test that loaded module executes and produces correct output shape.""" model = SimpleModel() example_input = torch.randn(2, 32) - iree_module = load_torch_module_as_iree( + iree_module = adapt_torch_module_to_iree( model, example_args=(example_input,), device="local-sync", @@ -192,7 +192,7 @@ def test_output_matches_torch(self): example_input = torch.randn(2, 32) torch_output = model(example_input) - iree_module = load_torch_module_as_iree( + iree_module = adapt_torch_module_to_iree( model, example_args=(example_input,), device="local-sync", @@ -207,7 +207,7 @@ def test_multi_output_model(self): model = MultiOutputModel() example_input = torch.randn(2, 32) - iree_module = load_torch_module_as_iree( + iree_module = adapt_torch_module_to_iree( model, example_args=(example_input,), device="local-sync", From 52ccb5aca828fca38f7fc319c69df2fceb213db5 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Mon, 27 Oct 2025 13:55:33 -0700 Subject: [PATCH 5/6] Address PR comments --- sharktank/sharktank/types/module.py | 37 ++++++++++++++++++++++ sharktank/sharktank/utils/iree.py | 49 +++++------------------------ sharktank/tests/utils/iree_test.py | 31 +++++++++--------- 3 files changed, 58 insertions(+), 59 deletions(-) create mode 100644 sharktank/sharktank/types/module.py diff --git a/sharktank/sharktank/types/module.py b/sharktank/sharktank/types/module.py new file mode 100644 index 00000000000..914328d8c25 --- /dev/null +++ b/sharktank/sharktank/types/module.py @@ -0,0 +1,37 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class InferenceModule(Protocol): + """Protocol for inference modules (both torch and IREE). + + This defines a common interface that both torch.nn.Module and + TorchLikeIreeModule can satisfy, allowing them to be used + interchangeably in inference code. + + Example: + >>> def run_inference(model: InferenceModule, inputs): + ... return model(inputs) + >>> + >>> # Works with torch modules + >>> torch_model = MyTorchModel() + >>> run_inference(torch_model, x) + >>> + >>> # Also works with IREE modules + >>> iree_model = adapt_torch_module_to_iree(torch_model, ...) + >>> run_inference(iree_model, x) + """ + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Execute the module's forward pass.""" + ... + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Execute the module's forward pass explicitly.""" + ... \ No newline at end of file diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index 707ec166d4a..4efa4cbdaf0 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -14,8 +14,6 @@ overload, TYPE_CHECKING, Sequence, - Protocol, - runtime_checkable, ) import os import sys @@ -251,46 +249,13 @@ def adapt_torch_module_to_iree( iree_devices = get_iree_devices(device=device, device_count=device_count) - def load_fn(devices: list[iree.runtime.HalDevice]) -> TorchLikeIreeModule: - vm_module, vm_context, vm_instance = load_iree_module( - module_buff=vmfb_bytes, - devices=devices, - parameters_path=parameters_path, - tensor_parallel_size=len(devices), - ) - return TorchLikeIreeModule(vm_module, vm_context, devices) - - return with_iree_device_context(load_fn, iree_devices) - - -@runtime_checkable -class InferenceModule(Protocol): - """Protocol for inference modules (both torch and IREE). - - This defines a common interface that both torch.nn.Module and - TorchLikeIreeModule can satisfy, allowing them to be used - interchangeably in inference code. - - Example: - >>> def run_inference(model: InferenceModule, inputs): - ... return model(inputs) - >>> - >>> # Works with torch modules - >>> torch_model = MyTorchModel() - >>> run_inference(torch_model, x) - >>> - >>> # Also works with IREE modules - >>> iree_model = adapt_torch_module_to_iree(torch_model, ...) - >>> run_inference(iree_model, x) - """ - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - """Execute the module's forward pass.""" - ... - - def forward(self, *args: Any, **kwargs: Any) -> Any: - """Execute the module's forward pass explicitly.""" - ... + vm_module, vm_context, vm_instance = load_iree_module( + module_buff=vmfb_bytes, + devices=iree_devices, + parameters_path=parameters_path, + tensor_parallel_size=len(iree_devices), + ) + return TorchLikeIreeModule(vm_module, vm_context, iree_devices) class TorchLikeIreeModule: diff --git a/sharktank/tests/utils/iree_test.py b/sharktank/tests/utils/iree_test.py index 1e6b688e9a2..6c2b93900ee 100644 --- a/sharktank/tests/utils/iree_test.py +++ b/sharktank/tests/utils/iree_test.py @@ -129,12 +129,12 @@ def roundtrip(iree_devices: list[iree.runtime.HalDevice]): class SimpleModel(nn.Module): def forward(self, x): - return torch.relu(x) + 1.0 + return x + 1 class MultiOutputModel(nn.Module): def forward(self, x): - return torch.relu(x), torch.tanh(x) + return x + 1, x * 2 class TestCompileTorchModule: @@ -167,10 +167,10 @@ def test_compilation(self, tmp_path): class TestAdaptTorchModuleToIree: """Tests for adapt_torch_module_to_iree.""" - def test_basic_loading_and_execution(self): + def test_basic_loading_and_execution(self, deterministic_random_seed): """Test that loaded module executes and produces correct output shape.""" model = SimpleModel() - example_input = torch.randn(2, 32) + example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64) iree_module = adapt_torch_module_to_iree( model, @@ -182,14 +182,12 @@ def test_basic_loading_and_execution(self): result = iree_module.forward(example_input) assert isinstance(result, torch.Tensor) assert result.shape == (2, 32) - assert not torch.isnan(result).any() - def test_output_matches_torch(self): - """Test that IREE output matches torch output.""" - torch.manual_seed(42) + def test_output_matches_torch(self, deterministic_random_seed): + """Test that IREE output matches torch output""" model = SimpleModel() model.eval() - example_input = torch.randn(2, 32) + example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64) torch_output = model(example_input) iree_module = adapt_torch_module_to_iree( @@ -200,7 +198,7 @@ def test_output_matches_torch(self): ) iree_output = iree_module.forward(example_input) - torch.testing.assert_close(iree_output, torch_output, rtol=1e-4, atol=1e-4) + assert torch.equal(iree_output, torch_output) def test_multi_output_model(self): """Test model with multiple outputs.""" @@ -224,10 +222,10 @@ def test_multi_output_model(self): class TestOneshotCompileAndRun: """Tests for oneshot_iree_run.""" - def test_basic_oneshot(self): + def test_basic_oneshot(self, deterministic_random_seed): """Test basic one-shot execution.""" model = SimpleModel() - example_input = torch.randn(2, 32) + example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64) result = oneshot_iree_run( model, @@ -237,14 +235,12 @@ def test_basic_oneshot(self): ) assert isinstance(result, torch.Tensor) assert result.shape == (2, 32) - assert not torch.isnan(result).any() - def test_oneshot_matches_torch(self): + def test_oneshot_matches_torch(self, deterministic_random_seed): """Test that one-shot execution matches torch.""" - torch.manual_seed(42) model = SimpleModel() model.eval() - example_input = torch.randn(2, 32) + example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64) torch_output = model(example_input) iree_output = oneshot_iree_run( @@ -254,4 +250,5 @@ def test_oneshot_matches_torch(self): compile_args=COMPILE_FLAGS, ) - torch.testing.assert_close(iree_output, torch_output, rtol=1e-4, atol=1e-4) + # Use exact comparison for integer arithmetic + assert torch.equal(iree_output, torch_output) From e5017183e959a954a3b620e26afc504407c99af7 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Mon, 27 Oct 2025 14:31:42 -0700 Subject: [PATCH 6/6] Fix end of file --- sharktank/sharktank/types/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sharktank/sharktank/types/module.py b/sharktank/sharktank/types/module.py index 914328d8c25..5c80e2ff864 100644 --- a/sharktank/sharktank/types/module.py +++ b/sharktank/sharktank/types/module.py @@ -34,4 +34,4 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: def forward(self, *args: Any, **kwargs: Any) -> Any: """Execute the module's forward pass explicitly.""" - ... \ No newline at end of file + ...