洛谷P5437/5442 约定(概率期望,拉格朗日插值,自然数幂)

题目大意:$n$ 个点的完全图,点 $i$ 和点 $j$ 的边权为 $(i+j)^k$。随机一个生成树,问这个生成树边权和的期望对 $998244353$ 取模的值。

对于P5437:$1\le n\le 998244352,1\le k\le 10^7$。

对于P5442:$1\le n\le 10^4,\le k\le 10^7$。


其实也是一道比较简单的题。(所以就应该把这题和上一道原题调个位置)

考虑一条边在生成树中出现的概率,由于一共有 $\dfrac{n(n-1)}{2}$ 条边,一个生成树有 $n-1$ 条边,而每条边的概率相等,所以为 $\dfrac{2}{n}$。

那么开始推式子:(注:第三步是枚举 $i+j$)

$$\dfrac{2}{n}\sum\limits_{i=1}^n\sum\limits_{j=i+1}^n(i+j)^k$$$$\dfrac{1}{n}(\sum\limits_{i=1}^n\sum\limits_{j=1}^n(i+j)^k-\sum\limits^n_{i=1}(i+i)^k)$$

$$\dfrac{1}{n}(\sum\limits_{s=1}^{2n}s^k\min(s-1,2n+1-s)-2^k\sum\limits^n_{i=1}i^k)$$

$$\dfrac{1}{n}(\sum\limits_{s=1}^{n}s^k(s-1)+\sum\limits_{s=n+1}^{2n}s^k(2n+1-s)-2^k\sum\limits^n_{i=1}i^k)$$

$$\dfrac{1}{n}(\sum\limits_{s=1}^{n}s^{k+1}-\sum\limits_{s=1}^{n}s^{k}+(2n+1)\sum\limits_{s=n+1}^{2n}s^k-\sum\limits_{s=n+1}^{2n}s^{k+1}-2^{k}\sum\limits^n_{i=1}i^k)$$

$$\dfrac{1}{n}(2\sum\limits_{i=1}^{n}i^{k+1}-(2n+2+2^k)\sum\limits_{i=1}^{n}i^{k}+(2n+1)\sum\limits_{i=1}^{2n}i^k-\sum\limits_{i=1}^{2n}i^{k+1})$$

现在问题就是求 $f(n)=\sum\limits_{i=1}^ni^k$ 了。

由于 $f(n)-f(n-1)=n^k$,$f$ 的差值是个 $k$ 次多项式,所以 $f$ 是个 $k+1$ 次多项式。

那么可以拉格朗日插值。(以下内容的代码实现细节比较多,注意要控制复杂度不带 $\log$)

取 $k+2$ 个点为 $1$ 到 $k+2$,发现点值 $y_i$ 可以 $O(k)$ 计算。($y_i=y_{i-1}+i^k$ 不能直接快速幂,不然带 $\log$。可以用欧拉筛筛出所有 $k$ 次方)

$$f(n)=\sum\limits_{i=1}^{k+2}y_i\dfrac{\prod\limits^{k+2}_{j=1,j\ne i}(n-x_j)}{\prod\limits^{k+2}_{j=1,j\ne i}(x_i-x_j)}$$

这样拉格朗日插值公式中的分母就是两个阶乘相乘的形式,可以 $O(1)$。(预处理要注意控制复杂度)

代入一个数算时,先特判 $n\ge mod$(因为会调用到 $f(2n)$),此时 $f(n)=\lfloor\dfrac{n}{mod}\rfloor f(mod-1)+f(n\%mod)$。

否则先算出 $fac=\prod\limits_{i=1}^{k+2}(n-i)$。同时预处理出所有 $n-i$ 的逆元 $inv_i$。(不要一个个快速幂算,复杂度错的。要用 $O(k+\log)$ 的方式)

此时就有:

$$f(n)=\sum\limits_{i=1}^{k+2}y_i\dfrac{fac\times inv_i}{(i-1)!(k-i+2)!(-1)^{k-i+2}}$$

已经可以 $O(k)$ 计算了。

时间复杂度 $O(k+\log)$。

#include<bits/stdc++.h>
using namespace std;
const int maxn=10001000,mod=998244353;
#define FOR(i,a,b) for(int i=(a);i<=(b);i++)
#define ROF(i,a,b) for(int i=(a);i>=(b);i--)
#define MEM(x,v) memset(x,v,sizeof(x))
inline int read(){
    char ch=getchar();int x=0,f=0;
    while(ch<'0' || ch>'9') f|=ch=='-',ch=getchar();
    while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
    return f?-x:x;
}
int n,k,ans,kp[maxn],k1p[maxn],pr[maxn/10],pl,ky[maxn],k1y[maxn],fac[maxn],invfac[maxn],tfac[maxn],tinv[maxn];
bool vis[maxn];
int qpow(int a,int b){
    int ans=1;
    for(;b;b>>=1,a=1ll*a*a%mod) if(b&1) ans=1ll*ans*a%mod;
    return ans;
}
int kcalc(int x){
    if(x<=k+2){
        int ans=1ll*ky[x]*fac[x-1]%mod*fac[k-x+2]%mod;
        if((k-x)&1) return ans?mod-ans:0;
        else return ans;
    }
    if(x>=mod) return (kcalc(mod-1)+kcalc(x%mod))%mod;
    tfac[0]=1;
    FOR(i,1,k+2) tfac[i]=1ll*tfac[i-1]*(x-i)%mod;
    tinv[k+2]=qpow(tfac[k+2],mod-2);
    ROF(i,k+1,1) tinv[i]=1ll*tinv[i+1]*(x-i-1)%mod;
    FOR(i,2,k+2) tinv[i]=1ll*tinv[i]*tfac[i-1]%mod;
    int ans=0;
    FOR(i,1,k+2) ans=(ans+1ll*ky[i]*tfac[k+2]%mod*tinv[i])%mod;
    return ans;
}
int k1calc(int x){
    if(x<=k+3){
        int ans=1ll*k1y[x]*fac[x-1]%mod*fac[k-x+3]%mod;
        if((k-x)&1) return ans;
        else return ans?mod-ans:0;
    }
    if(x>=mod) return (k1calc(mod-1)+k1calc(x%mod))%mod;
    tfac[0]=1;
    FOR(i,1,k+3) tfac[i]=1ll*tfac[i-1]*(x-i)%mod;
    tinv[k+3]=qpow(tfac[k+3],mod-2);
    ROF(i,k+2,1) tinv[i]=1ll*tinv[i+1]*(x-i-1)%mod;
    FOR(i,2,k+3) tinv[i]=1ll*tinv[i]*tfac[i-1]%mod;
    int ans=0;
    FOR(i,1,k+3) ans=(ans+1ll*k1y[i]*tfac[k+3]%mod*tinv[i])%mod;
    return ans;
}
int main(){
    n=read();k=read();
    fac[0]=1;
    FOR(i,1,k+3) fac[i]=1ll*fac[i-1]*i%mod;
    invfac[k+3]=qpow(fac[k+3],mod-2);
    ROF(i,k+2,0) invfac[i]=1ll*invfac[i+1]*(i+1)%mod;
    kp[1]=k1p[1]=1;
    FOR(i,2,k+3){
        if(!vis[i]){
            pr[++pl]=i;
            kp[i]=qpow(i,k);
            k1p[i]=qpow(i,k+1);
        }
        FOR(j,1,pl){
            if(i*pr[j]>k+3) break;
            vis[i*pr[j]]=true;
            kp[i*pr[j]]=1ll*kp[i]*kp[pr[j]]%mod;
            k1p[i*pr[j]]=1ll*k1p[i]*k1p[pr[j]]%mod;
            if(i%pr[j]==0) break;
        }
    }
    ky[1]=k1y[1]=1;
    FOR(i,2,k+3) ky[i]=(ky[i-1]+kp[i])%mod,k1y[i]=(k1y[i-1]+k1p[i])%mod;
    FOR(i,1,k+3){
        ky[i]=1ll*ky[i]*invfac[i-1]%mod*invfac[k-i+2]%mod;
        if((k-i)&1) ky[i]=ky[i]?mod-ky[i]:0;
        k1y[i]=1ll*k1y[i]*invfac[i-1]%mod*invfac[k-i+3]%mod;
        if(!((k-i)&1)) k1y[i]=k1y[i]?mod-k1y[i]:0;
    }
    ans=2*k1calc(n)%mod;
    ans=(ans-(2ll*n+2+qpow(2,k))*kcalc(n)%mod+mod)%mod;
    ans=(ans+1ll*(2*n+1)*kcalc(2*n)%mod)%mod;
    ans=(ans-k1calc(2*n)+mod)%mod;
    ans=1ll*ans*qpow(n,mod-2)%mod;
    printf("%d\n",ans);
}
View Code

upd:

以上是原版的做法。

毒瘤的神鱼又加强了这题……然后就……让我来做……

受宠若惊呢QwQ

不过真没想到自己也能想到加强版正解。


先考虑 $n$ 不为 $998244353$(以下简写为 $p$)的倍数。(因为式子前面有个 $\dfrac{1}{n}$,可能没有逆元)

那么 $\sum\limits_{i=1}^ni^k=\lfloor\dfrac{n}{p}\rfloor\sum\limits_{i=1}^{p-1}i^k+\sum\limits_{i=1}^{n\bmod p}i^k$。

写个高精除低精就没了。

然后考虑 $n$ 是 $p$ 的倍数。

$\sum\limits_{i=1}^ni^k=\dfrac{n}{p}\sum\limits_{i=1}^pi^k$。

此时原式:

$$\dfrac{1}{n}(\dots\sum\limits^n_{i=1}i^k+\dots\sum\limits^{2n}_{i=1}i^k)$$

$$\dfrac{1}{n}(\dfrac{n}{p}\dots\sum\limits^{p-1}_{i=1}i^k+2\dfrac{n}{p}\dots\sum\limits^{p-1}_{i=1}i^k)$$

$$\dfrac{1}{p}\dots\sum\limits^{p-1}_{i=1}i^k+\dfrac{2}{p}\dots\sum\limits^{p-1}_{i=1}i^k$$

此处取点值为 $0$ 到 $k+1$ 而不是上文中的 $1$ 到 $k+2$。

观察拉格朗日插值的式子,分子是有个 $x-x_j$ 的。那么当 $i=0$ 时,$y_i=0$,是 $p$ 的倍数。当 $i\ne 0$ 时,$x-x_0=x=p$,是 $p$ 的倍数。

那么只需要在分子中不乘上那个 $x-x_0=p$ 即可。就相当于求出了 $\frac{\sum\limits_{i=1}^pi^k}{p}$。

那么就做完了。

时间复杂度 $O(k+\log n)$。

(神鱼太毒瘤了……一道数论题我码了3.8K……码出了数据结构题的感觉……)

#define Orz_NaCly_Fish
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=10001000,mod=998244353;
#define FOR(i,a,b) for(int i=(a);i<=(b);i++)
#define ROF(i,a,b) for(int i=(a);i>=(b);i--)
#define MEM(x,v) memset(x,v,sizeof(x))
inline int read(){
    char ch=getchar();int x=0,f=0;
    while(ch<'0' || ch>'9') f|=ch=='-',ch=getchar();
    while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
    return f?-x:x;
}
int ndiv,nmod,k,ans,len,kp[maxn],k1p[maxn],pr[maxn/10],pl,ky[maxn],k1y[maxn],fac[maxn],invfac[maxn],tfac[maxn],tinv[maxn];
char nstr[11111];
bool vis[maxn];
inline int add(int a,int b){return a+b<mod?a+b:a+b-mod;}
inline int sub(int a,int b){return a<b?a-b+mod:a-b;}
inline int mul(int a,int b){return 1ll*a*b%mod;}
inline int qpow(int a,int b){
    int ans=1;
    for(;b;b>>=1,a=mul(a,a)) if(b&1) ans=mul(ans,a);
    return ans;
}
int kcalc(int x){
    if(x<=k+1){
        int ans=mul(mul(ky[x],fac[x]),fac[k-x+1]);
        if(!((k-x)&1)) return ans?mod-ans:0;
        else return ans;
    }
    if(x==mod){
        tfac[0]=1;
        FOR(i,1,k+1) tfac[i]=mul(tfac[i-1],x-i);
        tinv[k+1]=qpow(tfac[k+1],mod-2);
        ROF(i,k,1) tinv[i]=mul(tinv[i+1],x-i-1);
        FOR(i,2,k+1) tinv[i]=mul(tinv[i],tfac[i-1]);
        int ans=0;
        FOR(i,1,k+1) ans=add(ans,mul(mul(ky[i],tfac[k+1]),tinv[i]));
        return ans;
    }
    tfac[0]=x;
    FOR(i,1,k+1) tfac[i]=mul(tfac[i-1],x-i);
    tinv[k+1]=qpow(tfac[k+1],mod-2);
    ROF(i,k,0) tinv[i]=mul(tinv[i+1],x-i-1);
    FOR(i,1,k+1) tinv[i]=mul(tinv[i],tfac[i-1]);
    int ans=0;
    FOR(i,0,k+1) ans=add(ans,mul(mul(ky[i],tfac[k+1]),tinv[i]));
    return ans;
}
int k1calc(int x){
    if(x<=k+2){
        int ans=mul(mul(k1y[x],fac[x]),fac[k-x+2]);
        if((k-x)&1) return ans?mod-ans:0;
        else return ans;
    }
    if(x==mod){
        tfac[0]=1;
        FOR(i,1,k+2) tfac[i]=mul(tfac[i-1],x-i);
        tinv[k+2]=qpow(tfac[k+2],mod-2);
        ROF(i,k+1,1) tinv[i]=mul(tinv[i+1],x-i-1);
        FOR(i,2,k+2) tinv[i]=mul(tinv[i],tfac[i-1]);
        int ans=0;
        FOR(i,1,k+2) ans=add(ans,mul(mul(k1y[i],tfac[k+2]),tinv[i]));
        return ans;
    }
    tfac[0]=x;
    FOR(i,1,k+2) tfac[i]=mul(tfac[i-1],x-i);
    tinv[k+2]=qpow(tfac[k+2],mod-2);
    ROF(i,k+1,0) tinv[i]=mul(tinv[i+1],x-i-1);
    FOR(i,1,k+2) tinv[i]=mul(tinv[i],tfac[i-1]);
    int ans=0;
    FOR(i,0,k+2) ans=add(ans,mul(mul(k1y[i],tfac[k+2]),tinv[i]));
    return ans;
}
void init(){
    fac[0]=1;
    FOR(i,1,k+2) fac[i]=mul(fac[i-1],i);
    invfac[k+2]=qpow(fac[k+2],mod-2);
    ROF(i,k+1,0) invfac[i]=mul(invfac[i+1],i+1);
    kp[1]=k1p[1]=1;
    FOR(i,2,k+2){
        if(!vis[i]){
            pr[++pl]=i;
            kp[i]=qpow(i,k);
            k1p[i]=qpow(i,k+1);
        }
        FOR(j,1,pl){
            if(i*pr[j]>k+2) break;
            vis[i*pr[j]]=true;
            kp[i*pr[j]]=mul(kp[i],kp[pr[j]]);
            k1p[i*pr[j]]=mul(k1p[i],k1p[pr[j]]);
            if(i%pr[j]==0) break;
        }
    }
    ky[0]=k1y[0]=0;
    FOR(i,1,k+2) ky[i]=add(ky[i-1],kp[i]),k1y[i]=add(k1y[i-1],k1p[i]);
    FOR(i,0,k+2){
        ky[i]=mul(mul(ky[i],invfac[i]),invfac[k-i+1]);
        if(!((k-i)&1)) ky[i]=ky[i]?mod-ky[i]:0;
        k1y[i]=mul(mul(k1y[i],invfac[i]),invfac[k-i+2]);
        if((k-i)&1) k1y[i]=k1y[i]?mod-k1y[i]:0;
    }
}
int main(){
    scanf("%s",nstr+1);k=read();
    len=strlen(nstr+1);
    init();
    FOR(i,1,len){
        ll tmp=10ll*nmod+nstr[i]-'0';
        ndiv=add(mul(ndiv,10),tmp/mod);
        nmod=tmp%mod;
    }
    if(nmod){
        int ktmp=kcalc(mod-1),k1tmp=k1calc(mod-1);
        int krem=kcalc(nmod),k1rem=k1calc(nmod);
        int kans=add(mul(ndiv,ktmp),krem),k1ans=add(mul(ndiv,k1tmp),k1rem);
        ans=mul(2,k1ans);
        ans=sub(ans,mul(add(add(mul(nmod,2),2),qpow(2,k)),kans));
        krem=kcalc(mul(nmod,2)),k1rem=k1calc(mul(nmod,2));
        kans=add(mul(add(mul(ndiv,2),nmod>mod/2),ktmp),krem),k1ans=add(mul(add(mul(ndiv,2),nmod>mod/2),k1tmp),k1rem);
        ans=add(ans,mul(2*nmod+1,kans));
        ans=sub(ans,k1ans);
        ans=mul(ans,qpow(nmod,mod-2));
        printf("%d\n",ans);
    }
    else{
        int kans=kcalc(mod),k1ans=k1calc(mod);
        ans=mul(2,k1ans);
        ans=sub(ans,mul(add(2,qpow(2,k)),kans));
        ans=add(ans,mul(2,kans));
        ans=sub(ans,mul(2,k1ans));
        printf("%d\n",ans);
    }
}
View Code

 

posted @ 2019-06-30 13:35  ATS_nantf  阅读(320)  评论(0编辑  收藏  举报