专项测试 数据结构2
A. 5a2m5Yab6aKY
咕咕咕
B. Q0Y1NzFE
Q0YzMTAwLOW8oOWPo+aUvlQxCg== 5bGF54S25rKh5pS+VDHvvIHlt67or4TvvIHvvIHvvIEK
把冰茶姬的合并拍成序列,映射每个数的位置,使得修改变得连续
两类都这样搞一下
然后再开两个线段树分别维护区间加法,和最近的一次赋 \(0\) 的时间即可
然后对于每一次询问,查询他在这次询问时最近的一次赋值,然后在那个时间点给他赋个 \(0\) 就行
复杂度 \(O(n\log{n})\)
Code
#include<bits/stdc++.h>
#define int long long
#define lson rt<<1
#define rson rt<<1|1
#define rint signed
#define inf 0x3f3f3f3f3f3f3f3f
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
int n,m;
char st[3];
int pos1[100010],pos2[100010],p1,p2;
int ans[100010];
vector<int> u0[500010];
struct DSU{
int fa[100010],siz[100010];
vector<int>vec[100010];
int getfa(int x){return fa[x]==x?x:fa[x]=getfa(fa[x]);}
inline void premerge(int x,int y){
x=getfa(fa[x]),y=getfa(fa[y]);
if(x==y) return ;
if(vec[x].size()>vec[y].size()) swap(x,y);
fa[x]=y;for(auto v:vec[x]) vec[y].emplace_back(v);
}
inline void merge(int x,int y){
x=getfa(fa[x]),y=getfa(fa[y]);
if(x==y) return ;
if(siz[x]>siz[y]) swap(x,y);
fa[x]=y,siz[y]+=siz[x];
}
inline void init(){for(int i=1;i<=n;i++) fa[i]=i,siz[i]=1;}
}s1,s2;
struct OPT{int op,x,y;}L[500010];
namespace Seg1{
struct Seg{int atag;}t[100010*4];
void upd(int rt,int l,int r,int L,int R,int k){
if(L<=l&&r<=R) return t[rt].atag+=k,void();
int mid=(l+r)>>1;
if(L<=mid) upd(lson,l,mid,L,R,k);
if(R>mid) upd(rson,mid+1,r,L,R,k);
}
int query(int rt,int l,int r,int pos){
if(l==r) return t[rt].atag;
int mid=(l+r)>>1;
if(pos<=mid) return t[rt].atag+query(lson,l,mid,pos);
else return t[rt].atag+query(rson,mid+1,r,pos);
}
}
namespace Seg2{
struct Seg{int mx;}t[100010*4];
void upd(int rt,int l,int r,int L,int R,int k){
if(L<=l&&r<=R) return t[rt].mx=k,void();
int mid=(l+r)>>1;
if(L<=mid) upd(lson,l,mid,L,R,k);
if(R>mid) upd(rson,mid+1,r,L,R,k);
}
int query(int rt,int l,int r,int pos){
if(l==r) return t[rt].mx;
int mid=(l+r)>>1;
if(pos<=mid) return max(t[rt].mx,query(lson,l,mid,pos));
else return max(t[rt].mx,query(rson,mid+1,r,pos));
}
}
signed main(){
#ifdef LOCAL
freopen("in","r",stdin);
freopen("out","w",stdout);
#endif
n=read(),m=read();
s1.init(),s2.init();
for(int i=1;i<=n;i++) s1.vec[i].emplace_back(i),s2.vec[i].emplace_back(i);
for(int i=1;i<=m;i++){
scanf("%s",st+1);
if(st[1]=='U'){L[i].op=1,L[i].x=read(),L[i].y=read();s1.premerge(L[i].x,L[i].y);}
if(st[1]=='M'){L[i].op=2,L[i].x=read(),L[i].y=read();s2.premerge(L[i].x,L[i].y);}
if(st[1]=='A') L[i].op=3,L[i].x=read();
if(st[1]=='Z') L[i].op=4,L[i].x=read();
if(st[1]=='Q') L[i].op=5,L[i].x=read();
}
for(int i=1;i<=n;i++) if(s1.getfa(i)==i) for(auto v:s1.vec[i]) pos1[v]=++p1;
for(int i=1;i<=n;i++) if(s2.getfa(i)==i) for(auto v:s2.vec[i]) pos2[v]=++p2;
s1.init(),s2.init();
for(int i=1;i<=m;i++){
if(L[i].op==2) s2.merge(L[i].x,L[i].y);
if(L[i].op==4){L[i].x=s2.getfa(L[i].x);Seg2::upd(1,1,n,pos2[L[i].x],pos2[L[i].x]+s2.siz[L[i].x]-1,i);}
if(L[i].op==5) u0[Seg2::query(1,1,n,pos2[L[i].x])].emplace_back(L[i].x);
}
for(int i=1,val;i<=m;i++){
for(auto v:u0[i]){val=Seg1::query(1,1,n,pos1[v]);Seg1::upd(1,1,n,pos1[v],pos1[v],-val);}
if(L[i].op==1) s1.merge(L[i].x,L[i].y);
if(L[i].op==3){L[i].x=s1.getfa(L[i].x);Seg1::upd(1,1,n,pos1[L[i].x],pos1[L[i].x]+s1.siz[L[i].x]-1,s1.siz[L[i].x]);}
if(L[i].op==5) printf("%lld\n",Seg1::query(1,1,n,pos1[L[i].x]));
}
return 0;
}
C. U1BPSiBDT1Q2
先看 \(O(n^2)\) 的做法
对于每一个点都找到他子树内的一个点和他匹配时的答案,然后取 \(\min\) 为这个点的最小贡献
匹配的答案就是路径上所有点的不在路径上的儿子的贡献和路径权值的平方
设 \(sum_u\) 为从目前子树的根到子树内的点的路径上不在路径上的儿子的贡献和
\(s_u\) 为从根到 \(u\) 的权值和
最后的答案为 \(w\)
那么贡献的式子为 \(w=\min\{sum_v+(s_v-s_u)^2\}\)
拆开平方变成 \(w=sum_v+s_v^2+s_u^2-2\times s_v \times s_u\)
发现可以斜率优化变成
\(sum_v+s_v^2=2\times s_v \times s_u + w -s_u^2\)
\(y=sum_v+s_v^2 , x=s_v,k=2\times s_u,b=w-s_u^2\)
最小化截距即可,维护下凸包,用 \(set\)
再对于每一个 \(set\) 维护一个 \(tag\) 表示整体平移的距离
Code
#include<bits/stdc++.h>
#define int long long
#define rint signed
#define lson rt<<1
#define rson rt<<1|1
#define inf 0x3f3f3f3f3f3f3f3f
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
int n;
int a[1000010],sum[1000010];
int f[1000010],fa[1000010];
int head[1000010],ver[1000010],to[1000010],tot;
int b[1000010];
set<pair<int,int>>S[1000010];//x,y
set<pair<int,int>>::iterator iter,iterr;
inline void add(int x,int y){
ver[++tot]=y;
to[tot]=head[x];
head[x]=tot;
}
inline double slope(pair<int,int> x,pair<int,int> y){
return 1.0*(y.second-x.second)/(y.first-x.first);
}
inline bool cmp(pair<int,int> x,pair<int,int> y,pair<int,int> z){
return slope(x,y)<slope(y,z);
}
inline void insert(int x,pair<int,int> k){
k.second-=b[x];
iter=S[x].lower_bound(k);
if(iter!=S[x].end()&&iter->first==k.first){
if(iter->second<=k.second) return;
else S[x].erase(iter);
}
while(!S[x].empty()){
iter=S[x].lower_bound(k);
if(iter==S[x].begin()) break;
iter--;
if(iter==S[x].begin()) break;
iterr=iter;
iter--;
if(!cmp(*iter,*iterr,k)) S[x].erase(iterr);
else break;
}
while(!S[x].empty()){
iter=S[x].lower_bound(k);
if(iter==S[x].end()) break;
iterr=iter;iter++;
if(iter==S[x].end()) break;
if(!cmp(k,*iterr,*iter)) S[x].erase(iterr);
else break;
}
iter=S[x].lower_bound(k);
if(iter==S[x].begin()||iter==S[x].end()) return S[x].insert(k),void();
iterr=iter;iter--;
if(cmp(*iter,k,*iterr)) S[x].insert(k);
}
inline int getmn(int x){
if(S[x].empty()) return inf;
pair<int,int> A,L,R;
if(S[x].size()==1) A=*S[x].begin();
else{
int l=S[x].begin()->first,r=(--S[x].end())->first;
while(l<r){
int mid=(l+r+1)>>1;
iter=S[x].lower_bound(make_pair(mid,-inf));
if(iter==S[x].begin()) l=mid;
else{
R=*iter;iter--;L=*iter;L.second+=b[x],R.second+=b[x];
if(slope(L,R)<2*sum[fa[x]]) l=mid;
else r=mid-1;
}
}
A=*S[x].lower_bound(make_pair(l,-inf));
}
A.second+=b[x];
return (sum[fa[x]]*sum[fa[x]])-2*sum[fa[x]]*A.first+A.second;
}
void dfs(int x,int fath){
fa[x]=fath,sum[x]=sum[fath]+a[x];
int maxson=0,SSS=0;
for(int i=head[x];i;i=to[i]){
int y=ver[i];
dfs(y,x);SSS+=f[y];
if(S[y].size()>=S[maxson].size()) maxson=y;
}
swap(S[x],S[maxson]);b[x]=b[maxson]+SSS-f[maxson];
//printf("SSS : %lld\n",SSS);
//printf("maxson : %lld\n",maxson);
//printf("b[%lld] : %lld\n",x,b[x]);
for(int i=head[x];i;i=to[i]){
int y=ver[i];
if(y==maxson) continue;
//printf("x : %lld y : %lld SSS : %lld f[y] : %lld\n",x,y,SSS,f[y]);
b[y]+=SSS-f[y];
for(auto VVV:S[y]) VVV.second+=b[y],insert(x,VVV);
}
//printf("b[%lld] : %lld\n",x,b[x]);
f[x]=min(getmn(x),SSS+(a[x]*a[x]));
//printf("f[%lld] : %lld\n",x,f[x]);
insert(x,make_pair(sum[x],SSS+sum[x]*sum[x]));
}
signed main(){
#ifdef LOCAL
freopen("in","r",stdin);
freopen("out","w",stdout);
#endif
n=read();
for(int i=1;i<=n;i++) a[i]=read();
for(int i=2;i<=n;i++) add(read(),i);
dfs(1,0);
printf("%lld\n",f[1]);
return 0;
}