@@ -916,6 +916,7 @@ def sample(
916916 verbose : bool = True ,
917917 seg : torch .Tensor | None = None ,
918918 cfg : float | None = None ,
919+ cfg_fill_value : float = - 1.0 ,
919920 ) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
920921 """
921922 Args:
@@ -929,6 +930,7 @@ def sample(
929930 verbose: if true, prints the progression bar of the sampling process.
930931 seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
931932 cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
933+ cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
932934 """
933935 if mode not in ["crossattn" , "concat" ]:
934936 raise NotImplementedError (f"{ mode } condition is not supported" )
@@ -961,7 +963,7 @@ def sample(
961963 model_input = torch .cat ([image ] * 2 , dim = 0 )
962964 if conditioning is not None :
963965 uncondition = torch .ones_like (conditioning )
964- uncondition .fill_ (- 1 )
966+ uncondition .fill_ (cfg_fill_value )
965967 conditioning_input = torch .cat ([uncondition , conditioning ], dim = 0 )
966968 else :
967969 conditioning_input = None
@@ -1261,6 +1263,7 @@ def sample( # type: ignore[override]
12611263 verbose : bool = True ,
12621264 seg : torch .Tensor | None = None ,
12631265 cfg : float | None = None ,
1266+ cfg_fill_value : float = - 1.0 ,
12641267 ) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
12651268 """
12661269 Args:
@@ -1276,6 +1279,7 @@ def sample( # type: ignore[override]
12761279 seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
12771280 is instance of SPADEAutoencoderKL, segmentation must be provided.
12781281 cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1282+ cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
12791283 """
12801284
12811285 if (
@@ -1300,6 +1304,7 @@ def sample( # type: ignore[override]
13001304 verbose = verbose ,
13011305 seg = seg ,
13021306 cfg = cfg ,
1307+ cfg_fill_value = cfg_fill_value ,
13031308 )
13041309
13051310 if save_intermediates :
@@ -1479,6 +1484,7 @@ def sample( # type: ignore[override]
14791484 verbose : bool = True ,
14801485 seg : torch .Tensor | None = None ,
14811486 cfg : float | None = None ,
1487+ cfg_fill_value : float = - 1.0 ,
14821488 ) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
14831489 """
14841490 Args:
@@ -1493,7 +1499,8 @@ def sample( # type: ignore[override]
14931499 mode: Conditioning mode for the network.
14941500 verbose: if true, prints the progression bar of the sampling process.
14951501 seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
1496- cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1502+ cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1503+ cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
14971504 """
14981505 if mode not in ["crossattn" , "concat" ]:
14991506 raise NotImplementedError (f"{ mode } condition is not supported" )
@@ -1521,7 +1528,7 @@ def sample( # type: ignore[override]
15211528 model_input = torch .cat ([image ] * 2 , dim = 0 )
15221529 if conditioning is not None :
15231530 uncondition = torch .ones_like (conditioning )
1524- uncondition .fill_ (- 1 )
1531+ uncondition .fill_ (cfg_fill_value )
15251532 conditioning_input = torch .cat ([uncondition , conditioning ], dim = 0 )
15261533 else :
15271534 conditioning_input = None
@@ -1839,6 +1846,7 @@ def sample( # type: ignore[override]
18391846 verbose : bool = True ,
18401847 seg : torch .Tensor | None = None ,
18411848 cfg : float | None = None ,
1849+ cfg_fill_value : float = - 1.0 ,
18421850 ) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
18431851 """
18441852 Args:
@@ -1856,6 +1864,7 @@ def sample( # type: ignore[override]
18561864 seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
18571865 is instance of SPADEAutoencoderKL, segmentation must be provided.
18581866 cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1867+ cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
18591868 """
18601869
18611870 if (
@@ -1884,6 +1893,7 @@ def sample( # type: ignore[override]
18841893 verbose = verbose ,
18851894 seg = seg ,
18861895 cfg = cfg ,
1896+ cfg_fill_value = cfg_fill_value ,
18871897 )
18881898
18891899 if save_intermediates :
0 commit comments