图神经网络在知识图谱中的应用

GCN/GAT/R-GCN用于知识图谱补全、链接预测、节点分类与PyTorch Geometric实现

引言

图神经网络(Graph Neural Network, GNN)是处理图结构数据的深度学习方法。传统的知识图谱嵌入方法(TransE/DistMult/ComplEx)将实体和关系映射为低维向量,但忽略了图的局部结构信息。GNN通过消息传递机制,让每个节点聚合邻居信息来更新自身表示,从而同时捕获实体语义和图结构模式。本文将系统阐述GNN在知识图谱中的核心应用——链接预测、节点分类和知识补全,并提供基于PyTorch Geometric(PyG)的工程实现。

GNN基础

消息传递范式

GNN消息传递(Message Passing)

节点v的更新过程:

Step 1: 消息生成 (Message)
  对每个邻居u, 生成消息:
  m_{u->v} = MSG(h_u, h_v, e_{u,v})

Step 2: 消息聚合 (Aggregate)
  聚合所有邻居消息:
  M_v = AGG({m_{u->v} | u in N(v)})

Step 3: 状态更新 (Update)
  更新节点表示:
  h_v' = UPD(h_v, M_v)

常见聚合方式:
  - SUM:   M_v = sum(m_{u->v})        # GCN
  - MEAN:  M_v = mean(m_{u->v})       # GraphSAGE
  - MAX:   M_v = max(m_{u->v})        # GraphSAGE
  - ATT:   M_v = sum(alpha * m_{u->v})  # GAT (注意力加权)

示意图:

      h_u1 ──msg──┐
      h_u2 ──msg──┤
      h_u3 ──msg──┼──→ AGG ──→ UPD ──→ h_v'
      h_u4 ──msg──┤              ↑
                  │            h_v (自身)

主流GNN架构对比

架构 聚合方式 关系感知 注意力 参数量 适用场景
GCN 对称归一化 同质图/通用
GAT 注意力加权 邻居重要性不同
GraphSAGE 采样+聚合 可选 大规模归纳
R-GCN 关系特定矩阵 知识图谱
CompGCN 组合操作 可选 知识图谱
HGT 异构注意力 异构图谱

GCN与GAT实现

基础GCN

import torch
import torch.nn as nn
import torch.nn.functional as F

class GCNLayer(nn.Module):
    """Graph Convolutional Network layer.

    h_v' = sigma(sum_{u in N(v)} (1/sqrt(d_u * d_v)) * W * h_u)
    """

    def __init__(self, in_dim: int, out_dim: int, bias: bool = True):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_dim, out_dim))
        self.bias = nn.Parameter(torch.FloatTensor(out_dim)) if bias else None
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Node features [N, in_dim]
            adj: Normalized adjacency matrix [N, N]
        Returns:
            Updated node features [N, out_dim]
        """
        support = torch.mm(x, self.weight)  # [N, out_dim]
        output = torch.spmm(adj, support)   # [N, out_dim]
        if self.bias is not None:
            output += self.bias
        return output


class GATLayer(nn.Module):
    """Graph Attention Network layer.

    Attention: alpha_{ij} = softmax(LeakyReLU(a^T [Wh_i || Wh_j]))
    Output: h_i' = sigma(sum_j alpha_{ij} * W * h_j)
    """

    def __init__(self, in_dim: int, out_dim: int,
                 n_heads: int = 4, dropout: float = 0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = out_dim // n_heads

        self.W = nn.Linear(in_dim, out_dim, bias=False)
        self.a = nn.Parameter(torch.FloatTensor(n_heads, 2 * self.head_dim))
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(dropout)

        nn.init.xavier_uniform_(self.a)

    def forward(self, x: torch.Tensor,
                edge_index: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Node features [N, in_dim]
            edge_index: Edge indices [2, E]
        Returns:
            Updated node features [N, out_dim]
        """
        N = x.size(0)
        h = self.W(x).view(N, self.n_heads, self.head_dim)  # [N, H, D]

        src, dst = edge_index
        h_src = h[src]  # [E, H, D]
        h_dst = h[dst]  # [E, H, D]

        # Compute attention scores
        edge_feat = torch.cat([h_src, h_dst], dim=-1)  # [E, H, 2D]
        attn = (edge_feat * self.a).sum(dim=-1)         # [E, H]
        attn = self.leaky_relu(attn)

        # Softmax per destination node
        attn_max = torch.zeros(N, self.n_heads, device=x.device)
        attn_max.scatter_reduce_(0, dst.unsqueeze(-1).expand_as(attn),
                                  attn, reduce="amax")
        attn = torch.exp(attn - attn_max[dst])
        attn_sum = torch.zeros(N, self.n_heads, device=x.device)
        attn_sum.scatter_add_(0, dst.unsqueeze(-1).expand_as(attn), attn)
        attn = attn / (attn_sum[dst] + 1e-8)
        attn = self.dropout(attn)

        # Aggregate
        msg = h_src * attn.unsqueeze(-1)  # [E, H, D]
        out = torch.zeros(N, self.n_heads, self.head_dim, device=x.device)
        out.scatter_add_(0, dst.unsqueeze(-1).unsqueeze(-1).expand_as(msg), msg)

        return out.view(N, -1)  # [N, out_dim]

R-GCN:关系感知图卷积

R-GCN核心思想

R-GCN (Relational Graph Convolutional Network)

核心改进:为每种关系类型使用不同的变换矩阵

标准GCN:
  h_v' = sigma(sum_{u in N(v)} W * h_u)
  所有边共享同一个W

R-GCN:
  h_v' = sigma(sum_{r in R} sum_{u in N_r(v)} (1/|N_r(v)|) * W_r * h_u + W_0 * h_v)
  每种关系r有自己的W_r

问题:关系太多时参数爆炸
解决方案:
  1. 基分解(Basis Decomposition):
     W_r = sum_b a_{rb} * V_b
     所有关系共享B个基矩阵V_b,每个关系用不同系数a_rb组合

  2. 块对角分解(Block Diagonal):
     W_r = diag(W_r^1, W_r^2, ..., W_r^B)
     每个关系的变换矩阵是块对角的

R-GCN实现

class RGCNLayer(nn.Module):
    """Relational Graph Convolutional Network layer with basis decomposition."""

    def __init__(self, in_dim: int, out_dim: int,
                 n_relations: int, n_bases: int = 4):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.n_relations = n_relations
        self.n_bases = n_bases

        # Basis matrices shared across relations
        self.bases = nn.Parameter(
            torch.FloatTensor(n_bases, in_dim, out_dim)
        )
        # Coefficients per relation
        self.coefficients = nn.Parameter(
            torch.FloatTensor(n_relations, n_bases)
        )
        # Self-loop transformation
        self.self_loop = nn.Linear(in_dim, out_dim)

        nn.init.xavier_uniform_(self.bases)
        nn.init.xavier_uniform_(self.coefficients)

    def forward(self, x: torch.Tensor,
                edge_index: torch.Tensor,
                edge_type: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Node features [N, in_dim]
            edge_index: Edge indices [2, E]
            edge_type: Relation type per edge [E]
        Returns:
            Updated node features [N, out_dim]
        """
        N = x.size(0)

        # Compute relation-specific weight matrices via basis decomposition
        # W_r = sum_b coeff[r,b] * bases[b]
        weights = torch.einsum('rb,bij->rij', self.coefficients, self.bases)
        # weights: [n_relations, in_dim, out_dim]

        # Message passing per relation
        out = torch.zeros(N, self.out_dim, device=x.device)
        src, dst = edge_index

        for r in range(self.n_relations):
            mask = edge_type == r
            if mask.sum() == 0:
                continue

            src_r = src[mask]
            dst_r = dst[mask]

            # Transform source node features with relation-specific weight
            h_src = torch.mm(x[src_r], weights[r])  # [E_r, out_dim]

            # Normalization factor
            deg = torch.zeros(N, device=x.device)
            deg.scatter_add_(0, dst_r, torch.ones_like(dst_r, dtype=torch.float))

            # Aggregate
            out.scatter_add_(0, dst_r.unsqueeze(-1).expand_as(h_src), h_src)

        # Normalize by degree
        total_deg = torch.zeros(N, device=x.device)
        total_deg.scatter_add_(0, dst, torch.ones(dst.size(0), device=x.device))
        out = out / (total_deg.unsqueeze(-1) + 1e-8)

        # Add self-loop
        out = out + self.self_loop(x)

        return out


class RGCNModel(nn.Module):
    """Multi-layer R-GCN for knowledge graph tasks."""

    def __init__(self, n_entities: int, n_relations: int,
                 hidden_dim: int = 128, n_layers: int = 2,
                 n_bases: int = 4, dropout: float = 0.2):
        super().__init__()
        self.embedding = nn.Embedding(n_entities, hidden_dim)
        self.layers = nn.ModuleList()

        for i in range(n_layers):
            self.layers.append(
                RGCNLayer(hidden_dim, hidden_dim, n_relations, n_bases)
            )

        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.embedding.weight)

    def forward(self, edge_index: torch.Tensor,
                edge_type: torch.Tensor) -> torch.Tensor:
        """Get all entity embeddings."""
        x = self.embedding.weight

        for layer in self.layers:
            x = layer(x, edge_index, edge_type)
            x = F.relu(x)
            x = self.dropout(x)

        return x

链接预测

基于GNN的链接预测

class LinkPredictor(nn.Module):
    """Link prediction using R-GCN encoder + DistMult decoder."""

    def __init__(self, n_entities: int, n_relations: int,
                 hidden_dim: int = 128, n_bases: int = 4):
        super().__init__()
        self.encoder = RGCNModel(n_entities, n_relations,
                                  hidden_dim, n_bases=n_bases)
        # DistMult decoder: score = h^T * diag(r) * t
        self.relation_emb = nn.Embedding(n_relations, hidden_dim)

    def forward(self, edge_index: torch.Tensor,
                edge_type: torch.Tensor) -> torch.Tensor:
        """Encode all entities."""
        return self.encoder(edge_index, edge_type)

    def score(self, entity_emb: torch.Tensor,
              head: torch.Tensor, relation: torch.Tensor,
              tail: torch.Tensor) -> torch.Tensor:
        """Score triples using DistMult.

        Args:
            entity_emb: All entity embeddings [N, D]
            head: Head entity indices [B]
            relation: Relation indices [B]
            tail: Tail entity indices [B]
        Returns:
            Scores [B]
        """
        h = entity_emb[head]                 # [B, D]
        r = self.relation_emb(relation)      # [B, D]
        t = entity_emb[tail]                 # [B, D]
        return (h * r * t).sum(dim=-1)       # [B]

    def predict_tail(self, entity_emb: torch.Tensor,
                     head: int, relation: int,
                     top_k: int = 10) -> list[tuple[int, float]]:
        """Predict top-K most likely tail entities."""
        h = entity_emb[head].unsqueeze(0)         # [1, D]
        r = self.relation_emb(torch.tensor([relation]))  # [1, D]

        # Score against all entities
        scores = (h * r * entity_emb).sum(dim=-1)  # [N]
        top_scores, top_indices = torch.topk(scores, top_k)

        return list(zip(
            top_indices.tolist(),
            top_scores.tolist(),
        ))


def train_link_prediction(model: LinkPredictor,
                           train_triples: torch.Tensor,
                           edge_index: torch.Tensor,
                           edge_type: torch.Tensor,
                           n_entities: int,
                           epochs: int = 100,
                           lr: float = 0.01):
    """Train link prediction model with negative sampling."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        # Encode entities
        entity_emb = model(edge_index, edge_type)

        # Positive samples
        head, relation, tail = train_triples.T
        pos_scores = model.score(entity_emb, head, relation, tail)

        # Negative sampling: corrupt tail
        neg_tail = torch.randint(0, n_entities, tail.shape)
        neg_scores = model.score(entity_emb, head, relation, neg_tail)

        # Margin ranking loss
        target = torch.ones_like(pos_scores)
        loss = F.margin_ranking_loss(pos_scores, neg_scores, target, margin=1.0)

        loss.backward()
        optimizer.step()

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

节点分类

基于GNN的实体类型预测

class EntityClassifier(nn.Module):
    """Classify entity types using R-GCN features."""

    def __init__(self, n_entities: int, n_relations: int,
                 n_classes: int, hidden_dim: int = 128):
        super().__init__()
        self.encoder = RGCNModel(n_entities, n_relations, hidden_dim)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, n_classes),
        )

    def forward(self, edge_index: torch.Tensor,
                edge_type: torch.Tensor) -> torch.Tensor:
        """Predict entity types for all nodes."""
        entity_emb = self.encoder(edge_index, edge_type)  # [N, D]
        logits = self.classifier(entity_emb)               # [N, C]
        return logits

    def predict(self, edge_index, edge_type, node_ids=None):
        """Predict entity types."""
        self.eval()
        with torch.no_grad():
            logits = self.forward(edge_index, edge_type)
            if node_ids is not None:
                logits = logits[node_ids]
            probs = F.softmax(logits, dim=-1)
            pred = probs.argmax(dim=-1)
        return pred, probs

PyTorch Geometric实战

使用PyG构建KG模型

# pip install torch-geometric

import torch
from torch_geometric.nn import RGCNConv, FastRGCNConv
from torch_geometric.data import Data

class PyGKGModel(torch.nn.Module):
    """Knowledge graph model using PyG's built-in R-GCN."""

    def __init__(self, n_entities: int, n_relations: int,
                 hidden_dim: int = 128, out_dim: int = 64,
                 n_bases: int = 30):
        super().__init__()
        self.emb = torch.nn.Embedding(n_entities, hidden_dim)
        self.conv1 = RGCNConv(hidden_dim, hidden_dim,
                               n_relations, num_bases=n_bases)
        self.conv2 = RGCNConv(hidden_dim, out_dim,
                               n_relations, num_bases=n_bases)
        self.dropout = torch.nn.Dropout(0.2)

    def forward(self, edge_index, edge_type):
        x = self.emb.weight
        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index, edge_type)
        return x


def build_pyg_data(triples: list[tuple[int, int, int]],
                    n_entities: int) -> Data:
    """Convert KG triples to PyG Data object.

    Args:
        triples: List of (head, relation, tail) integer tuples
        n_entities: Total number of entities
    """
    heads, relations, tails = zip(*triples)

    # Make bidirectional (add inverse edges)
    edge_index = torch.tensor(
        [list(heads) + list(tails),
         list(tails) + list(heads)],
        dtype=torch.long,
    )

    # Inverse relation types offset by n_relations
    n_rels = max(relations) + 1
    edge_type = torch.tensor(
        list(relations) + [r + n_rels for r in relations],
        dtype=torch.long,
    )

    data = Data(
        edge_index=edge_index,
        edge_type=edge_type,
        num_nodes=n_entities,
    )
    return data, n_rels * 2  # Total relation types including inverses

评估指标

链接预测评估

指标 定义 计算方式 越高/低越好
MRR 正确实体排名的倒数均值 mean(1/rank) 越高越好
Hits@1 排名第1的比例 count(rank==1)/total 越高越好
Hits@3 排名前3的比例 count(rank<=3)/total 越高越好
Hits@10 排名前10的比例 count(rank<=10)/total 越高越好
MR 平均排名 mean(rank) 越低越好
def evaluate_link_prediction(model: LinkPredictor,
                              test_triples: torch.Tensor,
                              edge_index: torch.Tensor,
                              edge_type: torch.Tensor,
                              n_entities: int) -> dict:
    """Evaluate link prediction with standard KG metrics."""
    model.eval()
    ranks = []

    with torch.no_grad():
        entity_emb = model(edge_index, edge_type)

        for triple in test_triples:
            h, r, t = triple

            # Score all possible tails
            all_tails = torch.arange(n_entities)
            h_repeat = h.expand(n_entities)
            r_repeat = r.expand(n_entities)

            scores = model.score(entity_emb, h_repeat, r_repeat, all_tails)

            # Rank of correct tail
            correct_score = scores[t]
            rank = (scores >= correct_score).sum().item()
            ranks.append(rank)

    ranks = torch.tensor(ranks, dtype=torch.float)

    return {
        "MRR": float((1.0 / ranks).mean()),
        "MR": float(ranks.mean()),
        "Hits@1": float((ranks <= 1).float().mean()),
        "Hits@3": float((ranks <= 3).float().mean()),
        "Hits@10": float((ranks <= 10).float().mean()),
    }

结论

图神经网络为知识图谱带来了结构感知的表示学习能力。R-GCN通过关系特定的消息传递,让模型能够区分不同类型的边在信息聚合中的不同作用。在链接预测任务上,GNN编码器+评分函数解码器的架构已成为主流;在节点分类任务上,GNN的半监督学习能力可以用少量标注数据预测大量未标注实体的类型。工程实践中,PyTorch Geometric提供了成熟的R-GCN实现,基分解策略有效控制了参数量,使得大规模知识图谱上的GNN训练成为可能。未来方向包括更高效的异构图注意力网络(HGT)、与LLM的联合训练以及归纳式GNN在新实体上的泛化能力。


Maurice | maurice_wen@proton.me