您的位置:首页 > 其它

论文解读(GraphSAGE)《Inductive Representation Learning on Large Graphs》

2022-01-17 08:24 1161 查看

 《Inductive Representation Learning on Large Graphs》

  论文标题:Inductive Representation Learning on Large Graphs

  论文作者: William L. Hamilton (wleif@stanford.edu), Rex Ying (rexying@stanford.edu)

  论文来源:NIPS 2017

  论文链接:chrome-extension://ibllepbpahcoppkjjllbabhnigcbffpi/https://arxiv.org/pdf/1706.02216.pdf

  论文代码:https://github.com/williamleif/GraphSAGE

2 介绍及相关工作    

Transductive Learning

  假设要测试的节点和训练的节点在一个图中,并且训练过程中图结构中的所有节点都被考虑进去。它们只能得到已经包含在训练过程中的节点嵌入,对于训练过程中没有出现过的未知节点则束手无策。由于它们在一个固定的图上直接生成最终的节点嵌入,如果这个图的结构稍后有所改变,就需要重新训练。

   直推式学习已经预先观察了所有数据,含训练和测试数据集。 从已经观察到的数据集中学习,然后预测测试数据集的标签。 即过程会利用这些不知道数据标签的测试集数据的模式和其他信息。还有一个区别是,一旦有新的节点出现,直推式学习需要重新训练模型。

Inductive Learning

  主要观点是:节点的嵌入可以通过一个共同的聚合邻居节点信息的函数得到,在训练时只要得到这个聚合函数,就可以将其泛化到未知的节点上。

Factorization-based embedding approaches
  • 一些使用随机游走统计和基于矩阵分解的学习目标的节点嵌入方法。
  • 这些嵌入算法中的大多数直接为单个节点训练节点嵌入。 因此需要昂贵的额外训练(例如,通过随机梯度下降)来对新节点进行预测。

Supervised learning over graphs

  • 基于 Graph kernel  的方法,其中图的特征向量来自不同的图内核。
Graph convolutional networks  
  • 如果半监督时带有 label 的节点过少,GCN 的性能会有比较严重的下降;
  • 浅层的 GCN 网络不能大范围地传播 label 信息 (层级越深,节点的感受野越大);
  • 深层的 GCN 网络会导致过度平滑 (smooth) 的问题;

  本文提出的 GraphSAGE(Inductive Method) 可以利用所有图中存在的结构特征(如:节点度,邻居信息),去推测 Unseen Node 的节点 Embeeding。

      

  1. 先对邻居随机采样,降低计算复杂度(Figure 1 :一跳邻居采样数=3,二跳邻居采样数=5)
  2. 生成目标节点 Emebedding:先聚合2跳邻居特征,生成一跳邻居 Embedding,再聚合一跳邻居 Embedding,生成目标节点 Embedding,从而获得二跳邻居信息。
  3. 将 Embedding 作为全连接层的输入,预测目标节点的标签。

3 GraphSAGE Method

  GraphSAGE 的核心思想:不是试图学习一个图上所有 Node Embedding,而是学习一个为每个 Node 产生 Embedding 的映射。

3.1 Embedding generation algorithm

  该部分假设模型已经被训练过了,并且参数是固定的。

  我们假设我们已经学习了 $K$ 个聚合器函数的参数,

    $\text { AGGREGATE } \left._{k}, \forall k \in\{1, \ldots, K\}\right)$

  用模型的不同层或“搜索深度”之间传播信息。

  步骤:

  GraphSAGE 的前向传播算法如下,前向传播描述了如何使用聚合函数对节点的邻居信息进行聚合,从而生成节点 Embedding:

      

  • $ \mathcal{G}=(\mathcal{V}, \mathcal{E})$ 表示一个图;
  • $ K$  是网络的层数,也代表着每个顶点能够聚合的邻接点的跳数,因为每增加一层,可以聚合更远的一层邻居的信息;
  • $ x_{v}$,$\forall v \in \mathcal{V}$  表示节点 $v$  的特征向量,并且作为输入;
  • $ \left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}$  表示在 $k-1$  层中节点 $v$  的邻居节点 $u$  的 Embedding;
  • $ \mathbf{h}_{\mathcal{N}(v)}^{k}$  表示在第 $k$  层,节点 $v$  的所有邻居节点的特征表示;
  • $ \mathbf{h}_{v}^{k}$,$\forall v \in \mathcal{V}$  表示在第 $k$  层,节点 $v$  的特征表示;
  • $ \mathcal{N}(v) $ 定义为从集合 $\{u \in v:(u, \mathcal{V}) \in \mathcal{E}\}$  中的固定 $size$ 的均匀取出,即 GraphSAGE 中每一层的节点邻居都是是从上一层网络采样的,并不是所有邻居参与,并且采样后的邻居 $size$ 是固定的;

3.2 Learning the parameters of GraphSAGE

  损失函数分为基于图的无监督损失有监督损失

  • 基于图的无监督损失:目标是使节点 $u$ 与 “邻居” $v$ 的 Embedding 相似,与无边相连的节点 $v_n$ 不相似。

    $J_{\mathcal{G}}\left(\mathbf{z}_{u}\right)=-\log \left(\sigma\left(\mathbf{z}_{u}^{\top} \mathbf{z}_{v}\right)\right)-Q \cdot \mathbb{E}_{v_{n} \sim P_{n}(v)} \log \left(\sigma\left(-\mathbf{z}_{u}^{\top} \mathbf{z}_{v_{n}}\right)\right)$

  • $z_{u}$ 为节点通过 GraphSAGE 生成的 Embedding ;
  • 节点 $v$  是节点 $u$  结果固定长度的 Random walk 到达的"邻居";
  • $v_{n} \sim P_{n}(u)$  表示负采样:节点 $v_{n}$  是从节点 $u$  的负采样分布 $P_{n}$  采样的, $Q$  为采样样本数;
  • Embedding 之间的相似度通过向量点积计算得到;
  • 基于图的有监督损失:无监督损失函数的设定来学习节点 Embedding 可以供下游多个任务使用,若仅使用在特定某个任务上,则可以替代上述损失函数符合特定任务目标,如交叉熵。

3.3 Aggregator Architectures

  算法可以应用于任意顺序的节点表示向量(即:排列不变性),所以聚集函数(aggregation function)应该是对称的。

  排列不变性(permutation invariance):指输入的顺序改变不会影响输出的值。

  这里采用 Mean aggregator 、LSTM aggregator 、Pooling aggregator。

  • Mean aggregator

    Mean aggregator 将目标顶点和邻居顶点的第 $k−1$ 层向量拼接起来,然后对向量的每个维度进行求均值的操作,将得到的结果做一次非线性变换产生目标顶点的第 $k$ 层表示向量。
    GCN 的 inductive 变形:

    $h_{v}^{k}=\sigma\left(W^{k} \cdot \operatorname{mean}\left(\left\{h_{v}^{k-1}\right\} \cup\left\{h_{u}^{k-1}, \forall u \in N(v)\right\}\right)\right.$

  Convolutional aggregator     $\begin{array}{c}h_{N(v)}^{k}=\operatorname{mean}\left(\left\{h_{u}^{k-1}, u \in N(v)\right\}\right) \\h_{v}^{k}=\sigma\left(W^{k} \cdot C O N C A T\left(h_{v}^{k-1}, h_{N(u)}^{k}\right)\right)\end{array}$
  • LSTM聚合:LSTM函数不符合 "排列不变性" 的性质,需要先对邻居随机排序,然后将随机的邻居序列 Embedding $ \left\{x_{t}, t \in N(v)\right\}$  作为 LSTM 输入。
  • Pooling 聚合:

   它既是对称的,又是可训练的。Pooling aggregator 先对目标顶点的邻居顶点的 Embedding 向量进行一次非线性变换,之后进行一次 Pooling 操作(max pooling or mean pooling),将得到结果与目标顶点的表示向量拼接,最后再经过一次非线性变换得到目标顶点的第 $k$ 层表示向量。

  一个element-wise max pooling操作应用在邻居集合上来聚合信息:
    $\text { AGGREGATE }_{k}^{\mathrm{pool}}=\max \left(\left\{\sigma\left(\mathbf{W}_{\text {pool }} \mathbf{h}_{u_{i}}^{k}+\mathbf{b}\right), \forall u_{i} \in \mathcal{N}(v)\right\}\right)$

    $\mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W}^{k} \cdot \operatorname{CONCAT}\left(\mathbf{h}_{v}^{k-1}, \mathbf{h}_{\mathcal{N}(v)}^{k}\right)\right)$

  其中

  • $max$ 表示 $element-wise$ 最大值操作, 取每个特征的最大值
  • $\sigma$  是非线性激活函数
  • 所有相邻节点的向量共享权重, 先经过一个非线性全连接层, 然后做 $max-pooling$
  • 按维度应用 $max / mean \quad pooling$,可以捕获邻居集上在某一个维度的突出的综合的表现。

4 Experiments

  在三个基准任务上测试了GraphSAGE的性能。

  datasets

  • [li]Citation 论文引用网络(节点分类)
  • Reddit 帖子论坛 (节点分类)
  • PPI 蛋白质网络 (graph分类)
[/li]

  four baselines

  • [li]Random classifer,随机分类器
  • Raw features,手工特征(非图特征)
  • Deepwalk(图拓扑特征)
  • DeepWalk + features, deepwalk+手工特征
[/li]

  基于图的无监督损失

    $J_{\mathcal{G}}\left(\mathbf{z}_{u}\right)=-\log \left(\sigma\left(\mathbf{z}_{u}^{\top} \mathbf{z}_{v}\right)\right)-Q \cdot \mathbb{E}_{v_{n} \sim P_{n}(v)} \log \left(\sigma\left(-\mathbf{z}_{u}^{\top} \mathbf{z}_{v_{n}}\right)\right)$

  基于图的有监督损失

     交叉熵

  实验设置

  • $K=2$,聚合两跳内邻居特征
  • $S_1=25,S_2=10$: 对一跳邻居抽样25个,二跳邻居抽样10个
  • RELU 激活单元
  • Adam 优化器(仅 DeepWalk 使用 SGD )
  • 文中所有的模型都是用 TensorFlow 实现
  • 对每个节点进行步长为 5 的 50 次随机游走
  • 负采样参考 Word2vec,按平滑 degree 进行,对每个节点采样 20 个
  • 保证公平性:所有版本都采用相同的minibatch迭代器、损失函数、邻居采样器
  • 实验测试了根据式1的损失函数训练的GraphSAGE的各种变体,还有在分类交叉熵损失上训练的可监督变体
  • 对于Reddit和citation数据集,使用”online”的方式来训练DeepWalk
  • 在多图情况下,不能使用DeepWalk,因为通过DeepWalk在不同不相交的图上运行后生成的embedding空间对它们彼此说可能是arbitrarily rotated的。

  实验结果1:分类准确率

      

  结论:

  • [li]GraphSAGE的性能显著优于baseline方法。
  • 三个数据集显示:一般是 LSTM 或 pooling 效果比较好,有监督都比无监督好。
  • LSTM 是为有序数据而不是无序集设计的,但是基于 LSTM 的聚合器显示了强大的性能。
  • 可以看到无监督 GraphSAGE 的性能与完全监督的版本相比具有相当的竞争力,这表明文中的框架可以在不进行特定于任务的微调( task-specific fine-tuning )的情况下实现强大的性能。
[/li]

  实验结果2:Timing experiments on Reddit data

        

  1. 计算时间:下图A中GraphSAGE中LSTM训练速度最慢,但相比DeepWalk,GraphSAGE在预测时间减少100-500倍(因为对于未知节点,DeepWalk要重新进行随机游走以及通过SGD学习embedding)
  2. 邻居抽样数量:上图B中邻居抽样数量递增,边际收益递减(F1),但计算时间也变大。 平衡F1和计算时间,将S1设为25。
  3. 聚合  K  跳内信息:在  GraphSAGE, K=2 相比  K=1 有10-15%的提升;但将K设置超过2,边际效果上只有  0-5%  的提升,但是计算时间却变大了10-100倍。

『总结不易,加个关注呗!』

 

 

   

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐