题目大意
给定一棵完全二叉树,要求任意一条不拐弯长度为
k
+
1
k+1
k+1的链(即从某个点不断往上跳
k
k
k次
p
a
r
e
n
t
parent
parent),满足链上所有点的和是
m
m
m的倍数。
n
≤
1
0
7
,
k
≤
10
n\le 10^7,k\le 10
n≤107,k≤10
题解
考虑两条相邻的链
a
0
,
a
2
,
.
.
.
,
a
k
a_0,a_2,...,a_k
a0,a2,...,ak与
a
1
,
a
2
,
.
.
.
,
a
k
+
1
a_1,a_2,...,a_{k+1}
a1,a2,...,ak+1,由于他们的和都是
m
m
m的倍数,则显然有
a
0
≡
a
k
+
1
(
m
o
d
m
)
a_0\equiv a_{k+1}~(mod~m)
a0≡ak+1 (mod m)。
也就是说我们最后只需要考虑编号
<
2
k
+
1
<2^{k+1}
<2k+1的那些点。但是我们需要预处理
g
[
i
]
[
j
]
g[i][j]
g[i][j]表示点
i
(
i
<
2
k
+
1
)
i(i<2^{k+1})
i(i<2k+1)及所有需要和点
i
i
i相等的点,全部改成
j
j
j所需要的最小代价。
如果暴力的话,也就是暴力枚举
j
j
j再暴力枚举所有点,复杂度是
O
(
n
m
)
O(nm)
O(nm)的。但是我们可以把权值
m
o
d
m
mod~m
mod m相同的点一起处理,具体的,令
a
l
l
[
i
]
[
j
]
all[i][j]
all[i][j]表示点
i
(
i
<
2
k
+
1
)
i(i<2^{k+1})
i(i<2k+1)及所有需要和点
i
i
i相等的点中,权值
m
o
d
m
=
j
mod~m=j
mod m=j的所有点的单次修改代价之和。
于是我们把问题转化为了
m
m
m个点,暴力枚举
m
m
m次,这样就可以在
O
(
2
k
m
2
)
O(2^km^2)
O(2km2)的时间内解决了。当然前缀和优化一下可以做到
O
(
2
k
m
)
O(2^km)
O(2km),但我懒。
然后我们就可以在只有
2
k
+
1
−
1
2^{k+1}-1
2k+1−1个节点上的树进行dp了。
f
[
i
]
[
j
]
f[i][j]
f[i][j]表示所有叶节点到
i
i
i的权值和
m
o
d
m
mod~m
mod m均为
j
j
j的最小修改代价,暴力枚举当前点上的值转移即可,复杂度
O
(
2
k
m
2
)
O(2^km^2)
O(2km2)。
于是整道题的复杂度就是
O
(
n
+
2
k
m
2
)
O(n+2^km^2)
O(n+2km2)了。
#include <bits/stdc++.h>
namespace IOStream {
const int MAXR = 1 << 23;
char _READ_[MAXR], _PRINT_[MAXR];
int _READ_POS_, _PRINT_POS_, _READ_LEN_;
inline char readc() {
#ifndef ONLINE_JUDGE
return getchar();
#endif
if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
char c = _READ_[_READ_POS_++];
if (_READ_POS_ == MAXR) _READ_POS_ = 0;
if (_READ_POS_ > _READ_LEN_) return 0;
return c;
}
template<typename T> inline void read(T &x) {
x = 0; register int flag = 1, c;
while (((c = readc()) < '0' || c > '9') && c != '-');
if (c == '-') flag = -1; else x = c - '0';
while ((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
x *= flag;
}
template<typename T1, typename ...T2> inline void read(T1 &a, T2 &...x) {
read(a), read(x...);
}
inline int reads(char *s) {
register int len = 0, c;
while (isspace(c = readc()) || !c);
s[len++] = c;
while (!isspace(c = readc()) && c) s[len++] = c;
s[len] = 0;
return len;
}
inline void ioflush() {
fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0;
fflush(stdout);
}
inline void printc(char c) {
_PRINT_[_PRINT_POS_++] = c;
if (_PRINT_POS_ == MAXR) ioflush();
}
inline void prints(char *s) {
for (int i = 0; s[i]; i++) printc(s[i]);
}
template<typename T> inline void print(T x, char c = '\n') {
if (x < 0) printc('-'), x = -x;
if (x) {
static char sta[20];
register int tp = 0;
for (; x; x /= 10) sta[tp++] = x % 10 + '0';
while (tp > 0) printc(sta[--tp]);
} else printc('0');
printc(c);
}
template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
print(x, ' '), print(y...);
}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
typedef pair<int, int> P;
#define cls(a) memset(a, 0, sizeof(a))
const int MAXN = 10000005, MAXK = 2050, MAXM = 205;
ll f[MAXK][MAXM], g[MAXK][MAXM], all[MAXK][MAXM];
int bel[MAXN], T, n, m, K, Q;
unsigned int SA, SB, SC; int pp, A, B;
unsigned int rng61(){
SA ^= SA << 16;
SA ^= SA >> 5;
SA ^= SA << 1;
unsigned int t = SA;
SA = SB;
SB = SC;
SC ^= t ^ SA;
return SC;
}
void gen(){
cls(g), cls(all), memset(f, 0x3f, sizeof(f));
read(n, K, m, pp, SA, SB, SC, A, B);
Q = (1 << (++K)) - 1;
for (int i = 1; i <= Q; i++) bel[i] = i;
for (int i = Q + 1; i <= n; i++) bel[i] = bel[i >> K];
for (int i = 1; i <= pp; i++) {
int a, b; read(a, b);
all[bel[i]][a % m] += b;
}
for (int i = pp + 1; i <= n; i++){
int a = rng61() % A + 1;
int b = rng61() % B + 1;
all[bel[i]][a % m] += b;
}
for (int i = 1; i <= Q; i++) {
for (int j = 0; j < m; j++) {
for (int k = 0; k <= j; k++) g[i][j] += all[i][k] * (j - k);
for (int k = j + 1; k < m; k++) g[i][j] += all[i][k] * (j + m - k);
}
}
}
inline void upd(ll &x, ll y) { x = min(x, y); }
int main() {
for (read(T); T--;) {
gen();
for (int i = Q; i > 0; i--) {
int ls = i << 1, rs = i << 1 | 1;
if (ls > Q) for (int j = 0; j < m; j++) f[i][j] = g[i][j];
else if (rs > Q) {
for (int j = 0; j < m; j++)
for (int k = 0; k < m; k++)
upd(f[i][(j + k) % m], f[ls][k] + g[i][j]);
} else {
for (int j = 0; j < m; j++)
for (int k = 0; k < m; k++)
upd(f[i][(j + k) % m], f[ls][k] + f[rs][k] + g[i][j]);
}
}
printf("%lld\n", f[1][0]);
}
return 0;
}