前言
- 本文主要分析Lucene中KnnFloatVectorQuery的查询原理,分别分析没有filter和有filter的查询
- 本文基于lucene 9.9.0 版本
没有filter的查询
代码示例
简单创建一个10000个128维的向量,然后进行knn查询
// 创建向量索引
IndexWriter indexWriter = new IndexWriter(directory, config);
int count = 10000;
int dim = 128;
List<Document> docs = new ArrayList<>();
for (int i = 0; i < count; i++) {
Document doc = new Document();
doc.add(new KeywordField("id", Integer.toString(i), Field.Store.YES));
doc.add(new KnnFloatVectorField("fvecs", generateFVector(dim)));
docs.add(doc);
}
indexWriter.addDocuments(docs);
indexWriter.commit();
// knn 查询
Directory readDirectory = FSDirectory.open(Path.of("data/lucene_knn_demo"));
IndexReader indexReader = DirectoryReader.open(readDirectory);
IndexSearcher indexSearcher = new IndexSearcher(indexReader);
float[] queryVector = generateFVector(128);
int k = 3;
TopDocs topDocs = indexSearcher.search(new KnnFloatVectorQuery("fvecs", queryVector, k), k);
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
System.out.println("score doc " + scoreDoc.doc + ", score: " + scoreDoc.score);
}
查询原理
整体流程
具体的调用栈
关键代码解读
1. 在完成rewrite之后, KnnFloatVectorQuery
会变成DocAndScoreQuery
, 首先来看一下DocAndScoreQuery
的定义
// AbstractKnnVectorQuery
static class DocAndScoreQuery extends Query {
private final int k;
private final int[] docs;
private final float[] scores;
// segmentStarts的作用
private final int[] segmentStarts;
private final Object contextIdentity;
}
可以看到,DocAndScoreQuery
中已经将符合条件的docId
和score
都存储好了,从上面详细的调用栈中也能看到会在rewrite那一步做knn的search, 本文重点在于KnnFloatVectorQuery
是如何完成knn的search的,那么重点会在Lucene99HnswVectorsReader#search
方法上
Lucene99HnswVectorsReader#search
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return;
}
final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
if (knnCollector.k() < scorer.maxOrd()) {
// getGraph 会返回OffHeapHnswGraph的实例
HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds);
} else {
// if k is larger than the number of vectors, we can just iterate over all vectors
// and collect them
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
knnCollector.incVisitedCount(1);
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
}
}
}
private HnswGraph getGraph(FieldEntry entry) throws IOException {
return new OffHeapHnswGraph(entry, vectorIndex);
}
HnswGraphSearcher#search
public static void search(
RandomVectorScorer scorer, KnnCollector knnCollector, HnswGraph graph, Bits acceptOrds)
throws IOException {
HnswGraphSearcher graphSearcher =
new HnswGraphSearcher(
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph)));
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
}
// 搜索的核心代码
private static void search(
RandomVectorScorer scorer,
KnnCollector knnCollector,
HnswGraph graph,
HnswGraphSearcher graphSearcher,
Bits acceptOrds)
throws IOException {
int initialEp = graph.entryNode();
if (initialEp == -1) {
return;
}
// 找到level 0 中最合适的entry point
int[] epAndVisited = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector.visitLimit());
int numVisited = epAndVisited[1];
int ep = epAndVisited[0];
if (ep == -1) {
knnCollector.incVisitedCount(numVisited);
return;
}
knnCollector.incVisitedCount(numVisited);
// 在level 0 中进行最后的搜索
graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds);
}
// graphSearcher.findBestEntryPoint
private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit)
throws IOException {
int size = getGraphSize(graph);
int visitedCount = 1;
prepareScratchState(size);
int currentEp = graph.entryNode();
float currentScore = scorer.score(currentEp);
boolean foundBetter;
// 最高level开始搜索
for (int level = graph.numLevels() - 1; level >= 1; level--) {
foundBetter = true;
visited.set(currentEp);
// Keep searching the given level until we stop finding a better candidate entry point
while (foundBetter) {
foundBetter = false;
// 每一层中的搜索
graphSeek(graph, level, currentEp);
int friendOrd;
// 找到currentEp在level中的所有邻居
while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
if (visited.getAndSet(friendOrd)) {
continue;
}
if (visitedCount >= visitLimit) {
return new int[] {-1, visitedCount};
}
float friendSimilarity = scorer.score(friendOrd);
visitedCount++;
if (friendSimilarity > currentScore) {
currentScore = friendSimilarity;
currentEp = friendOrd;
foundBetter = true;
}
}
}
}
return new int[] {currentEp, visitedCount};
}
// graphSeek的核心代码, Lucene99HnswVectorsReader#seek
public void seek(int level, int targetOrd) throws IOException {
int targetIndex =
level == 0
? targetOrd
: Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd);
assert targetIndex >= 0;
// unsafe; no bounds checking
// 存储的时候是按照level的顺序依次存储的,相当于二维数组到一维的一个映射,所以需要用targetIndex + graphLevelNodeIndexOffsets[level]
dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level]));
arcCount = dataIn.readVInt();
if (arcCount > 0) {
currentNeighborsBuffer[0] = dataIn.readVInt();
for (int i = 1; i < arcCount; i++) {
// 写入的时候使用delta编码的,所以这里需要重新累加起来
currentNeighborsBuffer[i] = currentNeighborsBuffer[i - 1] + dataIn.readVInt();
}
}
arc = -1;
arcUpTo = 0;
}
// level 0的搜索, Bits acceptOrds 是过滤条件 (没有filter 的情况下是null)
void searchLevel(
KnnCollector results,
RandomVectorScorer scorer,
int level,
final int[] eps,
HnswGraph graph,
Bits acceptOrds)
throws IOException {
int size = getGraphSize(graph);
prepareScratchState(size);
for (int ep : eps) {
if (visited.getAndSet(ep) == false) {
if (results.earlyTerminated()) {
break;
}
float score = scorer.score(ep);
results.incVisitedCount(1);
candidates.add(ep, score);
if (acceptOrds == null || acceptOrds.get(ep)) {
results.collect(ep, score);
}
}
}
// A bound that holds the minimum similarity to the query vector that a candidate vector must
// have to be considered.
float minAcceptedSimilarity = results.minCompetitiveSimilarity();
// candidates 相当于是优先权队列
while (candidates.size() > 0 && results.earlyTerminated() == false) {
// get the best candidate (closest or best scoring)
float topCandidateSimilarity = candidates.topScore();
if (topCandidateSimilarity < minAcceptedSimilarity) {
break;
}
int topCandidateNode = candidates.pop();
// 同样的逻辑,不过这里是level 0
graphSeek(graph, level, topCandidateNode);
int friendOrd;
while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
if (visited.getAndSet(friendOrd)) {
continue;
}
if (results.earlyTerminated()) {
break;
}
float friendSimilarity = scorer.score(friendOrd);
results.incVisitedCount(1);
if (friendSimilarity > minAcceptedSimilarity) {
candidates.add(friendOrd, friendSimilarity); // filter 不影响加入candidates
if (acceptOrds == null || acceptOrds.get(friendOrd)) { // 如果result一直没有collect,会导致搜索时间变长
if (results.collect(friendOrd, friendSimilarity)) {
minAcceptedSimilarity = results.minCompetitiveSimilarity();
}
}
}
}
}
}
score计算中原始向量的获取
上述的搜索过程中,搜索的节点信息是HNSW中的graph, 但是具体score的计算还是依赖原始向量,接下来看下score具体的计算,是使用RandomVectorScorer
计算的,来看一下RandomVectorScorer
是如何创建的
// Lucene99FlatVectorsReader#getRandomVectorScorer
// 核心是原始向量的加载, vectorData是原始向量的存储文件
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
return RandomVectorScorer.createFloats(
OffHeapFloatVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData),
fieldEntry.similarityFunction,
target);
}
// OffHeapFloatVectorValues#load
public static OffHeapFloatVectorValues load(
OrdToDocDISIReaderConfiguration configuration,
VectorEncoding vectorEncoding,
int dimension,
long vectorDataOffset,
long vectorDataLength,
IndexInput vectorData)
throws IOException {
if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.FLOAT32) {
return new EmptyOffHeapVectorValues(dimension);
}
IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength);
int byteSize = dimension * Float.BYTES;
if (configuration.docsWithFieldOffset == -1) {
// 我们的例子中会使用DenseOffHeapVectorValues, 这里的DenseOffHeapVectorValues并不会直接加载到内存里
return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, byteSize);
} else {
return new SparseOffHeapVectorValues(
configuration, vectorData, bytesSlice, dimension, byteSize);
}
}
// DenseOffHeapVectorValues#vectorValue, 根据docId获取具体的原始向量
public float[] vectorValue(int targetOrd) throws IOException {
if (lastOrd == targetOrd) {
return value;
}
// slice 就是原始向量的存储文件, 本文的例子是MemorySegmentIndexInput的实例, 因为Directory用的MMapDirectory
// 由于用的是MMap所以就是整体也就是OffHeap的内存
slice.seek((long) targetOrd * byteSize);
slice.readFloats(value, 0, value.length);
lastOrd = targetOrd;
return value;
}
至此,knn的search过程就结束了,这个过程会将搜索到的结果存储到DocAndScoreQuery
中
2. createWeight
createWeight就是创建一个Weight的实例,然后返回,但是这个Weight实例也要负责Scorer的创建, 而Scorer又要负责DocIdSetIterator的创建,DocIdSetIterator又要负责docId的迭代,所以createWeight这一步是相对比较复杂的
这里主要需要理解一下segmentStarts的作用
首先了解一下DocAndScoreQuery中的关键字段, 这些字段中在createWeight中会被用到
private final int k; // TopK
private final int[] docs; // 搜索到的所有segments中的docId, 这个docId是已经加上了segment的base,属于整体的docId而不是单个segment中的docId,所以后续在单个segment中搜素的时候需要减去相应的base
private final float[] scores; // 对应的score
private final int[] segmentStarts; // 每个segment中的对应的docId的数量(累加值),用于定位docs中docId对应的segment
private final Object contextIdentity;
举个例子
// 一共有6个segment,每个segment有10000个doc
docs: [9224, 41966, 53837]
// segmentStarts数组的大小是segment的数量+1, segment 从0开始
segmentStarts:[0, 1, 1, 1, 1, 2, 3]
第一个docId是9224, segmentStarts[1] - segmentStarts[0] = 1, 说明第0个segment中有1个doc, 该docId落在segment 0中
segmentStarts[2] - segemntStarts[1] = 0, 说明第二个segment中没有docId
segmentStarts[3] - segemntStarts[2] = 0, 说明第三个segment中没有docId
segmentStarts[4] - segemntStarts[3] = 0, 说明第四个segment中没有docId
第二个docId是41966, segmentStarts[5] - segmentStarts[4] = 1, 说明第4个segment中有1个doc, 该docId落在segment 4中
第三个docId是53837, segmentStarts[6] - segmentStarts[5] = 1, 说明第5个segment中有1个doc, 该docId落在segment 5中
// DocAndScoreQuery#createWeight
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
if (searcher.getIndexReader().getContext().id() != contextIdentity) {
throw new IllegalStateException("This DocAndScore query was created by a different reader");
}
return new Weight(this) {
@Override
public Explanation explain(LeafReaderContext context, int doc) {
int found = Arrays.binarySearch(docs, doc + context.docBase);
if (found < 0) {
return Explanation.noMatch("not in top " + k);
}
return Explanation.match(scores[found] * boost, "within top " + k);
}
@Override
public int count(LeafReaderContext context) {
return segmentStarts[context.ord + 1] - segmentStarts[context.ord];
}
@Override
public Scorer scorer(LeafReaderContext context) {
// context.ord 是segment的index
if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) {
// 说明这个segment中没有搜索到任何的docId
return null;
}
return new Scorer(this) {
final int lower = segmentStarts[context.ord]; // lower是该segment在docId中的最小偏移
final int upper = segmentStarts[context.ord + 1]; // upper是下一个segment在docId中的最小偏移
// lower 和 upper 是该segment中的docId在docs数组中的边界
int upTo = -1;
@Override
public DocIdSetIterator iterator() {
return new DocIdSetIterator() {
@Override
public int docID() {
return docIdNoShadow();
}
@Override
public int nextDoc() {
if (upTo == -1) {
upTo = lower;
} else {
++upTo;
}
return docIdNoShadow();
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}
@Override
public long cost() {
return upper - lower;
}
};
}
@Override
public float getMaxScore(int docId) {
docId += context.docBase;
float maxScore = 0;
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
maxScore = Math.max(maxScore, scores[idx]);
}
return maxScore * boost;
}
@Override
public float score() {
return scores[upTo] * boost;
}
@Override
public int advanceShallow(int docid) {
int start = Math.max(upTo, lower);
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
if (docidIndex < 0) {
docidIndex = -1 - docidIndex;
}
if (docidIndex >= upper) {
return NO_MORE_DOCS;
}
return docs[docidIndex];
}
/**
* move the implementation of docID() into a differently-named method so we can call it
* from DocIDSetIterator.docID() even though this class is anonymous
*
* @return the current docid
*/
private int docIdNoShadow() {
if (upTo == -1) {
return -1;
}
if (upTo >= upper) {
return NO_MORE_DOCS;
}
// 这里减去context.docBase是因为docs中存储的是整体的docId,而不是单个segment中的docId
return docs[upTo] - context.docBase;
}
@Override
public int docID() {
return docIdNoShadow();
}
};
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
3. 计算score,根据segment做分组,基于taskExecutor并发计算
// IndexSearcher#search
private <C extends Collector, T> T search(
Weight weight, CollectorManager<C, T> collectorManager, C firstCollector) throws IOException {
final LeafSlice[] leafSlices = getSlices();
if (leafSlices.length == 0) {
// there are no segments, nothing to offload to the executor, but we do need to call reduce to
// create some kind of empty result
assert leafContexts.size() == 0;
return collectorManager.reduce(Collections.singletonList(firstCollector));
} else {
final List<C> collectors = new ArrayList<>(leafSlices.length);
collectors.add(firstCollector);
final ScoreMode scoreMode = firstCollector.scoreMode();
for (int i = 1; i < leafSlices.length; ++i) {
final C collector = collectorManager.newCollector();
collectors.add(collector);
if (scoreMode != collector.scoreMode()) {
throw new IllegalStateException(
"CollectorManager does not always produce collectors with the same score mode");
}
}
final List<Callable<C>> listTasks = new ArrayList<>(leafSlices.length);
// 按照segment(leafSlices) 并发score
for (int i = 0; i < leafSlices.length; ++i) {
final LeafReaderContext[] leaves = leafSlices[i].leaves;
final C collector = collectors.get(i);
listTasks.add(
() -> {
search(Arrays.asList(leaves), weight, collector);
return collector;
});
}
List<C> results = taskExecutor.invokeAll(listTasks);
return collectorManager.reduce(results);
}
}
protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector)
throws IOException {
collector.setWeight(weight);
// TODO: should we make this
// threaded...? the Collector could be sync'd?
// always use single thread:
for (LeafReaderContext ctx : leaves) { // search each subreader
final LeafCollector leafCollector;
try {
leafCollector = collector.getLeafCollector(ctx);
} catch (
@SuppressWarnings("unused")
CollectionTerminatedException e) {
// there is no doc of interest in this reader context
// continue with the following leaf
continue;
}
BulkScorer scorer = weight.bulkScorer(ctx);
if (scorer != null) {
if (queryTimeout != null) {
scorer = new TimeLimitingBulkScorer(scorer, queryTimeout);
}
try {
scorer.score(leafCollector, ctx.reader().getLiveDocs());
} catch (
@SuppressWarnings("unused")
CollectionTerminatedException e) {
// collection was terminated prematurely
// continue with the following leaf
} catch (
@SuppressWarnings("unused")
TimeLimitingBulkScorer.TimeExceededException e) {
partialResult = true;
}
}
// Note: this is called if collection ran successfully, including the above special cases of
// CollectionTerminatedException and TimeExceededException, but no other exception.
leafCollector.finish();
}
// createWeight 中会有生成scorer的接口, 每个scorer有 iterator的接口负责docId的迭代, collector是SimpleTopScoreDocCollector
static void scoreAll(
LeafCollector collector,
DocIdSetIterator iterator,
TwoPhaseIterator twoPhase,
Bits acceptDocs)
throws IOException {
if (twoPhase == null) {
for (int doc = iterator.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;
doc = iterator.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
collector.collect(doc);
}
}
} else {
// The scorer has an approximation, so run the approximation first, then check acceptDocs,
// then confirm
for (int doc = iterator.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;
doc = iterator.nextDoc()) {
if ((acceptDocs == null || acceptDocs.get(doc)) && twoPhase.matches()) {
collector.collect(doc);
}
}
}
}
// SimpleTopScoreDocCollector.getLeafCollector ScorerLeafController#collect
public void collect(int doc) throws IOException {
float score = scorer.score();
// This collector relies on the fact that scorers produce positive values:
assert score >= 0; // NOTE: false for NaN
totalHits++;
hitsThresholdChecker.incrementHitCount();
if (minScoreAcc != null && (totalHits & minScoreAcc.modInterval) == 0) {
updateGlobalMinCompetitiveScore(scorer);
}
if (score <= pqTop.score) {
if (totalHitsRelation == TotalHits.Relation.EQUAL_TO) {
// we just reached totalHitsThreshold, we can start setting the min
// competitive score now
updateMinCompetitiveScore(scorer);
}
// Since docs are returned in-order (i.e., increasing doc Id), a document
// with equal score to pqTop.score cannot compete since HitQueue favors
// documents with lower doc Ids. Therefore reject those docs too.
return;
}
pqTop.doc = doc + docBase;
pqTop.score = score;
// 更新topN, pq 是优先权队列
pqTop = pq.updateTop();
updateMinCompetitiveScore(scorer);
}
带有filter的查询
代码示例
Directory readDirectory = FSDirectory.open(Path.of("data/lucene_knn_demo"));
IndexReader indexReader = DirectoryReader.open(readDirectory);
IndexSearcher indexSearcher = new IndexSearcher(indexReader);
float[] queryVector = generateFVector(128);
int k = 3;
BooleanQuery.Builder booleanQueryBuilder = new BooleanQuery.Builder();
TermQuery termQuery = new TermQuery(new Term("id", Integer.toString(1)));
booleanQueryBuilder.add(termQuery, BooleanClause.Occur.FILTER);
BooleanQuery filterQuery = booleanQueryBuilder.build();
// 将生成好的filterQuery添加到KnnFloatVectorQuery中,即可作为filter使用
KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery("fvecs", queryVector, k, filterQuery);
TopDocs topDocs = indexSearcher.search(knnQuery, k);
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
System.out.println("score doc " + scoreDoc.doc + ", score: " + scoreDoc.score);
}
查询原理
整体上和没有filter的查询是一样的,只是在搜索的时候会多一个filter的逻辑,通过HnswGraphSearcher#search
中的acceptDocs
体现, 简言之就是filter后的docIdSet会作为acceptDocs参数给HnswGraphSearcher, 后者在search的时候只有满足acceptDocs的docId才会被result的collector收集,这种方式在正确性上是没有问题的,但是对性能会有一定的影响,相当于扩大了搜索范围.
上文已经介绍了整体的搜索流程,这里不再赘述,下面主要介绍下filterQuery的生成acceptDocs的过程
过滤阶段也是主要发生在Query的rewrite阶段
//
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
final Weight filterWeight;
if (filter != null) {
// 可以看filter 被转化成了两个用Occur.FILTER连接的BooleanQuery, 其中FieldExistsQuery是用来过滤对应的向量字段不存在的doc
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
// 这里最终会把两个clause的booleanQuery转化成只含有filter的中的booleanQuery的ConstantScoreQuery,
// FieldExistsQuery先会被转化成MatchAllDocsQuery然后还会被优化掉, 不过这个不是本文的重点,感兴趣的话可以看下BooleanQuery#rewrite
Query rewritten = indexSearcher.rewrite(booleanQuery);
// 返回LRUQueryCache中的CachingWrapperWeight,最底层的包装的是TermWeight (因为例子中的filter是TermQuery)
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}
//... 下面的代码和之前没有filter的逻辑是一致的
}
// 转化过后的booleanQuery的rewrite, 由与只涉及到了filter相关的逻辑,这里省略了其他分支的代码
// BooleanQuery#rewrite
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (clauses.size() == 0) {
return new MatchNoDocsQuery("empty BooleanQuery");
}
// ...
// recursively rewrite
{
BooleanQuery.Builder builder = new BooleanQuery.Builder();
builder.setMinimumNumberShouldMatch(getMinimumNumberShouldMatch());
boolean actuallyRewritten = false;
for (BooleanClause clause : this) {
Query query = clause.getQuery();
BooleanClause.Occur occur = clause.getOccur();
Query rewritten;
if (occur == Occur.FILTER || occur == Occur.MUST_NOT) {
// Clauses that are not involved in scoring can get some extra simplifications
// 这里将上文转化的的query再次封装成了ConstantScoreQuery, 然后调用rewrite
rewritten = new ConstantScoreQuery(query).rewrite(indexSearcher);
if (rewritten instanceof ConstantScoreQuery) {
rewritten = ((ConstantScoreQuery) rewritten).getQuery();
}
} else {
rewritten = query.rewrite(indexSearcher);
}
if (rewritten != query || query.getClass() == MatchNoDocsQuery.class) {
// rewrite clause
actuallyRewritten = true;
if (rewritten.getClass() == MatchNoDocsQuery.class) {
switch (occur) {
case SHOULD:
case MUST_NOT:
// the clause can be safely ignored
break;
case MUST:
case FILTER:
return rewritten;
}
} else {
builder.add(rewritten, occur);
}
} else {
// leave as-is
builder.add(clause);
}
}
if (actuallyRewritten) {
return builder.build();
}
// ...
}
带有filter的查询的search过程
private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();
// 没有filter查询的走的是这个分支
if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
}
// 这里由于用的TermQuery,最终会返回一个TermQuery的scorer
Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
return NO_RESULTS;
}
// 关键: 基于term的scorer创建acceptDocs
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
int cost = acceptDocs.cardinality();
// 判断经过filter之后的docIdSet是否小于等于k, 如果小于等于k,则直接使用exactSearch
if (cost <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
}
// Perform the approximate kNN search
// 上文已经分析过approximateSearch的逻辑,这里不再赘述
TopDocs results = approximateSearch(ctx, acceptDocs, cost);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
}
}
// AbstractQuery#exactSearch
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return NO_RESULTS;
}
VectorScorer vectorScorer = createVectorScorer(context, fi);
// 预设一个长度为k的优先权队列
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
// 直接获取filter中的docId
boolean advanced = vectorScorer.advanceExact(doc);
assert advanced;
float score = vectorScorer.score();
if (score > topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
// 过滤出分数小于0的docId
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
}
// AbstractQuery#createBitSet
private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return ((BitSetIterator) iterator).getBitSet();
} else {
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator =
new FilteredDocIdSetIterator(iterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
// 基于filterIterator的cost 创建SparseFixedBitSet 或者FixedBitSet
return BitSet.of(filterIterator, maxDoc);
}
}
// TermQuery#scorerSuppiler
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
assert termStates == null || termStates.wasBuiltFor(ReaderUtil.getTopLevelContext(context))
: "The top-reader used to create Weight is not the same as the current reader's top-reader ("
+ ReaderUtil.getTopLevelContext(context);
// 获取TermsEnum的iterator
final TermsEnum termsEnum = getTermsEnum(context);
if (termsEnum == null) {
return null;
}
final int docFreq = termsEnum.docFreq();
return new ScorerSupplier() {
private boolean topLevelScoringClause = false;
@Override
public Scorer get(long leadCost) throws IOException {
LeafSimScorer scorer =
new LeafSimScorer(simScorer, context.reader(), term.field(), scoreMode.needsScores());
if (scoreMode == ScoreMode.TOP_SCORES) {
return new TermScorer(
TermWeight.this,
termsEnum.impacts(PostingsEnum.FREQS),
scorer,
topLevelScoringClause);
} else {
return new TermScorer(
TermWeight.this,
termsEnum.postings(
null, scoreMode.needsScores() ? PostingsEnum.FREQS : PostingsEnum.NONE),
scorer);
}
}
@Override
public long cost() {
return docFreq;
}
@Override
public void setTopLevelScoringClause() throws IOException {
topLevelScoringClause = true;
}
};
}
总结
- 本文主要分析了KnnFloatVectorQuery的search过程,主要介绍向量搜索的基本原理和关键源码分析,其中在带有filter的查询中,filter的结果会作为acceptDocs参数给HnswGraphSearcher, 有可能会带来性能上的损耗,这个是需要注意的