Lucene中的KnnFloatVectorQuery的查询原理

108 阅读12分钟

前言

  • 本文主要分析Lucene中KnnFloatVectorQuery的查询原理,分别分析没有filter和有filter的查询
  • 本文基于lucene 9.9.0 版本

没有filter的查询

代码示例

简单创建一个10000个128维的向量,然后进行knn查询

// 创建向量索引
IndexWriter indexWriter = new IndexWriter(directory, config);

int count = 10000;
int dim = 128;
List<Document> docs = new ArrayList<>();
for (int i = 0; i < count; i++) {
    Document doc = new Document();
    doc.add(new KeywordField("id", Integer.toString(i), Field.Store.YES));
    doc.add(new KnnFloatVectorField("fvecs", generateFVector(dim)));
    docs.add(doc);
}
indexWriter.addDocuments(docs);
indexWriter.commit();

// knn 查询
Directory readDirectory = FSDirectory.open(Path.of("data/lucene_knn_demo"));
IndexReader indexReader = DirectoryReader.open(readDirectory);

IndexSearcher indexSearcher = new IndexSearcher(indexReader);

float[] queryVector = generateFVector(128);

int k = 3;

TopDocs topDocs = indexSearcher.search(new KnnFloatVectorQuery("fvecs", queryVector, k), k);

for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
    System.out.println("score doc " + scoreDoc.doc + ", score: " + scoreDoc.score);
}

查询原理

整体流程

image.png1

具体的调用栈 lucene_knnfloatvectorquery_search.svg

关键代码解读

1. 在完成rewrite之后, KnnFloatVectorQuery会变成DocAndScoreQuery, 首先来看一下DocAndScoreQuery的定义
// AbstractKnnVectorQuery 
static class DocAndScoreQuery extends Query {

    private final int k;
    private final int[] docs;
    private final float[] scores;
    // segmentStarts的作用
    private final int[] segmentStarts;
    private final Object contextIdentity;
}

可以看到,DocAndScoreQuery中已经将符合条件的docIdscore都存储好了,从上面详细的调用栈中也能看到会在rewrite那一步做knn的search, 本文重点在于KnnFloatVectorQuery是如何完成knn的search的,那么重点会在Lucene99HnswVectorsReader#search方法上

Lucene99HnswVectorsReader#search
 public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
      throws IOException {
    FieldEntry fieldEntry = fields.get(field);

    if (fieldEntry.size() == 0
        || knnCollector.k() == 0
        || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
      return;
    }
    final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
    final KnnCollector collector =
        new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
    final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
    if (knnCollector.k() < scorer.maxOrd()) {
      // getGraph 会返回OffHeapHnswGraph的实例
      HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds);
    } else {
      // if k is larger than the number of vectors, we can just iterate over all vectors
      // and collect them
      for (int i = 0; i < scorer.maxOrd(); i++) {
        if (acceptedOrds == null || acceptedOrds.get(i)) {
          knnCollector.incVisitedCount(1);
          knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
        }
      }
    }

  private HnswGraph getGraph(FieldEntry entry) throws IOException {
    return new OffHeapHnswGraph(entry, vectorIndex);
  }
  
HnswGraphSearcher#search
public static void search(
    RandomVectorScorer scorer, KnnCollector knnCollector, HnswGraph graph, Bits acceptOrds)
    throws IOException {
  HnswGraphSearcher graphSearcher =
      new HnswGraphSearcher(
          new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph)));
  search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
}

// 搜索的核心代码
private static void search(
    RandomVectorScorer scorer,
    KnnCollector knnCollector,
    HnswGraph graph,
    HnswGraphSearcher graphSearcher,
    Bits acceptOrds)
    throws IOException {
  int initialEp = graph.entryNode();
  if (initialEp == -1) {
    return;
  }
  // 找到level 0 中最合适的entry point
  int[] epAndVisited = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector.visitLimit());
  int numVisited = epAndVisited[1];
  int ep = epAndVisited[0];
  if (ep == -1) {
    knnCollector.incVisitedCount(numVisited);
    return;
  }
  knnCollector.incVisitedCount(numVisited);
  // 在level 0 中进行最后的搜索
  graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds);
}
// graphSearcher.findBestEntryPoint
private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit)
    throws IOException {
  int size = getGraphSize(graph);
  int visitedCount = 1;
  prepareScratchState(size);
  int currentEp = graph.entryNode();
  float currentScore = scorer.score(currentEp);
  boolean foundBetter;

  // 最高level开始搜索
  for (int level = graph.numLevels() - 1; level >= 1; level--) {
    foundBetter = true;
    visited.set(currentEp);
    // Keep searching the given level until we stop finding a better candidate entry point
    while (foundBetter) {
      foundBetter = false;
      // 每一层中的搜索
      graphSeek(graph, level, currentEp);
      int friendOrd;
      // 找到currentEp在level中的所有邻居
      while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) {
        assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
        if (visited.getAndSet(friendOrd)) {
          continue;
        }
        if (visitedCount >= visitLimit) {
          return new int[] {-1, visitedCount};
        }
        float friendSimilarity = scorer.score(friendOrd);
        visitedCount++;
        if (friendSimilarity > currentScore) {
          currentScore = friendSimilarity;
          currentEp = friendOrd;
          foundBetter = true;
        }
      }
    }
  }
  return new int[] {currentEp, visitedCount};
}

// graphSeek的核心代码, Lucene99HnswVectorsReader#seek
public void seek(int level, int targetOrd) throws IOException {
  int targetIndex =
      level == 0
          ? targetOrd
          : Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd);
  assert targetIndex >= 0;
  // unsafe; no bounds checking
  // 存储的时候是按照level的顺序依次存储的,相当于二维数组到一维的一个映射,所以需要用targetIndex + graphLevelNodeIndexOffsets[level]
  dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level]));
  arcCount = dataIn.readVInt();
  if (arcCount > 0) {
    currentNeighborsBuffer[0] = dataIn.readVInt();
    for (int i = 1; i < arcCount; i++) {
      // 写入的时候使用delta编码的,所以这里需要重新累加起来
      currentNeighborsBuffer[i] = currentNeighborsBuffer[i - 1] + dataIn.readVInt();
    }
  }
  arc = -1;
  arcUpTo = 0;
}

// level 0的搜索, Bits acceptOrds 是过滤条件 (没有filter 的情况下是null)
void searchLevel(
     KnnCollector results,
     RandomVectorScorer scorer,
     int level,
     final int[] eps,
     HnswGraph graph,
     Bits acceptOrds)
     throws IOException {

   int size = getGraphSize(graph);

   prepareScratchState(size);

   for (int ep : eps) {
     if (visited.getAndSet(ep) == false) {
       if (results.earlyTerminated()) {
         break;
       }
       float score = scorer.score(ep);
       results.incVisitedCount(1);
       candidates.add(ep, score);
       if (acceptOrds == null || acceptOrds.get(ep)) {
         results.collect(ep, score);
       }
     }
   }

   // A bound that holds the minimum similarity to the query vector that a candidate vector must
   // have to be considered.
   float minAcceptedSimilarity = results.minCompetitiveSimilarity();
   // candidates 相当于是优先权队列
   while (candidates.size() > 0 && results.earlyTerminated() == false) {
     // get the best candidate (closest or best scoring)
     float topCandidateSimilarity = candidates.topScore();
     if (topCandidateSimilarity < minAcceptedSimilarity) {
       break;
     }

     int topCandidateNode = candidates.pop();
     // 同样的逻辑,不过这里是level 0
     graphSeek(graph, level, topCandidateNode);
     int friendOrd;
     while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) {
       assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
       if (visited.getAndSet(friendOrd)) {
         continue;
       }

       if (results.earlyTerminated()) {
         break;
       }
       float friendSimilarity = scorer.score(friendOrd);
       results.incVisitedCount(1);
       if (friendSimilarity > minAcceptedSimilarity) {
         candidates.add(friendOrd, friendSimilarity); // filter 不影响加入candidates
         if (acceptOrds == null || acceptOrds.get(friendOrd)) { // 如果result一直没有collect,会导致搜索时间变长
           if (results.collect(friendOrd, friendSimilarity)) {
             minAcceptedSimilarity = results.minCompetitiveSimilarity();
           }
         }
       }
     }
   }
 }
score计算中原始向量的获取

上述的搜索过程中,搜索的节点信息是HNSW中的graph, 但是具体score的计算还是依赖原始向量,接下来看下score具体的计算,是使用RandomVectorScorer计算的,来看一下RandomVectorScorer是如何创建的

// Lucene99FlatVectorsReader#getRandomVectorScorer
// 核心是原始向量的加载, vectorData是原始向量的存储文件
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
  FieldEntry fieldEntry = fields.get(field);
  if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
    return null;
  }
  return RandomVectorScorer.createFloats(
      OffHeapFloatVectorValues.load(
          fieldEntry.ordToDoc,
          fieldEntry.vectorEncoding,
          fieldEntry.dimension,
          fieldEntry.vectorDataOffset,
          fieldEntry.vectorDataLength,
          vectorData),
      fieldEntry.similarityFunction,
      target);
} 

// OffHeapFloatVectorValues#load
  public static OffHeapFloatVectorValues load(
      OrdToDocDISIReaderConfiguration configuration,
      VectorEncoding vectorEncoding,
      int dimension,
      long vectorDataOffset,
      long vectorDataLength,
      IndexInput vectorData)
      throws IOException {
    if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.FLOAT32) {
      return new EmptyOffHeapVectorValues(dimension);
    }
    IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength);
    int byteSize = dimension * Float.BYTES;
    if (configuration.docsWithFieldOffset == -1) {
        // 我们的例子中会使用DenseOffHeapVectorValues, 这里的DenseOffHeapVectorValues并不会直接加载到内存里
      return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, byteSize);
    } else {
      return new SparseOffHeapVectorValues(
          configuration, vectorData, bytesSlice, dimension, byteSize);
    }
  }

// DenseOffHeapVectorValues#vectorValue, 根据docId获取具体的原始向量
public float[] vectorValue(int targetOrd) throws IOException {
    if (lastOrd == targetOrd) {
        return value;
    }
    // slice 就是原始向量的存储文件, 本文的例子是MemorySegmentIndexInput的实例, 因为Directory用的MMapDirectory
    // 由于用的是MMap所以就是整体也就是OffHeap的内存
    slice.seek((long) targetOrd * byteSize);
    slice.readFloats(value, 0, value.length);
    lastOrd = targetOrd;
    return value;
}

至此,knn的search过程就结束了,这个过程会将搜索到的结果存储到DocAndScoreQuery

2. createWeight

createWeight就是创建一个Weight的实例,然后返回,但是这个Weight实例也要负责Scorer的创建, 而Scorer又要负责DocIdSetIterator的创建,DocIdSetIterator又要负责docId的迭代,所以createWeight这一步是相对比较复杂的

这里主要需要理解一下segmentStarts的作用

首先了解一下DocAndScoreQuery中的关键字段, 这些字段中在createWeight中会被用到

private final int k; // TopK
private final int[] docs; // 搜索到的所有segments中的docId, 这个docId是已经加上了segment的base,属于整体的docId而不是单个segment中的docId,所以后续在单个segment中搜素的时候需要减去相应的base
private final float[] scores; // 对应的score
private final int[] segmentStarts; // 每个segment中的对应的docId的数量(累加值),用于定位docs中docId对应的segment
private final Object contextIdentity;
举个例子
// 一共有6个segment,每个segment有10000个doc

docs: [9224, 41966, 53837]
// segmentStarts数组的大小是segment的数量+1, segment 从0开始
segmentStarts:[0, 1, 1, 1, 1, 2, 3]
第一个docId是9224, segmentStarts[1] - segmentStarts[0] = 1, 说明第0个segment中有1个doc, 该docId落在segment 0中
segmentStarts[2] - segemntStarts[1] = 0, 说明第二个segment中没有docId
segmentStarts[3] - segemntStarts[2] = 0, 说明第三个segment中没有docId
segmentStarts[4] - segemntStarts[3] = 0, 说明第四个segment中没有docId
第二个docId是41966, segmentStarts[5] - segmentStarts[4] = 1, 说明第4个segment中有1个doc, 该docId落在segment 4中
第三个docId是53837, segmentStarts[6] - segmentStarts[5] = 1, 说明第5个segment中有1个doc, 该docId落在segment 5中
// DocAndScoreQuery#createWeight
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
       throws IOException {
  if (searcher.getIndexReader().getContext().id() != contextIdentity) {
    throw new IllegalStateException("This DocAndScore query was created by a different reader");
  }
  return new Weight(this) {
    @Override
    public Explanation explain(LeafReaderContext context, int doc) {
      int found = Arrays.binarySearch(docs, doc + context.docBase);
      if (found < 0) {
        return Explanation.noMatch("not in top " + k);
      }
      return Explanation.match(scores[found] * boost, "within top " + k);
    }

    @Override
    public int count(LeafReaderContext context) {
      return segmentStarts[context.ord + 1] - segmentStarts[context.ord];
    }

    @Override
    public Scorer scorer(LeafReaderContext context) {
        // context.ord 是segment的index
      if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) {
        // 说明这个segment中没有搜索到任何的docId
        return null;
      }
      return new Scorer(this) {
        final int lower = segmentStarts[context.ord]; // lower是该segment在docId中的最小偏移
        final int upper = segmentStarts[context.ord + 1]; // upper是下一个segment在docId中的最小偏移
        // lower 和 upper 是该segment中的docId在docs数组中的边界
    

        int upTo = -1;

        @Override
        public DocIdSetIterator iterator() {
          return new DocIdSetIterator() {
            @Override
            public int docID() {
              return docIdNoShadow();
            }

            @Override
            public int nextDoc() {
              if (upTo == -1) {
                upTo = lower;
              } else {
                ++upTo;
              }
              return docIdNoShadow();
            }

            @Override
            public int advance(int target) throws IOException {
              return slowAdvance(target);
            }

            @Override
            public long cost() {
              return upper - lower;
            }
          };
        }

        @Override
        public float getMaxScore(int docId) {
          docId += context.docBase;
          float maxScore = 0;
          for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
            maxScore = Math.max(maxScore, scores[idx]);
          }
          return maxScore * boost;
        }

        @Override
        public float score() {
          return scores[upTo] * boost;
        }

        @Override
        public int advanceShallow(int docid) {
          int start = Math.max(upTo, lower);
          int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
          if (docidIndex < 0) {
            docidIndex = -1 - docidIndex;
          }
          if (docidIndex >= upper) {
            return NO_MORE_DOCS;
          }
          return docs[docidIndex];
        }

        /**
         * move the implementation of docID() into a differently-named method so we can call it
         * from DocIDSetIterator.docID() even though this class is anonymous
         *
         * @return the current docid
         */
        private int docIdNoShadow() {
          if (upTo == -1) {
            return -1;
          }
          if (upTo >= upper) {
            return NO_MORE_DOCS;
          }
          // 这里减去context.docBase是因为docs中存储的是整体的docId,而不是单个segment中的docId
          return docs[upTo] - context.docBase;
        }

        @Override
        public int docID() {
          return docIdNoShadow();
        }
      };
    }

    @Override
    public boolean isCacheable(LeafReaderContext ctx) {
      return true;
    }
  };
}

3. 计算score,根据segment做分组,基于taskExecutor并发计算

// IndexSearcher#search
private <C extends Collector, T> T search(
    Weight weight, CollectorManager<C, T> collectorManager, C firstCollector) throws IOException {
    final LeafSlice[] leafSlices = getSlices();
    if (leafSlices.length == 0) {
        // there are no segments, nothing to offload to the executor, but we do need to call reduce to
        // create some kind of empty result
        assert leafContexts.size() == 0;
        return collectorManager.reduce(Collections.singletonList(firstCollector));
    } else {
        final List<C> collectors = new ArrayList<>(leafSlices.length);
        collectors.add(firstCollector);
        final ScoreMode scoreMode = firstCollector.scoreMode();
        for (int i = 1; i < leafSlices.length; ++i) {
            final C collector = collectorManager.newCollector();
            collectors.add(collector);
            if (scoreMode != collector.scoreMode()) {
                throw new IllegalStateException(
                    "CollectorManager does not always produce collectors with the same score mode");
            }
        }
        final List<Callable<C>> listTasks = new ArrayList<>(leafSlices.length);
        // 按照segment(leafSlices) 并发score
        for (int i = 0; i < leafSlices.length; ++i) {
            final LeafReaderContext[] leaves = leafSlices[i].leaves;
            final C collector = collectors.get(i);
            listTasks.add(
                () -> {
                    search(Arrays.asList(leaves), weight, collector);
                    return collector;
                });
        }
        List<C> results = taskExecutor.invokeAll(listTasks);
        return collectorManager.reduce(results);
    }
}


  protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector)
      throws IOException {

    collector.setWeight(weight);

    // TODO: should we make this
    // threaded...? the Collector could be sync'd?
    // always use single thread:
    for (LeafReaderContext ctx : leaves) { // search each subreader
      final LeafCollector leafCollector;
      try {
        leafCollector = collector.getLeafCollector(ctx);
      } catch (
          @SuppressWarnings("unused")
          CollectionTerminatedException e) {
        // there is no doc of interest in this reader context
        // continue with the following leaf
        continue;
      }
      BulkScorer scorer = weight.bulkScorer(ctx);
      if (scorer != null) {
        if (queryTimeout != null) {
          scorer = new TimeLimitingBulkScorer(scorer, queryTimeout);
        }
        try {
          scorer.score(leafCollector, ctx.reader().getLiveDocs());
        } catch (
            @SuppressWarnings("unused")
            CollectionTerminatedException e) {
          // collection was terminated prematurely
          // continue with the following leaf
        } catch (
            @SuppressWarnings("unused")
            TimeLimitingBulkScorer.TimeExceededException e) {
          partialResult = true;
        }
      }
      // Note: this is called if collection ran successfully, including the above special cases of
      // CollectionTerminatedException and TimeExceededException, but no other exception.
      leafCollector.finish();
    }
  

// createWeight 中会有生成scorer的接口, 每个scorer有 iterator的接口负责docId的迭代, collector是SimpleTopScoreDocCollector
static void scoreAll(
        LeafCollector collector,
        DocIdSetIterator iterator,
        TwoPhaseIterator twoPhase,
        Bits acceptDocs)
        throws IOException {
      if (twoPhase == null) {
        for (int doc = iterator.nextDoc();
            doc != DocIdSetIterator.NO_MORE_DOCS;
            doc = iterator.nextDoc()) {
          if (acceptDocs == null || acceptDocs.get(doc)) {
            collector.collect(doc);
          }
        }
      } else {
        // The scorer has an approximation, so run the approximation first, then check acceptDocs,
        // then confirm
        for (int doc = iterator.nextDoc();
            doc != DocIdSetIterator.NO_MORE_DOCS;
            doc = iterator.nextDoc()) {
          if ((acceptDocs == null || acceptDocs.get(doc)) && twoPhase.matches()) {
            collector.collect(doc);
          }
        }
      }
    }

// SimpleTopScoreDocCollector.getLeafCollector ScorerLeafController#collect
public void collect(int doc) throws IOException {
    float score = scorer.score();

    // This collector relies on the fact that scorers produce positive values:
    assert score >= 0; // NOTE: false for NaN

    totalHits++;
    hitsThresholdChecker.incrementHitCount();

    if (minScoreAcc != null && (totalHits & minScoreAcc.modInterval) == 0) {
      updateGlobalMinCompetitiveScore(scorer);
    }

    if (score <= pqTop.score) {
      if (totalHitsRelation == TotalHits.Relation.EQUAL_TO) {
        // we just reached totalHitsThreshold, we can start setting the min
        // competitive score now
        updateMinCompetitiveScore(scorer);
      }
      // Since docs are returned in-order (i.e., increasing doc Id), a document
      // with equal score to pqTop.score cannot compete since HitQueue favors
      // documents with lower doc Ids. Therefore reject those docs too.
      return;
    }
    pqTop.doc = doc + docBase;
    pqTop.score = score;
    // 更新topN, pq 是优先权队列
    pqTop = pq.updateTop();
    updateMinCompetitiveScore(scorer);
}

带有filter的查询

代码示例

Directory readDirectory = FSDirectory.open(Path.of("data/lucene_knn_demo"));
IndexReader indexReader = DirectoryReader.open(readDirectory);
IndexSearcher indexSearcher = new IndexSearcher(indexReader);


float[] queryVector = generateFVector(128);

int k = 3;
BooleanQuery.Builder booleanQueryBuilder = new BooleanQuery.Builder();
TermQuery termQuery = new TermQuery(new Term("id", Integer.toString(1)));
booleanQueryBuilder.add(termQuery, BooleanClause.Occur.FILTER);
BooleanQuery filterQuery = booleanQueryBuilder.build();

// 将生成好的filterQuery添加到KnnFloatVectorQuery中,即可作为filter使用
KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery("fvecs", queryVector, k, filterQuery);

TopDocs topDocs = indexSearcher.search(knnQuery, k);

for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
    System.out.println("score doc " + scoreDoc.doc + ", score: " + scoreDoc.score);
}

查询原理

整体上和没有filter的查询是一样的,只是在搜索的时候会多一个filter的逻辑,通过HnswGraphSearcher#search中的acceptDocs体现, 简言之就是filter后的docIdSet会作为acceptDocs参数给HnswGraphSearcher, 后者在search的时候只有满足acceptDocs的docId才会被result的collector收集,这种方式在正确性上是没有问题的,但是对性能会有一定的影响,相当于扩大了搜索范围.

上文已经介绍了整体的搜索流程,这里不再赘述,下面主要介绍下filterQuery的生成acceptDocs的过程

过滤阶段也是主要发生在Query的rewrite阶段

// 
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
  IndexReader reader = indexSearcher.getIndexReader();

  final Weight filterWeight;
  if (filter != null) {
    // 可以看filter 被转化成了两个用Occur.FILTER连接的BooleanQuery, 其中FieldExistsQuery是用来过滤对应的向量字段不存在的doc
    BooleanQuery booleanQuery =
        new BooleanQuery.Builder()
            .add(filter, BooleanClause.Occur.FILTER)
            .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
            .build();
    // 这里最终会把两个clause的booleanQuery转化成只含有filter的中的booleanQuery的ConstantScoreQuery, 
    // FieldExistsQuery先会被转化成MatchAllDocsQuery然后还会被优化掉, 不过这个不是本文的重点,感兴趣的话可以看下BooleanQuery#rewrite
    Query rewritten = indexSearcher.rewrite(booleanQuery);
    // 返回LRUQueryCache中的CachingWrapperWeight,最底层的包装的是TermWeight (因为例子中的filter是TermQuery)
    filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
  } else {
    filterWeight = null;
  }
  //... 下面的代码和之前没有filter的逻辑是一致的
} 

// 转化过后的booleanQuery的rewrite, 由与只涉及到了filter相关的逻辑,这里省略了其他分支的代码
// BooleanQuery#rewrite
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
 if (clauses.size() == 0) {
   return new MatchNoDocsQuery("empty BooleanQuery");
 }
 // ... 
 // recursively rewrite
 {
   BooleanQuery.Builder builder = new BooleanQuery.Builder();
   builder.setMinimumNumberShouldMatch(getMinimumNumberShouldMatch());
   boolean actuallyRewritten = false;
   for (BooleanClause clause : this) {
     Query query = clause.getQuery();
     BooleanClause.Occur occur = clause.getOccur();
     Query rewritten;
     if (occur == Occur.FILTER || occur == Occur.MUST_NOT) {
       // Clauses that are not involved in scoring can get some extra simplifications
       // 这里将上文转化的的query再次封装成了ConstantScoreQuery, 然后调用rewrite
       rewritten = new ConstantScoreQuery(query).rewrite(indexSearcher);
       if (rewritten instanceof ConstantScoreQuery) {
         rewritten = ((ConstantScoreQuery) rewritten).getQuery();
       }
     } else {
       rewritten = query.rewrite(indexSearcher);
     }
     if (rewritten != query || query.getClass() == MatchNoDocsQuery.class) {
       // rewrite clause
       actuallyRewritten = true;
       if (rewritten.getClass() == MatchNoDocsQuery.class) {
         switch (occur) {
           case SHOULD:
           case MUST_NOT:
             // the clause can be safely ignored
             break;
           case MUST:
           case FILTER:
             return rewritten;
         }
       } else {
         builder.add(rewritten, occur);
       }
     } else {
       // leave as-is
       builder.add(clause);
     }
   }
   if (actuallyRewritten) {
     return builder.build();
   }
 // ...
}

带有filter的查询的search过程

 private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
    Bits liveDocs = ctx.reader().getLiveDocs();
    int maxDoc = ctx.reader().maxDoc();

    // 没有filter查询的走的是这个分支
    if (filterWeight == null) {
      return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
    }

    // 这里由于用的TermQuery,最终会返回一个TermQuery的scorer
    Scorer scorer = filterWeight.scorer(ctx);
    if (scorer == null) {
      return NO_RESULTS;
    }

    // 关键: 基于term的scorer创建acceptDocs
    BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
    int cost = acceptDocs.cardinality();

    // 判断经过filter之后的docIdSet是否小于等于k, 如果小于等于k,则直接使用exactSearch
    if (cost <= k) {
      // If there are <= k possible matches, short-circuit and perform exact search, since HNSW
      // must always visit at least k documents
      return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
    }

    // Perform the approximate kNN search
    // 上文已经分析过approximateSearch的逻辑,这里不再赘述
    TopDocs results = approximateSearch(ctx, acceptDocs, cost);
    if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
      return results;
    } else {
      // We stopped the kNN search because it visited too many nodes, so fall back to exact search
      return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
    }
  }

// AbstractQuery#exactSearch

  protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
      throws IOException {
    FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
    if (fi == null || fi.getVectorDimension() == 0) {
      // The field does not exist or does not index vectors
      return NO_RESULTS;
    }

    VectorScorer vectorScorer = createVectorScorer(context, fi);
    // 预设一个长度为k的优先权队列
    HitQueue queue = new HitQueue(k, true);
    ScoreDoc topDoc = queue.top();
    int doc;
    while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
        // 直接获取filter中的docId
      boolean advanced = vectorScorer.advanceExact(doc);
      assert advanced;

      float score = vectorScorer.score();
      if (score > topDoc.score) {
        topDoc.score = score;
        topDoc.doc = doc;
        topDoc = queue.updateTop();
      }
    }

    // Remove any remaining sentinel values
    while (queue.size() > 0 && queue.top().score < 0) {
      // 过滤出分数小于0的docId
      queue.pop();
    }

    ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
    for (int i = topScoreDocs.length - 1; i >= 0; i--) {
      topScoreDocs[i] = queue.pop();
    }

    TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
    return new TopDocs(totalHits, topScoreDocs);
  }

// AbstractQuery#createBitSet
private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
      throws IOException {
    if (liveDocs == null && iterator instanceof BitSetIterator) {
      // If we already have a BitSet and no deletions, reuse the BitSet
      return ((BitSetIterator) iterator).getBitSet();
    } else {
      // Create a new BitSet from matching and live docs
      FilteredDocIdSetIterator filterIterator =
          new FilteredDocIdSetIterator(iterator) {
            @Override
            protected boolean match(int doc) {
              return liveDocs == null || liveDocs.get(doc);
            }
          };
      // 基于filterIterator的cost 创建SparseFixedBitSet 或者FixedBitSet
      return BitSet.of(filterIterator, maxDoc);
    }
  }

// TermQuery#scorerSuppiler
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
  assert termStates == null || termStates.wasBuiltFor(ReaderUtil.getTopLevelContext(context))
      : "The top-reader used to create Weight is not the same as the current reader's top-reader ("
          + ReaderUtil.getTopLevelContext(context);

  // 获取TermsEnum的iterator
  final TermsEnum termsEnum = getTermsEnum(context);
  if (termsEnum == null) {
    return null;
  }
  final int docFreq = termsEnum.docFreq();

  return new ScorerSupplier() {

    private boolean topLevelScoringClause = false;

    @Override
    public Scorer get(long leadCost) throws IOException {
      LeafSimScorer scorer =
          new LeafSimScorer(simScorer, context.reader(), term.field(), scoreMode.needsScores());
      if (scoreMode == ScoreMode.TOP_SCORES) {
        return new TermScorer(
            TermWeight.this,
            termsEnum.impacts(PostingsEnum.FREQS),
            scorer,
            topLevelScoringClause);
      } else {
        return new TermScorer(
            TermWeight.this,
            termsEnum.postings(
                null, scoreMode.needsScores() ? PostingsEnum.FREQS : PostingsEnum.NONE),
            scorer);
      }
    }

    @Override
    public long cost() {
      return docFreq;
    }

    @Override
    public void setTopLevelScoringClause() throws IOException {
      topLevelScoringClause = true;
    }
  };
}

总结

  • 本文主要分析了KnnFloatVectorQuery的search过程,主要介绍向量搜索的基本原理和关键源码分析,其中在带有filter的查询中,filter的结果会作为acceptDocs参数给HnswGraphSearcher, 有可能会带来性能上的损耗,这个是需要注意的

Footnotes

  1. BulkScorer(一)(Lucene 9.6.0)