题目大意
给一棵环套树,每个点有点权ei,求所有距离≤K的点对数,以及这些点对权值积的和。
Data Constraint
n≤105,K≤n,ei≤104
题解
这题与 bzoj3648寝室管理 类似
。先考虑没有环的情况,这时就是裸的点剖了,所以我们可以先点剖求出环上每个外向树的答案,再统计经过环边的答案。
具体来说就是,先找环上任意一点作为起点,任一方向为起始方向编号。然后对于环上一点i,只统计
现在分析第二种情况,第一种类似。
如果(u,v)合法,那么满足du+dv+i−k≤K
即dv−k≤K−du−i
所以只需要按dv−k建一棵线段树就可以统计答案了。
SRC
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std ;
#define N 100000 + 10
typedef long long ll ;
const int LB = -1e5 ;
struct Tree {
int Son[2] ;
ll tot , sum ;
} T[36*N] ;
bool vis[N] , ori[N] ;
int Node[2*N] , Next[2*N] , Head[N] , tot = 1 ;
int DFN[N] , LOW[N] , s[N] , Tim = 0 , top = 0 ;
int Cir[N] , Size[N] , D[N] , Dis[N] , Maxs[N] ;
int Rt[N] ;
ll e[N] ;
int n , K ;
int Root , All , Mins , last , tail , Cnt ;
ll ans , sum , rettot , retsum ;
bool cmp( int i , int j ) { return Dis[i] < Dis[j] ; }
void link( int u , int v ) {
Node[++tot] = v ;
Next[tot] = Head[u] ;
Head[u] = tot ;
}
void FindCir( int x , int fe ) {
if ( Cir[0] ) return ;
DFN[x] = LOW[x] = ++ Tim ;
s[++top] = x ;
for (int p = Head[x] ; p ; p = Next[p] ) {
if ( p == (fe ^ 1) ) continue ;
if ( !DFN[Node[p]] ) {
FindCir( Node[p] , p ) ;
LOW[x] = min( LOW[x] , LOW[Node[p]] ) ;
}
else LOW[x] = min( LOW[x] , DFN[Node[p]] ) ;
}
if ( LOW[x] == DFN[x] ) {
bool exist = 0 ;
while ( s[top] != x ) {
exist = 1 ;
Cir[++Cir[0]] = s[top] ;
vis[s[top]] = 1 ;
top -- ;
}
if ( exist ) Cir[++Cir[0]] = s[top] , vis[s[top]] = 1 ;
top -- ;
}
}
void GetSize( int x , int Fa ) {
Size[x] = Maxs[x] = 1 ;
for (int p = Head[x] ; p ; p = Next[p] ) {
if ( Node[p] == Fa || vis[Node[p]] ) continue ;
GetSize( Node[p] , x ) ;
Size[x] += Size[Node[p]] ;
if ( Size[Node[p]] > Maxs[x] ) Maxs[x] = Size[Node[p]] ;
}
}
void GetRoot( int x, int Fa ) {
Maxs[x] = max( Maxs[x] , Size[All] - Maxs[x] ) ;
if ( Maxs[x] < Mins ) Mins = Maxs[x] , Root = x ;
for (int p = Head[x] ; p ; p = Next[p] ) {
if ( Node[p] == Fa || vis[Node[p]] ) continue ;
GetRoot( Node[p] , x ) ;
}
}
void DFS( int x , int Fa ) {
for (int p = Head[x] ; p ; p = Next[p] ) {
if ( Node[p] == Fa || vis[Node[p]] ) continue ;
Dis[Node[p]] = Dis[x] + 1 ;
D[++tail] = Node[p] ;
DFS( Node[p] , x ) ;
}
}
void Qsort( int l , int r ) {
if ( l > r ) return ;
sort( D + l , D + r + 1 , cmp ) ;
}
void Calc( int Root ) {
Qsort( 1 , last ) ;
Qsort( last + 1 , tail ) ;
ll lsum = 0 ;
int head = 1 ;
for (int j = tail ; j > last ; j -- ) {
while ( head <= last && Dis[D[head]] + Dis[D[j]] <= K ) {
lsum += e[D[head]] ;
head ++ ;
}
ans += head - 1 + (Dis[D[j]] <= K) ;
sum += e[D[j]] * lsum + e[D[j]] * e[Root] * (Dis[D[j]] <= K) ;
}
last = tail ;
}
void DIV( int x ) {
last = tail = 0 ;
Mins = 0x7FFFFFFF ;
Root = All = x ;
GetSize( x , 0 ) ;
GetRoot( x , 0 ) ;
Dis[Root] = 0 ;
vis[Root] = 1 ;
for (int p = Head[Root] ; p ; p = Next[p] ) {
if ( vis[Node[p]] ) continue ;
D[++tail] = Node[p] ;
Dis[Node[p]] = 1 ;
DFS( Node[p] , Root ) ;
Calc( Root ) ;
}
for (int p = Head[Root] ; p ; p = Next[p] ) if ( !vis[Node[p]] ) DIV( Node[p] ) ;
}
void SolveTree() {
for (int i = 1 ; i <= Cir[0] ; i ++ ) {
vis[Cir[i]] = 0 ;
DIV( Cir[i] ) ;
vis[Cir[i]] = 1 ;
}
}
bool Sign = 0 ;
int Search( int i ) {
int l = 0 , r = i , ret = 0 ;
while ( l <= r ) {
int mid = (l + r) / 2 ;
if ( i - mid > Cir[0] - i + mid ) ret = mid , l = mid + 1 ;
else r = mid - 1 ;
}
return ret ;
}
int NewNode() {
Cnt ++ ;
T[Cnt].Son[0] = T[Cnt].Son[1] = 0 ;
T[Cnt].tot = T[Cnt].sum = 0 ;
return Cnt ;
}
void Update( int v ) {
int ls = T[v].Son[0] ;
int rs = T[v].Son[1] ;
T[v].tot = T[ls].tot + T[rs].tot ;
T[v].sum = T[ls].sum + T[rs].sum ;
}
void ADDTree( int v , int l , int r , int x , ll e ) {
if ( l == x && r == x ) {
T[v].tot ++ ;
T[v].sum += e ;
return ;
}
int mid = (l + r) / 2 ;
if ( l + r < 0 && (l + r) / 2 * 2 != l + r ) mid -- ;
if ( x <= mid ) {
int ls = NewNode() ;
T[ls] = T[T[v].Son[0]] ;
T[v].Son[0] = ls ;
ADDTree( ls , l , mid , x , e ) ;
} else {
int rs = NewNode() ;
T[rs] = T[T[v].Son[1]] ;
T[v].Son[1] = rs ;
ADDTree( rs , mid + 1 , r , x , e ) ;
}
Update( v ) ;
}
void SearchTree( int lv , int rv , int l , int r , int x , int y ) {
if ( l == x && r == y ) {
rettot += T[rv].tot - T[lv].tot ;
retsum += T[rv].sum - T[lv].sum ;
return ;
}
int mid = (l + r) / 2 ;
if ( l + r < 0 && (l + r) / 2 * 2 != l + r ) mid -- ;
if ( y <= mid ) SearchTree( T[lv].Son[0] , T[rv].Son[0] , l , mid , x , y ) ;
else if ( x > mid ) SearchTree( T[lv].Son[1] , T[rv].Son[1] , mid + 1 , r , x , y ) ;
else {
SearchTree( T[lv].Son[0] , T[rv].Son[0] , l , mid , x , mid ) ;
SearchTree( T[lv].Son[1] , T[rv].Son[1] , mid + 1 , r , mid + 1 , y ) ;
}
}
void GetDist( int x , int i , int j , int Fa , int deep ) {
if ( !Sign ) {
if ( deep + i + LB <= K ) {
rettot = retsum = 0 ;
SearchTree( Rt[j] , Rt[i-1] , LB , -LB , LB , K - deep - i ) ;
ans += rettot ;
sum += e[x] * retsum ;
}
if ( !Rt[i] ) {
Rt[i] = NewNode() ;
T[Rt[i]] = T[Rt[i-1]] ;
}
ADDTree( Rt[i] , LB , -LB , deep - i , e[x] ) ;
} else {
if ( deep + Cir[0] - i <= K ) {
rettot = retsum = 0 ;
SearchTree( 0 , Rt[j] , 0 , -LB , 0 , K - deep - Cir[0] + i ) ;
ans += rettot ;
sum += e[x] * retsum ;
}
if ( !Rt[i] ) {
Rt[i] = NewNode() ;
T[Rt[i]] = T[Rt[i-1]] ;
}
ADDTree( Rt[i] , 0 , -LB , deep + i , e[x] ) ;
}
for (int p = Head[x] ; p ; p = Next[p] ) {
if ( vis[Node[p]] || Node[p] == Fa ) continue ;
GetDist( Node[p] , i , j , x , deep + 1 ) ;
}
}
void SolveCircle() {
for (int i = 1 ; i <= Cir[0] ; i ++ ) {
int j = Search( i ) ;
GetDist( Cir[i] , i , j , 0 , 0 ) ;
}
Cnt = 0 ;
Sign = 1 ;
memset( T , 0 , sizeof(T) ) ;
memset( Rt , 0 , sizeof(Rt) ) ;
for (int i = 1 ; i <= Cir[0] ; i ++ ) {
int j = Search( i ) ;
GetDist( Cir[i] , i , j , 0 , 0 ) ;
}
}
void Solve() {
SolveTree() ;
memcpy( vis , ori , sizeof(ori) ) ;
SolveCircle() ;
}
int main() {
freopen( "pronet.in" , "r" , stdin ) ;
freopen( "pronet.out" , "w" , stdout ) ;
scanf( "%d%d" , &n , &K ) ;
for (int i = 1 ; i <= n ; i ++ ) {
int x ;
scanf( "%d" , &x ) ;
if ( !x ) continue ;
link( x , i ) ;
link( i , x ) ;
}
for (int i = 1 ; i <= n ; i ++ ) scanf( "%lld" , &e[i] ) ;
for (int i = 1 ; i <= n ; i ++ ) if ( !DFN[i] ) FindCir( i , 0 ) ;
memcpy( ori , vis , sizeof(vis) ) ;
if ( Cir[0] ) Solve() ;
else DIV( 1 ) ;
printf( "%lld %lld\n" , ans , sum ) ;
return 0 ;
}
以上.