Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 93 additions & 28 deletions mistralrs-core/src/pipeline/isq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ use candle_core::{quantized, Context, Device, Tensor};
use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
use itertools::Itertools;
use mistralrs_quant::{
CollectedImatrixData, FP8Linear, GgufMatMul, HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard,
QuantizedSerde, QuantizedSerdeType, UnquantLinear,
CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul, HqqLayer,
IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType, ReplicatedLayer,
RowParallelLayer, UnquantLinear,
};
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use regex::Regex;
Expand Down Expand Up @@ -690,6 +691,7 @@ pub trait IsqModel {
});

let mut devices = Vec::new();
let mut comms = Vec::new();
for (_, layer_num) in &tensors {
let device = if let Some(ref layers) = layers {
if let Some(layer) = layer_num {
Expand All @@ -711,6 +713,7 @@ pub trait IsqModel {
device.clone()
};
devices.push(device);
comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
}

let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::new(artifacts)? };
Expand Down Expand Up @@ -751,20 +754,51 @@ pub trait IsqModel {
.map(|(i, (tensor, _))| {
if let Some(artifact) = artifact_isqs.get(&i) {
let artifact = artifact.data();
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
QuantizedSerdeType::Gguf => {
GgufMatMul::deserialize(Cow::from(artifact), &devices[i])?
}
QuantizedSerdeType::Unquant => {
UnquantLinear::deserialize(Cow::from(artifact), &devices[i])?
}
QuantizedSerdeType::Hqq => {
HqqLayer::deserialize(Cow::from(artifact), &devices[i])?

let comm = comms[i].clone();
let deserialized = match tensor.is_distributed() {
Some(DistributedKind::ColumnParallel) => {
ColumnParallelLayer::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?
}
QuantizedSerdeType::Fp8 => {
FP8Linear::deserialize(Cow::from(artifact), &devices[i])?
Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
None => {
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
match QuantizedSerdeType::try_from(isq_type as usize)? {
QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
QuantizedSerdeType::Hqq => HqqLayer::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
}
}
};
*tensor = deserialized;
Expand All @@ -780,20 +814,51 @@ pub trait IsqModel {
.map(|(i, (tensor, _))| {
if let Some(artifact) = artifact_isqs.get(&i) {
let artifact = artifact.data();
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
let isq_type = artifact[4];
let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
QuantizedSerdeType::Gguf => {
GgufMatMul::deserialize(Cow::from(artifact), &devices[i])?
}
QuantizedSerdeType::Unquant => {
UnquantLinear::deserialize(Cow::from(artifact), &devices[i])?
}
QuantizedSerdeType::Hqq => {
HqqLayer::deserialize(Cow::from(artifact), &devices[i])?

let comm = comms[i].clone();
let deserialized = match tensor.is_distributed() {
Some(DistributedKind::ColumnParallel) => {
ColumnParallelLayer::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?
}
QuantizedSerdeType::Fp8 => {
FP8Linear::deserialize(Cow::from(artifact), &devices[i])?
Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
None => {
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
match QuantizedSerdeType::try_from(isq_type as usize)? {
QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
QuantizedSerdeType::Hqq => HqqLayer::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
Cow::from(artifact),
&devices[i],
&comm,
)?,
}
}
};
*tensor = deserialized;
Expand Down
6 changes: 5 additions & 1 deletion mistralrs-quant/src/bitsandbytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,11 @@ impl QuantizedSerde for BnbLinear {
todo!()
}

fn deserialize(_data: Cow<[u8]>, _device: &Device) -> Result<Arc<dyn QuantMethod>>
fn deserialize(
_data: Cow<[u8]>,
_device: &Device,
_comm: &Arc<crate::Comm>,
) -> Result<Arc<dyn QuantMethod>>
where
Self: Sized,
{
Expand Down
79 changes: 75 additions & 4 deletions mistralrs-quant/src/distributed/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ use candle_core::{Context, Result, Tensor};
use candle_nn::Linear;

use crate::{
blockwise_fp8::blockwise_fp8_linear_b, distributed, gptq::gptq_linear, BnbLinear, DummyLayer,
QuantMethod, QuantMethodConfig, QuantMethodType, QuantizeOntoGuard, QuantizedConfig,
QuantizedSerde, Shard, ShardedVarBuilder, UnquantLinear,
blockwise_fp8::blockwise_fp8_linear_b, distributed, gptq::gptq_linear, BnbLinear,
DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, QuantMethod, QuantMethodConfig,
QuantMethodType, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, QuantizedSerdeType, Shard,
ShardedVarBuilder, UnquantLinear,
};

use super::{Comm, DistributedOperation};
use super::{Comm, DistributedOperation, SumAllReduce};

fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
Shard::Simple {
Expand Down Expand Up @@ -174,6 +175,10 @@ impl QuantMethod for RowParallelLayer {
all_reduce: self.all_reduce.clone(),
}))
}

fn is_distributed(&self) -> Option<DistributedKind> {
Some(DistributedKind::RowParallel)
}
}

impl QuantizedSerde for RowParallelLayer {
Expand All @@ -186,6 +191,28 @@ impl QuantizedSerde for RowParallelLayer {
fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
self.weight.serialize_with_bias(self.bias.clone())
}
fn deserialize(
data: std::borrow::Cow<[u8]>,
device: &candle_core::Device,
comm: &Arc<crate::Comm>,
) -> Result<Arc<dyn QuantMethod>>
where
Self: Sized,
{
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device)?,
QuantizedSerdeType::Unquant => UnquantLinear::deserialize_ext_bias(data, device)?,
QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device)?,
QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device)?,
};
Ok(Arc::new(Self {
weight,
bias,
all_reduce: SumAllReduce::new(comm),
}))
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -352,6 +379,10 @@ impl QuantMethod for ColumnParallelLayer {
};
Ok(Arc::new(Self { weight, bias }))
}

fn is_distributed(&self) -> Option<DistributedKind> {
Some(DistributedKind::ColumnParallel)
}
}

impl QuantizedSerde for ColumnParallelLayer {
Expand All @@ -364,6 +395,24 @@ impl QuantizedSerde for ColumnParallelLayer {
fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
self.weight.serialize_with_bias(self.bias.clone())
}
fn deserialize(
data: std::borrow::Cow<[u8]>,
device: &candle_core::Device,
_comm: &Arc<crate::Comm>,
) -> Result<Arc<dyn QuantMethod>>
where
Self: Sized,
{
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device)?,
QuantizedSerdeType::Unquant => UnquantLinear::deserialize_ext_bias(data, device)?,
QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device)?,
QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device)?,
};
Ok(Arc::new(Self { weight, bias }))
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -483,6 +532,10 @@ impl QuantMethod for ReplicatedLayer {
.clone()
.apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
}

fn is_distributed(&self) -> Option<DistributedKind> {
Some(DistributedKind::Replicated)
}
}

impl QuantizedSerde for ReplicatedLayer {
Expand All @@ -495,6 +548,24 @@ impl QuantizedSerde for ReplicatedLayer {
fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
self.0.serialize()
}
fn deserialize(
data: std::borrow::Cow<[u8]>,
device: &candle_core::Device,
comm: &Arc<crate::Comm>,
) -> Result<Arc<dyn QuantMethod>>
where
Self: Sized,
{
// NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm)?,
QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm)?,
QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm)?,
QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm)?,
};
Ok(Arc::new(Self(deserialized)))
}
}

/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.
Expand Down
63 changes: 60 additions & 3 deletions mistralrs-quant/src/fp8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ impl QuantizedSerde for FP8Linear {
"fp8-linear"
}
fn serialize(&self) -> Result<Cow<[u8]>> {
self.serialize_with_bias(self.lin.bias().cloned())
}
fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
let mut buffer = Vec::new();

// Version is always first!
Expand All @@ -218,7 +221,7 @@ impl QuantizedSerde for FP8Linear {
buffer.push(QuantizedSerdeType::Fp8 as u8);

// Has bias
buffer.push(self.lin.bias().is_some() as u8);
buffer.push(bias.is_some() as u8);

// Weight
serialize_tensor(&mut buffer, self.lin.weight())?;
Expand All @@ -233,15 +236,19 @@ impl QuantizedSerde for FP8Linear {
// DType
write_dtype(self.dtype, &mut buffer);

if let Some(bias) = self.lin.bias() {
if let Some(bias) = &bias {
// Bias
serialize_tensor(&mut buffer, bias)?;
}

Ok(Cow::from(buffer))
}

fn deserialize(data: Cow<[u8]>, device: &Device) -> Result<Arc<dyn QuantMethod>>
fn deserialize(
data: Cow<[u8]>,
device: &Device,
_comm: &Arc<crate::Comm>,
) -> Result<Arc<dyn QuantMethod>>
where
Self: Sized,
{
Expand Down Expand Up @@ -285,4 +292,54 @@ impl QuantizedSerde for FP8Linear {
dtype,
}))
}
fn deserialize_ext_bias(
data: Cow<[u8]>,
device: &Device,
) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
where
Self: Sized,
{
let mut buffer = Cursor::new(data.to_vec());

let version = buffer.read_u32::<LittleEndian>()?;
if let Err(e) = version_is_compatible(version) {
return Err(candle_core::Error::wrap(e));
}

let isq_type = buffer.read_u8()? as usize;
if isq_type != QuantizedSerdeType::Fp8 as usize {
candle_core::bail!(
"ISQ type ({isq_type}) doesn't match expected type {}",
QuantizedSerdeType::Fp8 as usize
);
}

let has_bias = buffer.read_u8()? != 0;

let w = deserialize_tensor(&mut buffer, device)?;

let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;

// DType
let dtype = read_dtype(&mut buffer)?;

let b = if has_bias {
Some(deserialize_tensor(&mut buffer, device)?)
} else {
None
};

Ok((
Arc::new(Self {
lin: Linear::new(w, None),
dequant_w_scale,
dequant_x_scale,
quant_scale,
dtype,
}),
b,
))
}
}
Loading
Loading