Skip to content
49 changes: 29 additions & 20 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsQuant
from .interfaces import SupportsLoRA, SupportsQuant
from .utils import maybe_prefix

logger = init_logger(__name__)
Expand Down Expand Up @@ -108,23 +108,6 @@ def replace_linear_class(
"rowwise": RowParallelLinear,
}.get(style, ReplicatedLinear)

lora_linear_cls = {
ColumnParallelLinear: {
True: ColumnParallelLinearWithShardedLoRA, # fully sharded
False: ColumnParallelLinearWithLoRA # not fully sharded
},
RowParallelLinear: {
True: RowParallelLinearWithShardedLoRA,
False: RowParallelLinearWithLoRA
},
# ReplicatedLinear doesn't support fully sharded LoRA yet,
# so we use the same class for both cases.
ReplicatedLinear: {
True: ReplicatedLinearWithLoRA,
False: ReplicatedLinearWithLoRA
}
}

class HFCompatibleLinear(vllm_linear_cls):
"""
Wrapper class that removes `output_bias` from returned output.
Expand All @@ -144,7 +127,33 @@ def get_lora_class(cls, fully_sharded: bool = False):
that supports fully sharded LoRA. Defaults to False.

"""
return lora_linear_cls[vllm_linear_cls][fully_sharded]

lora_linear_cls = {
ColumnParallelLinear: {
True: ColumnParallelLinearWithShardedLoRA, # fully sharded
False: ColumnParallelLinearWithLoRA # not fully sharded
},
RowParallelLinear: {
True: RowParallelLinearWithShardedLoRA,
False: RowParallelLinearWithLoRA
},
# ReplicatedLinear doesn't support fully sharded LoRA yet,
# so we use the same class for both cases.
ReplicatedLinear: {
True: ReplicatedLinearWithLoRA,
False: ReplicatedLinearWithLoRA
}
}

lora_cls = lora_linear_cls[vllm_linear_cls][fully_sharded]

class HFCompatibleLinearWithLoRA(lora_cls):

def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.squeeze(0)
return super().forward(input)[0]

return HFCompatibleLinearWithLoRA

return HFCompatibleLinear(
input_size=linear.in_features,
Expand All @@ -154,7 +163,7 @@ def get_lora_class(cls, fully_sharded: bool = False):
)


class TransformersModel(nn.Module, SupportsQuant):
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
Expand Down
Loading