-
Notifications
You must be signed in to change notification settings - Fork 69
Add utilities for automatically converting a torch module to a IREE-backed module-type #2526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Coverage reportClick to see where and how coverage changed
This report was generated by python-coverage-comment-action |
||||||||||||||||||||||||||||||||||||
sogartar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding good documentation.
sharktank/sharktank/utils/iree.py
Outdated
| return results | ||
| return TorchLikeIreeModule(vm_module, vm_context, devices) | ||
|
|
||
| return with_iree_device_context(load_fn, iree_devices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you could remove the usage of with_iree_device_context and just return the TorchLikeIreeModule.
The purpose of with_iree_device_context is to signal and prevent leaking IREE backed torch tensors.
The problem I observed is that it seems the Torch runtime has its own mind of how to manage the life-time of the tensor resource.
It is all speculative, but I think it does not free immediately the underlying buffers even though no Python variables reference the tensors and the garbage collector is called explicitly. Then if your IREE device gets destroyed first you will get a crash.
But with_iree_device_context does not make any effort to enforce no leaking IREE backed Torch tensor are returned. The author needs to take care of this. It does not have teeth.
To be honest in the future we probably need a more thorough investigation into this buffer life-time nonsense to be 100% that we need this complication. I don't remember if I did it correctly. Something like:
- Make an IREE device.
- Make an IREE tensor.
- Convert it to torch tensor such that is backed by the IREE device buffer.
- Delete all Python variables referencing the IREE buffer directly or indirectly (except the IREE device).
- Call the Python GC.
- Observe if the IREE buffer gets destroyed through logging or running with a debugger. This is what needs to happen in a sane World.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you could remove the usage of with_iree_device_context and just return the TorchLikeIreeModule.
This part, done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I appreciate the explanation, but its unclear to me if the rest of this is actionable, so I'm currently leaving this thread open.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not actionable on this PR. I guess it was just a rant.
sharktank/sharktank/utils/iree.py
Outdated
| def __call__(self, *args: Any, **kwargs: Any) -> Any: | ||
| """Execute the module's forward pass.""" | ||
| ... | ||
|
|
||
| def forward(self, *args: Any, **kwargs: Any) -> Any: | ||
| """Execute the module's forward pass explicitly.""" | ||
| ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could have other functions, not just forward. Also there my be no forward at all.
The @runtime_checkable decorator would allow us to use isinstance, but if a module lacks a forward method, isinstance would return False.
I am not sure how to structurally say that something is a module. It can have anything really.
The only implicit assumption coming from torch.nn.Module that I can think of is that
module(...)
should be equivalent to
module.forward(...)
They should maybe always accept/return tensors as sharktank.AnyTensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that TorchLikeIreeModule will not always have a forward (though maybe it should, i.e. you declare one or it picks a random/first one, etc). I also agree that there could be more functions.
I'm not sure how to make this perfect, but I think having it is better than not. In particular, I'd like to have some basic type safety for a project of attempting a prototype of eager on GPU that can compile some/most/all of the classes in layers/ to move most of the workload to GPU even if there's a lot of pieces that don't compile or requires moving a lot of data back and forth between CPU and GPU.
Modules can accept parameters and return values that are not Tensors or AnyTensors, so I'm hesitant to create such a restriction. It might be true that those are the only cases we care about, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that
TorchLikeIreeModulewill not always have aforward(though maybe it should, i.e. you declare one or it picks a random/first one, etc). I also agree that there could be more functions.
For example our PagedLlmModelV1 does not have a forward method. You will not be able to use isinstance(llm_module, InferenceModule) on it. We would need another mechanism.
Maybe we can use ABC instead of Protocol and use ABC.register to "tag" classes as such a type.
If we can't express accurately what InferenceModule is with structural subtyping then we probably should not use a Protocol.
If we are going to use this only as a type hit, then we can just make Union[torch.nn.Module, TorchLikeIreeModule]. I think we can expand on it later.
Modules can accept parameters and return values that are not Tensors or AnyTensors, so I'm hesitant to create such a restriction. It might be true that those are the only cases we care about, though.
I meant if they have tensor args/results that they should be sharktank.AnyTensor.
cf01e3e to
6a6c0fd
Compare
No description provided.