Luogu5291 [十二省联考2019]希望

Luogu5291 [十二省联考2019]希望

[十二省联考2019]绝望

长链剖分优化\(DP\)

\(O(nL)DP\):我们考虑单独计算每一个节点的贡献,但是这样做显然有问题,因为对于一个连通块集合来说,不仅仅有一个点是满足要求的。

那么我们可以考虑容斥计算答案,这就需要挖掘题目中的性质。显然满足题意的点集一定形成一个连通块,因为我们很容易发现对于两个满足条件的点\(u,v\)来说,\(u \rightarrow v\)的路径上的点一定符合题意。

树上的连通块依然是一棵树,那么必然满足边数\(=\)点数\(-1\),所以我们通过满足题意的点的贡献减去满足题意的边的贡献即可得到最终的答案。

单独统计点和边的贡献较容易计算。

由于与一个点有关的边要么来自它的子树,要么来自父亲,我们就可以记录这两方面的答案进行计算。

\(f_{u,i}\)表示\(u\)子树内,所有点与\(u\)距离不超过\(i\)的连通块个数(当然计算的连通块必然包括\(u\)节点)\(+1\)\(+1\)是为了便于计算)。

\[f_{u,i}=\prod_{v \in son_u} f_{v,i-1}+1 \]

\(g_{u,i}\)表示所有点与\(u\)距离不超过\(i\)的,包含\(u\)且不包含任何\(u\)子树内节点,向\(fa_u\)方向延伸出来的连通块个数。

\[g_{u,i}=(g_{fa_u,i-1} \prod_{v \in son_{fa_u},v \ne u} f_{v,i-2})+1 \]

考虑点的贡献,应该是:

\[\sum_{u}( (f_{u,L}-1)g_{u,L})^k \]

边的贡献:

\[\sum_{u} ((f_{u,L-1}-1)(g_{u,L}-1))^k \]

\[ans=\sum_{u}( (f_{u,L}-1)g_{u,L})^k-((f_{u,L-1}-1)(g_{u,L}-1))^k \]

\(36pts Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#define N 1000005
#define ll long long
using namespace std;
const int p=998244353;
int n,L,k,x,y;
int ans=0;
struct edge
{
    int nxt,v;
    edge (int Nxt=0,int V=0)
    {
        nxt=Nxt,v=V;
    }
}e[N << 1];
int tot,fr[N];
int r[N],z[N];
vector<int>f[N],g[N],s1[N],s2[N];
int ksm(int x,int y)
{
    int ans=1;
    while (y)
    {
        if (y & 1)
            ans=(ll)ans*x%p;
        x=(ll)x*x%p;
        y >>=1;
    }
    return ans;
}
int add(int x,int y)
{
    return (x+y)%p;
}
int del(int x,int y)
{
    return (x-y)%p;
}
int mul(int x,int y)
{
    return (ll)x*y%p;
}
void Add(int &x,int y)
{
    x=(x+y)%p;
}
void Del(int &x,int y)
{
    x=(x-y)%p;
}
void Mul(int &x,int y)
{
    x=(ll)x*y%p;
}
void link(int x,int y)
{
    ++tot;
    e[tot]=edge(fr[x],y),fr[x]=tot;
}
void dfs1(int u,int F)
{
    for (int i=0;i<=L;++i)
        f[u][i]=1;
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==F)
            continue;
        dfs1(v,u);
        for (int j=1;j<=L;++j)
            Mul(f[u][j],f[v][j-1]);
    }
    for (int i=0;i<=L;++i)
        Add(f[u][i],1);
}
void dfs2(int u,int F)
{
    for (int i=0;i<=L;++i)
        Add(g[u][i],1);
    int cnt=0;
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==F)
            continue;
        r[v]=++cnt;
        z[cnt]=v;
    }
    for (int i=0;i<=L;++i)
        s1[0][i]=s2[cnt+1][i]=1;
    for (int i=1;i<=cnt;++i)
        for (int j=2;j<=L;++j)
            s1[i][j]=mul(s1[i-1][j],f[z[i]][j-2]);
    for (int i=cnt;i;--i)
        for (int j=2;j<=L;++j)
            s2[i][j]=mul(s2[i+1][j],f[z[i]][j-2]);
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==F)
            continue;
        g[v][1]=1;
        for (int j=2;j<=L;++j)
            g[v][j]=mul(g[u][j-1],mul(s1[r[v]-1][j],s2[r[v]+1][j]));
    }
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==F)
            continue;
        dfs2(v,u);
    }
}
int main()
{
    scanf("%d%d%d",&n,&L,&k);
    for (int i=0;i<=n+2;++i)
        for (int j=0;j<=L+2;++j)
        {
            f[i].push_back(0);
            g[i].push_back(0);
            s1[i].push_back(0);
            s2[i].push_back(0);
        }
    for (int i=1;i<n;++i)
    {
        scanf("%d%d",&x,&y);
        link(x,y),link(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    for (int i=1;i<=n;++i)
    {
        Add(ans,ksm(mul(f[i][L]-1,g[i][L]),k));
        Del(ans,ksm(mul(f[i][L-1]-1,g[i][L]-1),k));
    }
    ans=(ans%p+p)%p;
    printf("%d\n",ans);
    return 0;
}

这道题的数据范围是\(n \le 10^6\),而且\(dp\)方程明显与深度有关,因此我们可以考虑长链剖分优化。

\(len_u\)表示\(u\)所在长链底部与\(u\)的之间路径中的点数。

比较容易的是\(f\)当然也不是那么容易)。

\[f_{u,i}=\prod_{v \in son_u} f_{v,i-1}+1 \]

\(problem1:f_{u,i}(i \ge len_u)\)同样有意义,而不是\(0\)

\(problem2:\)本题的式子中含有整体\(+1\)操作。

\(problem3:\)对于\(f_{v,i}(v \in son_u)\),它会对\(f_{u,t}(t \ge i)\)产生贡献,如果暴力转移,时间复杂度就会不正确。

对于\(problem1\),由于\(f_{u,i}(i \ge len_u)=f_{u,len_u}\)我们可以打一个标记\(pos\),记录到达哪一位的值是相同的,同时记录那一位的值。

对于\(problem2\),我们可以打上加法标记和乘法标记,实现\(O(1)\)修改。

对于\(problem3\),由于采用了长链剖分,显然轻子树深度不会超过\(f\)数组范围,那么对于\(i < len_v\)的情况暴力转移是符合线性复杂度的。因此,我们只需要考虑\(f_{v,len_v-1}\)的贡献(因为\(f_{v,t}(t \ge len_v)=f_{v,len_v-1}\)中),然后尝试对于后面的答案一次性合并。分两种情况,如果\(f_{v,len_v-1} \equiv 0 \pmod{998244353}\),那么\(f_{u,len_v}\)以后的值都会乘上它,那么它们的值将永远是\(0\),利用上面\(pos\)标记,我们可以记录到\(len_v\)位置之后,后面的值均为\(0\)。如果\(f_{v,len_v-1} \equiv a \pmod{998244353}(a \ne 0)\),那么我们可以对于整体乘上\(f_{v,len_v-1}\),再对\(f_{u,i}(0 \le i \le len_v)\)暴力乘上\(a\)的逆元即可。

然后\(f\)就基本解决了。

考虑\(g\)如何转移。

\[g_{u,i}=(g_{fa_u,i-1} \prod_{v \in son_{fa_u},v \ne u} f_{v,i-2})+1 \]

\(f\)不同,\(g\)是自顶向下进行转移的,我们采取的措施是让重儿子继承父亲\(g\)数组,轻儿子利用\(g_u\)暴力更新。

看起来不好解决,我们先笼统地解决大的问题,比如轻儿子的暴力转移。

我们必须满足时间复杂度保持线性。

可以得到结论,对于一个节点来说\(g_{u,i}(0 \le i \le L-len_u)\)并没有用处,因为子树内的点至少也可以跳到\(g_{u,L-len_u+1}\),那么\([0,L-len_u]\)一部分信息我们不需要继承。

那么每个节点继承的信息为\([L-len_u+1,L]\),长度与长链长度相同。

对于一条长链来说,它的信息只会在长链顶部被暴力更新一次,时间复杂度\(O(n)\)

对于\(g\)数组,我们用\(g_{u,0}\)来代替原来的\(g'_{u,L-len_u+1}\),需要注意下标的转化。

分别考虑重儿子\(w\)和轻儿子\(v\)如何继承。

重儿子:由于\(g_{u,i}\)的实际值为\(g'_{u,L-len_u+i+1}\),根据\(dp\)方程,\(w\)应该用\(L-len_u+i+2\)这一位的值来继承,也就是\(g'_{w,L-len_u+i+2}\),由于\(len_w=len_u-1\)\(g'_{w,L-len_u+i+2}=g'_{w,L-len_w+i+1}=g_{w,i}\)\(g_{u,i},g_{w,i}\)下标相同,这意味着我们可以直接把\(u\)的指针转移给\(w\)即可。

轻儿子:我们需要将\(L-len_v+1 \le i \le L\)\(g'_{u,i}\)值更新\(v\),我们可以先枚举\(0 \le i < len_v\),然后转化成\(g_u\)数组中的下标。

现在还有两个问题:

\(1:\)来自\(f\)数组的转移。

\(2:\)轻儿子向重儿子的合并。

考虑第一个问题,\(\prod_{v \in son_{fa_u},v \ne u} f_{v,i-2}\),不能直接用\(\frac{f_{fa_u,i-1}}{f_{v,i-2}}\),因为\(f_{v,i-2}\)经过\(+1\)操作,它的值在模\(998244353\)意义下可能为\(0\)

我们只好采用前缀积乘后缀积的方式解决,我们的方案是倒序进行\(f\)\(dp\),同时利用之前的值。

我们需要支持可撤销的数据结构,类似的方式,在第一次\(f\)\(dp\)时,把一次更新前原来的值压入栈中,然后进行还原,即可得到前缀积。

那么第一个问题解决了。

继续,第二个问题。

还是它\(\prod_{v \in son_{fa_u},v \ne u} f_{v,i-2}\),对于一个节点\(v\)\(f_{v,i-2}\)会对\(g_{u,i}\)产生贡献,同样,对于\(v\)来说,只有\(f_{v,i}(0 \le i < len_v-1)\)是需要存储的,\(f_{v,i}(i \ge len_v)=f_{v,len_v-1}\)

又来了,全局乘,前缀乘逆元,最后还有个\(+1\),加法标记继续上,和\(f\)完全一致。

所以现在我们的复杂度是\(O(n \log k)\)了?

不对,还有逆元需要一只\(\log\),这显然让人很不舒服。

然而我们发现,对于一个节点\(u\),有关\(u\)需要求解的逆元只有\(f_{u,len_u-1}\),因为我们只有在全局乘操作时才需要逆元,之所以全局乘,是因为\(len_u-1\)之后的值都等于\(f_{u,len_u-1}\)

对于\(f_{u,len_u-1}\),可以看做去掉了长度限制,我们不需要开第二维数组即可计算,可以用暴力\(dp\)实现\(O(n)\)统计。

然后是线性求\(m\)个数\(a_{1 \cdots m}\)的逆元的问题(不要忘了把\(0\)去掉)。

\[s=\prod_{i=1}^m a_i\\pre_k=\prod_{i=1}^k a_i\\succ_k=\prod_{i=k}^m a_i\\ {a_i}^{-1}=s^{-1} pre_{i-1} succ_{i+1} \]

预处理\(s^{-1},pre,succ\)即可线性求逆元。

总时间复杂度\(O(n \log k)\)这个\(\log\)可以忽略不计了吧,恐怕还没有我们\(dp\)转移时的常数大呢)。

奉上\(7k\)代码(不删调试信息\(11k\)

\(Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<list>
#define N 1000005
#define ll long long
#define mp make_pair
using namespace std;
const int p=998244353;
const int INF=1000000007;
int n,L,k,x,y;
struct edge
{
    int nxt,v;
    edge (int Nxt=0,int V=0)
    {
        nxt=Nxt,v=V;
    }
}e[N << 1];
int tot,fr[N];
int dp[N],len[N],son[N];
int a0,an[N],iz[N],pre[N],inv[N];
void link(int x,int y)
{
    ++tot;
    e[tot]=edge(fr[x],y),fr[x]=tot;
}
void Add(int &x,int y)
{
    x=(x+y)%p;
}
void Del(int &x,int y)
{
    x=(x-y)%p;
}
void Mul(int &x,int y)
{
    x=(ll)x*y%p;
}
int add(int x,int y)
{
    return (x+y)%p;
}
int del(int x,int y)
{
    return (x-y)%p;
}
int mul(int x,int y)
{
    return (ll)x*y%p;
}
int ksm(int x,int y)
{
    int ans=1;
    while (y)
    {
        if (y & 1)
            Mul(ans,x);
        Mul(x,x);
        y >>=1;
    }
    return ans;
}
#define iv(x) ksm(x,p-2)
void dfs(int u,int F)
{
    len[u]=1,dp[u]=1;
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==F)
            continue;
        dfs(v,u);
        Mul(dp[u],dp[v]);
        son[u]=(len[son[u]]<len[v])?v:son[u];
    }
    Add(dp[u],1);
    if (dp[u])
        an[++a0]=dp[u],iz[a0]=u;
    len[u]=len[son[u]]+1;
}
void getinv()
{
    pre[0]=1;
    for (int i=1;i<=a0;++i)
        pre[i]=mul(pre[i-1],an[i]);
    int sf=iv(pre[a0]);
    int succ=1;
    for (int i=a0;i;--i)
    {
        inv[iz[i]]=mul(mul(pre[i-1],succ),sf);
        Mul(succ,an[i]);
    }
}
struct node
{
    int ad=0,mu=1,inv=1,pos,num;
    node (int A=0,int B=1,int C=1,int D=INF,int E=0)
    {
        ad=A,mu=B,inv=C,pos=D,num=E;
    }
    void inc(int x)
    {
        Add(ad,x);
    }
    void dec(int x)
    {
        Del(ad,x);
    }
    void times(int x)
    {
        Mul(ad,x),Mul(mu,x);
    }
};
int poolF[N << 1],*f[N << 1],*xf=poolF;
int poolG[N],*g[N],*xg=poolG;
int ans=0;
namespace F
{
    list< pair <node , list < pair<int,int> > > >rf[N];
    node tag[N << 1];
    void put(int u,int i,int val)
    {
        f[u][i]=mul(del(val,tag[u].ad),tag[u].inv);
    }
    int get(int u,int i)
    {
        return add(mul((i>=tag[u].pos)?tag[u].num:f[u][i],tag[u].mu),tag[u].ad);
    }
    void combine(int u,int v,int l)
    {
        list< pair< int , int > >V;
        node rtag=tag[u];
        for (int i=1;i<=l;++i)
        {
            V.push_back(mp(i,f[u][i]));
            if (i==tag[u].pos)
                f[u][tag[u].pos++]=tag[u].num;
            put(u,i,mul(get(u,i),get(v,i-1)));
        }
        if (l<L)
        {
            int val=dp[v];
            if (!val)
                tag[u].pos=l+1,tag[u].num=mul(del(0,tag[u].ad),tag[u].inv); else
                {
                    int t=inv[v];
                    tag[u].times(val);
                    Mul(tag[u].inv,t);
                    V.push_back(mp(0,f[u][0]));
                    for (int i=0;i<=l;++i)
                        put(u,i,mul(get(u,i),t));
                }
        }
        if (u<=n)
            rf[u].push_back(mp(rtag,V));
    }
    void dfs(int u,int F)
    {
        if (son[u])
            f[son[u]]=f[u]+1,dfs(son[u],u),tag[u]=tag[son[u]]; else
            ++tag[u].ad,tag[u].pos=INF;
        put(u,0,1);
        for (int i=fr[u];i;i=e[i].nxt)
        {
            int v=e[i].v;
            if (v==F || v==son[u])
                continue;
            f[v]=xf,xf+=len[v];
            dfs(v,u);
            combine(u,v,min(L,len[v]-1));
        }
        ++tag[u].ad;
    }
    void back(int u)
    {
        tag[u]=rf[u].back().first;
        for (list < pair< int , int> > :: iterator it=rf[u].back().second.begin();it!=rf[u].back().second.end();++it)
            f[u][it->first]=it->second;
        rf[u].pop_back();
    }
};
namespace G
{
    node tag[N];
    void put(int u,int i,int val)
    {
        if (i<0)
            return;
        g[u][i]=mul(del(val,tag[u].ad),tag[u].inv);
    }
    int get(int u,int i)
    {
        return add(mul((i>=tag[u].pos)?tag[u].num:g[u][i],tag[u].mu),tag[u].ad);
    }
    int iget(int u,int i)
    {
        return i-(L-len[u]+1);
    }
    int rget(int u,int i)
    {
        return i+(L-len[u]+1);
    }
    void dfs(int u,int F)
    {
        Add(ans,ksm(mul(F::get(u,min(len[u]-1,L))-1,get(u,iget(u,L))),k));
        if (u!=1)
            Del(ans,ksm(mul(F::get(u,min(len[u]-1,L-1))-1,get(u,iget(u,L))-1),k));
        if (!son[u])
            return;
        int mxlen=-1;
        list<int>sn;
        for (int i=fr[u];i;i=e[i].nxt)
        {
            int v=e[i].v;
            if (v==F || v==son[u])
                continue;
            sn.push_back(v);
            mxlen=max(mxlen,len[v]);
        }
        mxlen=min(mxlen,L);
        sn.reverse();
        f[u+n]=xf,xf+=mxlen+1;
        F::tag[u+n]=node(1,1,1,INF,0);
        for (list<int> :: iterator it=sn.begin();it!=sn.end();++it)
        {
            int v=(*it);
            F::back(u);
            g[v]=xg;
            xg+=len[v];
            for (int i=0;i<len[v];++i)
            {
                int vl=rget(v,i);
                if (vl<1)
                    continue;
                if (vl==1)
                {
                    put(v,i,1);
                    continue;
                }
                int ul=iget(u,vl-1);
                put(v,i,mul(get(u,ul),mul(F::get(u,min(len[u]-1,vl-1)),F::get(u+n,min(mxlen,vl-1)))));
            }
            ++tag[v].ad;
            F::combine(u+n,v,min(L,len[v]-1));
        }
        int w=son[u];
        g[w]=g[u],tag[w]=tag[u];
        for (list<int> :: iterator it=sn.begin();it!=sn.end();++it)
        {
            int v=(*it);
            for (int i=0;i<len[v];++i)
            {
                int wl=iget(w,i+2);
                if (wl<0)
                    continue;
                if (wl>len[w]-1)
                    break;
                if (wl==tag[w].pos)
                    g[w][tag[w].pos++]=tag[w].num;
                put(w,wl,mul(get(w,wl),F::get(v,min(len[v]-1,i))));
            }
            int wl=iget(w,len[v]+1);
            if (wl<len[w]-1)
            {
                int val=dp[v];
                if (!val)
                    tag[w].pos=wl+1,tag[w].num=mul(del(0,tag[w].ad),tag[w].inv); else
                    {
                        int t=inv[v];
                        tag[w].times(val);
                        Mul(tag[w].inv,t);
                        for (int i=iget(w,max(L-len[w]+1,0));i<=wl;++i)
                            put(w,i,mul(get(w,i),t));
                    }
            }
        }
        ++tag[w].ad;
        put(w,iget(w,0),1);
        put(w,iget(w,1),2);
        for (list<int> :: iterator it=sn.begin();it!=sn.end();++it)
            dfs((*it),u);
        dfs(w,u);
    }
};
int main()
{
    scanf("%d%d%d",&n,&L,&k);
    if (!L)
    {
        printf("%d\n",n);
        return 0;
    }
    for (int i=1;i<n;++i)
    {
        scanf("%d%d",&x,&y);
        link(x,y),link(y,x);
    }
    dfs(1,0);
    getinv();
    f[1]=xf,xf+=len[1];
    F::dfs(1,0);
    g[1]=xg,xg+=len[1];
    G::tag[1]=node(1,1,1,INF,0);
    G::dfs(1,0);
    ans=(ans%p+p)%p;
    printf("%d\n",ans);
    return 0;
}
posted @ 2020-11-22 20:15  GK0328  阅读(42)  评论(0编辑  收藏  举报