『NOIP 2019Day2 T3』 保卫王国(defense)
重温NOIP2018的试题,发现只要好好想想还是能想出一些东西的。
比如说本题是一个DDP的模板题,硬是做成了倍增优化DP的题目。
对于给出的$n$个节点的树,每个点都有点权$v_i$,共$Q$次询问。
每次询问指定两个点的状态取或者不取,询问树中最小权覆盖集。
如果最小权覆盖集不存在,输出$-1$
对于$100\%$保证$1 \leq n,m \leq 10^5 , 1 \leq v_i \leq 10^9$
Solution :
我们设$g[u][0/1]$表示节点$u$是否选择,$u$的子树最小权覆盖集。
一个显然的转移是$g[u][0] = \sum\limits_{v \in u_{son}} g[v][1] , g[u][1] = val[u] + \sum\limits_{v \in u_{son}} min(g[v][0],g[v][1])$
为了解决$Q$个询问,我们需要设倍增数组来优化上述转移。
设$其中f[u][i][p][q] ( 其中p,q =0,1)$表示从$father(u)$到$u$向上跳$2^i$步的节点,其中$u$状态是$p $,$u$向上跳$2^i$得到的节点状态是$q$这段贡献最小值。
我们显然可以通过一次$dfs$,计算出$f[u][0][0/1][0/1] , g[u][0/1]$的值
- $f[u][0][0][0] = inf$
- $f[u][0][0][1] = val[father(u)] + \sum\limits_{v \in father(u)_{son} v \neq u}min\{g[v][0],g[v][1]\}$
- $f[u][0][1][0] =\sum\limits_{v \in father(u)_{son}} g[v][1]$
- $f[u][0][1][1]= val[father(u)] + \sum\limits_{v \in father(u)_{son} v \neq u}min\{g[v][0],g[v][1]\}$
然后我们也能通过$O(n \ log_2 \ n)$的复杂度预处理上述的倍增的数组,和$lca$的数组
处理每个询问的时候,把$u,v$在同一条链上的情况和$u,v$不在一条链上的情况进行讨论
(主要是向上跳的位置及贡献计算不同)
在向上跳的时候记录两个变量$ret0,ret1$表示当前节点取或不取当前的子树最小权覆盖集。
由于向上跳的步数倍增预处理完毕,一次询问最多只会向上跳$log_2n$步,
所以本题的时间复杂度就是$O((n+m) log_2 n )$
# include <bits/stdc++.h> # define int long long # define inf (1e12) using namespace std; const int N=5e5+10; struct rec{ int pre,to;}a[N<<1]; int n,m,tot; char type[10]; int f[N][22][2][2],g[N][2],d[N][22],dep[N]; int sum1[N],sum2[N],head[N],val[N]; inline int read() { int X=0,w=0; char c=0; while(c<'0'||c>'9') {w|=c=='-';c=getchar();} while(c>='0'&&c<='9') X=(X<<3)+(X<<1)+(c^48),c=getchar(); return w?-X:X; } void write(int x) { if (x<0) putchar('-'),x=-x; if (x>9) write(x/10); putchar('0'+x%10); } void adde(int u,int v) { a[++tot].pre=head[u]; a[tot].to=v; head[u]=tot; } void dfs1(int u,int fa) { dep[u]=dep[fa]+1; g[u][0]=0; g[u][1]=val[u]; d[u][0]=fa; sum1[u]=0; sum2[u]=0; for (int i=head[u];i;i=a[i].pre) { int v=a[i].to; if (v==fa) continue; dfs1(v,u); g[u][0]+=g[v][1]; g[u][1]+=min(g[v][0],g[v][1]); sum1[u]+=g[v][1]; sum2[u]+=min(g[v][0],g[v][1]); } } void dfs2(int u,int fa) { for (int i=head[u];i;i=a[i].pre) { int v=a[i].to; if (v==fa) continue; dfs2(v,u); } if (u != 1) { f[u][0][0][0] = inf; f[u][0][0][1] = val[fa] + sum2[fa] - min(g[u][0],g[u][1]); f[u][0][1][0] = sum1[fa] - g[u][1]; f[u][0][1][1] = val[fa] + sum2[fa] - min(g[u][0],g[u][1]); } } int lca(int u,int v) { if (dep[u]<dep[v]) swap(u,v); for (int i=21;i>=0;i--) if (dep[d[u][i]]>=dep[v]) u=d[u][i]; if (u==v) return u; for (int i=21;i>=0;i--) if (d[u][i]!=d[v][i]) u=d[u][i],v=d[v][i]; return d[u][0]; } int work2(int u,int op1,int v,int op2) { int ret0=g[u][0],ret1=g[u][1]; if (op1 == 1) ret0=inf; else ret1=inf; bool flag = true; int tmp0,tmp1; for (int i=21;i>=0;i--) { if (dep[d[u][i]]<=dep[v]) continue; tmp0=ret0,tmp1=ret1; ret0 = min(tmp0+f[u][i][0][0],tmp1+f[u][i][1][0]); ret1 = min(tmp0+f[u][i][0][1],tmp1+f[u][i][1][1]); u = d[u][i]; } tmp0=ret0,tmp1=ret1; ret0 = tmp1 + sum1[v] - g[u][1]; ret1 = min(tmp1,tmp0) + sum2[v] - min(g[u][0],g[u][1]) + val[v]; if (op2 == 0) ret1=inf; else ret0=inf; u = v; for (int i=21;i>=0;i--) { if (dep[d[u][i]]<dep[1]) continue; tmp0=ret0,tmp1=ret1; ret0 = min(tmp0+f[u][i][0][0],tmp1+f[u][i][1][0]); ret1 = min(tmp0+f[u][i][0][1],tmp1+f[u][i][1][1]); u=d[u][i]; } int ans = min(ret0,ret1); if (ans>=inf) return -1; else return ans; } int work(int u,int op1,int v,int op2) { if (dep[u]<dep[v]) swap(u,v),swap(op1,op2); int l=lca(u,v); if (l == v) return work2(u,op1,v,op2); int ret0=g[u][0],ret1=g[u][1]; if (op1 == 1) ret0=inf; else ret1=inf; bool flag = true; int tmp0,tmp1; for (int i=21;i>=0;i--) { if (dep[d[u][i]]<=dep[l]) continue; tmp0=ret0,tmp1=ret1; ret0 = min(tmp0+f[u][i][0][0],tmp1+f[u][i][1][0]); ret1 = min(tmp0+f[u][i][0][1],tmp1+f[u][i][1][1]); u = d[u][i]; } int val1[2]; val1[0] = ret0; val1[1] = ret1; ret0=g[v][0],ret1=g[v][1]; if (op2 == 1) ret0=inf; else ret1=inf; flag = true; for (int i=21;i>=0;i--) { if (dep[d[v][i]]<=dep[l]) continue; tmp0=ret0,tmp1=ret1; ret0 = min(tmp0+f[v][i][0][0],tmp1+f[v][i][1][0]); ret1 = min(tmp0+f[v][i][0][1],tmp1+f[v][i][1][1]); v = d[v][i]; } int val2[2]; val2[0] = ret0; val2[1] = ret1; ret0 = val1[1]+val2[1]+sum1[l]-g[u][1]-g[v][1]; ret1 = val[l] + min(val1[0],val1[1]) + min(val2[0],val2[1]) + sum2[l] - min(g[u][0],g[u][1]) - min(g[v][0],g[v][1]); u = l; for (int i=21;i>=0;i--) { if (dep[d[u][i]]<dep[1]) continue; tmp0=ret0,tmp1=ret1; ret0 = min(tmp0+f[u][i][0][0],tmp1+f[u][i][1][0]); ret1 = min(tmp0+f[u][i][0][1],tmp1+f[u][i][1][1]); u=d[u][i]; } int ans = min(ret0,ret1); if (ans>=inf) return -1; else return ans; } signed main() { n=read();m=read(); scanf("%s",type); for (int i=1;i<=n;i++) val[i]=read(); for (int i=2;i<=n;i++) { int u=read(),v=read(); adde(u,v); adde(v,u); } dfs1(1,0); dfs2(1,0); for (int i=1;i<=21;i++) for (int j=1;j<=n;j++) d[j][i]=d[d[j][i-1]][i-1]; for (int i=1;i<=21;i++) for (int u=1;u<=n;u++) { f[u][i][0][0]=f[u][i][0][1]=f[u][i][1][0]=f[u][i][1][1]=inf; for (int p=0;p<=1;p++) { f[u][i][0][0] = min(f[u][i][0][0],f[u][i-1][0][p] + f[d[u][i-1]][i-1][p][0]); f[u][i][0][1] = min(f[u][i][0][1],f[u][i-1][0][p] + f[d[u][i-1]][i-1][p][1]); f[u][i][1][0] = min(f[u][i][1][0],f[u][i-1][1][p] + f[d[u][i-1]][i-1][p][0]); f[u][i][1][1] = min(f[u][i][1][1],f[u][i-1][1][p] + f[d[u][i-1]][i-1][p][1]); } } while (m--) { int u=read(),op1=read(),v=read(),op2=read(); int ans = work(u,op1,v,op2); write(ans); putchar('\n'); } return 0; }