【碳硫磷模拟赛】消失的+和* (树形DP)

好久没做过这么恶心的DP题了

题面

题面很简单,有一个计算式,由+号、*号、括号和小于10的正整数组成,现在所有的+*(由于属于违禁词而)都被-号给和谐掉了,现在要求所有可能的原计算式的结果之和。

你知道的信息:计算式总长度 n ∈ [ 1 , 1 0 5 ] n\in[1,10^5] n[1,105](其中保证-号总数 m ≤ 2500 m\leq2500 m2500),原计算式的+号总数 k ∈ [ 0 , m ] k\in[0,m] k[0,m] ,被和谐后的计算式(含括号)。

题解

括号表示了计算间的优先关系,我们可以通过这种关系建棵树,子节点比父节点先算。

然后,设计DP状态: d p [ i ] [ j ] . s u m dp[i][j].sum dp[i][j].sum 表示该子树 i i i 内存在 j j j+号的所有算式结果之和, d p [ i ] [ j ] . c n t dp[i][j].cnt dp[i][j].cnt 表示该子树 i i i 内存在 j j j+号的算式总数。此处 d p [ i ] [ j ] dp[i][j] dp[i][j] 是一个二元组。

经典的树形背包DP枚举+转移思路:记录前面儿子的答案,与下一个儿子合并。此时“前面儿子”不一定两端有括号,但下一个儿子一定是一个整体。

那么对于两个算式间用+号相连( C = A + B C=A+B C=A+B),有转移:
d p [ C ] [ j + k + 1 ] ← ( d p [ A ] [ j ] . s u m ∗ d p [ B ] [ k ] . c n t + d p [ B ] [ k ] . s u m ∗ d p [ A ] [ j ] . c n t   , d p [ A ] [ j ] . c n t ∗ d p [ B ] [ k ] . c n t ) dp[C][j+k+1]\leftarrow (dp[A][j].sum*dp[B][k].cnt+dp[B][k].sum*dp[A][j].cnt~,\\dp[A][j].cnt*dp[B][k].cnt) dp[C][j+k+1](dp[A][j].sumdp[B][k].cnt+dp[B][k].sumdp[A][j].cnt ,dp[A][j].cntdp[B][k].cnt)

但是对于乘法( C = A ∗ ( B ) C=A*(B) C=A(B))的情况就有困难,由于前一个算式不一定两端有括号,所以 B B B 只能乘 A A A 的最后一项。那我们就把 A A A所有情况下的最后一项拿出来求和,记为 g [ A ] [ . . . ] g[A][...] g[A][...](不是二元组),然后可以有一个复杂的转移:
d p [ C ] [ j + k ] ← ( d p [ B ] [ k ] . s u m ∗ g [ A ] [ j ] + ( d p [ A ] [ j ] . s u m − g [ A ] [ j ] ) ∗ d p [ B ] [ k ] . c n t   , d p [ A ] [ j ] . c n t ∗ d p [ B ] [ k ] . c n t ) g [ C ] [ j + k + 1 ] ← d p [ B ] [ k ] . s u m ∗ d p [ A ] [ j ] . c n t g [ C ] [ j + k ] ← g [ A ] [ j ] ∗ d p [ B ] [ k ] . s u m dp[C][j+k]\leftarrow \Big(dp[B][k].sum*g[A][j]+(dp[A][j].sum-g[A][j])*dp[B][k].cnt~,\\ dp[A][j].cnt*dp[B][k].cnt\Big)\\ g[C][j+k+1]\leftarrow dp[B][k].sum*dp[A][j].cnt\\ g[C][j+k]\leftarrow g[A][j]*dp[B][k].sum dp[C][j+k](dp[B][k].sumg[A][j]+(dp[A][j].sumg[A][j])dp[B][k].cnt ,dp[A][j].cntdp[B][k].cnt)g[C][j+k+1]dp[B][k].sumdp[A][j].cntg[C][j+k]g[A][j]dp[B][k].sum

复杂度是经典的树上背包DP时间复杂度, O ( n 2 ) O(n^2) O(n2)

有几点要注意的:

  1. n n n 很大, m m m 很小,说明括号可能很多,得缩掉一些儿子数只有1的废点。
  2. 注意转移的先后顺序。
  3. 注意子树 size ,边界情况卡准。
  4. 回溯的时候,由于在算式两边加上了括号,要把所有的 g [ i ] [ j ] g[i][j] g[i][j] 赋值为 d p [ i ] [ j ] . s u m dp[i][j].sum dp[i][j].sum

CODE

#include<set>
#include<map>
#include<stack>
#include<cmath>
#include<ctime>
#include<queue>
#include<bitset>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 2505
#define LL long long
#define ULL unsigned long long
#define UI unsigned int
#define DB double
#define ENDL putchar('\n')
#define lowbit(x) (-(x) & (x))
#define FI first
#define SE second
#define eps (1e-4)
LL read() {
    LL f=1,x=0;char s = getchar();
    while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
    while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
    return f*x;
}
void putpos(LL x) {
    if(!x) return ;
    putpos(x/10); putchar('0'+(x%10));
}
void putnum(LL x) {
    if(!x) putchar('0');
    else if(x < 0) putchar('-'),putpos(-x);
    else putpos(x);
}
void AIput(LL x,char c) {putnum(x);putchar(c);}
 
const int MOD = 1000000007;
int n,m,s,o,k;
int le;
char ss[100005];
int cnd,sz[MAXN];
struct it{
    int x,y;it(){x=y=0;}
    it(int X,int Y){x=X;y=Y;}
};
it operator + (it a,it b) {return it((a.x+b.x)%MOD,(a.y+b.y)%MOD);}
it Plus(it a,it b) {return it((a.x*1ll*b.y%MOD+a.y*1ll*b.x%MOD)%MOD,a.y*1ll*b.y%MOD);}
it Mult(it a,it b) {return it(a.x*1ll*b.x%MOD,a.y*1ll*b.y%MOD);}
it dp[MAXN][MAXN];
int g[MAXN][MAXN];
int dfs(int ad) {
    if(ss[ad] != '(') {
        int nm = ss[ad]-'0';
        int x = ++ cnd;
        sz[x] = 1;
        for(int i = 1;i <= m;i ++) dp[x][i] = it(),g[x][i] = 0;
        dp[x][0] = it(nm,1);
        g[x][0] = nm;
        return x;
    }
    int le = 0,cc = 1,st = ad;
    vector<int> v;
    v.push_back(0);
    while(cc) {
        ad ++;
        if(ss[ad] != '-') {
            if(ss[ad] == ')') cc --;
            else {
                if(cc == 1) v.push_back(dfs(ad)),le ++;
                if(ss[ad] == '(') cc ++;
            }
        }
    }
    int tl = cnd+1;
    int siz = sz[v[1]],las = v[1];
    for(int i = 2;i <= le;i ++) {
        int y = v[i],p = v[i-1];
        las = y;
        siz += sz[y];
        for(int j = 0;j < siz;j ++) dp[tl][j] = it(),g[tl][j] = 0;
        for(int j = 0;j < sz[y];j ++) {
            for(int k = 0;k < siz-sz[y];k ++) {
                int nm = (dp[y][j].x *1ll* g[p][k] % MOD + (dp[p][k].x+MOD-g[p][k]) % MOD *1ll* dp[y][j].y % MOD) % MOD;
                dp[tl][j+k] = dp[tl][j+k] + it(nm,dp[y][j].y *1ll* dp[p][k].y % MOD);
                dp[tl][j+k+1] = dp[tl][j+k+1] + Plus(dp[y][j],dp[p][k]);
                (g[tl][j+k] += g[p][k] *1ll* dp[y][j].x % MOD) %= MOD;
                (g[tl][j+k+1] += dp[y][j].x *1ll* dp[p][k].y % MOD) %= MOD;
            }
        }
        swap(dp[tl],dp[y]);
        swap(g[tl],g[y]);
        sz[y] = siz;
    }
    for(int i = 0;i < siz;i ++) g[las][i] = dp[las][i].x;
    return las;
}
int main() {
    freopen("operator.in","r",stdin);
    freopen("operator.out","w",stdout);
    le = read();m = read();
    scanf("%s",ss + 1);
    ss[0] = '(';
    ss[le+1] = ')';
    int rt = dfs(0);
//  printf("\n<%d>\n",n);
    AIput(dp[rt][m].x,'\n');
    return 0;
}
posted @ 2021-10-04 22:56  DD_XYX  阅读(25)  评论(0编辑  收藏  举报