[AGC045D] Lamps and Buttons 题解
[AGC045D] Lamps and Buttons 题解
首先,由于排列生成随机,所以最优决策就是不决策(反正你也不知道),也就是,让 Snuke 从左往右依次按。
那么,什么情况下 Snuke 会输呢?我们可以把每个 \(p_i\) 向 \(i\) 连边,我们发现,如果灭着的灯里面存在自环,也就是只能自己打开自己的,或者在打开所有灯之前,把亮着的灯中有自环的灭掉了,都会输。那么,我们可以考虑枚举 \([1, A]\) 中第一个自环的位置 \(t\),(还有可能没有自环,那么就设为 \(A+1\) ),获胜的条件就变为,在按到 \(t\) 之前,能够把其它的灯全打开。而这等效于,对于 $ \forall x \in [A+1, n] , \exists i \in [1, t-1] $,使得 \(x\) 与 \(i\) 在一个环内。
现在我们就要求 \(t\) 内没有自环的情况,考虑二项式反演。我们钦定 \(t\) 内有 \(k\) 个自环,则整个序列被划分为五部分:\([1, t-1]\) 内的自环,\([1, t-1]\) 内其他的点,\(t\) 点,\([t+1, A]\)的亮灯部分,以及最初没有亮的 \([A+1, n]\) 部分。对于第一个部分是在 \(t-1\) 中取出 \(k\) 个点,第三个部分不用考虑,我们来讨论剩下的部分。
因为要成环,我们就通过插入来考虑。每次插入一个点 \(p\),都相当于是断开前面连接两个点 \(u, v\) 的一条边,然后加上 \(u\) 到 \(p\) 和 \(p\) 到 \(v\) 的边。当然,也可以自己成为一个新环。那么,对于第二个部分内的点,可以随意插入,也可以自成环,所以每枚举到第 \(i\) 个点都有 \(i\) 种方案;而为了保证 “对于 $ \forall x \in [A+1, n] , \exists i \in [1, t-1] $,使得 \(x\) 与 \(i\) 在一个环内。” 这个条件,连完第二部分后就要去考虑第五部分。这个部分中的点不能连自环,故每次枚举贡献 \(i-1\);然后第四部分也是随意插入,每次贡献 \(i\)。注意这里的 \(i\) 是连续枚举而非分别枚举。我们令这三部分的大小分别为 \(a, b,
c\),整理一下,则有 \(g_k = {t-1 \choose k} \frac{a(a+b+c)!}{a+b}\)。直接套二项式反演即可。
如果按照我一开始每次单独求 \(a+b\) 逆元,总复杂度为 \(O(n+A^2 \log n)\)(机子快能跑过去)。当然这里可以直接利用阶乘和阶乘逆元求出来,这样复杂度就是 \(O(n+A^2)\)。
代码:
#include<bits/stdc++.h>
using namespace std;
const int mod = 1e9+7;
const int N = 1e7+10, M = 5050;
int fac[N], inv[N];
inline int C(int n, int m){
if(n<0 || m<0 || n<m) return 0;
return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;
}
int n, m;
inline int fpow(int a, int b){
a%=mod;
int ret = 1;
while(b){
if(b & 1){
ret = (1ll*ret*a)%mod;
}
b>>=1;
a = (1ll*a*a)%mod;
}
return ret;
}
void prework(){
fac[0] = 1;
for(int i = 1; i<=n; ++i){
fac[i] = (1ll*fac[i-1]*i)%mod;
}
inv[n] = fpow(fac[n], mod-2);
for(int i = n-1; i>=0; --i){
inv[i] = (1ll*inv[i+1]*(i+1))%mod;
}
}
inline int calc(int a, int b, int c){
return 1ll*a*fac[a+b+c]%mod*fac[a+b-1]%mod*inv[a+b]%mod;
}
int ans;
int main(){
scanf("%d%d", &n, &m);
prework();
int a, b, c;
for(int t = 1; t<=m+1; ++t){
for(int i = 0; i<t; ++i){
a = t-i-1, b = n-m, c = m-t;
if(t == m+1) c = 0;
int fu = (i&1)?-1:1;
ans = (1ll*ans+1ll*fu*C(t-1, a)*calc(a, b, c)%mod)%mod;
ans = (ans+mod)%mod;
}
}
printf("%d\n", ans);
return 0;
}