Skip to content

Commit 206398a

Browse files
author
jax authors
committed
Merge pull request #19599 from ROCm:rocm-add-triton_command_buffer
PiperOrigin-RevId: 604400256
2 parents be99451 + f01c27f commit 206398a

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

jaxlib/gpu/triton_kernels.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -524,16 +524,16 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
524524

525525
gpustreamCaptureStatus_t capture_status;
526526
GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status));
527-
bool is_capturing = capture_status == CU_STREAM_CAPTURE_STATUS_ACTIVE;
527+
bool is_capturing = capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE;
528528

529-
gpustreamCaptureMode_t capture_mode = CU_STREAM_CAPTURE_MODE_RELAXED;
529+
gpustreamCaptureMode_t capture_mode = GPU_STREAM_CAPTURE_MODE_RELAXED;
530530
gpuStream_t autotune_stream = stream;
531531

532532
if (is_capturing) {
533+
533534
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
534535
// Need a side stream so as not to interfere with graph capture.
535-
GPU_RETURN_IF_ERROR(
536-
gpuStreamCreate(&autotune_stream, CU_STREAM_NON_BLOCKING));
536+
GPU_RETURN_IF_ERROR(gpuStreamCreate(&autotune_stream, GPU_STREAM_NON_BLOCKING));
537537
}
538538

539539
// If an input aliases with an output, it will get overwritten during the

jaxlib/gpu/vendor.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
254254
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT CUSPARSE_SPARSETODENSE_ALG_DEFAULT
255255
#define GPUSPARSE_STATUS_SUCCESS CUSPARSE_STATUS_SUCCESS
256256

257+
#define GPU_STREAM_CAPTURE_STATUS_ACTIVE CU_STREAM_CAPTURE_STATUS_ACTIVE
258+
#define GPU_STREAM_CAPTURE_MODE_RELAXED CU_STREAM_CAPTURE_MODE_RELAXED
259+
#define GPU_STREAM_NON_BLOCKING CU_STREAM_NON_BLOCKING
260+
257261
#define gpuCtxGetDevice cuCtxGetDevice
258262
#define gpuCtxPopCurrent cuCtxPopCurrent
259263
#define gpuCtxPushCurrent cuCtxPushCurrent
@@ -332,6 +336,8 @@ typedef hipsolverFillMode_t gpusolverFillMode_t;
332336
typedef hipblasHandle_t gpublasHandle_t;
333337
typedef hipblasStatus_t gpublasStatus_t;
334338
typedef hipCtx_t gpuContext_t;
339+
typedef hipStreamCaptureMode gpustreamCaptureMode_t;
340+
typedef hipStreamCaptureStatus gpustreamCaptureStatus_t;
335341
typedef hipDataType gpuDataType;
336342
typedef hipDevice_t gpuDevice_t;
337343
typedef hipDeviceptr_t gpuDevicePtr_t;
@@ -494,6 +500,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
494500
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT
495501
#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
496502

503+
#define GPU_STREAM_CAPTURE_STATUS_ACTIVE hipStreamCaptureStatusActive
504+
#define GPU_STREAM_CAPTURE_MODE_RELAXED hipStreamCaptureModeRelaxed
505+
#define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking
506+
497507
#define gpuGetLastError hipGetLastError
498508
#define gpuGetErrorString hipGetErrorString
499509
#define gpuMemcpyAsync hipMemcpyAsync
@@ -526,6 +536,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
526536
#define gpuMemcpyDtoHAsync hipMemcpyDtoHAsync
527537
#define gpuMemcpyHtoDAsync hipMemcpyHtoDAsync
528538
#define gpuMemsetD8Async hipMemsetD8Async
539+
#define gpuThreadExchangeStreamCaptureMode hipThreadExchangeStreamCaptureMode
540+
#define gpuStreamCreate hipStreamCreateWithFlags
541+
#define gpuStreamDestroy hipStreamDestroy
542+
#define gpuStreamIsCapturing hipStreamIsCapturing
529543

530544
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \
531545
hipDeviceAttributeComputeCapabilityMajor

0 commit comments

Comments
 (0)