点分治板子
https://ac.nowcoder.com/acm/contest/11174/E
正好用这道题重新搞一下点分治板子
这道题就是个裸的点分治题
主程序里的话要保留
//#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math") //#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native") //#include <immintrin.h> //#include <emmintrin.h> #include <bits/stdc++.h> using namespace std; #define rep(i,h,t) for (int i=h;i<=t;i++) #define dep(i,t,h) for (int i=t;i>=h;i--) #define ll long long #define me(x) memset(x,0,sizeof(x)) #define IL inline #define rint register int inline ll rd(){ ll x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } char ss[1<<24],*A=ss,*B=ss; IL char gc() { return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++; } template<class T>void maxa(T &x,T y) { if (y>x) x=y; } template<class T>void mina(T &x,T y) { if (y<x) x=y; } template<class T>void read(T &x) { int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48); while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f; } const int mo=998244353; ll fsp(int x,int y) { if (y==1) return x; ll ans=fsp(x,y/2); ans=ans*ans%mo; if (y%2==1) ans=ans*x%mo; return ans; } struct cp { ll x,y; cp operator +(cp B) { return (cp){x+B.x,y+B.y}; } cp operator -(cp B) { return (cp){x-B.x,y-B.y}; } ll operator *(cp B) { return x*B.y-y*B.x; } int half() { return y < 0 || (y == 0 && x < 0); } }; const int N=3e5; struct re{ int a,b,c; }a[N]; int gg[N]; int av[N]; bool vis[N]; ll c[N]; int n,m,rt,son[N],f[N],sum,d[N]; vector<int> pq[N]; ll ans=0; bool cmp(re x,re y) { return x.a<y.a; } #define mid ((h+t)>>1) struct sgt{ vector<int> ve; ll sum[N*4]; int v[N*4]; void clear() { for (auto u:ve) { sum[u]=v[u]=0; } ve.clear(); } void change(int x,int h,int t,int pos,ll k) { ve.push_back(x); sum[x]=(sum[x]+k)%mo; v[x]++; if (h==t) return; if (pos<=mid) change(x*2,h,mid,pos,k); else change(x*2+1,mid+1,t,pos,k); } ll q1(int x,int h,int t,int h1,int t1) { if (h1<=h&&t<=t1) return sum[x]; ll ans=0; if (h1<=mid) ans+=q1(x*2,h,mid,h1,t1); if (mid<t1) ans+=q1(x*2+1,mid+1,t,h1,t1); ans%=mo; return ans; } ll q2(int x,int h,int t,int h1,int t1) { if (h1<=h&&t<=t1) return v[x]; ll ans=0; if (h1<=mid) ans+=q2(x*2,h,mid,h1,t1); if (mid<t1) ans+=q2(x*2+1,mid+1,t,h1,t1); return ans; } }S; void gr(int x,int y) { son[x]=1;f[x]=0; for (auto v:pq[x]) if (vis[v]&&v!=y) { gr(v,x); son[x]+=son[v]; f[x]=max(f[x],son[v]); } f[x]=max(f[x],sum-son[x]); if (f[x]<f[rt]) rt=x; } vector<re> ve,an; void gd(int x,int fa,int mx,int mn) { ve.push_back((re){mx,mn}); for (auto v:pq[x]) if (vis[v]&&v!=fa) { gd(v,x,max(mx,gg[v]),min(mn,gg[v])); } } ll gao(vector<re> ve) { S.clear(); sort(ve.begin(),ve.end(),cmp); ll ans=0; for (auto v:ve) { ans=(ans+S.q1(1,1,n,1,v.b)%mo*c[v.a])%mo; if (v.b!=n) ans=(ans+(c[v.a]*c[v.b])%mo*S.q2(1,1,n,v.b+1,n))%mo; S.change(1,1,n,v.b,c[v.b]); } return ans; } void solve(int x,int y) { vis[x]=0; an.clear(); for (auto v:pq[x]) if (vis[v]&&v!=x) { gd(v,x,max(gg[v],gg[x]),min(gg[v],gg[x])); ans=(ans-gao(ve))%mo; an.insert(an.end(),ve.begin(),ve.end()); ve.clear(); } an.push_back((re){gg[x],gg[x]}); ans=(ans+gao(an))%mo; for (auto v:pq[x]) if (vis[v]) { rt=0; sum=son[v]; gr(v,x); solve(rt,y+1); } } int main() { freopen("2.in","r",stdin); freopen("1.out","w",stdout); ios::sync_with_stdio(false); cin>>n; rep(i,1,n) cin>>av[i]; rep(i,1,n) c[i]=av[i]; sort(c+1,c+n+1); f[0]=1e9; int nn=unique(c+1,c+n+1)-c-1; rep(i,1,n) gg[i]=lower_bound(c+1,c+nn+1,av[i])-c; rep(i,1,n) vis[i]=1; rep(i,1,n-1) { int x,y; cin>>x>>y; pq[x].push_back(y); pq[y].push_back(x); } sum=n; gr(1,0); solve(rt,0); rep(i,1,n) ans=(ans+1ll*av[i]*av[i])%mo; cout<<(ans+mo)%mo<<endl; return 0; }