首先导入需要的库
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import seaborn as sns
import tensorflow.keras.layers as layers
#import tensorflow.keras.model as Model
import matplotlib.pyplot as plt
下载数据
dataset_path = keras.utils.get_file('auto-mpg.data', 'https://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data')
print(dataset_path)
``
结果:
C:\Users\john\.keras\datasets\auto-mpg.data
`
数据读取
column_names = ['MPG','Cylinders','Displacement','Horsepower','Weight',
'Acceleration', 'Model Year', 'Origin']
raw_dataset = pd.read_csv(dataset_path, names=column_names,
na_values='?', comment='\t',
sep=' ', skipinitialspace=True)
dataset = raw_dataset.copy()
dataset.tail() #只输出最后几行
结果:
MPG Cylinders Displacement Horsepower Weight Acceleration Model Year Origin
102 26.0 4 97.0 46.0 1950.0 21.0 73.0 2.0
103 11.0 8 400.0 150.0 4997.0 14.0 73.0 1.0
104 12.0 8 400.0 167.0 4906.0 12.5 73.0 1.0
105 13.0 8 360.0 170.0 4654.0 13.0 73.0 1.0
106 12.0 8 350.0 180.0 NaN NaN NaN NaN
清洗数据
#清洗数据
print(dataset.tail().isna())
:
MPG Cylinders Displacement Horsepower Weight Acceleration
101 False False False False False False
102 False False False False False False
103 False False False False False False
104 False False False False False False
105 False False False False False False
Model Year USA Europe Japan
101 False False False False
102 False False False False
103 False False False False
104 False False False False
105 False False False False
print(dataset.tail().isna().sum())
:
MPG 0
Cylinders 0
Displacement 0
Horsepower 0
Weight 0
Acceleration 0
Model Year 0
USA 0
Europe 0
Japan 0
dtype: int64
常用的清洗数据方法
dataset = dataset.dropna() #用dropna()清洗数据
origin = dataset.pop('Origin')
dataset['USA'] = (origin == 1)*1.0
dataset['Europe'] = (origin == 2)*1.0
dataset['Japan'] = (origin == 3)*1.0
dataset.tail()
#划分测试集合训练集
train_dataset = dataset.sample(frac=0.8,random_state=0)
test_dataset = dataset.drop(train_dataset.index)
#不太懂为什么图画出来是这个样子
sns.pairplot(train_dataset[["MPG", "Cylinders", "Displacement", "Weight"]], diag_kind="kde")
#整体数据统计
train_stats = train_dataset.describe()
train_stats.pop("MPG") #把MPG去出
train_stats = train_stats.transpose()
train_stats
#标准化数据
def norm(x):
return (x - train_stats['mean']) / train_stats['std']
normed_train_data = norm(train_dataset)
normed_test_data = norm(test_dataset)
example_batch = normed_train_data[:10]
example_result = model.predict(example_batch)
example_result
结果:
array([[ 9.936947], [13.620609], [17.324503], [13.222853], [13.925085], [24.418594], [19.316984], [22.01549 ], [23.534117], [21.643145]], dtype=float32)
#构建模型
def build_model():
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=[len(train_dataset.keys())]),
layers.Dense(64, activation='relu'),
layers.Dense(1)
])
optimizer = tf.keras.optimizers.RMSprop(0.001)
model.compile(loss='mse',
optimizer=optimizer,
metrics=['mae', 'mse']) #metrics怎么用的?
return model
model = build_model()
model.summary()
#训练模型
class PrintDot(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs):
if epoch % 100 == 0: print('')
print('.', end='')
EPOCHS = 1000
history = model.fit(
normed_train_data, train_labels,
epochs=EPOCHS, validation_split = 0.2, verbose=0,
callbacks=[PrintDot()] #callbacks怎么用的?
)
#查看记录
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch
hist.tail()
``
def plot_history(history):
hist = pd.DataFrame(history.history)
hist[‘epoch’] = history.epoch
plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Abs Error [MPG]')
plt.plot(hist['epoch'], hist['mae'],
label='Train Error')
plt.plot(hist['epoch'], hist['val_mae'],
label = 'Val Error')
plt.ylim([0,5])
plt.legend()
plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Square Error [$MPG^2$]')
plt.plot(hist['epoch'], hist['mse'],
label='Train Error')
plt.plot(hist['epoch'], hist['val_mse'],
label = 'Val Error')
plt.ylim([0,20])
plt.legend()
plt.show()
plot_history(history)
#使用早停法
model = build_model()
early_stop = keras.callbacks.EarlyStopping(monitor=‘val_loss’, patience=10)
history = model.fit(normed_train_data, train_labels, epochs=EPOCHS,
validation_split = 0.2, verbose=0, callbacks=[early_stop, PrintDot()])
plot_history(history)