11from abc import ABC , abstractmethod
2- from typing import Any , Dict , List , Optional
2+ from typing import List , Optional
33
44import torch
55import torch .nn .functional as F
@@ -28,19 +28,24 @@ class LinearMethodBase(ABC):
2828 """Base class for different (maybe quantized) linear methods."""
2929
3030 @abstractmethod
31- def create_weights (self , input_size_per_partition : int ,
31+ def create_weights (self , layer : torch .nn .Module ,
32+ input_size_per_partition : int ,
3233 output_size_per_partition : int , input_size : int ,
33- output_size : int ,
34- params_dtype : torch .dtype ) -> Dict [str , Any ]:
35- """Create weights for a linear layer."""
34+ output_size : int , params_dtype : torch .dtype ,
35+ ** extra_weight_attrs ):
36+ """Create weights for a linear layer.
37+
38+ The weights will be set as attributes of the layer."""
3639 raise NotImplementedError
3740
3841 @abstractmethod
3942 def apply_weights (self ,
40- weights : Dict [ str , torch .Tensor ] ,
43+ layer : torch .nn . Module ,
4144 x : torch .Tensor ,
4245 bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
43- """Apply the weights to the input tensor."""
46+ """Apply the weights in layer to the input tensor.
47+
48+ Expects create_weights to have been called before on the layer."""
4449 raise NotImplementedError
4550
4651
@@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
5560 def __init__ (self , separate_bias_add : bool = False ):
5661 self .separate_bias_add = separate_bias_add
5762
58- def create_weights (self , input_size_per_partition : int ,
63+ def create_weights (self , layer : torch .nn .Module ,
64+ input_size_per_partition : int ,
5965 output_size_per_partition : int , input_size : int ,
60- output_size : int ,
61- params_dtype : torch . dtype ) -> Dict [ str , Any ] :
66+ output_size : int , params_dtype : torch . dtype ,
67+ ** extra_weight_attrs ) :
6268 weight = Parameter (torch .empty (output_size_per_partition ,
6369 input_size_per_partition ,
6470 dtype = params_dtype ),
6571 requires_grad = False )
6672 set_weight_attrs (weight , {"input_dim" : 1 , "output_dim" : 0 })
67- return {"weight" : weight }
73+ layer .register_parameter ("weight" , weight )
74+ set_weight_attrs (weight , extra_weight_attrs )
6875
6976 def apply_weights (self ,
70- weights : Dict [ str , torch .Tensor ] ,
77+ layer : torch .nn . Module ,
7178 x : torch .Tensor ,
7279 bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
73- weight = weights [ " weight" ]
80+ weight = layer . weight
7481 if self .separate_bias_add :
7582 if bias is not None :
7683 return F .linear (x , weight ) + bias
@@ -111,12 +118,9 @@ def __init__(
111118 if linear_method is None :
112119 linear_method = UnquantizedLinearMethod ()
113120 self .linear_method = linear_method
114- self .linear_weights = self .linear_method .create_weights (
115- self .input_size , self .output_size , self .input_size ,
116- self .output_size , self .params_dtype )
117- for name , weight in self .linear_weights .items ():
118- if isinstance (weight , torch .Tensor ):
119- self .register_parameter (name , weight )
121+ self .linear_method .create_weights (self , self .input_size ,
122+ self .output_size , self .input_size ,
123+ self .output_size , self .params_dtype )
120124 if bias :
121125 self .bias = Parameter (
122126 torch .empty (self .output_size , dtype = self .params_dtype ))
@@ -126,7 +130,7 @@ def __init__(
126130
127131 def forward (self , x : torch .Tensor ) -> torch .Tensor :
128132 bias = self .bias if not self .skip_bias_add else None
129- output = self .linear_method .apply_weights (self . linear_weights , x , bias )
133+ output = self .linear_method .apply_weights (self , x , bias )
130134 output_bias = self .bias if self .skip_bias_add else None
131135 return output , output_bias
132136
@@ -177,13 +181,13 @@ def __init__(
177181 if linear_method is None :
178182 linear_method = UnquantizedLinearMethod ()
179183 self .linear_method = linear_method
180- self .linear_weights = self . linear_method .create_weights (
181- self . input_size , self . output_size_per_partition , self .input_size ,
182- self . output_size , self .params_dtype )
183- for name , weight in self .linear_weights . items ():
184- if isinstance ( weight , torch . Tensor ):
185- self .register_parameter ( name , weight )
186- set_weight_attrs ( weight , { "weight_loader" : self .weight_loader } )
184+ self .linear_method .create_weights (self ,
185+ self .input_size ,
186+ self .output_size_per_partition ,
187+ self .input_size ,
188+ self . output_size ,
189+ self .params_dtype ,
190+ weight_loader = self .weight_loader )
187191 if bias :
188192 self .bias = Parameter (
189193 torch .empty (self .output_size_per_partition ,
@@ -211,8 +215,7 @@ def forward(self, input_):
211215 bias = self .bias if not self .skip_bias_add else None
212216
213217 # Matrix multiply.
214- output_parallel = self .linear_method .apply_weights (
215- self .linear_weights , input_ , bias )
218+ output_parallel = self .linear_method .apply_weights (self , input_ , bias )
216219 if self .gather_output :
217220 # All-gather across the partitions.
218221 output = tensor_model_parallel_all_gather (output_parallel )
@@ -523,13 +526,13 @@ def __init__(
523526 if linear_method is None :
524527 linear_method = UnquantizedLinearMethod ()
525528 self .linear_method = linear_method
526- self .linear_weights = self . linear_method .create_weights (
527- self . input_size_per_partition , self . output_size , self .input_size ,
528- self .output_size , self . params_dtype )
529- for name , weight in self .linear_weights . items ():
530- if isinstance ( weight , torch . Tensor ):
531- self .register_parameter ( name , weight )
532- set_weight_attrs ( weight , { "weight_loader" : self .weight_loader } )
529+ self .linear_method .create_weights (self ,
530+ self .input_size_per_partition ,
531+ self .output_size ,
532+ self .input_size ,
533+ self . output_size ,
534+ self .params_dtype ,
535+ weight_loader = self .weight_loader )
533536
534537 if not reduce_results and (bias and not skip_bias_add ):
535538 raise ValueError ("When not reduce the results, adding bias to the "
@@ -569,7 +572,7 @@ def forward(self, input_):
569572
570573 # Matrix multiply.
571574 output_parallel = self .linear_method .apply_weights (
572- self . linear_weights , input_parallel )
575+ self , input_parallel )
573576 if self .reduce_results and self .tp_size > 1 :
574577 output_ = tensor_model_parallel_all_reduce (output_parallel )
575578 else :
0 commit comments