Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,7 +2207,11 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
if alpha_mask:
if "alpha_mask" not in npz:
return False
if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking validityにおいて、
前述の変更により、alpha_maskが8の倍数になるため、このチェック判定式は使えません。
そのため、後述の変更を加えました。

alpha_mask_size = npz["alpha_mask"].shape[0:2]
if alpha_mask_size[0] != alpha_mask_size[0] // 8 * 8 or alpha_mask_size[1] != alpha_mask_size[1] // 8 * 8: # ...is legacy caching scheme without rounding to divisible by 8
return False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

まず、npzチェック行程において、後方互換性を確保するため、
まず、8でround downしていない旧方式で作成したnpzに出会った場合、
npzを再作成させます。

この変更を行わずに、古いキャッシュを削除するよう口頭で周知する方法もありますが、
おそらく伝わりにくい情報だと思いますので、自動判定させるようにしました。
将来的には削除した方が良いかもしれません。
(もし想定外の不具合報告が発生した場合は削除しましょう)

alpha_mask_size = (alpha_mask_size[0] // 8, alpha_mask_size[1] // 8) # Resize alpha_mask to 1/8 scale, the same as latents
if alpha_mask_size != expected_latents_size: # HxW
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

この変更では、新しいalpha_mask_sizeの判定式と比較するように補正しました。
また、この前後の行と同じように、resoではなくexpected_latents_sizeと比較するようにして、コードの視認性を向上しました。
※機能的にはresoでもexpected_latents_sizeでも同一

return False
else:
if "alpha_mask" in npz:
Expand Down Expand Up @@ -2514,6 +2518,7 @@ def cache_batch_latents(
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0
alpha_mask = alpha_mask[:image.shape[0] // 8 * 8, :image.shape[1] // 8 * 8] # Without rounding down [H, W], Tensor sizes may not match, so stack the tensors.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

この変更では、alpha_maskの解像度を8の倍数にround downします。
これによって、画像サイズによっては、training stepループ中にtensor sizeの不一致を解消して、スタックしないようにできます。
結果として、保存されるnpzのalpha_maskサイズは8の倍数になります。

alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
else:
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
Expand Down