【牛客】排列计数机(线段树 二项式定理)
题目描述
定义一个长为k的序列A1,A2,…,Ak的权值为:对于所有1≤i≤k,max(A1,A2,…,Ai)有多少种不同的取值。
给出一个1到n的排列B1,B2,…,Bn,求B的所有非空子序列的权值的m次方之和。
答案对109+7取模。
输入描述
第一行两个整数n、m。
接下来一行n个整数,第i个整数为Bi。
输出描述
输出一个整数,表示答案。
示例1
输入
3 2
1 3 2
输出
16
说明
在所有非空子序列中:
(1), (3), (2), (3, 2)权值为1,
(1, 3), (1, 2), (1, 3, 2)权值为2。
那么所有非空子序列权值的2次方和为4×12+3×22=16。
备注:
对于前10%的数据,n≤20。
对于前20%的数据,n≤100。
对于前40%的数据,n≤1000。
对于另外20%的数据,m=1。
对于所有数据,1≤n≤105,1≤m≤20,保证B是1到n的排列。
分析
直接求看样子是不可能的,出题人这辈子都不可能让你直接求的
所以考虑从左到右一个一个加入数。
当加入一个数的时候,只有最大值小于这个数的子序列,权值才会被更新(即加1)
我们不可能把每个子序列的权值都求出来后才乘方再加起来,铁定会T
但我们根据二项式定理,可以发现对于任意一个数x
$$(x+1)^m=C_{m}^{0}x^{m}+C_{m}^{1}x^{m-1}+C_{m}^{2}x^{m-2}+\cdots\cdots+C_{m}^{m-2}x^{2}+C_{m}^{m-1}x^{1}+C_{m}^{m}x^{0}$$
通过这个
我们可以通过维护每个子序列权值的m次方和,m-1次方和,m-2次方和。。。。。。1次方和,0次方和,就可以通过上面的公式得到这些子序列权值集体加1后的乘方的和
由于每次插入值的时候更新只跟最大值有关,而且题目中保证了B是一个排列。
所以,我们可以把最大值相同的子序列一起处理,维护它们的m次方和,m-1次方和,m-2次方和。。。。。。
加入一个新的数的时候,找到所有最大值比它小的子序列,将它们的m次方和,m-1次方和。。。。。。加起来,再用二项式定理得到加1后的m次方和,得到一组新的子序列的信息
对于那些最大值比它的子序列,因为无法更新但可以形成新的子序列,直接将和乘以2就好了
这个可以用线段树维护,线段树的每个位置维护相同最大值的子序列的一些信息
Code
#include<cstdio>
#include<cstring>
const int mod=1e9+7;
const int maxn=100005;
int n,m,ans,a[maxn],tmp[30],pre[30],C[30][30],sum[maxn<<2][30],mk[maxn<<2][30];
void pushdown(int id,int l,int r,int k)
{
mk[id<<1][k]=1ll*mk[id<<1][k]*mk[id][k]%mod;
mk[id<<1|1][k]=1ll*mk[id<<1|1][k]*mk[id][k]%mod;
sum[id<<1][k]=1ll*sum[id<<1][k]*mk[id][k]%mod;
sum[id<<1|1][k]=1ll*sum[id<<1|1][k]*mk[id][k]%mod;
mk[id][k]=1;return;
}
void add(int id,int l,int r,int x,int v,int k)
{
if(l==r){sum[id][k]=v;return;}
int mid=(l+r)>>1;pushdown(id,l,r,k);
x<=mid?add(id<<1,l,mid,x,v,k):add(id<<1|1,mid+1,r,x,v,k);
sum[id][k]=(sum[id<<1][k]+sum[id<<1|1][k])%mod;
}
int que(int id,int l,int r,int x,int k)
{
if(r<=x)return sum[id][k];
int mid=(l+r)>>1;pushdown(id,l,r,k);
return (que(id<<1,l,mid,x,k)+(x>mid?que(id<<1|1,mid+1,r,x,k):0))%mod;
}
void mul(int id,int l,int r,int x,int k)
{
if(r<x)return;
if(x<=l){sum[id][k]=2ll*sum[id][k]%mod;mk[id][k]=mk[id][k]*2ll%mod;return;}
int mid=(l+r)>>1;pushdown(id,l,r,k);
x<=mid?mul(id<<1,l,mid,x,k),1:0;
mul(id<<1|1,mid+1,r,x,k);
sum[id][k]=(sum[id<<1][k]+sum[id<<1|1][k])%mod;
}
int main()
{
C[0][0]=1;
for(int i=1;i<=20;i++)
{C[i][0]=1;for(int j=1;j<=i;j++)C[i][j]=(C[i-1][j-1]+C[i-1][j])%mod;}
scanf("%d%d",&n,&m);for(int i=0;i<=n<<2;i++)for(int k=0;k<=m;k++)mk[i][k]=1;
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);memset(tmp,0,sizeof tmp);memset(pre,0,sizeof pre);
for(int k=0;k<=m;k++)pre[k]=que(1,1,n,a[i],k);
for(int k=m;k>=0;k--)for(int j=k;j>=0;j--)
tmp[k]=(tmp[k]+1ll*pre[j]*C[k][j]%mod)%mod;
for(int k=0;k<=m;k++)add(1,1,n,a[i],tmp[k]+1,k),mul(1,1,n,a[i]+1,k);
}
printf("%d",(que(1,1,n,n,m)%mod+mod)%mod);
}