Keras-6 IMDB, a binary classification example

Classifying movie reviews: IMDB, a binary classification example

0.探索数据

from keras.datasets import imdb
from keras import models
from keras import layers
from keras import callbacks
from keras import regularizers
from keras import backend as K
from keras.wrappers import scikit_learn

from collections import Counter

from sklearn.model_selection import GridSearchCV

import numpy as np
Using TensorFlow backend.
# 导入数据

# 只选择出现频率最多的前 10000 个单词
max_features = 10000
(x_train, y_train),(x_test, y_test) = imdb.load_data(num_words=max_features)
print('Training data shape:{}, training labels shape:{}'.format(x_train.shape, y_train.shape))
print('Test data shaep:{}, test labels shape:{}'.format(x_test.shape, y_test.shape))
Training data shape:(25000,), training labels shape:(25000,)
Test data shaep:(25000,), test labels shape:(25000,)
# 显示部分数据
#print('No.\t\tLength\t\tContent(the first 10 words)\t\t\t\tTarget')
print('{:<6}{:<10}{:<60}{:<10}'.format('No.', 'Length', 'Content(first 10 words)', 'Targets'))
for i, (x,y) in enumerate(zip(x_train[:100], y_train[:100])):
    target = 'Positive'if y==1 else 'Negative'
    print('{:<6}{:<10}{:<60}{:<10}'.format(i, len(x), str(x[:10]), target))
No.   Length    Content(first 10 words)                                     Targets   
0     195       [1, 103, 319, 14, 22, 13, 8033, 8, 61, 719]                 Negative  
1     211       [1, 2894, 2, 2, 9, 6, 307, 21, 5129, 22]                    Positive  
2     104       [1, 14, 9, 6, 55, 163, 20, 13, 28, 57]                      Positive  
3     120       [1, 4, 1132, 5, 2914, 26, 574, 11, 4, 644]                  Positive  
4     330       [1, 447, 4, 204, 65, 69, 55, 312, 1398, 18]                 Negative  
5     319       [1, 14, 20, 739, 8, 28, 77, 35, 23, 4]                      Negative  
6     253       [1, 146, 35, 2, 2, 5, 2, 5, 16, 1346]                       Negative  
7     80        [1, 737, 20, 261, 13, 104, 12, 69, 4, 986]                  Positive  
8     820       [1, 1065, 3184, 523, 2, 31, 7, 4, 91, 1149]                 Negative  
9     98        [1, 4, 20, 165, 47, 6, 1018, 52, 65, 21]                    Positive  
10    112       [1, 13, 66, 40, 14, 22, 54, 13, 645, 8]                     Positive  
11    51        [1, 51, 517, 46, 17, 35, 221, 65, 946, 2]                   Negative  
12    189       [1, 13, 16, 149, 14, 54, 61, 322, 446, 8]                   Positive  
13    135       [1, 13, 546, 244, 6, 194, 337, 7, 364, 352]                 Negative  
14    227       [1, 4, 2, 381, 5092, 3977, 17, 6, 3865, 46]                 Positive  
15    122       [1, 14, 9, 24, 4, 801, 3774, 3228, 22, 12]                  Positive  
16    189       [1, 13, 774, 1498, 14, 254, 33, 6, 20, 11]                  Positive  
17    161       [1, 14, 20, 9, 35, 2152, 437, 7, 58, 4]                     Negative  
18    418       [1, 785, 167, 1560, 2, 218, 618, 573, 18, 87]               Negative  
19    126       [1, 48, 25, 70, 264, 12, 160, 604, 7, 2521]                 Negative  
20    102       [1, 14, 20, 4342, 53, 4346, 3680, 7, 5536, 34]              Negative  
21    700       [1, 13, 244, 179, 6, 337, 7, 7858, 2979, 488]               Negative  
22    168       [1, 3281, 9, 1016, 2238, 625, 196, 416, 2, 336]             Negative  
23    113       [1, 13, 6037, 1133, 14, 20, 17, 160, 338, 781]              Negative  
24    139       [1, 14, 22, 9, 44, 6, 223, 269, 8, 216]                     Negative  
25    126       [1, 17, 13, 16, 149, 14, 20, 13, 16, 536]                   Negative  
26    154       [1, 857, 18, 14, 22, 71, 2, 33, 118, 137]                   Negative  
27    205       [1, 73, 13, 215, 135, 15, 14, 16, 31, 609]                  Positive  
28    180       [1, 14, 402, 8748, 2, 2196, 1477, 6, 1081, 5794]            Negative  
29    103       [1, 4, 2050, 506, 16, 1812, 9988, 86, 22, 237]              Positive  
30    256       [1, 6092, 5, 4156, 323, 17, 241, 3010, 18, 4]               Positive  
31    121       [1, 35, 204, 2, 2492, 7, 14, 480, 22, 16]                   Positive  
32    159       [1, 4, 64, 147, 2468, 11, 4, 20, 9, 4]                      Negative  
33    83        [1, 14, 20, 144, 24, 2, 17, 438, 261, 12]                   Negative  
34    72        [1, 4, 7568, 1530, 803, 50, 8, 135, 44, 12]                 Negative  
35    303       [1, 138, 81, 13, 1343, 81, 14, 8, 546, 13]                  Negative  
36    177       [1, 14, 592, 1648, 500, 9, 4, 118, 1648, 500]               Positive  
37    131       [1, 61, 223, 5, 13, 7930, 9597, 4, 314, 159]                Negative  
38    477       [1, 2, 8902, 16, 641, 153, 1404, 7, 27, 2]                  Positive  
39    100       [1, 242, 162, 2, 249, 20, 126, 93, 10, 10]                  Negative  
40    307       [1, 13, 219, 14, 20, 11, 6747, 1242, 187, 982]              Positive  
41    227       [1, 14, 3720, 852, 2, 1545, 4, 6610, 5, 12]                 Positive  
42    148       [1, 14, 9, 4, 91, 9453, 664, 13, 28, 126]                   Positive  
43    297       [1, 14, 9, 51, 13, 1040, 8, 49, 369, 908]                   Positive  
44    138       [1, 57, 1028, 133, 21, 13, 28, 77, 6, 337]                  Negative  
45    194       [1, 13, 296, 14, 22, 1033, 23, 288, 5, 13]                  Negative  
46    174       [1, 75, 28, 312, 1398, 19, 14, 31, 88, 94]                  Positive  
47    101       [1, 1334, 418, 7, 52, 438, 14, 9, 31, 7]                    Positive  
48    150       [1, 14, 9, 869, 31, 7, 4, 249, 102, 13]                     Negative  
49    263       [1, 5259, 5, 2289, 28, 77, 6, 2, 7, 438]                    Negative  
50    176       [1, 2, 9, 165, 6, 1491, 56, 11, 1303, 7]                    Negative  
51    155       [1, 13, 104, 12, 16, 35, 3879, 3245, 2001, 595]             Negative  
52    309       [1, 13, 2, 14, 5, 2, 457, 17, 4, 118]                       Positive  
53    499       [1, 397, 8, 157, 23, 14, 22, 54, 12, 16]                    Positive  
54    331       [1, 13, 16, 2229, 8, 842, 15, 294, 69, 93]                  Negative  
55    45        [1, 13, 421, 1309, 83, 4, 182, 7, 4, 7673]                  Positive  
56    220       [1, 160, 1295, 119, 5929, 22, 39, 2912, 24, 6]              Positive  
57    138       [1, 608, 86, 7, 32, 4, 374, 272, 40, 12]                    Negative  
58    165       [1, 14, 20, 9, 373, 33, 4, 130, 7, 12]                      Negative  
59    313       [1, 75, 32, 28, 110, 49, 2, 7445, 11, 263]                  Negative  
60    472       [1, 2891, 185, 7362, 1102, 39, 1831, 8, 162, 782]           Positive  
61    121       [1, 14, 9, 35, 589, 34, 199, 2133, 4500, 7418]              Negative  
62    499       [1, 33, 31, 130, 7, 4, 4277, 3864, 4517, 1075]              Positive  
63    304       [1, 14, 3679, 2, 2, 11, 14, 20, 5, 11]                      Positive  
64    99        [1, 13, 633, 40, 865, 102, 21, 14, 31, 16]                  Negative  
65    161       [1, 4, 20, 2013, 56, 19, 6, 196, 686, 324]                  Negative  
66    127       [1, 14, 22, 127, 6, 897, 292, 7, 5007, 4]                   Positive  
67    332       [1, 198, 24, 43, 61, 1192, 6843, 23, 14, 22]                Negative  
68    152       [1, 141, 6, 1917, 20, 55, 483, 4051, 31, 191]               Positive  
69    112       [1, 321, 993, 708, 256, 76, 2, 5, 7569, 74]                 Positive  
70    95        [1, 51, 9, 14, 9, 12, 6, 212, 6, 189]                       Negative  
71    167       [1, 13, 43, 296, 14, 20, 23, 2, 387, 72]                    Negative  
72    258       [1, 261, 4, 973, 18, 14, 22, 16, 5481, 572]                 Negative  
73    96        [1, 14, 20, 16, 40, 4, 910, 1308, 103, 465]                 Negative  
74    123       [1, 11, 35, 589, 8, 721, 145, 4, 1629, 1179]                Negative  
75    140       [1, 2, 2, 6592, 9, 4, 2, 1094, 1162, 664]                   Positive  
76    119       [1, 1065, 5121, 470, 4, 2, 9, 6, 93, 18]                    Negative  
77    406       [1, 45, 254, 8, 376, 48, 2, 5, 5343, 26]                    Negative  
78    261       [1, 10, 10, 261, 4, 485, 524, 9, 9546, 307]                 Negative  
79    206       [1, 13, 286, 1097, 252, 51, 8, 535, 39, 6]                  Negative  
80    328       [1, 13, 2, 69, 14, 20, 23, 61, 4200, 18]                    Negative  
81    257       [1, 57, 9697, 466, 1115, 206, 10, 10, 1067, 1219]           Negative  
82    460       [1, 86, 13, 144, 760, 15, 13, 66, 510, 2]                   Negative  
83    178       [1, 103, 1790, 56, 11, 4, 2542, 1986, 7, 7190]              Positive  
84    121       [1, 61, 2, 13, 258, 14, 20, 8, 30, 55]                      Negative  
85    129       [1, 13, 377, 319, 14, 31, 54, 13, 16, 1542]                 Positive  
86    191       [1, 2, 46, 7, 1827, 6, 1366, 7, 6751, 1298]                 Negative  
87    586       [1, 13, 92, 104, 15, 111, 108, 262, 1290, 28]               Positive  
88    227       [1, 103, 4, 1023, 7, 4, 333, 2, 745, 3378]                  Positive  
89    126       [1, 1450, 7824, 830, 758, 2, 9, 6, 20, 15]                  Positive  
90    239       [1, 31, 7, 4, 1126, 1936, 7, 2, 2115, 7]                    Positive  
91    148       [1, 13, 473, 8, 40, 14, 22, 422, 94, 6]                     Negative  
92    159       [1, 14, 20, 16, 373, 13, 69, 6, 55, 878]                    Negative  
93    110       [1, 13, 219, 14, 145, 11, 2, 54, 12, 16]                    Negative  
94    47        [1, 66, 45, 164, 76, 13, 64, 386, 149, 12]                  Negative  
95    181       [1, 9, 12, 6, 52, 326, 8, 361, 412, 1389]                   Positive  
96    377       [1, 4, 20, 778, 6, 19, 2985, 5235, 250, 19]                 Positive  
97    55        [1, 321, 2542, 5, 283, 1146, 7, 2, 7033, 113]               Positive  
98    54        [1, 13, 92, 124, 138, 21, 54, 13, 244, 1803]                Negative  
99    81        [1, 316, 299, 68, 173, 184, 73, 11, 14, 117]                Positive  
# 统计正负数据的个数、比例

# 训练数据
train_labels_count = Counter(y_train)
test_labels_count = Counter(y_test)
print('Training labels\npositive:{}, negative:{}'.format(train_labels_count[1], train_labels_count[0]))
print('Test labels\npositive:{}, negative:{}'.format(test_labels_count[1], test_labels_count[0]))
Training labels
positive:12500, negative:12500
Test labels
positive:12500, negative:12500
# 将数字转换为单词
word_index = imdb.get_word_index()
num_to_word = { value:key for (key, value) in word_index.items()}
print('Max word index:', max([max(sequence) for sequence in x_train]))
print('Min word index:', min([min(sequence) for sequence in x_train]))

def decode_review(num_to_word, review):
    # i - 3是因为0, 1, 2代表着'padding','start', 'unknow', 因此单词的下标真正是从3开始的
    decoded = ' '.join( [num_to_word.get(i-3, '?') for i in review])
    return decoded

print(decode_review(num_to_word, x_train[0]))
Max word index: 9999
Min word index: 1
? i rated this movie as awful 1 after watching the trailer i thought this movie could be pretty cool guaranteed to offend everyone the trailer said well it did offend me because this movie really sucks it is hardly a comedy as i laughed about two seconds during the entire movie and what's with all the gays in this movie i'm not gay and i don't have a problem with those who are but what's the point of adding so many gay scenes in a so called comedy movie when these scenes are absolutely not funny i guess the director is a gay man in denial or something like that br br so my advice to you is if you want to waste good money go rent a good comedy you've already seen a million times you'll be better off than watching this mother of all lousy ? it really is total crap
# 打印出使用频率最高的一些单词
for i in range(1, 20):
    print('No.%d \t\t %s'%(i, num_to_word[i]))
No.1         the
No.2         and
No.3         a
No.4         of
No.5         to
No.6         is
No.7         br
No.8         in
No.9         it
No.10        i
No.11        this
No.12        that
No.13        was
No.14        as
No.15        for
No.16        with
No.17        movie
No.18        but
No.19        film

0.结论

  • 训练数据有25000条,训练数据的长度不是固定。测试数据25000条,测试数据长度不固定。因此需要将每条数据进行固定长度,太长的数据进行截断,太短的数据进行填充
  • 单词转为数字才能作为网络的输入,数字越小说明单词出现频率越高,但是一些高频率单词,例如the, and , a之类的并没有提供有效的信息,因此可以考虑将这部分单词过滤,数字的范围 [1, max_feature]。
  • 样本的正负比例是1:1,无论是训练样本还是测试样本,都是12500条正样本, 12500条负样本
  • 样本的排列已经是随机状态。

1. 定义问题

  • 输入数据: xRn x ∈ R n , n n 表示长度,xi是一个数字,表示一个单词,不同的 x x ,n的值不同
  • 输出:y{0,1}, 0表示负面评价,1表示正面评价
  • 问题归类:属于二元分类问题

2. 衡量指标

  • 这是一个正负比例平衡的问题,因此选择Accuracy作为衡量模型指标

3. 验证策略

  • 数据够多,选择hold-out验证策略

4. 准备数据

  • 过滤出现频率过高的单词,过滤前10%
  • 固定样本长度,过长截断,过短填充
  • 将数据向量化,有两种策略(两种都试试吧):
    1. x x 转换为一个XRM,M是单词最多个数, Xi{0,1} X i ∈ { 0 , 1 } ,1表示数字为 i i 的单词出现在x
    2. 加入EMbedding层
# 单词个数
max_features = 10000
skips = 50

# 导入数据,设置过滤个数,单词总数
(x_train, y_train),(x_test, y_test) = imdb.load_data(num_words=max_features, skip_top=skips)
# 向量化(策略1)
def vectorize_sequences(sequences, dimension=10000):
    results = np.zeros((len(sequences), dimension))

    for i, sequence in enumerate(sequences):
        results[i, sequence] = 1.

    return results

x_train = vectorize_sequences(x_train, max_features)
x_test = vectorize_sequences(x_test, max_features)
print('x_train shape:{}, x_test shaep:{}'.format(x_train.shape, x_test.shape) )
x_train shape:(25000, 10000), x_test shaep:(25000, 10000)
# 目标向量化
y_train = y_train.astype(np.float32)
y_test = y_test.astype(np.float32)

5. 简单模型

  • 该问题的base line 是0.5。建立一个简单的模型,准确率高于0.5.
def build_base_model():
    model = models.Sequential()

    model.add(layers.Dense(16, activation='relu', input_shape=(max_features, )))
    model.add(layers.Dense(1, activation='sigmoid'))

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
    return model

# 建立模型
base_model = build_base_model()
# 回调函数
callback_list = [callbacks.EarlyStopping(monitor='val_loss', patience=8),
                callbacks.ModelCheckpoint('best_base_model.h5', save_best_only=True),
                callbacks.TensorBoard('./logs', histogram_freq=1)]
base_model.fit(x_train, y_train, epochs=20, batch_size=32, callbacks=callback_list, validation_split=0.3)
Train on 17500 samples, validate on 7500 samples
Epoch 1/20
17500/17500 [==============================] - 9s 530us/step - loss: 0.4716 - acc: 0.7906 - val_loss: 0.3832 - val_acc: 0.8315
Epoch 2/20
17500/17500 [==============================] - 9s 507us/step - loss: 0.2629 - acc: 0.8949 - val_loss: 0.4079 - val_acc: 0.8223
Epoch 3/20
17500/17500 [==============================] - 9s 519us/step - loss: 0.1842 - acc: 0.9311 - val_loss: 0.4648 - val_acc: 0.8185
Epoch 4/20
17500/17500 [==============================] - 9s 521us/step - loss: 0.1307 - acc: 0.9549 - val_loss: 0.5308 - val_acc: 0.8133
Epoch 5/20
17500/17500 [==============================] - 9s 533us/step - loss: 0.0938 - acc: 0.9678 - val_loss: 0.6205 - val_acc: 0.8093
Epoch 6/20
17500/17500 [==============================] - 9s 526us/step - loss: 0.0662 - acc: 0.9803 - val_loss: 0.7154 - val_acc: 0.8067
Epoch 7/20
17500/17500 [==============================] - 9s 514us/step - loss: 0.0492 - acc: 0.9860 - val_loss: 0.8079 - val_acc: 0.8025
Epoch 8/20
17500/17500 [==============================] - 9s 513us/step - loss: 0.0348 - acc: 0.9914 - val_loss: 0.9046 - val_acc: 0.8000
Epoch 9/20
17500/17500 [==============================] - 9s 533us/step - loss: 0.0254 - acc: 0.9939 - val_loss: 0.9714 - val_acc: 0.8028





<keras.callbacks.History at 0x7f5cad8a9ef0>

5.1 结论

  • 在验证集上准确率最高83%左右,但是一开始就过拟合了。

6. 全面升级:开发一个过拟合的模型

  • 基本模型已经过拟合,这一步省略

7. 调整参数

  • 尝试Dropout、Batch Normaliztion等
  • 添加L1/L2正则化
  • 尝试不同的网络结构,添加层或者删除层
  • 尝试不同的超参数,例如神经元的个数,batch_size等等
def create_model(dropout=0.2, L=0.001):
    model = models.Sequential()

    model.add(layers.Dense(16,
                           activation='relu', kernel_initializer='he_normal', 
                           input_shape=(max_features, )))
    model.add(layers.Dropout(dropout))

    model.add(layers.BatchNormalization()) 
    model.add(layers.Dense(16,activation='relu', kernel_initializer='he_normal'))
    model.add(layers.Dropout(dropout))

    model.add(layers.BatchNormalization()) 
    #model.add(layers.Dense(16, activation='relu', kernel_initializer='he_normal'))
    #model.add(layers.Dropout(dropout))

    model.add(layers.Dense(1, activation='sigmoid', kernel_initializer='he_normal'))
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
    return model
K.clear_session()
dropout = 0.5
L = 0.001
model = create_model(dropout, L)

# 回调函数
callback_list = [callbacks.EarlyStopping(monitor='val_loss', patience=7),
                callbacks.ModelCheckpoint('best_model.h5', save_best_only=True),
                callbacks.TensorBoard('./logs', histogram_freq=1)]

model.fit(x_train, y_train, epochs=50, batch_size=64, callbacks=callback_list, validation_split=0.2)
Train on 20000 samples, validate on 5000 samples
Epoch 1/50
20000/20000 [==============================] - 10s 478us/step - loss: 0.5407 - acc: 0.7213 - val_loss: 0.3559 - val_acc: 0.8792
Epoch 2/50
20000/20000 [==============================] - 9s 440us/step - loss: 0.3554 - acc: 0.8560 - val_loss: 0.2814 - val_acc: 0.8870
Epoch 3/50
20000/20000 [==============================] - 8s 408us/step - loss: 0.2865 - acc: 0.8887 - val_loss: 0.2762 - val_acc: 0.8900
Epoch 4/50
20000/20000 [==============================] - 8s 423us/step - loss: 0.2525 - acc: 0.9026 - val_loss: 0.2788 - val_acc: 0.8894
Epoch 5/50
20000/20000 [==============================] - 8s 398us/step - loss: 0.2193 - acc: 0.9168 - val_loss: 0.2879 - val_acc: 0.8910
Epoch 6/50
20000/20000 [==============================] - 8s 424us/step - loss: 0.2068 - acc: 0.9197 - val_loss: 0.2954 - val_acc: 0.8884
Epoch 7/50
20000/20000 [==============================] - 8s 404us/step - loss: 0.1847 - acc: 0.9280 - val_loss: 0.3156 - val_acc: 0.8872
Epoch 8/50
20000/20000 [==============================] - 8s 418us/step - loss: 0.1756 - acc: 0.9320 - val_loss: 0.3034 - val_acc: 0.8876
Epoch 9/50
20000/20000 [==============================] - 9s 444us/step - loss: 0.1626 - acc: 0.9365 - val_loss: 0.3270 - val_acc: 0.8862
Epoch 10/50
20000/20000 [==============================] - 10s 479us/step - loss: 0.1655 - acc: 0.9342 - val_loss: 0.3384 - val_acc: 0.8844





<keras.callbacks.History at 0x7fa095863c18>
# 导入最优模型
model = models.load_model('best_model.h5')
model.evaluate(x_test, y_test, batch_size=64)
25000/25000 [==============================] - 3s 133us/step


[0.29331848189353943, 0.87672000013351437]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值