UOJ#374. 【ZJOI2018】历史 贪心,LCT
原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ374.html
题解
想出正解有点小激动。
不过因为傻逼错误调到自闭。不如贺题
首先我们考虑如何 $O(n)$ 求一个答案。
首先,计算两条路径的贡献时,由于两国连续交战数次只算一次,所以我们可以只看这两条路径的交的最深点。
也就是说,我们可以对于每一个点分开考虑,假装他的同一个子树的所有点颜色相同,不同子树的点颜色不同,它本身也当作一个子树看。
假设 $x$ 是当前节点,$y$ 是 $x$ 的子树。
设 $size[v]$ 表示 $v$ 子树的所有节点的 $a[v]$ 之和。
那么我们容易推出两个断论:
1. $x$ 节点对答案的贡献最多不超过 $size[x] - 1$ 。
2. 节点 $x$ 的贡献具体是什么?我们要尝试将每个子树中的节点颜色集合混合成一个序列,最大化相邻不同色对数。设 $\max(size[y])$ 表示 $x$ 的所有子树中 $size$ 最大的子树的 $size$ ,当 $\max(size[y]) - 1 \leq size[x] - \max(size[y])$ 时,都有使 $x$ 的贡献为 $size[x] - 1$ 的方案;否则, $x$ 节点对答案的贡献最大为 $\max(size[y]) - 1 - (size[x] - \max(size[y])) = 2\max(size[y]) - 1 - size[x]$
所以贡献为
$$min(size[x] - 1, 2max(size[y]) - 1 - size[x])$$
设 $val[x] = size[x] - 1$ ,可以证明 $\sum_{y} val[y] \leq \sum_{y} (size[y] - 1) \leq size[x] - 1 = val[x]$
则这个式子会更加好看(把常数消掉了,然并卵):
$$min(val[x],2max(val[y])-val[x])$$
现在已经可以轻松拿到 30 分了。
考虑 100 分怎么做。
我们可以发现好像操作的时候所有的 $\max(val[y])$ 的 $y$ 的变化次数不多啊!
于是我们可以想到 LCT 维护这个东西。
这里的 LCT 不是传统的 LCT 。
如果 $val[x] \geq val[fa[x]]$ 那么我们将 $x$ 作为 $fa[x]$ 的重儿子。我们可以发现每一个节点只有一个重儿子:由于 $\sum_{y} val[y] \leq val[x]$ ,而且两个子树的特殊情况特殊考虑一下发现也是对的。
这样的话,可能会有节点没有重儿子。
但是,从任意一个节点到根走过的轻边条数是 $O(\log \sum a[i])$ 的,因为每走过一条轻边,子树权值和至少翻一倍。
然后你发现修改一个点的时候只要修改它到根路径上的所有点权$(val[x])$,而且对于重链,它对答案的贡献是不变的!
所以只要对 $O(\log\sum a[i])$ 个轻边处理就好了。
由于要链上修改点权,所以每一段重链都要预先下传标记。
总的来说,这样做要跳过 $O(\log \sum a[i])$ 段重链,每段重链 splay 需要花费 $O(\log n)$ 的时间复杂度,所以看上去复杂度是 $O(n\log^2 n)$ 的。80分很开心了吧!更开心的是如果交上去的话它能 AC 。
这是为什么呢?我们考虑势能分析,定义势函数为 $\sum_{ LCT 上所有节点 }\ \ \ \ \ \log (该节点在splay结构上的size + 它的虚子树的size)$ ,类似于 splay 复杂度的证明,可以证明这个东西是均摊 $O(\log \sum a[i] + \log n)$ 的。
这里不把证明写出来了。懒得写了。
最终时间复杂度为 $O(n\log(n\sum a[i]))$ 。
注意在写代码的时候要注意一些细节。对于节点本身的贡献我们可以把每一个点拆成两个点,第一个点先连原先所有子树,再新建第二个点,让他们连起来,并使第一个点是第二个点的父亲,第二个点的权值为 $a[x] - 1$ 。这样可以减掉几个 if 。
注意链上修改的时候,不是直接给根打标记就完事了,因为这里的 LCT 比较奇怪,所以直接打标记会多给一段后缀重链带来修改,所以我们还要再在这个后缀重链上打个标记来抵消根上的标记。
代码
#pragma GCC optimize("Ofast","inline") #include <bits/stdc++.h> #define clr(x) memset(x,0,sizeof (x)) #define For(i,a,b) for (int i=a;i<=b;i++) #define Fod(i,b,a) for (int i=b;i>=a;i--) #define pb push_back #define mp make_pair #define fi first #define se second #define _SEED_ ('C'+'L'+'Y'+'A'+'K'+'I'+'O'+'I') #define outval(x) printf(#x" = %d\n",x) #define outvec(x) printf("vec "#x" = ");for (auto _v : x)printf("%d ",_v);puts("") #define outtag(x) puts("----------"#x"----------") using namespace std; typedef long long LL; LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=400005*2; int n,m; LL a[N],s[N],v[N],f[N]; vector <int> e[N]; LL ans=0; void dfs(int x,int pre){ f[x]=pre; s[x]=a[x]; LL Mx=a[x]-1; for (auto y : e[x]) if (y!=pre){ dfs(y,x); s[x]+=s[y]; Mx=max(Mx,v[y]); } v[x]=s[x]-1; ans+=min(v[x],(v[x]-Mx)*2); } int fa[N],son[N][2]; LL val[N],Add[N],Mxv[N]; void LCT_build(){ clr(son),clr(val),clr(Add); For(i,1,n){ fa[i]=i+n,val[i]=a[i]-1; fa[i+n]=f[i]?f[i]+n:0,val[i+n]=v[i]; } For(i,1,n*2){ Mxv[i]=val[i]; if (fa[i]&&val[i]*2>=val[fa[i]]) son[fa[i]][1]=i; } } #define ls son[x][0] #define rs son[x][1] int isroot(int x){ return son[fa[x]][0]!=x&&son[fa[x]][1]!=x; } int wson(int x){ return son[fa[x]][1]==x; } void pushup(int x){ Mxv[x]=max(val[x],max(Mxv[ls],Mxv[rs])); } void pushdown(int x){ if (Add[x]){ if (ls) val[ls]+=Add[x],Add[ls]+=Add[x],Mxv[ls]+=Add[x]; if (rs) val[rs]+=Add[x],Add[rs]+=Add[x],Mxv[rs]+=Add[x]; Add[x]=0; } } void pushadd(int x){ if (!isroot(x)) pushadd(fa[x]); pushdown(x); } void rotate(int x){ if (isroot(x)) return; int y=fa[x],z=fa[y],L=wson(x),R=L^1; if (!isroot(y)) son[z][wson(y)]=x; fa[x]=z,fa[y]=x,fa[son[x][R]]=y; son[y][L]=son[x][R],son[x][R]=y; pushup(y),pushup(x); } void splay(int x){ pushadd(x); for (int y=fa[x];!isroot(x);rotate(x),y=fa[x]) if (!isroot(y)) rotate(wson(x)==wson(y)?y:x); } void False_Access(int x){//pushdown the tags while (x) splay(x),x=fa[x]; } void update(int x,LL w){ False_Access(x); if (rs) val[rs]-=w,Add[rs]-=w,Mxv[rs]-=w; while (fa[x]){ int y=fa[x]; if (son[y][1]){ if (val[y]+w>Mxv[son[y][1]]*2){ ans+=val[y]+w-(val[y]-Mxv[son[y][1]])*2; son[y][1]=0; } else ans+=w*2; } else ans+=w; if ((Mxv[x]+w)*2>val[y]+w){ ans+=(val[y]+w-(Mxv[x]+w))*2-(val[y]+w); son[y][1]=x; } else { val[x]+=w,Add[x]+=w,Mxv[x]+=w; if (son[y][1]) val[son[y][1]]-=w,Add[son[y][1]]-=w,Mxv[son[y][1]]-=w; } x=y; } val[x]+=w,Add[x]+=w,Mxv[x]+=w; } #undef ls #undef rs int main(){ n=read(),m=read(); For(i,1,n) a[i]=read(); For(i,1,n-1){ int x=read(),y=read(); e[x].pb(y),e[y].pb(x); } dfs(1,0); printf("%lld\n",ans); LCT_build(); For(i,1,m){ int x=read(),w=read(); update(x,w); printf("%lld\n",ans); } return 0; }