Skip to content

Commit ec7f8af

Browse files
authored
[ConvNeXT] Fix drop_path_rate (#17280)
* Fix drop_path_rate * Fix TF's drop path rate
1 parent a26ab95 commit ec7f8af

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/transformers/models/convnext/modeling_convnext.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,9 @@ class ConvNextEncoder(nn.Module):
209209
def __init__(self, config):
210210
super().__init__()
211211
self.stages = nn.ModuleList()
212-
drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
213-
cur = 0
212+
drop_path_rates = [
213+
x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
214+
]
214215
prev_chs = config.hidden_sizes[0]
215216
for i in range(config.num_stages):
216217
out_chs = config.hidden_sizes[i]
@@ -220,10 +221,9 @@ def __init__(self, config):
220221
out_channels=out_chs,
221222
stride=2 if i > 0 else 1,
222223
depth=config.depths[i],
223-
drop_path_rates=drop_path_rates[cur],
224+
drop_path_rates=drop_path_rates[i],
224225
)
225226
self.stages.append(stage)
226-
cur += config.depths[i]
227227
prev_chs = out_chs
228228

229229
def forward(

src/transformers/models/convnext/modeling_tf_convnext.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,9 @@ class TFConvNextEncoder(tf.keras.layers.Layer):
235235
def __init__(self, config, **kwargs):
236236
super().__init__(**kwargs)
237237
self.stages = []
238-
drop_path_rates = [x for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))]
239-
cur = 0
238+
drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths))
239+
drop_path_rates = tf.split(drop_path_rates, config.depths)
240+
drop_path_rates = [x.numpy().tolist() for x in drop_path_rates]
240241
prev_chs = config.hidden_sizes[0]
241242
for i in range(config.num_stages):
242243
out_chs = config.hidden_sizes[i]
@@ -246,11 +247,10 @@ def __init__(self, config, **kwargs):
246247
out_channels=out_chs,
247248
stride=2 if i > 0 else 1,
248249
depth=config.depths[i],
249-
drop_path_rates=drop_path_rates[cur],
250+
drop_path_rates=drop_path_rates[i],
250251
name=f"stages.{i}",
251252
)
252253
self.stages.append(stage)
253-
cur += config.depths[i]
254254
prev_chs = out_chs
255255

256256
def call(self, hidden_states, output_hidden_states=False, return_dict=True):

0 commit comments

Comments
 (0)