「SLYZ Online Judge#74 Be」

只有学校里的电脑才能看的题目

昊哥从牛客搬的,懒得找原题了

题意就是多组询问,每次询问一条树上路径,将这条路径上的点拿下来做\(0/1\)背包,求使得点权和为\(K\)的倍数的方案有几种

\(n<=200000,K<=50,Q<=500000\)

首先这确实是一个背包,我们可以直接用树剖和线段树来维护这些路径,线段树上每个节点存一个数组\(dp[i][j]\),表示\(i\)这个区间选择出的数\(mod\ K=j\)的方案数

之后发现我们每次合并都是一个卷积,于是复杂度\(O(Qk^2logn)\),可以用\(NTT\)优化到\(O(Qklognlogk)\),但是并没有什么用

正解点分治,我们把询问离线,处理好每一组询问在那一个分治中心被处理到

处理当前分治重心的时候,我们直接求出每一个点到分治重心的\(dp\)数组,之后合并答案,由于这个时候我们只需要求\(dp[0]\),所以合并答案\(O(k)\)时间内就能完成

代码

#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define LL long long
#define re register
#define inf 999999999
#define maxn 500005
const LL mod=998244353;
inline int read() {
	char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
	while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
struct E{int v,nxt;}e[maxn<<1];
struct Ask{int x,y,l,rk;}q[maxn];
std::vector<int> v[maxn],t[maxn];
int sum[maxn],mx[maxn],vis[maxn],col[maxn];
int head[maxn],dfn[maxn],st[maxn],Ans[maxn],a[maxn];
int n,m,num,S,now,rt,R,__,Top,K;
inline int cmp(Ask A,Ask B) {return dfn[A.l]<dfn[B.l];}
inline void add(int x,int y) {e[++num].v=y;e[num].nxt=head[x];head[x]=num;}
LL dp[2][maxn][50];
void getroot(int x,int fa) {
	sum[x]=1,mx[x]=0;
	for(re int i=head[x];i;i=e[i].nxt) {
		if(vis[e[i].v]||e[i].v==fa) continue;
		getroot(e[i].v,x);sum[x]+=sum[e[i].v];
		if(sum[e[i].v]>mx[x]) mx[x]=sum[e[i].v];
	}
	mx[x]=max(mx[x],S-sum[x]);
	if(mx[x]<now) now=mx[x],rt=x;
}
void paint(int x,int fa,int c,int now) {
	col[x]=c;st[++Top]=x;
	for(re int i=0;i<t[x].size();i++) {
		if(col[q[t[x][i]].x]&&col[q[t[x][i]].x]!=c) q[t[x][i]].l=now;
		if(col[q[t[x][i]].y]&&col[q[t[x][i]].y]!=c) q[t[x][i]].l=now;
	}
	for(re int i=head[x];i;i=e[i].nxt) {
		if(vis[e[i].v]||fa==e[i].v) continue;
		paint(e[i].v,x,c,now);
	}
}
void rebuild(int x) {
	vis[x]=1;dfn[x]=++__;
	int cnt=1;Top=0;col[x]=1;
	for(re int i=head[x];i;i=e[i].nxt) {
		if(vis[e[i].v]) continue;
		cnt++;paint(e[i].v,0,cnt,x);
	}
	while(Top) col[st[Top--]]=0;col[x]=0;
	for(re int i=head[x];i;i=e[i].nxt) {
		if(vis[e[i].v]) continue;
		S=sum[e[i].v],now=inf,getroot(e[i].v,0);
		v[x].push_back(rt),rebuild(rt);
	}
}
void getdis(int x,int fa,int o) {
	if(o) st[++Top]=x;
	for(re int i=0;i<K;i++) 
		dp[o][x][i]=dp[o][fa][i];
	for(re int i=0;i<K;i++) 
		dp[o][x][(i+a[x])%K]+=dp[o][fa][i],dp[o][x][(i+a[x])%K]%=mod;
	for(re int i=head[x];i;i=e[i].nxt) {
		if(vis[e[i].v]||fa==e[i].v) continue;
		getdis(e[i].v,x,o);
	}
}
inline void clear(int x) {
	memset(dp[0][x],0,sizeof(dp[0][x]));
	memset(dp[1][x],0,sizeof(dp[1][x]));
}
void dfs(int x) {
	vis[x]=1;
	Top=0;st[++Top]=x;dp[1][x][0]=1;dp[1][x][a[x]]++,dp[0][x][0]=1;
	for(re int i=head[x];i;i=e[i].nxt) {
		if(vis[e[i].v]) continue;
		getdis(e[i].v,x,0);getdis(e[i].v,x,1);
	}
	while(q[now].l==x&&now<=m) {
		LL ans=0;int ls=q[now].x,rs=q[now].y;
		for(re int i=0;i<K;i++)
			ans+=(dp[1][ls][i]*dp[0][rs][(K-i)%K]%mod),ans%=mod;
		Ans[q[now].rk]=ans,now++;
	}
	while(Top) clear(st[Top--]);
	for(re int i=0;i<v[x].size();i++) dfs(v[x][i]);
}
signed main() {
	n=read(),K=read();
	for(re int x,y,i=1;i<n;i++) x=read(),y=read(),add(x,y),add(y,x);
	for(re int i=1;i<=n;i++) a[i]=read()%K;m=read();
	for(re int i=1;i<=m;i++) 
		q[i].x=read(),q[i].y=read(),q[i].rk=i,t[q[i].x].push_back(i),t[q[i].y].push_back(i);
	for(re int i=1;i<=m;i++) if(!q[i].l) q[i].l=q[i].x;
	S=n,now=inf,getroot(1,0);R=rt;rebuild(rt);
	std::sort(q+1,q+m+1,cmp);
	memset(vis,0,sizeof(vis));now=1;dfs(R);
	for(re int i=1;i<=m;i++) printf("%d\n",Ans[i]);
	return 0;
}
posted @ 2019-03-03 15:57  asuldb  阅读(136)  评论(1编辑  收藏  举报