【比赛】NOIP2018 保卫王国
DDP模板题
#include<bits/stdc++.h>
#define ui unsigned int
#define ll long long
#define db double
#define ld long double
#define ull unsigned long long
#define ft first
#define sd second
#define pb(a) push_back(a)
#define mp(a,b) std::make_pair(a,b)
#define ITR(a,b) for(auto a:b)
#define REP(a,b,c) for(register int a=(b),a##end=(c);a<=a##end;++a)
#define DEP(a,b,c) for(register int a=(b),a##end=(c);a>=a##end;--a)
const int MAXN=100000+10;
const ll inf=1e18,vinf=1e12;
int n,m,e,beg[MAXN],nex[MAXN<<1],to[MAXN<<1],size[MAXN],hson[MAXN],st[MAXN],ed[MAXN],top[MAXN],fa[MAXN],w[MAXN],cnt;
ll f[MAXN][2],all;
char type[5];
template<typename T> inline bool chkmin(T &x,T y){return y<x?(x=y,true):false;}
template<typename T> inline bool chkmax(T &x,T y){return y>x?(x=y,true):false;}
struct Matrix{
ll a[2][2];
Matrix(){
REP(i,0,1)REP(j,0,1)a[i][j]=-inf;
};
inline Matrix operator * (const Matrix &A) const {
Matrix B;
REP(i,0,1)REP(k,0,1)REP(j,0,1)chkmax(B.a[i][j],a[i][k]+A.a[k][j]);
return B;
};
};
Matrix val[MAXN];
#define Mid ((l+r)>>1)
#define ls rt<<1
#define rs rt<<1|1
#define lson ls,l,Mid
#define rson rs,Mid+1,r
struct Segment_Tree{
Matrix sum[MAXN<<2];
inline void PushUp(int rt)
{
sum[rt]=sum[ls]*sum[rs];
}
inline void Build(int rt,int l,int r)
{
if(l==r)sum[rt]=val[l];
else Build(lson),Build(rson),PushUp(rt);
}
inline void Update(int rt,int l,int r,int ps,Matrix k)
{
if(l==r)sum[rt]=k;
else
{
if(ps<=Mid)Update(lson,ps,k);
else Update(rson,ps,k);
PushUp(rt);
}
}
inline Matrix Query(int rt,int l,int r,int L,int R)
{
if(L<=l&&r<=R)return sum[rt];
else
{
if(R<=Mid)return Query(lson,L,R);
else if(L>Mid)return Query(rson,L,R);
else return Query(lson,L,R)*Query(rson,L,R);
}
}
};
Segment_Tree T;
#undef Mid
#undef ls
#undef rs
#undef lson
#undef rson
template<typename T> inline void read(T &x)
{
T data=0,w=1;
char ch=0;
while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
if(ch=='-')w=-1,ch=getchar();
while(ch>='0'&&ch<='9')data=((T)data<<3)+((T)data<<1)+(ch^'0'),ch=getchar();
x=data*w;
}
template<typename T> inline void write(T x,char ch='\0')
{
if(x<0)putchar('-'),x=-x;
if(x>9)write(x/10);
putchar(x%10+'0');
if(ch!='\0')putchar(ch);
}
template<typename T> inline T min(T x,T y){return x<y?x:y;}
template<typename T> inline T max(T x,T y){return x>y?x:y;}
inline void insert(int x,int y)
{
to[++e]=y;
nex[e]=beg[x];
beg[x]=e;
}
inline void dfs1(int x,int p)
{
int res=0;
size[x]=1;fa[x]=p;
for(register int i=beg[x];i;i=nex[i])
if(to[i]==p)continue;
else
{
dfs1(to[i],x);
size[x]+=size[to[i]];
if(chkmax(res,size[to[i]]))hson[x]=to[i];
}
}
inline void dfs2(int x,int tp)
{
top[x]=tp;st[x]=++cnt;
val[cnt].a[0][0]=val[cnt].a[0][1]=f[x][0];
val[cnt].a[1][0]=f[x][1];
if(hson[x])
{
val[cnt].a[0][0]-=max(f[hson[x]][0],f[hson[x]][1]);
val[cnt].a[0][1]=val[cnt].a[0][0];
val[cnt].a[1][0]-=f[hson[x]][0];
dfs2(hson[x],tp);ed[x]=ed[hson[x]];
}
else ed[x]=cnt;
for(register int i=beg[x];i;i=nex[i])
if(to[i]==fa[x]||to[i]==hson[x])continue;
else dfs2(to[i],to[i]);
}
inline void dfs(int x)
{
f[x][1]=w[x];
for(register int i=beg[x];i;i=nex[i])
if(to[i]==fa[x])continue;
else
{
dfs(to[i]);
f[x][1]+=f[to[i]][0];
f[x][0]+=max(f[to[i]][0],f[to[i]][1]);
}
}
inline void init()
{
dfs1(1,0);dfs(1);dfs2(1,1);
T.Build(1,1,n);
}
inline void solve(int u,ll v)
{
Matrix A,B,C;
B=T.Query(1,1,n,st[u],st[u]);
A=T.Query(1,1,n,st[top[u]],ed[u]);
B.a[1][0]+=v;
T.Update(1,1,n,st[u],B);
while(u)
{
B=T.Query(1,1,n,st[top[u]],ed[u]);
u=fa[top[u]];
if(!u)break;
C=T.Query(1,1,n,st[u],st[u]);
C.a[0][0]+=max(B.a[0][0],B.a[1][0])-max(A.a[0][0],A.a[1][0]);
C.a[0][1]=C.a[0][0];
C.a[1][0]+=B.a[0][0]-A.a[0][0];
A=T.Query(1,1,n,st[top[u]],ed[u]);
T.Update(1,1,n,st[u],C);
}
}
inline ll value(int ot1,int ot2)
{
Matrix A=T.Query(1,1,n,st[1],ed[1]);
return max(A.a[0][0],A.a[1][0])+(ot1?0:-vinf)+(ot2?0:-vinf);
}
int main()
{
freopen("defense.in","r",stdin);
freopen("defense.out","w",stdout);
read(n);read(m);scanf("%s",type);
REP(i,1,n)read(w[i]),all+=w[i];
REP(i,1,n-1)
{
int u,v;read(u);read(v);
insert(u,v);insert(v,u);
}
init();
while(m--)
{
int a,x,b,y;read(a);read(x);read(b);read(y);
if((fa[a]==b||fa[b]==a)&&!x&&!y)
{
puts("-1");
continue;
}
solve(a,x?-vinf:vinf);
solve(b,y?-vinf:vinf);
printf("%lld\n",all-value(x,y));
solve(a,x?vinf:-vinf);
solve(b,y?vinf:-vinf);
}
return 0;
}