线段树区间修改
题目:http://poj.org/problem?id=1436
题目大意:
给出n根竖直的线段, 如果三条线段可两两水平可见,那么这三条线段被叫做一个三角形,求这样的三角形的个数。
思路:
用线段树维护区间被那个线段盖住是很好想到的,但是如何判断三条线段两两可见就不容易了,数据范围是8000,理论上O(n^3)是不可能的,但是没想到还真是这么干,枚举 三条边判断是否两两可见。利用vector记录每条线段可见的线段,利用一个二维数组记录两个线段是否可见,这样三重循环就可以了
提交情况 :
悲剧的调试了2个星期啊,一开始是由于懒, 也为了好看,就给树定义了lsize,rlize,sonl,sonr, color 5个变量, 结果是超时, 以为是算法的问题,很多人思路都是这样 的,最后想可能是空间太大导致缓存读写的时间差异吧,结果改成只记录color后真的AC了。
AC:code
/*
题目:http://poj.org/problem?id=1436
线段树的题目, 找竖直的线段中能够3个线段能两两水平可见看的个数
利用线段树记录每一段被谁覆盖。
最后居然是O(n^3)的查找方式
*/
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
#define maxn 8010
#define mid(a, b) (((a) + (b)) >> 1)
const int bit[10] = {1, 2, 4, 8, 16, 32, 64, 128, 256, 512};
vector<int> see[maxn];
bool map[maxn][maxn];
struct lineNode{
int x, low, high;
}line[maxn];
struct segmentTreeNode{
int color;
};
bool comp(const lineNode &A, const lineNode &B){
return A.x < B.x;
}
struct segmentTree{
segmentTreeNode * tree;
int root;
int l, r;
segmentTree(){tree = new segmentTreeNode[maxn * 8]; }
~segmentTree(){ delete[] tree; }
void init(int _l, int _r){ l = _l, r = _r, root = 1; }
/* 建立线段数 */
void built(int l, int r, int rt){
tree[rt].color = -2;
if(l == r) return;
built(l, mid(l, r), rt * 2);
built(mid(l, r) + 1, r, rt * 2 + 1);
}
/* 标记向下传递 */
void down(int rt, int l, int r){
if(tree[rt].color == -1) return;
if(tree[rt].color >= 0 && l != r) tree[rt * 2].color = tree[rt * 2 + 1].color = tree[rt].color;
tree[rt].color = -1;
}
/* 插入线段 */
void insert(int id, int l, int r, int rtl, int rtr, int rt){
if(l == rtl && r == rtr){
tree[rt].color = id;
return;
}
down(rt, rtl, rtr);
int md = mid(rtl, rtr);
if(l > md) insert(id, l, r, md + 1, rtr, rt * 2 + 1);
else{
if(r <= md) insert(id, l, r, rtl, md, rt * 2);
else{
insert(id, l, md, rtl, md, rt * 2);
insert(id, md + 1, r, md + 1, rtr, rt * 2 + 1);
}
}
}
/* 查找节点 */
void find(int id, int l, int r, int rtl, int rtr, int rt){
int y = tree[rt].color;
if(y == -2) return;
if(y >= 0){
if(!(map[id][y])){
see[id].push_back(y);
map[id][y] = 1;
}
return;
}
if(rtl == rtr) return;
down(rt, rtl, rtr);
int md = mid(rtl, rtr);
if(l > md) find(id, l, r, md + 1, rtr, rt * 2 + 1);
else{
if(r <= md) find(id, l, r, rtl, md, rt * 2);
else{
find(id, l, md, rtl, md, rt * 2);
find(id, md + 1, r, md + 1, rtr, rt * 2 + 1);
}
}
}
/* 计算答案 */
void slove(int n, int &ans){
int i, j, k, p, temp;
ans = 0;
for(i = 2; i < n; ++ i){
for(j = 0; j < see[i].size(); ++ j){
temp = see[i].at(j);
for(k = 0; k < see[temp].size(); ++ k){
p = see[temp].at(k);
if(map[i][p]) ans ++;
}
}
}
}
};
int main(){
int n, ans, CASE, max;
segmentTree seg;
scanf("%d", &CASE);
while(CASE --){
scanf("%d", &n);
max = 0;
for(int i = 0; i < n; ++ i){
scanf("%d %d %d", &line[i].low, &line[i].high, &line[i].x);
if(line[i].high > max) max = line[i].high;
see[i].clear();
}
sort(line, line + n, comp);
seg.init(0, max * 2);
seg.built(seg.l, seg.r, seg.root);
seg.insert(0, line[0].low * 2, line[0].high * 2, seg.l, seg.r, seg.root);
memset(map, 0, sizeof(map));
for(int i = 1; i < n; ++ i){
seg.find(i, line[i].low * 2, line[i].high * 2, seg.l, seg.r, seg.root);
seg.insert(i, line[i].low * 2, line[i].high * 2, seg.l, seg.r, seg.root);
}
seg.slove(n, ans);
printf("%d\n", ans);
}
return 0;
}