if __name__ == '__main__':
t = torch.ones(32, 3, 64, 64)
model = sa_layer(64)
y = model(t)
print("print(y.shape)", y.shape) # shape
if __name__ == '__main__':
t = torch.ones(32, 3, 64, 64)
model = sa_layer(64)
y = model(t)
print("print(y.shape)", y.shape) # shape