红黑树代码实现

107 阅读4分钟

视频链接:手撕红黑树-C语言版

rb_tree.h

#ifndef _RB_TREE_H
#define _RB_TREE_H

#define RB_RED 0
#define RB_BLACK 1

struct rb_node {
    struct rb_node *parent;
    struct rb_node *left, *right;
    int color;
    int key;
};

struct rb_tree {
    struct rb_node *root;
};

static inline void rb_init(struct rb_tree *tree)
{
    tree->root = NULL;
}

int rb_insert(struct rb_tree *tree, int key);
const struct rb_node *rb_find(const struct rb_tree *tree, int key);
void rb_earase(struct rb_tree *tree, const struct rb_node *node);
void rb_destroy(struct rb_tree *tree);

#endif

rb_tree.c

#include <stdio.h>
#include <stdlib.h>
#include "rb_tree.h"

static void rb_left_rotate(struct rb_tree *tree, struct rb_node *node)
{
    struct rb_node *right = node->right;
    struct rb_node *parent = node->parent;

    node->parent = right;
    node->right = right->left;

    if (right->left)
        right->left->parent = node;

    right->left = node;
    right->parent = parent;

    if (parent) {
        if (parent->left == node)
            parent->left = right;
        else
            parent->right = right;
    } else
        tree->root = right;
}

static void rb_right_rotate(struct rb_tree *tree, struct rb_node *node)
{
    struct rb_node *left = node->left;
    struct rb_node *parent = node->parent;

    node->parent = left;
    node->left = left->right;

    if (left->right)
        left->right->parent = node;

    left->right = node;
    left->parent = parent;

    if (parent) {
        if (parent->right == node)
            parent->right = left;
        else
            parent->left = left;
    } else
        tree->root = left;
}

static void rb_insert_fixup(struct rb_tree *tree, struct rb_node *node)
{
    struct rb_node *parent, *gparent;

    while ((parent = node->parent) && parent->color == RB_RED) {
        gparent = parent->parent;
        if (gparent->left == parent) {
            struct rb_node *uncle = gparent->right;
            if (uncle && uncle->color == RB_RED) {
                parent->color = uncle->color = RB_BLACK;
                gparent->color = RB_RED;
                node = gparent;
                continue;
            }
            if (parent->right == node) {
                rb_left_rotate(tree, parent);
                struct rb_node *tmp = parent;
                parent = node;
                node = tmp;
            }

            parent->color = RB_BLACK;
            gparent->color = RB_RED;
            rb_right_rotate(tree, gparent);
            return;
        } else {
            struct rb_node *uncle = gparent->left;
            if (uncle && uncle->color == RB_RED) {
                parent->color = uncle->color = RB_BLACK;
                gparent->color = RB_RED;
                node = gparent;
                continue;
            }
            if (parent->left == node) {
                rb_right_rotate(tree, parent);
                struct rb_node *tmp = parent;
                parent = node;
                node = tmp;
            }

            parent->color = RB_BLACK;
            gparent->color = RB_RED;
            rb_left_rotate(tree, gparent);
            return;
        }
    }
    if (node == tree->root) {
        node->color = RB_BLACK;
    }
}

int rb_insert(struct rb_tree *tree, int key)
{
    struct rb_node *cur = tree->root;
    struct rb_node *parent = NULL;

    while (cur) {
        parent = cur;
        if (key == cur->key)
            return 0;
        if (key < cur->key)
            cur = cur->left;
        else
            cur = cur->right;
    }

    struct rb_node *node = (struct rb_node *)malloc(sizeof(*node));
    if (!node)
        return 0;

    node->color = RB_RED;
    node->key = key;
    node->parent = parent;
    node->left = node->right = NULL;

    if (parent) {
        if (key < parent->key)
            parent->left = node;
        else
            parent->right = node;
    } else
        tree->root = node;
    rb_insert_fixup(tree, node);
}

const struct rb_node *rb_find(const struct rb_tree *tree, int key)
{
    const struct rb_node *node = tree->root;

    while (node) {
        if (key == node->key)
            break;
        else if (key < node->key)
            node = node->left;
        else
            node = node->right;
    }

    return node;
}

static void rb_earase_fixup(struct rb_tree *tree, struct rb_node *node, struct rb_node *parent)
{
    while ((!node || node->color == RB_BLACK) && node != tree->root) {
        if (parent->left == node) {
            struct rb_node *brother = parent->right;

            if (brother->color == RB_RED) {
                brother->color = RB_BLACK;
                parent->color = RB_RED;
                rb_left_rotate(tree, parent);
                brother = parent->right;
            }

            if ((!brother->left || brother->left->color == RB_BLACK) &&
                (!brother->right || brother->right->color == RB_BLACK)) {
                brother->color = RB_RED;
                node = parent;
                parent = node->parent;
                continue;
            }
            if (!brother->right || brother->right->color == RB_BLACK) {
                brother->left->color = RB_BLACK;
                brother->color = RB_RED;
                rb_right_rotate(tree, brother);
                brother = parent->right;
            }

            brother->right->color = RB_BLACK;
            brother->color = parent->color;
            parent->color = RB_BLACK;
            rb_left_rotate(tree, parent);
            return;
        } else {
            struct rb_node *brother = parent->left;

            if (brother->color == RB_RED) {
                brother->color = RB_BLACK;
                parent->color = RB_RED;
                rb_right_rotate(tree, parent);
                brother = parent->left;
            }

            if ((!brother->right || brother->right->color == RB_BLACK) &&
                (!brother->left || brother->left->color == RB_BLACK)) {
                brother->color = RB_RED;
                node = parent;
                parent = node->parent;
                continue;
            }
            if (!brother->left || brother->left->color == RB_BLACK) {
                brother->right->color = RB_BLACK;
                brother->color = RB_RED;
                rb_left_rotate(tree, brother);
                brother = parent->left;
            }

            brother->left->color = RB_BLACK;
            brother->color = parent->color;
            parent->color = RB_BLACK;
            rb_right_rotate(tree, parent);
            return;
        }
    }

    node->color = RB_BLACK;
}

void rb_earase(struct rb_tree *tree, const struct rb_node *node)
{
    struct rb_node *child, *parent;
    int color;

    if (node->left && node->right) {
        struct rb_node *replace = node->right;
        while (replace->left)
            replace = replace->left;
        color = replace->color;
        parent = replace->parent;
        child = replace->right;

        if (node == parent) {
            parent = replace;
        } else {
            replace->right = node->right;
            node->right->parent = replace;
            parent->left = child;
            if (child)
                child->parent = parent;
        }

        if (node->parent) {
            if (node->parent->left == node)
                node->parent->left = replace;
            else
                node->parent->right = replace;
        } else {
            tree->root = replace;
        }

        replace->left = node->left;
        node->left->parent = replace;
        replace->parent = node->parent;
        replace->color = node->color;
        goto fixup;
    } else if (!node->right) {
        child = node->left;
    } else
        child = node->right;

    parent = node->parent;
    color = node->color;

    if (parent) {
        if (parent->left == node)
            parent->left = child;
        else
            parent->right = child;
    } else
        tree->root = child;

    if (child)
        child->parent = parent;
fixup:
    free((void *)node);
    if (tree->root && color == RB_BLACK)
        rb_earase_fixup(tree, child, parent);
}

static void destroy(struct rb_node **root)
{
    if (!root || !*root)
        return;

    destroy(&(*root)->left);
    destroy(&(*root)->right);
    free(*root);
    *root = NULL;
}

void rb_destroy(struct rb_tree *tree)
{
    destroy(&tree->root);
}

rb_test.c

#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include "rb_tree.h"

void arr_init(int *arr, int len)
{
    for (int i = 0; i < len; i++) {
        arr[i] = i;
    }
}

void shuffle(int *arr, int len)
{
    srand(time(NULL));

    for (int i = 0; i < len; i++) {
        int a = random() % len;
        int b = random() % len;

        int t = arr[a];
        arr[a] = arr[b];
        arr[b] = t;
    }
}

int height(struct rb_node *root, int *ret)
{
    if (!root || !*ret)
        return 0;

    int lh = height(root->left, ret);
    int rh = height(root->right, ret);
    if (lh != rh) {
        *ret = 0;
        return 0;
    }
    if (root->color == RB_BLACK)
        ++lh;
    else {
        if ((root->left && root->left->color == RB_RED) || (root->right && root->right->color == RB_RED)) {
            *ret = 0;
            return 0;
        }
    }
    if ((root->left && root->left->key >= root->key) || (root->right && root->right->key <= root->key)) {
        *ret = 0;
        return 0;
    }
    return lh;
}

int judge(struct rb_tree *tree)
{
    if (tree->root && tree->root->color == RB_RED)
        return 0;
    int ret = 1;

    int h = height(tree->root, &ret);
    return ret;
}

void test(void)
{
    int arr[1000];
    int len = sizeof(arr) / sizeof(arr[0]);
    struct rb_tree tree;
    arr_init(arr, len);
    rb_init(&tree);
    shuffle(arr, len);

    for (int i = 0; i < len; i++) {
        if (rb_insert(&tree, arr[i]) && judge(&tree)) {
            printf("insert %d ok!\n", arr[i]);
        } else {
            printf("insert %d error, exit!\n", arr[i]);
            exit(1);
        }
    }
    printf("insert all numbers ok!\n");

    shuffle(arr, len);
    for (int i = 0; i < len; i++) {
        const struct rb_node *node = rb_find(&tree, arr[i]);
        if (!node) {
            printf("rb_find %d error, exit!\n", arr[i]);
            exit(1);
        }

        rb_earase(&tree, node);
        if (judge(&tree)) {
            printf("earse %d ok!\n", arr[i]);
        } else {
            printf("earse %d error!\n", arr[i]);
            exit(1);
        }
    }

#if 0
    while (tree.root) {
        int key = tree.root->key;
        rb_earase(&tree, tree.root);
        if (judge(&tree))
        {
            printf("earse %d ok!\n", key);
        } else 
        {
            printf("earse %d error!\n", key);
            exit(1);
        }
    }
    printf("earse all numbers ok!\n");
#endif

    // rb_destroy(&tree);
}

int main(void)
{
    test();
    return 0;
}