分析

我们如果已经找到了一条满足要求的路径,如果将它继续延伸的话,或值只会增加,所以你不管延伸多长都可以的。

所以考虑二分答案,先确定路径终点,二分最短能满足要求的延伸长度,然后该次贡献的答案就能通过深度轻松算出。

考虑二分的 check,我们用倍增来计算,\(f[i][j]\) 表示倍增到的点编号,\(g[i][j]\) 表示到那个点路上所有点的权值的或值。

然后捆绑点2需要数据分治,因为会爆空间,所以单独用一个 ST 表去处理它。

代码

#include<bits/stdc++.h>
using namespace std;
inline void read(int &res){
	char c;
	int f=1;
	res=0;
	c=getchar();
	while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
	while(c>='0'&&c<='9')res=(res<<1)+(res<<3)+c-48,c=getchar();
	res*=f;
}
int n,k;
int lg[1000005];
int st[1000005][20];
int f[500005][20],g[500005][20];
int dep[500005],a[500005];
int head[500005],tot;
int in[500005],root;
struct edge{
	int to,nex;
}e[500005];
int specheck(int l,int r){
	int kk=lg[r-l+1];
	return st[l][kk]|st[r-(1<<kk)+1][kk];
}
void special(){
	lg[0]=-1;
	for(int i=1;i<=n;i++)lg[i]=lg[i>>1]+1;
	for(int i=1;i<=n;i++)read(st[i][0]);
	for(int i=1;i<=19;i++){
		for(int j=1;j+(1<<i)-1<=n;j++){
			st[j][i]=st[j][i-1]|st[j+(1<<(i-1))][i-1];//ST
		}
	}
	long long anss=0;
	for(int i=1;i<=n;i++){
		int l=i,r=n,ans=-1;
		while(l<=r){
			int mid=(l+r)>>1;
			if(specheck(i,mid)>=k){
				ans=mid;
				r=mid-1;
			}
			else l=mid+1;
		}
		if(ans!=-1){
			anss+=n-ans+1;
		}
	}
	cout<<anss;
}
inline void add(int qq,int mm){
	e[++tot].to=mm;
	e[tot].nex=head[qq];
	head[qq]=tot;
}
void dfs(int u,int ff){
	f[u][0]=ff;
	g[u][0]=a[ff];//dfs初始化 
	dep[u]=dep[ff]+1;
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to;
		dfs(v,u);
	}
}
void pre(){
	for(int i=1;i<=19;i++){
		for(int j=1;j<=n;j++){
			f[j][i]=f[f[j][i-1]][i-1];
			g[j][i]=g[j][i-1]|g[f[j][i-1]][i-1];//倍增预处理 
		}
	}
}
int check(int x,int len){
	int sum=a[x];
	for(int i=0;i<=19;i++){
		if(len&(1<<i)){
			sum|=g[x][i];
			x=f[x][i];
		}
	}
	return sum;
}
signed main()
{
	read(n);read(k);
	if(n>500000){
		special();
		return 0;
	}
	for(int i=1;i<=n;i++)read(a[i]);
	for(int i=1;i<n;i++){
		int x,y;
		read(x);read(y);
		add(x,y);
		in[y]++;
	}
	for(int i=1;i<=n;i++)if(!in[i]){
		root=i;
		break;
	}
	dfs(root,0);
	pre();
	long long anss=0;
	for(int i=1;i<=n;i++){
		int l=0,r=dep[i]-1,ans=-1;
		while(l<=r){//二分 
			int mid=(l+r)>>1;
			if(check(i,mid)>=k){
				ans=mid;
				r=mid-1;
			}
			else l=mid+1;
		}
		if(ans!=-1)anss+=dep[i]-ans;
	}
	cout<<anss;
	return 0;
}




posted on 2021-11-12 16:16  漠寒·  阅读(31)  评论(0编辑  收藏  举报