【牛客】排列计数机(线段树 二项式定理)

题目描述

  定义一个长为k的序列A1,A2,…,Ak的权值为:对于所有1≤i≤k,max⁡(A1,A2,…,Ai)有多少种不同的取值。
  给出一个1到n的排列B1,B2,…,Bn,求B的所有非空子序列的权值的m次方之和。
  答案对109+7取模。

输入描述

  第一行两个整数n、m。
  接下来一行n个整数,第i个整数为Bi。

输出描述

  输出一个整数,表示答案。
  示例1
  输入
  3 2
  1 3 2
  输出
  16
  说明
  在所有非空子序列中:
  (1), (3), (2), (3, 2)权值为1,
  (1, 3), (1, 2), (1, 3, 2)权值为2。
  那么所有非空子序列权值的2次方和为4×12+3×22=16。
备注:
  对于前10%的数据,n≤20。
  对于前20%的数据,n≤100。
  对于前40%的数据,n≤1000。
  对于另外20%的数据,m=1。
  对于所有数据,1≤n≤105,1≤m≤20,保证B是1到n的排列。

分析

  直接求看样子是不可能的,出题人这辈子都不可能让你直接求的

  所以考虑从左到右一个一个加入数。

  当加入一个数的时候,只有最大值小于这个数的子序列,权值才会被更新(即加1)

  我们不可能把每个子序列的权值都求出来后才乘方再加起来,铁定会T

  但我们根据二项式定理,可以发现对于任意一个数x

  $$(x+1)^m=C_{m}^{0}x^{m}+C_{m}^{1}x^{m-1}+C_{m}^{2}x^{m-2}+\cdots\cdots+C_{m}^{m-2}x^{2}+C_{m}^{m-1}x^{1}+C_{m}^{m}x^{0}$$

  通过这个

  我们可以通过维护每个子序列权值的m次方和,m-1次方和,m-2次方和。。。。。。1次方和,0次方和,就可以通过上面的公式得到这些子序列权值集体加1后的乘方的和

  由于每次插入值的时候更新只跟最大值有关,而且题目中保证了B是一个排列。

  所以,我们可以把最大值相同的子序列一起处理,维护它们的m次方和,m-1次方和,m-2次方和。。。。。。

  加入一个新的数的时候,找到所有最大值比它小的子序列,将它们的m次方和,m-1次方和。。。。。。加起来,再用二项式定理得到加1后的m次方和,得到一组新的子序列的信息

  对于那些最大值比它的子序列,因为无法更新但可以形成新的子序列,直接将和乘以2就好了

  这个可以用线段树维护,线段树的每个位置维护相同最大值的子序列的一些信息

  Code

#include<cstdio>
#include<cstring>
const int mod=1e9+7;
const int maxn=100005;
int n,m,ans,a[maxn],tmp[30],pre[30],C[30][30],sum[maxn<<2][30],mk[maxn<<2][30];
void pushdown(int id,int l,int r,int k)
{
    mk[id<<1][k]=1ll*mk[id<<1][k]*mk[id][k]%mod;
    mk[id<<1|1][k]=1ll*mk[id<<1|1][k]*mk[id][k]%mod;
    sum[id<<1][k]=1ll*sum[id<<1][k]*mk[id][k]%mod;
    sum[id<<1|1][k]=1ll*sum[id<<1|1][k]*mk[id][k]%mod;
    mk[id][k]=1;return;
}
void add(int id,int l,int r,int x,int v,int k)
{
    if(l==r){sum[id][k]=v;return;}
    int mid=(l+r)>>1;pushdown(id,l,r,k);
    x<=mid?add(id<<1,l,mid,x,v,k):add(id<<1|1,mid+1,r,x,v,k);
    sum[id][k]=(sum[id<<1][k]+sum[id<<1|1][k])%mod;
}
int que(int id,int l,int r,int x,int k)
{
    if(r<=x)return sum[id][k];
    int mid=(l+r)>>1;pushdown(id,l,r,k);
    return (que(id<<1,l,mid,x,k)+(x>mid?que(id<<1|1,mid+1,r,x,k):0))%mod;
}
void mul(int id,int l,int r,int x,int k)
{
    if(r<x)return;
    if(x<=l){sum[id][k]=2ll*sum[id][k]%mod;mk[id][k]=mk[id][k]*2ll%mod;return;}
    int mid=(l+r)>>1;pushdown(id,l,r,k);
    x<=mid?mul(id<<1,l,mid,x,k),1:0;
    mul(id<<1|1,mid+1,r,x,k);
    sum[id][k]=(sum[id<<1][k]+sum[id<<1|1][k])%mod;
}
int main()
{
    C[0][0]=1;
    for(int i=1;i<=20;i++)
    {C[i][0]=1;for(int j=1;j<=i;j++)C[i][j]=(C[i-1][j-1]+C[i-1][j])%mod;}
    scanf("%d%d",&n,&m);for(int i=0;i<=n<<2;i++)for(int k=0;k<=m;k++)mk[i][k]=1;
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);memset(tmp,0,sizeof tmp);memset(pre,0,sizeof pre);
        for(int k=0;k<=m;k++)pre[k]=que(1,1,n,a[i],k);
        for(int k=m;k>=0;k--)for(int j=k;j>=0;j--)
        tmp[k]=(tmp[k]+1ll*pre[j]*C[k][j]%mod)%mod;
        for(int k=0;k<=m;k++)add(1,1,n,a[i],tmp[k]+1,k),mul(1,1,n,a[i]+1,k);
    }
    printf("%d",(que(1,1,n,n,m)%mod+mod)%mod);
}

 

 

 

posted @ 2019-11-07 08:10  散樗  阅读(358)  评论(0编辑  收藏  举报