Classifying movie reviews: IMDB, a binary classification example
0.探索数据
from keras.datasets import imdb
from keras import models
from keras import layers
from keras import callbacks
from keras import regularizers
from keras import backend as K
from keras.wrappers import scikit_learn
from collections import Counter
from sklearn.model_selection import GridSearchCV
import numpy as np
Using TensorFlow backend.
max_features = 10000
(x_train, y_train),(x_test, y_test) = imdb.load_data(num_words=max_features)
print('Training data shape:{}, training labels shape:{}' .format(x_train.shape, y_train.shape))
print('Test data shaep:{}, test labels shape:{}' .format(x_test.shape, y_test.shape))
Training data shape:(25000,), training labels shape:(25000,)
Test data shaep:(25000,), test labels shape:(25000,)
print('{:<6}{:<10}{:<60}{:<10}' .format('No.' , 'Length' , 'Content(first 10 words)' , 'Targets' ))
for i, (x,y) in enumerate(zip(x_train[:100 ], y_train[:100 ])):
target = 'Positive' if y==1 else 'Negative'
print('{:<6}{:<10}{:<60}{:<10}' .format(i, len(x), str(x[:10 ]), target))
No. Length Content(first 10 words) Targets
0 195 [1, 103, 319, 14, 22, 13, 8033, 8, 61, 719] Negative
1 211 [1, 2894, 2, 2, 9, 6, 307, 21, 5129, 22] Positive
2 104 [1, 14, 9, 6, 55, 163, 20, 13, 28, 57] Positive
3 120 [1, 4, 1132, 5, 2914, 26, 574, 11, 4, 644] Positive
4 330 [1, 447, 4, 204, 65, 69, 55, 312, 1398, 18] Negative
5 319 [1, 14, 20, 739, 8, 28, 77, 35, 23, 4] Negative
6 253 [1, 146, 35, 2, 2, 5, 2, 5, 16, 1346] Negative
7 80 [1, 737, 20, 261, 13, 104, 12, 69, 4, 986] Positive
8 820 [1, 1065, 3184, 523, 2, 31, 7, 4, 91, 1149] Negative
9 98 [1, 4, 20, 165, 47, 6, 1018, 52, 65, 21] Positive
10 112 [1, 13, 66, 40, 14, 22, 54, 13, 645, 8] Positive
11 51 [1, 51, 517, 46, 17, 35, 221, 65, 946, 2] Negative
12 189 [1, 13, 16, 149, 14, 54, 61, 322, 446, 8] Positive
13 135 [1, 13, 546, 244, 6, 194, 337, 7, 364, 352] Negative
14 227 [1, 4, 2, 381, 5092, 3977, 17, 6, 3865, 46] Positive
15 122 [1, 14, 9, 24, 4, 801, 3774, 3228, 22, 12] Positive
16 189 [1, 13, 774, 1498, 14, 254, 33, 6, 20, 11] Positive
17 161 [1, 14, 20, 9, 35, 2152, 437, 7, 58, 4] Negative
18 418 [1, 785, 167, 1560, 2, 218, 618, 573, 18, 87] Negative
19 126 [1, 48, 25, 70, 264, 12, 160, 604, 7, 2521] Negative
20 102 [1, 14, 20, 4342, 53, 4346, 3680, 7, 5536, 34] Negative
21 700 [1, 13, 244, 179, 6, 337, 7, 7858, 2979, 488] Negative
22 168 [1, 3281, 9, 1016, 2238, 625, 196, 416, 2, 336] Negative
23 113 [1, 13, 6037, 1133, 14, 20, 17, 160, 338, 781] Negative
24 139 [1, 14, 22, 9, 44, 6, 223, 269, 8, 216] Negative
25 126 [1, 17, 13, 16, 149, 14, 20, 13, 16, 536] Negative
26 154 [1, 857, 18, 14, 22, 71, 2, 33, 118, 137] Negative
27 205 [1, 73, 13, 215, 135, 15, 14, 16, 31, 609] Positive
28 180 [1, 14, 402, 8748, 2, 2196, 1477, 6, 1081, 5794] Negative
29 103 [1, 4, 2050, 506, 16, 1812, 9988, 86, 22, 237] Positive
30 256 [1, 6092, 5, 4156, 323, 17, 241, 3010, 18, 4] Positive
31 121 [1, 35, 204, 2, 2492, 7, 14, 480, 22, 16] Positive
32 159 [1, 4, 64, 147, 2468, 11, 4, 20, 9, 4] Negative
33 83 [1, 14, 20, 144, 24, 2, 17, 438, 261, 12] Negative
34 72 [1, 4, 7568, 1530, 803, 50, 8, 135, 44, 12] Negative
35 303 [1, 138, 81, 13, 1343, 81, 14, 8, 546, 13] Negative
36 177 [1, 14, 592, 1648, 500, 9, 4, 118, 1648, 500] Positive
37 131 [1, 61, 223, 5, 13, 7930, 9597, 4, 314, 159] Negative
38 477 [1, 2, 8902, 16, 641, 153, 1404, 7, 27, 2] Positive
39 100 [1, 242, 162, 2, 249, 20, 126, 93, 10, 10] Negative
40 307 [1, 13, 219, 14, 20, 11, 6747, 1242, 187, 982] Positive
41 227 [1, 14, 3720, 852, 2, 1545, 4, 6610, 5, 12] Positive
42 148 [1, 14, 9, 4, 91, 9453, 664, 13, 28, 126] Positive
43 297 [1, 14, 9, 51, 13, 1040, 8, 49, 369, 908] Positive
44 138 [1, 57, 1028, 133, 21, 13, 28, 77, 6, 337] Negative
45 194 [1, 13, 296, 14, 22, 1033, 23, 288, 5, 13] Negative
46 174 [1, 75, 28, 312, 1398, 19, 14, 31, 88, 94] Positive
47 101 [1, 1334, 418, 7, 52, 438, 14, 9, 31, 7] Positive
48 150 [1, 14, 9, 869, 31, 7, 4, 249, 102, 13] Negative
49 263 [1, 5259, 5, 2289, 28, 77, 6, 2, 7, 438] Negative
50 176 [1, 2, 9, 165, 6, 1491, 56, 11, 1303, 7] Negative
51 155 [1, 13, 104, 12, 16, 35, 3879, 3245, 2001, 595] Negative
52 309 [1, 13, 2, 14, 5, 2, 457, 17, 4, 118] Positive
53 499 [1, 397, 8, 157, 23, 14, 22, 54, 12, 16] Positive
54 331 [1, 13, 16, 2229, 8, 842, 15, 294, 69, 93] Negative
55 45 [1, 13, 421, 1309, 83, 4, 182, 7, 4, 7673] Positive
56 220 [1, 160, 1295, 119, 5929, 22, 39, 2912, 24, 6] Positive
57 138 [1, 608, 86, 7, 32, 4, 374, 272, 40, 12] Negative
58 165 [1, 14, 20, 9, 373, 33, 4, 130, 7, 12] Negative
59 313 [1, 75, 32, 28, 110, 49, 2, 7445, 11, 263] Negative
60 472 [1, 2891, 185, 7362, 1102, 39, 1831, 8, 162, 782] Positive
61 121 [1, 14, 9, 35, 589, 34, 199, 2133, 4500, 7418] Negative
62 499 [1, 33, 31, 130, 7, 4, 4277, 3864, 4517, 1075] Positive
63 304 [1, 14, 3679, 2, 2, 11, 14, 20, 5, 11] Positive
64 99 [1, 13, 633, 40, 865, 102, 21, 14, 31, 16] Negative
65 161 [1, 4, 20, 2013, 56, 19, 6, 196, 686, 324] Negative
66 127 [1, 14, 22, 127, 6, 897, 292, 7, 5007, 4] Positive
67 332 [1, 198, 24, 43, 61, 1192, 6843, 23, 14, 22] Negative
68 152 [1, 141, 6, 1917, 20, 55, 483, 4051, 31, 191] Positive
69 112 [1, 321, 993, 708, 256, 76, 2, 5, 7569, 74] Positive
70 95 [1, 51, 9, 14, 9, 12, 6, 212, 6, 189] Negative
71 167 [1, 13, 43, 296, 14, 20, 23, 2, 387, 72] Negative
72 258 [1, 261, 4, 973, 18, 14, 22, 16, 5481, 572] Negative
73 96 [1, 14, 20, 16, 40, 4, 910, 1308, 103, 465] Negative
74 123 [1, 11, 35, 589, 8, 721, 145, 4, 1629, 1179] Negative
75 140 [1, 2, 2, 6592, 9, 4, 2, 1094, 1162, 664] Positive
76 119 [1, 1065, 5121, 470, 4, 2, 9, 6, 93, 18] Negative
77 406 [1, 45, 254, 8, 376, 48, 2, 5, 5343, 26] Negative
78 261 [1, 10, 10, 261, 4, 485, 524, 9, 9546, 307] Negative
79 206 [1, 13, 286, 1097, 252, 51, 8, 535, 39, 6] Negative
80 328 [1, 13, 2, 69, 14, 20, 23, 61, 4200, 18] Negative
81 257 [1, 57, 9697, 466, 1115, 206, 10, 10, 1067, 1219] Negative
82 460 [1, 86, 13, 144, 760, 15, 13, 66, 510, 2] Negative
83 178 [1, 103, 1790, 56, 11, 4, 2542, 1986, 7, 7190] Positive
84 121 [1, 61, 2, 13, 258, 14, 20, 8, 30, 55] Negative
85 129 [1, 13, 377, 319, 14, 31, 54, 13, 16, 1542] Positive
86 191 [1, 2, 46, 7, 1827, 6, 1366, 7, 6751, 1298] Negative
87 586 [1, 13, 92, 104, 15, 111, 108, 262, 1290, 28] Positive
88 227 [1, 103, 4, 1023, 7, 4, 333, 2, 745, 3378] Positive
89 126 [1, 1450, 7824, 830, 758, 2, 9, 6, 20, 15] Positive
90 239 [1, 31, 7, 4, 1126, 1936, 7, 2, 2115, 7] Positive
91 148 [1, 13, 473, 8, 40, 14, 22, 422, 94, 6] Negative
92 159 [1, 14, 20, 16, 373, 13, 69, 6, 55, 878] Negative
93 110 [1, 13, 219, 14, 145, 11, 2, 54, 12, 16] Negative
94 47 [1, 66, 45, 164, 76, 13, 64, 386, 149, 12] Negative
95 181 [1, 9, 12, 6, 52, 326, 8, 361, 412, 1389] Positive
96 377 [1, 4, 20, 778, 6, 19, 2985, 5235, 250, 19] Positive
97 55 [1, 321, 2542, 5, 283, 1146, 7, 2, 7033, 113] Positive
98 54 [1, 13, 92, 124, 138, 21, 54, 13, 244, 1803] Negative
99 81 [1, 316, 299, 68, 173, 184, 73, 11, 14, 117] Positive
train_labels_count = Counter(y_train)
test_labels_count = Counter(y_test)
print('Training labels\npositive:{}, negative:{}' .format(train_labels_count[1 ], train_labels_count[0 ]))
print('Test labels\npositive:{}, negative:{}' .format(test_labels_count[1 ], test_labels_count[0 ]))
Training labels
positive:12500, negative:12500
Test labels
positive:12500, negative:12500
word_index = imdb.get_word_index()
num_to_word = { value:key for (key, value) in word_index.items()}
print('Max word index:' , max([max(sequence) for sequence in x_train]))
print('Min word index:' , min([min(sequence) for sequence in x_train]))
def decode_review (num_to_word, review) :
decoded = ' ' .join( [num_to_word.get(i-3 , '?' ) for i in review])
return decoded
print(decode_review(num_to_word, x_train[0 ]))
Max word index: 9999
Min word index: 1
? i rated this movie as awful 1 after watching the trailer i thought this movie could be pretty cool guaranteed to offend everyone the trailer said well it did offend me because this movie really sucks it is hardly a comedy as i laughed about two seconds during the entire movie and what's with all the gays in this movie i'm not gay and i don't have a problem with those who are but what's the point of adding so many gay scenes in a so called comedy movie when these scenes are absolutely not funny i guess the director is a gay man in denial or something like that br br so my advice to you is if you want to waste good money go rent a good comedy you've already seen a million times you'll be better off than watching this mother of all lousy ? it really is total crap
for i in range(1 , 20 ):
print('No.%d \t\t %s' %(i, num_to_word[i]))
No.1 the
No.2 and
No.3 a
No.4 of
No.5 to
No.6 is
No.7 br
No.8 in
No.9 it
No.10 i
No.11 this
No.12 that
No.13 was
No.14 as
No.15 for
No.16 with
No.17 movie
No.18 but
No.19 film
0.结论
训练数据有25000条,训练数据的长度不是固定。测试数据25000条,测试数据长度不固定。因此需要将每条数据进行固定长度,太长的数据进行截断,太短的数据进行填充 单词转为数字才能作为网络的输入,数字越小说明单词出现频率越高,但是一些高频率单词,例如the, and , a之类的并没有提供有效的信息,因此可以考虑将这部分单词过滤,数字的范围 [1, max_feature]。 样本的正负比例是1:1,无论是训练样本还是测试样本,都是12500条正样本, 12500条负样本 样本的排列已经是随机状态。
1. 定义问题
输入数据:
x ∈ R n
x
∈
R
n
,
n
n
表示长度,x i x i 是一个数字,表示一个单词,不同的
x
x
,n的值不同
输出:y ∈ { 0 , 1 } y ∈ { 0 , 1 } , 0表示负面评价,1表示正面评价 问题归类:属于二元分类问题
2. 衡量指标
这是一个正负比例平衡的问题,因此选择Accuracy作为衡量模型指标
3. 验证策略
4. 准备数据
过滤出现频率过高的单词,过滤前10% 固定样本长度,过长截断,过短填充 将数据向量化,有两种策略(两种都试试吧):
将
x
x
转换为一个X ∈ R M X ∈ R M ,M是单词最多个数,
X i ∈ { 0 , 1 }
X
i
∈
{
0
,
1
}
,1表示数字为
i
i
的单词出现在x x 中 加入EMbedding层
max_features = 10000
skips = 50
(x_train, y_train),(x_test, y_test) = imdb.load_data(num_words=max_features, skip_top=skips)
def vectorize_sequences (sequences, dimension=10000 ) :
results = np.zeros((len(sequences), dimension))
for i, sequence in enumerate(sequences):
results[i, sequence] = 1.
return results
x_train = vectorize_sequences(x_train, max_features)
x_test = vectorize_sequences(x_test, max_features)
print('x_train shape:{}, x_test shaep:{}' .format(x_train.shape, x_test.shape) )
x_train shape:(25000, 10000), x_test shaep:(25000, 10000)
y_train = y_train.astype(np.float32)
y_test = y_test.astype(np.float32)
5. 简单模型
该问题的base line 是0.5。建立一个简单的模型,准确率高于0.5.
def build_base_model () :
model = models.Sequential()
model.add(layers.Dense(16 , activation='relu' , input_shape=(max_features, )))
model.add(layers.Dense(1 , activation='sigmoid' ))
model.compile(optimizer='adam' , loss='binary_crossentropy' , metrics=['acc' ])
return model
base_model = build_base_model()
callback_list = [callbacks.EarlyStopping(monitor='val_loss' , patience=8 ),
callbacks.ModelCheckpoint('best_base_model.h5' , save_best_only=True ),
callbacks.TensorBoard('./logs' , histogram_freq=1 )]
base_model.fit(x_train, y_train, epochs=20 , batch_size=32 , callbacks=callback_list, validation_split=0.3 )
Train on 17500 samples, validate on 7500 samples
Epoch 1/20
17500/17500 [==============================] - 9s 530us/step - loss: 0.4716 - acc: 0.7906 - val_loss: 0.3832 - val_acc: 0.8315
Epoch 2/20
17500/17500 [==============================] - 9s 507us/step - loss: 0.2629 - acc: 0.8949 - val_loss: 0.4079 - val_acc: 0.8223
Epoch 3/20
17500/17500 [==============================] - 9s 519us/step - loss: 0.1842 - acc: 0.9311 - val_loss: 0.4648 - val_acc: 0.8185
Epoch 4/20
17500/17500 [==============================] - 9s 521us/step - loss: 0.1307 - acc: 0.9549 - val_loss: 0.5308 - val_acc: 0.8133
Epoch 5/20
17500/17500 [==============================] - 9s 533us/step - loss: 0.0938 - acc: 0.9678 - val_loss: 0.6205 - val_acc: 0.8093
Epoch 6/20
17500/17500 [==============================] - 9s 526us/step - loss: 0.0662 - acc: 0.9803 - val_loss: 0.7154 - val_acc: 0.8067
Epoch 7/20
17500/17500 [==============================] - 9s 514us/step - loss: 0.0492 - acc: 0.9860 - val_loss: 0.8079 - val_acc: 0.8025
Epoch 8/20
17500/17500 [==============================] - 9s 513us/step - loss: 0.0348 - acc: 0.9914 - val_loss: 0.9046 - val_acc: 0.8000
Epoch 9/20
17500/17500 [==============================] - 9s 533us/step - loss: 0.0254 - acc: 0.9939 - val_loss: 0.9714 - val_acc: 0.8028
<keras.callbacks.History at 0x7f5cad8a9ef0>
5.1 结论
在验证集上准确率最高83%左右,但是一开始就过拟合了。
6. 全面升级:开发一个过拟合的模型
7. 调整参数
尝试Dropout、Batch Normaliztion等 添加L1/L2正则化 尝试不同的网络结构,添加层或者删除层 尝试不同的超参数,例如神经元的个数,batch_size等等
def create_model (dropout=0.2 , L=0.001 ) :
model = models.Sequential()
model.add(layers.Dense(16 ,
activation='relu' , kernel_initializer='he_normal' ,
input_shape=(max_features, )))
model.add(layers.Dropout(dropout))
model.add(layers.BatchNormalization())
model.add(layers.Dense(16 ,activation='relu' , kernel_initializer='he_normal' ))
model.add(layers.Dropout(dropout))
model.add(layers.BatchNormalization())
model.add(layers.Dense(1 , activation='sigmoid' , kernel_initializer='he_normal' ))
model.compile(optimizer='adam' , loss='binary_crossentropy' , metrics=['acc' ])
return model
K.clear_session()
dropout = 0.5
L = 0.001
model = create_model(dropout, L)
callback_list = [callbacks.EarlyStopping(monitor='val_loss' , patience=7 ),
callbacks.ModelCheckpoint('best_model.h5' , save_best_only=True ),
callbacks.TensorBoard('./logs' , histogram_freq=1 )]
model.fit(x_train, y_train, epochs=50 , batch_size=64 , callbacks=callback_list, validation_split=0.2 )
Train on 20000 samples, validate on 5000 samples
Epoch 1/50
20000/20000 [==============================] - 10s 478us/step - loss: 0.5407 - acc: 0.7213 - val_loss: 0.3559 - val_acc: 0.8792
Epoch 2/50
20000/20000 [==============================] - 9s 440us/step - loss: 0.3554 - acc: 0.8560 - val_loss: 0.2814 - val_acc: 0.8870
Epoch 3/50
20000/20000 [==============================] - 8s 408us/step - loss: 0.2865 - acc: 0.8887 - val_loss: 0.2762 - val_acc: 0.8900
Epoch 4/50
20000/20000 [==============================] - 8s 423us/step - loss: 0.2525 - acc: 0.9026 - val_loss: 0.2788 - val_acc: 0.8894
Epoch 5/50
20000/20000 [==============================] - 8s 398us/step - loss: 0.2193 - acc: 0.9168 - val_loss: 0.2879 - val_acc: 0.8910
Epoch 6/50
20000/20000 [==============================] - 8s 424us/step - loss: 0.2068 - acc: 0.9197 - val_loss: 0.2954 - val_acc: 0.8884
Epoch 7/50
20000/20000 [==============================] - 8s 404us/step - loss: 0.1847 - acc: 0.9280 - val_loss: 0.3156 - val_acc: 0.8872
Epoch 8/50
20000/20000 [==============================] - 8s 418us/step - loss: 0.1756 - acc: 0.9320 - val_loss: 0.3034 - val_acc: 0.8876
Epoch 9/50
20000/20000 [==============================] - 9s 444us/step - loss: 0.1626 - acc: 0.9365 - val_loss: 0.3270 - val_acc: 0.8862
Epoch 10/50
20000/20000 [==============================] - 10s 479us/step - loss: 0.1655 - acc: 0.9342 - val_loss: 0.3384 - val_acc: 0.8844
<keras.callbacks.History at 0x7fa095863c18>
model = models.load_model('best_model.h5' )
model.evaluate(x_test, y_test, batch_size=64 )
25000/25000 [==============================] - 3s 133us/step
[0.29331848189353943, 0.87672000013351437]