题目链接
原题链接
由于异或的优先级最高,我们可以将表达式划分为若干个异或和的加减运算,这里我们仅考虑划分为 s + x s + x s+x,其中 x x x 表示最后一段异或和, s s s 表示除了 x x x 的前缀和。
考虑 d p \rm{dp} dp 维护三个信息:
- d p s [ i ] dps[i] dps[i] 表示 [ 0 , i ) [0, i) [0,i) 里的所有添加方案的前缀和之和;
- d p x [ i ] dpx[i] dpx[i] 表示 [ 0 , i ] [0, i] [0,i] 里所有添加方案的后缀异或和之和;
- c n t [ i ] cnt[i] cnt[i] 表示向 [ 0 , i ] [0, i] [0,i] 添加运算符的合法方案数。
考虑想要在 i i i 之后插入运算符的状态转移。
- 若为
+
+
+,则应该有
- d p s [ i + 1 ] = d p s [ i ] + d p x [ i ] dps[i + 1] = dps[i] + dpx[i] dps[i+1]=dps[i]+dpx[i],表示将上文中的 s + x s + x s+x 加入总和;
- d p x [ i + 1 ] = c n t [ i ] × A [ i + 1 ] dpx[i + 1] = cnt[i] \times A[i + 1] dpx[i+1]=cnt[i]×A[i+1],表示后缀异或和换成单独的一个 A [ i + 1 ] A[i + 1] A[i+1],不过要乘上方案数 c n t [ i ] cnt[i] cnt[i]。
- 若为
−
-
−,则应该有
- d p s [ i + 1 ] = d p s [ i ] + d p x [ i ] dps[i + 1] = dps[i] + dpx[i] dps[i+1]=dps[i]+dpx[i],表示将上文中的 s + x s + x s+x 加入总和;
- d p x [ i + 1 ] = c n t [ i ] × − A [ i + 1 ] dpx[i + 1] = cnt[i] \times -A[i + 1] dpx[i+1]=cnt[i]×−A[i+1],表示后缀异或和换成单独的一个 A [ i + 1 ] A[i + 1] A[i+1],不过要乘上方案数 c n t [ i ] cnt[i] cnt[i]。
- 若为
⊕
\oplus
⊕,则应该有
- d p s [ i + 1 ] = d p s [ i ] dps[i + 1] = dps[i] dps[i+1]=dps[i],表示将上文的 s s s 加入总和;
- d p x [ i + 1 ] = d p x [ i ] ⊕ A [ i + 1 ] dpx[i + 1] = dpx[i] \oplus A[i + 1] dpx[i+1]=dpx[i]⊕A[i+1],表示所有后缀异或和都需要再异或一个 A [ i + 1 ] A[i + 1] A[i+1]。
三者合起来,发现 c n t [ i ] cnt[i] cnt[i] 抵消,不用维护,得到最终的转移方程。
- d p s [ i + 1 ] = 2 ( d p s [ i ] + d p x [ i ] ) + d p s [ i ] , dps[i + 1] = 2(dps[i] + dpx[i]) + dps[i], dps[i+1]=2(dps[i]+dpx[i])+dps[i],
- d p x [ i + 1 ] = d p x [ i ] ⊕ A [ i + 1 ] . dpx[i + 1] = dpx[i] \oplus A[i + 1]. dpx[i+1]=dpx[i]⊕A[i+1].
初始状态。
- d p s [ 0 ] = 0 , dps[0] = 0, dps[0]=0,
- d p x [ 0 ] = A [ 0 ] . dpx[0] = A[0]. dpx[0]=A[0].
最终答案为 d p s [ N − 1 ] + d p x [ N − 1 ] dps[N - 1] + dpx[N - 1] dps[N−1]+dpx[N−1]。
时间复杂度 O ( n ) O(n) O(n)
C++ Code
#include <bits/stdc++.h>
using i64 = long long;
template<class T>
std::istream &operator>>(std::istream &is, std::vector<T> &v) {
for (auto &x: v) {
is >> x;
}
return is;
}
template<class T>
constexpr T power(T a, i64 b) {
T res = 1;
for (; b; b /= 2, a *= a) {
if (b % 2) {
res *= a;
}
}
return res;
}
template<int P>
struct MInt {
int x;
constexpr MInt() : x{} {}
constexpr MInt(i64 x) : x{norm(x % getMod())} {}
static int Mod;
constexpr static int getMod() {
if (P > 0) {
return P;
} else {
return Mod;
}
}
constexpr static void setMod(int Mod_) {
Mod = Mod_;
}
constexpr int norm(int x) const {
if (x < 0) {
x += getMod();
}
if (x >= getMod()) {
x -= getMod();
}
return x;
}
constexpr int val() const {
return x;
}
explicit constexpr operator int() const {
return x;
}
constexpr MInt operator-() const {
MInt res;
res.x = norm(getMod() - x);
return res;
}
constexpr MInt inv() const {
assert(x != 0);
return power(*this, getMod() - 2);
}
constexpr MInt &operator*=(MInt rhs) & {
x = 1LL * x * rhs.x % getMod();
return *this;
}
constexpr MInt &operator+=(MInt rhs) & {
x = norm(x + rhs.x);
return *this;
}
constexpr MInt &operator-=(MInt rhs) & {
x = norm(x - rhs.x);
return *this;
}
constexpr MInt &operator/=(MInt rhs) & {
return *this *= rhs.inv();
}
friend constexpr MInt operator*(MInt lhs, MInt rhs) {
MInt res = lhs;
res *= rhs;
return res;
}
friend constexpr MInt operator+(MInt lhs, MInt rhs) {
MInt res = lhs;
res += rhs;
return res;
}
friend constexpr MInt operator-(MInt lhs, MInt rhs) {
MInt res = lhs;
res -= rhs;
return res;
}
friend constexpr MInt operator/(MInt lhs, MInt rhs) {
MInt res = lhs;
res /= rhs;
return res;
}
friend constexpr std::istream &operator>>(std::istream &is, MInt &a) {
i64 v;
is >> v;
a = MInt(v);
return is;
}
friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a) {
return os << a.val();
}
friend constexpr bool operator==(MInt lhs, MInt rhs) {
return lhs.val() == rhs.val();
}
friend constexpr bool operator!=(MInt lhs, MInt rhs) {
return lhs.val() != rhs.val();
}
};
template<>
int MInt<0>::Mod = 998244353;
template<int V, int P>
constexpr MInt<P> CInv = MInt<P>(V).inv();
constexpr int P = 1000000007;
using Z = MInt<P>;
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int N;
std::cin >> N;
std::vector<int> A(N);
std::cin >> A;
std::vector<Z> dps(N);
std::vector<int> dpx(N);
dpx[0] = A[0];
for (int i = 0; i + 1 < N; i++) {
dps[i + 1] += 2 * (dps[i] + dpx[i]) + dps[i];
dpx[i + 1] = dpx[i] ^ A[i + 1];
}
std::cout << dps.back() + dpx.back() << "\n";
return 0;
}