环境准备
EDA过程中使用了的环境和第三方库::
- python 3.7.2
- Jupyter Notebook(代码均在此测试成功)
- pandas 0.23.4
- numpy 1.15.4(数据处理)
- matplotlib 3.0.2
- seaborn 0.9.0(数据可视化)
数据准备
数据字段简介,训练集与测试集下载地址:
https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data.
正文
开工前准备,导入第三方库:
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
pd.set_option('max_columns',None) # 设置后打印DataFrame显示完整的列
plt.style.use('seaborn-whitegrid')
训练集有5.4GB,全部导入内存吃不消,先加载200万行训练集看看:
(这里介绍了如何导入大数据集csv:https://www.kaggle.com/szelee/how-to-import-a-csv-file-of-55-million-rows)
df=pd.read_csv(
r'.\dataset\train.csv',
nrows=2000000,
parse_dates=['pickup_datetime'])
df_test=pd.read_csv(r'./dataset/test.csv')
简单了解下数据:
display(df.head()) # 预览前5行数据
display(df.info()) # 数据类型,内存使用
display(df.describe()) # 数值型数据的统计量(统计时不计入空缺值nan),包括数量,均值,标准差,百分位值,极值
key | fare_amount | pickup_datetime | pickup_longitude | pickup_latitude | dropoff_longitude | dropoff_latitude | passenger_count | |
---|---|---|---|---|---|---|---|---|
0 | 2009-06-15 17:26:21.0000001 | 4.5 | 2009-06-15 17:26:21 | -73.844311 | 40.721319 | -73.841610 | 40.712278 | 1 |
1 | 2010-01-05 16:52:16.0000002 | 16.9 | 2010-01-05 16:52:16 | -74.016048 | 40.711303 | -73.979268 | 40.782004 | 1 |
2 | 2011-08-18 00:35:00.00000049 | 5.7 | 2011-08-18 00:35:00 | -73.982738 | 40.761270 | -73.991242 | 40.750562 | 2 |
3 | 2012-04-21 04:30:42.0000001 | 7.7 | 2012-04-21 04:30:42 | -73.987130 | 40.733143 | -73.991567 | 40.758092 | 1 |
4 | 2010-03-09 07:51:00.000000135 | 5.3 | 2010-03-09 07:51:00 | -73.968095 | 40.768008 | -73.956655 | 40.783762 | 1 |
<class ‘pandas.core.frame.DataFrame’>
RangeIndex: 2000000 entries, 0 to 1999999
Data columns (total 8 columns):
key object
fare_amount float64
pickup_datetime datetime64[ns]
pickup_longitude float64
pickup_latitude float64
dropoff_longitude float64
dropoff_latitude float64
passenger_count int64
dtypes: datetime64(1), float64(5), int64(1), object(1)
memory usage: 122.1+ MB
fare_amount | pickup_longitude | pickup_latitude | dropoff_longitude | dropoff_latitude | passenger_count | |
---|---|---|---|---|---|---|
count | 2.000000e+06 | 2.000000e+06 | 2.000000e+06 | 1.999986e+06 | 1.999986e+06 | 2.000000e+06 |
mean | 1.134779e+01 | -7.252321e+01 | 3.992963e+01 | -7.252395e+01 | 3.992808e+01 | 1.684113e+00 |
std | 9.852883e+00 | 1.286804e+01 | 7.983352e+00 | 1.277497e+01 | 1.032382e+01 | 1.314982e+00 |
min | -6.200000e+01 | -3.377681e+03 | -3.458665e+03 | -3.383297e+03 | -3.461541e+03 | 0.000000e+00 |
25% | 6.000000e+00 | -7.399208e+01 | 4.073491e+01 | -7.399141e+01 | 4.073400e+01 | 1.000000e+00 |
50% | 8.500000e+00 | -7.398181e+01 | 4.075263e+01 | -7.398016e+01 | 4.075312e+01 | 1.000000e+00 |
75% | 1.250000e+01 | -7.396713e+01 | 4.076710e+01 | -7.396369e+01 | 4.076809e+01 | 2.000000e+00 |
max | 1.273310e+03 | 2.856442e+03 | 2.621628e+03 | 3.414307e+03 | 3.345917e+03 | 2.080000e+02 |
数据预处理
观察到fare_amount最小值小于0,需要过滤掉这些数据:
print("Old Size: {oldsize}".format(oldsize=len(df)))
df_train=df[df['fare_amount']>=0]
print("New Size: {newsize}".format(newsize=len(df_train)))
Old Size: 2000000
New Size: 1999923
检查一下数据缺失的情况:
df_train.isnull().agg(['any','mean','sum','count'])
# any 是否存在缺失值,mean 缺失值频率,sum 缺失值频数,count 总数
key | fare_amount | pickup_datetime | pickup_longitude | pickup_latitude | dropoff_longitude | dropoff_latitude | passenger_count | |
---|---|---|---|---|---|---|---|---|
any | False | False | False | False | False | True | True | False |
mean | 0 | 0 | 0 | 0 | 0 | 7.00027e-06 | 7.00027e-06 | 0 |
sum | 0 | 0 | 0 | 0 | 0 | 14 | 14 | 0 |
count | 1999923 | 1999923 | 1999923 | 1999923 | 1999923 | 1999923 | 1999923 | 1999923 |
数据比较完整,可以考虑直接删除不完整的数据:
print("Old Size: {oldsize}".format(oldsize=len(df_train)))
df_train=df_train.dropna(how='any',axis=0)
print("New Size: {newsize}".format(newsize=len(df_train)))
Old Size: 1999923
New Size: 1999909
探索性数据分析
首先了解一下训练集taget - fare_amount 的分布图,值集中在(0,20)之间,且数据跨度很大,也可以从上文的数据describe()看出来:
plt.figure(figsize=(16,4))
# 所有fare_amount数据的分布
plt.subplot(1,2,1)
sns.distplot(df_train['fare_amount'])
# fare_amount<100数据的分布
plt.subplot(1,2,2)
sns.distplot(df_train[df_train['fare_amount']<100]['fare_amount'])
plt.show()
测试集的上下车点经纬度范围:
# 上下车点的纬度范围
min(df_test.pickup_latitude.min(),df_test.dropoff_latitude.min()),\
max(df_test.pickup_latitude.max(),df_test.dropoff_latitude.max())
# 上下车点的经度范围
min(df_test.pickup_longitude.min(),df_test.dropoff_longitude.min()),\
max(df_test.pickup_longitude.max(),df_test.dropoff_longitude.max())
(40.568973, 41.709555)
(-74.263242, -72.986532)
网上查了下纽约市的经纬度范围:
# 地图
BB = (-74.5, -72.8,