图神经网络在知识图谱中的应用
原创
灵阙教研团队
S 精选 进阶 |
约 10 分钟阅读
更新于 2026-02-28 AI 导读
图神经网络在知识图谱中的应用 GCN/GAT/R-GCN用于知识图谱补全、链接预测、节点分类与PyTorch Geometric实现 引言 图神经网络(Graph Neural Network,...
图神经网络在知识图谱中的应用
GCN/GAT/R-GCN用于知识图谱补全、链接预测、节点分类与PyTorch Geometric实现
引言
图神经网络(Graph Neural Network, GNN)是处理图结构数据的深度学习方法。传统的知识图谱嵌入方法(TransE/DistMult/ComplEx)将实体和关系映射为低维向量,但忽略了图的局部结构信息。GNN通过消息传递机制,让每个节点聚合邻居信息来更新自身表示,从而同时捕获实体语义和图结构模式。本文将系统阐述GNN在知识图谱中的核心应用——链接预测、节点分类和知识补全,并提供基于PyTorch Geometric(PyG)的工程实现。
GNN基础
消息传递范式
GNN消息传递(Message Passing)
节点v的更新过程:
Step 1: 消息生成 (Message)
对每个邻居u, 生成消息:
m_{u->v} = MSG(h_u, h_v, e_{u,v})
Step 2: 消息聚合 (Aggregate)
聚合所有邻居消息:
M_v = AGG({m_{u->v} | u in N(v)})
Step 3: 状态更新 (Update)
更新节点表示:
h_v' = UPD(h_v, M_v)
常见聚合方式:
- SUM: M_v = sum(m_{u->v}) # GCN
- MEAN: M_v = mean(m_{u->v}) # GraphSAGE
- MAX: M_v = max(m_{u->v}) # GraphSAGE
- ATT: M_v = sum(alpha * m_{u->v}) # GAT (注意力加权)
示意图:
h_u1 ──msg──┐
h_u2 ──msg──┤
h_u3 ──msg──┼──→ AGG ──→ UPD ──→ h_v'
h_u4 ──msg──┤ ↑
│ h_v (自身)
主流GNN架构对比
| 架构 | 聚合方式 | 关系感知 | 注意力 | 参数量 | 适用场景 |
|---|---|---|---|---|---|
| GCN | 对称归一化 | 否 | 否 | 低 | 同质图/通用 |
| GAT | 注意力加权 | 否 | 是 | 中 | 邻居重要性不同 |
| GraphSAGE | 采样+聚合 | 否 | 可选 | 中 | 大规模归纳 |
| R-GCN | 关系特定矩阵 | 是 | 否 | 高 | 知识图谱 |
| CompGCN | 组合操作 | 是 | 可选 | 中 | 知识图谱 |
| HGT | 异构注意力 | 是 | 是 | 高 | 异构图谱 |
GCN与GAT实现
基础GCN
import torch
import torch.nn as nn
import torch.nn.functional as F
class GCNLayer(nn.Module):
"""Graph Convolutional Network layer.
h_v' = sigma(sum_{u in N(v)} (1/sqrt(d_u * d_v)) * W * h_u)
"""
def __init__(self, in_dim: int, out_dim: int, bias: bool = True):
super().__init__()
self.weight = nn.Parameter(torch.FloatTensor(in_dim, out_dim))
self.bias = nn.Parameter(torch.FloatTensor(out_dim)) if bias else None
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Node features [N, in_dim]
adj: Normalized adjacency matrix [N, N]
Returns:
Updated node features [N, out_dim]
"""
support = torch.mm(x, self.weight) # [N, out_dim]
output = torch.spmm(adj, support) # [N, out_dim]
if self.bias is not None:
output += self.bias
return output
class GATLayer(nn.Module):
"""Graph Attention Network layer.
Attention: alpha_{ij} = softmax(LeakyReLU(a^T [Wh_i || Wh_j]))
Output: h_i' = sigma(sum_j alpha_{ij} * W * h_j)
"""
def __init__(self, in_dim: int, out_dim: int,
n_heads: int = 4, dropout: float = 0.1):
super().__init__()
self.n_heads = n_heads
self.head_dim = out_dim // n_heads
self.W = nn.Linear(in_dim, out_dim, bias=False)
self.a = nn.Parameter(torch.FloatTensor(n_heads, 2 * self.head_dim))
self.leaky_relu = nn.LeakyReLU(0.2)
self.dropout = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.a)
def forward(self, x: torch.Tensor,
edge_index: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Node features [N, in_dim]
edge_index: Edge indices [2, E]
Returns:
Updated node features [N, out_dim]
"""
N = x.size(0)
h = self.W(x).view(N, self.n_heads, self.head_dim) # [N, H, D]
src, dst = edge_index
h_src = h[src] # [E, H, D]
h_dst = h[dst] # [E, H, D]
# Compute attention scores
edge_feat = torch.cat([h_src, h_dst], dim=-1) # [E, H, 2D]
attn = (edge_feat * self.a).sum(dim=-1) # [E, H]
attn = self.leaky_relu(attn)
# Softmax per destination node
attn_max = torch.zeros(N, self.n_heads, device=x.device)
attn_max.scatter_reduce_(0, dst.unsqueeze(-1).expand_as(attn),
attn, reduce="amax")
attn = torch.exp(attn - attn_max[dst])
attn_sum = torch.zeros(N, self.n_heads, device=x.device)
attn_sum.scatter_add_(0, dst.unsqueeze(-1).expand_as(attn), attn)
attn = attn / (attn_sum[dst] + 1e-8)
attn = self.dropout(attn)
# Aggregate
msg = h_src * attn.unsqueeze(-1) # [E, H, D]
out = torch.zeros(N, self.n_heads, self.head_dim, device=x.device)
out.scatter_add_(0, dst.unsqueeze(-1).unsqueeze(-1).expand_as(msg), msg)
return out.view(N, -1) # [N, out_dim]
R-GCN:关系感知图卷积
R-GCN核心思想
R-GCN (Relational Graph Convolutional Network)
核心改进:为每种关系类型使用不同的变换矩阵
标准GCN:
h_v' = sigma(sum_{u in N(v)} W * h_u)
所有边共享同一个W
R-GCN:
h_v' = sigma(sum_{r in R} sum_{u in N_r(v)} (1/|N_r(v)|) * W_r * h_u + W_0 * h_v)
每种关系r有自己的W_r
问题:关系太多时参数爆炸
解决方案:
1. 基分解(Basis Decomposition):
W_r = sum_b a_{rb} * V_b
所有关系共享B个基矩阵V_b,每个关系用不同系数a_rb组合
2. 块对角分解(Block Diagonal):
W_r = diag(W_r^1, W_r^2, ..., W_r^B)
每个关系的变换矩阵是块对角的
R-GCN实现
class RGCNLayer(nn.Module):
"""Relational Graph Convolutional Network layer with basis decomposition."""
def __init__(self, in_dim: int, out_dim: int,
n_relations: int, n_bases: int = 4):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.n_relations = n_relations
self.n_bases = n_bases
# Basis matrices shared across relations
self.bases = nn.Parameter(
torch.FloatTensor(n_bases, in_dim, out_dim)
)
# Coefficients per relation
self.coefficients = nn.Parameter(
torch.FloatTensor(n_relations, n_bases)
)
# Self-loop transformation
self.self_loop = nn.Linear(in_dim, out_dim)
nn.init.xavier_uniform_(self.bases)
nn.init.xavier_uniform_(self.coefficients)
def forward(self, x: torch.Tensor,
edge_index: torch.Tensor,
edge_type: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Node features [N, in_dim]
edge_index: Edge indices [2, E]
edge_type: Relation type per edge [E]
Returns:
Updated node features [N, out_dim]
"""
N = x.size(0)
# Compute relation-specific weight matrices via basis decomposition
# W_r = sum_b coeff[r,b] * bases[b]
weights = torch.einsum('rb,bij->rij', self.coefficients, self.bases)
# weights: [n_relations, in_dim, out_dim]
# Message passing per relation
out = torch.zeros(N, self.out_dim, device=x.device)
src, dst = edge_index
for r in range(self.n_relations):
mask = edge_type == r
if mask.sum() == 0:
continue
src_r = src[mask]
dst_r = dst[mask]
# Transform source node features with relation-specific weight
h_src = torch.mm(x[src_r], weights[r]) # [E_r, out_dim]
# Normalization factor
deg = torch.zeros(N, device=x.device)
deg.scatter_add_(0, dst_r, torch.ones_like(dst_r, dtype=torch.float))
# Aggregate
out.scatter_add_(0, dst_r.unsqueeze(-1).expand_as(h_src), h_src)
# Normalize by degree
total_deg = torch.zeros(N, device=x.device)
total_deg.scatter_add_(0, dst, torch.ones(dst.size(0), device=x.device))
out = out / (total_deg.unsqueeze(-1) + 1e-8)
# Add self-loop
out = out + self.self_loop(x)
return out
class RGCNModel(nn.Module):
"""Multi-layer R-GCN for knowledge graph tasks."""
def __init__(self, n_entities: int, n_relations: int,
hidden_dim: int = 128, n_layers: int = 2,
n_bases: int = 4, dropout: float = 0.2):
super().__init__()
self.embedding = nn.Embedding(n_entities, hidden_dim)
self.layers = nn.ModuleList()
for i in range(n_layers):
self.layers.append(
RGCNLayer(hidden_dim, hidden_dim, n_relations, n_bases)
)
self.dropout = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.embedding.weight)
def forward(self, edge_index: torch.Tensor,
edge_type: torch.Tensor) -> torch.Tensor:
"""Get all entity embeddings."""
x = self.embedding.weight
for layer in self.layers:
x = layer(x, edge_index, edge_type)
x = F.relu(x)
x = self.dropout(x)
return x
链接预测
基于GNN的链接预测
class LinkPredictor(nn.Module):
"""Link prediction using R-GCN encoder + DistMult decoder."""
def __init__(self, n_entities: int, n_relations: int,
hidden_dim: int = 128, n_bases: int = 4):
super().__init__()
self.encoder = RGCNModel(n_entities, n_relations,
hidden_dim, n_bases=n_bases)
# DistMult decoder: score = h^T * diag(r) * t
self.relation_emb = nn.Embedding(n_relations, hidden_dim)
def forward(self, edge_index: torch.Tensor,
edge_type: torch.Tensor) -> torch.Tensor:
"""Encode all entities."""
return self.encoder(edge_index, edge_type)
def score(self, entity_emb: torch.Tensor,
head: torch.Tensor, relation: torch.Tensor,
tail: torch.Tensor) -> torch.Tensor:
"""Score triples using DistMult.
Args:
entity_emb: All entity embeddings [N, D]
head: Head entity indices [B]
relation: Relation indices [B]
tail: Tail entity indices [B]
Returns:
Scores [B]
"""
h = entity_emb[head] # [B, D]
r = self.relation_emb(relation) # [B, D]
t = entity_emb[tail] # [B, D]
return (h * r * t).sum(dim=-1) # [B]
def predict_tail(self, entity_emb: torch.Tensor,
head: int, relation: int,
top_k: int = 10) -> list[tuple[int, float]]:
"""Predict top-K most likely tail entities."""
h = entity_emb[head].unsqueeze(0) # [1, D]
r = self.relation_emb(torch.tensor([relation])) # [1, D]
# Score against all entities
scores = (h * r * entity_emb).sum(dim=-1) # [N]
top_scores, top_indices = torch.topk(scores, top_k)
return list(zip(
top_indices.tolist(),
top_scores.tolist(),
))
def train_link_prediction(model: LinkPredictor,
train_triples: torch.Tensor,
edge_index: torch.Tensor,
edge_type: torch.Tensor,
n_entities: int,
epochs: int = 100,
lr: float = 0.01):
"""Train link prediction model with negative sampling."""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
# Encode entities
entity_emb = model(edge_index, edge_type)
# Positive samples
head, relation, tail = train_triples.T
pos_scores = model.score(entity_emb, head, relation, tail)
# Negative sampling: corrupt tail
neg_tail = torch.randint(0, n_entities, tail.shape)
neg_scores = model.score(entity_emb, head, relation, neg_tail)
# Margin ranking loss
target = torch.ones_like(pos_scores)
loss = F.margin_ranking_loss(pos_scores, neg_scores, target, margin=1.0)
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
节点分类
基于GNN的实体类型预测
class EntityClassifier(nn.Module):
"""Classify entity types using R-GCN features."""
def __init__(self, n_entities: int, n_relations: int,
n_classes: int, hidden_dim: int = 128):
super().__init__()
self.encoder = RGCNModel(n_entities, n_relations, hidden_dim)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim // 2, n_classes),
)
def forward(self, edge_index: torch.Tensor,
edge_type: torch.Tensor) -> torch.Tensor:
"""Predict entity types for all nodes."""
entity_emb = self.encoder(edge_index, edge_type) # [N, D]
logits = self.classifier(entity_emb) # [N, C]
return logits
def predict(self, edge_index, edge_type, node_ids=None):
"""Predict entity types."""
self.eval()
with torch.no_grad():
logits = self.forward(edge_index, edge_type)
if node_ids is not None:
logits = logits[node_ids]
probs = F.softmax(logits, dim=-1)
pred = probs.argmax(dim=-1)
return pred, probs
PyTorch Geometric实战
使用PyG构建KG模型
# pip install torch-geometric
import torch
from torch_geometric.nn import RGCNConv, FastRGCNConv
from torch_geometric.data import Data
class PyGKGModel(torch.nn.Module):
"""Knowledge graph model using PyG's built-in R-GCN."""
def __init__(self, n_entities: int, n_relations: int,
hidden_dim: int = 128, out_dim: int = 64,
n_bases: int = 30):
super().__init__()
self.emb = torch.nn.Embedding(n_entities, hidden_dim)
self.conv1 = RGCNConv(hidden_dim, hidden_dim,
n_relations, num_bases=n_bases)
self.conv2 = RGCNConv(hidden_dim, out_dim,
n_relations, num_bases=n_bases)
self.dropout = torch.nn.Dropout(0.2)
def forward(self, edge_index, edge_type):
x = self.emb.weight
x = self.conv1(x, edge_index, edge_type)
x = F.relu(x)
x = self.dropout(x)
x = self.conv2(x, edge_index, edge_type)
return x
def build_pyg_data(triples: list[tuple[int, int, int]],
n_entities: int) -> Data:
"""Convert KG triples to PyG Data object.
Args:
triples: List of (head, relation, tail) integer tuples
n_entities: Total number of entities
"""
heads, relations, tails = zip(*triples)
# Make bidirectional (add inverse edges)
edge_index = torch.tensor(
[list(heads) + list(tails),
list(tails) + list(heads)],
dtype=torch.long,
)
# Inverse relation types offset by n_relations
n_rels = max(relations) + 1
edge_type = torch.tensor(
list(relations) + [r + n_rels for r in relations],
dtype=torch.long,
)
data = Data(
edge_index=edge_index,
edge_type=edge_type,
num_nodes=n_entities,
)
return data, n_rels * 2 # Total relation types including inverses
评估指标
链接预测评估
| 指标 | 定义 | 计算方式 | 越高/低越好 |
|---|---|---|---|
| MRR | 正确实体排名的倒数均值 | mean(1/rank) |
越高越好 |
| Hits@1 | 排名第1的比例 | count(rank==1)/total |
越高越好 |
| Hits@3 | 排名前3的比例 | count(rank<=3)/total |
越高越好 |
| Hits@10 | 排名前10的比例 | count(rank<=10)/total |
越高越好 |
| MR | 平均排名 | mean(rank) |
越低越好 |
def evaluate_link_prediction(model: LinkPredictor,
test_triples: torch.Tensor,
edge_index: torch.Tensor,
edge_type: torch.Tensor,
n_entities: int) -> dict:
"""Evaluate link prediction with standard KG metrics."""
model.eval()
ranks = []
with torch.no_grad():
entity_emb = model(edge_index, edge_type)
for triple in test_triples:
h, r, t = triple
# Score all possible tails
all_tails = torch.arange(n_entities)
h_repeat = h.expand(n_entities)
r_repeat = r.expand(n_entities)
scores = model.score(entity_emb, h_repeat, r_repeat, all_tails)
# Rank of correct tail
correct_score = scores[t]
rank = (scores >= correct_score).sum().item()
ranks.append(rank)
ranks = torch.tensor(ranks, dtype=torch.float)
return {
"MRR": float((1.0 / ranks).mean()),
"MR": float(ranks.mean()),
"Hits@1": float((ranks <= 1).float().mean()),
"Hits@3": float((ranks <= 3).float().mean()),
"Hits@10": float((ranks <= 10).float().mean()),
}
结论
图神经网络为知识图谱带来了结构感知的表示学习能力。R-GCN通过关系特定的消息传递,让模型能够区分不同类型的边在信息聚合中的不同作用。在链接预测任务上,GNN编码器+评分函数解码器的架构已成为主流;在节点分类任务上,GNN的半监督学习能力可以用少量标注数据预测大量未标注实体的类型。工程实践中,PyTorch Geometric提供了成熟的R-GCN实现,基分解策略有效控制了参数量,使得大规模知识图谱上的GNN训练成为可能。未来方向包括更高效的异构图注意力网络(HGT)、与LLM的联合训练以及归纳式GNN在新实体上的泛化能力。
Maurice | maurice_wen@proton.me