题解 CF1326F2 Wise Men (Hard Version)
CF1326F2 Wise Men (Hard Version)
题目大意
有 \(n\) 个人。给出 \(n\) 个人的「认识情况」(双向且保证合法)。
对于每个长度为 \(n\) 的排列 \(p\),可以用它生成一个长度为 \(n - 1\) 的 \(01\) 串 \(s\)。其中 \(s_i\) 为 \(1\) 当且仅当 \(p_{i}\) 和 \(p_{i + 1}\) 认识。
对于所有 \(2^{n - 1}\) 种 \(01\) 串,分别求出它可以由多少个排列生成。
数据范围:\(2\leq n\leq 18\)。
前置知识:快速沃尔什变换(FWT)
or卷积
给出两个序列\(a,b\),求一个序列\(c\),使得\(c_i=\sum_{j\operatorname{OR}k=i}a_jb_k\)。
仿照FFT的思路,我们构造两个序列\(FWT(a),FWT(b)\)(对应了FFT里的点值),使得\(FWT(c)[i]=FWT(a)[i]\cdot FWT(b)[i]\)。然后再对\(FWT(c)\)做逆变换,得到\(c\)。
FWT算法的结论是:对于or卷积,\(FWT(a)[i]=\sum_{j\operatorname{OR}i=i}a_j\)。可以发现,\(j\operatorname{OR}i=i\)就等价于“\(j\)是\(i\)的一个子集”。
值得一提的是,根据这个定义,FWT-or就相当于是做高维前缀和;FWT-or的逆变换(IFWT-or)就相当于是高维前缀和的逆变换(差分)。
在实现时,对于一个最高次项为\(2^n\)的多项式\(a\),把它分成\(a_0,a_1\)两部分,分别表示前面的\(2^{n-1}\)项和后面的\(2^{n-1}\)项,则:
这个逗号是啥意思?因为\(FWT(a)\)是一个长度为\(2^n\)的序列,因此逗号左边就是序列的前\(2^{n-1}\)项,右边就是序列的后\(2^{n-1}\)项。
而逆变换就把这个过程反过来即可,即:
and卷积
给出两个序列\(a,b\),求一个序列\(c\),使得\(c_i=\sum_{j\operatorname{AND}k=i}a_jb_k\)。
对于and卷积,\(FWT(a)[i]=\sum_{j\operatorname{AND}i=i}a_j\)。可以发现,\(j\operatorname{AND}i=i\)就等价于“\(i\)是\(j\)的一个子集”,和or卷积恰好相反。
同样可以看出,根据这个定义,FWT-and就相当于是做高维后缀和;FWT-and的逆变换(IFWT-and)就相当于是高维后缀和的逆变换(差分)。
在实现时,
同理可以做逆变换:
xor卷积
与本题无关。只是顺带提一下做法:
于是可知,逆变换为:
本题题解
我们设\(ans(s)\)表示串\(s\)的答案。直接求\(ans(s)\)不好求,考虑集合中至少包含\(s\)的答案,即所有\(s\subseteq S\)的\(ans(S)\)之和,记为\(ans'(s)\)。然后我们对\(ans'\)数组做IFWT-and卷积,就可以求出所有\(ans(s)\)。
把朋友之间的关系看做一张无向图。我们定义一条链的长度为它经过的节点数。
那么对于一个长度为\(n-1\)的01串\(s\),它代表的其实是图中的若干条链。具体来讲,如果在串\(s\)后面补上一个\(0\),那么:
- 串中每段连续的\(1\)是一条链。如果有\(x\)个\(1\),则链的长度为\(x+1\)。
- 每个\(0\)是单独的一个节点(也就是一条长度为\(1\)的链)。特别地:一段连续的\(1\)之后的第一个\(0\)除外,它这个位置上的节点已经被计入了上一条连续的\(1\)组成的链中。
按照上述规则,不难发现,所有链的长度之和恰好为\(n\)。而对于一个01串\(s\)来说,\(ans'(s)\)只取决于它划分出的链的长度的可重集。例如:\(ans'(111011)=ans'(110111)\),因为它们的这个可重集都是\(\{1,3,4\}\)。
又因为所有链的长度之和恰好为\(n\),故本质不同的可重集数量只有\(P(n)\)种,其中\(P(n)\)表示\(n\)的划分数。\(P(18)=385\)。于是我们只需要对这\(P(n)\)个“链的长度的可重集”,分别求答案。
设\(f_{i,mask}\)表示对于一个大小为\(i\)的节点集合\(mask\),图中有多少条链,恰好经过\(mask\)中的这些节点。
如果我们求出了\(f_{i,mask}\)数组,那么对于一个“链的长度的可重集”\(T\),它的答案就是\(\displaystyle\sum_{m_1,\dots,m_{|T|}}\ \prod_{i=1}^{|T|}f_{len(T_i),m_i}\)。其中\(len(T_i)\)表示\(T\)中第\(i\)条链的长度。前面的\(\sum\)枚举的是一个\(m_i\)数组,表示对每个\(i\)各取一个大小为\(len(T_i)\)的点集\(m_i\),要求这些\(m_i\)的并为\([1,n]\)且互相不交。容易发现只要并为\([1,n]\)就必然互相不交,因为它们的\(len(T_i)\)之和为\(n\)。所以我们可以做一个FMT-or卷积。把所有\(f_{len(T_i)}\)这\(|T|\)个序列卷起来。卷积结果的\(2^n-1\)项前的系数即为\(T\)这个可重集的答案。
现在最后的问题是\(f_{i,mask}\)数组怎么求。可以做简单的状压DP。设\(dp[mask][j]\)表示经过了\(mask\)中的这些节点,最后一个经过的节点为\(j\)的链的数量。转移时枚举一个不在\(mask\)中切与\(j\)有连边的点作为下一个点即可。则\(f_{i,mask}=\sum_{j=1}^{n}dp[mask][j]\)。
DP求\(f_{i,mask}\)的复杂度为\(O(2^nn^2)\),之后枚举每个可重集,求答案的复杂度为\(O(P(n)2^nn)\),其中\(P(18)=385\)。
参考代码
//problem:CF1326F2
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
if(S==T){
T=(S=buf)+fread(buf,1,MAXN,stdin);
if(S==T)return EOF;
}
return *S++;
}
}//namespace Fread
#ifdef ONLINE_JUDGE
#define getchar Fread::getchar
#endif
inline int read(){
int f=1,x=0;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline ll readll(){
ll f=1,x=0;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline int readbit(){
char ch=getchar();
while(ch<'0'||ch>'1')ch=getchar();
return ch-'0';
}
/* ------ by:duyi ------ */ // myt天下第一
const int MAXN=18;
int n,a[MAXN+5][MAXN+5];
ll dp[1<<MAXN][MAXN+5],f[MAXN+5][1<<MAXN],h[400],ans[1<<MAXN];
int cnt;
map<vector<int>,int>mp;
vector<int>vec[400],tmp;
void dfs(int cur,int lst){
if(cur==n+1){
mp[tmp]=++cnt;
vec[cnt]=tmp;
return;
}
if(n-cur+1<lst)return;
for(int i=lst;cur+i-1<=n;++i){
tmp.pb(i);
dfs(cur+i,i);
tmp.pop_back();
}
}
int bitcnt(uint x){
int res=0;
for(int j=0;j<=31;++j)res+=((x>>j)&1u);
return res;
}
void fwt_or(ll *f,uint n,int flag){
// FWT_or(A)[i] = \sum_{j|i=i} A[j]
//即:j是i的一个子集
for(uint i=1;i<n;i<<=1){
for(uint j=0;j<n;j+=(i<<1)){
for(uint k=0;k<i;++k){
f[i+j+k]+=f[j+k]*flag;
}
}
}
}
void fwt_and(ll *f,uint n,int flag){
// FWT_and(A)[i] = \sum_{j&i=i} A[j]
//即:i是j的一个子集
for(uint i=1;i<n;i<<=1){
for(uint j=0;j<n;j+=(i<<1)){
for(uint k=0;k<i;++k){
f[j+k]+=f[i+j+k]*flag;
}
}
}
}
int main() {
n=read();
dfs(1,1);//搜出所有划分数
for(int i=1;i<=n;++i)for(int j=1;j<=n;++j)a[i][j]=readbit();
//dp[mask][j] 表示经过了mask中这些点,以j结尾的链有多少.
//用来求出 f[i][mask] 表示经过了大小为i的集合mask的链的数量
for(int i=1;i<=n;++i)dp[1u<<(i-1)][i]=1;
for(uint i=1;i<(1u<<n);++i){
for(int j=1;j<=n;++j)if((i>>(j-1))&1u){
for(int k=1;k<=n;++k)if(a[j][k]&&!((i>>(k-1))&1u)){
dp[i|(1u<<(k-1))][k]+=dp[i][j];
}
}
int t=bitcnt(i);
for(int j=1;j<=n;++j)f[t][i]+=dp[i][j];
}
for(int i=1;i<=n;++i)fwt_or(f[i],1u<<n,1);
static ll IE[1<<MAXN],tmp[1<<MAXN];
IE[0]=1;
fwt_or(IE,1u<<n,1);
for(int i=1;i<=cnt;++i){
for(uint j=0;j<(1u<<n);++j)tmp[j]=IE[j];
for(uint j=0;j<vec[i].size();++j){
for(uint k=0;k<(1u<<n);++k)tmp[k]=(ll)tmp[k]*f[vec[i][j]][k];
}
fwt_or(tmp,1u<<n,-1);
h[i]=tmp[(1u<<n)-1];
}
for(uint i=0;i<(1u<<(n-1));++i){
vector<int>tmp;
for(int j=0;j<=n-1;){
int jj=j+1;
while(jj-1<=n-2 && ((i>>(jj-1))&1u))++jj;
tmp.pb(jj-j);
j=jj;
}
sort(tmp.begin(),tmp.end());
//for(uint j=0;j<tmp.size();++j)cout<<tmp[j]<<" ";cout<<endl;
assert(mp.count(tmp));
ans[i]=h[mp[tmp]];
}
fwt_and(ans,(1u<<(n-1)),-1);
for(uint i=0;i<(1u<<(n-1));++i)printf("%lld ",ans[i]);
return 0;
}