Candy Retribution

\(\text{Problem}:\)题目链接

\(\text{Solution}:\)

计数基础内容的综合运用。

首先将第一个限制的下界去掉。记 \(S(x)\) 表示满足 \(a_{1}+a_{2}+...+a_{n}\leq x\) 且满足第二个条件的序列数,则答案为 \(S(R)-S(L-1)\)

考虑满足 \(a_{1}+a_{2}+...+a_{n}<=K\) 的总序列数为 \(C_{n+K}^{K}\),我们可以求出不满足第二个条件的序列数为 \(W\),则 \(S(K)=C_{n+K}^{n}-W\)

设序列 \(a\) 的降序排列为 \(b\)\(b_{m}\not= b_{m+1}\) 等价于:枚举 \(x\)\(m\) 个数大于等于 \(x\)\(b_{m}=x\)\(n-m\) 个数小于 \(x\) 的总序列数。对于 \(m\) 个数中至少一个数为 \(x\) 的限制,套路的,将其转化为:\(m\) 个数大于等于 \(x\) 的序列数,减去 \(m\) 个数大于等于 \(x+1\) 的序列数。记 \(f(x,y)\) 表示当 \(K\) 固定时,有 \(m\) 个数大于等于 \(x\)\(n-m\) 小于等于 \(y\) 的序列数,则 \(W=\sum\limits_{x} f(x,x-1)-f(x+1,x-1)\)

考虑求出 \(f(x,y)\)。对于 \(m\) 个大于等于 \(x\) 的数,把它们都减去 \(x\),则去除了这 \(m\) 个数的范围限制。对于 \(n-m\) 个小于等于 \(x\) 的数,利用容斥思想,设有 \(i\) 个数大于 \(y\),把它们减去 \(y+1\),则去除了这 \(n-m\) 个数的限制。则只需要考虑 \(n\) 个数总和为 \(K-mx-i(y+1)\) 的序列数即可,这个非常好算:

\[\qquad f(x,y)=C_{n}^{m}\sum\limits_{i=0}^{n-m} (-1)^{i}\times C_{n-m}^{i}\times C_{K-mx-i(y+1)+n}^{n} \qquad \]

\(n,m,L,R\) 同阶,则总时间复杂度为 \(O(n\log n)\)

\(\text{Code}:\)

#include <bits/stdc++.h>
#pragma GCC optimize(3)
#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
using namespace std; const int N=600010, Mod=1e9+7;
inline int read()
{
	int s=0, w=1; ri char ch=getchar();
	while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
	while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
	return s*w;
}
int n,m,L,R,fac[N+10],inv[N+10];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=x*x%Mod) if(p&1) res=res*x%Mod; return res; }
inline int C(int x,int y) { if(x<y||x<0||y<0) return 0; return fac[x]*inv[x-y]%Mod*inv[y]%Mod; }
inline int F(int a,int b,int G)
{
	if(b*m>G) return 0;
	int gg=0;
	for(ri int i=0;i<=n-m;i++)
	{
		int tp=(i&1)?(-1):(1);
		if(G-b*m-i*(a+1)<0) break;
		gg=(gg+tp*C(n-m,i)%Mod*C(G-b*m-i*(a+1)+n,n)%Mod+Mod)%Mod;
	}
	return gg*C(n,m)%Mod;
}
inline int Solve(int G)
{
	if(!G) return 1;
	int res=0;
	for(ri int i=1;m*i<=G;i++)
	{
		res=(res+F(i-1,i,G)-F(i-1,i+1,G)+Mod)%Mod;
	}
	return (C(n+G,n)-res+Mod)%Mod;
}
signed main()
{
	fac[0]=1;
	for(ri int i=1;i<=N;i++) fac[i]=fac[i-1]*i%Mod;
	inv[N]=ksc(fac[N],Mod-2);
	for(ri int i=N;i;i--) inv[i-1]=inv[i]*i%Mod;
	n=read(), m=read(), L=read(), R=read();
	printf("%lld\n",(Solve(R)-Solve(L-1)+Mod)%Mod);
	return 0;
}
posted @ 2021-03-04 15:05  zkdxl  阅读(78)  评论(0编辑  收藏  举报