diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index a8ae7a44bee38..96f1e1a317af1 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -1931,11 +1931,27 @@ pi_result cuda_piDeviceGetInfo(pi_device device, pi_device_info param_name, } case PI_EXT_INTEL_DEVICE_INFO_FREE_MEMORY: { + // Check the device of the currently set context uses the same device. + // CUDA_ERROR_INVALID_CONTEXT signifies the absence of an active context. + CUdevice current_ctx_device; + CUresult current_ctx_device_ret = cuCtxGetDevice(¤t_ctx_device); + if (current_ctx_device_ret != CUDA_ERROR_INVALID_CONTEXT) + PI_CHECK_ERROR(current_ctx_device_ret); + bool need_primary_ctx = current_ctx_device_ret == CUDA_ERROR_INVALID_CONTEXT || + current_ctx_device != device->get(); + if (need_primary_ctx) { + // Use the primary context for the device if no context with the device is set. + CUcontext primary_context; + PI_CHECK_ERROR(cuDevicePrimaryCtxRetain(&primary_context, device->get())); + PI_CHECK_ERROR(cuCtxSetCurrent(primary_context)); + } size_t FreeMemory = 0; size_t TotalMemory = 0; sycl::detail::pi::assertion(cuMemGetInfo(&FreeMemory, &TotalMemory) == CUDA_SUCCESS, "failed cuMemGetInfo() API."); + if (need_primary_ctx) + PI_CHECK_ERROR(cuDevicePrimaryCtxRelease(device->get())); return getInfo(param_value_size, param_value, param_value_size_ret, FreeMemory); }