Skip to content

Commit 54e1cc0

Browse files
inherit from PreTrainedModel
1 parent d2156bf commit 54e1cc0

File tree

2 files changed

+127
-179
lines changed

2 files changed

+127
-179
lines changed

src/demo_hf.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from transformers.models.blt_wip.modeling_blt_wip import ByteLatentTransformer
9+
from transformers.models.blt_wip.modeling_blt_wip import BLTModel
1010
from transformers.models.blt_wip.configuration_blt import BLTConfig
1111
from transformers.models.blt_wip.tokenizers.blt_tokenizer import BltTokenizer
1212

@@ -47,7 +47,7 @@ def sample_top_p(probs, p):
4747
def generate(
4848
prompts: list[str] | None,
4949
*,
50-
model: ByteLatentTransformer,
50+
model: BLTModel,
5151
tokenizer: BltTokenizer,
5252
max_prompt_len: int = 256,
5353
max_gen_len: int = 256,
@@ -114,8 +114,8 @@ def generate(
114114

115115

116116

117-
def main(prompt: str = "hi", model_name: str = "blt-1b"):
118-
device = "mps"
117+
def main(prompt: str = "my name is", model_name: str = "blt-1b"):
118+
device = "cuda"
119119

120120
#HF
121121
blt_repo = "facebook/blt-1b"
@@ -134,24 +134,35 @@ def main(prompt: str = "hi", model_name: str = "blt-1b"):
134134

135135
config['args']['attn_bias_type'] = 'causal'
136136
config['args']['attn_impl'] = 'sdpa'
137-
138-
model_config = BLTConfig(**config["args"])
139137

138+
# Create model using normal constructor instead of from_pretrained
139+
from transformers.models.blt_wip.configuration_blt import BLTConfig
140+
model_config = BLTConfig(**config['args'])
141+
142+
# Set entropy model parameters manually
140143
patcher_args = entropy_params["data"]["patcher_args"]
141144
model_config.patch_in_forward = True
142-
model_config.realtime_patching = True # Enable realtime patching
145+
model_config.realtime_patching = True
143146
model_config.patch_size = patcher_args["patch_size"]
144-
model_config.patching_mode = "entropy" #patcher_args["patching_mode"] #TODO: we need to pass "entropy" to run through the Patcher / "entropy model", which is the LMTransformer
147+
model_config.patching_mode = "entropy"
145148
model_config.patching_threshold = patcher_args["threshold"]
146149
model_config.patching_threshold_add = patcher_args["threshold_add"]
147150
model_config.max_patch_length = patcher_args["max_patch_length"]
148151
model_config.patching_batch_size = patcher_args["patching_batch_size"]
149152
model_config.patching_device = patcher_args["patching_device"]
150153
model_config.monotonicity = patcher_args["monotonicity"]
151-
model_config.entropy_model_checkpoint_dir = entropy_dir #original config on the hub don't set this
152-
153-
154-
model = ByteLatentTransformer.from_pretrained(blt_repo, config=model_config).to(device)
154+
model_config.entropy_model_checkpoint_dir = entropy_dir
155+
156+
# Use direct construction instead of from_pretrained to avoid meta tensor issues
157+
print("Creating model...")
158+
model = BLTModel(model_config).to(device)
159+
160+
# Load model weights manually
161+
print("Loading model weights...")
162+
from safetensors.torch import load_file
163+
checkpoint_path = hf_hub_download(repo_id=blt_repo, filename="model.safetensors")
164+
state_dict = load_file(checkpoint_path)
165+
model.load_state_dict(state_dict, strict=False)
155166

156167
tokenizer = BltTokenizer(
157168
vocab_size_unit_1=model_config.vocab_size,
@@ -164,7 +175,7 @@ def main(prompt: str = "hi", model_name: str = "blt-1b"):
164175
prompts,
165176
model=model,
166177
tokenizer=tokenizer,
167-
max_gen_len=4,
178+
max_gen_len=200,
168179
device=device
169180
)
170181

0 commit comments

Comments
 (0)