-
Notifications
You must be signed in to change notification settings - Fork 5.1k
[Feature] CUDA Green Context Support #7649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
16e3e23
add greenctx stream to sgl-kernel
ykcombat c34f51b
Merge branch 'main' into greenctx_stream
ykcombat 6a9a7cf
remove resource partition by percent
ykcombat ed1bcb5
Merge branch 'greenctx_stream' of https://github.com/ykcombat/sglang …
ykcombat b8375f3
Merge branch 'main' into greenctx_stream
ykcombat File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| #include <cuda.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #define CUDA_RT(call) \ | ||
| do { \ | ||
| cudaError_t _status = (call); \ | ||
| if (_status != cudaSuccess) { \ | ||
| std::cerr << "ERROR: CUDA RT call \"" << #call << "\" in line " << __LINE__ << " of file " << __FILE__ \ | ||
| << " failed with " << cudaGetErrorString(_status) << std::endl; \ | ||
| exit(1); \ | ||
| } \ | ||
| } while (0) | ||
|
|
||
| #define CUDA_DRV(call) \ | ||
| do { \ | ||
| CUresult _status = (call); \ | ||
| if (_status != CUDA_SUCCESS) { \ | ||
| const char* err_str; \ | ||
| cuGetErrorString(_status, &err_str); \ | ||
| std::cerr << "ERROR: CUDA DRV call \"" << #call << "\" in line " << __LINE__ << " of file " << __FILE__ \ | ||
| << " failed with " << err_str << std::endl; \ | ||
| exit(1); \ | ||
ykcombat marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } \ | ||
| } while (0) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| #include <torch/all.h> | ||
|
|
||
| #include <cstdlib> | ||
| #include <iomanip> | ||
| #include <iostream> | ||
|
|
||
| #include "cuda_utils.h" | ||
| #include "greenctx_stream.h" | ||
|
|
||
| std::vector<int64_t> create_greenctx_stream_by_percent(float smA, float smB, int device) { | ||
| CUgreenCtx gctx[3]; | ||
| CUdevResourceDesc desc[3]; | ||
| CUdevResource input; | ||
| CUdevResource resources[4]; | ||
| CUstream streamA; | ||
| CUstream streamB; | ||
|
|
||
| unsigned int nbGroups = 1; | ||
|
|
||
| if (smA + smB > 1.0) { | ||
| TORCH_CHECK(false, "Sum of SM percentages cannot exceed 1.0"); | ||
| } | ||
|
|
||
| if (smA <= 0.0 || smB <= 0.0) { | ||
| TORCH_CHECK(false, "SM percentages must be greater than 0.0"); | ||
| } | ||
|
|
||
| // Initialize device | ||
| CUDA_RT(cudaInitDevice(device, 0, 0)); | ||
|
|
||
| // Query input SMs | ||
| CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM)); | ||
| // We want 3/4 the device for our green context | ||
| unsigned int minCount = (unsigned int)((float)input.sm.smCount * (smA + smB)); | ||
| unsigned int minCountA = (unsigned int)((float)input.sm.smCount * smA); | ||
|
|
||
| // Split resources | ||
| CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount)); | ||
| CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1)); | ||
| CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); | ||
| CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM)); | ||
| CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA)); | ||
|
|
||
| CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1)); | ||
| CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); | ||
| CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1)); | ||
| CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); | ||
|
|
||
| CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); | ||
| CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); | ||
|
|
||
| int smCountA = resources[0].sm.smCount; | ||
| int smCountB = resources[1].sm.smCount; | ||
|
|
||
| std::vector<int64_t> vec = {(int64_t)streamA, (int64_t)streamB, smCountA, smCountB}; | ||
| return vec; | ||
| } | ||
|
|
||
| std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) { | ||
| CUgreenCtx gctx[3]; | ||
| CUdevResourceDesc desc[3]; | ||
ykcombat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| CUdevResource input; | ||
| CUdevResource resources[4]; | ||
| CUstream streamA; | ||
| CUstream streamB; | ||
|
|
||
| unsigned int nbGroups = 1; | ||
|
|
||
| if (smA <= 0 || smB <= 0) { | ||
| TORCH_CHECK(false, "SM counts must be positive"); | ||
| } | ||
|
|
||
| // Initialize device | ||
| CUDA_RT(cudaInitDevice(device, 0, 0)); | ||
|
|
||
| // Query input SMs | ||
| CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM)); | ||
| // We want 3/4 the device for our green context | ||
| unsigned int minCount = (unsigned int)(smA + smB); | ||
| unsigned int minCountA = (unsigned int)(smA); | ||
|
|
||
| TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration"); | ||
|
|
||
| // Split resources | ||
| CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount)); | ||
|
|
||
| CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1)); | ||
| CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); | ||
| CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM)); | ||
| CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA)); | ||
|
|
||
| CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1)); | ||
| CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); | ||
| CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1)); | ||
| CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); | ||
|
|
||
| CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); | ||
| CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); | ||
|
|
||
| int smCountA = resources[0].sm.smCount; | ||
| int smCountB = resources[1].sm.smCount; | ||
|
|
||
| std::vector<int64_t> vec = {(int64_t)streamA, (int64_t)streamB, smCountA, smCountB}; | ||
| return vec; | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| #include <vector> | ||
|
|
||
| std::vector<int64_t> create_greenctx_stream_by_percent(float smA, float smB, int device); | ||
| std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| import torch | ||
| from torch.cuda.streams import ExternalStream | ||
|
|
||
|
|
||
| def create_greenctx_stream_by_value( | ||
| SM_a: int, SM_b: int, device_id: int = None | ||
| ) -> tuple[ExternalStream, ExternalStream]: | ||
| """ | ||
| Create two streams for greenctx. | ||
| Args: | ||
| sm_A (int): The SM of stream A. | ||
| sm_B (int): The weight of stream B. | ||
| device_id (int): The device id. | ||
| Returns: | ||
| tuple[ExternalStream, ExternalStream]: The two streams. | ||
| """ | ||
| if device_id is None: | ||
| device_id = torch.cuda.current_device() | ||
|
|
||
| res = torch.ops.sgl_kernel.create_greenctx_stream_by_value(SM_a, SM_b, device_id) | ||
|
|
||
| if (res[2] != SM_a) or (res[3] != SM_b): | ||
| raise RuntimeError( | ||
| f"The SMs of the created streams are not equal to the input SMs, expected: {SM_a}, {SM_b}, got: {res[2]}, {res[3]}" | ||
| ) | ||
ykcombat marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| stream_a = ExternalStream( | ||
| stream_ptr=res[0], device=torch.device(f"cuda:{device_id}") | ||
| ) | ||
| stream_b = ExternalStream( | ||
| stream_ptr=res[1], device=torch.device(f"cuda:{device_id}") | ||
| ) | ||
|
|
||
| return stream_a, stream_b | ||
|
|
||
|
|
||
| def get_sm_available(device_id: int = None) -> int: | ||
| """ | ||
| Get the SMs available on the device. | ||
| Args: | ||
| device_id (int): The device id. | ||
| Returns: | ||
| int: The SMs available. | ||
| """ | ||
| if device_id is None: | ||
| device_id = torch.cuda.current_device() | ||
|
|
||
| device_props = torch.cuda.get_device_properties(device_id) | ||
|
|
||
| # Get the number of Streaming Multiprocessors (SMs) | ||
| sm_count = device_props.multi_processor_count | ||
|
|
||
| return sm_count | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from sgl_kernel import create_greenctx_stream_by_value, get_sm_available | ||
|
|
||
|
|
||
| def test_green_ctx(): | ||
| A = torch.randn(5120, 5120).cuda() | ||
| B = torch.randn(5120, 5120).cuda() | ||
| C = torch.matmul(A, B) | ||
| sm_counts = get_sm_available(0) | ||
| stream_group = create_greenctx_stream_by_value(sm_counts // 2, sm_counts // 2, 0) | ||
| with torch.cuda.stream(stream_group[0]): | ||
| for _ in range(100): | ||
| result_0 = torch.matmul(A, B) | ||
| with torch.cuda.stream(stream_group[1]): | ||
| for _ in range(100): | ||
| result_1 = torch.matmul(A, B) | ||
| torch.cuda.synchronize() | ||
| assert torch.allclose(result_0, C) | ||
| assert torch.allclose(result_1, C) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__]) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.