使用 PyTorch 和 PyTorch Geometric 进行图神经网络7 Jan 2025 | 7 分钟阅读 图神经网络 (GNN)图神经网络(GNN)是一类专门用于处理图结构信息的神经网络。由于其在捕捉图中节点之间复杂连接和依赖关系方面的专业能力,它们受到了广泛的认可。GNN 在不同的领域都有广泛的应用:社交网络分析、化学、生物学和推荐系统。 关键概念图的表示 - 节点:图中的实体,每个实体都可以拥有相关的特征。
- 边:节点之间的连接,也可以拥有相关的特征或权重。
- 图(G)通常表示为 ( G = (V, E) ),其中 ( V ) 是节点集合,( E ) 是边集合。
图的类型 - 有向与无向:在有向图中,边有方向;在无向图中,边是双向的。
- 加权与无权:在加权图中,边具有权重,表示关系的强度或值。
消息传递框架 - GNN 的核心思想是通过一种称为消息传递的方式,聚合邻近节点的信息来迭代地更新节点表示。
- 每个节点通过接收其邻居的消息并使用聚合函数组合这些消息来更新其状态。
节点、边和图级别的任务 - 节点分类:基于节点的特征和图结构预测节点的标签。
- 边预测:预测节点之间的边的存在性或类型(也称为链接预测)。
- 图分类:为整个图预测一个标签,常用于分子性质预测。
应用 - 社交网络:分析社交网络中的关系和影响力。
- 推荐系统:利用用户-物品交互图来推荐物品。
- 生物网络:研究蛋白质-蛋白质相互作用或预测分子性质。
- 交通和城市规划:分析交通网络以进行交通预测和路径优化。
PyTorchPyTorch 是由 Facebook AI 研究实验室开发的一个著名的开源机器学习库。它为构建和训练深度学习模型提供了一个灵活且强大的框架。以下是 PyTorch 的概述以及如何将其用于各种机器学习任务。 主要特点- 动态计算图:PyTorch 使用动态计算图,这意味着图是在操作进行时即时构建的。这使得调试和模型构建比静态计算图更直观、更灵活。
- 张量运算:PyTorch 张量类似于 NumPy 数组,但可以在 GPU 上使用以加速计算。它们支持各种运算,包括算术、线性代数等。
- 自动微分:PyTorch 的 `autograd` 模块提供自动微分功能,让您可以轻松计算优化所需的梯度。
- 模块化设计:PyTorch 具有模块化和可扩展的设计,易于创建自定义层、损失函数和优化器。
- 与 Python 集成:PyTorch 与 Python 环境无缝集成,支持原生的 Python 控制流、与 NumPy 等库的集成等等。
使用 PyTorch 实现 GCN让我们看一个使用 PyTorch 实现 GCN 的示例。 示例 让我们通过一个简单的图卷积网络(GCN)进行节点分类的示例。 先决条件:请确保您已安装 PyTorch 和 PyTorch Geometric。您可以使用以下命令进行安装: 代码 输出 Model output: tensor([[-0.8819, -0.5344],
[-0.9129, -0.5131],
[-0.8737, -0.5402]], grad_fn=<LogSoftmaxBackward0>)
Loss: 0.7562562823295593
说明 模型定义 - `GCN` 类继承自 `torch.nn.Module`,并定义了 GCN 层(`GCNConv`)。
- `conv1` 和 `conv2` 是图卷积层。`conv1` 将输入特征映射到 16 维,`conv2` 映射到 2 维(用于分类)。
前向传播方法 - 接收一个包含节点特征(`x`)和边索引(`edge_index`)的图数据对象(`data`)。
- 应用第一个 GCN 层,后跟 ReLU 激活函数。
- 应用第二个 GCN 层和 log softmax 激活函数进行分类。
数据准备 - `x` 是一个包含三个节点特征的张量。
- `edge_index` 定义了节点之间的边。
- `Data` 对象组合了节点特征和边索引。
训练 - 我们通过模型进行一次前向传播。
- 定义一个简单的分类目标,并使用 `F.nll_loss` 计算损失。
- 执行反向传播,并使用优化器更新模型参数。
输出 - 打印模型输出和损失。输出是每个节点的类别对数概率分布。
PyTorch Geometric (PyG)PyTorch Geometric (PyG) 是 PyTorch 的一个扩展库,旨在方便处理图结构数据和实现图神经网络(GNN)。它提供了广泛的工具、数据集和预构建层,使得构建和训练 GNN 变得简单。 主要特点- 图数据处理: PyG 提供了高效的数据结构来表示和处理图数据,包括对大规模图的支持。
- 预构建 GNN 层:该库包含各种流行的 GNN 层,如 GCN(图卷积网络)、GAT(图注意力网络)、GraphSAGE 等。
- 消息传递 API:PyG 实现了一个灵活而强大的消息传递框架,允许您轻松定义数据如何在节点之间聚合和传播。
- 丰富的数据集集合:PyG 提供了用于节点分类、链接预测和图分类任务的基准数据集集合。
- 可扩展性:PyG 经过优化以提高性能,可以有效地处理大规模图,利用稀疏张量运算和 GPU 加速。
使用 PyTorch Geometric 实现 GCN让我们看一个使用 PyTorch Geometric 实现 GCN 的示例。 Code Example 输出 Epoch 0, Loss: 1.9471
Epoch 10, Loss: 0.5516
Epoch 20, Loss: 0.0935
Epoch 30, Loss: 0.0244
Epoch 40, Loss: 0.0137
Epoch 50, Loss: 0.0128
Epoch 60, Loss: 0.0145
Epoch 70, Loss: 0.0163
Epoch 80, Loss: 0.0170
Epoch 90, Loss: 0.0166
Epoch 100, Loss: 0.0157
Epoch 110, Loss: 0.0147
Epoch 120, Loss: 0.0139
Epoch 130, Loss: 0.0132
Epoch 140, Loss: 0.0127
Epoch 150, Loss: 0.0122
Epoch 160, Loss: 0.0117
Epoch 170, Loss: 0.0113
Epoch 180, Loss: 0.0109
Epoch 190, Loss: 0.0106
Accuracy: 0.8080

 说明 数据集加载 - 我们使用 PyTorch Geometric 的 `Planetoid` 类加载 Cora 数据集。
- 这个数据集是一个引用网络,其中每个节点代表一篇论文,边表示论文之间的引用。
模型定义 - 我们定义了一个简单的图卷积网络(GCN),包含两个层。
- 第一个 `GCNConv` 层将节点特征降维至 16 维,后跟 ReLU 激活函数。
- 第二个 `GCNConv` 层将节点特征映射到类别的数量。
正向传播 - 在前向传播方法中,我们将节点特征和边索引通过 GCN 层,应用 log softmax 来获得类别概率。
训练循环 - 我们使用 Adam 优化器和负对数似然损失来训练模型 200 个 epoch。
- 模型在节点上进行训练,特别是通过 `train_mask`,它指示了训练集。
求值 - 我们使用 `test_mask`(指示用于测试的节点)来评估模型在测试集上的准确率。
可视化 - 我们绘制了训练损失随 epoch 的变化图,并使用 NetworkX 和 Matplotlib 可视化了节点按预测类别着色的图。
|