实现汇率预测LSTM网络的代码
14.1 lstm1
程序:
import tensorflow as tf
import numpy as np
import pandas as pd
import sys
roundT = 100#训练轮次
learnRateT = 0.001#优化器学习率
argt = sys.argv[1:]
print("argt: %s" % argt)
for v in argt:
if v.startswith("-round="):
roundT = int(v[len("-round="):])
if v.startswith("-learnrate="):
learnRateT = float(v[len("-learnrate="):])
fileData = pd.read_csv('exchangeData.txt', dtype=np.float32, header=None)
wholeData = np.reshape(fileData.as_matrix(), (-1))#(-1) 把以前的二维矩阵拉平
print("wholeData: %s" % wholeData)
cellCount = 3#LSTM层中结构元的数量,3天为一批
unitCount = 5#每个结构元中神经元节点数量
testData = wholeData[-cellCount:]#取最后三项数据
print("testData: %s\n" % testData)
rowCount = wholeData.shape[0] - cellCount
print("rowCount: %d\n" % rowCount)
xData = [wholeData[i:i + cellCount] for i in range(rowCount)]
yTrainData = [wholeData[i + cellCount] for i in range(rowCount)]
print("xData: %s\n" % xData)
print("yTrainData: %s\n" % yTrainData)
x = tf.placeholder(shape=[cellCount], dtype=tf.float32)
yTrain = tf.placeholder(dtype=tf.float32)
cellT = tf.nn.rnn_cell.BasicLSTMCell(unitCount)
initState = cellT.zero_state(1, dtype=tf.float32)
h, finalState = tf.nn.dynamic_rnn(cellT, tf.reshape(x, [1, cellCount, 1]), initial_state=initState, dtype=tf.float32)
hr = tf.reshape(h, [cellCount, unitCount])
w2 = tf.Variable(tf.random_normal([unitCount, 1]), dtype=tf.float32)
b2 = tf.Variable(0.0, dtype=tf.float32)
y = tf.reduce_sum(tf.matmul(hr, w2) + b2)
loss = tf.abs(y - yTrain)
optimizer = tf.train.RMSPropOptimizer(learnRateT)
train = optimizer.minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(roundT):
lossSum = 0.0
for j in range(rowCount):
result = sess.run([train, x, yTrain, y, h, finalState, loss], feed_dict={x: xData[j], yTrain: yTrainData[j]})
lossSum = lossSum + float(result[len(result) - 1])
if j == (rowCount - 1):
print("i: %d, x: %s, yTrain: %s, y: %s, h: %s, finalState: %s, loss: %s, avgLoss: %10.10f\n" % (i, result[1], result[2], result[3], result[4], result[5], result[6], (lossSum / rowCount)))
result = sess.run([x, y], feed_dict={x: testData})
print("x: %s, y: %s\n" % (result[0], result[1]))
结果:
argt: []
wholeData: [6.5379 6.5428 6.5559 6.5321 6.5062 6.5062 6.5062 6.5062 6.4909 6.5029
6.4933 6.4874 6.4874 6.4874 6.4973 6.5262 6.5054 6.5045 6.4606 6.4606
6.4349 6.4415 6.4329 6.4174 6.3989 6.3989 6.4034 6.4017 6.355 6.3188
6.3198 6.3198]
testData: [6.3188 6.3198 6.3198]
rowCount: 29
xData: [array([6.5379, 6.5428, 6.5559], dtype=float32), array([6.5428, 6.5559, 6.5321], dtype=float32), array([6.5559, 6.5321, 6.5062], dtype=float32), array([6.5321, 6.5062, 6.5062], dtype=float32), array([6.5062, 6.5062, 6.5062], dtype=float32), array([6.5062, 6.5062, 6.5062], dtype=float32), array([6.5062, 6.5062, 6.4909], dtype=float32), array([6.5062, 6.4909, 6.5029], dtype=float32), array([6.4909, 6.5029, 6.4933], dtype=float32), array([6.5029, 6.4933, 6.4874], dtype=float32), array([6.4933, 6.4874, 6.4874], dtype=float32), array([6.4874, 6.4874, 6.4874], dtype=float32), array([6.4874, 6.4874, 6.4973], dtype=float32), array([6.4874, 6.4973, 6.5262], dtype=float32), array([6.4973, 6.5262, 6.5054], dtype=float32), array([6.5262, 6.5054, 6.5045], dtype=float32), array([6.5054, 6.5045, 6.4606], dtype=float32), array([6.5045, 6.4606, 6.4606], dtype=float32), array([6.4606, 6.4606, 6.4349], dtype=float32), array([6.4606, 6.4349, 6.4415], dtype=float32), array([6.4349, 6.4415, 6.4329], dtype=float32), array([6.4415, 6.4329, 6.4174], dtype=float32), array([6.4329, 6.4174, 6.3989], dtype=float32), array([6.4174, 6.3989, 6.3989], dtype=float32), array([6.3989, 6.3989, 6.4034], dtype=float32), array([6.3989, 6.4034, 6.4017], dtype=float32), array([6.4034, 6.4017, 6.355 ], dtype=float32), array([6.4017, 6.355 , 6.3188], dtype=float32), array([6.355 , 6.3188, 6.3198], dtype=float32)]
yTrainData: [6.5321, 6.5062, 6.5062, 6.5062, 6.5062, 6.4909, 6.5029, 6.4933, 6.4874, 6.4874, 6.4874, 6.4973, 6.5262, 6.5054, 6.5045, 6.4606, 6.4606, 6.4349, 6.4415, 6.4329, 6.4174, 6.3989, 6.3989, 6.4034, 6.4017, 6.355, 6.3188, 6.3198, 6.3198]
i: 0, x: [6.355 6.3188 6.3198], yTrain: 6.3198, y: 0.11324437, h: [[[-0.5416661 0.00873393 -0.09274875 0.23681769 -0.13414243]
[-0.60828066 0.01499585 -0.1810691 0.3053319 -0.23316866]
[-0.62211126 0.01780411 -0.25426126 0.3208857 -0.28884727]]], finalState: LSTMStateTuple(c=array([[-0.8638425, 0.039619 , -0.2780444, 1.259472 , -0.8173182]],
dtype=float32), h=array([[-0.62211126, 0.01780411, -0.25426126, 0.3208857 , -0.28884727]],
dtype=float32)), loss: 6.2065554, avgLoss: 6.7232196413
i: 1, x: [6.355 6.3188 6.3198], yTrain: 6.3198, y: 0.90126705, h: [[[-0.5783776 0.00416238 -0.07909829 0.275007 -0.10082621]
[-0.665255 0.00692256 -0.15466554 0.36321616 -0.17512825]
[-0.68763715 0.00817272 -0.2174195 0.38527822 -0.2186326 ]]], finalState: LSTMStateTuple(c=array([[-0.98637336, 0.01926189, -0.23619439, 1.3259457 , -0.6522788 ]],
dtype=float32), h=array([[-0.68763715, 0.00817272, -0.2174195 , 0.38527822, -0.2186326 ]],
dtype=float32)), loss: 5.418533, avgLoss: 5.9221585208
...
i: 99, x: [6.355 6.3188 6.3198], yTrain: 6.3198, y: 6.3430643, h: [[[-7.13031709e-01 -1.07374154e-01 -1.39900476e-01 5.64226091e-01
-6.53162424e-04]
[-9.12595630e-01 -2.98668206e-01 -2.98295170e-01 8.00733328e-01
2.28596898e-03]
[-9.64956701e-01 -4.70945537e-01 -4.48750645e-01 8.70144904e-01
3.40454676e-03]]], finalState: LSTMStateTuple(c=array([[-2.3183544 , -0.6476579 , -0.5008161 , 2.362856 , 0.02529641]],
dtype=float32), h=array([[-0.9649567 , -0.47094554, -0.44875064, 0.8701449 , 0.00340455]],
dtype=float32)), loss: 0.023264408, avgLoss: 0.0433279070
x: [6.3188 6.3198 6.3198], y: 6.315945
用循环神经网络来进行自然语言处理
14.2 lang1
程序:
import gensim
sentences = [['this', 'is', 'a', 'hot', 'pie'], ['this', 'is', 'a', 'cool', 'pie'], ['this', 'is', 'a', 'red', 'pie'], ['this', 'is', 'not', 'a', 'cool', 'pie']]
model = gensim.models.Word2Vec(sentences, min_count=1)
print(model.wv['this'])
print(model.wv['is'])
print('vector size: ', len(model.wv['is']))
print(model.wv.similarity('this', 'is'))
print(model.wv.similarity('this', 'not'))
print(model.wv.similarity('this', 'a'))
print(model.wv.similarity('this', 'hot'))
print(model.wv.similarity('this', 'cool'))
print(model.wv.similarity('this', 'pie'))
print(model.wv.most_similar(positive=['cool', 'red'], negative=['this']))
结果:
[-0.00444952 -0.00301248 -0.00464311 0.00192921 -0.00406584 -0.00273407
-0.00137526 0.00102495 -0.00393232 -0.00480536 -0.00440357 -0.00316634
-0.00176818 -0.00236163 -0.00488326 -0.00206803 -0.00034063 -0.00051817
0.0011918 -0.00175033 -0.0045952 0.00284435 -0.00472876 -0.00148427
0.00497728 -0.0045706 -0.00062199 -0.00061593 0.00035512 0.00220891
-0.00162698 -0.00496666 0.00456833 -0.00263797 0.00453086 0.00246133
-0.00378933 -0.00457572 0.00435866 0.00180793 0.0028313 0.00437243
0.00115309 0.00206793 0.00069872 0.00075151 0.0043659 0.0044
0.00114415 0.00248385 0.00270567 -0.00312238 0.00467412 0.0012327
-0.00358159 -0.00369509 -0.00303954 0.00175266 -0.00300804 -0.00020176
0.00492243 0.00117108 0.0044823 -0.00486554 -0.00078731 0.00020269
-0.00457607 0.00477098 -0.00446973 -0.00394621 -0.00145939 -0.00249522
0.00240969 -0.00070635 0.0045018 0.0002413 0.00128478 0.00285358
-0.00271746 -0.00466185 -0.00405611 -0.00152602 -0.00061854 0.00416803
0.00082446 0.00282294 0.00479311 -0.00361287 -0.00249125 0.00459972
-0.00475971 0.00172046 0.00423085 -0.00262733 0.001689 -0.00054189
0.00455724 -0.00155868 0.00356341 -0.00446598]
[-4.5894543e-03 -1.9179031e-03 -2.8209835e-03 3.2639136e-03
1.8910058e-03 -2.7918709e-03 8.0866978e-04 4.5444611e-03
1.9958478e-03 -9.1378688e-04 2.2512334e-03 -4.5806249e-03
-1.0931346e-03 2.4494289e-03 -3.7682627e-03 1.0529665e-03
1.2763535e-03 4.0838700e-03 -3.6436628e-04 3.1723392e-03
-4.7087111e-03 -3.2658232e-03 -3.7687381e-03 -3.8879539e-03
-4.5129191e-03 -4.5577153e-03 2.8537826e-03 -9.5443864e-04
-3.5305852e-03 6.6543603e-04 -4.0941290e-03 2.6993607e-03
2.9571527e-03 4.2969892e-03 4.6353419e-03 4.1818171e-04
4.4170809e-03 3.9940095e-03 -1.0161035e-03 4.9301172e-03
7.5437099e-04 4.9338117e-03 -3.3024719e-03 -3.4118560e-03
2.7776335e-03 -2.1603450e-03 1.5431269e-03 -4.3047168e-03
4.0522730e-03 -2.1072607e-03 1.7592679e-04 -4.3802727e-03
3.9813002e-03 1.8843423e-03 -1.0693111e-03 2.9390331e-03
-4.8243362e-03 -9.8357757e-04 -1.1988253e-03 3.6845913e-03
-3.1838727e-03 -2.2128222e-03 1.9847061e-03 -4.5792544e-03
-2.7159147e-03 1.8701727e-03 1.3915236e-03 -3.5280767e-03
-1.8214776e-04 1.2135048e-03 4.4915513e-03 -4.2495159e-03
2.6700945e-04 1.6584441e-04 -4.3170187e-03 -3.6475111e-03
4.4147810e-03 3.4587677e-03 -8.6171669e-04 -3.3640568e-03
-1.9558756e-04 3.5415075e-03 -3.6160197e-04 -5.1331805e-04
-8.7224122e-04 4.4182967e-03 -9.2446664e-04 6.6160056e-04
2.1130587e-03 -7.0152535e-05 2.6635902e-03 -4.4978885e-03
4.0106787e-03 -2.1194320e-03 -4.6026031e-03 2.1868090e-03
7.0464576e-04 1.6734154e-03 -2.9403993e-03 -2.9848020e-03]
vector size: 100
0.1163308
0.146698
-0.009531897
-0.08159499
0.03128836
0.1995438
[('pie', 0.006718521937727928), ('hot', -0.007558740675449371), ('a', -0.09354205429553986), ('is', -0.10120982676744461), ('not', -0.12055482715368271)]