Skip to content

Conversation

@KyleHerndon
Copy link
Contributor

No description provided.

@github-actions
Copy link
Contributor

github-actions bot commented Oct 16, 2025

Coverage report

Click to see where and how coverage changed

FileStatementsMissingCoverageCoverage
(new stmts)
Lines missing
  sharktank/sharktank/utils
  iree.py 130, 286
  sharktank/tests/utils
  iree_test.py
Project Total  

This report was generated by python-coverage-comment-action

@KyleHerndon KyleHerndon changed the title [WIP] Modulify v0 Add utilities for automatically converting a torch module to a IREE-backed module-type Oct 16, 2025
Copy link
Contributor

@sogartar sogartar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding good documentation.

return results
return TorchLikeIreeModule(vm_module, vm_context, devices)

return with_iree_device_context(load_fn, iree_devices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could remove the usage of with_iree_device_context and just return the TorchLikeIreeModule.

The purpose of with_iree_device_context is to signal and prevent leaking IREE backed torch tensors.
The problem I observed is that it seems the Torch runtime has its own mind of how to manage the life-time of the tensor resource.
It is all speculative, but I think it does not free immediately the underlying buffers even though no Python variables reference the tensors and the garbage collector is called explicitly. Then if your IREE device gets destroyed first you will get a crash.
But with_iree_device_context does not make any effort to enforce no leaking IREE backed Torch tensor are returned. The author needs to take care of this. It does not have teeth.

To be honest in the future we probably need a more thorough investigation into this buffer life-time nonsense to be 100% that we need this complication. I don't remember if I did it correctly. Something like:

  1. Make an IREE device.
  2. Make an IREE tensor.
  3. Convert it to torch tensor such that is backed by the IREE device buffer.
  4. Delete all Python variables referencing the IREE buffer directly or indirectly (except the IREE device).
  5. Call the Python GC.
  6. Observe if the IREE buffer gets destroyed through logging or running with a debugger. This is what needs to happen in a sane World.

Copy link
Contributor Author

@KyleHerndon KyleHerndon Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could remove the usage of with_iree_device_context and just return the TorchLikeIreeModule.

This part, done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate the explanation, but its unclear to me if the rest of this is actionable, so I'm currently leaving this thread open.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not actionable on this PR. I guess it was just a rant.

Comment on lines 287 to 293
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Execute the module's forward pass."""
...

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Execute the module's forward pass explicitly."""
...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have other functions, not just forward. Also there my be no forward at all.
The @runtime_checkable decorator would allow us to use isinstance, but if a module lacks a forward method, isinstance would return False.

I am not sure how to structurally say that something is a module. It can have anything really.
The only implicit assumption coming from torch.nn.Module that I can think of is that

module(...)

should be equivalent to

module.forward(...)

They should maybe always accept/return tensors as sharktank.AnyTensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that TorchLikeIreeModule will not always have a forward (though maybe it should, i.e. you declare one or it picks a random/first one, etc). I also agree that there could be more functions.

I'm not sure how to make this perfect, but I think having it is better than not. In particular, I'd like to have some basic type safety for a project of attempting a prototype of eager on GPU that can compile some/most/all of the classes in layers/ to move most of the workload to GPU even if there's a lot of pieces that don't compile or requires moving a lot of data back and forth between CPU and GPU.

Modules can accept parameters and return values that are not Tensors or AnyTensors, so I'm hesitant to create such a restriction. It might be true that those are the only cases we care about, though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that TorchLikeIreeModule will not always have a forward (though maybe it should, i.e. you declare one or it picks a random/first one, etc). I also agree that there could be more functions.

For example our PagedLlmModelV1 does not have a forward method. You will not be able to use isinstance(llm_module, InferenceModule) on it. We would need another mechanism.
Maybe we can use ABC instead of Protocol and use ABC.register to "tag" classes as such a type.
If we can't express accurately what InferenceModule is with structural subtyping then we probably should not use a Protocol.
If we are going to use this only as a type hit, then we can just make Union[torch.nn.Module, TorchLikeIreeModule]. I think we can expand on it later.

Modules can accept parameters and return values that are not Tensors or AnyTensors, so I'm hesitant to create such a restriction. It might be true that those are the only cases we care about, though.

I meant if they have tensor args/results that they should be sharktank.AnyTensor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants