1 decode_csv读取数据
import tensorflow as tf
#创建文件队列
filenames ['./s0000025_1.csv','./s0000025_2.csv']
filename_queue = tf.train.string_input_producer(filenames,shuffle=True)
#shuffle=True 文件队列随机读取,默认
TFReader = tf.TextLineReader(skip_header_lines=1) # 跳过首行
key,value = TFReader.read(filename_queue) #读取内容
record_default = [0.0, 0.0,0.0,0.0]
col1,col2,col3,col4= tf.decode_csv(value, record_defaults=record_default,field_delim=",") #field_delim 分隔符默认是","
features = tf.concat([[col1], [col2], [col3]], 0)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(10):
example, lable = sess.run([features,col4])
print(example, lable)
coord.request_stop()
coord.join(threads)
2 TextLineDataset 读取数据
def decode_line(line):
# Decode the line to tensor
record_defaults = [[0.0] for col in range(4)]
items = tf.decode_csv(line, record_defaults)
features = items[0:3]
label = items[3]
# features = tf.cast(features, tf.float32)
# features = tf.reshape(features,[28,28,1])
# label = tf.cast(label, tf.int64)
return features,label
def create_dataset(filename, batch_size=32, is_shuffle=False, n_repeats=0):
dataset = tf.data.TextLineDataset(filename).skip(1)
if n_repeats > 0:
dataset = dataset.repeat(n_repeats) # for train
dataset = dataset.map(decode_line)
if is_shuffle:
dataset = dataset.shuffle(10000) # shuffle
dataset = dataset.batch(batch_size)
return dataset
filename = './s0000025_2.csv'
training_dataset = create_dataset(filename, batch_size=5, is_shuffle=True)
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,training_dataset.output_shapes)
features, labels = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
sess = tf.Session()
sess.run(training_init_op)
print("TRAIN\n",sess.run([features,labels]))