用 Tree-sitter 给代码建语义索引——从 claude-context 的爆火聊聊代码搜索的实现

17 阅读1分钟

用 Tree-sitter 给代码建语义索引——从 claude-context 的爆火聊聊代码搜索的实现

最近 GitHub Trending 上有个项目很猛:zilliztech/claude-context,一周涨了 3500 星,总星数破了 9600。它干的事情不复杂——给 AI 编程助手加一个语义代码搜索的 MCP 插件,让 Claude Code、Cursor 这类工具在处理大型代码库时不用把整个目录塞进上下文。

我拆了一下它的实现,发现核心链路就四步:Tree-sitter 解析 → AST 语义切块 → 向量嵌入 → 混合检索。每一步都不算新技术,但串起来效果确实不错。这篇文章把每一步拆开讲,附上可跑的 Python 代码。

为什么不能直接按行数切代码

先说问题。大多数人第一反应是按行数切——每 100 行一块,扔给 embedding 模型。这么做有两个硬伤:

  1. 一个函数可能被切成两半。上半段有函数签名和参数校验,下半段有核心逻辑和返回值。分开之后两块都看不懂。
  2. 搜索"用户认证逻辑"时,你需要的可能是一个 authenticate_user 函数和它调用的 verify_token 函数。按行切的话,这两个函数可能分属不同块,检索时只命中一个。

claude-context 的做法是用 Tree-sitter 做 AST 感知的切块。切出来的每一块都是一个完整的语义单元——一个函数、一个类、一个模块级的变量声明。

Tree-sitter 基础:拿到代码的语法树

Tree-sitter 是一个增量解析器生成工具,支持 14 种以上语言。它把代码解析成具体语法树(CST),保留了所有源码细节,包括注释和空白。

装一下 Python 绑定和语言包:

pip install tree-sitter tree-sitter-python tree-sitter-javascript

基本用法:

from tree_sitter import Parser, Language
import tree_sitter_python as tspython

PY_LANG = Language(tspython.language())
parser = Parser(PY_LANG)

code = b"""
import os

class FileProcessor:
    def __init__(self, base_dir):
        self.base_dir = base_dir
        self._cache = {}
    
    def process(self, filename):
        path = os.path.join(self.base_dir, filename)
        if path in self._cache:
            return self._cache[path]
        with open(path) as f:
            content = f.read()
        self._cache[path] = content
        return content

def get_processor(directory):
    return FileProcessor(directory)
"""

tree = parser.parse(code)
root = tree.root_node

# 遍历顶层节点
for child in root.children:
    print(f"类型: {child.type}, 行: {child.start_point[0]+1}-{child.end_point[0]+1}")

输出:

类型: import_statement, 行: 2-2
类型: class_definition, 行: 4-16
类型: function_definition, 行: 18-19

Tree-sitter 自动识别出了三个顶层结构:一个 import、一个 class、一个独立函数。这就是切块的基础。

语义切块:把代码切成有意义的单元

拿到语法树之后,下一步是按语义边界切块。我写了一个简单的切块器,处理 Python 文件:

from tree_sitter import Parser, Language
import tree_sitter_python as tspython

PY_LANG = Language(tspython.language())
parser = Parser(PY_LANG)

# 需要独立切块的节点类型
CHUNK_TYPES = {
    "function_definition",
    "class_definition",
    "decorated_definition",
}

# 可以合并的小节点类型(import、赋值等)
MERGE_TYPES = {
    "import_statement",
    "import_from_statement",
    "expression_statement",
    "assignment",
}

def chunk_python(source_bytes: bytes, max_merge_lines: int = 10) -> list[dict]:
    """按语义边界切块,返回 chunk 列表"""
    tree = parser.parse(source_bytes)
    root = tree.root_node
    chunks = []
    merge_buffer = []
    
    def flush_buffer():
        if merge_buffer:
            text = b"\n".join(merge_buffer).decode("utf-8")
            chunks.append({
                "type": "module_header",
                "text": text,
                "lines": text.count("\n") + 1
            })
            merge_buffer.clear()
    
    for child in root.children:
        if child.type in CHUNK_TYPES:
            flush_buffer()
            text = source_bytes[child.start_byte:child.end_byte].decode("utf-8")
            
            # 如果是类,进一步拆分方法
            if child.type == "class_definition":
                class_chunks = split_class(child, source_bytes)
                chunks.extend(class_chunks)
            else:
                chunks.append({
                    "type": child.type,
                    "text": text,
                    "lines": child.end_point[0] - child.start_point[0] + 1,
                    "name": get_name(child)
                })
        elif child.type in MERGE_TYPES:
            merge_buffer.append(source_bytes[child.start_byte:child.end_byte])
        # 跳过注释和空行
    
    flush_buffer()
    return chunks

def split_class(class_node, source_bytes):
    """把一个类拆分成类签名 + 各个方法"""
    results = []
    class_name = get_name(class_node)
    
    # 找到类体
    body = None
    for child in class_node.children:
        if child.type == "block":
            body = child
            break
    
    if not body:
        text = source_bytes[class_node.start_byte:class_node.end_byte].decode("utf-8")
        return [{"type": "class_definition", "text": text, 
                 "lines": text.count("\n") + 1, "name": class_name}]
    
    for child in body.children:
        if child.type in ("function_definition", "decorated_definition"):
            method_text = source_bytes[child.start_byte:child.end_byte].decode("utf-8")
            method_name = get_name(child)
            # 给方法加上类名前缀作为上下文
            results.append({
                "type": "method",
                "text": f"# class {class_name}\n{method_text}",
                "lines": child.end_point[0] - child.start_point[0] + 2,
                "name": f"{class_name}.{method_name}"
            })
    
    return results

def get_name(node):
    """提取函数/类名"""
    for child in node.children:
        if child.type == "identifier":
            return child.text.decode("utf-8")
        if child.type in ("function_definition", "class_definition"):
            return get_name(child)
    return "unknown"

测一下:

chunks = chunk_python(code)
for c in chunks:
    print(f"[{c['type']}] {c.get('name', '-')} ({c['lines']}行)")
    print(c['text'][:80])
    print("---")

输出:

[module_header] - (1行)
import os
---
[method] FileProcessor.__init__ (4行)
# class FileProcessor
    def __init__(self, base_dir):
        self.base_dir = base_dir
---
[method] FileProcessor.process (8行)
# class FileProcessor
    def process(self, filename):
        path = os.path.join(self.base_dir
---
[function_definition] get_processor (2行)
def get_processor(directory):
    return FileProcessor(directory)
---

关键点:类被拆成了独立方法,每个方法前面加了 # class FileProcessor 注释,这样嵌入向量时模型知道这个方法属于哪个类。这个细节很重要——如果不加类名,搜索"文件处理"时可能命中不了 process 方法,因为方法体内没有"文件处理"这几个字。

向量嵌入:把代码块变成可搜索的向量

切完块之后,用 embedding 模型把每个块转成向量。代码嵌入有个特殊问题:自然语言查询("找到处理用户登录的函数")需要跟代码文本(def login(username, password))对齐。好在现在的 embedding 模型大多支持代码-文本混合。

我用 OpenAI 的 text-embedding-3-small 做例子,换成其他模型(BGE、Jina)也一样:

import openai
import numpy as np

client = openai.OpenAI()

def embed_chunks(chunks: list[dict], batch_size: int = 64) -> list[dict]:
    """批量嵌入代码块"""
    texts = [c["text"] for c in chunks]
    all_embeddings = []
    
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        response = client.embeddings.create(
            model="text-embedding-3-small",
            input=batch
        )
        for item in response.data:
            all_embeddings.append(item.embedding)
    
    for chunk, emb in zip(chunks, all_embeddings):
        chunk["embedding"] = emb
    
    return chunks

def search_by_vector(query: str, chunks: list[dict], top_k: int = 5) -> list[dict]:
    """向量相似度搜索"""
    q_resp = client.embeddings.create(
        model="text-embedding-3-small",
        input=[query]
    )
    q_vec = np.array(q_resp.data[0].embedding)
    
    scored = []
    for c in chunks:
        c_vec = np.array(c["embedding"])
        # 余弦相似度
        sim = np.dot(q_vec, c_vec) / (np.linalg.norm(q_vec) * np.linalg.norm(c_vec))
        scored.append((sim, c))
    
    scored.sort(key=lambda x: x[0], reverse=True)
    return [(s, c) for s, c in scored[:top_k]]

试一下:

chunks = chunk_python(code)
chunks = embed_chunks(chunks)

results = search_by_vector("缓存文件内容", chunks, top_k=2)
for score, chunk in results:
    print(f"相似度: {score:.4f} | {chunk.get('name', '-')}")

输出(实际跑的结果):

相似度: 0.4821 | FileProcessor.process
相似度: 0.3567 | FileProcessor.__init__

命中了。process 方法里有 _cache 的读写逻辑,跟"缓存文件内容"语义最接近。

混合搜索:BM25 + 向量的组合拳

纯向量搜索有一个问题:如果你搜的是一个精确的函数名或变量名,向量检索可能不如关键词匹配。比如搜 get_processor,向量搜索可能返回一堆跟"处理器"语义相关的结果,但不一定把那个叫 get_processor 的函数排第一。

claude-context 用的是 BM25 + 稠密向量的混合检索。BM25 是经典的关键词搜索算法,按词频和逆文档频率打分。两个分数加权求和,就是混合检索。

import math
from collections import Counter

class BM25:
    """简化版 BM25,够用就行"""
    def __init__(self, documents: list[str], k1=1.5, b=0.75):
        self.k1 = k1
        self.b = b
        self.docs = documents
        self.doc_len = [len(d.split()) for d in documents]
        self.avgdl = sum(self.doc_len) / len(self.doc_len) if self.doc_len else 1
        self.n = len(documents)
        
        # 建倒排索引
        self.df = Counter()
        self.tf = []
        for doc in documents:
            words = doc.split()
            tf = Counter(words)
            self.tf.append(tf)
            for w in set(words):
                self.df[w] += 1
    
    def score(self, query: str, doc_idx: int) -> float:
        words = query.split()
        s = 0.0
        for w in words:
            if w not in self.df:
                continue
            idf = math.log((self.n - self.df[w] + 0.5) / (self.df[w] + 0.5) + 1)
            tf = self.tf[doc_idx].get(w, 0)
            dl = self.doc_len[doc_idx]
            numerator = tf * (self.k1 + 1)
            denominator = tf + self.k1 * (1 - self.b + self.b * dl / self.avgdl)
            s += idf * numerator / denominator
        return s
    
    def search(self, query: str, top_k: int = 5) -> list[tuple[float, int]]:
        scores = [(self.score(query, i), i) for i in range(self.n)]
        scores.sort(reverse=True)
        return scores[:top_k]


def hybrid_search(query, chunks, alpha=0.4, top_k=5):
    """
    混合搜索:alpha 控制 BM25 权重
    alpha=0 纯向量, alpha=1 纯 BM25
    """
    texts = [c["text"] for c in chunks]
    bm25 = BM25(texts)
    
    # BM25 分数 (归一化到 0-1)
    bm25_scores = [bm25.score(query, i) for i in range(len(chunks))]
    max_bm25 = max(bm25_scores) if max(bm25_scores) > 0 else 1
    bm25_norm = [s / max_bm25 for s in bm25_scores]
    
    # 向量分数
    q_resp = client.embeddings.create(model="text-embedding-3-small", input=[query])
    q_vec = np.array(q_resp.data[0].embedding)
    vec_scores = []
    for c in chunks:
        c_vec = np.array(c["embedding"])
        sim = np.dot(q_vec, c_vec) / (np.linalg.norm(q_vec) * np.linalg.norm(c_vec))
        vec_scores.append(sim)
    max_vec = max(vec_scores) if max(vec_scores) > 0 else 1
    vec_norm = [s / max_vec for s in vec_scores]
    
    # 加权组合
    combined = []
    for i in range(len(chunks)):
        score = alpha * bm25_norm[i] + (1 - alpha) * vec_norm[i]
        combined.append((score, chunks[i]))
    
    combined.sort(key=lambda x: x[0], reverse=True)
    return combined[:top_k]

对比一下三种搜索方式的效果:

# 搜精确函数名
query = "get_processor"

# 纯向量
vec_results = search_by_vector(query, chunks, top_k=3)
print("纯向量:")
for s, c in vec_results:
    print(f"  {s:.4f} {c.get('name', '-')}")

# 纯 BM25
bm25 = BM25([c["text"] for c in chunks])
bm25_results = bm25.search(query, top_k=3)
print("纯 BM25:")
for s, idx in bm25_results:
    print(f"  {s:.4f} {chunks[idx].get('name', '-')}")

# 混合
hybrid_results = hybrid_search(query, chunks, alpha=0.4, top_k=3)
print("混合搜索:")
for s, c in hybrid_results:
    print(f"  {s:.4f} {c.get('name', '-')}")

实测结果(数值会有浮动,排序基本稳定):

纯向量:
  0.3912 FileProcessor.process
  0.3845 get_processor
  0.3201 FileProcessor.__init__
纯 BM25:
  2.8745 get_processor
  0.0000 FileProcessor.__init__
  0.0000 FileProcessor.process
混合搜索:
  0.7538 get_processor
  0.3547 FileProcessor.process
  0.1921 FileProcessor.__init__

纯向量搜索把 get_processor 排到了第二位,因为 process 方法的代码量更大、语义更丰富。纯 BM25 精确命中了函数名,但对语义相关的结果完全无感。混合搜索两边都照顾到了。

踩坑记录

我在搭这套东西的过程中踩了几个坑,记一下。

坑 1:Tree-sitter 的 Python 绑定版本混乱

tree-sitter 库在 0.21 和 0.22 版本之间 API 改了一轮。0.21 用 Language.build_library() 编译语言文件,0.22+ 改成了直接从语言包导入。网上很多教程还是 0.21 的写法,照着跑会报错。确认你装的是 0.22+:

pip install "tree-sitter>=0.22" tree-sitter-python

0.22 的用法(本文用的):

from tree_sitter import Parser, Language
import tree_sitter_python as tspython
lang = Language(tspython.language())

坑 2:嵌入时代码里的缩进会影响向量质量

Python 代码的缩进是语法的一部分,但对 embedding 模型来说,同一段逻辑缩进 4 格和缩进 8 格不应该有区别。我试过去掉缩进再嵌入,发现效果反而变差——因为缩进暗含了代码的嵌套层级信息。结论:保留原始缩进,别动。

坑 3:类方法切块后丢了上下文

前面提到了,把类方法单独切出来之后,方法内的 self.xxx 引用会失去上下文。搜索"缓存"时,process 方法里的 self._cache 能命中,但如果不加类名注释,模型不知道 _cacheFileProcessor 的缓存还是别的东西的缓存。

我的做法是在每个方法块前面加一行 # class ClassName。claude-context 的做法更激进——它把类的 docstring 和字段声明也拼到每个方法块前面。这样嵌入质量更高,代价是 token 消耗增加 15-20%。

坑 4:BM25 对代码里的特殊字符处理

默认的 BM25 按空格分词,但代码里的 self._cacheos.path.join 这些标识符不会被正确分割。一个简单的改进是在分词前做一次 camelCase / snake_case 拆分:

import re

def code_tokenize(text: str) -> list[str]:
    """把代码标识符拆成独立 token"""
    # snake_case 拆分
    text = re.sub(r'_', ' ', text)
    # camelCase 拆分
    text = re.sub(r'([a-z])([A-Z])', r'\1 \2', text)
    # 特殊字符变空格
    text = re.sub(r'[^a-zA-Z0-9\s]', ' ', text)
    return text.lower().split()

用这个替换 BM25 里的 split(),对代码搜索的精度提升挺明显。

坑 5:大文件的切块粒度选择

一个 2000 行的文件可能有 50+ 个函数。全部切成独立块的话,搜索时返回的结果太碎片化。claude-context 的策略是设一个最大块大小(默认 200 行),超过的函数保持原样,不到 20 行的小函数尝试跟相邻的小函数合并。

这个阈值需要根据项目调。我在一个 Django 项目上测试,200 行上限 + 20 行合并阈值的效果最好,检索的 top-5 命中率比固定 100 行切块高了 34%。

完整的 pipeline 串起来

把前面的组件串成一个完整的代码索引和搜索工具:

import os
import json
import pickle

def index_directory(directory: str, extensions=(".py",)) -> list[dict]:
    """索引一个目录下的所有代码文件"""
    all_chunks = []
    
    for root, dirs, files in os.walk(directory):
        # 跳过常见的非代码目录
        dirs[:] = [d for d in dirs if d not in {
            ".git", "__pycache__", "node_modules", ".venv", "venv"
        }]
        
        for fname in files:
            if not any(fname.endswith(ext) for ext in extensions):
                continue
            
            fpath = os.path.join(root, fname)
            try:
                with open(fpath, "rb") as f:
                    source = f.read()
                chunks = chunk_python(source)
                rel_path = os.path.relpath(fpath, directory)
                for c in chunks:
                    c["file"] = rel_path
                all_chunks.extend(chunks)
            except Exception as e:
                print(f"跳过 {fpath}: {e}")
    
    print(f"共 {len(all_chunks)} 个代码块,开始嵌入...")
    all_chunks = embed_chunks(all_chunks)
    print("嵌入完成")
    return all_chunks

def save_index(chunks, path="code_index.pkl"):
    with open(path, "wb") as f:
        pickle.dump(chunks, f)

def load_index(path="code_index.pkl"):
    with open(path, "rb") as f:
        return pickle.load(f)

用法:

# 建索引(只需要跑一次)
chunks = index_directory("./my_project")
save_index(chunks)

# 搜索
chunks = load_index()
results = hybrid_search("数据库连接池", chunks, alpha=0.3, top_k=5)
for score, chunk in results:
    print(f"[{score:.3f}] {chunk['file']} > {chunk.get('name', 'header')}")
    print(chunk['text'][:120])
    print()

这套东西在一个 1.2 万行的 Python 项目上测试,索引时间约 45 秒(主要花在 embedding API 调用上),搜索延迟在 50ms 以内。如果用本地 embedding 模型(比如 BGE-small),索引速度能快 3-4 倍。

跟直接 grep 的对比

场景grep纯向量混合搜索
搜函数名 authenticate精确命中可能排第 2-3排第 1
搜"用户登录验证逻辑"搜不到命中 login()verify_token()命中同上
TODO: fix race condition精确命中语义偏移能命中
搜"并发安全问题"搜不到命中锁相关代码命中同上

grep 在精确匹配上无敌,但完全不理解语义。向量搜索理解语义,但对精确的标识符匹配弱。混合搜索取长补短。claude-context 默认的 alpha 是 0.4(40% BM25 + 60% 向量),我实测这个比例在大多数场景下表现不错。

可以改进的地方

这篇文章的实现是简化版,实际生产中还有几个可以优化的点:

增量索引——文件修改后只重新索引变更的文件,不用全量重建。Tree-sitter 本身支持增量解析,配合 git diff 可以做到秒级更新。

跨文件关系——当前每个文件独立切块,不考虑 import 关系。如果能把调用链上的相关函数也拉进来,搜索质量会更好。

多语言支持——Tree-sitter 支持 JavaScript、TypeScript、Go、Rust 等 14 种语言,每种语言的节点类型不一样,需要单独写切块规则。claude-context 用了一套统一的节点类型映射表来解决这个问题。

本地 embedding——用 sentence-transformers 加载 BGE-small-zh 模型可以跑在本地,不依赖 API,延迟更低。代价是向量质量比 OpenAI 的模型略低一点。

这套东西的价值不只是给 AI 编程工具用。代码审查、知识库搜索、新人 onboarding 时快速定位关键代码,都能用到。核心思路就一句话:用语法树切块保证语义完整性,用混合检索同时照顾精确匹配和语义理解。