卷积神经网络模型的验证

17 Mar 2025 | 5 分钟阅读

在训练部分,我们已经在MNIST数据集(无限数据集)上训练了我们的CNN模型,并且它似乎达到了合理的损失和准确率。如果该模型可以利用它所学到的东西并将其推广到新数据,那么这将是对其性能的真正证明。这将以与我们在上一个主题中相同的方式完成。

步骤 1

我们将借助我们在训练部分中创建的训练数据集来创建我们的验证集。这一次,我们将train设置为false,如下所示:

步骤 2

现在,类似于为什么我们在训练部分中声明了一个训练加载器,我们将定义一个验证加载器。验证加载器的创建方式也与我们创建训练加载器的方式相同,但这次我们传递的是训练加载器,而不是训练数据集,并且我们将shuffle设置为false,因为我们不会训练我们的验证数据。没有必要对其进行洗牌,因为它仅用于测试目的。

步骤 3

我们的下一步是分析每个epoch的验证损失和准确率。为此,我们必须创建两个列表,分别用于验证运行损失和验证运行损失校正。

步骤 4

下一步,我们将验证模型。该模型将在同一epoch中进行验证。在我们完成遍历整个训练集以训练我们的数据之后,我们现在将遍历我们的验证集以测试我们的数据。

我们将首先衡量两件事。第一件事是我们模型的性能,即,它在验证集上对测试集进行了多少次正确的分类,以检查是否过度拟合。我们将验证的运行损失和运行校正设置为:

步骤 5

我们现在可以循环遍历我们的测试数据。因此,在else语句之后,我们将为标签和输入定义一个循环语句,如下所示:

步骤 6

我们正在处理卷积神经网络,输入首先要经过它。我们将专注于这些图像的四个维度。因此,没有必要将它们扁平化。

正如我们将模型分配给我们的设备一样,我们也将输入和标签分配给我们的设备。

现在,在这些输入的帮助下,我们得到输出为

步骤 7

借助输出,我们将计算总的类别交叉熵损失,并且输出最终会与实际标签进行比较。

我们没有训练我们的神经网络,因此无需调用zero_grad(),backward()或任何类似的操作。并且也不再需要计算导数。在操作范围内为了节省内存,我们在使用torch的For循环之前调用no_grad()方法,如下所示:

它将暂时将所有require grad标志设置为false。

步骤 8

现在,我们将以与计算训练损失和训练准确率相同的方式计算验证损失和准确率,如下所示:

步骤 9

现在,我们将计算验证epoch损失,这将与我们计算训练epoch损失的方式相同,即将总运行损失除以数据集的长度。因此,它将写成:

步骤 10

我们将打印验证损失和验证准确率,如下所示:


Validation of Convolutional Neural Network

步骤 11

现在,我们将对其进行绘制以进行可视化。我们将绘制它,如下所示:


Validation of Convolutional Neural Network

Validation of Convolutional Neural Network

从上面的图中可以清楚地看出,CNN中发生了过度拟合。为了减少这种过度拟合,我们将介绍另一种名为Dropout Layer的快速技术。

步骤 12

在下一步中,我们将转到LeNet类并添加一种特定的层类型,这将减少我们数据的过度拟合。此层类型称为Dropout层。该层本质上通过在训练期间随机将一部分输入单元设置为0来发挥作用,并且每次更新都会发生这种情况。

Validation of Convolutional Neural Network

上图显示了一个标准的神经网络,以及应用dropout后相同的神经网络。我们可以看到,一些节点已被关闭,不再与网络一起传递信息。

我们将使用多个dropout层,这些层将在给定的网络中使用以获得所需的性能。我们将把这些dropout层放置在卷积层之间以及完全连接的层之间。dropout层用于参数数量较多的层之间,因为这些高参数层更有可能过度拟合和记住训练数据。因此,我们将在完全连接的层之间设置我们的dropout层。

我们将借助nn.Dropout模块初始化我们的dropout层,并在我们的初始化程序中传递dropout率作为参数。给定节点被丢弃的概率将设置为0.5,如下所示:

步骤 13

下一步,我们将在前向函数中的完全连接层之间定义我们的第二个dropout层,如下所示:

现在,我们将运行我们的程序,它将为我们提供更准确的结果,如下所示:

Validation of Convolutional Neural Network

Validation of Convolutional Neural Network

Validation of Convolutional Neural Network

完整代码


下一主题CNN的测试