Skip to content

Commit fb3cd1d

Browse files
authored
Merge branch 'main' into mistral
2 parents 06cdfeb + e8fd39a commit fb3cd1d

File tree

8 files changed

+123
-61
lines changed

8 files changed

+123
-61
lines changed

benchmarks/run_mmstar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def main(args):
2424
print(f"Created temporary image directory: {cache_dir}")
2525

2626
# Read data
27-
dataset = load_dataset("Lin-Chen/MMStar")["validation"]
27+
dataset = load_dataset("Lin-Chen/MMStar")["val"]
2828
questions = []
2929
for idx, q in enumerate(dataset):
3030
if idx >= args.num_questions:
File renamed without changes.

scripts/train_eagle3_offline.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,14 @@ def main():
376376
epoch_acces = [[] for _ in range(eagle3_model.module.length)]
377377
epoch_plosses = [[] for _ in range(eagle3_model.module.length)]
378378

379-
for data in tqdm(train_dataloader, desc=f"Training Epoch {epoch}"):
379+
if dist.get_rank() == 0:
380+
progress_bar = tqdm(
381+
train_dataloader, desc=f"Training Epoch {epoch}", leave=True
382+
)
383+
else:
384+
progress_bar = train_dataloader
385+
386+
for data in progress_bar:
380387
batch_index += 1
381388
if args.profile:
382389
if batch_index == args.profile_start_step:
@@ -410,6 +417,7 @@ def main():
410417
hidden_states=data["hidden_state"].cuda(), # [B, S, D]
411418
target=data["target"].cuda(), # [B, S, D*3]
412419
)
420+
acces = torch.stack(acces).cpu().tolist()
413421

414422
# calculate weighted loss
415423
ploss_weight = [0.8**i for i in range(len(plosses))]
@@ -444,6 +452,13 @@ def main():
444452
)
445453
last_time = time.time()
446454

455+
if dist.get_rank() == 0:
456+
avg_loss = sum(pl.item() for pl in plosses) / len(plosses)
457+
avg_acc = sum(acces) / len(acces)
458+
progress_bar.set_postfix(
459+
{"loss": f"{avg_loss:.2f}", "acc": f"{avg_acc:.2f}"}
460+
)
461+
447462
# Log epoch-level training metrics
448463
train_epoch_logdict = {}
449464
for i in range(len(epoch_acces)):
@@ -479,6 +494,7 @@ def main():
479494
hidden_states=data["hidden_state"].cuda(),
480495
target=data["target"].cuda(),
481496
)
497+
acces = torch.stack(acces).cpu().tolist()
482498
eval_acces = [eval_acces[i] + [acces[i]] for i in range(len(acces))]
483499
eval_plosses = [
484500
eval_plosses[i] + [plosses[i].item()] for i in range(len(plosses))

scripts/train_eagle3_online.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1313
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
1414
from tqdm import tqdm
15-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
15+
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
1616

1717
from specforge import (
1818
AutoDistributedTargetModel,
@@ -283,6 +283,10 @@ def main():
283283
.eval()
284284
.cuda()
285285
)
286+
287+
for p in target_model.parameters():
288+
p.requires_grad = False
289+
286290
print_with_rank("Initialized target model")
287291

288292
# load model with resume
@@ -402,6 +406,7 @@ def main():
402406
draft_model=draft_model,
403407
processor=processor,
404408
length=args.ttt_length,
409+
attention_backend=args.attention_backend,
405410
)
406411
else:
407412
eagle3_model = OnlineEagle3Model(
@@ -426,7 +431,7 @@ def main():
426431

427432
# build other components
428433
optimizer = BF16Optimizer(
429-
eagle3_model,
434+
draft_model,
430435
lr=args.learning_rate,
431436
max_grad_norm=args.max_grad_norm,
432437
warmup_ratio=args.warmup_ratio,
@@ -469,7 +474,14 @@ def main():
469474
epoch_acces = [[] for _ in range(eagle3_model.module.length)]
470475
epoch_plosses = [[] for _ in range(eagle3_model.module.length)]
471476

472-
for data in tqdm(train_dataloader, desc=f"Training Epoch {epoch}"):
477+
if dist.get_rank() == 0:
478+
progress_bar = tqdm(
479+
train_dataloader, desc=f"Training Epoch {epoch}", leave=True
480+
)
481+
else:
482+
progress_bar = train_dataloader
483+
484+
for data in progress_bar:
473485
batch_index += 1
474486
if args.profile:
475487
if batch_index == args.profile_start_step:
@@ -506,6 +518,7 @@ def main():
506518
attention_mask=data["attention_mask"].cuda(),
507519
loss_mask=data["loss_mask"].cuda(),
508520
)
521+
acces = torch.stack(acces).cpu().tolist()
509522

510523
# calculate weighted loss
511524
ploss_weight = [0.8**i for i in range(len(plosses))]
@@ -539,6 +552,13 @@ def main():
539552
)
540553
last_time = time.time()
541554

555+
if dist.get_rank() == 0:
556+
avg_loss = sum(pl.item() for pl in plosses) / len(plosses)
557+
avg_acc = sum(acces) / len(acces)
558+
progress_bar.set_postfix(
559+
{"loss": f"{avg_loss:.2f}", "acc": f"{avg_acc:.2f}"}
560+
)
561+
542562
epoch_logdict = {}
543563
for i in range(len(epoch_acces)):
544564
acc_i = torch.tensor(epoch_acces[i]).cuda().mean()
@@ -581,6 +601,7 @@ def main():
581601
attention_mask=data["attention_mask"].cuda(),
582602
loss_mask=data["loss_mask"].cuda(),
583603
)
604+
acces = torch.stack(acces).cpu().tolist()
584605

585606
eval_acces = [eval_acces[i] + [acces[i]] for i in range(len(acces))]
586607
eval_plosses = [

specforge/core/eagle3.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,12 @@ class QwenVLOnlineEagle3Model(Eagle3Model):
463463
"""
464464

465465
def __init__(
466-
self, target_model, draft_model: Eagle3DraftModel, processor, length: int = 7
466+
self,
467+
target_model,
468+
draft_model: Eagle3DraftModel,
469+
processor,
470+
length: int = 7,
471+
attention_backend: str = "sdpa",
467472
):
468473
"""
469474
Args:
@@ -476,6 +481,7 @@ def __init__(
476481
self.draft_model = draft_model
477482
self.processor = processor
478483
self.length = length
484+
self.attention_backend = attention_backend
479485

480486
@torch.no_grad()
481487
def _prepare_data(
@@ -605,11 +611,20 @@ def forward(
605611
pixel_values: batch image pixel values, used for VLM models
606612
image_grid_thw: (batch, 3), image grid thw, used for VLM models
607613
"""
608-
# Step 1: prepare data with the target model
614+
# Step 0: prepare data with the target model
609615
hidden_states, target, loss_mask, input_ids = self._prepare_data(
610616
input_ids, attention_mask, loss_mask, pixel_values, image_grid_thw
611617
)
612618

619+
# Step 1: handle vocab size
620+
target_p_padded, position_mask = _compute_target_p_padded(
621+
target=target,
622+
t2d=self.draft_model.t2d,
623+
loss_mask=loss_mask,
624+
length=self.length,
625+
)
626+
del target
627+
613628
# basic info
614629
batch_size, seq_length, _ = hidden_states.shape
615630
seq_length_with_past = seq_length
@@ -656,21 +671,28 @@ def forward(
656671
dtype=torch.bool,
657672
device=hidden_states.device,
658673
)
659-
attention_mask = self.draft_model.prepare_decoder_attention_mask(
660-
attention_mask=attention_mask,
661-
hidden_states=hidden_states,
662-
batch_size=batch_size,
663-
seq_length=seq_length,
664-
past_key_values_length=past_key_values_length,
665-
)
674+
if self.attention_backend == "sdpa":
675+
attention_mask = self.draft_model.prepare_decoder_attention_mask(
676+
attention_mask=attention_mask,
677+
hidden_states=hidden_states,
678+
batch_size=batch_size,
679+
seq_length=seq_length,
680+
past_key_values_length=past_key_values_length,
681+
)
666682

667683
# Step 5: run TTT
668684
plosses = []
669685
vlosses = []
670686
acces = []
671-
cache_hidden = [[], []]
687+
if self.attention_backend == "sdpa":
688+
cache_hidden = [[], []]
689+
past_key_values = None
690+
elif self.attention_backend == "flex_attention":
691+
cache_hidden = None
692+
past_key_values = DynamicCache()
672693

673694
for idx in range(self.length):
695+
target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous()
674696
is_last = idx == self.length - 1
675697

676698
# Step 5.1: embed the input ids
@@ -685,55 +707,44 @@ def forward(
685707
cache_hidden=cache_hidden,
686708
attention_mask=attention_mask,
687709
position_ids=position_ids,
710+
past_key_values=past_key_values,
688711
use_cache=True,
689712
)
690713

691-
# Step 5.3: handle vocab size
692-
with torch.no_grad():
693-
target_head = target
694-
target_max_token = target_head.argmax(-1)
695-
target_mask = self.draft_model.t2d[target_max_token]
696-
target_mask = target_mask[..., None].int()
697-
position_mask = target_mask * loss_mask
698-
target_head = target_head[..., self.draft_model.t2d]
699-
target_head = target_head.float()
700-
target_p = nn.Softmax(dim=2)(target_head)
701-
target_p = target_p.detach()
702-
703714
# update hidden states for next step
704715
hidden_states = hidden_states_out
705716

706717
# Step 5.4: get logits
707718
logits = self.draft_model.compute_logits(hidden_states)
708-
logits = logits.float()
709-
710-
# Step 5.5: calculate loss
711-
out_logp = nn.LogSoftmax(dim=2)(logits)
712-
plogp = target_p * out_logp
713-
loss = -torch.sum(position_mask * plogp, 2).mean()
714719

715-
# Step 5.6: record metrics
716-
plosses.append(loss)
720+
# Step 5.5: record metrics first as we in-place modify logits
717721
with torch.no_grad():
718722
acces.append(
719-
(
720-
(logits.argmax(-1) == target_p.argmax(-1))
721-
* position_mask.squeeze(-1)
723+
_compute_metric_acc(
724+
logits=logits,
725+
target_p=target_p,
726+
position_mask=position_mask,
727+
loss_mask=loss_mask,
722728
)
723-
.sum()
724-
.item()
725-
/ (loss_mask.sum().item() + 1e-6)
726729
)
727730

731+
# Step 5.6: calculate loss, in-place modifies logits!
732+
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
733+
plosses.append(loss)
734+
728735
if not is_last:
729736
# Step 5.7: we need to update the loss mask
730737
input_ids = padding(input_ids, left=False)
731-
target = padding(target, left=False)
738+
position_mask = padding(position_mask, left=False)
732739
loss_mask = padding(loss_mask, left=False)
733-
ind = torch.arange(seq_length, device=attention_mask.device)
734-
ind0 = ind[idx:]
735-
ind1 = ind[: seq_length - idx]
736-
attention_mask[:, :, ind0, ind1] = torch.finfo(attention_mask.dtype).min
740+
if self.attention_backend == "sdpa":
741+
ind = torch.arange(seq_length, device=attention_mask.device)
742+
ind0 = ind[idx:]
743+
ind1 = ind[: seq_length - idx]
744+
attention_mask[:, :, ind0, ind1] = torch.finfo(
745+
attention_mask.dtype
746+
).min
747+
# Flex attention mask shirnking is handled inside attention module
737748
return plosses, vlosses, acces
738749

739750

@@ -775,4 +786,4 @@ def _compute_target_p(target, t2d, loss_mask):
775786
def _compute_metric_acc(logits, target_p, position_mask, loss_mask):
776787
return (
777788
(logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)
778-
).sum().item() / (loss_mask.sum().item() + 1e-6)
789+
).sum() / loss_mask.sum().clamp_min(1e-6)

specforge/core/loss.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,9 @@ class LogSoftmaxLoss(torch.autograd.Function):
170170
def forward(ctx, logits, target, position_mask):
171171
B, T, V = logits.shape
172172
loss = torch.zeros((B * T, 1), device=logits.device)
173-
logits_flat = logits.view(B * T, V).contiguous()
174-
target_flat = target.view(B * T, V).contiguous()
175-
position_mask_flat = position_mask.view(B * T, 1).contiguous().bool()
173+
logits_flat = logits.contiguous().view(B * T, V)
174+
target_flat = target.contiguous().view(B * T, V)
175+
position_mask_flat = position_mask.contiguous().view(B * T, 1).bool()
176176
grid = (B * T,)
177177
m = torch.zeros((B * T,), device=logits.device, dtype=torch.float32)
178178
d = torch.zeros((B * T,), device=logits.device, dtype=torch.float32)
@@ -200,9 +200,9 @@ def backward(ctx, grad_output):
200200
logits, target, position_mask, m, d = ctx.saved_tensors
201201
B, T, V = logits.shape
202202
scaling_factor = 1.0 / (B * T)
203-
logits = logits.view(B * T, V).contiguous()
204-
target = target.view(B * T, V).contiguous()
205-
position_mask = position_mask.view(B * T, 1).contiguous().bool()
203+
logits = logits.contiguous().view(B * T, V)
204+
target = target.contiguous().view(B * T, V)
205+
position_mask = position_mask.contiguous().view(B * T, 1).bool()
206206
grid = (B * T,)
207207
BLOCK_SIZE, num_warps = _calculate_settings(V)
208208
log_softmax_backward_kernel[grid](

specforge/data/parse.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def parse(
6868
convroles = ["user", "assistant"]
6969
for j, sentence in enumerate(conversation):
7070
role = sentence["role"]
71-
assert role == convroles[j % 2], f"unexpected role {role}"
71+
if role != convroles[j % 2]:
72+
warnings.warn(
73+
f"Conversation truncated due to unexpected role '{role}'. Expected '{convroles[j % 2]}'."
74+
)
75+
break
7276
messages.append({"role": role, "content": sentence["content"]})
7377

7478
conversation = self.tokenizer.apply_chat_template(
@@ -150,7 +154,6 @@ def parse(
150154
reasoning_level = "Low"
151155

152156
for j, message in enumerate(conversation):
153-
print(message)
154157
if message["role"] == "user":
155158
user_message = message["content"]
156159
if message["role"] == "assistant_analysis":

specforge/modeling/draft/llama3_eagle.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -575,12 +575,23 @@ def forward(
575575
).transpose(1, 2)
576576

577577
lck = past_seen_tokens // q_len
578-
cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck)
579-
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
580-
# Keep positions ids aligned when padding so the KV cache is unaffected.
581-
query_states, key_states = apply_rotary_pos_emb(
582-
query_states, key_states, cos, sin, position_ids + lck
583-
)
578+
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
579+
cos, sin = self.rotary_emb(query_states, position_ids + lck)
580+
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
581+
query_states, key_states = apply_multimodal_rotary_pos_emb(
582+
query_states,
583+
key_states,
584+
cos,
585+
sin,
586+
self.config.rope_scaling["mrope_section"],
587+
)
588+
else:
589+
cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck)
590+
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
591+
# Keep positions ids aligned when padding so the KV cache is unaffected.
592+
query_states, key_states = apply_rotary_pos_emb(
593+
query_states, key_states, cos, sin, position_ids + lck
594+
)
584595

585596
cache_position: torch.Tensor = torch.arange(
586597
past_seen_tokens, past_seen_tokens + q_len, device=hidden_states.device

0 commit comments

Comments
 (0)