diff --git a/sharktank/sharktank/types/module.py b/sharktank/sharktank/types/module.py new file mode 100644 index 00000000000..5c80e2ff864 --- /dev/null +++ b/sharktank/sharktank/types/module.py @@ -0,0 +1,37 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class InferenceModule(Protocol): + """Protocol for inference modules (both torch and IREE). + + This defines a common interface that both torch.nn.Module and + TorchLikeIreeModule can satisfy, allowing them to be used + interchangeably in inference code. + + Example: + >>> def run_inference(model: InferenceModule, inputs): + ... return model(inputs) + >>> + >>> # Works with torch modules + >>> torch_model = MyTorchModel() + >>> run_inference(torch_model, x) + >>> + >>> # Also works with IREE modules + >>> iree_model = adapt_torch_module_to_iree(torch_model, ...) + >>> run_inference(iree_model, x) + """ + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Execute the module's forward pass.""" + ... + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Execute the module's forward pass explicitly.""" + ... diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index daaa11124c6..4efa4cbdaf0 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -4,7 +4,17 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any, Callable, List, Tuple, Optional, Union, overload, TYPE_CHECKING +from typing import ( + Any, + Callable, + List, + Tuple, + Optional, + Union, + overload, + TYPE_CHECKING, + Sequence, +) import os import sys import json @@ -72,49 +82,193 @@ def oneshot_iree_run( device_count: int | None = None, compile_args: tuple[str, ...] = None, ) -> tuple[torch.Tensor, ...]: - """All in one: export, compile and run.""" - from iree.turbine import aot - from iree.turbine.aot import FxProgramsBuilder - - fxb = FxProgramsBuilder(module) - - @fxb.export_program(name=function, args=args, kwargs=kwargs, strict=False) - def _(module, *args, **kwargs): - return getattr(module, function)(*args, **kwargs) - - export_output = aot.export( - fxb, + """One-shot function: export, compile, load, and run in one call. + This is useful for quick testing and benchmarking. For repeated use, + prefer adapt_torch_module_to_iree() to reuse the compiled module. + Args: + module: The torch.nn.Module to run + args: Positional arguments to pass to the function + kwargs: Keyword arguments to pass to the function + function: Name of the function to call (default: "forward") + device: IREE device(s) to run on + device_count: Number of devices to use + compile_args: Additional IREE compiler flags + Returns: + Tensor, or tuple of output tensors + Example: + >>> model = MyTorchModel() + >>> input_tensor = torch.randn(1, 3, 224, 224) + >>> outputs = oneshot_compile_and_run( + ... model, + ... args=(input_tensor,), + ... device="local-task" + ... ) + """ + vmfb_bytes = compile_torch_module_to_iree( + module=module, + example_args=args, + example_kwargs=kwargs, + function_name=function, + compile_args=compile_args, ) - if compile_args is not None: - export_output.session.set_flags(*compile_args) - memory_view: memoryview = export_output.compile( - save_to=None, target_backends=None - ).map_memory() + iree_devices = get_iree_devices(device=device, device_count=device_count) def run(iree_devices: list[iree.runtime.HalDevice]): vm_module, vm_context, vm_instance = load_iree_module( - module_buff=memory_view, devices=iree_devices + module_buff=vmfb_bytes, devices=iree_devices ) torch_like_iree_module = TorchLikeIreeModule( vm_module, vm_context, iree_devices ) results = getattr(torch_like_iree_module, function)(*args, **kwargs) # Clone to avoid leaking IREE-backed torch tensors. - results = tuple(t.clone() for t in results) - return results + # Results can be a single tensor or tuple of tensors + if isinstance(results, torch.Tensor): + return results.clone() + else: + return tuple(t.clone() for t in results) return with_iree_device_context(run, iree_devices) +def compile_torch_module_to_iree( + module: torch.nn.Module, + example_args: tuple[Any, ...] = tuple(), + example_kwargs: dict[str, Any] = None, + function_name: str = "forward", + compile_args: Sequence[str] = None, + save_mlir_to: Optional[Path] = None, + save_vmfb_to: Optional[Path] = None, +) -> memoryview: + """Compile a torch module using IREE to VMFB bytecode. + + Args: + module: The torch.nn.Module to compile + example_args: Example positional arguments for tracing + example_kwargs: Example keyword arguments for tracing + function_name: Name of the function to export (default: "forward") + compile_args: Additional IREE compiler flags + save_mlir_to: Optional path to save the exported MLIR + save_vmfb_to: Optional path to save the compiled VMFB + + Returns: + A memoryview of the compiled VMFB bytecode + + Example: + >>> model = MyTorchModel() + >>> example_input = torch.randn(1, 3, 224, 224) + >>> vmfb_bytes = compile_torch_module_to_iree( + ... model, + ... example_args=(example_input,), + ... compile_args=["--iree-hal-target-device=local-task"] + ... ) + """ + from iree.turbine import aot + from iree.turbine.aot import FxProgramsBuilder + + if example_kwargs is None: + example_kwargs = {} + + fxb = FxProgramsBuilder(module) + + @fxb.export_program( + name=function_name, args=example_args, kwargs=example_kwargs, strict=False + ) + def _(module, *args, **kwargs): + return getattr(module, function_name)(*args, **kwargs) + + export_output = aot.export(fxb) + + if save_mlir_to is not None: + export_output.save_mlir(save_mlir_to) + + if compile_args is not None: + export_output.session.set_flags(*compile_args) + + if save_vmfb_to is not None: + export_output.compile(save_to=str(save_vmfb_to), target_backends=None) + with open(save_vmfb_to, "rb") as f: + return memoryview(f.read()) + else: + return export_output.compile(save_to=None, target_backends=None).map_memory() + + +def adapt_torch_module_to_iree( + module: torch.nn.Module, + example_args: tuple[Any, ...] = tuple(), + example_kwargs: dict[str, Any] = None, + function_name: str = "forward", + device: str | list[str] = "local-task", + device_count: int | None = None, + compile_args: Sequence[str] = None, + parameters_path: Optional[str] = None, + save_mlir_to: Optional[Path] = None, + save_vmfb_to: Optional[Path] = None, +) -> "TorchLikeIreeModule": + """Compile a torch module to IREE and adapt it as a TorchLikeIreeModule. + + This is a high-level convenience function that combines export, compilation, + and adaption into a single call. + + Args: + module: The torch.nn.Module to compile + example_args: Example positional arguments for tracing + example_kwargs: Example keyword arguments for tracing + function_name: Name of the function to export (default: "forward") + device: IREE device(s) to load on (e.g., "local-task", "hip://0") + device_count: Number of devices to use (for multi-device scenarios) + compile_args: Additional IREE compiler flags + parameters_path: Optional path to external parameters (IRPA file) + save_mlir_to: Optional path to save the exported MLIR + save_vmfb_to: Optional path to save the compiled VMFB + + Returns: + A TorchLikeIreeModule that can be called like the original torch module. + Single outputs are unwrapped, multiple outputs are returned as tuples. + + Example: + >>> model = MyTorchModel() + >>> example_input = torch.randn(1, 3, 224, 224) + >>> iree_model = adapt_torch_module_to_iree( + ... model, + ... example_args=(example_input,), + ... device="local-task" + ... ) + >>> output = iree_model.forward(example_input) # Single tensor, not list/tuple + """ + vmfb_bytes = compile_torch_module_to_iree( + module=module, + example_args=example_args, + example_kwargs=example_kwargs, + function_name=function_name, + compile_args=compile_args, + save_mlir_to=save_mlir_to, + save_vmfb_to=save_vmfb_to, + ) + + iree_devices = get_iree_devices(device=device, device_count=device_count) + + vm_module, vm_context, vm_instance = load_iree_module( + module_buff=vmfb_bytes, + devices=iree_devices, + parameters_path=parameters_path, + tensor_parallel_size=len(iree_devices), + ) + return TorchLikeIreeModule(vm_module, vm_context, iree_devices) + + class TorchLikeIreeModule: """Makes an IREE module look like a torch module. Where it can be called with Sharktank and Torch tensors. This handles marshaling of torch tensor and sharktank.type.InferenceTensor arguments. Unfortunately, we can't marshall the output back to the correct tensor types as - some of the information is lost. E.g. the sharded tensor types. We return a flat - list of torch tensors. + some of the information is lost. E.g. the sharded tensor types. + + Returns: + - Single output: Returns the tensor directly + - Multiple outputs: Returns a tuple of tensors """ def __init__( @@ -127,10 +281,14 @@ def __init__( self.vm_context = vm_context self.devices = devices + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Execute forward method (torch.nn.Module compatibility).""" + return self.forward(*args, **kwargs) + def __getattr__(self, name: str) -> Any: def f( *args: tuple[Any, ...], **kwargs: dict[str, Any] - ) -> tuple[torch.Tensor, ...]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: flat_args = flatten_for_iree_signature( ( args, @@ -152,7 +310,11 @@ def f( for arg, iree_arg in zip(flat_args, iree_args_post_call): arg[...] = iree_arg - return res + # Match torch.nn.Module behavior: single output unwrapped, multiple as tuple + if len(res) == 1: + return res[0] + else: + return tuple(res) return f diff --git a/sharktank/tests/utils/iree_test.py b/sharktank/tests/utils/iree_test.py index 77d4fe98f26..6c2b93900ee 100644 --- a/sharktank/tests/utils/iree_test.py +++ b/sharktank/tests/utils/iree_test.py @@ -9,6 +9,7 @@ import pytest import platform import torch +import torch.nn as nn from parameterized import parameterized from pathlib import Path @@ -16,8 +17,11 @@ from sharktank.types import DefaultPrimitiveTensor from sharktank.utils import chdir from sharktank.utils.iree import ( + adapt_torch_module_to_iree, + compile_torch_module_to_iree, device_array_to_host, get_iree_devices, + oneshot_iree_run, run_model_with_iree_run_module, tensor_to_device_array, trace_model_with_tracy, @@ -115,3 +119,136 @@ def roundtrip(iree_devices: list[iree.runtime.HalDevice]): assert ops.equal(tensor, tensor_roundtrip) with_iree_device_context(roundtrip, iree_devices) + + +COMPILE_FLAGS = [ + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", +] + + +class SimpleModel(nn.Module): + def forward(self, x): + return x + 1 + + +class MultiOutputModel(nn.Module): + def forward(self, x): + return x + 1, x * 2 + + +class TestCompileTorchModule: + """Tests for compile_torch_module_to_iree.""" + + def test_compilation(self, tmp_path): + """Test compilation with optional artifact saving.""" + model = SimpleModel() + example_input = torch.randn(2, 32) + mlir_path = tmp_path / "model.mlir" + vmfb_path = tmp_path / "model.vmfb" + + vmfb_bytes = compile_torch_module_to_iree( + model, + example_args=(example_input,), + compile_args=COMPILE_FLAGS, + save_mlir_to=mlir_path, + save_vmfb_to=vmfb_path, + ) + + assert len(vmfb_bytes) > 100 + + assert mlir_path.exists() + assert mlir_path.stat().st_size > 0 + + assert vmfb_path.exists() + assert vmfb_path.stat().st_size > 0 + + +class TestAdaptTorchModuleToIree: + """Tests for adapt_torch_module_to_iree.""" + + def test_basic_loading_and_execution(self, deterministic_random_seed): + """Test that loaded module executes and produces correct output shape.""" + model = SimpleModel() + example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64) + + iree_module = adapt_torch_module_to_iree( + model, + example_args=(example_input,), + device="local-sync", + compile_args=COMPILE_FLAGS, + ) + + result = iree_module.forward(example_input) + assert isinstance(result, torch.Tensor) + assert result.shape == (2, 32) + + def test_output_matches_torch(self, deterministic_random_seed): + """Test that IREE output matches torch output""" + model = SimpleModel() + model.eval() + example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64) + + torch_output = model(example_input) + iree_module = adapt_torch_module_to_iree( + model, + example_args=(example_input,), + device="local-sync", + compile_args=COMPILE_FLAGS, + ) + iree_output = iree_module.forward(example_input) + + assert torch.equal(iree_output, torch_output) + + def test_multi_output_model(self): + """Test model with multiple outputs.""" + model = MultiOutputModel() + example_input = torch.randn(2, 32) + + iree_module = adapt_torch_module_to_iree( + model, + example_args=(example_input,), + device="local-sync", + compile_args=COMPILE_FLAGS, + ) + + result = iree_module.forward(example_input) + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0].shape == (2, 32) + assert result[1].shape == (2, 32) + + +class TestOneshotCompileAndRun: + """Tests for oneshot_iree_run.""" + + def test_basic_oneshot(self, deterministic_random_seed): + """Test basic one-shot execution.""" + model = SimpleModel() + example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64) + + result = oneshot_iree_run( + model, + args=(example_input,), + device="local-sync", + compile_args=COMPILE_FLAGS, + ) + assert isinstance(result, torch.Tensor) + assert result.shape == (2, 32) + + def test_oneshot_matches_torch(self, deterministic_random_seed): + """Test that one-shot execution matches torch.""" + model = SimpleModel() + model.eval() + example_input = torch.randint(0, 100, (2, 32), dtype=torch.int64) + + torch_output = model(example_input) + iree_output = oneshot_iree_run( + model, + args=(example_input,), + device="local-sync", + compile_args=COMPILE_FLAGS, + ) + + # Use exact comparison for integer arithmetic + assert torch.equal(iree_output, torch_output)