Skip to content

Commit b13877a

Browse files
bnellnmsumitd2
authored andcommitted
[Bugfix] Fix support for dimension like integers and ScalarType (vllm-project#9299)
Signed-off-by: Sumit Dubey <[email protected]>
1 parent e45c956 commit b13877a

22 files changed

+427
-677
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,12 @@ steps:
230230
commands:
231231
- pytest -v -s compile/test_basic_correctness.py
232232

233-
# TODO: re-write in comparison tests, and fix symbolic shape
234-
# for quantization ops.
235-
# - label: "PyTorch Fullgraph Test" # 18min
236-
# source_file_dependencies:
237-
# - vllm/
238-
# - tests/compile
239-
# commands:
240-
# - pytest -v -s compile/test_full_graph.py
233+
- label: "PyTorch Fullgraph Test" # 18min
234+
source_file_dependencies:
235+
- vllm/
236+
- tests/compile
237+
commands:
238+
- pytest -v -s compile/test_full_graph.py
241239

242240
- label: Kernels Test %N # 1h each
243241
mirror_hardwares: [amd]

CMakeLists.txt

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,24 +83,6 @@ endif()
8383
#
8484
find_package(Torch REQUIRED)
8585

86-
#
87-
message(STATUS "Enabling core extension.")
88-
89-
# Define _core_C extension
90-
# built for (almost) every target platform, (excludes TPU and Neuron)
91-
92-
set(VLLM_EXT_SRC
93-
"csrc/core/torch_bindings.cpp")
94-
95-
define_gpu_extension_target(
96-
_core_C
97-
DESTINATION vllm
98-
LANGUAGE CXX
99-
SOURCES ${VLLM_EXT_SRC}
100-
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
101-
USE_SABI 3
102-
WITH_SOABI)
103-
10486
#
10587
# Forward the non-CUDA device extensions to external CMake scripts.
10688
#

csrc/core/scalar_type.hpp

Lines changed: 4 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

3-
#include <torch/custom_class.h>
3+
// For TORCH_CHECK
4+
#include <torch/library.h>
45

56
namespace vllm {
67

@@ -9,12 +10,7 @@ namespace vllm {
910
// in particular it can be used to represent sub-byte data types (something
1011
// that torch.dtype currently does not support).
1112
//
12-
// ScalarTypeTorch is a subclass of ScalarType that is compatible with
13-
// TORCH_LIBRARY, making it accessible from Python as well meaning this class
14-
// can be used as a argument for custom operators, helping to simplify these
15-
// interfaces.
16-
//
17-
// The type definitions on the Python side can be found in: vllm/_core_ext.pyi
13+
// The type definitions on the Python side can be found in: vllm/scalar_type.py
1814
// these type definitions should be kept up to date with any Python API changes
1915
// here.
2016
//
@@ -308,204 +304,7 @@ class ScalarType {
308304
}
309305
};
310306

311-
// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
312-
// torch::CustomClassHolder), we use multiple inheritance here since we cannot
313-
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
314-
// constructor at the same time (torch::CustomClassHolder does not have a
315-
// constexpr destructor)
316-
// See also:
317-
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
318-
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
319-
public:
320-
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
321-
bool _signed)
322-
: ScalarType(exponent, mantissa, bias, _signed){};
323-
324-
ScalarTypeTorch(ScalarType type) : ScalarType(type){};
325-
326-
using Base = ScalarType;
327-
using Self = ScalarTypeTorch;
328-
using SelfPtr = c10::intrusive_ptr<Self>;
329-
330-
static void check_size_bits(int64_t size_bits, bool signed_) {
331-
TORCH_CHECK(
332-
size_bits <=
333-
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
334-
"size_bits bit width is too large to be represented");
335-
}
336-
337-
static void check_bias(int64_t bias) {
338-
using Bias = decltype(std::declval<Self>().bias);
339-
TORCH_CHECK(bias <= std::numeric_limits<Bias>::max() &&
340-
bias >= std::numeric_limits<Bias>::min(),
341-
"bias too large or small to be represented");
342-
}
343-
344-
static void check_exponent(int64_t exponent) {
345-
TORCH_CHECK(
346-
exponent <=
347-
std::numeric_limits<decltype(std::declval<Self>().exponent)>::max(),
348-
"exponent bit width is too large to be represented");
349-
}
350-
351-
static void check_mantissa(int64_t mantissa) {
352-
TORCH_CHECK(
353-
mantissa <=
354-
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
355-
"mantissa bit width is too large to be represented");
356-
}
357-
358-
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
359-
check_size_bits(size_bits, true);
360-
check_bias(bias.value_or(0));
361-
return c10::make_intrusive<Self>(
362-
ScalarType::int_(size_bits, bias.value_or(0)));
363-
}
364-
365-
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
366-
check_size_bits(size_bits, true);
367-
check_bias(bias.value_or(0));
368-
return c10::make_intrusive<Self>(
369-
ScalarType::uint(size_bits, bias.value_or(0)));
370-
}
371-
372-
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
373-
check_mantissa(mantissa);
374-
check_exponent(exponent);
375-
return c10::make_intrusive<Self>(
376-
ScalarType::float_IEEE754(exponent, mantissa));
377-
}
378-
379-
static SelfPtr float_(int64_t exponent, int64_t mantissa,
380-
bool finite_values_only, int64_t nan_repr) {
381-
check_mantissa(mantissa);
382-
check_exponent(exponent);
383-
return c10::make_intrusive<Self>(ScalarType::float_(
384-
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
385-
}
386-
387-
// This needs to be implemented and throw a TypeError in order for
388-
// PyTorch's opcheck to work on ops that use ScalarTypes.
389-
int64_t len() const {
390-
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
391-
"__len__ not implemented");
392-
return 0;
393-
}
394-
395-
// Serialize a ScalarType into a tuple of pairs. Where each pair
396-
// is a (fieldname, value).
397-
// For simplicity, we are just going to convert to a ScalarTypeId.
398-
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
399-
return {{"ScalarType", id()}};
400-
}
401-
402-
// Deserialize a scalar type that has been serialized by obj_flatten,
403-
// ostensibly from a tuple of (member name, value) pairs, but in reality
404-
// just a ScalarTypeId.
405-
static SelfPtr obj_unflatten(
406-
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
407-
return c10::make_intrusive<Self>(
408-
from_id(std::get<1>(std::get<0>(flat_type))));
409-
}
410-
411-
template <typename T>
412-
static void bind_readonly_property(torch::class_<Self>& cls,
413-
std::string const& name, T Base::*field) {
414-
auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) {
415-
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
416-
return (self.get()->*field)();
417-
} else {
418-
return self.get()->*field;
419-
}
420-
};
421-
422-
auto getter_func = [field = std::move(field),
423-
getter_func_helper = std::move(getter_func_helper)](
424-
SelfPtr const& self) {
425-
auto val = getter_func_helper(self);
426-
// upconvert uint8_t, int32_t etc. to int64_t for python
427-
if constexpr (std::is_integral_v<T>) {
428-
return static_cast<int64_t>(val);
429-
} else {
430-
return val;
431-
}
432-
};
433-
434-
cls.def_property(name, getter_func);
435-
}
436-
437-
template <typename MemberFunc, typename Cls>
438-
static void bind_function(torch::class_<Self>& cls, const std::string& name,
439-
MemberFunc Cls::*member) {
440-
cls.def(name, [member = std::move(member)](SelfPtr const& self) {
441-
return (self.get()->*member)();
442-
});
443-
}
444-
445-
template <typename Func>
446-
static void bind_function(torch::class_<Self>& cls, const std::string& name,
447-
Func func) {
448-
cls.def(name, func);
449-
}
450-
451-
template <typename Func>
452-
static void bind_static_function(torch::class_<Self>& cls,
453-
const std::string& name, Func func) {
454-
cls.def_static(name, func);
455-
}
456-
457-
static void bind_class(torch::Library& lib) {
458-
auto cls = lib.class_<ScalarTypeTorch>("ScalarType")
459-
.def(torch::init<int64_t, int64_t, int64_t, bool>());
460-
461-
// Bind Properties
462-
bind_readonly_property(cls, "mantissa", &Base::mantissa);
463-
bind_readonly_property(cls, "exponent", &Base::exponent);
464-
bind_readonly_property(cls, "bias", &Base::bias);
465-
bind_readonly_property(cls, "signed", &Base::is_signed);
466-
bind_readonly_property(cls, "size_bits", &Base::size_bits);
467-
468-
// Bind member functions
469-
bind_function(cls, "is_signed", &Base::is_signed);
470-
bind_function(cls, "is_integer", &Base::is_integer);
471-
bind_function(cls, "is_floating_point", &Base::is_floating_point);
472-
bind_function(cls, "is_ieee_754", &Base::is_ieee_754);
473-
bind_function(cls, "has_nans", &Base::has_nans);
474-
bind_function(cls, "has_infs", &Base::has_infs);
475-
bind_function(cls, "has_bias", &Base::has_bias);
476-
477-
bind_function(cls, "max", [](SelfPtr const& self) {
478-
return std::visit([](auto arg) { return c10::IValue(arg); },
479-
self.get()->max());
480-
});
481-
bind_function(cls, "min", [](SelfPtr const& self) {
482-
return std::visit([](auto arg) { return c10::IValue(arg); },
483-
self.get()->min());
484-
});
485-
486-
bind_function(cls, "__len__", &ScalarTypeTorch::len);
487-
bind_function(cls, "__str__", &Base::str);
488-
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
489-
return *self == *other;
490-
});
491-
bind_function(cls, "__repr__", [](SelfPtr const& self) {
492-
return "ScalarType." + self.get()->str();
493-
});
494-
495-
bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
496-
bind_static_function(cls, "__obj_unflatten__",
497-
&ScalarTypeTorch::obj_unflatten);
498-
499-
// Bind static functions (convenience constructors)
500-
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
501-
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
502-
bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754);
503-
bind_static_function(cls, "float_", &ScalarTypeTorch::float_);
504-
}
505-
};
506-
507-
using ScalarTypeId = int64_t;
508-
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
307+
using ScalarTypeId = ScalarType::Id;
509308

510309
// "rust style" names generally following:
511310
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70

csrc/core/torch_bindings.cpp

Lines changed: 0 additions & 16 deletions
This file was deleted.

csrc/moe/marlin_moe_ops.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -484,21 +484,22 @@ torch::Tensor marlin_gemm_moe(
484484
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
485485
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
486486
const torch::Tensor& perm, torch::Tensor& workspace,
487-
vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
487+
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
488488
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
489489
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
490+
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
490491
bool has_zp = b_zeros.size(1) != 0;
491492
if (has_zp) {
492493
TORCH_CHECK(
493-
*b_q_type == vllm::kU4,
494-
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str());
494+
b_q_type == vllm::kU4,
495+
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
495496
} else {
496497
TORCH_CHECK(
497-
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
498-
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
498+
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
499+
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str());
499500
}
500501

501-
int pack_factor = 32 / b_q_type->size_bits();
502+
int pack_factor = 32 / b_q_type.size_bits();
502503

503504
int max_par = 4;
504505

@@ -575,7 +576,7 @@ torch::Tensor marlin_gemm_moe(
575576
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
576577
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
577578
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
578-
*b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
579+
b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
579580
num_experts, topk, moe_block_size, dev,
580581
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
581582
replicate_input, apply_weights);

csrc/moe/torch_bindings.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
1313
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
1414
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
1515
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
16-
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
17-
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
16+
"int b_q_type, SymInt size_m, "
17+
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
18+
"topk, "
1819
"int moe_block_size, bool replicate_input, bool apply_weights)"
1920
" -> Tensor");
2021
// conditionally compiled so impl registration is in source file

0 commit comments

Comments
 (0)