fix bugs when tracing of swin_transformer

pull/1673/head
root 1 year ago
parent 624266148d
commit 00001f91d0

@ -0,0 +1,13 @@
import timm
import torch
if __name__ == "__main__":
model = timm.create_model("swin_s3_tiny_224", pretrained=False)
model.eval()
input = torch.randn(1, 3, 224, 224)
tracemodel = torch.jit.trace(model,input)
x= torch.randn(5, 3, 224, 224)
y = model(x)
y_traced = tracemodel(x)
print("diff between trace and untraced:", torch.max(abs(y-y_traced)))

@ -126,7 +126,12 @@ def window_reverse(windows, window_size: int, H: int, W: int):
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
b_tmp = windows.shape[0] / (H * W / window_size / window_size)
if torch.is_tensor(b_tmp):
B = b_tmp.int()
else:
B = int(b_tmp)
#B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x

Loading…
Cancel
Save