【uoj58】 WC2013—糖果公园
http://uoj.ac/problem/58 (题目链接)
题意
给定一棵树,每个点有一个颜色,提供两种操作:
1.询问两点间路径上的${\sum{v[a[i]]*w[k]}}$,其中${a[i]}$代表这个点的颜色,${k}$表示这个点是这种颜色第${k}$次出现
2.修改某个点的颜色
Solution
带修改树上莫队。%%%vfleaking
按左端点所在块为第一关键字,右端点所在块为第二关键字,时间为第三关键字,排序。可能会有疑问可不可以以右端点dfs序为第二关键字?这里我们为了突出第三关键字的作用,选择以右端点所在块为第二关键字。每个节点的dfs序都不同,如果以dfs序为第二关键字的话,第三关键字就没用了。当然这样写也不是不行,但时间会略长。
然后进行树上莫队,将询问和修改分开记录每次询问经过修改或逆修改来使时间倒流或前进。
2个月后重新来看自己的程序,纠结了好久为何这样在树上分块就能保证复杂度的问题,自己总结了几点:
- 分块只是用来保证莫队算法复杂的的工具,并无实际用处。
- 分块的方法并不是固定的,只要满足3点:相邻两块的距离是${\sqrt{n}}$级别的;每块的大小是${\sqrt{n}}$级别的;块的个数是${\sqrt{n}}$级别的。
UPD 2017.1.3:完了完了不会证复杂度了。。。
考虑先后2次询问${(u_i,v_i),(u_{i+1},v_{i+1})}$,块的大小为${s}$,询问以及长度为${n}$。我们分几种情况讨论。
1.${u_i,u_{i+1};v_i,v_{i+1}}$在同一块
块内的询问是无序的,所以块内转移一次最坏是${s}$,有${n}$组询问。那么复杂度是${O(ns)}$。
考虑时间:因为块内的时间是有序的,所以这种情况时间的转移就是按顺序扫过去${O(ns)}$
2.${u_i,u_{i+1};v_i,v_{i+1}}$不在同一块
思考,这样的转移最多${\frac{n}{s}^2}$次,每次最坏${n}$,那么时间复杂度${O(\frac{n^2}{s^2}*n)}$
考虑时间:同理,因为每次转移时间的移动可能是${O(n)}$,所以时间转移的复杂度${O(\frac{n^2}{s^2}*n)}$
于是我们为了平衡这两个复杂度,使用均值不等式,可以解得这两个复杂度相加的最小值,当${s=n^{\frac{2}{3}}}$时取到。
代码
// uoj58 #include<algorithm> #include<iostream> #include<cstring> #include<cstdlib> #include<cstdio> #include<cmath> #define MOD 1000000007 #define inf 2147483640 #define LL long long #define free(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout); using namespace std; inline LL getint() { LL x=0,f=1;char ch=getchar(); while (ch>'9' || ch<'0') {if (ch=='-') f=-1;ch=getchar();} while (ch>='0' && ch<='9') {x=x*10+ch-'0';ch=getchar();} return x*f; } const int maxn=100010; struct edge {int to,next;}e[maxn<<2]; struct ask {int u,v,id,pre,t;}a1[maxn],a2[maxn]; LL res[maxn],ans; int pos[maxn],v[maxn],w[maxn],dfn[maxn],bin[20],fa[maxn][20],deep[maxn],st[maxn],p[maxn],vis[maxn],c[maxn],pre[maxn],head[maxn]; int block,blonum,n,m,q,cnt,cnt1,cnt2,top; void link(int u,int v) { e[++cnt].to=v;e[cnt].next=head[u];head[u]=cnt; e[++cnt].to=u;e[cnt].next=head[v];head[v]=cnt; } bool cmp(ask a,ask b) { if (pos[a.u]==pos[b.u] && pos[a.v]==pos[b.v]) return a.t<b.t; if (pos[a.u]==pos[b.u]) return pos[a.v]<pos[b.v]; return pos[a.u]<pos[b.u]; } int dfs(int x) { //预处理+分块 int size=0; dfn[x]=++cnt; for (int i=1;i<20;i++) fa[x][i]=fa[fa[x][i-1]][i-1]; for (int i=head[x];i;i=e[i].next) if (e[i].to!=fa[x][0]) { deep[e[i].to]=deep[x]+1; fa[e[i].to][0]=x; size+=dfs(e[i].to); if (size>=block) { blonum++; while (size--) pos[st[top--]]=blonum; size=0; } } st[++top]=x; return size+1; } void work(int x) { if (!vis[x]) {vis[x]=1;p[c[x]]++;ans+=(LL)w[p[c[x]]]*v[c[x]];} else {vis[x]=0;ans-=(LL)w[p[c[x]]]*v[c[x]];p[c[x]]--;} } void modify(int x,int y) { //修改操作 if (vis[x]) {work(x);c[x]=y;work(x);} else c[x]=y; } void solve(int x,int y) { while (x!=y) { if (deep[x]<deep[y]) work(y),y=fa[y][0]; else work(x),x=fa[x][0]; } } int lca(int x,int y) { if (deep[x]<deep[y]) swap(x,y); int t=deep[x]-deep[y]; for (int i=0;bin[i]<=t;i++) if (bin[i]&t) x=fa[x][i]; for (int i=19;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return x==y?x:fa[x][0]; } int main() { bin[0]=1;for (int i=1;i<20;i++) bin[i]=bin[i-1]<<1; scanf("%d%d%d",&n,&m,&q); for (int i=1;i<=m;i++) scanf("%d",&v[i]); for (int i=1;i<=n;i++) scanf("%d",&w[i]); for (int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); link(u,v); } for (int i=1;i<=n;i++) scanf("%d",&c[i]),pre[i]=c[i]; block=(int)pow(n,0.6); cnt=0;dfs(1); cnt1=0,cnt2=0; for (int i=1;i<=q;i++) { int x,u,v; scanf("%d%d%d",&x,&u,&v); if (x) { //查询操作 cnt1++; if (dfn[u]>dfn[v]) swap(u,v); a1[cnt1].u=u;a1[cnt1].v=v;a1[cnt1].id=cnt1;a1[cnt1].t=cnt2; //查询时间 } else { //修改操作 cnt2++; a2[cnt2].u=u;a2[cnt2].v=v;a2[cnt2].pre=pre[u];pre[u]=v; //邻接表存一个节点的修改顺序 } } sort(a1+1,a1+cnt1+1,cmp); for (int i=1;i<=a1[1].t;i++) modify(a2[i].u,a2[i].v); solve(a1[1].u,a1[1].v); int t=lca(a1[1].u,a1[1].v); work(t); res[a1[1].id]=ans; work(t); for (int i=2;i<=cnt1;i++) { for (int j=a1[i-1].t+1;j<=a1[i].t;j++) modify(a2[j].u,a2[j].v); //时间流逝,修改 for (int j=a1[i-1].t;j>a1[i].t;j--) modify(a2[j].u,a2[j].pre); //时间倒流,逆修改 solve(a1[i-1].u,a1[i].u); solve(a1[i-1].v,a1[i].v); t=lca(a1[i].u,a1[i].v); work(t); res[a1[i].id]=ans; work(t); } for (int i=1;i<=cnt1;i++) printf("%lld\n",res[i]); return 0; }