11#pragma once
22
3- #include < torch/custom_class.h>
3+ // For TORCH_CHECK
4+ #include < torch/library.h>
45
56namespace vllm {
67
@@ -9,12 +10,7 @@ namespace vllm {
910// in particular it can be used to represent sub-byte data types (something
1011// that torch.dtype currently does not support).
1112//
12- // ScalarTypeTorch is a subclass of ScalarType that is compatible with
13- // TORCH_LIBRARY, making it accessible from Python as well meaning this class
14- // can be used as a argument for custom operators, helping to simplify these
15- // interfaces.
16- //
17- // The type definitions on the Python side can be found in: vllm/_core_ext.pyi
13+ // The type definitions on the Python side can be found in: vllm/scalar_type.py
1814// these type definitions should be kept up to date with any Python API changes
1915// here.
2016//
@@ -308,204 +304,7 @@ class ScalarType {
308304 }
309305};
310306
311- // Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
312- // torch::CustomClassHolder), we use multiple inheritance here since we cannot
313- // have ScalarType inherit from torch::CustomClassHolder and have a constexpr
314- // constructor at the same time (torch::CustomClassHolder does not have a
315- // constexpr destructor)
316- // See also:
317- // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
318- class ScalarTypeTorch : public torch ::CustomClassHolder, public ScalarType {
319- public:
320- ScalarTypeTorch (int64_t exponent, int64_t mantissa, int64_t bias,
321- bool _signed)
322- : ScalarType(exponent, mantissa, bias, _signed){};
323-
324- ScalarTypeTorch (ScalarType type) : ScalarType(type){};
325-
326- using Base = ScalarType;
327- using Self = ScalarTypeTorch;
328- using SelfPtr = c10::intrusive_ptr<Self>;
329-
330- static void check_size_bits (int64_t size_bits, bool signed_) {
331- TORCH_CHECK (
332- size_bits <=
333- std::numeric_limits<decltype (std::declval<Self>().mantissa )>::max (),
334- " size_bits bit width is too large to be represented" );
335- }
336-
337- static void check_bias (int64_t bias) {
338- using Bias = decltype (std::declval<Self>().bias );
339- TORCH_CHECK (bias <= std::numeric_limits<Bias>::max () &&
340- bias >= std::numeric_limits<Bias>::min (),
341- " bias too large or small to be represented" );
342- }
343-
344- static void check_exponent (int64_t exponent) {
345- TORCH_CHECK (
346- exponent <=
347- std::numeric_limits<decltype (std::declval<Self>().exponent )>::max (),
348- " exponent bit width is too large to be represented" );
349- }
350-
351- static void check_mantissa (int64_t mantissa) {
352- TORCH_CHECK (
353- mantissa <=
354- std::numeric_limits<decltype (std::declval<Self>().mantissa )>::max (),
355- " mantissa bit width is too large to be represented" );
356- }
357-
358- static SelfPtr int_ (int64_t size_bits, c10::optional<int64_t > bias) {
359- check_size_bits (size_bits, true );
360- check_bias (bias.value_or (0 ));
361- return c10::make_intrusive<Self>(
362- ScalarType::int_ (size_bits, bias.value_or (0 )));
363- }
364-
365- static SelfPtr uint (int64_t size_bits, c10::optional<int64_t > bias) {
366- check_size_bits (size_bits, true );
367- check_bias (bias.value_or (0 ));
368- return c10::make_intrusive<Self>(
369- ScalarType::uint (size_bits, bias.value_or (0 )));
370- }
371-
372- static SelfPtr float_IEEE754 (int64_t exponent, int64_t mantissa) {
373- check_mantissa (mantissa);
374- check_exponent (exponent);
375- return c10::make_intrusive<Self>(
376- ScalarType::float_IEEE754 (exponent, mantissa));
377- }
378-
379- static SelfPtr float_ (int64_t exponent, int64_t mantissa,
380- bool finite_values_only, int64_t nan_repr) {
381- check_mantissa (mantissa);
382- check_exponent (exponent);
383- return c10::make_intrusive<Self>(ScalarType::float_ (
384- exponent, mantissa, finite_values_only, NanRepr (nan_repr)));
385- }
386-
387- // This needs to be implemented and throw a TypeError in order for
388- // PyTorch's opcheck to work on ops that use ScalarTypes.
389- int64_t len () const {
390- throw c10::TypeError ({__func__, __FILE__, static_cast <uint32_t >(__LINE__)},
391- " __len__ not implemented" );
392- return 0 ;
393- }
394-
395- // Serialize a ScalarType into a tuple of pairs. Where each pair
396- // is a (fieldname, value).
397- // For simplicity, we are just going to convert to a ScalarTypeId.
398- std::tuple<std::tuple<std::string, int64_t >> obj_flatten () const {
399- return {{" ScalarType" , id ()}};
400- }
401-
402- // Deserialize a scalar type that has been serialized by obj_flatten,
403- // ostensibly from a tuple of (member name, value) pairs, but in reality
404- // just a ScalarTypeId.
405- static SelfPtr obj_unflatten (
406- std::tuple<std::tuple<std::string, int64_t >> const & flat_type) {
407- return c10::make_intrusive<Self>(
408- from_id (std::get<1 >(std::get<0 >(flat_type))));
409- }
410-
411- template <typename T>
412- static void bind_readonly_property (torch::class_<Self>& cls,
413- std::string const & name, T Base::*field) {
414- auto getter_func_helper = [field = std::move (field)](SelfPtr const & self) {
415- if constexpr (std::is_member_function_pointer_v<decltype (field)>) {
416- return (self.get ()->*field)();
417- } else {
418- return self.get ()->*field;
419- }
420- };
421-
422- auto getter_func = [field = std::move (field),
423- getter_func_helper = std::move (getter_func_helper)](
424- SelfPtr const & self) {
425- auto val = getter_func_helper (self);
426- // upconvert uint8_t, int32_t etc. to int64_t for python
427- if constexpr (std::is_integral_v<T>) {
428- return static_cast <int64_t >(val);
429- } else {
430- return val;
431- }
432- };
433-
434- cls.def_property (name, getter_func);
435- }
436-
437- template <typename MemberFunc, typename Cls>
438- static void bind_function (torch::class_<Self>& cls, const std::string& name,
439- MemberFunc Cls::*member) {
440- cls.def (name, [member = std::move (member)](SelfPtr const & self) {
441- return (self.get ()->*member)();
442- });
443- }
444-
445- template <typename Func>
446- static void bind_function (torch::class_<Self>& cls, const std::string& name,
447- Func func) {
448- cls.def (name, func);
449- }
450-
451- template <typename Func>
452- static void bind_static_function (torch::class_<Self>& cls,
453- const std::string& name, Func func) {
454- cls.def_static (name, func);
455- }
456-
457- static void bind_class (torch::Library& lib) {
458- auto cls = lib.class_ <ScalarTypeTorch>(" ScalarType" )
459- .def (torch::init<int64_t , int64_t , int64_t , bool >());
460-
461- // Bind Properties
462- bind_readonly_property (cls, " mantissa" , &Base::mantissa);
463- bind_readonly_property (cls, " exponent" , &Base::exponent);
464- bind_readonly_property (cls, " bias" , &Base::bias);
465- bind_readonly_property (cls, " signed" , &Base::is_signed);
466- bind_readonly_property (cls, " size_bits" , &Base::size_bits);
467-
468- // Bind member functions
469- bind_function (cls, " is_signed" , &Base::is_signed);
470- bind_function (cls, " is_integer" , &Base::is_integer);
471- bind_function (cls, " is_floating_point" , &Base::is_floating_point);
472- bind_function (cls, " is_ieee_754" , &Base::is_ieee_754);
473- bind_function (cls, " has_nans" , &Base::has_nans);
474- bind_function (cls, " has_infs" , &Base::has_infs);
475- bind_function (cls, " has_bias" , &Base::has_bias);
476-
477- bind_function (cls, " max" , [](SelfPtr const & self) {
478- return std::visit ([](auto arg) { return c10::IValue (arg); },
479- self.get ()->max ());
480- });
481- bind_function (cls, " min" , [](SelfPtr const & self) {
482- return std::visit ([](auto arg) { return c10::IValue (arg); },
483- self.get ()->min ());
484- });
485-
486- bind_function (cls, " __len__" , &ScalarTypeTorch::len);
487- bind_function (cls, " __str__" , &Base::str);
488- bind_function (cls, " __eq__" , [](SelfPtr const & self, SelfPtr const & other) {
489- return *self == *other;
490- });
491- bind_function (cls, " __repr__" , [](SelfPtr const & self) {
492- return " ScalarType." + self.get ()->str ();
493- });
494-
495- bind_function (cls, " __obj_flatten__" , &ScalarTypeTorch::obj_flatten);
496- bind_static_function (cls, " __obj_unflatten__" ,
497- &ScalarTypeTorch::obj_unflatten);
498-
499- // Bind static functions (convenience constructors)
500- bind_static_function (cls, " int_" , &ScalarTypeTorch::int_);
501- bind_static_function (cls, " uint" , &ScalarTypeTorch::uint);
502- bind_static_function (cls, " float_IEEE754" , &ScalarTypeTorch::float_IEEE754);
503- bind_static_function (cls, " float_" , &ScalarTypeTorch::float_);
504- }
505- };
506-
507- using ScalarTypeId = int64_t ;
508- using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
307+ using ScalarTypeId = ScalarType::Id;
509308
510309// "rust style" names generally following:
511310// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
0 commit comments