代码地址:FMexample代码示例
1. 不能下载数据
mnist = fetch_mldata('MNIST original', data_home='./tmp')
fatch_mldata 是从mldata.org fetch数据,由于虚拟机终端不能联网的原因,不能从url获取数据。
因此,先下载MNIST_original.mat数据到本地。然后加载本地.mat格式的数据:
mnist = sio.loadmat(u'/home/song/MNIST_data/mnist-original.mat')
2. 数据维度错误
X_all = scale(mnist['data'][mask].astype(float))
这句报错,是因为数据维度对不上,更改为:
X_all = scale(np.transpose(mnist['data'])[mask[0]].astype(float))
主要到:mnist[‘data’],mnist[‘label’]都是numpy.array数据类型。
备注:上面报错内容,
3. TFFMClassifier输入数据类型错误
from tffm import TFFMClassifier
model = TFFMClassifier(
order=6,
rank=10,
optimizer=tf.train.AdamOptimizer(learning_rate=0.01),
n_epochs=100,
batch_size=-1,
init_std=0.001,
input_type='dense'
)
model.fit(X_tr, y_tr, show_progress=True)
在样例中,X_tr和y_tr是numpy.array格式,如果输入数据框格式,会有以上报错。