BZOJ 4712 洪水 动态dp(LCT+矩阵乘法)
把之前写的版本改了一下,这个版本的更好理解一些.
特地在一个链的最底端特判了一下.
code:
#include <bits/stdc++.h> #define N 200005 #define ll long long #define inf 10000000005 #define lson p[x].ch[0] #define rson p[x].ch[1] #define setIO(s) freopen(s".in","r",stdin) , freopen(s".out","w",stdout) using namespace std; int n,edges; ll f[N],g[N],val[N],son[N],sta[N]; int hd[N],to[N<<1],nex[N<<1]; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } struct Node { int ch[2],f,rev; }p[N]; struct Matrix { ll a[2][2]; ll*operator[](int x) { return a[x]; } }t[N],tmp[N]; Matrix operator*(Matrix a,Matrix b) { Matrix c; c[0][0]=min(a[0][0]+b[0][0],a[0][1]+b[1][0]); c[0][1]=min(a[0][0]+b[0][1],a[0][1]+b[1][1]); c[1][0]=min(a[1][0]+b[0][0],a[1][1]+b[1][0]); c[1][1]=min(a[1][0]+b[0][1],a[1][1]+b[1][1]); return c; } int get(int x) { return p[p[x].f].ch[1]==x; } int isrt(int x) { return !(p[p[x].f].ch[0]==x||p[p[x].f].ch[1]==x); } void pushup(int x) { if(!rson) { t[x]=tmp[x]; t[x][0][0]=min(t[x][0][0],t[x][0][1]); if(lson) t[x]=t[lson]*t[x]; } else { t[x]=tmp[x]; if(lson) t[x]=t[lson]*t[x]; if(rson) t[x]=t[x]*t[rson]; } } void rotate(int x) { int old=p[x].f,fold=p[old].f,which=get(x); if(!isrt(old)) p[fold].ch[p[fold].ch[1]==old]=x; p[old].ch[which]=p[x].ch[which^1],p[p[old].ch[which]].f=old; p[x].ch[which^1]=old,p[old].f=x,p[x].f=fold; pushup(old),pushup(x); } void splay(int x) { int v=0,u=x,fa; for(u;!isrt(u);u=p[u].f); for(u=p[u].f;(fa=p[x].f)!=u;rotate(x)) if(p[fa].f!=u) rotate(get(fa)==get(x)?fa:x); } void Access(int x) { for(int y=0;x;y=x,x=p[x].f) { splay(x); if(rson) { tmp[x][0][0]+=t[rson][0][0]; } if(y) { tmp[x][0][0]-=t[y][0][0]; } rson=y; pushup(x); } } void dfs(int x,int ff) { p[x].f=ff; ll sum=0; for(int i=hd[x];i;i=nex[i]) { int v=to[i]; if(v!=ff) dfs(v,x), sum+=f[v]; } if(sum!=0) f[x]=min(sum, val[x]); else f[x]=val[x]; tmp[x][0][0]=(sum==0?inf:sum); // 子树信息 tmp[x][0][1]=val[x]; tmp[x][1][0]=tmp[x][1][1]=0; pushup(x); } int main() { int i,j,k,m; // setIO("input"); scanf("%d",&n); for(i=1;i<=n;++i) scanf("%lld",&val[i]); for(i=1;i<n;++i) { int u,v; scanf("%d%d",&u,&v),add(u,v),add(v,u); } dfs(1,0); scanf("%d",&m); for(i=1;i<=m;++i) { char ss[5]; scanf("%s",ss); if(ss[0]=='Q') { int x; scanf("%d",&x); Access(x); splay(x); printf("%lld\n",min(tmp[x][0][0],tmp[x][0][1])); } else { int x; ll d; scanf("%d%lld",&x,&d); Access(x), splay(x); tmp[x][0][1]+=d; val[x]+=d; pushup(x); } } return 0; }