题解 小 H 爱染色

传送门

不要用百度搜这个题目名字

首先发现答案是个大约 \(3m\) 次的多项式,所以部分分可以暴力插值
然后考虑正解

  • 涉及到多项式的式子有时可以将多项式按幂次拆开

    注意寻找能凑出插值的式子

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 3000020
#define ll long long
// #define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m, lim;
const ll mod=998244353, rt=3, phi=mod-1;
ll fac[N], inv[N], pre[N], suf[N], y[N];
inline ll qpow(ll a, int b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
inline ll C(int n, int k) {return n<k?0:fac[n]*inv[k]%mod*inv[n-k]%mod;}
inline ll sqr(ll n) {return n*n%mod;}
inline ll C2(ll n, ll k) {
	if (n<k) return 0;
	ll ans=inv[k];
	for (ll i=n; i>n-k; --i) ans=ans*i%mod;
	return ans;
}

ll lagrange(int n, ll k) {
	pre[0]=suf[n+1]=1;
	for (int i=1; i<=n; ++i) pre[i]=pre[i-1]*(k-i)%mod;
	for (int i=n; i; --i) suf[i]=suf[i+1]*(k-i)%mod;
	ll ans=0;
	for (int i=1; i<=n; ++i) ans=(ans+((n-i)&1?-1:1)*y[i]*pre[i-1]%mod*suf[i+1]%mod*inv[i-1]%mod*inv[n-i]%mod)%mod;
	return ans;
}

namespace force{
	ll e[N], ans=0;
	void solve() {
		// ll inv_cnm=qpow(C(n, m), mod-2), cnm=C(n, m);
		// for (int i=0; i<n; ++i) e[i]=sqr(C(n-i+1, m)*inv_cnm%mod)*(1-sqr(C(n-i, m)*inv_cnm%mod))%mod;
		for (int i=1; i<=n; ++i) e[i]=(sqr(C(n-i+1, m))-sqr(C(n-i, m)))%mod;
		for (int i=0; i<n; ++i) ans=(ans+lagrange(m+1, i+1)*e[i+1])%mod;
		printf("%lld\n", (ans%mod+mod)%mod);
		exit(0);
	}
}

namespace task1{
	ll f[N], sum;
	void solve() {
		int t=3*m+10;
		for (int i=1; i<=t; ++i) f[i]=lagrange(m+1, i); //, cout<<f[i]<<' '; cout<<endl;
		for (int i=1; i<=t; ++i) f[i]=f[i]*(sqr(C2(n-i+1, m))-sqr(C2(n-i, m)))%mod; //, cout<<(sqr(C2(n-i+1, m))-sqr(C2(n-i, m)))<<endl;
		y[0]=0;
		for (int i=1; i<=t; ++i) y[i]=(y[i-1]+f[i])%mod;
		printf("%lld\n", (lagrange(t, n)%mod+mod)%mod);
		exit(0);
	}
}

namespace task2{
	ll f[N], sum;
	void solve() {
		int t=2*m+10;
		for (int i=1; i<=t; ++i) f[i]=1; //, cout<<f[i]<<' '; cout<<endl;
		for (int i=1; i<=t; ++i) f[i]=f[i]*(sqr(C2(n-i+1, m))-sqr(C2(n-i, m)))%mod; //, cout<<(sqr(C2(n-i+1, m))-sqr(C2(n-i, m)))<<endl;
		y[0]=0;
		for (int i=1; i<=t; ++i) y[i]=(y[i-1]+f[i])%mod;
		printf("%lld\n", (lagrange(t, n)%mod+mod)%mod);
		exit(0);
	}
}

namespace task{
	int rev[N], bln, bct;
	ll f[N], g[N], h[N], ans, chk[N], chk2[N];
	void ntt(ll* a, int len, int op) {
		for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i], a[rev[i]]);
		ll w, wn, t;
		for (int i=1; i<len; i<<=1) {
			wn=qpow(rt, (op*phi/(i<<1)+phi)%phi);
			for (int j=0,step=i<<1; j<len; j+=step) {
				w=1;
				for (int k=j; k<j+i; ++k,w=w*wn%mod) {
					t=w*a[k+i];
					a[k+i]=(a[k]-t)%mod;
					a[k]=(a[k]+t)%mod;
				}
			}
		}
		if (op==-1) {
			ll inv=qpow(len, mod-2);
			for (int i=0; i<len; ++i) a[i]=a[i]*inv%mod;
		}
	}
	void solve() {
		for (int i=0; i<=m; ++i) f[i]=C(i+m, m)*C(m, m-i)%mod;
		for (int i=0; i<=m; ++i) g[i]=((i&1)?-1:1)*inv[i];
		for (int i=0; i<=m; ++i) h[i]=y[i+1]*inv[i]%mod;
		for (bln=1; bln<=(m+1)*2; bln<<=1,++bct) ;
		for (int i=0; i<bln; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bct-1));
		ntt(g, bln, 1); ntt(h, bln, 1);
		for (int i=0; i<bln; ++i) g[i]=g[i]*h[i]%mod;
		ntt(g, bln, -1);
		for (int i=0; i<=m; ++i) g[i]=g[i]*fac[i]%mod;
		for (int i=m+1; i<bln; ++i) g[i]=0;
		// cout<<"g: "; for (int i=0; i<=m; ++i) cout<<(g[i]%mod+mod)%mod<<' '; cout<<endl;
		// cout<<"chk: ";
		// for (int t=0; t<=m; ++t) {
		// 	for (int i=0; i<=t; ++i) chk[t]=(chk[t]+(i&1?-1:1)*C(t, i)*y[t-i+1])%mod;
		// 	cout<<(chk[t]+mod)%mod<<' ';
		// } cout<<endl;
		ntt(f, bln, 1); ntt(g, bln, 1);
		for (int i=0; i<bln; ++i) f[i]=f[i]*g[i]%mod;
		ntt(f, bln, -1);
		// ntt(g, bln, -1);
		// cout<<"f: "; for (int i=0; i<=m; ++i) cout<<(f[i]%mod+mod)%mod<<' '; cout<<endl;
		// cout<<"chk: ";
		// for (int t=0; t<=3*m; ++t) {
		// 	for (int i=m; i<=2*m; ++i) if (t-i>=0&&t-i<=m) chk[t]=(chk[t]+C(i, m)*C(m, 2*m-i)%mod*g[t-i])%mod;
		// 	cout<<(chk[t]+mod)%mod<<' ';
		// } cout<<endl;
		// for (int i=m; i<=3*m; ++i) ans=(ans+C(n, i)*chk[i])%mod;
		ll tem=1, now=n-m;
		for (int i=n; i>now; --i) tem=tem*i%mod;
		for (int i=0; i<=2*m; ++i) ans=(ans+tem*inv[i+m]%mod*f[i])%mod, tem=tem*(now--)%mod;
		// for (int t=0; t<=m; ++t) 
		// 	for (int k=m; k<=2*m; ++k)
		// 		ans=(ans+C(n, t+k)*C(k, m)%mod*C(m, 2*m-k)%mod*g[t])%mod;
		printf("%lld\n", (ans%mod+mod)%mod);
		exit(0);
	}
}

signed main()
{
	n=read(); m=read(); lim=max(n<=1000?n:0, 3*m+15);
	bool all_one=1;
	for (int i=0; i<=m; ++i) {
		y[i+1]=read();
		if (y[i+1]!=1) all_one=0;
	}
	fac[0]=fac[1]=1; inv[0]=inv[1]=1;
	for (int i=2; i<=lim; ++i) fac[i]=fac[i-1]*i%mod;
	for (int i=2; i<=lim; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	for (int i=2; i<=lim; ++i) inv[i]=inv[i-1]*inv[i]%mod;
	// if (n<=1000) force::solve();
	// else if (all_one) task2::solve();
	// else task1::solve();
	task::solve();
	
	return 0;
}
posted @ 2022-01-06 21:42  Administrator-09  阅读(0)  评论(0编辑  收藏  举报