二叉树之实现集合接口

202 阅读7分钟

前言

其实这篇文章本来想写AVL树的,但这几天想了想,上一篇手写二叉树的文章其实还可以再拓展一下,也就是实现集合接口,不再使用上一篇文章中定义的旧接口,这里记录一下实现过程,那就开始吧。

上一篇文章地址:手写二叉树

定义 Tree 接口

interface Tree<E> : Collection<E> {

    override val size: Int

    override fun contains(element: E): Boolean

    override fun containsAll(elements: Collection<E>): Boolean

    override fun isEmpty(): Boolean

    override fun iterator(): Iterator<E>

    /**
     * 前序遍历
     */
    fun preorder(): Iterator<E>

    /**
     * 中序遍历
     */
    fun inorder(): Iterator<E>

    /**
     * 后序遍历
     */
    fun postorder(): Iterator<E>

    /**
     * 层序遍历
     */
    fun levelOrder(): Iterator<E>

    /**
     * 获取树的高度
     */
    fun height(): Int

    /**
     * 是否满树
     */
    fun isFull(): Boolean

    /**
     * 是否完美树
     */
    fun isPerfect(): Boolean

    /**
     * 是否完全树
     */
    fun isComplete(): Boolean
}

interface MutableTree<E> : Tree<E>, MutableCollection<E> {

    override fun add(element: E): Boolean

    override fun addAll(elements: Collection<E>): Boolean

    override fun clear()

    override fun remove(element: E): Boolean

    override fun removeAll(elements: Collection<E>): Boolean

    override fun retainAll(elements: Collection<E>): Boolean

    override fun iterator(): MutableIterator<E>

    override fun preorder(): MutableIterator<E>

    override fun inorder(): MutableIterator<E>

    override fun postorder(): MutableIterator<E>

    override fun levelOrder(): MutableIterator<E>
}

这里需要说明一下,在 Kotlin 中所有集合都被分为不可变与可变,如 ListMutableList ,变指的是能否对集合元素进行添加删除等,因此这里定义了 TreeMutableTree 接口,分别实现 CollectionMutableCollection 接口。对于遍历操作这里统一使用迭代器来实现,而 isFull()isPerfect()isComplete() 这几个接口就暂时先放在 Tree 接口吧,接下来对一些简单操作进行实现。

实现 AbstractTree

abstract class AbstractTree<E> : Tree<E> {

    override fun contains(element: E): Boolean {
        return any { it == element }
    }

    override fun containsAll(elements: Collection<E>): Boolean {
        return elements.all { contains(it) }
    }

    override fun isEmpty(): Boolean {
        return size == 0
    }

    override fun toString(): String {
        return toList().toString()
    }
}

由于实现的这几个方法一般来说都是通用的,所以定义一个 AbstractTree 来完成,现在针对这些方法进行一一解释。

  1. contains:判断是否包含指定元素,这里用到了 any() 这样一个方法,它的原型是这样的:

    public inline fun <T> Iterable<T>.any(predicate: (T) -> Boolean): Boolean {
        if (this is Collection && isEmpty()) return false
        for (element in this) if (predicate(element)) return true
        return false
    }
    

    表示如果集合包含指定元素则返回 true 。

  2. containsAll:判断是否包含指定集合元素,这里用到了 all() ,它的原型是这样的:

    public inline fun <T> Iterable<T>.all(predicate: (T) -> Boolean): Boolean {
        if (this is Collection && isEmpty()) return true
        for (element in this) if (!predicate(element)) return false
        return true
    }
    

    表示如果每个集合元素都满足 predicate 就返回 true 。

  3. isEmpty:判断树是否为空

  4. toString:直接将树转成 List 后 toString ,toList() 原型如下:

    public fun <T> Iterable<T>.toList(): List<T> {
        if (this is Collection) {
            return when (size) {
                0 -> emptyList()
                1 -> listOf(if (this is List) get(0) else iterator().next())
                else -> this.toMutableList()
            }
        }
        return this.toMutableList().optimizeReadOnlyList()
    }
    

实现 BinaryTree

BinaryTree 的实现其实跟上一篇文章差别不大,唯一不同就是实现了迭代器,这里将代码贴出来然后针对部分代码进行解释。

abstract class BinaryTree<E> : AbstractTree<E>() {

    protected open var root: TreeNode<E>? = null

    override fun iterator(): Iterator<E> {
        return levelOrder()
    }

    override fun preorder(): Iterator<E> {
        return PreorderIterator(root)
    }

    override fun inorder(): Iterator<E> {
        return InorderIterator(root)
    }

    override fun postorder(): Iterator<E> {
        return PostorderIterator(root)
    }

    override fun levelOrder(): Iterator<E> {
        return LevelOrderIterator(root)
    }

    override fun height(): Int {
        val node = root ?: return 0
        val queue = LinkedList<TreeNode<E>>().also {
            it.offer(node)
        }
        var height = 0
        var levelSize = 1
        while (queue.isNotEmpty()) {
            val poll = queue.poll()
            levelSize--
            if (poll.left != null) {
                queue.offer(poll.left)
            }
            if (poll.right != null) {
                queue.offer(poll.right)
            }
            if (levelSize == 0) {
                levelSize = queue.size
                height++
            }
        }
        return height
    }

    override fun isFull(): Boolean {
        var isFull = true
        levelOrderTraversal {
            if (!(it.isLeaf || it.hasTwoChildren)) {
                isFull = false
                return@levelOrderTraversal true
            }
            false
        }
        return isFull
    }

    override fun isPerfect(): Boolean {
        val height = height()
        if (height == 0) {
            return false
        }
        return (2 shl (height - 1)) - 1 == size
    }

    override fun isComplete(): Boolean {
        var isComplete = true
        var isLeafFound = false
        levelOrderTraversal {
            if (isLeafFound && !it.isLeaf) {
                isComplete = false
                return@levelOrderTraversal true
            }
            if (it.left == null && it.right != null) {
                isComplete = false
                return@levelOrderTraversal true
            }
            if (it.right == null) {
                isLeafFound = true
            }
            false
        }
        return isComplete
    }

    private fun levelOrderTraversal(onPoll: (TreeNode<E>) -> Boolean) {
        val node = root ?: return
        val queue = LinkedList<TreeNode<E>>().also {
            it.offer(node)
        }
        while (queue.isNotEmpty()) {
            val poll = queue.poll()
            if (onPoll(poll)) break
            if (poll.left != null) {
                queue.offer(poll.left)
            }
            if (poll.right != null) {
                queue.offer(poll.right)
            }
        }
    }

    protected open class TreeNode<E>(var item: E, var parent: TreeNode<E>?) {

        var left: TreeNode<E>? = null

        var right: TreeNode<E>? = null

        val isLeaf: Boolean
            get() = left == null && right == null

        val hasTwoChildren: Boolean
            get() = left != null && right != null

        val isLeftChild: Boolean
            get() = this == parent?.left

        val isRightChild: Boolean
            get() = this == parent?.right
    }

    protected abstract inner class BaseIterator(protected var node: TreeNode<E>?) : Iterator<E> {

        private var index = 0

        override fun hasNext(): Boolean {
            return index < size
        }

        override fun next(): E {
            if (!hasNext()) {
                throw NoSuchElementException()
            }
            index++
            return nextNode().item
        }

        internal abstract fun nextNode(): TreeNode<E>
    }

    protected open inner class PreorderIterator(node: TreeNode<E>?) : BaseIterator(node) {

        private val stack by lazy {
            LinkedList<TreeNode<E>>().also {
                it.push(node!!)
            }
        }

        override fun nextNode(): TreeNode<E> {
            val stack = stack
            val pop = stack.pop()
            if (pop.right != null) {
                stack.push(pop.right)
            }
            if (pop.left != null) {
                stack.push(pop.left)
            }
            return pop
        }
    }

    protected open inner class InorderIterator(node: TreeNode<E>?) : BaseIterator(node) {

        private val stack by lazy {
            LinkedList<TreeNode<E>>()
        }

        override fun nextNode(): TreeNode<E> {
            val stack = stack
            var node = node
            while (node != null || stack.isNotEmpty()) {
                if (node != null) {
                    stack.push(node)
                    node = node.left
                } else {
                    node = stack.pop()
                    this.node = node.right
                    return node
                }
            }
            throw NoSuchElementException()
        }
    }

    protected open inner class PostorderIterator(node: TreeNode<E>?) : BaseIterator(node) {

        private val inStack by lazy {
            LinkedList<TreeNode<E>>().also {
                it.push(node!!)
            }
        }

        private val outStack by lazy {
            LinkedList<TreeNode<E>>()
        }

        override fun nextNode(): TreeNode<E> {
            val inStack = inStack
            val outStack = outStack
            while (inStack.isNotEmpty()) {
                val pop = inStack.pop()
                outStack.push(pop)
                if (pop.left != null) {
                    inStack.push(pop.left)
                }
                if (pop.right != null) {
                    inStack.push(pop.right)
                }
            }
            while (outStack.isNotEmpty()) {
                return outStack.pop()
            }
            throw NoSuchElementException()
        }
    }

    protected open inner class LevelOrderIterator(node: TreeNode<E>?) : BaseIterator(node) {

        private val queue by lazy {
            LinkedList<TreeNode<E>>().also {
                it.offer(node!!)
            }
        }

        override fun nextNode(): TreeNode<E> {
            val queue = queue
            val poll = queue.poll()
            if (poll.left != null) {
                queue.offer(poll.left)
            }
            if (poll.right != null) {
                queue.offer(poll.right)
            }
            return poll
        }
    }
}

除了迭代器之外,其他都是换汤不换药,这里主要说下迭代器的实现。

其实就是将遍历循环的步骤分摊到 nextNode() 里面,外部每调用一次迭代器的 next() 才取下一个,除了后序遍历比较特殊,其他遍历基本都是调用一次执行一次操作然后返回,而由于后序遍历用到了两个栈 inStackoutStack ,只有等 inStack 将所有节点入栈后才会有出栈操作,因此在第一次调用 next() 时会先将所有节点入栈,后面的遍历中就直接从 outStack 取就行了。

实现 BinarySearchTree

由于 BinaryTree 没有实现 MutableTree 接口,因此它是不可变的,毕竟它没办法定义添加删除的逻辑,而这些都得由子类去做,那么二叉搜索树的主要实现逻辑就是实现添加、删除等,当然,对比上一篇文章,大部分逻辑依旧是换汤不换药。

open class BST<E>(private var comparator: Comparator<E>? = null) : BinaryTree<E>(), MutableTree<E> {

    override var size: Int = 0
        protected set

    private var modCount = 0

    override fun add(element: E): Boolean {
        require(element != null) {
            "The element must not be null."
        }
        var node = root
        if (node == null) {
            node = createNode(element, null)
            root = node
            size++
            modCount++
            return true
        }
        var compare = 0
        var parent = node
        while (node != null) {
            parent = node
            compare = compare(element, node.item)
            if (compare < 0) {
                node = node.left
            } else if (compare > 0) {
                node = node.right
            } else {
                node.item = element
                return true
            }
        }
        val newNode = createNode(element, parent)
        if (compare < 0) {
            parent!!.left = newNode
        } else {
            parent!!.right = newNode
        }
        size++
        modCount++
        return true
    }

    override fun addAll(elements: Collection<E>): Boolean {
        elements.forEach {
            add(it)
        }
        return true
    }

    override fun clear() {
        root = null
        size = 0
        modCount++
    }

    override fun remove(element: E): Boolean {
        return remove(node(element))
    }

    private fun remove(node: TreeNode<E>?): Boolean {
        var _node = node ?: return false
        if (_node.hasTwoChildren) {
            val succ = successor(_node)!!
            _node.item = succ.item
            _node = succ
        }
        val replacement = _node.left ?: _node.right
        if (replacement != null) {
            replacement.parent = _node.parent
            if (_node.parent == null) {
                root = replacement
            } else if (_node.isLeftChild) {
                _node.parent!!.left = replacement
            } else {
                _node.parent!!.right = replacement
            }
        } else if (_node.parent == null) {
            root = null
        } else {
            if (_node.isLeftChild) {
                _node.parent!!.left = null
            } else {
                _node.parent!!.right = null
            }
        }
        size--
        modCount++
        return true
    }

    private fun predecessor(node: TreeNode<E>): TreeNode<E>? {
        var _node: TreeNode<E>? = node
        if (node.left != null) {
            _node = node.left
            while (_node?.right != null) {
                _node = _node.right
            }
            return _node
        }
        while (_node?.parent != null && _node.isLeftChild) {
            _node = _node.parent
        }
        return _node?.parent
    }

    private fun successor(node: TreeNode<E>): TreeNode<E>? {
        var _node: TreeNode<E>? = node
        if (node.right != null) {
            _node = node.right
            while (_node?.left != null) {
                _node = _node.left
            }
            return _node
        }
        while (_node?.parent != null && _node.isRightChild) {
            _node = _node.parent
        }
        return _node?.parent
    }

    override fun removeAll(elements: Collection<E>): Boolean {
        var modified = false
        elements.forEach {
            modified = modified or remove(it)
        }
        return modified
    }

    override fun retainAll(elements: Collection<E>): Boolean {
        var modified = false
        val itr = iterator()
        while (itr.hasNext()) {
            val e = itr.next()
            if (!elements.contains(e)) {
                itr.remove()
                modified = true
            }
        }
        return modified
    }

    override fun contains(element: E): Boolean {
        return node(element) != null
    }

    protected open fun node(element: E): TreeNode<E>? {
        var node: TreeNode<E>? = root
        while (node != null) {
            val cmp = compare(element, node.item)
            if (cmp < 0) {
                node = node.left
            } else if (cmp > 0) {
                node = node.right
            } else {
                return node
            }
        }
        return null
    }

    protected open fun compare(e1: E, e2: E): Int {
        val cmp = comparator
        if (cmp != null) {
            return cmp.compare(e1, e2)
        }
        return (e1 as Comparable<E>).compareTo(e2)
    }

    override fun iterator(): MutableIterator<E> {
        return inorder()
    }

    override fun preorder(): MutableIterator<E> {
        return ItrProxy(PreorderIterator(root))
    }

    override fun inorder(): MutableIterator<E> {
        return ItrProxy(InorderIterator(root))
    }

    override fun postorder(): MutableIterator<E> {
        return ItrProxy(PostorderIterator(root))
    }

    override fun levelOrder(): MutableIterator<E> {
        return ItrProxy(LevelOrderIterator(root))
    }

    protected open fun createNode(item: E, parent: TreeNode<E>?): TreeNode<E> {
        return TreeNode(item, parent)
    }

    private inner class ItrProxy(target: BaseIterator) : MutableIterator<E> {

        private var lastIndex = -1

        private var expectedModCount = modCount

        private val list by lazy {
            val temp = arrayListOf<E>()
            for (e in target) {
                temp.add(e)
            }
            temp
        }

        override fun hasNext(): Boolean {
            checkForComodification()
            return lastIndex < size - 1
        }

        override fun next(): E {
            checkForComodification()
            return list[++lastIndex]
        }

        override fun remove() {
            checkForComodification()
            val e = list.removeAt(lastIndex)
            remove(e)
            lastIndex--
            expectedModCount = modCount
        }

        private fun checkForComodification() {
            if (modCount != expectedModCount) {
                throw ConcurrentModificationException()
            }
        }
    }
}

说一下几点不同:

  1. modCount:即 modifiedCount ,表示修改次数,主要是处理迭代器并发修改异常的。

  2. 迭代器:你可能会疑惑,明明父类 BinaryTree 已经写好了迭代器,为何子类还要重写?原因很简单,因为这个 BST 是可变的,必须提供迭代器的 remove() 操作(当然也可以不处理);再者,这里用了一个 ItrProxy 实现 MutableTree 接口,传入 BaseIterator 的作用也只是为了先把所有元素取出来放进列表,这里主要是考虑到自平衡搜索树的问题。

    一棵自平衡搜索树在添加和删除元素之后可能会进行平衡调整,而这就会导致树结构变化,那这个问题就严重了,可能未添加或修改元素之前我们迭代到了第二个元素,这时候突然把一个元素给移除了,那树可能就发生了平衡调整,并且调整后把后面需要迭代到的几个元素的位置调到了刚迭代过的元素前面,这样在后面的迭代中就无法获取到这些元素。

    因此这里做法也比较粗暴,直接将所有元素取出来放到 List ,即使迭代期间移除了元素,也可以保证迭代顺序不会被打乱。

  3. 一些 List 的操作:在之前的实现中没有实现 addAll()removeAll()retailsAll() 等这些标准集合接口,这里对其进行实现。

BinaryPrintTree

另外你可能发现了,之前的打印接口好像不见了,这里主要是不想将这种与实际使用不相干的代码揉到一起,于是使用另一个方法对打印进行实现,如下:

class BinaryPrintTree<E>(private val tree: BinaryTree<E>) : Tree<E> by tree, BinaryTreeInfo {

    override fun root(): Any? {
        return getRoot(tree)
    }

    override fun left(node: Any?): Any? {
        return getNode(node, true)
    }

    override fun right(node: Any?): Any? {
        return getNode(node, false)
    }

    override fun string(node: Any?): Any? {
        return getString(node)
    }

    private fun getRoot(obj: Any?): Any? {
        return try {
            val field = BinaryTree::class.java.getDeclaredField("root").also {
                it.isAccessible = true
            }
            field.get(obj)
        } catch (ignored: Exception) {
            null
        }
    }

    private fun getNode(obj: Any?, left: Boolean): Any? {
        return try {
            val bt = BinaryTree::class.java
            val subClasses = bt.declaredClasses
            val clazz = subClasses.find {
                it.name == "${bt.name}\$TreeNode"
            } ?: return null
            val child = when (left) {
                true -> clazz.getDeclaredField("left")
                else -> clazz.getDeclaredField("right")
            }.also {
                it.isAccessible = true
            }
            child.get(obj)
        } catch (ignored: Exception) {
            null
        }
    }

    private fun getString(obj: Any?): Any? {
        return try {
            val bt = BinaryTree::class.java
            val subClasses = bt.declaredClasses
            val clazz = subClasses.find {
                it.name == "${bt.name}\$TreeNode"
            } ?: return null
            val field = clazz.getDeclaredField("item").also {
                it.isAccessible = true
            }
            field.get(obj)
        } catch (ignored: Exception) {
            null
        }
    }
}

首先它也是一棵树,因为实现了 Tree 接口,并且将具体实现委托给了构造传入的 tree ,这是 Kotlin 的语法,对于节点的获取这里使用了反射,这样处理过后其实在外界使用,如果这棵树想要打印,那就构建一个 BinaryPrintTree 就可以了。

fun main() {
    val tree = getTree()
    val printTree = BinaryPrintTree(tree)
    printTree.println()
}

结尾

这样对集合接口进行实现后,对于一些形参是 Collection 的方法就可以选择使用树结构了,并且可以发现可以和集合接口进行互转了。

用这种方式实现后,对比上一篇文章的二叉树实现,是不是更好呢?

相关代码已经上传 Github ,下一篇终于要写AVL树,嘿嘿~