Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eagle/model/cnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def reset_kv(self):
self.stable_kv = None

@torch.no_grad()
def topK_genrate(self, hidden_states, input_ids, head, logits_processor):
def topK_generate(self, hidden_states, input_ids, head, logits_processor):

input_ids = input_ids.to(hidden_states.device)
total_tokens = self.total_tokens
Expand Down
2 changes: 1 addition & 1 deletion eagle/model/cnets1.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def reset_kv(self):
self.stable_kv = None

@torch.no_grad()
def topK_genrate(self, hidden_states, input_ids, head, logits_processor):
def topK_generate(self, hidden_states, input_ids, head, logits_processor):

input_ids = input_ids.to(hidden_states.device)
total_tokens = self.total_tokens
Expand Down
6 changes: 3 additions & 3 deletions eagle/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def initialize_tree0(input_ids, model, past_key_values, logits_processor):
# input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
# # Clone the output hidden states
#
# draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head)
# draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_generate(hidden_states, input_ids, self.base_model.lm_head)
# if output_orig:
# return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, orig, hidden_states, token
# return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, hidden_states, token
Expand All @@ -250,7 +250,7 @@ def initialize_tree(input_ids, model, past_key_values, logits_processor):
if outputs["hidden_states"][0].device != ea_device:
outputs["hidden_states"] = [x.to(ea_device) for x in outputs["hidden_states"]]
hidden_states=torch.cat(outputs["hidden_states"],dim=-1)
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor)
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_generate(hidden_states, input_ids, model.base_model.lm_head,logits_processor)
return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token


Expand Down Expand Up @@ -463,7 +463,7 @@ def update_inference_inputs(
token = torch.argmax(prob)
token = token[None, None]
# hidden_state = torch.cat((hidden_state, accept_hidden_state_new), dim=1)
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(accept_hidden_state_new,
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_generate(accept_hidden_state_new,
input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1),
head=model.base_model.lm_head,logits_processor=logits_processor)

Expand Down
6 changes: 3 additions & 3 deletions eagle/modeling_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def sample(self, logits, logits_processor, k=1):
return sampled_indices, sampled_probs, probabilities

@torch.no_grad()
def topK_genrate(self, hidden_states, input_ids, head, logits_processor, max_length=4, use_cache=True,
def topK_generate(self, hidden_states, input_ids, head, logits_processor, max_length=4, use_cache=True,
attention_mask=None, len_posi=None, ):
top_k = 5
bs = input_ids.shape[0]
Expand Down Expand Up @@ -1272,7 +1272,7 @@ def initialize_tree(input_ids, model, logits_processor, attention_mask=None):
token = token[:, None]
input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)

tree_logits = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head, logits_processor,
tree_logits = model.ea_layer.topK_generate(hidden_states, input_ids, model.base_model.lm_head, logits_processor,
attention_mask=attention_mask)


Expand Down Expand Up @@ -1543,7 +1543,7 @@ def update_inference_inputs(



tree_logits = model.ea_layer.topK_genrate(draft_hidden,
tree_logits = model.ea_layer.topK_generate(draft_hidden,
input_ids=draft_input_ids,
head=model.base_model.lm_head, logits_processor=logits_processor,attention_mask=attention_mask,len_posi=input_ids.shape[1])

Expand Down
2 changes: 1 addition & 1 deletion eagle/testbug/model/cnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def sample(self,logits, logits_processor,k=1, replacement=False):


@torch.no_grad()
def topK_genrate(self, hidden_states, input_ids, head, logits_processor,max_length=4, use_cache=True):
def topK_generate(self, hidden_states, input_ids, head, logits_processor,max_length=4, use_cache=True):
input_ids = input_ids[:, 1:]
ss_token,ss_prob,ss_op = [],[],[]
len_posi=input_ids.shape[1]
Expand Down
2 changes: 1 addition & 1 deletion eagle/testbug/model/ea_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def forward(
input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
# Clone the output hidden states

ea_logits = self.ea_layer.topK_genrate(hidden_states, input_ids, None, logits_processor)
ea_logits = self.ea_layer.topK_generate(hidden_states, input_ids, None, logits_processor)
if output_orig:
return ea_logits, outputs, orig, hidden_states, token
return ea_logits, hidden_states, token
Expand Down
2 changes: 1 addition & 1 deletion eagle/testbug/model/ea_modelbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def forward(
input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
# Clone the output hidden states

ea_logits = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head, logits_processor,attention_mask=attention_mask)
ea_logits = self.ea_layer.topK_generate(hidden_states, input_ids, self.base_model.lm_head, logits_processor,attention_mask=attention_mask)
if output_orig:
return ea_logits, outputs, orig, hidden_states, token
return ea_logits, hidden_states, token
Expand Down
2 changes: 1 addition & 1 deletion eagle/testbug/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def update_inference_inputs(
token = torch.argmax(prob)
token = token[None, None]
# hidden_state = torch.cat((hidden_state, accept_hidden_state_new), dim=1)
tree_logits = model.ea_layer.topK_genrate(None,
tree_logits = model.ea_layer.topK_generate(None,
input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1),
head=None, logits_processor=logits_processor)

Expand Down