平面扫描算法:高效计算线段交点的C++实现详解

引言

在计算几何学中,线段交点问题是一个基础而重要的问题,广泛应用于计算机图形学、地理信息系统(GIS)、电路设计、碰撞检测等领域。给定一组线段,如何高效地找出所有交点是一个具有挑战性的任务。平面扫描算法(Plane Sweep Algorithm)通过引入扫描线的概念,将二维问题转化为一维问题,实现了O((n+k)log n)的时间复杂度,其中n是线段数量,k是交点数量。

本文将深入解析平面扫描算法的原理,并通过完整的C++实现来展示其精妙之处。我们将从算法思想入手,逐步剖析代码实现,最后讨论算法的性能特点和优化空间。

算法原理详解

核心思想

平面扫描算法的核心思想是使用一条虚拟的扫描线(通常从上到下或从左到右)扫过整个平面,在扫描线移动过程中维护当前与扫描线相交的线段集合,并检测这些线段之间的交点。

算法步骤

  1. 初始化事件队列:将所有线段的端点作为事件点加入优先队列
  2. 处理事件点:按y坐标从大到小(从上到下)处理每个事件点
  3. 维护状态结构:使用平衡二叉搜索树维护当前与扫描线相交的线段
  4. 检测交点:在状态结构中相邻的线段可能产生新的交点

伪代码表示

function PlaneSweep(segments):
    events = 所有线段的端点
    status = 空平衡树
    intersections = 空集合
    
    while events不为空:
        p = 取出y坐标最大的事件点
        更新扫描线位置到p.y
        
        U = 上端点在p的线段
        L = 下端点在p的线段  
        C = 包含p但不是端点的线段
        
        if U+L+C中的线段数量 > 1:
            报告在p处的交点
        
        从状态结构中移除L和C
        向状态结构中插入U和C
        
        检查新相邻线段对的交点
        将新发现的交点加入事件队列
    
    return intersections

C++实现逐行解析

基础数据结构

点(Point)结构体
struct Point {
    double x, y;  // 点的x和y坐标
    
    Point(double x = 0, double y = 0);  // 构造函数,默认坐标为(0,0)
    
    // 比较运算符重载
    bool operator==(const Point& other) const;  // 判断两点是否相等
    bool operator!=(const Point& other) const;  // 判断两点是否不等
    bool operator<(const Point& other) const;   // 用于排序的比较操作
};

设计考量

  • 使用double类型存储坐标,提高计算精度
  • 重载比较运算符,便于在数据结构中进行比较和排序
  • 使用容差fTOL处理浮点数精度问题,避免因精度误差导致的错误判断
线段(Segment)结构体
struct Segment {
    Point upper, lower; // 上端点和下端点(y坐标大的为上端点)
    
    Segment(Point p1, Point p2); // 构造函数,自动确定上下端点
    
    // 成员函数
    bool contains(const Point& p) const; // 判断点是否在线段上
    bool operator==(const Segment& other) const; // 判断两线段是否相同
    bool operator<(const Segment& other) const;  // 用于排序的比较操作
    
    // 静态成员函数
    static bool computeIntersection(const Segment& s1, const Segment& s2, Point& result); // 计算两线段交点
    double getXAtY(double y) const; // 获取线段在给定y坐标处的x值
};

关键方法详细解析

contains方法使用叉积和点积双重验证点是否在线段上:

bool Segment::contains(const Point &p) const {
    // 第一步:检查点p的y坐标是否在线段的y坐标范围内
    // 使用容差fTOL来处理边界情况
    if (p.y < min(lower.y, upper.y) - fTOL ||
        p.y > max(lower.y, upper.y) + fTOL) {
        return false; // 点在线段的y范围之外
    }
    
    // 第二步:使用叉积判断三点是否共线
    // 叉积为0表示三点共线
    double cross = (p.x - upper.x) * (lower.y - upper.y) - 
                   (p.y - upper.y) * (lower.x - upper.x);
    if (fabs(cross) > fTOL) return false; // 不共线
    
    // 第三步:使用点积检查点p是否在线段的两端点之间
    // 计算向量UP和UL的点积,以及向量UL的长度的平方
    double dot = (p.x - upper.x) * (lower.x - upper.x) + 
                 (p.y - upper.y) * (lower.y - upper.y);
    double len_sq = (lower.x - upper.x) * (lower.x - upper.x) + 
                    (lower.y - upper.y) * (lower.y - upper.y);
    
    // 如果点积在0和长度平方之间,说明点在线段上
    return dot >= -fTOL && dot <= len_sq + fTOL;
}

computeIntersection方法使用参数方程法计算线段交点:

bool Segment::computeIntersection(const Segment &s1, const Segment &s2, Point& result) {
    // 获取两条线段的端点
    Point a1 = s1.upper, a2 = s1.lower;
    Point b1 = s2.upper, b2 = s2.lower;
    
    // 计算两条线段的方向向量
    double dx1 = a2.x - a1.x, dy1 = a2.y - a1.y;
    double dx2 = b2.x - b1.x, dy2 = b2.y - b1.y;
    
    // 计算方向向量的叉积(行列式)
    double det = dx1 * dy2 - dy1 * dx2;
    
    // 如果叉积接近0,说明线段平行或共线
    if (fabs(det) < fTOL) {
        // 处理平行和共线情况
        // 检查是否共线:计算点a1到b1的向量与线段s1方向向量的叉积
        if (fabs((a1.x - b1.x) * dy1 - (a1.y - b1.y) * dx1) < fTOL) {
            // 共线情况:检查是否有重叠部分
            // 使用参数t来表示点在线段上的位置
            double t0 = 0, t1 = 1;
            if (fabs(dx1) > fTOL) {
                // 使用x坐标计算参数
                t0 = (b1.x - a1.x) / dx1;
                t1 = (b2.x - a1.x) / dx1;
            } else if (fabs(dy1) > fTOL) {
                // 使用y坐标计算参数(处理垂直线段)
                t0 = (b1.y - a1.y) / dy1;
                t1 = (b2.y - a1.y) / dy1;
            }
            
            // 确保t0 <= t1
            if (t0 > t1) swap(t0, t1);
            
            // 检查参数范围是否有重叠
            if (t1 < -fTOL || t0 > 1 + fTOL) return false; // 没有重叠
            
            // 计算重叠部分的起点
            double t = max(0.0, t0);
            result = Point(a1.x + t * dx1, a1.y + t * dy1);
            return true;
        }
        return false; // 平行但不共线,没有交点
    }
    
    // 不平行的情况:使用克莱姆法则求解参数方程
    // 参数u表示点在线段s1上的位置,v表示在线段s2上的位置
    double u = ((b1.x - a1.x) * dy2 - (b1.y - a1.y) * dx2) / det;
    double v = ((b1.x - a1.x) * dy1 - (b1.y - a1.y) * dx1) / det;
    
    // 检查交点是否在线段范围内
    if (u >= -fTOL && u <= 1 + fTOL && v >= -fTOL && v <= 1 + fTOL) {
        // 计算交点坐标
        result = Point(a1.x + u * dx1, a1.y + u * dy1);
        return true;
    }
    
    return false; // 交点在线段的延长线上
}

事件管理系统

事件(Event)结构体
struct Event {
    Point point;                    // 事件点的坐标
    EventType type;                 // 事件类型
    std::vector<Segment> segments;  // 与该事件点相关的线段
    
    // 构造函数
    Event(Point p, EventType t, const std::vector<Segment>& segs = {});
    
    // 比较运算符
    bool operator<(const Event& other) const; // 用于事件队列排序
    bool operator==(const Event& other) const; // 判断事件是否相同
};

事件类型说明

  • UPPER_ENDPOINT:线段上端点事件,当扫描线遇到线段的上端点时触发
  • LOWER_ENDPOINT:线段下端点事件,当扫描线遇到线段的下端点时触发
  • INTERSECTION:交点事件,当扫描线遇到两条线段的交点时触发
事件队列(EventQueue)

使用AVL树实现优先队列,确保高效的事件插入和提取:

class EventQueue {
private:
    AVLTree<Event, EventCompare> tree; // 使用AVL树存储事件
    
public:
    void addEvent(const Event& event);      // 添加事件到队列
    Event nextEvent() const;                // 查看下一个事件(不移除)
    Event extractNextEvent();               // 取出并移除下一个事件
    bool empty() const;                     // 判断队列是否为空
    std::vector<Event> getAllEvents() const; // 获取所有事件(用于调试)
};

状态结构(StatusStructure)

状态结构维护当前与扫描线相交的线段,按它们在扫描线处的x坐标排序:

class StatusStructure {
private:
    double currentScanlineY;                    // 当前扫描线的y坐标
    AVLTree<Segment, SegmentCompare> tree;      // 使用AVL树存储线段
    
public:
    StatusStructure(double initialY);           // 构造函数,初始化扫描线位置
    
    void updateScanline(double y);              // 更新扫描线位置
    void insert(const Segment& s);              // 插入线段到状态结构
    void remove(const Segment& s);              // 从状态结构中移除线段
    bool contains(const Segment& s) const;      // 检查线段是否在状态结构中
    bool predecessor(const Segment& s, Segment& result) const; // 查找前驱线段
    bool successor(const Segment& s, Segment& result) const;   // 查找后继线段
    bool findSegmentContaining(const Point& p, Segment& result) const; // 查找包含点的线段
    std::vector<Segment> getAllSegments() const; // 获取所有线段(按x坐标排序)
    std::vector<std::pair<Segment, Segment>> getAdjacentPairs() const; // 获取相邻线段对
};

**线段比较器(SegmentCompare)**是关键组件,根据当前扫描线位置比较线段:

struct SegmentCompare {
    double scanlineY; // 当前扫描线的y坐标
    
    SegmentCompare(); // 默认构造函数
    SegmentCompare(double y); // 带参数的构造函数
    
    // 比较操作符:根据线段在当前扫描线位置的x坐标进行比较
    bool operator()(const Segment& s1, const Segment& s2) const {
        // 计算两条线段在当前扫描线y坐标处的x值
        double x1 = s1.getXAtY(scanlineY);
        double x2 = s2.getXAtY(scanlineY);
        
        // 如果x坐标非常接近,使用线段的固有顺序进行比较
        if (fabs(x1 - x2) < fTOL) {
            return s1 < s2; // 稳定比较,避免相等时的随机行为
        }
        return x1 < x2; // 正常情况下按x坐标比较
    }
};

平面扫描算法核心

算法初始化
PlaneSweep::PlaneSweep(const vector<Segment> &inputSegments) 
    : segments(inputSegments), status(numeric_limits<double>::max()) {
    // 初始化状态结构的扫描线位置为最大值(从顶部开始扫描)
    
    // 为每条线段创建端点事件并加入事件队列
    for (const auto &seg : segments) {
        // 添加上端点事件:扫描线第一次遇到线段
        eventQueue.addEvent(Event(seg.upper, EventType::UPPER_ENDPOINT, {seg}));
        // 添加下端点事件:扫描线离开线段
        eventQueue.addEvent(Event(seg.lower, EventType::LOWER_ENDPOINT, {seg}));
    }
}
主循环
void PlaneSweep::run() {
    cout << "Starting plane sweep algorithm..." << endl;
    cout << "Total segments: " << segments.size() << endl;
    
    // 主循环:处理所有事件点,直到事件队列为空
    while (!eventQueue.empty()) {
        // 取出y坐标最大的事件点(从上到下扫描)
        Event event = eventQueue.extractNextEvent();
        cout << "\nProcessing event at (" << event.point.x << ", " << event.point.y << ")" << endl;

        // 更新扫描线位置到当前事件点的y坐标
        // 同时重新排序状态结构中的线段
        status.updateScanline(event.point.y);

        // 处理当前事件点
        handleEventPoint(event);
    }
    
    cout << "\nAlgorithm completed. Found " << intersections.size() << " intersections." << endl;
}
事件点处理

这是算法最核心的部分,处理三种类型的事件点:

void PlaneSweep::handleEventPoint(const Event& event) {
    const Point& p = event.point; // 当前事件点的坐标
    
    // 分类与当前事件点相关的线段
    vector<Segment> U, L, C;
    classifySegments(p, U, L, C);
    
    // 如果有多条线段与当前事件点相关,报告交点
    if (U.size() + L.size() + C.size() > 1) {
        // 收集所有相关线段
        vector<Segment> allSegs;
        allSegs.insert(allSegs.end(), U.begin(), U.end());
        allSegs.insert(allSegs.end(), L.begin(), L.end());
        allSegs.insert(allSegs.end(), C.begin(), C.end());
        
        // 使用set去除重复线段
        set<Segment> uniqueSegs(allSegs.begin(), allSegs.end());
        vector<Segment> uniqueVec(uniqueSegs.begin(), uniqueSegs.end());
        
        // 检查所有唯一的线段对组合
        for (size_t i = 0; i < uniqueVec.size(); ++i) {
            for (size_t j = i + 1; j < uniqueVec.size(); ++j) {
                // 检查线段对是否有交点(排除当前事件点)
                checkAndAddIntersection(uniqueVec[i], uniqueVec[j], p);
            }
        }
    }

    // 更新状态结构:
    // 1. 移除下端点和包含当前点的线段(L和C)
    for (const auto &seg : L) {
        if (status.contains(seg)) {
            status.remove(seg); // 线段结束,从状态结构中移除
        }
    }
    for (const auto &seg : C) {
        if (status.contains(seg)) {
            status.remove(seg); // 暂时移除,稍后重新插入以更新顺序
        }
    }

    // 2. 插入上端点和包含当前点的线段(U和C)
    for (const auto &seg : U) {
        status.insert(seg); // 新线段开始,加入状态结构
    }
    for (const auto &seg : C) {
        status.insert(seg); // 重新插入以更新在状态结构中的顺序
    }

    // 检查新相邻线段对的交点
    checkAdjacentIntersections(U, p); // 检查新插入线段与邻居的交点
    checkAdjacentIntersections(C, p); // 检查重新插入线段与邻居的交点
}
线段分类方法
void PlaneSweep::classifySegments(const Point& p,
                                 vector<Segment>& U,  // 输出:上端点在p的线段
                                 vector<Segment>& L,  // 输出:下端点在p的线段
                                 vector<Segment>& C) const { // 输出:包含p但不是端点的线段
    // 遍历所有输入线段,检查它们与点p的关系
    for (const auto& seg : segments) {
        if (seg.upper == p) {
            // 线段的上端点就是当前事件点p
            U.push_back(seg);
        } else if (seg.lower == p) {
            // 线段的下端点就是当前事件点p
            L.push_back(seg);
        } else if (seg.contains(p)) {
            // 线段包含点p,但p不是端点
            // 避免重复添加端点已经在U或L中的线段
            if (!(seg.upper == p || seg.lower == p)) {
                C.push_back(seg);
            }
        }
    }
}
交点检测方法
void PlaneSweep::checkAndAddIntersection(const Segment& s1, const Segment& s2, const Point& p) {
    // 避免检查相同的线段
    if (s1 == s2) return;
    
    Point intersection;
    // 计算两条线段的交点
    if (Segment::computeIntersection(s1, s2, intersection)) {
        // 验证交点确实在线段上且不是当前事件点
        if (s1.contains(intersection) && s2.contains(intersection) && !(intersection == p)) {
            // 创建交点记录
            IntersectionRecord record(s1, s2, intersection);
            
            // 检查是否已经存在相同的交点记录(去重)
            bool exists = false;
            for (const auto& existing : intersections) {
                if (existing == record) {
                    exists = true;
                    break;
                }
            }
            
            // 如果是新发现的交点,添加到结果集和事件队列
            if (!exists) {
                intersections.insert(record);
                // 将交点作为新的事件点加入队列
                eventQueue.addEvent(Event(intersection, EventType::INTERSECTION, {s1, s2}));
                cout << "Found intersection at (" << intersection.x << ", " << intersection.y << ")" << endl;
            }
        }
    }
}

AVL树实现

算法使用AVL树作为底层数据结构,确保所有操作在O(log n)时间内完成。AVL树是一种自平衡二叉搜索树,通过旋转操作维持树的平衡。

AVL树节点结构
template <typename T, typename Compare = std::less<T>>
struct AVLNode {
    T key;  // 节点存储的数据
    std::shared_ptr<AVLNode<T, Compare>> left, right, parent; // 左右子节点和父节点指针
    int height;  // 节点高度,用于计算平衡因子

    // 构造函数
    AVLNode(const T &key) : key(key), left(nullptr), right(nullptr), parent(nullptr), height(1) {}
};
AVL树核心操作
template <typename T, typename Compare = std::less<T>>
class AVLTree {
private:
    std::shared_ptr<AVLNode<T, Compare>> root; // 树的根节点
    Compare comp; // 比较函数对象

public:
    // 基本操作
    void insert(const T &key);  // 插入元素
    void remove(const T &key);  // 删除元素
    bool contains(const T &key) const; // 检查元素是否存在
    
    // 查询操作
    T findMin() const; // 查找最小元素
    T findMax() const; // 查找最大元素
    bool predecessor(const T &key, T& result) const; // 查找前驱
    bool successor(const T &key, T& result) const; // 查找后继
    
    // 工具函数
    bool empty() const; // 判断树是否为空
    std::vector<T> inorder() const; // 中序遍历

private:
    // 内部辅助函数
    int height(std::shared_ptr<AVLNode<T, Compare>> node) const; // 获取节点高度
    void updateHeight(std::shared_ptr<AVLNode<T, Compare>> node); // 更新节点高度
    int balanceFactor(std::shared_ptr<AVLNode<T, Compare>> node) const; // 计算平衡因子
    
    // 旋转操作
    std::shared_ptr<AVLNode<T, Compare>> rotateRight(std::shared_ptr<AVLNode<T, Compare>> y); // 右旋
    std::shared_ptr<AVLNode<T, Compare>> rotateLeft(std::shared_ptr<AVLNode<T, Compare>> x); // 左旋
    
    // 平衡操作
    std::shared_ptr<AVLNode<T, Compare>> balance(std::shared_ptr<AVLNode<T, Compare>> node); // 平衡节点
    
    // 递归辅助函数
    std::shared_ptr<AVLNode<T, Compare>> insertHelper(std::shared_ptr<AVLNode<T, Compare>> node, const T &key);
    std::shared_ptr<AVLNode<T, Compare>> removeHelper(std::shared_ptr<AVLNode<T, Compare>> node, const T &key);
    std::shared_ptr<AVLNode<T, Compare>> findHelper(std::shared_ptr<AVLNode<T, Compare>> node, const T &key) const;
};
旋转操作详解

右旋转操作(处理LL型不平衡):

template <typename T, typename Compare>
std::shared_ptr<AVLNode<T, Compare>> AVLTree<T, Compare>::rotateRight(std::shared_ptr<AVLNode<T, Compare>> y) {
    // y是需要旋转的不平衡节点
    auto x = y->left;    // x是y的左子节点,将成为新的根节点
    auto T2 = x->right;  // T2是x的右子树
    
    // 执行旋转
    x->right = y;        // 将y作为x的右子节点
    y->left = T2;        // 将T2作为y的左子树
    
    // 更新父指针
    if (T2) T2->parent = y;
    x->parent = y->parent;
    y->parent = x;
    
    // 更新节点高度
    updateHeight(y);
    updateHeight(x);
    
    return x; // 返回新的根节点
}

左旋转操作(处理RR型不平衡):

template <typename T, typename Compare>
std::shared_ptr<AVLNode<T, Compare>> AVLTree<T, Compare>::rotateLeft(std::shared_ptr<AVLNode<T, Compare>> x) {
    // x是需要旋转的不平衡节点
    auto y = x->right;   // y是x的右子节点,将成为新的根节点
    auto T2 = y->left;   // T2是y的左子树
    
    // 执行旋转
    y->left = x;         // 将x作为y的左子节点
    x->right = T2;       // 将T2作为x的右子树
    
    // 更新父指针
    if (T2) T2->parent = x;
    y->parent = x->parent;
    x->parent = y;
    
    // 更新节点高度
    updateHeight(x);
    updateHeight(y);
    
    return y; // 返回新的根节点
}
平衡操作
template <typename T, typename Compare>
std::shared_ptr<AVLNode<T, Compare>> AVLTree<T, Compare>::balance(std::shared_ptr<AVLNode<T, Compare>> node) {
    if (!node) return nullptr;
    
    // 更新当前节点高度
    updateHeight(node);
    
    // 计算平衡因子(左子树高度 - 右子树高度)
    int bf = balanceFactor(node);
    
    // LL情况:左子树比右子树高超过1,且左子树的左子树更高或相等
    if (bf > 1 && balanceFactor(node->left) >= 0) {
        return rotateRight(node); // 单次右旋
    }
    
    // LR情况:左子树比右子树高超过1,且左子树的右子树更高
    if (bf > 1 && balanceFactor(node->left) < 0) {
        node->left = rotateLeft(node->left); // 先对左子节点左旋
        return rotateRight(node);            // 再对当前节点右旋
    }
    
    // RR情况:右子树比左子树高超过1,且右子树的右子树更高或相等
    if (bf < -1 && balanceFactor(node->right) <= 0) {
        return rotateLeft(node); // 单次左旋
    }
    
    // RL情况:右子树比左子树高超过1,且右子树的左子树更高
    if (bf < -1 && balanceFactor(node->right) > 0) {
        node->right = rotateRight(node->right); // 先对右子节点右旋
        return rotateLeft(node);               // 再对当前节点左旋
    }
    
    return node; // 节点已经平衡,无需旋转
}

复杂度分析

时间复杂度

  • 最好情况:O(n log n) - 当没有交点时,只需要处理2n个端点事件
  • 最坏情况:O((n+k) log n) - 当有k个交点时,需要处理2n+k个事件
  • 平均情况:O((n+k) log n) - 对于随机线段集合

详细分析

  • 每个线段端点处理:O(log n) - AVL树操作
  • 每个交点处理:O(log n) - 状态结构更新和相邻检查
  • 状态结构操作:O(log n) - 插入、删除、查找
  • 事件队列操作:O(log n) - 事件插入和提取

空间复杂度

  • 主要开销:O(n + k)
  • 事件队列存储:O(n + k) - 存储所有端点和交点事件
  • 状态结构存储:O(n) - 存储当前与扫描线相交的线段
  • 输出结果存储:O(k) - 存储所有发现的交点
  • 递归栈深度:O(log n) - AVL树操作的最大递归深度

优势与局限性

优势

  1. 高效性:相比朴素的O(n²)暴力算法,在大规模数据上性能提升显著
  2. 输出敏感性:时间复杂度与交点数量k相关,对稀疏交点情况特别友好
  3. 通用性:能够处理各种线段配置,包括水平、垂直、退化情况
  4. 数值稳定性:使用容差机制处理浮点数精度问题,提高算法鲁棒性
  5. 理论基础坚实:基于成熟的扫描线算法范式,可靠性高

局限性

  1. 实现复杂度高:需要同时维护事件队列和状态结构两个复杂数据结构
  2. 常数因子较大:由于频繁的平衡树操作,实际运行时的常数因子较明显
  3. 浮点精度挑战:对近乎平行或重合的线段处理仍有数值稳定性问题
  4. 内存占用相对较高:需要存储所有事件点,包括中间发现的交点
  5. 不适合动态更新:线段集合发生变化时需要重新构建整个数据结构

总结与扩展

平面扫描算法是计算几何中的经典算法,通过巧妙的扫描线思想和高效的数据结构,成功将线段交点问题的时间复杂度从O(n²)优化到O((n+k)log n)。本文详细解析了算法的C++实现,包括关键的数据结构设计和算法流程。

算法核心要点总结

  1. 扫描线范式:将二维几何问题转化为一维序列处理问题
  2. 事件驱动:通过事件点来驱动算法流程,有序处理几何变化
  3. 状态维护:动态维护与扫描线相交的线段集合
  4. 相邻检测:只在状态结构中相邻的线段之间检测新交点
  5. 平衡数据结构:使用AVL树保证所有操作的高效性

扩展思考与优化方向

  1. 性能优化

    • 使用更简单的平衡树结构(如红黑树)减少旋转操作
    • 批处理操作减少树的重建次数
    • 使用整数坐标或有理数运算避免浮点误差
  2. 功能扩展

    • 支持线段添加和删除的动态操作
    • 扩展到时序数据或移动线段
    • 支持其他几何形状的交点检测
  3. 应用领域扩展

    • 多边形布尔运算和裁剪
    • VLSI电路设计中的网络分析和验证
    • 地理信息系统的空间索引和查询
    • 计算机图形学的可见性计算
  4. 工程实践建议

    • 添加详细的日志和调试信息
    • 实现序列化接口便于结果保存和验证
    • 提供可视化工具辅助算法理解和调试

平面扫描算法不仅解决了具体的线段交点问题,更重要的是展示了如何通过空间扫描和高效数据结构来降低问题复杂度的方法论。这种"降维打击"的思想在计算几何和算法设计中具有广泛的指导意义,是每个算法学习者都应该掌握的重要范式。


完整代码实现

PlaneSweepAlgorithm.h

#pragma once
#include <vector>
#include <set>
#include "PlaneSweepAlgorithm.hpp"

/************************点结构体********************************/
struct Point {
    double x, y;  // 点的x和y坐标
    
    // 构造函数:初始化点坐标,默认值为(0,0)
    Point(double x = 0, double y = 0);
    
    // 比较运算符重载
    bool operator==(const Point& other) const;  // 判断两点是否相等(考虑浮点容差)
    bool operator!=(const Point& other) const;  // 判断两点是否不相等
    bool operator<(const Point& other) const;   // 用于排序:先按y坐标降序,再按x坐标升序
};

/************************线段结构体********************************/
struct Segment {
    Point upper, lower; // 上端点和下端点(y坐标大的为上端点)
    
    // 构造函数:根据两点自动确定上下端点
    Segment(Point p1, Point p2);
    
    // 判断点p是否在线段上(包括端点)
    bool contains(const Point& p) const;

    // 比较运算符
    bool operator==(const Segment& other) const; // 判断两线段是否相同(考虑端点顺序)
    bool operator<(const Segment& other) const;  // 用于线段的稳定排序

    // 静态方法:计算两条线段的交点
    static bool computeIntersection(const Segment& s1, const Segment& s2, Point& result);
    
    // 获取线段在给定y坐标处的x值(用于扫描线比较)
    double getXAtY(double y) const;
};

/************************事件类型********************************/
enum class EventType {
    UPPER_ENDPOINT,  // 上端点事件:扫描线遇到线段上端点
    LOWER_ENDPOINT,  // 下端点事件:扫描线遇到线段下端点
    INTERSECTION     // 交点事件:扫描线遇到两线段交点
};

/************************事件结构体********************************/
struct Event {
    Point point;                    // 事件点的坐标
    EventType type;                 // 事件类型
    std::vector<Segment> segments; // 与该事件点相关的线段集合
    
    // 构造函数
    Event(Point p, EventType t, const std::vector<Segment>& segs = {});
    
    // 比较运算符
    bool operator<(const Event& other) const; // 用于事件队列排序
    bool operator==(const Event& other) const; // 判断事件是否相同
};

/************************事件点比较函数********************************/
struct EventCompare {
    // 函数调用操作符:比较两个事件的优先级
    bool operator()(const Event& e1, const Event& e2) const;
};

/************************事件队列********************************/
class EventQueue {
private:
    AVLTree<Event, EventCompare> tree; // 使用AVL树实现优先队列
    
public:
    // 添加事件到队列
    void addEvent(const Event& event);
    
    // 获取下一个事件(不移除)
    Event nextEvent() const;
    
    // 取出并移除下一个事件
    Event extractNextEvent();
    
    // 判断队列是否为空
    bool empty() const;
    
    // 获取所有事件(用于调试和测试)
    std::vector<Event> getAllEvents() const;
};

/************************线段比较函数(用于状态结构)********************************/
struct SegmentCompare {
    double scanlineY; // 当前扫描线的y坐标
    
    // 构造函数
    SegmentCompare();
    SegmentCompare(double y);
    
    // 比较操作符:根据线段在当前扫描线位置的x坐标进行比较
    bool operator()(const Segment& s1, const Segment& s2) const;
};

/************************算法状态结构********************************/
class StatusStructure {
private:
    double currentScanlineY;                    // 当前扫描线的y坐标
    AVLTree<Segment, SegmentCompare> tree;      // 使用AVL树存储当前相交的线段
    
public:
    // 构造函数:初始化状态结构,设置初始扫描线位置
    StatusStructure(double initialY);
    
    // 更新扫描线位置,并重新排序线段
    void updateScanline(double y);
    
    // 插入线段到状态结构
    void insert(const Segment& s);
    
    // 从状态结构中移除线段
    void remove(const Segment& s);
    
    // 检查线段是否在状态结构中
    bool contains(const Segment& s) const;
    
    // 获取线段的前驱(按当前扫描线位置的x坐标)
    bool predecessor(const Segment& s, Segment& result) const;

    // 获取线段的后继(按当前扫描线位置的x坐标)
    bool successor(const Segment& s, Segment& result) const;

    // 查找包含给定点的线段
    bool findSegmentContaining(const Point& p, Segment& result) const;
    
    // 获取所有线段(按当前扫描线位置的x坐标排序)
    std::vector<Segment> getAllSegments() const;
    
    // 获取相邻线段对(用于交点检测)
    std::vector<std::pair<Segment, Segment>> getAdjacentPairs() const;
};

/************************相交记录结构体********************************/
struct IntersectionRecord {
    Segment seg1;        // 相交的第一条线段
    Segment seg2;        // 相交的第二条线段
    Point intersection;  // 交点坐标
    
    // 构造函数
    IntersectionRecord(const Segment& s1, const Segment& s2, const Point& p);
        
    // 比较运算符
    bool operator<(const IntersectionRecord& other) const; // 用于set排序
    bool operator==(const IntersectionRecord& other) const; // 判断记录是否相同
};

/************************平面扫描算法********************************/
class PlaneSweep
{
private:
    EventQueue eventQueue;           // 事件队列
    std::vector<Segment> segments;   // 输入线段集合
    StatusStructure status;          // 状态结构
    std::set<IntersectionRecord> intersections; // 发现的交点集合

public:
    // 构造函数:初始化算法,构建初始事件队列
    PlaneSweep(const std::vector<Segment> &inputSegments);

    // 运行平面扫描算法
    void run();

    // 获取所有相交记录
    std::set<IntersectionRecord> getIntersections() const;

private:
    // 处理事件点(算法核心)
    void handleEventPoint(const Event& event);
    
    // 分类线段:U(上端点)、L(下端点)、C(包含点)
    void classifySegments(const Point& p, 
                         std::vector<Segment>& U, 
                         std::vector<Segment>& L, 
                         std::vector<Segment>& C) const;
    
    // 检查并添加交点事件
    void checkAndAddIntersection(const Segment& s1, const Segment& s2, const Point& p);
    
    // 检查相邻线段交点
    void checkAdjacentIntersections(const std::vector<Segment>& newSegments, const Point& p);
};

PlaneSweepAlgorithm.cpp

#include <iostream>
#include <algorithm>
#include <limits>
#include <cassert>
#include <cmath>
#include "PlaneSweepAlgorithm.h"

using namespace std;

#define fTOL 1e-10  // 浮点数比较容差,用于处理精度问题

/************************点结构体实现********************************/
Point::Point(double x, double y) : x(x), y(y) {}

bool Point::operator==(const Point &other) const
{
    // 使用容差比较两个点的坐标是否相等
    return fabs(x - other.x) < fTOL && fabs(y - other.y) < fTOL;
}

bool Point::operator!=(const Point &other) const
{
    return !(*this == other);
}

bool Point::operator<(const Point &other) const
{
    // 优先按y坐标降序排列(从上到下扫描)
    if (fabs(y - other.y) >= fTOL)
        return y > other.y;
    // y坐标相同时,按x坐标升序排列
    return x < other.x;
}

/************************线段结构体实现********************************/
Segment::Segment(Point p1, Point p2)
{
    // 确定上下端点:y坐标大的为上端点
    // y坐标相同时,x坐标小的为上端点
    if (p1.y > p2.y || (fabs(p1.y - p2.y) < fTOL && p1.x < p2.x))
    {
        upper = p1;
        lower = p2;
    }
    else
    {
        upper = p2;
        lower = p1;
    }
}

// 判断点是否在线段上
bool Segment::contains(const Point &p) const
{
    // 第一步:检查点p的y坐标是否在线段的y坐标范围内
    if (p.y < min(lower.y, upper.y) - fTOL ||
        p.y > max(lower.y, upper.y) + fTOL)
    {
        return false;
    }

    // 处理水平线段:检查x坐标范围
    if (fabs(upper.y - lower.y) < fTOL)
    {
        return p.x >= min(upper.x, lower.x) - fTOL &&
               p.x <= max(upper.x, lower.x) + fTOL;
    }

    // 处理垂直线段:检查x坐标是否相等
    if (fabs(upper.x - lower.x) < fTOL)
    {
        return fabs(p.x - upper.x) < fTOL;
    }

    // 一般情况:使用叉积判断三点是否共线
    double cross = (p.x - upper.x) * (lower.y - upper.y) - 
                   (p.y - upper.y) * (lower.x - upper.x);
    if (fabs(cross) > fTOL) return false;

    // 使用点积检查点p是否在线段的两端点之间
    double dot = (p.x - upper.x) * (lower.x - upper.x) + 
                 (p.y - upper.y) * (lower.y - upper.y);
    double len_sq = (lower.x - upper.x) * (lower.x - upper.x) + 
                    (lower.y - upper.y) * (lower.y - upper.y);
    
    return dot >= -fTOL && dot <= len_sq + fTOL;
}

bool Segment::operator==(const Segment &other) const
{
    // 两条线段相等当且仅当它们的端点相同(不考虑顺序)
    return (upper == other.upper && lower == other.lower) ||
           (upper == other.lower && lower == other.upper);
}

bool Segment::operator<(const Segment &other) const
{
    // 定义点比较函数:先按x坐标,再按y坐标
    auto comparePoints = [](const Point &a, const Point &b)
    {
        if (fabs(a.x - b.x) >= fTOL) return a.x < b.x;
        return a.y < b.y;
    };

    // 获取当前线段的最小和最大端点
    Point min1 = comparePoints(upper, lower) ? upper : lower;
    Point max1 = comparePoints(upper, lower) ? lower : upper;

    // 获取比较线段的最小和最大端点
    Point min2 = comparePoints(other.upper, other.lower) ? other.upper : other.lower;
    Point max2 = comparePoints(other.upper, other.lower) ? other.lower : other.upper;

    // 先比较最小端点,再比较最大端点
    if (min1 != min2)
        return comparePoints(min1, min2);
    return comparePoints(max1, max2);
}

// 获取线段在给定y坐标的x值
double Segment::getXAtY(double y) const
{
    // 处理水平线段
    if (fabs(upper.y - lower.y) < fTOL)
        return min(upper.x, lower.x);
    
    // 如果y坐标等于端点y坐标,直接返回端点x坐标
    if (fabs(y - upper.y) < fTOL)
        return upper.x;
    if (fabs(y - lower.y) < fTOL)
        return lower.x;
    
    // 使用线性插值计算x坐标
    double t = (y - lower.y) / (upper.y - lower.y);
    return lower.x + t * (upper.x - lower.x);
}

// 计算两条线段的交点
bool Segment::computeIntersection(const Segment &s1, const Segment &s2, Point& result)
{
    Point a1 = s1.upper, a2 = s1.lower;
    Point b1 = s2.upper, b2 = s2.lower;

    // 计算方向向量
    double dx1 = a2.x - a1.x, dy1 = a2.y - a1.y;
    double dx2 = b2.x - b1.x, dy2 = b2.y - b1.y;

    // 计算叉积(行列式)
    double det = dx1 * dy2 - dy1 * dx2;

    // 检查线段是否平行
    if (fabs(det) < fTOL) {
        // 检查是否共线
        if (fabs((a1.x - b1.x) * dy1 - (a1.y - b1.y) * dx1) < fTOL) {
            // 共线情况:检查是否有重叠
            double t0 = 0, t1 = 1;
            if (fabs(dx1) > fTOL) {
                t0 = (b1.x - a1.x) / dx1;
                t1 = (b2.x - a1.x) / dx1;
            } else if (fabs(dy1) > fTOL) {
                t0 = (b1.y - a1.y) / dy1;
                t1 = (b2.y - a1.y) / dy1;
            }
            
            if (t0 > t1) swap(t0, t1);
            
            if (t1 < -fTOL || t0 > 1 + fTOL) return false;
            
            double t = max(0.0, t0);
            result = Point(a1.x + t * dx1, a1.y + t * dy1);
            return true;
        }
        return false; // 平行但不共线
    }

    // 不平行情况:使用参数方程求解
    double u = ((b1.x - a1.x) * dy2 - (b1.y - a1.y) * dx2) / det;
    double v = ((b1.x - a1.x) * dy1 - (b1.y - a1.y) * dx1) / det;

    // 检查交点是否在线段范围内
    if (u >= -fTOL && u <= 1 + fTOL && v >= -fTOL && v <= 1 + fTOL) {
        result = Point(a1.x + u * dx1, a1.y + u * dy1);
        return true;
    }

    return false;
}

/************************事件结构体实现********************************/
Event::Event(Point p, EventType t, const std::vector<Segment>& segs) 
    : point(p), type(t), segments(segs) {}

bool Event::operator<(const Event& other) const {
    return point < other.point;
}

bool Event::operator==(const Event& other) const {
    return point == other.point && type == other.type;
}

/************************事件点比较函数实现********************************/
bool EventCompare::operator()(const Event& e1, const Event& e2) const {
    return e1 < e2;
}

/************************事件队列实现********************************/
void EventQueue::addEvent(const Event& event) {
    tree.insert(event);
}

Event EventQueue::nextEvent() const {
    if (tree.empty()) {
        throw runtime_error("Event queue is empty");
    }
    return tree.findMin();
}

Event EventQueue::extractNextEvent() {
    if (tree.empty()) {
        throw runtime_error("Event queue is empty");
    }
    Event e = tree.findMin();
    tree.remove(e);
    return e;
}

bool EventQueue::empty() const {
    return tree.empty();
}

vector<Event> EventQueue::getAllEvents() const {
    return tree.inorder();
}

/************************线段比较函数实现********************************/
SegmentCompare::SegmentCompare() : scanlineY(0) {}
SegmentCompare::SegmentCompare(double y) : scanlineY(y) {}

bool SegmentCompare::operator()(const Segment& s1, const Segment& s2) const {
    // 计算两条线段在当前扫描线位置的x坐标
    double x1 = s1.getXAtY(scanlineY);
    double x2 = s2.getXAtY(scanlineY);
    
    // 如果x坐标非常接近,使用线段的稳定比较
    if (fabs(x1 - x2) < fTOL) {
        return s1 < s2;
    }
    return x1 < x2;
}

/************************算法状态结构实现********************************/
StatusStructure::StatusStructure(double initialY) : currentScanlineY(initialY), tree(SegmentCompare(initialY)) {}

void StatusStructure::updateScanline(double y) {
    currentScanlineY = y;
    // 重新构建树以更新比较器中的扫描线位置
    auto segments = tree.inorder();
    tree = AVLTree<Segment, SegmentCompare>(SegmentCompare(y));
    for (const auto& seg : segments) {
        tree.insert(seg);
    }
}

void StatusStructure::insert(const Segment& s) {
    tree.insert(s);
}

void StatusStructure::remove(const Segment& s) {
    tree.remove(s);
}

bool StatusStructure::contains(const Segment& s) const {
    return tree.contains(s);
}

bool StatusStructure::predecessor(const Segment& s, Segment& result) const {
    return tree.predecessor(s, result);
}

bool StatusStructure::successor(const Segment& s, Segment& result) const {
    return tree.successor(s, result);
}

bool StatusStructure::findSegmentContaining(const Point& p, Segment& result) const {
    auto segments = getAllSegments();
    for (const auto& seg : segments) {
        if (seg.contains(p)) {
            result = seg;
            return true;
        }
    }
    return false;
}

vector<Segment> StatusStructure::getAllSegments() const {
    return tree.inorder();
}

vector<pair<Segment, Segment>> StatusStructure::getAdjacentPairs() const {
    vector<Segment> segs = getAllSegments();
    vector<pair<Segment, Segment>> pairs;
    
    for (size_t i = 1; i < segs.size(); i++) {
        pairs.emplace_back(segs[i-1], segs[i]);
    }
    
    return pairs;
}

/************************相交记录结构体实现********************************/
IntersectionRecord::IntersectionRecord(const Segment& s1, const Segment& s2, const Point& p)
    : seg1(s1), seg2(s2), intersection(p) {}

bool IntersectionRecord::operator<(const IntersectionRecord& other) const {
    // 优先按交点坐标排序
    if (!(intersection == other.intersection))
        return intersection < other.intersection;
    
    // 交点相同时,按线段排序
    if (!(seg1 == other.seg1))
        return seg1 < other.seg1;
    
    return seg2 < other.seg2;
}

bool IntersectionRecord::operator==(const IntersectionRecord& other) const {
    // 考虑线段顺序无关性
    return (seg1 == other.seg1 && seg2 == other.seg2 && intersection == other.intersection) ||
           (seg1 == other.seg2 && seg2 == other.seg1 && intersection == other.intersection);
}

/************************平面扫描算法实现********************************/
PlaneSweep::PlaneSweep(const vector<Segment> &inputSegments) 
    : segments(inputSegments), status(numeric_limits<double>::max()) {
    
    // 初始化事件队列:为每条线段添加上下端点事件
    for (const auto &seg : segments) {
        // 添加上端点事件
        eventQueue.addEvent(Event(seg.upper, EventType::UPPER_ENDPOINT, {seg}));
        // 添加下端点事件
        eventQueue.addEvent(Event(seg.lower, EventType::LOWER_ENDPOINT, {seg}));
    }
}

void PlaneSweep::run() {
    cout << "Starting plane sweep algorithm..." << endl;
    cout << "Total segments: " << segments.size() << endl;
    
    // 主循环:处理所有事件点
    while (!eventQueue.empty()) {
        Event event = eventQueue.extractNextEvent();
        cout << "\nProcessing event at (" << event.point.x << ", " << event.point.y << ")" << endl;

        // 更新扫描线位置
        status.updateScanline(event.point.y);

        // 处理事件点
        handleEventPoint(event);
    }
    
    cout << "\nAlgorithm completed. Found " << intersections.size() << " intersections." << endl;
}

set<IntersectionRecord> PlaneSweep::getIntersections() const {
    return intersections;
}

void PlaneSweep::handleEventPoint(const Event& event) {
    const Point& p = event.point;
    
    vector<Segment> U, L, C;
    // 分类与事件点相关的线段
    classifySegments(p, U, L, C);

    // 如果有多条线段与事件点相关,检查交点
    if (U.size() + L.size() + C.size() > 1) {
        vector<Segment> allSegs;
        allSegs.insert(allSegs.end(), U.begin(), U.end());
        allSegs.insert(allSegs.end(), L.begin(), L.end());
        allSegs.insert(allSegs.end(), C.begin(), C.end());
        
        // 去重
        set<Segment> uniqueSegs(allSegs.begin(), allSegs.end());
        vector<Segment> uniqueVec(uniqueSegs.begin(), uniqueSegs.end());
        
        // 检查所有唯一线段对组合
        for (size_t i = 0; i < uniqueVec.size(); ++i) {
            for (size_t j = i + 1; j < uniqueVec.size(); ++j) {
                checkAndAddIntersection(uniqueVec[i], uniqueVec[j], p);
            }
        }
    }

    // 更新状态结构:先移除后插入
    for (const auto &seg : L) {
        if (status.contains(seg)) {
            status.remove(seg);
        }
    }
    for (const auto &seg : C) {
        if (status.contains(seg)) {
            status.remove(seg);
        }
    }

    // 插入新线段和重新插入包含当前点的线段
    for (const auto &seg : U) {
        status.insert(seg);
    }
    for (const auto &seg : C) {
        status.insert(seg);
    }

    // 检查相邻线段的新交点
    checkAdjacentIntersections(U, p);
    checkAdjacentIntersections(C, p);
}

void PlaneSweep::classifySegments(const Point& p,
                                 vector<Segment>& U,
                                 vector<Segment>& L,
                                 vector<Segment>& C) const {
    // 遍历所有线段,分类它们与点p的关系
    for (const auto& seg : segments) {
        if (seg.upper == p) {
            U.push_back(seg);
        } else if (seg.lower == p) {
            L.push_back(seg);
        } else if (seg.contains(p)) {
            // 避免重复添加端点已经在U或L中的线段
            if (!(seg.upper == p || seg.lower == p)) {
                C.push_back(seg);
            }
        }
    }
}

void PlaneSweep::checkAndAddIntersection(const Segment& s1, const Segment& s2, const Point& p) {
    // 避免检查相同的线段
    if (s1 == s2) return;
    
    Point intersection;
    if (Segment::computeIntersection(s1, s2, intersection)) {
        // 验证交点在线段上且不是当前事件点
        if (s1.contains(intersection) && s2.contains(intersection) && !(intersection == p)) {
            IntersectionRecord record(s1, s2, intersection);
            
            // 检查是否已经存在相同的交点记录
            bool exists = false;
            for (const auto& existing : intersections) {
                if (existing == record) {
                    exists = true;
                    break;
                }
            }
            
            if (!exists) {
                intersections.insert(record);
                eventQueue.addEvent(Event(intersection, EventType::INTERSECTION, {s1, s2}));
                cout << "Found intersection at (" << intersection.x << ", " << intersection.y << ")" << endl;
            }
        }
    }
}

void PlaneSweep::checkAdjacentIntersections(const vector<Segment>& newSegments, const Point& p) {
    if (newSegments.empty()) return;
    
    // 检查所有相邻线段对
    auto adjacentPairs = status.getAdjacentPairs();
    for (const auto& pair : adjacentPairs) {
        checkAndAddIntersection(pair.first, pair.second, p);
    }
    
    // 检查新插入线段与邻居的交点
    for (const auto& seg : newSegments) {
        Segment pred(Point(0,0), Point(0,0));
        if (status.predecessor(seg, pred)) {
            checkAndAddIntersection(pred, seg, p);
        }
        Segment succ(Point(0,0), Point(0,0));
        if (status.successor(seg, succ)) {
            checkAndAddIntersection(seg, succ, p);
        }
    }
}

PlaneSweepAlgorithm.hpp

#pragma once
#include <functional>
#include <memory>
#include <vector>
#include <algorithm>

/*******************************************************************
| 特性          | 描述                                            |
| ------------- | ----------------------------------------------- |
| 二叉搜索树性质 | 左子树所有节点 < 根节点 < 右子树所有节点           |
| 自平衡        | 任何节点的左右子树高度差(称为平衡因子)不超过1      |
| 平衡因子      | 定义为:左子树高度 - 右子树高,取值只能是 -1, 0, 1  |
| 旋转操作      | 通过  左旋、右旋、左右旋、右左旋  来维持平衡        |
*********************************************************************/

/************************平衡二叉搜索树节点模板********************************/
/**
 * @brief AVL树节点模板类
 * @tparam T 节点存储的数据类型
 * @tparam Compare 比较函数类型,默认为std::less<T>
 *
 * AVL树节点包含键值、左右子节点指针、父节点指针和高度信息
 * 使用智能指针管理内存,避免内存泄漏
 */
template <typename T, typename Compare = std::less<T>>
struct AVLNode
{
    T key;  ///< 节点存储的键值
    std::shared_ptr<AVLNode<T, Compare>> left, right, parent;  ///< 左子节点、右子节点、父节点指针
    int height;  ///< 节点高度,用于计算平衡因子

    /**
     * @brief 构造函数
     * @param key 节点键值
     *
     * 初始化节点,设置键值,左右子节点和父节点为空,高度为1
     */
    AVLNode(const T &key) : key(key), left(nullptr), right(nullptr), parent(nullptr), height(1) {}
};

/*******************************************************************
| 类型  | 触发条件         | 操作        |
| ---- | ---------------- | ----------- |
| LL型 | 左子树的左子树过高 | 一次 右旋    |
| RR型 | 右子树的右子树过高 | 一次 左旋    |
| LR型 | 左子树的右子树过高 | 先左旋再右旋 |
| RL型 | 右子树的左子树过高 | 先右旋再左旋 |
*********************************************************************/

/************************平衡二叉搜索树模板********************************/
/**
 * @brief AVL平衡二叉搜索树模板类
 * @tparam T 树中存储的数据类型
 * @tparam Compare 比较函数类型,默认为std::less<T>
 *
 * AVL树是一种自平衡二叉搜索树,通过旋转操作保持树的平衡
 * 保证任何节点的左右子树高度差不超过1,从而保证O(log n)的时间复杂度
 */
template <typename T, typename Compare = std::less<T>>
class AVLTree
{
public:
    /**
     * @brief 默认构造函数
     * 创建空的AVL树,使用默认比较函数
     */
    AVLTree() : root(nullptr), comp() {}
    
    /**
     * @brief 带比较函数的构造函数
     * @param comparator 自定义比较函数
     * 创建空的AVL树,使用指定的比较函数
     */
    AVLTree(const Compare &comparator) : root(nullptr), comp(comparator) {}

    /**
     * @brief 插入元素
     * @param key 要插入的键值
     * 在AVL树中插入新元素,插入后自动进行平衡调整
     */
    void insert(const T &key)
    {
        root = insertHelper(root, key);
    }

    /**
     * @brief 删除元素
     * @param key 要删除的键值
     * 从AVL树中删除指定元素,删除后自动进行平衡调整
     */
    void remove(const T &key)
    {
        root = removeHelper(root, key);
    }

    /**
     * @brief 查找元素是否存在
     * @param key 要查找的键值
     * @return 如果元素存在返回true,否则返回false
     */
    bool contains(const T &key) const
    {
        return findHelper(root, key) != nullptr;
    }

    /**
     * @brief 获取最小元素
     * @return 树中的最小元素
     * @throws std::runtime_error 如果树为空
     * 通过不断访问左子树找到最小元素
     */
    T findMin() const
    {
        auto node = findMin(root);
        if (!node)
            throw std::runtime_error("Tree is empty");
        return node->key;
    }

    /**
     * @brief 获取最大元素
     * @return 树中的最大元素
     * @throws std::runtime_error 如果树为空
     * 通过不断访问右子树找到最大元素
     */
    T findMax() const
    {
        auto node = findMax(root);
        if (!node)
            throw std::runtime_error("Tree is empty");
        return node->key;
    }

    /**
     * @brief 删除并返回最小元素
     * @return 被删除的最小元素
     * 先找到最小元素,然后删除它并返回其值
     */
    T extractMin()
    {
        T minVal = findMin();
        remove(minVal);
        return minVal;
    }

    /**
     * @brief 判断树是否为空
     * @return 如果树为空返回true,否则返回false
     */
    bool empty() const
    {
        return !root;
    }

    /**
     * @brief 中序遍历
     * @return 包含所有元素的有序向量
     * 按照升序返回树中的所有元素
     */
    std::vector<T> inorder() const
    {
        std::vector<T> result;
        inorderHelper(root, result);
        return result;
    }

    /**
     * @brief 查找元素节点
     * @param key 要查找的键值
     * @return 指向找到的节点的智能指针,如果未找到返回nullptr
     */
    std::shared_ptr<AVLNode<T, Compare>> findNode(const T &key) const
    {
        return findHelper(root, key);
    }

    /**
     * @brief 获取前驱元素
     * @param key 参考键值
     * @param result 存储前驱元素的结果变量
     * @return 如果找到前驱返回true,否则返回false
     * 前驱是比给定键值小的最大元素
     */
    bool predecessor(const T &key, T& result) const
    {
        auto node = findNode(key);
        if (!node)
            return false;

        auto pred = predecessorHelper(node);
        if (pred) {
            result = pred->key;
            return true;
        }
        return false;
    }

    /**
     * @brief 获取后继元素
     * @param key 参考键值
     * @param result 存储后继元素的结果变量
     * @return 如果找到后继返回true,否则返回false
     * 后继是比给定键值大的最小元素
     */
    bool successor(const T &key, T& result) const
    {
        auto node = findNode(key);
        if (!node)
            return false;

        auto succ = successorHelper(node);
        if (succ) {
            result = succ->key;
            return true;
        }
        return false;
    }

private:
    std::shared_ptr<AVLNode<T, Compare>> root;  ///< 树的根节点
    Compare comp;  ///< 比较函数对象

    /**
     * @brief 获取节点高度
     * @param node 节点指针
     * @return 节点的高度,空节点高度为0
     */
    int height(std::shared_ptr<AVLNode<T, Compare>> node) const
    {
        return node ? node->height : 0;
    }

    /**
     * @brief 更新节点高度
     * @param node 要更新高度的节点
     * 节点高度 = 1 + max(左子树高度, 右子树高度)
     */
    void updateHeight(std::shared_ptr<AVLNode<T, Compare>> node)
    {
        if (node)
        {
            node->height = 1 + std::max(height(node->left), height(node->right));
        }
    }

    /**
     * @brief 获取平衡因子
     * @param node 节点指针
     * @return 平衡因子 = 左子树高度 - 右子树高度
     * 平衡因子应该在[-1, 0, 1]范围内,否则需要旋转调整
     */
    int balanceFactor(std::shared_ptr<AVLNode<T, Compare>> node) const
    {
        return height(node->left) - height(node->right);
    }

    /**
     * @brief 右旋转操作(LL情况)
     * @param y 不平衡的节点
     * @return 旋转后的新根节点
     *
     * 右旋转用于处理左子树过高的不平衡情况
     * 将y的左子节点x提升为新的根节点,y成为x的右子节点
     * x原来的右子树T2成为y的左子树
     */
    std::shared_ptr<AVLNode<T, Compare>> rotateRight(std::shared_ptr<AVLNode<T, Compare>> y)
    {
        auto x = y->left;
        auto T2 = x->right;

        // 执行旋转
        x->right = y;
        y->left = T2;

        // 更新父指针
        if (T2)
            T2->parent = y;
        x->parent = y->parent;
        y->parent = x;

        // 更新节点高度
        updateHeight(y);
        updateHeight(x);

        return x;
    }

    /**
     * @brief 左旋转操作(RR情况)
     * @param x 不平衡的节点
     * @return 旋转后的新根节点
     *
     * 左旋转用于处理右子树过高的不平衡情况
     * 将x的右子节点y提升为新的根节点,x成为y的左子节点
     * y原来的左子树T2成为x的右子树
     */
    std::shared_ptr<AVLNode<T, Compare>> rotateLeft(std::shared_ptr<AVLNode<T, Compare>> x)
    {
        auto y = x->right;
        auto T2 = y->left;

        // 执行旋转
        y->left = x;
        x->right = T2;

        // 更新父指针
        if (T2)
            T2->parent = x;
        y->parent = x->parent;
        x->parent = y;

        // 更新节点高度
        updateHeight(x);
        updateHeight(y);

        return y;
    }

    /**
     * @brief 平衡节点
     * @param node 需要平衡的节点
     * @return 平衡后的节点
     *
     * 根据平衡因子判断不平衡类型,执行相应的旋转操作:
     * - LL型(左左):平衡因子>1且左子节点平衡因子>=0,执行右旋转
     * - LR型(左右):平衡因子>1且左子节点平衡因子<0,先左旋再右旋
     * - RR型(右右):平衡因子<-1且右子节点平衡因子<=0,执行左旋转
     * - RL型(右左):平衡因子<-1且右子节点平衡因子>0,先右旋再左旋
     */
    std::shared_ptr<AVLNode<T, Compare>> balance(std::shared_ptr<AVLNode<T, Compare>> node)
    {
        if (!node)
            return nullptr;

        // 更新节点高度
        updateHeight(node);

        // 计算平衡因子
        int bf = balanceFactor(node);

        // LL情况:左子树过高,且左子树的左子树更高或相等
        if (bf > 1 && balanceFactor(node->left) >= 0)
        {
            return rotateRight(node);
        }

        // LR情况:左子树过高,但左子树的右子树更高
        if (bf > 1 && balanceFactor(node->left) < 0)
        {
            node->left = rotateLeft(node->left);
            return rotateRight(node);
        }

        // RR情况:右子树过高,且右子树的右子树更高或相等
        if (bf < -1 && balanceFactor(node->right) <= 0)
        {
            return rotateLeft(node);
        }

        // RL情况:右子树过高,但右子树的左子树更高
        if (bf < -1 && balanceFactor(node->right) > 0)
        {
            node->right = rotateRight(node->right);
            return rotateLeft(node);
        }

        return node; // 节点已经平衡
    }

    /**
     * @brief 插入辅助函数(递归实现)
     * @param node 当前节点
     * @param key 要插入的键值
     * @return 插入并平衡后的节点
     *
     * 递归地在AVL树中插入新节点:
     * 1. 如果节点为空,创建新节点
     * 2. 根据比较结果递归插入到左子树或右子树
     * 3. 更新父节点指针
     * 4. 如果键值重复,直接返回原节点
     * 5. 插入后调用balance函数进行平衡调整
     */
    std::shared_ptr<AVLNode<T, Compare>> insertHelper(std::shared_ptr<AVLNode<T, Compare>> node, const T &key)
    {
        if (!node)
            return std::make_shared<AVLNode<T, Compare>>(key);

        // 插入值小于当前节点值,向左子树插入
        if (comp(key, node->key))
        {
            node->left = insertHelper(node->left, key);
            if (node->left)
                node->left->parent = node;
        }
        // 插入值大于当前节点值,向右子树插入
        else if (comp(node->key, key))
        {
            node->right = insertHelper(node->right, key);
            if (node->right)
                node->right->parent = node;
        }
        else
        {
            return node; // 键值重复,不插入
        }

        // 平衡当前节点
        return balance(node);
    }

    /**
     * @brief 查找最小节点
     * @param node 起始节点
     * @return 最小节点指针
     *
     * 通过不断访问左子树找到最小节点
     * 在二叉搜索树中,最小节点是最左边的节点
     */
    std::shared_ptr<AVLNode<T, Compare>> findMin(std::shared_ptr<AVLNode<T, Compare>> node) const
    {
        while (node && node->left)
        {
            node = node->left;
        }
        return node;
    }

    /**
     * @brief 查找最大节点
     * @param node 起始节点
     * @return 最大节点指针
     *
     * 通过不断访问右子树找到最大节点
     * 在二叉搜索树中,最大节点是最右边的节点
     */
    std::shared_ptr<AVLNode<T, Compare>> findMax(std::shared_ptr<AVLNode<T, Compare>> node) const
    {
        while (node && node->right)
        {
            node = node->right;
        }
        return node;
    }

    /**
     * @brief 删除辅助函数(递归实现)
     * @param node 当前节点
     * @param key 要删除的键值
     * @return 删除并平衡后的节点
     *
     * 递归地在AVL树中删除节点:
     * 1. 如果节点为空,返回nullptr
     * 2. 根据比较结果递归删除左子树或右子树
     * 3. 找到要删除的节点后:
     *    - 如果节点只有一个子节点或无子节点,直接用子节点替换
     *    - 如果节点有两个子节点,用右子树的最小值替换当前节点值,然后删除右子树的最小值
     * 4. 更新父节点指针
     * 5. 删除后调用balance函数进行平衡调整
     */
    std::shared_ptr<AVLNode<T, Compare>> removeHelper(std::shared_ptr<AVLNode<T, Compare>> node, const T &key)
    {
        if (!node)
            return nullptr;

        // 目标key值小于当前节点值,向左子树继续查找删除
        if (comp(key, node->key))
        {
            node->left = removeHelper(node->left, key);
            if (node->left)
                node->left->parent = node;
        }
        // 目标key值大于当前节点值,向右子树继续查找删除
        else if (comp(node->key, key))
        {
            node->right = removeHelper(node->right, key);
            if (node->right)
                node->right->parent = node;
        }
        else
        {
            // 找到要删除的节点
            if (!node->left || !node->right)
            {
                // 情况1:节点有0个或1个子节点
                node = node->left ? node->left : node->right;
                if (node)
                    node->parent = nullptr;
            }
            else
            {
                // 情况2:节点有2个子节点
                // 找到右子树的最小节点来替换当前节点
                auto temp = findMin(node->right);
                node->key = temp->key; // 用最小节点的值替换当前节点值
                // 递归删除右子树中的最小节点
                node->right = removeHelper(node->right, temp->key);
                if (node->right)
                    node->right->parent = node;
            }
        }

        // 平衡当前节点
        return balance(node);
    }

    /**
     * @brief 查找辅助函数(递归实现)
     * @param node 当前节点
     * @param key 要查找的键值
     * @return 找到的节点指针,如果未找到返回nullptr
     *
     * 递归地在AVL树中查找节点:
     * 1. 如果节点为空,返回nullptr
     * 2. 如果键值小于当前节点键值,递归查找左子树
     * 3. 如果键值大于当前节点键值,递归查找右子树
     * 4. 如果键值等于当前节点键值,返回当前节点
     */
    std::shared_ptr<AVLNode<T, Compare>> findHelper(std::shared_ptr<AVLNode<T, Compare>> node, const T &key) const
    {
        if (!node)
            return nullptr;

        if (comp(key, node->key))
        {
            return findHelper(node->left, key);
        }
        else if (comp(node->key, key))
        {
            return findHelper(node->right, key);
        }
        else
        {
            return node;
        }
    }

    /**
     * @brief 查找前驱节点
     * @param node 参考节点
     * @return 前驱节点指针,如果不存在前驱返回nullptr
     *
     * 前驱节点是比给定节点小的最大节点:
     * 1. 如果节点有左子树,前驱是左子树的最大值
     * 2. 否则,向上查找第一个右子树的祖先(即第一个比节点小的祖先)
     */
    std::shared_ptr<AVLNode<T, Compare>> predecessorHelper(std::shared_ptr<AVLNode<T, Compare>> node) const
    {
        if (!node)
            return nullptr;

        // 如果有左子树,前驱是左子树的最大值
        if (node->left)
        {
            return findMax(node->left);
        }

        // 否则向上查找第一个右子树的祖先
        auto parent = node->parent;
        while (parent && node == parent->left)
        {
            node = parent;
            parent = parent->parent;
        }
        return parent;
    }

    /**
     * @brief 查找后继节点
     * @param node 参考节点
     * @return 后继节点指针,如果不存在后继返回nullptr
     *
     * 后继节点是比给定节点大的最小节点:
     * 1. 如果节点有右子树,后继是右子树的最小值
     * 2. 否则,向上查找第一个左子树的祖先(即第一个比节点大的祖先)
     */
    std::shared_ptr<AVLNode<T, Compare>> successorHelper(std::shared_ptr<AVLNode<T, Compare>> node) const
    {
        if (!node)
            return nullptr;

        // 如果有右子树,后继是右子树的最小值
        if (node->right)
        {
            return findMin(node->right);
        }

        // 否则向上查找第一个左子树的祖先
        auto parent = node->parent;
        while (parent && node == parent->right)
        {
            node = parent;
            parent = parent->parent;
        }
        return parent;
    }

    /**
     * @brief 中序遍历辅助函数(递归实现)
     * @param node 当前节点
     * @param result 存储遍历结果的向量
     *
     * 递归地执行中序遍历:
     * 1. 递归遍历左子树
     * 2. 访问当前节点(将键值添加到结果向量)
     * 3. 递归遍历右子树
     * 中序遍历的结果是有序的
     */
    void inorderHelper(std::shared_ptr<AVLNode<T, Compare>> node, std::vector<T> &result) const
    {
        if (!node)
            return;
        inorderHelper(node->left, result);
        result.push_back(node->key);
        inorderHelper(node->right, result);
    }
};

这个完整的平面扫描算法实现包含了详细的中文注释,解释了每个类、方法和关键代码段的作用和原理。代码实现了高效的线段交点检测,通过事件驱动的扫描线方法和AVL平衡树来保证算法的高效性。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值