/*====================================================================
BSD 2-Clause License
Copyright (c) 2025, Ruler
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
====================================================================*/
#pragma once
#ifndef __CORE_RB_TREE_H__
#define __CORE_RB_TREE_H__
#include <memory>
#include <stdexcept>
#include <iterator>
#include <functional>
#include <utility>
#include "define.h"
#ifndef DEFAULT_ALLOCATOR
#define DEFAULT_ALLOCATOR(T) std::allocator<T>
#endif // !DEFAULT_ALLOCATOR
// Enum class rbt_color
enum class rbt_color : bool
{
red = false,
black = true
};
// Enum class rbt_type
enum class rbt_type : signed char
{
root = 0x00,
parent = -0x0F,
left = 0x12,
right = 0x13,
sibling = 0x04
};
// Class template rbt_node
template <class T>
struct rbt_node
{
using node_type = rbt_node<T>;
using node_pointer = node_type*;
using const_node_pointer = const node_type*;
using node_reference = node_type&;
using const_node_reference = const node_type&;
node_pointer parent;
node_pointer left;
node_pointer right;
rbt_color color;
T data;
};
// Class template rbt_type_traits
template <class Tree, bool IsConst>
struct rbt_type_traits
{
using value_type = typename Tree::value_type;
using pointer = typename Tree::pointer;
using size_type = typename Tree::size_type;
using difference_type = typename Tree::difference_type;
using node_type = typename Tree::node_type;
using node_pointer = typename Tree::node_pointer;
using reference = value_type&;
};
template <class Tree>
struct rbt_type_traits<Tree, true>
{
using value_type = typename Tree::value_type;
using pointer = typename Tree::const_pointer;
using size_type = typename Tree::size_type;
using difference_type = typename Tree::difference_type;
using node_type = typename Tree::node_type;
using node_pointer = typename Tree::node_pointer;
using reference = const value_type&;
};
// Class template rbt_iterator
template <class Tree, bool IsConst>
class rbt_iterator
{
public:
// types:
using value_type = typename rbt_type_traits<Tree, IsConst>::value_type;
using pointer = typename rbt_type_traits<Tree, IsConst>::pointer;
using reference = typename rbt_type_traits<Tree, IsConst>::reference;
using size_type = typename rbt_type_traits<Tree, IsConst>::size_type;
using difference_type = typename rbt_type_traits<Tree, IsConst>::difference_type;
using node_type = typename rbt_type_traits<Tree, IsConst>::node_type;
using node_pointer = typename rbt_type_traits<Tree, IsConst>::node_pointer;
using iterator_type = rbt_iterator<Tree, IsConst>;
using iterator_category = std::bidirectional_iterator_tag;
// construct/copy/destroy:
rbt_iterator(void) noexcept
: _node(nullptr)
{}
explicit rbt_iterator(const node_pointer node) noexcept
: _node(node)
{}
rbt_iterator(const iterator_type& other) noexcept
: _node(other.get_pointer())
{}
inline iterator_type& operator=(const iterator_type& other) noexcept
{
if (this != &other)
_node = other.get_pointer();
return *this;
}
inline operator rbt_iterator<Tree, true>(void) const noexcept
{
return rbt_iterator<Tree, true>(_node);
}
// rbt_iterator operations:
inline node_pointer get_parent(void) noexcept
{
return _node->parent;
}
inline const node_pointer get_parent(void) const noexcept
{
return _node->parent;
}
inline node_pointer get_pointer(void) noexcept
{
return _node;
}
inline const node_pointer get_pointer(void) const noexcept
{
return _node;
}
inline rbt_color get_color(void) const noexcept
{
return _node->color;
}
inline reference operator*(void) const noexcept
{
return _node->data;
}
inline pointer operator->(void) const noexcept
{
return &(operator*());
}
// increment / decrement
iterator_type& operator++(void) noexcept
{
if (_node->right)
{
_node = _node->right;
while (_node->left)
_node = _node->left;
}
else
{
node_pointer p = _node->parent;
while (_node == p->right)
{
_node = p;
p = p->parent;
}
if (_node->right != p)
_node = p;
}
return *this;
}
iterator_type& operator--(void) noexcept
{
if (_node->color == rbt_color::red && _node->parent->parent == _node)
_node = _node->right;
else if (_node->left)
{
node_pointer p = _node->left;
while (p->right)
p = p->right;
_node = p;
}
else
{
node_pointer p = _node->parent;
while (_node == p->left)
{
_node = p;
p = p->parent;
}
_node = p;
}
return *this;
}
inline iterator_type operator++(int) noexcept
{
iterator_type itr(*this);
this->operator++();
return itr;
}
inline iterator_type operator--(int) noexcept
{
iterator_type itr(*this);
this->operator--();
return itr;
}
// relational operators:
template <bool is_const>
inline bool operator==(const rbt_iterator<Tree, is_const>& rhs) const noexcept
{
return _node == rhs.get_pointer();
}
template <bool is_const>
inline bool operator!=(const rbt_iterator<Tree, is_const>& rhs) const noexcept
{
return _node != rhs.get_pointer();
}
private:
node_pointer _node;
};
// Class template rbt_leaf_iterator
template <class Tree, bool IsConst>
class rbt_leaf_iterator
{
public:
// types:
using value_type = typename rbt_type_traits<Tree, IsConst>::value_type;
using pointer = typename rbt_type_traits<Tree, IsConst>::pointer;
using reference = typename rbt_type_traits<Tree, IsConst>::reference;
using size_type = typename rbt_type_traits<Tree, IsConst>::size_type;
using difference_type = typename rbt_type_traits<Tree, IsConst>::difference_type;
using node_type = typename rbt_type_traits<Tree, IsConst>::node_type;
using node_pointer = typename rbt_type_traits<Tree, IsConst>::node_pointer;
using iterator_type = rbt_leaf_iterator<Tree, IsConst>;
using iterator_category = std::bidirectional_iterator_tag;
// construct/copy/destroy:
rbt_leaf_iterator(void) noexcept
: _node(nullptr)
{}
explicit rbt_leaf_iterator(const node_pointer node) noexcept
: _node(node)
{}
rbt_leaf_iterator(const iterator_type& other) noexcept
: _node(other.get_pointer())
{}
inline iterator_type& operator=(const iterator_type& other) noexcept
{
if (this != &other)
_node = other.get_pointer();
return *this;
}
inline operator rbt_leaf_iterator<Tree, true>(void) const noexcept
{
return rbt_leaf_iterator<Tree, true>(_node);
}
// rbt_leaf_iterator operations:
inline node_pointer get_parent(void) noexcept
{
return _node->parent;
}
inline const node_pointer get_parent(void) const noexcept
{
return _node->parent;
}
inline node_pointer get_pointer(void) noexcept
{
return _node;
}
inline const node_pointer get_pointer(void) const noexcept
{
return _node;
}
inline rbt_color get_color(void) const noexcept
{
return _node->color;
}
inline reference operator*(void) const noexcept
{
return _node->data;
}
inline pointer operator->(void) const noexcept
{
return &(operator*());
}
std::pair<bool, size_t> verify(void) const noexcept
{
bool valid = true;
size_t count = 1;
rbt_color color = rbt_color::black;
for (node_pointer p = _node; p->parent->parent != p; p = p->parent)
{
// two consecutive red nodes are not permitted
if (p->color == rbt_color::red)
{
if (color == rbt_color::red)
valid = false;
}
// count the number of black nodes
else
++count;
color = p->color;
}
return std::make_pair(valid, count);
}
// increment / decrement
inline iterator_type& operator++(void) noexcept
{
do
{
if (_node->right)
{
_node = _node->right;
while (_node->left)
_node = _node->left;
}
else
{
node_pointer p = _node->parent;
while (_node == p->right)
{
_node = p;
p = p->parent;
}
if (_node->right != p)
_node = p;
}
} while ((_node->color == rbt_color::black || _node->parent->parent != _node) &&
(_node->left || _node->right));
return *this;
}
inline iterator_type& operator--(void) noexcept
{
do
{
if (_node->color == rbt_color::red && _node->parent->parent == _node)
_node = _node->right;
else if (_node->left)
{
node_pointer p = _node->left;
while (p->right)
p = p->right;
_node = p;
}
else
{
node_pointer p = _node->parent;
while (_node == p->left)
{
_node = p;
p = p->parent;
}
_node = p;
}
} while ((_node->color == rbt_color::black || _node->parent->parent != _node) &&
(_node->left || _node->right));
return *this;
}
inline iterator_type operator++(int) noexcept
{
iterator_type itr(*this);
this->operator++();
return itr;
}
inline iterator_type operator--(int) noexcept
{
iterator_type itr(*this);
this->operator--();
return itr;
}
// relational operators:
template <bool is_const>
inline bool operator==(const rbt_leaf_iterator<Tree, is_const>& rhs) const noexcept
{
return _node == rhs.get_pointer();
}
template <bool is_const>
inline bool operator!=(const rbt_leaf_iterator<Tree, is_const>& rhs) const noexcept
{
return _node != rhs.get_pointer();
}
private:
node_pointer _node;
};
// Class template rbt_primitive_iterator
template <class Tree, bool IsConst>
class rbt_primitive_iterator
{
public:
// types:
using value_type = typename rbt_type_traits<Tree, IsConst>::value_type;
using pointer = typename rbt_type_traits<Tree, IsConst>::pointer;
using reference = typename rbt_type_traits<Tree, IsConst>::reference;
using size_type = typename rbt_type_traits<Tree, IsConst>::size_type;
using difference_type = typename rbt_type_traits<Tree, IsConst>::difference_type;
using node_type = typename rbt_type_traits<Tree, IsConst>::node_type;
using node_pointer = typename rbt_type_traits<Tree, IsConst>::node_pointer;
using iterator_type = rbt_primitive_iterator<Tree, IsConst>;
using iterator_category = std::bidirectional_iterator_tag;
// construct/copy/destroy:
rbt_primitive_iterator(void) noexcept
: _node(nullptr)
, _type(rbt_type::root)
{}
explicit rbt_primitive_iterator(const node_pointer node) noexcept
: _node(node)
, _type(rbt_type::root)
{}
rbt_primitive_iterator(const iterator_type& other) noexcept
: _node(other.get_pointer())
, _type(other.get_type())
{}
inline iterator_type& operator=(const iterator_type& other) noexcept
{
if (this != &other)
{
_node = other.get_pointer();
_type = other.get_type();
}
return *this;
}
inline operator rbt_primitive_iterator<Tree, true>(void) const noexcept
{
return rbt_primitive_iterator<Tree, true>(_node);
}
// rbt_primitive_iterator operations:
inline node_pointer get_parent(void) noexcept
{
return _node->parent;
}
inline const node_pointer get_parent(void) const noexcept
{
return _node->parent;
}
inline node_pointer get_pointer(void) noexcept
{
return _node;
}
inline const node_pointer get_pointer(void) const noexcept
{
return _node;
}
inline rbt_type get_type(void) const noexcept
{
return _type;
}
inline intptr_t get_depth(void) const noexcept
{
return static_cast<intptr_t>(std::underlying_type_t<rbt_type>(_type) >> 4);
}
inline rbt_color get_color(void) const noexcept
{
return _node->color;
}
inline reference operator*(void) const noexcept
{
return _node->data;
}
inline pointer operator->(void) const noexcept
{
return &(operator*());
}
// increment / decrement
iterator_type& operator++(void) noexcept
{
if (_type != rbt_type::parent && _node->left)
{
_node = _node->left;
_type = rbt_type::left;
}
else if (_type != rbt_type::parent && _node->right)
{
_node = _node->right;
_type = rbt_type::right;
}
else if (_node != _node->parent->parent &&
_node->parent->right && _node != _node->parent->right)
{
_node = _node->parent->right;
_type = rbt_type::sibling;
}
else
{
_node = _node->parent;
_type = rbt_type::parent;
}
return *this;
}
iterator_type& operator--(void) noexcept
{
if (_type != rbt_type::parent && _node->right)
{
_node = _node->right;
_type = rbt_type::right;
}
else if (_type != rbt_type::parent && _node->left)
{
_node = _node->left;
_type = rbt_type::left;
}
else if (_node != _node->parent->parent &&
_node->parent->left && _node != _node->parent->left)
{
_node = _node->parent->left;
_type = rbt_type::sibling;
}
else
{
_node = _node->parent;
_type = rbt_type::parent;
}
return *this;
}
inline iterator_type operator++(int) noexcept
{
iterator_type itr(*this);
this->operator++();
return itr;
}
inline iterator_type operator--(int) noexcept
{
iterator_type itr(*this);
this->operator--();
return itr;
}
// relational operators:
template <bool is_const>
inline bool operator==(const rbt_primitive_iterator<Tree, is_const>& rhs) const noexcept
{
return _node == rhs.get_pointer();
}
template <bool is_const>
inline bool operator!=(const rbt_primitive_iterator<Tree, is_const>& rhs) const noexcept
{
return _node != rhs.get_pointer();
}
private:
node_pointer _node;
rbt_type _type;
};
// Class template rbt_node_allocator
template <class T, class Allocator>
class rbt_node_allocator
: public std::allocator_traits<Allocator>::template rebind_alloc<T>
{
public:
// types:
using tree_traits_type = std::allocator_traits<Allocator>;
using tree_node_type = typename rbt_node<T>::node_type;
using allocator_type = typename tree_traits_type::template rebind_alloc<T>;
using allocator_traits = typename tree_traits_type::template rebind_traits<T>;
using node_allocator_type = typename tree_traits_type::template rebind_alloc<tree_node_type>;
using node_allocator_traits = typename tree_traits_type::template rebind_traits<tree_node_type>;
using node_type = typename node_allocator_traits::value_type;
using node_pointer = typename node_allocator_traits::pointer;
using node_size_type = typename node_allocator_traits::size_type;
using node_difference_type = typename node_allocator_traits::difference_type;
// construct/copy/destroy:
rbt_node_allocator(void)
: allocator_type()
, _alloc()
{}
explicit rbt_node_allocator(const allocator_type& alloc)
: allocator_type(alloc)
, _alloc()
{}
// rbt_node_allocator operations:
constexpr allocator_type& get_allocator(void) noexcept
{
return *this;
}
constexpr const allocator_type& get_allocator(void) const noexcept
{
return *this;
}
constexpr node_allocator_type& get_node_allocator(void) noexcept
{
return _alloc;
}
constexpr node_allocator_type& get_node_allocator(void) const noexcept
{
return _alloc;
}
inline node_size_type max_size(void) const noexcept
{
return _alloc.max_size();
}
protected:
template <class ...Args>
inline node_pointer create_node(Args&&... args)
{
node_pointer p = node_allocator_traits::allocate(_alloc, 1);
allocator_traits::construct(get_allocator(), std::addressof(p->data), std::forward<Args>(args)...);
return p;
}
inline void destroy_node(const node_pointer p)
{
allocator_traits::destroy(get_allocator(), std::addressof(p->data));
node_allocator_traits::deallocate(_alloc, p, 1);
}
private:
node_allocator_type _alloc;
};
// Class template rb_tree
template <class Key, class Value, class Pred = std::less<Key>, class Allocator = DEFAULT_ALLOCATOR(void)>
class rb_tree : public rbt_node_allocator<std::pair<const Key, Value>, Allocator>
{
public:
// types:
using key_type = Key;
using key_compare = Pred;
using mapped_type = Value;
using value_type = std::pair<const Key, Value>;
using tree_type = rb_tree<Key, Value, Pred, Allocator>;
using tree_base = rbt_node_allocator<value_type, Allocator>;
using node_type = typename tree_base::node_type;
using node_pointer = typename tree_base::node_pointer;
using allocator_type = typename tree_base::allocator_type;
using allocator_traits = typename tree_base::allocator_traits;
using size_type = typename allocator_traits::size_type;
using difference_type = typename allocator_traits::difference_type;
using pointer = typename allocator_traits::pointer;
using const_pointer = typename allocator_traits::const_pointer;
using reference = value_type&;
using const_reference = const value_type&;
using iterator = rbt_iterator<tree_type, false>;
using const_iterator = rbt_iterator<tree_type, true>;
using leaf_iterator = rbt_leaf_iterator<tree_type, false>;
using const_leaf_iterator = rbt_leaf_iterator<tree_type, true>;
using primitive_iterator = rbt_primitive_iterator<tree_type, false>;
using const_primitive_iterator = rbt_primitive_iterator<tree_type, true>;
using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
using reverse_leaf_iterator = std::reverse_iterator<leaf_iterator>;
using const_reverse_leaf_iterator = std::reverse_iterator<const_leaf_iterator>;
using reverse_primitive_iterator = std::reverse_iterator<primitive_iterator>;
using const_reverse_primitive_iterator = std::reverse_iterator<const_primitive_iterator>;
// construct/copy/destroy:
explicit rb_tree(const key_compare& pred = key_compare(), const allocator_type& alloc = allocator_type())
: tree_base(alloc)
, _head(nullptr)
, _size(0)
, _comp(pred)
{
_head = create_header();
}
explicit rb_tree(const allocator_type& alloc)
: tree_base(alloc)
, _head(nullptr)
, _size(0)
, _comp()
{
_head = create_header();
}
rb_tree(const tree_type& other)
: tree_base(other.get_allocator())
, _head(nullptr)
, _size(0)
, _comp(other.key_comp())
{
_head = create_header();
insert(other.begin(), other.end());
}
rb_tree(const tree_type& other, const allocator_type& alloc)
: tree_base(alloc)
, _head(nullptr)
, _size(0)
, _comp(other.key_comp())
{
_head = create_header();
insert(other.begin(), other.end());
}
rb_tree(tree_type&& other) noexcept
: tree_base(other.get_allocator())
, _head(nullptr)
, _size(0)
, _comp(other.key_comp())
{
_head = create_header();
swap(other);
}
rb_tree(tree_type&& other, const allocator_type& alloc) noexcept
: tree_base(alloc)
, _head(nullptr)
, _size(0)
, _comp(other.key_comp())
{
_head = create_header();
swap(other);
}
template <class InputIt>
rb_tree(InputIt first, InputIt last, const allocator_type& alloc = allocator_type())
: tree_base(alloc)
, _head(nullptr)
, _size(0)
, _comp()
{
_head = create_header();
insert(first, last);
}
rb_tree(std::initializer_list<value_type> init, const allocator_type& alloc = allocator_type())
: tree_base(alloc)
, _head(nullptr)
, _size(0)
, _comp()
{
_head = create_header();
insert(init.begin(), init.end());
}
~rb_tree(void)
{
clear();
destroy_header(_head);
}
inline tree_type& operator=(const tree_type& other)
{
if (this != &other)
{
clear();
insert(other.begin(), other.end());
}
return *this;
}
inline tree_type& operator=(tree_type&& other) noexcept
{
swap(other);
return *this;
}
// iterators:
inline iterator begin(void) noexcept
{
return iterator(_head->left);
}
inline const_iterator begin(void) const noexcept
{
return const_iterator(_head->left);
}
inline const_iterator cbegin(void) const noexcept
{
return const_iterator(_head->left);
}
inline iterator end(void) noexcept
{
return iterator(_head);
}
inline const_iterator end(void) const noexcept
{
return const_iterator(_head);
}
inline const_iterator cend(void) const noexcept
{
return const_iterator(_head);
}
inline reverse_iterator rbegin(void) noexcept
{
return reverse_iterator(end());
}
inline const_reverse_iterator rbegin(void) const noexcept
{
return const_reverse_iterator(end());
}
inline const_reverse_iterator crbegin(void) const noexcept
{
return const_reverse_iterator(cend());
}
inline reverse_iterator rend(void) noexcept
{
return reverse_iterator(begin());
}
inline const_reverse_iterator rend(void) const noexcept
{
return const_reverse_iterator(begin());
}
inline const_reverse_iterator crend(void) const noexcept
{
return const_reverse_iterator(cbegin());
}
inline leaf_iterator lbegin(void) noexcept
{
return leaf_iterator(_head->left);
}
inline const_leaf_iterator lbegin(void) const noexcept
{
return const_leaf_iterator(_head->left);
}
inline const_leaf_iterator clbegin(void) const noexcept
{
return const_leaf_iterator(_head->left);
}
inline leaf_iterator lend(void) noexcept
{
return leaf_iterator(_head);
}
inline const_leaf_iterator lend(void) const noexcept
{
return const_leaf_iterator(_head);
}
inline const_leaf_iterator clend(void) const noexcept
{
return const_leaf_iterator(_head);
}
inline reverse_leaf_iterator rlbegin(void) noexcept
{
return reverse_leaf_iterator(lend());
}
inline const_reverse_leaf_iterator rlbegin(void) const noexcept
{
return const_reverse_leaf_iterator(lend());
}
inline const_reverse_leaf_iterator crlbegin(void) const noexcept
{
return const_reverse_leaf_iterator(clend());
}
inline reverse_leaf_iterator rlend(void) noexcept
{
return reverse_leaf_iterator(lbegin());
}
inline const_reverse_leaf_iterator rlend(void) const noexcept
{
return const_reverse_leaf_iterator(lbegin());
}
inline const_reverse_leaf_iterator crlend(void) const noexcept
{
return const_reverse_leaf_iterator(clbegin());
}
inline primitive_iterator pbegin(void) noexcept
{
return primitive_iterator(root());
}
inline const_primitive_iterator pbegin(void) const noexcept
{
return const_primitive_iterator(root());
}
inline const_primitive_iterator cpbegin(void) const noexcept
{
return const_primitive_iterator(root());
}
inline primitive_iterator pend(void) noexcept
{
return primitive_iterator(_head);
}
inline const_primitive_iterator pend(void) const noexcept
{
return const_primitive_iterator(_head);
}
inline const_primitive_iterator cpend(void) const noexcept
{
return const_primitive_iterator(_head);
}
inline reverse_primitive_iterator rpbegin(void) noexcept
{
return reverse_primitive_iterator(pend());
}
inline const_reverse_primitive_iterator rpbegin(void) const noexcept
{
return const_reverse_primitive_iterator(pend());
}
inline const_reverse_primitive_iterator crpbegin(void) const noexcept
{
return const_reverse_primitive_iterator(cpend());
}
inline reverse_primitive_iterator rpend(void) noexcept
{
return reverse_primitive_iterator(pbegin());
}
inline const_reverse_primitive_iterator rpend(void) const noexcept
{
return const_reverse_primitive_iterator(pbegin());
}
inline const_reverse_primitive_iterator crpend(void) const noexcept
{
return const_reverse_primitive_iterator(cpbegin());
}
// capacity:
inline bool empty(void) const noexcept
{
return !_head->parent;
}
inline size_type size(void) const noexcept
{
return _size;
}
// observers:
inline key_compare key_comp(void) const
{
return _comp;
}
// element access:
inline reference at(const key_type& key)
{
if (empty())
throw std::domain_error(RBT_NOT_INITIALIZED);
iterator itr = find(key);
if (itr == end())
throw std::out_of_range(RBT_OUT_OF_RANGE);
return *itr;
}
inline const_reference at(const key_type& key) const
{
if (empty())
throw std::domain_error(RBT_NOT_INITIALIZED);
const_iterator itr = find(key);
if (itr == cend())
throw std::out_of_range(RBT_OUT_OF_RANGE);
return *itr;
}
// modifiers:
template <class... Args>
inline std::pair<iterator, bool> emplace(Args&&... args)
{
std::pair<node_pointer, bool> res = insert_node(_head->parent, std::forward<Args>(args)...);
return std::make_pair(iterator(res.first), res.second);
}
template <class... Args>
inline iterator emplace_hint(const_iterator hint, Args&&... args)
{
std::pair<node_pointer, bool> res = insert_node(hint.get_pointer(), std::forward<Args>(args)...);
return iterator(res.first);
}
inline std::pair<iterator, bool> insert(const value_type& value)
{
std::pair<node_pointer, bool> res = insert_node(_head->parent, value);
return std::make_pair(iterator(res.first), res.second);
}
inline std::pair<iterator, bool> insert(value_type&& value)
{
std::pair<node_pointer, bool> res = insert_node(_head->parent, std::move(value));
return std::make_pair(iterator(res.first), res.second);
}
inline iterator insert(const_iterator pos, const value_type& value)
{
std::pair<node_pointer, bool> res = insert_node(pos.get_pointer(), value);
return iterator(res.first);
}
inline iterator insert(const_iterator pos, value_type&& value)
{
std::pair<node_pointer, bool> res = insert_node(pos.get_pointer(), std::move(value));
return iterator(res.first);
}
template <class InputIt>
inline void insert(InputIt first, InputIt last)
{
for (; first != last; ++first)
insert_node(_head->parent, *first);
}
inline void insert(std::initializer_list<value_type> init)
{
insert(init.begin(), init.end());
}
inline iterator erase(const_iterator pos)
{
iterator next = iterator(pos.get_pointer());
if (pos != cend())
{
++next;
erase_node(pos.get_pointer());
}
return next;
}
inline iterator erase(const_iterator first, const_iterator last)
{
iterator next = iterator(last.get_pointer());
if (first == cbegin() && last == cend())
clear();
else
while (first != last)
erase(first++);
return next;
}
inline size_type erase(const key_type& key)
{
size_type n = 0;
iterator first = lower_bound(key);
iterator last = upper_bound(key);
while (first != last)
{
++n;
erase(first++);
}
return n;
}
inline void swap(tree_type& other) noexcept
{
if (this != &other)
{
std::swap(_head, other._head);
std::swap(_size, other._size);
std::swap(_comp, other._comp);
}
}
inline void clear(void)
{
if (_head->parent)
{
erase_root();
_head->parent = nullptr;
_head->left = _head;
_head->right = _head;
_size = 0;
}
}
// operations:
inline iterator find(const key_type& key) noexcept
{
return iterator(find_node(key));
}
inline const_iterator find(const key_type& key) const noexcept
{
return const_iterator(find_node(key));
}
inline iterator lower_bound(const key_type& key) noexcept
{
return iterator(lower_bound_node(key));
}
inline const_iterator lower_bound(const key_type& key) const noexcept
{
return const_iterator(lower_bound_node(key));
}
inline iterator upper_bound(const key_type& key) noexcept
{
return iterator(upper_bound_node(key));
}
inline const_iterator upper_bound(const key_type& key) const noexcept
{
return const_iterator(upper_bound_node(key));
}
private:
inline node_pointer root(void) const noexcept
{
return _head->parent ? _head->parent : _head;
}
inline node_pointer leftmost(node_pointer t) const noexcept
{
while (t->left)
t = t->left;
return t;
}
inline node_pointer rightmost(node_pointer t) const noexcept
{
while (t->right)
t = t->right;
return t;
}
inline node_pointer create_header(void)
{
node_pointer head = this->create_node(value_type());
head->parent = nullptr;
head->left = head;
head->right = head;
head->color = rbt_color::red;
return head;
}
inline void destroy_header(node_pointer& head)
{
if (head)
{
this->destroy_node(head);
head = nullptr;
}
}
node_pointer find_node(const key_type& key) const noexcept
{
node_pointer pre = _head;
node_pointer cur = _head->parent;
while (cur)
{
if (!_comp(cur->data.first, key))
{
pre = cur;
cur = cur->left;
}
else
cur = cur->right;
}
if (_comp(key, pre->data.first))
pre = _head;
return pre;
}
node_pointer lower_bound_node(const key_type& key) const noexcept
{
node_pointer pre = _head;
node_pointer cur = _head->parent;
while (cur)
{
if (!_comp(cur->data.first, key))
{
pre = cur;
cur = cur->left;
}
else
cur = cur->right;
}
return pre;
}
node_pointer upper_bound_node(const key_type& key) const noexcept
{
node_pointer pre = _head;
node_pointer cur = _head->parent;
while (cur)
{
if (_comp(key, cur->data.first))
{
pre = cur;
cur = cur->left;
}
else
cur = cur->right;
}
return pre;
}
template <class ...Args>
std::pair<node_pointer, bool> insert_node(node_pointer t, Args&&... args)
{
node_pointer n = nullptr;
value_type val = value_type(std::forward<Args>(args)...);
// if the tree is not empty
if (_head->parent)
{
while (t)
{
if (_comp(val.first, t->data.first))
{
if (t->left)
t = t->left;
else
{
// creates a new node
n = this->create_node(std::move(val));
n->parent = t;
n->left = nullptr;
n->right = nullptr;
n->color = rbt_color::red;
// inserts the node
t->left = n;
// replaces the minimum value
if (t == _head->left)
_head->left = n;
t = nullptr;
}
}
else
{
// if it already exists
if (!_comp(t->data.first, val.first))
return std::make_pair(t, false);
if (t->right)
t = t->right;
else
{
// creates a new node
n = this->create_node(std::move(val));
n->parent = t;
n->left = nullptr;
n->right = nullptr;
n->color = rbt_color::red;
// inserts the node
t->right = n;
// replaces the maximum value
if (t == _head->right)
_head->right = n;
t = nullptr;
}
}
}
// rebalance after insertion
insert_rebalance(n);
}
// if the tree is empty
else
{
// creates a new node
n = this->create_node(std::move(val));
n->parent = _head;
n->left = nullptr;
n->right = nullptr;
n->color = rbt_color::black;
// inserts the node
_head->parent = n;
// the minimum value
_head->left = n;
// the maximum value
_head->right = n;
}
++_size;
return std::make_pair(n, true);
}
void erase_node(node_pointer t)
{
node_pointer x, y, p;
// if node t has two child nodes
if (t->left && t->right)
{
y = leftmost(t->right);
x = y->right;
// replaces node t with node y and removes node t
t->left->parent = y;
y->left = t->left;
if (y != t->right)
{
y->parent->left = y->right;
if (y->right)
y->right->parent = y->parent;
t->right->parent = y;
y->right = t->right;
p = y->parent;
}
else
p = y;
if (t == _head->parent)
_head->parent = y;
else if (t == t->parent->left)
t->parent->left = y;
else
t->parent->right = y;
y->parent = t->parent;
// swap the colors of node y and node t
std::swap(y->color, t->color);
// rebalance after deletion
if (t->color != rbt_color::red)
erase_rebalance(p, x);
}
// if node t has one child node at most
else
{
x = t->left ? t->left : t->right;
// removes node t
if (x)
x->parent = t->parent;
if (t == _head->parent)
_head->parent = x;
else if (t == t->parent->left)
t->parent->left = x;
else
t->parent->right = x;
// replaces the minimum value
if (t == _head->left)
_head->left = x ? leftmost(x) : t->parent;
// replaces the maximum value
if (t == _head->right)
_head->right = x ? rightmost(x) : t->parent;
// rebalance after deletion
if (t->color != rbt_color::red)
erase_rebalance(t->parent, x);
}
// destroy node
this->destroy_node(t);
--_size;
}
void erase_root(void)
{
node_pointer next;
node_pointer cur = _head->parent;
do
{
while (cur->left)
cur = cur->left;
if (cur->right)
cur = cur->right;
else
{
next = cur->parent;
if (cur == next->left)
next->left = nullptr;
else
next->right = nullptr;
this->destroy_node(cur);
cur = next;
}
} while (cur != _head);
}
void rotate_left(node_pointer t) const noexcept
{
node_pointer r = t->right;
t->right = r->left;
if (r->left)
r->left->parent = t;
r->parent = t->parent;
if (t == _head->parent)
_head->parent = r;
else if (t == t->parent->left)
t->parent->left = r;
else
t->parent->right = r;
r->left = t;
t->parent = r;
}
void rotate_right(node_pointer t) const noexcept
{
node_pointer l = t->left;
t->left = l->right;
if (l->right)
l->right->parent = t;
l->parent = t->parent;
if (t == _head->parent)
_head->parent = l;
else if (t == t->parent->right)
t->parent->right = l;
else
t->parent->left = l;
l->right = t;
t->parent = l;
}
void insert_rebalance(node_pointer x) const noexcept
{
node_pointer u, p = x->parent;
while (p != _head && p->color == rbt_color::red)
{
if (p == p->parent->left)
{
// case 1: color(U) == red
u = p->parent->right;
if (u && u->color == rbt_color::red)
{
p->color = rbt_color::black;
u->color = rbt_color::black;
p->parent->color = rbt_color::red;
x = p->parent;
p = x->parent;
}
// case 2: U == null or color(U) == black
else
{
if (x == p->right)
{
rotate_left(p);
p = x;
}
rotate_right(p->parent);
p->color = rbt_color::black;
p->right->color = rbt_color::red;
break;
}
}
else
{
// case 3: color(U) == red
u = p->parent->left;
if (u && u->color == rbt_color::red)
{
p->color = rbt_color::black;
u->color = rbt_color::black;
p->parent->color = rbt_color::red;
x = p->parent;
p = x->parent;
}
// case 4: U == null or color(U) == black
else
{
if (x == p->left)
{
rotate_right(p);
p = x;
}
rotate_left(p->parent);
p->color = rbt_color::black;
p->left->color = rbt_color::red;
break;
}
}
}
_head->parent->color = rbt_color::black;
}
void erase_rebalance(node_pointer p, node_pointer x) const noexcept
{
node_pointer s;
while (x != _head->parent && (!x || x->color == rbt_color::black))
{
if (x == p->left)
{
s = p->right;
if (s->color == rbt_color::red)
{
s->color = rbt_color::black;
p->color = rbt_color::red;
rotate_left(p);
s = p->right;
}
if ((!s->left || s->left->color == rbt_color::black) &&
(!s->right || s->right->color == rbt_color::black))
{
s->color = rbt_color::red;
x = p;
p = x->parent;
}
else
{
if (!s->right || s->right->color == rbt_color::black)
{
s->color = rbt_color::red;
s->left->color = rbt_color::black;
rotate_right(s);
s = p->right;
}
s->color = p->color;
p->color = rbt_color::black;
if (s->right)
s->right->color = rbt_color::black;
rotate_left(p);
break;
}
}
else
{
s = p->left;
if (s->color == rbt_color::red)
{
s->color = rbt_color::black;
p->color = rbt_color::red;
rotate_right(p);
s = p->left;
}
if ((!s->right || s->right->color == rbt_color::black) &&
(!s->left || s->left->color == rbt_color::black))
{
s->color = rbt_color::red;
x = p;
p = x->parent;
}
else
{
if (!s->left || s->left->color == rbt_color::black)
{
s->color = rbt_color::red;
s->right->color = rbt_color::black;
rotate_left(s);
s = p->left;
}
s->color = p->color;
p->color = rbt_color::black;
if (s->left)
s->left->color = rbt_color::black;
rotate_right(p);
break;
}
}
}
if (x)
x->color = rbt_color::black;
}
private:
node_pointer _head;
size_type _size;
key_compare _comp;
};
#endif
rb_tree 源码
于 2016-12-28 09:38:15 首次发布