@@ -226,9 +226,9 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int
226226 """
227227 Merges windows to produce higher resolution features.
228228 """
229- x = shape_list (windows )[0 ]
229+ x = tf . shape (windows )[0 ]
230230 y = tf .cast (height * width / (window_size * window_size ), tf .int32 )
231- batch_size = int ( x / y )
231+ batch_size = tf . math . floordiv ( x , y )
232232 windows = tf .reshape (
233233 windows , (batch_size , height // window_size , width // window_size , window_size , window_size , - 1 )
234234 )
@@ -695,16 +695,18 @@ def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: i
695695 img_mask = tf .expand_dims (img_mask , - 1 )
696696 img_mask = tf .expand_dims (img_mask , 0 )
697697
698- mask_windows = window_partition (img_mask , self . window_size )
699- mask_windows = tf .reshape (mask_windows , (- 1 , self . window_size * self . window_size ))
698+ mask_windows = window_partition (img_mask , window_size )
699+ mask_windows = tf .reshape (mask_windows , (- 1 , window_size * window_size ))
700700 attn_mask = tf .expand_dims (mask_windows , 1 ) - tf .expand_dims (mask_windows , 2 )
701701 attn_mask = tf .where (attn_mask != 0 , float (- 100.0 ), attn_mask )
702702 attn_mask = tf .where (attn_mask == 0 , float (0.0 ), attn_mask )
703703 return attn_mask
704704
705- def maybe_pad (self , hidden_states : tf .Tensor , height : int , width : int ) -> Tuple [tf .Tensor , tf .Tensor ]:
706- pad_right = (self .window_size - width % self .window_size ) % self .window_size
707- pad_bottom = (self .window_size - height % self .window_size ) % self .window_size
705+ def maybe_pad (
706+ self , hidden_states : tf .Tensor , window_size : int , height : int , width : int
707+ ) -> Tuple [tf .Tensor , tf .Tensor ]:
708+ pad_right = (window_size - width % window_size ) % window_size
709+ pad_bottom = (window_size - height % window_size ) % window_size
708710 pad_values = [[0 , 0 ], [0 , pad_bottom ], [0 , pad_right ], [0 , 0 ]]
709711 hidden_states = tf .pad (hidden_states , pad_values )
710712 pad_values = tf .reshape (pad_values , (- 1 ,))
@@ -730,7 +732,7 @@ def call(
730732 hidden_states = self .layernorm_before (hidden_states , training = training )
731733 hidden_states = tf .reshape (hidden_states , (batch_size , height , width , channels ))
732734 # pad hidden_states to multiples of window size
733- hidden_states , pad_values = self .maybe_pad (hidden_states , height , width )
735+ hidden_states , pad_values = self .maybe_pad (hidden_states , window_size , height , width )
734736
735737 _ , height_pad , width_pad , _ = shape_list (hidden_states )
736738 # cyclic shift
0 commit comments