Skip to content

Commit 2dd050e

Browse files
authored
Merge pull request #1844 from bghira/bugfix/chroma-masking-update
import change from Diffusers upstream for Chroma masking fixes
2 parents 28c7496 + 7fde1cc commit 2dd050e

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

simpletuner/helpers/models/chroma/pipeline.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -841,19 +841,23 @@ def _get_t5_prompt_embeds(
841841
return_tensors="pt",
842842
)
843843
text_input_ids = text_inputs.input_ids
844-
attention_mask = text_inputs.attention_mask.clone()
845-
846-
seq_lengths = attention_mask.sum(dim=1)
847-
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
848-
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
844+
tokenizer_mask = text_inputs.attention_mask
845+
tokenizer_mask_device = tokenizer_mask.to(device)
849846

847+
# unlike Flux, Chroma uses the tokenizer's attention mask when generating the T5 embeddings
850848
prompt_embeds = self.text_encoder(
851-
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
849+
text_input_ids.to(device),
850+
output_hidden_states=False,
851+
attention_mask=tokenizer_mask_device,
852852
)[0]
853853

854854
dtype = self.text_encoder.dtype
855855
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
856-
attention_mask = attention_mask.to(device=device)
856+
857+
# for the text tokens, Chroma requires that all except the first padding token are masked out
858+
seq_lengths = tokenizer_mask_device.sum(dim=1)
859+
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
860+
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
857861

858862
_, seq_len, _ = prompt_embeds.shape
859863

@@ -1154,7 +1158,15 @@ def _prepare_attention_mask(
11541158
return attention_mask
11551159

11561160
attention_mask = torch.cat(
1157-
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
1161+
[
1162+
attention_mask,
1163+
torch.ones(
1164+
batch_size,
1165+
sequence_length,
1166+
device=attention_mask.device,
1167+
dtype=attention_mask.dtype,
1168+
),
1169+
],
11581170
dim=1,
11591171
)
11601172

0 commit comments

Comments
 (0)