RedBlackBST
#pragma once
#include <stdexcept>
template<typename Key,typename Value>
class RedBlackBST
{
private:
enum Color
{
BLACK = 0, RED = 1
};
class Node
{
public:
Node* right = nullptr;
Node* left = nullptr;
Key key;
Value value;
Color color;
int count;
public:
Node(const Key& k, const Value& v, const int& n,const Color& c) :key(k), value(v), count(n),color(c)
{
}
};
private:
Node* root = nullptr;
private:
/*******************************************
函数名称: size
函数说明: 得到红黑树节点个数
返回值: int
*******************************************/
int size(Node* r)
{
if (r == nullptr)
return 0;
return r->count;
}
/*******************************************
函数名称: rotateLeft
函数说明: 将右边红链接旋转到左边
返回值: Node*
*******************************************/
Node* rotateLeft(Node* curr)
{
Node* r = curr->right;
r->color = curr->color;
curr->right = r->left;
curr->color = RED;
r->left = curr;
return r;
}
/*******************************************
函数名称: rotateRight
函数说明: 将左边红链接旋转到右边
返回值: Node*
******************************************/
Node* rotateRight(Node* curr)
{
Node* r = curr->left;
r->color = curr->color;
curr->color = RED;
curr->left = r->right;
r->right = curr;
return r;
}
/******************************************************
函数名称: flipColors
函数说明: 若节点左右链接皆为红链接则将其改为黑链接
返回值: void
*******************************************************/
void flipColors(Node* r)
{
r->color = static_cast<Color>(!r->color);
r->left->color = static_cast<Color>(!r->left->color);
r->right->color = static_cast<Color>(!r->right->color);
}
/******************************************
函数名称: isRed
函数说明: 判断该节点连接是否为红链接
返回值: bool
*******************************************/
bool isRed(Node* r)
{
if (r == nullptr)
return false;
return r->color == RED;
}
/*******************************************
函数名称: put
函数说明: 向红黑树中插入key-value新节点
返回值: Node*
*******************************************/
Node* put(Node* r, const Key& k, const Value& v)
{
if (r == nullptr)
return new Node(k, v, 1, RED);
if (r->key > k)
r->left = put(r->left, k, v);
else if (r->key < k)
r->right = put(r->right, k, v);
else
r->value = v;
if (isRed(r->right) && !isRed(r->left)) r = rotateLeft(r);
if (isRed(r->left) && isRed(r->left->left)) r = rotateRight(r);
if (isRed(r->left) && isRed(r->right)) flipColors(r);
r->count = size(r->left) + size(r->right) + 1;
return r;
}
/******************************************
函数名称: get
函数说明: 得到key = k的节点
返回值: Node*
*******************************************/
Node* get(Node* r,const Key& k)
{
if (r == nullptr)
return nullptr;
if (r->key > k)
return get(r->left, k);
else if (r->key < k)
return get(r->right, k);
else
return r;
}
/*******************************************
函数名称: display
函数说明: 中序遍历二叉树
返回值: void
*******************************************/
void display(Node* r)
{
if (r == nullptr)
return;
display(r->left);
cout << "Key: " << r->key << " Value: " << r->value << endl;
display(r->right);
}
/*******************************************
函数名称: min
函数说明: 得到红黑树最小的节点
返回值: Node*
********************************************/
Node* min(Node* r)
{
if (r == nullptr)
return nullptr;
if (r->left == nullptr)
return r;
else
return min(r->left);
}
/******************************************
函数名称: max
函数说明: 得到红黑树中key最大的节点
返回值: Node*
*******************************************/
Node* max(Node* r)
{
if (r == nullptr)
return nullptr;
if (r->right == nullptr)
return r;
else
return max(r->right);
}
/******************************************
函数名称: balance
函数说明: 恢复红黑树的平衡
返回值: Node*
*******************************************/
Node* balance(Node* r)
{
if (isRed(r->right))
r = rotateLeft(r);
if (!isRed(r->left) && isRed(r->right))
r = rotateLeft(r);
if (isRed(r->left) && isRed(r->left->left))
r = rotateRight(r);
if (isRed(r->right) && isRed(r->left))
flipColors(r);
r->count = size(r->left) + size(r->right) + 1;
return r;
}
/******************************************
函数名称: moveRedLeft
函数说明: 向左得到4-结点 即 红 黑 红
返回值: Node*
*******************************************/
Node* moveRedLeft(Node* r)
{
flipColors(r);
if (r->right != nullptr && isRed(r->right->left))
{
r->right = rotateRight(r->right);
r = rotateLeft(r);
}
return r;
}
/******************************************
函数名称: moveRedRight
函数说明: 向左得到4-结点 即 红 黑 红
返回值: Node*
*******************************************/
Node* moveRedRight(Node *r)
{
flipColors(r);
if (isRed(r->left->left))
r = rotateRight(r);
return r;
}
/******************************************
函数名称: deleteMax
函数说明: 删除key最大的节点
返回值: Node*
*******************************************/
Node* deleteMax(Node* r)
{
if (isRed(r->left))
r = rotateRight(r);
if (r->right == nullptr)
return nullptr;
if (!isRed(r->right) && !isRed(r->right->left))
r = moveRedRight(r);
r->right = deleteMax(r->right);
return balance(r);
}
/******************************************
函数名称: deleteMin
函数说明: 删除key最小的节点
返回值: Node*
*******************************************/
Node* deleteMin(Node* r)
{
if (r->left == nullptr)
return nullptr;
if (!isRed(r->left) && !isRed(r->left->left))
r = moveRedLeft(r);
r->left = deleteMin(r->left);
return balance(r);
}
/******************************************
函数名称: erase
函数说明: 删除红黑树中与key相等的结点
返回值: Node*
*******************************************/
Node* erase(Node* r, const Key& k)
{
if (r->key > k)
{
if (!isRed(r->left) && !isRed(r->left->left))
r = moveRedLeft(r);
r->left = erase(r->left, k);
}
else
{
if (isRed(r->left))
r = rotateRight(r);
if (r->key == k && r->right == nullptr)
return nullptr;
if (!isRed(r->right) && !isRed(r->right->left))
r = moveRedRight(r);
if (r->key == k)
{
Node* t = min(r->right);
r->key = t->key;
r->value = t->value;
r->right = deleteMin(r->right);
}
else
r->right = erase(r->right,k);
}
return balance(r);
}
public:
int size()
{
return size(root);
}
void put(const Key& k, const Value& v)
{
root = put(root, k, v);
root->color = BLACK;
}
void display()
{
display(root);
}
Key get(const Key& k)
{
Node* ret;
if ((ret = get(root, k)) == nullptr)
throw std::out_of_range("can't get");
else
return ret->key;
}
void deleteMin()
{
if (root == nullptr)
return;
if (!isRed(root->left) && !isRed(root->right))
root->color = RED;
root = deleteMin(root);
if (root != nullptr)
root->color = BLACK;
}
void deleteMax()
{
if (root == nullptr)
return;
if (!isRed(root->left) && !isRed(root->right))
root->color = RED;
root = deleteMax(root);
if (root != nullptr)
root->color = BLACK;
}
void erase(const Key& k)
{
if (root == nullptr)
return;
if (!isRed(root->left) && !isRed(root->right))
root->color = RED;
root = erase(root, k);
if (root != nullptr)
root->color = BLACK;
}
};
main.cpp
#include <iostream>
#include "RedBlackBST.h"
using namespace std;
int main()
{
RedBlackBST<double,int> rbst;
for (int i = 0; i < 10; ++i)
rbst.put(i + 0.1, i);
cout << rbst.get(4.1);
rbst.deleteMin();
rbst.deleteMin();
rbst.deleteMin();
rbst.deleteMax();
rbst.deleteMax();
rbst.erase(7.1);
rbst.display();
system("pause");
return 0;
}
运行: