洛谷P5437/5442 约定(概率期望,拉格朗日插值,自然数幂)
题目大意:$n$ 个点的完全图,点 $i$ 和点 $j$ 的边权为 $(i+j)^k$。随机一个生成树,问这个生成树边权和的期望对 $998244353$ 取模的值。
对于P5437:$1\le n\le 998244352,1\le k\le 10^7$。
对于P5442:$1\le n\le 10^4,\le k\le 10^7$。
其实也是一道比较简单的题。(所以就应该把这题和上一道原题调个位置)
考虑一条边在生成树中出现的概率,由于一共有 $\dfrac{n(n-1)}{2}$ 条边,一个生成树有 $n-1$ 条边,而每条边的概率相等,所以为 $\dfrac{2}{n}$。
那么开始推式子:(注:第三步是枚举 $i+j$)
$$\dfrac{2}{n}\sum\limits_{i=1}^n\sum\limits_{j=i+1}^n(i+j)^k$$$$\dfrac{1}{n}(\sum\limits_{i=1}^n\sum\limits_{j=1}^n(i+j)^k-\sum\limits^n_{i=1}(i+i)^k)$$
$$\dfrac{1}{n}(\sum\limits_{s=1}^{2n}s^k\min(s-1,2n+1-s)-2^k\sum\limits^n_{i=1}i^k)$$
$$\dfrac{1}{n}(\sum\limits_{s=1}^{n}s^k(s-1)+\sum\limits_{s=n+1}^{2n}s^k(2n+1-s)-2^k\sum\limits^n_{i=1}i^k)$$
$$\dfrac{1}{n}(\sum\limits_{s=1}^{n}s^{k+1}-\sum\limits_{s=1}^{n}s^{k}+(2n+1)\sum\limits_{s=n+1}^{2n}s^k-\sum\limits_{s=n+1}^{2n}s^{k+1}-2^{k}\sum\limits^n_{i=1}i^k)$$
$$\dfrac{1}{n}(2\sum\limits_{i=1}^{n}i^{k+1}-(2n+2+2^k)\sum\limits_{i=1}^{n}i^{k}+(2n+1)\sum\limits_{i=1}^{2n}i^k-\sum\limits_{i=1}^{2n}i^{k+1})$$
现在问题就是求 $f(n)=\sum\limits_{i=1}^ni^k$ 了。
由于 $f(n)-f(n-1)=n^k$,$f$ 的差值是个 $k$ 次多项式,所以 $f$ 是个 $k+1$ 次多项式。
那么可以拉格朗日插值。(以下内容的代码实现细节比较多,注意要控制复杂度不带 $\log$)
取 $k+2$ 个点为 $1$ 到 $k+2$,发现点值 $y_i$ 可以 $O(k)$ 计算。($y_i=y_{i-1}+i^k$ 不能直接快速幂,不然带 $\log$。可以用欧拉筛筛出所有 $k$ 次方)
$$f(n)=\sum\limits_{i=1}^{k+2}y_i\dfrac{\prod\limits^{k+2}_{j=1,j\ne i}(n-x_j)}{\prod\limits^{k+2}_{j=1,j\ne i}(x_i-x_j)}$$
这样拉格朗日插值公式中的分母就是两个阶乘相乘的形式,可以 $O(1)$。(预处理要注意控制复杂度)
代入一个数算时,先特判 $n\ge mod$(因为会调用到 $f(2n)$),此时 $f(n)=\lfloor\dfrac{n}{mod}\rfloor f(mod-1)+f(n\%mod)$。
否则先算出 $fac=\prod\limits_{i=1}^{k+2}(n-i)$。同时预处理出所有 $n-i$ 的逆元 $inv_i$。(不要一个个快速幂算,复杂度错的。要用 $O(k+\log)$ 的方式)
此时就有:
$$f(n)=\sum\limits_{i=1}^{k+2}y_i\dfrac{fac\times inv_i}{(i-1)!(k-i+2)!(-1)^{k-i+2}}$$
已经可以 $O(k)$ 计算了。
时间复杂度 $O(k+\log)$。
#include<bits/stdc++.h> using namespace std; const int maxn=10001000,mod=998244353; #define FOR(i,a,b) for(int i=(a);i<=(b);i++) #define ROF(i,a,b) for(int i=(a);i>=(b);i--) #define MEM(x,v) memset(x,v,sizeof(x)) inline int read(){ char ch=getchar();int x=0,f=0; while(ch<'0' || ch>'9') f|=ch=='-',ch=getchar(); while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar(); return f?-x:x; } int n,k,ans,kp[maxn],k1p[maxn],pr[maxn/10],pl,ky[maxn],k1y[maxn],fac[maxn],invfac[maxn],tfac[maxn],tinv[maxn]; bool vis[maxn]; int qpow(int a,int b){ int ans=1; for(;b;b>>=1,a=1ll*a*a%mod) if(b&1) ans=1ll*ans*a%mod; return ans; } int kcalc(int x){ if(x<=k+2){ int ans=1ll*ky[x]*fac[x-1]%mod*fac[k-x+2]%mod; if((k-x)&1) return ans?mod-ans:0; else return ans; } if(x>=mod) return (kcalc(mod-1)+kcalc(x%mod))%mod; tfac[0]=1; FOR(i,1,k+2) tfac[i]=1ll*tfac[i-1]*(x-i)%mod; tinv[k+2]=qpow(tfac[k+2],mod-2); ROF(i,k+1,1) tinv[i]=1ll*tinv[i+1]*(x-i-1)%mod; FOR(i,2,k+2) tinv[i]=1ll*tinv[i]*tfac[i-1]%mod; int ans=0; FOR(i,1,k+2) ans=(ans+1ll*ky[i]*tfac[k+2]%mod*tinv[i])%mod; return ans; } int k1calc(int x){ if(x<=k+3){ int ans=1ll*k1y[x]*fac[x-1]%mod*fac[k-x+3]%mod; if((k-x)&1) return ans; else return ans?mod-ans:0; } if(x>=mod) return (k1calc(mod-1)+k1calc(x%mod))%mod; tfac[0]=1; FOR(i,1,k+3) tfac[i]=1ll*tfac[i-1]*(x-i)%mod; tinv[k+3]=qpow(tfac[k+3],mod-2); ROF(i,k+2,1) tinv[i]=1ll*tinv[i+1]*(x-i-1)%mod; FOR(i,2,k+3) tinv[i]=1ll*tinv[i]*tfac[i-1]%mod; int ans=0; FOR(i,1,k+3) ans=(ans+1ll*k1y[i]*tfac[k+3]%mod*tinv[i])%mod; return ans; } int main(){ n=read();k=read(); fac[0]=1; FOR(i,1,k+3) fac[i]=1ll*fac[i-1]*i%mod; invfac[k+3]=qpow(fac[k+3],mod-2); ROF(i,k+2,0) invfac[i]=1ll*invfac[i+1]*(i+1)%mod; kp[1]=k1p[1]=1; FOR(i,2,k+3){ if(!vis[i]){ pr[++pl]=i; kp[i]=qpow(i,k); k1p[i]=qpow(i,k+1); } FOR(j,1,pl){ if(i*pr[j]>k+3) break; vis[i*pr[j]]=true; kp[i*pr[j]]=1ll*kp[i]*kp[pr[j]]%mod; k1p[i*pr[j]]=1ll*k1p[i]*k1p[pr[j]]%mod; if(i%pr[j]==0) break; } } ky[1]=k1y[1]=1; FOR(i,2,k+3) ky[i]=(ky[i-1]+kp[i])%mod,k1y[i]=(k1y[i-1]+k1p[i])%mod; FOR(i,1,k+3){ ky[i]=1ll*ky[i]*invfac[i-1]%mod*invfac[k-i+2]%mod; if((k-i)&1) ky[i]=ky[i]?mod-ky[i]:0; k1y[i]=1ll*k1y[i]*invfac[i-1]%mod*invfac[k-i+3]%mod; if(!((k-i)&1)) k1y[i]=k1y[i]?mod-k1y[i]:0; } ans=2*k1calc(n)%mod; ans=(ans-(2ll*n+2+qpow(2,k))*kcalc(n)%mod+mod)%mod; ans=(ans+1ll*(2*n+1)*kcalc(2*n)%mod)%mod; ans=(ans-k1calc(2*n)+mod)%mod; ans=1ll*ans*qpow(n,mod-2)%mod; printf("%d\n",ans); }
upd:
以上是原版的做法。
毒瘤的神鱼又加强了这题……然后就……让我来做……
受宠若惊呢QwQ
不过真没想到自己也能想到加强版正解。
先考虑 $n$ 不为 $998244353$(以下简写为 $p$)的倍数。(因为式子前面有个 $\dfrac{1}{n}$,可能没有逆元)
那么 $\sum\limits_{i=1}^ni^k=\lfloor\dfrac{n}{p}\rfloor\sum\limits_{i=1}^{p-1}i^k+\sum\limits_{i=1}^{n\bmod p}i^k$。
写个高精除低精就没了。
然后考虑 $n$ 是 $p$ 的倍数。
$\sum\limits_{i=1}^ni^k=\dfrac{n}{p}\sum\limits_{i=1}^pi^k$。
此时原式:
$$\dfrac{1}{n}(\dots\sum\limits^n_{i=1}i^k+\dots\sum\limits^{2n}_{i=1}i^k)$$
$$\dfrac{1}{n}(\dfrac{n}{p}\dots\sum\limits^{p-1}_{i=1}i^k+2\dfrac{n}{p}\dots\sum\limits^{p-1}_{i=1}i^k)$$
$$\dfrac{1}{p}\dots\sum\limits^{p-1}_{i=1}i^k+\dfrac{2}{p}\dots\sum\limits^{p-1}_{i=1}i^k$$
此处取点值为 $0$ 到 $k+1$ 而不是上文中的 $1$ 到 $k+2$。
观察拉格朗日插值的式子,分子是有个 $x-x_j$ 的。那么当 $i=0$ 时,$y_i=0$,是 $p$ 的倍数。当 $i\ne 0$ 时,$x-x_0=x=p$,是 $p$ 的倍数。
那么只需要在分子中不乘上那个 $x-x_0=p$ 即可。就相当于求出了 $\frac{\sum\limits_{i=1}^pi^k}{p}$。
那么就做完了。
时间复杂度 $O(k+\log n)$。
(神鱼太毒瘤了……一道数论题我码了3.8K……码出了数据结构题的感觉……)
#define Orz_NaCly_Fish #include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn=10001000,mod=998244353; #define FOR(i,a,b) for(int i=(a);i<=(b);i++) #define ROF(i,a,b) for(int i=(a);i>=(b);i--) #define MEM(x,v) memset(x,v,sizeof(x)) inline int read(){ char ch=getchar();int x=0,f=0; while(ch<'0' || ch>'9') f|=ch=='-',ch=getchar(); while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar(); return f?-x:x; } int ndiv,nmod,k,ans,len,kp[maxn],k1p[maxn],pr[maxn/10],pl,ky[maxn],k1y[maxn],fac[maxn],invfac[maxn],tfac[maxn],tinv[maxn]; char nstr[11111]; bool vis[maxn]; inline int add(int a,int b){return a+b<mod?a+b:a+b-mod;} inline int sub(int a,int b){return a<b?a-b+mod:a-b;} inline int mul(int a,int b){return 1ll*a*b%mod;} inline int qpow(int a,int b){ int ans=1; for(;b;b>>=1,a=mul(a,a)) if(b&1) ans=mul(ans,a); return ans; } int kcalc(int x){ if(x<=k+1){ int ans=mul(mul(ky[x],fac[x]),fac[k-x+1]); if(!((k-x)&1)) return ans?mod-ans:0; else return ans; } if(x==mod){ tfac[0]=1; FOR(i,1,k+1) tfac[i]=mul(tfac[i-1],x-i); tinv[k+1]=qpow(tfac[k+1],mod-2); ROF(i,k,1) tinv[i]=mul(tinv[i+1],x-i-1); FOR(i,2,k+1) tinv[i]=mul(tinv[i],tfac[i-1]); int ans=0; FOR(i,1,k+1) ans=add(ans,mul(mul(ky[i],tfac[k+1]),tinv[i])); return ans; } tfac[0]=x; FOR(i,1,k+1) tfac[i]=mul(tfac[i-1],x-i); tinv[k+1]=qpow(tfac[k+1],mod-2); ROF(i,k,0) tinv[i]=mul(tinv[i+1],x-i-1); FOR(i,1,k+1) tinv[i]=mul(tinv[i],tfac[i-1]); int ans=0; FOR(i,0,k+1) ans=add(ans,mul(mul(ky[i],tfac[k+1]),tinv[i])); return ans; } int k1calc(int x){ if(x<=k+2){ int ans=mul(mul(k1y[x],fac[x]),fac[k-x+2]); if((k-x)&1) return ans?mod-ans:0; else return ans; } if(x==mod){ tfac[0]=1; FOR(i,1,k+2) tfac[i]=mul(tfac[i-1],x-i); tinv[k+2]=qpow(tfac[k+2],mod-2); ROF(i,k+1,1) tinv[i]=mul(tinv[i+1],x-i-1); FOR(i,2,k+2) tinv[i]=mul(tinv[i],tfac[i-1]); int ans=0; FOR(i,1,k+2) ans=add(ans,mul(mul(k1y[i],tfac[k+2]),tinv[i])); return ans; } tfac[0]=x; FOR(i,1,k+2) tfac[i]=mul(tfac[i-1],x-i); tinv[k+2]=qpow(tfac[k+2],mod-2); ROF(i,k+1,0) tinv[i]=mul(tinv[i+1],x-i-1); FOR(i,1,k+2) tinv[i]=mul(tinv[i],tfac[i-1]); int ans=0; FOR(i,0,k+2) ans=add(ans,mul(mul(k1y[i],tfac[k+2]),tinv[i])); return ans; } void init(){ fac[0]=1; FOR(i,1,k+2) fac[i]=mul(fac[i-1],i); invfac[k+2]=qpow(fac[k+2],mod-2); ROF(i,k+1,0) invfac[i]=mul(invfac[i+1],i+1); kp[1]=k1p[1]=1; FOR(i,2,k+2){ if(!vis[i]){ pr[++pl]=i; kp[i]=qpow(i,k); k1p[i]=qpow(i,k+1); } FOR(j,1,pl){ if(i*pr[j]>k+2) break; vis[i*pr[j]]=true; kp[i*pr[j]]=mul(kp[i],kp[pr[j]]); k1p[i*pr[j]]=mul(k1p[i],k1p[pr[j]]); if(i%pr[j]==0) break; } } ky[0]=k1y[0]=0; FOR(i,1,k+2) ky[i]=add(ky[i-1],kp[i]),k1y[i]=add(k1y[i-1],k1p[i]); FOR(i,0,k+2){ ky[i]=mul(mul(ky[i],invfac[i]),invfac[k-i+1]); if(!((k-i)&1)) ky[i]=ky[i]?mod-ky[i]:0; k1y[i]=mul(mul(k1y[i],invfac[i]),invfac[k-i+2]); if((k-i)&1) k1y[i]=k1y[i]?mod-k1y[i]:0; } } int main(){ scanf("%s",nstr+1);k=read(); len=strlen(nstr+1); init(); FOR(i,1,len){ ll tmp=10ll*nmod+nstr[i]-'0'; ndiv=add(mul(ndiv,10),tmp/mod); nmod=tmp%mod; } if(nmod){ int ktmp=kcalc(mod-1),k1tmp=k1calc(mod-1); int krem=kcalc(nmod),k1rem=k1calc(nmod); int kans=add(mul(ndiv,ktmp),krem),k1ans=add(mul(ndiv,k1tmp),k1rem); ans=mul(2,k1ans); ans=sub(ans,mul(add(add(mul(nmod,2),2),qpow(2,k)),kans)); krem=kcalc(mul(nmod,2)),k1rem=k1calc(mul(nmod,2)); kans=add(mul(add(mul(ndiv,2),nmod>mod/2),ktmp),krem),k1ans=add(mul(add(mul(ndiv,2),nmod>mod/2),k1tmp),k1rem); ans=add(ans,mul(2*nmod+1,kans)); ans=sub(ans,k1ans); ans=mul(ans,qpow(nmod,mod-2)); printf("%d\n",ans); } else{ int kans=kcalc(mod),k1ans=k1calc(mod); ans=mul(2,k1ans); ans=sub(ans,mul(add(2,qpow(2,k)),kans)); ans=add(ans,mul(2,kans)); ans=sub(ans,mul(2,k1ans)); printf("%d\n",ans); } }