Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
4432676
Add meta functions for ops to prevent graph breaks
bnellnm Jul 29, 2024
37adb20
format
bnellnm Jul 29, 2024
f0a93b7
add torch.compile to loader + symint support for gptq_gemm_meta + twe…
bnellnm Jul 29, 2024
3113671
pull out punica support test, move torch.compile to runner to avoid w…
bnellnm Aug 1, 2024
6462f77
tweaks
bnellnm Aug 1, 2024
1a1b0d9
change codebook_partition_sizes to List[int]
bnellnm Aug 2, 2024
6d703b3
use string schemas for all functions
bnellnm Aug 5, 2024
e6239f5
back out lora test hacks
bnellnm Aug 5, 2024
7d5d029
cleanups
bnellnm Aug 5, 2024
c86d222
fix flash_attn
bnellnm Aug 5, 2024
d85a2db
fix marlin schemas and meta funcs
bnellnm Aug 5, 2024
147e783
fix format
bnellnm Aug 6, 2024
518c6b5
add some opcheck tests
bnellnm Aug 6, 2024
128b617
fix registrations for non-Tensor ops
bnellnm Aug 6, 2024
6843f97
rebase + fix gguf registrations
bnellnm Aug 6, 2024
290d703
update PR template with info on pytorch registration
bnellnm Aug 6, 2024
e1e2b2a
try registering meta-function via python to handle symbolic shapes
bnellnm Aug 6, 2024
eb9753a
format
bnellnm Aug 6, 2024
82a1b0c
conditionally register gptq_marlin_24_gemm_fake
bnellnm Aug 6, 2024
4fb01e9
format stuff
bnellnm Aug 6, 2024
d4e2d82
try python meta functions
bnellnm Aug 7, 2024
ac31d31
temporarily add opchecks to almost all custom ops
bnellnm Aug 7, 2024
cc37741
comment out opchecks
bnellnm Aug 7, 2024
708a725
remove temporary opchecks in _custom_ops
bnellnm Aug 7, 2024
dc9fbd6
tweak copy_blocks schema
bnellnm Aug 8, 2024
9aae824
remove most C++ meta functions
bnellnm Aug 8, 2024
99ac24d
activation opcheck tests
bnellnm Aug 9, 2024
3149c45
add more opcheck tests
bnellnm Aug 9, 2024
2435313
add more opcheck tests
bnellnm Aug 9, 2024
2478a3c
run opchecks on fewer combinations to reduce memory use
bnellnm Aug 9, 2024
355250e
use @youkaichao's flash_attn registration
bnellnm Aug 9, 2024
8833436
fix format
bnellnm Aug 9, 2024
de5775d
fix cutlass test
bnellnm Aug 9, 2024
78aa723
add custom op for tensor_modle_parallel_all_reduce
SageMoore Aug 9, 2024
d721798
format
SageMoore Aug 9, 2024
4cbdaa0
register lora triton ops to avoid dynamo problems
bnellnm Aug 9, 2024
64e7007
fix cpu support in tensor_model_parallel_all_reduce
SageMoore Aug 9, 2024
69fa794
format
SageMoore Aug 9, 2024
d9833c5
cleanups
bnellnm Aug 5, 2024
ac65b5f
fix flash_attn signatures
bnellnm Aug 12, 2024
c05db30
rebase + cleanups
bnellnm Aug 12, 2024
bbc173c
tweaks + add gc.collect() to fix memory profiling errors when dynamo …
bnellnm Aug 13, 2024
4a62992
fix broken env var
bnellnm Aug 13, 2024
167652d
add clones to all_reduce
SageMoore Aug 14, 2024
1cb8184
fix format
bnellnm Aug 13, 2024
771daa4
fix aqlm custom op type annotations
bnellnm Aug 14, 2024
bff7d64
fix gptq custom op registration
bnellnm Aug 14, 2024
d805265
add dynamo support for ScalarType
bnellnm Aug 14, 2024
c1184fd
add some pointers to PT2 custom class docs
bnellnm Aug 14, 2024
a5a8489
tweaks
bnellnm Aug 16, 2024
49953f2
fix merge
bnellnm Aug 16, 2024
950be6a
fix cpu schemas
bnellnm Aug 16, 2024
ec0f252
fix merge
bnellnm Aug 17, 2024
4699534
rebase + add meta functions for machete kernels
bnellnm Aug 20, 2024
fd6b7c9
tweak tests + custom ar bindings
bnellnm Aug 26, 2024
6cecf82
rebase + fix schema for selective_scan_fwd, add meta functions for ne…
bnellnm Aug 29, 2024
3791beb
modify copy_blocks opcheck test
bnellnm Aug 29, 2024
12c845a
remove some custom ar changes
bnellnm Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ FIX #xxxx (*link existing issues this PR will resolve*)
<li>Please add documentation to <code>docs/source/</code> if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.</li>
</ul>

<h3>Adding or changing kernels</h3>
<p>Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.</p>
<ul>
<li>Make sure custom ops are registered following PyTorch guidelines: <a href="https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial">Custom C++ and CUDA Operators</a> and <a href="https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU">The Custom Operators Manual</a></li>
<li>Custom operations that return <code>Tensors</code> require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.</li>
<li>Use <a href="https://pytorch.org/docs/stable/library.html#torch.library.opcheck"><code>torch.libary.opcheck()</code></a> to test the function registration and meta-function for any registered ops. See <code>tests/kernels</code> for examples.</li>
<li>When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.</li>
<li>If a new custom type is needed, see the following document: <a href="https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA">Custom Class Support in PT2</a>.
</ul>

<h3>Notes for Large Changes</h3>
<p>Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with <code>rfc-required</code> and might not go through the PR.</p>

Expand Down
1 change: 1 addition & 0 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
${GPU_INCLUDE_DIRECTORIES})

# TODO: is torch_python_LIBRARY needed?
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY}
${GPU_LIBRARIES})

Expand Down
8 changes: 4 additions & 4 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
Expand Down Expand Up @@ -95,8 +95,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {

// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
"block_mapping) -> ()");
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
"Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);

// Reshape the key and value tensors and cache them.
Expand Down
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);

torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits);

torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);

torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits);

torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);

Expand Down
12 changes: 12 additions & 0 deletions csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,15 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
}

#endif

torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
12 changes: 12 additions & 0 deletions csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
}

#endif

torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
144 changes: 100 additions & 44 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
Expand Down Expand Up @@ -73,7 +73,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);

// prepare_inputs advance_step
ops.def("advance_step", &advance_step);
ops.def(
"advance_step(int num_seqs, int num_queries, int block_size, "
"Tensor! input_tokens, Tensor sampled_token_ids, "
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
"Tensor block_tables) -> ()");
ops.impl("advance_step", torch::kCUDA, &advance_step);

// Layernorm
Expand Down Expand Up @@ -110,27 +114,56 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
ops.def("aqlm_gemm", &aqlm_gemm);
ops.def(
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
"-> Tensor");
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);

// Decompression method for AQLM.
ops.def("aqlm_dequant", &aqlm_dequant);
ops.def(
"aqlm_dequant(Tensor codes, Tensor codebooks, "
"int[] codebook_partition_sizes) -> Tensor");
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);

// Quantized GEMM for AWQ.
ops.def("awq_gemm", &awq_gemm);
ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, int split_k_iters) -> Tensor");
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);

// Dequantization for AWQ.
ops.def("awq_dequantize", &awq_dequantize);
ops.def(
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor");
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);

// Note about marlin kernel 'workspace' arguments:
// Technically these should be mutable since they are modified by the kernel.
// But since they are set back to zero once the kernel is finished we can
// hand wave and say that they have no net effect.
//
// The reason to mark 'workspace' as immutable is so that they don't interfere
// with using ScalarType arguments in the ops. If they are marked as mutable,
// pytorch throws an assert in
// 'torch._higher_order_ops._register_effectful_op' that prevents these
// kernels from being torch.compile'd.
// See the following document for more info on custom types and ops that use
// custom types:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA

// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
ops.def("marlin_gemm", &marlin_gemm);
ops.def(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);

// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
ops.def(
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
"Tensor b_scales, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k) -> Tensor");
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);

// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
Expand All @@ -149,35 +182,55 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);

// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor");
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);

// gptq_marlin repack from GPTQ.
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);

// awq_marlin repack from AWQ.
ops.def("awq_marlin_repack", &awq_marlin_repack);
ops.def(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor");
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);

// Dequantization for GGML.
ops.def("ggml_dequantize", &ggml_dequantize);
ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);

// mmvq kernel for GGML.
ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8);
ops.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
"-> Tensor");
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);

// mmq kernel for GGML.
ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor");
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);

// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
ops.def(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, int size_m, int size_n, "
"int size_k) -> Tensor");
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);

// marlin_qqq_gemm for QQQ.
ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
ops.def(
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace, int size_m, int size_n, "
"int size_k) -> Tensor");
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);

// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
Expand All @@ -199,16 +252,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
&cutlass_scaled_mm_supports_fp8);
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);

// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor? x) -> Tensor[]");
"Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);

ops.def(
Expand All @@ -230,7 +283,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif

// Quantized GEMM for GPTQ.
ops.def("gptq_gemm", &gptq_gemm);
// Note: even though the C++ inferred schema is correct for this op, it seems
// to prevent the meta function registry.
ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
"-> Tensor");
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);

// Post processing for GPTQ.
Expand All @@ -250,8 +308,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
"scale, Tensor? scale_ub) -> "
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);
Expand Down Expand Up @@ -288,8 +346,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {

// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
"block_mapping) -> ()");
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
"Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);

// Reshape the key and value tensors and cache them.
Expand All @@ -314,33 +372,37 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {

// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
"kv_cache_dtype) -> ()");
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
"str kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils

// Gets the specified device attribute.
cuda_utils.def("get_device_attribute", &get_device_attribute);
cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
cuda_utils.impl("get_device_attribute", &get_device_attribute);

// Gets the maximum shared memory per block device attribute.
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute);
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
torch::kCUDA,
&get_max_shared_memory_per_block_device_attribute);
}

#ifndef USE_ROCM
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def("init_custom_ar", &init_custom_ar);
custom_ar.def(
"init_custom_ar(Tensor meta, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, "
"bool full_nvlink) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);

custom_ar.def("should_custom_ar", &should_custom_ar);
custom_ar.def(
"should_custom_ar(Tensor inp, int max_size, int world_size, "
"bool full_nvlink) -> bool");
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);

custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
Expand All @@ -352,21 +414,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);

custom_ar.def("dispose", &dispose);
custom_ar.impl("dispose", torch::kCPU, &dispose);

custom_ar.def("meta_size", &meta_size);
custom_ar.impl("meta_size", torch::kCPU, &meta_size);

custom_ar.def("register_buffer", &register_buffer);
custom_ar.def(
"register_buffer(int fa, Tensor t, str[] handles, "
"int[] offsets) -> ()");
custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);

custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
&get_graph_buffer_ipc_meta);

custom_ar.def("register_graph_buffers", &register_graph_buffers);
custom_ar.impl("register_graph_buffers", torch::kCPU,
&register_graph_buffers);
}
#endif

Expand Down
Loading