原题
AT_abc306_h [ABC306Ex] Balance Scale
题目描述
$ 1,2,\ \dots,N $ の番号の付いた $ N $ 個のおもりがあります。
これから天秤を用いて $ M $ 回の重さの比較を行います。
- 比較開始前に、空文字列 $ S $ を用意する。
- $ i $ 回目の比較では、左の皿におもり $ A_i $ のみを、右の皿におもり $ B_i $ のみを乗せる。
- この際、以下の $ 3 $ 通りのうちいずれかの結果が得られる。
- おもり $ A_i $ の方がおもり $ B_i $ より重い。
- この際 $ S $ の末尾に
>
を加える。
- この際 $ S $ の末尾に
- おもり $ A_i $ とおもり $ B_i $ は同じ重さである。
- この際 $ S $ の末尾に
=
を加える。
- この際 $ S $ の末尾に
- おもり $ B_i $ の方がおもり $ A_i $ より重い。
- この際 $ S $ の末尾に
<
を加える。
- この際 $ S $ の末尾に
- おもり $ A_i $ の方がおもり $ B_i $ より重い。
- 天秤が誤った結果を出すことはない。
実験終了後、長さ $ M $ の文字列 $ S $ が得られます。
>
, =
, <
からなる長さ $ M $ の文字列は $ 3^M $ 通りありますが、そのうち実験で得られた $ S $ として考えられるものは何通りありますか?
答えは非常に大きくなることがあるので、 $ 998244353 $ で割った余りを出力してください。
输入格式
入力は以下の形式で標準入力から与えられる。
$ N $ $ M $ $ A_1 $ $ B_1 $ $ A_2 $ $ B_2 $ $ \vdots $ $ A_M $ $ B_M $
输出格式
答えを整数として出力せよ。
输入输出样例 #1
输入 #1
3 3
1 2
1 3
2 3
输出 #1
13
输入输出样例 #2
输入 #2
4 4
1 4
2 3
1 3
3 4
输出 #2
39
输入输出样例 #3
输入 #3
14 15
1 2
1 3
2 4
2 5
2 6
4 8
5 6
6 8
7 8
9 10
9 12
9 13
10 11
11 12
11 13
输出 #3
1613763
思路
题目可化为一个有 M M M 条边的有向无环图。对于每一条边,进行一个定向与合并的操作。
先考虑定向。
可以用拓扑排序对图进行分层(不需要真正进行出来)。然后就可以考虑进行一个状压 D P DP DP 了。
考虑有一个集合 S S S,它表示已经加入图中的点,并设计状态 d p S dp_S dpS。再定义一个集合 T T T,它表示准备加入图中的点,且集合中的点无边相连与 S ∩ T = ∅ S \cap T = \empty S∩T=∅。
那么如果 i ∈ S , j ∈ T i \in S,j \in T i∈S,j∈T,就可以定一条 i → j i \rightarrow j i→j 的边。
但是这样会算重复,便可以在 D P DP DP 是加一个系数, ( − 1 ) ∣ T ∣ − 1 (-1)^{|T|−1} (−1)∣T∣−1。这样就可以实现容斥。
根据上面的便可推出方程式 d p S ∪ T = ∑ S ⊆ T ( − 1 ) ∣ T ∣ − 1 d p S dp_{S \cup T} = \sum_{ S\subseteq T} (-1)^{|T|−1} dp_S dpS∪T=∑S⊆T(−1)∣T∣−1dpS。
再考虑合并。
同样考虑有一个集合 S S S。还有一个集合 T T T。如果集合 T T T 中有两个点相连,那么可以把两个点合并,得到一个新的集合。
Code
#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,m,a[200005],b[200005],f[1<<20],s[1<<20],mod=998244353,dp[1<<20];
int find(int x){
if(x==f[x]){
return x;
}
return f[x]=find(f[x]);
}
signed main(){
cin>>n>>m;
for(int i=1;i<=m;i++){
cin>>a[i]>>b[i];
}
for(int i=1;i<1<<n;i++){
for(int j=1;j<=n;j++){
f[j]=j;
}
for(int j=1;j<=m;j++){
if(i&(1<<a[j]-1)&&i&(1<<b[j]-1)){
f[find(a[j])]=find(b[j]);
}
}
for(int j=1;j<=n;j++){
if(i&(1<<(j-1))&&f[j]==j){
s[i]++;
}
}
}
dp[0]=1;
for(int i=1;i<1<<n;i++){
for(int j=i;j;j=(j-1)&i){
if((s[j]-1)%2==0){
dp[i]+=1*dp[i^j];
}
else{
dp[i]+=-1*dp[i^j];
}
dp[i]=(dp[i]+mod)%mod;
}
}
cout<<dp[(1<<n)-1];
return 0;
}