题解 糖果

传送门

是个矩阵快速幂,但因为没写出来基础DP根本无从快速幂……

首先基础DP(Yubai优秀写法):
\(dp[i][j]\) 为考虑到第 \(i\) 种糖果,已经确定了 \(j\) 个糖果顺序关系的方案数
考虑原来的顺序关系方案数是 \(\frac{k!}{\prod x_t!}\),我们想让它成为 \(\frac{j!}{x_i!\times\prod x_t!}\),那乘个 \(\frac{(k+1)^{\underline(j-k)}}{x_i!}\) 即可
于是转移就显然了

然后发现这个 \(a_i\) 是有循环节的,但循环节很长
考虑 \(k\) 的上下界,发现与0取max后有用的 \(k\) 只有大约 \(m\) 种取值
而且这些 \(a_i\) 之间并没有先后顺序关系
于是可以对这 \(m\) 种值分别开桶记录要转移多少次
每个固定的值的转移矩阵是一样的,所以乘这么多次就好了
注意开头和结尾的循环节可能是不完整的

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
int a[10000010], A, B, P;
ll fac[N], inv[N], inv2[N];
const ll mod=998244353;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll qpow(ll a, ll b) {ll ans=1ll; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	int buc[N], ans, rest;
	void dfs(int u) {
		// cout<<"dfs "<<u<<endl;
		if (u>n) {
			if (rest) return ;
			ll tem=fac[m];
			for (int i=1; i<=n; ++i) tem=tem*inv2[buc[i]]%mod;
			ans=(ans+tem)%mod;
			return ;
		}
		for (int i=0; i<=min(a[u], rest); ++i) {
			buc[u]=i; rest-=i;
			dfs(u+1);
			rest+=i; buc[u]=0;
		}
	}
	void solve() {
		for (int i=2; i<=n; ++i) a[i]=(a[i-1]*A+B)%P+1;
		// cout<<"a: "; for (int i=1; i<=n; ++i) cout<<a[i]<<' '; cout<<endl;
		rest=m;
		dfs(1);
		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task1{
	ll dp[1010][1010];
	void solve() {
		dp[0][0]=1;
		for (int i=2; i<=n; ++i) a[i]=(a[i-1]*A+B)%P+1;
		for (int i=1; i<=n; ++i) {
			for (int j=0; j<=m; ++j) {
				for (int k=max(j-a[i], 0ll); k<=j; ++k) {
					dp[i][j] = (dp[i][j]+dp[i-1][k]*fac[j]%mod*inv2[k]%mod*inv2[j-k]%mod)%mod;
				}
			}
		}
		printf("%lld\n", dp[n][m]);
		exit(0);
	}
}

namespace task{
	int cnt[120], now;
	bool vis[10000010];
	ll dp2[2][120];
	struct matrix{
		int n, m;
		ll a[105][105];
		matrix(){memset(a, 0, sizeof(a));}
		matrix(int x, int y){n=x; m=y; memset(a, 0, sizeof(a));}
		inline void resize(int x, int y) {n=x; m=y;}
		inline void clear() {memset(a, 0, sizeof(a));}
		inline void put() {for (int i=0; i<=n; ++i) {for (int j=0; j<=m; ++j) cout<<a[i][j]<<' '; cout<<endl;}}
		inline ll* operator [] (int t) {return a[t];}
		inline matrix operator * (matrix b) {
			matrix ans(n, b.m);
			for (int i=0; i<=n; ++i)
				for (int k=0; k<=m; ++k)
					for (int j=0; j<=b.m; ++j)
						md(ans[i][j], a[i][k]*b[k][j]%mod);
			return ans;
		}
	}mat, t;
	matrix qpow(matrix a, ll b) {
		matrix ans=a; --b;
		while (b) {
			if (b&1) ans=ans*a;
			a=a*a; b>>=1;
		}
		return ans;
	}
	void solve() {
		int st=1, ed=1;
		vis[a[1]]=1;
		for (int i=2; ; ++i,++ed) {
			a[i]=(a[i-1]*A+B)%P+1;
			if (vis[a[i]]) break;
			vis[a[i]]=1;
		}
		// cout<<"a: "; for (int i=1; i<=ed; ++i) cout<<a[i]<<' '; cout<<endl;
		ll a1=a[1]; dp2[now^1][0]=1;
		for (int i=1; i<=n&&a1!=a[ed+1]; ++i,++st,now^=1,a1=(a1*A+B)%P+1) {
			// cout<<"extra: "<<i<<endl;
			for (int j=0; j<=m; ++j) {
				dp2[now][j]=0;
				for (int k=max(j-a1, 0ll); k<=j; ++k) {
					dp2[now][j] = (dp2[now][j]+dp2[now^1][k]*fac[j]%mod*inv2[k]%mod*inv2[j-k]%mod)%mod;
				}
			}
		}
		// cout<<"now: "<<now<<endl;
		// printf("%lld\n", dp2[now^1][m]);
		for (int i=0; i<=m; ++i) mat[1][i]=dp2[now^1][i];
		int len=ed-st+1;
		// cout<<"len: "<<len<<endl;
		// cout<<"s&e: "<<st<<' '<<ed<<endl;
		if (n<=len) {
			for (int i=st; i<=n; ++i,now^=1,a1=(a1*A+B)%P+1) {
				// cout<<"i: "<<i<<endl;
				for (int j=0; j<=m; ++j) {
					dp2[now][j]=0;
					for (int k=max(j-a1, 0ll); k<=j; ++k) {
						dp2[now][j] = (dp2[now][j]+dp2[now^1][k]*fac[j]%mod*inv2[k]%mod*inv2[j-k]%mod)%mod;
					}
				}
			}
			// cout<<"pos1"<<endl;
			printf("%lld\n", dp2[now^1][m]);
			exit(0);
		}
		for (int i=st; i<=ed; ++i) cnt[min(a[i], m)]+=(n-st+1)/len; //, cout<<"cnt: "<<(n-st+1)/len<<endl;
		n=(n-st+1)%len;
		// cout<<"mod: "<<n<<endl;
		for (int i=1; i<=n; ++i) ++cnt[min(a[i+st-1], m)];
		#if 0
		for (int i=1; i<=m; ++i) {
			for (int s=1; s<=cnt[i]; ++s,now^=1) {
				for (int j=0; j<=m; ++j) {
					dp2[now][j]=0;
					for (int k=max(j-i, 0ll); k<=j; ++k) {
						dp2[now][j] = (dp2[now][j]+dp2[now^1][k]*fac[j]%mod*inv2[k]%mod*inv2[j-k]%mod)%mod;
					}
				}
			}
		}
		// printf("%lld\n", dp2[now^1][m]);
		#endif
		mat.resize(1, m); t.resize(m, m);
		for (int i=1; i<=m; ++i) if (cnt[i]) {
			// cout<<"cnt: "<<i<<' '<<cnt[i]<<endl;
			t.clear();
			for (int j=0; j<=m; ++j) {
				for (int k=max(j-i, 0ll); k<=j; ++k) {
					t[k][j]=fac[j]%mod*inv2[k]%mod*inv2[j-k]%mod;
				}
			}
			// cout<<"at "<<i<<' '<<"t: "<<endl;
			// t.put(); cout<<endl;
			t=qpow(t, cnt[i]);
			mat=mat*t;
		}
		printf("%lld\n", mat[1][m]);
		exit(0);
	}
}

signed main()
{
	freopen("sugar.in", "r", stdin);
	freopen("sugar.out", "w", stdout);

	n=read(); m=read();
	a[1]=read(); A=read(); B=read(); P=read();
	fac[0]=fac[1]=1; inv[0]=inv[1]=1; inv2[0]=inv2[1]=1;
	for (int i=2; i<=10000; ++i) fac[i]=fac[i-1]*i%mod;
	for (int i=2; i<=10000; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	for (int i=2; i<=10000; ++i) inv2[i]=inv2[i-1]*inv[i]%mod;
	// force::solve();
	// task1::solve();
	task::solve();
	
	return 0;
}
posted @ 2021-09-23 21:02  Administrator-09  阅读(13)  评论(1编辑  收藏  举报