如何在 PyTorch 中获取模型摘要?

2025年7月22日 | 阅读 8 分钟

打印模型摘要对于了解神经网络的结构和参数至关重要。虽然 Keras 有一个基本的 model.summary() 方法,但在 PyTorch 中,需要通过另一个命令来实现。在本文中,我们将指导您如何使用 torchinfo 包在 PyTorch 中打印详细的模型摘要,该包取代了 torch summary。

为什么模型摘要很重要?

在我们开始介绍如何操作之前,了解模型摘要的用处非常重要。

  • 调试:它有助于在开发的早期阶段确定模型架构中的错误或不匹配之处。
  • 优化:提供关于可训练参数数量和计算成本的信息,这在调整或规划资源时非常有用。
  • 文档:作为模型结构的简要指南,方便与他人分享并解释其构造。

分步获取模型摘要

torch summary 是一个易于使用的工具,可以为 PyTorch 模型生成完整的摘要,就像 Keras 中的 model.summary() 所做的那样。它显示了层的类型、输出的形状以及每层中的参数数量。

1. 基于 torchsummary 包的使用

可以通过 pip 安装 torchsummary。

这是使用 torchsummary 打印 PyTorch 模型摘要的方法。

输出

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 32, 28, 28]             320
            Conv2d-2          [-1, 64, 28, 28]          18,496
            Linear-3               [-1, 128]        64,020,480
            Linear-4                [-1, 10]             1,290
================================================================
Total params: 64,040,586
Trainable params: 64,040,586
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 7.68
Params size (MB): 244.30
Estimated Total Size (MB): 251.99

2. 模型摘要的模型实现

如果您不想使用外部库,也可以自行实现查看模型摘要的功能。下面是一个简化的自定义实现:

输出

----------------------------------------------------------------
        Layer (type)              Output Shape         Param #
================================================================
           Conv2d-1             [1, 32, 28, 28]             320
           Conv2d-2             [1, 64, 28, 28]          18,496
             Linear-3                  [1, 128]      64,020,480
             Linear-4                   [1, 10]           1,290
================================================================
Total params: 64,040,586

3. 使用 torchinfo

您可以使用 torchinfo 包在 PyTorch 中显示详细的模型摘要。它提供了精确的、类似 Keras 的模型架构描述。

示例

输出

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
SimpleModel                              [1, 10]                   0
├─Conv2d: 1-1                            [1, 32, 26, 26]           320
├─MaxPool2d: 1-2                         [1, 32, 13, 13]           0
├─Conv2d: 1-3                            [1, 64, 11, 11]           18,496
├─MaxPool2d: 1-4                         [1, 64, 5, 5]             0
├─Linear: 1-5                            [1, 128]                  204,928
├─Linear: 1-6                            [1, 10]                   1,290
==========================================================================================
Total params: 225,034
Trainable params: 225,034
Non-trainable params: 0
Total mult-adds (M): 1.42
------------------------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.32
Param size (MB): 0.86
Estimated Total Size (MB): 1.18

torchsummary 的优势

轻量级且简单

torchsummary 是一个简单明了的包,其操作只需要最小的设置。您可以在几行代码中生成 PyTorch 模型的简洁摘要。这使得它对于初学者或需要快速有效方法来检查模型结构的用户来说特别方便。

类似 Keras 的输出

这是 torchsummary 的重要吸引力之一,因为它与 Keras 的 model.summary() 相似。这对于从 TensorFlow/Keras 迁移到 PyTorch 的开发者来说非常方便,学习曲线更平缓,并且能够以熟悉的方式检查模型。

快速调试工具

当您创建或编辑模型时,torchsummary 是一个非常有用的调试工具。它可以即时显示每层有多少参数,以及该层的输出形状是什么,有助于在模型开发周期的早期识别配置问题。

torchinfo 的优势

丰富的信息和更好的详细输出

torchinfo 在 torchsummary 的基础上提供了更长的模型摘要。除了层的类型和参数数量外,它还提供了内存消耗、参数可训练性以及乘加运算 (FLOPs) 的估计,从而更深入地了解模型的性能特征。

嵌套模块支持

与 torchsummary 不同,torchinfo 还能显示模型的嵌套或分层组件。这在处理较新架构(如 ResNet、U-Net 或 Transformers)时特别有用,这些架构倾向于在块中嵌入块或子模块。分层格式可以更好地理解数据在整个结构中的流动。

高度可定制

torchinfo 允许用户管理摘要输出的许多功能。可以设置详细程度、控制模块(嵌套)显示的深度、指定设备(CPU/GPU),以及调整列宽。这种程度的自定义确保了该工具适用于各种用途,包括研究或设备部署诊断。

打印模型摘要时遇到的常见问题

在使用 torchinfo 或 torchsummary 等工具创建 PyTorch 模型摘要时,用户可能会遇到一些常见问题。如果处理不当,这些问题可能导致错误或不完整的結果。下面描述了最重要的问题。

1. 形状不匹配

打印模型摘要时常见的错误是输入的尺寸与模型第一层的预期尺寸不符。这种不匹配可能导致运行时错误或不正确的摘要。规避此问题的方法是将输入张量的形状与模型进行正确的匹配。

例如,一个期望(灰度)图像的卷积层将期望输入形状为 (1, height, width),而彩色图像则需要 (3, height, width)。通常建议始终确认 summary() 函数在给定模型架构和后续维度下所期望的输入大小。

2. 未注册或定义不当的模块

当使用未注册为模型组件的自定义层或组件进行建模时,会出现另一个常见问题。该模块必须继承自 nn.Module;否则,PyTorch 将无法识别并跟踪它在模型层级结构中的位置。因此,所有此类层都不会显示在摘要中。要解决这个问题,我们需要所有自定义组件都继承自 nn.Module,并在模型的 __init__() 方法中声明。只有正确注册的模块才会被遍历并包含在摘要输出中。

3. forward() 中的动态操作

在 forward() 方法调用中使用的递归,通过循环结构、条件判断或根据输入参数改变形状的逻辑,会干扰模型跟踪,从而影响摘要的生成。torchinfo 等分析器使用固定计算图的跟踪。

如果模型的行为根据输入而改变,或者实现了形状相关的逻辑,那么模型可能无法生成完整或准确的摘要。在调试或审查模型时,forward 传递应保持确定性且尽可能不变。

4. 设备配置丢失

对于在 GPU 上运行的模型,模型和用于构建摘要的虚拟输入张片之间的不匹配可能导致错误。例如,如果您的模型在 CUDA 上,而输入张片默认设置为 CPU,那么摘要将失败。

要解决此问题,有必要在 summary() 调用中指定要使用的设备,例如,device="cuda" 或 device="cpu",具体取决于您的模型所在的位置。这可以保持一致性,并避免在生成摘要时出现与设备相关的问题。

5. Flatten 或 View 操作错误

最基本的一个错误发生在将预处理数据展平以输入全连接层时。当线性层中计算的输入特征应具有的特征数量不正确时,可能会发生运行时错误。

这通常是由于未能正确计算该层之前的卷积层的输出大小。为了避免这种情况,您可以在 forward() 中打印中间张量的形状,或将一个虚拟输入通过模型,检查实际的特征数量,然后定义您的线性层。

6. 超出 Module 范围的 Functional 层的使用

PyTorch 允许在 forward() 方法中调用函数式 API 层(例如 torch.nn.functional.relu)。然而,这些模块未被注册,因此不会出现在层摘要中。虽然这不会导致训练错误,但会导致模型摘要不完整,因为使用模块内省的工具看不到这些操作。为了获得全面的视图,请查看它们的模块版本(例如 nn.ReLU),这些版本在 __init__() 框架中表示。

结论

通过 PyTorch,可以更容易地开发和调试神经网络架构及其可视化,因为可以打印详细的模型摘要。无论是使用 torchsummary、torchinfo 还是自定义函数,清晰地显示模型结构都对开发和调试大有帮助。它尤其有助于检测形状不兼容、未注册模块和参数计算错误等问题。

通过本教程,您应该能够安装和使用 torchsummary 包,开发自己的模型摘要和打印功能,并处理在使用复杂模型时可能遇到的各种问题。摘要可以结构清晰,并对模型中的层、参数和数据流提供很好的解释,因此,它是构建深度学习模型的一个非常有用的组成部分。