@@ -4,12 +4,13 @@ use candle_core::{Context, Result, Tensor};
44use candle_nn:: Linear ;
55
66use crate :: {
7- blockwise_fp8:: blockwise_fp8_linear_b, distributed, gptq:: gptq_linear, BnbLinear , DummyLayer ,
8- QuantMethod , QuantMethodConfig , QuantMethodType , QuantizeOntoGuard , QuantizedConfig ,
9- QuantizedSerde , Shard , ShardedVarBuilder , UnquantLinear ,
7+ blockwise_fp8:: blockwise_fp8_linear_b, distributed, gptq:: gptq_linear, BnbLinear ,
8+ DistributedKind , DummyLayer , FP8Linear , GgufMatMul , HqqLayer , QuantMethod , QuantMethodConfig ,
9+ QuantMethodType , QuantizeOntoGuard , QuantizedConfig , QuantizedSerde , QuantizedSerdeType , Shard ,
10+ ShardedVarBuilder , UnquantLinear ,
1011} ;
1112
12- use super :: { Comm , DistributedOperation } ;
13+ use super :: { Comm , DistributedOperation , SumAllReduce } ;
1314
1415fn shard ( dim : usize , rank : usize , world_size : usize ) -> Shard {
1516 Shard :: Simple {
@@ -174,6 +175,10 @@ impl QuantMethod for RowParallelLayer {
174175 all_reduce : self . all_reduce . clone ( ) ,
175176 } ) )
176177 }
178+
179+ fn is_distributed ( & self ) -> Option < DistributedKind > {
180+ Some ( DistributedKind :: RowParallel )
181+ }
177182}
178183
179184impl QuantizedSerde for RowParallelLayer {
@@ -186,6 +191,28 @@ impl QuantizedSerde for RowParallelLayer {
186191 fn serialize ( & self ) -> Result < std:: borrow:: Cow < [ u8 ] > > {
187192 self . weight . serialize_with_bias ( self . bias . clone ( ) )
188193 }
194+ fn deserialize (
195+ data : std:: borrow:: Cow < [ u8 ] > ,
196+ device : & candle_core:: Device ,
197+ comm : & Arc < crate :: Comm > ,
198+ ) -> Result < Arc < dyn QuantMethod > >
199+ where
200+ Self : Sized ,
201+ {
202+ // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
203+ let isq_type = data[ crate :: UQFF_QUANT_TYPE_OFFSET ] ;
204+ let ( weight, bias) = match QuantizedSerdeType :: try_from ( isq_type as usize ) ? {
205+ QuantizedSerdeType :: Gguf => GgufMatMul :: deserialize_ext_bias ( data, device) ?,
206+ QuantizedSerdeType :: Unquant => UnquantLinear :: deserialize_ext_bias ( data, device) ?,
207+ QuantizedSerdeType :: Hqq => HqqLayer :: deserialize_ext_bias ( data, device) ?,
208+ QuantizedSerdeType :: Fp8 => FP8Linear :: deserialize_ext_bias ( data, device) ?,
209+ } ;
210+ Ok ( Arc :: new ( Self {
211+ weight,
212+ bias,
213+ all_reduce : SumAllReduce :: new ( comm) ,
214+ } ) )
215+ }
189216}
190217
191218#[ derive( Debug ) ]
@@ -352,6 +379,10 @@ impl QuantMethod for ColumnParallelLayer {
352379 } ;
353380 Ok ( Arc :: new ( Self { weight, bias } ) )
354381 }
382+
383+ fn is_distributed ( & self ) -> Option < DistributedKind > {
384+ Some ( DistributedKind :: ColumnParallel )
385+ }
355386}
356387
357388impl QuantizedSerde for ColumnParallelLayer {
@@ -364,6 +395,24 @@ impl QuantizedSerde for ColumnParallelLayer {
364395 fn serialize ( & self ) -> Result < std:: borrow:: Cow < [ u8 ] > > {
365396 self . weight . serialize_with_bias ( self . bias . clone ( ) )
366397 }
398+ fn deserialize (
399+ data : std:: borrow:: Cow < [ u8 ] > ,
400+ device : & candle_core:: Device ,
401+ _comm : & Arc < crate :: Comm > ,
402+ ) -> Result < Arc < dyn QuantMethod > >
403+ where
404+ Self : Sized ,
405+ {
406+ // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
407+ let isq_type = data[ crate :: UQFF_QUANT_TYPE_OFFSET ] ;
408+ let ( weight, bias) = match QuantizedSerdeType :: try_from ( isq_type as usize ) ? {
409+ QuantizedSerdeType :: Gguf => GgufMatMul :: deserialize_ext_bias ( data, device) ?,
410+ QuantizedSerdeType :: Unquant => UnquantLinear :: deserialize_ext_bias ( data, device) ?,
411+ QuantizedSerdeType :: Hqq => HqqLayer :: deserialize_ext_bias ( data, device) ?,
412+ QuantizedSerdeType :: Fp8 => FP8Linear :: deserialize_ext_bias ( data, device) ?,
413+ } ;
414+ Ok ( Arc :: new ( Self { weight, bias } ) )
415+ }
367416}
368417
369418#[ derive( Debug ) ]
@@ -483,6 +532,10 @@ impl QuantMethod for ReplicatedLayer {
483532 . clone ( )
484533 . apply_isq ( dtype, device, n_quantized, imatrix_weight, guard)
485534 }
535+
536+ fn is_distributed ( & self ) -> Option < DistributedKind > {
537+ Some ( DistributedKind :: Replicated )
538+ }
486539}
487540
488541impl QuantizedSerde for ReplicatedLayer {
@@ -495,6 +548,24 @@ impl QuantizedSerde for ReplicatedLayer {
495548 fn serialize ( & self ) -> Result < std:: borrow:: Cow < [ u8 ] > > {
496549 self . 0 . serialize ( )
497550 }
551+ fn deserialize (
552+ data : std:: borrow:: Cow < [ u8 ] > ,
553+ device : & candle_core:: Device ,
554+ comm : & Arc < crate :: Comm > ,
555+ ) -> Result < Arc < dyn QuantMethod > >
556+ where
557+ Self : Sized ,
558+ {
559+ // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
560+ let isq_type = data[ crate :: UQFF_QUANT_TYPE_OFFSET ] ;
561+ let deserialized = match QuantizedSerdeType :: try_from ( isq_type as usize ) ? {
562+ QuantizedSerdeType :: Gguf => GgufMatMul :: deserialize ( data, device, comm) ?,
563+ QuantizedSerdeType :: Unquant => UnquantLinear :: deserialize ( data, device, comm) ?,
564+ QuantizedSerdeType :: Hqq => HqqLayer :: deserialize ( data, device, comm) ?,
565+ QuantizedSerdeType :: Fp8 => FP8Linear :: deserialize ( data, device, comm) ?,
566+ } ;
567+ Ok ( Arc :: new ( Self ( deserialized) ) )
568+ }
498569}
499570
500571/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.
0 commit comments