Lucene源码系列(十二):FST的读取和查找

1,314 阅读7分钟

背景

在前面的文章中我们从源码以及图文示例的层面分别介绍了FST完整的构建逻辑,如果已经掌握了构建的逻辑,那么FST的读取和查找逻辑是比较简单的。

本文就是介绍如何使用FST,首先我们会介绍如何加载持久化的FST,然后介绍一些基本的查找逻辑,最后会介绍FST的几个使用场景。

读取时使用的arc结构(和构建时使用的是不同的)

// arc的label
private int label;
// arc的output
private T output;
// arc指向的节点的起始位置
private long target;
// arc的flag
private byte flags;
// arc的finalOutput
private T nextFinalOutput;
// 有两种情况,一种是下一个arc,另一种是下一个节点的位置(深度优先遍历的时候使用)
private long nextArc;
// 用来判断节点中arc的存储方式
private byte nodeFlags;

// 下面这些变量都是在使用固定长度存储arc的时候使用
private int bytesPerArc;
private long posArcsStart;
private int arcIdx;
private int numArcs;
private long bitTableStart;
private int firstLabel;
private int presenceIndex;

加载FST

public FST(DataInput metaIn, DataInput in, Outputs<T> outputs, FSTStore fstStore)
    throws IOException {
  bytes = null;
  this.fstStore = fstStore;
  this.outputs = outputs;

  this.version = CodecUtil.checkHeader(metaIn, FILE_FORMAT_NAME, VERSION_START, VERSION_CURRENT);
  if (metaIn.readByte() == 1) { // 是1的话说明存在空输入
    BytesStore emptyBytes = new BytesStore(10);
    int numBytes = metaIn.readVInt();
    emptyBytes.copyBytes(metaIn, numBytes);

    BytesReader reader = emptyBytes.getReverseReader();
    if (numBytes > 0) {
      reader.setPosition(numBytes - 1);
    }
    // 读取空输入的输出  
    emptyOutput = outputs.readFinalOutput(reader);
  } else {
    emptyOutput = null;
  }
  final byte t = metaIn.readByte(); // label的大小
  switch (t) {
    case 0:
      inputType = INPUT_TYPE.BYTE1;
      break;
    case 1:
      inputType = INPUT_TYPE.BYTE2;
      break;
    case 2:
      inputType = INPUT_TYPE.BYTE4;
      break;
    default:
      throw new CorruptIndexException("invalid input type " + t, in);
  }
  // root节点的起始位置
  startNode = metaIn.readVLong();
  // 读取fst的信息,存储在fstStore中
  long numBytes = metaIn.readVLong();
  this.fstStore.init(in, numBytes);
}

FST查找的底层依赖逻辑

构造虚拟arc,指向的是root节点

为什么要构造一个虚拟的arc,因为使用虚拟arc,并且其指向root节点,这样遍历或者查找都可以以arc为入口,统一查找的逻辑。查找或者遍历所有输入主要的两个操作是:查找下一个node和查找下一个arc,这两种操作都可以以arc为入口来实现,后面会有具体的方法介绍。

// arc就是最终的虚拟arc
public Arc<T> getFirstArc(Arc<T> arc) {
  T NO_OUTPUT = outputs.getNoOutput();

  if (emptyOutput != null) { 
    // 如果存在空输入,则这是BIT_FINAL_ARC 和 BIT_LAST_ARC
    // BIT_FINAL_ARC:空输入的情况,root就是可接受节点
    // BIT_LAST_ARC:虚拟arc只有一个,肯定是最后一个arc  
    arc.flags = BIT_FINAL_ARC | BIT_LAST_ARC;
    arc.nextFinalOutput = emptyOutput;
    if (emptyOutput != NO_OUTPUT) {
      arc.flags = (byte) (arc.flags() | BIT_ARC_HAS_FINAL_OUTPUT);
    }
  } else {
    // BIT_LAST_ARC:虚拟arc只有一个,肯定是最后一个arc    
    arc.flags = BIT_LAST_ARC;
    arc.nextFinalOutput = NO_OUTPUT;
  }
  arc.output = NO_OUTPUT;

  // 虚拟arc的target设置为root节点的起始位置
  arc.target = startNode;
  return arc;
}

读取node的第一个arc(如果是可接受节点,则是构造一个arc作为第一个arc)

这个方法是在深度优先遍历FST的时候使用的,定位到node的第一个arc。

  // 从follow指向的target中获取第一个arc,存储在arc变量中
  public Arc<T> readFirstTargetArc(Arc<T> follow, Arc<T> arc, BytesReader in) throws IOException {
    if (follow.isFinal()) { // 如果follow指向的是可接受节点,构造一个arc,用END_LABEL标记
      arc.label = END_LABEL;
      arc.output = follow.nextFinalOutput();
      arc.flags = BIT_FINAL_ARC;
      if (follow.target() <= 0) {
        arc.flags |= BIT_LAST_ARC;
      } else {
        // 可以看到深度优先遍历的时候,arc.nextArc指向的是下一个节点。
        // 这是为了处理某个输入是另一个输入前缀的情况,这种情况,深度优先遍历的下一个arc是在下一个节点中。  
        arc.nextArc = follow.target();
      }
      arc.target = FINAL_END_NODE;
      arc.nodeFlags = arc.flags;
      return arc;
    } else { // 如果follow不是可接受节点,则读取真实的第一个arc
      return readFirstRealTargetArc(follow.target(), arc, in);
    }
  }

从node中获取第一个arc

真实的读取node的第一个arc,这个方法的调用必须确保nodeAddress指向的节点存在arc。

public Arc<T> readFirstRealTargetArc(long nodeAddress, Arc<T> arc, final BytesReader in)
    throws IOException {
  // 定位到node的地址  
  in.setPosition(nodeAddress);
  
  // 读取flag  
  byte flags = arc.nodeFlags = in.readByte();  
  if (flags == ARCS_FOR_BINARY_SEARCH || flags == ARCS_FOR_DIRECT_ADDRESSING) { // 节点使用固定长度存储arc
    // 读取固定长度存储模式的各个头部信息  
    arc.numArcs = in.readVInt();
    arc.bytesPerArc = in.readVInt();
    arc.arcIdx = -1;
    if (flags == ARCS_FOR_DIRECT_ADDRESSING) {
      readPresenceBytes(arc, in);
      arc.firstLabel = readLabel(in);
      arc.presenceIndex = -1;
    }
    arc.posArcsStart = in.getPosition();
  } else { // node的起始位置就是第一个arc的位置
    // arc.nextArc指向了第一个arc的位置
    arc.nextArc = nodeAddress;
    arc.bytesPerArc = 0;
  }

  // 这样查找arc的下一个arc就得到了第一个arc  
  return readNextRealArc(arc, in);
}

读取下一个arc

因为node中的arc有3中存储方式,所以读取的时候也是分三种情况,需要注意的是二分查找存储方式在这种情况先并不需要使用二分查找,因为需要查找的是下一个arc,根据当前arc的下标,可以得到下一个arc的下标,有了下标,可以直接根据下标定位arc的位置。

public Arc<T> readNextRealArc(Arc<T> arc, final BytesReader in) throws IOException {
  switch (arc.nodeFlags()) {
    case ARCS_FOR_BINARY_SEARCH: // 二分查找存储的方式,知道下标可以直接定位
      assert arc.bytesPerArc() > 0;
      arc.arcIdx++;
      assert arc.arcIdx() >= 0 && arc.arcIdx() < arc.numArcs();
      in.setPosition(arc.posArcsStart() - arc.arcIdx() * arc.bytesPerArc());
      arc.flags = in.readByte();
      break;

    case ARCS_FOR_DIRECT_ADDRESSING: // 直接寻址
      assert BitTable.assertIsValid(arc, in);
      assert arc.arcIdx() == -1 || BitTable.isBitSet(arc.arcIdx(), arc, in);
      int nextIndex = BitTable.nextBitSet(arc.arcIdx(), arc, in);
      return readArcByDirectAddressing(arc, in, nextIndex, arc.presenceIndex + 1);

    default: // 线性遍历
      assert arc.bytesPerArc() == 0;
      // 定位到下一个arc的位置    
      in.setPosition(arc.nextArc());
      arc.flags = in.readByte();
  }
  // 读取arc的信息  
  return readArc(arc, in);
}

读取arc的信息

填充arc的结构字段信息。

// 注意,这里是已经定位到了in中arc所在的位置了
private Arc<T> readArc(Arc<T> arc, BytesReader in) throws IOException {
  if (arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING) { // 直接寻址的方式是的label需要计算
    // 直接寻址的label是第一个label的值加上arc的下标  
    arc.label = arc.firstLabel() + arc.arcIdx();
  } else {
    arc.label = readLabel(in); 
  }

  if (arc.flag(BIT_ARC_HAS_OUTPUT)) { // 读取output
    arc.output = outputs.read(in);
  } else {
    arc.output = outputs.getNoOutput();
  }

  if (arc.flag(BIT_ARC_HAS_FINAL_OUTPUT)) { // 读取finalOutput
    arc.nextFinalOutput = outputs.readFinalOutput(in);
  } else {
    arc.nextFinalOutput = outputs.getNoOutput();
  }

  if (arc.flag(BIT_STOP_NODE)) { // 这种情况,下一个节点的位置是存储在一起的
    if (arc.flag(BIT_FINAL_ARC)) {
      arc.target = FINAL_END_NODE;
    } else { // 一般逻辑不会出现,剪枝才有可能
      arc.target = NON_FINAL_END_NODE;
    }
    arc.nextArc = in.getPosition();  
  } else if (arc.flag(BIT_TARGET_NEXT)) { // 这种情况,下一个节点的位置是存储在一起的
    arc.nextArc = in.getPosition(); 
    if (!arc.flag(BIT_LAST_ARC)) {
      if (arc.bytesPerArc() == 0) { // 需要使用遍历的方式查找
        seekToNextNode(in);
      } else { // 固定长度的arc存储方式可以通过计算下标直接定位
        int numArcs =
            arc.nodeFlags == ARCS_FOR_DIRECT_ADDRESSING
                ? BitTable.countBits(arc, in)
                : arc.numArcs();
        in.setPosition(arc.posArcsStart() - arc.bytesPerArc() * numArcs);
      }
    }
    arc.target = in.getPosition();
  } else { 
    arc.target = readUnpackedNodeTarget(in);
    arc.nextArc = in.getPosition(); 
  }
  return arc;
}

读取标签

元信息的inputType记录了label占用的存储空间的大小,根据存储空间的大小读取label。

public int readLabel(DataInput in) throws IOException {
  final int v;
  if (inputType == INPUT_TYPE.BYTE1) {
    v = in.readByte() & 0xFF;
  } else if (inputType == INPUT_TYPE.BYTE2) {
    if (version < VERSION_LITTLE_ENDIAN) {
      v = Short.reverseBytes(in.readShort()) & 0xFFFF;
    } else {
      v = in.readShort() & 0xFFFF;
    }
  } else {
    v = in.readVInt();
  }
  return v;
}

在节点中查找目标label的arc

这里也可以想到,会分三种情况查找。

// labelToMatch:目标arc的label
// follow:从follow的target中查找目标arc
// arc:查找到匹配的arc会存储在这个变量中
// in:存储fst的输入流
public Arc<T> findTargetArc(int labelToMatch, Arc<T> follow, Arc<T> arc, BytesReader in)
    throws IOException {
  if (labelToMatch == END_LABEL) { // 如果要找的结束的label
    if (follow.isFinal()) { // 如果要找的是label是END_LABEL的话,必须follow是可接受节点
      if (follow.target() <= 0) {
        arc.flags = BIT_LAST_ARC;
      } else {
        arc.flags = 0;
        arc.nextArc = follow.target();
      }
      arc.output = follow.nextFinalOutput();
      arc.label = END_LABEL;
      arc.nodeFlags = arc.flags;
      return arc;
    } else {
      return null;
    }
  }

  if (!targetHasArcs(follow)) { // 要查找的节点没有arc,直接返回null
    return null;
  }

  // 定位到target的位置  
  in.setPosition(follow.target());

  // 读取flag  
  byte flags = arc.nodeFlags = in.readByte();
  if (flags == ARCS_FOR_DIRECT_ADDRESSING) { // 1.以直接寻址的方式查找
    // 读取用以直接寻址的头部信息  
    arc.numArcs = in.readVInt(); 
    arc.bytesPerArc = in.readVInt();
    readPresenceBytes(arc, in);
    arc.firstLabel = readLabel(in);
    arc.posArcsStart = in.getPosition();

    // 按直接寻址的方式查找,先从位图判断是否存在该label的arc,如果存在则获取存储的下标,直接定位到arc位置读取信息 
    int arcIndex = labelToMatch - arc.firstLabel();
    if (arcIndex < 0 || arcIndex >= arc.numArcs()) {
      return null; 
    } else if (!BitTable.isBitSet(arcIndex, arc, in)) {
      return null; 
    }
    return readArcByDirectAddressing(arc, in, arcIndex);
  } else if (flags == ARCS_FOR_BINARY_SEARCH) { // 2.以二分查找的方式查找
    // 读取用以二分查找的头部信息  
    arc.numArcs = in.readVInt();
    arc.bytesPerArc = in.readVInt();
    arc.posArcsStart = in.getPosition();
    // 二分查找的实现
    int low = 0;
    int high = arc.numArcs() - 1;
    while (low <= high) {
      int mid = (low + high) >>> 1;
      // +1是略过flag,定位到label的位置  
      in.setPosition(arc.posArcsStart() - (arc.bytesPerArc() * mid + 1));
      int midLabel = readLabel(in);
      final int cmp = midLabel - labelToMatch;
      if (cmp < 0) {
        low = mid + 1;
      } else if (cmp > 0) {
        high = mid - 1;
      } else { // 找到了,则下标设置为目标的前一个arc的下标,然后使用readNextRealArc方法读取
        arc.arcIdx = mid - 1;
        return readNextRealArc(arc, in);
      }
    }
    return null;
  }

  // 3.线性查找:先读取到第一个arc,然后使用  readNextRealArc 方法来遍历
  readFirstRealTargetArc(follow.target(), arc, in);
  while (true) {
    if (arc.label() == labelToMatch) {
      return arc;
    } else if (arc.label() > labelToMatch) {
      return null;
    } else if (arc.isLast()) {
      return null;
    } else {
      readNextRealArc(arc, in);
    }
  }
}

FST的常见使用场景

我们先看下示例代码,包含了FST的构建以及我们要介绍的3种使用场景:

public class FSTDemo {
    public static void main(String[] args) throws IOException {
        // 构建FST
        String[] inputValues = {"bat", "cat", "deep", "do", "dog", "dogs"};
        long[] outputvalues = {2, 5, 15, 10, 3, 2};
        PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
        FSTCompiler.Builder<Long> builder = new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
        FSTCompiler<Long> build = builder.build();
        IntsRefBuilder intsRefBuilder = new IntsRefBuilder();
        for (int i = 0; i < inputValues.length; i ++) {
            BytesRef bytesRef = new BytesRef(inputValues[i]);
            build.add(Util.toIntsRef(bytesRef, intsRefBuilder), outputvalues[i]);
        }
        FST<Long> fst = build.compile();
        
        // 跟据key获取value
        BytesRef bytesRef = new BytesRef(inputValues[3]);
        Long aLong = Util.get(fst, Util.toIntsRef(bytesRef, intsRefBuilder));
        System.out.println(aLong);

        // 遍历FST
        IntsRefFSTEnum<Long> fstEnum = new IntsRefFSTEnum<>(fst);
        IntsRefFSTEnum.InputOutput<Long> inputOutput;
        BytesRefBuilder scratch = new BytesRefBuilder();
        while ((inputOutput = fstEnum.next()) != null) {
            String input = Util.toBytesRef(inputOutput.input, scratch).utf8ToString();
            Long output = inputOutput.output;
            System.out.println(input + "\t" + output);
        }
        
        // 跟输入d自动补全
        String userInput = "d";
        BytesRef bytesRefUserInput = new BytesRef(userInput);
        IntsRef intsRefUserInput = Util.toIntsRef(bytesRefUserInput, new IntsRefBuilder());
        FST.Arc<Long> arc = fst.getFirstArc(new FST.Arc<>());
        for (int i = 0; i < intsRefUserInput.length; i ++) {
            if (fst.findTargetArc(intsRefUserInput.ints[intsRefUserInput.offset + i], arc, arc, fst.getBytesReader()) == null) {
                System.out.println("没找到d开头的");
            }
        }

        Util.TopResults<Long> results = Util.shortestPaths(fst, arc, PositiveIntOutputs.getSingleton().getNoOutput(), new Comparator<Long>() {
            @Override
            public int compare(Long o1, Long o2) {
                return o1.compareTo(o2);
            }
        }, 3, false);
        
        BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
        for (Util.Result<Long> result : results) {
            IntsRef intsRef = result.input;
            System.out.println(userInput + Util.toBytesRef(intsRef, bytesRefBuilder).utf8ToString());
        }
    }
}

场景一:根据input获取output

这是一种比较常用的场景,带输出的FST其实就是一个hashmap,可以根据key来查找value。这种场景的实现在org.apache.lucene.util.fst.Util#get(FST<T> fst, BytesRef input)和org.apache.lucene.util.fst.Util#get(FST<T> fst, IntsRef input)中,这两个方法实现一模一样,只是因为入参类型的区别,所以我们看一个就可以,这里我们可以看到finalOutput和output的生效条件。

public static <T> T get(FST<T> fst, BytesRef input) throws IOException {
  assert fst.inputType == FST.INPUT_TYPE.BYTE1;

  final BytesReader fstReader = fst.getBytesReader();

  // 获取虚拟的arc,指向的是root节点
  final FST.Arc<T> arc = fst.getFirstArc(new FST.Arc<>());

  T output = fst.outputs.getNoOutput();
  for (int i = 0; i < input.length; i++) {
    // 以arc的target为起点,查找满足label的arc,如果找到的话,把arc更新为满足查找条件的arc。
    // 这里有点绕,是因为实现上为了避免频繁创建arc对象,所以都是以原地更新的方式复用arc变量。
    if (fst.findTargetArc(input.bytes[i + input.offset] & 0xFF, arc, arc, fstReader) == null) {
      return null;
    }
    output = fst.outputs.add(output, arc.output());
  }

  // 如果最后一个arc指向的是可接受节点,则需要把finalOutput也加上  
  if (arc.isFinal()) {
    return fst.outputs.add(output, arc.nextFinalOutput());
  } else {
    return null;
  }
}

场景二:遍历FST

FST的遍历实现在FSTEnum中,它有两个实现类,分别是BytesRefFSTEnum和IntsRefFSTEnum,这两个实现中大部分逻辑都是一样,只是对input类型有所区别的操作不一样,我们以BytesRefFSTEnum为例子来看就行了。

遍历的入口在next方法中,这里使用了模板模式,真正的逻辑调用了父类的doNext方法:

public InputOutput<T> next() throws IOException {
  doNext();
  return setResult();
}

doNext先找深度优先遍历的起始arc,然后再执行深度优先遍历查找到第一个可接受节点。

protected void doNext() throws IOException {
  // 1.查找深度优先遍历的起始arc
  if (upto == 0) { // 如果寻找的是第一个输入,则起始遍历的arc就是root节点的第一个arc
    upto = 1;
    fst.readFirstTargetArc(getArc(0), getArc(1), fstReader);
  } else { // 否则就是就是当前arc的下一个arc
    while (arcs[upto].isLast()) {
      upto--;
      if (upto == 0) {
        return;
      }
    }  
    fst.readNextArc(arcs[upto], fstReader);
  }

  // 2.深度优先遍历查找可接受节点
  pushFirst();
}

private void pushLast() throws IOException {

  FST.Arc<T> arc = arcs[upto];
  assert arc != null;
  // 深度优先遍历,找到第一个可接受的节点
  while (true) {
    setCurrentLabel(arc.label());
    output[upto] = fst.outputs.add(output[upto - 1], arc.output());
    if (arc.label() == FST.END_LABEL) { // 找到可接受的节点
      break;
    }
    incr();

    arc = fst.readLastTargetArc(arc, getArc(upto), fstReader);
  }
}

场景三:根据前缀自动补全输入

这种场景我们每天都在接触,如下图所示:

query词自动补全.png

搜索引擎的query自动联想方案会更丰富和复杂,后面我们也会有专门的文章介绍。基于FST一般用来做term级的补全,我们把所有的term构建成一个FST,则可以根据输入的前缀,从FST中查找以该前缀开头并且存在可接受节点的路径,就是联想的结果。

因此,这个场景的查找分为两步:

  1. 获取输入前缀的起始路径集合
  2. 查找前缀路径到所有可接受节点的路径(深度优先遍历)
// 获取以node为起点的所有arc
public void addStartPaths(
    FST.Arc<T> node,  // 前缀的最后一个arc
    T startOutput,
    boolean allowEmptyString,
    IntsRefBuilder input,
    float boost,
    CharSequence context,
    int payload)
    throws IOException {
  if (startOutput.equals(fst.outputs.getNoOutput())) {
    startOutput = fst.outputs.getNoOutput();
  }

  FSTPath<T> path = new FSTPath<>(startOutput, node, input, boost, context, payload);
  fst.readFirstTargetArc(node, path.arc, bytesReader);

  // 获取node的所有arc  
  while (true) {
    if (allowEmptyString || path.arc.label() != FST.END_LABEL) {
      addIfCompetitive(path);
    }
    if (path.arc.isLast()) { // node的最后一个arc了
      break;
    }
    fst.readNextArc(path.arc, bytesReader);
  }
}

// 深度优先遍历,实现是借助了queue。
public TopResults<T> search() throws IOException {

  final List<Result<T>> results = new ArrayList<>();

  final BytesReader fstReader = fst.getBytesReader();
  final T NO_OUTPUT = fst.outputs.getNoOutput();

  int rejectCount = 0;

  // 只要topN个结果
  while (results.size() < topN) {

    FSTPath<T> path;

    if (queue == null) {
      break;
    }

    path = queue.pollFirst();

    if (path == null) {
      break;
    }

    if (acceptPartialPath(path) == false) {
      continue;
    }

    if (path.arc.label() == FST.END_LABEL) {
      path.input.setLength(path.input.length() - 1);
      results.add(new Result<>(path.input.get(), path.output));
      continue;
    }

    if (results.size() == topN - 1 && maxQueueDepth == topN) { // 只需要再找最后一个,queue不需要了
      queue = null;
    }

    while (true) {
      fst.readFirstTargetArc(path.arc, path.arc, fstReader);
      boolean foundZero = false;
      boolean arcCopyIsPending = false;
      while (true) {// 深度优先遍历所有path.arc指向的target的可用路径,加入到queue中
        if (comparator.compare(NO_OUTPUT, path.arc.output()) == 0) {
          if (queue == null) {
            foundZero = true;
            break;
          } else if (!foundZero) {
            arcCopyIsPending = true;
            foundZero = true;
          } else {
            addIfCompetitive(path);
          }
        } else if (queue != null) {
          addIfCompetitive(path);
        }
        if (path.arc.isLast()) {
          break;
        }
        if (arcCopyIsPending) {
          scratchArc.copyFrom(path.arc);
          arcCopyIsPending = false;
        }
        fst.readNextArc(path.arc, fstReader);
      }

      assert foundZero;

      if (queue != null && !arcCopyIsPending) {
        path.arc.copyFrom(scratchArc);
      }

      if (path.arc.label() == FST.END_LABEL) { // 找到一个结果
        path.output = fst.outputs.add(path.output, path.arc.output());
        if (acceptResult(path)) {
          results.add(new Result<>(path.input.get(), path.output));
        } else {
          rejectCount++;
        }
        break;
      } else {
        path.input.append(path.arc.label());
        path.output = fst.outputs.add(path.output, path.arc.output());
        if (acceptPartialPath(path) == false) {
          break;
        }
      }
    }
  }
  return new TopResults<>(rejectCount + topN <= maxQueueDepth, results);
}