if __name__ == '__main__':
batch_size = 2
encoder_out_dim = 256
x = torch.randn(1, 3, 640, 640).cuda()
x = (x - x.min()) / (x.max() - x.min())
model = 'Your model'('Your Parm').cuda()
print(model)
out = model(x)
print('out shape:', out.shape)