【C++】【二叉树】【AVL 树】构建avl tree 类

本文深入讲解AVL树的原理及实现,包括恢复平衡的四种情况、插入与删除节点的方法,并提供完整的AVL树类实现代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


前言

AVL树的引入是为了改造BST 的不平衡性,防止退化为链表。BST 类的构造在文末有链接。

AVL树使得其任一节点的左右子树高度差不超过1。平衡性相对宽松。使得其可以有看起来不平衡但却认为平衡的结构。

而二叉堆、线段树等平衡二叉树是叶子节点的高度差不超过1。

其在插入和删除的时候,都需要从插入、删除节点处向上维护平衡。

一、重点解析

1、恢复平衡

恢复平衡分为了四种情况,分别记为LL、RR、LR、RL;

LL:向左子树的左孩子插入节点导致不平衡;解决方案:右旋

RR:向右子树的右孩子插入节点导致不平衡;解决方案:左旋

LR:向左子树的右孩子插入节点导致不平衡;解决方案:左子树左旋+当前节点右旋

RL:向右子树的左孩子插入节点导致不平衡;解决方案:右子树右旋+当前节点左旋

下面分别是LL、RR、LR、RL;最后一张图为LR 进行左旋+右旋后的结果
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

这里是向左子树的右孩子插入节点 6.使得其变成LR,然后对左子树节点5进行左旋,然后再整体右旋。

在这里插入图片描述

2、插入节点

相对于BST,插入节点加入了:

  1. 对当前节点的高度计算;
  2. 对平衡因子的计算(左右子树的高度计算);
  3. 根据平衡性,分四种情况进行左旋右旋->以恢复平衡。
// 向以node为根的二叉搜索树中,插入节点(key, value)
// 返回插入新节点后的二叉搜索树的根
Node *add(Node *node, Key key, Value value) {
    if (node == nullptr) {
        size++;
        return new Node(key, value);
    }
    if (key == node->key) {
        node->value = value;
    } else if (key < node->key) {
        node->left = add(node->left, key, value);
    } else {
        node->right = add(node->right, key, value);
    }
    node->height = 1 + std::max(getHeight(node->left), getHeight(node->right));
    int balanceFactor = getBalanceFactor(node);

    if (balanceFactor > 1 && getBalanceFactor(node->left) >= 0) {
        return rightRotate(node);
    }

    if (balanceFactor < -1 && getBalanceFactor(node->right) <= 0) {
        return leftRotate(node);
    }

    if (balanceFactor > 1 && getBalanceFactor(node->left) < 0) {
        node->left = leftRotate(node->left);
        return rightRotate(node);
    }

    if (balanceFactor < -1 && getBalanceFactor(node->right) > 0) {
        node->right = rightRotate(node->right);
        return leftRotate(node);
    }

    return node;
}

3、删除节点

相对于BST,删除节点加入了:

  1. 删除节点后,将要返回的根节点暂存。
  2. 对当前节点的高度计算;
  3. 对平衡因子的计算(左右子树的高度计算);
  4. 根据平衡性,分四种情况进行左旋右旋->以恢复平衡。

取消了删除最值,因为还要给删除最值中添加恢复平衡的代码,不如就只复用remove 的代码即可。

// 删除掉以node为根的二分搜索树中键值为key的节点
// 返回删除节点后新的二分搜索树的根
Node *remove(Node *node, Key key) {
    if (node == nullptr) {
        return nullptr;
    }

    Node *retNode;
    if (key < node->key) {
        node->left = remove(node->left, key);
        retNode = node;
    } else if (key > node->key) {
        node->right = remove(node->right, key);
        retNode = node;
    } else {
        if (node->left == nullptr) {
            Node *rightNode = node->right;
            delete node;
            size--;
            retNode = rightNode;
        } else if (node->right == nullptr) {
            Node *leftNode = node->left;
            delete node;
            size--;
            retNode = leftNode;
        } else {
            Node *successor = new Node(minimum(node->right));
            size++;

            successor->right = remove(node->right, successor->key);
            successor->left = node->left;

            delete node;
            size--;

            retNode = successor;
        }
    }
    if (retNode == nullptr) 
    	return nullptr;

    retNode->height = 1 + std::max(getHeight(retNode->left), getHeight(retNode->right));
    int balanceFactor = getBalanceFactor(retNode);

    if (balanceFactor > 1 && getBalanceFactor(retNode->left) >= 0) {
        return rightRotate(retNode);
    }

    if (balanceFactor < -1 && getBalanceFactor(retNode->right) <= 0) {
        return leftRotate(retNode);
    }

    if (balanceFactor > 1 && getBalanceFactor(retNode->left) < 0) {
        retNode->left = leftRotate(retNode->left);
        return rightRotate(retNode);
    }

    if (balanceFactor < -1 && getBalanceFactor(retNode->right) > 0) {
        retNode->right = rightRotate(retNode->right);
        return leftRotate(retNode);
    }

    return retNode;
}

3、析构函数

析构函数,需要用到后序遍历的思想。

即,删除当前节点,得先将其左右子树都删除掉,否则无法删除整个树。

下面代码,其实就是后序遍历的递归形式的析构函数。

    // 析构函数, 释放二分搜索树的所有空间
    ~BST(){
        destroy( root );
    }

    // 释放以node为根的二分搜索树的所有节点
    // 采用后续遍历的递归算法
    void destroy(Node* node){

        if( node != NULL ){
            destroy( node->left );
            destroy( node->right );

            delete node;
            count --;
        }
    }

二、完整AVL tree类

完整代码取自bobo老师及相关同学的github,文末有路径。

template<typename Key, typename Value>
class AVLTree {
private:
    struct Node {
        Key key;
        Value value;
        Node *left;
        Node *right;
        int height;

        Node(Key key, Value value) {
            this->key = key;
            this->value = value;
            this->left = this->right = nullptr;
            height = 1;
        }

        Node(Node *node) {
            this->key = node->key;
            this->value = node->value;
            this->left = node->left;
            this->right = node->right;
            this->height = node->height;
        }
    };

    Node *root;
    int size;

public:

    AVLTree() {
        root = nullptr;
        size = 0;
    }

    ~AVLTree() {
        destroy(root);
    }

    int getSize() {
        return size;
    }

    int isEmpty() {
        return size == 0;
    }

    int getHeight(Node *node) {
        if (node == nullptr) {
            return 0;
        }
        return node->height;
    }

    int getBalanceFactor(Node *node) {
        if (node == nullptr) {
            return 0;
        }
        return getHeight(node->left) - getHeight(node->right);
    }

    bool isBST() {
        std::vector<Key> keys;
        inOrder(root, keys);
        for (int i = 1; i < keys.size(); ++i) {
            if (keys.at(i - 1) < keys.at(i)) {
                return false;
            }
        }
        return true;
    }

    bool isBalanced() {
        return isBalanced(root);
    }

    void add(Key key, Value value) {
        root = add(root, key, value);
    }

    bool contains(Key key) {
        return getNode(root, key) != nullptr;
    }

    Value *get(Key key) {
        Node *node = getNode(root, key);
        return node == nullptr ? nullptr : &(node->value);
    }

    void set(Key key, Value newValue) {
        Node *node = getNode(root, key);
        if (node != nullptr) {
            node->value = newValue;
        }
    }

    // 从二叉树中删除键值为key的节点
    Value *remove(Key key) {
        Node *node = getNode(root, key);
        if (node != nullptr) {
            root = remove(root, key);
            return &(node->value);
        }
        return nullptr;
    }

private:

    // 向以node为根的二叉搜索树中,插入节点(key, value)
    // 返回插入新节点后的二叉搜索树的根
    Node *add(Node *node, Key key, Value value) {
        if (node == nullptr) {
            size++;
            return new Node(key, value);
        }
        if (key == node->key) {
            node->value = value;
        } else if (key < node->key) {
            node->left = add(node->left, key, value);
        } else {
            node->right = add(node->right, key, value);
        }
        node->height = 1 + std::max(getHeight(node->left), getHeight(node->right));
        int balanceFactor = getBalanceFactor(node);

        if (balanceFactor > 1 && getBalanceFactor(node->left) >= 0) {
            return rightRotate(node);
        }

        if (balanceFactor < -1 && getBalanceFactor(node->right) <= 0) {
            return leftRotate(node);
        }

        if (balanceFactor > 1 && getBalanceFactor(node->left) < 0) {
            node->left = leftRotate(node->left);
            return rightRotate(node);
        }

        if (balanceFactor < -1 && getBalanceFactor(node->right) > 0) {
            node->right = rightRotate(node->right);
            return leftRotate(node);
        }

        return node;
    }

    // 在以node为根的二叉搜索树中查找key所对应的Node
    Node *getNode(Node *node, Key key) {
        if (node == nullptr) {
            return nullptr;
        }
        if (key == node->key) {
            return node;
        } else if (key < node->key) {
            return getNode(node->left, key);
        } else {
            return getNode(node->right, key);
        }
    }

    void destroy(Node *node) {
        if (node != nullptr) {
            destroy(node->left);
            destroy(node->right);
            delete node;
            size--;
        }
    }

    // 在以node为根的二叉搜索树中,返回最小键值的节点
    Node *minimum(Node *node) {
        if (node->left == nullptr)
            return node;
        return minimum(node->left);
    }

    // 在以node为根的二叉搜索树中,返回最大键值的节点
    Node *maximum(Node *node) {
        if (node->right == nullptr)
            return node;
        return maximum(node->right);
    }

    // 删除掉以node为根的二分搜索树中的最小节点
    // 返回删除节点后新的二分搜索树的根
    Node *removeMin(Node *node) {
        if (node->left == nullptr) {
            Node *rightNode = node->right;
            delete node;
            size--;
            return rightNode;
        }

        node->left = removeMin(node->left);
        return node;
    }

    // 删除掉以node为根的二分搜索树中的最大节点
    // 返回删除节点后新的二分搜索树的根
    Node *removeMax(Node *node) {
        if (node->right == nullptr) {
            Node *leftNode = node->left;
            delete node;
            size--;
            return leftNode;
        }

        node->right = removeMax(node->right);
        return node;
    }

    // 删除掉以node为根的二分搜索树中键值为key的节点
    // 返回删除节点后新的二分搜索树的根
    Node *remove(Node *node, Key key) {
        if (node == nullptr) {
            return nullptr;
        }

        Node *retNode;
        if (key < node->key) {
            node->left = remove(node->left, key);
            retNode = node;
        } else if (key > node->key) {
            node->right = remove(node->right, key);
            retNode = node;
        } else {
            if (node->left == nullptr) {
                Node *rightNode = node->right;
                delete node;
                size--;
                retNode = rightNode;
            } else if (node->right == nullptr) {
                Node *leftNode = node->left;
                delete node;
                size--;
                retNode = leftNode;
            } else {
                Node *successor = new Node(minimum(node->right));
                size++;

                successor->right = remove(node->right, successor->key);
                successor->left = node->left;

                delete node;
                size--;

                retNode = successor;
            }
        }
        if (retNode == nullptr) {
            return nullptr;
        }
        retNode->height = 1 + std::max(getHeight(retNode->left), getHeight(retNode->right));
        int balanceFactor = getBalanceFactor(retNode);

        if (balanceFactor > 1 && getBalanceFactor(retNode->left) >= 0) {
            return rightRotate(retNode);
        }

        if (balanceFactor < -1 && getBalanceFactor(retNode->right) <= 0) {
            return leftRotate(retNode);
        }

        if (balanceFactor > 1 && getBalanceFactor(retNode->left) < 0) {
            retNode->left = leftRotate(retNode->left);
            return rightRotate(retNode);
        }

        if (balanceFactor < -1 && getBalanceFactor(retNode->right) > 0) {
            retNode->right = rightRotate(retNode->right);
            return leftRotate(retNode);
        }

        return retNode;
    }

    void inOrder(Node *node, std::vector<Key> keys) {
        if (node == nullptr) {
            return;
        }
        inOrder(node->left, keys);
        keys.push_back(node->key);
        inOrder(node->right, keys);
    }

    bool isBalanced(Node *node) {
        if (node == nullptr) {
            return true;
        }

        int balanceFactor = getBalanceFactor(node);
        if (std::abs(balanceFactor) > 1) {
            return false;
        }

        return isBalanced(node->left) && isBalanced(node->right);
    }

    Node *leftRotate(Node *y) {
        Node *x = y->right;
        Node *tmp = x->left;
        x->left = y;
        y->right = tmp;
        y->height = std::max(getHeight(y->left), getHeight(y->right)) + 1;
        x->height = std::max(getHeight(x->left), getHeight(x->right)) + 1;
        return x;
    }

    Node *rightRotate(Node *y) {
        Node *x = y->left;
        Node *tmp = x->right;
        x->right = y;
        y->left = tmp;
        y->height = std::max(getHeight(y->left), getHeight(y->right)) + 1;
        x->height = std::max(getHeight(x->left), getHeight(x->right)) + 1;
        return x;
    }
};

参考

模板类、模板函数设计

二叉搜索树类 实现

bobo老师github

houpengfei的github

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值