P6773 [NOI2020] 命运(线段树合并维护整体 dp)

P6773

一眼看出 \(\rm dp\) 方程

\(dp[i][j]\) 表示以 \(i\) 为根的子树,子树内所有未覆盖路径中结尾点最小的深度为 \(j\)

考虑合并两个子树

\[dp[u][i] \times dp[v][j] \to dp[u][\max(i,j)] \]

然后可以得到一个 \(O(n^2)\) 的算法。但这玩意只能得 \(\rm 32pts\)

再考虑对 \(m\) 条路径进行容斥,可以得到一个 \(O(2^m\log^2)\) 的算法

这样就能获得 \(\rm 64pts\)

上面那个 \(O(n^2)\) 的 dp 方程非常不好看,我们化简一下

\[dp[x][j] = dp[x][j]\times sum[v][i] + dp[v][j] \times sum[x][j-1] \]

显然这个东西是不能再优化了,毕竟空间复杂度都不太够

考虑整体 \(\rm dp\),用线段树合并维护一下乘法标记就行……

而且我写的时候算的空间复杂度是 \(\rm 1024MB\) 但实际只用了\(\rm 224MB\)

#include<bits/stdc++.h>

using namespace std;

// #define INF 1<<30
#define ill long long
#define pb emplace_back

template<typename _T>
inline void read(_T &x)
{
    x= 0 ;int f =1;char s=getchar();
    while(s<'0' ||s>'9') {f =1;if(s == '-')f =- 1;s=getchar();}
    while('0'<=s&&s<='9'){x = (x<<3) + (x<<1) + s - '0';s= getchar();}
    x*=f;
}
const int INF = 1 << 30;
const int mod = 998244353;
const int np = 5e5 + 5;
int head[np],ver[np*2],nxt[np*2];
//int dp[3005][3005];
int dep[np];
int tit,m;
vector<int> vec[np];

inline void add(int x,int y)
{
    ver[++tit] = y;
    nxt[tit] = head[x];
    head[x] = tit;
}

struct OPT{
    int l,r;
    int add;
};
const int cp = np * 25+ 100;
int ls[cp],rs[cp];
int sum[cp],tag[cp];
int top;
int cnt,rot[np],n;
inline int New()
{
    ++cnt;
    tag[cnt] = 1;
    return cnt;
}

inline void mul(int &a,int b)
{
	ill tmp = a;
	tmp *= b;
	tmp %= mod;
	a = tmp;
//	return 1;
}

inline void maketag(int x,int val)
{
//    sum[x] *= val;
    mul(sum[x],val);
//    sum[x] %= mod;
    if(x) mul(tag[x],val);
}
inline void pushdown(int x)
{
    if(tag[x]!=1)
    {
        maketag(ls[x],tag[x]);
        maketag(rs[x],tag[x]);
        tag[x] = 1;
    }
}
inline void pushup(int x)
{
    sum[x] = sum[ls[x]] + sum[rs[x]];
    sum[x] %= mod;
}

inline void upd(int &qt,int l,int r,int pos,int val)
{
    (!qt)&&(qt=New());
    if(l == r)
    {
       sum[qt] += val;
       sum[qt] %= mod;
       return;
    }
    int mid = l + r >> 1;
    pushdown(qt);
    if(pos <= mid) upd(ls[qt],l,mid,pos,val);
    else upd(rs[qt],mid+1,r,pos,val);
    pushup(qt);
}
inline int query(int qt,int l,int r,int L,int R)
{
    if(!qt) return 0;
    if(L <= l && r <= R)
    {
        return sum[qt];
    }
    int mid = l + r >> 1;
    pushdown(qt);
    int lsum(0),rsum(0);
    if(L<=mid) lsum = query(ls[qt],l,mid,L,R);
    if(R >= mid+1) rsum = query(rs[qt],mid+1,r,L,R);
    return (lsum + rsum)%mod;
}
inline int Merge(int qu,int qv,int u,int v,int l,int r,int &sumu,int &sumv)
{
    if(!u&&!v) return 0;
    if(!u)
    {
//        if(sumu){
//	    tag[v] *= sumu;
	    mul(tag[v],sumu);
	    tag[v] %= mod;
	    
//		if(!tag[v]) tag[v] = 1;	
//		}
        sumv += sum[v];
        sumv %= mod;
//        sum[v] *= sumu;
        mul(sum[v],sumu);
        sum[v] %= mod;
        return v;
    }
    if(!v)
    {
//    	if(sumv){
//	    tag[u] *= sumv;
	    mul(tag[u],sumv);
	    tag[u] %= mod;
//	    if(!tag[u]) tag[u] = 1;
	    sumu += sum[u];
	    sumu %= mod;
//	    sum[u] *= sumv;
	    mul(sum[u],sumv);
		sum[u] %= mod;    		
//		}
        return u;
    }
    if(l == r)
    {
    	int a1 = sum[u];
    	int qop = (sum[v] + sumv)%mod;
    	mul(sum[u],qop);
//        sum[u] *= ();
//        sum[u] %= mod;
		ill opt = sumu;
		opt *= sum[v];
		opt %= mod;
        sum[u] += opt;//sum[v] * sumu;
        sum[u] %= mod;
        sumv += sum[v];
        sumv %= mod;
        sumu += a1;
        sumu %= mod;
        return u;
    }
    // if(l == r)
    int mid = l + r >> 1;
    pushdown(u),pushdown(v);
    ls[u] = Merge(qu,qv,ls[u],ls[v],l,mid,sumu,sumv);
    rs[u] = Merge(qu,qv,rs[u],rs[v],mid+1,r,sumu,sumv);
    pushup(u);
    return u;
}

inline void dfs(int x,int ff)
{
    // dp[x][0] = 1;
    int minn = 0;
    for(auto q:vec[x]) minn = max(minn,dep[q]);
	upd(rot[x],0,n,minn,1); 
    for(int i=head[x],v;i;i=nxt[i])
    {
        v = ver[i];
        if(v ==  ff) continue;
        dfs(v,x);
        int sumu(0),sumv(0);
        Merge(rot[x],rot[v],rot[x],rot[v],0,n,sumu,sumv);
    }
	if(x!=1)
	{
        // int val = sum[qt[x]];
        int op = query(rot[x],0,n,0,dep[x]-1);
//        printf("%lld\n",op);
        upd(rot[x],0,n,0,op);
        // upd(x,0,val);
		// for(int i=0;i < dep[x];i ++) dp[x][0] += dp[x][i],dp[x][0] %= mod;//,MOD(dp[x][0]);   
	}
}

inline void dfs0(int x,int ff)
{
    dep[x] = dep[ff] + 1;
    for(int i=head[x],v;i;i=nxt[i])
    {
        v = ver[i];
        if(v == ff) continue;
        dfs0(v,x);
    }
}

signed main()
{
    read(n);
    for(int i=1,a,b;i <= n-1;i ++)
    {
        read(a),read(b);    
        add(a,b);
        add(b,a);
    }
    dfs0(1,0);
    read(m);
    for(int i=1,v,u;i <= m;i ++)
    {
        read(u),read(v);
        if(dep[v] < dep[u])
        {
            vec[u].pb(v);
        }
        else vec[v].pb(u);
    }
    dfs(1,0);
    printf("%d",query(rot[1],0,n,0,0));
    
}

辰星凌姐姐这个题拿了 \(\rm 16pts\)

要努力超越我心中的目标哦。

posted @ 2021-09-28 20:33  ·Iris  阅读(84)  评论(0编辑  收藏  举报