66
77import torch
88
9- from transformers .models .blt_wip .modeling_blt_wip import ByteLatentTransformer
9+ from transformers .models .blt_wip .modeling_blt_wip import BLTModel
1010from transformers .models .blt_wip .configuration_blt import BLTConfig
1111from transformers .models .blt_wip .tokenizers .blt_tokenizer import BltTokenizer
1212
@@ -47,7 +47,7 @@ def sample_top_p(probs, p):
4747def generate (
4848 prompts : list [str ] | None ,
4949 * ,
50- model : ByteLatentTransformer ,
50+ model : BLTModel ,
5151 tokenizer : BltTokenizer ,
5252 max_prompt_len : int = 256 ,
5353 max_gen_len : int = 256 ,
@@ -114,8 +114,8 @@ def generate(
114114
115115
116116
117- def main (prompt : str = "hi " , model_name : str = "blt-1b" ):
118- device = "mps "
117+ def main (prompt : str = "my name is " , model_name : str = "blt-1b" ):
118+ device = "cuda "
119119
120120 #HF
121121 blt_repo = "facebook/blt-1b"
@@ -134,24 +134,35 @@ def main(prompt: str = "hi", model_name: str = "blt-1b"):
134134
135135 config ['args' ]['attn_bias_type' ] = 'causal'
136136 config ['args' ]['attn_impl' ] = 'sdpa'
137-
138- model_config = BLTConfig (** config ["args" ])
139137
138+ # Create model using normal constructor instead of from_pretrained
139+ from transformers .models .blt_wip .configuration_blt import BLTConfig
140+ model_config = BLTConfig (** config ['args' ])
141+
142+ # Set entropy model parameters manually
140143 patcher_args = entropy_params ["data" ]["patcher_args" ]
141144 model_config .patch_in_forward = True
142- model_config .realtime_patching = True # Enable realtime patching
145+ model_config .realtime_patching = True
143146 model_config .patch_size = patcher_args ["patch_size" ]
144- model_config .patching_mode = "entropy" #patcher_args["patching_mode"] #TODO: we need to pass "entropy" to run through the Patcher / "entropy model", which is the LMTransformer
147+ model_config .patching_mode = "entropy"
145148 model_config .patching_threshold = patcher_args ["threshold" ]
146149 model_config .patching_threshold_add = patcher_args ["threshold_add" ]
147150 model_config .max_patch_length = patcher_args ["max_patch_length" ]
148151 model_config .patching_batch_size = patcher_args ["patching_batch_size" ]
149152 model_config .patching_device = patcher_args ["patching_device" ]
150153 model_config .monotonicity = patcher_args ["monotonicity" ]
151- model_config .entropy_model_checkpoint_dir = entropy_dir #original config on the hub don't set this
152-
153-
154- model = ByteLatentTransformer .from_pretrained (blt_repo , config = model_config ).to (device )
154+ model_config .entropy_model_checkpoint_dir = entropy_dir
155+
156+ # Use direct construction instead of from_pretrained to avoid meta tensor issues
157+ print ("Creating model..." )
158+ model = BLTModel (model_config ).to (device )
159+
160+ # Load model weights manually
161+ print ("Loading model weights..." )
162+ from safetensors .torch import load_file
163+ checkpoint_path = hf_hub_download (repo_id = blt_repo , filename = "model.safetensors" )
164+ state_dict = load_file (checkpoint_path )
165+ model .load_state_dict (state_dict , strict = False )
155166
156167 tokenizer = BltTokenizer (
157168 vocab_size_unit_1 = model_config .vocab_size ,
@@ -164,7 +175,7 @@ def main(prompt: str = "hi", model_name: str = "blt-1b"):
164175 prompts ,
165176 model = model ,
166177 tokenizer = tokenizer ,
167- max_gen_len = 4 ,
178+ max_gen_len = 200 ,
168179 device = device
169180 )
170181
0 commit comments