Skip to content

Commit 61e1cbc

Browse files
dingyiweiglenn-jocher
authored andcommitted
Use permute() instead of 2 transpose()
1 parent 540ef0d commit 61e1cbc

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

models/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def forward(self, x):
8686
if self.conv is not None:
8787
x = self.conv(x)
8888
b, _, w, h = x.shape
89-
p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
90-
return self.tr(p + self.linear(p)).unsqueeze(3).transpose(0, 3).reshape(b, self.c2, w, h)
89+
p = x.flatten(2).permute(2, 0, 1)
90+
return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
9191

9292

9393
class Bottleneck(nn.Module):

0 commit comments

Comments
 (0)