动态dp

题解:

首先这类题目本身是一个dp/树形dp

然后带上了修改(ddp)

为了权衡查询和修改的时间,我们采用树剖来维护

假设我们能够对每个点维护除了重儿子之外的信息

那么我们的修改只需要修改log个节点,查询只需要把从当前点到最后一个重儿子的信息合并即可

观察这一题的转移方程

$$f[i][0]=\sum{max(f[v][1],f[v][0])},f[i][1]=MAX(v[i],0)+\sum{f[v][0]}$$

如果要维护一条链的信息,我们肯定要维护它上端选不选,下端选不选

我们定义一种矩阵乘法,满足用$+$替代原来的$*$,用$max$代替原来的$+$

而普通的矩阵乘法满足结合律 $$a*b*c=a*(b*c) a*b+a*c=a*(b+c)$$

不满足交换律 $a*b不一定等于b*a$

而可以证明现在这个运算仍然满足这个性质

于是我们维护这么一个矩阵 

$\left\{ \begin{array} {lcr}
父亲一定不取,儿子随意 & 父亲一定取,儿子随意 \\
父亲一定不取,儿子不取 & 父亲一定取 儿子不取
\end{array} \right\}$


那么现在将后面位置的矩阵作为$A$,上面的作为$B$
设中间的两个节点为$k1,k2$,上面的为$fa$,下面的为$son$
即$A$对应$k1,son$ $B$对应$fa,k2$
运算

$$A*B的(0,0)=
MAX(k1不取+son随意+fa不取+k2随意 ,k1取+son随意+fa取+k2取)
$$


同理另外4个位置也可以发现包含了所有合法情况
而某个位置的初值
$\left\{\begin{array} {lcr}
f[x][0] & f[x][1] \\
f[x][0] & -INF
\end{array} \right\}$
于是这样子我们修改的时候修改当前点以及每条重链的父亲
另外注意由于维护重链父亲的信息的时候要减去以前的当前点的信息(因为不是重儿子所以算入了父亲点)
所以开个数组记录一下修改之前这个点产生的贡献

 

代码:

// luogu-judger-enable-o2
#include <bits/stdc++.h>
using namespace std;
#define rint register int
#define IL inline
#define rep(i,h,t) for (int i=h;i<=t;i++)
#define dep(i,t,h) for (int i=t;i>=h;i--)
#define me(x) memset(x,0,sizeof(x))
#define mid ((h+t)>>1)
namespace IO{
    char ss[1<<24],*A=ss,*B=ss;
    IL char gc()
    {
        return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
    }
    template<class T>void read(T &x)
    {
        rint f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
        while (c=gc(),c>47&&c<58) x=(x<<3)+(x<<1)+(c^48); x*=f;
    }
    char sr[1<<24],z[20]; int Z,CC=-1;
    template<class T>void wer(T x)
    {
        if (x<0) sr[++CC]='-',x=-x;
        while (z[++Z]=x%10+48,x/=10);
        while (sr[++CC]=z[Z],--Z); 
    }
    IL void wer1()
    {
        sr[++CC]=' ';
    }
    IL void wer2()
    {
        sr[++CC]='\n';
    }
    template<class T>IL void mina(T &x,T y) { if (x>y) x=y;}
    template<class T>IL void maxa(T &x,T y) { if (x<y) x=y;}
    template<class T>IL T MIN(T x,T y){return x<y?x:y;}
    template<class T>IL T MAX(T x,T y){return x>y?x:y;}
};
using namespace IO;
const int INF=1e9;
const int N=2.1e5;
int head[N],l;
struct re{
    int a,b;
}e[N*2];
IL void arr(int x,int y)
{
    e[++l].a=head[x];
    e[l].b=y;
    head[x]=l;
}
int n,num[N],fa[N],son[N],top[N],f[N][2],g[N][2],v[N],cnt,dfn[N],en[N];
void dfs(int x,int y)
{
    num[x]=1; fa[x]=y;
    for (rint u=head[x];u;u=e[u].a)
    {
        int v=e[u].b;
        if (v!=y)
        {
            dfs(v,x);
            num[x]+=num[v];
            if (num[v]>num[son[x]]) son[x]=v;
        }
    }
}
void dfs1(int x,int y,int z)
{
    top[x]=y; dfn[x]=++cnt;
    if (son[x]) dfs1(son[x],y,x); else en[y]=x;
    for (rint u=head[x];u;u=e[u].a)
    {
        int v=e[u].b;
        if (v!=z&&v!=son[x])
        {
            dfs1(v,v,x);
        }
    }
}
struct re2{
  int a[2][2];
  re2() {a[0][0]=a[0][1]=a[1][0]=a[1][1]=0;}
  re2 operator *(const re2 x)
  {
      re2 c;
      rep(i,0,1)
        rep(j,0,1)
          rep(k,0,1)
            maxa(c.a[i][j],a[i][k]+x.a[k][j]);    
      return c;
  }
}s[N];
#define updata(x) sum[x]=sum[x*2+1]*sum[x*2];
struct sgt{
    re2 sum[N*4];
  void change(int x,int h,int t,int pos,re2 k)
  {
      if (h==t)
       {
          sum[x]=k; return;
      }
      if (pos<=mid) change(x*2,h,mid,pos,k);
      else change(x*2+1,mid+1,t,pos,k);
      updata(x);
  }
  re2 query(int x,int h,int t,int h1,int t1)
  {
      if (h1<=h&&t<=t1) return sum[x];
      if (h1<=mid&&mid<t1)
        return query(x*2+1,mid+1,t,h1,t1)*query(x*2,h,mid,h1,t1);
      else if (h1<=mid) return query(x*2,h,mid,h1,t1);
        else if (mid<t1) return query(x*2+1,mid+1,t,h1,t1);  
  }
}S;

void dfs2(int x,int y)
{
    f[x][1]=MAX(v[x],0); f[x][0]=0;
    g[x][1]=f[x][1]; g[x][0]=0;
    for (rint u=head[x];u;u=e[u].a)
    {
        int v=e[u].b;
        if (v!=y)
        {
          dfs2(v,x);
            if (v!=son[x])
            { 
            f[x][1]+=g[v][0];
            f[x][0]+=MAX(g[v][0],g[v][1]);
          }
          g[x][1]+=g[v][0];
          g[x][0]+=MAX(g[v][0],g[v][1]);
        }
    }
    re2 q;
    q.a[0][0]=q.a[1][0]=f[x][0];
    q.a[0][1]=f[x][1]; q.a[1][1]=-INF;
    S.change(1,1,n,dfn[x],q);
}
void change(int x,int y)
{
      re2 k;
      k=S.query(1,1,n,dfn[x],dfn[x]);
      k.a[0][1]+=-MAX(0,v[x])+MAX(y,0);
      v[x]=y;
      S.change(1,1,n,dfn[x],k);
      int f1=fa[top[x]];
      while (f1)
      {
          k=S.query(1,1,n,dfn[f1],dfn[f1]);
          k.a[0][0]-=MAX(s[top[x]].a[0][1],s[top[x]].a[0][0]);
          k.a[0][1]-=MAX(0,s[top[x]].a[0][0]);
          re2 k1=S.query(1,1,n,dfn[top[x]],dfn[en[top[x]]]);
          k.a[0][0]+=MAX(k1.a[0][0],k1.a[0][1]); 
            k.a[0][1]+=MAX(k1.a[0][0],0);
            k.a[1][0]=k.a[0][0];
            S.change(1,1,n,dfn[f1],k);
          s[top[x]]=k1;
          x=fa[top[x]];
          f1=fa[top[x]];
      }
}
int main()
{
    freopen("1.in","r",stdin);
    freopen("1.out","w",stdout);
    int m;
    read(n); read(m);
    rep(i,1,n) read(v[i]);
    rep(i,1,n-1)
    {
        int x,y;
        read(x); read(y); arr(x,y); arr(y,x);
    }
    dfs(1,0); 
    dfs1(1,1,0);
    dfs2(1,0);
    rep(x,1,n)
       s[x]=S.query(1,1,n,dfn[x],dfn[en[top[x]]]);
    rep(i,1,m)
    {
        int x,y;
        read(x); read(y);
        change(x,y);
        re2 k=S.query(1,1,n,dfn[1],dfn[en[1]]);
        wer(MAX(k.a[0][0],k.a[0][1])); wer2();
    }
    fwrite(sr,1,CC+1,stdout);
    return 0;
}

 

posted @ 2018-12-14 16:04  尹吴潇  阅读(433)  评论(0编辑  收藏  举报