分析
我们如果已经找到了一条满足要求的路径,如果将它继续延伸的话,或值只会增加,所以你不管延伸多长都可以的。
所以考虑二分答案,先确定路径终点,二分最短能满足要求的延伸长度,然后该次贡献的答案就能通过深度轻松算出。
考虑二分的 check,我们用倍增来计算,\(f[i][j]\) 表示倍增到的点编号,\(g[i][j]\) 表示到那个点路上所有点的权值的或值。
然后捆绑点2需要数据分治,因为会爆空间,所以单独用一个 ST 表去处理它。
代码
#include<bits/stdc++.h>
using namespace std;
inline void read(int &res){
char c;
int f=1;
res=0;
c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')res=(res<<1)+(res<<3)+c-48,c=getchar();
res*=f;
}
int n,k;
int lg[1000005];
int st[1000005][20];
int f[500005][20],g[500005][20];
int dep[500005],a[500005];
int head[500005],tot;
int in[500005],root;
struct edge{
int to,nex;
}e[500005];
int specheck(int l,int r){
int kk=lg[r-l+1];
return st[l][kk]|st[r-(1<<kk)+1][kk];
}
void special(){
lg[0]=-1;
for(int i=1;i<=n;i++)lg[i]=lg[i>>1]+1;
for(int i=1;i<=n;i++)read(st[i][0]);
for(int i=1;i<=19;i++){
for(int j=1;j+(1<<i)-1<=n;j++){
st[j][i]=st[j][i-1]|st[j+(1<<(i-1))][i-1];//ST
}
}
long long anss=0;
for(int i=1;i<=n;i++){
int l=i,r=n,ans=-1;
while(l<=r){
int mid=(l+r)>>1;
if(specheck(i,mid)>=k){
ans=mid;
r=mid-1;
}
else l=mid+1;
}
if(ans!=-1){
anss+=n-ans+1;
}
}
cout<<anss;
}
inline void add(int qq,int mm){
e[++tot].to=mm;
e[tot].nex=head[qq];
head[qq]=tot;
}
void dfs(int u,int ff){
f[u][0]=ff;
g[u][0]=a[ff];//dfs初始化
dep[u]=dep[ff]+1;
for(int i=head[u];i;i=e[i].nex){
int v=e[i].to;
dfs(v,u);
}
}
void pre(){
for(int i=1;i<=19;i++){
for(int j=1;j<=n;j++){
f[j][i]=f[f[j][i-1]][i-1];
g[j][i]=g[j][i-1]|g[f[j][i-1]][i-1];//倍增预处理
}
}
}
int check(int x,int len){
int sum=a[x];
for(int i=0;i<=19;i++){
if(len&(1<<i)){
sum|=g[x][i];
x=f[x][i];
}
}
return sum;
}
signed main()
{
read(n);read(k);
if(n>500000){
special();
return 0;
}
for(int i=1;i<=n;i++)read(a[i]);
for(int i=1;i<n;i++){
int x,y;
read(x);read(y);
add(x,y);
in[y]++;
}
for(int i=1;i<=n;i++)if(!in[i]){
root=i;
break;
}
dfs(root,0);
pre();
long long anss=0;
for(int i=1;i<=n;i++){
int l=0,r=dep[i]-1,ans=-1;
while(l<=r){//二分
int mid=(l+r)>>1;
if(check(i,mid)>=k){
ans=mid;
r=mid-1;
}
else l=mid+1;
}
if(ans!=-1)anss+=dep[i]-ans;
}
cout<<anss;
return 0;
}