交叉熵函数的推导

2025年1月7日 | 阅读 4 分钟

引言

交叉熵(Cross Entropy)是数据科学和人工智能领域中的一个核心概念,它被用作衡量两个概率分布之间差异的关键指标。在分类任务中,它尤其重要,用于衡量模型预测的概率分布与真实概率分布之间的差异。交叉熵的核心思想源于信息论,它衡量使用一个基于另一种概率分布的编码,来表示来自一种概率分布的事件所需的平均比特数。

在机器学习中,尤其是在神经网络中,交叉熵被用作损失函数,指导模型在训练过程中的优化。该函数惩罚模型预测概率与真实标签之间的偏差,从而量化模型的性能。其数学公式涉及对正确类别的预测概率取负对数,强调正确预测,并对错误预测进行严厉惩罚。

推导交叉熵的动机

  • 评估模型的性能

在分类任务中,评估模型预测概率与实际类别标签分布的匹配程度至关重要。交叉熵提供了一种量化方法来衡量真实分布与预测概率之间的差异,从而可以评估模型性能。

  • 知识增益

交叉熵以信息论为基础,衡量用一个概率分布逼近另一个概率分布所获得的信息量。通过使用基于另一种分布的编码,它计算了编码一个分布中的事件所需的平均比特数。理解这个概念对于弄清楚分类算法在多大程度上捕捉了数据中的基本趋势至关重要。

  • 优化的目标

机器学习模型通过优化一个目标函数(也称为损失函数)来训练。由于交叉熵是凸函数且可微,因此在分类问题中常被用作损失函数。这使得它适用于梯度下降等基于梯度的优化方法。

  • 强调准确的预测

当预测结果非常错误,并且分配给正确类别的预测概率远离 1 时,交叉熵会严重惩罚此类预测。这种对准确预测的关注与分类任务的目标一致,即精确识别类别标签至关重要。

交叉熵关于 Logits 的导数

步骤 1:定义交叉熵损失

考虑单个样本的二元交叉熵损失

其中,y 是真实标签(0 或 1),^y 是预测为类别 1 的概率。

步骤 2:使用 Sigmoid 函数

假设预测概率 ^y 是通过将 Sigmoid 函数应用于 Logits z 获得的

步骤 3:用 Logits 表示交叉熵

将 Sigmoid 函数代入交叉熵损失

步骤 4:应用链式法则

应用链式法则计算交叉熵关于 Logits z 的导数

步骤 5:计算偏导数

步骤 6:合并和简化

步骤 7:进一步简化

合并分数并简化表达式

交叉熵关于概率的导数

为了推导交叉熵损失关于概率的梯度,我们考虑一个二元分类场景。我们将真实标签表示为 y(0 或 1),预测概率表示为 ^y。交叉熵损失由下式给出

现在,我们将计算该损失关于 ^y 的导数。

步骤 1:表示交叉熵损失

步骤 2:应用链式法则

应用链式法则计算交叉熵关于 ^y 的导数

步骤 3:计算偏导数

步骤 4:合并偏导数

步骤 5:公分母

为了合并各项,找到一个公分母

步骤 6:合并和简化

简化表达式

实际应用

  • 神经网络的反向传播训练

在许多问题中,交叉熵通常被用作损失函数,尤其是在神经网络中。在训练过程中,利用链式法则来计算交叉熵损失关于模型参数(权重和偏置)的导数。通过以最小化交叉熵损失的方式调整模型参数,梯度下降等优化过程就会受到这些导数的指导,从而提高模型的预测精度。

  • 多类和二元分类

对于二元分类,计算交叉熵关于 Logits 的导数有助于更新模型的参数,并提高其区分两个类别之间的能力。对于多类分类,通过将导数扩展到处理多个类别,模型可以被指导着准确地将样本分类到正确的类别中。

  • 神经网络中的 Softmax 激活

在神经网络的输出层,交叉熵和 Softmax 激活函数常常被用于多类分类。交叉熵关于 Logits 的导数对于在整个网络中反向传播误差并调整训练过程中的偏置和权重至关重要。

  • NLP,即自然语言处理

在神经网络的输出层,交叉熵和 Softmax 激活函数常常被用于多类分类。交叉熵关于 Logits 的导数对于在整个网络中反向传播误差并调整训练过程中的偏置和权重至关重要。

  • 通过强化学习

在涉及强化学习的某些情况,尤其是策略优化中,交叉熵被用作损失函数。交叉熵的导数有助于更新策略参数,以改进环境的导航。

  • 异常检测

当出现与预期概率分布的偏差,显示出异常时,可以使用交叉熵来检测异常。利用导数来指导参数调整,以提高模型检测数据中异常情况的能力。