题目大意:
有镜子’/’ ‘\’n面。然后一个激光在(bx,by),可向四周(上下左右)发射,问可以在几个地方加一面镜子,使得激光可以从北面进入原点(0,0)
样例:
4 1 2
-2 1 \
2 1 /
2 2 \
-2 2 /
输出:
2
解释:
在(0,2)放’/’ 或(0,1)放’/’
题解:
我实在懒得自己想,(实际上是不想想那么多的细节)。然后就开始扒正解,把其中一些地方理解了一下。
我将在代码里告诉你怎么理解每个模块。
我有好多define 不习惯的可以全部替换
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<map>
#include<vector>
using namespace std;
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define M 100005
#define MAXN (1 <<17 )
#define oo 1000000000
int n,bx,by;
map<int ,vector< pair<int ,char > > >row;
map<int ,vector< pair<int ,char > > >line;
pair<pair<int,int>,char > getnext(int x,int y,int flag,char c){
if(flag==1){//up 接着激光向上走
int id=lower_bound(row[x].begin(),row[x].end(),mp(y,c))-row[x].begin();//x不变,y变大
id++;
if(id==row[x].size())return mp(mp(x,oo+1),'?');//没有点,激光向(x,oo)延伸
else return mp(mp(x,row[x][id].fi),row[x][id].se);
}else if(flag==2){//down
int id=lower_bound(row[x].begin(),row[x].end(),mp(y,c))-row[x].begin();
id--;
if(id<0)return mp(mp(x,-oo-1),'?');
else return mp(mp(x,row[x][id].fi),row[x][id].second);
}else if(flag==3){//right
int id=lower_bound(line[y].begin(),line[y].end(),mp(x,c))-line[y].begin();
id++;
if(id==line[y].size())return mp(mp(oo+1,y),'?');
else return mp(mp(line[y][id].fi,y),line[y][id].se);
}else if(flag==4){//left
int id=lower_bound(line[y].begin(),line[y].end(),mp(x,c))-line[y].begin();
id--;
if(id<0)return mp(mp(-oo-1,y),'?');
else return mp(mp(line[y][id].fi,y),line[y][id].se);
}
}
vector<pair<int,int > > getpath(int x,int y,int flag,char c){//收集点,按顺序放入vector
pair<int,int >pos(x,y);
vector<pair<int ,int > >path(1,pos);//不知道为什么这样,应该可以先定义,再push_back
for(;;){
pair<pair<int,int >,char > res= getnext(pos.fi,pos.se,flag,c);
pos=res.fi;
c=res.se;
path.pb(pos);
if(c=='/'){
flag=(flag+2)%4;
if(!flag)flag=4;//其实是枚举原先是什么,然后变成什么,找规律
}else if(c=='\\'){
flag+=(flag&1)?3:1;//同上
flag%=4;
if(!flag)flag=4;
}else break;
}return path;
}
vector <pair<int,pair<int,int> > > lin(vector<pair<int,int> >& path){//将路径上的点,还原成竖着的边
vector<pair<int,pair<int,int> > > ret;
int sz=path.size();
for(int i=0;i<sz-1;++i)
if(path[i].fi==path[i+1].fi){
int a=path[i].se,b=path[i+1].se;
ret.pb(mp(path[i].fi,a<b?mp(a,b):mp(b,a)));
}
return ret;
}
vector <pair<int,pair<int,int> > > ro(vector<pair<int,int> >& path){//还原成横着的边
vector<pair<int,pair<int,int> > > ret;
int sz=path.size();
for(int i=0;i<sz-1;++i)
if(path[i].se==path[i+1].se){
int a=path[i].fi,b=path[i+1].fi;
ret.pb(mp(path[i].se,a<b?mp(a,b):mp(b,a)));
}
return ret;
}
//标程的树状数组
//int BT[MAXN];
///* Logically executes array[x] += v. */
//void bit_add(int x, int v) {
// for(int i = x | MAXN; i < (MAXN << 1); i += i & -i) {
// BT[i ^ MAXN] += v;
// }
//}
///* Returns the sum of array[i] for 0 <= i < x */
//int bit_get(int x) {
// int ret = 0;
// for(int i = x - 1; x != 0; i &= i - 1) {
// ret += BT[i];
// if(!i) break;
// }
// return ret;
//}
//正常版的,我算了一下,M*4的大小绝对够,不晓得能不能再小的,像标程一样。。QAQ
int cnt[M<<2];
void bit_add(int x,int v){
while(x<(M<<2)){
cnt[x]+=v;
x+=x&-x;
}
}
int bit_get(int x){
int res=0;
while(x){
res+=cnt[x];
x-=x&-x;
}
return res;
}
int Count(vector<pair<int,pair<int,int > > > vs,vector<pair<int,pair<int,int> > > hs){//判断交点。。。
//前者是竖着的,后者横着
vector<int>ys;
int szvs=vs.size();
for(int i=0;i<szvs;++i){
ys.pb(vs[i].se.fi);
ys.pb(vs[i].se.se);
}
int szhs=hs.size();
for(int i=0;i<szhs;++i)ys.pb(hs[i].fi);
sort(ys.begin(),ys.end());
ys.resize(unique(ys.begin(),ys.end())-ys.begin());
for(int i=0;i<szvs;++i){
vs[i].se.fi=lower_bound(ys.begin(),ys.end(),vs[i].se.fi)-ys.begin()+1;
vs[i].se.se=lower_bound(ys.begin(),ys.end(),vs[i].se.se)-ys.begin()+1;
}
for(int i=0;i<szhs;++i)hs[i].fi=lower_bound(ys.begin(),ys.end(),hs[i].fi)-ys.begin()+1;
//以上均为离散处理。。。我们是按照x的大小扫描而来。
//记得加1 因为正常的树状数组从1开始。。
sort(vs.begin(),vs.end());
vector<pair<pair<int,int>,int> > E;
for(int i=0;i<szhs;++i){//扫描线时的增加,删除工作
E.pb(mp(mp(hs[i].se.fi,hs[i].fi),1));
E.pb(mp(mp(hs[i].se.se,hs[i].fi),-1));
}
sort(E.begin(),E.end());
int res=0;
int szE=E.size();
//标程
// memset(BT,0,sizeof(BT));
// for(int i = 0, j = 0; i <E.size(); i++) {
// int x = E[i].first.first;
// for(; j < vs.size() && vs[j].first < x; j++)
// res += bit_get(vs[j].second.second) - bit_get(vs[j].second.first + 1);
// bit_add(E[i].first.second, E[i].second);
// }
memset(cnt,0,sizeof(cnt));
for(int i = 0, j = 0; i < E.size(); i++) {
int x = E[i].first.first;
for(; j < vs.size() && vs[j].first < x; j++)
res += bit_get(vs[j].second.second-1) - bit_get(vs[j].second.first );//两端不要(这很关键)
bit_add(E[i].first.second, E[i].second);
}
return res;
}
int main(){
scanf("%d %d %d",&n,&bx,&by);
row[0].pb(mp(0,'S'));
line[0].pb(mp(0,'S'));
row[bx].pb(mp(by,'B'));
line[by].pb(mp(bx,'B'));
for(int i=1;i<=n;i++){
char a[5];
int x,y;
scanf("%d %d %s",&x,&y,a);
row[x].pb(mp(y,a[0]));
line[y].pb(mp(x,a[0]));
}
// 变得有序,才可以lower_bound
for(map<int,vector<pair<int ,char > > >::iterator it=row.begin();it!=row.end();it++)sort(it->se.begin(),it->se.end());
for(map<int,vector<pair<int ,char > > >::iterator it=line.begin();it!=line.end();it++)sort(it->se.begin(),it->se.end());
vector<pair<int,int > > old=getpath(0,0,1,'S');
//应该是可以先把old 的lin 和ro 求出来
int ans=0;
for(int i=1;i<5;i++){//枚举激光从哪个方向发射
vector<pair<int,int> > now=getpath(bx,by,i,'B');
int res=Count(lin(old),ro(now))+Count(lin(now),ro(old));
if(now[0]==now.back())ans+=res;//若回到出发点,则代表重复计算 另一方向也是一样。。
else ans+=res*2;//只是因为上面的原因 ,如果上面的除以二,可能有余数,然后就整体乘2
}
printf("%d\n",ans/2);
return 0;
}
下面的是标程
#include <iostream>
#include <vector>
#include <map>
#include <cstring>
#include <algorithm>
#include <cstdio>
using namespace std;
#define MAXN (1 << 17)
#define MAXVAL 1000000000
int dx[] = {0, 1, 0, -1};
int dy[] = {1, 0, -1, 0};
map<int, vector<pair<int, char> > > objx;
map<int, vector<pair<int, char> > > objy;
pair<pair<int, int>, char> getnext(int x, int y, int dir) {
bool vmove = dir % 2 == 0;
int a = vmove ? x : y;
int b = vmove ? y : x;
int db = vmove ? dy[dir] : dx[dir];
vector<pair<int, char> >& objs = (vmove ? objx : objy)[a];
int id = lower_bound(objs.begin(), objs.end(), make_pair(b, (char)0))
- objs.begin();
id += db;
char ch = '?';
if(id < 0) {
b = -(MAXVAL + 1);
} else if(id == objs.size()) {
b = MAXVAL + 1;
} else {
b = objs[id].first;
ch = objs[id].second;
}
return make_pair(vmove ? make_pair(a, b) : make_pair(b, a), ch);
}
vector<pair<int, int> > getpath(int x, int y, int dir) {
pair<int, int> pos(x, y);
vector<pair<int, int> > path(1, pos);
for(;;) {
pair<pair<int, int>, char> res = getnext(pos.first, pos.second,
dir);
pos = res.first;
path.push_back(pos);
if(res.second == '/') {
dir = (dir + (dir % 2 != 0 ? 3 : 1)) % 4;
} else if(res.second == '\\') {
dir = (dir + (dir % 2 == 0 ? 3 : 1)) % 4;
} else {
break;
}
}
return path;
}
vector<pair<int, pair<int, int> > >
getverts(vector<pair<int, int> >& path) {
vector<pair<int, pair<int, int> > > ret;
for(int i = 0; i + 1 < path.size(); i++) {
if(path[i].first == path[i + 1].first) {
ret.push_back(make_pair(path[i].first,
make_pair(path[i].second, path[i + 1].second)));
if(ret.back().second.second < ret.back().second.first) {
swap(ret.back().second.first, ret.back().second.second);
}
}
}
return ret;
}
vector<pair<int, pair<int, int> > >
gethorz(vector<pair<int, int> >& path) {
vector<pair<int, pair<int, int> > > ret;
for(int i = 0; i + 1 < path.size(); i++) {
if(path[i].second == path[i + 1].second) {
ret.push_back(make_pair(path[i].second,
make_pair(path[i].first, path[i + 1].first)));
if(ret.back().second.second < ret.back().second.first) {
swap(ret.back().second.first, ret.back().second.second);
}
}
}
return ret;
}
int BT[MAXN];
/* Logically executes array[x] += v. */
void bit_add(int x, int v) {
for(int i = x | MAXN; i < (MAXN << 1); i += i & -i) {
BT[i ^ MAXN] += v;
}
}
/* Returns the sum of array[i] for 0 <= i < x */
int bit_get(int x) {
int ret = 0;
for(int i = x - 1; x != 0; i &= i - 1) {
ret += BT[i];
if(!i) break;
}
return ret;
}
int countints(vector<pair<int, pair<int, int> > > vs,
vector<pair<int, pair<int, int> > > hs) {
/* Start with a coordinate compression of y values. */
vector<int> ys;
for(int i = 0; i < vs.size(); i++) {
ys.push_back(vs[i].second.first);
ys.push_back(vs[i].second.second);
}
for(int i = 0; i < hs.size(); i++) {
ys.push_back(hs[i].first);
}
sort(ys.begin(), ys.end());
ys.resize(unique(ys.begin(), ys.end()) - ys.begin());
for(int i = 0; i < vs.size(); i++) {
vs[i].second.first = lower_bound(ys.begin(), ys.end(),
vs[i].second.first) - ys.begin();
vs[i].second.second = lower_bound(ys.begin(), ys.end(),
vs[i].second.second) - ys.begin();
}
for(int i = 0; i < hs.size(); i++) {
hs[i].first = lower_bound(ys.begin(), ys.end(), hs[i].first) - ys.begin();
}
/* Sort vertical intervals by x, create event list. */
sort(vs.begin(), vs.end());
vector<pair<pair<int, int>, int> > events;
for(int i = 0; i < hs.size(); i++) {
events.push_back(make_pair(make_pair(hs[i].second.first, hs[i].first), 1));
events.push_back(make_pair(make_pair(hs[i].second.second,
hs[i].first), -1));
}
sort(events.begin(), events.end());
/* Finally, count the intersections using a Fenwick tree. */
int result = 0;
memset(BT, 0, sizeof(BT));
for(int i = 0, j = 0; i < events.size(); i++) {
int x = events[i].first.first;
for(; j < vs.size() && vs[j].first < x; j++) {
result += bit_get(vs[j].second.second) - bit_get(vs[j].second.first + 1);
}
bit_add(events[i].first.second, events[i].second);
}
return result;
}
int main() {
freopen("optics.in", "r", stdin);
freopen("optics.out", "w", stdout);
int N, bx, by;
cin >> N >> bx >> by;
objx[0].push_back(make_pair(0, 'S'));
objy[0].push_back(make_pair(0, 'S'));
objx[bx].push_back(make_pair(by, 'B'));
objy[by].push_back(make_pair(bx, 'B'));
for(int i = 0; i < N; i++) {
int x, y;
string mr;
cin >> x >> y >> mr;
objx[x].push_back(make_pair(y, mr[0]));
objy[y].push_back(make_pair(x, mr[0]));
}
for(map<int, vector<pair<int, char> > >::iterator it =
objx.begin();
it != objx.end(); ++it) {
sort(it->second.begin(), it->second.end());
}
for(map<int, vector<pair<int, char> > >::iterator it =
objy.begin();
it != objy.end(); ++it) {
sort(it->second.begin(), it->second.end());
}
int result = 0;
vector<pair<int, int> > plaser = getpath(0, 0, 0);
for(int i = 0; i < 4; i++) {
vector<pair<int, int> > pbarn = getpath(bx, by, i);
int res = countints(getverts(plaser), gethorz(pbarn)) +
countints(getverts(pbarn), gethorz(plaser));
if(pbarn[0] == pbarn.back()) {
result += res;
} else {
result += 2 * res;
}
}
cout << result / 2 << endl;
return 0;
}