-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[Snippets][CPU] Introduce MatMul tokenization config, do not tokenize Transpose after MatMul on ARM64 #32592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Snippets][CPU] Introduce MatMul tokenization config, do not tokenize Transpose after MatMul on ARM64 #32592
Conversation
| auto label = ov::pass::pattern::any_input([config](const ov::Output<ov::Node>& out) { | ||
| const auto n = out.get_node_shared_ptr(); | ||
| // Config-aware gating: optionally reject specific Transpose cases around MatMul | ||
| if (ov::is_type<ov::op::v1::Transpose>(n)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this check to is_supported_transpose lambda in is_supported_op. We will need to change their signatures, but I think it is okay
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, done
| // Transpose before MatMul on input 0 | ||
| bool is_supported_transpose_a = true; | ||
| // Transpose before MatMul on input 1 | ||
| bool is_supported_transpose_b = true; | ||
| // Transpose after MatMul on its output | ||
| bool is_supported_transpose_c = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose just boolean values are not enough here: not all types of transposes are actually supported even on X86. Ideally, we need to have lambdas here which will replace is_supported_transpose in collapse_subgraph.cpp and is_valid_transpose in mha_tokenization.cpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point to discuss it. Right now, these flags are just being used in is_supported_transpose. It seemed for me that this way we will get less code duplication here and the logic seems to look simpler. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the original ideas of this PR are to avoid unsupported transposes tokenization + make supported transposes configuration device specific. And the current behavior doesn't fully solves the both problems:
- We still tokenize transposes in MHATokenization and then we have to move them out of the Subgraph
- The supported transposes checks in
MHATokenization, related to X64, are performed for ARM as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disabled all transposes on ARM
| // Disable Transpose after MatMul in general tokenization on ARM64 by default. | ||
| // Keep Transpose before MatMul inputs allowed for flexibility. | ||
| ov::snippets::pass::MatMulConfig mm_cfg; | ||
| mm_cfg.is_supported_transpose_c = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't is_supported_transpose_a/is_supported_transpose_b be false as well on ARM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are being converted to GemmCopyB, actually. Output one is the one that is troublesome right now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really:
- Transpose a is placed on A, not B input. So we don't even insert GemmCopyB there
- Transpose b = true is not supported by GemmCopyB. Please see ExplicitTransposeMatMulInputs callback for the details
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough, fixed
|
Decided to switch to an alternative approach. PR can be found here: #32676 |
### Details: Introduce a callback function for `ExtractUnsupportedTransposes` pass as a part of `CommonOptimizations::Config` to customize pass behavior depending on Transpose support. For example, ARM64 platform supports transpose decomposition, but MatMul with Transpose A/B is not supported so far. Rest of the (potential) platforms mark Transpose as not supported completely An alternative approach for #32592 ### Tickets: - 176061
Details:
Tickets: