[提高组集训2021] 消失的运算符
一、题目
给定一个长度为 \(n\) 的表达式,表达式只出现括号、减号和数字 \(1\sim 9\),设一共有 \(m\) 个减号。
求出把 \(m\) 个减号其中 \(k\) 个替换成加号,\(m-k\) 个替换个乘号的所有表达式之和,答案模 \(1e9+7\)
\(n\leq 10^5,m\leq 2500\)
二、解法
首先考虑没有括号怎么做,注意这里不能直接设 \(f[i][j]\) 表示前 \(i\) 个数字选 \(j\) 个加号的值,因为如果下一个符号是乘号的话你是转移不动的。那么我们特殊处理这种情况,设 \(f[i][j]\) 表示前 \(i\) 个数字选 \(j\) 个加号,不计入最后一段连乘的值,\(g[i][j]\) 表示计入最后一段连乘的值,\(num[i][j]\) 表示方案数。
因为我们的 \(dp\) 记录的是所有情况的和,利用分配律就可以简单转移了。
如果有括号怎么办呢?我们用类似分治的方法建出一棵括号树,把每一层最浅的那些括号段取出来递归下去,合并的时候当成没有括号的情况按顺序合并,复杂度就是树背包的 \(O(m^2)\)
当然本题的难点是实现,我们如何找到最浅的括号段呢?对于每个减号记录左边未匹配的左括号个数 \(match[i]\),那么每一层都是 \(match\) 最小的减号作为分隔符号,以其为端点分治即可。
//I wish that I was bulletproof
#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;
const int M = 2505;
const int N = 5005;
const int MOD = 1e9+7;
#define int long long
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m,f[N][M],g[N][M],num[N][M],len[N];
int cnt,l1,l2,val[M],mt[M],tf[M],tg[M],tn[M];
char s[1000005];
void div(int x,int l,int r)
{
if(l==r)
{
f[x][0]=val[l];
num[x][0]=1;
return ;
}
int mi=MOD;vector<int> v;
for(int i=l;i<r;i++)
mi=min(mi,mt[i]);
for(int i=l;i<r;i++)
if(mt[i]==mi) v.push_back(i);
int y=++cnt;
div(y,l,v[0]);
len[x]=len[y];
for(int i=0;i<=len[x];i++)
{
f[x][i]=0;g[x][i]=f[y][i];
num[x][i]=num[y][i];
}
for(int i=0;i<v.size();i++)
{
y=++cnt;
div(y,v[i]+1,(i+1==v.size())?r:v[i+1]);
for(int j=0;j<=len[x];j++)
{
tf[j]=f[x][j];tg[j]=g[x][j];tn[j]=num[x][j];
f[x][j]=g[x][j]=num[x][j]=0;
}
for(int j=0;j<=len[x];j++)
for(int k=0;k<=len[y];k++)
{
//*
int t=j+k;
f[x][t]=(f[x][t]+tf[j]*num[y][k])%MOD;
g[x][t]=(g[x][t]+tg[j]*f[y][k])%MOD;
num[x][t]=(num[x][t]+tn[j]*num[y][k])%MOD;
//+
t=j+k+1;
f[x][t]=(f[x][t]+(tf[j]+tg[j])*num[y][k])%MOD;
g[x][t]=(g[x][t]+tn[j]*f[y][k])%MOD;
num[x][t]=(num[x][t]+tn[j]*num[y][k])%MOD;
}
len[x]+=len[y]+1;
}
for(int j=0;j<=len[x];j++)
f[x][j]=(f[x][j]+g[x][j])%MOD;
}
signed main()
{
freopen("operator.in","r",stdin);
freopen("operator.out","w",stdout);
n=read();m=read();scanf("%s",s+1);
for(int i=1,zz=0;i<=n;i++)
{
if(s[i]=='(') zz++;
else if(s[i]==')') zz--;
else if('0'<=s[i] && s[i]<='9')
val[++l1]=s[i]-'0';
else mt[++l2]=zz;
}
div(cnt=1,1,l1);
printf("%lld\n",f[1][m]);
}