这样读取数据比较麻烦,因为map是一行行读取,需要自己把数据整理成列的方式。处理起来还是比较麻烦,用decode_csv可以直接把数据处理成列的方式,简单很多。
import tensorflow as tf
from tensorflow.contrib.lookup import HashTable
from tensorflow.contrib.lookup import TextFileIdTableInitializer
from tensorflow.contrib.lookup import IdTableWithHashBuckets
label_idx = 0
fields_type = ["int", "tags", "weights", "indexes", "weights"]
fields_idx = [1, 2, 3, 4]
fields_count = len(fields_idx)
new_fields_idx = []
for i in fields_idx:
if i < label_idx:
new_fields_idx.append(i)
else:
new_fields_idx.append(i-1)
## feature parse ##
def input_fn(file_list, epoches=1, batch_size=2, shuffle=False):
def parse_index(indexes, sep=",", default_value="0"):
w = tf.string_split(indexes, ",")
process_str = tf.map_fn(lambda x: tf.cond(tf.equal(tf.string_strip(x), ""),
lambda: default_value,
lambda: x),
elems=w.values)
indexes_number = tf.string_to_number(process_str,tf.int32)
spt_index = tf.SparseTensor(indices=w.indices,
values=indexes_number,
dense_shape=w.dense_shape)
return spt_index
def parse_weight(weights, sep=",", default_value="1"):
w = tf.string_split(weights, sep)
process_str = tf.map_fn(lambda x: tf.cond(tf.equal(tf.string_strip(x), ""),
lambda: default_value,
lambda: x),
elems=w.values)
wgt_number = tf.string_to_number(process_str, tf.float32)
spt_wgt = tf.SparseTensor(indices=w.indices,
values=wgt_number,
dense_shape=w.dense_shape)
return spt_wgt
def parse_split(line):
parse_res = tf.string_split([line], delimiter='|')
values = parse_res.values
label = values[label_idx]
features_values = [label]
for idx in fields_idx:
s = values[idx]
features_values.append(s)
return features_values, label
def parse_feature(f,y): ## 解析feature-value ##
weights = []
for i in range(batch_size):
k = f[i][-1]
weights.append(k)
spt_weight = parse_weight(weights)
index = []
for i in range(batch_size):
k = f[i][-2]
index.append(k)
spt_index = parse_index(index, ",")
return spt_index, spt_weight, y
# 读取文件列表
dataset = tf.data.TextLineDataset(file_list)
# 并行读取
dataset = dataset.map(parse_split, num_parallel_calls=2)
if shuffle:
dataset = dataset.shuffle(buffer_size=5000)
dataset = dataset.repeat(count=epoches)
# 提取读取 节约时间,这里的数量设置为cpu数量* k?
dataset.prefetch(batch_size * 12)
# 如果数据不够一个batch_size 则丢弃
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
# 数据只消费异常
dataset = dataset.make_one_shot_iterator()
features_values, y = dataset.get_next()
spt_index, spt_weight, y = parse_feature(features_values,y)
# batch_sparse_feature = SparseTensorFeature(keys, values)
#batch_sparse_feature = SparseTensorFeature_batch(keys, values)
return spt_index, spt_weight, y
# https://blog.youkuaiyun.com/cjopengler/article/details/78150650
from tensorflow.python.training import coordinator
with tf.Session() as sess:
table_used.init.run()
# 定义numpy input fn
# 运行input_fn, 产生featrue和targets
spt_index, spt_weight, y = input_fn(["./data.txt"], batch_size=3)
coord = coordinator.Coordinator()
threads = tf.train.start_queue_runners(sess, coord=coord)
num_step = 1
for step in range(num_step):
spt_index, spt_weight, y= sess.run([spt_index, spt_weight, y])
print('featrues:', spt_index, spt_weight, y)
coord.request_stop()
coord.join(threads)
数据:
1|click,show,李志林,股灾,演变|21.0,120.0,1,1,1|1,2,3|0.1
1|click,show,李志林,股灾,演变|21.0,120.0,1,1,1|9,2,3|0.2,0.3
1|click,show,杨幂,股灾,演变|21.0,120.0,1,1,1|8,2,3|0.4,0.5,0.6
1|click,show,开心,股灾,演变|21.0,120.0,1,1,1|7,2,3|0.1,0.2,0.3