Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=
model_override_args["image_token_index"] = 64002

if args.num_frames == 32:
model_override_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
model_override_args["max_sequence_length"] = 4096 * 2
model_override_args["tokenizer_model_max_length"] = 4096 * 2
elif args.num_frames < 32:
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"outlines>=0.0.44", "modelscope"]
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
srt = ["sglang[runtime_common]", "torch", "vllm==0.5.5"]
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
srt_xpu = ["sglang[runtime_common]"]

openai = ["openai>=1.0", "tiktoken"]
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/launch_server_llavavid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
model_override_args["num_frames"] = 16
model_override_args["model_type"] = "llavavid"
if model_override_args["num_frames"] == 32:
model_override_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
model_override_args["max_sequence_length"] = 4096 * 2
model_override_args["tokenizer_model_max_length"] = 4096 * 2
model_override_args["model_max_length"] = 4096 * 2
Expand Down
152 changes: 89 additions & 63 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.parameter import (
BasevLLMParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter,
)

from sglang.srt.layers.quantization.base_config import (
Expand All @@ -39,6 +41,7 @@
"GPTQMarlinLinearMethod",
"Fp8LinearMethod",
"MarlinLinearMethod",
"GPTQLinearMethod",
]


Expand All @@ -50,7 +53,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


def adjust_bitsandbytes_shard(
def adjust_bitsandbytes_4bit_shard(
param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
Expand Down Expand Up @@ -207,7 +210,6 @@ def __init__(
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader,
prefix=prefix,
)

if bias:
Expand Down Expand Up @@ -315,7 +317,6 @@ def __init__(
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
prefix=prefix,
)
if bias:
self.bias = Parameter(
Expand Down Expand Up @@ -345,8 +346,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)

param_data = param.data
if output_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if output_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
Expand Down Expand Up @@ -454,17 +459,22 @@ def weight_loader(
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return

if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size

ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)

param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return

param_data = param.data
output_dim = getattr(param, "output_dim", None)
Expand Down Expand Up @@ -526,26 +536,17 @@ def weight_loader(
param, shard_size, shard_offset
)

use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id

if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape

input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)

param_data = param_data.narrow(output_dim, shard_offset, shard_size)
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -595,7 +596,7 @@ def _load_fused_module_from_checkpoint(
# If quantized, we need to adjust the offset and size to account
# for the packing.
if (
isinstance(param, PackedvLLMParameter)
isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
and param.packed_dim == param.output_dim
):
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
Expand All @@ -617,7 +618,7 @@ def weight_loader_v2(
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
return
elif type(param) is BasevLLMParameter:
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
Expand Down Expand Up @@ -760,7 +761,7 @@ def _load_fused_module_from_checkpoint(
# If quantized, we need to adjust the offset and size to account
# for the packing.
if (
isinstance(param, PackedvLLMParameter)
isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
and param.packed_dim == param.output_dim
):
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
Expand All @@ -780,10 +781,10 @@ def weight_loader_v2(
):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
return
elif type(param) is BasevLLMParameter:
param.load_merged_column_weight(loaded_weight=loaded_weight)
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
Expand Down Expand Up @@ -818,17 +819,22 @@ def weight_loader(
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return

if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size

ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)

param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return

param_data = param.data
output_dim = getattr(param, "output_dim", None)
Expand Down Expand Up @@ -863,6 +869,8 @@ def weight_loader(
self.total_num_kv_heads * self.head_size,
),
]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)

packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
Expand All @@ -877,6 +885,29 @@ def weight_loader(
param, shard_size, shard_offset
)

if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.total_num_heads * self.head_size),
"k": (
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
"v": (
(self.total_num_heads + self.total_num_kv_heads)
* self.head_size,
self.total_num_kv_heads * self.head_size,
),
"total": (
(self.total_num_heads + 2 * self.total_num_kv_heads)
* self.head_size,
0,
),
}

shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, shard_id
)

loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
Expand Down Expand Up @@ -910,8 +941,8 @@ def weight_loader(
param, shard_size, shard_offset
)

use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (
Expand All @@ -927,29 +958,22 @@ def weight_loader(
0,
),
}
shard_size, shard_offset = adjust_bitsandbytes_shard(
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id
)

if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape

input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)

param_data = param_data.narrow(output_dim, shard_offset, shard_size)
if loaded_shard_id == "q":
shard_id = tp_rank
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)

# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)

# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -1037,7 +1061,6 @@ def __init__(
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
prefix=prefix,
)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError(
Expand All @@ -1061,6 +1084,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)

# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
Expand All @@ -1076,7 +1100,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)

param_data = param.data
if input_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,9 @@ def initialize_weights(self):
loader = DefaultModelLoader(self.load_config)
revision = getattr(self.config.hf_config, "revision", None)
for name, loaded_weight in loader._get_weights_iterator(
model_path, revision=revision, fall_back_to_pt=True
DefaultModelLoader.Source(
model_path, revision=revision, fall_back_to_pt=True
)
):
match = re.search(r"layers\.(\d+)\.", name)
if match is not None:
Expand Down
19 changes: 14 additions & 5 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@
from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
is_attention_free_model,
is_embedding_model,
is_generation_model,
is_multimodal_model,
model_has_inner_state,
monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check,
)
Expand Down Expand Up @@ -316,11 +319,13 @@ def update_weights(self, model_path: str, load_format: str):

def get_weight_iter(config):
iter = loader._get_weights_iterator(
config.model,
config.revision,
fall_back_to_pt=getattr(
self.model, "fall_back_to_pt_during_load", True
),
DefaultModelLoader.Source(
config.model,
revision=config.revision,
fall_back_to_pt=getattr(
self.model, "fall_back_to_pt_during_load", True
),
)
)
return iter

Expand Down Expand Up @@ -662,3 +667,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:

# Monkey patch model loader
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model)
setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model)
setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state)
setattr(ModelRegistry, "is_embedding_model", is_embedding_model)
Loading