[LOJ][WC2018]州区划分

链接

根据题意我们可以列出这样的一个转移方程

\[dp_{S} = sum_{S}^{-1} * \sum_{T \subset S} dp_{T} * (sum_{S - T})^p \]

这玩意很明显是不能直接算的 , 如果我们直接算它的话是 O(3 ^ n) 的一个复杂度 ( $ \sum_{m <= n} C_{n}^{m} * 2 ^ m = 3 ^ n $ ) , 丢到WC里面就是 $ 10 ~ 20pt $ 了

根据子集卷积的常规套路qwq , 给 $ dp $ 再加一个下标 $ i $ 表示集合的大小。
这时我们得到数组 $ dp_{i , S} $ , 然后我们再对每个 $ 0 <= i <= n $ 做一次 FMT后得到 $ dp' $ 数组 , 且 $$ dp'{i , S} = \sum [|T| = i] dp_{i , T} $$

接着我们只要枚举原式中 $ S $ 的大小 $ i $ 和 $ T $ 的大小 $ i - j $ 做一次二维 $ dp $ 就行了 ,转移方式可以看代码 qwq 。(这玩意再过几年就是模板题了)
(ps. 不要忘记每次算完一个 $ i $ 以后要 IFMT 回来乘 $ sum_{S}^{-1} $ 再 FMT 回去)


#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<map>
#include<queue>
#include<set>
#include<vector>
#include<cstdio>
#define lson k << 1
#define rson k << 1 | 1

using namespace std;
typedef long long ll;
const int N = 1 << 22 , Mod = 998244353 , G = 3 , INF = 0x3f3f3f3f;
double Pi = acos(-1);

struct Virt{
    double x , y;
    Virt(double _x = 0.0 , double _y = 0.0):x(_x) , y(_y){}
};

Virt operator + (Virt x , Virt y){return Virt(x.x + y.x , x.y + y.y);}
Virt operator - (Virt x , Virt y){return Virt(x.x - y.x , x.y - y.y);}
Virt operator * (Virt x , Virt y){return Virt(x.x * y.x - x.y * y.y , x.x * y.y + x.y * y.x);}

int read(){
	int x = 0 , f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -1 ; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = x * 10 + ch - '0'; ch = getchar();}
	return x * f;
}

int n , m , d[22] , num , head[22] , p , w[N] , g[N] , sum[N] , fa[N] , size[N];
ll inv[N] , temp[22][N] , dp[22][N];

struct edge{
	int pos , nx;
}e[600];

void add(int u , int v){
	e[++num].pos = v; e[num].nx = head[u]; head[u] = num;
}

ll power(ll x , int y){
	ll ans = 1;
	for(; y ; y >>= 1 , x = x * x % Mod) if(y & 1) ans = ans * x % Mod;
	return ans;
}

int find(int x){return fa[x] == x ? x : (fa[x] = find(fa[x]));}

bool check(int s){
	int cnt = 0;
	for(int i = 0 ; i < n ; i++) if((s >> i) & 1) d[i] = 0 , sum[s] += w[i] , fa[i] = i , cnt++;
	size[s] = cnt;
	for(int i = 0 ; i < n ; i++){
		if(((s >> i) & 1) == 0) continue;
		for(int j = head[i] ; j ; j = e[j].nx){
			if(((s >> e[j].pos) & 1) == 0) continue;
			d[i]++ , d[e[j].pos]++;
			if(find(i) != find(e[j].pos)){
				fa[find(i)] = find(e[j].pos);
				cnt--;
			}
		}
	}
	if(cnt > 1) return 1;
	for(int i = 0 ; i < n ; i++) if(((s >> i) & 1) && (d[i] & 1)) return 1;
	return 0;
}

ll calc(ll val){
	if(p == 0) return 1;
	else if(p == 1) return val;
	else return val * val % Mod;
}

void FMT(ll F[] , int len){
	for(int i = 1 ; i < len ; i <<= 1)
		for(int j = 0 ; j < len ; j++)
			if(i & j) F[j] = (F[j] + F[j ^ i]) % Mod;
}

void IFMT(ll F[] , int len){
	for(int i = 1 ; i < len ; i <<= 1)
		for(int j = 0 ; j < len ; j++)
			if(i & j) F[j] = (F[j] - F[j ^ i] + Mod) % Mod;
}

void Set_Convolution(){
	int L = 1 << n;
	dp[0][0] = 1; FMT(dp[0] , L);
	for(int i = 0 ; i <= n ; i++) FMT(temp[i] , L);
	for(int i = 1 ; i <= n ; i++){
		for(int j = 0 ; j < i ; j++)
			for(int k = 0 ; k < L ; k++)
				dp[i][k] = (dp[i][k] + dp[j][k] * temp[i - j][k]) % Mod;
		IFMT(dp[i] , L);
		for(int j = 1 ; j < L ; j++) dp[i][j] = dp[i][j] * inv[j] % Mod;
		if(i == n) continue;
		FMT(dp[i] , L);
	}
}

int main(){
	int u , v;
	n = read(); m = read(); p = read();
	for(int i = 1 ; i <= m ; i++){
		u = read() - 1; v = read() - 1;
		add(u , v);
	}
	for(int i = 0 ; i < n ; i++) w[i] = read();
	for(int i = 0 ; i < (1 << n) ; i++) g[i] = check(i);
	for(int i = 0 ; i < (1 << n) ; i++){
		sum[i] = calc(sum[i]);
		inv[i] = power(sum[i] , Mod - 2);
		if(g[i]) temp[size[i]][i] = sum[i];
	}
	Set_Convolution();
	printf("%lld\n",dp[n][(1 << n) - 1]);
	return 0;
}

posted @ 2018-04-30 21:49  FranceDisco  阅读(278)  评论(0编辑  收藏  举报