parent
624266148d
commit
00001f91d0
@ -0,0 +1,13 @@
|
|||||||
|
import timm
|
||||||
|
import torch
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model = timm.create_model("swin_s3_tiny_224", pretrained=False)
|
||||||
|
model.eval()
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
tracemodel = torch.jit.trace(model,input)
|
||||||
|
|
||||||
|
x= torch.randn(5, 3, 224, 224)
|
||||||
|
y = model(x)
|
||||||
|
y_traced = tracemodel(x)
|
||||||
|
print("diff between trace and untraced:", torch.max(abs(y-y_traced)))
|
Loading…
Reference in new issue