@@ -78,10 +78,6 @@ def add_shrink(self, y: torch.Tensor, x: torch.Tensor,
7878 ...], scale : float , ** kwargs ):
7979 """
8080 Performs GEMM for multiple slices of lora_a.
81- When `is_prefill is` true, it indicates that it is currently the
82- prefill stage, and the `_shrink_prefill` function should be called.
83- Otherwise, it is the decode stage, and the _shrink_decode function
84- should be called.
8581
8682 Semantics:
8783 for i in range(len(lora_a_stacked)):
@@ -129,7 +125,7 @@ def add_expand(self,
129125 lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
130126 bias's weight
131127 output_slices (Tuple[int, ...]): Every slice's size
132- add_inputs (bool): Defaults to True.
128+ add_inputs (bool): Defaults to True.
133129 """
134130 y_org = y
135131 y = y .view (- 1 , y .shape [- 1 ])
@@ -226,7 +222,7 @@ def add_lora_linear(self,
226222
227223 if buffer is None :
228224 r = lora_b_stacked [0 ].size (- 1 )
229- # We set the buffer to be float32 by default , refer to:
225+ # We set the buffer to be float32 by default, refer to:
230226 # https://github.com/triton-lang/triton/issues/1387
231227 buffer = torch .zeros ( # type: ignore
232228 (len (output_slices ), x .size (0 ), r ),
@@ -268,16 +264,16 @@ def add_lora_logits(self,
268264 y (torch.Tensor): Output tensor.
269265 x (torch.Tensor): Input tensor.
270266 lora_a_stacked (torch.Tensor): lora_a's weights.
271- lora_b_stacked (torch.Tensor):lora_b's weights.
267+ lora_b_stacked (torch.Tensor): lora_b's weights.
272268 scale (float): Scaling factor.
273- buffer (Optional[torch.Tensor]):Default to None.
269+ buffer (Optional[torch.Tensor]): Default to None.
274270 """
275271 y_org = y
276272 y = y .view (- 1 , y .shape [- 1 ])
277273 x = x .view (- 1 , x .shape [- 1 ])
278274 r = lora_b_stacked .size (- 1 )
279275 if buffer is None :
280- # We set the buffer to be float32 by default , refer to:
276+ # We set the buffer to be float32 by default, refer to:
281277 # https://github.com/triton-lang/triton/issues/1387
282278 buffer = torch .zeros ((x .size (0 ), r ),
283279 dtype = torch .float32 ,
0 commit comments