HDU 5396 区间DP 数学 Expression
题意:有n个数字,n-1个运算符,每个运算符的顺序可以任意,因此一共有 (n - 1)! 种运算顺序,得到 (n - 1)! 个运算结果,然后求这些运算结果之和 MOD 1e9+7.
分析:
类比最优矩阵链乘,枚举区间[l, r]中最后一个运算符的位置k。
如果运算符为乘法的话,那么根据乘法分配率这个乘法会分配进去。
这个区间中一共有r - l个运算符,其中最后一个运算符已经定了是第k个,左区间[l, k]有k - l个运算符,右区间[k + 1, r]有 r - k - 1 个运算符。
而且左、右区间运算符的先后顺序确定以后,两个区间之间的顺序是互不影响的,因此这样相同的结果一共有C(r - l - 1, k - l)
因此答案还要乘上这个数,d(i, j) += d(i, k) * d(k + 1, r) * C(r - l - 1, k - l) | op[k] = *
但如果是加减法的话就不能直接按照运算符进行区间合并了。
对于左区间的确定的一个运算顺序,右区间一共有 (r - k - 1)! 个运算结果,所以答案累加一个 d(l, k) * (r - k - 1)!
同样地,对于右区间一个确定的操作顺序,左区间对应有 (k - l)! 个运算结果,答案累加一个 d(k + 1, r) * (k - l)!
最后确定两个区间 r - l - 1 个运算符的顺序,最终答案乘上 C(r - l - 1, k - l)
最后总结一下答案就是:
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 using namespace std; 6 7 typedef long long LL; 8 9 const int maxn = 100 + 10; 10 const LL MOD = 1000000007; 11 12 int n; 13 LL a[maxn]; 14 LL fac[maxn], C[maxn][maxn]; 15 char op[maxn]; 16 17 int vis[maxn][maxn]; 18 LL d[maxn][maxn]; 19 20 LL dp(int l, int r) 21 { 22 if(vis[l][r]) return d[l][r]; 23 LL& ans = d[l][r]; 24 ans = 0; 25 vis[l][r] = true; 26 if(l == r) return ans = a[l]; 27 if(l + 1 == r) 28 { 29 if(op[l] == '*') return ans = a[l] * a[r] % MOD; 30 if(op[l] == '+') return ans = (a[l] + a[r]) % MOD; 31 if(op[l] == '-') return ans = (((a[l] - a[r]) % MOD) + MOD) % MOD; 32 } 33 for(int k = l; k < r; k++) 34 { 35 LL t1 = dp(l, k), t2 = dp(k + 1, r); 36 LL t; 37 if(op[k] == '*') 38 { 39 t = t1 * t2 % MOD; 40 t = t * C[r - l - 1][k - l]; 41 ans = (ans + t) % MOD; 42 continue; 43 } 44 45 t1 = t1 * fac[r - k - 1] % MOD; 46 t2 = t2 * fac[k - l] % MOD; 47 if(op[k] == '+') t = (t1 + t2) % MOD; 48 else t = (((t1 - t2) % MOD) + MOD) % MOD; 49 t = t * C[r - l - 1][k - l]; 50 ans = (ans + t) % MOD; 51 } 52 53 return ans; 54 } 55 56 int main() 57 { 58 fac[0] = 1; 59 for(int i = 1; i < maxn; i++) fac[i] = fac[i - 1] * i % MOD; 60 for(int i = 0; i < maxn; i++) C[i][0] = C[i][i] = 1LL; 61 for(int i = 2; i < maxn; i++) 62 for(int j = 1; j < i; j++) C[i][j] = (C[i-1][j] + C[i-1][j-1]) % MOD; 63 64 while(scanf("%d", &n) == 1 && n) 65 { 66 for(int i = 1; i <= n; i++) scanf("%I64d", a + i); 67 scanf("%s", op + 1); 68 memset(vis, false, sizeof(vis)); 69 memset(vis, 0, sizeof(vis)); 70 printf("%I64d\n", dp(1, n)); 71 } 72 73 return 0; 74 }