bzoj 4361 isn

题目大意

给出一个长度为\(n\)的序列A(A1,A2...AN)
如果序列\(A\)不是非降的,你必须从中删去一个数(随便删都行)
反复执行这种操作,直到\(A\)非降为止
求有多少种不同的操作方案,答案模\(10^9+7\)
(剩下的数相同,操作顺序不同算不同方案)

思路1

自己想了下,如果最后剩下一个长度为\(i\)的合法状态,那么它上一步只能是
长度为\(i\)的非降序列加了一个 降点
\(f[len][0][mx],f[len][1][mx]\)表示当前长度,是否存在降点,当前最大值mx来dp
但发现有些dp转移难以优化

思路2

先不考虑中途非法的限制,再进行容斥

做法

\(g[i]\)为长度为\(i\)的非降序列有多少个
这个可以通过\(f[len][mx]\)+树状数组优化简单求出
然后\(g[i]*(n-i)!\)得到最终串为长度\(i\)的全部方案数
那么我们减掉最终串为长度\(i\)的非法方案数即可
根据思路1,到达长度\(i\)的所有非法方案,上一步都是来自长度\(i+1\)的非降串
(否则有一个降点保底,前面顺序不会导致非法)

而分析一下可以发现
每个长度\(i+1\)的非降串,都会有\(i+1\)条边,且全都是指到长度\(i\)的非降串的
所以非法路径数为\([n-(i+1)]!*g[i+1]*(i+1)\)

solution

#include <cstdio>
#include <cstdlib>
#include <cctype>
#include <cmath>
#include <algorithm>
#include <cstring>
using namespace std;
const int M=2e3+7;
const int Q=1e9+7;
typedef long long LL;
const LL INF=9223372036854775807;

inline int pls(int x,int y){return ((LL)x+y)%Q;}
inline int mul(int x,int y){return 1LL*x*y%Q;}
inline int mns(int x,int y){return pls(x,Q-y);}

inline int ri(){
	int x=0;bool f=1;char c=getchar();
	for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
	for(;isdigit(c);c=getchar()) x=x*10+c-48;
	return f?x:-x;
}

inline LL rl(){
	LL x=0;bool f=1;char c=getchar();
	for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
	for(;isdigit(c);c=getchar()) x=x*10+c-48;
	return f?x:-x;
}

int n,m;
int a[M];
LL val[M],b[M];

struct Bitarr{
	int c[M];
	Bitarr(){memset(c,0,sizeof c);}
	inline int lb(int x){return x&-x;}
	inline int add(int x,int d){for(;x<=m;x+=lb(x)) c[x]=pls(c[x],d);}
	inline int sum(int x){
		int res=0;
		for(;x>0;x-=lb(x)) res=pls(res,c[x]);
		return res;
	}
	inline int sum(int x,int y){return mns(sum(y),sum(x-1));}
}f[M];

int g[M],fac[M];

int main(){

	int i,j,tp;

	n=ri();
	for(i=1;i<=n;i++) val[i]=b[i]=ri();
	b[n+1]=-INF;
	sort(b+1,b+n+2); m=unique(b+1,b+n+2)-(b+1);
	for(i=1;i<=n;i++) a[i]=lower_bound(b+1,b+m+1,val[i])-b;

	f[0].add(1,1);
	for(i=1;i<=n;i++){
		for(j=i-1;j>=0;j--){
			tp=f[j].sum(1,a[i]);
			f[j+1].add(a[i],tp);
		}
	}

	for(i=1;i<=n;i++) g[i]=f[i].sum(1,m);
	for(i=1,fac[0]=1;i<=n;i++) fac[i]=mul(fac[i-1],i);

	int ans=0;
	for(i=1;i<n;i++) ans=pls(ans, mns(mul(g[i],fac[n-i]),mul(mul(g[i+1],fac[n-i-1]),i+1)) );
	ans=pls(ans,g[n]);

	printf("%d\n",ans);

	return 0;
}
posted @ 2017-07-11 16:02  _zwl  阅读(117)  评论(0编辑  收藏  举报