20190803

信仰圣光

题意简述

求对于有 $n$ 个点的 $e$ 个简单环。有 $k$ 个守卫,每个环至少要有一个守卫的方案数。

$1\leq k\leq n\leq 152501$

$solution:$

考虑对于朴素 $O(n^2)\space dp$ 的优化,简单思考后发现 $dp$ 的过程其实是一个背包卷积的过程。

考虑对每个简单环构造生成函数 $A$ ,则 $A_i=C_{num}^i$ , $num$ 表示其环中节点个数。

$B=\prod_{i=1}^e A$ ,答案则为 $B_k$。

现在的问题变成了求 $e$ 个多项式的卷积,而暴力卷积时间复杂度为 $O(n^2\log n)$ ,考虑分治优化即可。

时间复杂度 $O(n\log^2 n)$ 。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<vector>
#define int long long
#define mod 998244353
using namespace std;
inline int read(){
    int f=1,ans=0;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();}
    return f*ans;
}
const int MAXN=2402502;
int T,n,k,ff[MAXN],M[MAXN],Num[MAXN];
vector<int> ve[MAXN];
int fac[MAXN],inv[MAXN],infac[MAXN];
inline void init(){
    fac[0]=fac[1]=1;
    for(int i=2;i<=152501;i++) fac[i]=fac[i-1]*i,fac[i]%=mod;
    inv[1]=1;for(int i=2;i<=152501;i++) inv[i]=((mod-mod/i)*inv[mod%i])%mod;
    infac[0]=1;
    for(int i=1;i<=152501;i++) infac[i]=infac[i-1]*inv[i],infac[i]%=mod;
    return;
}
int find(int x){
    if(ff[x]==x) return x;
    return ff[x]=find(ff[x]);
}
int merge(int x,int y){
    int t1=find(x),t2=find(y);
    ff[t2]=t1;
}
int ksm(int a,int b){
    int ans=1;
    while(b){
        if(b&1) ans*=a,ans%=mod;
        a*=a,a%=mod;
        b>>=1;
    }return ans;
}
inline int C(int a,int b){return (((fac[a]*infac[b])%mod)*infac[a-b])%mod;}
int f[MAXN],g[MAXN],N,Lim,flip[MAXN];
inline void NTT(int *f,int opt){
    for(int i=0;i<N;i++) if(i<flip[i]) swap(f[i],f[flip[i]]);
    for(int p=2;p<=N;p<<=1){
        int len=p>>1,buf=ksm(3,(mod-1)/p);
        if(opt==-1) buf=ksm(buf,mod-2);
        for(int be=0;be<N;be+=p){
            int tmp=1;
            for(int l=be;l<be+len;l++){
                int t=(f[l+len]*tmp)%mod;
                f[l+len]=(f[l]-t+mod)%mod,f[l]=(f[l]+t)%mod;
                tmp*=buf,tmp%=mod;
            }
        }
    }if(opt==-1){
        int Inv=ksm(N,mod-2);
        for(int i=0;i<N;i++) f[i]*=Inv,f[i]%=mod;
    }return;
}
inline void _NTT(vector<int> &F,vector<int> G){
    int sizf=F.size()-1,sizg=G.size()-1;
    Lim=sizf+sizg;
    for(N=1;N<=Lim;N<<=1);
    for(int i=0;i<N;i++) flip[i]=((flip[i>>1]>>1)|(i&1?N>>1:0));
    for(int i=0;i<=sizf;i++) f[i]=F[i];
    for(int i=0;i<=sizg;i++) g[i]=G[i];
    for(int i=sizf+1;i<=N;i++) f[i]=0;
    for(int i=sizg+1;i<=N;i++) g[i]=0;
    NTT(f,1),NTT(g,1);
    for(int i=0;i<N;i++) f[i]*=g[i],f[i]%=mod;
    NTT(f,-1);
    F.clear();
    for(int i=0;i<=Lim;i++) F.push_back(f[i]);return;
}
inline void cdq(int l,int r){
    if(l==r) return;
    int mid=l+r>>1;
    cdq(l,mid),cdq(mid+1,r);
    _NTT(ve[l],ve[mid+1]);
    return;
}
void solve(){
    memset(M,0,sizeof(M)),memset(Num,0,sizeof(Num));
    n=read(),k=read();
    for(int i=1;i<=n;i++) ff[i]=i;
    for(int i=1;i<=n;i++) merge(i,read());
    for(int i=1;i<=n;i++) ff[i]=find(ff[i]);
    for(int i=1;i<=n;i++){
        if(!M[ff[i]]) M[ff[i]]=++M[0];
        Num[M[ff[i]]]++;
    }
    for(int i=1;i<=M[0];i++)
        for(int j=0;j<=Num[i];j++) {
            if(j) ve[i].push_back(C(Num[i],j));
            else ve[i].push_back(0);
        }
    cdq(1,M[0]);
    int Ans1=ve[1][k],Ans2=C(n,k);
    printf("%lld\n",(Ans1*ksm(Ans2,mod-2))%mod);
    for(int i=1;i<=M[0];i++) ve[i].clear();
    return;
}
signed main(){
    freopen("bishop.in","r",stdin);
    freopen("bishop.out","w",stdout);
    T=read();init();
    while(T--) solve();
    return 0;
}
View Code

 

灵大会议

题意简述

给定一棵有 $n$ 个节点的有根带权,第 $i$ 号点有 $val_i$ 表示 $i$ 号点有多少人。

$q$ 次询问,每次询问 $(u,v)$ 简单路径上的人走到哪个点的总距离最少,求其总距离或更改 $val$ 。

$n,q\leq 152501$

$solution:$

降智好题,忘记了有中位数这个东西,一直在想边对答案的贡献。

考虑若我们将 $(u,v)$ 这条链摘出来,则其选择的点为其中位数(也可以说是让两边 $val$ 值最少),问题就变成了求 $(u,v)$ 到 $x$ 的总距离。

我们设 $F_i=\sum_{lca(i,j)=j} dis_j\times val_j$ ,$dis_i$ 表示 $i$ 号点到跟的距离。

我们设 $x$ 与 $u$ 在同侧,将路径拆为 $(u,x),(x,lca),(lca,v)$ 。

$$Ans=W(u,x)+W(x,lca)+W(lca,v)\\=(F_u-F_{fath_x}-dis_{fath}\times \sum_{i\in (u,x)} C_i)+(F_v-F_{fath_{lca}}-dis_{lca}\times \sum_{i\in{lca,v}} C_i)+(dis_{lca}\times \sum _{i\in {x,lca}}C_i-(F_{x}-F_{fath_{lca}}))$$

直接用线段树或者树状数组加 $dfs$ 序(因为我们发现信息都只有一条与根相连的链),维护 $F$ 与 $C$ 的信息即可。

而求中位数直接比较后倍增或者二分维护即可。

而对于树链剖分时间复杂度 $O(n\log ^3 n)$ ,面对 $n,q\leq 152501$ 的数据会 $T $ 。

利用 $dfs$ 序优化即可,时间复杂度 $O(n\log ^2 n)$ 。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define int long long
using namespace std;
inline int read(){
    int f=1,ans=0;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();}
    return f*ans;
}
const int MAXN=200001;
int n;
int in[MAXN],out[MAXN],fa[MAXN][21],dep[MAXN],cnt,head[MAXN],val[MAXN],tot,dis[MAXN];
struct node{
    int u,v,w,nex;
}x[MAXN<<1];
struct BIT{
    int sum[MAXN];
    int lowbit(int x){return x&-x;}
    void Modify(int x,int w){
        for(;x<=n;x+=lowbit(x)) sum[x]+=w;
        return;
    }
    inline int Query(int x){
        int ans=0;
        for(;x;x-=lowbit(x)) ans+=sum[x];
        return ans;
    }
    inline void Add(int u,int w){
        Modify(in[u],w),Modify(out[u]+1,-w);
        return;
    }
    inline int Que(int u){return Query(in[u]);}
}t1,t2;
inline void add(int u,int v,int w){
    x[cnt].u=u,x[cnt].v=v,x[cnt].w=w,x[cnt].nex=head[u],head[u]=cnt++;
}
inline void dfs(int u,int fath){
    fa[u][0]=fath;dep[u]=dep[fath]+1;
    in[u]=++tot;
    for(int i=1;(1<<i)<=dep[u];i++) fa[u][i]=fa[fa[u][i-1]][i-1];
    for(int i=head[u];i!=-1;i=x[i].nex){
        if(x[i].v==fath) continue;
        dis[x[i].v]=dis[u]+x[i].w;
        dfs(x[i].v,u);
    }out[u]=tot;return;
}
inline int Lca(int u,int v){
    if(dep[u]<dep[v]) swap(u,v);
    for(int i=20;i>=0;i--)
        if(dep[u]-(1<<i)>=dep[v]) u=fa[u][i];
    if(u==v) return u;
    for(int i=20;i>=0;i--){
        if(fa[u][i]==fa[v][i]) continue;
        u=fa[u][i],v=fa[v][i];
    }return fa[u][0];
}
void Modify(int u,int w){
    t1.Add(u,w-val[u]);t2.Add(u,(w-val[u])*dis[u]);
    val[u]=w;return;
}
inline int qcnt(int u,int v){
    int lca=Lca(u,v);
    return t1.Que(u)+t1.Que(v)-2*t1.Que(lca)+val[lca];
}
inline int Q1(int u,int v){
    return t2.Que(u)-t2.Que(fa[v][0])-dis[v]*qcnt(u,v);
}
inline int Q2(int u,int v){
    return dis[u]*qcnt(u,v)-(t2.Que(u)-t2.Que(fa[v][0]));
}
inline int Query(int u,int v){
    int lca=Lca(u,v),Num=qcnt(u,v);
    int res=(Num+1)/2,tmp;
    if(val[u]>=res)    tmp=u;
    else if(val[v]>=res) tmp=v,swap(u,v);
    else{
        if(qcnt(v,lca)>=res) swap(u,v);
        tmp=u;
        for(int i=20;i>=0;i--)
            if(qcnt(fa[tmp][i],u)<res&&dep[tmp]-(1<<i)>=dep[lca]) tmp=fa[tmp][i];
        tmp=fa[tmp][0];
    }
    int Ans=0;
    Ans+=Q1(u,tmp);
    Ans+=Q1(v,lca);
    Ans+=Q2(tmp,lca);
    int G=qcnt(v,lca)-val[lca];
    Ans+=G*(dis[tmp]-dis[lca]);
    return Ans;
}
int q;
signed main(){
    freopen("conference.in","r",stdin);
    freopen("conference.out","w",stdout);
    memset(head,-1,sizeof(head));
    n=read();
    for(int i=1;i<=n;i++) val[i]=read();
    for(int i=1;i<n;i++){
        int u=read(),v=read(),w=read();
        add(u,v,w),add(v,u,w);
    }
    dfs(1,0);
    for(int i=1;i<=n;i++){
        int t=val[i];val[i]=0;
        Modify(i,t);
    }
    q=read();
    for(int i=1;i<=q;i++){
        int opt=read();
        if(opt==1){
            int u=read(),v=read();
            printf("%lld\n",Query(u,v));
        }else{
            int u=read(),w=read();
            Modify(u,w);
        }
    }return 0;
}
View Code

 

posted @ 2019-08-03 23:16  siruiyang_sry  阅读(143)  评论(2编辑  收藏  举报