题意
给N(10^5)个点,每个点有个权值A[i],对于所有非空子集,该子集的值就是其前K大的数的和,不够K个数就是所有数的和。求对于K=1~N,每个K值对应的子集值的和。
思路
首先枚举所有子集,那么和点的顺序无关,可以先排序方便组合数学的计算,这个直觉做法很重要,这里就按sort默认从小到大排序。然后不难想到按照贡献法,可以计算每个值*它在多少个子集中被计算。一个点被计算,它要是前K大,他所出现的子集,比他小的点可以随便选,比他大的点只能选择不超过K个。那么考虑第i个点它左边(小于)的点随意选有2^(i-1)种可能,右边(大于)可以选择1~k-1个,是组合数求和,然后左右乘起来就是被计算的次数。可以得到基本公式
ans[k]=∑(i=0)^(n-1)▒A[i] ⋅2^ⅈ⋅∑(j=0)^(k-1)▒C_(n-1-ⅈ)^j
并定义j>n-1-i时组合数值为0。J部分收到i的影响无法化成卷积形式,那么ans[k]就不是卷积的直接结果。交换求和次序,并展开组合数(别忘了组合数为0的情况)可得
ans[k]=∑(j=0)^(k-1)▒∑(i:0)^(n-1-j)▒A[i] ⋅2^ⅈ⋅(n-1-ⅈ)!/j!(n-1-ⅈ-j)!
发现ans[k]是前缀和的形式,而i和n的关系有点卷积了。我们想要c[i]=sum{a[i]*b[n-i]}的形式,那么不妨设f[j]表示去掉前缀和部分的式子,令t=n-1-j,即j=n-1-t并单独处理j!可得卷积形式
(n-1-t)!f[n-1-t]=∑_(i=0)^t▒A[i] ⋅2^ⅈ⋅(n-1-ⅈ)!÷(t-ⅈ)!
用除号隔开就是两个序列,然后卷积结果除以(n-1-t)!再求后缀和就是答案了。
卷积和是有快速算法的,这道题有模数998244353=7*17*2^23+1,是一个快速数论变换的专用模,那首选没有精度问题的快速数论变换NTT。这个原理挺复杂的,以后有时间在总结,做此题倒是整理了一个自觉非常完美的模板。
AC代码 C++
#include <stdio.h>
#include <algorithm>
using namespace std;
#define NTT_MOD 998244353 //模数998244353=7*17*2^23+1
#define NTT_G 3 //与模数对应的原根值
#define NTT_MAXW 30 //原根数组的最大长度,取log2(MOD)即可
inline int ntt_pow(long long a, int n = NTT_MOD - 2)
{ //快速幂,默认求逆元
long long res = 1;
do
{
if (n & 1)
res = res * a % NTT_MOD;
a = a * a % NTT_MOD;
} while (n >>= 1);
return (int)res;
}
void ntt(int* x, int n, bool inv = false)
{ //快速数论变换,inv表示是否进行反变换,默认false表示正变换,变换长度应为2的倍数,若线性卷积则大于序列长两倍
int i, j, k, d, w, cur, tmp;
static int g[NTT_MAXW], ng[NTT_MAXW]; //原根及其逆
if (!*ng) //bupt_boning专用初始化
for (i = 1, k = NTT_MOD - 1 >> 1, *ng = ntt_pow(NTT_G); !(k & 1); i++, k >>= 1)
{
g[i] = ntt_pow(NTT_G, k);
ng[i] = ntt_pow(*ng, k);
}
for (i = n >> 1, j = 1; j < n; j++, i ^= k)
{ //调整位置
if (i < j)
swap(x[i], x[j]);
for (k = n >> 1; i&k; k >>= 1)
i ^= k;
}
for (d = 2, k = 1; d <= n; d <<= 1, k++) //蝶形运算
for (i = 0,w = inv ? ng[k] : g[k]; i < n; i += d)
for (j = i, cur = 1; j < i + (d >> 1); j++)
{
tmp = (long long)x[j + (d >> 1)] * cur % NTT_MOD;
x[j + (d >> 1)] = (x[j] - tmp + NTT_MOD) % NTT_MOD;
x[j] = (x[j] + tmp) % NTT_MOD;
cur = (long long)cur * w % NTT_MOD;
}
if (inv) //对逆变换的乘以逆元处理
for (i = 0, w = ntt_pow(n); i < n; i++)
x[i] = (long long)x[i] * w % NTT_MOD;
}
#define MAXN 100005
int a[MAXN << 2];
int b[MAXN << 2];
long long ans[MAXN];
long long jc[MAXN] = { 1 };
long long ic[MAXN] = { 1 };
long long p2[MAXN] = { 1 };
int main()
{
int t, n, i, len;
scanf("%d", &t);
for (i = 1; i<MAXN; i++)
{
jc[i] = jc[i - 1] * i % NTT_MOD;
ic[i] = ntt_pow(jc[i], NTT_MOD - 2);
p2[i] = p2[i - 1] * 2 % NTT_MOD;
}
while (t-- && scanf("%d", &n) > 0)
{
for (i = 0; i<n; i++)
scanf("%d", a + i);
sort(a, a + n);
for (len = 1; len<n << 1; len <<= 1);
for (i = 0; i<n; i++)
{
a[i] = p2[i] * a[i] % NTT_MOD * jc[n - 1 - i] % NTT_MOD;
b[i] = ic[i];
}
do a[i] = b[i] = 0;
while (++i < len);
ntt(a, len, 0);
ntt(b, len, 0);
for (i = 0; i<len; i++)
b[i] = (long long)a[i] * b[i] % NTT_MOD;
ntt(b, len, true);
for (i = n, ans[n] = 0; i--;)
printf("%d ", ans[i] = (ans[i + 1] + (long long)b[i] * ic[n - 1 - i] % NTT_MOD) % NTT_MOD);
putchar('\n');
}
return 0;
}