题解:
就是按照常规的合并
期望有一点麻烦
首先计算全部的和
再减去有多少种
具体看看http://blog.csdn.net/PoPoQQQ/article/category/2542261这个博客吧
代码:
#include<bits/stdc++.h> using namespace std; #define pa t[x].fa #define lc t[x].ch[0] #define rc t[x].ch[1] const int N=5e4+5; typedef long long ll; int read() { char c=getchar();int x=0,f=1; while(c<'0'||c>'9'){if (c=='-')f=-1;c=getchar();} while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();} return x*f; } struct node { int ch[2],fa,rev; ll add,lsum,rsum,sum,exp,w,size; }t[N]; int wh(int x){return t[pa].ch[1]==x;} int isRoot(int x){return t[pa].ch[0]!=x&&t[pa].ch[1]!=x;} void update(int x) { t[x].size=t[lc].size+t[rc].size+1; t[x].sum=t[lc].sum+t[rc].sum+t[x].w; t[x].lsum=t[lc].lsum+t[x].w*(t[lc].size+1)+t[rc].lsum+t[rc].sum*(t[lc].size+1); t[x].rsum=t[rc].rsum+t[x].w*(t[rc].size+1)+t[lc].rsum+t[lc].sum*(t[rc].size+1); t[x].exp=t[lc].exp+t[rc].exp +t[lc].lsum*(t[rc].size+1)+t[rc].rsum*(t[lc].size+1) +t[x].w*(t[lc].size+1)*(t[rc].size+1); } ll cal1(ll x){return x*(x+1)/2;} ll cal2(ll x){return x*(x+1)*(x+2)/6;} void paint(int x,ll d) { t[x].w+=d; t[x].add+=d; t[x].sum+=d*t[x].size; t[x].lsum+=d*cal1(t[x].size); t[x].rsum+=d*cal1(t[x].size); t[x].exp+=d*cal2(t[x].size); } void rever(int x) { swap(lc,rc); swap(t[x].lsum,t[x].rsum); t[x].rev^=1; } void pushDown(int x) { if (t[x].rev) { rever(lc); rever(rc); t[x].rev=0; } if (t[x].add) { paint(lc,t[x].add); paint(rc,t[x].add); t[x].add=0; } } void rotate(int x) { int f=t[x].fa,g=t[f].fa,c=wh(x); if (!isRoot(f)) t[g].ch[wh(f)]=x;t[x].fa=g; t[f].ch[c]=t[x].ch[c^1];t[t[f].ch[c]].fa=f; t[x].ch[c^1]=f;t[f].fa=x; update(f);update(x); } int st[N],top; void splay(int x) { top=0;st[++top]=x; for (int i=x;!isRoot(i);i=t[i].fa) st[++top]=t[i].fa; for (int i=top;i>=1;i--) pushDown(st[i]); for (;!isRoot(x);rotate(x)) if (!isRoot(pa)) rotate(wh(x)==wh(pa)?pa:x); } void Access(int x) { for (int y=0;x;y=x,x=pa) { splay(x); rc=y; update(x); } } void MakeR(int x){Access(x);splay(x);rever(x);} int FindR(int x){Access(x);splay(x);while(lc) x=lc;return x;} void Link(int x,int y){MakeR(x);t[x].fa=y;} void Cut(int x,int y) { MakeR(x);Access(y);splay(y); t[y].ch[0]=t[x].fa=0; update(y); } void Add(int x,int y,int d) { if (FindR(x)!=FindR(y)) return; MakeR(x);Access(y);splay(y); paint(y,d); } ll gcd(ll a,ll b){return b==0?a:gcd(b,a%b);} void Que(int x,int y) { if (FindR(x)!=FindR(y)){puts("-1");return;} MakeR(x);Access(y);splay(y); ll a=t[y].exp,b=t[y].size*(t[y].size+1)/2; ll g=gcd(a,b); printf("%lld/%lld\n",a/g,b/g); } int n,Q,a,op,x,y,d; int main() { n=read();Q=read(); for (int i=1;i<=n;i++) { a=read(); t[i].size=1; t[i].w=t[i].lsum=t[i].rsum=t[i].sum=t[i].exp=a; } for (int i=1;i<=n-1;i++) x=read(),y=read(),Link(x,y); while(Q--) { op=read();x=read();y=read(); if (op==1) if (FindR(x)==FindR(y)) Cut(x,y); if (op==2) if (FindR(x)!=FindR(y)) Link(x,y); if (op==3) d=read(),Add(x,y,d); if (op==4) Que(x,y); } }