Tensorflow 和 Keras 中的 Dropout 实现

2025年6月18日 | 阅读 17 分钟

深度神经网络的模型参数非常多。通常需要数万甚至数百万个参数。这些因素极大地增强了它们学习各种复杂数据集的能力。但这并非总是好事。过拟合,即测试集性能差而训练集性能高(偏差低,方差高)的情况,常常是这种能力的结果。由于模型过度依赖训练数据,测试误差率可能会更高。为了避免这种情况,我们尝试通过各种正则化策略来减弱模型的学习能力。Dropout 就是这样一种正则化技术。当应用于未知数据时,正则化可以确保模型按预期运行。

Dropout Implementation in Tensorflow and Keras

上图说明了过拟合模型(绿色边际)与正则化模型(黑色边际)之间的区别。在未知样本(测试集)上,绿色边际不太可能表现得很好,尽管它似乎更适合训练数据。上图是过拟合的一个很好的例子。

深度神经网络中的 Dropout

Dropout 是在神经网络中故意消除的数据或噪声,以加快处理和周转时间。

2012 年,Geoffrey Hinton 在其论文《通过防止特征检测器的协同适应来改进神经网络》中提出了一种非常流行的正则化技术。这个概念非常简单但却非常强大。

在每个训练步骤中,每个神经元都有一个概率“p”,可以暂时停止训练(也称为“掉出”)。在这种情况下,“p”代表可调超参数,称为 dropout 率。

Dropout Implementation in Tensorflow and Keras

上图显示了 dropout 实现对网络连接的影响。(左)密集连接的标准前馈网络。(右)Dropout 显著减少了连接数量。来源:“Dropout:一种防止神经网络过拟合的简单方法”。

例如,如果 p=0.5,则一个神经元在每个 epoch 中有 50% 的概率掉出。如果一个神经元跳过训练阶段,下游层将受到影响,因为它的所有连接都会断开。结果,神经网络的连接密度将大大降低(图 2)。输出层不受 dropout 的影响,而输入层和隐藏层受影响。这是因为为了便于训练,模型必须不断为损失函数产生输出。只有在训练阶段使用 dropout 程序。在推理阶段,网络中的每个神经元都完全参与。

随机关闭神经元可能非常令人震惊。人们可能会认为这会导致训练过程中的极端不稳定。然而,在实践中,它被证明在简化模型方面非常有效。

如果让员工每天早上抛硬币来决定是否上班,公司的绩效会提高吗?你永远不知道,也许会!不用说,公司必须改变组织方式;它不能依赖某个人来补充咖啡机或执行任何其他关键任务。因此,这些知识需要分布在不同的人员之间。员工需要培养与众多同事合作的能力,而不仅仅是少数几个人。公司将变得更加健壮。如果一个人离开,差别不会太大。尽管这种概念是否真的适用于企业还有待商榷,但神经网络肯定可以。

深度神经网络的网络架构在训练过程中也会因 epoch 而异。此外,每个神经元都必须关注其所有输入,而不是过度依赖少量输入连接。因此,它们对输入连接的更改具有更强的弹性。

这样,它就保证了一个更强大的网络,具有更好的泛化能力。

dropout 率 p 是 dropout 中可控的超参数。调整它相当简单。

  • 当你的模型过拟合时,增加 p。
  • 当你的模型欠拟合时,减小 p。
  • 对于大层,保持较高;对于小层,保持较低。

Dropout 层

神经网络的单元在任何一个时刻都会随机处理无数输入,然后发出无数输出,就像人脑中的神经元一样。在产生最终输出或结论之前很久,每个单元的处理和输出可能是中间输出的放电,发送到另一个单元进行进一步处理。这种处理的一部分会产生噪声,这是处理操作的中间但非最终的输出。

数据科学家在将 dropout 应用于神经网络时会考虑到这种随机处理的性质。在决定消除哪些数据噪声后,他们按照以下方式将 dropout 应用于各个神经网络层:

  • 输入层:这是机器学习和人工智能 (AI) 的最高级别,其中吸收了第一个原始数据。根据认为与所考虑的业务问题无关的数据,可以将 dropout 应用于此可见数据层。
  • 中间层:这些是数据摄入后的处理层。由于我们无法确切地看到这些层的作用,因此它们被隐藏起来。处理输入后,这些层(可能包含一个或多个)会将中间但非最终的输出发送到其他神经元以进行进一步处理。数据科学家使用 dropout 来消除部分中间处理,因为其中大部分会变成噪声。
  • 输出层:这是所有神经元单元的最终、可见的处理输出。此层不使用 dropout。

Keras 中的 Dropout 实现

Dropout 是 Keras 的核心层之一

keras.layers.Dropout( rate, noise_shape = None, seed = None)

它可以包含在用 Keras 构建的深度学习模型中。包括以下属性:

  • Rate:建立神经元掉出概率的因子。请记住,如果您没有使用验证集来确定哪个最适合您(请记住,Keras 反转了逻辑,使其成为掉出而非保留神经元的概率!),最好为隐藏层设置 rate=0.5,为输入层设置 rate = 0.1。
  • Noise shape:如果您愿意,可以将 noise shape 设置为在批次、时间步长或特征上散布噪声。
  • Seed:如果您想固定确定 Bernoulli 变量是 1 还是 0 的伪随机生成器,您可以通过在此处放置一个整数值来设置一些种子(例如,以排除与数字生成器相关的问题)。

带 dropout 的卷积神经网络分类器设计

现在,让我们看看如何使用 Dropout 来减少使用 Keras 构建的神经网络中的过拟合。为此,我们将创建一个用于图像分类的卷积神经网络。然后,我们将介绍模型的架构和我们今天使用的数据集。

CIFAR-10 是常见的机器学习数据集之一,它包含一万个小型自然图像的十个类别。例如,它包含船只、车辆和动物的图像。当您想演示特定模型如何工作时,这是默认选项之一。

导入所需模块

代码

将从 Keras 深度学习框架中使用多个功能。我们从 keras.datasets 导入 CIFAR-10 数据集。这是一个方便的快捷方式:只需几行代码,您就可以导入 MNIST 和 CIFAR-10 等数据集,因为 Keras 拥有与它们的 API 链接。这使我们能够完全专注于构建模型,而不是陷入大量数据加载任务。

从 keras.layers 导入 Dense(密集连接层类型)、Dropout(用于正则化)、Flatten(连接卷积层和 Dense 层)以及 Conv2D 和 MaxPooling2D(卷积和相关层)。

此外,我们从 keras.models 导入 Sequential 模型,该模型允许我们将层整齐地堆叠在一起。

然后导入 Keras 后端用于一些数据预处理功能。

导入 max_norm Constraints 可以大大增强模型,这是 Dropout 的推荐做法。

模型配置

我们将 img_width = img_height = 32,因为 CIFAR-10 样本的宽度和高度为 32 像素。批次大小设置为 250,根据我的模型,实验表明它最适合 CIFAR-10。我选择使用 55 个 epoch,因为正如我们将看到的,到那时,dropout 和 no dropout 之间的差异将相当明显。

CIFAR-10 数据集支持 10 个类别,这是我们的模型可以处理的类别数(no_classes)。当详细模式设置为 1(或 True)时,所有输出都会发送到屏幕。对于验证,将使用 20% 的训练数据。

Max_norm_value 最后设置为 2.0。使用 MaxNorm Keras 约束,此数字表示可用于 max-norm 正则化的最大范数。我通过经验发现,2.0 对于今天的模型来说是一个公平的值。但是,如果您将其与另一个模型和/或数据集一起使用,您将需要进行一些实验以自行找到合适的值。

数据加载和准备

为了加载和准备 CIFAR-10 数据集,需要执行以下操作:

使用 Keras 的 load_data 方法,可以轻松地将 CIFAR-10 加载到训练集和测试集的特征和目标变量中。

数据加载后,我们会根据我们使用的后端(如 CNTK、Tensorflow 或 Theano)对其进行重塑,以便数据具有一致的结构,无论后端如何。

然后,将数字解析为浮点数,这似乎会加快训练过程。然后我们标准化数据,这 神经网络 发现很有用。为了确保分类交叉熵损失适用于我们的多类分类问题,我们最后使用 to_categorical。

定义架构

加载数据后,我们可以指定架构。

这与我们之前检查的架构图一致。它具有 Softmax 激活函数,为样本生成多类概率分布,并包含两个 Conv2D 和相关层以及两个 Dense 层。

训练和编译

下一步是编译模型。您可以在编译或配置模型时定义损失函数、优化器以及准确率等其他指标。如前所述,我们通过分类交叉熵损失计算实际目标和预测之间的差异。我们还使用 Adam 优化器,它基本上是当今最常见的优化器之一。

配置好模型后,我们就可以将训练数据拟合到模型中!为此,我们设置 input_train 和 target_train 变量,以及验证分割、详细模式、批次大小和 epoch 数量。它们的值是预先确定的。

模型评估

将评估指标添加到测试集是最后一步,它决定了模型在未见过的数据上的泛化能力。我们现在可以比较不同的模型,这将在下一步进行。

输出

 
Epoch 1/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 105s 641ms/step - accuracy: 0.1872 - loss: 2.3361 - val_accuracy: 0.4160 - val_loss: 1.7017
Epoch 2/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 140s 627ms/step - accuracy: 0.4234 - loss: 1.5871 - val_accuracy: 0.5080 - val_loss: 1.4064
Epoch 3/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 144s 638ms/step - accuracy: 0.5003 - loss: 1.3878 - val_accuracy: 0.5746 - val_loss: 1.2713
Epoch 4/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 149s 688ms/step - accuracy: 0.5458 - loss: 1.2734 - val_accuracy: 0.6004 - val_loss: 1.1754
Epoch 5/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 690ms/step - accuracy: 0.5800 - loss: 1.1841 - val_accuracy: 0.6275 - val_loss: 1.0953
Epoch 6/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 135s 645ms/step - accuracy: 0.6005 - loss: 1.1253 - val_accuracy: 0.6485 - val_loss: 1.0370
Epoch 7/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 103s 647ms/step - accuracy: 0.6223 - loss: 1.0616 - val_accuracy: 0.6577 - val_loss: 1.0075
Epoch 8/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 654ms/step - accuracy: 0.6388 - loss: 1.0198 - val_accuracy: 0.6691 - val_loss: 0.9689
Epoch 9/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 108s 678ms/step - accuracy: 0.6570 - loss: 0.9738 - val_accuracy: 0.6860 - val_loss: 0.9372
Epoch 10/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 685ms/step - accuracy: 0.6668 - loss: 0.9516 - val_accuracy: 0.6919 - val_loss: 0.9016
Epoch 11/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 137s 652ms/step - accuracy: 0.6837 - loss: 0.9042 - val_accuracy: 0.6973 - val_loss: 0.8813
Epoch 12/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 144s 666ms/step - accuracy: 0.6978 - loss: 0.8595 - val_accuracy: 0.7033 - val_loss: 0.8643
Epoch 13/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 115s 718ms/step - accuracy: 0.7065 - loss: 0.8318 - val_accuracy: 0.7166 - val_loss: 0.8444
Epoch 14/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 134s 670ms/step - accuracy: 0.7112 - loss: 0.8140 - val_accuracy: 0.7178 - val_loss: 0.8260
Epoch 15/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 138s 648ms/step - accuracy: 0.7252 - loss: 0.7791 - val_accuracy: 0.7149 - val_loss: 0.8278
Epoch 16/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 144s 661ms/step - accuracy: 0.7350 - loss: 0.7497 - val_accuracy: 0.7284 - val_loss: 0.8028
Epoch 17/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 662ms/step - accuracy: 0.7487 - loss: 0.7247 - val_accuracy: 0.7294 - val_loss: 0.7932
Epoch 18/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 658ms/step - accuracy: 0.7507 - loss: 0.7078 - val_accuracy: 0.7328 - val_loss: 0.7885
Epoch 19/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 102s 637ms/step - accuracy: 0.7596 - loss: 0.6837 - val_accuracy: 0.7358 - val_loss: 0.7725
Epoch 20/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 642ms/step - accuracy: 0.7668 - loss: 0.6622 - val_accuracy: 0.7337 - val_loss: 0.7823
Epoch 21/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 648ms/step - accuracy: 0.7670 - loss: 0.6547 - val_accuracy: 0.7409 - val_loss: 0.7598
Epoch 22/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 144s 661ms/step - accuracy: 0.7766 - loss: 0.6400 - val_accuracy: 0.7331 - val_loss: 0.7781
Epoch 23/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 138s 638ms/step - accuracy: 0.7826 - loss: 0.6194 - val_accuracy: 0.7428 - val_loss: 0.7600
Epoch 24/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 643ms/step - accuracy: 0.7824 - loss: 0.6136 - val_accuracy: 0.7438 - val_loss: 0.7518
Epoch 25/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 103s 644ms/step - accuracy: 0.7884 - loss: 0.5990 - val_accuracy: 0.7436 - val_loss: 0.7569
Epoch 26/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 644ms/step - accuracy: 0.7902 - loss: 0.5956 - val_accuracy: 0.7461 - val_loss: 0.7596
Epoch 27/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 103s 644ms/step - accuracy: 0.8028 - loss: 0.5629 - val_accuracy: 0.7417 - val_loss: 0.7673
Epoch 28/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 643ms/step - accuracy: 0.8041 - loss: 0.5529 - val_accuracy: 0.7426 - val_loss: 0.7552
Epoch 29/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 103s 644ms/step - accuracy: 0.8043 - loss: 0.5504 - val_accuracy: 0.7473 - val_loss: 0.7561
Epoch 30/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 148s 681ms/step - accuracy: 0.8088 - loss: 0.5415 - val_accuracy: 0.7387 - val_loss: 0.7615
Epoch 31/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 134s 635ms/step - accuracy: 0.8147 - loss: 0.5270 - val_accuracy: 0.7489 - val_loss: 0.7465
Epoch 32/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 101s 631ms/step - accuracy: 0.8177 - loss: 0.5213 - val_accuracy: 0.7416 - val_loss: 0.7610
Epoch 33/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 141s 626ms/step - accuracy: 0.8191 - loss: 0.5140 - val_accuracy: 0.7483 - val_loss: 0.7539
Epoch 34/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 151s 682ms/step - accuracy: 0.8252 - loss: 0.5014 - val_accuracy: 0.7476 - val_loss: 0.7492
Epoch 35/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 135s 638ms/step - accuracy: 0.8237 - loss: 0.4978 - val_accuracy: 0.7475 - val_loss: 0.7554
Epoch 36/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 101s 629ms/step - accuracy: 0.8230 - loss: 0.4999 - val_accuracy: 0.7482 - val_loss: 0.7547
Epoch 37/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 630ms/step - accuracy: 0.8296 - loss: 0.4896 - val_accuracy: 0.7460 - val_loss: 0.7635
Epoch 38/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 638ms/step - accuracy: 0.8339 - loss: 0.4729 - val_accuracy: 0.7462 - val_loss: 0.7728
Epoch 39/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 147s 670ms/step - accuracy: 0.8287 - loss: 0.4837 - val_accuracy: 0.7481 - val_loss: 0.7474
Epoch 40/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 135s 625ms/step - accuracy: 0.8310 - loss: 0.4791 - val_accuracy: 0.7460 - val_loss: 0.7637
Epoch 41/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 626ms/step - accuracy: 0.8377 - loss: 0.4668 - val_accuracy: 0.7440 - val_loss: 0.7658
Epoch 42/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 101s 635ms/step - accuracy: 0.8431 - loss: 0.4521 - val_accuracy: 0.7499 - val_loss: 0.7442
Epoch 43/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 140s 623ms/step - accuracy: 0.8431 - loss: 0.4486 - val_accuracy: 0.7495 - val_loss: 0.7578
Epoch 44/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 624ms/step - accuracy: 0.8436 - loss: 0.4444 - val_accuracy: 0.7509 - val_loss: 0.7523
Epoch 45/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 624ms/step - accuracy: 0.8460 - loss: 0.4368 - val_accuracy: 0.7560 - val_loss: 0.7537
Epoch 46/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 622ms/step - accuracy: 0.8416 - loss: 0.4434 - val_accuracy: 0.7491 - val_loss: 0.7621
Epoch 47/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 624ms/step - accuracy: 0.8466 - loss: 0.4329 - val_accuracy: 0.7522 - val_loss: 0.7542
Epoch 48/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 623ms/step - accuracy: 0.8476 - loss: 0.4312 - val_accuracy: 0.7477 - val_loss: 0.7695
Epoch 49/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 100s 623ms/step - accuracy: 0.8501 - loss: 0.4212 - val_accuracy: 0.7433 - val_loss: 0.7813
Epoch 50/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 629ms/step - accuracy: 0.8443 - loss: 0.4364 - val_accuracy: 0.7465 - val_loss: 0.7710
Epoch 51/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 141s 623ms/step - accuracy: 0.8469 - loss: 0.4328 - val_accuracy: 0.7543 - val_loss: 0.7703
Epoch 52/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 625ms/step - accuracy: 0.8519 - loss: 0.4188 - val_accuracy: 0.7497 - val_loss: 0.7562
Epoch 53/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 101s 633ms/step - accuracy: 0.8525 - loss: 0.4174 - val_accuracy: 0.7516 - val_loss: 0.7620
Epoch 54/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 141s 630ms/step - accuracy: 0.8548 - loss: 0.4168 - val_accuracy: 0.7563 - val_loss: 0.7526
Epoch 55/55
160/160 ━━━━━━━━━━━━━━━━━━━━ 141s 624ms/step - accuracy: 0.8519 - loss: 0.4117 - val_accuracy: 0.7503 - val_loss: 0.7769
Test loss: 0.797421395778656 / Test accuracy: 0.7391999959945679   

Keras 中的 Dropout 实现

导入必要的库

首先应导入 TensorFlow 和其他必要的库。

代码

准备数据集

此示例将使用 MNIST 数据集。

使用 Dropout 构建模型

编译模型

使用正确的优化器、损失函数和指标来编译模型。

模型训练

使用训练数据来训练模型。此外,我们将使用测试数据来验证模型。

评估模型

最后,使用测试数据集评估模型的性能。

可视化

在训练和验证过程中,查看准确率和损失,以了解模型在 epoch 期间的表现。

输出

 
Epoch 1/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 57s 110ms/step - accuracy: 0.7675 - loss: 0.7137 - val_accuracy: 0.9797 - val_loss: 0.0642
Epoch 2/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 50s 106ms/step - accuracy: 0.9641 - loss: 0.1198 - val_accuracy: 0.9877 - val_loss: 0.0387
Epoch 3/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 81s 104ms/step - accuracy: 0.9724 - loss: 0.0905 - val_accuracy: 0.9868 - val_loss: 0.0374
Epoch 4/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 80s 102ms/step - accuracy: 0.9780 - loss: 0.0732 - val_accuracy: 0.9891 - val_loss: 0.0305
Epoch 5/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 83s 104ms/step - accuracy: 0.9800 - loss: 0.0678 - val_accuracy: 0.9911 - val_loss: 0.0268
Epoch 6/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 48s 102ms/step - accuracy: 0.9840 - loss: 0.0553 - val_accuracy: 0.9915 - val_loss: 0.0244
Epoch 7/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 83s 104ms/step - accuracy: 0.9846 - loss: 0.0505 - val_accuracy: 0.9915 - val_loss: 0.0231
Epoch 8/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 48s 102ms/step - accuracy: 0.9859 - loss: 0.0455 - val_accuracy: 0.9920 - val_loss: 0.0228
Epoch 9/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 83s 104ms/step - accuracy: 0.9863 - loss: 0.0428 - val_accuracy: 0.9922 - val_loss: 0.0225
Epoch 10/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 81s 101ms/step - accuracy: 0.9876 - loss: 0.0415 - val_accuracy: 0.9924 - val_loss: 0.0212
313/313 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.9903 - loss: 0.0261
Test accuracy: 0.9923999905586243   

Dropout Implementation in Tensorflow and Keras