Skip to content

Commit 8b656a9

Browse files
authored
Fixes for UQFF + distributed layers (#1250)
* Fixes for uqff + distributed layers * Typo
1 parent d7cb787 commit 8b656a9

File tree

8 files changed

+474
-47
lines changed

8 files changed

+474
-47
lines changed

mistralrs-core/src/pipeline/isq.rs

Lines changed: 93 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ use candle_core::{quantized, Context, Device, Tensor};
1414
use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
1515
use itertools::Itertools;
1616
use mistralrs_quant::{
17-
CollectedImatrixData, FP8Linear, GgufMatMul, HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard,
18-
QuantizedSerde, QuantizedSerdeType, UnquantLinear,
17+
CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul, HqqLayer,
18+
IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType, ReplicatedLayer,
19+
RowParallelLayer, UnquantLinear,
1920
};
2021
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
2122
use regex::Regex;
@@ -690,6 +691,7 @@ pub trait IsqModel {
690691
});
691692

692693
let mut devices = Vec::new();
694+
let mut comms = Vec::new();
693695
for (_, layer_num) in &tensors {
694696
let device = if let Some(ref layers) = layers {
695697
if let Some(layer) = layer_num {
@@ -711,6 +713,7 @@ pub trait IsqModel {
711713
device.clone()
712714
};
713715
devices.push(device);
716+
comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
714717
}
715718

716719
let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::new(artifacts)? };
@@ -751,20 +754,51 @@ pub trait IsqModel {
751754
.map(|(i, (tensor, _))| {
752755
if let Some(artifact) = artifact_isqs.get(&i) {
753756
let artifact = artifact.data();
754-
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
755-
let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
756-
let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
757-
QuantizedSerdeType::Gguf => {
758-
GgufMatMul::deserialize(Cow::from(artifact), &devices[i])?
759-
}
760-
QuantizedSerdeType::Unquant => {
761-
UnquantLinear::deserialize(Cow::from(artifact), &devices[i])?
762-
}
763-
QuantizedSerdeType::Hqq => {
764-
HqqLayer::deserialize(Cow::from(artifact), &devices[i])?
757+
758+
let comm = comms[i].clone();
759+
let deserialized = match tensor.is_distributed() {
760+
Some(DistributedKind::ColumnParallel) => {
761+
ColumnParallelLayer::deserialize(
762+
Cow::from(artifact),
763+
&devices[i],
764+
&comm,
765+
)?
765766
}
766-
QuantizedSerdeType::Fp8 => {
767-
FP8Linear::deserialize(Cow::from(artifact), &devices[i])?
767+
Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
768+
Cow::from(artifact),
769+
&devices[i],
770+
&comm,
771+
)?,
772+
Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
773+
Cow::from(artifact),
774+
&devices[i],
775+
&comm,
776+
)?,
777+
None => {
778+
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
779+
let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
780+
match QuantizedSerdeType::try_from(isq_type as usize)? {
781+
QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
782+
Cow::from(artifact),
783+
&devices[i],
784+
&comm,
785+
)?,
786+
QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
787+
Cow::from(artifact),
788+
&devices[i],
789+
&comm,
790+
)?,
791+
QuantizedSerdeType::Hqq => HqqLayer::deserialize(
792+
Cow::from(artifact),
793+
&devices[i],
794+
&comm,
795+
)?,
796+
QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
797+
Cow::from(artifact),
798+
&devices[i],
799+
&comm,
800+
)?,
801+
}
768802
}
769803
};
770804
*tensor = deserialized;
@@ -780,20 +814,51 @@ pub trait IsqModel {
780814
.map(|(i, (tensor, _))| {
781815
if let Some(artifact) = artifact_isqs.get(&i) {
782816
let artifact = artifact.data();
783-
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
784-
let isq_type = artifact[4];
785-
let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
786-
QuantizedSerdeType::Gguf => {
787-
GgufMatMul::deserialize(Cow::from(artifact), &devices[i])?
788-
}
789-
QuantizedSerdeType::Unquant => {
790-
UnquantLinear::deserialize(Cow::from(artifact), &devices[i])?
791-
}
792-
QuantizedSerdeType::Hqq => {
793-
HqqLayer::deserialize(Cow::from(artifact), &devices[i])?
817+
818+
let comm = comms[i].clone();
819+
let deserialized = match tensor.is_distributed() {
820+
Some(DistributedKind::ColumnParallel) => {
821+
ColumnParallelLayer::deserialize(
822+
Cow::from(artifact),
823+
&devices[i],
824+
&comm,
825+
)?
794826
}
795-
QuantizedSerdeType::Fp8 => {
796-
FP8Linear::deserialize(Cow::from(artifact), &devices[i])?
827+
Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
828+
Cow::from(artifact),
829+
&devices[i],
830+
&comm,
831+
)?,
832+
Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
833+
Cow::from(artifact),
834+
&devices[i],
835+
&comm,
836+
)?,
837+
None => {
838+
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
839+
let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
840+
match QuantizedSerdeType::try_from(isq_type as usize)? {
841+
QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
842+
Cow::from(artifact),
843+
&devices[i],
844+
&comm,
845+
)?,
846+
QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
847+
Cow::from(artifact),
848+
&devices[i],
849+
&comm,
850+
)?,
851+
QuantizedSerdeType::Hqq => HqqLayer::deserialize(
852+
Cow::from(artifact),
853+
&devices[i],
854+
&comm,
855+
)?,
856+
QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
857+
Cow::from(artifact),
858+
&devices[i],
859+
&comm,
860+
)?,
861+
}
797862
}
798863
};
799864
*tensor = deserialized;

mistralrs-quant/src/bitsandbytes/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,11 @@ impl QuantizedSerde for BnbLinear {
282282
todo!()
283283
}
284284

285-
fn deserialize(_data: Cow<[u8]>, _device: &Device) -> Result<Arc<dyn QuantMethod>>
285+
fn deserialize(
286+
_data: Cow<[u8]>,
287+
_device: &Device,
288+
_comm: &Arc<crate::Comm>,
289+
) -> Result<Arc<dyn QuantMethod>>
286290
where
287291
Self: Sized,
288292
{

mistralrs-quant/src/distributed/layers.rs

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ use candle_core::{Context, Result, Tensor};
44
use candle_nn::Linear;
55

66
use crate::{
7-
blockwise_fp8::blockwise_fp8_linear_b, distributed, gptq::gptq_linear, BnbLinear, DummyLayer,
8-
QuantMethod, QuantMethodConfig, QuantMethodType, QuantizeOntoGuard, QuantizedConfig,
9-
QuantizedSerde, Shard, ShardedVarBuilder, UnquantLinear,
7+
blockwise_fp8::blockwise_fp8_linear_b, distributed, gptq::gptq_linear, BnbLinear,
8+
DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, QuantMethod, QuantMethodConfig,
9+
QuantMethodType, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, QuantizedSerdeType, Shard,
10+
ShardedVarBuilder, UnquantLinear,
1011
};
1112

12-
use super::{Comm, DistributedOperation};
13+
use super::{Comm, DistributedOperation, SumAllReduce};
1314

1415
fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
1516
Shard::Simple {
@@ -174,6 +175,10 @@ impl QuantMethod for RowParallelLayer {
174175
all_reduce: self.all_reduce.clone(),
175176
}))
176177
}
178+
179+
fn is_distributed(&self) -> Option<DistributedKind> {
180+
Some(DistributedKind::RowParallel)
181+
}
177182
}
178183

179184
impl QuantizedSerde for RowParallelLayer {
@@ -186,6 +191,28 @@ impl QuantizedSerde for RowParallelLayer {
186191
fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
187192
self.weight.serialize_with_bias(self.bias.clone())
188193
}
194+
fn deserialize(
195+
data: std::borrow::Cow<[u8]>,
196+
device: &candle_core::Device,
197+
comm: &Arc<crate::Comm>,
198+
) -> Result<Arc<dyn QuantMethod>>
199+
where
200+
Self: Sized,
201+
{
202+
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
203+
let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
204+
let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
205+
QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device)?,
206+
QuantizedSerdeType::Unquant => UnquantLinear::deserialize_ext_bias(data, device)?,
207+
QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device)?,
208+
QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device)?,
209+
};
210+
Ok(Arc::new(Self {
211+
weight,
212+
bias,
213+
all_reduce: SumAllReduce::new(comm),
214+
}))
215+
}
189216
}
190217

191218
#[derive(Debug)]
@@ -352,6 +379,10 @@ impl QuantMethod for ColumnParallelLayer {
352379
};
353380
Ok(Arc::new(Self { weight, bias }))
354381
}
382+
383+
fn is_distributed(&self) -> Option<DistributedKind> {
384+
Some(DistributedKind::ColumnParallel)
385+
}
355386
}
356387

357388
impl QuantizedSerde for ColumnParallelLayer {
@@ -364,6 +395,24 @@ impl QuantizedSerde for ColumnParallelLayer {
364395
fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
365396
self.weight.serialize_with_bias(self.bias.clone())
366397
}
398+
fn deserialize(
399+
data: std::borrow::Cow<[u8]>,
400+
device: &candle_core::Device,
401+
_comm: &Arc<crate::Comm>,
402+
) -> Result<Arc<dyn QuantMethod>>
403+
where
404+
Self: Sized,
405+
{
406+
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
407+
let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
408+
let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
409+
QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device)?,
410+
QuantizedSerdeType::Unquant => UnquantLinear::deserialize_ext_bias(data, device)?,
411+
QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device)?,
412+
QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device)?,
413+
};
414+
Ok(Arc::new(Self { weight, bias }))
415+
}
367416
}
368417

369418
#[derive(Debug)]
@@ -483,6 +532,10 @@ impl QuantMethod for ReplicatedLayer {
483532
.clone()
484533
.apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
485534
}
535+
536+
fn is_distributed(&self) -> Option<DistributedKind> {
537+
Some(DistributedKind::Replicated)
538+
}
486539
}
487540

488541
impl QuantizedSerde for ReplicatedLayer {
@@ -495,6 +548,24 @@ impl QuantizedSerde for ReplicatedLayer {
495548
fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
496549
self.0.serialize()
497550
}
551+
fn deserialize(
552+
data: std::borrow::Cow<[u8]>,
553+
device: &candle_core::Device,
554+
comm: &Arc<crate::Comm>,
555+
) -> Result<Arc<dyn QuantMethod>>
556+
where
557+
Self: Sized,
558+
{
559+
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
560+
let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
561+
let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
562+
QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm)?,
563+
QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm)?,
564+
QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm)?,
565+
QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm)?,
566+
};
567+
Ok(Arc::new(Self(deserialized)))
568+
}
498569
}
499570

500571
/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.

mistralrs-quant/src/fp8/mod.rs

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ impl QuantizedSerde for FP8Linear {
209209
"fp8-linear"
210210
}
211211
fn serialize(&self) -> Result<Cow<[u8]>> {
212+
self.serialize_with_bias(self.lin.bias().cloned())
213+
}
214+
fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
212215
let mut buffer = Vec::new();
213216

214217
// Version is always first!
@@ -218,7 +221,7 @@ impl QuantizedSerde for FP8Linear {
218221
buffer.push(QuantizedSerdeType::Fp8 as u8);
219222

220223
// Has bias
221-
buffer.push(self.lin.bias().is_some() as u8);
224+
buffer.push(bias.is_some() as u8);
222225

223226
// Weight
224227
serialize_tensor(&mut buffer, self.lin.weight())?;
@@ -233,15 +236,19 @@ impl QuantizedSerde for FP8Linear {
233236
// DType
234237
write_dtype(self.dtype, &mut buffer);
235238

236-
if let Some(bias) = self.lin.bias() {
239+
if let Some(bias) = &bias {
237240
// Bias
238241
serialize_tensor(&mut buffer, bias)?;
239242
}
240243

241244
Ok(Cow::from(buffer))
242245
}
243246

244-
fn deserialize(data: Cow<[u8]>, device: &Device) -> Result<Arc<dyn QuantMethod>>
247+
fn deserialize(
248+
data: Cow<[u8]>,
249+
device: &Device,
250+
_comm: &Arc<crate::Comm>,
251+
) -> Result<Arc<dyn QuantMethod>>
245252
where
246253
Self: Sized,
247254
{
@@ -285,4 +292,54 @@ impl QuantizedSerde for FP8Linear {
285292
dtype,
286293
}))
287294
}
295+
fn deserialize_ext_bias(
296+
data: Cow<[u8]>,
297+
device: &Device,
298+
) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
299+
where
300+
Self: Sized,
301+
{
302+
let mut buffer = Cursor::new(data.to_vec());
303+
304+
let version = buffer.read_u32::<LittleEndian>()?;
305+
if let Err(e) = version_is_compatible(version) {
306+
return Err(candle_core::Error::wrap(e));
307+
}
308+
309+
let isq_type = buffer.read_u8()? as usize;
310+
if isq_type != QuantizedSerdeType::Fp8 as usize {
311+
candle_core::bail!(
312+
"ISQ type ({isq_type}) doesn't match expected type {}",
313+
QuantizedSerdeType::Fp8 as usize
314+
);
315+
}
316+
317+
let has_bias = buffer.read_u8()? != 0;
318+
319+
let w = deserialize_tensor(&mut buffer, device)?;
320+
321+
let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
322+
let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
323+
let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
324+
325+
// DType
326+
let dtype = read_dtype(&mut buffer)?;
327+
328+
let b = if has_bias {
329+
Some(deserialize_tensor(&mut buffer, device)?)
330+
} else {
331+
None
332+
};
333+
334+
Ok((
335+
Arc::new(Self {
336+
lin: Linear::new(w, None),
337+
dequant_w_scale,
338+
dequant_x_scale,
339+
quant_scale,
340+
dtype,
341+
}),
342+
b,
343+
))
344+
}
288345
}

0 commit comments

Comments
 (0)