Skip to content

Conversation

@zoranjovanovic-ns
Copy link

@zoranjovanovic-ns zoranjovanovic-ns commented May 9, 2025

Implement passing waves_per_eu function attribute from JAX trough XLA to Triton
For passing correct attribute value existing serialized_metadata parameter is used on JAX side.

syntax:
serialized_metadata="waves_per_eu:2"

example:

    params = plgpu.TritonCompilerParams(serialized_metadata="waves_per_eu:2")

    @jax.jit
    @functools.partial(
        self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),
        compiler_params=params)
    def copy_kernel(x_ref, o_ref):
      o_ref[...] = x_ref[...]

Copy link
Collaborator

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@i-chaochen i-chaochen merged commit bc1d816 into rocm-jaxlib-v0.5.0 May 12, 2025
6 of 7 checks passed
if (attr_smd) {
TritonCallArgs triton_call_args_proto;
auto sermetadata = attr_smd.getValue().str();
if (tsl::protobuf::TextFormat::ParseFromString(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this work if user passes extra keys?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into this.
AllowUnknownField(true) is not supported for text format.
Also, tried with google.protobuf.Any, but parser still reports the error if user passes extra key.

gulsumgudukbay pushed a commit that referenced this pull request May 16, 2025
gulsumgudukbay pushed a commit that referenced this pull request May 20, 2025
gulsumgudukbay pushed a commit that referenced this pull request May 21, 2025
@alekstheod alekstheod deleted the rocm-jaxlib-v0.5.0-waves_per_eu-2 branch July 3, 2025 08:11
@alekstheod alekstheod restored the rocm-jaxlib-v0.5.0-waves_per_eu-2 branch July 15, 2025 08:21
zahiqbal pushed a commit that referenced this pull request Oct 2, 2025
---------

Co-authored-by: Pavel Emeliyanenko <[email protected]>
(cherry picked from commit f013645)
(cherry picked from commit b03cd94)

Added support for waves_per_eu function attribute. (#181)

(cherry picked from commit bc1d816)
(cherry picked from commit d3f94e9)

removed two line change (revert of half of the openxla#25959 commit

(cherry picked from commit 109e138)

Fixes for jax 0.6.0 (#207)

* Add fixes for jax plugin 0.6.0

Drop NEEDED linking to unnecessary libs.
These are loaded by amdhipruntime and not us.

Fix missing NEEDED on MIOpen shared object.

* Minor rocblas related changes for rocm 70

(cherry picked from commit 0de7d49)

---------

Co-authored-by: Zoran Jovanovic <[email protected]>
(cherry picked from commit 28f10a0)

Add hipBLASLt support for gfx11. (#301)

(cherry picked from commit f814bff)

Add bf16 starting from gfx11, bugfix & optimize RocmComputeCapability (#303)

* Bugfix and improve device_description.h::RocmComputeCompatibility

* Enable ALG_DOT_BF16* on rocm with HW support

(cherry picked from commit 510ea06)

[ROCm] Use bundled bitcode files (#196)

Also trim bitcode file list to ockl.bc and ocml.bc only.

(cherry picked from commit fc9e3c3)

Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms (#312)

* Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms

* Exclude failing CollectiveOpsE2E tests

(cherry picked from commit fb6ddfb)

Restore RocmComputeCapability:: gfx11_rx7900() and gfx12_rx8900() methods (#333)

At least gfx11_rx7900() is still needed for TF build.

(cherry picked from commit 13c3de1)

Make device_count_ atomic (#343)

* Make device_count_ atomic

* Use relaxed memory order

* Fix build error

(cherry picked from commit 8513f2d)

fix hardcoded max registers (#345)

(cherry picked from commit f3e170a)

fix hardcoded ecc enabled (#348)

(cherry picked from commit 9cfa74a)

remove reserved memory (#349)

(cherry picked from commit 0015d0e)

Add rocm_dev config for remote caching (#353)

(cherry picked from commit c815420)

added rocm7 support to EnablePeerAccess (#347)

* added rocm7 support to EnablePeerAccess

* use wrap namespace, clang-format and add comments

(cherry picked from commit 85548a7)

[ROCm] Disable Cudnn fusions (#358)

(cherry picked from commit edab8b2)
zahiqbal pushed a commit that referenced this pull request Oct 2, 2025
---------

Co-authored-by: Pavel Emeliyanenko <[email protected]>
(cherry picked from commit f013645)
(cherry picked from commit b03cd94)

Added support for waves_per_eu function attribute. (#181)

(cherry picked from commit bc1d816)
(cherry picked from commit d3f94e9)

removed two line change (revert of half of the openxla#25959 commit

(cherry picked from commit 109e138)

Fixes for jax 0.6.0 (#207)

* Add fixes for jax plugin 0.6.0

Drop NEEDED linking to unnecessary libs.
These are loaded by amdhipruntime and not us.

Fix missing NEEDED on MIOpen shared object.

* Minor rocblas related changes for rocm 70

(cherry picked from commit 0de7d49)

---------

Co-authored-by: Zoran Jovanovic <[email protected]>
(cherry picked from commit 28f10a0)

Add hipBLASLt support for gfx11. (#301)

(cherry picked from commit f814bff)

Add bf16 starting from gfx11, bugfix & optimize RocmComputeCapability (#303)

* Bugfix and improve device_description.h::RocmComputeCompatibility

* Enable ALG_DOT_BF16* on rocm with HW support

(cherry picked from commit 510ea06)

[ROCm] Use bundled bitcode files (#196)

Also trim bitcode file list to ockl.bc and ocml.bc only.

(cherry picked from commit fc9e3c3)

Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms (#312)

* Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms

* Exclude failing CollectiveOpsE2E tests

(cherry picked from commit fb6ddfb)

Restore RocmComputeCapability:: gfx11_rx7900() and gfx12_rx8900() methods (#333)

At least gfx11_rx7900() is still needed for TF build.

(cherry picked from commit 13c3de1)

Make device_count_ atomic (#343)

* Make device_count_ atomic

* Use relaxed memory order

* Fix build error

(cherry picked from commit 8513f2d)

fix hardcoded max registers (#345)

(cherry picked from commit f3e170a)

fix hardcoded ecc enabled (#348)

(cherry picked from commit 9cfa74a)

remove reserved memory (#349)

(cherry picked from commit 0015d0e)

Add rocm_dev config for remote caching (#353)

(cherry picked from commit c815420)

added rocm7 support to EnablePeerAccess (#347)

* added rocm7 support to EnablePeerAccess

* use wrap namespace, clang-format and add comments

(cherry picked from commit 85548a7)

[ROCm] Disable Cudnn fusions (#358)

(cherry picked from commit edab8b2)
@zahiqbal zahiqbal mentioned this pull request Oct 2, 2025
1 task
zahiqbal pushed a commit that referenced this pull request Oct 3, 2025
---------

Co-authored-by: Pavel Emeliyanenko <[email protected]>
(cherry picked from commit f013645)
(cherry picked from commit b03cd94)

Added support for waves_per_eu function attribute. (#181)

(cherry picked from commit bc1d816)
(cherry picked from commit d3f94e9)

removed two line change (revert of half of the openxla#25959 commit

(cherry picked from commit 109e138)

Fixes for jax 0.6.0 (#207)

* Add fixes for jax plugin 0.6.0

Drop NEEDED linking to unnecessary libs.
These are loaded by amdhipruntime and not us.

Fix missing NEEDED on MIOpen shared object.

* Minor rocblas related changes for rocm 70

(cherry picked from commit 0de7d49)

---------

Co-authored-by: Zoran Jovanovic <[email protected]>
(cherry picked from commit 28f10a0)

Add hipBLASLt support for gfx11. (#301)

(cherry picked from commit f814bff)

Add bf16 starting from gfx11, bugfix & optimize RocmComputeCapability (#303)

* Bugfix and improve device_description.h::RocmComputeCompatibility

* Enable ALG_DOT_BF16* on rocm with HW support

(cherry picked from commit 510ea06)

[ROCm] Use bundled bitcode files (#196)

Also trim bitcode file list to ockl.bc and ocml.bc only.

(cherry picked from commit fc9e3c3)

Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms (#312)

* Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms

* Exclude failing CollectiveOpsE2E tests

(cherry picked from commit fb6ddfb)

Restore RocmComputeCapability:: gfx11_rx7900() and gfx12_rx8900() methods (#333)

At least gfx11_rx7900() is still needed for TF build.

(cherry picked from commit 13c3de1)

Make device_count_ atomic (#343)

* Make device_count_ atomic

* Use relaxed memory order

* Fix build error

(cherry picked from commit 8513f2d)

fix hardcoded max registers (#345)

(cherry picked from commit f3e170a)

fix hardcoded ecc enabled (#348)

(cherry picked from commit 9cfa74a)

remove reserved memory (#349)

(cherry picked from commit 0015d0e)

Add rocm_dev config for remote caching (#353)

(cherry picked from commit c815420)

added rocm7 support to EnablePeerAccess (#347)

* added rocm7 support to EnablePeerAccess

* use wrap namespace, clang-format and add comments

(cherry picked from commit 85548a7)

[ROCm] Disable Cudnn fusions (#358)

(cherry picked from commit edab8b2)
zahiqbal pushed a commit that referenced this pull request Oct 5, 2025
---------

Co-authored-by: Pavel Emeliyanenko <[email protected]>
(cherry picked from commit f013645)
(cherry picked from commit b03cd94)

Added support for waves_per_eu function attribute. (#181)

(cherry picked from commit bc1d816)
(cherry picked from commit d3f94e9)

removed two line change (revert of half of the openxla#25959 commit

(cherry picked from commit 109e138)

Fixes for jax 0.6.0 (#207)

* Add fixes for jax plugin 0.6.0

Drop NEEDED linking to unnecessary libs.
These are loaded by amdhipruntime and not us.

Fix missing NEEDED on MIOpen shared object.

* Minor rocblas related changes for rocm 70

(cherry picked from commit 0de7d49)

---------

Co-authored-by: Zoran Jovanovic <[email protected]>
(cherry picked from commit 28f10a0)

Add hipBLASLt support for gfx11. (#301)

(cherry picked from commit f814bff)

Add bf16 starting from gfx11, bugfix & optimize RocmComputeCapability (#303)

* Bugfix and improve device_description.h::RocmComputeCompatibility

* Enable ALG_DOT_BF16* on rocm with HW support

(cherry picked from commit 510ea06)

[ROCm] Use bundled bitcode files (#196)

Also trim bitcode file list to ockl.bc and ocml.bc only.

(cherry picked from commit fc9e3c3)

Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms (#312)

* Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms

* Exclude failing CollectiveOpsE2E tests

(cherry picked from commit fb6ddfb)

Restore RocmComputeCapability:: gfx11_rx7900() and gfx12_rx8900() methods (#333)

At least gfx11_rx7900() is still needed for TF build.

(cherry picked from commit 13c3de1)

Make device_count_ atomic (#343)

* Make device_count_ atomic

* Use relaxed memory order

* Fix build error

(cherry picked from commit 8513f2d)

fix hardcoded max registers (#345)

(cherry picked from commit f3e170a)

fix hardcoded ecc enabled (#348)

(cherry picked from commit 9cfa74a)

remove reserved memory (#349)

(cherry picked from commit 0015d0e)

Add rocm_dev config for remote caching (#353)

(cherry picked from commit c815420)

added rocm7 support to EnablePeerAccess (#347)

* added rocm7 support to EnablePeerAccess

* use wrap namespace, clang-format and add comments

(cherry picked from commit 85548a7)

[ROCm] Disable Cudnn fusions (#358)

(cherry picked from commit edab8b2)
hsharsha pushed a commit that referenced this pull request Oct 6, 2025
* rocprof-sdk addition,
upstream PR: openxla/pull/29769

Squash following commits..
Update rocprofiler-sdk (v3) along with roctracer (v1) for rocm-jaxlib-v0.6.0 (#302)

* update for integration of rocprofiler-sdk (along with roctracer as a backup based on bazel_options from CLI)

(cherry picked from commit 7775dd0)

use VLOG(2) to replace LOG(INFO), so PGLE has no verbose info (#357)

(cherry picked from commit 5950125)

update with kernel details for rocm-7.x (#364)

* update with kernel details for rocm-7.x

(cherry picked from commit 5597c0d)

update to remove previously hard-coded rocprofiler-sdk path (#369)

* update to remove previously hard-coded rocprofiler-sdk path and add skip_rocprofiler_sdk to avoid loading `rocprofiler-sdk`

(cherry picked from commit ff74b5f)

* fixed buffer comparator test

* misc fixes ported from rocm-jaxlib-v0.6.0

---------

Co-authored-by: Pavel Emeliyanenko <[email protected]>
(cherry picked from commit f013645)
(cherry picked from commit b03cd94)

Added support for waves_per_eu function attribute. (#181)

(cherry picked from commit bc1d816)
(cherry picked from commit d3f94e9)

removed two line change (revert of half of the openxla#25959 commit

(cherry picked from commit 109e138)

Fixes for jax 0.6.0 (#207)

* Add fixes for jax plugin 0.6.0

Drop NEEDED linking to unnecessary libs.
These are loaded by amdhipruntime and not us.

Fix missing NEEDED on MIOpen shared object.

* Minor rocblas related changes for rocm 70

(cherry picked from commit 0de7d49)

---------

Co-authored-by: Zoran Jovanovic <[email protected]>
(cherry picked from commit 28f10a0)

Add hipBLASLt support for gfx11. (#301)

(cherry picked from commit f814bff)

Add bf16 starting from gfx11, bugfix & optimize RocmComputeCapability (#303)

* Bugfix and improve device_description.h::RocmComputeCompatibility

* Enable ALG_DOT_BF16* on rocm with HW support

(cherry picked from commit 510ea06)

[ROCm] Use bundled bitcode files (#196)

Also trim bitcode file list to ockl.bc and ocml.bc only.

(cherry picked from commit fc9e3c3)

Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms (#312)

* Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms

* Exclude failing CollectiveOpsE2E tests

(cherry picked from commit fb6ddfb)

Restore RocmComputeCapability:: gfx11_rx7900() and gfx12_rx8900() methods (#333)

At least gfx11_rx7900() is still needed for TF build.

(cherry picked from commit 13c3de1)

Make device_count_ atomic (#343)

* Make device_count_ atomic

* Use relaxed memory order

* Fix build error

(cherry picked from commit 8513f2d)

fix hardcoded max registers (#345)

(cherry picked from commit f3e170a)

fix hardcoded ecc enabled (#348)

(cherry picked from commit 9cfa74a)

remove reserved memory (#349)

(cherry picked from commit 0015d0e)

Add rocm_dev config for remote caching (#353)

(cherry picked from commit c815420)

added rocm7 support to EnablePeerAccess (#347)

* added rocm7 support to EnablePeerAccess

* use wrap namespace, clang-format and add comments

(cherry picked from commit 85548a7)

[ROCm] Disable Cudnn fusions (#358)

(cherry picked from commit edab8b2)

---------

Co-authored-by: Chunyu Jin <[email protected]>
Co-authored-by: zoranjovanovic-ns <[email protected]>
hsharsha pushed a commit that referenced this pull request Oct 6, 2025
* rocprof-sdk addition,
upstream PR: openxla/pull/29769

Squash following commits..
Update rocprofiler-sdk (v3) along with roctracer (v1) for rocm-jaxlib-v0.6.0 (#302)

* update for integration of rocprofiler-sdk (along with roctracer as a backup based on bazel_options from CLI)

(cherry picked from commit 7775dd0)

use VLOG(2) to replace LOG(INFO), so PGLE has no verbose info (#357)

(cherry picked from commit 5950125)

update with kernel details for rocm-7.x (#364)

* update with kernel details for rocm-7.x

(cherry picked from commit 5597c0d)

update to remove previously hard-coded rocprofiler-sdk path (#369)

* update to remove previously hard-coded rocprofiler-sdk path and add skip_rocprofiler_sdk to avoid loading `rocprofiler-sdk`

(cherry picked from commit ff74b5f)

* fixed buffer comparator test

* misc fixes ported from rocm-jaxlib-v0.6.0

---------

Co-authored-by: Pavel Emeliyanenko <[email protected]>
(cherry picked from commit f013645)
(cherry picked from commit b03cd94)

Added support for waves_per_eu function attribute. (#181)

(cherry picked from commit bc1d816)
(cherry picked from commit d3f94e9)

removed two line change (revert of half of the openxla#25959 commit

(cherry picked from commit 109e138)

Fixes for jax 0.6.0 (#207)

* Add fixes for jax plugin 0.6.0

Drop NEEDED linking to unnecessary libs.
These are loaded by amdhipruntime and not us.

Fix missing NEEDED on MIOpen shared object.

* Minor rocblas related changes for rocm 70

(cherry picked from commit 0de7d49)

---------

Co-authored-by: Zoran Jovanovic <[email protected]>
(cherry picked from commit 28f10a0)

Add hipBLASLt support for gfx11. (#301)

(cherry picked from commit f814bff)

Add bf16 starting from gfx11, bugfix & optimize RocmComputeCapability (#303)

* Bugfix and improve device_description.h::RocmComputeCompatibility

* Enable ALG_DOT_BF16* on rocm with HW support

(cherry picked from commit 510ea06)

[ROCm] Use bundled bitcode files (#196)

Also trim bitcode file list to ockl.bc and ocml.bc only.

(cherry picked from commit fc9e3c3)

Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms (#312)

* Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms

* Exclude failing CollectiveOpsE2E tests

(cherry picked from commit fb6ddfb)

Restore RocmComputeCapability:: gfx11_rx7900() and gfx12_rx8900() methods (#333)

At least gfx11_rx7900() is still needed for TF build.

(cherry picked from commit 13c3de1)

Make device_count_ atomic (#343)

* Make device_count_ atomic

* Use relaxed memory order

* Fix build error

(cherry picked from commit 8513f2d)

fix hardcoded max registers (#345)

(cherry picked from commit f3e170a)

fix hardcoded ecc enabled (#348)

(cherry picked from commit 9cfa74a)

remove reserved memory (#349)

(cherry picked from commit 0015d0e)

Add rocm_dev config for remote caching (#353)

(cherry picked from commit c815420)

added rocm7 support to EnablePeerAccess (#347)

* added rocm7 support to EnablePeerAccess

* use wrap namespace, clang-format and add comments

(cherry picked from commit 85548a7)

[ROCm] Disable Cudnn fusions (#358)

(cherry picked from commit edab8b2)

* Ported all triton related changes from v0.6.0 to v0.7.1

(cherry picked from commit 1851bcc)

Disable softmax triton fusion if triton gemm is off (#281)

* Disable softmax rewriter triton if triton gemm is disabled

* Add specific flag to enable triton softmax fusion

* Address review comments

(cherry picked from commit 51a7f4b)

[ROCm][Triton] Disable transposed load in certain conditions

(cherry picked from commit 50860e9)

Enable unit tests that pass after fixing some Triton related issues. (#285)

* Enable unit tests that pass after fixing some Triton related issues.

* fusion_emitter_device_legacy_test still fails on MI200

(cherry picked from commit 97dd565)

Rocm jaxlib v0.6.0 triton support ut (#279)

* Fixed triton/support_test - no fmfa.

* Fix issue with rounding mode in accelerate amd matmul.

* Fixed issues with usage of mfma in support_test.

(cherry picked from commit 44f7d87)

Restore gpu_triton_custom_call_test (#262)

(cherry picked from commit 32eafa4)

Skipped CanNotEmitTritonCustomCallOnPreAmpereGpu test for ROCM.

(cherry picked from commit 56ec7ec)
(cherry picked from commit b1f3e9f)

fixed createTritonAMDGPULowerInstructionSchedHintsPass (#179)

(cherry picked from commit 8517a3a)
(cherry picked from commit c62e47d)

fixed bazel build issue

---------

Co-authored-by: Chunyu Jin <[email protected]>
Co-authored-by: zoranjovanovic-ns <[email protected]>
Co-authored-by: Alex <[email protected]>
hsharsha pushed a commit that referenced this pull request Nov 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants