题目链接

(luogu) https://www.luogu.org/problem/P4091
(bzoj) https://www.lydsy.com/JudgeOnline/problem.php?id=4555

题解

终于不是神仙题了啊。。。
首先\(O(n\log n)\)的FFT做法非常明显,直接用容斥展开,这里不再赘述了。发现最后就是要求一个\(\sum^{n}_{k=0}\sum^{n}_{j=k}(-1)^{j-k}{j\choose k}2^j(\sum^{n}_{i=0}k^i)\).
注意容斥的公式在上指标小于下指标时依然成立,此时其给出的值恒为\(0\).

然后去膜了一波大佬的线性做法:
依然是要求那个式子,但是那个式子可以不用卷积算。
我们发现,假设某序列\(A\)的生成函数为\(A(x)=\sum^{n}_{i=0}a_ix^i\), 那么\(\sum^{n}_{j=0}\sum^n_{i=j}a_i{i\choose j}q^{i-j}x^j=\sum^{n}_{j=0}a_j(x+q)^j=A(x+q)\).
所以要求的就相当于令\(A_i=2^i,q=-1\), 那么\(A(x)=\frac{(2x)^{n+1}}{2x-1}, A(x-1)=\frac{(2x-2)^{n+1}}{2x-3}\), 这个东西直接把上面的二项式展开然后除以一次式\(O(n)\)解决。
这样求出了多项式的每一项,再和\(\sum^{n}_{i=0}k^i\)乘起来求和即可。
对于快速幂,可以用线性筛先求出所有质数的\((n+1)\)次幂,然后按积性函数乘起来,假设质数密度是\(O(\frac{1}{\log n})\)那么复杂度就是\(O(\frac{n}{\log n}\times \log n)=O(n)\).

现在思考一个问题: 这个做法先把二项式展开之后的组合数合了起来,又把合起来之后的式子二项式展开了,那么实际上应该相当于啥也没干,凭什么就优化复杂度了呢?
其实是因为,原来的二项式合并有\(n\)项(做了\(n\)次合并),合并完了之后它变成了一个非常好看的等比数列求和的形式,那么可以直接表示成\((n+1)\)次减\(1\)除以\(1\)\(-1\), 那么再做二项式展开就只要做一次了,于是复杂度成功降了下来!
真神奇。

代码

由于我懒得写线性筛所以用了快速幂复杂度依然是\(O(n\log n)\).
但是线性做法的思想还是非常值得借鉴的。

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cassert>
#include<iostream>
#define llong long long
using namespace std;

inline int read()
{
	int x=0; bool f=1; char c=getchar();
	for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
	for(; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^'0');
	if(f) return x;
	return -x;
}

const int N = 1e5+1;
const int P = 998244353;
const llong INV2 = 499122177;
llong fact[N+3],finv[N+3];
int n;

llong quickpow(llong x,llong y)
{
	llong cur = x,ret = 1ll;
	for(int i=0; y; i++)
	{
		if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
		cur = cur*cur%P;
	}
	return ret;
}
llong mulinv(llong x) {return quickpow(x,P-2);}
llong comb(llong x,llong y) {return x<0||y<0||x<y ? 0ll : fact[x]*finv[y]%P*finv[x-y]%P;}

llong a[N+3];
llong aa[N+3];
llong b[N+3];

int main()
{
	fact[0] = 1ll; for(int i=1; i<=N; i++) fact[i] = fact[i-1]*i%P;
	finv[N] = quickpow(fact[N],P-2); for(int i=N-1; i>=0; i--) finv[i] = finv[i+1]*(i+1)%P;
	scanf("%d",&n);
	for(int i=0; i<=n+1; i++)
	{
		a[i] = quickpow(2ll,n+1)*comb(n+1,i);
		if((n+1-i)&1) {a[i] = P-a[i];}
	}
	a[0]--;
	for(int i=n; i>=0; i--)
	{
		aa[i] = a[i+1]*INV2%P;
		a[i] = (a[i]+3ll*aa[i])%P;
	}
	b[0] = 1ll; b[1] = n+1;
	for(int i=2; i<=n; i++)
	{
		b[i] = (quickpow(i,n+1)-1)*mulinv(i-1)%P;
	}
	llong ans = 0ll;
	for(int i=0; i<=n; i++) {ans = (ans+aa[i]*b[i])%P;}
	printf("%lld\n",ans);
	return 0;
}