目的:
将one-hot的一组数据转为list数组,统计各个类别的数量
首先:
- 假设我们总共有5类样本,标签分别从0-4,如:label=[0,1,2,3,4]
- 标签转化为‘one-hot’的形式:0: [1, 0,0,0,0]
- 1: [0,1,0,0,0]
- 2: [0,0,1,0,0]
- 3: [0,0,0,1,0]
- 4: [0,0,0,0,1]
那么这样的一个one-hot的形式转为一维标签,就没有函数了。我们可以这样处理 :
- 遍历每一‘one-hot’,找出1的那个位置即可,numpy有函数argmax()实现了这个功能:
np.argmax(array):参数会返回数组指定维度上最大值的索引。这里每一个item是一维的,所以只会返回一维数组中最大值的索引,也就是标签的位置
- 最后使用列表表达式,将其转为列表: