线段树合并nlogn.
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define maxn 400500 using namespace std; int n,l[maxn],r[maxn],val[maxn],x,roots,root[maxn]; int ls[maxn*20],rs[maxn*20],sum[maxn*20],tot=0,cnt=0; long long ans=0,cnt1,cnt2; int read() { char ch;int data=0; while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9') { data=data*10+ch-'0'; ch=getchar(); } return data; } void get_tree(int &now) { x=read(); now=++cnt;val[now]=x; if (x) return; get_tree(l[now]); get_tree(r[now]); } void pushup(int now) { sum[now]=sum[ls[now]]+sum[rs[now]]; } void build(int &now,int left,int right,int val) { now=++tot; if (left==right) {sum[now]=1;return;} int mid=(left+right)>>1; if (val<=mid) build(ls[now],left,mid,val); else build(rs[now],mid+1,right,val); pushup(now);return; } int merge(int x,int y) { if (!x) return y; if (!y) return x; cnt1+=(long long)sum[rs[x]]*sum[ls[y]]; cnt2+=(long long)sum[ls[x]]*sum[rs[y]]; ls[x]=merge(ls[x],ls[y]); rs[x]=merge(rs[x],rs[y]); pushup(x); return x; } void dfs(int x) { if (!x) return; dfs(l[x]);dfs(r[x]); if (!val[x]) { cnt1=cnt2=0; root[x]=merge(root[l[x]],root[r[x]]); ans+=min(cnt1,cnt2); } return; } int main() { n=read(); get_tree(roots); for (int i=1;i<=cnt;i++) if (val[i]) build(root[i],1,n,val[i]); dfs(1); printf("%lld\n",ans); return 0; }