import threading
from queue import Queue
class Worker:
def __init__(self, train_batch_queue, cuda_batch_queue, lines, batch):
self.train_batch_queue = train_batch_queue # 传入训练数据列队
self.cuda_batch_queue = cuda_batch_queue # 将数据传入cuda数据队列
self.lines = lines # list类型的<地址路径,标签>数据集,已经打乱
self.length = len(lines) # list的长度
self.lock_queue = threading.Lock() # 队列锁
self.lock_index = threading.Lock() # index锁
self.index = 0 # 索引
self.batch = batch # 批次
self.images = [] # 每个批次的图片数据
self.labels = [] # 每个批次的标签数据
def produce(self):
while True: # 线程执行,死循环
# 将磁盘数据加载到内存里
with self.lock_index: # 索引锁
if self.index + self.batch > self.length: # 如果读取数据的索引大于总数个数,模运算
self.index = (self.index + self.batch) % self.length
self.images = []
self.labels = []
# 拿出一个批次的数据
data = self.lines[self.index:self.index+self.batch]
for i in range(self.batch): # 遍历数据
try:
img, label = data[i].strip().split(' ')
except ValueError:
continue
# 处理图片
pass
self.images.append(img)
self.labels.append(float(label))
with self.lock_queue: # 线程锁,一个线程只能执行一次
if self.train_batch_queue.empty(): # 如果队列不为空,将数据放入队列
if len(self.images) % self.batch == 0: # 批次完成,返回数据, 防止死锁
self.train_batch_queue.put([self.images, self.labels], block=True) # 将数据放入队列,阻塞方式
self.images, self.labels = [],[]
self.index += self.batch # 每装入一批数据,index移动
def comsumer(self):
while True: # 条件判断,不能kill这个线程
if self.cuda_batch_queue.empty():
imgs, labels = self.train_batch_queue.get(block=True)
self.cuda_batch_queue.put([imgs, labels], block=True) # 从batch队列里拿出数据,转变为cuda
if __name__ == '__main__':
txt_lines = open('train.txt').readlines()
train_batch_queue = Queue(maxsize=3)
cuda_batch_queue = Queue(maxsize=1)
worker = Worker(train_batch_queue,cuda_batch_queue, txt_lines, 2)
# worker_produce = threading.Thread(target=worker.produce, args=())
# worker_produce.start()
for _ in range(4):
t = threading.Thread(target=worker.produce, args=())
t.start()
worker_consumer = threading.Thread(target=worker.comsumer, args=())
worker_consumer.start()
epoches = 100000
for i in range(epoches):
imgs, labels = cuda_batch_queue.get(block=True)
print(imgs)
print(labels)
print('-'*10)
代码还不完整,没有判断kill线程