[WC2018]州区划分
就当那个判断一个州不合法的条件是存在欧拉回路吧
一张无向图存在欧拉回路的条件是
-
图连通
-
不存在度数为奇数的点
于是我们枚举每一个子集,可以在\(O(2^nn^2)\)的时间内判断一个集合是否能独立成为一个州
之后我们设\(dp_i\)表示选取状态为\(i\)的时候的答案,\(s_i\)为这个状态对应的城市的人口之和
于是就有
\[dp_i=\sum_{j\subseteq i} dp_j(\frac{s_{i\bigoplus j}}{s_i})^p
\]
我们可以枚举子集转移,于是复杂度\(O(3^n)\)
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define lb(x) ((x)&-(x))
const int mod=998244353;
const int maxn=(1<<21)+1;
inline int read() {
char c=getchar();int x=0;while(c<'0'||x>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
struct E{int v,nxt;}e[2222];
int n,num,m,p,S,len;
int fa[22],sz[22],c[22],w[22],head[22];
LL s[maxn],dp[maxn],g[maxn],inv[100*22];
int f[maxn];
inline void add(int x,int y) {
e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
inline int find(int x) {
return (x==fa[x])?x:fa[x]=find(fa[x]);
}
inline void merge(int x,int y) {
int xx=find(x),yy=find(y);
if(xx==yy) return;
if(sz[xx]<sz[yy]) fa[xx]=yy,sz[yy]+=sz[xx];
else fa[yy]=xx,sz[xx]+=sz[yy];
}
inline int chk(int S) {
int tot=0;
for(re int i=1;i<=n;i++) c[i]=0,fa[i]=i,sz[i]=1;
for(re int i=1;i<=n;i++) {
if(!(S>>(i-1)&1)) continue;++tot;s[S]+=w[i];
for(re int j=head[i];j;j=e[j].nxt)
if(S>>(e[j].v-1)&1) c[e[j].v]++,merge(i,e[j].v);
}
int t=0;
for(re int i=1;i<=n;i++) {
if(!(S>>(i-1)&1)) continue;
if(sz[find(i)]<tot) return 1;
t|=(c[i]&1);
}
return t;
}
inline LL calc(int t,int i) {
if(p==0) return 1;
if(p==1) return s[t]*inv[s[i]]%mod;
if(p==2) return s[t]*inv[s[i]]%mod*s[t]%mod*inv[s[i]]%mod;
}
int main() {
n=read();m=read();p=read();
for(re int x,y,i=1;i<=m;i++)
x=read(),y=read(),add(x,y),add(y,x);
for(re int i=1;i<=n;i++) w[i]=read(),S+=w[i];
inv[1]=1;
for(re int i=2;i<=S;i++) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
len=(1<<n);
for(re int i=0;i<len;i++) f[i]=chk(i);
dp[0]=1;
for(re int i=1;i<len;i++) {
for(re int t=i;t;t=(t-1)&i) {
if(!f[t]) continue;
dp[i]=(dp[i]+dp[i^t]*calc(t,i)%mod)%mod;
}
}
std::cout<<dp[len-1];
return 0;
}
发现这样不是很科学啊,我们考虑改写一下我们的柿子
\[dp_i=\frac{\sum_{j|k=i,b_j+b_k=b_i}dp_js_k^p}{s_i^p}
\]
也就是
\[dp_is_i^p=\sum_{j|k=i,b_j+b_k=b_i}dp_js_k^p
\]
\(b_i\)表示\(i\)这个状态\(1\)的个数
满足这两个条件我们就可以认为\(j\)是\(i\)的一个子集,同时\(k\)是\(j\)关于\(i\)的补集了
看起来有点像或卷积啦,但是有了那个限制条件看起来又不是很好做
但是我们发现我们可以强行一下,就是对于每一种\(b_i\)分别来做
考虑当前我们求得是\(b_i=T\)的\(dp_i\)的值,从小到大枚举\(T\)的值
我们把\(s_i^p\)和\(dp_i\)按照\(b_i\)分别存好,之后\(fwt\)变换一下
我们可以枚举\(b_j\)的取值,那样相应的\(b_k=T-b_j\),之后我们利用已经得到的\(fwt\)数组直接对应求积就好了
但是我们这样求得是\(dp_i\times s_i^p\),我们求完一种\(T\)就要把这个数组给\(fwt\)回来,乘上\(s_i^p\)的逆元,再\(fwt\)回去,因为之后还需要用
由于我们强行枚举\(b_i\)的取值,所以复杂度是\(O(2^nn^2)\)
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define lb(x) ((x)&-(x))
const int mod=998244353;
const int maxn=(1<<21)+1;
inline int read() {
char c=getchar();int x=0;while(c<'0'||x>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
std::vector<int> v[22];
struct E{int v,nxt;}e[2222];
int n,num,m,p,S,len;
int fa[22],sz[22],c[22],w[22],head[22];
int s[maxn],inv[100*22];
int a[22][maxn],b[22][maxn];
int f[maxn],cnt[maxn];
inline int calc(int t) {
if(!p) return 1;
if(p==1) return t;
return 1ll*t*t%mod;
}
inline int Inv(int t) {
if(!p) return 1;
if(p==1) return inv[t];
return 1ll*inv[t]*inv[t]%mod;
}
inline void add(int x,int y) {
e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
inline int find(int x) {
return (x==fa[x])?x:fa[x]=find(fa[x]);
}
inline void merge(int x,int y) {
int xx=find(x),yy=find(y);
if(xx==yy) return;
if(sz[xx]<sz[yy]) fa[xx]=yy,sz[yy]+=sz[xx];
else fa[yy]=xx,sz[xx]+=sz[yy];
}
inline int chk(int S) {
int tot=0;
for(re int i=1;i<=n;i++) c[i]=0,fa[i]=i,sz[i]=1;
for(re int i=1;i<=n;i++) {
if(!(S>>(i-1)&1)) continue;++tot;s[S]+=w[i];
for(re int j=head[i];j;j=e[j].nxt)
if(S>>(e[j].v-1)&1) c[e[j].v]++,merge(i,e[j].v);
}
int t=0;
for(re int i=1;i<=n;i++) {
if(!(S>>(i-1)&1)) continue;
if(sz[find(i)]<tot) return 1;
t|=(c[i]&1);
}
return t;
}
inline void Fwt(int *f,int o) {
for(re int i=2;i<=len;i<<=1)
for(re int ln=i>>1,l=0;l<len;l+=i)
for(re int x=l;x<l+ln;++x) {
f[x+ln]=(f[x+ln]+o*f[x])%mod;
if(o<0&&f[x+ln]<0) f[x+ln]=(f[ln+x]+mod)%mod;
}
}
int main() {
n=read();m=read();p=read();
for(re int x,y,i=1;i<=m;i++)
x=read(),y=read(),add(x,y),add(y,x);
for(re int i=1;i<=n;i++) w[i]=read(),S+=w[i];
inv[1]=1;
for(re int i=2;i<=S;i++) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
len=(1<<n);
for(re int i=1;i<len;i++) f[i]=chk(i);
b[0][0]=1;
for(re int i=0;i<len;i++) {
cnt[i]=cnt[i>>1]+(i&1);
v[cnt[i]].push_back(i);
a[cnt[i]][i]=calc(s[i])*f[i];
if(cnt[i]==1) b[1][i]=1ll*calc(s[i])*f[i]*Inv(s[i])%mod;
}
for(re int i=1;i<=n;i++) Fwt(a[i],1);
Fwt(b[0],1);Fwt(b[1],1);
for(re int i=2;i<=n;i++) {
for(re int j=1;j<=i;j++)
for(re int k=0;k<len;k++)
b[i][k]=(b[i][k]+1ll*a[j][k]*b[i-j][k]%mod)%mod;
Fwt(b[i],-1);
for(re int j=0;j<v[i].size();j++) {
int k=v[i][j];
b[i][k]=(1ll*b[i][k]*Inv(s[k]))%mod;
}
if(i!=n) Fwt(b[i],1);
}
printf("%d\n",b[n][len-1]);
return 0;
}