description

一棵树每条边可涂为1/0,有q个点对(x,y),问你有多少种涂色方案满足x到y至少有一个为1的边。

solution

\(dp[x][y]\)节点\(x\)内的边已经涂好了,从里面往上的没被覆盖的上端点的最深深度为\(y\)(为\(0\)没有)
\(dp[x][i]=\sum\limits_{j=0}^{dep_x}dp[x][i]*dp[y][j]+\sum\limits_{j=0}^{i}dp[x][i]*dp[y][j]+\sum\limits_{j=0}^{i-1}dp[x][j]*dp[y][i]\)
第一个式子表示\(E(x,y)=1\),后两个式子分别代表最深点是否从子树\(y\)贡献来。
发现树上dp+第二维的前缀求和直接用边合并边修改的线段树合并。
注意到一种叶子可能两棵线段树都有,要更新。

code

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod=998244353;
const int N=1e6+5;
const int M=3e7+5;
namespace IO {
	char buf[1<<23],*p1=buf,*p2=buf;
	#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
	inline int rd() {
		int x=0,f=1;char ch=getchar();
		while(!isdigit(ch)){if(ch=='-') f=-1;ch=getchar();}
		while(isdigit(ch)) x=x*10+(ch^48),ch=getchar();
		return x*f;
	}
}
int up,n,m,nxt[N],dep[N],to[N],head[N],ecnt,mxd[N];
void add_edge(int u,int v) {nxt[++ecnt]=head[u];to[ecnt]=v;head[u]=ecnt;}
int rt[N],nd,ls[M],rs[M];
ll sum[M],tag[M];
void init(int u,int fa) {
	dep[u]=dep[fa]+1;up=max(up,dep[u]);
	for(int i=head[u];i;i=nxt[i]) {
		int v=to[i];if(v==fa)continue;
		init(v,u);
	}
}
void P_up(int x) {sum[x]=(sum[ls[x]]+sum[rs[x]])%mod;}
void P_dw(int x) {
	if(tag[x]==1)return;
	int L(ls[x]),R(rs[x]);
	if(L)sum[L]=sum[L]*tag[x]%mod,tag[L]=tag[L]*tag[x]%mod;
	if(R)sum[R]=sum[R]*tag[x]%mod,tag[R]=tag[R]*tag[x]%mod;
	tag[x]=1;
}
void Add(int &x,int p,int l,int r) {
	x=++nd;tag[x]=sum[x]=1;
	if(l==r)return;
	int mid=(l+r)>>1;
	(p<=mid)?Add(ls[x],p,l,mid):Add(rs[x],p,mid+1,r);
}
int Merge(int x,int y,int l,int r,ll sx,ll sy) {
	if(!x&&!y)return 0;
	int o=++nd;tag[o]=1;
	if(!x) {
		sum[o]=(sum[y]*sx)%mod;
		tag[o]=(tag[y]*sx)%mod;
		ls[o]=ls[y];rs[o]=rs[y];
		return o;
	}
	if(!y) {
		sum[o]=(sum[x]*sy)%mod;
		tag[o]=(tag[x]*sy)%mod;
		ls[o]=ls[x];rs[o]=rs[x];
		return o;
	}
	if(l==r) {
		sum[o]=(sum[x]*(sy+sum[y])+sum[y]*sx)%mod;
		return o;
	}
	P_dw(x);P_dw(y);
	int mid=(l+r)>>1;
	ls[o]=Merge(ls[x],ls[y],l,mid,sx,sy);
	rs[o]=Merge(rs[x],rs[y],mid+1,r,(sx+sum[ls[x]])%mod,(sy+sum[ls[y]])%mod);
	P_up(o);
//	printf("%d %lld\n",o,sum[o]);
	return o;
}
ll Sum(int x,int l,int r,int p,int q) {
	if(p<=l&&r<=q) return sum[x];
	P_dw(x);
	int mid=(l+r)>>1;
	if(p<=mid) {
		if(q>mid) return (Sum(ls[x],l,mid,p,q)+Sum(rs[x],mid+1,r,p,q))%mod;
		return Sum(ls[x],l,mid,p,q);
	}
	else return Sum(rs[x],mid+1,r,p,q);
}
ll Query(int x,int l,int r,int p) {
	if(l==r) {return sum[x];}
	P_dw(x);
	int mid=(l+r)>>1;
	return (p<=mid)?Query(ls[x],l,mid,p):Query(rs[x],mid+1,r,p);
}
void solve(int u,int fa) {
	Add(rt[u],mxd[u],0,up);
	for(int i=head[u];i;i=nxt[i]) {
		int v=to[i];if(v==fa)continue;
		solve(v,u);
		rt[u]=Merge(rt[u],rt[v],0,up,0,Sum(rt[v],0,up,0,dep[u]));
	}
}
int main() {
//	freopen("data.in","r",stdin);
	n=IO::rd();
	for(int i=1;i<n;i++) {int u=IO::rd(),v=IO::rd();
	add_edge(u,v);add_edge(v,u);}
	init(1,0);
	m=IO::rd();
	for(int i=1;i<=m;i++) {
		int x=IO::rd(),y=IO::rd();
		if(dep[x]<dep[y])swap(x,y);
		mxd[x]=max(mxd[x],dep[y]);
	}
	solve(1,0);
	printf("%lld",Query(rt[1],0,up,0));
	return 0;
}