File tree Expand file tree Collapse file tree 2 files changed +483
-499
lines changed
transformers/models/blt_wip Expand file tree Collapse file tree 2 files changed +483
-499
lines changed Original file line number Diff line number Diff line change @@ -60,9 +60,6 @@ def generate(
6060) -> list [list [int ]]:
6161 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
6262
63- assert (
64- model .patch_in_forward
65- ), "generate requires model.patch_in_forward=True"
6663 model .eval ()
6764 prompt_tokens = [tokenizer .encode (t , add_eos = False ) for t in prompts ]
6865 # Truncation
@@ -81,8 +78,7 @@ def generate(
8178
8279 for i , curr_pos in enumerate (range (start_pos , end_pos )):
8380 current_tokens = tokens [:, :curr_pos ]
84- patch_lengths , _ = model .patch (current_tokens , include_next_token = True )
85- logits = model (current_tokens , patch_lengths = patch_lengths )[:, - 1 ]
81+ logits = model (current_tokens )[:, - 1 ]
8682
8783 if use_sampling :
8884 probs = torch .softmax (logits / temp , dim = - 1 )
You can’t perform that action at this time.
0 commit comments