知识图谱驱动的推荐系统

KG增强协同过滤、基于路径的推理、可解释推荐与工程实现

引言

传统推荐系统面临两大核心挑战:冷启动问题(新用户/新物品缺乏交互数据)和可解释性不足(用户不知道为什么被推荐某样东西)。知识图谱作为结构化的领域知识库,天然具备丰富的实体属性和关系信息,能够从根本上缓解这两个问题。通过将用户-物品交互图与知识图谱连接,推荐系统可以利用物品的属性、类别、创作者等结构化信息进行推理,不仅提升推荐质量,还能生成人类可理解的推荐解释。本文将系统阐述知识图谱驱动推荐系统的四种范式及其工程实现。

推荐范式概览

四种KG推荐范式

KG驱动推荐系统架构

范式1: KG-Enhanced Embedding(嵌入增强)
  User/Item ──→ [KG嵌入对齐] ──→ 联合嵌入空间 ──→ 相似度匹配
  代表: CKE, KGAT, KGIN
  特点: 端到端训练,利用KG结构学习更好的向量表示

范式2: Path-based Reasoning(路径推理)
  User ──→ [KG路径搜索] ──→ 路径特征 ──→ 推荐+解释
  代表: PER, KPRN, PGPR
  特点: 可解释性强,推理过程透明

范式3: Propagation-based(传播增强)
  User ──→ [KG信息传播] ──→ 多跳聚合 ──→ 用户偏好建模
  代表: RippleNet, KGCN, KGNN-LS
  特点: 自动发现远距离关联

范式4: Hybrid(混合方法)
  User ──→ [协同过滤 + KG嵌入 + 路径推理] ──→ 融合排序
  代表: KGAT, CKAN
  特点: 综合利用多种信号

范式对比

范式 推荐质量 可解释性 冷启动 计算成本 实现复杂度
嵌入增强
路径推理 中高
传播增强
混合方法 最高 中高 最好

数据模型

推荐知识图谱Schema

推荐系统知识图谱结构

用户交互层:
  (User)───[:PURCHASED]──→(Item)
  (User)───[:VIEWED]────→(Item)
  (User)───[:RATED {score: 5}]──→(Item)
  (User)───[:BOOKMARKED]──→(Item)

物品知识层:
  (Item)───[:BELONGS_TO]──→(Category)
  (Item)───[:HAS_BRAND]──→(Brand)
  (Item)───[:CREATED_BY]──→(Creator)
  (Item)───[:HAS_TAG]────→(Tag)
  (Item)───[:SIMILAR_TO]──→(Item)

属性层:
  (Item)───[:HAS_FEATURE]──→(Feature)
  (Creator)───[:WORKS_IN]──→(Genre)
  (Category)───[:PARENT]──→(Category)
  (Brand)───[:FROM_COUNTRY]──→(Country)

数据准备

from dataclasses import dataclass, field
import numpy as np

@dataclass
class RecKGData:
    """Data container for KG-based recommendation."""
    # User-item interactions
    user_item: list[tuple[int, int, float]]  # (user_id, item_id, rating)

    # Knowledge graph triples
    kg_triples: list[tuple[int, int, int]]   # (head, relation, tail)

    # Entity/relation mappings
    entity2id: dict[str, int] = field(default_factory=dict)
    relation2id: dict[str, int] = field(default_factory=dict)
    item2entity: dict[int, int] = field(default_factory=dict)  # item_id -> entity_id

    @property
    def n_users(self) -> int:
        return len(set(u for u, _, _ in self.user_item))

    @property
    def n_items(self) -> int:
        return len(set(i for _, i, _ in self.user_item))

    @property
    def n_entities(self) -> int:
        return len(self.entity2id)

    @property
    def n_relations(self) -> int:
        return len(self.relation2id)

    def get_user_history(self, user_id: int) -> list[int]:
        """Get items a user has interacted with."""
        return [item for u, item, _ in self.user_item if u == user_id]

    def get_item_neighbors(self, entity_id: int,
                            max_hops: int = 2) -> dict[int, list]:
        """Get KG neighbors of an item entity up to max_hops."""
        neighbors = {0: [entity_id]}
        visited = {entity_id}

        for hop in range(1, max_hops + 1):
            hop_neighbors = []
            for eid in neighbors[hop - 1]:
                for h, r, t in self.kg_triples:
                    if h == eid and t not in visited:
                        hop_neighbors.append(t)
                        visited.add(t)
                    elif t == eid and h not in visited:
                        hop_neighbors.append(h)
                        visited.add(h)
            neighbors[hop] = hop_neighbors

        return neighbors

范式一:KG嵌入增强推荐

联合嵌入学习

import torch
import torch.nn as nn
import torch.nn.functional as F

class KGEnhancedCF(nn.Module):
    """Knowledge Graph Enhanced Collaborative Filtering.

    Joint learning of user-item CF and KG entity embeddings.
    """

    def __init__(self, n_users: int, n_items: int,
                 n_entities: int, n_relations: int,
                 embed_dim: int = 64):
        super().__init__()
        self.user_emb = nn.Embedding(n_users, embed_dim)
        self.entity_emb = nn.Embedding(n_entities, embed_dim)
        self.relation_emb = nn.Embedding(n_relations, embed_dim)

        # Item embeddings are shared with entity embeddings
        # item_id -> entity_id mapping is handled externally

        self.fc_combine = nn.Sequential(
            nn.Linear(embed_dim * 2, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
        )

        nn.init.xavier_uniform_(self.user_emb.weight)
        nn.init.xavier_uniform_(self.entity_emb.weight)
        nn.init.xavier_uniform_(self.relation_emb.weight)

    def get_item_embedding(self, item_entity_ids: torch.Tensor,
                            kg_context: torch.Tensor = None) -> torch.Tensor:
        """Get item embedding enriched by KG context.

        Args:
            item_entity_ids: Entity IDs corresponding to items [B]
            kg_context: Average embedding of item's KG neighbors [B, D]
        """
        entity_vec = self.entity_emb(item_entity_ids)  # [B, D]

        if kg_context is not None:
            combined = torch.cat([entity_vec, kg_context], dim=-1)
            return self.fc_combine(combined)
        return entity_vec

    def cf_loss(self, user_ids: torch.Tensor,
                pos_items: torch.Tensor,
                neg_items: torch.Tensor) -> torch.Tensor:
        """BPR loss for collaborative filtering."""
        user_vec = self.user_emb(user_ids)
        pos_vec = self.entity_emb(pos_items)
        neg_vec = self.entity_emb(neg_items)

        pos_score = (user_vec * pos_vec).sum(dim=-1)
        neg_score = (user_vec * neg_vec).sum(dim=-1)

        loss = -torch.log(torch.sigmoid(pos_score - neg_score) + 1e-8).mean()
        return loss

    def kg_loss(self, head: torch.Tensor, relation: torch.Tensor,
                tail: torch.Tensor, neg_tail: torch.Tensor) -> torch.Tensor:
        """TransE loss for KG embedding."""
        h = self.entity_emb(head)
        r = self.relation_emb(relation)
        t = self.entity_emb(tail)
        nt = self.entity_emb(neg_tail)

        pos_dist = torch.norm(h + r - t, p=2, dim=-1)
        neg_dist = torch.norm(h + r - nt, p=2, dim=-1)

        loss = F.relu(pos_dist - neg_dist + 1.0).mean()
        return loss

    def recommend(self, user_id: int, item_entity_ids: torch.Tensor,
                   top_k: int = 10) -> list[tuple[int, float]]:
        """Generate top-K recommendations for a user."""
        self.eval()
        with torch.no_grad():
            user_vec = self.user_emb(torch.tensor([user_id]))  # [1, D]
            item_vecs = self.entity_emb(item_entity_ids)       # [N, D]
            scores = (user_vec * item_vecs).sum(dim=-1)        # [N]

            top_scores, top_idx = torch.topk(scores, top_k)
            return list(zip(
                item_entity_ids[top_idx].tolist(),
                top_scores.tolist(),
            ))

范式二:基于路径的推理推荐

KG路径特征

class PathBasedRecommender:
    """Path-based recommendation with KG reasoning."""

    def __init__(self, kg_triples: list[tuple], entity2name: dict):
        self.adj: dict[int, list[tuple[int, int]]] = {}  # entity -> [(relation, target)]
        self.entity2name = entity2name

        for h, r, t in kg_triples:
            self.adj.setdefault(h, []).append((r, t))
            self.adj.setdefault(t, []).append((r + 1000, h))  # Inverse relation

    def find_paths(self, source: int, target: int,
                   max_length: int = 4,
                   max_paths: int = 50) -> list[list[tuple]]:
        """Find all paths between source and target entities.

        Returns paths as [(entity, relation, entity, relation, ..., entity)]
        """
        paths = []
        queue = [(source, [(source,)])]

        while queue and len(paths) < max_paths:
            current, path_so_far = queue.pop(0)

            if len(path_so_far[-1]) > max_length * 2 + 1:
                continue

            for relation, neighbor in self.adj.get(current, []):
                if neighbor in [p for step in path_so_far for p in step if isinstance(p, int)]:
                    continue  # Avoid cycles

                new_step = path_so_far[:-1] + [path_so_far[-1] + (relation, neighbor)]

                if neighbor == target:
                    paths.append(new_step[-1])
                else:
                    queue.append((neighbor, new_step))

        return paths

    def score_paths(self, user_history: list[int],
                    candidate_item: int,
                    relation2name: dict) -> dict:
        """Score a candidate item based on KG paths from user history."""
        all_paths = []
        for hist_item in user_history[-10:]:  # Recent history
            paths = self.find_paths(hist_item, candidate_item, max_length=3)
            all_paths.extend(paths)

        if not all_paths:
            return {"score": 0, "paths": [], "explanation": "No KG connection found"}

        # Score based on path patterns
        path_scores = []
        for path in all_paths:
            length = len([p for p in path if isinstance(p, int)])
            path_scores.append(1.0 / length)  # Shorter paths score higher

        total_score = sum(path_scores) / len(path_scores)

        # Generate explanation from best path
        best_path = all_paths[path_scores.index(max(path_scores))]
        explanation = self._path_to_explanation(best_path, relation2name)

        return {
            "score": total_score,
            "n_paths": len(all_paths),
            "explanation": explanation,
            "paths": all_paths[:5],
        }

    def _path_to_explanation(self, path: tuple,
                              relation2name: dict) -> str:
        """Convert a KG path to human-readable explanation."""
        parts = []
        for i in range(0, len(path) - 2, 2):
            entity = self.entity2name.get(path[i], str(path[i]))
            relation = relation2name.get(path[i + 1], str(path[i + 1]))
            parts.append(f"{entity} --[{relation}]-->")
        parts.append(self.entity2name.get(path[-1], str(path[-1])))
        return " ".join(parts)

    def recommend_with_explanation(self, user_id: int,
                                    user_history: list[int],
                                    candidates: list[int],
                                    relation2name: dict,
                                    top_k: int = 10) -> list[dict]:
        """Generate explainable recommendations."""
        scored = []
        for item in candidates:
            result = self.score_paths(user_history, item, relation2name)
            if result["score"] > 0:
                scored.append({
                    "item_id": item,
                    "item_name": self.entity2name.get(item, str(item)),
                    "score": result["score"],
                    "explanation": result["explanation"],
                    "n_paths": result["n_paths"],
                })

        scored.sort(key=lambda x: x["score"], reverse=True)
        return scored[:top_k]

范式三:传播增强推荐

RippleNet风格的偏好传播

class PreferencePropagation(nn.Module):
    """Preference propagation on KG (RippleNet-style).

    User's historical items as seeds, propagate preferences
    through KG to discover relevant items.
    """

    def __init__(self, n_entities: int, n_relations: int,
                 embed_dim: int = 64, n_hops: int = 2):
        super().__init__()
        self.n_hops = n_hops
        self.entity_emb = nn.Embedding(n_entities, embed_dim)
        self.relation_emb = nn.Embedding(n_relations, embed_dim * embed_dim)
        self.embed_dim = embed_dim

        self.transform = nn.ModuleList([
            nn.Linear(embed_dim, embed_dim) for _ in range(n_hops)
        ])

    def get_ripple_set(self, user_history: list[int],
                        kg_adj: dict[int, list[tuple]],
                        max_size: int = 32) -> list[list[tuple]]:
        """Get multi-hop ripple sets for a user's history items.

        Returns: list of [(head, relation, tail), ...] per hop
        """
        ripple_sets = []
        seeds = set(user_history)

        for hop in range(self.n_hops):
            hop_triples = []
            for seed in seeds:
                for rel, tail in kg_adj.get(seed, []):
                    hop_triples.append((seed, rel, tail))

            # Sample if too many
            if len(hop_triples) > max_size:
                indices = np.random.choice(len(hop_triples), max_size, replace=False)
                hop_triples = [hop_triples[i] for i in indices]

            ripple_sets.append(hop_triples)
            seeds = {t for _, _, t in hop_triples}

        return ripple_sets

    def forward(self, item_ids: torch.Tensor,
                ripple_sets: list[list[tuple]]) -> torch.Tensor:
        """Score items based on preference propagation.

        Args:
            item_ids: Candidate item entity IDs [B]
            ripple_sets: Multi-hop ripple sets from user history
        Returns:
            Item scores [B]
        """
        item_emb = self.entity_emb(item_ids)  # [B, D]

        for hop in range(self.n_hops):
            if not ripple_sets[hop]:
                continue

            heads, rels, tails = zip(*ripple_sets[hop])
            h_emb = self.entity_emb(torch.tensor(heads))  # [S, D]
            r_emb = self.relation_emb(torch.tensor(rels))  # [S, D*D]
            t_emb = self.entity_emb(torch.tensor(tails))   # [S, D]

            # Reshape relation as matrix
            R = r_emb.view(-1, self.embed_dim, self.embed_dim)  # [S, D, D]

            # Attention: item * R * tail
            Rh = torch.bmm(R, h_emb.unsqueeze(-1)).squeeze(-1)  # [S, D]
            attn = torch.mm(item_emb, Rh.T)  # [B, S]
            attn = F.softmax(attn, dim=-1)

            # Weighted sum of tail embeddings
            context = torch.mm(attn, t_emb)  # [B, D]

            # Update item embedding
            item_emb = self.transform[hop](item_emb + context)

        scores = (item_emb ** 2).sum(dim=-1)  # [B]
        return torch.sigmoid(scores)

可解释推荐

推荐解释生成

class ExplainableRecommender:
    """Generate human-readable explanations for recommendations."""

    EXPLANATION_TEMPLATES = {
        "same_category": "Because you liked {source}, and {target} is in the same category '{category}'",
        "same_creator": "Because you enjoyed works by {creator}, who also created {target}",
        "similar_features": "Because {target} shares features ({features}) with {source} that you liked",
        "popular_in_group": "Popular among users with similar taste to you",
        "collaborative": "Users who liked {source} also liked {target}",
    }

    def __init__(self, kg_data: RecKGData, entity_names: dict, relation_names: dict):
        self.data = kg_data
        self.names = entity_names
        self.rel_names = relation_names

    def explain(self, user_id: int, recommended_item: int,
                user_history: list[int]) -> list[dict]:
        """Generate multiple explanation candidates for a recommendation."""
        explanations = []

        item_entity = self.data.item2entity.get(recommended_item)
        if not item_entity:
            return [{"type": "default", "text": "Recommended based on your preferences"}]

        # Check category match
        for hist_item in user_history[-5:]:
            hist_entity = self.data.item2entity.get(hist_item)
            if hist_entity:
                shared = self._find_shared_connections(hist_entity, item_entity)
                for conn_type, conn_name in shared:
                    explanations.append({
                        "type": conn_type,
                        "text": self._format_explanation(
                            conn_type, hist_item, recommended_item, conn_name
                        ),
                        "confidence": 0.8,
                    })

        if not explanations:
            explanations.append({
                "type": "default",
                "text": "Recommended based on your browsing history",
                "confidence": 0.5,
            })

        return explanations

    def _find_shared_connections(self, entity_a: int,
                                  entity_b: int) -> list[tuple[str, str]]:
        """Find shared KG connections between two entities."""
        shared = []
        neighbors_a = {}
        neighbors_b = {}

        for h, r, t in self.data.kg_triples:
            if h == entity_a:
                neighbors_a.setdefault(r, set()).add(t)
            if h == entity_b:
                neighbors_b.setdefault(r, set()).add(t)

        for rel in set(neighbors_a) & set(neighbors_b):
            common = neighbors_a[rel] & neighbors_b[rel]
            for entity in common:
                rel_name = self.rel_names.get(rel, str(rel))
                entity_name = self.names.get(entity, str(entity))
                shared.append((rel_name, entity_name))

        return shared

    def _format_explanation(self, conn_type: str, source: int,
                             target: int, connection: str) -> str:
        source_name = self.names.get(source, str(source))
        target_name = self.names.get(target, str(target))
        return (
            f"Because you liked '{source_name}', and '{target_name}' "
            f"shares the same {conn_type}: '{connection}'"
        )

评估指标

推荐质量评估

指标 类型 计算方式 适用场景
Precision@K 准确率 命中数/推荐数 Top-K推荐
Recall@K 召回率 命中数/实际正例 覆盖度评估
NDCG@K 排序质量 位置加权增益 排序推荐
AUC 区分能力 ROC曲线下面积 点击预测
解释覆盖率 可解释性 可解释推荐/总推荐 KG推荐
解释多样性 可解释性 解释类型数/推荐数 KG推荐
def evaluate_recommendation(predictions: list[list[int]],
                              ground_truth: list[list[int]],
                              k: int = 10) -> dict:
    """Evaluate recommendation quality."""
    precisions, recalls, ndcgs = [], [], []

    for pred, truth in zip(predictions, ground_truth):
        pred_k = pred[:k]
        truth_set = set(truth)

        # Precision@K
        hits = len(set(pred_k) & truth_set)
        precisions.append(hits / k)

        # Recall@K
        recalls.append(hits / max(len(truth_set), 1))

        # NDCG@K
        dcg = sum(
            1 / np.log2(i + 2) for i, item in enumerate(pred_k)
            if item in truth_set
        )
        ideal_dcg = sum(1 / np.log2(i + 2) for i in range(min(len(truth_set), k)))
        ndcgs.append(dcg / max(ideal_dcg, 1e-8))

    return {
        f"Precision@{k}": float(np.mean(precisions)),
        f"Recall@{k}": float(np.mean(recalls)),
        f"NDCG@{k}": float(np.mean(ndcgs)),
    }

结论

知识图谱为推荐系统带来了三重价值:一是缓解冷启动问题(新物品可以通过KG属性/关系获得初始表示),二是提升推荐质量(KG提供的结构化知识补充了协同过滤的稀疏信号),三是赋能可解释性(KG路径天然构成推荐理由)。在工程实践中,建议从KG嵌入增强方案入手(实现简单、收益明显),待数据和算力成熟后引入路径推理(提升可解释性)和传播增强(提升远距离关联发现能力)。可解释性不是锦上添花,而是建立用户信任和满足合规要求的核心能力。


Maurice | maurice_wen@proton.me