Skip to content

Commit 7e7ab8b

Browse files
committed
feat(rope): loosen RopeParameters typing
1 parent 90d1b67 commit 7e7ab8b

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

src/transformers/modeling_rope_utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,20 @@
1313
# limitations under the License.
1414

1515
import math
16+
import sys
1617
from functools import wraps
1718
from typing import Optional, TypedDict
1819

1920
from .configuration_utils import PreTrainedConfig
2021
from .utils import is_torch_available, logging
2122

2223

24+
if sys.version_info >= (3, 11):
25+
from typing import Required
26+
else:
27+
from typing_extensions import Required
28+
29+
2330
logger = logging.get_logger(__name__)
2431

2532

@@ -885,14 +892,16 @@ def rope_config_validation(config: PreTrainedConfig, ignore_keys: Optional[set]
885892
)
886893

887894

888-
class RopeParameters(TypedDict):
895+
class RopeParameters(TypedDict, total=False):
889896
"""
890897
Args:
891898
rope_theta (`float`):
892899
The base period of the RoPE embeddings.
893-
rope_type (`str`, *optional*, defaults to "default"):
900+
rope_type (`str`):
894901
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
895-
'llama3'], with 'default' being the original RoPE implementation.
902+
'llama3'], with 'default' being the original RoPE implementation. This value will be "default" if
903+
constructed by `standardize_rope_params()` from a legacy config without a `rope_parameters` or
904+
`rope_scaling` field that specifies this value.
896905
factor (`float`, *optional*):
897906
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
898907
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
@@ -924,8 +933,8 @@ class RopeParameters(TypedDict):
924933
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
925934
"""
926935

927-
rope_theta: float
928-
rope_type: Optional[str]
936+
rope_theta: Required[float]
937+
rope_type: Required[str]
929938
factor: Optional[float]
930939
original_max_position_embeddings: Optional[int]
931940
attention_factor: Optional[float]

src/transformers/models/gemma3/convert_gemma3_weights.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
Gemma3TextModel,
4545
GemmaTokenizerFast,
4646
GenerationConfig,
47+
RopeParameters,
4748
SiglipVisionConfig,
4849
)
4950
from transformers.image_utils import PILImageResampling
@@ -142,7 +143,10 @@
142143
max_position_embeddings=1024,
143144
query_pre_attn_scalar=256,
144145
sliding_window=512,
145-
rope_parameters=None,
146+
rope_parameters={
147+
"full_attention": RopeParameters(rope_type="default", rope_theta=1_000_000.0),
148+
"sliding_attention": RopeParameters(rope_type="default", rope_theta=10_000.0),
149+
},
146150
use_bidirectional_attention=True,
147151
),
148152
vision_config=None,
@@ -159,7 +163,10 @@
159163
max_position_embeddings=32768,
160164
query_pre_attn_scalar=256,
161165
sliding_window=512,
162-
rope_parameters=None,
166+
rope_parameters={
167+
"full_attention": RopeParameters(rope_type="default", rope_theta=1_000_000.0),
168+
"sliding_attention": RopeParameters(rope_type="default", rope_theta=10_000.0),
169+
},
163170
),
164171
vision_config=None,
165172
),
@@ -173,8 +180,10 @@
173180
num_key_value_heads=1,
174181
head_dim=256,
175182
sliding_window=512,
176-
rope_theta=1_000_000, # used for global RoPE only
177-
rope_local_base_freq=10_000,
183+
rope_parameters={
184+
"full_attention": RopeParameters(rope_type="default", rope_theta=1_000_000.0),
185+
"sliding_attention": RopeParameters(rope_type="default", rope_theta=10_000.0),
186+
},
178187
attn_logit_softcapping=None,
179188
query_pre_attn_scalar=256,
180189
max_position_embeddings=32_768,
@@ -192,11 +201,9 @@
192201
num_key_value_heads=4,
193202
sliding_window=1024,
194203
rope_parameters={
195-
"full_attention": {"rope_type": "linear", "factor": 8.0},
196-
"sliding_attention": {"rope_type": "default"},
204+
"full_attention": RopeParameters(rope_type="linear", rope_theta=1_000_000.0, factor=8.0),
205+
"sliding_attention": RopeParameters(rope_type="default", rope_theta=10_000.0),
197206
},
198-
rope_theta=1_000_000,
199-
rope_local_base_freq=10_000,
200207
attn_logit_softcapping=None,
201208
query_pre_attn_scalar=256,
202209
),
@@ -213,11 +220,9 @@
213220
num_key_value_heads=8,
214221
sliding_window=1024,
215222
rope_parameters={
216-
"full_attention": {"rope_type": "linear", "factor": 8.0},
217-
"sliding_attention": {"rope_type": "default"},
223+
"full_attention": RopeParameters(rope_type="linear", rope_theta=1_000_000.0, factor=8.0),
224+
"sliding_attention": RopeParameters(rope_type="default", rope_theta=10_000.0),
218225
},
219-
rope_theta=1_000_000,
220-
rope_local_base_freq=10_000,
221226
attn_logit_softcapping=None,
222227
query_pre_attn_scalar=256,
223228
),
@@ -234,11 +239,9 @@
234239
head_dim=128,
235240
sliding_window=1024,
236241
rope_parameters={
237-
"full_attention": {"rope_type": "linear", "factor": 8.0},
238-
"sliding_attention": {"rope_type": "default"},
242+
"full_attention": RopeParameters(rope_type="linear", rope_theta=1_000_000.0, factor=8.0),
243+
"sliding_attention": RopeParameters(rope_type="default", rope_theta=10_000.0),
239244
},
240-
rope_theta=1_000_000,
241-
rope_local_base_freq=10_000,
242245
attn_logit_softcapping=None,
243246
query_pre_attn_scalar=(42 * 128 // 32), # 1 / sqrt(hidden_size // num_attention_heads)
244247
),

0 commit comments

Comments
 (0)