记录一下自己系统完整的实现了一个pytorch项目,(使用了clip模型)
首先,学习pytorch工程规范,如一个项目中应包含data文件夹、modal文件夹、checkpoints文件夹、uitils文件夹、congfig、main、requiment、readme。具体参考链接:http://t.csdnimg.cn/Ts2dT
在数据阶段,学习了如何利用torch.utils.data.Dataloader读取自己的文件夹。在数据集类中应包含--init__,__getitem__,__len__。其中init存放数据集的地址和标签,getitem读取图片信息,len是数据集的长度。按照这种格式创建则可以使用dataloader进行读取。详情参考上面链接。
在具体使用阶段,发现常规代码是在getitem中读取图片并转换为tensor,及得到一个数组在转换为tensor。而clip模型的输入要求pil格式,所以如果在getitem中读取图片会报错,因此我选择将图片的地址信息作为数据送入dataloader,之后在送入clip模型。
网络部分,由于是使用了clip的编码层和自己创建的映射层,定义了两个model,为了实现冻结clip编码层,在Adam函数中只更新映射层参数。实现了clip编码层冻结。
在训练阶段:
clip模型输出的特征是float16,而映射层参数是float32,因此要转换统一。.to(torch,float32)
文本特征使用映射层提取要放在for循环里,及每次loss反向传播更新网络后都要重新提取,不然loss会报错,说变量被修改。