分享朋友的机器学习应用案例:使用机器学习实现财富自由www.abuquant.com
虽然,可以使用常用的类型,但是使用tfrecord更好。
- protobuf的格式传输更快
- 结构统一。相当于屏蔽了底层的数据结构。
import tensorflow as tf
import numpy as np
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import pandas as pd
plt.rcParams["figure.figsize"] = (20,10)
train_df = pd.read_csv('train.csv')
display(train_df.head())
label | pixel0 | pixel1 | pixel2 | pixel3 | pixel4 | pixel5 | pixel6 | pixel7 | pixel8 | … | pixel774 | pixel775 | pixel776 | pixel777 | pixel778 | pixel779 | pixel780 | pixel781 | pixel782 | pixel783 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | … | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | … | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | … | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
3 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | … | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | … | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
5 rows × 785 columns
label_df = train_df.pop(item='label')
train_values = train_df.values
train_labels = label_df.values
display(type(train_values))
display(train_values.shape)
display(type(train_labels))
display(train_labels.shape)
numpy.ndarray
(42000, 784)
numpy.ndarray
(42000,)
Example protobuf:
message Example {
Features features = 1;
};
message Features {
map<string, Feature> feature = 1;
};
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
# 建立tfrecorder writer
writer = tf.python_io.TFRecordWriter('csv_train.tfrecords')
for i in xrange(train_values.shape[0]):
image_raw = train_values[i].tostring()
# build example protobuf
example = tf.train.Example(
features=tf.train.Features(feature={
'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels[i]]))
}))
writer.write(record=example.SerializeToString())
writer.close()
从TFRecord中读取数据
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer(['csv_train.tfrecords'])
_, serialized_record = reader.read(filename_queue)
features = tf.parse_single_example(serialized_record,
features={
## tf.FixedLenFeature return Tensor
## tf.VarLenFeature return SparseTensor
"image_raw": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64),
})
images = tf.decode_raw(features['image_raw'], tf.uint8)
labels = tf.cast(features['label'], tf.int32)
with tf.Session() as session:
session.run(tf.local_variables_initializer())
session.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=session, coord=coord)
for i in xrange(2):
image, label = session.run([images, labels])
display(label)
display(image)
print '-' * 40
1
array([0, 0, 0, ..., 0, 0, 0], dtype=uint8)
----------------------------------------
0
array([0, 0, 0, ..., 0, 0, 0], dtype=uint8)
----------------------------------------
INFO:tensorflow:Error reported to Coordinator: <class 'tensorflow.python.framework.errors_impl.CancelledError'>, Enqueue operation was cancelled
[[Node: input_producer_3/input_producer_3_EnqueueMany = QueueEnqueueManyV2[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input_producer_3, input_producer_3/RandomShuffle)]]