From 994e4faddde943861ea6106f889d9785d8c6520f Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 6 Aug 2024 19:11:05 +0000 Subject: [PATCH] add emtpy_cache() after each padding --- vllm/model_executor/models/mixtral.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c34077fa2bfa..ee9db7048f1f 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -187,9 +187,11 @@ def process_weights_after_loading(self): self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data, (0, 128), "constant", 0), requires_grad=False) + torch.cuda.empty_cache() self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data, (0, 128), "constant", 0), requires_grad=False) + torch.cuda.empty_cache() return # If checkpoint is fp16, quantize here.