题解 糖果
是个矩阵快速幂,但因为没写出来基础DP根本无从快速幂……
首先基础DP(Yubai优秀写法):
令 \(dp[i][j]\) 为考虑到第 \(i\) 种糖果,已经确定了 \(j\) 个糖果顺序关系的方案数
考虑原来的顺序关系方案数是 \(\frac{k!}{\prod x_t!}\),我们想让它成为 \(\frac{j!}{x_i!\times\prod x_t!}\),那乘个 \(\frac{(k+1)^{\underline(j-k)}}{x_i!}\) 即可
于是转移就显然了
然后发现这个 \(a_i\) 是有循环节的,但循环节很长
考虑 \(k\) 的上下界,发现与0取max后有用的 \(k\) 只有大约 \(m\) 种取值
而且这些 \(a_i\) 之间并没有先后顺序关系
于是可以对这 \(m\) 种值分别开桶记录要转移多少次
每个固定的值的转移矩阵是一样的,所以乘这么多次就好了
注意开头和结尾的循环节可能是不完整的
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m;
int a[10000010], A, B, P;
ll fac[N], inv[N], inv2[N];
const ll mod=998244353;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll qpow(ll a, ll b) {ll ans=1ll; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
int buc[N], ans, rest;
void dfs(int u) {
// cout<<"dfs "<<u<<endl;
if (u>n) {
if (rest) return ;
ll tem=fac[m];
for (int i=1; i<=n; ++i) tem=tem*inv2[buc[i]]%mod;
ans=(ans+tem)%mod;
return ;
}
for (int i=0; i<=min(a[u], rest); ++i) {
buc[u]=i; rest-=i;
dfs(u+1);
rest+=i; buc[u]=0;
}
}
void solve() {
for (int i=2; i<=n; ++i) a[i]=(a[i-1]*A+B)%P+1;
// cout<<"a: "; for (int i=1; i<=n; ++i) cout<<a[i]<<' '; cout<<endl;
rest=m;
dfs(1);
printf("%lld\n", ans);
exit(0);
}
}
namespace task1{
ll dp[1010][1010];
void solve() {
dp[0][0]=1;
for (int i=2; i<=n; ++i) a[i]=(a[i-1]*A+B)%P+1;
for (int i=1; i<=n; ++i) {
for (int j=0; j<=m; ++j) {
for (int k=max(j-a[i], 0ll); k<=j; ++k) {
dp[i][j] = (dp[i][j]+dp[i-1][k]*fac[j]%mod*inv2[k]%mod*inv2[j-k]%mod)%mod;
}
}
}
printf("%lld\n", dp[n][m]);
exit(0);
}
}
namespace task{
int cnt[120], now;
bool vis[10000010];
ll dp2[2][120];
struct matrix{
int n, m;
ll a[105][105];
matrix(){memset(a, 0, sizeof(a));}
matrix(int x, int y){n=x; m=y; memset(a, 0, sizeof(a));}
inline void resize(int x, int y) {n=x; m=y;}
inline void clear() {memset(a, 0, sizeof(a));}
inline void put() {for (int i=0; i<=n; ++i) {for (int j=0; j<=m; ++j) cout<<a[i][j]<<' '; cout<<endl;}}
inline ll* operator [] (int t) {return a[t];}
inline matrix operator * (matrix b) {
matrix ans(n, b.m);
for (int i=0; i<=n; ++i)
for (int k=0; k<=m; ++k)
for (int j=0; j<=b.m; ++j)
md(ans[i][j], a[i][k]*b[k][j]%mod);
return ans;
}
}mat, t;
matrix qpow(matrix a, ll b) {
matrix ans=a; --b;
while (b) {
if (b&1) ans=ans*a;
a=a*a; b>>=1;
}
return ans;
}
void solve() {
int st=1, ed=1;
vis[a[1]]=1;
for (int i=2; ; ++i,++ed) {
a[i]=(a[i-1]*A+B)%P+1;
if (vis[a[i]]) break;
vis[a[i]]=1;
}
// cout<<"a: "; for (int i=1; i<=ed; ++i) cout<<a[i]<<' '; cout<<endl;
ll a1=a[1]; dp2[now^1][0]=1;
for (int i=1; i<=n&&a1!=a[ed+1]; ++i,++st,now^=1,a1=(a1*A+B)%P+1) {
// cout<<"extra: "<<i<<endl;
for (int j=0; j<=m; ++j) {
dp2[now][j]=0;
for (int k=max(j-a1, 0ll); k<=j; ++k) {
dp2[now][j] = (dp2[now][j]+dp2[now^1][k]*fac[j]%mod*inv2[k]%mod*inv2[j-k]%mod)%mod;
}
}
}
// cout<<"now: "<<now<<endl;
// printf("%lld\n", dp2[now^1][m]);
for (int i=0; i<=m; ++i) mat[1][i]=dp2[now^1][i];
int len=ed-st+1;
// cout<<"len: "<<len<<endl;
// cout<<"s&e: "<<st<<' '<<ed<<endl;
if (n<=len) {
for (int i=st; i<=n; ++i,now^=1,a1=(a1*A+B)%P+1) {
// cout<<"i: "<<i<<endl;
for (int j=0; j<=m; ++j) {
dp2[now][j]=0;
for (int k=max(j-a1, 0ll); k<=j; ++k) {
dp2[now][j] = (dp2[now][j]+dp2[now^1][k]*fac[j]%mod*inv2[k]%mod*inv2[j-k]%mod)%mod;
}
}
}
// cout<<"pos1"<<endl;
printf("%lld\n", dp2[now^1][m]);
exit(0);
}
for (int i=st; i<=ed; ++i) cnt[min(a[i], m)]+=(n-st+1)/len; //, cout<<"cnt: "<<(n-st+1)/len<<endl;
n=(n-st+1)%len;
// cout<<"mod: "<<n<<endl;
for (int i=1; i<=n; ++i) ++cnt[min(a[i+st-1], m)];
#if 0
for (int i=1; i<=m; ++i) {
for (int s=1; s<=cnt[i]; ++s,now^=1) {
for (int j=0; j<=m; ++j) {
dp2[now][j]=0;
for (int k=max(j-i, 0ll); k<=j; ++k) {
dp2[now][j] = (dp2[now][j]+dp2[now^1][k]*fac[j]%mod*inv2[k]%mod*inv2[j-k]%mod)%mod;
}
}
}
}
// printf("%lld\n", dp2[now^1][m]);
#endif
mat.resize(1, m); t.resize(m, m);
for (int i=1; i<=m; ++i) if (cnt[i]) {
// cout<<"cnt: "<<i<<' '<<cnt[i]<<endl;
t.clear();
for (int j=0; j<=m; ++j) {
for (int k=max(j-i, 0ll); k<=j; ++k) {
t[k][j]=fac[j]%mod*inv2[k]%mod*inv2[j-k]%mod;
}
}
// cout<<"at "<<i<<' '<<"t: "<<endl;
// t.put(); cout<<endl;
t=qpow(t, cnt[i]);
mat=mat*t;
}
printf("%lld\n", mat[1][m]);
exit(0);
}
}
signed main()
{
freopen("sugar.in", "r", stdin);
freopen("sugar.out", "w", stdout);
n=read(); m=read();
a[1]=read(); A=read(); B=read(); P=read();
fac[0]=fac[1]=1; inv[0]=inv[1]=1; inv2[0]=inv2[1]=1;
for (int i=2; i<=10000; ++i) fac[i]=fac[i-1]*i%mod;
for (int i=2; i<=10000; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<=10000; ++i) inv2[i]=inv2[i-1]*inv[i]%mod;
// force::solve();
// task1::solve();
task::solve();
return 0;
}