Comet OJ#12E Ternary String Counting 解题报告
Comet OJ - Contest #12 Ternary String Counting](https://cometoj.com/contest/71/problem/E?problem_id=4020)
题意
有一个长为 \(n\) 的三进制串 ( 字符集为 1, 2, 3 ),
有 \(m\) 个形如 \((l,r,x)\) 的限制条件, 表示区间 \(l,r\) 中有且仅有 \(x\) 个字符 \(1 \le x \le 3\),
求满足这些限制条件的字符串个数 \((\bmod 10^9+7)\).
思路
考虑 dp.
先不考虑限制,
设 \(f[i][j][k][l]\) 为填到第 \(i\) 位, 字符 1 最后一次出现在 \(j\) , 字符 2 最后一次出现在 \(k\), 字符 3 最后一次出现在 \(l\) 时的方案数, 复杂度为 \(O(n^4)\).
我们发现, 如果当前填到第 \(i\) 位, 则 \(j,k,l\) 中一定有一个等于 \(i\), 并且, 我们其实并不在乎哪个字符放在了哪里, 我们只在乎第几个字符放在了哪里,
所以, 设 \(f[i][j][k]\) 为当前填到第 \(i\) 位, 第二个字符最后一次出现在 \(j\) 第三个字符最后一次出现在 \(k\) 时的方案数.
\(f[i][j][k]\) 可以转移到
- \(f[i+1][j][k]\) ( \(i+1\) 位置放第一个字符 )
- \(f[i+1][i][k]\) ( \(i+1\) 位置放第二个字符 )
- \(f[i+1][i][j]\) ( \(i+1\) 位置放第三个字符 )
再考虑题目中的限制, 实际上就是限制了 dp 过程中 \(j\) 和 \(k\) 的取值范围, 我们分类讨论一下. 考虑限制\((l,r,x)\), 设 $i =r $,
- \(x=1\), 则 \(j<l.\)
- \(x=2\), 则 \(j \ge l, k < l\)
- \(x=3\), 则 \(j > l, k \ge l\)
然后我们只需在 dp 时据此限制一下 \(j,k\) 的大小就行了.
这样, 我们就得到了一个 \(O(n^3)\) 的 dp, 但还是不能满足 \(n\le5000\) 的数据范围.
我们考虑把这个转移过程形象化 :
有 \(i\) 层, 每一层有一个平面, 平面上的横坐标为 \(k\), 纵坐标为 \(j\),
那么, 按照上面的转移方程, 每一层的状态只能从上一层转移过来,
第一个转移就是直接从上一层的对应点转移, 第二, 三个转移分别是对上一层的 一列 和 一行求和,
而 \(j,k\) 的取值范围就相当于在这一层画了一个矩阵, 只有矩阵内的状态才是合法的.
所以, 现在我们所需要的操作有,
- 对一行或一列求和.
- 把除了一个特定区域外的值清空.
第一反应估计是要用一个数据结构来维护,
但我们考虑一下, 每次会更新的值其实只有 \(j=i-1\) 那一行, 其他的状态都只能从上一层的对应状态获取,
那么也就是说, 如果一个状态在当前被清空了, 那么它之后永远都是 \(0\), 因为它不会再被更新了,
所以, 我们可以开两个数组来维护每行每列的值: \(lsum[i]\) 表示第 \(i\) 行的和, \(csum[i]\) 表示第 \(i\) 列的和, 并维护每一行的有效区域 (没有被清空过) 的左端点和右端点, 并用一个变量 \(lj\) 维护当前最小的有效行,
(由于行数 \(j\) 是随着 \(i\) 递增的, 所以我们无法维护最大的有效行).
因为有效区域是不断减小的, 所以每一行的左右端点最多移动 \(n\) 次, 总共就是 \(n^2\) 次, 所以时间复杂度为 \(O(n^2)\).
代码
#include<bits/stdc++.h>
using namespace std;
const int _=5e3+7;
const int mod=1e9+7;
int T,n,m,lsum[_],csum[_],minj[_],maxj[_],mink[_],maxk[_],f[_][_];
int t[_],lj,rk[_],lk[_];
int dif(int x,int y){ return ((x-y)%mod+mod)%mod; }
void del(int j,int k){
lsum[j]=dif(lsum[j],f[j][k]);
csum[k]=dif(csum[k],f[j][k]);
f[j][k]=0;
}
void clear(int i){
for(int j=lj;j<=i-1;j++){
if(j<minj[i]||j>maxj[i]){
for(int k=0;k<=n;k++) del(j,k);
lk[j]=n; rk[j]=0;
}
else{
for(int k=lk[j];k<mink[i];k++) del(j,k);
for(int k=rk[j];k>maxk[i];k--) del(j,k);
lk[j]=max(lk[j],mink[i]);
rk[j]=min(rk[j],maxk[i]);
}
}
lj=max(lj,minj[i]);
}
int main(){
//freopen("x.in","r",stdin);
//freopen("x.out","w",stdout);
cin>>T;
while(T--){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++){ // 初始化
for(int j=0;j<=n;j++) f[i][j]=0;
lsum[i]=csum[i]=0;
minj[i]=mink[i]=0;
maxj[i]=maxk[i]=n;
lk[i]=0; rk[i]=n;
}
int l,r,x;
for(int i=1;i<=m;i++){ // 限制 j 和 k 的取值范围.
scanf("%d%d%d",&l,&r,&x);
if(x==1) maxj[r]=min(maxj[r],l-1);
else if(x==2){
minj[r]=max(minj[r],l);
maxk[r]=min(maxk[r],l-1);
}
else{
minj[r]=max(minj[r],l+1);
mink[r]=max(mink[r],l);
}
}
lj=0;
f[0][0]=1;
lsum[0]=csum[0]=1;
for(int i=1;i<=n;i++){
for(int k=0;k<i;k++) t[k]=(lsum[k]+csum[k])%mod;
if(minj[i]<=i-1&&maxj[i]>=i-1){
for(int k=mink[i];k<=min(maxk[i],max(0,i-2));k++){ // 更新 f[i][i-1][k]
f[i-1][k]=(f[i-1][k]+t[k])%mod;
csum[k]=(csum[k]+t[k])%mod;
lsum[i-1]=(lsum[i-1]+t[k])%mod;
}
}
clear(i); // 将不合法的状态清零
}
int ans=0;
for(int i=0;i<=n;i++) ans=(ans+lsum[i])%mod;
printf("%d\n",ans);
}
return 0;
}