用Java实现线段树

1,448 阅读4分钟

线段树是为区间更新和区间查询而生的数据结构,旨在快速解决区间问题。

一般来说,线段树是不会加节点的,也不支持动态添加节点。线段树也是二叉树的一种,不过它的节点是以一个区间来定义节点的。具有一个单一区间的就是叶子节点。所以线段树,本质上就是一棵区间树。

我们在查找的时候,只需要找出结果区间由哪些子区间构成即可。

实现代码

首先定义出基础的结构

public class SegmentTree {
    
    private Integer value;
    private Integer maxValue;

    private Integer l;
    private Integer r;
    
    private SegmentTree leftChild;
    private SegmentTree rightChild;
}

l和r用来唯一刻画这个区间。然后其他的内容,与标准的二叉树没得任何区别。

建树过程

与二叉树建树没得区别,我们这里采用前序建树的方式进行。代码如下:

public static SegmentTree buildTree(int left, int right, int[] value) {
    if (left > right) {
        return null;
    }

    SegmentTree node = new SegmentTree();
    node.setValue(value[left]);
    node.setL(left);
    node.setR(right);
    if (left == right) {
        // TODO: 2022/1/17 退出条件
        node.setMaxValue(node.getValue());
        return node;
    }
    int mid = (left + right) >>> 1;
    node.setLeftChild(buildTree(left, mid, value));
    node.setRightChild(buildTree(mid + 1, right, value));
    if (Objects.isNull(node.getLeftChild())) {
        if (Objects.isNull(node.getRightChild())) {
            node.setMaxValue(node.getValue());
        } else {
            node.setMaxValue(node.getRightChild().getMaxValue());
        }
    } else {
        if (Objects.isNull(node.getRightChild())) {
            node.setMaxValue(node.getLeftChild().getMaxValue());
        } else {
            node.setMaxValue(Math.max(node.getLeftChild().getMaxValue(),
                                      node.getRightChild().getMaxValue()));
        }
    }
    return node;
}

可以看见,这里的叶子节点判断条件就是 left == right。其他方面和二叉树没有任何区别。

查询区间最大值

public static Integer getMaxValue(SegmentTree segmentTree, int left, int right) {
    if (Objects.isNull(segmentTree)) return null;
    if (segmentTree.getL() == left && segmentTree.getR() == right) {
        System.out.println("获取了区间 [" + left + "," + right + "] 的最大值" + segmentTree.getMaxValue());
        return segmentTree.getMaxValue();
    }
    int segMid = (segmentTree.getL() + segmentTree.getR()) >>> 1;
    if (segMid < left) {
        return getMaxValue(segmentTree.getRightChild(), left, right);
    }
    if (segMid >= right) {
        return getMaxValue(segmentTree.getLeftChild(), left, right);
    }
    // TODO: 2022/1/17 左半边答案
    Integer leftMax = getMaxValue(segmentTree.getLeftChild(), left, segMid);
    Integer rightMax = getMaxValue(segmentTree.getRightChild(), segMid + 1, right);
    if (Objects.isNull(leftMax)) {
        if (Objects.isNull(rightMax)) {
            return -100000;
        } else {
            return rightMax;
        }
    } else {
        if (Objects.isNull(rightMax)) {
            return leftMax;
        } else {
            return Math.max(leftMax, rightMax);
        }
    }
}

从上面的代码分析,设当前节点的区间为【L,R】,那么对于区间[l,r]的最大值来说,就需要进行分类讨论,如果LR的区间中点Mid在lr区间的左边,那么max(lr) = max(右子树,l,r);如果LR的区间中点在lr区间的右边,则max(lr) = max(左子树,l,r);如果Mid在lr区间里面,则 max(lr) = max(左子树,l,mid) 和 max(右子树,mid+1,r)中的较大值。

下面我们来看看测试用例和运行结果:

public static void main(String[] args) {
    int[] a = new int[]{2, 5, 4, 7, 6, 0, 1, -1, 2, 3, 6, 7, 0, 2, 9, 8, 5, 4, 7, 2};
    SegmentTree segmentTree = buildTree(0, a.length - 1, a);
    System.out.println(getMaxValue(segmentTree, 0, 16));
}

结果如下

获取了区间 [0,9] 的最大值7 获取了区间 [10,14] 的最大值9 获取了区间 [15,16] 的最大值8 9

获取区间和

现在需要对原来的建树过程进行改造,首先,在基础结构中添加sum字段

public class SegmentTree {

    private Integer value;
    private Integer maxValue;
    private Integer sum;

    private Integer l;
    private Integer r;

    private SegmentTree leftChild;
    private SegmentTree rightChild;
}

然后在建树方法中,添加对和的维护

public static SegmentTree buildTree(int left, int right, int[] value) {
    if (left > right) {
        return null;
    }

    SegmentTree node = new SegmentTree();
    node.setValue(value[left]);
    node.setL(left);
    node.setR(right);
    if (left == right) {
        // TODO: 2022/1/17 退出条件
        node.setMaxValue(node.getValue());
        node.setSum(node.getValue());
        return node;
    }
    int mid = (left + right) >>> 1;
    node.setLeftChild(buildTree(left, mid, value));
    node.setRightChild(buildTree(mid + 1, right, value));
    if (Objects.isNull(node.getLeftChild())) {
        if (Objects.isNull(node.getRightChild())) {
            node.setMaxValue(node.getValue());
            node.setSum(node.getValue());
        } else {
            node.setMaxValue(node.getRightChild().getMaxValue());
            node.setSum(node.getRightChild().getSum());
        }
    } else {
        if (Objects.isNull(node.getRightChild())) {
            node.setMaxValue(node.getLeftChild().getMaxValue());
            node.setSum(node.getLeftChild().getSum());
        } else {
            node.setMaxValue(Math.max(node.getLeftChild().getMaxValue(),
                                      node.getRightChild().getMaxValue()));
            node.setSum(node.getLeftChild().getSum() + node.getRightChild().getSum());
        }
    }
    return node;
}

然后获取总和

public static Integer getSum(SegmentTree segmentTree, int left, int right) {
    if (Objects.isNull(segmentTree)) return null;
    if (segmentTree.getL() == left && segmentTree.getR() == right) {
        System.out.println("获取了区间 [" + left + "," + right + "] 的和" + segmentTree.getSum());
        return segmentTree.getSum();
    }
    int segMid = (segmentTree.getL() + segmentTree.getR()) >>> 1;
    if (segMid < left) {
        return getSum(segmentTree.getRightChild(), left, right);
    }
    if (segMid >= right) {
        return getSum(segmentTree.getLeftChild(), left, right);
    }
    // TODO: 2022/1/17 左半边答案
    Integer leftSum = getSum(segmentTree.getLeftChild(), left, segMid);
    Integer rightSum = getSum(segmentTree.getRightChild(), segMid + 1, right);
    if (Objects.isNull(leftSum)) {
        if (Objects.isNull(rightSum)) {
            return segmentTree.getSum();
        } else {
            return rightSum;
        }
    } else {
        if (Objects.isNull(rightSum)) {
            return leftSum;
        } else {
            return leftSum + rightSum;
        }
    }
}

测试程序和结果如下:

public static void main(String[] args) {
    int[] a = new int[]{2, 5, 4, 7, 6, 0, 1, -1, 2, 3, 6, 7, 0, 2, 9, 8, 5, 4, 7, 2};
    SegmentTree segmentTree = buildTree(0, a.length - 1, a);
    System.out.println(getSum(segmentTree,0,3));
}

获取了区间 [0,2] 的和11 获取了区间 [3,3] 的和7 18

单点更新

/**
     * 这里的left == right
     *
     * @param segmentTree
     * @param left
     * @param right
     * @param value
     */
public static void update(SegmentTree segmentTree, int left, int right, int value) {
    if (segmentTree.getL() == left && segmentTree.getR() == right) {
        segmentTree.setValue(value);
        segmentTree.setMaxValue(value);
        segmentTree.setSum(value);
        return;
    }
    int mid = (segmentTree.getL() + segmentTree.getR()) >>> 1;
    if (mid >= left) {
        update(segmentTree.getLeftChild(), left, right, value);
    }
    if (mid < left) {
        update(segmentTree.getRightChild(), left, right, value);
    }
    segmentTree.setMaxValue(Math.max(segmentTree.getLeftChild().getMaxValue(),segmentTree.getRightChild().getMaxValue()));
    segmentTree.setSum(segmentTree.getLeftChild().getSum() + segmentTree.getRightChild().getSum());
}

更新的时候也是利用递归的方法,不断从左右节点中寻找到需要被更新的区间,同时更新上级节点的最新值。

总结

可以按需进行延伸,记住一点,线段树是以区间为维度的二叉树,或者说,是以二维维度进行刻画的二叉树,普通二叉树只有一维。这样一来,我们在计算多维度的值的时候,其实也可以利用这样的方式构建线段树(二维树,三维树,n维树)。不管几维树,找到结束状态和下级子状态就是关键中的关键。典型的方法就是分类讨论,前期不用怕分得过细,细了可以进行合并。