一文搞懂F.cross_entropy中的weight参数

本文介绍了在样本不平衡的情况下,如何利用Pytorch中的F.cross_entropy损失函数的weight参数来调整不同类别的权重,以解决类别不平衡问题。通过实例展示了加权后损失值增大的现象,并解析了weight参数如何影响损失计算,从而验证了加权的正确性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

交叉熵是在分类任务中常用的损失函数,对于样本均衡的分类任务我们可以直接使用。但当我们面对样本类别失衡的情况时,导致训练过程中的损失被数据量最多的类别的主导,从而导致模型不能被有效的训练。我们需要通过为不同的样本损失赋予不同的权重以平衡不同类别间数据量的差异。这时了解一下F.cross_entrpy中的weight参数的底层是如何实现的,是非常有必要的!
关于Pytorch中F.cross_entropy详细的实现过程请看https://blog.youkuaiyun.com/code_plus/article/details/115420223这篇博客。
下面分析weight是如何作用的。
1、数据准备

input = torch.tensor([[[[0.5546, 0.1304, 0.9288],
                        [0.6879, 0.3553, 0.9984],
                        [0.1474, 0.6745,0.8948]],
		               [[0.8524, 0.2278, 0.6476],
                        [0.6203, 0.6977, 0.3352],
                        [0.4946, 0.4613, 0.6882]]]])
target = torch.tensor([[[0, 0, 0],
                        [0, 0, 0],
                        [0, 0, 1]]])
weight = torch.tensor([1.0, 9.0])

2、Pytorch中的实现结果
首先我们看一下普通的交叉熵和加权的交叉熵的结果,为了方便结果直接注释在了代码中。

loss = F.cross_entropy(input, target)
print(loss)  # tensor(0.7099)
weight = torch.tensor([1.0, 9.0])
loss = F.cross_entropy(input, target, weight)
print(loss)  # tensor(0.7531)

从交叉熵的结果熵我们发现,加权后的交叉熵在数值上变大了,然后可能就会有人问,那结果是不是一定会变大呢?(当然,我们理想的结果就是让损失值变大)
这个要分析数据的类别失衡问题,我本人目前在做变化检测,对于变化检测来说,变化的像素是极少数的,而未变化的像素是多数的。以上面input的数据为例,假设未变化的类别标签是0,变化的类别标签是1,可以发现0占了大多数,变化和未变化的样本数是失衡的,然后我们给的权重是[1.0,9.0]也就是说我们加大了变化样本(类别标签是1)的权重,然后最终得到的损失是放大的。
总结:当我们增加少数样本的权重时,计算出的损失值应该是放大的。

3、weight参数是如何作用的?
交叉熵的实现过程请看上篇文章,下面加权的交叉熵代码和未加权的交叉熵在实现上只是torch.log前面多加了一个weight[target[b][i][j]]参数,就这么简单。

input = F.softmax(input, dim=1)
loss = 0.0
for b in range(target.shape[0]):
    for i in range(target.shape[1]):
        for j in range(target.shape[2]):
            # loss -= torch.log(input[b][target[b][i][j]][i][j])
            loss -= weight[target[b][i][j]] * torch.log(input[b][target[b][i][j]][i][j])
print(loss/(1*8+9*1))

当我们在计算损失前,当遇到真实标签是0的样本时,我们就乘上我们为它附上的权重,也就是1;当我们遇到真实标签是1的样本时,我们就乘上给该类样本附的权重,也就是9。在对所用样本计算完损失后,再求平均,也就是loss/(1*8+9*1):这个含义是input数据中标签为0样本的个数是8,权重是1,即1个类别为0的样本我们仍然将其视为1个;标签为1的样本的个数是1,权重是9,即1个类别为1的样本我们将其视为9个;这样总的样本数就是1*8+9*1=17个,在对loss求平均,也就是最终的损失。该值也是0.7531,确认我们的分析是没有问题的。

4、总结
weight的作用就是扩大不同类别样本的个数(同时,总的样本个数也跟随着扩大了)。

注:如有错误还请指出!

### 反向传播算法概念 反向传播算法是一种用于训练人工神经网络的监督学习方法。此算法的核心在于利用链式法则计算损失函数相对于各权重的梯度,从而调整这些权重以最小化预测误差[^1]。 ### 原理阐述 #### 正向传递阶段 在网络接收输入数据并进行正向传播的过程中,每层节点依据前一层传来的加权和加上偏置项后的激活值作为自身的输出。这一过程持续至最终输出层产生模型对于给定样本集的预测结果[^4]。 ```python def forward_pass(input_data, weights, biases): activations = input_data for w_layer, b_layer in zip(weights[:-1], biases[:-1]): z = np.dot(w_layer.T, activations) + b_layer activations = sigmoid(z) # For the output layer using softmax activation function. final_z = np.dot(weights[-1] predictions = softmax(final_z) return predictions ``` #### 计算误差 当获得预测值之后,会将其与实际标签对比得出差异程度——即所谓的“误差”。通常采用均方差(MSE)或交叉熵(Cross Entropy Loss)等形式表达这种差距大小。 #### 反向传播阶段 一旦确定了整体系统的总误差,则可以开始执行真正的核心部分—反向传播: - **从末端到前端逐步回溯**:按照从输出层往回走的方式依次处理各个隐藏单元; - **应用链式法则**:借助于微分学里的链式法则来解析当前层级上的局部变化如何影响全局性能指标; - **累积贡献量**:记录下每个连接边所携带的信息流强度及其方向性特征,以便后续更新操作时能够精准定位责任归属[^3]。 ```python def backward_pass(predictions, actual_labels, weights, learning_rate=0.01): error = predictions - actual_labels # Calculate delta at output layer. deltas = [] for i in reversed(range(len(weights))): if i == len(weights)-1: delta = error * derivative_softmax(predictions) else: delta = (deltas[-1].dot(weights[i+1])) * derivative_sigmoid(layer_activations[i]) deltas.append(delta) deltas.reverse() gradients = [] for d, a in zip(deltas, [input_data]+layer_activations[:-1]): gradient = d[:,np.newaxis]*a[np.newaxis,:] gradients.append(gradient) updated_weights = [ weight - learning_rate*grad for grad, weight in zip(gradients, weights)] return updated_weights ``` #### 参数更新 最后一步便是根据之前收集好的信息完成对原始参数矩阵Wij以及biases的具体修正工作。一般而言,这涉及到简单的减去按比例缩放过的梯度向量。 ### 实现总结 综上所述,在构建一个多层感知机架构下的深度学习框架时,除了要精心设计好每一级之间的映射关系外,更重要的是掌握一套有效的优化策略使得整个体系能够在有限次数内收敛到较优解附近。而BP正是这样一种经典而又实用的技术手段之一[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值