松鼠的新家 (lca+树上差分)或(树链剖分)
题目链接:https://www.luogu.com.cn/problem/P3258
题意:给出一个n 再给出走这n个点的顺序,再给出这n个点的连接方式(n-1条边,形成树)
思路:我们考虑lca+树上差分,首先介绍一下树上差分;
树上差分:想法跟普通的差不多,举个例子:假如我们要在某节点以及其到根节点所设计的节点都加上1;
我们就只需要在某节点加上1即可(用数组p来表示)
然后本题要求:假如从a走到b 我们则需要如下操作:
这样得出的答案就是这一条路径都+1;
那么本题呢,就需要再求lca,这里我们用倍增法;
求出最后答案后,我们再将重复走的点--即可;
因为走的方式是 (a,b) (b,c) (c,d) 所以中间的点(除去左右端点)我们在计算的时候多算了一次,减去 ,而最后一个节点题目中有说明不需要放置糖果,所以也减去
1 #include<bits/stdc++.h> 2 using namespace std; 3 const int maxn=3e5+10; 4 struct node 5 { 6 int v; 7 int nxt; 8 }G[maxn<<1]; 9 int f[maxn][30]; int dep[maxn]; 10 int b[maxn]; 11 int head[maxn];int num=-1; 12 int p[maxn]; 13 void add(int u,int v) 14 { 15 G[++num].v=v;G[num].nxt=head[u];head[u]=num; 16 G[++num].v=u;G[num].nxt=head[v];head[v]=num; 17 } 18 void dfs(int u,int father)//对应深搜预处理f数组 19 { 20 dep[u]=dep[father]+1; 21 for(int i=1;(1<<i)<=dep[u];i++){ 22 f[u][i]=f[f[u][i-1]][i-1]; 23 } 24 for(int i=head[u];i!=-1;i=G[i].nxt){ 25 int v=G[i].v; 26 if(v==father)continue;//双向图需要判断是不是父亲节点 27 f[v][0]=u; 28 dfs(v,u); 29 } 30 } 31 int lca(int x,int y) 32 { 33 if(dep[x]<dep[y]) swap(x,y); 34 for(int i=20;i>=0;i--)//从大到小枚举使x和y到了同一层 35 { 36 if(dep[f[x][i]]>=dep[y]) 37 x=f[x][i]; 38 if(x==y)return x; 39 } 40 for(int i=20;i>=0;i--)//从大到小枚举 41 { 42 if(f[x][i]!=f[y][i])//尽可能接近 43 { 44 x=f[x][i];y=f[y][i]; 45 } 46 } 47 return f[x][0];//随便找一个**输出 48 } 49 void solve(int x,int fa){ 50 for(int i=head[x];i!=-1;i=G[i].nxt){ 51 if(fa==G[i].v) continue; 52 solve(G[i].v,x); 53 p[x]+=p[G[i].v]; 54 } 55 } 56 int main() 57 { 58 int n; 59 memset(head,-1,sizeof(head)); 60 scanf("%d",&n); 61 for(int i=1;i<=n;i++) 62 scanf("%d",&b[i]); 63 for(int i=1;i<n;i++){ 64 int x,y; 65 scanf("%d%d",&x,&y); 66 add(x,y); 67 } 68 dfs(1,0); 69 for(int i=1;i<n;i++){ 70 p[b[i]]++;p[b[i+1]]++; 71 int tmp=lca(b[i],b[i+1]); 72 p[tmp]--; 73 p[f[tmp][0]]--; 74 } 75 solve(1,0); 76 for(int i=2;i<=n;i++) p[b[i]]--; 77 for(int i=1;i<=n;i++) 78 printf("%d\n",p[i]); 79 //printf("\n"); 80 }
再贴一份树链剖分的代码;(代码贴自其他题目,所以有些冗余的东西)
1 #include<algorithm> 2 #include<iostream> 3 #include<cstdlib> 4 #include<cstring> 5 #include<cstdio> 6 #define Rint register int 7 #define mem(a,b) memset(a,(b),sizeof(a)) 8 #define Temp template<typename T> 9 using namespace std; 10 typedef long long LL; 11 const int maxn=3e5+10; 12 const int mod=1e9+7; 13 14 //见题意 15 int b[maxn]; 16 int w[maxn],wt[maxn]; 17 int ans[maxn]; 18 //链式前向星数组,w[]、wt[]初始点权数组 19 int son[maxn],id[maxn],fa[maxn],cnt,dep[maxn],siz[maxn],top[maxn]; 20 //son[]重儿子编号,id[]新编号,fa[]父亲节点,cnt dfs_clock/dfs序,dep[]深度,siz[]子树大小,top[]当前链顶端节点 21 //查询答案 22 struct node 23 { 24 int v,nxt; 25 }G[maxn<<2]; int head[maxn]; int num; 26 struct tre 27 { 28 int l,r,lazy,sum; 29 }tree[maxn<<2]; 30 void add(int u,int v) 31 { 32 G[++num].v=v;G[num].nxt=head[u];head[u]=num; 33 G[++num].v=u;G[num].nxt=head[v];head[v]=num; 34 } 35 void pushdown(int rt,int lenn){ 36 tree[rt<<1].lazy+=tree[rt].lazy; 37 tree[rt<<1|1].lazy+=tree[rt].lazy; 38 tree[rt<<1].sum+=tree[rt].lazy*(lenn-(lenn>>1)); 39 tree[rt<<1|1].sum+=tree[rt].lazy*(lenn>>1); 40 tree[rt<<1].sum%=mod; 41 tree[rt<<1|1].sum%=mod; 42 tree[rt].lazy=0; 43 } 44 45 void build(int l,int r,int root){ 46 tree[root].l=l;tree[root].r=r; 47 tree[root].sum=tree[root].lazy=0; 48 if(l==r){ 49 tree[root].sum=wt[l]; 50 if(tree[root].sum>mod)tree[root].sum%=mod; 51 return; 52 } 53 int mid=l+r>>1; 54 build(l,mid,root<<1); 55 build(mid+1,r,root<<1|1); 56 tree[root].sum=(tree[root<<1].sum+tree[root<<1|1].sum)%mod; 57 } 58 59 int query(int l,int r,int root){ 60 int L=tree[root].l;int R=tree[root].r; 61 if(l<=L&&R<=r){ 62 return tree[root].sum%mod; 63 } 64 if(tree[root].lazy)pushdown(root,R-L+1); 65 int mid=L+R>>1; 66 int ans=0; 67 if(l<=mid) ans+=query(l,r,root<<1),ans%=mod; 68 if(r>mid) ans+=query(l,r,root<<1|1),ans%=mod; 69 return ans; 70 } 71 void update(int l,int r,int val,int root){ 72 int L=tree[root].l;int R=tree[root].r; 73 if(l<=L&&R<=r){ 74 tree[root].lazy+=val; 75 tree[root].sum+=val*(R-L+1); 76 } 77 else{ 78 if(tree[root].lazy)pushdown(root,R-L+1); 79 int mid=L+R>>1; 80 if(l<=mid)update(l,r,val,root<<1); 81 if(r>mid) update(l,r,val,root<<1|1); 82 tree[root].sum=(tree[root<<1].sum+tree[root<<1|1].sum)%mod; 83 } 84 } 85 int qRange(int x,int y){ 86 int ans=0; 87 while(top[x]!=top[y]){//当两个点不在同一条链上 88 if(dep[top[x]]<dep[top[y]])swap(x,y);//把x点改为所在链顶端的深度更深的那个点 89 ans+=query(id[top[x]],id[x],1);//ans加上x点到x所在链顶端 这一段区间的点权和 90 ans%=mod;//按题意取模 91 x=fa[top[x]];//把x跳到x所在链顶端的那个点的上面一个点 92 } 93 //直到两个点处于一条链上 94 if(dep[x]>dep[y])swap(x,y);//把x点深度更深的那个点 95 ans+=query(id[x],id[y],1);//这时再加上此时两个点的区间和即可 96 return ans%mod; 97 } 98 99 void updRange(int x,int y,int k){//同上 100 k%=mod; 101 while(top[x]!=top[y]){ 102 if(dep[top[x]]<dep[top[y]])swap(x,y); 103 update(id[top[x]],id[x],k,1); 104 x=fa[top[x]]; 105 } 106 if(dep[x]>dep[y])swap(x,y); 107 update(id[x],id[y],k,1); 108 } 109 110 void dfs1(int u,int f,int deep){//x当前节点,f父亲,deep深度 111 dep[u]=deep;//标记每个点的深度 112 fa[u]=f;//标记每个点的父亲 113 siz[u]=1;//标记每个非叶子节点的子树大小 114 int maxson=-1;//记录重儿子的儿子数 115 for(int i=head[u];i!=-1;i=G[i].nxt){ 116 int v=G[i].v; 117 if(v==f)continue;//若为父亲则continue 118 dfs1(v,u,deep+1);//dfs其儿子 119 siz[u]+=siz[v];//把它的儿子数加到它身上 120 if(siz[v]>maxson)son[u]=v,maxson=siz[v];//标记每个非叶子节点的重儿子编号 121 } 122 } 123 124 void dfs2(int u,int topf){//x当前节点,topf当前链的最顶端的节点 125 id[u]=++cnt;//标记每个点的新编号 126 wt[cnt]=w[u];//把每个点的初始值赋到新编号上来 127 top[u]=topf;//这个点所在链的顶端 128 if(!son[u])return;//如果没有儿子则返回 129 dfs2(son[u],topf);//按先处理重儿子,再处理轻儿子的顺序递归处理 130 for(int i=head[u];i!=-1;i=G[i].nxt){ 131 int v=G[i].v; 132 if(v==fa[u]||v==son[u])continue; 133 dfs2(v,v);//对于每一个轻儿子都有一条从它自己开始的链 134 } 135 } 136 void init() 137 { 138 memset(head,-1,sizeof(head)); 139 num=-1; 140 } 141 int main(){ 142 init(); 143 int n; 144 scanf("%d",&n); 145 for(int i=1;i<=n;i++) scanf("%d",&b[i]); 146 for(int i=1;i<n;i++){ 147 int u,v; 148 scanf("%d%d",&u,&v); 149 add(u,v); 150 } 151 dfs1(1,0,1); 152 dfs2(1,1); 153 build(1,n,1); 154 for(int i=1;i<n;i++) { 155 int tmp1=b[i]; 156 int tmp2=b[i+1]; 157 updRange(tmp1,tmp2,1); 158 } 159 for(int i=1;i<=n;i++){ 160 ans[i]=qRange(i,i); 161 } 162 for(int i=2;i<=n;i++) ans[b[i]]--; 163 for(int i=1;i<=n;i++) 164 printf("%d\n",ans[i]); 165 return 0; 166 }