需求背景
我在做本科毕设,题目是是视觉问答,用的keras框架。
视觉问答需要两方面的输入数据,一是图片,二是问题;输出是答案。
提取图片特征是利用vgg16或者faster-rcnn,直接保存到csv文件当中,但是问题就在于图片特征文件可能会太大无法记一次加载进内存。
每一张图片对应的问答对数量是不固定的,4-10个问答对不等。
问答对存储格式:img_id|question|answer
图片:保存有全部图片特征的一整个img_fetures.csv文件,索引为图片id。
思路
今天稍微有了思路,记载一下(但是还没有写代码测试,现在只是写个伪代码简单进行记录)。
没必要把图片特征只存在一个csv文件,完全可以一条数据一个csv文件,
即1.csv 、2.csv 、 3.csv 、 ·······等等。
将csv文件命名为(图片id.csv),以此对应图片与图片特征,这样我们就有好多个csv文件
而问答对数据作为txt文档或者其他格式,其实大小并不大,可以一次性加载内存。
import keras
import os
import math
import pandas as pd
import numpy as np
from keras.layers import Dense
class DataGenerator(keras.utils.Sequence):
def __init__(self, datas, batch_size=128, shuffle=True):
#datas就是已经加载进内存的所有的问答对数据,也可以传入问答对路径,在init函数里读进来,都一样。
self.batch_size = batch_size
self.datas = datas
self.indexes = np.arange(len(self.datas))
self.shuffle = shuffle
def __len__(self):
#计算每一个epoch的迭代次数 即总长度/batch_size
return math.ceil(len(self.datas) / float(self.batch_size))
def __getitem__(self, index):
# 生成batch_size个索引
batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# 根据索引获取问答对即(img_id|question|answer的形式)
batch_datas = [self.datas[k] for k in batch_indexs]
# 生成数据
img,question,answer = self.data_generation(batch_datas)
return img,question,answer
def on_epoch_end(self):
#在每一次epoch结束是否需要进行一次随机,重新随机一下index
if self.shuffle == True:
np.random.shuffle(self.indexes)
def data_generation(self, batch_datas):
#batch_datas形式为batch_szie个mg_id|question|answer
img_feature = np.zeros((batch_size,36,2048))
question_vec = np.zeros(batch_size,14)
answer_label = np.zeros(batch_size,1000)
for i, items in enumerate(batch_datas):
img_id,que,ans = items .split("|")
img_csv_path = os.path.join("path","img_id.csv")
img_feature[i] = pd.read_csv( img_csv_path ) #根据自己的csv文件格式进行调整,正确读取csv数据即可。
question_vec = get_question_vec() #得到question编码
answer_label = get_answer_label() #得到答案编码
return img_feature ,question_vec ,answer_label
#得到Generator实例,提前读取问答对文本,传入Generator
train_generator = DataGenerator(train_datas)
model.fit_generator(training_generator, epochs=50)
其实主要的就是不要将图片特征保存在一个csv文件,每个图片保存一个csv即可。
再次重申:但是还没有写代码测试,现在只是写个伪代码简单进行记录,可能思路还有不足