使用 java 实现一个二分搜索树

89 阅读2分钟

学习二分搜素树,这种数据结构在搜索方面的速度是非常可佳的。

二分搜索树之所以快,有一个很重要的前提就是可比较性。

public class BST<E extends Comparable<E>> {
    // 包括元素 e 和 左右两个节点
    private class Node {
        E e;
        Node right, left;

        Node(E e) {
            this.e = e;
            this.right = null;
            this.left = null;
        }
    }

    private Node root;
    private int size;

    public BST() {
        this.root = null;
        this.size = 0;
    }

    public int size() {
        return this.size;
    }

    public boolean isEmpty() {
        return this.size == 0;
    }

    // 新增
    public void add(E e) {
        this.root = add(root, e);
    }

    /*
    对于返回值可以看传进来的和返回的,
    我们首先不考虑中间的过程,就考虑第一次传
    进来的是 root,那么返回的 node ,此时
    node 就是传进来的 root, 固然这里的返回不会出现任何问题
     */
    private Node add(Node node, E e) {
        if (node == null) {
            this.size++;
            return new Node(e);
        } else if (e.compareTo(node.e) > 0) {
            node.right = add(node.right, e);
        } else {
            node.left = add(node.left, e);
        }
        return node;
    }

    // 查找当前搜索树是否包括对应的值 e
    public boolean includes(E e) {
        return includes(this.root, e);
    }

    private boolean includes(Node node, E e) {
        if (node == null) return false;
        if (e.equals(node.e)) {
            return true;
        } else if (e.compareTo(node.e) > 0) {
            return includes(node.right, e);
        } else {
            return includes(node.left, e);
        }

    }

    /*
    删除  利用返回值来完成指向的变更,
    比如在链表中可能需要上一个节点来辅助,
    但如果利用返回值则可以很好的利用这一点
     */
    public void delete(E e) {
        root = delete(root, e);
    }

    private Node delete(Node node, E e) {
        if (node == null) return null;
        if (e.compareTo(node.e) > 0) {
            node.right = delete(node.right, e);
            return node;
        } else if (e.compareTo(node.e) < 0) {
            node.left = delete(node.left, e);
            return node;
        } else {
            if (node.left == null) {
                Node newNode = node.right;
                node.right = null;
                size --;
                return newNode;
            }

            if (node.right == null) {
                Node newNode = node.left;
                node.left = null;
                size --;
                return newNode;
            }

            Node replaceNode = findMaxValue(node.left);
            node.left = deleteMaxValue(node.left);
            replaceNode.right = node.right;
            replaceNode.left = node.left;
            node.left = null;
            node.right = null;
            return replaceNode;
        }
    }

    // 查找最大值的节点 方便在删除中使用
    private Node findMaxValue(Node node) {
        Node p = node;
        while (p.right != null) {
            p = p.right;
        }
        return p;
    }

    // 查找最小值的节点 方便在删除中使用
    private Node findMinValue(Node node) {
        Node p = node;
        while (p.left != null) {
            p = p.left;
        }
        return p;
    }
    
    // 删除最大值 跟上面相同利用返回值改变指向
    private Node deleteMaxValue(Node node) {
        if (node.right == null) {
            Node newNode = node.left;
            node.left = null;
            size--;
            return newNode;
        }
        node.right = deleteMaxValue(node.right);
        return node;
    }

    // 删除最小值 跟上面相同利用返回值改变指向
    private Node deleteMinValue(Node node) {
        if (node.left == null) {
            Node newNode = node.right;
            node.right = null;
            size--;
            return newNode;
        }
        node.left = deleteMinValue(node.left);
        return node;
    }

    // 打印 方便自己查看结果是否正确
    public void print() {
        print(this.root);
        System.out.println();
    }

    private void print(Node node) {
        if (node == null) {
            return;
        }
        System.out.print(node.e + " ");
        print(node.left);
        print(node.right);
    }
}

特别感谢 liuyubobobo 老师的课程。