Skip to content

Commit 93b5436

Browse files
[BiT] Small patch fix (#20657)
* patch fix for `fp16` * use `np` instead
1 parent 0526a07 commit 93b5436

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/transformers/models/bit/modeling_bit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import math
1919
from typing import Optional, Tuple
2020

21+
import numpy as np
2122
import torch
2223
import torch.utils.checkpoint
2324
from torch import Tensor, nn
@@ -592,7 +593,8 @@ def __init__(self, config: BitConfig):
592593
dilation = 1
593594

594595
layer_dropouts = [
595-
x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
596+
x.tolist()
597+
for x in torch.Tensor(np.linspace(0, config.drop_path_rate, sum(config.depths))).split(config.depths)
596598
]
597599

598600
for stage_idx, (current_depth, current_hidden_size, layer_dropout) in enumerate(

0 commit comments

Comments
 (0)