lucene源码系列(一):HNSW实现

3,970 阅读10分钟

前提

在看本文之前,各位看官最好对HNSW的算法原理有所了解(大白话唠HNSW),这样事半功倍。

当然,如果跟本文从代码实现反推HNSW的算法思路也是可以的,但是会比较吃力。另外,本文是对lucene源码的解析,所以至少要了解lucene的一些基本概念。

注意:本文源码解析基于lucene 9.1.0,其中HNSW的实现和论文中的实现整体上是一样的,部分细节有所区别。

向量数据添加过程

代码示例

    public static void main(String[] args) throws IOException, InterruptedException {
        Directory directory = FSDirectory.open(Paths.get("D:\\my_index_knn"));
        StandardAnalyzer analyzer = new StandardAnalyzer();
        IndexWriterConfig indexWriterConfig = new IndexWriterConfig(analyzer);
        indexWriterConfig.setUseCompoundFile(false);

        IndexWriter indexWriter = new IndexWriter(directory, indexWriterConfig);

        for (int i = 1; i <= 500; i ++) {
            Document document = new Document();
            // 添加向量字段
            document.add(new KnnVectorField("vector1", TestDataGenerator.generateData(128), VectorSimilarityFunction.EUCLIDEAN));
            document.add(new KnnVectorField("vector2", TestDataGenerator.generateData(128), VectorSimilarityFunction.EUCLIDEAN));
            indexWriter.addDocument(document);

            if (i % 100 == 0) {
                indexWriter.flush();
                indexWriter.commit();
            }
        }

        indexWriter.flush();
        indexWriter.commit();

        IndexReader reader = DirectoryReader.open(indexWriter);
        IndexSearcher searcher = new IndexSearcher(reader);

        // 检索
        KnnVectorQuery knnVectorQuery = new KnnVectorQuery("vector1", TestDataGenerator.generateData(128), 10);
        TopDocs search = searcher.search(knnVectorQuery, 10);
    }

Lucene 中添加向量数据是通过KnnVectorField封装向量数据,然后作为文档的一个字段使用IndexWriter.addDocuments接口一起进行索引。

IndexWriter.addDocuments最终会调用到org.apache.lucene.index.IndexingChain#processField处理每个字段的数据,在这个方法中处理向量数据时,使用org.apache.lucene.index.VectorValuesWriter暂存向量的数据,可以简单看下VectorValuesWriter的成员变量:

// 字段的信息,向量相关的是:维度和距离度量  
private final FieldInfo fieldInfo;

// segment当前使用的内存,会配合阈值使用,是触发flush的一个条件
private final Counter iwBytesUsed;

// 存储向量数据
private final List<float[]> vectors = new ArrayList<>();

// 位图,不一定每个文档都有这个字段,用位图来记录包含这个字段的文档id
private final DocsWithFieldSet docsWithField;

使用VectorValuesWriter#addValue暂存数据:

  public void addValue(int docID, float[] vectorValue) {
    // 向量字段,每个文档只能有一个
    if (docID == lastDocID) {
      throw new IllegalArgumentException(
          "VectorValuesField \""
              + fieldInfo.name
              + "\" appears more than once in this document (only one value is allowed per field)");
    }

    。。。(省略)  

    // 记录docId
    docsWithField.add(docID);

    // 存储向量数据
    vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));

    // 更新当前段的内存数据
    updateBytesUsed();

    lastDocID = docID;
  }

可以看到,在数据添加过程中只是把向量数据存起来,并没有构建HNSW。HNSW的构建逻辑在VectorValuesWriter的flush方法中,这边先有个印象后面会详细介绍。

HNSW查找和构建的工具类

为了方便后面源码的理解,我们先熟悉下HNSW查找和构建过程中几个相关的工具类。

NeighborArray

NeighborArray是存储某个节点的邻居信息的数据结构,其中有两个数组,分别存储邻居节点编号和邻居到节点距离,两个数组元素是一一对应的。可以看到NeighborArray的结构非常简单:

  // 邻居个数
  private int size;

  // 所有邻居到节点的距离
  float[] score;

  // 所有邻居节点的编号
  int[] node;

BoundsChecker

BoundsChecker是一个边界检查工具,可以设置一个最大值或者最小值来判断是否满足上界或者下界要求。

比如在HNSW查找的迭代过程中,如果新找到的候选节点集合的中距离目标节点最近都没有比已找到的节点集合的最远距离还小,则停止迭代,这话说的可能有点绕,不要着急,后面用到的时候会有直观的理解。

public abstract class BoundsChecker {
  // 边界值
  float bound;

  // 如果边界更优则更新边界值。
  // 如果是最小值检查,sample比当前最小值小,则更新边界。
  // 如果是最大值检查,sample比当前最大值大,则更新边界。
  public abstract void update(float sample);

  // 设置边界
  public void set(float sample) {
    bound = sample;
  }

  // 边界检查
  // 如果是最小值检查,则判断sample是否大于当前的最小值边界
  // 如果是最大值检查,则判断sample是否小于当前的最大值边界 
  public abstract boolean check(float sample);

  // 根据是否逆序创建Min或者Max,Min和Max是BoundsChecker内部类  
  public static BoundsChecker create(boolean reversed) {
    if (reversed) {
      return new Min();
    } else {
      return new Max();
    }
  }

  // 最大值检查工具类
  public static class Max extends BoundsChecker {
    Max() {
      bound = Float.NEGATIVE_INFINITY;
    }

    @Override
    public void update(float sample) {
      if (sample > bound) {
        bound = sample;
      }
    }

    @Override
    public boolean check(float sample) {
      return sample < bound;
    }
  }

  // 最小值检查工具类
  public static class Min extends BoundsChecker {

    Min() {
      bound = Float.POSITIVE_INFINITY;
    }

    @Override
    public void update(float sample) {
      if (sample < bound) {
        bound = sample;
      }
    }

    @Override
    public boolean check(float sample) {
      return sample > bound;
    }
  }
}

LongHeap

有界最小堆(堆顶元素永远最小),底层是数组实现,所以最大size不能超过数组的最大长度。重点看一个方法:

  // 带溢出检查的方式插入   
  public boolean insertWithOverflow(long value) {
    // 如果已经达到堆的大小限制  
    if (size >= maxSize) {
      // 因为是最小堆,所以比堆顶小就直接返回插入失败  
      if (value < heap[1]) {
        return false;
      }
      // 走到这里说明比堆顶大,直接替换堆顶元素,然后调整堆重新成为最小堆  
      updateTop(value);
      return true;
    }
    // 没有达到堆大小限制直接插入  
    push(value);
    return true;
  }

NeighborQueue

NeighborQueue其实也是一个堆,数据存储是使用成员变量LongHeap。因为LongHeap是个最小堆,因此引入了一个内部类Order根据是否逆序做值转化,就相当于使用LongHeap也可以实现最大堆的功能。

先看NeighborQueue的内部枚举类Order:

  // 因为LongHeap是最小堆,所以如果是需要最大堆的功能,则需要做倒序转化
  private static enum Order {
    // 自然顺序:从小到大  
    NATURAL {
      @Override
      long apply(long v) {
        return v;
      }
    },
    // 倒序:从大到小  
    REVERSED {
      @Override
      long apply(long v) {
        return -1 - v;
      }
    };

    // 值转化
    // 自然顺序,不需要转化,也就是最小堆
    // 逆序,则存-1 - v,相当于是最大堆   
    abstract long apply(long v);
  }

再看下NeighborQueue的成员变量和构造函数:

  // 存储距离和节点编号的复合体,可以简单理解成是节点和目标节点的距离
  private final LongHeap heap;

  // 自然顺序还是逆序
  private final Order order;

  // 遍历过的节点
  private int visitedCount;

  // 是否是提前截断导致查找停止
  private boolean incomplete;

  // initialSize:堆的大小限制
  // reversed:是否逆序,控制是最大堆还是最小堆
  public NeighborQueue(int initialSize, boolean reversed) {
    this.heap = new LongHeap(initialSize);
    this.order = reversed ? Order.REVERSED : Order.NATURAL;
  }

简单看下NeighborQueue中几个重要的方法,其他方法比较简单,可以直接翻源码:

  // 添加新的邻居节点及其距离,经过编码之后插入堆中
  public void add(int newNode, float newScore) {
    heap.push(encode(newNode, newScore));
  }

  // 带溢出检查的方式添加新的邻居节点及其距离,编码之后使用堆的insertWithOverflow方法插入
  public boolean insertWithOverflow(int newNode, float newScore) {
    return heap.insertWithOverflow(encode(newNode, newScore));
  }

  // 高32位存储的是距离,低32位存的是节点编号
  private long encode(int node, float score) {
    return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | node);
  }

  // 删除堆顶元素,返回的是节点编号,直接long转int就行,低32位就是编号
  public int pop() {
    return (int) order.apply(heap.pop());
  }

  // 返回堆顶节点的编号
  public int topNode() {
    return (int) order.apply(heap.top());
  }

  // 返回堆顶节点的距离
  public float topScore() {
    return NumericUtils.sortableIntToFloat((int) (order.apply(heap.top()) >> 32));
  }

HnswGraph

HnswGraph是HNSW图的数据结构,最主要信息是HNSW的层次信息,层中的节点信息,层中节点的邻居信息。抽象类HnswGraph主要描述的是HNSW数据结构应该可以提供哪些信息,具体如下:

public abstract class HnswGraph {
  // 定位到target节点的邻居的存储位置,然后可以调用nextNeighbor遍历所有的邻居
  public abstract void seek(int level, int target) throws IOException;

  // HNSW中的节点总数,其实也是对底层的节点总数,因为最底层包含了所有的节点
  public abstract int size();

  // 获取邻居节点,如果遍历结束或者没有邻居返回NO_MORE_DOCS
  public abstract int nextNeighbor() throws IOException;

  // 返回HNSW的层数
  public abstract int numLevels() throws IOException;

  // 搜索时,位于顶层的起始遍历节点,只有一个起始节点entry point
  public abstract int entryNode() throws IOException;

  // 指定层的节点迭代器,可以通过迭代器获取某一层的所有节点
  public abstract NodesIterator getNodesOnLevel(int level) throws IOException;
}

HnswGraph有两个实现类,OnHeapHnswGraph和OffHeapHnswGraph。构建的时候使用的OnHeapHnswGraph,持久化生成索引文件之后如果重新加载会重新使用OffHeapHnswGraph打开。从类名也可以看出这两个类使用的内存分别是是堆外和堆内,因为OnHeapHnswGraph在构建过程中需要添加新节点,所以OnHeapHnswGraph中额外多一个方法addNode。下面详细介绍OnHeapHnswGraph,OffHeapHnswGraph大家可以自行翻阅源码。

先来看下OnHeapHnswGraph的成员变量:

  // 每个节点最大的邻居个数
  private final int maxConn;

  // 总层数
  private int numLevels; 

  // 检索的初始节点,在最顶层。在论文中表示为entry point
  private int entryNode; 

  // 每一层的节点
  // 默认第0层是所有的节点,因此不需要存储数据,nodesByLevel.get(0) == null
  private final List<int[]> nodesByLevel;

  // 每一层节点的邻居
  private final List<List<NeighborArray>> graph;

  // 迭代遍历邻居用的
  private int upto;
  private NeighborArray cur;

OnHeapHnswGraph构造函数:

 // levelOfFirstNode是随机生成的,代表的是初始的层编号,注意层编号是从0开始的。
 OnHeapHnswGraph(int maxConn, int levelOfFirstNode) {
    this.maxConn = maxConn;
    this.numLevels = levelOfFirstNode + 1;
    this.graph = new ArrayList<>(numLevels);
    // 默认是第一个向量作为顶层的节点,也是HNSW构建的起始点 
    this.entryNode = 0;
    // 为每一层初始化邻居列表 
    for (int i = 0; i < numLevels; i++) {
      graph.add(new ArrayList<>());
      // 根据经验值初始化列表大小
      graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
    }

    this.nodesByLevel = new ArrayList<>(numLevels);
    // 第0层直接设置为null,代表所有的节点
    nodesByLevel.add(null); 
    // 每一层都加上初始节点,因为最顶层的节点肯定在所有层中都存在
    for (int l = 1; l < numLevels; l++) {
      nodesByLevel.add(new int[] {0});
    }
  }

OnHeapHnswGraph中实现HnswGraph的所有方法都比较简单,大家可以自行翻阅源码。这里我们重点看下构建过程新增节点的逻辑:

  // level:要加入哪一层,是新加入节点通过随机函数产生的
  // node:加入的节点编码
  public void addNode(int level, int node) {
    // 因为第0层默认是所有节点都存在的,所以只记录0层以上的
    if (level > 0) {
      // 如果要加入的层大于目前的最高层
      if (level >= numLevels) {
        // 超出目前最高层的每一层都加上该节点  
        for (int i = numLevels; i <= level; i++) {
          graph.add(new ArrayList<>());
          nodesByLevel.add(new int[] {node});
        }
        // 更新层数,层编号是从0开始,所以加1
        numLevels = level + 1;
        // 更新初始节点,这是唯一的更新入口,由此可见 entryNode 是最高层的第一个节点
        entryNode = node;
      } else {
        // 当前层中加入节点,如果超出数组大小,有扩容处理
        int[] nodes = nodesByLevel.get(level);
        int idx = graph.get(level).size();
        if (idx < nodes.length) {
          nodes[idx] = node;
        } else {
          nodes = ArrayUtil.grow(nodes);
          nodes[idx] = node;
          nodesByLevel.set(level, nodes);
        }
      }
    }

    // 新加入的节点无论在那一层,都为这一层初始化一个邻居容器  
    graph.get(level).add(new NeighborArray(maxConn + 1));
  }

HNSW查找

因为HNSW的构建过程其实依赖了查找实现,所以在正式介绍HNSW构建之前,先看下怎么查找目标节点的最近邻。

说明:因为查找过程涉及距离的对比,不同的距离度量方式表示节点相近是不同的,比如欧式距离,距离越近值越小,而consine距离越近,值越大。为了方便起见,后面源码描述统一按照欧式距离来描述。

查找的逻辑全部在org.apache.lucene.util.hnsw.HnswGraphSearcher中实现。我们先看下成员变量,比较简单:

// 距离度量方式
private final VectorSimilarityFunction similarityFunction;

// 最近邻的候选者。以欧式距离为例,则是一个最小堆,队顶元素是距离目标节点最近的节点
private final NeighborQueue candidates;

// 用来标记访问过的节点位图:因为邻居的邻居可能也是自己的邻居,相同节点不需要重复访问
private final BitSet visited;

HnswGraphSearcher中只有三个核心的方法,都是用来搜索目标节点的topK近邻节点,一个是提供在整个HNSW中检索目标节点的topK近邻节点使用的,另外两个是在指定层中查找目标节点的topK近邻节点,区别是是否支持截断。

在指定层带截断的查找

在看具体实现之前,需要先弄清楚两个堆,以欧式距离为例:

  • candidates是个最小堆,用来存储查找过程中候选的节点,堆顶节点是距离目标节点最近的节点。
  • results是个最大堆,存储的是目前为止找到的最近邻的节点,堆顶节点是距离目标节点最远的节点。
private NeighborQueue searchLevel(
    float[] query,  // 要查找的目标向量
    int topK, 
    int level,      // 在那一层查找
    final int[] eps, // 起始遍历节点的编号列表
    RandomAccessVectorValues vectors, // 待检索的向量集合
    HnswGraph graph, 
    Bits acceptOrds,// 相当于是向量编号白名单,用来过滤查找结果。lucene中的段是不可变的,如果段中的数据被删除,
                    // 真正的数据不会被删除,而是用位图记录起来,acceptOrds相当于是记录存活的有效向量id

    int visitedLimit) // 正常是访问迭代到最近邻,这个参数可以控制提前停止迭代,当然找到的一般不是最优结果
    throws IOException {
  int size = graph.size();
  // 最终的topK结果,注意是带大小限制的堆,以欧式距离为例,是个大顶堆
  NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed);
  // 初始化,清空 HnswGraphSearcher 的两个成员变量,候选者(candidates) 和已经访问过的节点位图(visited)
  clearScratchState();

  // 记录当前访问过的节点个数,用来和visitedLimit判断是否截断,停止搜索
  int numVisited = 0;
  // 遍历起始节点,获取初始的候选节点和结果集  
  for (int ep : eps) {
    // 如果没有访问过  
    if (visited.getAndSet(ep) == false) {
      // 是否要截断
      if (numVisited >= visitedLimit) {
        // 截断标记  
        results.markIncomplete();
        break;
      }
      // 计算entry point和目标节点的距离  
      float score = similarityFunction.compare(query, vectors.vectorValue(ep));
      numVisited++;
      // 加入候选者堆中  
      candidates.add(ep, score);
      if (acceptOrds == null || acceptOrds.get(ep)) {
        // 如果向量在白名单中,则加入最终结果的堆中  
        results.add(ep, score);
      }
    }
  }

  // 以下的流程就是判断候选结果中的最近节点的距离是不是都比结果集中的大,如果是,说明已经找到了最近邻topK,则返回结果集,否则继续查找候选集中的节点的邻居,不断迭代。

  // 以欧式距离为例,设置一个最小值边界检查,如果待检查的值大于这个边界,返回true。
  // 把当前结果集中的最大距离设置为边界
  BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
  if (results.size() >= topK) {
    bound.set(results.topScore());
  }
  // 遍历候选节点堆(最小堆),也就是遍历是从近到远,直到没有候选者或者截断发生
  while (candidates.size() > 0 && results.incomplete() == false) {
    // 获取堆顶节点,是候选列表中的最近节点
    float topCandidateScore = candidates.topScore();
    // 如果候选列表中的距离最近节点都没有满足边界要求,则结束迭代  
    if (bound.check(topCandidateScore)) {
      break;
    }

    int topCandidateNode = candidates.pop();
    // 定位到候选节点的邻居  
    graph.seek(level, topCandidateNode);
    int friendOrd;
    // 遍历候选节点的邻居  
    while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
      // 访问过则忽略  
      if (visited.getAndSet(friendOrd)) {
        continue;
      }

      // 截断发生
      if (numVisited >= visitedLimit) {
        results.markIncomplete();
        break;
      }

      float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
      numVisited++;
      // 如果节点比结果集中的堆顶节点要近。
      // 这个条件可能造成最后结果没有topK,
      // 我觉得是个bug,应该是bound.check(score) == false || results.size() < topK
      if (bound.check(score) == false) {
        // 当前遍历节点加入候选列表  
        candidates.add(friendOrd, score);
        if (acceptOrds == null || acceptOrds.get(friendOrd)) {
          if (results.insertWithOverflow(friendOrd, score) && results.size() >= topK) {
            // 如果超出结果集的大小限制,说明堆顶节点被替换,则重新设置堆顶节点的距离作为边界  
            bound.set(results.topScore());
          }
        }
      }
    }
  }
  // 前面从eps初始化的时候,可能超过topK
  while (results.size() > topK) {
    results.pop();
  }
  results.setVisitedCount(numVisited);
  return results;
}

在指定层不带截断的查找

比较简单,调用的是带截断的指定层查找方法,传入的截断参数是integer的最大值。

  NeighborQueue searchLevel(
      float[] query,
      int topK,
      int level,
      final int[] eps,
      RandomAccessVectorValues vectors,
      HnswGraph graph)
      throws IOException {
    return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
  }

HNSW全局检索使用的接口

HNSW全局检索是从最顶层的entry point开始逐层查找,逐层往下在每一层中寻找最近邻节点作为下一层的entry point,最后在最底层寻找topK个最近邻节点,具体一起来看代码更清楚:

  public static NeighborQueue search(
      float[] query,  // 待查找的目标向量
      int topK, 
      RandomAccessVectorValues vectors,  // 目标向量集合
      VectorSimilarityFunction similarityFunction, 
      HnswGraph graph, // HNSW图,主要是层次信息和节点的邻居信息
      Bits acceptOrds, // 相当于是向量编号白名单,用来过滤查找结果。lucene中的段是不可变的,如果段中的数据被删除,
                       // 真正的数据不会被删除,而是用位图记录起来,acceptOrds相当于是记录存活的有效向量id

      int visitedLimit) // 正常是访问迭代到最近邻,这个参数可以控制提前停止迭代,当然找到的一般不是最优结果
      throws IOException {
    // 因为 graphSearcher 不是线程安全的,所以使用局部变量
    HnswGraphSearcher graphSearcher =
        new HnswGraphSearcher(
            similarityFunction,
            new NeighborQueue(topK, similarityFunction.reversed == false),
            new SparseFixedBitSet(vectors.size()));
    // 存储最后的结果  
    NeighborQueue results;
    // 起始遍历的节点为最顶层的entry point 
    int[] eps = new int[] {graph.entryNode()};
    int numVisited = 0;
    // 从最顶层开始遍历直到倒数第二层,每一层中找到一个最近邻节点作为下一层的起始点
    // 这是一个快速靠近目标节点的过程
    for (int level = graph.numLevels() - 1; level >= 1; level--) {
      // 调用了带截断的在指定层查找方法
      results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
      eps[0] = results.pop();

      numVisited += results.visitedCount();
      visitedLimit -= results.visitedCount();
    }
    // 最后一层包含了所有的节点,从最后一层中找到真正的最近邻topK  
    results =
        graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
    results.setVisitedCount(results.visitedCount() + numVisited);
    return results;
  }

HNSW构建

触发构建时机

HNSW的构建触发是在segment生成的是时候。segment的生成可以是由IndexWriter.flush手动触发,也可以IndexWriter.addDocuments调用过程中满足某些条件自动触发。

构建的入口

segment的flush最终会走到org.apache.lucene.index.IndexingChain#flush,在其中会调用org.apache.lucene.index.IndexingChain#writeVectors方法进行HNSW的构建和落盘:

private void writeVectors(SegmentWriteState state, Sorter.DocMap sortMap) throws IOException {
    // KnnVectorsWriter是个抽象类,在lucene9.1.0中的实现是Lucene91HnswVectorsWriter。
    KnnVectorsWriter knnVectorsWriter = null;
    boolean success = false;
    try {
        for (int i = 0; i < fieldHash.length; i++) {
            PerField perField = fieldHash[i];
            while (perField != null) {

                    。。。(一些判断和knnVectorsWriter的初始化)
                    // 真正执行flush的逻辑
                    perField.vectorValuesWriter.flush(sortMap, knnVectorsWriter);
                    perField.vectorValuesWriter = null;

                    。。。(一些判断)
            }
        }
        if (knnVectorsWriter != null) {
            // 为相关索引文件添加注脚,主要是校验码
            knnVectorsWriter.finish();
        }
        success = true;
    } finally {
        if (success) {
            IOUtils.close(knnVectorsWriter);
        } else {
            IOUtils.closeWhileHandlingException(knnVectorsWriter);
        }
    }
}

上面方法主要逻辑是在初始化Lucene91HnswVectorsWriter(KnnVectorWriter的实现类),然后将Lucene91HnswVectorsWriter传给

org.apache.lucene.index.VectorValuesWriter#flush继续执行,我们看下VectorValuesWriter#flush的具体逻辑:

  public void flush(Sorter.DocMap sortMap, KnnVectorsWriter knnVectorsWriter) throws IOException {
    // KnnVectorsReader 中最主要的方法是getVectorValues获取所有待构建的向量数据
    // getVectorValues  返回的也是封装了向量数据,维度等信息的BufferedVectorValues
    KnnVectorsReader knnVectorsReader =
        new KnnVectorsReader() {
          @Override
          public long ramBytesUsed() {
            return 0;
          }

          @Override
          public void close() throws IOException {
            throw new UnsupportedOperationException();
          }

          @Override
          public void checkIntegrity() throws IOException {
            throw new UnsupportedOperationException();
          }

          @Override
          public VectorValues getVectorValues(String field) throws IOException {
            VectorValues vectorValues =
                new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
            return sortMap != null ? new SortingVectorValues(vectorValues, sortMap) : vectorValues;
          }

          @Override
          public TopDocs search(
              String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
              throws IOException {
            throw new UnsupportedOperationException();
          }
        };

    // 核心逻辑  
    knnVectorsWriter.writeField(fieldInfo, knnVectorsReader);
  }

在介绍knnVectorsWriter.writeField之前,先介绍下Lucene91HnswVectorsWriter,先看下成员变量:

  // 持久化的相关文件的输出流
  // meta:向量元信息文件输出流
  // vectorData: 向量数据文件输出流
  // vectorIndex: 向量索引文件输出流(存储邻居信息)
  private final IndexOutput meta, vectorData, vectorIndex;

  // segment中的文档总个数
  private final int maxDoc;

  // 每个节点邻居的上限
  private final int maxConn;

  // 从第0层到Math.min(nodeLevel, curMaxLevel),在每一层中查询最近邻的候选个数
  private final int beamWidth;

  // 是否成功构建完成
  private boolean finished;

好了,现在来看最底层的核心逻辑:

org.apache.lucene.codecs.lucene91.Lucene91HnswVectorsWriter#writeField

  public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
      throws IOException {
    // 字节对齐(https://www.thinbug.com/q/47510783)
    long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
    // 前面提到的BufferedVectorValues,保存了向量信息
    VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);

    // 临时文件,用来写所有的向量数据,如果hnsw构建失败,则数据还在。
    IndexOutput tempVectorData =
        segmentWriteState.directory.createTempOutput(
            vectorData.getName(), "temp", segmentWriteState.context);
    IndexInput vectorDataInput = null;
    boolean success = false;
    try {
      // 向量数据写入临时文件
      DocsWithFieldSet docsWithField = writeVectorData(tempVectorData, vectors);
      CodecUtil.writeFooter(tempVectorData);
      IOUtils.close(tempVectorData);

      // 将临时文件中的数据拷贝到真正的segment中的向量数据文件
      vectorDataInput =
          segmentWriteState.directory.openInput(
              tempVectorData.getName(), segmentWriteState.context);
      vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
      CodecUtil.retrieveChecksum(vectorDataInput);
      long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
      long vectorIndexOffset = vectorIndex.getFilePointer();

      // 又把向量数据重新封装到OffHeapVectorValues中
      Lucene91HnswVectorsReader.OffHeapVectorValues offHeapVectors =
          new Lucene91HnswVectorsReader.OffHeapVectorValues(
              vectors.dimension(), docsWithField.cardinality(), null, vectorDataInput);

      OnHeapHnswGraph graph =
          offHeapVectors.size() == 0
              ? null
                // 写入HNSW图结构,其中包含了构建流程
              : writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
      long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
      // 写向量元信息  
      writeMeta(
          fieldInfo,
          vectorDataOffset,
          vectorDataLength,
          vectorIndexOffset,
          vectorIndexLength,
          docsWithField,
          graph);
      success = true;
    } finally {
      IOUtils.close(vectorDataInput);
      if (success) {
        segmentWriteState.directory.deleteFile(tempVectorData.getName());
      } else {
        IOUtils.closeWhileHandlingException(tempVectorData);
        IOUtils.deleteFilesIgnoringExceptions(
            segmentWriteState.directory, tempVectorData.getName());
      }
    }
  }

writeGraph的逻辑是先构建HNSW,然后再把图结构持久化落盘,层层嵌套,终于要到关键的构建逻辑了:

  private OnHeapHnswGraph writeGraph(
      RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
      throws IOException {

    // HNSW的构建器
    HnswGraphBuilder hnswGraphBuilder =
        new HnswGraphBuilder(
            vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
    hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
    // 构建HNSW  
    OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());

    // 构建好的HNSW写入文件
    int countOnLevel0 = graph.size();
    for (int level = 0; level < graph.numLevels(); level++) {
      NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
      while (nodesOnLevel.hasNext()) {
        int node = nodesOnLevel.nextInt();
        NeighborArray neighbors = graph.getNeighbors(level, node);
        int size = neighbors.size();
        vectorIndex.writeInt(size);
        // Destructively modify; it's ok we are discarding it after this
        int[] nnodes = neighbors.node();
        Arrays.sort(nnodes, 0, size);
        for (int i = 0; i < size; i++) {
          int nnode = nnodes[i];
          assert nnode < countOnLevel0 : "node too large: " + nnode + ">=" + countOnLevel0;
          vectorIndex.writeInt(nnode);
        }
        // if number of connections < maxConn, add bogus values up to maxConn to have predictable
        // offsets
        for (int i = size; i < maxConn; i++) {
          vectorIndex.writeInt(0);
        }
      }
    }
    return graph;
  }

构建流程

HNSW的构建核心逻辑都在HnswGraphBuilder中,我们先看下HnswGraphBuilder的成员变量:

  // 每个节点的最大邻居个数
  private final int maxConn;

  // 构建查找节点邻居时,最多的候选邻居个数
  // 然后再从这些候选邻居中根据启发式选择算法(HNSW论文中)选择出maxConn个真正的邻居
  private final int beamWidth;

  // 新增节点时需要一个随机函数为其生成该节点可以到达的最高层,ml是这个函数的标准化参数
  private final double ml;

  // 候选的邻居节点会按照距离由远到近排序存储在scratch,然后执行启发式选择算法从中选择真正的邻居
  private final NeighborArray scratch;

  // 距离度量
  private final VectorSimilarityFunction similarityFunction;

  // 在查找邻居的时候使用vectorValues获取指定的向量
  private final RandomAccessVectorValues vectorValues;

  // 在执行启发式邻居选择算法时通过buildVectors获取指定的向量。
  // buildVectors底层的向量数据跟vectorValues是一样的。
  // 虽然RandomAccessVectorValues不是线程安全的,
  // 但是当前的构建是单线程的,所以应该可以和vectorValues公用一个,难道是为了多线程做铺垫?
  private RandomAccessVectorValues buildVectors;

  // 随机生成层高的函数会用到
  private final SplittableRandom random;

  // 启发式选择算法用到
  private final BoundsChecker bound;

  // 构建的过程需要为新加入的节点查找近邻候选者
  private final HnswGraphSearcher graphSearcher;

  // 当前正在构建的HNSW图
  final OnHeapHnswGraph hnsw;

再看下HnswGraphBuilder的构造函数:

  public HnswGraphBuilder(
      RandomAccessVectorValuesProducer vectors,
      VectorSimilarityFunction similarityFunction,
      int maxConn,
      int beamWidth,
      long seed) {
    // 可见vectorValues和buildVectors底层是同一份向量数据  
    vectorValues = vectors.randomAccess();
    buildVectors = vectors.randomAccess();
    this.similarityFunction = Objects.requireNonNull(similarityFunction);
    this.maxConn = maxConn;
    this.beamWidth = beamWidth;
    // 标准化参数和maxConn有关
    this.ml = 1 / Math.log(1.0 * maxConn);
    this.random = new SplittableRandom(seed);
    // 前面介绍OnHeapHnswGraph的时候说过,levelOfFirstNode是随机生成的初始层号。
    int levelOfFirstNode = getRandomGraphLevel(ml, random);
    // 初始化HNSW结构,等待加入节点  
    this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode);
    this.graphSearcher =
        new HnswGraphSearcher(
            similarityFunction,
            new NeighborQueue(beamWidth, similarityFunction.reversed == false),
            new FixedBitSet(vectorValues.size()));
    bound = BoundsChecker.create(similarityFunction.reversed);
    scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
  }

终于可以进入正题了,来看看构建的具体逻辑,入口在build方法中:

  public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
    // vectors 和 vectorValues 底层也是同一份向量数据
    // 之所以要区分,我理解是为了顺序IO。vectors是用来遍历所有的向量数据进行构建的,从头往后遍历,
    // 而vectorValues是随机访问
    if (vectors == vectorValues) {
      throw new IllegalArgumentException(
          "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
    }

    // 第0个节点默认就是第一个向量,在初始化HNSW的时候已经默认构建好了第一个向量,还记得那个levelOfFirstNode吗?
    // 开始遍历节点,添加进HNSW(添加的意思是加入节点信息,也加入节点的邻居信息)
    for (int node = 1; node < vectors.size(); node++) {
      // 添加当前遍历的节点  
      addGraphNode(node, vectors.vectorValue(node));
    }
    return hnsw;
  }

在理解addGraphNode方法之前,先理解下随机生成层的可能情况,当前层的范围是[0, curMaxLevel] ,新加入的节点会随机生成它可以加入的最高层nodeLevel,因此nodeLevel可能的值有三种可能,添加的过程也是分别处理这三种情况:

  • nodeLevel > curMaxLevel,在大于curMaxLevel的每一层中直接加入当前处理的节点,因为只有它自己,没有其他节点。
  • 0 < nodeLevel < curMaxLevel,在(nodeLevel, curMaxLevel]的每一层,从高层到低层依次寻找最近的一个节点,起始节点是上一层的最近邻,最顶层的起始节点是HNSW的entry point。
  • 在[0, min(nodeLevel, curMaxLevel)]中的每一层寻找最近的beamWidth个最近邻,然后从中由启发式选择算法挑选maxCnn个作为真正的邻居
void addGraphNode(int node, float[] value) throws IOException {
    NeighborQueue candidates;
    // 获取当前处理节点的最高层号,随机产生,
    final int nodeLevel = getRandomGraphLevel(ml, random);
    // 当前HNSW的最高层号,层编号是从0开始的
    int curMaxLevel = hnsw.numLevels() - 1;
    // 起始遍历的节点,初始化为最顶层的entry point
    int[] eps = new int[] {hnsw.entryNode()};

    // 如果待加入节点的nodeLevel大于当前最高层,则在超出最高层的每一层中加入这个节点
    for (int level = nodeLevel; level > curMaxLevel; level--) {
      hnsw.addNode(level, node);
    }

    // 如果待加入节点的nodeLevel小于当前最高层,则从最高层开始在超出nodeLevel的每一层找一个最近点,然后作为下一层
    // 的entry point,继续往下迭代
    for (int level = curMaxLevel; level > nodeLevel; level--) {
      candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw);
      eps = new int[] {candidates.pop()};
    }

    // 在[0, min(nodeLevel, curMaxLevel)]中的每一层寻找最近的beamWidth个最近邻,
    // 然后从中由启发式选择算法挑选maxCnn个作为真正的邻居
    for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
      // 在当前层中寻找beamWidth个候选邻居节点
      candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw);
      // 当前层找到的候选邻居节点作为下一层的entry point  
      eps = candidates.nodes();
      // 把节点加入当前层,并生成一个空的邻居列表,在下面addDiverseNeighbors中会填充这个空的邻居列表
      hnsw.addNode(level, node);
      // 用启发式选择算法从候选的邻居中选择maxConn个邻居
      addDiverseNeighbors(level, node, candidates);
    }
  }

在上面方法中,其中一个核心步骤是寻找当前hnsw中的最近邻,这个在HSWN的查找中已经介绍了。我们来看另一个核心步骤添加邻居的操作。

启发式选择算法

  private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
      throws IOException {
    // 获取当前节点空的邻居列表  
    NeighborArray neighbors = hnsw.getNeighbors(level, node);
    // 对当前的候选者按照距离由远到近排序,暂存进scratch,底层是个数组,这个方法简单就不展开了
    popToScratch(candidates);
    // 按照启发式算法选择最多样(视觉效果是最发散)的邻居(多样的度量后面介绍)存入neighbors中
    selectDiverse(neighbors, scratch);

    // 邻居是互为邻居,所以也需要为邻居节点把当前节点作为邻居
    int size = neighbors.size();
    for (int i = 0; i < size; i++) {
      int nbr = neighbors.node[i];
      NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
      // 把当前节点作为邻居节点的邻居  
      nbrNbr.add(node, neighbors.score[i]);
      // 因为当前节点的邻居可能因为增加当前节点这个邻居导致邻居总数超过maxConn,
      // 所以需要调整已有的节点的邻居,满足最多邻居个数不超过maxConn
      if (nbrNbr.size() > maxConn) {
        // 如果超出了最大邻居个数,则按照多样性去掉一个 
        diversityUpdate(nbrNbr);
      }
    }
  }

再来详细看看怎么根据启发式选择算法选择多样性的邻居:

  private void selectDiverse(NeighborArray neighbors, NeighborArray candidates) throws IOException {
    // 从距离最近的候选节点开始遍历
    for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
      int cNode = candidates.node[i];
      float cScore = candidates.score[i]; 
      assert cNode < hnsw.size();
      // 如果满足多样性的检查,就加入结果集  
      if (diversityCheck(vectorValues.vectorValue(cNode), cScore, neighbors, buildVectors)) {
        neighbors.add(cNode, cScore);
      }
    }
  }

多样性的检查逻辑:

  private boolean diversityCheck(
      float[] candidate,
      float score, // 候选节点和当前加入节点的距离
      NeighborArray neighbors,
      RandomAccessVectorValues vectorValues)
      throws IOException {
    // 欧式距离为例  
    // 如果待校验的值比score(候选节点和当前加入节点的距离)小,则bound.check为false
    bound.set(score);
    // 遍历当前已经选出的邻居  
    for (int i = 0; i < neighbors.size(); i++) {
      // 当前候选邻居和其他邻居之间的距离  
      float diversityCheck =
          similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
      // 当前候选邻居和其他邻居之间的距离有一个比当前候选邻居和节点之间的距离小,则该候选邻居多样性不足,不能成为邻居。
      // 也就是邻居之间的距离要足够远,可以想象到要找的邻居是遍布四周的,而不是集中在一堆,
      // 这样查找的时候可以快速从该节点通过邻居出发遍历不同方向的节点  
      if (bound.check(diversityCheck) == false) {
        return false;
      }
    }
    return true;
  }

寻找多样性违规的邻居:

  private int findNonDiverse(NeighborArray neighbors) throws IOException {
    // 遍历寻找第一个多样性违规的邻居
    for (int i = neighbors.size() - 1; i >= 0; i--) {
      int nbrNode = neighbors.node[i];
      bound.set(neighbors.score[i]);
      float[] nbrVector = vectorValues.vectorValue(nbrNode);
      // 不用担心越界问题,能走到这里肯定是maxConn + 1个邻居  
      for (int j = maxConn; j > i; j--) {
        float diversityCheck =
            similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j]));
        // 如果节点i和节点j离的太近了,则返回该ndoe编号,用来删除  
        if (bound.check(diversityCheck) == false) {
          return i;
        }
      }
    }
    // 不存在多样性违规的节点,则返回-1  
    return -1;
  }

多样性去除超出邻居个数限制的逻辑:

  private void diversityUpdate(NeighborArray neighbors) throws IOException {
    assert neighbors.size() == maxConn + 1;
    // 寻找多样性违规的邻居  
    int replacePoint = findNonDiverse(neighbors);
    // 如果没有多样性违规的节点  
    if (replacePoint == -1) {
      // 如果新加入的邻居距离比第1个邻居远,则直接删除新加入的邻居,感觉有点粗暴
      bound.set(neighbors.score[0]);
      if (bound.check(neighbors.score[maxConn])) {
        neighbors.removeLast();
        return;
      } else {
        replacePoint = 0;
      }
    }
    // 用新加入的邻居替换可以替换的邻居  
    neighbors.node[replacePoint] = neighbors.node[maxConn];
    neighbors.score[replacePoint] = neighbors.score[maxConn];
    neighbors.removeLast();
  }

尾声

与论文标准算法的区别

论文中,启发式邻居选择算法如果没有找到足够的邻居,则会按照最近邻补足邻居个数,而lucene中没有做补足操作。

可以讨论的问题

  • 数据删除导致的搜索问题

    因为lucene的删除是标记删除,目前的实现是把存活的白名单传给检索接口,在检索的过程中直接过滤。

    直接考虑最极端的情况,某一层的entry point的邻居都被删除了,则无法找到目标的最近邻。部分删除也会影响召回。

    有人说是否可以用merge清理删除的数据,但是HNSW的merge操作就是由多个段的数据重新构建,相当耗时。

    另一种做法是检索完之后再做过滤,这样不会影响召回,但是过滤之后可能会不满足topK。

    不知道是否还有更优的解决方案?

最后

如有疏漏,欢迎指正讨论。