题解 签到 / 序列计数问题 / [UNR #2] 梦中的题面

签到
序列计数问题
[UNR #2] 梦中的题面

怎么都搞这个

image

打 nm 过 nm 爷就是要打锤子

那么首先有一个经典容斥,钦定超过限制的桶
那么就有

\[ans=\sum\limits_s(-1)^{|s|}\binom{n+m+(c-1)|s|-\sum\limits_{i\in s}b^i}{m} \]

然后这个式子就只会 \(O(2^m)\) 算,于是寄了

  • 包含未知数的组合数 \(\dbinom{ax+b}{c}\) 在保证 \(ax+b\geqslant 0\) 的情况下可以拆成关于 \(x\) 的多项式
    这样的好处是不必对每个 \(x\) 分别计算,可以直接代入 \(\sum x_i\) 求出 \(\sum\dbinom{ax_i+b}{c}\)

那么枚举钦定 \(|s|\)\(A=n+m+(c-1)|s|\) 就确定了
那么就是要算出满足 \(\sum\limits_{i\in s}b^i\leqslant A\)\(\sum\limits_{i\in s}b^i\)
发现这里都是 \(b^i\),那将 \(A\) 写成 \(b\) 进制的话可以很方便地钦定 \(A\)\(\sum\limits_{i\in s}b^i\) 的 lcp 长度
那么问题转化为求 \(g_{i, j, k}\) 表示考虑前 \(i\) 个元素,选了 \(j\) 个的权值的 \(k\) 次方和
这里要 DP \(k\) 次方和是为了方便代入多项式求组合数
那么枚举 lcp,高位方案是固定的,和低位的 DP 合并一下就好了
在 UOJ 上被 hack 后 upd:特别注意一点:考虑 \(n=b^{m+1}\) 的情况,所以 pw 之类的东西要预处理到 \(m+1\) 而不是 \(m\)
复杂度 \(O(n^4)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define pb push_back
#define ll long long
#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

const ll mod=998244353;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	int n, m, b, c;
	ll inv[N], up[N], ans;
	inline ll C(ll n, ll k) {
		ll ans=inv[k];
		for (int i=n; i>n-k; --i) ans=ans*i%mod;
		return ans;
	}
	void solve() {
		m=read(); b=read(); c=read(); n=read();
		inv[0]=inv[1]=1;
		for (int i=2; i<=m; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=m; ++i) inv[i]=inv[i-1]*inv[i]%mod;
		for (int i=1; i<=m; ++i) up[i]=(qpow(b, i)-c)%mod;
		int lim=1<<m;
		n=n-1;
		for (int s=0; s<lim; ++s) {
			int sum=n;
			for (int i=1; i<=m; ++i) if (s&(1<<(i-1))) {
				int t=qpow(b, i)-c+1;
				if (sum>=t) sum=sum-t;
				else goto jump;
			}
			// C((sum.toint()+m)%mod, m)
			ans=(ans+(__builtin_popcount(s)&1?-1:1)*C(sum+m, m))%mod;
			jump: ;
		}
		cout<<(ans%mod+mod)%mod<<endl;
	}
}

namespace force2{
	int n, m, b, c;
	ll x[N], up[N], ans;
	void dfs(int u) {
		if (u>m) {
			ll sum=0;
			for (int i=1; i<=m; ++i) sum+=x[i];
			if (sum<n) ++ans;
			return ;
		}
		for (int i=0; i<=up[u]; ++i) x[u]=i, dfs(u+1);
	}
	void solve() {
		m=read(); b=read(); c=read(); n=read();
		for (int i=1; i<=m; ++i) up[i]=(qpow(b, i)-c)%mod;
		dfs(1);
		cout<<ans<<endl;
	}
}

namespace task1{
	int m, b, c;
	ll inv[N], up[N], ans;

	struct Int{
		vector<int> a;
		Int(){}
		Int(int t){do {a.pb(t%10); t/=10;} while (t);}
		int len() {return a.size();}
		inline int& operator [] (int t) {return a[t];}
		void adjust() {while (a.size()>1&&!a.back()) a.pop_back();}
		void put() {for (int i=a.size()-1; ~i; --i) printf("%lld", a[i]); printf("\n");}
		void get() {
			char c=getchar();
			while (!isdigit(c)) c=getchar();
			while (isdigit(c)) {a.pb(c-'0'); c=getchar();}
			reverse(a.begin(), a.end());
		}
		inline Int operator + (Int b) {
			Int ans;
			int lim=max(len(), b.len())+2;
			ans.a.resize(lim);
			for (int i=0; i<lim; ++i) {
				if (i<len()) ans[i]+=a[i];
				if (i<b.len()) ans[i]+=b[i];
				ans[i+1]+=ans[i]/10;
				ans[i]%=10;
			}
			ans.adjust();
			return ans;
		}
		inline Int operator * (Int b) {
			Int ans; ans.a.resize(len()+b.len()+1);
			for (int i=0; i<len(); ++i)
				for (int j=0; j<b.len(); ++j) {
					ans[i+j]=ans[i+j]+a[i]*b[j];
					ans[i+j+1]+=ans[i+j]/10;
					ans[i+j]%=10;
				}
			ans.adjust();
			return ans;
		}
		inline Int operator - (Int b) {
			Int ans=*this;
			for (int i=0; i<b.len(); ++i) ans[i]-=b[i];
			for (int i=0; i<len(); ++i)
				if (ans[i]<0) --ans[i+1], ans[i]+=10;
			ans.adjust();
			return ans;
		}
		inline ll toint() {
			ll ans=0;
			for (int i=a.size()-1; ~i; --i) ans=(ans*10+a[i])%mod;
			return ans;
		}
		inline bool operator <= (Int b) {
			if (len()!=b.len()) return len()<b.len();
			for (int i=len()-1; ~i; --i)
				if (a[i]!=b[i]) return a[i]<b[i];
			return 1;
		}
	}n, b2, pw[100];

	inline ll C(ll n, ll k) {
		ll ans=inv[k];
		for (int i=n; i>n-k; --i) ans=ans*i%mod;
		return ans;
	}
	void solve() {
		m=read(); b=read(); c=read();
		n.get();
		inv[0]=inv[1]=1;
		for (int i=2; i<=m; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=m; ++i) inv[i]=inv[i-1]*inv[i]%mod;
		for (int i=1; i<=m; ++i) up[i]=(qpow(b, i)-c)%mod;
		pw[1]=Int(b);
		for (int i=2; i<=m; ++i) pw[i]=pw[i-1]*Int(b);
		int lim=1<<m;
		n=n-Int(1);
		for (int s=0; s<lim; ++s) {
			Int sum=n;
			for (int i=1; i<=m; ++i) if (s&(1<<(i-1))) {
				Int t=pw[i];
				if (c-1>=0) t=t-Int(c-1);
				else if (c-1<0) t=t+Int(c-1);
				if (t<=sum) sum=sum-t;
				else goto jump;
			}
			// C((sum.toint()+m)%mod, m)
			ans=(ans+(__builtin_popcount(s)&1?-1:1)*C(sum.toint()+m, m))%mod;
			jump: ;
		}
		cout<<(ans%mod+mod)%mod<<endl;
	}
}

namespace task{
	int m, b, c;
	ll fac[N], inv[N], pw[100], pw2[100][100], hpw[100], f[100][100], g[100][100][100], F[100], ans;
	inline ll C(int n, int k) {return fac[n]*inv[k]%mod*inv[n-k]%mod;}
	inline ll qval(int x) {
		int now=1, ans=0;
		for (int i=0; i<=m; ++i,now=now*x%mod) ans=(ans+F[i]*now)%mod;
		return ans;
	}
	struct Int{
		int base;
		vector<int> a;
		Int(){base=b;}
		Int(int t){base=b; do {a.pb(t%base); t/=base;} while (t);}
		int len() {return a.size();}
		inline int& operator [] (int t) {return a[t];}
		void adjust() {while (a.size()>1&&!a.back()) a.pop_back();}
		void print() {for (int i=a.size()-1; ~i; --i) printf("%lld", a[i]); printf("\n");}
		void scan() {
			// cout<<"scan"<<endl;
			a.clear();
			vector<int> tem[2];
			int now=0; char c=getchar();
			while (!isdigit(c)) c=getchar();
			while (isdigit(c)) tem[now].pb(c-'0'), c=getchar();
			// cout<<"tem: "; for (auto it:tem[now]) cout<<it<<' '; cout<<endl;
			for (; ; now^=1) {
				// cout<<"div: "; for (auto it:tem[now]) cout<<it<<' '; cout<<endl;
				for (auto it:tem[now]) if (it) goto jump; break; jump: ;
				tem[now^1].clear();
				int rest=0;
				for (auto it:tem[now]) {
					rest=rest*10+it;
					tem[now^1].pb(rest/base);
					rest%=base;
				}
				a.pb(rest);
				// cout<<"rest: "<<rest<<endl;
			}
			adjust();
			// cout<<"ans: "; for (auto it:a) cout<<it<<' '; cout<<endl;
		}
		inline Int operator + (Int b) {
			Int ans;
			int lim=max(len(), b.len())+2;
			ans.a.resize(lim);
			for (int i=0; i<lim; ++i) {
				if (i<len()) ans[i]+=a[i];
				if (i<b.len()) ans[i]+=b[i];
				ans[i+1]+=ans[i]/base;
				ans[i]%=base;
			}
			ans.adjust();
			return ans;
		}
		inline Int operator * (Int b) {
			Int ans; ans.a.resize(len()+b.len()+1);
			for (int i=0; i<len(); ++i)
				for (int j=0; j<b.len(); ++j) {
					ans[i+j]=ans[i+j]+a[i]*b[j];
					ans[i+j+1]+=ans[i+j]/base;
					ans[i+j]%=base;
				}
			ans.adjust();
			return ans;
		}
		inline Int operator - (Int b) {
			Int ans=*this;
			for (int i=0; i<b.len(); ++i) ans[i]-=b[i];
			for (int i=0; i<len(); ++i)
				if (ans[i]<0) --ans[i+1], ans[i]+=base;
			ans.adjust();
			return ans;
		}
		inline ll toint() {
			ll ans=0;
			for (int i=a.size()-1; ~i; --i) ans=(ans*base+a[i])%mod;
			return ans;
		}
		inline bool operator <= (Int b) {
			if (len()!=b.len()) return len()<b.len();
			for (int i=len()-1; ~i; --i)
				if (a[i]!=b[i]) return a[i]<b[i];
			return 1;
		}
		inline bool operator < (Int b) {
			if (len()!=b.len()) return len()<b.len();
			for (int i=len()-1; ~i; --i)
				if (a[i]!=b[i]) return a[i]<b[i];
			return 0;
		}
	};
	void solve() {
		m=read(); b=read(); c=read();
		Int n; n.scan(); //n.print();
		fac[0]=fac[1]=1; inv[0]=inv[1]=1; pw[0]=1;
		for (int i=2; i<=m+1000; ++i) fac[i]=fac[i-1]*i%mod;
		for (int i=2; i<=m+1000; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=m+1000; ++i) inv[i]=inv[i-1]*inv[i]%mod;
		for (int i=1; i<=m; ++i) pw[i]=pw[i-1]*b%mod;
		for (int i=0; i<=m; ++i) {pw2[i][0]=1; for (int j=1; j<=m; ++j) pw2[i][j]=pw2[i][j-1]*pw[i]%mod;}
		n=n-Int(1);
		g[0][0][0]=1;
		for (int i=1; i<=m; ++i)
			for (int j=0; j<=m; ++j)
				for (int k=0; k<=m; ++k) {
					g[i][j][k]=g[i-1][j][k];
					for (int t=0; t<=k; ++t)
						g[i][j][k]=(g[i][j][k]+C(k, t)*g[i-1][j-1][t]%mod*pw2[i][k-t])%mod, assert(pw2[i][k-t]==qpow(qpow(b, i), k-t));
				}
		for (int len=0; len<=m; ++len) {
			Int A;
			if (c<=0 && n+Int(m)<Int(1-c)*Int(len)) continue;
			if (c<=0) A=n+Int(m)-Int(1-c)*Int(len);
			else A=n+Int(m)+Int(c-1)*Int(len);
			ll a=A.toint(); int siz=A.len()-1;
			memset(f, 0, sizeof(f));
			f[0][1]=-1, f[0][0]=a;
			for (int i=1; i<m; ++i)
				for (int j=0; j<=m; ++j)
					f[i][j]=((a-i)*f[i-1][j]-f[i-1][j-1])%mod;
			for (int i=0; i<=m; ++i) F[i]=f[m-1][i]*inv[m]%mod;
			// cout<<"F(x): "; for (int i=0; i<=m; ++i) cout<<F[i]<<' '; cout<<endl;
			// cout<<"A: "; A.print();
			// for (int i=0; i<=m; ++i) assert((qval(i)%mod+mod)%mod==C(a-i, m));
			hpw[0]=1;
			ll high=0, now_high, cnt=0, x, sum, now;
			// cout<<"siz: "<<siz<<endl;
			for (int i=siz; ~i&&cnt<=len; --i) {
				// cout<<"i: "<<i<<' '<<cnt<<' '<<len<<endl;
				if (!i) {if (cnt==len) ans=(ans+(len&1?-1:1)*qval(high))%mod; break;}
				for (int j=0; j<min(A[i], 2ll); ++j) {
					if (j==1) ++cnt, high=(high+pw[i])%mod;
					if (cnt>len) break;
					// cout<<"j: "<<j<<endl;
					sum=0;
					for (int k=1; k<=m; ++k) hpw[k]=hpw[k-1]*high%mod;
					for (int k=0; k<=m; ++k) {
						x=0;
						for (int t=0; t<=k; ++t) x=(x+C(k, t)*hpw[t]%mod*g[i-1][len-cnt][k-t])%mod; //, cout<<g[i-1][len-cnt][k-t]<<endl;
						sum=(sum+F[k]*x)%mod;
					}
					// cout<<"len: "<<len<<' '<<i<<' '<<sum<<endl;
					// cout<<cnt<<' '<<len<<endl;
					ans=(ans+(len&1?-1:1)*sum)%mod;
				}
				if (A[i]==1) ++cnt, high=(high+pw[i])%mod;
				if (A[i]>1) break;
			}
		}
		// cout<<qval(0)<<endl;
		printf("%lld\n", (ans%mod+mod)%mod);
	}
}

signed main()
{
	freopen("checkin.in", "r", stdin);
	freopen("checkin.out", "w", stdout);

	// force::solve();
	// task1::solve();
	// force2::solve();
	task::solve();

	return 0;
}
posted @ 2022-03-23 18:55  Administrator-09  阅读(18)  评论(0编辑  收藏  举报