-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Add INT4 compressed-tensors + LoRA support (including MoE) #28791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8ef757a
2a0f94e
8fd7c16
038244d
22bf730
57faaea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -614,22 +614,45 @@ def create_dummy_lora( | |
| if module_name not in self.packed_modules: | ||
| assert embedding_modules is not None | ||
| if parts[-1] in embedding_modules: | ||
| input_dim = ( | ||
| module.base_layer.org_vocab_size | ||
| + self.lora_config.lora_extra_vocab_size | ||
| if hasattr(module.base_layer, "org_vocab_size") | ||
| else module.base_layer.weight.shape[1] | ||
| ) | ||
| output_dim = ( | ||
| module.base_layer.embedding_dim | ||
| if hasattr(module.base_layer, "embedding_dim") | ||
| else module.base_layer.weight.shape[0] | ||
| ) | ||
| embeddings_tensor_dim = ( | ||
| module.base_layer.embedding_dim | ||
| if hasattr(module.base_layer, "embedding_dim") | ||
| else module.base_layer.weight.shape[1] | ||
| ) | ||
| # Try to get dimensions from layer attributes first | ||
| if hasattr(module.base_layer, "org_vocab_size"): | ||
| input_dim = ( | ||
| module.base_layer.org_vocab_size | ||
| + self.lora_config.lora_extra_vocab_size | ||
| ) | ||
| elif hasattr(module.base_layer, "input_size"): | ||
| input_dim = module.base_layer.input_size | ||
| elif hasattr(module.base_layer, "weight_shape"): | ||
| # Compressed tensors: weight_shape stores [output, input] | ||
| # For embeddings: [vocab_size, embedding_dim] | ||
| input_dim = module.base_layer.weight_shape[0].item() | ||
| else: | ||
| # For embeddings: weight.shape = [vocab_size, embedding_dim] | ||
| input_dim = module.weight.shape[0] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't it be worrying that tests passed with an issue like this
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. Not sure the tests cover this branch. |
||
|
|
||
| if hasattr(module.base_layer, "embedding_dim"): | ||
| output_dim = module.base_layer.embedding_dim | ||
| elif hasattr(module.base_layer, "output_size"): | ||
| output_dim = module.base_layer.output_size | ||
| elif hasattr(module.base_layer, "weight_shape"): | ||
| # Compressed tensors: weight_shape stores [output, input] | ||
| # For embeddings: [vocab_size, embedding_dim] | ||
| output_dim = module.base_layer.weight_shape[1].item() | ||
| else: | ||
| # For embeddings: weight.shape = [vocab_size, embedding_dim] | ||
| output_dim = module.weight.shape[1] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also backward, should be shape[0] |
||
|
|
||
| if hasattr(module.base_layer, "embedding_dim"): | ||
| embeddings_tensor_dim = module.base_layer.embedding_dim | ||
| elif hasattr(module.base_layer, "output_size"): | ||
| embeddings_tensor_dim = module.base_layer.output_size | ||
| elif hasattr(module.base_layer, "weight_shape"): | ||
| # Compressed tensors: weight_shape stores [output, input] | ||
| # For embeddings: [vocab_size, embedding_dim] | ||
| embeddings_tensor_dim = module.base_layer.weight_shape[1].item() | ||
| else: | ||
| # For embeddings: weight.shape = [vocab_size, embedding_dim] | ||
| embeddings_tensor_dim = module.weight.shape[1] | ||
| lora = LoRALayerWeights.create_dummy_lora_weights( | ||
| module_name, | ||
| input_dim, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -203,6 +203,10 @@ def create_weights( | |
| params_dtype: torch.dtype, | ||
| **extra_weight_attrs, | ||
| ): | ||
| # Set layer attributes needed for LoRA compatibility | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't this only needed for W4A16? |
||
| layer.hidden_size = hidden_size | ||
| layer.intermediate_size_per_partition = intermediate_size_per_partition | ||
| layer.local_num_experts = num_experts | ||
| layer.num_experts = num_experts | ||
| layer.params_dtype = params_dtype | ||
|
|
||
|
|
@@ -1367,6 +1371,11 @@ def create_weights( | |
| params_dtype: torch.dtype, | ||
| **extra_weight_attrs, | ||
| ): | ||
| # Set layer attributes needed for LoRA compatibility | ||
| layer.hidden_size = hidden_size | ||
| layer.intermediate_size_per_partition = intermediate_size_per_partition | ||
| layer.local_num_experts = num_experts | ||
|
|
||
| intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") | ||
|
|
||
| # Will transpose the loaded weight along the | ||
|
|
@@ -1738,6 +1747,11 @@ def create_weights( | |
| params_dtype: torch.dtype, | ||
| **extra_weight_attrs, | ||
| ): | ||
| # Set layer attributes needed for LoRA compatibility | ||
| layer.hidden_size = hidden_size | ||
| layer.intermediate_size_per_partition = intermediate_size_per_partition | ||
| layer.local_num_experts = num_experts | ||
|
|
||
| # Will transpose the loaded weight along the | ||
| # intermediate and hidden dim sizes. Will | ||
| # shard for TP along the transposed dims | ||
|
|
@@ -2013,6 +2027,11 @@ def create_weights( | |
| **extra_weight_attrs, | ||
| ): | ||
| # Shapes per local rank (TP/EP): | ||
| # Set layer attributes needed for LoRA compatibility | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't this only needed for W4A16 as of now? |
||
| layer.hidden_size = hidden_size | ||
| layer.intermediate_size_per_partition = intermediate_size_per_partition | ||
| layer.local_num_experts = num_experts | ||
|
|
||
| # w13: [E, 2*I_local, H] int8 (int4 values in [-8,7]) | ||
| # w2 : [E, H, I_local] int8 | ||
| # Scales: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to detect if we're doing Lora and write that in one if branch and normal logic in the other.
easier to read
than