Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ set(SOURCES
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/packbit.cu"
"csrc/spatial/greenctx_stream.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/common_extension.cc"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
Expand Down
6 changes: 6 additions & 0 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
"Tensor _ascales, Tensor! _out_feats) -> ()");
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);

/*
* From csrc/spatial
*/
m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]");
m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value);
}

REGISTER_EXTENSION(common_ops)
24 changes: 24 additions & 0 deletions sgl-kernel/csrc/spatial/cuda_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include <cuda.h>
#include <cuda_runtime.h>

#define CUDA_RT(call) \
do { \
cudaError_t _status = (call); \
if (_status != cudaSuccess) { \
std::cerr << "ERROR: CUDA RT call \"" << #call << "\" in line " << __LINE__ << " of file " << __FILE__ \
<< " failed with " << cudaGetErrorString(_status) << std::endl; \
exit(1); \
} \
} while (0)

#define CUDA_DRV(call) \
do { \
CUresult _status = (call); \
if (_status != CUDA_SUCCESS) { \
const char* err_str; \
cuGetErrorString(_status, &err_str); \
std::cerr << "ERROR: CUDA DRV call \"" << #call << "\" in line " << __LINE__ << " of file " << __FILE__ \
<< " failed with " << err_str << std::endl; \
exit(1); \
} \
} while (0)
105 changes: 105 additions & 0 deletions sgl-kernel/csrc/spatial/greenctx_stream.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#include <torch/all.h>

#include <cstdlib>
#include <iomanip>
#include <iostream>

#include "cuda_utils.h"
#include "greenctx_stream.h"

std::vector<int64_t> create_greenctx_stream_by_percent(float smA, float smB, int device) {
CUgreenCtx gctx[3];
CUdevResourceDesc desc[3];
CUdevResource input;
CUdevResource resources[4];
CUstream streamA;
CUstream streamB;

unsigned int nbGroups = 1;

if (smA + smB > 1.0) {
TORCH_CHECK(false, "Sum of SM percentages cannot exceed 1.0");
}

if (smA <= 0.0 || smB <= 0.0) {
TORCH_CHECK(false, "SM percentages must be greater than 0.0");
}

// Initialize device
CUDA_RT(cudaInitDevice(device, 0, 0));

// Query input SMs
CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));
// We want 3/4 the device for our green context
unsigned int minCount = (unsigned int)((float)input.sm.smCount * (smA + smB));
unsigned int minCountA = (unsigned int)((float)input.sm.smCount * smA);

// Split resources
CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount));
CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1));
CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM));
CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA));

CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1));
CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1));
CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));

CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0));
CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0));

int smCountA = resources[0].sm.smCount;
int smCountB = resources[1].sm.smCount;

std::vector<int64_t> vec = {(int64_t)streamA, (int64_t)streamB, smCountA, smCountB};
return vec;
}

std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
CUgreenCtx gctx[3];
CUdevResourceDesc desc[3];
CUdevResource input;
CUdevResource resources[4];
CUstream streamA;
CUstream streamB;

unsigned int nbGroups = 1;

if (smA <= 0 || smB <= 0) {
TORCH_CHECK(false, "SM counts must be positive");
}

// Initialize device
CUDA_RT(cudaInitDevice(device, 0, 0));

// Query input SMs
CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));
// We want 3/4 the device for our green context
unsigned int minCount = (unsigned int)(smA + smB);
unsigned int minCountA = (unsigned int)(smA);

TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration");

// Split resources
CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount));

CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1));
CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM));
CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA));

CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1));
CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1));
CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));

CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0));
CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0));

int smCountA = resources[0].sm.smCount;
int smCountB = resources[1].sm.smCount;

std::vector<int64_t> vec = {(int64_t)streamA, (int64_t)streamB, smCountA, smCountB};
return vec;
}
4 changes: 4 additions & 0 deletions sgl-kernel/csrc/spatial/greenctx_stream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include <vector>

std::vector<int64_t> create_greenctx_stream_by_percent(float smA, float smB, int device);
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device);
5 changes: 5 additions & 0 deletions sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,8 @@ void qserve_w4a8_per_group_gemm(
const torch::Tensor& _wscales,
const torch::Tensor& _ascales,
torch::Tensor& _out_feats);

/*
* From csrc/spatial
*/
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device);
1 change: 1 addition & 0 deletions sgl-kernel/python/sgl_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
top_p_renorm_prob,
top_p_sampling_from_probs,
)
from sgl_kernel.spatial import create_greenctx_stream_by_value, get_sm_available
from sgl_kernel.speculative import (
build_tree_kernel_efficient,
segment_packbits,
Expand Down
53 changes: 53 additions & 0 deletions sgl-kernel/python/sgl_kernel/spatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from torch.cuda.streams import ExternalStream


def create_greenctx_stream_by_value(
SM_a: int, SM_b: int, device_id: int = None
) -> tuple[ExternalStream, ExternalStream]:
"""
Create two streams for greenctx.
Args:
sm_A (int): The SM of stream A.
sm_B (int): The weight of stream B.
device_id (int): The device id.
Returns:
tuple[ExternalStream, ExternalStream]: The two streams.
"""
if device_id is None:
device_id = torch.cuda.current_device()

res = torch.ops.sgl_kernel.create_greenctx_stream_by_value(SM_a, SM_b, device_id)

if (res[2] != SM_a) or (res[3] != SM_b):
raise RuntimeError(
f"The SMs of the created streams are not equal to the input SMs, expected: {SM_a}, {SM_b}, got: {res[2]}, {res[3]}"
)

stream_a = ExternalStream(
stream_ptr=res[0], device=torch.device(f"cuda:{device_id}")
)
stream_b = ExternalStream(
stream_ptr=res[1], device=torch.device(f"cuda:{device_id}")
)

return stream_a, stream_b


def get_sm_available(device_id: int = None) -> int:
"""
Get the SMs available on the device.
Args:
device_id (int): The device id.
Returns:
int: The SMs available.
"""
if device_id is None:
device_id = torch.cuda.current_device()

device_props = torch.cuda.get_device_properties(device_id)

# Get the number of Streaming Multiprocessors (SMs)
sm_count = device_props.multi_processor_count

return sm_count
25 changes: 25 additions & 0 deletions sgl-kernel/tests/spatial/test_greenctx_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import create_greenctx_stream_by_value, get_sm_available


def test_green_ctx():
A = torch.randn(5120, 5120).cuda()
B = torch.randn(5120, 5120).cuda()
C = torch.matmul(A, B)
sm_counts = get_sm_available(0)
stream_group = create_greenctx_stream_by_value(sm_counts // 2, sm_counts // 2, 0)
with torch.cuda.stream(stream_group[0]):
for _ in range(100):
result_0 = torch.matmul(A, B)
with torch.cuda.stream(stream_group[1]):
for _ in range(100):
result_1 = torch.matmul(A, B)
torch.cuda.synchronize()
assert torch.allclose(result_0, C)
assert torch.allclose(result_1, C)


if __name__ == "__main__":
pytest.main([__file__])