P3352 [ZJOI2016] 线段树 思考--zhengjun
有一个显然的 \(O(n^3q)\) 的做法:
-
设 \(f_{i,l,r,x}\) 表示 \(i\) 次操作过后,区间 \([l,r]\) 的数 \(\le x\),\(a_{l-1},a_{r+1}>x\) 的方案数。
-
转移:$$f_{i,l,r,x}=f_{i-1,l,r,x}\times g_{l,r}+\sum\limits_{j<l}f_{i-1,j,r,x}\times(j-1)+\sum\limits_{j>r}f_{i-1,l,j,x}\times(n-j)$$
-
其中 \(g_{l,r}=\frac{(r-l+1)(r-l+2)}{2}+\frac{l(l-1)}{2}+\frac{(n-r)(n-r+1)}{2}\),为无效操作的方案数。
-
最后算答案只需差分一下即可:$$ans_i=\sum\limits_{l\le i\le r}b_x\times \sum \limits_{x}f_{q,l,r,x}-f_{q,l,r,x-1}$$
-
其中,\(b_x\) 为离散化过后的值域数组。
当然,有一个数据随机的限制,复杂度实际上为期望 \(O(n^2q)\),可过。
但是,其实可以做到任意数据下 \(O(n^2q)\),需要一点技巧。
开始推狮子:
\[ans_i=\sum\limits_{l\le i\le r}b_x\times \sum \limits_{x}f_{q,l,r,x}-f_{q,l,r,x-1}\\
=\sum\limits_{l\le i\le r}\sum\limits_{x}f_{q,l,r,x}\times (b_x-b_{x+1})
\]
发现转移其实和 \(x\) 没有关系,所以考虑优化这一维。
设 \(f'_{i,l,r}=\sum\limits_{x}f_{i,l,r,x}\times(b_x-b_{x+1})\)。
那么 \(ans_i=\sum\limits_{l\le i\le r}f'_{q,l,r}\)。
转移变成了:
\[f'_{i,l,r}=f'_{i-1,l,r}\times g_{l,r}+\sum\limits_{j<l}f'_{i-1,j,r}\times(j-1)+\sum\limits_{j>r}f'_{i-1,l,j}\times(n-j)
\]
复杂度即可降为 \(O(n^2q)\)。
代码
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const int N=4e2+10,mod=1e9+7;
int n,m,q,a[N],b[N];
int mx[N][N],f[2][N][N],g[N][N],h[N][N];
int calc(int x){
return x*(x+1)/2;
}
int main(){
freopen(".in","r",stdin);
//freopen(".out","w",stdout);
cin>>n>>q;
for(int i=1;i<=n;i++)cin>>a[i];
copy(a,a+1+n,b),sort(b+1,b+1+n),m=unique(b+1,b+1+n)-b-1;
for(int i=1;i<=n;i++)a[i]=lower_bound(b+1,b+1+m,a[i])-b;
for(int i=1;i<=n;i++){
mx[i][i]=a[i];
for(int j=i+1;j<=n;j++)mx[i][j]=max(mx[i][j-1],a[j]);
}
a[0]=a[n+1]=m+1;
for(int i=1;i<=n;i++){
for(int j=i;j<=n;j++){
if(mx[i][j]<min(a[i-1],a[j+1]))
f[0][i][j]=(b[mx[i][j]]-b[min(a[i-1],a[j+1])]+mod)%mod;
}
}
for(int i=1,now=1,las=0;i<=q;i++,swap(now,las)){
for(int l=1;l<=n;l++)for(int r=n;r>=l;r--){
g[l][r]=(g[l-1][r]+f[las][l][r]*(l-1ll))%mod;
h[l][r]=(h[l][r+1]+1ll*f[las][l][r]*(n-r))%mod;
}
for(int l=1;l<=n;l++)for(int r=l;r<=n;r++){
f[now][l][r]=(1ll*f[las][l][r]*(calc(r-l+1)+calc(l-1)+calc(n-r))+g[l-1][r]+h[l][r+1])%mod;
}
}
for(int i=1;i<=n;i++){
int ans=0;
for(int l=1;l<=i;l++){
for(int r=i;r<=n;r++){
(ans+=f[q&1][l][r])%=mod;
}
}
cout<<ans<<' ';
}
return 0;
}