卸载torchsummary,按照torch-summary,中间多了一个横杠,它是torchsummary的加强版。这里我们卸掉torchsummary库,安装torch-summary库,上面的问题就可以解决了。
在anaconda的环境里面,先卸载再按照
pip uninstall torchsummary
pip install torch-summary==1.4.4
我尝试了好几种方式打印yolov5 中6.0版本的网络结构,还是torchsummary好用!!
from models.yolo import * from torchsummary import summary FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH # ROOT = ROOT.relative_to(Path.cwd()) # relative if __name__ == '__main__': parser = argparse.ArgumentParser() # 官方权重 torch.float16 parser.add_argument('--weights', type=str, default=r'D:\codePythonBinocularvision\yolov5-6.0\weight\yolov5s.pt', help='weights path') # 自训练权重 torch.float16 # parser.add_argument('--weights', type=str, default=r'D:\codePythonBinocularvision\yolov5-6.0\weight\best.pt', # help='weights path') opt = parser.parse_args() # Load pytorch model on GPU device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 检查是否有可用的 GPU model = torch.load(opt.weights, map_location=device)['model'] # 将模型加载到设备上 # 转换为FloatTensor float32 model = model.float() # 将模型的权重和参数移到 GPU 上 model.to(device) # 打印模型结构 print(model) # 遍历各个层次并打印结构信息 # for name, layer in model.named_children(): # print(f"\nLayer: {name}") # print(layer) for name, parameters in model.named_parameters(): # print(name, ':', parameters.size()) print(parameters.dtype) # 官方代码打印结构 model_info(model, True) # yolo.py文件结构里面的打印模型 parser = argparse.ArgumentParser() parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--profile', action='store_true', help='profile model speed') opt = parser.parse_args() opt.cfg = check_yaml(opt.cfg) # check YAML print_args(FILE.stem, opt) set_logging() device = select_device(opt.device) # Create model model = Model(opt.cfg).to(device) model.train() # 使用summary函数来打印模型 input_shape = [640, 640] m = Model(opt.cfg).to(device) summary(m, (3, input_shape[0], input_shape[1]))