题意:给定两数组 A A A, B B B,定义 C C C数组为: C k = m a x ( A i ∗ B j ) C_k=max(A_i*B_j) Ck=max(Ai∗Bj), i i i& j j j ≥ k \geq k ≥k,求 ∑ i = 1 n C i \sum_{i=1}^{n}C_i ∑i=1nCi m o d p modp modp
解析:我们定义
D
D
D数组为:
D
k
=
m
a
x
(
A
i
∗
B
j
,
D
k
+
1
)
D_k=max(A_i*B_j,D_{k+1})
Dk=max(Ai∗Bj,Dk+1),
i
i
i&
j
=
k
j=k
j=k,则
C
i
=
D
i
C_i=D_i
Ci=Di,我们从
k
=
n
k=n
k=n开始求
D
k
D_k
Dk,每次用四个数组
m
a
x
a
,
m
i
n
a
,
m
a
x
b
,
m
i
n
b
maxa,mina,maxb,minb
maxa,mina,maxb,minb记录和
i
i
i进行与操作后可以大于等于
i
i
i的所有下标的最大
A
A
A,最小
A
A
A,最大
B
B
B,最小
B
B
B。
接下来感性理解一下:以
m
a
x
a
maxa
maxa进行举例:
m
a
x
a
i
=
m
a
x
(
m
a
x
a
i
∣
(
1
<
<
j
)
,
A
i
)
,
i
∣
(
1
<
<
j
)
≤
n
−
1
maxa_i = max(maxa_{i|(1<<j)},A_i),{i|(1<<j) \leq n-1}
maxai=max(maxai∣(1<<j),Ai),i∣(1<<j)≤n−1,
i
∣
(
1
<
<
j
)
i|(1<<j)
i∣(1<<j)可以得到至多将1个0变为1的的最大A,而
m
a
x
a
i
∣
1
<
<
j
maxa_{i|1<<j}
maxai∣1<<j得到的是
i
∣
(
1
<
<
j
)
i|(1<<j)
i∣(1<<j)至多将1个0变为1的的最大A,以此类推,
m
a
x
a
i
maxa_i
maxai可以得到所有
i
i
i&
j
j
j
≥
i
\geq i
≥i的A。其他三数组同理。
因为
∣
A
∣
,
∣
B
∣
≤
1
e
9
|A|,|B| \leq 1e9
∣A∣,∣B∣≤1e9,为了防止负数,所以需要
m
i
n
a
,
m
i
n
b
mina,minb
mina,minb来处理负数情况,
C
i
=
m
a
x
(
C
i
+
1
,
m
a
x
a
∗
m
a
x
b
,
m
i
n
a
∗
m
i
n
b
,
m
a
x
a
∗
m
i
n
b
,
m
a
x
b
∗
m
i
n
a
)
C_i=max(C_{i+1},maxa*maxb,mina*minb,maxa*minb,maxb*mina)
Ci=max(Ci+1,maxa∗maxb,mina∗minb,maxa∗minb,maxb∗mina)
AcCode:
#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#define inf 2e18
#define int long long
const int N = 1e6 + 100;
const int mod = 998244353;
int a[N], b[N], maxa[N], maxb[N], mina[N], minb[N];
inline int max(int a, int b) { return (a > b) ? a : b; }
inline int min(int a, int b) { return (a < b) ? a : b; }
inline int max(int a, int b, int c, int d, int e) {
int res = a;
if (b > res) res = b;
if (c > res) res = c;
if (d > res) res = d;
if (e > res) res = e;
return res;
}
signed main() {
int t; scanf("%lld", &t);
const int sign = 1;
while (t--) {
int n; scanf("%lld", &n);
for (int i = 0; i < n; i++) scanf("%lld", &a[i]), maxa[i] = mina[i] = a[i];
for (int i = 0; i < n; i++) scanf("%lld", &b[i]), maxb[i] = minb[i] = b[i];
for (int i = n - 1; i >= 0; i--) {
for (int j = 0; (sign << j) < n; j++) {
int temp = (sign << j);
temp = temp | i;
if (temp >= n) continue;
maxa[i] = max(maxa[i], maxa[temp]);
mina[i] = min(mina[i], mina[temp]);
maxb[i] = max(maxb[i], maxb[temp]);
minb[i] = min(minb[i], minb[temp]);
}
}
int ans = 0, rem = -inf;
for (int i = n - 1; i >= 0; i--) {
rem = max(rem, maxa[i] * maxb[i], mina[i] * minb[i], maxa[i] * minb[i], mina[i] * maxb[i]);
ans = (ans + rem) % mod;
}
printf("%lld\n", (ans % mod + mod) % mod);
}
}