@@ -820,11 +820,15 @@ def __init__(self, model, max_static_cache_length, batch_size):
820820 self .lm_head = model .lm_head
821821 self .config = model .config
822822
823+ # Detect the device of the exported models by checking a parameter
824+ # We'll use the model's device as the target device
825+ model_device = next (model .parameters ()).device
826+
823827 # Initialize static cache for decoder and DynamicCache for encoder
824828 self .static_cache = StaticCache (config = self .config , max_cache_len = max_static_cache_length )
825829 head_dim = getattr (self .config , "head_dim" , self .config .hidden_size // self .config .num_attention_heads )
826830 num_heads = getattr (self .config , "num_key_value_heads" , self .config .num_attention_heads )
827- self .static_cache .early_initialization (batch_size , num_heads , head_dim , torch .float32 , "cpu" )
831+ self .static_cache .early_initialization (batch_size , num_heads , head_dim , torch .float32 , model_device )
828832 self .cache = EncoderDecoderCache (self .static_cache , DynamicCache ())
829833
830834 register_dynamic_cache_export_support ()
@@ -887,16 +891,22 @@ def _export_encoder(self, encoder_input_ids):
887891 return exported_encoder
888892
889893 def _export_decoder (self , decoder_input_ids , encoder_hidden_states , cache_position ):
894+ target_device = self .full_model .device
890895 wrapped_decoder = (
891896 Seq2SeqLMDecoderExportableModuleWithStaticCache (
892897 model = self .full_model ,
893- max_static_cache_length = self .generation_config .cache_config .max_cache_len ,
894- batch_size = self .generation_config .cache_config .batch_size ,
898+ max_static_cache_length = self .generation_config .cache_config .get ( " max_cache_len" ) ,
899+ batch_size = self .generation_config .cache_config .get ( " batch_size" ) ,
895900 )
896- .to ("cpu" )
901+ .to (target_device )
897902 .eval ()
898903 )
899904
905+ # Move input tensors to the same device as the wrapped decoder
906+ decoder_input_ids = decoder_input_ids .to (target_device )
907+ encoder_hidden_states = encoder_hidden_states .to (target_device )
908+ cache_position = cache_position .to (target_device )
909+
900910 # Define dynamic dimension for encoder output sequence length
901911 encoder_seq_len_dim = torch .export .Dim ("encoder_hidden_seq_length" , max = self .max_hidden_seq_length )
902912
@@ -934,7 +944,7 @@ def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_
934944 encoder_hidden_states
935945 if encoder_hidden_states is not None
936946 else torch .zeros (
937- (self .generation_config .cache_config .batch_size , 10 , self .config .d_model ),
947+ (self .generation_config .cache_config .get ( " batch_size" ) , 10 , self .config .d_model ),
938948 dtype = torch .float32 ,
939949 device = device ,
940950 )
@@ -949,26 +959,32 @@ def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_
949959
950960 def generate (self , prompt_token_ids , max_new_tokens ):
951961 with torch .no_grad ():
962+ model_device = self .full_model .device
963+
964+ # Move input to the model's device if it's on a different device
965+ if prompt_token_ids .device != model_device :
966+ prompt_token_ids = prompt_token_ids .to (model_device )
967+
952968 # Run encoder
953969 encoder_output = self .exported_encoder .module ()(prompt_token_ids )
954970
955- # Initialize with start token (0 for T5)
956- decoder_input_ids = torch .tensor ([[0 ]], dtype = torch .long )
971+ # Initialize with start token (0 for T5) on the correct device
972+ decoder_input_ids = torch .tensor ([[0 ]], dtype = torch .long , device = model_device )
957973 generated_ids = [0 ]
958974
959975 # Generate tokens one by one
960976 for i in range (max_new_tokens - 1 ):
961977 # Run decoder for next token prediction
962978 logits = self .exported_decoder .module ()(
963- decoder_input_ids , encoder_output , torch .tensor ([i ], dtype = torch .long )
979+ decoder_input_ids , encoder_output , torch .tensor ([i ], dtype = torch .long , device = model_device )
964980 )
965981
966982 # Get next token
967983 next_token = torch .argmax (logits [:, - 1 , :], dim = - 1 ).item ()
968984 generated_ids .append (next_token )
969985
970- # Update input for next iteration
971- decoder_input_ids = torch .tensor ([[next_token ]], dtype = torch .long )
986+ # Update input for next iteration on the correct device
987+ decoder_input_ids = torch .tensor ([[next_token ]], dtype = torch .long , device = model_device )
972988
973989 # Check if EOS token
974990 if next_token == self .config .eos_token_id :
0 commit comments