Skip to content

[Codegen] Fix ArgCompare vectorization#23775

Open
bangtianliu wants to merge 1 commit intoiree-org:mainfrom
bangtianliu:fix-argcompare-batched-vectorization
Open

[Codegen] Fix ArgCompare vectorization#23775
bangtianliu wants to merge 1 commit intoiree-org:mainfrom
bangtianliu:fix-argcompare-batched-vectorization

Conversation

@bangtianliu
Copy link
Contributor

@bangtianliu bangtianliu commented Mar 13, 2026

This PR fixes ArgCompare vectorization to use input shape instead of the output shape for inferring vector sizes.

The existing implementation uses inferSizesFromIR(argCompareOp.getDpsInits()[0]), which fails for rank-reducing operations. Inferring from the output returns incorrect sizes, especially when outputs come from rank-reducing tensor.extract_slice operations.

Example failure case:

 %extracted_slice_2 = tensor.extract_slice %extracted_slice[0] [1] [1] : tensor<1xf32> to tensor<f32>
 %32:2 = iree_linalg_ext.arg_compare dimension(0)
     ins(%28, %31 : tensor<1024xf32>, tensor<1024xi64>)
     outs(%extracted_slice_2, %extracted_slice_4 : tensor<f32>, tensor<i64>)

For reduction ops, inferSizesFromIR(linalgOp, ...) uses linalgOp.getNumLoops() to map ALL loop dimensions (parallel + reduction) to operands.

Assisted-by: Claude Code

@bangtianliu bangtianliu force-pushed the fix-argcompare-batched-vectorization branch 2 times, most recently from c7d3608 to c36f9af Compare March 13, 2026 05:54
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
@bangtianliu bangtianliu force-pushed the fix-argcompare-batched-vectorization branch from c36f9af to 04423dc Compare March 13, 2026 06:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant