Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions sharktank/sharktank/types/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2025 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Any, Protocol, runtime_checkable


@runtime_checkable
class InferenceModule(Protocol):
"""Protocol for inference modules (both torch and IREE).

This defines a common interface that both torch.nn.Module and
TorchLikeIreeModule can satisfy, allowing them to be used
interchangeably in inference code.

Example:
>>> def run_inference(model: InferenceModule, inputs):
... return model(inputs)
>>>
>>> # Works with torch modules
>>> torch_model = MyTorchModel()
>>> run_inference(torch_model, x)
>>>
>>> # Also works with IREE modules
>>> iree_model = adapt_torch_module_to_iree(torch_model, ...)
>>> run_inference(iree_model, x)
"""

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Execute the module's forward pass."""
...

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Execute the module's forward pass explicitly."""
...
212 changes: 187 additions & 25 deletions sharktank/sharktank/utils/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Any, Callable, List, Tuple, Optional, Union, overload, TYPE_CHECKING
from typing import (
Any,
Callable,
List,
Tuple,
Optional,
Union,
overload,
TYPE_CHECKING,
Sequence,
)
import os
import sys
import json
Expand Down Expand Up @@ -72,49 +82,193 @@ def oneshot_iree_run(
device_count: int | None = None,
compile_args: tuple[str, ...] = None,
) -> tuple[torch.Tensor, ...]:
"""All in one: export, compile and run."""
from iree.turbine import aot
from iree.turbine.aot import FxProgramsBuilder

fxb = FxProgramsBuilder(module)

@fxb.export_program(name=function, args=args, kwargs=kwargs, strict=False)
def _(module, *args, **kwargs):
return getattr(module, function)(*args, **kwargs)

export_output = aot.export(
fxb,
"""One-shot function: export, compile, load, and run in one call.
This is useful for quick testing and benchmarking. For repeated use,
prefer adapt_torch_module_to_iree() to reuse the compiled module.
Args:
module: The torch.nn.Module to run
args: Positional arguments to pass to the function
kwargs: Keyword arguments to pass to the function
function: Name of the function to call (default: "forward")
device: IREE device(s) to run on
device_count: Number of devices to use
compile_args: Additional IREE compiler flags
Returns:
Tensor, or tuple of output tensors
Example:
>>> model = MyTorchModel()
>>> input_tensor = torch.randn(1, 3, 224, 224)
>>> outputs = oneshot_compile_and_run(
... model,
... args=(input_tensor,),
... device="local-task"
... )
"""
vmfb_bytes = compile_torch_module_to_iree(
module=module,
example_args=args,
example_kwargs=kwargs,
function_name=function,
compile_args=compile_args,
)
if compile_args is not None:
export_output.session.set_flags(*compile_args)
memory_view: memoryview = export_output.compile(
save_to=None, target_backends=None
).map_memory()

iree_devices = get_iree_devices(device=device, device_count=device_count)

def run(iree_devices: list[iree.runtime.HalDevice]):
vm_module, vm_context, vm_instance = load_iree_module(
module_buff=memory_view, devices=iree_devices
module_buff=vmfb_bytes, devices=iree_devices
)
torch_like_iree_module = TorchLikeIreeModule(
vm_module, vm_context, iree_devices
)
results = getattr(torch_like_iree_module, function)(*args, **kwargs)
# Clone to avoid leaking IREE-backed torch tensors.
results = tuple(t.clone() for t in results)
return results
# Results can be a single tensor or tuple of tensors
if isinstance(results, torch.Tensor):
return results.clone()
else:
return tuple(t.clone() for t in results)

return with_iree_device_context(run, iree_devices)


def compile_torch_module_to_iree(
module: torch.nn.Module,
example_args: tuple[Any, ...] = tuple(),
example_kwargs: dict[str, Any] = None,
function_name: str = "forward",
compile_args: Sequence[str] = None,
save_mlir_to: Optional[Path] = None,
save_vmfb_to: Optional[Path] = None,
) -> memoryview:
"""Compile a torch module using IREE to VMFB bytecode.

Args:
module: The torch.nn.Module to compile
example_args: Example positional arguments for tracing
example_kwargs: Example keyword arguments for tracing
function_name: Name of the function to export (default: "forward")
compile_args: Additional IREE compiler flags
save_mlir_to: Optional path to save the exported MLIR
save_vmfb_to: Optional path to save the compiled VMFB

Returns:
A memoryview of the compiled VMFB bytecode

Example:
>>> model = MyTorchModel()
>>> example_input = torch.randn(1, 3, 224, 224)
>>> vmfb_bytes = compile_torch_module_to_iree(
... model,
... example_args=(example_input,),
... compile_args=["--iree-hal-target-device=local-task"]
... )
"""
from iree.turbine import aot
from iree.turbine.aot import FxProgramsBuilder

if example_kwargs is None:
example_kwargs = {}

fxb = FxProgramsBuilder(module)

@fxb.export_program(
name=function_name, args=example_args, kwargs=example_kwargs, strict=False
)
def _(module, *args, **kwargs):
return getattr(module, function_name)(*args, **kwargs)

export_output = aot.export(fxb)

if save_mlir_to is not None:
export_output.save_mlir(save_mlir_to)

if compile_args is not None:
export_output.session.set_flags(*compile_args)

if save_vmfb_to is not None:
export_output.compile(save_to=str(save_vmfb_to), target_backends=None)
with open(save_vmfb_to, "rb") as f:
return memoryview(f.read())
else:
return export_output.compile(save_to=None, target_backends=None).map_memory()


def adapt_torch_module_to_iree(
module: torch.nn.Module,
example_args: tuple[Any, ...] = tuple(),
example_kwargs: dict[str, Any] = None,
function_name: str = "forward",
device: str | list[str] = "local-task",
device_count: int | None = None,
compile_args: Sequence[str] = None,
parameters_path: Optional[str] = None,
save_mlir_to: Optional[Path] = None,
save_vmfb_to: Optional[Path] = None,
) -> "TorchLikeIreeModule":
"""Compile a torch module to IREE and adapt it as a TorchLikeIreeModule.

This is a high-level convenience function that combines export, compilation,
and adaption into a single call.

Args:
module: The torch.nn.Module to compile
example_args: Example positional arguments for tracing
example_kwargs: Example keyword arguments for tracing
function_name: Name of the function to export (default: "forward")
device: IREE device(s) to load on (e.g., "local-task", "hip://0")
device_count: Number of devices to use (for multi-device scenarios)
compile_args: Additional IREE compiler flags
parameters_path: Optional path to external parameters (IRPA file)
save_mlir_to: Optional path to save the exported MLIR
save_vmfb_to: Optional path to save the compiled VMFB

Returns:
A TorchLikeIreeModule that can be called like the original torch module.
Single outputs are unwrapped, multiple outputs are returned as tuples.

Example:
>>> model = MyTorchModel()
>>> example_input = torch.randn(1, 3, 224, 224)
>>> iree_model = adapt_torch_module_to_iree(
... model,
... example_args=(example_input,),
... device="local-task"
... )
>>> output = iree_model.forward(example_input) # Single tensor, not list/tuple
"""
vmfb_bytes = compile_torch_module_to_iree(
module=module,
example_args=example_args,
example_kwargs=example_kwargs,
function_name=function_name,
compile_args=compile_args,
save_mlir_to=save_mlir_to,
save_vmfb_to=save_vmfb_to,
)

iree_devices = get_iree_devices(device=device, device_count=device_count)

vm_module, vm_context, vm_instance = load_iree_module(
module_buff=vmfb_bytes,
devices=iree_devices,
parameters_path=parameters_path,
tensor_parallel_size=len(iree_devices),
)
return TorchLikeIreeModule(vm_module, vm_context, iree_devices)


class TorchLikeIreeModule:
"""Makes an IREE module look like a torch module. Where it can be called with
Sharktank and Torch tensors.

This handles marshaling of torch tensor and sharktank.type.InferenceTensor arguments.
Unfortunately, we can't marshall the output back to the correct tensor types as
some of the information is lost. E.g. the sharded tensor types. We return a flat
list of torch tensors.
some of the information is lost. E.g. the sharded tensor types.

Returns:
- Single output: Returns the tensor directly
- Multiple outputs: Returns a tuple of tensors
"""

def __init__(
Expand All @@ -127,10 +281,14 @@ def __init__(
self.vm_context = vm_context
self.devices = devices

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Execute forward method (torch.nn.Module compatibility)."""
return self.forward(*args, **kwargs)

def __getattr__(self, name: str) -> Any:
def f(
*args: tuple[Any, ...], **kwargs: dict[str, Any]
) -> tuple[torch.Tensor, ...]:
) -> torch.Tensor | tuple[torch.Tensor, ...]:
flat_args = flatten_for_iree_signature(
(
args,
Expand All @@ -152,7 +310,11 @@ def f(
for arg, iree_arg in zip(flat_args, iree_args_post_call):
arg[...] = iree_arg

return res
# Match torch.nn.Module behavior: single output unwrapped, multiple as tuple
if len(res) == 1:
return res[0]
else:
return tuple(res)

return f

Expand Down
Loading
Loading