loss优化
- 这篇是我在看完bert预训练之后,bert的两个任务的loss直接加起来,之后去学习怎么调整loss
- 这里参考了天池比赛给出的loss调整方法:
第一篇
- 这篇文章是cv里面的,但是它讲了两种loss的组合方式,第一种是连续性的output(如输出某个物体的距离,文章中的depth regression),第二种是分类的output(如语义分割,每个像素点是不是边界)
- 假设我们的模型输出值是$\mathbf{f}^{\mathbf{W}}(\mathbf{x})$,这里的$\mathbf{x}$是模型的输入值,如
[batch_size, seq_len, hidden_size]
这种,$\mathbf{W}$是模型的参数 - 下面的$\mathbf{y_i} \ i=1,\dots,K$是每个任务,比如$\mathbf{y_1}$是回归任务,$\mathbf{y_2}$是分类任务。
- 这里有一个很强的假设,就是在给定输出值之后,每个任务的预测相互独立
- 对于回归问题,怎么计算这个$p\left(\mathbf{y}_{1} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right)$呢
- 对于分类问题,怎么计算这个$p\left(\mathbf{y}_{2} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right)$呢
- 换句话说,对于回归,假设了输出值的条件概率是正态分布;对于分类,输出值通过一个scaled的softmax
例子
- 具体来讲,假设我们的模型有两个任务,那么就有两个loss,依然假设$\mathbf{y_1}$是回归任务,$\mathbf{y_2}$是分类任务
- 那么整体的loss方程$\mathcal{L}\left(\mathbf{W}, \sigma_{1}, \sigma_{2}\right)$就是,
其中$\mathcal{L}_{1}(\mathbf{W})=\left|\mathbf{y}_{1}-\mathbf{f}^{\mathbf{W}}(\mathbf{x})\right|^{2}$,$\mathcal{L}_{2}(\mathbf{W})=-\log \operatorname{Softmax}\left(\mathbf{y}_{2}, \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right)$
注意上面的$\mathcal{L}_{2}(\mathbf{W})$应该是交叉熵函数,而且$ \mathbf{f}^{\mathbf{W}}(\mathbf{x})$没有scaled
- 那么这样就需要更新$\mathbf{W}$和两个$\sigma$参数了
代码
- 下面用torch写一份loss的代码,github上的代码其实是有问题的,它没有cross_entropy的内容
1 | log_var_a = torch.zeros((1,), requires_grad=True) |
- 这里贴一份
torch.nn.CrossEntropyLoss()
的源代码,疯狂嵌套!!!
1 | class CrossEntropyLoss(_WeightedLoss): |
第二篇
- 这篇文章的思想是,困难的任务优先处理,所以重点在于,怎么定义是困难的任务?
- 这里文章定义了总体的loss,总共有$|T|$个任务,其中$\mathrm{FL}\left(\bar{\kappa}_{t} ; \gamma_{t}\right)$就可以看作是各个任务的权重,就是一个Focal loss
- $\mathcal{L}_{t}^{*}(\cdot)$就是某个任务的常规loss,比如cross entropy
Focal loss是什么
- Focal loss主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。
- 正常的二分类cross entropy是:
- $loss =-y\log(p)-(1-y)\log(1-p)$,注意真正的loss只会有一个项,要么$y\log(p)$,要么$(1-y)\log(1-p)$
- cross entropy的问题在哪里呢?我们当然会希望,模型预测的$p$值能区分出来,比如$y=1$的样本,$p$值特别高,但是这个很难做到。如果$p$值在0.4~0.6之间,我们怎么判断这些样本是正样本或负样本?
- 苏剑林的博客提到,模型不要注意那些正样本且$p>0.5$、负样本且$p<0.5$的这些,即,已经预测得不错的样本,不要再关注了
- 那么怎么改正这个loss呢?注意下面的$\hat{y}=p$。类别不均衡本质上就是分类难度差异的体现
- 上面就是Focal loss
Focal loss相比正常的cross entropy,$y=1$的时候,多了一个$(1-\hat{y})^{\gamma}=(1-p)^{\gamma}$,当p值越小的时候,前面的项越大,相当于提高了这个样本的loss权重
这里的$\overline{\kappa_{t}}$定义如下,$\alpha$是一个超参数,$\kappa_{t}^{(\tau)}$定义为第$\tau$次的训练、第$t$个任务的某种性能参数(如正确率)
代码
1 | # 这里按照上面的理解,写一个loss的伪代码,其实可以封装成一个class |