[ZJOI2016] 线段树

一、题目

点此看题

二、解法

思维方式太低级了,好像最近状态还不及省选之前,要准备开始复习了。

直接使用 01-principle,先考虑对于一个 \(01\) 序列如何计算。对于序列上每个 \(0\) 的极长连续段都是独立的,可以分开来考虑,它们被两个 \(1\) 包夹住(特别地,序列边界也视为有 \(1\)

\(dp[i][l][r]\) 表示经过前 \(i\) 次操作,区间 \((l,r)\)\(l,r\) 处的 \(1\) 给包夹住的方案数,转移:

\[\begin{aligned} dp[i][l][r]=&dp[i-1][l][r]\cdot\frac{l(l+1)+(n-r+1)(n-r+2)+(r-l)(r-l-1)}{2}\\ +&dp[i-1][l'][r]\cdot l'+dp[i-1][l][r']\cdot (n-r'+1) \end{aligned} \]

转移的意义是不难理解的,其中需要满足 \(l'\leq l\)\(r\leq r'\),容易看出可以使用前缀和优化。

再推广到任意序列的情况,对于 \(w\) 把贡献拆成 \(\max a_i-\sum_{i=0}^{\max a_i} [w<i]\),那么对于离散化之后的每一小段,我们在原序列上设置 \(0\)\(1\),然后按照 \(01\) 序列的方法跑 \(dp\),算贡献即可,时间复杂度大概 \(O(n^3q)\)

注意到 \(n\)\(dp\) 只有初始化不同,转移方式都是完全相同的。那么可以使用整体 \(dp\) 的技巧,合理地初始化之后跑一遍 \(dp\) 即可,时间复杂度 \(O(n^2q)\)

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int M = 405;
const int MOD = 1e9+7;
#define int long long
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,m,a[M],p[M],v[M],ans[M];
int dp[M][M],s1[M][M],s2[M][M];
int qkpow(int a,int b)
{
	int r=1;
	while(b>0)
	{
		if(b&1) r=r*a%MOD;
		a=a*a%MOD;
		b>>=1;
	}
	return r;
}
int f(int l,int r)
{
	return ((l+1)*l+(n-r+1)*(n-r+2)+(r-l)*(r-l-1))/2%MOD;
}
signed main()
{
	n=read();m=read();v[n+1]=1;
	for(int i=1;i<=n;i++)
		a[i]=read(),p[i]=i;
	sort(p+1,p+1+n,[&](int x,int y){return a[x]<a[y];});
	for(int i=n;i>=1;i--)
	{
		int ls=0;v[p[i]]=1;
		for(int j=1;j<=n+1;j++) if(v[j])
			dp[ls][j]+=a[p[i]]-a[p[i-1]],ls=j;
	}
	ans[1]=a[p[n]]*qkpow(n*(n+1)/2,m)%MOD;
	for(int i=2;i<=n;i++) ans[i]=ans[1];
	for(int i=1;i<=m;i++)
	{
		for(int l=0;l<=n+1;l++)
		for(int r=n+1;r>l+1;r--)
		{
			s1[l][r]=((l?s1[l-1][r]:0)+l*dp[l][r])%MOD;
			s2[l][r]=((r<=n?s2[l][r+1]:0)+(n-r+1)*dp[l][r])%MOD;
		}
		for(int l=0;l<=n+1;l++)
		for(int r=n+1;r>l+1;r--)
		{
			dp[l][r]=dp[l][r]*f(l,r)%MOD;
			if(l) dp[l][r]=(dp[l][r]+s1[l-1][r])%MOD;
			if(r<=n) dp[l][r]=(dp[l][r]+s2[l][r+1])%MOD;
		}
	}
	for(int l=0;l<=n+1;l++)
		for(int r=n+1;r>l+1;r--)
			for(int i=l+1;i<r;i++)
				ans[i]=(ans[i]-dp[l][r])%MOD;
	for(int i=1;i<=n;i++)
		printf("%lld ",(ans[i]+MOD)%MOD);
}
posted @ 2022-07-18 22:27  C202044zxy  阅读(206)  评论(0编辑  收藏  举报