[LOJ#2340] [WC2018] 州区划分
题目链接
洛谷题面。
LOJ题面。还是LOJ机子比较快
Solution
设\(f(s)\)表示选\(s\)这些城市的总代价,那么我们可以得到一个比较显然的\(dp\):
\[f(s)=\frac{1}{sum_s^p}\sum_{t\subset s} f(t)g(s-t)
\]
其中\(g(i)\)表示合法的\(sum_i^p\),即若不合法则为\(0\)。
这个\(dp\)是\(O(3^n)\)的,注意到这是个子集卷积的形式,我们可以考虑用\(fwt\)优化它。
那么我们可以一层一层的\(dp\),即按照\(bitcnt(s)\)的顺序从小到大计算\(f\),假设当前算到了\(bitcnt=i\),那么那些\(bitcnt<i\)的\(f\)肯定已经算完了,所以我们可以用\(fwt\)加速来算好当前层的\(f\)。
容易发现这样做的复杂度是\(O(n^22^n)\)的,足以通过此题。
洛谷机子好慢啊...要开O2才能过,loj一半时限就过了
#include<bits/stdc++.h>
using namespace std;
void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
#define lf double
#define ll long long
const int maxn = (1<<21)+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;
int cnt[maxn],sum[maxn],g[22][maxn];
int f[22][maxn],n,m,p,w[30],fa[30],d[30],head[30],tot,inv[maxn];
struct edge{int to,nxt;}e[900];
void Add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {Add(u,v),Add(v,u);}
int find(int x) {return fa[x]==x?x:fa[x]=find(fa[x]);}
int add(int x,int y) {return x+y>mod?x+y-mod:x+y;}
int del(int x,int y) {return x-y<0?x-y+mod:x-y;}
int mul(int x,int y) {return 1ll*x*y-1ll*x*y/mod*mod;}
int check(int s) {
if(cnt[s]==1) return 0;
for(int i=1;i<=n;i++) fa[i]=i,d[i]=0;
for(int x=1;x<=n;x++)
if(s>>(x-1)&1) {
for(int u,v,i=head[x];i;i=e[i].nxt) {
if(!(s>>(e[i].to-1)&1)) continue;
d[e[i].to]++;
if((u=find(x))!=(v=find(e[i].to))) fa[u]=v;
}
}
int rt=__builtin_ctz(s&-s)+1;
for(int i=1;i<=n;i++) if(s>>(i-1)&1) if(find(i)!=find(rt)||(d[i]&1)) return 1;
return 0;
}
void fwt(int *r,int N) {
for(int i=1;i<N;i<<=1)
for(int j=0;j<N;j+=i<<1)
for(int k=0;k<i;k++)
r[i+j+k]=add(r[i+j+k],r[j+k]);
}
void ifwt(int *r,int N) {
for(int i=1;i<N;i<<=1)
for(int j=0;j<N;j+=i<<1)
for(int k=0;k<i;k++)
r[i+j+k]=del(r[i+j+k],r[j+k]);
}
int qpow(int a,int x) {
int res=1;
for(;x;x>>=1,a=mul(a,a)) if(x&1) res=mul(res,a);
return res;
}
int main() {
read(n),read(m),read(p);
for(int i=1,x,y;i<=m;i++) read(x),read(y),ins(x,y);
for(int i=1;i<=n;i++) read(w[i]);
int all=1<<n;
for(int i=1;i<all;i++) cnt[i]=__builtin_popcount(i);
for(int i=1;i<all;i++) sum[i]=w[__builtin_ctz(i&-i)+1]+sum[i^(i&-i)];
for(int i=1;i<all;i++) sum[i]=qpow(sum[i],p);
for(int i=0;i<all;i++) g[cnt[i]][i]=sum[i]*check(i),inv[i]=qpow(sum[i],mod-2);
for(int i=0;i<=n;i++) fwt(g[i],all);
f[0][0]=1;fwt(f[0],all);
for(int i=1;i<=n;i++) {
for(int j=0;j<i;j++)
for(int s=0;s<all;s++)
f[i][s]=add(f[i][s],mul(f[j][s],g[i-j][s]));
ifwt(f[i],all);
for(int s=0;s<all;s++) f[i][s]=cnt[s]==i?mul(f[i][s],inv[s]):0;
if(i!=n) fwt(f[i],all);
}
write(f[n][all-1]);
return 0;
}