Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,11 +824,15 @@ def __init__(self, model, max_static_cache_length, batch_size):
self.lm_head = model.lm_head
self.config = model.config

# Detect the device of the exported models by checking a parameter
# We'll use the model's device as the target device
model_device = next(model.parameters()).device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool stuff! I think the comment at line 965 should be here


# Initialize static cache for decoder and DynamicCache for encoder
self.static_cache = StaticCache(config=self.config, max_cache_len=max_static_cache_length)
head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, "cpu")
self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model_device)
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache())

register_dynamic_cache_export_support()
Expand Down Expand Up @@ -891,16 +895,22 @@ def _export_encoder(self, encoder_input_ids):
return exported_encoder

def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
target_device = self.full_model.device
wrapped_decoder = (
Seq2SeqLMDecoderExportableModuleWithStaticCache(
model=self.full_model,
max_static_cache_length=self.generation_config.cache_config.max_cache_len,
batch_size=self.generation_config.cache_config.batch_size,
max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"),
batch_size=self.generation_config.cache_config.get("batch_size"),
)
.to("cpu")
.to(target_device)
.eval()
)

# Move input tensors to the same device as the wrapped decoder
decoder_input_ids = decoder_input_ids.to(target_device)
encoder_hidden_states = encoder_hidden_states.to(target_device)
cache_position = cache_position.to(target_device)

# Define dynamic dimension for encoder output sequence length
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)

Expand Down Expand Up @@ -938,7 +948,7 @@ def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_
encoder_hidden_states
if encoder_hidden_states is not None
else torch.zeros(
(self.generation_config.cache_config.batch_size, 10, self.config.d_model),
(self.generation_config.cache_config.get("batch_size"), 10, self.config.d_model),
dtype=torch.float32,
device=device,
)
Expand All @@ -953,26 +963,32 @@ def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_

def generate(self, prompt_token_ids, max_new_tokens):
with torch.no_grad():
model_device = self.full_model.device

# Move input to the model's device if it's on a different device
if prompt_token_ids.device != model_device:
prompt_token_ids = prompt_token_ids.to(model_device)

# Run encoder
encoder_output = self.exported_encoder.module()(prompt_token_ids)

# Initialize with start token (0 for T5)
decoder_input_ids = torch.tensor([[0]], dtype=torch.long)
# Initialize with start token (0 for T5) on the correct device
decoder_input_ids = torch.tensor([[0]], dtype=torch.long, device=model_device)
generated_ids = [0]

# Generate tokens one by one
for i in range(max_new_tokens - 1):
# Run decoder for next token prediction
logits = self.exported_decoder.module()(
decoder_input_ids, encoder_output, torch.tensor([i], dtype=torch.long)
decoder_input_ids, encoder_output, torch.tensor([i], dtype=torch.long, device=model_device)
)

# Get next token
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
generated_ids.append(next_token)

# Update input for next iteration
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long)
# Update input for next iteration on the correct device
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long, device=model_device)

# Check if EOS token
if next_token == self.config.eos_token_id:
Expand Down