import torch
net = yourNet()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state=torch.load('weights.pth', map_location=device)
for name, child in net.named_children():
if name in state.keys():
child.load_state_dict(state[name], strict=True)