轻松掌握数据结构:平衡二叉树

121 阅读7分钟

平衡二叉树

平衡二叉树(如 AVL 树)是一种自平衡的二叉搜索树(BST),在BST的基础上增加了平衡约束,保证左右子树的高度差控制在1以内,从而保证高效的查询和插入删除操作。

平衡因子:失衡节点左子树高度 减去 右子树高度 的绝对值,值可能大于1,等于0,小于1

举例

这是一个平衡二叉树

image-20250806154708312.png

这不是一个平衡二叉树

image-20250806155209890.png

**解释:**因为根节点6,左子节点的高度是3,右子节点的高度是1,所以平衡因子是3-1=2,所以失衡了,不是平衡二叉树。失衡节点是6。

为什么使用AVL树

BST在极端情况下会退化成链表的数据结构,使得所有操作的时间复杂度由O (log n) 降至 O (n),平衡二叉树就能解决这种问题。

旋转

旋转的作用是让失衡的树重新满足 "==左右子树高度差≤1==" 的条件。

失衡的条件

当某个节点的平衡因子(左子树高度 - 右子树高度)的绝对值>1 时,树就失衡了。

失衡的情况

  • 左左失衡 ------右旋转

    平衡因子>1,并且新增节点在失衡节点的左子树的左子树上。

  • 右右失衡-------左旋转

    平衡因子<1,并且新增节点在失衡节点的右子树的右子树上。

  • 左右失衡-------先左旋转,再右旋转

    平衡因子>1,并且新增节点在失衡节点的左子树的右子树上。

  • 右左失衡-------先右旋转,再左旋转

    平衡因子<1,并且新增节点在失衡节点的右子树的左子树上。

旋转规则

右旋转

旋转前
	      6
       / \
      4   8
     / \
    3   5
   /
  2  

标记节点:

​ 失衡节点x=6

​ 失衡节点的左子节点y=4

​ 失衡节点的左子节点的右子树T=5

        x
       / \
      y   8
     / \
    3   T
   /
  2

右旋转的过程

  1. 提升 y 为新的根节点: 在右旋转中,y(即 4)成为新的根节点。

  2. x 移动到新根节点y的右子树: 原根节点 x(即 6)成为新根节点 y(即 4)的右子节点。

  3. T 放到 x 的左子树: 原本 y 的右子树 T(即 5)被移动到 x 的左子树上。

    	  y
       / \
      3   x
     /   / \
    2   T   8
    
  旋转后
      4
     / \
    3   6
   /   / \
  2   5   8

左旋转

旋转前
	5
   / \
  4   7
     / \          
    6   8
         \
          9

标记节点:

​ 失衡节点x=5

​ 失衡节点的右子节点y=7

​ 失衡节点的右子节点的左子树T2=6

旋转过程:

  1. 第一步:将 y 提升为新的根节点
    • 旋转后,y(即 7)将成为新的根节点。
  2. 第二步:将原根节点 x 移动到 y 的左子树上
    • 旋转后,x(即 5)将成为 y 的左子节点。
  3. 第三步:将 y 的左子树 T2 移到 x 的右子树上
    • 原本 y 的左子树是 T2(即 6),现在将 T2 移到 x 的右子树上。
旋转后          
    7
   / \
  5   8
 / \     \
4   6     9

图解失衡情况

左左失衡 ------右旋转

image-20250806161603065.png

右右失衡-------左旋转

image-20250806164103030.png

左右失衡-------先左旋后右旋

image-20250806174827887.png

右左失衡-------先右旋后左旋

image-20250806180421974.png

java代码实现平衡二叉树

package com.stu.shujujiegou;

/**
 * @title:
 * @auther: eleven
 * @description:
 * @data: 2025/8/5 17:17
 * @parm:
 * @return:
 */
import java.util.LinkedList;
import java.util.Queue;

public class AVLTree<T extends Comparable<T>> {
    // AVL树节点类,包含高度信息
    private static class Node<T> {
        T data;          // 节点数据
        Node<T> left;    // 左子节点
        Node<T> right;   // 右子节点
        int height;      // 节点高度(以该节点为根的子树高度)

        public Node(T data) {
            this.data = data;
            this.height = 1;  // 新节点初始高度为1
            this.left = null;
            this.right = null;
        }
    }

    private Node<T> root;  // 根节点

    // 构造空树
    public AVLTree() {
        root = null;
    }

    /**
     * 获取节点高度(空节点高度为0)
     */
    private int height(Node<T> node) {
        return (node == null) ? 0 : node.height;
    }

    /**
     * 计算平衡因子 = 左子树高度 - 右子树高度
     */
    private int balanceFactor(Node<T> node) {
        return (node == null) ? 0 : height(node.left) - height(node.right);
    }

    /**
     * 更新节点高度 = 1 + 左右子树最大高度
     */
    private void updateHeight(Node<T> node) {
        if (node != null) {
            node.height = 1 + Math.max(height(node.left), height(node.right));
        }
    }

    /**
     * 右旋转操作(处理左左失衡)
     *        y               x
     *       / \             / \
     *      x   T3   →     T1   y
     *     / \                 / \
     *    T1  T2             T2  T3
     */
    private Node<T> rightRotate(Node<T> y) {
        Node<T> x = y.left;
        Node<T> T2 = x.right;

        // 执行旋转
        x.right = y;
        y.left = T2;

        // 更新高度
        updateHeight(y);
        updateHeight(x);

        // 返回新的根节点
        return x;
    }

    /**
     * 左旋转操作(处理右右失衡)
     *    x               y
     *   / \             / \
     *  T1  y     →     x   T3
     *     / \         / \
     *    T2  T3      T1  T2
     */
    private Node<T> leftRotate(Node<T> x) {
        Node<T> y = x.right;
        Node<T> T2 = y.left;

        // 执行旋转
        y.left = x;
        x.right = T2;

        // 更新高度
        updateHeight(x);
        updateHeight(y);

        // 返回新的根节点
        return y;
    }

    /**
     * 插入元素
     */
    public void insert(T data) {
        root = insert(root, data);
    }

    private Node<T> insert(Node<T> node, T data) {
        // 1. 执行标准BST插入
        if (node == null) {
            return new Node<>(data);
        }

        int compareResult = data.compareTo(node.data);
        if (compareResult < 0) {
            // 插入左子树
            node.left = insert(node.left, data);
        } else if (compareResult > 0) {
            // 插入右子树
            node.right = insert(node.right, data);
        } else {
            // 不允许插入重复值
            return node;
        }

        // 2. 更新当前节点高度
        updateHeight(node);

        // 3. 计算平衡因子,检查是否失衡
        int balance = balanceFactor(node);

        // 4. 处理四种失衡情况

        // 情况1:左左失衡(LL)- 右旋转
        if (balance > 1 && data.compareTo(node.left.data) < 0) {
            return rightRotate(node);
        }

        // 情况2:右右失衡(RR)- 左旋转
        if (balance < -1 && data.compareTo(node.right.data) > 0) {
            return leftRotate(node);
        }

        // 情况3:左右失衡(LR)- 先左旋转左子树,再右旋转当前节点
        if (balance > 1 && data.compareTo(node.left.data) > 0) {
            node.left = leftRotate(node.left);
            return rightRotate(node);
        }

        // 情况4:右左失衡(RL)- 先右旋转右子树,再左旋转当前节点
        if (balance < -1 && data.compareTo(node.right.data) < 0) {
            node.right = rightRotate(node.right);
            return leftRotate(node);
        }

        // 未失衡,返回原节点
        return node;
    }

    /**
     * 删除元素
     */
    public void delete(T data) {
        root = delete(root, data);
    }

    private Node<T> delete(Node<T> node, T data) {
        // 1. 执行标准BST删除
        if (node == null) {
            return null; // 未找到要删除的节点
        }

        int compareResult = data.compareTo(node.data);
        if (compareResult < 0) {
            // 左子树中删除
            node.left = delete(node.left, data);
        } else if (compareResult > 0) {
            // 右子树中删除
            node.right = delete(node.right, data);
        } else {
            // 找到要删除的节点

            // 情况1:叶子节点或只有一个子节点
            if (node.left == null || node.right == null) {
                Node<T> temp = (node.left != null) ? node.left : node.right;

                // 子节点为空(叶子节点)
                if (temp == null) {
                    temp = node;
                    node = null;
                } else {
                    // 一个子节点,用子节点替换当前节点
                    node = temp;
                }
            } else {
                // 情况2:有两个子节点
                // 找到右子树的最小值节点
                Node<T> temp = findMinNode(node.right);
                // 用最小值替换当前节点值
                node.data = temp.data;
                // 删除右子树中的最小值节点
                node.right = delete(node.right, temp.data);
            }
        }

        // 如果树为空,返回null
        if (node == null) {
            return null;
        }

        // 2. 更新当前节点高度
        updateHeight(node);

        // 3. 计算平衡因子
        int balance = balanceFactor(node);

        // 4. 处理失衡情况

        // 左左失衡
        if (balance > 1 && balanceFactor(node.left) >= 0) {
            return rightRotate(node);
        }

        // 左右失衡
        if (balance > 1 && balanceFactor(node.left) < 0) {
            node.left = leftRotate(node.left);
            return rightRotate(node);
        }

        // 右右失衡
        if (balance < -1 && balanceFactor(node.right) <= 0) {
            return leftRotate(node);
        }

        // 右左失衡
        if (balance < -1 && balanceFactor(node.right) > 0) {
            node.right = rightRotate(node.right);
            return leftRotate(node);
        }

        return node;
    }

    /**
     * 查找以node为根的树中的最小值节点(最左节点)
     */
    private Node<T> findMinNode(Node<T> node) {
        Node<T> current = node;
        while (current.left != null) {
            current = current.left;
        }
        return current;
    }

    /**
     * 查找以node为根的树中的最大值节点(最右节点)
     */
    private Node<T> findMaxNode(Node<T> node) {
        Node<T> current = node;
        while (current.right != null) {
            current = current.right;
        }
        return current;
    }

    /**
     * 查找元素是否存在
     */
    public boolean contains(T data) {
        return contains(root, data);
    }

    private boolean contains(Node<T> node, T data) {
        if (node == null) {
            return false;
        }

        int compareResult = data.compareTo(node.data);
        if (compareResult < 0) {
            return contains(node.left, data);
        } else if (compareResult > 0) {
            return contains(node.right, data);
        } else {
            return true; // 找到匹配节点
        }
    }

    /**
     * 前序遍历(根 -> 左 -> 右)
     */
    public void preOrder() {
        preOrder(root);
        System.out.println();
    }

    private void preOrder(Node<T> node) {
        if (node != null) {
            System.out.print(node.data + " ");
            preOrder(node.left);
            preOrder(node.right);
        }
    }

    /**
     * 中序遍历(左 -> 根 -> 右)
     * 对于BST,中序遍历结果为升序序列
     */
    public void inOrder() {
        inOrder(root);
        System.out.println();
    }

    private void inOrder(Node<T> node) {
        if (node != null) {
            inOrder(node.left);
            System.out.print(node.data + " ");
            inOrder(node.right);
        }
    }

    /**
     * 后序遍历(左 -> 右 -> 根)
     */
    public void postOrder() {
        postOrder(root);
        System.out.println();
    }

    private void postOrder(Node<T> node) {
        if (node != null) {
            postOrder(node.left);
            postOrder(node.right);
            System.out.print(node.data + " ");
        }
    }

    /**
     * 层序遍历(按层次从上到下,从左到右)
     */
    public void levelOrder() {
        if (root == null) {
            return;
        }

        Queue<Node<T>> queue = new LinkedList<>();
        queue.add(root);

        while (!queue.isEmpty()) {
            Node<T> node = queue.poll();
            System.out.print(node.data + " ");

            if (node.left != null) {
                queue.add(node.left);
            }
            if (node.right != null) {
                queue.add(node.right);
            }
        }
        System.out.println();
    }

    /**
     * 获取树中的最小值
     */
    public T findMin() {
        if (root == null) {
            throw new IllegalStateException("树为空,无法查找最小值");
        }
        return findMinNode(root).data;
    }

    /**
     * 获取树中的最大值
     */
    public T findMax() {
        if (root == null) {
            throw new IllegalStateException("树为空,无法查找最大值");
        }
        return findMaxNode(root).data;
    }

    /**
     * 检查树是否平衡(用于测试)
     */
    public boolean isBalanced() {
        return isBalanced(root);
    }

    private boolean isBalanced(Node<T> node) {
        if (node == null) {
            return true;
        }

        int balance = balanceFactor(node);
        // 平衡因子绝对值必须 <= 1,且左右子树都必须平衡
        return Math.abs(balance) <= 1 && isBalanced(node.left) && isBalanced(node.right);
    }

    /**
     * 清空树
     */
    public void clear() {
        root = null;
    }

    /**
     * 检查树是否为空
     */
    public boolean isEmpty() {
        return root == null;
    }

    // 测试方法
    public static void main(String[] args) {
        AVLTree<Integer> avlTree = new AVLTree<>();

        // 插入测试 - 包含会导致失衡的场景
        avlTree.insert(10);
        avlTree.insert(20);
        avlTree.insert(30); // 触发右右失衡,执行左旋转
        avlTree.insert(40);
        avlTree.insert(50);
        avlTree.insert(25); // 触发右左失衡,先右旋转再左旋转

        System.out.println("插入后的中序遍历(升序):");
        avlTree.inOrder(); // 输出: 10 20 25 30 40 50

        System.out.println("插入后的前序遍历:");
        avlTree.preOrder(); // 输出: 30 20 10 25 40 50

        System.out.println("插入后的层序遍历:");
        avlTree.levelOrder(); // 输出: 30 20 40 10 25 50

        System.out.println("树是否平衡: " + avlTree.isBalanced()); // 输出: true

        // 删除测试
        avlTree.delete(30);
        System.out.println("\n删除30后的中序遍历:");
        avlTree.inOrder(); // 输出: 10 20 25 40 50

        System.out.println("删除30后的前序遍历:");
        avlTree.preOrder(); // 输出: 40 20 10 25 50

        System.out.println("删除后树是否平衡: " + avlTree.isBalanced()); // 输出: true

        // 查找和最值测试
        System.out.println("\n是否包含25: " + avlTree.contains(25)); // 输出: true
        System.out.println("是否包含30: " + avlTree.contains(30)); // 输出: false
        System.out.println("最小值: " + avlTree.findMin()); // 输出: 10
        System.out.println("最大值: " + avlTree.findMax()); // 输出: 50
    }
}