二项式反演的三道练习题
板子题,所用到的柿子为
\[f(n)=\sum_{i=n}^m\mathrm{C}_i^ng(i)\Leftrightarrow g(n)=\sum_{i=n}^m(-1)^{i-n}\mathrm{C}_i^nf(i)
\]
先考虑需要多少组糖果比药片能量大的,就是\(\frac{n+k}{2}\),这个随便推推就有了。
考虑dp,设\(f_{i,j}\)表示前\(i\)中 糖果比恰好糖果比药片能量大的组数钦定\(j\)组的方案数。
记\(r_i\)表示比第\(i\)个糖果能量小的药片的个数。
如果序列无序的话,很不好维护这个东西,发现答案与顺序无关,直接排序即可。
然后dp就好维护了,转移方程为
\[f_{i,j} = f_{i-1,j-1}\times (r_i-j+1) + f_{i-1,j}
\]
但是发现这个是不完整的,因为除了钦定的\(j\)个,其它的\(n-j\)个无所谓,要取\(n-j\)的全排列,所以最后要将所有的\(f_{n,i}\)乘上\((n-i)\)的全排列。
然后套二项式反演的柿子求出\(g_k\)即可。
点此查看代码
#include<bits/stdc++.h>
#include<bits/extc++.h>
// using namespace __gnu_pbds;
// using namespace __gnu_cxx;
using namespace std;
#define infile(x) freopen(#x".in","r",stdin)
#define outfile(x) freopen(#x".out","w",stdout)
#define errfile(x) freopen(#x".err","w",stderr)
#define ansfile(x) freopen(#x".ans","w",stdout)
#define rep(i,s,t,p) for(int i = s;i <= t; i += p)
#define drep(i,s,t,p) for(int i = s;i >= t; i -= p)
#ifdef LOCAL
FILE *InFile = infile(in),*OutFile = outfile(out);
// FILE *ErrFile = errfile(err);
#else
FILE *Infile = stdin,*OutFile = stdout;
//FILE *ErrFile = stderr;
#endif
using ll=long long;using ull=unsigned long long;
using db = double;using ldb = long double;
const int N = 2010,mod = 1e9 + 9;
int n,k,a[N],b[N],f[2][N],fac[N],inv[N],r[N],ans;
//f_{i,j}表示前i个 糖果中比药片能量多的组数 至少为j个的方案数。
inline int power(int a,int b,int mod){
int res = 1;
for(;b;b >>= 1,a = 1ll*a*a%mod)
if(b&1) res = 1ll*res*a%mod;
return res;
}
inline int Inv(int a){return power(a,mod-2,mod);}
inline int C(int n,int m){return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;}
inline void solve(){
cin>>n>>k;
if((n-k)%2) return cout<<"0\n",void();
k = (n+k)>>1;
rep(i,1,n,1) cin>>a[i];rep(i,1,n,1) cin>>b[i];
sort(a+1,a+1+n);sort(b+1,b+1+n);
int now = 0;
rep(i,1,n,1){
while(now < n && b[now + 1] < a[i]) ++now;
r[i] = now;
}
fac[0] = 1;
rep(i,1,n,1) fac[i] = 1ll*fac[i-1]*i%mod;
inv[n] = Inv(fac[n]);
drep(i,n-1,0,1) inv[i] = 1ll*inv[i+1]*(i+1)%mod;
f[0][0] = 1;
rep(i,1,n,1) rep(j,0,i,1)
f[i&1][j] = (f[(i-1)&1][j] + (j?1ll*(r[i]-j+1)*f[(i-1)&1][j-1]%mod:0ll))%mod;
int *dp = f[n&1];
rep(i,0,n,1) dp[i] = 1ll*dp[i]*fac[n-i]%mod;
int i = k;
rep(j,i,n,1){
int sgn = ((j-i)&1)?-1:1;
ans = (ans + 1ll*sgn*C(j,i)*dp[j]%mod + mod)%mod;
}
cout<<ans<<'\n';
}
signed main(){
cin.tie(nullptr)->sync_with_stdio(false);
cout.tie(nullptr)->sync_with_stdio(false);
solve();
}
还是套那个柿子。
\[f(n)=\sum_{i=n}^m\mathrm{C}_i^ng(i)\Leftrightarrow g(n)=\sum_{i=n}^m(-1)^{i-n}\mathrm{C}_i^nf(i)
\]
记\(f_i\)表示选出的集合中交集元素个数至少为\(i\)的方案数,有\(f_i=\mathrm{C}_n^i(2^{2^{n-i}})\)
简单解释一下,就是从\(n\)个数中钦定\(i\)个作为交集,然后剩下的数的子集个数为\(2^{n-2}\),可以选择的方案为\(2^{2^{n-i}}\)个。
然后套柿子容斥就可以了。
点此查看代码
#include<bits/stdc++.h>
#include<bits/extc++.h>
// using namespace __gnu_pbds;
// using namespace __gnu_cxx;
using namespace std;
#define infile(x) freopen(#x".in","r",stdin)
#define outfile(x) freopen(#x".out","w",stdout)
#define errfile(x) freopen(#x".err","w",stderr)
#define ansfile(x) freopen(#x".ans","w",stdout)
#define rep(i,s,t,p) for(int i = s;i <= t; i += p)
#define drep(i,s,t,p) for(int i = s;i >= t; i -= p)
#ifdef LOCAL
FILE *InFile = infile(in),*OutFile = outfile(out);
// FILE *ErrFile = errfile(err);
#else
FILE *Infile = stdin,*OutFile = stdout;
//FILE *ErrFile = stderr;
#endif
using ll=long long;using ull=unsigned long long;
using db = double;using ldb = long double;
const int N = 1e7 + 10,mod = 1e9 + 7;
int fac[N],inv[N],n,k,ans;
inline int power(int a,int b,int mod){
int res = 1;
for(;b;b >>= 1,a = 1ll*a*a%mod)
if(b&1) res = 1ll*res*a%mod;
return res;
}
inline int C(int n,int m){return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;}
signed main(){
cin.tie(nullptr)->sync_with_stdio(false);
cout.tie(nullptr)->sync_with_stdio(false);
cin>>n>>k;
fac[0] = 1;
rep(i,1,n,1) fac[i] = 1ll*fac[i-1]*i%mod;
inv[n] = power(fac[n],mod-2,mod);
drep(i,n-1,0,1) inv[i] = 1ll*inv[i+1]*(i+1)%mod;
rep(i,k,n,1){
int sgn = ((i-k)&1)?-1:1;
ans = (ans + 1ll*sgn*C(n,i)*C(i,k)%mod*power(2,power(2,n-i,mod-1),mod)%mod + mod) % mod;
}
cout<<ans<<'\n';
}
没错,还是上面那个柿子。再粘一遍
\[f(n)=\sum_{i=n}^m\mathrm{C}_i^ng(i)\Leftrightarrow g(n)=\sum_{i=n}^m(-1)^{i-n}\mathrm{C}_i^nf(i)
\]
设\(f_{x,i}\)表示以\(x\)为根的子树中钦定\(i\)对点会比出胜负,然后树形背包计算,注意最后考虑钦定这个点和子树中的某个节点的情况。
点此查看代码
#include<bits/stdc++.h>
#include<bits/extc++.h>
// using namespace __gnu_pbds;
// using namespace __gnu_cxx;
using namespace std;
#define infile(x) freopen(#x".in","r",stdin)
#define outfile(x) freopen(#x".out","w",stdout)
#define errfile(x) freopen(#x".err","w",stderr)
#define ansfile(x) freopen(#x".ans","w",stdout)
#define rep(i,s,t,p) for(int i = s;i <= t; i += p)
#define drep(i,s,t,p) for(int i = s;i >= t; i -= p)
#ifdef LOCAL
FILE *InFile = infile(in),*OutFile = outfile(out);
// FILE *ErrFile = errfile(err);
#else
FILE *Infile = stdin,*OutFile = stdout;
//FILE *ErrFile = stderr;
#endif
using ll=long long;using ull=unsigned long long;
using db = double;using ldb = long double;
#define eb emplace_back
const int N = 5010,mod = 998244353;
int n,m,fac[N],inv[N],f[N][N],siz[N][2],g[N];
char s[N];
vector<int> e[N];
inline int power(int a,int b,int mod){
int res = 1;
for(;b;b >>= 1,a = 1ll*a*a%mod)
if(b&1) res = 1ll*res*a%mod;
return res;
}
inline int Inv(int a){return power(a,mod-2,mod);}
inline int C(int n,int m){return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;}
void DP(int x,int fa){
f[x][0] = 1,siz[x][s[x]-'0']++;
for(int y:e[x]){
if(y == fa) continue;
DP(y,x);
drep(j,min(m,(siz[x][0]+siz[x][1])/2),0,1){
drep(k,min((siz[y][0]+siz[y][1])/2,m-j),1,1){
f[x][j+k] = (f[x][j+k]+1ll*f[x][j]*f[y][k]%mod)%mod;
}
}
siz[x][0] += siz[y][0];
siz[x][1] += siz[y][1];
}
if(s[x] == '1')
drep(i,min(m-1,siz[x][0]-1),0,1)
f[x][i+1] = (f[x][i+1] + 1ll*f[x][i]*(siz[x][0]-i)%mod)%mod;
else
drep(i,min(m-1,siz[x][1]-1),0,1)
f[x][i+1] = (f[x][i+1] + 1ll*f[x][i]*(siz[x][1]-i)%mod)%mod;
}
inline void solve(){
cin>>n>>(s+1);m = n>>1;
rep(i,2,n,1){int u,v;cin>>u>>v;e[u].eb(v),e[v].eb(u);}
fac[0] = 1;
rep(i,1,m,1) fac[i] = 1ll*fac[i-1]*i%mod;
inv[m] = Inv(fac[m]);
drep(i,m-1,0,1) inv[i] = 1ll*inv[i+1]*(i+1)%mod;
DP(1,0);
rep(i,0,m,1) g[i] = 1ll*f[1][i]*fac[m-i]%mod;
rep(i,0,m,1){
int ans = 0;
rep(j,i,m,1){
int sgn = ((j-i)&1)?-1:1;
ans = (ans + 1ll*sgn*C(j,i)*g[j]%mod + mod)%mod;
}
cout<<ans<<'\n';
}
}
signed main(){
cin.tie(nullptr)->sync_with_stdio(false);
cout.tie(nullptr)->sync_with_stdio(false);
solve();
}
__________________________________________________________________________________________
本文来自博客园,作者:CuFeO4,转载请注明原文链接:https://www.cnblogs.com/hzoi-Cu/p/18470986