2828import torch .nn as nn
2929
3030from vllm .model_executor .custom_op import CustomOp
31- from vllm .platforms import current_platform
3231
3332
3433def _rotate_neox (x : torch .Tensor ) -> torch .Tensor :
@@ -48,21 +47,29 @@ def _apply_rotary_emb(
4847 x : torch .Tensor ,
4948 cos : torch .Tensor ,
5049 sin : torch .Tensor ,
50+ is_neox_style : bool ,
5151) -> torch .Tensor :
5252 """
5353 Args:
5454 x: [num_tokens, num_heads, head_size]
5555 cos: [num_tokens, head_size // 2]
5656 sin: [num_tokens, head_size // 2]
57+ is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
58+ positional embeddings.
5759 """
58- orig_dtype = x .dtype
59- x = x .float ()
60- x1 , x2 = torch .chunk (x , 2 , dim = - 1 )
61- cos = cos .unsqueeze (- 2 )
62- sin = sin .unsqueeze (- 2 )
60+ cos = cos .unsqueeze (- 2 ).to (x .dtype )
61+ sin = sin .unsqueeze (- 2 ).to (x .dtype )
62+ if is_neox_style :
63+ x1 , x2 = torch .chunk (x , 2 , dim = - 1 )
64+ else :
65+ x1 = x [..., ::2 ]
66+ x2 = x [..., 1 ::2 ]
6367 o1 = x1 * cos - x2 * sin
6468 o2 = x2 * cos + x1 * sin
65- return torch .cat ((o1 , o2 ), dim = - 1 ).to (orig_dtype )
69+ if is_neox_style :
70+ return torch .cat ((o1 , o2 ), dim = - 1 )
71+ else :
72+ return torch .stack ((o1 , o2 ), dim = - 1 ).flatten (- 2 )
6673
6774
6875class RotaryEmbedding (CustomOp ):
@@ -87,10 +94,9 @@ def __init__(
8794
8895 cache = self ._compute_cos_sin_cache ()
8996 cache = cache .to (dtype )
97+ self .cos_sin_cache : torch .Tensor
9098 self .register_buffer ("cos_sin_cache" , cache , persistent = False )
9199
92- self .use_native2 = current_platform .is_tpu () and is_neox_style
93-
94100 def _compute_inv_freq (self , base : Union [int , float ]) -> torch .Tensor :
95101 """Compute the inverse frequency."""
96102 # NOTE(woosuk): To exactly match the HF implementation, we need to
@@ -119,59 +125,7 @@ def forward_native(
119125 key : torch .Tensor ,
120126 offsets : Optional [torch .Tensor ] = None ,
121127 ) -> Tuple [torch .Tensor , torch .Tensor ]:
122- """A PyTorch-native implementation equivalent to forward().
123-
124- This method mimics the implementation of the custom CUDA kernel
125- used in `forward_cuda()`.
126- """
127- query = query .view (* query .shape [:- 1 ], - 1 , self .head_size )
128- key = key .view (* key .shape [:- 1 ], - 1 , self .head_size )
129-
130- query_rot = query [..., :self .rotary_dim ]
131- key_rot = key [..., :self .rotary_dim ]
132- if self .rotary_dim < self .head_size :
133- query_pass = query [..., self .rotary_dim :]
134- key_pass = key [..., self .rotary_dim :]
135-
136- self .cos_sin_cache : torch .Tensor = self .cos_sin_cache .to (
137- positions .device , dtype = query .dtype )
138- cos_sin = self .cos_sin_cache [torch .add (positions , offsets )
139- if offsets is not None else positions ]
140- cos , sin = cos_sin .chunk (2 , dim = - 1 )
141- if self .is_neox_style :
142- # NOTE(woosuk): Here we assume that the positions tensor has the
143- # shape [batch_size, seq_len].
144- cos = cos .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
145- sin = sin .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
146- else :
147- cos = cos .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
148- sin = sin .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
149-
150- rotate_fn = _rotate_neox if self .is_neox_style else _rotate_gptj
151- query_rot = query_rot * cos + rotate_fn (query_rot ) * sin
152- key_rot = key_rot * cos + rotate_fn (key_rot ) * sin
153-
154- if self .rotary_dim < self .head_size :
155- query = torch .cat ((query_rot , query_pass ), dim = - 1 )
156- key = torch .cat ((key_rot , key_pass ), dim = - 1 )
157- else :
158- query = query_rot
159- key = key_rot
160- query = query .flatten (- 2 )
161- key = key .flatten (- 2 )
162- return query , key
163-
164- def forward_native2 (
165- self ,
166- positions : torch .Tensor ,
167- query : torch .Tensor ,
168- key : torch .Tensor ,
169- offsets : Optional [torch .Tensor ] = None ,
170- ) -> Tuple [torch .Tensor , torch .Tensor ]:
171- """Another PyTorch-native implementation of forward().
172-
173- This method might perform better than `forward_native()` when compiled.
174- """
128+ """A PyTorch-native implementation of forward()."""
175129 if offsets is not None :
176130 positions = positions + offsets
177131 positions = positions .flatten ()
@@ -183,14 +137,14 @@ def forward_native2(
183137 query = query .view (num_tokens , - 1 , self .head_size )
184138 query_rot = query [..., :self .rotary_dim ]
185139 query_pass = query [..., self .rotary_dim :]
186- query_rot = _apply_rotary_emb (query_rot , cos , sin )
140+ query_rot = _apply_rotary_emb (query_rot , cos , sin , self . is_neox_style )
187141 query = torch .cat ((query_rot , query_pass ), dim = - 1 ).reshape (query_shape )
188142
189143 key_shape = key .shape
190144 key = key .view (num_tokens , - 1 , self .head_size )
191145 key_rot = key [..., :self .rotary_dim ]
192146 key_pass = key [..., self .rotary_dim :]
193- key_rot = _apply_rotary_emb (key_rot , cos , sin )
147+ key_rot = _apply_rotary_emb (key_rot , cos , sin , self . is_neox_style )
194148 key = torch .cat ((key_rot , key_pass ), dim = - 1 ).reshape (key_shape )
195149 return query , key
196150
@@ -203,7 +157,7 @@ def forward_cuda(
203157 ) -> Tuple [torch .Tensor , torch .Tensor ]:
204158 from vllm import _custom_ops as ops
205159
206- self .cos_sin_cache = self .cos_sin_cache .to (positions .device ,
160+ self .cos_sin_cache = self .cos_sin_cache .to (query .device ,
207161 dtype = query .dtype )
208162 # ops.rotary_embedding()/batched_rotary_embedding()
209163 # are in-place operations that update the query and key tensors.
@@ -240,17 +194,6 @@ def forward_xpu(
240194 self .cos_sin_cache , self .is_neox_style )
241195 return query , key
242196
243- def forward_tpu (
244- self ,
245- positions : torch .Tensor ,
246- query : torch .Tensor ,
247- key : torch .Tensor ,
248- offsets : Optional [torch .Tensor ] = None ,
249- ) -> Tuple [torch .Tensor , torch .Tensor ]:
250- forward_fn = (self .forward_native2
251- if self .use_native2 else self .forward_native )
252- return forward_fn (positions , query , key , offsets )
253-
254197 def extra_repr (self ) -> str :
255198 s = f"head_size={ self .head_size } , rotary_dim={ self .rotary_dim } "
256199 s += f", max_position_embeddings={ self .max_position_embeddings } "
0 commit comments