其实题目要求的就是一颗子树上把所有的权值从小到大排序,每个值对应第k大,把k和值乘在一起加和,就是这个子树对应的根的答案。
说到底就是有多少个数大于一个值,这个值就要多加几次。
然后我就死在这个理解上,因为要考虑值相同的情况,这句话没错,但是容易忽略值相同的情况。
对于每个点我们建一颗权值线段树,表示以这个点为根的子树上的权值分部情况。
然后右子树上有多少值大于左子树上的值,以及左子树上有多少值大于右子树上的值就很好求了,线段树合并的时候,左子树对应的线段树的左儿子的值一定小于右子树的线段树的右儿子,反之亦然,维护一下num代表个数,sum代表加和就好了。
这样的话,在左右子树都已经求出结果的情况下,我们只需要求出跨越左右子树需要加的值是多少,而这个过程就是上面的操作。
然后还需要注意下,这时候都没有考虑根节点,插入根节点要统计下子树中有多少值大于自己,所以要最后插入。
最后,就是前面提到的,需要注意下值重复的情况,对于左右子树值相同的情况我们在权值线段树里没有办法通过左右区间得到,然而实际情况是需要我们对相等的值进行一个随机排名的,所以在插入的时候,需要加上跟自己值相同的个数,在合并的时候,当递归到叶子节点,需要让左子树的sum*右子树的num,反之也可以。
思路复杂了,一开始对root数组过量初始化,然后越界hdu莫名tle,后来没考虑值相同的情况也wa了很久。
代码:
#include <bits/stdc++.h>
#define MID int mid=(l+r)>>1;
#define ps push_back
#define LL long long
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
const int maxn=1e5+5;
LL ans[maxn], ans1;
int tree[maxn][3], root[maxn];
vector<int>edg[maxn];
LL sum[maxn*50], num[maxn*50];
int lson[maxn*50], rson[maxn*50];
vector<LL>val;
LL V[maxn];
int cnt, n;
int a[maxn];
void insert(int &root, int l, int r, int x, int k)
{
if(root==0)root=++cnt;
if(l==r)
{
k+=num[root];
ans1+=V[x]*((long long)k+1LL);
sum[root]+=V[x];
num[root]++;
return;
}
MID
if(x<=mid)insert(lson[root], l, mid, x, k+num[rson[root]]);
else
{
ans1+=sum[lson[root]];
insert(rson[root], mid+1, r, x, k);
}
sum[root]=sum[lson[root]]+sum[rson[root]];
num[root]=num[lson[root]]+num[rson[root]];
return;
}
int merg(int l, int r)
{
if(l==0 || r==0)return l|r;
ans1+=sum[lson[l]]*num[rson[r]];
ans1+=sum[lson[r]]*num[rson[l]];
if(lson[l]==0 && rson[l]==0 && lson[r]==0 && rson[r]==0)
{
// printf("%lld %lld\n", sum[l], num[r]);
ans1+=sum[l]*num[r];
}
sum[l]+=sum[r];
num[l]+=num[r];
lson[l]=merg(lson[l], lson[r]);
rson[l]=merg(rson[l], rson[r]);
return l;
}
int len;
void dfs(int x, int fa)
{
int i, to;
ans[x]=0;
tree[x][0]=tree[x][1]=-1;
for(i=0; i<(int)edg[x].size(); i++)
{
to=edg[x][i];
if(to==fa)continue;
if(tree[x][0]==-1)tree[x][0]=to;
else tree[x][1]=to;
dfs(to, x);
ans[x]+=ans[to];
}
ans1=0;
if(tree[x][0]!=-1 && tree[x][1]!=-1)
{
root[x]=merg(root[tree[x][0]], root[tree[x][1]]);
}
else if(tree[x][0]!=-1)root[x]=root[tree[x][0]];
insert(root[x], 1, len, a[x], 0);
ans[x]+=ans1;
return;
}
inline void Read(int &Num)
{
char c = getchar();
while (c < '0' || c > '9') c = getchar();
Num = c - '0'; c = getchar();
while (c >= '0' && c <= '9')
{
Num = Num * 10 + c - '0';
c = getchar();
}
}
void init(int len)
{
for(int i=0; i<=len; i++)
{
root[i]=V[i]=0;
edg[i].clear();
}
val.clear();
}
void init2(int len)
{
for(int i=0; i<=len; i++)
{
sum[i]=num[i]=rson[i]=lson[i]=0;
}
}
int xx[maxn], yy[maxn];
int main()
{
// freopen("C:\\Users\\johsnow\\Desktop\\1001.in", "r", stdin);
// freopen("C:\\Users\\johsnow\\Desktop\\ans.out", "w", stdout);
int x, y, t, i, j;
cin>>t;
while(t--)
{
cnt=0;
Read(n);
init(n);
for(i=1; i<=n; i++)
{
Read(a[i]);
val.ps(a[i]);
}
sort(val.begin(), val.end());
val.erase(unique(val.begin(), val.end()), val.end());
len=val.size();
int pos;
for(i=1; i<=n; i++)
{
pos=lower_bound(val.begin(), val.end(), a[i])-val.begin()+1;
V[pos]=a[i];
a[i]=pos;
}
for(i=1; i<n; i++)
{
Read(x), Read(y);
xx[i]=x, yy[i]=y;
edg[x].ps(y);
edg[y].ps(x);
}
dfs(1, 0);
for(i=1; i<=n; i++)
{
printf("%lld ", ans[i]);
tree[i][0]=tree[i][1]=-1;
}
init2(cnt);
printf("\n");
}
}