Tensorflow 和 Keras 中的 Dropout 实现2025年6月18日 | 阅读 17 分钟 深度神经网络的模型参数非常多。通常需要数万甚至数百万个参数。这些因素极大地增强了它们学习各种复杂数据集的能力。但这并非总是好事。过拟合,即测试集性能差而训练集性能高(偏差低,方差高)的情况,常常是这种能力的结果。由于模型过度依赖训练数据,测试误差率可能会更高。为了避免这种情况,我们尝试通过各种正则化策略来减弱模型的学习能力。Dropout 就是这样一种正则化技术。当应用于未知数据时,正则化可以确保模型按预期运行。 ![]() 上图说明了过拟合模型(绿色边际)与正则化模型(黑色边际)之间的区别。在未知样本(测试集)上,绿色边际不太可能表现得很好,尽管它似乎更适合训练数据。上图是过拟合的一个很好的例子。 深度神经网络中的 DropoutDropout 是在神经网络中故意消除的数据或噪声,以加快处理和周转时间。 2012 年,Geoffrey Hinton 在其论文《通过防止特征检测器的协同适应来改进神经网络》中提出了一种非常流行的正则化技术。这个概念非常简单但却非常强大。 在每个训练步骤中,每个神经元都有一个概率“p”,可以暂时停止训练(也称为“掉出”)。在这种情况下,“p”代表可调超参数,称为 dropout 率。 ![]() 上图显示了 dropout 实现对网络连接的影响。(左)密集连接的标准前馈网络。(右)Dropout 显著减少了连接数量。来源:“Dropout:一种防止神经网络过拟合的简单方法”。 例如,如果 p=0.5,则一个神经元在每个 epoch 中有 50% 的概率掉出。如果一个神经元跳过训练阶段,下游层将受到影响,因为它的所有连接都会断开。结果,神经网络的连接密度将大大降低(图 2)。输出层不受 dropout 的影响,而输入层和隐藏层受影响。这是因为为了便于训练,模型必须不断为损失函数产生输出。只有在训练阶段使用 dropout 程序。在推理阶段,网络中的每个神经元都完全参与。 随机关闭神经元可能非常令人震惊。人们可能会认为这会导致训练过程中的极端不稳定。然而,在实践中,它被证明在简化模型方面非常有效。 如果让员工每天早上抛硬币来决定是否上班,公司的绩效会提高吗?你永远不知道,也许会!不用说,公司必须改变组织方式;它不能依赖某个人来补充咖啡机或执行任何其他关键任务。因此,这些知识需要分布在不同的人员之间。员工需要培养与众多同事合作的能力,而不仅仅是少数几个人。公司将变得更加健壮。如果一个人离开,差别不会太大。尽管这种概念是否真的适用于企业还有待商榷,但神经网络肯定可以。 深度神经网络的网络架构在训练过程中也会因 epoch 而异。此外,每个神经元都必须关注其所有输入,而不是过度依赖少量输入连接。因此,它们对输入连接的更改具有更强的弹性。 这样,它就保证了一个更强大的网络,具有更好的泛化能力。 dropout 率 p 是 dropout 中可控的超参数。调整它相当简单。
Dropout 层神经网络的单元在任何一个时刻都会随机处理无数输入,然后发出无数输出,就像人脑中的神经元一样。在产生最终输出或结论之前很久,每个单元的处理和输出可能是中间输出的放电,发送到另一个单元进行进一步处理。这种处理的一部分会产生噪声,这是处理操作的中间但非最终的输出。 数据科学家在将 dropout 应用于神经网络时会考虑到这种随机处理的性质。在决定消除哪些数据噪声后,他们按照以下方式将 dropout 应用于各个神经网络层:
Keras 中的 Dropout 实现Dropout 是 Keras 的核心层之一 keras.layers.Dropout( rate, noise_shape = None, seed = None) 它可以包含在用 Keras 构建的深度学习模型中。包括以下属性:
带 dropout 的卷积神经网络分类器设计现在,让我们看看如何使用 Dropout 来减少使用 Keras 构建的神经网络中的过拟合。为此,我们将创建一个用于图像分类的卷积神经网络。然后,我们将介绍模型的架构和我们今天使用的数据集。 CIFAR-10 是常见的机器学习数据集之一,它包含一万个小型自然图像的十个类别。例如,它包含船只、车辆和动物的图像。当您想演示特定模型如何工作时,这是默认选项之一。 导入所需模块代码 将从 Keras 深度学习框架中使用多个功能。我们从 keras.datasets 导入 CIFAR-10 数据集。这是一个方便的快捷方式:只需几行代码,您就可以导入 MNIST 和 CIFAR-10 等数据集,因为 Keras 拥有与它们的 API 链接。这使我们能够完全专注于构建模型,而不是陷入大量数据加载任务。 从 keras.layers 导入 Dense(密集连接层类型)、Dropout(用于正则化)、Flatten(连接卷积层和 Dense 层)以及 Conv2D 和 MaxPooling2D(卷积和相关层)。 此外,我们从 keras.models 导入 Sequential 模型,该模型允许我们将层整齐地堆叠在一起。 然后导入 Keras 后端用于一些数据预处理功能。 导入 max_norm Constraints 可以大大增强模型,这是 Dropout 的推荐做法。 模型配置我们将 img_width = img_height = 32,因为 CIFAR-10 样本的宽度和高度为 32 像素。批次大小设置为 250,根据我的模型,实验表明它最适合 CIFAR-10。我选择使用 55 个 epoch,因为正如我们将看到的,到那时,dropout 和 no dropout 之间的差异将相当明显。 CIFAR-10 数据集支持 10 个类别,这是我们的模型可以处理的类别数(no_classes)。当详细模式设置为 1(或 True)时,所有输出都会发送到屏幕。对于验证,将使用 20% 的训练数据。 Max_norm_value 最后设置为 2.0。使用 MaxNorm Keras 约束,此数字表示可用于 max-norm 正则化的最大范数。我通过经验发现,2.0 对于今天的模型来说是一个公平的值。但是,如果您将其与另一个模型和/或数据集一起使用,您将需要进行一些实验以自行找到合适的值。 数据加载和准备为了加载和准备 CIFAR-10 数据集,需要执行以下操作: 使用 Keras 的 load_data 方法,可以轻松地将 CIFAR-10 加载到训练集和测试集的特征和目标变量中。 数据加载后,我们会根据我们使用的后端(如 CNTK、Tensorflow 或 Theano)对其进行重塑,以便数据具有一致的结构,无论后端如何。 然后,将数字解析为浮点数,这似乎会加快训练过程。然后我们标准化数据,这 神经网络 发现很有用。为了确保分类交叉熵损失适用于我们的多类分类问题,我们最后使用 to_categorical。 定义架构加载数据后,我们可以指定架构。 这与我们之前检查的架构图一致。它具有 Softmax 激活函数,为样本生成多类概率分布,并包含两个 Conv2D 和相关层以及两个 Dense 层。 训练和编译下一步是编译模型。您可以在编译或配置模型时定义损失函数、优化器以及准确率等其他指标。如前所述,我们通过分类交叉熵损失计算实际目标和预测之间的差异。我们还使用 Adam 优化器,它基本上是当今最常见的优化器之一。 配置好模型后,我们就可以将训练数据拟合到模型中!为此,我们设置 input_train 和 target_train 变量,以及验证分割、详细模式、批次大小和 epoch 数量。它们的值是预先确定的。 模型评估将评估指标添加到测试集是最后一步,它决定了模型在未见过的数据上的泛化能力。我们现在可以比较不同的模型,这将在下一步进行。 输出 Epoch 1/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 105s 641ms/step - accuracy: 0.1872 - loss: 2.3361 - val_accuracy: 0.4160 - val_loss: 1.7017 Epoch 2/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 140s 627ms/step - accuracy: 0.4234 - loss: 1.5871 - val_accuracy: 0.5080 - val_loss: 1.4064 Epoch 3/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 144s 638ms/step - accuracy: 0.5003 - loss: 1.3878 - val_accuracy: 0.5746 - val_loss: 1.2713 Epoch 4/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 149s 688ms/step - accuracy: 0.5458 - loss: 1.2734 - val_accuracy: 0.6004 - val_loss: 1.1754 Epoch 5/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 690ms/step - accuracy: 0.5800 - loss: 1.1841 - val_accuracy: 0.6275 - val_loss: 1.0953 Epoch 6/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 135s 645ms/step - accuracy: 0.6005 - loss: 1.1253 - val_accuracy: 0.6485 - val_loss: 1.0370 Epoch 7/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 103s 647ms/step - accuracy: 0.6223 - loss: 1.0616 - val_accuracy: 0.6577 - val_loss: 1.0075 Epoch 8/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 654ms/step - accuracy: 0.6388 - loss: 1.0198 - val_accuracy: 0.6691 - val_loss: 0.9689 Epoch 9/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 108s 678ms/step - accuracy: 0.6570 - loss: 0.9738 - val_accuracy: 0.6860 - val_loss: 0.9372 Epoch 10/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 685ms/step - accuracy: 0.6668 - loss: 0.9516 - val_accuracy: 0.6919 - val_loss: 0.9016 Epoch 11/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 137s 652ms/step - accuracy: 0.6837 - loss: 0.9042 - val_accuracy: 0.6973 - val_loss: 0.8813 Epoch 12/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 144s 666ms/step - accuracy: 0.6978 - loss: 0.8595 - val_accuracy: 0.7033 - val_loss: 0.8643 Epoch 13/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 115s 718ms/step - accuracy: 0.7065 - loss: 0.8318 - val_accuracy: 0.7166 - val_loss: 0.8444 Epoch 14/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 134s 670ms/step - accuracy: 0.7112 - loss: 0.8140 - val_accuracy: 0.7178 - val_loss: 0.8260 Epoch 15/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 138s 648ms/step - accuracy: 0.7252 - loss: 0.7791 - val_accuracy: 0.7149 - val_loss: 0.8278 Epoch 16/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 144s 661ms/step - accuracy: 0.7350 - loss: 0.7497 - val_accuracy: 0.7284 - val_loss: 0.8028 Epoch 17/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 662ms/step - accuracy: 0.7487 - loss: 0.7247 - val_accuracy: 0.7294 - val_loss: 0.7932 Epoch 18/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 658ms/step - accuracy: 0.7507 - loss: 0.7078 - val_accuracy: 0.7328 - val_loss: 0.7885 Epoch 19/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 102s 637ms/step - accuracy: 0.7596 - loss: 0.6837 - val_accuracy: 0.7358 - val_loss: 0.7725 Epoch 20/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 642ms/step - accuracy: 0.7668 - loss: 0.6622 - val_accuracy: 0.7337 - val_loss: 0.7823 Epoch 21/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 648ms/step - accuracy: 0.7670 - loss: 0.6547 - val_accuracy: 0.7409 - val_loss: 0.7598 Epoch 22/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 144s 661ms/step - accuracy: 0.7766 - loss: 0.6400 - val_accuracy: 0.7331 - val_loss: 0.7781 Epoch 23/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 138s 638ms/step - accuracy: 0.7826 - loss: 0.6194 - val_accuracy: 0.7428 - val_loss: 0.7600 Epoch 24/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 643ms/step - accuracy: 0.7824 - loss: 0.6136 - val_accuracy: 0.7438 - val_loss: 0.7518 Epoch 25/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 103s 644ms/step - accuracy: 0.7884 - loss: 0.5990 - val_accuracy: 0.7436 - val_loss: 0.7569 Epoch 26/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 644ms/step - accuracy: 0.7902 - loss: 0.5956 - val_accuracy: 0.7461 - val_loss: 0.7596 Epoch 27/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 103s 644ms/step - accuracy: 0.8028 - loss: 0.5629 - val_accuracy: 0.7417 - val_loss: 0.7673 Epoch 28/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 643ms/step - accuracy: 0.8041 - loss: 0.5529 - val_accuracy: 0.7426 - val_loss: 0.7552 Epoch 29/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 103s 644ms/step - accuracy: 0.8043 - loss: 0.5504 - val_accuracy: 0.7473 - val_loss: 0.7561 Epoch 30/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 148s 681ms/step - accuracy: 0.8088 - loss: 0.5415 - val_accuracy: 0.7387 - val_loss: 0.7615 Epoch 31/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 134s 635ms/step - accuracy: 0.8147 - loss: 0.5270 - val_accuracy: 0.7489 - val_loss: 0.7465 Epoch 32/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 101s 631ms/step - accuracy: 0.8177 - loss: 0.5213 - val_accuracy: 0.7416 - val_loss: 0.7610 Epoch 33/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 141s 626ms/step - accuracy: 0.8191 - loss: 0.5140 - val_accuracy: 0.7483 - val_loss: 0.7539 Epoch 34/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 151s 682ms/step - accuracy: 0.8252 - loss: 0.5014 - val_accuracy: 0.7476 - val_loss: 0.7492 Epoch 35/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 135s 638ms/step - accuracy: 0.8237 - loss: 0.4978 - val_accuracy: 0.7475 - val_loss: 0.7554 Epoch 36/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 101s 629ms/step - accuracy: 0.8230 - loss: 0.4999 - val_accuracy: 0.7482 - val_loss: 0.7547 Epoch 37/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 630ms/step - accuracy: 0.8296 - loss: 0.4896 - val_accuracy: 0.7460 - val_loss: 0.7635 Epoch 38/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 638ms/step - accuracy: 0.8339 - loss: 0.4729 - val_accuracy: 0.7462 - val_loss: 0.7728 Epoch 39/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 147s 670ms/step - accuracy: 0.8287 - loss: 0.4837 - val_accuracy: 0.7481 - val_loss: 0.7474 Epoch 40/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 135s 625ms/step - accuracy: 0.8310 - loss: 0.4791 - val_accuracy: 0.7460 - val_loss: 0.7637 Epoch 41/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 626ms/step - accuracy: 0.8377 - loss: 0.4668 - val_accuracy: 0.7440 - val_loss: 0.7658 Epoch 42/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 101s 635ms/step - accuracy: 0.8431 - loss: 0.4521 - val_accuracy: 0.7499 - val_loss: 0.7442 Epoch 43/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 140s 623ms/step - accuracy: 0.8431 - loss: 0.4486 - val_accuracy: 0.7495 - val_loss: 0.7578 Epoch 44/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 624ms/step - accuracy: 0.8436 - loss: 0.4444 - val_accuracy: 0.7509 - val_loss: 0.7523 Epoch 45/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 624ms/step - accuracy: 0.8460 - loss: 0.4368 - val_accuracy: 0.7560 - val_loss: 0.7537 Epoch 46/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 622ms/step - accuracy: 0.8416 - loss: 0.4434 - val_accuracy: 0.7491 - val_loss: 0.7621 Epoch 47/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 624ms/step - accuracy: 0.8466 - loss: 0.4329 - val_accuracy: 0.7522 - val_loss: 0.7542 Epoch 48/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 623ms/step - accuracy: 0.8476 - loss: 0.4312 - val_accuracy: 0.7477 - val_loss: 0.7695 Epoch 49/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 100s 623ms/step - accuracy: 0.8501 - loss: 0.4212 - val_accuracy: 0.7433 - val_loss: 0.7813 Epoch 50/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 143s 629ms/step - accuracy: 0.8443 - loss: 0.4364 - val_accuracy: 0.7465 - val_loss: 0.7710 Epoch 51/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 141s 623ms/step - accuracy: 0.8469 - loss: 0.4328 - val_accuracy: 0.7543 - val_loss: 0.7703 Epoch 52/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 142s 625ms/step - accuracy: 0.8519 - loss: 0.4188 - val_accuracy: 0.7497 - val_loss: 0.7562 Epoch 53/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 101s 633ms/step - accuracy: 0.8525 - loss: 0.4174 - val_accuracy: 0.7516 - val_loss: 0.7620 Epoch 54/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 141s 630ms/step - accuracy: 0.8548 - loss: 0.4168 - val_accuracy: 0.7563 - val_loss: 0.7526 Epoch 55/55 160/160 ━━━━━━━━━━━━━━━━━━━━ 141s 624ms/step - accuracy: 0.8519 - loss: 0.4117 - val_accuracy: 0.7503 - val_loss: 0.7769 Test loss: 0.797421395778656 / Test accuracy: 0.7391999959945679 Keras 中的 Dropout 实现导入必要的库首先应导入 TensorFlow 和其他必要的库。 代码 准备数据集此示例将使用 MNIST 数据集。 使用 Dropout 构建模型编译模型使用正确的优化器、损失函数和指标来编译模型。 模型训练使用训练数据来训练模型。此外,我们将使用测试数据来验证模型。 评估模型最后,使用测试数据集评估模型的性能。 可视化在训练和验证过程中,查看准确率和损失,以了解模型在 epoch 期间的表现。 输出 Epoch 1/10 469/469 ━━━━━━━━━━━━━━━━━━━━ 57s 110ms/step - accuracy: 0.7675 - loss: 0.7137 - val_accuracy: 0.9797 - val_loss: 0.0642 Epoch 2/10 469/469 ━━━━━━━━━━━━━━━━━━━━ 50s 106ms/step - accuracy: 0.9641 - loss: 0.1198 - val_accuracy: 0.9877 - val_loss: 0.0387 Epoch 3/10 469/469 ━━━━━━━━━━━━━━━━━━━━ 81s 104ms/step - accuracy: 0.9724 - loss: 0.0905 - val_accuracy: 0.9868 - val_loss: 0.0374 Epoch 4/10 469/469 ━━━━━━━━━━━━━━━━━━━━ 80s 102ms/step - accuracy: 0.9780 - loss: 0.0732 - val_accuracy: 0.9891 - val_loss: 0.0305 Epoch 5/10 469/469 ━━━━━━━━━━━━━━━━━━━━ 83s 104ms/step - accuracy: 0.9800 - loss: 0.0678 - val_accuracy: 0.9911 - val_loss: 0.0268 Epoch 6/10 469/469 ━━━━━━━━━━━━━━━━━━━━ 48s 102ms/step - accuracy: 0.9840 - loss: 0.0553 - val_accuracy: 0.9915 - val_loss: 0.0244 Epoch 7/10 469/469 ━━━━━━━━━━━━━━━━━━━━ 83s 104ms/step - accuracy: 0.9846 - loss: 0.0505 - val_accuracy: 0.9915 - val_loss: 0.0231 Epoch 8/10 469/469 ━━━━━━━━━━━━━━━━━━━━ 48s 102ms/step - accuracy: 0.9859 - loss: 0.0455 - val_accuracy: 0.9920 - val_loss: 0.0228 Epoch 9/10 469/469 ━━━━━━━━━━━━━━━━━━━━ 83s 104ms/step - accuracy: 0.9863 - loss: 0.0428 - val_accuracy: 0.9922 - val_loss: 0.0225 Epoch 10/10 469/469 ━━━━━━━━━━━━━━━━━━━━ 81s 101ms/step - accuracy: 0.9876 - loss: 0.0415 - val_accuracy: 0.9924 - val_loss: 0.0212 313/313 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.9903 - loss: 0.0261 Test accuracy: 0.9923999905586243 ![]() |
我们请求您订阅我们的新闻通讯以获取最新更新。