Skip to content

Commit 63cd4c7

Browse files
authored
Llama Guard updates (#37872)
* Unhardcode use_chunked_attention, fix no_rope_layers * Go back to exhaustive list of bools * Conversion and modeling updates * Fix rope * Unhardcode rope * Fix context length * style * Minor updates to conversion * Use StaticCache * Minor simplification * DynamicCache 🤦 * Style * Style
1 parent 34f26e2 commit 63cd4c7

File tree

3 files changed

+72
-49
lines changed

3 files changed

+72
-49
lines changed

src/transformers/models/llama4/configuration_llama4.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,13 @@ class Llama4TextConfig(PretrainedConfig):
224224
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
225225
<TODO>
226226
<TODO>
227-
no_rope_layers (`int`, *optional*): TODO
228-
no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO
227+
no_rope_layers (`List[int]`, *optional*):
228+
List with at least the same length as the number of layers in the model.
229+
A `1` at an index position indicates that the corresponding layer will use RoPE,
230+
while a `0` indicates that it's a NoPE layer.
231+
no_rope_layer_interval (`int`, *optional*, defaults to 4):
232+
If `no_rope_layers` is `None`, it will be created using a NoPE layer every
233+
`no_rope_layer_interval` layers.
229234
attention_chunk_size (`int`, *optional*, defaults to 8192):
230235
<TODO>
231236
attn_temperature_tuning (`bool`, *optional*, defaults to `True`):
@@ -339,11 +344,15 @@ def __init__(
339344
self.output_router_logits = output_router_logits
340345
self.router_aux_loss_coef = router_aux_loss_coef
341346
self.router_jitter_noise = router_jitter_noise
347+
348+
# Backwards compatibility
349+
if no_rope_layers == []:
350+
no_rope_layers = None
351+
342352
default_no_rope_layers = [
343353
int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers)
344354
]
345355

346-
# no_rope_layers == [] is invalid as we cannot have 0 layers
347356
self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers
348357

349358
self.interleave_moe_layer_step = interleave_moe_layer_step

src/transformers/models/llama4/convert_llama4_weights_to_hf.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.feed_forward.up_proj.weight", # might need to be fused for efficiency?
6666
# r"layers.(\d+).feed_forward.mlp.fc1_weight": r"language_model.model.layers.\1.feed_forward.gate_up_proj.weight",
6767
r"layers.(\d+).feed_forward.mlp.fc2_weight": r"language_model.model.layers.\1.feed_forward.down_proj.weight",
68+
r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.feed_forward.down_proj.weight",
6869
r"layers.(\d+).feed_forward.mlp.layer_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
6970

7071
# Vision encoder mapping
@@ -166,8 +167,8 @@ def get_concat_dim(key):
166167
return 0
167168

168169

169-
def compute_intermediate_size(hidden_dim, multiple_of=1024, ffn_dim_multiplier=1.3):
170-
hidden_dim = 4 * int(2 * hidden_dim / 3)
170+
def compute_intermediate_size(hidden_dim, ffn_exp=4, multiple_of=1024, ffn_dim_multiplier=1.2):
171+
hidden_dim = ffn_exp * int(2 * hidden_dim / 3)
171172
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
172173
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
173174
return hidden_dim
@@ -203,6 +204,8 @@ def max_context_length(model_path, instruct=False):
203204
with open(os.path.join(model_path, "params.json"), "r") as f:
204205
params = json.load(f)
205206
params = params.get("model", params)
207+
if params.get("moe_args") is None:
208+
return 8192
206209
num_experts = params["moe_args"]["num_experts"]
207210
return 10485760 if num_experts == 16 else 1048576
208211

@@ -242,24 +245,40 @@ def write_model(
242245
# some constants from original code
243246
rope_scaling = {
244247
"rope_type": "llama3",
245-
"factor": 8.0,
248+
"factor": params.get("rope_scaling_factor", 8.0),
246249
"low_freq_factor": 1.0,
247-
"high_freq_factor": 4.0,
250+
"high_freq_factor": params.get("rope_high_freq_factor", 4.0),
248251
"original_max_position_embeddings": 8192,
249252
}
250253
config_kwargs.update({"rope_scaling": rope_scaling})
251254

255+
if attention_chunk_size is None:
256+
config_kwargs.update({"cache_implementation": "static"})
257+
252258
# compute additional params for weight conversion
253259
num_heads_per_shard = num_heads // num_shards
254260
dim_per_head = dim // num_heads
255-
# intermediate_size = compute_intermediate_size(dim, multiple_of=params["multiple_of"])
261+
intermediate_size_mlp = compute_intermediate_size(
262+
dim,
263+
ffn_exp=params["ffn_exp"],
264+
multiple_of=params["multiple_of"],
265+
ffn_dim_multiplier=params["ffn_dim_multiplier"],
266+
)
256267

257268
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
258269

259-
num_experts = params["moe_args"]["num_experts"]
260-
interleave_moe_layer_step = params["moe_args"].get("interleave_moe_layer_step", 1)
270+
if hasattr(params, "moe_args"):
271+
num_experts = params["moe_args"]["num_experts"]
272+
interleave_moe_layer_step = params["moe_args"].get("interleave_moe_layer_step", 1)
273+
else:
274+
# Dense model (possibly Llama Guard) - disable all moe layers
275+
num_experts = 0
276+
interleave_moe_layer_step = 0
277+
config_kwargs.update({"moe_layers": []})
261278

279+
# Ensure all layers are rope if `nope_layer_interval` is None
262280
no_rope_layer_interval = params["nope_layer_interval"]
281+
no_rope_layer_interval = num_heads * 2 if no_rope_layer_interval is None else no_rope_layer_interval
263282

264283
bos_token_id = 200000
265284
eos_token_id = [200001, 200007, 200008] if instruct else 200001
@@ -273,7 +292,7 @@ def write_model(
273292
rope_theta=rope_theta,
274293
num_hidden_layers=num_layers,
275294
intermediate_size=8192,
276-
intermediate_size_mlp=16384,
295+
intermediate_size_mlp=intermediate_size_mlp,
277296
max_position_embeddings=max_context_length(input_base_path, instruct),
278297
num_local_experts=num_experts,
279298
interleave_moe_layer_step=interleave_moe_layer_step,
@@ -336,7 +355,7 @@ def write_model(
336355
sharded_keys = []
337356
for _key in all_keys_raw:
338357
try:
339-
if (loaded[0][_key] == loaded[1][_key]).all():
358+
if num_shards == 1 or (loaded[0][_key] == loaded[1][_key]).all():
340359
repeated_keys.append(_key)
341360
else:
342361
sharded_keys.append(_key)
@@ -354,7 +373,7 @@ def write_model(
354373
for key in tqdm(all_keys, desc="Renaming and processing all keys", unit="key"):
355374
new_key = new_keys[key]
356375
print(key, new_key)
357-
if not is_param_same_across_shards(new_key):
376+
if num_shards > 1 and not is_param_same_across_shards(new_key):
358377
current_parameter = [chunk.pop(key) for chunk in loaded if not isinstance(chunk[key], io.BytesIO)]
359378
else:
360379
print(f"{key} (now {new_key}) is the same across all shards.")
@@ -565,8 +584,8 @@ def get_reserved_special_tokens(name, count, start_index=0):
565584
"<|python_end|>",
566585
"<|finetune_right_pad|>",
567586
] + get_reserved_special_tokens(
568-
"text_post_train", 61, 6
569-
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
587+
"text_post_train", 61, 8
588+
) # <|text_post_train_reserved_special_token_8|>, ..., <|text_post_train_reserved_special_token_68|>
570589

571590
# 200080, ..., 201133
572591
LLAMA4_VISION_SPECIAL_TOKENS = [
@@ -621,15 +640,6 @@ def __init__(
621640
**kwargs,
622641
)
623642

624-
# to check
625-
# import tiktoken
626-
# model = tiktoken.Encoding(
627-
# name=Path(model_path).name,
628-
# pat_str=self.O200K_PATTERN,
629-
# mergeable_ranks=mergeable_ranks,
630-
# special_tokens=self.special_tokens,
631-
# )
632-
633643
instruct = chat_template is not None
634644
self.update_post_processor(self.converted_tokenizer)
635645
# finer special_tokens_map.json
@@ -687,12 +697,10 @@ def write_tokenizer(args):
687697
parser.add_argument(
688698
"--input_dir",
689699
type=str,
690-
default="/fsx/arthur/Llama-4-17B-Omni-Instruct-Original",
691700
help="Location of the local folder copied from the Hub.",
692701
)
693702
parser.add_argument(
694703
"--output_dir",
695-
default="llama4_hf_vision",
696704
type=str,
697705
help="Location to write HF model and tokenizer",
698706
)

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
import torch
2121
import torch.nn as nn
2222
import torch.nn.functional as F
23-
import torch.utils.checkpoint
2423

2524
from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
2625

2726
from ...activations import ACT2FN
28-
from ...cache_utils import Cache, HybridChunkedCache
27+
from ...cache_utils import Cache, DynamicCache, HybridChunkedCache
2928
from ...generation import GenerationMixin
3029
from ...integrations.hub_kernels import use_kernel_forward_from_hub
3130
from ...modeling_attn_mask_utils import AttentionMaskConverter
@@ -287,7 +286,7 @@ def __init__(self, config: Llama4TextConfig, layer_idx):
287286
self.attn_temperature_tuning = config.attn_temperature_tuning
288287
self.attention_dropout = config.attention_dropout
289288
self.is_causal = True
290-
self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
289+
self.use_rope = config.no_rope_layers[layer_idx]
291290
self.q_proj = nn.Linear(
292291
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
293292
)
@@ -374,7 +373,7 @@ def __init__(self, config, layer_idx):
374373
super().__init__()
375374
self.hidden_size = config.hidden_size
376375
self.self_attn = Llama4TextAttention(config, layer_idx)
377-
self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
376+
self.use_chunked_attention = config.attention_chunk_size is not None and bool(config.no_rope_layers[layer_idx])
378377
self.is_moe_layer = layer_idx in config.moe_layers
379378
if self.is_moe_layer: # the 128E model interleaves dense / sparse
380379
self.feed_forward = Llama4TextMoe(config)
@@ -643,7 +642,10 @@ def forward(
643642
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
644643

645644
if use_cache and past_key_values is None:
646-
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
645+
if self.config.get_text_config().get("attention_chunk_size") is not None:
646+
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
647+
else:
648+
past_key_values = DynamicCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
647649

648650
if cache_position is None:
649651
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -740,6 +742,7 @@ def _update_causal_mask(
740742
sequence_length = input_tensor.shape[1]
741743
cache_position = cache_position.to(self.device)
742744
attention_chunk_size = self.config.attention_chunk_size
745+
using_chunked_attention = attention_chunk_size is not None
743746

744747
first_cache_position = cache_position[0]
745748

@@ -748,26 +751,28 @@ def _update_causal_mask(
748751
else:
749752
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
750753

751-
cond1 = first_cache_position >= attention_chunk_size
752-
cond2 = (first_cache_position < attention_chunk_size) & (
753-
first_cache_position + sequence_length > attention_chunk_size
754-
)
755-
key_length = (
756-
torch.where(
757-
cond1,
758-
attention_chunk_size + sequence_length - 1,
759-
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
754+
if using_chunked_attention:
755+
cond1 = first_cache_position >= attention_chunk_size
756+
cond2 = (first_cache_position < attention_chunk_size) & (
757+
first_cache_position + sequence_length > attention_chunk_size
758+
)
759+
key_length = (
760+
torch.where(
761+
cond1,
762+
attention_chunk_size + sequence_length - 1,
763+
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
764+
)
765+
if use_cache
766+
else full_cache_length
760767
)
761-
if use_cache
762-
else full_cache_length
763-
)
764768

765769
if self.config._attn_implementation == "flex_attention":
766770
if isinstance(attention_mask, torch.Tensor):
767-
offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0))
768-
chunked_attention_mask = make_flex_block_causal_mask(
769-
attention_mask, self.config.attention_chunk_size, sequence_length, key_length, offsets=offsets
770-
)
771+
if using_chunked_attention:
772+
offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0))
773+
chunked_attention_mask = make_flex_block_causal_mask(
774+
attention_mask, attention_chunk_size, sequence_length, key_length, offsets=offsets
775+
)
771776
attention_mask = make_flex_block_causal_mask(
772777
attention_mask,
773778
query_length=sequence_length,
@@ -780,15 +785,16 @@ def _update_causal_mask(
780785

781786
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
782787
dtype, device = input_tensor.dtype, input_tensor.device
788+
target_length = max(full_cache_length, attention_chunk_size) if using_chunked_attention else full_cache_length
783789
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
784790
attention_mask,
785791
sequence_length=sequence_length,
786-
target_length=max(full_cache_length, attention_chunk_size),
792+
target_length=target_length,
787793
dtype=dtype,
788794
cache_position=cache_position,
789795
batch_size=input_tensor.shape[0],
790796
)
791-
if full_cache_length > self.config.attention_chunk_size:
797+
if using_chunked_attention and full_cache_length > attention_chunk_size:
792798
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
793799
end_idx = start_idx + key_length
794800
chunked_attention_mask = self.create_chunked_attention_mask(

0 commit comments

Comments
 (0)