Skip to content

Commit 7d14788

Browse files
Gemma3n audio model (#18)
* testing utilities for numerics comparisons * Implement CumulativeGroupNorm and add to SubSampleConvProjection and SSCPConvBlock * Add audio version of forward script based on RyanMullins' implementation * Updating to match encoder tests. WIP: config question needs resolving * Updates to audio classes to enable end-to-end running * Removing vestigial classes, cleaning up print statements * Adding SiLU / Swish to audio conformer feed forward block * Shifted Gemma3p5Audio naming prefix to Gemma3NanoAudio * Adding outputs to audio test * Fixes to padding in SSCP and 1D convolution, align RMS Norm with wider model * Update forward test to load from local weights * Update conversion to process / output audio layers * Update __all__ to export audio encoder * AutoModel registration for Gemma 3n Audio * Use AutoModel for ConditionalGeneration.audio_tower * Fixing input_proj_linear transpose * Fixing Gemma3NanoAudioConformerAttention.post conversion * Fixing Gemma3NanoAudioSSCPConvBlock.conv weights conversion * Correcting indentation issue on Gemma3p5RMSNorm --------- Co-authored-by: Ryan Mullins <[email protected]>
1 parent eda9d33 commit 7d14788

File tree

8 files changed

+1887
-652
lines changed

8 files changed

+1887
-652
lines changed

gemma3n_audio_forward_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import numpy as np
2+
import torch
3+
import transformers
4+
from transformers import (
5+
GemmaTokenizer,
6+
Gemma3p5ForCausalLM,
7+
model_addition_debugger_context,
8+
Gemma3NanoAudioConfig,
9+
Gemma3NanoAudioEncoder,
10+
)
11+
12+
from transformers.models.gemma3p5.modeling_gemma3p5 import Gemma3NanoAudioEncoder
13+
14+
model_id = "gg-hf-gm/gemma-3p5-audio-encoder"
15+
# model = Gemma3NanoAudioEncoder.from_pretrained(model_id, input_feat_size=128)
16+
model = Gemma3NanoAudioEncoder.from_pretrained("/usr/local/google/home/philculliton/gemma3p5/checkpoints/4b_it_safetensors/")
17+
audio_config = model.config
18+
19+
print (audio_config)
20+
21+
batch_size = 1
22+
seq_len = 80 # Example input sequence length (make it odd to test padding)
23+
pad_len = 40
24+
25+
print ("audio_config.input_feat_size", audio_config.input_feat_size)
26+
27+
rng = np.random.default_rng(seed=42)
28+
audio_mel = rng.normal(size=(batch_size, seq_len, audio_config.input_feat_size)).astype(np.float32)
29+
print ("audio_mel", audio_mel.shape)
30+
audio_mel_mask_np = np.zeros((batch_size, seq_len), dtype=bool)
31+
if seq_len >= pad_len: # Ensure pad_len is not out of bounds
32+
audio_mel_mask_np[:, -pad_len:] = True # Pad the end
33+
34+
with model_addition_debugger_context(
35+
model=model,
36+
debug_path="/usr/local/google/home/philculliton/nano3/gemma3n_audio_encoder_debug",
37+
do_prune_layers=False,
38+
use_repr=False,
39+
):
40+
print(audio_mel, audio_mel_mask_np)
41+
42+
outputs = model.forward(torch.from_numpy(audio_mel), torch.from_numpy(audio_mel_mask_np))
43+
44+
print (outputs)
45+
print("Sum: ", np.sum(outputs[0].numpy()))

gemma3n_forward_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
model_addition_debugger_context
99
)
1010

11-
from transformers.models.gemma3p5.modeling_gemma3p5 import Gemma3p5AudioEncoder
12-
1311
model_id = "/usr/local/google/home/ryanmullins/nano3/checkpoints/g251_safetensors"
1412

1513
tokenizer = AutoTokenizer.from_pretrained(model_id)

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
("gemma3", "Gemma3Config"),
139139
("gemma3p5", "Gemma3p5Config"),
140140
("gemma3_text", "Gemma3TextConfig"),
141+
("gemma3p5_audio", "Gemma3NanoAudioConfig"),
141142
("gemma3p5_text", "Gemma3p5TextConfig"),
142143
("gemma3p5_vision", "Gemma3p5VisionConfig"),
143144
("git", "GitConfig"),
@@ -511,6 +512,7 @@
511512
("gemma3", "Gemma3ForConditionalGeneration"),
512513
("gemma3p5", "Gemma3p5ForConditionalGeneration"),
513514
("gemma3_text", "Gemma3ForCausalLM"),
515+
("gemma3p5_audio", "Gemma3NanoAudioEncoder"),
514516
("gemma3p5_text", "Gemma3p5ForCausalLM"),
515517
("gemma3p5_vision", "TimmWrapperModel"),
516518
("git", "GIT"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
("gemma2", "Gemma2Model"),
130130
("gemma3", "Gemma3Model"),
131131
("gemma3_text", "Gemma3TextModel"),
132+
("gemma3p5_audio", "Gemma3NanoAudioEncoder"),
132133
("gemma3p5_text", "Gemma3p5TextModel"),
133134
("gemma3p5_vision", "TimmWrapperModel"),
134135
("git", "GitModel"),

src/transformers/models/gemma3p5/configuration_gemma3p5.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,12 @@ def __init__(
271271
self.activation_sparsity_pattern = activation_sparsity_pattern
272272

273273

274-
class Gemma3p5AudioConfig(PretrainedConfig):
274+
class Gemma3NanoAudioConfig(PretrainedConfig):
275275
model_type = "gemma3p5_audio"
276276

277277
def __init__(
278278
self,
279-
input_feat_size: int = 80,
279+
input_feat_size: int = 128,
280280
hidden_size: int = 1536,
281281
embedding_norm_eps: float = 1e-6,
282282
vocab_size: int = 128,
@@ -458,7 +458,7 @@ class Gemma3p5Config(PretrainedConfig):
458458
>>> vision_config = AutoConfig.from_pretrained(checkpoint)
459459
460460
>>> # Initializing a Gemma3p5 Audio config
461-
>>> audio_config = Gemma3p5AudioConfig()
461+
>>> audio_config = Gemma3NanoAudioConfig()
462462
463463
>>> # Initializing a Gemma3p5 Text config
464464
>>> text_config = Gemma3p5TextConfig()
@@ -477,14 +477,14 @@ class Gemma3p5Config(PretrainedConfig):
477477
sub_configs = {
478478
"text_config": Gemma3p5TextConfig,
479479
"vision_config": Gemma3p5VisionConfig,
480-
"audio_config": Gemma3p5AudioConfig,
480+
"audio_config": Gemma3NanoAudioConfig,
481481
}
482482

483483
def __init__(
484484
self,
485485
text_config: Optional[Union[Gemma3p5TextConfig, dict[str, Any]]] = None,
486486
vision_config: Optional[Union[Gemma3p5VisionConfig, dict[str, Any]]] = None,
487-
audio_config: Optional[Union[Gemma3p5AudioConfig, dict[str, Any]]] = None,
487+
audio_config: Optional[Union[Gemma3NanoAudioConfig, dict[str, Any]]] = None,
488488
audio_soft_tokens_per_image: int = 256,
489489
vision_soft_tokens_per_image: int = 256,
490490
boi_token_id: int = 255_999,
@@ -511,10 +511,10 @@ def __init__(
511511
logger.info("vision_config is None. Using default Gemma3p5VisionConfig.")
512512

513513
if isinstance(audio_config, dict):
514-
audio_config = Gemma3p5AudioConfig(**audio_config)
514+
audio_config = Gemma3NanoAudioConfig(**audio_config)
515515
elif audio_config is None:
516-
audio_config = Gemma3p5AudioConfig()
517-
logger.info("audio_config is None. Using default Gemma3p5AudioConfig.")
516+
audio_config = Gemma3NanoAudioConfig()
517+
logger.info("audio_config is None. Using default Gemma3NanoAudioConfig.")
518518

519519
self.text_config = text_config
520520
self.vision_config = vision_config
@@ -531,4 +531,4 @@ def __init__(
531531
self.initializer_range = initializer_range
532532

533533

534-
__all__ = ["Gemma3p5Config", "Gemma3p5AudioConfig", "Gemma3p5TextConfig", "Gemma3p5VisionConfig"]
534+
__all__ = ["Gemma3p5Config", "Gemma3NanoAudioConfig", "Gemma3p5TextConfig", "Gemma3p5VisionConfig"]

src/transformers/models/gemma3p5/convert_gemma3p5_weights.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@
3939
from transformers import (
4040
AutoConfig,
4141
Gemma3p5Config,
42-
# Gemma3p5AudioEncoder,
4342
Gemma3p5ForCausalLM,
4443
Gemma3p5ForConditionalGeneration,
4544
Gemma3ImageProcessor,
4645
Gemma3Processor,
47-
Gemma3p5AudioConfig,
46+
Gemma3NanoAudioConfig,
47+
Gemma3NanoAudioEncoder,
4848
Gemma3p5TextConfig,
4949
Gemma3p5VisionConfig,
5050
GemmaTokenizerFast,
@@ -155,12 +155,12 @@
155155
activation_sparsity_pattern=(0.95,)*10 + (0.0,)*20,
156156
),
157157
vision_config=Gemma3p5VisionConfig(),
158-
audio_config=Gemma3p5AudioConfig(),
158+
audio_config=Gemma3NanoAudioConfig(),
159159
),
160160
_VARIANT_GEMMA_3_4B: Gemma3p5Config(
161161
text_config=Gemma3p5TextConfig(),
162162
vision_config=Gemma3p5VisionConfig(),
163-
audio_config=Gemma3p5AudioConfig(),
163+
audio_config=Gemma3NanoAudioConfig(),
164164
),
165165
}
166166

@@ -228,7 +228,7 @@
228228

229229

230230
def convert_audio_encoder_weights(
231-
config: Gemma3p5AudioConfig,
231+
config: Gemma3NanoAudioConfig,
232232
path: str,
233233
param: str,
234234
weights: np.ndarray,
@@ -242,7 +242,7 @@ def convert_audio_encoder_weights(
242242

243243
for i, matrix in enumerate(weights):
244244
if "fflayer_end" in path:
245-
base = f"audio_tower.conformer.{i}.ffw_layer_end"
245+
base = f"conformer.{i}.ffw_layer_end"
246246

247247
if path.endswith("ffn_layer1"):
248248
converted_paths.append(f"{base}.ffw_layer_1.weight")
@@ -257,7 +257,7 @@ def convert_audio_encoder_weights(
257257
converted_paths.append(f"{base}.pre_layer_norm.weight")
258258
converted_weights.append(matrix)
259259
elif "fflayer_start" in path:
260-
base = f"audio_tower.conformer.{i}.ffw_layer_start"
260+
base = f"conformer.{i}.ffw_layer_start"
261261

262262
if path.endswith("ffn_layer1"):
263263
converted_paths.append(f"{base}.ffw_layer_1.weight")
@@ -272,10 +272,10 @@ def convert_audio_encoder_weights(
272272
converted_paths.append(f"{base}.pre_layer_norm.weight")
273273
converted_weights.append(matrix)
274274
elif path.endswith("final_ln"):
275-
converted_paths.append(f"audio_tower.conformer.{i}.norm.weight")
275+
converted_paths.append(f"conformer.{i}.norm.weight")
276276
converted_weights.append(matrix)
277277
elif "lconv" in path:
278-
base = f"audio_tower.conformer.{i}.lconv1d"
278+
base = f"conformer.{i}.lconv1d"
279279

280280
if path.endswith("conv_norm"):
281281
converted_paths.append(f"{base}.conv_norm.weight")
@@ -293,7 +293,7 @@ def convert_audio_encoder_weights(
293293
converted_paths.append(f"{base}.pre_layer_norm.weight")
294294
converted_weights.append(matrix)
295295
elif "trans_atten" in path:
296-
base = f"audio_tower.conformer.{i}.attention"
296+
base = f"conformer.{i}.attention"
297297

298298
if param == "per_dim_scale":
299299
converted_paths.append(f"{base}.attn.per_dim_scale")
@@ -312,7 +312,7 @@ def convert_audio_encoder_weights(
312312
converted_weights.append(matrix.reshape(config.hidden_size, config.hidden_size).transpose())
313313
elif path.endswith("post"):
314314
converted_paths.append(f"{base}.post.weight")
315-
converted_weights.append(matrix.transpose(1, 2, 0).reshape(config.hidden_size, config.hidden_size))
315+
converted_weights.append(matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.hidden_size))
316316
elif path.endswith("post_norm"):
317317
converted_paths.append(f"{base}.post_norm.weight")
318318
converted_weights.append(matrix)
@@ -321,21 +321,18 @@ def convert_audio_encoder_weights(
321321
converted_weights.append(matrix)
322322
elif path.startswith(_AUDIO_ENCODER_SSCP):
323323
if path.endswith("input_proj"):
324-
converted_paths.append(f"audio_tower.subsample_conv_projection.input_proj_linear.weight")
324+
converted_paths.append(f"subsample_conv_projection.input_proj_linear.weight")
325325
converted_weights.append(
326-
weights.transpose(1, 2, 0).reshape(config.hidden_size, config.sscp_conv_channel_size[1] ** 2)
326+
weights.transpose(2, 0, 1).reshape(config.hidden_size, config.sscp_conv_channel_size[1] ** 2)
327327
)
328328
elif "norm_" in path:
329329
index = int(path[-1])
330-
converted_paths.extend([
331-
f"audio_tower.subsample_conv_projection.conv_{index}.norm.bias",
332-
f"audio_tower.subsample_conv_projection.conv_{index}.norm.weight",
333-
])
334-
converted_weights.extend([np.zeros_like(weights), weights])
330+
converted_paths.append(f"subsample_conv_projection.conv_{index}.norm.weight")
331+
converted_weights.append(weights)
335332
elif "subsampling_" in path:
336333
index = int(path[-1])
337-
converted_paths.append(f"audio_tower.subsample_conv_projection.conv_{index}.conv.weight")
338-
converted_weights.append(weights.transpose())
334+
converted_paths.append(f"subsample_conv_projection.conv_{index}.conv.weight")
335+
converted_weights.append(weights.transpose(3, 2, 0, 1))
339336

340337
if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
341338
raise ValueError(
@@ -649,7 +646,7 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No
649646
update_tree(
650647
"embed_vision.embedding_projection.weight", value.transpose(), config.vision_config.torch_dtype
651648
)
652-
elif path.startswith(_TRANSFORMER_PARAMETER):
649+
if path.startswith(_TRANSFORMER_PARAMETER):
653650
for path, weights in convert_transformer_weights(config.text_config, path, param, value):
654651
update_tree(f"language_model.{path}", weights, config.text_config.torch_dtype)
655652
elif _MOBILE_NET_PREFIX in path:
@@ -659,7 +656,7 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No
659656
update_tree(f"vision_tower.timm_model.{path}", weights, config.vision_config.torch_dtype)
660657
elif path.startswith(_AUDIO_ENCODER_PARAMETER):
661658
for path, weights in convert_audio_encoder_weights(config.audio_config, path, param, value):
662-
update_tree(path, weights, config.audio_config.torch_dtype)
659+
update_tree(f"audio_tower.{path}", weights, config.audio_config.torch_dtype)
663660

664661
hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"]
665662

@@ -700,7 +697,7 @@ def main(*args):
700697
variant,
701698
type(model).__name__,
702699
)
703-
model.save_pretrained(output_path, safe_serialization=True)
700+
model.save_pretrained(output_path, state_dict=state_tree, safe_serialization=True)
704701
logging.info(
705702
"Saved Gemma 3 (%s) to SafeTensors in %s using %s",
706703
variant,

0 commit comments

Comments
 (0)