bzoj2164 采矿
题目描述:
题解:
线段树。
考虑在线段树上维护区间内在某个点选$i$个的最大值,以及区间内任意分配$i$个点的最大值。
前者合并$O(m)$,后者合并$O(m^2)$。
所以复杂度$O(nm^2+mlog^2n+m^2logn)$,可过。
代码:
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; typedef long long ll; const int N = 20050; const int M = 55; const int X = (1<<16); const int Y = 2147483647; template<typename T> inline void read(T&x) { T f = 1,c = 0;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){c=c*10+ch-'0';ch=getchar();} x = f*c; } int n,m,A,B,Q,k[N][M],hed[N],cnt; inline int read() { A = (((A^B) + (B/X) + (B*X))&Y); B = (((A^B) + (A/X) + (A*X))&Y); return (A^B)%Q; } void get_k(int*k) { for(int i=1;i<=m;i++) k[i]=read(); sort(k+1,k+1+m); } struct EG { int to,nxt; }e[N]; void ae(int f,int t) { e[++cnt].to = t; e[cnt].nxt = hed[f]; hed[f] = cnt; } int dep[N],fa[N],son[N],top[N],siz[N],tin[N],tout[N],pla[N],tim; void dfs0(int u,int f) { fa[u] = f,siz[u] = 1,dep[u] = dep[f]+1; for(int j=hed[u];j;j=e[j].nxt) { int to = e[j].to; dfs0(to,u); siz[u]+=siz[to]; if(siz[to]>siz[son[u]])son[u]=to; } } void dfs1(int u,int Top) { top[u] = Top,tin[u] = ++tim,pla[tim] = u; if(son[u])dfs1(son[u],Top); for(int j=hed[u];j;j=e[j].nxt) { int to = e[j].to; if(to!=son[u]) dfs1(to,to); } tout[u] = tim; } struct node { ll s[M]; void reset(int i) { memset(s,0,sizeof(s)); if(!i)return ; for(int j=1;j<=m;j++) s[j]=k[i][j]; } node operator + (const node&a)const { node ret;ret.reset(0); for(int i=1;i<=m;i++)ret.s[i]=max(s[i],a.s[i]); return ret; } node operator * (const node&a)const { node ret;ret.reset(0); for(int i=1;i<=m;i++) for(int j=0;j<=i;j++) ret.s[i]=max(ret.s[i],a.s[j]+s[i-j]); return ret; } }; struct segtree { node s1[N<<2],s2[N<<2]; void update(int u){s1[u]=s1[u<<1]+s1[u<<1|1],s2[u]=s2[u<<1]*s2[u<<1|1];} void build(int l,int r,int u) { if(l==r){s1[u].reset(pla[l]),s2[u].reset(pla[l]);return ;} int mid = (l+r)>>1; build(l,mid,u<<1),build(mid+1,r,u<<1|1); update(u); } void insert(int l,int r,int u,int qx) { if(l==r){s1[u].reset(pla[l]),s2[u].reset(pla[l]);return ;} int mid = (l+r)>>1; if(qx<=mid)insert(l,mid,u<<1,qx); else insert(mid+1,r,u<<1|1,qx); update(u); } node qs1(int l,int r,int u,int ql,int qr) { if(l==ql&&r==qr)return s1[u]; int mid = (l+r)>>1; if(qr<=mid)return qs1(l,mid,u<<1,ql,qr); else if(ql>mid)return qs1(mid+1,r,u<<1|1,ql,qr); else return qs1(l,mid,u<<1,ql,mid)+qs1(mid+1,r,u<<1|1,mid+1,qr); } node qs2(int l,int r,int u,int ql,int qr) { if(l==ql&&r==qr)return s2[u]; int mid = (l+r)>>1; if(qr<=mid)return qs2(l,mid,u<<1,ql,qr); else if(ql>mid)return qs2(mid+1,r,u<<1|1,ql,qr); else return qs2(l,mid,u<<1,ql,mid)*qs2(mid+1,r,u<<1|1,mid+1,qr); } }tr; int main() { // freopen("tt.in","r",stdin); read(n),read(m),read(A),read(B),read(Q); for(int i=1;i<=n;i++) get_k(k[i]); for(int i=2,f;i<=n;i++) read(f),ae(f,i); dfs0(1,0),dfs1(1,1);tr.build(1,n,1); int C;read(C); for(int op,u,v,p,i=1;i<=C;i++) { read(op); if(!op) { read(p); get_k(k[p]); tr.insert(1,n,1,tin[p]); }else { read(u),read(v); node ans; if(u==v) { ans = tr.qs2(1,n,1,tin[u],tout[u]); }else { node k1,k2;k1.reset(0),k2.reset(0); int j; for(j=fa[u];top[j]!=top[v];j=fa[top[j]]) k1 = k1+tr.qs1(1,n,1,tin[top[j]],tin[j]); k1 = k1+tr.qs1(1,n,1,tin[v],tin[j]); k2 = tr.qs2(1,n,1,tin[u],tout[u]); ans = k1*k2; } printf("%lld\n",ans.s[m]); } } return 0; }