def Dataset(file_pattern, batch_size, num_epochs=1):
logging.info('Creating Dataset from %s', file_pattern)
return tf.data.experimental.make_csv_dataset(
file_pattern=file_pattern,
batch_size=batch_size,
label_name=LABEL_NAME,
num_epochs=num_epochs,
num_rows_for_inference=10
)
input_fn = Dataset(expanded, FLAGS.batch_size)
def collect_unique_tokens(input_fn):
logging.info('Creating vocabulary...')
vocabulary_dict = {item: set() for item in CATEGORICAL_COLUMNS}
graph = tf.Graph()
with graph.as_default():
iterator = input_fn().make_one_shot_iterator()
t_features, t_labels = iterator.get_next()
with tf.Session(graph=graph) as sess:
while True:
try:
features, _ = sess.run([t_features, t_labels])
for item in CATEGORICAL_COLUMNS:
for value in features[item]:
vocabulary_dict[item].add(value)
except tf.errors.OutOfRangeError:
break
return vocabulary_dict
tensorflow 建个小图
最新推荐文章于 2024-08-13 14:39:04 发布