luogu P7324 [WC2021] 表达式求值
https://www.luogu.com.cn/problem/P7324
70分的DP非常好想和好写
我们其实不用关心具体的值是什么,只用关心相对大小
拿个真值表进去随便搞搞即可
code:
#include<bits/stdc++.h> // 11 : < 12 : > 13 : ?
#define N 100050
#define mod 1000000007
#define ll long long
using namespace std;
void add(int &x, int y) {
x += y;
if(x >= mod) x -= mod;
}
int val[N], ch[N][2], sz;
char st[N];
int build(int l, int r) {
while(st[l] == '(' && st[r] == ')') l ++, r --;
if(l > r) return 0;
int rt = ++ sz;
if(l == r) {
val[rt] = st[l] - '0' + 1;
return rt;
}
int gs = 0; // printf("%d %d\n", l, r);
for(int i = r; i >= l; i --) {
if(st[i] == ')') gs ++;
if(st[i] == '(') gs --;
// printf("$%d ", gs);
if(gs == 0 && !(st[i] >= '0' && st[i] <= '9') && st[i] != '(' && st[i] != ')') {
if(st[i] == '<') val[rt] = 11;
if(st[i] == '>') val[rt] = 12;
if(st[i] == '?') val[rt] = 13;
ch[rt][0] = build(l, i - 1);
ch[rt][1] = build(i + 1, r);
return rt;
}
}
return rt;
}
int a[15][N], f[N][2];
int n, m, rt, ok[15];
#define ls ch[u][0]
#define rs ch[u][1]
void dfs(int u) {
f[u][0] = f[u][1] = 0;
if(val[u] <= m) {f[u][ok[val[u]]] = 1; return ;}
dfs(ls), dfs(rs);
if(val[u] == 11 || val[u] == 13) {
(f[u][0] += 1ll * f[ls][0] * f[rs][0] % mod) %= mod;
(f[u][0] += 1ll * f[ls][0] * f[rs][1] % mod) %= mod;
(f[u][0] += 1ll * f[ls][1] * f[rs][0] % mod) %= mod;
(f[u][1] += 1ll * f[ls][1] * f[rs][1] % mod) %= mod;
}
if(val[u] == 12 || val[u] == 13) {
(f[u][0] += 1ll * f[ls][0] * f[rs][0] % mod) %= mod;
(f[u][1] += 1ll * f[ls][0] * f[rs][1] % mod) %= mod;
(f[u][1] += 1ll * f[ls][1] * f[rs][0] % mod) %= mod;
(f[u][1] += 1ll * f[ls][1] * f[rs][1] % mod) %= mod;
}
}
pair<int, int> b[15];
ll g[1 << 12], ans;
void solve() {
int lim = (1 << m) - 1;
for(int S = 0; S <= lim; S ++) {
for(int i = 1; i <= m; i ++) ok[i] = ((S >> (i - 1)) & 1);
dfs(rt); g[S] = f[rt][1];
// printf("%d %lld\n", S, g[S]);
}
for(int i = 1; i <= n; i ++) {
for(int j = 1; j <= m; j ++) b[j] = make_pair(a[j][i], j);
sort(b + 1, b + 1 + m);
b[0] = make_pair(0, 0);
int pos = 1, S = 0;
for(int j = 1; j <= m; j ++) {
while(pos < j && b[pos].first < b[j].first) S |= (1 << (b[pos ++].second - 1));
//printf(" %d %d\n", j, S);
(ans += 1ll * g[lim ^ S] * (b[j].first - b[j - 1].first) % mod) %= mod;
}
}
}
int main() {
scanf("%d%d", &n, &m);
for(int i = 1; i <= m; i ++)
for(int j = 1; j <= n; j ++) scanf("%d", &a[i][j]);
scanf(" %s", st + 1);
int nn = strlen(st + 1);
rt = build(1, nn);
solve();
printf("%lld", ans);
return 0;
}