Skip to content

Commit e28278f

Browse files
committed
correct shapes
Signed-off-by: CarlosGomes98 <[email protected]>
1 parent adbe082 commit e28278f

1 file changed

Lines changed: 12 additions & 12 deletions

File tree

  • nemo/collections/diffusion/models/flux

nemo/collections/diffusion/models/flux/model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -518,18 +518,21 @@ def configure_vae(self, vae):
518518
# pylint: disable=C0116
519519
if isinstance(vae, nn.Module):
520520
self.vae = vae.eval().cuda()
521-
self.vae_scale_factor = 2 ** (len(self.vae.params.ch_mult))
521+
self.vae_scale_factor = 2 ** (len(self.vae.params.ch_mult) - 1)
522+
self.vae_channels = self.vae.params.z_channels
522523
for param in self.vae.parameters():
523524
param.requires_grad = False
524525
elif isinstance(vae, AutoEncoderConfig):
525526
self.vae = AutoEncoder(vae).eval().cuda()
526-
self.vae_scale_factor = 2 ** (len(vae.ch_mult))
527+
self.vae_scale_factor = 2 ** (len(vae.ch_mult) - 1)
528+
self.vae_channels = vae.z_channels
527529
for param in self.vae.parameters():
528530
param.requires_grad = False
529531
else:
530532
logging.info("Vae not provided, assuming the image input is precached...")
531533
self.vae = None
532-
self.vae_scale_factor = 16
534+
self.vae_scale_factor = 8
535+
self.vae_channels = 16
533536

534537
def configure_text_encoders(self, clip, t5):
535538
# pylint: disable=C0116
@@ -618,9 +621,8 @@ def forward_step(self, batch) -> torch.Tensor:
618621

619622
noise_pred = self._unpack_latents(
620623
noise_pred.transpose(0, 1),
621-
int(latents.shape[2] * self.vae_scale_factor // 2),
622-
int(latents.shape[3] * self.vae_scale_factor // 2),
623-
vae_scale_factor=self.vae_scale_factor,
624+
latents.shape[2],
625+
latents.shape[3],
624626
).transpose(0, 1)
625627

626628
target = noise - latents
@@ -717,17 +719,15 @@ def prepare_image_latent(self, latents):
717719
timesteps,
718720
)
719721

720-
def _unpack_latents(self, latents, height, width, vae_scale_factor):
722+
def _unpack_latents(self, latents, height, width):
721723
# pylint: disable=C0116
722724
batch_size, num_patches, channels = latents.shape
723725

724-
height = height // vae_scale_factor
725-
width = width // vae_scale_factor
726-
727-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
726+
# adjust h and w for patching
727+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
728728
latents = latents.permute(0, 3, 1, 4, 2, 5)
729729

730-
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
730+
latents = latents.reshape(batch_size, channels // 4, height, width)
731731

732732
return latents
733733

0 commit comments

Comments
 (0)