-
Notifications
You must be signed in to change notification settings - Fork 30
【Hackathon 9th No.113】torch._C._fft.fft_irfft API转换 torch.fft.irfft #323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
| # TODO | ||
| def fft_irfft_to_irfft(self, gm): | ||
| def replace_in_graph(graph_mod): | ||
| # 在 GraphModule 上注册稳定实现,codegen 可以使用 self.irfft(有时) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码里的注释还是写英文。用大模型翻译一下吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件就不放到代码库里了,放到本pr对话框里就行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些原始日志不需要上传
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原始日志不需要上传
|
|
||
| return gm | ||
|
|
||
| # def check_unstable_api(self, gm): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_unstable_api 这里的逻辑复原,放到下一次pr来更改。
check_unstable_api是裁判员,而unstable_to_stable是运动员。我们不在同一个pr里同时改裁判员和运动员
lixinqi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本Pr得改一下这里
https://github.com/PaddlePaddle/GraphNet/blob/develop/graph_net/torch/fx_graph_serialize_util.py#L22 这样新抽取的计算图就不会有torch._C._fft.fft_irfft了
| self.unstable_api = unstable_api | ||
|
|
||
| def my_backend(gm, sample_inputs): | ||
| gm = self.unstable_to_stable(gm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里暂时不要改。
|
|
||
| def my_backend(gm, sample_inputs): | ||
| gm = self.unstable_to_stable(gm) | ||
| gm = self.fft_irfft_to_irfft(gm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的逻辑写到unstable_to_stable里。
|
@G2uge 可以扫码进群,研发大哥在群里等着你 |
|
hi, @G2uge
|



Description
我对接口的检测逻辑进行了修改,原因是 PyTorch 2.5.1 及以上版本的 FX 代码生成机制会在生成代码字符串时保留原始 C API 路径(如 torch._C._fft.fft_irfft),即使节点 target 已经被替换为稳定 API,基于字符串的检测也会出现误报,无法准确反映模型实际运行时是否还会调用不稳定 API。为此,我将检测逻辑改为遍历 FX Graph 的所有节点 target,只要节点 target 包含不允许的 API 就报错。这样修改后,检测逻辑直接检查 FX Graph 中每个节点实际要调用的 target,而不是依赖代码字符串。即使 gm.code或其他生成的代码文本中还存在 torch._C._fft.fft_irfft这样的不稳定 API 名称,只要节点的 target 已经被替换为稳定的 torch.fft.irfft,模型实际运行时就会调用稳定接口,而不会再调用不稳定的实现。因此,这种检测方式能够准确反映模型实际运行时的安全性,避免了因代码生成机制导致的误判。换句话说,即使代码字符串中出现了不稳定 API 的名字,只要节点 target 已经替换,实际运行时调用的就是稳定接口,模型的安全性和正确性都能得到保障。
在 fft_irfft_to_irfft的实现中,我递归遍历主 GraphModule 及其所有子模块,对每个节点进行检查,如果发现调用了不稳定的 fft_irfft 接口,就将其 target 替换为稳定的 torch.fft.irfft。替换后重新编译图,确保后续模型运行时只会调用稳定 API。