Skip to content

Modify fusion_emitter_large_test to work on ROCm.#568

Merged
i-chaochen merged 2 commits into
rocm-jaxlib-v0.8.0from
rocm-jaxlib-v0.8.0-fusion_emitter_large_test
Jan 27, 2026
Merged

Modify fusion_emitter_large_test to work on ROCm.#568
i-chaochen merged 2 commits into
rocm-jaxlib-v0.8.0from
rocm-jaxlib-v0.8.0-fusion_emitter_large_test

Conversation

@zoranjovanovic-ns
Copy link
Copy Markdown

Motivation

Failing fusion_emitter_large_test

Technical Details

Skipping only the tests that require more than 4gb of memory.

Test Plan

Run fusion_emitter_large_test

Test Result

Test pass

Submission Checklist

Copy link
Copy Markdown
Collaborator

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@i-chaochen i-chaochen added cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. rocm-jaxlib-v0.8.0 labels Jan 26, 2026
Copy link
Copy Markdown
Collaborator

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please address the build warning like a841f7a#diff-44b6e626114acd84457b53d80d1a06fd6484b01db078382653921feb7c3dfee8R346-R351

[2026-01-23T19:42:59.621Z] ERROR: /tf/xla/xla/backends/gpu/codegen/triton/BUILD:1037:11: Compiling xla/backends/gpu/codegen/triton/support_legacy.cc failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing CppCompile command (from target //xla/backends/gpu/codegen/triton:support) external/local_config_rocm/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer ... (remaining 320 arguments skipped)

[2026-01-23T19:42:59.621Z] xla/backends/gpu/codegen/triton/support_legacy.cc:250:5: error: unannotated fall-through between switch labels [-Werror,-Wimplicit-fallthrough]

[2026-01-23T19:42:59.621Z]   250 |     case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:

[2026-01-23T19:42:59.621Z]       |     ^

[2026-01-23T19:42:59.621Z] xla/backends/gpu/codegen/triton/support_legacy.cc:250:5: note: insert 'ABSL_FALLTHROUGH_INTENDED;' to silence this warning

[2026-01-23T19:42:59.621Z]   250 |     case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:

[2026-01-23T19:42:59.621Z]       |     ^

[2026-01-23T19:42:59.621Z]       |     ABSL_FALLTHROUGH_INTENDED; 

[2026-01-23T19:42:59.621Z] xla/backends/gpu/codegen/triton/support_legacy.cc:250:5: note: insert 'break;' to avoid fall-through

[2026-01-23T19:42:59.621Z]   250 |     case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:

[2026-01-23T19:42:59.621Z]       |     ^

[2026-01-23T19:42:59.621Z]       |     break; 

[2026-01-23T19:42:59.621Z] xla/backends/gpu/codegen/triton/support_legacy.cc:263:5: error: unannotated fall-through between switch labels [-Werror,-Wimplicit-fallthrough]

[2026-01-23T19:42:59.621Z]   263 |     default:

[2026-01-23T19:42:59.621Z]       |     ^

[2026-01-23T19:42:59.621Z] xla/backends/gpu/codegen/triton/support_legacy.cc:263:5: note: insert 'ABSL_FALLTHROUGH_INTENDED;' to silence this warning

[2026-01-23T19:42:59.621Z]   263 |     default:

[2026-01-23T19:42:59.621Z]       |     ^

[2026-01-23T19:42:59.621Z]       |     ABSL_FALLTHROUGH_INTENDED; 

[2026-01-23T19:42:59.621Z] xla/backends/gpu/codegen/triton/support_legacy.cc:263:5: note: insert 'break;' to avoid fall-through

[2026-01-23T19:42:59.621Z]   263 |     default:

[2026-01-23T19:42:59.621Z]       |     ^

[2026-01-23T19:42:59.621Z]       |     break; 

[2026-01-23T19:42:59.621Z] 2 errors generated.


@zoranjovanovic-ns
Copy link
Copy Markdown
Author

please address the build warning like a841f7a#diff-44b6e626114acd84457b53d80d1a06fd6484b01db078382653921feb7c3dfee8R346-R351

Already addressed here:
#562

If needed I can add that here.

@zoranjovanovic-ns
Copy link
Copy Markdown
Author

Updated this PR to contain:
#562

@i-chaochen i-chaochen merged commit e988c42 into rocm-jaxlib-v0.8.0 Jan 27, 2026
6 of 8 checks passed
nurmukhametov pushed a commit that referenced this pull request Jan 27, 2026
* Modify fusion_emitter_large_test to work on ROCm.

* Fix fall-through warning in support_legacy.cc
i-chaochen added a commit that referenced this pull request Feb 9, 2026
* [ROCm] Build infrastructure and CI scripts

* Fix infinite recursion in HloInstruction::Accept/Visit const wrappers (#470)

The const wrapper methods for Accept() and Visit() were calling themselves
instead of the template versions, causing infinite recursion and stack overflow.

* Mark nvshmem tests as cuda-only (#458)

* Skipped CanNotEmitTritonCustomCallOnPreAmpereGpu test for ROCM.

* Make device_count_ atomic (#343)

* Make device_count_ atomic

* Use relaxed memory order

* Fix build error

* [ROCm] Enable embeded bitcode libs and inprocess lld (#507)

Added TF_ROCM_INPROCESS_LLD  and TF_ROCM_EMBEDDED_DEVICE_LIB form 0.6.0
otherwise identical to openxla#32439.
Env vars only needed for 0.8.0.

* [ROCm] Pass warp size to Triton compilation pipeline

* [ROCm] Add FNUZ FP8 type support in Triton

* [ROCm] Temporary workaround for column reduction warp size

* PR openxla#36046: [ROCm] Fix failing unit tests on ROCm platform

Imported from GitHub PR openxla#36046

📝 Summary of Changes

- layout_assignment tests are marked cuda-only.
- sample_file_test needs higher autotuner level for MIOpen to return conv algorithm. Earlier this was coming from GetDebugOptionsForTest.
- buffer_debug_log test is made gpu agnostic by using cannonical gpu name.
- cublas_gemm_rewriter_test_amdgpu_any fix unit test to remove padding for ROCm as introduced in openxla#33854
- gpu_kernel_tiling_test_amdgpu_any is updated to respect higher launch dimensions now supported by hipruntime
- Mark dynamic_shared_memory_test as cuda-only
- Add arch specific checks for barriers to sorting.hlo

🎯 Justification
Fixes failing unit tests on ROCm platform

* Fix build break in tfrt_gpu_buffer_test using absl_testing::StatusIs (#534)

* Port transpose changes from v0.8.0 to v0.8.2 (#526)

It should be dropped after the rebase on top of
330a305

* [ROCm] Fix failing test TritonEmitterTest/RocmWarpSizeIsSetCorrectly (#545)

* [ROCm] Fix failing test TritonEmitterTest/RocmWarpSizeIsSetCorrectly

Define valid tile parameters and non-zero shared memory.

* Update xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc

Co-authored-by: Maxime France-Pillois <[email protected]>

* Update xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc

Co-authored-by: Maxime France-Pillois <[email protected]>

---------

Co-authored-by: Maxime France-Pillois <[email protected]>

* Fix MIOpen linking for RNN kernels

Add explicit linkopts to miopen cc_library target to ensure
libMIOpen.so is properly linked at runtime. This fixes
AttributeError: module 'jaxlib.gpu_rnn' has no attribute
'compute_rnn_workspace_reserve_space_sizes' in experimental_rnn_test
in JAX.

Without this change, the _rnn.so shared library fails to load
MIOpen symbols properly, causing RNN test failures.

* Force rbe incompatible tests to be executed locally (#485)

* [ROCm] Add missing cuda-only tag

* enable mx datatype for rocm (#462)

* enable mx datatype for rocm

* add // TF_ROCM_VERSION >= 70000

* fix unit test build

* Add rocprofiler-sdk (v3) integration with roctracer fallback

Integrate rocprofiler-sdk for ROCm profiling with fallback to
roctracer (v1) when rocprofiler-sdk is not available.

* [ROCm] Always process convolutions through MIOpen backend for decomposition

Override AddConvAndGemmAutotuningPass in AMDGPUCompiler to ensure
convolutions are always sent to MIOpen for processing, regardless of
xla_gpu_autotune_level.  This is required because MIOpen handles
decomposition of unsupported fused convolutions back to regular convs,
which must happen even when autotuning is disabled.

Fixes cudnn_fused_conv_rewriter_autotune_disabled_test failures on ROCm.

* Changed error value for SplitK test in fusion_emitter_device_legacy_port_test.cc (#538)

* [ROCm] Add PJRT_Triton_Extension support (#548)

This change is PJRT_Triton_Extension support for ROCm as counterpart of
that for CUDA. Pallas Triton calls are lowered to HSACO directly rather
than PTX on ROCm platform.

* Fix expected output in fusion_emitter_int4_device_test for ROCm.

* skip conditional graph tests

* Fixed missing rtne in Triton to pass support_test.

* [ROCm] Add rocm-only tag to triton_rocm target

Fix dependency validation by tagging triton_rocm as rocm-only
since it depends on the rocm-only amdgpu_backend target.

* Avoid upcast of lib func operands to F32 for F16 type.

* Modify fusion_emitter_large_test to work on ROCm. (#568)

* Modify fusion_emitter_large_test to work on ROCm.

* Fix fall-through warning in support_legacy.cc

* Fixed dot_algorithms_test. Updated support_legacy and test itself.

* Modified triton_fusion_numerics_verifier_test to work on ROCm.

* [ROCm] Use shared AsBlasLtEpilogue in GemmWorkspaceRewriter

Replace the duplicate with the shared function to fix the issue and
prevent future divergence. The duplicate AsBlasLtEpilogue in
gemm_workspace_rewriter.cc was missing SILU epilogue support, breaking
ROCm Swish fusion tests. This duplicate was introduced in PR openxla#35132.

* Sync mgpu tests with xla_mgpu config

* [ROCm] Fix RocmWarpSizeIsSetCorrectly test to use new dump file naming

After commit 4ce9326, Triton pass dumps use the naming pattern
{module}.{kernel}.{pass_manager_name}.txt instead of
*.triton-passes.log. Update the test to match the new convention.

* Enable hlo_runner_main_gpu for rocm

* enable hipblaslt as a default choice and disable nccl comm split to avoid hanging

* Add flag to control swish activation fusion. (#577)

Add flag to control swish activation fusion.

* Improve test strategy for swish fusion flag (#585)

Move tests to a more suitable file.

* Revert "Fix infinite recursion in HloInstruction::Accept/Visit const wrappers (#470)"

This reverts commit 21a2d57.

* Disable hipblaslt as default choice

* Execute test directly if running on system without GPU (#608)

* Execute test directly if running on system without GPU

* Address review comments

* Address review comments

* Remove non-existent test targets from ROCm CI exclusion list

The following targets no longer exist in their respective BUILD files
and were causing Bazel target pattern parsing failures.

* Bundle librocm_smi64.so for MI200 lit tests

MI200 lit tests use hlo-opt which links against ROCm libraries. When
running on remote workers without ROCm installed, hlo-opt fails with:
  "error while loading shared libraries: librocm_smi64.so.1"

The _tools_on_path rule bundles libraries into lit_lib/ by extracting
them from CcInfo.linking_context.linker_inputs[].dynamic_library.
However, ROCm's cc_library targets with .so files in srcs don't populate
dynamic_library (unlike CUDA which uses cc_import).

Add a new rocm_smi_import target using cc_import, which properly exposes
the shared library via CcInfo. Use this target in lit.bzl so
librocm_smi64.so.1 gets bundled into lit_lib/ and is available at runtime
via hlo-opt's rpath.

---------

Co-authored-by: Pham Binh <[email protected]>
Co-authored-by: Alex <[email protected]>
Co-authored-by: Zoran Jovanovic <[email protected]>
Co-authored-by: Dragan Mladjenovic <[email protected]>
Co-authored-by: Harsha H S <[email protected]>
Co-authored-by: Maxime France-Pillois <[email protected]>
Co-authored-by: magaonka-amd <[email protected]>
Co-authored-by: Xuefei Jiang <[email protected]>
Co-authored-by: cj401-amd <[email protected]>
Co-authored-by: zoranjovanovic-ns <[email protected]>
Co-authored-by: Jian Li <[email protected]>
Co-authored-by: Chao Chen <[email protected]>
Co-authored-by: Alexandros Theodoridis <[email protected]>
Co-authored-by: Milica Makevic <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. rocm-jaxlib-v0.8.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants