@@ -841,19 +841,23 @@ def _get_t5_prompt_embeds(
841841 return_tensors = "pt" ,
842842 )
843843 text_input_ids = text_inputs .input_ids
844- attention_mask = text_inputs .attention_mask .clone ()
845-
846- seq_lengths = attention_mask .sum (dim = 1 )
847- mask_indices = torch .arange (attention_mask .size (1 )).unsqueeze (0 ).expand (batch_size , - 1 )
848- attention_mask = (mask_indices <= seq_lengths .unsqueeze (1 )).bool ()
844+ tokenizer_mask = text_inputs .attention_mask
845+ tokenizer_mask_device = tokenizer_mask .to (device )
849846
847+ # unlike Flux, Chroma uses the tokenizer's attention mask when generating the T5 embeddings
850848 prompt_embeds = self .text_encoder (
851- text_input_ids .to (device ), output_hidden_states = False , attention_mask = attention_mask .to (device )
849+ text_input_ids .to (device ),
850+ output_hidden_states = False ,
851+ attention_mask = tokenizer_mask_device ,
852852 )[0 ]
853853
854854 dtype = self .text_encoder .dtype
855855 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
856- attention_mask = attention_mask .to (device = device )
856+
857+ # for the text tokens, Chroma requires that all except the first padding token are masked out
858+ seq_lengths = tokenizer_mask_device .sum (dim = 1 )
859+ mask_indices = torch .arange (tokenizer_mask_device .size (1 ), device = device ).unsqueeze (0 ).expand (batch_size , - 1 )
860+ attention_mask = (mask_indices <= seq_lengths .unsqueeze (1 )).to (dtype = dtype , device = device )
857861
858862 _ , seq_len , _ = prompt_embeds .shape
859863
@@ -1154,7 +1158,15 @@ def _prepare_attention_mask(
11541158 return attention_mask
11551159
11561160 attention_mask = torch .cat (
1157- [attention_mask , torch .ones (batch_size , sequence_length , device = attention_mask .device , dtype = torch .bool )],
1161+ [
1162+ attention_mask ,
1163+ torch .ones (
1164+ batch_size ,
1165+ sequence_length ,
1166+ device = attention_mask .device ,
1167+ dtype = attention_mask .dtype ,
1168+ ),
1169+ ],
11581170 dim = 1 ,
11591171 )
11601172
0 commit comments