使用 Python 和 Jax 构建推荐系统(一) 基本概念

30 阅读12分钟

基本定义

推荐系统的本质在于解决“品味的几何学”问题,即利用用户微小的互动信号在抽象空间中进行定位。核心问题框架即给定一组可能被推荐的事物,根据特定的目标,为当前上下文和用户选择有序的少数几个。

核心架构组件

收集器(Collector)

收集器负责从海量候选中检索出可能被推荐的事物集合;基于当前上下文或状态,筛选出符合必要特征或属性的子集。 例如餐厅服务员确认菜单库存,并根据顾客口味(如不吃香菜)缩小范围。

排序器 (Ranker)

排序器负责对收集器提供的集合进行评分和排序(Scoring/Ordering)。利用模型结合上下文和用户特征,对候选集进行精细化排序。 例如服务员根据“热门程度”(大家都点)或“个性化匹配”(顾客喜欢石榴)推荐特定甜点。

服务端 (Server)

服务端负责交付最终结果(Serving)。获取有序子集,执行业务逻辑(如去重、格式化),验证数据模式 (Schema),并返回最终推荐列表。例如服务员口头向顾客陈述最终的几项建议。

基础推荐器模型

平凡推荐器

def get_trivial_recs() -> Optional[List[str]]:
    item_id = random.randint(0, MAX_ITEM_INDEX)
    if get_availability(item_id):
        return [item_id]
    return None

逻辑:随机生成一个 item_id,若可用则返回,否则返回 None

组件行为

  • 收集器:随机生成 ID 并检查可用性。
  • 排序器:无操作 (No-op),恒等函数。
  • 服务端:返回类型为 Optional[List[str]]

最热门商品推荐器

def get_item_popularities() -> Optional[Dict[str, int]]:
    ...
    # 返回字典对: (商品标识符, 商品被选择的次数)
    return item_choice_counts
    return Nonedef get_most_popular_recs(max_num_recs: int) -> Optional[List[str]]:
    items_popularity_dict = get_item_popularities()
    if items_popularity_dict:
        sorted_items = sorted(
            items_popularity_dict.items(),
            key=lambda item: item[1],
            reverse=True,
        )
        return [i[0] for i in sorted_items][:max_num_recs]
    return None

这是具有实用价值的最简模型,常作为冷启动兜底策略

逻辑:返回全局被选择次数最多的前 k 个商品。

组件行为

  • 收集器:获取包含商品ID和选择次数(热门度)的字典 Dict[str, int]
  • 排序器:按字典的值(计数)进行降序排序。
  • 服务端:截取前 max_num_recs 个结果,返回 ID 列表。

JAX 简明入门

JAX 是一个专为高性能数值计算和机器学习设计的 Python 框架,支持即时编译 (JIT) 和自动微分。

特性NumPyJAX (jax.numpy)
可变性 (Mutability)可变 (Mutable),支持 x[0] = 4.0不可变 (Immutable) ,修改元素需创建新数组或使用特定更新语法
硬件加速仅 CPU支持 CPU, GPU, TPU
随机数生成基于全局状态 (Stateful)基于显式密钥 (Stateless/Key-based)
执行模式解释执行支持 JIT 编译 (XLA)

切片

x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=jnp.int32)
​
# 打印整个矩阵。
print(x)
# [[1 2 3]
#  [4 5 6]
#  [7 8 9]]
​
# 打印第一行。
print(x[0])
# [1 2 3]
​
# 打印最后一行。
print(x[-1])
# [7 8 9]
​
# 打印第二列。
print(x[:, 1])
# [2 5 8]
​
# 每隔一个元素打印一次
# start:end:stride(开始:结束:步长)
print(x[::2, ::2])
# [[1 3]
#  [7 9]]

广播

当对两个不同形状的张量进行二元运算时,JAX 会自动提升维度为 1 的轴以匹配较大张量。

  • 标量广播:标量乘以矩阵,应用于所有元素。

  • 向量广播

    • (3, 1) \times (3, 3):向量按复制。
    • (1, 3) \times (3, 3):向量按复制。
x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=jnp.int32)
​
# 标量广播。
y = 2 * x
print(y)
# [[ 2  4  6]
#  [ 8 10 12]
#  [14 16 18]]
​
# * 逐元素相乘
# 按 列进行复制
vec = jnp.reshape(jnp.array([0.5, 1.0, 2.0]), [3, 1])
y = vec * x
print(y)
# [[0.5, 0.5, 0.5] * [1, 2, 3] = [ 0.5,  1. ,  1.5]] 
# [[1.0, 1.0, 1.0] * [4, 5, 6] = [ 4. ,  5. ,  6. ]] 
# [[2.0, 2.0, 2.0] * [7, 8, 9] = [14. , 16. , 18. ]]  
​
vec = jnp.reshape(vec, [1, 3])
y = vec * x
print(y)
# [[0.5, 1.0, 2.0] * [1, 2, 3] = [ 0.5,  2. ,  6. ]]  
# [[0.5, 1.0, 2.0] * [4, 5, 6] = [ 2. ,  5. , 12. ]]  
# [[0.5, 1.0, 2.0] * [7, 8, 9] = [ 3.5,  8. , 18. ]]  

不可变性

import jax.numpy as jnp
import numpy as np
​
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
print(x)
# [1. 2. 3.]print(x.shape)
# (3,)print(x[0])
# 1.0# 不可变性
x[0] = 4.0
# TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' 
# object does not support item assignment. JAX arrays are immutable.
  • JAX 数组具有静态类型(如 float32)和形状(Shape)。
  • 遵循函数式编程理念:纯函数 (Pure Functions) 无副作用,数据不可变。这使得底层 XLA (Accelerated Linear Algebra) 编译器能够安全地优化并行计算。

随机数

JAX 摒弃了传统的全局随机种子,采用显式的 PRNGKey;给定相同的 Key,生成的随机数永远相同。

密钥分裂 (Splitting) :为了生成新的随机数序列,必须将 Key 分裂为 keysubkey

Keynext,Subkey=split(Keycurrent)Key_{next}, Subkey = split(Key_{current})

这种机制保证了在并行计算中随机数的可复现性。

import jax.random as random
​
key = random.PRNGKey(0)
x = random.uniform(key, shape=[3, 3])
print(x)
# [[0.35490513 0.60419905 0.4275843 ]
#  [0.23061597 0.6735498  0.43953657]
#  [0.25099766 0.27730572 0.7678207 ]]
​
# random.split(key) 接收旧的 key,并确定性地生成新的、统计上独立的密钥
# key 来维持状态,用于下一步继续分裂
# Subkey 用于本次生成随机数
key, subkey = random.split(key)
x = random.uniform(key, shape=[3, 3])
print(x)
# [[0.0045197  0.5135027  0.8613342 ]
#  [0.06939673 0.93825936 0.85599923]
#  [0.706004   0.50679076 0.6072922 ]]
​
y = random.uniform(subkey, shape=[3, 3])
print(y)
# [[0.34896135 0.48210478 0.02053976]
#  [0.53161216 0.48158717 0.78698325]
#  [0.07476437 0.04522789 0.3543167 ]]

即时编译

import jax
x = random.uniform(key, shape=[2048, 2048]) - 0.5
​
def my_function(x):
    x = x @ x
    return jnp.maximum(0.0, x)
​
%timeit my_function(x).block_until_ready()
# 302 ms ± 9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
​
my_function_jitted = jax.jit(my_function)
%timeit my_function_jitted(x).block_until_ready()
# 294 ms ± 5.45 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

使用 @jax.jit 装饰器或 jax.jit() 函数将 Python 代码追踪并编译为 XLA 优化的机器码。在 GPU/TPU 上性能提升显著,但在首次调用时会有编译开销。

Python 控制流(如依赖数据的循环)有限制,推荐使用 JAX 内部原语。

用户-物品评分与问题定义

用户-物品矩阵

推荐系统的基础数据结构是矩阵,需要将非结构化的用户反馈(如口头评价)转化为结构化数据。

矩阵可视化

在数据量较小时,使用热力图 (Heatmap) 是直观理解数据分布和稀疏性的最佳方式。

import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
​
# 使用 np.nan 表示缺失值(即用户未评分/未互动的情况)
# 这种缺失是推荐系统需要预测的目标
_ = np.nan
​
# 构建评分矩阵 (行=用户, 列=物品)
# 对应关系: A, B, C, D, E vs Gouda, Chevre, Emmentaler, Brie
scores = np.array([
    [5, 4, 4, 1],    # User A
    [2, 3, 3, 4.5],  # User B
    [3, 2, 3, 4],    # User C
    [4, 4, 5, _],    # User D (未尝过 Brie)
    [3, _, _, _]     # User E (只尝过 Gouda)
])
​
# 绘图逻辑
plt.figure(figsize=(8, 6))
​
sns.heatmap(
    scores,
    annot=True,     # 在格子上显示具体数值
    fmt=".1f",      # 数值格式保留一位小数
    xticklabels=['Gouda', 'Chevre', 'Emmentaler', 'Brie'],
    yticklabels=['A','B','C','D','E']
)
plt.savefig('photo.jpg', bbox_inches='tight')

image.png

随着用户和物品数量激增,矩阵会变得极度稀疏(绝大多数位置是空的)。

数据表示方法

  • 稠密表示 (Dense)存储所有数据,包括空值(如上面的 numpy 数组)。
  • 稀疏表示 (Sparse)仅存储非空数据。通常使用元组 (user_id, item_id, rating) 或坐标列表 (COO) 格式。
# 稀疏表示
# 这种结构更接近生产环境中的存储方式
data = {
    # 索引列表:每一项是一个 (User_Index, Item_Index) 元组
    'indices': [
        (0,0), (0,1), (0,2), (0,3), # User A 的所有评分坐标
        (4,0)                       # User E 唯一的评分坐标
    ],
    # 值列表:对应上述坐标的具体评分
    'values': [
        5, 4, 4, 1,
        3
    ]
}

这种结构对应线性代数中的坐标格式 (Coordinate Format)

协同过滤

协同过滤(Collaborative Filtering, CF) 的核心假设:品味相似的人对未知物品的评价也相似。

类型侧重点推荐逻辑向量视角
User-User CF优先考虑用户相似性找到与用户 A 相似的用户 B,推荐 B 喜欢但 A 没看过的物品。矩阵的向量相似
Item-Item CF优先考虑物品相似性找到用户 A 喜欢的物品 X,推荐与 X 相似的物品 Y。矩阵的向量相似

向量相似度

通过计算两个向量(数字列表)在隐空间 (Latent Space) 中的距离来衡量相似度。通常先归一化,再计算余弦相似度(或点积)。

评分类型

数据来源决定了信号的质量和处理方式。

特性硬评分 (Hard Ratings)软评分 (Soft Ratings) / 隐式反馈
形式星级 (1-5)、点赞/点踩点击、观看时长、页面停留、加入购物车
触发用户显式响应提示用户行为隐式传达
优点信号明确,意图清晰数据量巨大,覆盖面广
缺点数据极其稀疏,存在幸存者偏差噪音大,含糊(没点≠不喜欢,可能是没看见)
应用评估模型准确性训练大规模召回/排序模型

数据收集与用户日志

现实世界中,用户-物品矩阵初始是空的。数据不会自己跑进模型里。为了让算法运转,需要解决三个工程与业务层面的关键问题:

  1. 数据来源:除了用户直接打分(显式评分),我们还能捕捉哪些行为作为信号(隐式评分)?
  2. 数据管道:如何从技术上捕获、传输和存储这些海量行为数据?
  3. 业务价值:除了做推荐,这些数据还能如何指导公司的商业决策?

显式与隐式信号

显式评分

定义:用户直接表达喜好的行为(如 1-5 星评分、点赞/点踩)。

工程注意:必须持久化存储。如果用户评了分,下次刷新却不见了,会带来极差的用户体验。

局限性:数据稀疏,用户很少主动打分。

隐式评分

由于显式评分太少,现代推荐系统主要依赖用户的行为轨迹。以下是按信号强度排序的关键数据点:

  • 页面加载 & 倾向评分

背景:你不能因为用户没有点击某本书,就断定他“不喜欢”它——很有可能他根本没看见它。

为什么记录加载:记录用户“看到了什么”(曝光集)至关重要。这涉及统计学中的倾向评分(Propensity Scores) 概念:

  • 如果《银河系漫游指南》从未在首页展示,用户就没有“点击它的倾向(概率)”。
  • 应用:我们需要用“曝光量”作为分母,来校正点击率。未曝光≠不喜欢;曝光未点击=负反馈

  • 页面浏览与悬停

背景:用户在做决定前的微小交互。

信号意义

  • 悬停 (Hover) :用户鼠标停留在书皮上,甚至触发了放大效果。这代表了“好奇心”或“微弱的兴趣”。
  • 发现成本:如果用户通过滚动轮播图(Carousel)才看到某本书,这个行为的权重应该比直接看到的高,因为用户付出了额外的交互成本。

  • 点击

背景:电商推荐系统的核心 KPI。

信号意义

  • 它是购买的必要前置条件
  • 它代表了明确的意图 (Intent)
  • 虽然有误触噪音,但数据量巨大,是训练模型的主力数据。

  • 加入购物车 & 点击流

背景:软评分的终点,距离“购买”仅一步之遥。

信号意义:极强的兴趣指标。有些算法甚至认为“加购”比“最终购买”更能反映用户的纯粹偏好(因为未购买可能只是因为支付失败或运费问题)。

点击流 (Click-Stream) :记录用户点击的顺序(如:先看裤子,再看皮带)。这为“序列推荐(Sequential Recommendation)”提供了数据基础。

数据管道与插桩

知道了“收什么数据”,接下来解决“怎么收”。

插桩

在网页或 App 代码中埋入“探针”。当特定行为(如点击、加载)发生时,触发一个事件 (Event)事件结构通常包含 User_ID, Item_ID, Timestamp, Session_ID 等上下文信息。

数据分流架构

当事件被触发后,通常会进入分叉的管道(Path Bifurcation):

  1. 路径一:日志数据库 (Log Database)

    • 存入 MySQL 等传统数据库。
    • 用途:长期存储、历史分析、离线模型训练。
  2. 路径二:实时事件流 (Event Stream)

    • 使用 Apache Kafka 等消息队列技术。
    • 用途实时推荐(用户刚点了一本书,下一秒首页推荐就变了)和实时监控看板。

漏斗分析

推荐系统上线后,怎么知道它好不好?除了看准确率,更要看它是否推动了业务流程。用户从进入网站到完成购买,像流过一个漏斗,每一步都会有人流失(Drop-off)。

三类漏斗

  1. 全局漏斗浏览 -> 加购 的整体转化率。衡量网站整体健康度。
  2. 推荐漏斗推荐曝光 -> 推荐点击 -> 加购。这是评估推荐算法质量的核心。如果推荐位的转化率低于搜索栏,说明推荐不够精准。
  3. 支付漏斗加购 -> 支付成功。这通常不是推荐系统的问题(可能是支付接口挂了),但数据科学家需要监控它,以免误判是推荐算法导致了收入下降。

业务洞察

推荐系统不仅仅是一个自动售货机,它还是一个观察市场趋势的“望远镜”。

归因分析

问题:这周流量大涨,是因为算法变好了,还是因为某部剧火了?

案例:Netflix 的《鱿鱼游戏》。数据暴涨并非算法功劳,而是内容本身爆火。

决策支持:通过识别这种异常,高管可以决定“多投资韩剧”或“增加字幕剧集预算”。

增量收益

定义:推荐系统带来的收益,必须是原本不会发生的那部分。

例子:如果用户本来就要买牙膏,你推荐牙膏就没有“增量价值”。真正的价值在于推荐了他原本不知道但感兴趣的“电动牙刷”。

启动效应

利用推荐系统人为地为新内容“造势”。通过在首页强推(即使初期数据一般),触发网络效应和病毒式传播,从而在长期获得更大的平台增长。