伸展树的三种旋转:单旋转,一字型旋转,之字形旋转
为了简化程序,将之字形旋转转变为如下旋转方式:
-SplayTree.h 代码:
/*
* by: peige
* 2015/11/29
*/
#ifndef __SPLAY_TREE_H__
#define __SPLAY_TREE_H__
#include <iostream>
template <class T>
struct SplayTreeNode {
T key; // 关键字
SplayTreeNode* left; // 左儿子
SplayTreeNode* right; // 右儿子
SplayTreeNode(const T& value = T(), SplayTreeNode* lt = NULL,
SplayTreeNode* rt = NULL) :key(value), left(lt), right(rt) {}
};
template <class T, class Compare = std::less<T> >
class SplayTree {
private:
SplayTreeNode<T>* mRoot; // 根节点
Compare isLessThan; // Compare
public:
SplayTree();
~SplayTree();
void preOrder()const; // 前序遍历
void inOrder()const; // 中序遍历
void postOrder()const; // 后序遍历
bool find(const T& key);
T findMin(); // 查找最小结点,返回键值
T findMax(); // 查找最大结点,返回键值
void splay(const T& key); // 旋转key对应的结点为根节点
void insert(const T& key); // 将结点(key为结点键值)插入到伸展树中
void remove(const T& key); // 删除结点(key为结点键值)
void destroy(); // 销毁伸展树
void print()const; // 打印伸展树
private:
void preOrder(SplayTreeNode<T>* node)const;
void inOrder(SplayTreeNode<T>* node)const;
void postOrder(SplayTreeNode<T>* node)const;
bool find(SplayTreeNode<T>*& node, const T& key);
T findMin(SplayTreeNode<T>*& node);
T findMax(SplayTreeNode<T>*& node);
// 旋转key对应的节点为根节点,并返回值为根节点
void splay(SplayTreeNode<T>*& node, const T& key);
void destroy(SplayTreeNode<T>*& node);
void print(SplayTreeNode<T>* node, const T& key, const int& direction)const;
void rotateWithLeftChild(SplayTreeNode<T>*& k1); // 左单旋
void rotateWithRightChild(SplayTreeNode<T>*& k1); // 右单旋
};
// constructor
template <class T, class Compare>
SplayTree<T, Compare>::SplayTree() :mRoot(NULL) {
}
// destructor
template <class T, class Compare>
SplayTree<T, Compare>::~SplayTree() {
destroy(mRoot);
}
/**
* Internal method: splay
* 整个伸展树中最重要的函数
* 对值为key的节点伸展,如果一项也没有找到,
* 那么就要对访问路径上的最后的节点进行一次伸展(important)
* 在insert,remove,findMax,findMin,find中都要进行伸展
*/
template <class T, class Compare>
void SplayTree<T, Compare>::splay(SplayTreeNode<T>*& tree, const T& key) {
if (NULL == tree) return;
SplayTreeNode<T> header, *l, *r;
header.left = header.right = NULL;
l = r = &header;
for (;;) {
if (isLessThan(key, tree->key)) { // key is smaller than tree->key
if (NULL == tree->left) // 没找到
break;
if (isLessThan(key, tree->left->key)) { // 一字型
rotateWithLeftChild(tree); // 左单旋
if (NULL == tree->left) // 没找到
break;
}
r->left = tree; // link right
r = tree;
tree = tree->left;
}
else if (isLessThan(tree->key, key)) { // tree->key is smaller than key
if (NULL == tree->right)
break;
if (isLessThan(tree->right->key, key)) { // 一字型
rotateWithRightChild(tree);
if (NULL == tree->right) // 没找到
break;
}
l->right = tree; // link left
l = tree;
tree = tree->right;
}
else
break;
}
l->right = tree->left;
r->left = tree->right;
tree->left = header.right;
tree->right = header.left;
}
// splay
template <class T, class Compare>
void SplayTree<T, Compare>::splay(const T& key) {
splay(mRoot, key);
}
// Internal method: rotateWithLeftChild(左单旋)
template <class T, class Compare>
void SplayTree<T, Compare>::rotateWithLeftChild(SplayTreeNode<T>*& k1) {
SplayTreeNode<T>* k2 = k1->left;
k1->left = k2->right;
k2->right = k1;
k1 = k2;
}
// Internal method: rotateWithRightChild(右单旋)
template <class T, class Compare>
void SplayTree<T, Compare>::rotateWithRightChild(SplayTreeNode<T>*& k1) {
SplayTreeNode<T>* k2 = k1->right;
k1->right = k2->left;
k2->left = k1;
k1 = k2;
}
/**
* insert
* 对值key进行伸展(树中可能已经含有有key值)
* 旋转后1.如果mRoot->key就是key值(即本来就存在key),则return
* 2.如果mRoot->key < key 则将mRoot和它的左子树放到newNode的左子树
* 将它的右子树放到newNode的右子树,将mRoot右子树置位空,再改变mRoot为newNode
* 3.如果mRoot->key > key 和2差不多
*/
template <class T, class Compare>
void SplayTree<T, Compare>::insert(const T& key) {
if (NULL == mRoot) {
mRoot = new SplayTreeNode<T>(key);
return;
}
splay(key);
if (isLessThan(key, mRoot->key)) {
SplayTreeNode<T>* newNode = new SplayTreeNode<T>(key);
newNode->left = mRoot->left;
newNode->right = mRoot;
mRoot->left = NULL;
mRoot = newNode;
}
else if (isLessThan(mRoot->key, key)) {
SplayTreeNode<T>* newNode = new SplayTreeNode<T>(key);
newNode->right = mRoot->right;
newNode->left = mRoot;
mRoot->right = NULL;
mRoot = newNode;
}
else {
return;
}
}
/**
* remove
* 对值key进行伸展(树中可能不存在key)
* 1.如果mRoot->key != key(树中本来不存在key),则return
* 2.mRoot->key == key
* ⑴如果mRoot左子树为空,则mRoot指向右儿子,成为一棵新的树,再delete原来的根节点
* ⑵如果mRoot左子树不为空,则对mRoot的左子树splay,将左子树中最大的key伸展到mRoot的左儿子
* 然后mRoot左儿子的右子树指向mRoot的右子树,此时mRoot的左儿子即可成为一个新的树,再delete原来的根节点
*/
template <class T,class Compare>
void SplayTree<T, Compare>::remove(const T& key) {
if (NULL == mRoot) return;
SplayTreeNode<T>* delNode;
splay(key);
if (isLessThan(key, mRoot->key) || isLessThan(mRoot->key, key))
return;
if (mRoot->left == NULL) {
delNode = mRoot;
mRoot = mRoot->right;
}
else {
delNode = mRoot;
splay(mRoot->left, mRoot->key); // 把左子树中最大的元素伸展到根
mRoot->left->right = mRoot->right;
mRoot = mRoot->left;
}
delete delNode;
}
// findMin
template <class T, class Compare>
T SplayTree<T, Compare>::findMin() {
return findMin(mRoot);
}
// Internal method: findMin
template <class T, class Compare>
T SplayTree<T, Compare>::findMin(SplayTreeNode<T>*& node) {
if (NULL == node)
throw std::runtime_error("can not find min in SplayTree.");
SplayTreeNode<T>* min = node;
while (NULL != min->left)
min = min->left;
T result = min->key;
splay(node, result);
return result;
}
// findMax
template <class T, class Compare>
T SplayTree<T, Compare>::findMax() {
return findMax(mRoot);
}
// Internal method: findMax
template <class T,class Compare>
T SplayTree<T, Compare>::findMax(SplayTreeNode<T>*& node) {
if (NULL == node)
throw std::runtime_error("can not find max in SplayTree.");
SplayTreeNode<T>* max = node;
while (NULL != max->right)
max = max->right;
T result = max->key;
splay(node, result);
return result;
}
// find
template <class T, class Compare>
bool SplayTree<T, Compare>::find(const T& key) {
return find(mRoot, key);
}
// Internal method: find
template <class T, class Compare>
bool SplayTree<T, Compare>::find(SplayTreeNode<T>*& node, const T& key) {
splay(node, key);
if (isLessThan(key, mRoot->key) || isLessThan(mRoot->key, key))
return false;
return true;
}
// 前序遍历:DLR
template <class T, class Compare>
void SplayTree<T, Compare>::preOrder()const {
preOrder(mRoot);
}
// Internal method: preOrder
template <class T, class Compare>
void SplayTree<T, Compare>::preOrder(SplayTreeNode<T>* node)const {
if (NULL == node) return;
cout << node->key << " ";
preOrder(node->left);
preOrder(node->right);
}
// 中序遍历:LDR
template <class T, class Compare>
void SplayTree<T, Compare>::inOrder()const {
inOrder(mRoot);
}
// Internal method: inOrder
template <class T, class Compare>
void SplayTree<T, Compare>::inOrder(SplayTreeNode<T>* node)const {
if (NULL == node) return;
inOrder(node->left);
cout << node->key << " ";
inOrder(node->right);
}
// 后序遍历:LRD
template <class T,class Compare>
void SplayTree<T, Compare>::postOrder()const {
postOrder(mRoot);
}
// Internal method: postOrder
template <class T, class Compare>
void SplayTree<T, Compare>::postOrder(SplayTreeNode<T>* node)const {
if (NULL == node) return;
inOrder(node->left);
inOrder(node->right);
cout << node->key << " ";
}
// destroy
template <class T, class Compare>
void SplayTree<T, Compare>::destroy() {
destroy(mRoot);
}
// Internal: destroy
template <class T, class Compare>
void SplayTree<T, Compare>::destroy(SplayTreeNode<T>*& node) {
if (NULL == node) return;
destroy(node->left);
destroy(node->right);
delete node;
}
// print
template <class T, class Compare>
void SplayTree<T, Compare>::print()const {
print(mRoot, T(), 0);
}
// Internal method: print
template <class T, class Compare>
void SplayTree<T, Compare>::print(SplayTreeNode<T>* node, const T& key, const int& direction)const {
if (NULL == node) return;
if (node == mRoot)
std::cout << node->key << " is the root" << std::endl;
else {
std::cout << node->key << " is " << key << "'s "
<< ((direction == 1) ? "left" : "right") << " son" << std::endl;
}
print(node->left, node->key, 1);
print(node->right, node->key, 2);
}
#endif
#include <iostream>
#include "SplayTree.h"
using namespace std;
int main()
{
SplayTree<int>* pist = new SplayTree<int>;
for (int i = 1; i <= 7; ++i)
pist->insert(i);
pist->print();
cout << pist->find(1) << endl << endl;
pist->print();
cout << pist->find(8) << endl << endl;
pist->print();
cout << pist->findMin() << endl << endl;
pist->print();
cout << pist->findMax() << endl << endl;
pist->destroy();
return 0;
}
测试2 -main.cpp
#include <iostream>
#include "SplayTree.h"
using namespace std;
int main()
{
SplayTree<int>* pist = new SplayTree<int>;
int n = 1;
for (int i = 1; i <= 7; n += 2, ++i)
pist->insert(n);
pist->find(1);
pist->print();
cout << endl;
pist->find(10);
pist->print();
pist->destroy();
return 0;
}