CycleGAN2025年3月17日 | 阅读18分钟 图像到图像的翻译是创建现有图像的合成修改新版本的过程。例如,将夏日风景转换为冬日风景。通常需要大量匹配的实例集合来训练图像到图像的翻译模型。某些文件,例如已故画家的艺术品照片,可能非常昂贵、复杂,甚至不可能收集。一种名为 CycleGAN 的方法可在没有配对实例的情况下自动训练图像到图像的翻译模型。通过使用来自源域和目标域的照片集——这些域之间不必有任何连接——模型会自动进行训练。 CycleGAN 由两种类型的网络组成:判别器和生成器。在此示例中,判别器负责将图像分类为真实或伪造(对于 X 和 Y 两种类型的图像)。生成器负责为两种类型的图像生成逼真的伪造图像。 它因其无配对图像翻译能力而受到青睐,这使得它能够在不要求训练集中有匹配对的情况下学习跨不同图像域的映射。它提供了更大的灵活性和适应性,因为它以无监督的方式运行,并且可以从源域和目标域的图片集中学习,而无需明确的关联。循环一致性是指翻译后的图像在多次翻译后仍能保持其源的真实性,从而产生更逼真的结果。由于该方法大大减少了对配对数据集的依赖,因此 CycleGAN 在难以获取标注数据的场景中非常有用。 CycleGAN 的实现导入库加载数据集get_data_loader 函数返回可以快速加载数据并按预定批次加载数据的训练和测试 DataLoader。该函数具有以下参数:
测试数据旨在供我们未来的生成器使用,以便我们可以查看固定在测试数据上的一些生成样本。 正如你所见,此函数还负责确保我们的图像被转换为 Tensor 图像类型并具有正确的方形尺寸(128x128x3)。 注意:建议将这些设置保留为默认值。如果你尝试在另一组数据上运行此代码,更高的 image_size 和 batch_size 选项可能会产生更好的结果。在训练循环中调整 batch_size 之前,请务必构建完整的批次,因为这样做可能会在尝试存储样本数据时导致错误。可视化数据输出 X 数据可视化 ![]() 输出 Y 数据可视化 ![]() 缩放由于我们知道 tanh 激活的生成器输出的像素值将在 -1 到 1 之间变化,因此我们需要进行一些预处理。因此,我们必须将训练图像重新缩放到此范围内。(目前,它们在 0 到 1 之间。) 输出 ![]() 输出 ![]() 定义模型CycleGAN 由两个生成器网络和两个判别器网络组成。 判别器在此 CycleGAN 中,判别器 DX 和 DY 是卷积神经网络,它们分析图像并尝试确定它是真实的还是伪造的。在这里,输出接近 1 表示真实,输出接近 0 表示伪造。判别器具有以下架构: ![]() 将 256x256x3 大小的图像输入到该网络,并通过 5 个卷积层进行处理,这些卷积层将其下采样 2 倍。BatchNorm 和 ReLu 激活函数应用于前四个卷积层的输出,而最后一层充当分类层并产生单个值。 卷积辅助函数你应该使用提供的 conv 函数,它会生成一个卷积层加上一个可选的 batch norm 层,来定义判别器。 判别器架构使用上述五层卷积网络设计,挑战在于完成 __init__ 函数。我们只需要指定一个类,然后实例化两个判别器,因为 DX 和 DY 共享相同的设计。 forward 函数决定了图像如何进入判别器;重要的是按顺序将图像通过每个卷积层,对除最后一层之外的所有层使用 ReLu 激活函数。 由于我们要使用平方误差损失进行训练,因此在这种情况下不应在输出中添加 sigmoid 激活函数。稍后你可以在笔记本中了解有关此损失函数的更多信息。 生成器生成器 G_XtoY 和 G_YtoX(有时称为 F)由一个编码器(一个将图像压缩为较小特征表示的卷积网络)和一个解码器(一个将该表示转换为修改后图像的转置卷积网络)组成。从 Y 到 X 以及从 X 到 Y 的这些生成器的构造如下: ![]() 当该网络接收到 256x256x3 的图像时,它会将其压缩为特征表示,并通过三个卷积层,然后进入一组残差块。它会通过多个此类残差块——通常是六个或更多——然后通过三个转置卷积层,也称为反卷积层,它们将残差块的输出上采样以生成新图像! 除了最后一个转置卷积层应用 tanh 激活函数到输出之外,请注意,大多数卷积层和转置卷积层在其输出上应用了 BatchNorm 和 ReLu 函数。此外,卷积层和批归一化层构成了残差块;我们稍后将更详细地讨论这些。 残差块类为了定义生成器,我们必须构建一个 ResidualBlock 类。此类将使我们能够连接生成器的编码器和解码器部分。也许你正在想,Resnet 块具体是什么?它可能看起来与图像分类系统 ResNet50 相似,如下所示。 ![]() 通过使用残差块,我们可以学习所谓的残差函数,当它们应用于层输入时,这是解决此问题的一种方法。 ![]() 残差函数典型的深度学习模型由许多带有激活函数的层组成,其任务是学习从输入 (x) 到输出 (y) 的映射 M。 通过定义残差函数,我们可以避免学习从 x 到 y 的直接映射。 这会检查原始输入 x 和应用于 x 的映射之间的差异。通常,F(x) 由一个归一化层、两个卷积层和一个中间的 ReLu 组成。这些卷积层的输入和输出的数量应相等。然后,映射可以表示为输入 x 和残差函数的函数。通过加法步骤,在输入 (x) 和输出 (y) 之间形成一个几乎圆形的连接。 定义 ResidualBlock 类我们将构建残差函数,这是一组层,将它们应用于输入 x,然后将它们添加到相同的输入,以定义 ResidualBlock 类。这与任何其他神经网络一样,使用相同的 __init__ 函数和 forward 函数加法步骤来定义。 在这种情况下,残差块应定义如下:
接下来,在 forward 函数中将输入 x 添加到此残差块。你可以使用上面提到的辅助 conv 方法来创建此块。 转置卷积辅助函数![]() 然后,我们使用 ResidualBlock 类、上面的 conv 方法和下面的 deconv 辅助函数来定义生成器。这些将生成一个转置卷积层以及一个可选的 batchnorm 层。 生成器架构
由于 GXtoY 和 GYtoX 的架构相同,因此我们只需要编写一个类,然后实例化两个生成器。 完成网络我们可以指定构建完整 CycleGAN 所需的生成器和判别器,使用您已经建立的类。提供的设置对于训练应该很有效。 首先,创建两个判别器:一个用于验证 X 样本图像的真实性,另一个用于验证 Y 样本图像的真实性。然后是生成器。创建两个实例:一个用于将一幅画转换为逼真的图像,另一个用于将一张图像转换为一幅画。 输出 ![]() 生成器和判别器的损失![]()
最小二乘 GAN如前所述,常规 GAN 使用 sigmoid 交叉熵损失函数将判别器视为分类器。然而,在学习阶段,此损失函数可能会导致梯度消失问题。我们将为判别器使用最小二乘损失函数来解决此问题。此结构也称为 LSGAN,即最小二乘 GAN。 判别器损失判别器损失定义为判别器的输出(图像)与目标值之间的均方误差,目标值可以是 0 或 1,具体取决于判别器应该将图像分类为真实还是伪造。例如,使用均方误差,我们可以通过检查 DX 在识别真实图像 x 时的接近程度来训练它。 生成器损失生成器损失计算过程将包含与判别器损失计算过程相似的阶段;这些过程包括创建看起来属于 X 图像集但实际上基于 Y 真实图像的伪造图像,反之亦然。这次,你的生成器试图让判别器将这些伪造图像识别为真实图像,因此你将通过检查判别器对这些伪造图像的应用来计算这些伪造图像的“真实损失”。 循环一致性损失除了对抗损失之外,生成器损失项还将包含循环一致性损失。此损失是用于评估重建图像质量与原始图像质量的指标。 假设你有一个生成的伪造图像 x_hat 和一张真实图像 y。应用 G_XtoY(x_hat) = y_hat 将得到一个重建的 y_hat。然后,你可以验证这个重建的 y_hat 与原始图像 y 是否匹配。为此,我们建议计算原始图像和重建图像之间的 L1 损失——绝对差值。为了强调此损失的重要性,你还可以选择将其乘以权重值 lambda_weight。 生成器损失总额将由生成器损失以及前向和后向循环中的一致性损失的总和决定。 定义优化器![]() 训练CycleGAN 在看到 X 和 Y 集合的一个真实图像批次后,通过执行以下操作进行训练: 判别器训练
生成器训练
![]() 辅助函数训练和损失模式找到理想的超参数,使得判别器和生成器不会互相压倒,这需要大量的反复试验。我建议阅读这篇 DCGAN 研究以及原始的 CycleGAN 论文,看看他们是如何做的。查看现有论文以了解早期研究中哪些有效通常是个好主意。之后,你将有一个坚实的基础来测试你自己的实验。 判别器损失请记住,我们正在尝试创建一个能够生成高质量“伪造”图像的模型,因此当你绘制生成器和判别器损失时,你应该注意到总有一些判别器损失。因此,总会有一些损失,因为完美的判别器将无法区分真实图像和伪造图像。此外,你应该注意到 DX 和 DY 的损失水平大致相同。如果不是这样,这表明你的训练中偏向于某一种判别器,你可能需要检查你的模型或数据中的偏差。 生成器损失由于生成器损失同时考虑了生成器损失和加权的重建误差,因此它应该比判别器损失高得多。由于最初生成的图像往往离好的伪造相去甚远,你应该注意到在训练初期损失会大幅下降。随着训练的进行,判别器和生成器都会进步,因此通常在一段时间后会趋于平稳。如果你注意到损失随着时间的推移波动很大,可以尝试调整循环一致性损失的权重,使其稍多或稍少,或者降低学习率。 输出 ![]() ![]() ![]() ![]() 转换可视化输出 ![]() 转换翻译后,我们可以看到伪造图像得到了改进。 模型似乎在每个 epoch 中都显示出判别器(d_X_loss, d_Y_loss)和生成器(g_total_loss)的损失下降,这可能表明性能尚可。 注意:但是,仅凭这些损失无法精确确定模型的有效性。下一个主题DNN 机器学习 |
我们请求您订阅我们的新闻通讯以获取最新更新。