Skip to content
88 changes: 57 additions & 31 deletions python/iron/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def decorator(*args, **kwargs):
return cached_kernel(*args, **kwargs)

# Clear any instances from previous runs to make sure if the user provided any broken code we don't try to recompile it
ExternalFunction._instances.clear()
ExternalFunction.clear()

# Find ExternalFunction instances in arguments and kwargs
external_kernels = []
Expand All @@ -226,10 +226,7 @@ def decorator(*args, **kwargs):

# Compile all ExternalFunction instances that were created during this JIT compilation
for func in ExternalFunction._instances:
if (
not hasattr(func, "_compiled") or not func._compiled
): # Don't compile if already compiled
external_kernels.append(func)
external_kernels.append(func)

# Determine target architecture based on device type
try:
Expand Down Expand Up @@ -312,9 +309,47 @@ def compile_external_kernel(func, kernel_dir, target_arch):
kernel_dir: Directory to place the compiled object file
target_arch: Target architecture (e.g., "aie2" or "aie2p")
"""
# Skip if already compiled
if hasattr(func, "_compiled") and func._compiled:
return
# Check if we can reuse a cached object file
if (
hasattr(func, "_cache_key")
and func._cache_key
and func._cache_key in func._cache
):
cached_info = func._cache[func._cache_key]
cached_object_file = cached_info["object_file_name"]

# Check if object file already exists in current kernel directory
current_output_path = os.path.join(kernel_dir, cached_object_file)
if os.path.exists(current_output_path):
# Object file already exists locally, just use it
func._object_file_name = cached_object_file
return

# Copy the cached object file to the current kernel directory
# This happens when we have a code object that is already cached from previous
# runs. The issue is that ExternalFunction objects get resolved during MLIR
# generation, but each JIT call creates different MLIR modules (different hashes),
# so they end up in different cache directories. However, the ExternalFunction
# cache is global and contains object files from previous directories. We can't
# just reference the old path because the linker runs in the new directory,
# so we must copy the cached object file to the current kernel directory.
# We also can't simply clear the code object cache between JIT compilations
# because ExternalFunctions get resolved when the worker gets resolved during
# MLIR generation. If we clear the instances, they no longer exist and we
# don't see the ExternalFunction anymore, breaking the compilation pipeline.
cached_source_dir = cached_info.get("source_dir", kernel_dir)
cached_source_path = os.path.join(cached_source_dir, cached_object_file)

if os.path.exists(cached_source_path):
# Copy object file to current kernel directory
shutil.copy2(cached_source_path, current_output_path)

# Update the function to use the local copy
func._object_file_name = cached_object_file
return
else:
# Cached object file doesn't exist, remove from cache and recompile
del func._cache[func._cache_key]

# Check if object file already exists in kernel directory
output_file = os.path.join(kernel_dir, func._object_file_name)
Expand All @@ -324,26 +359,13 @@ def compile_external_kernel(func, kernel_dir, target_arch):
# Create source file in kernel directory
source_file = os.path.join(kernel_dir, f"{func._name}.cc")

# Handle both source_string and source_file cases
if func._source_string is not None:
# Use source_string (write to file)
try:
with open(source_file, "w") as f:
f.write(func._source_string)
except Exception as e:
raise
elif func._source_file is not None:
# Use source_file (copy existing file)
# Check if source file exists before copying
if os.path.exists(func._source_file):
try:
shutil.copy2(func._source_file, source_file)
except Exception as e:
raise
else:
return
else:
raise ValueError("Neither source_string nor source_file is provided")
# Get source content
try:
source_content = func._get_source_content()
with open(source_file, "w") as f:
f.write(source_content)
except Exception as e:
raise

from .compile.compile import compile_cxx_core_function

Expand All @@ -357,12 +379,16 @@ def compile_external_kernel(func, kernel_dir, target_arch):
cwd=kernel_dir,
verbose=False,
)

# Only add to cache after successful compilation
if hasattr(func, "_cache_key") and func._cache_key:
# Store both object file name and source directory for future copying
func.add_to_cache(func._cache_key, func._object_file_name, kernel_dir)

except Exception as e:
# Don't add to cache if compilation failed
raise

# Mark the function as compiled
func._compiled = True


def hash_module(module, external_kernels=None, target_arch=None):
"""
Expand Down
76 changes: 71 additions & 5 deletions python/iron/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# (c) Copyright 2024 Advanced Micro Devices, Inc.
# (c) Copyright 2024-2025 Advanced Micro Devices, Inc.

import numpy as np
import hashlib

from .. import ir # type: ignore
from ..extras.dialects.ext.func import FuncOp # type: ignore
Expand Down Expand Up @@ -76,7 +77,8 @@ def resolve(


class ExternalFunction(BaseKernel):
_instances = set()
_instances = set() # A set of all instances of ExternalFunction
_cache = {} # Cache for compiled functions based on source hash

def __init__(
self,
Expand All @@ -93,7 +95,7 @@ def __init__(

Args:
name (str): The name of the function
object_file_name (str, optional): The name of the object file. If None, it will be name.o.
object_file_name (str, optional): The name of the object file. If None, it will be name.o. If provided, this bypasses all caching mechanisms and uses the exact filename specified.
source_file (str): Path to the C/C++ source file
source_string (str): C/C++ source code as a string
arg_types (list[type[np.ndarray] | np.dtype], optional): The type signature of the function. Defaults to [].
Expand All @@ -104,15 +106,49 @@ def __init__(
self._setup_source(source_file, source_string)
self._include_dirs = include_dirs
self._compile_flags = compile_flags

# If object_file_name is provided, bypass caching entirely
if object_file_name:
self._object_file_name = object_file_name
self._cache_key = None # No cache key needed when bypassing cache
else:
self._object_file_name = f"{self._name}.o"
self._compiled = False
# Generate a hash-based cache key for this function
self._cache_key = self._generate_cache_key()

# Check if we can reuse a cached object file
if self._cache_key in ExternalFunction._cache:
cached_info = ExternalFunction._cache[self._cache_key]
self._object_file_name = cached_info["object_file_name"]
else:
# We never compile the code object until we resolve the ExternalFunction
# so we use the cache key as the object file name to avoid having two object files or
# having to ask the user to use the same object file name for both ExternalFunctions.
self._object_file_name = self._cache_key

# Track this instance for JIT compilation
ExternalFunction._instances.add(self)

def _generate_cache_key(self) -> str:
"""Generate a unique cache key based on source content and compilation parameters."""
# Create a hash of the source content and compilation parameters
content_to_hash = []

# Get the source content as a string
try:
source_content = self._get_source_content()
content_to_hash.append(source_content)
except RuntimeError as e:
# If we can't read the source file, this is a critical error
raise RuntimeError(f"Failed to read source content for cache key: {e}")

# Include compilation parameters in the hash
content_to_hash.extend([str(self._include_dirs), str(self._compile_flags)])

# Create a hash of all the content
combined_content = "".join(content_to_hash)
computed_hash = hashlib.md5(combined_content.encode("utf-8")).hexdigest()
return computed_hash

def _setup_source(self, source_file: str | None, source_string: str | None) -> None:
"""Set up the source file for compilation."""
if source_file is not None:
Expand All @@ -124,6 +160,21 @@ def _setup_source(self, source_file: str | None, source_string: str | None) -> N
self._source_file = None
self._source_string = source_string

def _get_source_content(self) -> str:
"""Get the source content as a string."""
if self._source_string is not None:
return self._source_string
elif self._source_file is not None:
try:
with open(self._source_file, "r") as f:
return f.read()
except (IOError, OSError) as e:
raise RuntimeError(
f"Failed to read source file '{self._source_file}': {e}"
)
else:
raise RuntimeError("No source content available")

def __enter__(self):
"""Enter the context."""
return self
Expand All @@ -132,6 +183,11 @@ def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context."""
pass

@classmethod
def clear(cls):
"""Clear all instances."""
cls._instances.clear()

@property
def bin_name(self) -> str:
return self._object_file_name
Expand Down Expand Up @@ -218,3 +274,13 @@ def __call__(self, *args, **kwargs):
if not self._op:
raise ValueError("Need to resolve ExternalFunction before it can be called")
call(self._op, args, **kwargs)

@classmethod
def add_to_cache(
cls, cache_key: str, object_file_name: str, source_dir: str = None
):
"""Add a compiled function to the cache."""
cache_entry = {"object_file_name": object_file_name}
if source_dir:
cache_entry["source_dir"] = source_dir
cls._cache[cache_key] = cache_entry
Loading
Loading