题解 签到 / 序列计数问题 / [UNR #2] 梦中的题面
怎么都搞这个
打 nm 过 nm 爷就是要打锤子
那么首先有一个经典容斥,钦定超过限制的桶
那么就有
\[ans=\sum\limits_s(-1)^{|s|}\binom{n+m+(c-1)|s|-\sum\limits_{i\in s}b^i}{m}
\]
然后这个式子就只会 \(O(2^m)\) 算,于是寄了
- 包含未知数的组合数 \(\dbinom{ax+b}{c}\) 在保证 \(ax+b\geqslant 0\) 的情况下可以拆成关于 \(x\) 的多项式
这样的好处是不必对每个 \(x\) 分别计算,可以直接代入 \(\sum x_i\) 求出 \(\sum\dbinom{ax_i+b}{c}\)
那么枚举钦定 \(|s|\),\(A=n+m+(c-1)|s|\) 就确定了
那么就是要算出满足 \(\sum\limits_{i\in s}b^i\leqslant A\) 的 \(\sum\limits_{i\in s}b^i\)
发现这里都是 \(b^i\),那将 \(A\) 写成 \(b\) 进制的话可以很方便地钦定 \(A\) 与 \(\sum\limits_{i\in s}b^i\) 的 lcp 长度
那么问题转化为求 \(g_{i, j, k}\) 表示考虑前 \(i\) 个元素,选了 \(j\) 个的权值的 \(k\) 次方和
这里要 DP \(k\) 次方和是为了方便代入多项式求组合数
那么枚举 lcp,高位方案是固定的,和低位的 DP 合并一下就好了
在 UOJ 上被 hack 后 upd:特别注意一点:考虑 \(n=b^{m+1}\) 的情况,所以 pw 之类的东西要预处理到 \(m+1\) 而不是 \(m\)
复杂度 \(O(n^4)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define pb push_back
#define ll long long
#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
const ll mod=998244353;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
int n, m, b, c;
ll inv[N], up[N], ans;
inline ll C(ll n, ll k) {
ll ans=inv[k];
for (int i=n; i>n-k; --i) ans=ans*i%mod;
return ans;
}
void solve() {
m=read(); b=read(); c=read(); n=read();
inv[0]=inv[1]=1;
for (int i=2; i<=m; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<=m; ++i) inv[i]=inv[i-1]*inv[i]%mod;
for (int i=1; i<=m; ++i) up[i]=(qpow(b, i)-c)%mod;
int lim=1<<m;
n=n-1;
for (int s=0; s<lim; ++s) {
int sum=n;
for (int i=1; i<=m; ++i) if (s&(1<<(i-1))) {
int t=qpow(b, i)-c+1;
if (sum>=t) sum=sum-t;
else goto jump;
}
// C((sum.toint()+m)%mod, m)
ans=(ans+(__builtin_popcount(s)&1?-1:1)*C(sum+m, m))%mod;
jump: ;
}
cout<<(ans%mod+mod)%mod<<endl;
}
}
namespace force2{
int n, m, b, c;
ll x[N], up[N], ans;
void dfs(int u) {
if (u>m) {
ll sum=0;
for (int i=1; i<=m; ++i) sum+=x[i];
if (sum<n) ++ans;
return ;
}
for (int i=0; i<=up[u]; ++i) x[u]=i, dfs(u+1);
}
void solve() {
m=read(); b=read(); c=read(); n=read();
for (int i=1; i<=m; ++i) up[i]=(qpow(b, i)-c)%mod;
dfs(1);
cout<<ans<<endl;
}
}
namespace task1{
int m, b, c;
ll inv[N], up[N], ans;
struct Int{
vector<int> a;
Int(){}
Int(int t){do {a.pb(t%10); t/=10;} while (t);}
int len() {return a.size();}
inline int& operator [] (int t) {return a[t];}
void adjust() {while (a.size()>1&&!a.back()) a.pop_back();}
void put() {for (int i=a.size()-1; ~i; --i) printf("%lld", a[i]); printf("\n");}
void get() {
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) {a.pb(c-'0'); c=getchar();}
reverse(a.begin(), a.end());
}
inline Int operator + (Int b) {
Int ans;
int lim=max(len(), b.len())+2;
ans.a.resize(lim);
for (int i=0; i<lim; ++i) {
if (i<len()) ans[i]+=a[i];
if (i<b.len()) ans[i]+=b[i];
ans[i+1]+=ans[i]/10;
ans[i]%=10;
}
ans.adjust();
return ans;
}
inline Int operator * (Int b) {
Int ans; ans.a.resize(len()+b.len()+1);
for (int i=0; i<len(); ++i)
for (int j=0; j<b.len(); ++j) {
ans[i+j]=ans[i+j]+a[i]*b[j];
ans[i+j+1]+=ans[i+j]/10;
ans[i+j]%=10;
}
ans.adjust();
return ans;
}
inline Int operator - (Int b) {
Int ans=*this;
for (int i=0; i<b.len(); ++i) ans[i]-=b[i];
for (int i=0; i<len(); ++i)
if (ans[i]<0) --ans[i+1], ans[i]+=10;
ans.adjust();
return ans;
}
inline ll toint() {
ll ans=0;
for (int i=a.size()-1; ~i; --i) ans=(ans*10+a[i])%mod;
return ans;
}
inline bool operator <= (Int b) {
if (len()!=b.len()) return len()<b.len();
for (int i=len()-1; ~i; --i)
if (a[i]!=b[i]) return a[i]<b[i];
return 1;
}
}n, b2, pw[100];
inline ll C(ll n, ll k) {
ll ans=inv[k];
for (int i=n; i>n-k; --i) ans=ans*i%mod;
return ans;
}
void solve() {
m=read(); b=read(); c=read();
n.get();
inv[0]=inv[1]=1;
for (int i=2; i<=m; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<=m; ++i) inv[i]=inv[i-1]*inv[i]%mod;
for (int i=1; i<=m; ++i) up[i]=(qpow(b, i)-c)%mod;
pw[1]=Int(b);
for (int i=2; i<=m; ++i) pw[i]=pw[i-1]*Int(b);
int lim=1<<m;
n=n-Int(1);
for (int s=0; s<lim; ++s) {
Int sum=n;
for (int i=1; i<=m; ++i) if (s&(1<<(i-1))) {
Int t=pw[i];
if (c-1>=0) t=t-Int(c-1);
else if (c-1<0) t=t+Int(c-1);
if (t<=sum) sum=sum-t;
else goto jump;
}
// C((sum.toint()+m)%mod, m)
ans=(ans+(__builtin_popcount(s)&1?-1:1)*C(sum.toint()+m, m))%mod;
jump: ;
}
cout<<(ans%mod+mod)%mod<<endl;
}
}
namespace task{
int m, b, c;
ll fac[N], inv[N], pw[100], pw2[100][100], hpw[100], f[100][100], g[100][100][100], F[100], ans;
inline ll C(int n, int k) {return fac[n]*inv[k]%mod*inv[n-k]%mod;}
inline ll qval(int x) {
int now=1, ans=0;
for (int i=0; i<=m; ++i,now=now*x%mod) ans=(ans+F[i]*now)%mod;
return ans;
}
struct Int{
int base;
vector<int> a;
Int(){base=b;}
Int(int t){base=b; do {a.pb(t%base); t/=base;} while (t);}
int len() {return a.size();}
inline int& operator [] (int t) {return a[t];}
void adjust() {while (a.size()>1&&!a.back()) a.pop_back();}
void print() {for (int i=a.size()-1; ~i; --i) printf("%lld", a[i]); printf("\n");}
void scan() {
// cout<<"scan"<<endl;
a.clear();
vector<int> tem[2];
int now=0; char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) tem[now].pb(c-'0'), c=getchar();
// cout<<"tem: "; for (auto it:tem[now]) cout<<it<<' '; cout<<endl;
for (; ; now^=1) {
// cout<<"div: "; for (auto it:tem[now]) cout<<it<<' '; cout<<endl;
for (auto it:tem[now]) if (it) goto jump; break; jump: ;
tem[now^1].clear();
int rest=0;
for (auto it:tem[now]) {
rest=rest*10+it;
tem[now^1].pb(rest/base);
rest%=base;
}
a.pb(rest);
// cout<<"rest: "<<rest<<endl;
}
adjust();
// cout<<"ans: "; for (auto it:a) cout<<it<<' '; cout<<endl;
}
inline Int operator + (Int b) {
Int ans;
int lim=max(len(), b.len())+2;
ans.a.resize(lim);
for (int i=0; i<lim; ++i) {
if (i<len()) ans[i]+=a[i];
if (i<b.len()) ans[i]+=b[i];
ans[i+1]+=ans[i]/base;
ans[i]%=base;
}
ans.adjust();
return ans;
}
inline Int operator * (Int b) {
Int ans; ans.a.resize(len()+b.len()+1);
for (int i=0; i<len(); ++i)
for (int j=0; j<b.len(); ++j) {
ans[i+j]=ans[i+j]+a[i]*b[j];
ans[i+j+1]+=ans[i+j]/base;
ans[i+j]%=base;
}
ans.adjust();
return ans;
}
inline Int operator - (Int b) {
Int ans=*this;
for (int i=0; i<b.len(); ++i) ans[i]-=b[i];
for (int i=0; i<len(); ++i)
if (ans[i]<0) --ans[i+1], ans[i]+=base;
ans.adjust();
return ans;
}
inline ll toint() {
ll ans=0;
for (int i=a.size()-1; ~i; --i) ans=(ans*base+a[i])%mod;
return ans;
}
inline bool operator <= (Int b) {
if (len()!=b.len()) return len()<b.len();
for (int i=len()-1; ~i; --i)
if (a[i]!=b[i]) return a[i]<b[i];
return 1;
}
inline bool operator < (Int b) {
if (len()!=b.len()) return len()<b.len();
for (int i=len()-1; ~i; --i)
if (a[i]!=b[i]) return a[i]<b[i];
return 0;
}
};
void solve() {
m=read(); b=read(); c=read();
Int n; n.scan(); //n.print();
fac[0]=fac[1]=1; inv[0]=inv[1]=1; pw[0]=1;
for (int i=2; i<=m+1000; ++i) fac[i]=fac[i-1]*i%mod;
for (int i=2; i<=m+1000; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<=m+1000; ++i) inv[i]=inv[i-1]*inv[i]%mod;
for (int i=1; i<=m; ++i) pw[i]=pw[i-1]*b%mod;
for (int i=0; i<=m; ++i) {pw2[i][0]=1; for (int j=1; j<=m; ++j) pw2[i][j]=pw2[i][j-1]*pw[i]%mod;}
n=n-Int(1);
g[0][0][0]=1;
for (int i=1; i<=m; ++i)
for (int j=0; j<=m; ++j)
for (int k=0; k<=m; ++k) {
g[i][j][k]=g[i-1][j][k];
for (int t=0; t<=k; ++t)
g[i][j][k]=(g[i][j][k]+C(k, t)*g[i-1][j-1][t]%mod*pw2[i][k-t])%mod, assert(pw2[i][k-t]==qpow(qpow(b, i), k-t));
}
for (int len=0; len<=m; ++len) {
Int A;
if (c<=0 && n+Int(m)<Int(1-c)*Int(len)) continue;
if (c<=0) A=n+Int(m)-Int(1-c)*Int(len);
else A=n+Int(m)+Int(c-1)*Int(len);
ll a=A.toint(); int siz=A.len()-1;
memset(f, 0, sizeof(f));
f[0][1]=-1, f[0][0]=a;
for (int i=1; i<m; ++i)
for (int j=0; j<=m; ++j)
f[i][j]=((a-i)*f[i-1][j]-f[i-1][j-1])%mod;
for (int i=0; i<=m; ++i) F[i]=f[m-1][i]*inv[m]%mod;
// cout<<"F(x): "; for (int i=0; i<=m; ++i) cout<<F[i]<<' '; cout<<endl;
// cout<<"A: "; A.print();
// for (int i=0; i<=m; ++i) assert((qval(i)%mod+mod)%mod==C(a-i, m));
hpw[0]=1;
ll high=0, now_high, cnt=0, x, sum, now;
// cout<<"siz: "<<siz<<endl;
for (int i=siz; ~i&&cnt<=len; --i) {
// cout<<"i: "<<i<<' '<<cnt<<' '<<len<<endl;
if (!i) {if (cnt==len) ans=(ans+(len&1?-1:1)*qval(high))%mod; break;}
for (int j=0; j<min(A[i], 2ll); ++j) {
if (j==1) ++cnt, high=(high+pw[i])%mod;
if (cnt>len) break;
// cout<<"j: "<<j<<endl;
sum=0;
for (int k=1; k<=m; ++k) hpw[k]=hpw[k-1]*high%mod;
for (int k=0; k<=m; ++k) {
x=0;
for (int t=0; t<=k; ++t) x=(x+C(k, t)*hpw[t]%mod*g[i-1][len-cnt][k-t])%mod; //, cout<<g[i-1][len-cnt][k-t]<<endl;
sum=(sum+F[k]*x)%mod;
}
// cout<<"len: "<<len<<' '<<i<<' '<<sum<<endl;
// cout<<cnt<<' '<<len<<endl;
ans=(ans+(len&1?-1:1)*sum)%mod;
}
if (A[i]==1) ++cnt, high=(high+pw[i])%mod;
if (A[i]>1) break;
}
}
// cout<<qval(0)<<endl;
printf("%lld\n", (ans%mod+mod)%mod);
}
}
signed main()
{
freopen("checkin.in", "r", stdin);
freopen("checkin.out", "w", stdout);
// force::solve();
// task1::solve();
// force2::solve();
task::solve();
return 0;
}