Comet OJ - Contest #7 C 临时翻出来的题(容斥+状压)
题意
https://www.cometoj.com/contest/52/problem/C?problem_id=2416
思路
这里提供一种容斥的写法(?好像网上没看到这种写法)
题目要求编号为 \(i\) 的节点不能放在 \(p_i\) 位置,那我们不妨假设没有这些条件,然后再用二进制容斥的方法减去不满足条件的情况(即固定某些 \(i\) 在 \(p_i\) 上,这样会好考虑问题一点)。
然后我们面临的问题就是,计算 \(A\)(二进制)这些数不能选,\(B\)(二进制)这些位置不能填的方案数。我们枚举两个数,计算它们的对答案的贡献。设小的数为 \(i\) ,大的数为 \(j\) ,下面分四种情况讨论:
-
\(i,j\) 的位置均已确定
若 \(p_i>p_j\) ,则造成 \((j-i)(p_i-p_j)(n-cnt_A)!\) 的贡献( \(cnt_A\) 为 \(A\) 的二进制中 \(1\) 的个数)。
-
\(j\) 的位置已经确定
\(i\) 能选的位置为 \(\setminus A\)( \(A\) 对全集取补),尝试让 \(\setminus A\) 与 \(j\) 匹配,只有 \(\setminus A\) 中大于 \(j\) 的数才能得配,匹配结果是这个数减 \(j\) 。设 \(\setminus A\) 所有数的总匹配结果为 \(t\) ,它对答案的贡献即为 \((j-i)t(n-cnt_A-1)!\)
-
\(i\) 的位置已经确定
与上一种其实没什么区别,只是 \(i\) 去匹配 \(\setminus A\) 罢了。
-
\(i,j\) 的位置均未确定
这次是一个集合匹配一个集合,在外面通过情况二来预处理。假设总匹配结果为 \(t\) ,对答案的贡献即为 \((j-i)t(n-cnt_A-2)\)
要注意固定数时如果有两个数共用了一个位置(即 \(p\) 相等),那这个贡献就直接为 \(0\) 了。
代码
#include<bits/stdc++.h>
#define FOR(i,x,y) for(int i=(x),i##END=(y);i<=i##END;++i)
#define DOR(i,x,y) for(int i=(x),i##END=(y);i>=i##END;--i)
#define lowbit(x) ((x)&-(x))
template<typename T,typename _T>inline bool chk_min(T &x,const _T y){return y<x?x=y,1:0;}
template<typename T,typename _T>inline bool chk_max(T &x,const _T y){return x<y?x=y,1:0;}
typedef long long ll;
int cnt[(1<<16)+5],bin[(1<<16)+5],sum[(1<<16)+5];
int Sum[(1<<16)+5];
ll fac[18];
int n,p[18];
ll f1(int S,int pos)
{
S=(S|((1<<(pos+1))-1))^((1<<(pos+1))-1);
return sum[S]-cnt[S]*pos;
}
ll f2(int pos,int S)
{
S=S&((1<<pos)-1);
return cnt[S]*pos-sum[S];
}
ll solve(int A,int B)
{
ll ans=0;
FOR(i,0,n-1)FOR(j,i+1,n-1)
{
if((A>>i&1)&&(A>>j&1))
{
if(p[i]>p[j])
ans+=(j-i)*(p[i]-p[j])*fac[n-cnt[A]];
}
else if(!(A>>i&1)&&(A>>j&1))
{
ans+=(j-i)*f1(((1<<n)-1)^B,p[j])*fac[n-cnt[A]-1];
}
else if((A>>i&1)&&!(A>>j&1))
{
ans+=(j-i)*f2(p[i],((1<<n)-1)^B)*fac[n-cnt[A]-1];
}
else ans+=(j-i)*Sum[((1<<n)-1)^B]*fac[n-cnt[A]-2];
}
return ans;
}
int main()
{
fac[0]=1;FOR(i,1,16)fac[i]=fac[i-1]*i;
FOR(i,1,1<<16)cnt[i]=cnt[i^lowbit(i)]+1;
FOR(i,2,1<<16)bin[i]=bin[i>>1]+1;
FOR(i,1,1<<16)sum[i]=sum[i^lowbit(i)]+bin[lowbit(i)];
FOR(i,1,1<<16)Sum[i]=Sum[i^lowbit(i)]+f1(i^lowbit(i),bin[lowbit(i)]);
int T;
scanf("%d",&T);
while(T--)
{
scanf("%d",&n);
FOR(i,0,n-1)scanf("%d",&p[i]),p[i]--;
ll ans=0;
FOR(i,0,(1<<n)-1)
{
int S=0;
bool flg=1;
FOR(j,0,n-1)if(i>>j&1)
{
if(S&(1<<p[j]))
{
flg=0;
break;
}
S|=1<<p[j];
}
if(!flg)continue;
if(cnt[i]&1)ans-=solve(i,S);
else ans+=solve(i,S);
}
printf("%lld\n",ans);
}
return 0;
}