图解 LeetCode 将二叉搜索树变平衡(递归 + 非递归)

321 阅读7分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第27天,点击查看活动详情


大家好呀,我是帅蛋。

今天我们将二叉搜索树变平衡,这道题看起来是二叉搜索树到平衡二叉树的单个过程,实际上却是一道组合问题,而拆分出来的单个问题都是我们之前碰到过的。

我一直说,做题做到现在这个阶段,很多题目的解决都是从我们过去学过的知识中寻找办法,这道题就是很好的例子。

1382-0

下面就让我们来一起搞一下这道题~

1382-1

LeetCode 1382:将二叉搜索树变平衡

题意

给你一棵二叉搜索树,请你返回一棵平衡后的二叉搜索树,新生成的树应该与原来的树有着相同的节点值。如有多种构造方法,请你返回任意一种。

如果一棵二叉搜索树中,每个节点的两棵子树高度差不超过 1,我们称这棵二叉搜索树是平衡的。

示例

输入:root = [1,null,2,null,3,null,4,null,null]

输出:[2,1,3,null,null,null,4]

解释:这不是唯一的正确答案,[3,1,4,null,2,null,null] 也是一个可行的构造方案。

1382-2

提示

  • 树节点的数目在 [1,10^4] 范围内。
  • 1 <= Node.val <= 10^5

题目解析

将二叉搜索树变平衡这道题,如果没有之前的实战题目的沉淀,是蛮头疼的难题,因为单纯的在树上操作难度极大。

但是如果你看过我写的下面这道题:

ACM 选手图解 LeetCode 将有序数组转化为二叉搜索树

其实在上面这道题中,虽然说的是转换成二叉搜索树,其实按照里面的思路,转化成的是一棵平衡二叉树。

那现在就是如何将二叉搜索树转化成有序数组,这个更简单,说过无数次了:对二叉搜索树进行中序遍历时,得到的结果是一个有序的序列

梳理到这儿就明了了,2 步走:

(1) 将二叉搜索树转化为有序序列(中序遍历)。

(2) 将有序序列构造成平衡二叉树。这个分为 3 步:

  • 有序数组中间节点为根节点。
  • 根节点左侧区间为左子树。
  • 根节点右侧区间为右子树。

这里有一点大家稍微注意一下,那就是找【中间节点】。

如果数组是奇数那没问题,中间那个就是根节点,如果数组是偶数个,那中间节点就是两个,可能你有疑问,这个时候该怎么取?

1382-3

其实很简单,取左边那个或者右边那个都可以,都能构造出平衡二叉树。

1382-4

递归法

实现【递归算法】,两步走起来:

  • 找出重复的子问题(递推公式)。
  • 终止条件。

(1) 找出重复的子问题。

在这重复的子问题其实就是“中序遍历的重复子问题”和“有序序列构造平衡二叉树”的重复子问题。

首先中序遍历的重复子问题

中序遍历的顺序是:左子树、根、右子树。

对于左子树、右子树来说,也是同样的遍历顺序。

所以这个重复的子问题就是:先遍历左子树、再取根节点,最后遍历右子树

self.inOrder(root.left, nums)
nums.append(root.val)
self.inOrder(root.right, nums)

其次是有序序列构造平衡二叉树的重复子问题:

在题目解析中讲过了,有序序列构造平衡二叉树:

  • 有序数组中间节点为根节点。
  • 根节点左侧区间为左子树。
  • 根节点右侧区间为右子树。

那重复的子问题就是,找到根节点,递归构造左子树,递归构造右子树。

mid = left + (right - left) // 2
# 根节点
midNode = TreeNode(nums[mid])
# 递归构造左子树
midNode.left  = self.process(nums, left, mid - 1)
# 递归构造右子树
midNode.right = self.process(nums, mid + 1, right)

(2) 确定终止条件。

中序遍历的终止条件是:当前的节点为空,空的没啥好遍历的。

if root == None:
    return 

对于有序序列构造平衡二叉树的终止条件来说,当 left > right 时,就终止,返回 None,因为这个时候就是空节点。

if left > right:
    return None

接下来我们来看详细的代码。

1382-5

Python 代码实现

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
​
    # 中序遍历,将二叉搜索树转为有序数组
    def inOrder(self, root, nums):
        if root == None:
            return
        self.inOrder(root.left, nums)
        nums.append(root.val)
        self.inOrder(root.right, nums)
​
    # 将有序数组转为平衡二叉树
    def process(self, nums, left, right):
        if left > right:
            return None
        # 找数组中间元素
        mid = left + (right - left) // 2
         # 根节点
        midNode = TreeNode(nums[mid])
        # 递归构造左子树
        midNode.left = self.process(nums, left, mid - 1)
        # 递归构造右子树
        midNode.right = self.process(nums, mid + 1, right)
​
        return midNode
​
    def balanceBST(self, root: TreeNode) -> TreeNode:
        nums = []
        self.inOrder(root, nums)
        return self.process(nums, 0, len(nums) - 1)

Java 代码实现

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
class Solution {
    // 中序遍历,将二叉搜索树转为有序数组
    public void inOrder(TreeNode root, List<Integer> nums){
        if(root == null){
            return ;
        }
        inOrder(root.left, nums);
        nums.add(root.val);
        inOrder(root.right, nums);
    }
    // 将有序数组转为平衡二叉树
    public TreeNode process(List<Integer> nums, int left, int right){
        if(left > right){
            return null;
        }
        // 找数组中间元素
        int mid = left + ((right - left) >> 1);
        // 根节点
        TreeNode midNode = new TreeNode(nums.get(mid));
        // 递归构造左子树
        midNode.left = process(nums, left, mid - 1);
        // 递归构造右子树
        midNode.right = process(nums, mid + 1, right);
​
        return midNode;
    }
​
    public TreeNode balanceBST(TreeNode root) {
        List<Integer> nums = new ArrayList<Integer>();
        inOrder(root, nums);
        TreeNode rootNew = process(nums, 0, nums.size() - 1);
        return rootNew;
    }
}

递归法使用中序遍历将二叉搜索树转化为有序序列的时间复杂度为 O(n),有序序列构造平衡二叉树时每个元素都要访问到,且构造单个节点的时间复杂度为 O(1),即有序序列构造平衡二叉树的时间复杂度也为 O(n),所以最终的时间复杂度为 O(n)

因为初始是二叉搜索树,构造的又是二叉平衡树,所以无论是递归版的中序遍历或者递归版的有序序列构造平衡二叉树,调用的栈大小都是 O(logn),此外还使用了一个数组 nums 存储中序遍历之后的有序序列,这个的空间复杂度是 O(n),所以总的空间复杂度为 O(n)

非递归法(迭代)

非递归法的话其实还是上面说的 2 步:

(1)非递归的中序遍历。

(2)非递归的将有序序列构造成平衡二叉树。

我们以下图为例:

1382-6

非递归的中序遍历我就不在这图解了,都说过多少次了,如果你实在不会,看我下面这篇文章就稳了:

ACM 选手带你玩转二叉树前中后序遍历(非递归版)

中序遍历之后的有序序列 nums = [1, 2, 3, 4]。

下面我们重点看一下有序序列构造平衡二叉树。

这里为了模拟构造二叉搜索树的过程,需要用到 3 个队列:

  • rootQue 存放遍历的节点。
  • leftQue 存放左区间的下标。
  • rightQue 存放右区间的下标。

之后就是不断的模拟寻找根节点,构造左子树和构造右子树。

首先初始化二叉搜索树的根节点以及 3 个队列:

1382-7

# 初始化根节点
root = TreeNode(0)
# 队列存放遍历的节点
rootQue = [root]
# 队列存放左区间下标
leftQue = [0]
# 队列存放右区间下标
rightQue = [len(nums) - 1]

第 1 步,rootQue 不为空,cur 记录 rootQur 出队列节点 0,left 记录 leftQue 出队列下标 0,right 记录 rightQue 出队列下标 3。

cur = rootQue.pop(0)
left = leftQue.pop(0)
right = rightQue.pop(0)

此时的中间下标 mid = 1, nums[1] = 2,赋值给 cur,即此时 cur.val = nums[1] = 2。

1382-8

# 找数组中间元素
mid = left + (right - left) // 2
# 将中间元素值赋值给节点
cur.val = nums[mid]

接下来处理左区间,此时 left = 0,mid = 1,left < mid,初始化 cur 节点的左孩子节点,将左孩子 cur.left、左区间的左下标 0 和左区间的右下标 mid - 1 = 0 分别入队列。

1382-9

if left < mid:
    cur.left = TreeNode(0)
    rootQue.append(cur.left)
    leftQue.append(left)
    rightQue.append(mid - 1)

下面处理右区间,同理初始化 cur 节点的右孩子节点,将右孩子 cur.right、右区间的左下标 mid + 1 = 2 和右区间的右下标 3 分别入队列。

1382-10

if right > mid:
    cur.right = TreeNode(0)
    rootQue.append(cur.right)
    leftQue.append(mid + 1)
    rightQue.append(right)

第 2 步,rootQue 不为空,cur 记录 rootQur 出队列节点 0,left 记录 leftQue 出队列下标 0,right 记录 rightQue 出队列下标 0。

此时的中间下标 mid = 0, nums[0] = 1,赋值给 cur,即此时 cur.val = nums[0] = 1。

1382-11

因为此时 left = 0 并不小于 mid = 0,所以不处理左区间,同样此时 right = 0 并不大于 mid,所以也不处理右区间。

第 3 步,rootQue 不为空,cur 记录 rootQur 出队列节点 0,left 记录 leftQue 出队列下标 2,right 记录 rightQue 出队列下标 3。

此时的中间下标 mid = 2, nums[2] = 3,赋值给 cur,即此时 cur.val = nums[2] = 3。

1382-12

此时 left = 2 并不小于 mid = 2,所以不处理左区间,此时 right = 3 > mid,我们来处理右区间。

初始化 cur 节点的右孩子节点,将右孩子 cur.right、右区间的左下标 mid + 1 = 3 和右区间的右下标 3 分别入队列。

1382-13

第 4 步,rootQue 不为空,cur 记录 rootQur 出队列节点 0,left 记录 leftQue 出队列下标 3,right 记录 rightQue 出队列下标 3。

此时的中间下标 mid = 3, nums[3] = 4,赋值给 cur,即此时 cur.val = nums[3] = 4。

1382-14

此时 left = 3 并不小于 mid = 3,所以不处理左区间,同样此时 right = 3 并不大于 mid,所以也不处理右区间。

至此 rootQue 为空,返回平衡二叉树,程序结束。

Python 代码实现

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:

    def inOrder(self, root):
        # 注意:根节点为空,直接返回空列表
        if not root:
            return []

        stack = []
        nums = []

        while root or stack:
            # 一直向左子树走,每一次将当前节点保存到栈中
            if root:
                stack.append(root)
                root = root.left
            # 当前节点为空,证明走到了最左边,从栈中弹出节点加入结果数组
            # 开始对右子树重复上述过程。
            else:
                cur = stack.pop()
                nums.append(cur.val)
                root = cur.right

        return nums

    def balanceBST(self, root: TreeNode) -> TreeNode:

        nums = self.inOrder(root)

        if len(nums) == 0:
            return None
        # 初始化根节点
        root = TreeNode(0)
        # 队列存放遍历的节点
        rootQue = [root]
        # 队列存放左区间下标
        leftQue = [0]
        # 队列存放右区间下标
        rightQue = [len(nums) - 1]

        while rootQue:
            cur = rootQue.pop(0)
            left = leftQue.pop(0)
            right = rightQue.pop(0)
            # 找数组中间元素
            mid = left + (right - left) // 2
            # 将中间元素值赋值给节点
            cur.val = nums[mid]
            # 处理左区间
            if left < mid:
                cur.left = TreeNode(0)
                rootQue.append(cur.left)
                leftQue.append(left)
                rightQue.append(mid - 1)
            # 处理右区间
            if right > mid:
                cur.right = TreeNode(0)
                rootQue.append(cur.right)
                leftQue.append(mid + 1)
                rightQue.append(right)                
        return root

Java 代码实现

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
class Solution {
    
    public List<Integer> inOrder(TreeNode root){
        if(root == null){
            return null;
        }
        List<Integer> nums = new ArrayList<Integer>();
        Stack<TreeNode> stack = new Stack<TreeNode>();
        while(stack.size() > 0 || root != null){
            if(root != null){
                stack.add(root);
                root = root.left;
            }
            else{
                TreeNode cur = stack.pop();
                nums.add(cur.val);
                root = cur.right;
            }
        }
        return nums;
    }

    public TreeNode balanceBST(TreeNode root) {
        List<Integer> nums = new ArrayList<Integer>();
        nums = inOrder(root);
        if(nums.size() == 0){
            return null;
        }
        // 初始化根节点
        TreeNode rootNew = new TreeNode(0);
        // 队列存放遍历的节点
        Queue<TreeNode> rootQue = new LinkedList<>();
        // 队列存放左区间下标
        Queue<Integer> leftQue = new LinkedList<>();
        // 队列存放右区间下标
        Queue<Integer> rightQue = new LinkedList<>();
        // 初始化 3 个队列
        rootQue.offer(rootNew);
        leftQue.offer(0);
        rightQue.offer(nums.size() - 1);

        while (!rootQue.isEmpty()){
            TreeNode cur = rootQue.poll();
            int left = leftQue.poll();
            int right = rightQue.poll();
            // 找数组中间元素
            int mid = left + ((right - left) >> 1);
            // 将中间元素值赋值给节点
            cur.val = nums.get(mid);
            // 处理左区间
            if (left < mid) {
                cur.left = new TreeNode(0);
                rootQue.offer(cur.left);
                leftQue.offer(left);
                rightQue.offer(mid - 1);
            }
            // 处理右区间
            if (right > mid) {
                cur.right = new TreeNode(0);
                rootQue.offer(cur.right);
                leftQue.offer(mid + 1);
                rightQue.offer(right);
            }
        }
        return rootNew;
    }
}

同样对于非递归法,其时间复杂度和空间复杂度都为 O(n)


图解将二叉搜索树变平衡到这就结束辣,这道题看下来我不知道你有什么想法,我反正还是想再提一下【总结】的必要性。

每道题有每道题的总结,每种类型的题有某类题的总结,千万不要怕麻烦,虽然刚开始的时候确实会很麻烦...

习惯性梳理总结,在这个过程中重新产生更多的认识,理解更深,有更多的想法,以后你会感谢你的这种积累。

1382-15

呃,说教有时候确实很烦人,我也适可而止啦,不过还是希望你们能听进去一丢丢~

我是帅蛋,我们下次见!