这里针对的是处理二维矩阵:
# -*- coding: utf-8 -*-
import numpy as np
from keras.utils import to_categorical
def _OneHot_encode():
data = np.array([[0, 1, 2],
[3, 4, 5],
[7, 8, 9],
[10, 11, 12]])
print(data)
print(data.shape)
encoded_data = to_categorical(data)
print(encoded_data.shape)
print(encoded_data)
return encoded_data
def _OneHot_decode():
encoded_data = _OneHot_encode()
decoded_data = []
for i in range(encoded_data.shape[0]):
decoded = np.argmax(encoded_data[i], axis=1)
decoded_data.append(decoded)
decoded_data = np.array(decoded_data)
print(decoded_data.shape)
return decoded_data
if __name__ == '__main__':
decoded_data = _OneHot_decode()
print(decoded_data)
这里编码完是这样的:
[[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]]
然后解码完又回去了:
[[ 0 1 2]
[ 3 4 5]
[ 7 8 9]
[10 11 12]]