题目链接

https://www.lydsy.com/JudgeOnline/problem.php?id=4361

题解

f[i][j]f[i][j]表示前ii个中选择jj个数,第ii个必须选,得到的方案数。转移从值i\leq i位置的值的位置中选择,可以用树状数组优化。设g[i]g[i]表示选择ii个数,删掉其他的数得到一个不降序列的方案数。g[i]=f[j][i]g[i]=\sum f[j][i]

但是这样不考虑得到不降序列后必须停止。如果要除掉这种情况,那么g[i]=g[i](i+1)×g[i+1]g[i]=g[i]-(i+1)\times g[i+1]即可,因为只要删掉i+1i+1中任意一个就是非法情况了。

代码

#include <cstdio>
#include <algorithm>

int read()
{
  int x=0,f=1;
  char ch=getchar();
  while((ch<'0')||(ch>'9'))
    {
      if(ch=='-')
        {
          f=-f;
        }
      ch=getchar();
    }
  while((ch>='0')&&(ch<='9'))
    {
      x=x*10+ch-'0';
      ch=getchar();
    }
  return x*f;
}

const int maxn=2000;
const int mod=1000000007;

struct data
{
  int val,pos,v;

  bool operator <(const data &other) const
  {
    return val<other.val;
  }
};

int n,f[maxn+10][maxn+10],g[maxn+10],fac[maxn+10],ans;
data d[maxn+10];

bool cmp(data x,data y)
{
  return x.pos<y.pos;
}

struct tree_array
{
  int val[maxn+10];

  int lowbit(int x)
  {
    return x&(-x);
  }

  int add(int x,int v)
  {
    while(x<=n)
      {
        val[x]+=v;
        if(val[x]>=mod)
          {
            val[x]-=mod;
          }
        x+=lowbit(x);
      }
    return 0;
  }

  int getsum(int x)
  {
    int res=0;
    while(x)
      {
        res+=val[x];
        if(res>=mod)
          {
            res-=mod;
          }
        x-=lowbit(x);
      }
    return res;
  }
};

tree_array ta[maxn+10];

int main()
{
  n=read();
  for(int i=1; i<=n; ++i)
    {
      d[i].val=read();
      d[i].pos=i;
    }
  fac[0]=1;
  for(int i=1; i<=n; ++i)
    {
      fac[i]=1ll*fac[i-1]*i%mod;
    }
  std::sort(d+1,d+n+1);
  for(int i=1; i<=n; ++i)
    {
      if((i==1)||(d[i].val!=d[i-1].val))
        {
          d[i].v=d[i-1].v+1;
        }
      else
        {
          d[i].v=d[i-1].v;
        }
    }
  std::sort(d+1,d+n+1,cmp);
  ta[0].add(1,1);
  for(int i=1; i<=n; ++i)
    {
      for(int j=i; j; --j)
        {
          f[i][j]=ta[j-1].getsum(d[i].v);
          ta[j].add(d[i].v,f[i][j]);
        }
    }
  for(int i=1; i<=n; ++i)
    {
      for(int j=1; j<=i; ++j)
        {
          g[j]+=f[i][j];
          if(g[j]>=mod)
            {
              g[j]-=mod;
            }
        }
    }
  for(int i=1; i<=n; ++i)
    {
      g[i]=1ll*g[i]*fac[n-i]%mod;
    }
  for(int i=1; i<=n; ++i)
    {
      g[i]=(g[i]-1ll*g[i+1]*(i+1))%mod;
      if(g[i]<0)
        {
          g[i]+=mod;
        }
      ans+=g[i];
      if(ans>=mod)
        {
          ans-=mod;
        }
    }
  printf("%d\n",ans);
  return 0;
}