HDU6368
http://acm.hdu.edu.cn/showproblem.php?pid=6368
题意: 构造最小方差生成树
首先我们先从方差的定义出发可知 方差必定是一段连续的值 方差最小
最暴力的方法是 我们枚举平均数 然后转化最小生成树check 明显的复杂度太高
我们考虑到每条边都有一个作用区间 因为当加入某一条边形成环时 为了维持稳定的方差 我们需要将这个环里面的最小边踢掉(至于为什么 可以手动模拟一下
这样的话 我们就把问题简化成插入或者加入某一条边 来维护Σwe^2和Σwe 很显然的 这是LCT加速即可
这个题 在场上并没有转化成区间 不然还可以搏一下 赛后补题遇到了三个问题:
1.在加边和删边维护LCT是取最小方差不应该在模的意义下
2.因为我们考虑的是枚举每一条边的最小方差 也就是说每条边都要构造一颗最小方差树 那也就是要离线将后面权值比较大但是又独一无二的边先加入进来 也就是当处理有环时应该同时加入边和最小边(不用分先后)
3.然后就是写戳了...还好抢救成功
#include <iostream> #include <algorithm> #include <cstdio> #include <cstring> #define ll long long const int MAXN=5e5+10; using namespace std; const int mod=998244353; const int inf=1e9; int pre[MAXN],ch[MAXN][2],res[MAXN],pos[MAXN]; ll key[MAXN]; int cnt; int st[MAXN],tp; pair<int,int>d[MAXN]; bool rt[MAXN];int n,m; typedef struct node{ int u,v,vul; friend bool operator<(node aa,node bb){ return aa.vul<bb.vul; } }node; node que[MAXN]; ll ans1,ans2; inline int newnode(ll vul){ int x;x=++cnt;pre[x]=ch[x][0]=ch[x][1]=res[x]=0;key[x]=vul; rt[x]=1;pos[x]=x; return x; } inline void reverse(int x){ if(!x)return ; swap(ch[x][0],ch[x][1]); res[x]^=1; } inline void push(int x){ if(res[x]){ reverse(ch[x][0]); reverse(ch[x][1]); res[x]^=1; } } inline void up(int r){ pos[r]=r; if(key[pos[ch[r][0]]]<key[pos[r]]) pos[r]=pos[ch[r][0]]; if(key[pos[ch[r][1]]]<key[pos[r]]) pos[r]=pos[ch[r][1]]; } inline bool pd1(int x){ return ch[pre[x]][0]!=x&&ch[pre[x]][1]!=x; } inline void P(int x){ int i;st[++tp]=x; for(i=x;!pd1(i);i=pre[i]) st[++tp]=pre[i]; for(;tp;tp--) push(st[tp]); } /*void P(int r){ if(!rt[r]) P(pre[r]); push(r); }*/ inline void rotate(int x,int kind){ int y=pre[x]; pre[ch[x][kind]]=y;ch[y][!kind]=ch[x][kind]; if(!rt[y])ch[pre[y]][ch[pre[y]][1]==y]=x; else rt[y]=0,rt[x]=1; pre[x]=pre[y];ch[x][kind]=y;pre[y]=x; up(y); } inline void splay(int x){ P(x); while(!rt[x]){ if(rt[pre[x]])rotate(x,ch[pre[x]][0]==x); else{ int y=pre[x];int kind=ch[pre[y]][0]==y; if(ch[y][kind]==x)rotate(x,!kind),rotate(x,kind); else rotate(y,kind),rotate(x,kind); } } up(x); } inline void access(int x){ int y=0; while(x){ splay(x); if(ch[x][1])rt[ch[x][1]]=1,pre[ch[x][1]]=x,ch[x][1]=0; if(y)rt[y]=0; ch[x][1]=y;up(x); y=x; x=pre[x]; } } inline void mroot(int u){ access(u); splay(u); reverse(u); } inline bool pd(int u,int v){ while(pre[u])u=pre[u]; while(pre[v])v=pre[v]; return u==v; } inline void Link(int u,int v,ll vul){ int t1=newnode(vul);d[t1].first=u;d[t1].second=v; mroot(u);mroot(v); pre[u]=t1; pre[v]=t1; } void destory1(int u,int v){ mroot(u);access(v);splay(v); rt[u]=rt[v]=1;pre[u]=pre[v]=0;ch[v][0]=0; up(u);up(v); } int vis[MAXN],id; inline void destory(int u,int v){ mroot(u);access(v);splay(v); int t1=pos[v];id=t1; destory1(t1,d[t1].first);destory1(t1,d[t1].second); } inline ll ksm(ll a,ll b,ll c){ ll ans=1; while(b){ if(b&1)ans=ans*a%c; a=a*a%c;b=b>>1; } return ans; } typedef struct Node{ int vul;ll sum1,sum2; friend bool operator<(Node aa,Node bb){ if(aa.vul==bb.vul)return aa.sum1<bb.sum1; return aa.vul<bb.vul; } }Node; Node p[MAXN]; int main(){ //freopen("1.in","r",stdin); int _;scanf("%d",&_); while(_--){ scanf("%d%d",&n,&m); cnt=0;key[0]=inf;ans1=0;ans2=0;ll ans3=ksm(n-1,mod-2,mod); for(int i=1;i<=m;i++)scanf("%d%d%lld",&que[i].u,&que[i].v,&que[i].vul),vis[i]=-1; for(int i=1;i<=n;i++)newnode(inf),vis[i]=-1; sort(que+1,que+m+1);int t; int cnt1=0;ll h1=1e18,h2=1e18; for(int i=1;i<=m;i++){ if(pd(que[i].u,que[i].v)==0)Link(que[i].u,que[i].v,que[i].vul),p[++cnt1].vul=-1*inf,p[cnt1].sum1=1LL*que[i].vul*1LL*que[i].vul,p[cnt1].sum2=que[i].vul; else{ destory(que[i].u,que[i].v);t=id; vis[t-n]=i;p[++cnt1].vul=key[t]+que[i].vul;p[cnt1].sum1=1ll*que[i].vul*que[i].vul; p[cnt1].sum2=que[i].vul; Link(que[i].u,que[i].v,que[i].vul); } } for(int i=1;i<=m;i++){ if(vis[i]==-1)p[++cnt1].vul=inf,p[cnt1].sum1=-1*1LL*que[i].vul*que[i].vul-1,p[cnt1].sum2=-1*que[i].vul; else p[++cnt1].vul=que[vis[i]].vul+que[i].vul,p[cnt1].sum1=-1*1LL*que[i].vul*que[i].vul-1,p[cnt1].sum2=-1*que[i].vul; } sort(p+1,p+cnt1+1); //for(int i=1;i<=cnt1;i++)cout<<p[i].vul<<" "<<p[i].sum1<<" "<<p[i].sum2<<endl; int cnt2=0; for(int i=1;i<=cnt1;i++){ if(p[i].sum1>=0)cnt2++; else cnt2--,p[i].sum1++; ans1+=p[i].sum1;ans2+=p[i].sum2; // cout<<ans1<<" "<<ans2<<endl; if(cnt2==n-1){ ll h = ans1-(ans2/(n - 1))*ans2-(ans2%(n-1))*ans2/(n - 1); ll l=-((ans2%(n-1))*ans2%(n - 1)); if(l < 0) l += n - 1 , h--; // cout<<h<<"====="<<l<<endl; if(make_pair(h1,h2)>make_pair(h,l)){ h1=h;h2=l; } } } printf("%lld\n",((h1+h2*ans3)%mod*ans3%mod)); } return 0; }