Skip to content

Commit 0a16089

Browse files
refactor LMTransformer --> BLTPatcher
1 parent 54e1cc0 commit 0a16089

File tree

2 files changed

+483
-499
lines changed

2 files changed

+483
-499
lines changed

src/demo_hf.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ def generate(
6060
) -> list[list[int]]:
6161
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6262

63-
assert (
64-
model.patch_in_forward
65-
), "generate requires model.patch_in_forward=True"
6663
model.eval()
6764
prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts]
6865
# Truncation
@@ -81,8 +78,7 @@ def generate(
8178

8279
for i, curr_pos in enumerate(range(start_pos, end_pos)):
8380
current_tokens = tokens[:, :curr_pos]
84-
patch_lengths, _ = model.patch(current_tokens, include_next_token=True)
85-
logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1]
81+
logits = model(current_tokens)[:, -1]
8682

8783
if use_sampling:
8884
probs = torch.softmax(logits / temp, dim=-1)

0 commit comments

Comments
 (0)