对任意shape的label用numpy进行one_hot编码

def get_one_hot(labels, nb_classes):
    res = np.eye(nb_classes)[np.array(labels).reshape(-1)]
    return res.reshape(list(labels.shape)+[nb_classes])

解释:

  1. np.array(labels).reshape(-1)是将labels展平, 比如将[[2,1],[3,2],[0,0]](shape为[3,2])展平为[2,1,3,2,0,0](shape为[6, ])
  2. np.eye(nb_classes)生成对角矩阵, 左上-右下对角线上的值为1, 其余为0
  3. res = np.eye(nb_classes)[np.array(labels).reshape(-1)]根据展平后的结果, 取np.eyes()中的对应行, 得到新的矩阵
  4. 最后把res按照labels的shape复原, 增加一个nb_class维度
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值