BST.h
#pragma once
#include <stdexcept>
#include <iterator>
template<typename Key,typename Value>
class BST
{
private:
class Node
{
public:
Node* left = nullptr;
Node* right = nullptr;
Key key;
Value value;
int count = 0;
public:
Node(const Key& key, const Value& value, const int& count)
{
this->key = key;
this->value = value;
this->count = count;
}
};
Node* root = nullptr;
private:
int size(Node* root)
{
if (root == nullptr)
return 0;
return root->count;
}
Node* put(Node* r,const Key& k, const Value& v)
{
if (r == nullptr)
return new Node(k, v, 1);
if (r->key > k)
r->left = put(r->left, k, v);
else if (r->key < k)
r->right = put(r->right, k, v);
else
r->value = v;
r->count = size(r->left) + size(r->right) + 1;
return r;
}
Node* get(Node* r,const Key& k)
{
if (r == nullptr)
return nullptr;
if (r->key > k)
get(r->left, k);
else if (r->key < k)
get(r->right, k);
else
return r;
}
Node* min(Node* r)
{
if (r->left != nullptr)
return min(r->left);
else
return r;
}
Node* max(Node* r)
{
if (r->right != nullptr)
return max(r->right);
else
return r;
}
Node* deleteMin(Node* r)
{
if (r->left == nullptr)
return r->right;
r->left = deleteMin(r->left);
r->count = size(r->left) + size(r->right) + 1;
return r;
}
Node* deleteMax(Node* r)
{
if (r->right == null)
return r->left;
r->right = deleteMax(r->right);
r->count = size(r->left) + size(r->right) + 1;
return r;
}
Node* erase(Node* r,const Key& key)
{
// if you use ,need safe-check
if (r->key > key)
r->left = erase(r->left, key);
else if (r->key < key)
r->right = erase(r->right, key);
else
{
if (r->left == nullptr) return r->right;
if (r->right == nullptr) return r->left;
Node* curr = r;
r = min(curr->right);
r->right = deleteMin(curr->right);
r->left = curr->left;
}
r->count = size(r->left) + size(r->right) + 1;
return r;
}
Node* floor(Node* r,const Key& key)
{
// 向下取整
if (r == nullptr)
return nullptr;
if (r->key == key)
return r;
else if (r->key > key)
return floor(r->left, key);
Node* t = floor(r->right, key);
if (t != nullptr)
return t;
else
return r;
}
Node* ceiling(Node* r,const Key& k)
{
if (r == nullptr)
return nullptr;
if (r->key == k)
return r;
else if (r->key < k)
return ceiling(r->right, k);
Node* t = ceiling(r->right, k);
if (t != nullptr)
return t;
else
return r;
}
Node* select(Node* r, const int& k)
{
if (r == nullptr)
return nullptr;
int t = size(r->left);
if (t > k)
return select(r->left, k);
else if (t < k)
return select(r->right, k - t - 1);
else
return r;
}
int rank(Node* r, const Key& k)
{
if (r == nullptr)
return 0;
if (r->key > k)
return rank(r->left, k);
else if (r->key < k)
return 1 + size(r->left) + rank(r->right, k);
else
return size(r->left);
}
void before_display(Node* r)
{
if (r == nullptr)
return;
before_display(r->left);
cout << "Key: " << r->key << ends << "Value: " << r->value << endl;
before_display(r->right);
}
public:
int size()
{
return size(root);
}
void put(const Key& k, const Value& v)
{
root = put(root,k, v);
}
Value get(const Key& k)
{
Node* ret = nullptr;
if ((ret = get(root, k)) == nullptr)
throw std::out_of_range("can't find the key's value");
else
return ret->value;
}
Key min()
{
if (root == nullptr)
throw std::out_of_range("you must add key in the BST(by min())");
else
return min(root)->key;
}
Key max()
{
if (root == nullptr)
throw out_of_range("you must add key in the BST(by max())");
else
return max(root)->key;
}
void deleteMin()
{
if (root == nullptr)
throw out_of_range("you must add key in the BST(by deleteMin())");
else
root = deleteMin(root);
}
void deleteMax()
{
if (root == nullptr)
throw out_of_range("you must add key in the BST(by deleteMax())");
else
root = deleteMax(root);
}
void erase(const Key& key)
{
if (root == nullptr)
return;
else
root = erase(root, key);
}
Key floor(const Key& key)
{
Node* curr = nullptr;
if ((curr = floor(root,key)) == nullptr)
throw out_of_range("can't floor");
else
return curr->key;
}
Key ceiling(const Key& key)
{
Node* ret = nullptr;
if ((ret = ceiling(root, key)) == nullptr)
throw out_of_range("can't ceiling");
else
return ret->key;
}
Key select(const int& k)
{
Node* ret = nullptr;
if ((ret = select(root, k)) == nullptr)
throw std::out_of_range("select out-of-range");
else
return ret->key;
}
int rank(const Key& k)
{
return rank(root, k);
}
void before_display()
{
before_display(root);
}
};
main.cpp
#include <iostream>
#include "BST.h"
using namespace std;
int main()
{
BST<double,int> bst;
for (int i = 0; i < 10; ++i)
bst.put(i + 0.1, i);
bst.deleteMin();
cout << bst.min() << endl;
bst.before_display();
bst.erase(5.1);
bst.before_display();
cout << bst.select(3) << ends << bst.size() << endl;
cout << bst.rank(4.1) << endl;
cout << bst.get(6.1) << endl;
cout << bst.floor(5.1) << endl;
system("pause");
return 0;
}
运行: