[WC2018]州区划分
题目
题解
首先,作为一位 \(\text{OIer}\) 你首先需要的是能够判断
如果一个州内部存在一条起点终点相同,不经过任何不属于这个州的城市,且经过这个州的所有内部道路都恰好一次并且经过这个州的所有城市至少一次的路径(路径长度可以为 \(0\)),则称这个州是不合法的。
这句话是指我们划分出来的某个州内部不存在欧拉回路。
如何判断?只需要判断两个条件:
- 没有奇数度的点;
- 州是连通的;
满足这两个条件,这个州就是不合法的。
注意到 \(n\le 21\),那么我们可以暴力 \(\mathcal O(m\times 2^n)\) 处理出所有划分方案的合法性,以及每个划分方案的贡献。
接下来,我们想如何解决这道题。
定义 \(dp[s]\) 表示所有州划分过的城市构成 \(i\) 的二进制串,那么我们就有
\[dp[s]=\sum_{i\mid j=s,i\&j=0} dp[i]\times \text{calc}(s,s-i)
\]
这里的 \(\text{calc}(s,s-i)\) 就是我们的某一个州划分为 \(s-i\) 时的满意度,不难得到
\[\begin{aligned}
\text{calc}(s,s-i)&=\left (\frac{sum[s-i]}{sum[s]} \right )^p \\
&=\frac{(sum[s-i])^p}{(sum[s])^p} \\
\end{aligned}
\]
把它代回原来的状转,可得
\[\begin{aligned}
dp[s]=\sum_{i\mid j=s,i\&j=0} dp[i]\times\frac{(sum[s-i])^p}{(sum[s])^p}
\end{aligned}
\]
由于 \(\frac{1}{(sum[s])^p}\) 与 \(i\) 无关,考虑弄到左边去,可得
\[sum[s]\times dp[s]=\sum_{i\mid j=s,i\&j=0} dp[i]\times (sum[s-i])^p
\]
这就是标准的子集卷积了,用相同的方法进行处理之后,乘上 \(\text{inv}(sum[s])\) 就可以得到 \(dp[s]\) 了。
代码
#include<cstdio>
#include<iostream>
using namespace std;
const int MAXN=21;
const int MAXM=(MAXN*(MAXN-1))>>1;
const int MOD=998244353;
const int MAXSIZE=1<<MAXN;
inline void DWT_or(int a[],const int N,const int opt){
for(int t=2;t<=N;t<<=1)
for(int i=0,p=t>>1;i<N;i+=t)for(int j=i;j<i+p;++j){
a[j+p]+=a[j]*opt;
if(a[j+p]>=MOD)a[j+p]-=MOD;
else if(a[j+p]<0)a[j+p]+=MOD;
}
}
struct edge{int u,v;}e[MAXM+5];
int w[MAXN+5],sum[MAXSIZE+5],sz[MAXN+5][MAXSIZE+5];
int tsum[MAXSIZE+5];
int n,m,p;
inline void Init(){
scanf("%d %d %d",&n,&m,&p);
int u,v;
for(int i=1;i<=m;++i){
scanf("%d %d",&u,&v);
e[i]=edge{u-1,v-1};
}
for(int i=0;i<n;++i)scanf("%d",&w[i]);
}
/*
void print(const int s){
for(int i=n-1;i>=0;--i){
if((s>>i)&1)printf("1");
else printf("0");
}putchar('\n');
}
*/
bool vis[MAXN+5];
int d[MAXN+5];
int pre[MAXN+5],siz[MAXN+5];
int findSet(const int u){
return pre[u]==u?u:pre[u]=findSet(pre[u]);
}
inline void unionSet(const int u,const int v){
int x=findSet(u),y=findSet(v);
if(x^y){
pre[x]=y;
siz[y]+=siz[x];
}
}
inline void calc(const int s){
//printf("get a kind of situation : ");print(s);
/*
for(int i=0;i<n;i++)
cout<<vis[i]<<' ';
cout<<'\n';
*/
for(int i=0;i<n;++i){
if(vis[i])d[i]=0,pre[i]=i,siz[i]=1;
else d[i]=-1;
}
for(int i=1;i<=m;++i)if(vis[e[i].u] && vis[e[i].v])
++d[e[i].u],++d[e[i].v],unionSet(e[i].u,e[i].v);
bool flg=true;
int all=-1,cnt=0;
for(int i=0;i<n;++i)if(vis[i]){
++cnt;
all=siz[findSet(i)];
if(d[i]&1)flg=false;
sum[s]+=w[i];
if(sum[s]>=MOD)sum[s]-=MOD;
}tsum[s]=sum[s];
if(flg && all==cnt)sum[s]=0;
}
void dfs(const int now,const int s){
if(now==n)return calc(s);
dfs(now+1,s<<1);
vis[now]=true;
dfs(now+1,s<<1|1);
vis[now]=false;
}
int dp[MAXN+5][MAXSIZE+5];
int bitcnt[MAXSIZE+5];
inline int inv(int n,int m=MOD-2){
int ret=1;
for(;m>0;m>>=1){
if(m&1)ret=1ll*ret*n%MOD;
n=1ll*n*n%MOD;
}return ret;
}
signed main(){
Init();
dfs(0,0);
for(int i=1;i<=MAXSIZE;++i)bitcnt[i]=bitcnt[i>>1]+(i&1);
dp[0][0]=1;
DWT_or(dp[0],1<<n,1);
for(int i=0;i<(1<<n);++i){
if(p==0)sz[bitcnt[i]][i]=sum[i]=(sum[i]==0?0:1),tsum[i]=1;
else if(p==1)sz[bitcnt[i]][i]=sum[i];
else if(p==2)sz[bitcnt[i]][i]=1ll*sum[i]*sum[i]%MOD,sum[i]=sz[bitcnt[i]][i],tsum[i]=1ll*tsum[i]*tsum[i]%MOD/*,cout<<sz[bitcnt[i]][i]<<'\n'*/;
tsum[i]=inv(tsum[i]);
}
for(int i=0;i<=n;i++)
DWT_or(sz[i],1<<n,1);
for(int i=1;i<=n;++i){
for(int j=0;j<i;++j)
for(int k=0;k<(1<<n);++k){
dp[i][k]=dp[i][k]+1ll*dp[j][k]*sz[i-j][k]%MOD;
if(dp[i][k]>=MOD)dp[i][k]-=MOD;
}
DWT_or(dp[i],1<<n,-1);
for(int k=0;k<(1<<n);++k){
dp[i][k]=1ll*dp[i][k]*tsum[k]%MOD;
if(bitcnt[k]^i)dp[i][k]=0;
}
DWT_or(dp[i],1<<n,1);
}
DWT_or(dp[n],1<<n,-1);
printf("%d\n",dp[n][(1<<n)-1]);
return 0;
}