本系列为tensorflow官方教程与文档学习笔记,结合个人理解提取其中的关键内容,便于日后复习。
1. 数据加载
1.1 通过tf.data
加载CSV数据
通过tf.data.experimental.make_csv_dataset
将CSV文件读入dataset对象。几个重要的参数:
batch_size
:指定单个batch下的数据记录数目;column_names
:指定数据的列名,若未给定此参数,默认从数据文件首行获取;label_name
:指定作为label的数据列列名;na_value
:将指定的额外字符也认作NaN;num_epochs
:数据集重复的次数;
例:
TRAIN_DATA_PATH = "E:/Notes/Projects/tensorflow_to_pro/eat_tensorflow2_in_30_days/data/titanic/train.csv"
TEST_DATA_PATH = "E:/Notes/Projects/tensorflow_to_pro/eat_tensorflow2_in_30_days/data/titanic/test.csv"
def get_dataset(file_path):
dataset = tf.data.experimental.make_csv_dataset(
file_path,
batch_size = 12,
label_name = 'Survived',
na_value = '?',
num_epochs = 1,
ignore_errors = True)
return dataset
raw_train_data = get_dataset(TRAIN_DATA_PATH)
raw_test_data = get_dataset(TEST_DATA_PATH)
dataset 中的每个条目都是一个批次,用一个元组(多个样本,多个标签)表示。样本中的数据组织形式是以列为主的张量(而不是以行为主的张量),每条数据中包含的元素个数就是批次大小(这个示例中是 12)。
# dataset可以用于迭代
examples, labels = next(iter(raw_train_data))
print("EXAMPLES: \n", examples,