Candy Retribution
\(\text{Problem}:\)题目链接
\(\text{Solution}:\)
计数基础内容的综合运用。
首先将第一个限制的下界去掉。记 \(S(x)\) 表示满足 \(a_{1}+a_{2}+...+a_{n}\leq x\) 且满足第二个条件的序列数,则答案为 \(S(R)-S(L-1)\)。
考虑满足 \(a_{1}+a_{2}+...+a_{n}<=K\) 的总序列数为 \(C_{n+K}^{K}\),我们可以求出不满足第二个条件的序列数为 \(W\),则 \(S(K)=C_{n+K}^{n}-W\)。
设序列 \(a\) 的降序排列为 \(b\),\(b_{m}\not= b_{m+1}\) 等价于:枚举 \(x\),\(m\) 个数大于等于 \(x\) 且 \(b_{m}=x\),\(n-m\) 个数小于 \(x\) 的总序列数。对于 \(m\) 个数中至少一个数为 \(x\) 的限制,套路的,将其转化为:\(m\) 个数大于等于 \(x\) 的序列数,减去 \(m\) 个数大于等于 \(x+1\) 的序列数。记 \(f(x,y)\) 表示当 \(K\) 固定时,有 \(m\) 个数大于等于 \(x\),\(n-m\) 小于等于 \(y\) 的序列数,则 \(W=\sum\limits_{x} f(x,x-1)-f(x+1,x-1)\)。
考虑求出 \(f(x,y)\)。对于 \(m\) 个大于等于 \(x\) 的数,把它们都减去 \(x\),则去除了这 \(m\) 个数的范围限制。对于 \(n-m\) 个小于等于 \(x\) 的数,利用容斥思想,设有 \(i\) 个数大于 \(y\),把它们减去 \(y+1\),则去除了这 \(n-m\) 个数的限制。则只需要考虑 \(n\) 个数总和为 \(K-mx-i(y+1)\) 的序列数即可,这个非常好算:
设 \(n,m,L,R\) 同阶,则总时间复杂度为 \(O(n\log n)\)。
\(\text{Code}:\)
#include <bits/stdc++.h>
#pragma GCC optimize(3)
#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
using namespace std; const int N=600010, Mod=1e9+7;
inline int read()
{
int s=0, w=1; ri char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
return s*w;
}
int n,m,L,R,fac[N+10],inv[N+10];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=x*x%Mod) if(p&1) res=res*x%Mod; return res; }
inline int C(int x,int y) { if(x<y||x<0||y<0) return 0; return fac[x]*inv[x-y]%Mod*inv[y]%Mod; }
inline int F(int a,int b,int G)
{
if(b*m>G) return 0;
int gg=0;
for(ri int i=0;i<=n-m;i++)
{
int tp=(i&1)?(-1):(1);
if(G-b*m-i*(a+1)<0) break;
gg=(gg+tp*C(n-m,i)%Mod*C(G-b*m-i*(a+1)+n,n)%Mod+Mod)%Mod;
}
return gg*C(n,m)%Mod;
}
inline int Solve(int G)
{
if(!G) return 1;
int res=0;
for(ri int i=1;m*i<=G;i++)
{
res=(res+F(i-1,i,G)-F(i-1,i+1,G)+Mod)%Mod;
}
return (C(n+G,n)-res+Mod)%Mod;
}
signed main()
{
fac[0]=1;
for(ri int i=1;i<=N;i++) fac[i]=fac[i-1]*i%Mod;
inv[N]=ksc(fac[N],Mod-2);
for(ri int i=N;i;i--) inv[i-1]=inv[i]*i%Mod;
n=read(), m=read(), L=read(), R=read();
printf("%lld\n",(Solve(R)-Solve(L-1)+Mod)%Mod);
return 0;
}