if __name__ == '__main__':
input = torch.randn(1, 3, 20, 20, 16)
net = Net()
print("params: ", sum(p.numel() for p in net.parameters()))
out = net(input)
print(out.shape)
if not os.path.exists('./dataset'):
os.mkdir('./dataset')
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))
np.random.seed(3407)
torch.manual_seed(3407)
torch.cuda.manual_seed_all(3407)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == '__main__':
device = torch.device('cpu')
input=torch.randn(1, 1, 48, 48, 32).to(device)
net=UNet3Stage().to(device)
out=net(input)
from thop import profile
macs, params = profile(net, inputs=(input,))
print(macs / 1000000000)
print(params / 1000000)
print(out.shape)