P6773 [NOI2020] 命运(线段树合并维护整体 dp)
P6773
一眼看出 \(\rm dp\) 方程
设 \(dp[i][j]\) 表示以 \(i\) 为根的子树,子树内所有未覆盖路径中结尾点最小的深度为 \(j\)。
考虑合并两个子树
\[dp[u][i] \times dp[v][j] \to dp[u][\max(i,j)]
\]
然后可以得到一个 \(O(n^2)\) 的算法。但这玩意只能得 \(\rm 32pts\)。
再考虑对 \(m\) 条路径进行容斥,可以得到一个 \(O(2^m\log^2)\) 的算法
这样就能获得 \(\rm 64pts\)
上面那个 \(O(n^2)\) 的 dp 方程非常不好看,我们化简一下
\[dp[x][j] = dp[x][j]\times sum[v][i] + dp[v][j] \times sum[x][j-1]
\]
显然这个东西是不能再优化了,毕竟空间复杂度都不太够
考虑整体 \(\rm dp\),用线段树合并维护一下乘法标记就行……
而且我写的时候算的空间复杂度是 \(\rm 1024MB\) 但实际只用了\(\rm 224MB\)
#include<bits/stdc++.h>
using namespace std;
// #define INF 1<<30
#define ill long long
#define pb emplace_back
template<typename _T>
inline void read(_T &x)
{
x= 0 ;int f =1;char s=getchar();
while(s<'0' ||s>'9') {f =1;if(s == '-')f =- 1;s=getchar();}
while('0'<=s&&s<='9'){x = (x<<3) + (x<<1) + s - '0';s= getchar();}
x*=f;
}
const int INF = 1 << 30;
const int mod = 998244353;
const int np = 5e5 + 5;
int head[np],ver[np*2],nxt[np*2];
//int dp[3005][3005];
int dep[np];
int tit,m;
vector<int> vec[np];
inline void add(int x,int y)
{
ver[++tit] = y;
nxt[tit] = head[x];
head[x] = tit;
}
struct OPT{
int l,r;
int add;
};
const int cp = np * 25+ 100;
int ls[cp],rs[cp];
int sum[cp],tag[cp];
int top;
int cnt,rot[np],n;
inline int New()
{
++cnt;
tag[cnt] = 1;
return cnt;
}
inline void mul(int &a,int b)
{
ill tmp = a;
tmp *= b;
tmp %= mod;
a = tmp;
// return 1;
}
inline void maketag(int x,int val)
{
// sum[x] *= val;
mul(sum[x],val);
// sum[x] %= mod;
if(x) mul(tag[x],val);
}
inline void pushdown(int x)
{
if(tag[x]!=1)
{
maketag(ls[x],tag[x]);
maketag(rs[x],tag[x]);
tag[x] = 1;
}
}
inline void pushup(int x)
{
sum[x] = sum[ls[x]] + sum[rs[x]];
sum[x] %= mod;
}
inline void upd(int &qt,int l,int r,int pos,int val)
{
(!qt)&&(qt=New());
if(l == r)
{
sum[qt] += val;
sum[qt] %= mod;
return;
}
int mid = l + r >> 1;
pushdown(qt);
if(pos <= mid) upd(ls[qt],l,mid,pos,val);
else upd(rs[qt],mid+1,r,pos,val);
pushup(qt);
}
inline int query(int qt,int l,int r,int L,int R)
{
if(!qt) return 0;
if(L <= l && r <= R)
{
return sum[qt];
}
int mid = l + r >> 1;
pushdown(qt);
int lsum(0),rsum(0);
if(L<=mid) lsum = query(ls[qt],l,mid,L,R);
if(R >= mid+1) rsum = query(rs[qt],mid+1,r,L,R);
return (lsum + rsum)%mod;
}
inline int Merge(int qu,int qv,int u,int v,int l,int r,int &sumu,int &sumv)
{
if(!u&&!v) return 0;
if(!u)
{
// if(sumu){
// tag[v] *= sumu;
mul(tag[v],sumu);
tag[v] %= mod;
// if(!tag[v]) tag[v] = 1;
// }
sumv += sum[v];
sumv %= mod;
// sum[v] *= sumu;
mul(sum[v],sumu);
sum[v] %= mod;
return v;
}
if(!v)
{
// if(sumv){
// tag[u] *= sumv;
mul(tag[u],sumv);
tag[u] %= mod;
// if(!tag[u]) tag[u] = 1;
sumu += sum[u];
sumu %= mod;
// sum[u] *= sumv;
mul(sum[u],sumv);
sum[u] %= mod;
// }
return u;
}
if(l == r)
{
int a1 = sum[u];
int qop = (sum[v] + sumv)%mod;
mul(sum[u],qop);
// sum[u] *= ();
// sum[u] %= mod;
ill opt = sumu;
opt *= sum[v];
opt %= mod;
sum[u] += opt;//sum[v] * sumu;
sum[u] %= mod;
sumv += sum[v];
sumv %= mod;
sumu += a1;
sumu %= mod;
return u;
}
// if(l == r)
int mid = l + r >> 1;
pushdown(u),pushdown(v);
ls[u] = Merge(qu,qv,ls[u],ls[v],l,mid,sumu,sumv);
rs[u] = Merge(qu,qv,rs[u],rs[v],mid+1,r,sumu,sumv);
pushup(u);
return u;
}
inline void dfs(int x,int ff)
{
// dp[x][0] = 1;
int minn = 0;
for(auto q:vec[x]) minn = max(minn,dep[q]);
upd(rot[x],0,n,minn,1);
for(int i=head[x],v;i;i=nxt[i])
{
v = ver[i];
if(v == ff) continue;
dfs(v,x);
int sumu(0),sumv(0);
Merge(rot[x],rot[v],rot[x],rot[v],0,n,sumu,sumv);
}
if(x!=1)
{
// int val = sum[qt[x]];
int op = query(rot[x],0,n,0,dep[x]-1);
// printf("%lld\n",op);
upd(rot[x],0,n,0,op);
// upd(x,0,val);
// for(int i=0;i < dep[x];i ++) dp[x][0] += dp[x][i],dp[x][0] %= mod;//,MOD(dp[x][0]);
}
}
inline void dfs0(int x,int ff)
{
dep[x] = dep[ff] + 1;
for(int i=head[x],v;i;i=nxt[i])
{
v = ver[i];
if(v == ff) continue;
dfs0(v,x);
}
}
signed main()
{
read(n);
for(int i=1,a,b;i <= n-1;i ++)
{
read(a),read(b);
add(a,b);
add(b,a);
}
dfs0(1,0);
read(m);
for(int i=1,v,u;i <= m;i ++)
{
read(u),read(v);
if(dep[v] < dep[u])
{
vec[u].pb(v);
}
else vec[v].pb(u);
}
dfs(1,0);
printf("%d",query(rot[1],0,n,0,0));
}
辰星凌姐姐这个题拿了 \(\rm 16pts\)。
要努力超越我心中的目标哦。