本案例使用TensorFlow2加载CSV数据到tf.data.Dataset中,使用经典的数据集:泰坦尼克乘客数据。
1. 导入所需的库
import tensorflow as tf
import numpy as np
import pandas as pd
import functools
for i in [tf,np,pd]:
print(i.__name__,": ",i.__version__,sep="")
输出:
tensorflow: 2.2.0
numpy: 1.17.4
pandas: 0.25.3
2. 下载并导入数据
2.1 下载数据到本地
trainDataUrl = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
testDataUrl = "https://storage.googleapis.com/tf-datasets/titanic/eval.csv"
trainFilePath = tf.keras.utils.get_file("trainTitanic.csv",trainDataUrl)
testFilePath = tf.keras.utils.get_file("testTitanic.csv",testDataUrl)
输出:
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 1s 29us/step
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/eval.csv
16384/13049 [=====================================] - 0s 28us/step
Windows系统中下载的文件保存在:系统盘:\users\用户名.keras\datasets目录下
2.2 加载数据
labelColumn = "survived" # 指定数据标签的列名
labels = [0,1]
def getDataset(filePath, **kwargs):
dataset = tf.data.experimental.make_csv_dataset(filePath,
batch_size=5,
label_name=labelColumn,
na_value="?",
num_epochs=1,
ignore_errors=True,
**kwargs)
return dataset
rawTrainData = getDataset(trainFilePath)
rawTestData = getDataset(testFilePath)
def showBatch(dataset):
for batch, label in dataset.take(1):
for key, value in batch.items():
print("{:20s}:{}".format(key,value.numpy()))
print("{:20s}:{}".format("label",label.numpy()))
showBatch(rawTrainData)
输出:
sex :[b'male' b'female' b'male' b'male' b'male']
age :[50. 30. 28. 31. 27.]
n_siblings_spouses :[0 0 0 1 0]
parch :[0 0 0 0 0]
fare :[ 13. 106.425 8.4583 52. 8.6625]
class :[b'Second' b'First' b'Third' b'First' b'Third']
deck :[b'unknown' b'unknown' b'unknown' b'B' b'unknown']
embark_town :[b'Southampton' b'Cherbourg' b'Queenstown' b'Southampton' b'Southampton']
alone :[b'y' b'y' b'y' b'n' b'y']
label :[0 1 0 0 1]
3. 数据预处理
通过CSV文件导入的数据每列的数据类型可能不一样,这就需要将数据喂给模型前进行数据预处理。可以使用sklearn等工具进行前处理,再将数据传给Tens