如何使用线段树求解区间最长连续相同字符串?

450 阅读1分钟

题目描述:

给你一个下标从 0 开始的字符串 s 。另给你一个下标从 0 开始、长度为 k 的字符串 queryCharacters ,一个下标从 0 开始、长度也是 k 的整数 下标 数组 queryIndices ,这两个都用来描述 k 个查询。第 i 个查询会将 s 中位于下标 queryIndices[i] 的字符更新为 queryCharacters[i] 。返回一个长度为 k 的数组 lengths ,其中 lengths[i] 是在执行第 i 个查询 之后 s 中仅由 单个字符重复 组成的 最长子字符串 的 长度 。

样例:

输入:s = "babacc", queryCharacters = "bcb", queryIndices = [1,3,3] 输出:[3,3,4] 解释:

  • 第 1 次查询更新后 s = "bbbacc" 。由单个字符重复组成的最长子字符串是 "bbb" ,长度为 3
  • 第 2 次查询更新后 s = "bbbccc" 。由单个字符重复组成的最长子字符串是 "bbb" 或 "ccc",长度为 3
  • 第 3 次查询更新后 s = "bbbbcc" 。由单个字符重复组成的最长子字符串是 "bbbb" ,长度为 4
  • 因此,返回 [3,3,4]

范围:

  • 1 <= s.length <= 10^5
  • s 由小写英文字母组成
  • k == queryCharacters.length == queryIndices.length
  • 1 <= k <= 10^5
  • queryCharacters 由小写英文字母组成
  • 0 <= queryIndices[i] < s.length

解题步骤

  • 解题第一步,先看数据范围,我们可以看见字符串的长度为 10^5,并且查询次数也为 10^5,每次查询还附带了更新,这个时候我们可以大概估计一下正解的复杂度,O(n^2)是肯定解决不了的,那么我们应该想O(nlogn)的算法。
  • 解题第二步,既然我们能想到O(nlogn)解决这个问题,那么思考什么数据结构能够保证更新和查询都是O(logn)的呢?很显然,有点数据结构基础的人都能想到树形结构,这道题的正解就是线段树。
  • 解题第三步,既然大概确定了线段树可能是正解,那么在想想如何构造这颗树,首先更新就不用说了,没啥巧妙的地方。主要是查询,我们需要每次查询到当前字符串最大的连续相同字符长度,那么我们必然要用线段树维护这个结构才行。
  • 解题第四步,如何维护这个线段树结构呢?我们每个节点维护左右区间(这是标准线段树必有的),既然要取最大的,那么我们还需要记录一个mx(记录当前区间最长的连续相同字符串),那么这个mx我们应该如何获取呢?我们能想到线段树的合并操作应该都是由左右子树进行合并来维护父节点的信息,如何合并?既然是相同字符,那么如果左子树的最右边一个字符等于右子树最左边的一个字符,那么是不是就可以合并?合并什么?左子树最右边连续相同字符的长度与右子树最左边连续相同的字符长度就可能变成当前节点区间的最长连续相同字符长度。那么由此思路我们可以知道我们还需要以下属性,lc,rc,表示区间左右边界字符,lx,rx表示左右边界字符连续相同的最大长度。

程序实现

class Tree{
    int l, r, mx, lx, rx;
    char lc, rc;
    public Tree(int l, int r) {
        this.l = l;
        this.r = r;
    }
}
public void pushUp(int k, Tree[] tree) {
    tree[k].mx = Math.max(tree[k << 1].mx, tree[k << 1 | 1].mx);
    tree[k].lx = tree[k << 1].lx;
    tree[k].rx = tree[k << 1 | 1].rx;
    tree[k].lc = tree[k << 1].lc;
    tree[k].rc = tree[k << 1 | 1].rc;
    if (tree[k << 1].rc == tree[k << 1 | 1].lc) {
        tree[k].mx = Math.max(tree[k].mx, tree[k << 1].rx + tree[k << 1 | 1].lx);
        if (tree[k << 1].lx == (tree[k << 1].r - tree[k << 1].l + 1)) {
            tree[k].lx = tree[k << 1].lx + tree[k << 1 | 1].lx;
        }
        if (tree[k << 1 | 1].rx == (tree[k << 1 | 1].r - tree[k << 1 | 1].l + 1)) {
            tree[k].rx = tree[k << 1 | 1].rx + tree[k << 1].rx;
        }
    }
}
public void build(int l, int r, int k, String s, Tree[] tree){
    tree[k] = new Tree(l, r);
    if(l == r) {
        tree[k].mx = tree[k].lx = tree[k].rx = 1;
        tree[k].lc = tree[k].rc = s.charAt(l - 1);
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, k << 1, s, tree);
    build(mid + 1, r, k << 1 | 1, s, tree);
    pushUp(k, tree);
}
public void update(int pos, char val, int k, Tree[] tree){
    if(tree[k].l == tree[k].r){
        tree[k].lc = tree[k].rc = val;
        return;
    }
    int mid = (tree[k].l + tree[k].r) >> 1;
    if(pos <= mid) {
        update(pos, val, k<<1, tree);
    }
    else {
        update(pos, val,k<<1|1, tree);
    }
    pushUp(k, tree);
}
public int query(int l, int r, int k, Tree[] tree){
    if(l <= tree[k].l && tree[k].r <= r){
        return tree[k].mx;
    }
    int mid = (tree[k].l + tree[k].r) >> 1, ans = 0;
    if(l <= mid) {
        ans = Math.max(ans, query(l, r, k << 1, tree));
    }
    if(r > mid) {
        ans = Math.max(ans, query(l, r, k<<1 | 1, tree));
    }
    return ans;
}

public int[] longestRepeating(String s, String queryCharacters, int[] queryIndices) {
    Tree[] tree = new Tree[s.length() << 2 | 1];
    build(1, s.length(), 1, s, tree);
    int[] ans = new int[queryIndices.length];
    for (int i = 0; i < queryIndices.length; i++) {
        update(queryIndices[i] + 1, queryCharacters.charAt(i), 1, tree);
        ans[i] = query(1, s.length(), 1, tree);
    }
    return ans;
}