Skip to content

Commit bf174f9

Browse files
Refactor TFSwinLayer to increase serving compatibility (#18352)
* Refactor `TFSwinLayer` to increase serving compatibility Signed-off-by: Seunghwan Hong <[email protected]> * Fix missed parameters while refactoring Signed-off-by: Seunghwan Hong <[email protected]> * Fix window_reverse to calculate batch size Signed-off-by: Seunghwan Hong <[email protected]> Co-Authored-By: amyeroberts <[email protected]> Co-authored-by: amyeroberts <[email protected]>
1 parent 575aa6e commit bf174f9

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/transformers/models/swin/modeling_tf_swin.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,9 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int
226226
"""
227227
Merges windows to produce higher resolution features.
228228
"""
229-
x = shape_list(windows)[0]
229+
x = tf.shape(windows)[0]
230230
y = tf.cast(height * width / (window_size * window_size), tf.int32)
231-
batch_size = int(x / y)
231+
batch_size = tf.math.floordiv(x, y)
232232
windows = tf.reshape(
233233
windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1)
234234
)
@@ -695,16 +695,18 @@ def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: i
695695
img_mask = tf.expand_dims(img_mask, -1)
696696
img_mask = tf.expand_dims(img_mask, 0)
697697

698-
mask_windows = window_partition(img_mask, self.window_size)
699-
mask_windows = tf.reshape(mask_windows, (-1, self.window_size * self.window_size))
698+
mask_windows = window_partition(img_mask, window_size)
699+
mask_windows = tf.reshape(mask_windows, (-1, window_size * window_size))
700700
attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)
701701
attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask)
702702
attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask)
703703
return attn_mask
704704

705-
def maybe_pad(self, hidden_states: tf.Tensor, height: int, width: int) -> Tuple[tf.Tensor, tf.Tensor]:
706-
pad_right = (self.window_size - width % self.window_size) % self.window_size
707-
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
705+
def maybe_pad(
706+
self, hidden_states: tf.Tensor, window_size: int, height: int, width: int
707+
) -> Tuple[tf.Tensor, tf.Tensor]:
708+
pad_right = (window_size - width % window_size) % window_size
709+
pad_bottom = (window_size - height % window_size) % window_size
708710
pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]]
709711
hidden_states = tf.pad(hidden_states, pad_values)
710712
pad_values = tf.reshape(pad_values, (-1,))
@@ -730,7 +732,7 @@ def call(
730732
hidden_states = self.layernorm_before(hidden_states, training=training)
731733
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels))
732734
# pad hidden_states to multiples of window size
733-
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
735+
hidden_states, pad_values = self.maybe_pad(hidden_states, window_size, height, width)
734736

735737
_, height_pad, width_pad, _ = shape_list(hidden_states)
736738
# cyclic shift

0 commit comments

Comments
 (0)