【题解】Luogu P5291 [十二省联考2019]希望

ytq鸽鸽出的题真是毒瘤

原题传送门

题目大意:

有一棵有\(n\)个点的树,求有多少方案选\(k\)个联通块使得存在一个中心点\(p\),所有\(k\)个联通块中所有点到\(p\)的距离都\(\leq L\)

我们先从部分分考虑:

算法1

\(2^{nk}\)枚举联通快判断是否可行

期望得分:8pts

代码就不给了

引理1

对于任意一个联通块,可能成为它的中心点的点组成的集合也是一个联通块。正确性显然

  • 推论

这里作为\(k\)个联通块中心点的点组成的集合是一个联通块

算法2

枚举推论中中心点组成的集合,判断其他是否满足

期望得分:16pts

代码也不给了

引理2

注意到树上联通块中,点集为\(S\),边集为\(E\),有\(|S|-|E|=1\)。正确性显然

这样我们就珂以对每个点、边单独考虑,通过容斥算出答案。设\(f(x)\)表示\(x\)控制的联通块个数,\(g(e)\)表示\(e\)两端点都控制的联通块个数,易知答案为

\[\sum_{v \in V}f(v)^k-\sum_{e \in E}g(e)^k \]

算法3

我们考虑dp,设\(f_{i,j}\)表示\(i\)点子树中距离\(i\)不超过\(j\)且包含\(i\)的个数,\(g_{i,j}\)表示\(i\)点子树外距离\(i\)不超过\(j\)且包含\(i\)的个数,转移如下:

\[f_{i,j}=\prod_{v \in son(i)}(f_{v,j-1}+1) \]

\[g_{i,j}=g_{fa_i,j-1}\prod_{v \in son(fa_i),v \neq i}(f_{v,j-2}+1)+1 \]

对于每个点\(i\),对答案的贡献为:

\[(f_{i,L}·g_{i,L})^k \]

对于每条边\(e\)(设\(v\)为深度更深的点),对答案的贡献为:

\[-(f_{v,L-1}·(g_{v,L}-1))^k \]

时间复杂度:\(O(nL)\)

期望得分:36pts

算法4:

对于\(L=n\)的测试点,第二维的限制没用了,去掉即可

时间复杂度:\(O(n)\)

期望得分:结合算法3可得48pts

算法5:

对于链的情况,珂以手推一下dp的贡献,发现就是距离,算一下即可

时间复杂度:\(O(n)\)

期望得分:结合算法4可得52pts

这是暴力52pts的做法,代码如下,subtask1是算法3,subtask2是算法4,subtask3是算法5

#include <bits/stdc++.h>
#define mod 998244353
#define N 100005
#define getchar nc
using namespace std;
inline char nc(){
    static char buf[100000],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int read()
{
    register int x=0,f=1;register char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return x*f;
}
inline void write(register int x)
{
    if(!x)putchar('0');if(x<0)x=-x,putchar('-');
    static int sta[20];register int tot=0;
    while(x)sta[tot++]=x%10,x/=10;
    while(tot)putchar(sta[--tot]+48);
}
inline int power(register int a,register int b)
{
    int res=1;
    while(b)
    {
        if(b&1)
            res=1ll*res*a%mod;
        a=1ll*a*a%mod;
        b>>=1;
    }
    return res;
}
struct egde{
    int to,next;
}e[N<<2];
int head[N<<1],cnt=0;
inline void add(register int u,register int v)
{
    e[++cnt]=(egde){v,head[u]};
    head[u]=cnt;
}
int n,m,k,ans;
namespace subtask1{
    vector<int> f[N],g[N];
    inline void dfs1(register int x,register int fa)
    {
        for(register int i=head[x];i;i=e[i].next)
        {
            int v=e[i].to;
            if(v==fa)
                continue;
            dfs1(v,x);
            for(register int j=1;j<=m;++j)
                f[x][j]=1ll*f[x][j]*(f[v][j-1]+1)%mod;
        }
    }
    inline void dfs2(register int x,register int fa)
    {
        static int son[N],pre[N],suf[N];
        int cnt=0;
        for(register int i=head[x];i;i=e[i].next)
            if(e[i].to!=fa)
                son[++cnt]=e[i].to;
        for(register int i=pre[0]=suf[cnt+1]=1;i<=m;++i)
        {
            if(i>=2)
            {
                for(register int j=1;j<=cnt;++j)
                    pre[j]=1ll*pre[j-1]*(f[son[j]][i-2]+1)%mod;
                for(register int j=cnt;j>=1;--j)
                    suf[j]=1ll*suf[j+1]*(f[son[j]][i-2]+1)%mod;
            }
            else
            {
                for(register int j=1;j<=cnt;++j)
                    pre[j]=suf[j]=1;
            }
            for(register int j=1;j<=cnt;++j)
                g[son[j]][i]=(1ll*g[x][i-1]*pre[j-1]%mod*suf[j+1]+1)%mod;
        }
        for(register int i=head[x];i;i=e[i].next)
        {
            int v=e[i].to;
            if(v==fa)
                continue;
            ans=(ans+power(1ll*(g[v][m]-1)*f[v][m-1]%mod,k))%mod;
            dfs2(v,x);
        }
    }
    inline void solve()
    {
        for(register int i=1;i<=n;++i)
            f[i].resize(m+1,1),g[i].resize(m+1,1);
        dfs1(1,0);
        dfs2(1,0);
        ans=mod-ans;
        for(register int i=1;i<=n;++i)
            ans=(0ll+ans+power(1ll*f[i][m]*g[i][m]%mod,k))%mod;
        write(ans);
    }
}
namespace subtask2{
    int a[200005],b[200005];
    inline void dfs1(register int x,register int fa)
    {
        for(register int i=head[x];i;i=e[i].next)
        {
            int v=e[i].to;
            if(v==fa)
                continue;
            dfs1(v,x);
            a[x]=1ll*a[x]*(a[v]+1)%mod;
        }
    }
    inline void dfs2(register int x,register int fa)
    {
        static int son[200005],pre[200005],suf[200005];
        int cnt=0;
        for(register int i=head[x];i;i=e[i].next)
            if(e[i].to!=fa)
                son[++cnt]=e[i].to;
        pre[0]=suf[cnt+1]=1;
        for(register int j=1;j<=cnt;++j)
            pre[j]=1ll*pre[j-1]*(a[son[j]]+1)%mod;
        for(register int j=cnt;j>=1;--j)
            suf[j]=1ll*suf[j+1]*(a[son[j]]+1)%mod;
        for(register int j=1;j<=cnt;++j)
            b[son[j]]=(1ll*b[x]*pre[j-1]%mod*suf[j+1]+1)%mod;
        for(register int i=head[x];i;i=e[i].next)
        {
            int v=e[i].to;
            if(v==fa)
                continue;
            ans=(ans+power(1ll*(b[v]-1)*a[v]%mod,k))%mod;
            dfs2(v,x);
        }
    }
    inline void solve()
    {
        for(register int i=0;i<=n;++i)
            a[i]=b[i]=1;
        dfs1(1,0);
        dfs2(1,0);
        ans=mod-ans;
        for(register int i=1;i<=n;++i)
            ans=(ans+power(1ll*a[i]*b[i]%mod,k))%mod;
        write(ans);
    }
}
namespace subtask3{
    inline void solve()
    {
        for(register int i=1;i<=n;++i)
        {
            int l=max(1,i-m),r=min(n,i+m);
            ans=(ans+power(1ll*(i-l+1)*(r-i+1)%mod,k))%mod;
        }
        for(register int i=2;i<=n;++i)
        {
            int l=max(1,i-m),r=min(n,i+m-1);
            ans=(ans+mod-power(1ll*(i-l+1-1)*(r-i+1)%mod,k))%mod;
        }
        write(ans);
    }
}
int main()
{
    n=read(),m=read(),k=read();
    if(n<=200000)
    {
        for(register int i=1;i<n;++i)
        {
            int u=read(),v=read();
            add(u,v),add(v,u);
        }
    }
    if(m==n&&n<=200000)
        subtask2::solve();
    else if(n<=100000)
        subtask1::solve();
    else
        subtask3::solve();
	return 0;
}

算法6

发现\(f,g\)的转移珂以用长链剖分优化,套上一个可持久化数据结构,需要资瓷区间加、区间乘

时间复杂度:\(O(n\log n)\)

期望得分:结合算法5可得84pts

算法7

发现是全局加,后缀乘,这样可以像SDOI2019Day1T1 快速查询

珂以用数组线性维护,为了去掉求逆元的\(log\),我们珂以用类似于求阶乘逆元的方法\(O(n)\)预处理要用的逆元

至此,我们得到了一个\(O(n)\)的做法

期望得分:100pts

#include <bits/stdc++.h>
#define mod 998244353
#define N 1000005
#define getchar nc
using namespace std;
inline char nc(){
    static char buf[100000],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int read()
{
    register int x=0,f=1;register char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return x*f;
}
inline void write(register int x)
{
    if(!x)putchar('0');if(x<0)x=-x,putchar('-');
    static int sta[20];register int tot=0;
    while(x)sta[tot++]=x%10,x/=10;
    while(tot)putchar(sta[--tot]+48);
}
inline int power(register int a,register int b)
{
    int res=1;
    while(b)
    {
        if(b&1)
            res=1ll*res*a%mod;
        a=1ll*a*a%mod;
        b>>=1;
    }
    return res;
}
struct edge{
    int to,next;
}e[N<<1];
int head[N],cnt=0;
inline void add(register int u,register int v)
{
    e[++cnt]=(edge){v,head[u]};
    head[u]=cnt;
}
int n,m,k;
int dp[N],son[N],len[N],val[N],idx[N],tot,inv[N];
int arr[N<<3],*pos=arr,*f[N<<1],*g[N],ans[2][N],Ans;
inline void dfs1(register int x,register int fa)
{
    dp[x]=1;
    for(register int i=head[x];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa)
            continue;
        dfs1(v,x);
        son[x]=len[v]>len[son[x]]?v:son[x];
        dp[x]=1ll*dp[x]*dp[v]%mod;
    }
    if(++dp[x]<mod)
        val[idx[x]=++tot]=dp[x];
    len[x]=len[son[x]]+1;
}
inline void init_inv()
{
    static int pre[N];
    for(register int i=pre[0]=1;i<=tot;++i)
        pre[i]=1ll*pre[i-1]*val[i]%mod;
    int pre_inv=power(pre[tot],mod-2);
    for(register int i=tot;i;--i)
        inv[i]=1ll*pre[i-1]*pre_inv%mod,pre_inv=1ll*pre_inv*val[i]%mod;
}
struct node{
    int a,b,inv,pos,num;
};
namespace F{
    node tag[N<<1];
    vector<pair<node,vector<pair<int,int> > > >backup[N];
    inline int get(register int u,register int i)
    {
        int res=i<tag[u].pos?f[u][i]:tag[u].num;
        return (1ll*res*tag[u].a+tag[u].b)%mod;
    }
    inline void put(register int u,register int i,register int v)
    {
        f[u][i]=1ll*(v+mod-tag[u].b)*tag[u].inv%mod;
    }
    inline void merge(register int u,register int v,register int l)
    {
        node tmp=tag[u];
        vector<pair<int,int> >vec;
        for(register int i=1;i<=l;++i)
        {
            vec.push_back(make_pair(i,f[u][i]));
            int val=get(v,i-1);
            if(i==tag[u].pos)
                f[u][tag[u].pos++]=tag[u].num;
            put(u,i,1ll*get(u,i)*val%mod);
        }
        if(l<m)
        {
            int val=get(v,l);
            if(!val)
                tag[u].pos=l+1,tag[u].num=mod-1ll*tag[u].b*tag[u].inv%mod;
            else
            {
                int t=inv[idx[v]];
                vec.push_back(make_pair(0,f[u][0]));
                for(register int i=0;i<=l;++i)
                    put(u,i,1ll*get(u,i)*t%mod);
                tag[u].a=1ll*tag[u].a*val%mod;
                tag[u].b=1ll*tag[u].b*val%mod;
                tag[u].inv=1ll*tag[u].inv*t%mod;
            }
        }
        if(u<=n)
            backup[u].push_back(make_pair(tmp,vec));
    }
    inline void dfs2(register int x,register int fa)
    {
        if(son[x])
            f[son[x]]=f[x]+1,dfs2(son[x],x),tag[x]=tag[son[x]];
        else
            tag[x]=(node){1,1,1,n,0};
        put(x,0,1);
        for(register int i=head[x];i;i=e[i].next)
        {
            int v=e[i].to;
            if(v==fa||v==son[x])
                continue;
            f[v]=pos;
            pos+=len[v];
            dfs2(v,x);
            merge(x,v,min(len[v]-1,m));
        }
        ans[0][x]=get(x,min(len[x]-1,m-1));
        ans[1][x]=get(x,min(len[x]-1,m));
        ++tag[x].b;
    }
    inline void rollback(register int u)
    {
        tag[u]=backup[u].back().first;
        for(register int i=0;i<(backup[u].back().second).size();++i)
            f[u][(backup[u].back().second)[i].first]=(backup[u].back().second)[i].second;
        backup[u].pop_back();
    }
}
namespace G{
    node tag[N];
    inline int get(register int u,register int i)
    {
        int res=i<tag[u].pos?g[u][i]:tag[u].num;
        return (1ll*res*tag[u].a+tag[u].b)%mod;
    }
    inline void put(register int u,register int i,register int v)
    {
        g[u][i]=1ll*(v+mod-tag[u].b)*tag[u].inv%mod;
    }
    inline void dfs3(register int x,register int fa)
    {
        if(len[x]-m-1>=0)
            put(x,len[x]-m-1,1);
        Ans=(Ans+power(1ll*ans[1][x]*get(x,len[x]-1)%mod,k))%mod;
        if(fa)
            Ans=(Ans-power(1ll*ans[0][x]*(get(x,len[x]-1)+mod-1)%mod,k)+mod)%mod;
        if(!son[x])
            return;
        vector<int> ch;
        int maxlen=0;
        for(register int i=head[x];i;i=e[i].next)
        {
            int v=e[i].to;
            if(v==fa||v==son[x])
                continue;
            ch.push_back(v);
            maxlen=max(maxlen,len[v]);
        }
        maxlen=min(maxlen,m);
        reverse(ch.begin(),ch.end());
        f[x+n]=pos;
        pos+=maxlen+1;
        F::tag[x+n]=(node){1,1,1,n,0};
        F::put(x+n,0,1);
        for(register int id=0;id<ch.size();++id)
        {
            int v=ch[id];
            F::rollback(x);
            g[v]=pos;
            pos+=len[v];
            for(register int i=max(len[v]-m-1,0);i<len[v];++i)
            {
                if(m-len[v]+i==-1)
                    g[v][i]=get(x,len[son[x]]-len[v]+i);
                else
                    g[v][i]=1ll*get(x,len[son[x]]-len[v]+i)*F::get(x,min(len[x]-1,m-len[v]+i))%mod*F::get(x+n,min(maxlen,m-len[v]+i))%mod;
            }
            tag[v]=(node){1,1,1,n,0};
            F::merge(x+n,v,min(len[v]-1,m));
            dfs3(v,x);
        }
        int v=son[x];
        g[v]=g[x];
        tag[v]=tag[x];
        for(register int i=max(len[v]-m,0);i<=len[v]+maxlen-m-1;++i)
        {
            if(tag[v].pos==i)
                g[v][tag[v].pos++]=tag[v].num;
            put(v,i,1ll*get(v,i)*F::get(x+n,m-len[v]+i)%mod);
        }
        if(maxlen<m)
        {
            int vv=1,tt=1;
            for(register int i=0;i<ch.size();++i)
                vv=1ll*vv*val[idx[ch[i]]]%mod,tt=1ll*tt*inv[idx[ch[i]]]%mod;
            if(!vv)
                tag[v].pos=len[v]+maxlen-m,tag[v].num=mod-1ll*tag[v].b*tag[v].inv%mod;
            else
            {
                for(register int i=max(len[v]-m-1,0);i<=len[v]+maxlen-m-1;++i)
                    put(v,i,1ll*get(v,i)*tt%mod);
                tag[v].a=1ll*tag[v].a*vv%mod;
                tag[v].b=1ll*tag[v].b*vv%mod;
                tag[v].inv=1ll*tag[v].inv*tt%mod;
            }
        }
        ++tag[v].b;
        dfs3(v,x);
    }
}
int main()
{
    n=read(),m=read(),k=read();
    for(register int i=1;i<n;++i)
    {
        int u=read(),v=read();
        add(u,v),add(v,u);
    }
    dfs1(1,0);
    init_inv();
    f[1]=pos;
    pos+=len[1];
    F::dfs2(1,0);
    g[1]=pos;
    pos+=len[1];
    G::tag[1]=(node){1,1,1,n,0};
    G::dfs3(1,0);
    write(Ans);
	return 0;
}
posted @ 2019-08-04 21:55  JSOI爆零珂学家yzhang  阅读(577)  评论(3编辑  收藏  举报