Skip to content

Commit a4a1252

Browse files
ahadnagyNielsRogge
authored andcommitted
T5 test and target device fixes (huggingface#40313)
* Fix cache setup related issues * Fix target-device-related issues * Ruff * Address review comments
1 parent b179cef commit a4a1252

1 file changed

Lines changed: 26 additions & 10 deletions

File tree

src/transformers/integrations/executorch.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -820,11 +820,15 @@ def __init__(self, model, max_static_cache_length, batch_size):
820820
self.lm_head = model.lm_head
821821
self.config = model.config
822822

823+
# Detect the device of the exported models by checking a parameter
824+
# We'll use the model's device as the target device
825+
model_device = next(model.parameters()).device
826+
823827
# Initialize static cache for decoder and DynamicCache for encoder
824828
self.static_cache = StaticCache(config=self.config, max_cache_len=max_static_cache_length)
825829
head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
826830
num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
827-
self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, "cpu")
831+
self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model_device)
828832
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache())
829833

830834
register_dynamic_cache_export_support()
@@ -887,16 +891,22 @@ def _export_encoder(self, encoder_input_ids):
887891
return exported_encoder
888892

889893
def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
894+
target_device = self.full_model.device
890895
wrapped_decoder = (
891896
Seq2SeqLMDecoderExportableModuleWithStaticCache(
892897
model=self.full_model,
893-
max_static_cache_length=self.generation_config.cache_config.max_cache_len,
894-
batch_size=self.generation_config.cache_config.batch_size,
898+
max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"),
899+
batch_size=self.generation_config.cache_config.get("batch_size"),
895900
)
896-
.to("cpu")
901+
.to(target_device)
897902
.eval()
898903
)
899904

905+
# Move input tensors to the same device as the wrapped decoder
906+
decoder_input_ids = decoder_input_ids.to(target_device)
907+
encoder_hidden_states = encoder_hidden_states.to(target_device)
908+
cache_position = cache_position.to(target_device)
909+
900910
# Define dynamic dimension for encoder output sequence length
901911
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
902912

@@ -934,7 +944,7 @@ def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_
934944
encoder_hidden_states
935945
if encoder_hidden_states is not None
936946
else torch.zeros(
937-
(self.generation_config.cache_config.batch_size, 10, self.config.d_model),
947+
(self.generation_config.cache_config.get("batch_size"), 10, self.config.d_model),
938948
dtype=torch.float32,
939949
device=device,
940950
)
@@ -949,26 +959,32 @@ def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_
949959

950960
def generate(self, prompt_token_ids, max_new_tokens):
951961
with torch.no_grad():
962+
model_device = self.full_model.device
963+
964+
# Move input to the model's device if it's on a different device
965+
if prompt_token_ids.device != model_device:
966+
prompt_token_ids = prompt_token_ids.to(model_device)
967+
952968
# Run encoder
953969
encoder_output = self.exported_encoder.module()(prompt_token_ids)
954970

955-
# Initialize with start token (0 for T5)
956-
decoder_input_ids = torch.tensor([[0]], dtype=torch.long)
971+
# Initialize with start token (0 for T5) on the correct device
972+
decoder_input_ids = torch.tensor([[0]], dtype=torch.long, device=model_device)
957973
generated_ids = [0]
958974

959975
# Generate tokens one by one
960976
for i in range(max_new_tokens - 1):
961977
# Run decoder for next token prediction
962978
logits = self.exported_decoder.module()(
963-
decoder_input_ids, encoder_output, torch.tensor([i], dtype=torch.long)
979+
decoder_input_ids, encoder_output, torch.tensor([i], dtype=torch.long, device=model_device)
964980
)
965981

966982
# Get next token
967983
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
968984
generated_ids.append(next_token)
969985

970-
# Update input for next iteration
971-
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long)
986+
# Update input for next iteration on the correct device
987+
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long, device=model_device)
972988

973989
# Check if EOS token
974990
if next_token == self.config.eos_token_id:

0 commit comments

Comments
 (0)