【洛谷P7324】表达式求值
题目
题目链接:https://www.luogu.com.cn/problem/P7324
定义二元操作符 <
:对于两个长度都为 \(n\) 的数组 \(A, B\)(下标从 \(1\) 到 \(n\)),\(A\)<
\(B\) 的结果也是一个长度为 \(n\) 的数组,记为 \(C\)。则有 \(C[i] = \min(A[i], B[i])\)(\(1 \le i \le n\))。
定义二元操作符 >
:对于两个长度都为 \(n\) 的数组 \(A, B\)(下标从 \(1\) 到 \(n\)),\(A\)>
\(B\) 的结果也是一个长度为 \(n\) 的数组,记为 \(C\)。则有 \(C[i] = \max(A[i], B[i])\)(\(1 \le i \le n\))。
现在有 \(m\)(\(1 \le m \le 10\))个长度均为 \(n\) 的整数数组 \(A_0, A_1, \ldots , A_{m-1}\)。给定一个待计算的表达式 \(E\),其满足 \(E\) 中出现的每个操作数都是 \(A_0, A_1, \ldots , A_{m-1}\) 其中之一,且 \(E\) 中只包含 <
和 >;
两种操作符(<
和 >
的运算优先级相同),因此该表达式的结果值也将是一个长度为 \(n\) 的数组。
特殊地,表达式 \(E\) 中还可能出现操作符 ?
,它表示该运算符可能是 <
也可能是 >
。因此若表达式中有 \(t\) 个 ?
,则该表达式可生成 \(2^t\) 个可求确定值的表达式,从而可以得到 \(2^t\) 个结果值,你的任务就是求出这 \(2^t\) 个结果值(每个结果都是一个数组)中所有的元素的和。你只需要给出所有元素之和对 \({10}^9 + 7\) 取模后的值。
\(n,|S|\leq 5\times 10^4;m\leq 10\)。
思路
首先根据给出的表达式建立表达式树。
数组的每一位之间是互相独立的,所以考虑在表达式树上预处理出来所有情况,然后就可以枚举每一位计算答案。
但是如果直接枚举大小排列的话复杂度是 \(O(m!|S|)\),显然不可接受。
观察到本题只关心数字之间的大小关系,有一个经典的 trick 是考虑一个阈值 \(k\),把不超过 \(k\) 的设为 \(0\),超过 \(k\) 的设为 \(1\),然后进行完操作只需要看得到的是 \(0\) 还是 \(1\) 就可以判断答案与 \(k\) 的大小关系。
所以考虑设 \(f[x][s][0/1]\) 表示表达式树上点 \(x\) 为根的子树内,\(m\) 个表达式大小关系状态为 \(s\),且进行玩子树的表达式后答案在小的那部分 / 大的那部分的方案数。
其中一个状态 \(s\) 二进制下第 \(i\) 位表示第 \(i\) 个序列前这一位(上文提到每一个序列的每一位是独立的)是在小的部分 / 大的部分。
考虑转移,记左右儿子分别为 \(lc,rc\),拿当前节点的字符为 <
举例,有
>
的话就把 \(\min\) 改为 \(\max\);?
就两个都要转移。
统计答案的部分就直接枚举每一位,然后把 \(m\) 个序列这一位的数字从小到大排序,依次枚举每一位,假设把这一位划进小的部分前的状态为 \(S\),划进去后状态为 \(T\),那么贡献即为
时间复杂度 \(O(2^m|S|+nm\log m)\)。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=50010,M=10,MOD=1e9+7;
int n,m,rt,len,ans,I,a[M][N],b[M],ch[N][2],f[N][1<<M][2];
char s[N];
stack<int> st;
bool cmp(int x,int y)
{
return a[x][I]<a[y][I];
}
void build()
{
for (int i=len;i>=1;i--)
{
if (s[i]!='(') st.push(i);
if (s[i]=='(')
{
int x=st.top(); st.pop();
while (s[st.top()]!=')')
{
int y=st.top(); st.pop();
int z=st.top(); st.pop();
ch[y][0]=z; ch[y][1]=x; x=y;
}
st.pop(); st.push(x);
}
}
rt=st.top();
}
void dfs(int x)
{
if (isdigit(s[x]))
{
int num=s[x]-48;
for (int i=0;i<(1<<m);i++)
f[x][i][(i>>num)&1]=1;
return;
}
int lc=ch[x][0],rc=ch[x][1];
dfs(lc); dfs(rc);
for (int i=0;i<(1<<m);i++)
for (int j=0;j<=1;j++)
for (int k=0;k<=1;k++)
{
int sum=1LL*f[lc][i][j]*f[rc][i][k]%MOD;
if (s[x]=='<' || s[x]=='?') f[x][i][min(j,k)]=(f[x][i][min(j,k)]+sum)%MOD;
if (s[x]=='>' || s[x]=='?') f[x][i][max(j,k)]=(f[x][i][max(j,k)]+sum)%MOD;
}
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=0;i<m;i++)
for (int j=1;j<=n;j++)
scanf("%d",&a[i][j]);
scanf("%s",s+2);
len=strlen(s+2)+1;
s[1]='('; s[++len]=')';
build();
dfs(rt);
for (int i=0;i<m;i++) b[i]=i;
for (int i=1;i<=n;i++)
{
I=i;
sort(b,b+m,cmp);
for (int j=0,S=(1<<m)-1;j<m;j++)
{
ans=(ans+1LL*(f[rt][S^(1<<b[j])][0]-f[rt][S][0])*a[b[j]][i])%MOD;
S^=(1<<b[j]);
}
}
printf("%d\n",(ans+MOD)%MOD);
return 0;
}