#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <kmalloc.h>
#include <rb_tree.h>
#include <assert.h>

/* rb_node_create - create a new rb_node */
static inline rb_node *
rb_node_create(void) {
    return kmalloc(sizeof(rb_node));
}

/* rb_tree_empty - tests if tree is empty */
static inline bool
rb_tree_empty(rb_tree *tree) {
    rb_node *nil = tree->nil, *root = tree->root;
    return root->left == nil;
}

/* *
 * rb_tree_create - creates a new red-black tree, the 'compare' function
 * is required and returns 'NULL' if failed.
 *
 * Note that, root->left should always point to the node that is the root
 * of the tree. And nil points to a 'NULL' node which should always be
 * black and may have arbitrary children and parent node.
 * */
rb_tree *
rb_tree_create(int (*compare)(rb_node *node1, rb_node *node2)) {
    assert(compare != NULL);

    rb_tree *tree;
    rb_node *nil, *root;

    if ((tree = kmalloc(sizeof(rb_tree))) == NULL) {
        goto bad_tree;
    }

    tree->compare = compare;

    if ((nil = rb_node_create()) == NULL) {
        goto bad_node_cleanup_tree;
    }

    nil->parent = nil->left = nil->right = nil;
    nil->red = 0;
    tree->nil = nil;

    if ((root = rb_node_create()) == NULL) {
        goto bad_node_cleanup_nil;
    }

    root->parent = root->left = root->right = nil;
    root->red = 0;
    tree->root = root;
    return tree;

bad_node_cleanup_nil:
    kfree(nil);
bad_node_cleanup_tree:
    kfree(tree);
bad_tree:
    return NULL;
}

/* *
 * FUNC_ROTATE - rotates as described in "Introduction to Algorithm".
 *
 * For example, FUNC_ROTATE(rb_left_rotate, left, right) can be expaned to a
 * left-rotate function, which requires an red-black 'tree' and a node 'x'
 * to be rotated on. Basically, this function, named rb_left_rotate, makes the
 * parent of 'x' be the left child of 'x', 'x' the parent of its parent before
 * rotation and finally fixes other nodes accordingly.
 *
 * FUNC_ROTATE(xx, left, right) means left-rotate,
 * and FUNC_ROTATE(xx, right, left) means right-rotate.
 * */
#define FUNC_ROTATE(func_name, _left, _right)                   \
static void                                                     \
func_name(rb_tree *tree, rb_node *x) {                          \
    rb_node *nil = tree->nil, *y = x->_right;                   \
    assert(x != tree->root && x != nil && y != nil);            \
    x->_right = y->_left;                                       \
    if (y->_left != nil) {                                      \
        y->_left->parent = x;                                   \
    }                                                           \
    y->parent = x->parent;                                      \
    if (x == x->parent->_left) {                                \
        x->parent->_left = y;                                   \
    }                                                           \
    else {                                                      \
        x->parent->_right = y;                                  \
    }                                                           \
    y->_left = x;                                               \
    x->parent = y;                                              \
    assert(!(nil->red));                                        \
}

FUNC_ROTATE(rb_left_rotate, left, right);
FUNC_ROTATE(rb_right_rotate, right, left);

#undef FUNC_ROTATE

#define COMPARE(tree, node1, node2)                             \
    ((tree))->compare((node1), (node2))

/* *
 * rb_insert_binary - insert @node to red-black @tree as if it were
 * a regular binary tree. This function is only intended to be called
 * by function rb_insert.
 * */
static inline void
rb_insert_binary(rb_tree *tree, rb_node *node) {
    rb_node *x, *y, *z = node, *nil = tree->nil, *root = tree->root;

    z->left = z->right = nil;
    y = root, x = y->left;
    while (x != nil) {
        y = x;
        x = (COMPARE(tree, x, node) > 0) ? x->left : x->right;
    }
    z->parent = y;
    if (y == root || COMPARE(tree, y, z) > 0) {
        y->left = z;
    }
    else {
        y->right = z;
    }
}

/* rb_insert - insert a node to red-black tree */
void
rb_insert(rb_tree *tree, rb_node *node) {
    rb_insert_binary(tree, node);
    node->red = 1;

    rb_node *x = node, *y;

#define RB_INSERT_SUB(_left, _right)                            \
    do {                                                        \
        y = x->parent->parent->_right;                          \
        if (y->red) {                                           \
            x->parent->red = 0;                                 \
            y->red = 0;                                         \
            x->parent->parent->red = 1;                         \
            x = x->parent->parent;                              \
        }                                                       \
        else {                                                  \
            if (x == x->parent->_right) {                       \
                x = x->parent;                                  \
                rb_##_left##_rotate(tree, x);                   \
            }                                                   \
            x->parent->red = 0;                                 \
            x->parent->parent->red = 1;                         \
            rb_##_right##_rotate(tree, x->parent->parent);      \
        }                                                       \
    } while (0)

    while (x->parent->red) {
        if (x->parent == x->parent->parent->left) {
            RB_INSERT_SUB(left, right);
        }
        else {
            RB_INSERT_SUB(right, left);
        }
    }
    tree->root->left->red = 0;
    assert(!(tree->nil->red) && !(tree->root->red));

#undef RB_INSERT_SUB
}

/* *
 * rb_tree_successor - returns the successor of @node, or nil
 * if no successor exists. Make sure that @node must belong to @tree,
 * and this function should only be called by rb_node_prev.
 * */
static inline rb_node *
rb_tree_successor(rb_tree *tree, rb_node *node) {
    rb_node *x = node, *y, *nil = tree->nil;

    if ((y = x->right) != nil) {
        while (y->left != nil) {
            y = y->left;
        }
        return y;
    }
    else {
        y = x->parent;
        while (x == y->right) {
            x = y, y = y->parent;
        }
        if (y == tree->root) {
            return nil;
        }
        return y;
    }
}

/* *
 * rb_tree_predecessor - returns the predecessor of @node, or nil
 * if no predecessor exists, likes rb_tree_successor.
 * */
static inline rb_node *
rb_tree_predecessor(rb_tree *tree, rb_node *node) {
    rb_node *x = node, *y, *nil = tree->nil;

    if ((y = x->left) != nil) {
        while (y->right != nil) {
            y = y->right;
        }
        return y;
    }
    else {
        y = x->parent;
        while (x == y->left) {
            if (y == tree->root) {
                return nil;
            }
            x = y, y = y->parent;
        }
        return y;
    }
}

/* *
 * rb_search - returns a node with value 'equal' to @key (according to
 * function @compare). If there're multiple nodes with value 'equal' to @key,
 * the functions returns the one highest in the tree.
 * */
rb_node *
rb_search(rb_tree *tree, int (*compare)(rb_node *node, void *key), void *key) {
    rb_node *nil = tree->nil, *node = tree->root->left;
    int r;
    while (node != nil && (r = compare(node, key)) != 0) {
        node = (r > 0) ? node->left : node->right;
    }
    return (node != nil) ? node : NULL;
}

/* *
 * rb_delete_fixup - performs rotations and changes colors to restore
 * red-black properties after a node is deleted.
 * */
static void
rb_delete_fixup(rb_tree *tree, rb_node *node) {
    rb_node *x = node, *w, *root = tree->root->left;

#define RB_DELETE_FIXUP_SUB(_left, _right)                      \
    do {                                                        \
        w = x->parent->_right;                                  \
        if (w->red) {                                           \
            w->red = 0;                                         \
            x->parent->red = 1;                                 \
            rb_##_left##_rotate(tree, x->parent);               \
            w = x->parent->_right;                              \
        }                                                       \
        if (!w->_left->red && !w->_right->red) {                \
            w->red = 1;                                         \
            x = x->parent;                                      \
        }                                                       \
        else {                                                  \
            if (!w->_right->red) {                              \
                w->_left->red = 0;                              \
                w->red = 1;                                     \
                rb_##_right##_rotate(tree, w);                  \
                w = x->parent->_right;                          \
            }                                                   \
            w->red = x->parent->red;                            \
            x->parent->red = 0;                                 \
            w->_right->red = 0;                                 \
            rb_##_left##_rotate(tree, x->parent);               \
            x = root;                                           \
        }                                                       \
    } while (0)

    while (x != root && !x->red) {
        if (x == x->parent->left) {
            RB_DELETE_FIXUP_SUB(left, right);
        }
        else {
            RB_DELETE_FIXUP_SUB(right, left);
        }
    }
    x->red = 0;

#undef RB_DELETE_FIXUP_SUB
}

/* *
 * rb_delete - deletes @node from @tree, and calls rb_delete_fixup to
 * restore red-black properties.
 * */
void
rb_delete(rb_tree *tree, rb_node *node) {
    rb_node *x, *y, *z = node;
    rb_node *nil = tree->nil, *root = tree->root;

    y = (z->left == nil || z->right == nil) ? z : rb_tree_successor(tree, z);
    x = (y->left != nil) ? y->left : y->right;

    assert(y != root && y != nil);

    x->parent = y->parent;
    if (y == y->parent->left) {
        y->parent->left = x;
    }
    else {
        y->parent->right = x;
    }

    bool need_fixup = !(y->red);

    if (y != z) {
        if (z == z->parent->left) {
            z->parent->left = y;
        }
        else {
            z->parent->right = y;
        }
        z->left->parent = z->right->parent = y;
        *y = *z;
    }
    if (need_fixup) {
        rb_delete_fixup(tree, x);
    }
}

/* rb_tree_destroy - destroy a tree and free memory */
void
rb_tree_destroy(rb_tree *tree) {
    kfree(tree->root);
    kfree(tree->nil);
    kfree(tree);
}

/* *
 * rb_node_prev - returns the predecessor node of @node in @tree,
 * or 'NULL' if no predecessor exists.
 * */
rb_node *
rb_node_prev(rb_tree *tree, rb_node *node) {
    rb_node *prev = rb_tree_predecessor(tree, node);
    return (prev != tree->nil) ? prev : NULL;
}

/* *
 * rb_node_next - returns the successor node of @node in @tree,
 * or 'NULL' if no successor exists.
 * */
rb_node *
rb_node_next(rb_tree *tree, rb_node *node) {
    rb_node *next = rb_tree_successor(tree, node);
    return (next != tree->nil) ? next : NULL;
}

/* rb_node_root - returns the root node of a @tree, or 'NULL' if tree is empty */
rb_node *
rb_node_root(rb_tree *tree) {
    rb_node *node = tree->root->left;
    return (node != tree->nil) ? node : NULL;
}

/* rb_node_left - gets the left child of @node, or 'NULL' if no such node */
rb_node *
rb_node_left(rb_tree *tree, rb_node *node) {
    rb_node *left = node->left;
    return (left != tree->nil) ? left : NULL;
}

/* rb_node_right - gets the right child of @node, or 'NULL' if no such node */
rb_node *
rb_node_right(rb_tree *tree, rb_node *node) {
    rb_node *right = node->right;
    return (right != tree->nil) ? right : NULL;
}

int
check_tree(rb_tree *tree, rb_node *node) {
    rb_node *nil = tree->nil;
    if (node == nil) {
        assert(!node->red);
        return 1;
    }
    if (node->left != nil) {
        assert(COMPARE(tree, node, node->left) >= 0);
        assert(node->left->parent == node);
    }
    if (node->right != nil) {
        assert(COMPARE(tree, node, node->right) <= 0);
        assert(node->right->parent == node);
    }
    if (node->red) {
        assert(!node->left->red && !node->right->red);
    }
    int hb_left = check_tree(tree, node->left);
    int hb_right = check_tree(tree, node->right);
    assert(hb_left == hb_right);
    int hb = hb_left;
    if (!node->red) {
        hb ++;
    }
    return hb;
}

static void *
check_safe_kmalloc(size_t size) {
    void *ret = kmalloc(size);
    assert(ret != NULL);
    return ret;
}

struct check_data {
    long data;
    rb_node rb_link;
};

#define rbn2data(node)              \
    (to_struct(node, struct check_data, rb_link))

static inline int
check_compare1(rb_node *node1, rb_node *node2) {
    return rbn2data(node1)->data - rbn2data(node2)->data;
}

static inline int
check_compare2(rb_node *node, void *key) {
    return rbn2data(node)->data - (long)key;
}

void
check_rb_tree(void) {
    rb_tree *tree = rb_tree_create(check_compare1);
    assert(tree != NULL);

    rb_node *nil = tree->nil, *root = tree->root;
    assert(!nil->red && root->left == nil);

    int total = 1000;
    struct check_data **all = check_safe_kmalloc(sizeof(struct check_data *) * total);

    long i;
    for (i = 0; i < total; i ++) {
        all[i] = check_safe_kmalloc(sizeof(struct check_data));
        all[i]->data = i;
    }

    int *mark = check_safe_kmalloc(sizeof(int) * total);
    memset(mark, 0, sizeof(int) * total);

    for (i = 0; i < total; i ++) {
        mark[all[i]->data] = 1;
    }
    for (i = 0; i < total; i ++) {
        assert(mark[i] == 1);
    }

    for (i = 0; i < total; i ++) {
        int j = (rand() % (total - i)) + i;
        struct check_data *z = all[i];
        all[i] = all[j];
        all[j] = z;
    }

    memset(mark, 0, sizeof(int) * total);
    for (i = 0; i < total; i ++) {
        mark[all[i]->data] = 1;
    }
    for (i = 0; i < total; i ++) {
        assert(mark[i] == 1);
    }

    for (i = 0; i < total; i ++) {
        rb_insert(tree, &(all[i]->rb_link));
        check_tree(tree, root->left);
    }

    rb_node *node;
    for (i = 0; i < total; i ++) {
        node = rb_search(tree, check_compare2, (void *)(all[i]->data));
        assert(node != NULL && node == &(all[i]->rb_link));
    }

    for (i = 0; i < total; i ++) {
        node = rb_search(tree, check_compare2, (void *)i);
        assert(node != NULL && rbn2data(node)->data == i);
        rb_delete(tree, node);
        check_tree(tree, root->left);
    }

    assert(!nil->red && root->left == nil);

    long max = 32;
    if (max > total) {
        max = total;
    }

    for (i = 0; i < max; i ++) {
        all[i]->data = max;
        rb_insert(tree, &(all[i]->rb_link));
        check_tree(tree, root->left);
    }

    for (i = 0; i < max; i ++) {
        node = rb_search(tree, check_compare2, (void *)max);
        assert(node != NULL && rbn2data(node)->data == max);
        rb_delete(tree, node);
        check_tree(tree, root->left);
    }

    assert(rb_tree_empty(tree));

    for (i = 0; i < total; i ++) {
        rb_insert(tree, &(all[i]->rb_link));
        check_tree(tree, root->left);
    }

    rb_tree_destroy(tree);

    for (i = 0; i < total; i ++) {
        kfree(all[i]);
    }

    kfree(mark);
    kfree(all);
}