/*====================================================================
BSD 2-Clause License
Copyright (c) 2025, Ruler
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
====================================================================*/
#pragma once
#ifndef __CORE_AB_TREE_H__
#define __CORE_AB_TREE_H__
#include <memory>
#include <stdexcept>
#include <iterator>
#include <functional>
#include <utility>
#include "define.h"
#ifndef DEFAULT_ALLOCATOR
#define DEFAULT_ALLOCATOR(T) std::allocator<T>
#endif // !DEFAULT_ALLOCATOR
// Enum class abt_type
enum class abt_type : signed char
{
root = 0x00,
parent = -0x0F,
left = 0x12,
right = 0x13,
sibling = 0x04
};
// Class template abt_node
template <class T>
struct abt_node
{
using node_type = abt_node<T>;
using node_pointer = node_type*;
using const_node_pointer = const node_type*;
using node_reference = node_type&;
using const_node_reference = const node_type&;
node_pointer parent;
node_pointer left;
node_pointer right;
size_t size;
T data;
};
// Class template abt_type_traits
template <class Tree, bool IsConst>
struct abt_type_traits
{
using value_type = typename Tree::value_type;
using pointer = typename Tree::pointer;
using size_type = typename Tree::size_type;
using difference_type = typename Tree::difference_type;
using node_type = typename Tree::node_type;
using node_pointer = typename Tree::node_pointer;
using reference = value_type&;
};
template <class Tree>
struct abt_type_traits<Tree, true>
{
using value_type = typename Tree::value_type;
using pointer = typename Tree::const_pointer;
using size_type = typename Tree::size_type;
using difference_type = typename Tree::difference_type;
using node_type = typename Tree::node_type;
using node_pointer = typename Tree::node_pointer;
using reference = const value_type&;
};
// Class template abt_iterator
template <class Tree, bool IsConst>
class abt_iterator
{
public:
// types:
using value_type = typename abt_type_traits<Tree, IsConst>::value_type;
using pointer = typename abt_type_traits<Tree, IsConst>::pointer;
using reference = typename abt_type_traits<Tree, IsConst>::reference;
using size_type = typename abt_type_traits<Tree, IsConst>::size_type;
using difference_type = typename abt_type_traits<Tree, IsConst>::difference_type;
using node_type = typename abt_type_traits<Tree, IsConst>::node_type;
using node_pointer = typename abt_type_traits<Tree, IsConst>::node_pointer;
using iterator_type = abt_iterator<Tree, IsConst>;
using iterator_category = std::bidirectional_iterator_tag;
// construct/copy/destroy:
abt_iterator(void) noexcept
: _node(nullptr)
{}
explicit abt_iterator(const node_pointer node) noexcept
: _node(node)
{}
abt_iterator(const iterator_type& other) noexcept
: _node(other.get_pointer())
{}
inline iterator_type& operator=(const iterator_type& other) noexcept
{
if (this != &other)
_node = other.get_pointer();
return *this;
}
inline operator abt_iterator<Tree, true>(void) const noexcept
{
return abt_iterator<Tree, true>(_node);
}
// abt_iterator operations:
inline node_pointer get_parent(void) noexcept
{
return _node->parent;
}
inline const node_pointer get_parent(void) const noexcept
{
return _node->parent;
}
inline node_pointer get_pointer(void) noexcept
{
return _node;
}
inline const node_pointer get_pointer(void) const noexcept
{
return _node;
}
inline size_t get_size(void) const noexcept
{
return _node->size;
}
inline reference operator*(void) const noexcept
{
return _node->data;
}
inline pointer operator->(void) const noexcept
{
return &(operator*());
}
// increment / decrement
iterator_type& operator++(void) noexcept
{
if (_node->right)
{
_node = _node->right;
while (_node->left)
_node = _node->left;
}
else
{
node_pointer p = _node->parent;
while (_node == p->right)
{
_node = p;
p = p->parent;
}
if (_node->right != p)
_node = p;
}
return *this;
}
iterator_type& operator--(void) noexcept
{
if (!_node->parent || _node->parent->parent == _node)
_node = _node->right;
else if (_node->left)
{
node_pointer p = _node->left;
while (p->right)
p = p->right;
_node = p;
}
else
{
node_pointer p = _node->parent;
while (_node == p->left)
{
_node = p;
p = p->parent;
}
_node = p;
}
return *this;
}
inline iterator_type operator++(int) noexcept
{
iterator_type itr(*this);
this->operator++();
return itr;
}
inline iterator_type operator--(int) noexcept
{
iterator_type itr(*this);
this->operator--();
return itr;
}
// relational operators:
template <bool is_const>
inline bool operator==(const abt_iterator<Tree, is_const>& rhs) const noexcept
{
return _node == rhs.get_pointer();
}
template <bool is_const>
inline bool operator!=(const abt_iterator<Tree, is_const>& rhs) const noexcept
{
return _node != rhs.get_pointer();
}
private:
node_pointer _node;
};
// Class template abt_primitive_iterator
template <class Tree, bool IsConst>
class abt_primitive_iterator
{
public:
// types:
using value_type = typename abt_type_traits<Tree, IsConst>::value_type;
using pointer = typename abt_type_traits<Tree, IsConst>::pointer;
using reference = typename abt_type_traits<Tree, IsConst>::reference;
using size_type = typename abt_type_traits<Tree, IsConst>::size_type;
using difference_type = typename abt_type_traits<Tree, IsConst>::difference_type;
using node_type = typename abt_type_traits<Tree, IsConst>::node_type;
using node_pointer = typename abt_type_traits<Tree, IsConst>::node_pointer;
using iterator_type = abt_primitive_iterator<Tree, IsConst>;
using iterator_category = std::bidirectional_iterator_tag;
// construct/copy/destroy:
abt_primitive_iterator(void) noexcept
: _node(nullptr)
, _type(abt_type::root)
{}
explicit abt_primitive_iterator(const node_pointer node) noexcept
: _node(node)
, _type(abt_type::root)
{}
abt_primitive_iterator(const iterator_type& other) noexcept
: _node(other.get_pointer())
, _type(other.get_type())
{}
inline iterator_type& operator=(const iterator_type& other) noexcept
{
if (this != &other)
{
_node = other.get_pointer();
_type = other.get_type();
}
return *this;
}
inline operator abt_primitive_iterator<Tree, true>(void) const noexcept
{
return abt_primitive_iterator<Tree, true>(_node);
}
// abt_primitive_iterator operations:
inline node_pointer get_parent(void) noexcept
{
return _node->parent;
}
inline const node_pointer get_parent(void) const noexcept
{
return _node->parent;
}
inline node_pointer get_pointer(void) noexcept
{
return _node;
}
inline const node_pointer get_pointer(void) const noexcept
{
return _node;
}
inline abt_type get_type(void) const noexcept
{
return _type;
}
inline intptr_t get_depth(void) const noexcept
{
return static_cast<intptr_t>(std::underlying_type_t<abt_type>(_type) >> 4);
}
inline size_t get_size(void) const noexcept
{
return _node->size;
}
inline reference operator*(void) const noexcept
{
return _node->data;
}
inline pointer operator->(void) const noexcept
{
return &(operator*());
}
// increment / decrement
iterator_type& operator++(void) noexcept
{
if (_type != abt_type::parent && _node->left)
{
_node = _node->left;
_type = abt_type::left;
}
else if (_type != abt_type::parent && _node->right)
{
_node = _node->right;
_type = abt_type::right;
}
else if (_node != _node->parent->parent &&
_node->parent->right && _node != _node->parent->right)
{
_node = _node->parent->right;
_type = abt_type::sibling;
}
else
{
_node = _node->parent;
_type = abt_type::parent;
}
return *this;
}
iterator_type& operator--(void) noexcept
{
if (_type != abt_type::parent && _node->right)
{
_node = _node->right;
_type = abt_type::right;
}
else if (_type != abt_type::parent && _node->left)
{
_node = _node->left;
_type = abt_type::left;
}
else if (_node != _node->parent->parent &&
_node->parent->left && _node != _node->parent->left)
{
_node = _node->parent->left;
_type = abt_type::sibling;
}
else
{
_node = _node->parent;
_type = abt_type::parent;
}
return *this;
}
inline iterator_type operator++(int) noexcept
{
iterator_type itr(*this);
this->operator++();
return itr;
}
inline iterator_type operator--(int) noexcept
{
iterator_type itr(*this);
this->operator--();
return itr;
}
// relational operators:
template <bool is_const>
inline bool operator==(const abt_primitive_iterator<Tree, is_const>& rhs) const noexcept
{
return _node == rhs.get_pointer();
}
template <bool is_const>
inline bool operator!=(const abt_primitive_iterator<Tree, is_const>& rhs) const noexcept
{
return _node != rhs.get_pointer();
}
private:
node_pointer _node;
abt_type _type;
};
// Class template abt_node_allocator
template <class T, class Allocator>
class abt_node_allocator
: public std::allocator_traits<Allocator>::template rebind_alloc<T>
{
public:
// types:
using tree_traits_type = std::allocator_traits<Allocator>;
using tree_node_type = typename abt_node<T>::node_type;
using allocator_type = typename tree_traits_type::template rebind_alloc<T>;
using allocator_traits = typename tree_traits_type::template rebind_traits<T>;
using node_allocator_type = typename tree_traits_type::template rebind_alloc<tree_node_type>;
using node_allocator_traits = typename tree_traits_type::template rebind_traits<tree_node_type>;
using node_type = typename node_allocator_traits::value_type;
using node_pointer = typename node_allocator_traits::pointer;
using node_size_type = typename node_allocator_traits::size_type;
using node_difference_type = typename node_allocator_traits::difference_type;
// construct/copy/destroy:
abt_node_allocator(void)
: allocator_type()
, _alloc()
{}
explicit abt_node_allocator(const allocator_type& alloc)
: allocator_type(alloc)
, _alloc()
{}
// abt_node_allocator operations:
constexpr allocator_type& get_allocator(void) noexcept
{
return *this;
}
constexpr const allocator_type& get_allocator(void) const noexcept
{
return *this;
}
constexpr node_allocator_type& get_node_allocator(void) noexcept
{
return _alloc;
}
constexpr node_allocator_type& get_node_allocator(void) const noexcept
{
return _alloc;
}
inline node_size_type max_size(void) const noexcept
{
return _alloc.max_size();
}
protected:
template <class ...Args>
inline node_pointer create_node(Args&&... args)
{
node_pointer p = node_allocator_traits::allocate(_alloc, 1);
allocator_traits::construct(get_allocator(), std::addressof(p->data), std::forward<Args>(args)...);
return p;
}
inline void destroy_node(const node_pointer p)
{
allocator_traits::destroy(get_allocator(), std::addressof(p->data));
node_allocator_traits::deallocate(_alloc, p, 1);
}
private:
node_allocator_type _alloc;
};
// Class template ab_tree
template <class T, class Allocator = DEFAULT_ALLOCATOR(void)>
class ab_tree : public abt_node_allocator<T, Allocator>
{
public:
// types:
using tree_type = ab_tree<T, Allocator>;
using tree_base = abt_node_allocator<T, Allocator>;
using node_type = typename tree_base::node_type;
using node_pointer = typename tree_base::node_pointer;
using allocator_type = typename tree_base::allocator_type;
using allocator_traits = typename tree_base::allocator_traits;
using value_type = typename allocator_traits::value_type;
using size_type = typename allocator_traits::size_type;
using difference_type = typename allocator_traits::difference_type;
using pointer = typename allocator_traits::pointer;
using const_pointer = typename allocator_traits::const_pointer;
using reference = value_type&;
using const_reference = const value_type&;
using iterator = abt_iterator<tree_type, false>;
using const_iterator = abt_iterator<tree_type, true>;
using primitive_iterator = abt_primitive_iterator<tree_type, false>;
using const_primitive_iterator = abt_primitive_iterator<tree_type, true>;
using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
using reverse_primitive_iterator = std::reverse_iterator<primitive_iterator>;
using const_reverse_primitive_iterator = std::reverse_iterator<const_primitive_iterator>;
// construct/copy/destroy:
explicit ab_tree(const allocator_type& alloc = allocator_type())
: tree_base(alloc)
, _head(nullptr)
{
_head = create_header();
}
ab_tree(const tree_type& other)
: tree_base(other.get_allocator())
, _head(nullptr)
{
_head = create_header();
if (other._head->parent)
copy_node(other._head->parent);
}
ab_tree(const tree_type& other, const allocator_type& alloc)
: tree_base(alloc)
, _head(nullptr)
{
_head = create_header();
if (other._head->parent)
copy_node(other._head->parent);
}
ab_tree(tree_type&& other) noexcept
: tree_base()
, _head(nullptr)
{
_head = create_header();
swap(other);
}
ab_tree(tree_type&& other, const allocator_type& alloc) noexcept
: tree_base(alloc)
, _head(nullptr)
{
_head = create_header();
swap(other);
}
template <class InputIt>
ab_tree(InputIt first, InputIt last, const allocator_type& alloc)
: tree_base(alloc)
, _head(nullptr)
{
_head = create_header();
insert(cend(), first, last);
}
ab_tree(std::initializer_list<value_type> init, const allocator_type& alloc = allocator_type())
: tree_base(alloc)
, _head(nullptr)
{
_head = create_header();
assign(init.begin(), init.end());
}
~ab_tree(void)
{
clear();
destroy_header(_head);
}
inline tree_type& operator=(const tree_type& other)
{
if (this != &other)
{
clear();
if (other._head->parent)
copy_root(other._head->parent);
}
return *this;
}
inline tree_type& operator=(tree_type&& other) noexcept
{
swap(other);
return *this;
}
inline void assign(size_type n, const_reference value)
{
clear();
insert(cend(), n, value);
}
template <class InputIt>
inline void assign(InputIt first, InputIt last)
{
clear();
insert(cend(), first, last);
}
inline void assign(std::initializer_list<value_type> init)
{
assign(init.begin(), init.end());
}
// iterators:
inline iterator begin(void) noexcept
{
return iterator(_head->left);
}
inline const_iterator begin(void) const noexcept
{
return const_iterator(_head->left);
}
inline const_iterator cbegin(void) const noexcept
{
return const_iterator(_head->left);
}
inline iterator end(void) noexcept
{
return iterator(_head);
}
inline const_iterator end(void) const noexcept
{
return const_iterator(_head);
}
inline const_iterator cend(void) const noexcept
{
return const_iterator(_head);
}
inline reverse_iterator rbegin(void) noexcept
{
return reverse_iterator(end());
}
inline const_reverse_iterator rbegin(void) const noexcept
{
return const_reverse_iterator(end());
}
inline const_reverse_iterator crbegin(void) const noexcept
{
return const_reverse_iterator(cend());
}
inline reverse_iterator rend(void) noexcept
{
return reverse_iterator(begin());
}
inline const_reverse_iterator rend(void) const noexcept
{
return const_reverse_iterator(begin());
}
inline const_reverse_iterator crend(void) const noexcept
{
return const_reverse_iterator(cbegin());
}
inline primitive_iterator pbegin(void) noexcept
{
return primitive_iterator(root());
}
inline const_primitive_iterator pbegin(void) const noexcept
{
return const_primitive_iterator(root());
}
inline const_primitive_iterator cpbegin(void) const noexcept
{
return const_primitive_iterator(root());
}
inline primitive_iterator pend(void) noexcept
{
return primitive_iterator(_head);
}
inline const_primitive_iterator pend(void) const noexcept
{
return const_primitive_iterator(_head);
}
inline const_primitive_iterator cpend(void) const noexcept
{
return const_primitive_iterator(_head);
}
inline reverse_primitive_iterator rpbegin(void) noexcept
{
return reverse_primitive_iterator(pend());
}
inline const_reverse_primitive_iterator rpbegin(void) const noexcept
{
return const_reverse_primitive_iterator(pend());
}
inline const_reverse_primitive_iterator crpbegin(void) const noexcept
{
return const_reverse_primitive_iterator(cpend());
}
inline reverse_primitive_iterator rpend(void) noexcept
{
return reverse_primitive_iterator(pbegin());
}
inline const_reverse_primitive_iterator rpend(void) const noexcept
{
return const_reverse_primitive_iterator(pbegin());
}
inline const_reverse_primitive_iterator crpend(void) const noexcept
{
return const_reverse_primitive_iterator(cpbegin());
}
// capacity:
inline bool empty(void) const noexcept
{
return !_head->parent;
}
inline size_type size(void) const noexcept
{
return _head->parent ? _head->parent->size : 0;
}
// element access:
inline reference operator[](size_type idx) noexcept
{
return select_node(idx)->data;
}
inline const_reference operator[](size_type idx) const noexcept
{
return select_node(idx)->data;
}
inline reference at(size_type idx)
{
if (empty())
throw std::domain_error(ABT_NOT_INITIALIZED);
if (idx >= size())
throw std::out_of_range(ABT_OUT_OF_RANGE);
return select_node(idx)->data;
}
inline const_reference at(size_type idx) const
{
if (empty())
throw std::domain_error(ABT_NOT_INITIALIZED);
if (idx >= size())
throw std::out_of_range(ABT_OUT_OF_RANGE);
return select_node(idx)->data;
}
inline reference front(void)
{
return *begin();
}
inline const_reference front(void) const
{
return *begin();
}
inline reference back(void)
{
return *rbegin();
}
inline const_reference back(void) const
{
return *rbegin();
}
// modifiers:
template <class... Args>
inline void emplace_front(Args&&... args)
{
insert_node(_head->left, std::forward<Args>(args)...);
}
template <class... Args>
inline void emplace_back(Args&&... args)
{
insert_node(_head, std::forward<Args>(args)...);
}
template <class... Args>
inline iterator emplace(const_iterator pos, Args&&... args)
{
return iterator(insert_node(pos.get_pointer(), std::forward<Args>(args)...));
}
template <class... Args>
inline iterator emplace(size_type idx, Args&&... args)
{
return emplace(select(idx), std::forward<Args>(args)...);
}
inline void push_front(const_reference value)
{
insert_node(_head->left, value);
}
inline void push_front(value_type&& value)
{
insert_node(_head->left, std::forward<value_type>(value));
}
inline void push_back(const_reference value)
{
insert_node(_head, value);
}
inline void push_back(value_type&& value)
{
insert_node(_head, std::forward<value_type>(value));
}
inline void pop_front(void)
{
if (_head->parent)
erase_node(_head->left);
}
inline void pop_back(void)
{
if (_head->parent)
erase_node(_head->right);
}
inline iterator insert(const_iterator pos, const_reference value)
{
return iterator(insert_node(pos.get_pointer(), value));
}
inline iterator insert(const_iterator pos, value_type&& value)
{
return iterator(insert_node(pos.get_pointer(), std::forward<value_type>(value)));
}
inline iterator insert(const_iterator pos, size_type n, const_reference value)
{
node_pointer r, t = pos.get_pointer();
if (n > 0)
{
r = insert_node(t, value);
for (; n > 1; --n)
insert_node(t, value);
}
else
r = t;
return iterator(r);
}
template <class InputIt>
inline iterator insert(const_iterator pos, InputIt first, InputIt last)
{
node_pointer r, t = pos.get_pointer();
if (first != last)
{
r = insert_node(t, *first);
for (++first; first != last; ++first)
insert_node(t, *first);
}
else
r = t;
return iterator(r);
}
inline iterator insert(const_iterator pos, std::initializer_list<value_type> init)
{
return insert(pos, init.begin(), init.end());
}
inline iterator insert(size_type idx, const_reference value)
{
return insert(select(idx), value);
}
inline iterator insert(size_type idx, value_type&& value)
{
return insert(select(idx), std::forward<value_type>(value));
}
inline iterator insert(size_type idx, size_type n, const_reference value)
{
return insert(select(idx), n, value);
}
template <class InputIt>
inline iterator insert(size_type idx, InputIt first, InputIt last)
{
return insert(select(idx), first, last);
}
inline iterator insert(size_type idx, std::initializer_list<value_type> init)
{
return insert(idx, init.begin(), init.end());
}
inline iterator erase(const_iterator pos)
{
iterator next = iterator(pos.get_pointer());
if (pos != cend())
{
++next;
erase_node(pos.get_pointer());
}
return next;
}
inline iterator erase(const_iterator first, const_iterator last)
{
iterator next = iterator(last.get_pointer());
if (first == cbegin() && last == cend())
clear();
else
while (first != last)
erase(first++);
return next;
}
inline void erase(size_type idx)
{
node_pointer t = select_node(idx);
if (t != _head)
erase_node(t);
}
inline void erase(size_type idx, size_type n)
{
iterator itr = select(idx);
node_pointer t = itr.get_pointer();
for (; n > 0 && itr != cend(); --n)
{
++itr;
erase_node(t);
t = itr.get_pointer();
}
}
inline void swap(tree_type& other) noexcept
{
if (this != &other)
std::swap(_head, other._head);
}
inline void clear(void)
{
if (_head->parent)
{
erase_root();
_head->parent = nullptr;
_head->left = _head;
_head->right = _head;
}
}
// operations:
inline iterator select(size_type idx) noexcept
{
return iterator(select_node(idx));
}
inline const_iterator select(size_type idx) const noexcept
{
return const_iterator(select_node(idx));
}
private:
inline node_pointer root(void) const noexcept
{
return _head->parent ? _head->parent : _head;
}
inline node_pointer leftmost(node_pointer t) const noexcept
{
while (t->left)
t = t->left;
return t;
}
inline node_pointer rightmost(node_pointer t) const noexcept
{
while (t->right)
t = t->right;
return t;
}
inline node_pointer create_header(void)
{
node_pointer head = this->create_node(value_type());
head->parent = nullptr;
head->left = head;
head->right = head;
head->size = 0;
return head;
}
inline void destroy_header(node_pointer& head)
{
if (head)
{
this->destroy_node(head);
head = nullptr;
}
}
node_pointer select_node(size_type k) const noexcept
{
node_pointer t = _head->parent;
while (t)
{
size_type left_size = t->left ? t->left->size : 0;
if (left_size < k)
{
t = t->right;
k -= (left_size + 1);
}
else if (k < left_size)
t = t->left;
else
return t;
}
return _head;
}
void copy_node(const node_pointer t)
{
bool flag = true;
node_pointer src = t;
node_pointer dst = _head;
// copies the node t
node_pointer n = this->create_node(t->data);
n->parent = dst;
n->left = nullptr;
n->right = nullptr;
n->size = t->size;
dst->parent = n;
// update dst to root node
dst = n;
do
{
if (flag && src->left)
{
src = src->left;
// copies the left child node
n = this->create_node(src->data);
n->parent = dst;
n->left = nullptr;
n->right = nullptr;
n->size = src->size;
dst->left = n;
// update dst to left child node
dst = n;
}
else if (flag && src->right)
{
src = src->right;
// copies the right child node
n = this->create_node(src->data);
n->parent = dst;
n->left = nullptr;
n->right = nullptr;
n->size = src->size;
dst->right = n;
// update dst to right child node
dst = n;
}
else if (src->parent->right && src != src->parent->right)
{
src = src->parent->right;
// copies the sibling node
n = this->create_node(src->data);
n->parent = dst->parent;
n->left = nullptr;
n->right = nullptr;
n->size = src->size;
dst->parent->right = n;
// update dst to sibling node
dst = n;
flag = true;
}
else
{
// return to parent node
src = src->parent;
dst = dst->parent;
flag = false;
}
} while (src != t);
_head->left = leftmost(_head->parent);
_head->right = rightmost(_head->parent);
}
template <class ...Args>
node_pointer insert_node(node_pointer t, Args&&... args)
{
// creates a new node
node_pointer n = this->create_node(std::forward<Args>(args)...);
n->left = nullptr;
n->right = nullptr;
n->size = 1;
if (t == _head)
{
// if the tree is not empty
if (_head->parent)
{
t = _head->right;
// inserts the node
n->parent = t;
t->right = n;
_head->right = n;
// increases the size of nodes
for (node_pointer p = t; p != _head; p = p->parent)
++p->size;
do
{
// rebalance after insertion
t = insert_rebalance(t->parent, t == t->parent->right);
} while (t->parent != _head);
}
else
{
// inserts the node
n->parent = t;
_head->parent = n;
_head->left = n;
_head->right = n;
}
}
else if (t->left)
{
t = t->left;
while (t->right)
t = t->right;
// inserts the node
n->parent = t;
t->right = n;
// increases the size of nodes
for (node_pointer p = t; p != _head; p = p->parent)
++p->size;
do
{
// rebalance after insertion
t = insert_rebalance(t->parent, t == t->parent->right);
} while (t->parent != _head);
}
else
{
// inserts the node
n->parent = t;
t->left = n;
if (t == _head->left)
_head->left = n;
// increases the size of nodes
for (node_pointer p = t; p != _head; p = p->parent)
++p->size;
do
{
// rebalance after insertion
t = insert_rebalance(t->parent, t == t->parent->right);
} while (t->parent != _head);
}
return n;
}
void erase_node(node_pointer t)
{
bool flag;
node_pointer x;
node_pointer parent;
// if node t has two child nodes
if (t->left && t->right)
{
if (t->left->size < t->right->size)
{
x = leftmost(t->right);
// the rebalance flag
flag = (x == x->parent->right);
// reduces the size of nodes
for (node_pointer p = x->parent; p != _head; p = p->parent)
--p->size;
// replaces node t with node x and removes node t
t->left->parent = x;
x->left = t->left;
if (x != t->right)
{
x->parent->left = x->right;
if (x->right)
x->right->parent = x->parent;
t->right->parent = x;
x->right = t->right;
parent = x->parent;
}
else
parent = x;
if (t == _head->parent)
_head->parent = x;
else if (t == t->parent->left)
t->parent->left = x;
else
t->parent->right = x;
x->parent = t->parent;
x->size = t->size;
}
else
{
x = rightmost(t->left);
// the rebalance flag
flag = (x == x->parent->right);
// reduces the size of nodes
for (node_pointer p = x->parent; p != _head; p = p->parent)
--p->size;
// replaces node t with node x and removes node t
t->right->parent = x;
x->right = t->right;
if (x != t->left)
{
x->parent->right = x->left;
if (x->left)
x->left->parent = x->parent;
t->left->parent = x;
x->left = t->left;
parent = x->parent;
}
else
parent = x;
if (t == _head->parent)
_head->parent = x;
else if (t == t->parent->left)
t->parent->left = x;
else
t->parent->right = x;
x->parent = t->parent;
x->size = t->size;
}
// rebalance after deletion
node_pointer p = erase_rebalance(parent, flag);
while (p != _head)
p = erase_rebalance(p->parent, p == p->parent->right);
}
// if node t has one child node at most
else
{
x = t->left ? t->left : t->right;
// the rebalance flag
flag = (t == t->parent->right);
// removes node t
if (x)
x->parent = t->parent;
if (t == _head->parent)
_head->parent = x;
else if (t == t->parent->left)
t->parent->left = x;
else
t->parent->right = x;
if (t == _head->left)
_head->left = x ? leftmost(x) : t->parent;
if (t == _head->right)
_head->right = x ? rightmost(x) : t->parent;
// reduces the size of nodes
for (node_pointer p = t->parent; p != _head; p = p->parent)
--p->size;
if (t != _head)
{
// rebalance after deletion
node_pointer p = erase_rebalance(t->parent, flag);
while (p != _head)
p = erase_rebalance(p->parent, p == p->parent->right);
}
}
// destroy node
this->destroy_node(t);
}
void erase_root(void)
{
node_pointer next;
node_pointer cur = _head->parent;
do
{
while (cur->left)
cur = cur->left;
if (cur->right)
cur = cur->right;
else
{
next = cur->parent;
if (cur == next->left)
next->left = nullptr;
else
next->right = nullptr;
this->destroy_node(cur);
cur = next;
}
} while (cur != _head);
}
node_pointer rotate_left(node_pointer t) const noexcept
{
node_pointer r = t->right;
t->right = r->left;
if (r->left)
r->left->parent = t;
r->parent = t->parent;
if (t == _head->parent)
_head->parent = r;
else if (t == t->parent->left)
t->parent->left = r;
else
t->parent->right = r;
r->left = t;
r->size = t->size;
t->parent = r;
t->size = (t->left ? t->left->size : 0) + (t->right ? t->right->size : 0) + 1;
return r;
}
node_pointer rotate_right(node_pointer t) const noexcept
{
node_pointer l = t->left;
t->left = l->right;
if (l->right)
l->right->parent = t;
l->parent = t->parent;
if (t == _head->parent)
_head->parent = l;
else if (t == t->parent->right)
t->parent->right = l;
else
t->parent->left = l;
l->right = t;
l->size = t->size;
t->parent = l;
t->size = (t->left ? t->left->size : 0) + (t->right ? t->right->size : 0) + 1;
return l;
}
node_pointer insert_rebalance(node_pointer t, bool flag) const noexcept
{
if (flag)
{
if (t->right)
{
size_type left_size = t->left ? t->left->size : 0;
// case 1: size(T.left) < size(T.right.left)
if (t->right->left && left_size < t->right->left->size)
{
t->right = rotate_right(t->right);
t = rotate_left(t);
t->left = insert_rebalance(t->left, false);
t->right = insert_rebalance(t->right, true);
t = insert_rebalance(t, true);
}
// case 2. size(T.left) < size(T.right.right)
else if (t->right->right && left_size < t->right->right->size)
{
t = rotate_left(t);
t->left = insert_rebalance(t->left, false);
t = insert_rebalance(t, true);
}
}
}
else
{
if (t->left)
{
size_type right_size = t->right ? t->right->size : 0;
// case 3. size(T.right) < size(T.left.right)
if (t->left->right && right_size < t->left->right->size)
{
t->left = rotate_left(t->left);
t = rotate_right(t);
t->left = insert_rebalance(t->left, false);
t->right = insert_rebalance(t->right, true);
t = insert_rebalance(t, false);
}
// case 4. size(T.right) < size(T.left.left)
else if (t->left->left && right_size < t->left->left->size)
{
t = rotate_right(t);
t->right = insert_rebalance(t->right, true);
t = insert_rebalance(t, false);
}
}
}
return t;
}
node_pointer erase_rebalance(node_pointer t, bool flag) const noexcept
{
if (!flag)
{
if (t->right)
{
size_type left_size = t->left ? t->left->size : 0;
// case 1: size(T.left) < size(T.right.left)
if (t->right->left && left_size < t->right->left->size)
{
t->right = rotate_right(t->right);
t = rotate_left(t);
t->left = erase_rebalance(t->left, true);
t->right = erase_rebalance(t->right, false);
t = erase_rebalance(t, false);
}
// case 2. size(T.left) < size(T.right.right)
else if (t->right->right && left_size < t->right->right->size)
{
t = rotate_left(t);
t->left = erase_rebalance(t->left, true);
t = erase_rebalance(t, false);
}
}
}
else
{
if (t->left)
{
size_type right_size = t->right ? t->right->size : 0;
// case 3. size(T.right) < size(T.left.right)
if (t->left->right && right_size < t->left->right->size)
{
t->left = rotate_left(t->left);
t = rotate_right(t);
t->left = erase_rebalance(t->left, true);
t->right = erase_rebalance(t->right, false);
t = erase_rebalance(t, true);
}
// case 4. size(T.right) < size(T.left.left)
else if (t->left->left && right_size < t->left->left->size)
{
t = rotate_right(t);
t->right = erase_rebalance(t->right, false);
t = erase_rebalance(t, true);
}
}
}
return t;
}
private:
node_pointer _head;
};
#endif
ab_tree 源代码
于 2024-12-16 21:22:47 首次发布