@@ -518,18 +518,21 @@ def configure_vae(self, vae):
518518 # pylint: disable=C0116
519519 if isinstance (vae , nn .Module ):
520520 self .vae = vae .eval ().cuda ()
521- self .vae_scale_factor = 2 ** (len (self .vae .params .ch_mult ))
521+ self .vae_scale_factor = 2 ** (len (self .vae .params .ch_mult ) - 1 )
522+ self .vae_channels = self .vae .params .z_channels
522523 for param in self .vae .parameters ():
523524 param .requires_grad = False
524525 elif isinstance (vae , AutoEncoderConfig ):
525526 self .vae = AutoEncoder (vae ).eval ().cuda ()
526- self .vae_scale_factor = 2 ** (len (vae .ch_mult ))
527+ self .vae_scale_factor = 2 ** (len (vae .ch_mult ) - 1 )
528+ self .vae_channels = vae .z_channels
527529 for param in self .vae .parameters ():
528530 param .requires_grad = False
529531 else :
530532 logging .info ("Vae not provided, assuming the image input is precached..." )
531533 self .vae = None
532- self .vae_scale_factor = 16
534+ self .vae_scale_factor = 8
535+ self .vae_channels = 16
533536
534537 def configure_text_encoders (self , clip , t5 ):
535538 # pylint: disable=C0116
@@ -618,9 +621,8 @@ def forward_step(self, batch) -> torch.Tensor:
618621
619622 noise_pred = self ._unpack_latents (
620623 noise_pred .transpose (0 , 1 ),
621- int (latents .shape [2 ] * self .vae_scale_factor // 2 ),
622- int (latents .shape [3 ] * self .vae_scale_factor // 2 ),
623- vae_scale_factor = self .vae_scale_factor ,
624+ latents .shape [2 ],
625+ latents .shape [3 ],
624626 ).transpose (0 , 1 )
625627
626628 target = noise - latents
@@ -717,17 +719,15 @@ def prepare_image_latent(self, latents):
717719 timesteps ,
718720 )
719721
720- def _unpack_latents (self , latents , height , width , vae_scale_factor ):
722+ def _unpack_latents (self , latents , height , width ):
721723 # pylint: disable=C0116
722724 batch_size , num_patches , channels = latents .shape
723725
724- height = height // vae_scale_factor
725- width = width // vae_scale_factor
726-
727- latents = latents .view (batch_size , height , width , channels // 4 , 2 , 2 )
726+ # adjust h and w for patching
727+ latents = latents .view (batch_size , height // 2 , width // 2 , channels // 4 , 2 , 2 )
728728 latents = latents .permute (0 , 3 , 1 , 4 , 2 , 5 )
729729
730- latents = latents .reshape (batch_size , channels // ( 2 * 2 ) , height * 2 , width * 2 )
730+ latents = latents .reshape (batch_size , channels // 4 , height , width )
731731
732732 return latents
733733
0 commit comments