机器学习中的猫分类

2025年3月17日 | 阅读 12 分钟
Cat Classification Using Machine Learning

猫的分类是一个判断图像是否包含猫的过程。虽然这对人类来说很简单,但由于猫的外观、姿势和背景的多样性,这对机器人来说是一个相当大的挑战。机器学习算法通过识别区分猫与其他物体或动物的模式和特征来解决这个问题。

这个问题可以通过深度学习来解决。深度学习是机器学习的一个分支,研究基于大脑结构和功能(尤其是人工神经网络)的算法的设计和开发。这些神经网络由许多层组成,可以在不同抽象层次上学习数据的特征和表示。这使得深度学习模型在图像识别、自然语言处理和语音识别等各种任务上都能达到最先进的性能。

在这里,我们将构建一个用于图像分类的CNN模型,并使用PyTorch使用该模型进行数据预测。

代码

本课程使用的数据集是猫狗数据集,其中包含猫和狗的图片。让我们看看数据集文件夹中的文件。我将从创建一个函数开始。

输出

Cat Classification Using Machine Learning

我们已经看到了数据集文件夹中的文件。现在让我们创建训练和测试路径。

输出

Cat Classification Using Machine Learning

EDA

理解数据集对于深度学习分析至关重要,因为它是任何机器学习或深度学习模型的基础。深度学习模型的性能与其训练数据的质量相当,而对数据集的理解不足可能导致模型性能不佳或产生偏差。让我们来看一下数据集中的一张图片。

输出

Cat Classification Using Machine Learning
Cat Classification Using Machine Learning

我们也可以使用 matplotlib 进行数据可视化,我们来尝试一下。

输出

Cat Classification Using Machine Learning

转换数据

数据转换,也称为预处理,是深度学习研究中的一个关键步骤,因为它能提高模型性能并降低偏差的可能性。让我们使用 transform 方法对图片进行一些操作。

为了更好地理解如何转换图片,让我们使用数据可视化。为此,我们将创建一个名为 plot_transformed_images 的函数。

输出

Cat Classification Using Machine Learning
Cat Classification Using Machine Learning
Cat Classification Using Machine Learning

您可以看到我们是如何转换数据的。

加载图像数据

到目前为止,我们已经构建了一个数据转换函数。我们已经准备好使用这个函数来加载我们的数据集。使用 PyTorch 的 ImageFolder 方法是加载数据的最简单方法。让我们使用这个函数来加载数据集。

输出

Cat Classification Using Machine Learning

现在让我们使用下面的属性来探索数据集。

输出

Cat Classification Using Machine Learning

我们了解了一些关于数据集的信息。让我们获取一张图片并查看其特征。

输出

Cat Classification Using Machine Learning

让我们用 matplotlib 来可视化这张图片。

输出

Cat Classification Using Machine Learning
Cat Classification Using Machine Learning

到目前为止,我们已经加载了图片。PyTorch 中的 'DataLoader' 是一个工具,它以并行方式从数据集对象加载数据。它允许用户分批导入数据,这对于训练深度学习模型可能很有益,因为它允许模型一次分析多个样本,从而加速训练过程。此外,它还允许用户对数据进行洗牌,这有助于避免过拟合。

DataLoader 接受一个数据集对象以及各种其他可选参数,包括批次大小、用于数据加载的工作线程数以及一个表示是否洗牌数据的布尔标志。然后,DataLoader 将提供一个迭代器,允许您按批次迭代数据。

输出

Cat Classification Using Machine Learning

我们将数据集转换为 DataLoader 对象。现在让我们获取一个批次的图像并检查该批次的形状。

输出

Cat Classification Using Machine Learning

模型构建

数据增强是人工增加数据集大小的技术,通过对现有数据进行随机修改来实现。这可以通过为模型提供更广泛的训练样本来帮助提高深度学习模型的性能。当数据集有限或模型容易过拟合时,数据增强可能很有效。

通过利用数据增强方法,模型可以学习更有效地泛化,并对数据中的微小波动更具弹性。这有助于防止过拟合并提高模型在新数据上的性能。需要强调的是,数据增强应在数据预处理之前进行。请记住,数据增强仅应应用于训练集,而不应应用于验证集或测试集。

TrivialAugmentWide 是一种 PyTorch 数据增强技术,可以随机调整图像大小并裁剪图像。该方法被设计为一种“宽泛”的数据增强技术,这意味着它对图像进行大量随机修改以增加训练数据的多样性。这有助于提高机器学习模型的弹性和泛化能力。

现在让我们再次使用数据增强来加载数据集。

输出

Cat Classification Using Machine Learning

输出

Cat Classification Using Machine Learning

CNN 图像分类器

卷积神经网络 (CNN) 是一种深度学习神经网络,主要执行图像和视频识别任务。CNN 设计用于处理具有网格状拓扑结构的数据,例如图像,图像由排列在二维网格中的像素组成。CNN 的架构包括多个层,例如卷积层、池化层和全连接层。

卷积层使用一组可学习的滤波器来识别输入图像中的特征。这些滤波器与输入图像进行卷积,生成一组特征图,然后通过池化层进行处理,以减小空间尺寸并仅保留最重要的特征。然后,全连接层分析池化特征图以产生最终输出,该输出可能是一个分类标签。

CNN 已被用于在各种计算机视觉应用(包括图像分类、对象检测和语义分割)中实现最先进的性能。现在让我们使用 Pytorch 的 nn.Module 来创建一个 CNN 模型。

我们现在将输入一张图片,以便我们能理解这个模型是否在工作。

输出

Cat Classification Using Machine Learning

模型描述

理解模型架构非常重要。幸运的是,torchinfo 包可以帮助您检查模型的架构。

输出

Cat Classification Using Machine Learning
Cat Classification Using Machine Learning

训练模型

现在模型已在训练数据上构建,并在验证集上进行了测试。现在让我们创建两个函数来训练和测试模型。


现在让我们创建一个名为 train 的函数,将 train_step 和 test_step 函数结合起来。

到目前为止,我们已经建立了训练和测试阶段。我们已经准备好使用这些步骤来训练模型。

输出

Cat Classification Using Machine Learning

评估模型

为了更好地理解模型的性能,可以考虑可视化损失和准确性指标。

现在是时候使用此函数来可视化损失和准确性值了。

输出

Cat Classification Using Machine Learning

我们的模型在训练集和测试集上表现都相当好。让我们看看如何进行预测。

进行预测

现在我们拥有了一个不错的图像分类模型。为了理解这一点,我将对自定义图像进行预测。

输出

Cat Classification Using Machine Learning

现在让我们创建一个变换管道来调整图像大小。

输出

Cat Classification Using Machine Learning

首先,我们将模型的图片拟合到我们生成的函数中,然后进行预测。

输出

Cat Classification Using Machine Learning

现在让我们看一下我们模型的预测结果。

输出

Cat Classification Using Machine Learning

我们已经看到了模型的预测值。让我们看一下预测类别。

输出

Cat Classification Using Machine Learning

输出

Cat Classification Using Machine Learning

输出

Cat Classification Using Machine Learning

我们的模型预测正确了。它确实是猫。


下一个主题AIC 和 BIC