#include <iostream>
using namespace std;
class Node {
public:
Node(int key_):left(NULL),right(NULL),key(key_){}
Node* left;
Node* right;
int key;
};
class BST {
public:
BST() : root(NULL) {}
~BST() {
clear(root);
}
void clear(Node* node) {
if (node == NULL) return;
clear(node->left);
clear(node->right);
delete node;
}
Node* searchNode(int key) {
Node* result = root;
while(result != NULL) {
if (result->key == key) return result;
result = key < result->key ? result->left : result->right;
}
return result;
}
Node* minSubNode(Node* node) {
Node* result = node;
while(result->left != NULL) result = result->left;
return result;
}
void insertNode(Node*& node, Node* new_node) {
if (node == NULL) node = new_node;
else if (new_node->key < node->key) insertNode(node->left, new_node);
else insertNode(node->right, new_node);
}
void insertKey(int key) {
Node* new_node = new Node(key);
insertNode(root, new_node);
}
Node* parNode(Node* child_node) {
if (child_node == root) return NULL;
Node* result = root;
while (result->left != child_node && result->right != child_node)
result = child_node->key < result->key ? result->left : result->right;
return result;
}
void sliceNode(Node* slice_node) {
Node* par = parNode(slice_node);
Node* slice_child_node = slice_node->left != NULL ? slice_node->left : slice_node->right;
if (par == NULL) {
root = slice_child_node;
} else {
if (par->left == slice_node) par->left = slice_child_node;
else par->right = slice_child_node;
}
}
void deleteNode(Node* del_node) {
if (del_node->left == NULL || del_node->right == NULL) {
sliceNode(del_node);
delete del_node;
} else {
Node* suc_node = minSubNode(del_node->left);
sliceNode(suc_node);
del_node->key = suc_node->key;
delete suc_node;
}
}
bool deleteKey(int key) {
Node* del_node = searchNode(key);
if (del_node == NULL) return false;
deleteNode(del_node);
return true;
}
private:
Node* root;
};