代码完成的功能主要为:设置fits源文件路径和保存分类结果路径;读取fits文件并进行预处理和分类;保存分类结果
把主要代码贴出来:
设置目录:
if __name__ == '__main__':
# app = QApplication(sys.argv)
# ex = Example()
# ex.show()
# sys.exit(app.exec_())
#file = 'F:\Auroral\experiment\datasetfortest'
file = '/home/zhongjia/aurora_data/ASI_data/test'
savepath = '/home/zhongjia/aurora_data/ASI_data/test/predict_result'
predict(file,savepath)
寻找file目录下的fits文件
def __getitem__(self, idx):
filename = self.root_dir + '/' + self.metas[idx]
if filename[-4:]=='fits':
img, standertime = readfits(filename)
print('yes')
if filename[-3:] =='img':
img,standertime = readimg(filename)
if self.transform is not None:
img = self.transform(img)
return img
调用模型并将结果保存到savepath目录:
def predict(files,saveresult):
batchsize=100
model = DenseNet169()
model.eval()
load_path = 'checkpoint/_200.pth.tar'
load_model(load_path,model)
cuda_flag = False
val_dataset = McDataset(
files,
transforms.Compose([
ResizeCV2(224, 224),
trans2pil(),
transforms.ToTensor(),
White(),]))
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda_flag else {}
test_loader = torch.utils.data.DataLoader(
val_dataset, batch_size= batchsize, shuffle=False, **kwargs)
originfile = os.listdir(files)
for batch_idx, data in enumerate(test_loader):
# data, target = data.cuda(), target.cuda()
data= Variable(data)
output = model(data)
output = F.softmax(output, dim=1)
pred = output.data.max(1)[1] # get the index of the max log-probability
for i in range(len(pred)):
if pred[i] == 1:
oldname = os.path.join(files,originfile[i+batch_idx * batchsize])
newname = os.path.join(saveresult,'throat_aurora')
if not os.path.exists(newname):
os.makedirs(newname)
newname = os.path.join(newname, originfile[i + batch_idx * batchsize])
shutil.copyfile(oldname, newname)
else:
oldname = os.path.join(files, originfile[i + batch_idx * batchsize])
newname = os.path.join(saveresult, 'not_throat_aurora')
if not os.path.exists(newname):
os.makedirs(newname)
newname = os.path.join(newname,originfile[i + batch_idx * batchsize])
shutil.copyfile(oldname, newname)
以上是部分功能代码,整个程序第一次成功运行完成后,我在test目录下看到了predict_result的预测结果文件:
当我对代码进行修改时,如果修改正确,程序正常运行,每次都能生成predict_result文件。
重点是!!!!!当我错误地修改代码后,程序运行报错,然后我再次将代码还原,即使和之前的代码一模一样,还是出现了以下报错:UnboundLocalError: local variable 'img' referenced before assignment
然后我一看是 img = self.transform(img)变量未定义就开始引用,妈的,我当时想到为什么之前同样的程序就能运行,难道之前程序是存在bug,只是偶尔成功运行。于是乎我对img的transform相关的函数进行了一步步仔细的检查,发现真的没有问题啊,不应该报错呢。别慌,还有一个地方未检查,那就是可能根本没找到fits,导致执行到transform出img未定义!!!
于是我 print('yes')发现还真的未执行,证明未找到fits, 我在jupyter lab和linux终端ls命令查看了test目录下文件,就只有一个fits文件啊,为什么会未找到fits呢???然后我想到隐藏文件:
ls -a 查看test目录下究竟有什么东西:
.ipynb_checkpoints文件隐藏在此!!!!!
删除它!!!!
然后程序正常运行了!
notebook运行出错的时候会生成.ipynb_checkpoints的隐藏文件,很可能导致你在该目录下寻找其他文件出错。