题解 [AHOI2022] 山河重整
考虑什么时候一个数 \(i\) 不能被拼出来
发现肯定是要用 \([1, i-1]\) 中的数去拼,但不能被拼出来似乎有很多种情况欸?
那么考虑第一个不能被拼出来的数 \(i\)
这个数有很好的性质是 \([1, i-1]\) 中被选择的数的和是 \(i-1\) 且 \([1, i-1]\) 中的数都能被拼出来
对这个东西做一个 \(O(n^2)\) 的 DP 可以有 60 pts
然后再回来考虑怎么优化
发现这个 \([1, i-1]\) 中的数都能被拼出来的限制十分难搞
那么容斥掉这个限制
和是 \(i-1\) 就是简单背包,然后枚举第一个不能被拼出来的数
令 \(f_i\) 为在 \([1, i]\) 中选若干个两两不同的数和为 \(i\) 且 \([1, i]\) 中的数都可以被拼出的方案数
令 \(g_i\) 为在 \([1, i]\) 中选若干个两两不同的数和为 \(i\) 的方案数
有
\[f_i=g_i-\sum\limits_jf_j\times 在\ [j+2, i]\ 中选若干个数使其和为\ j-i\ 的方案数
\]
直接做还是 \(n^2\) 的
- 关于一类特殊背包问题的根号优化:
能优化的条件是选的数是根号级别的
考虑一个类似 Ferrers 图像的拆分
暴力背包是每次加入一列,但我们发现一行最多只有根号个格子,所以每次加入一行
转移是类似的,注意为了保证合法行要从大到小加入
然后回到这个题,发现后面还有一个系数我们不会算
但是这个系数是背包的形式,考虑套用上面的优化
上面的优化加入一种方案的方式是加入最上面那一行,并钦定这一行有 \(i\) 列
现在我们改为加入一个基座,钦定这个基座有一列高为 \(j\) 且有 \(i\) 列高为 \(j+2\)
发现计算一个 \(f_i\) 的时候要求 \(f_{1\cdots \frac{i}{2}}\) 都计算好了
那么弄一个类似分治 NTT 的东西就好了
复杂度 \(O(n\sqrt n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 500010
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
ll mod;
inline ll md(ll a) {return a>=mod?a-mod:a;}
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
int able[1<<20], ans;
void solve() {
int lim=1<<n;
for (int s=1; s<lim; ++s) {
int sum=0;
for (int i=1; i<=n; ++i) if (s&(1<<i-1)) sum+=i;
if (sum<=n) able[s]|=1<<sum-1;
for (int i=0; i<n; ++i) if (s&(1<<i))
able[s]|=able[s^(1<<i)];
if (able[s]==lim-1) ++ans;
}
cout<<ans%mod<<endl;
}
}
namespace task{
ll f[N], g[N], ans;
void solve(int n) {
// cout<<"solve: "<<n<<endl;
if (n==1) return ;
solve(n>>1);
for (int i=0; i<=n; ++i) g[i]=0;
for (int i=n; i; --i) if (1ll*i*(i+1)/2<=n) {
for (int j=n; j>=i; --j) g[j]=g[j-i];
for (int j=0; j<=n&&(j+(j+2ll)*i)<=n; ++j)
md(g[j+(j+2ll)*i], f[j]);
for (int j=i; j<=n; ++j) md(g[j], g[j-i]);
}
for (int i=n/2+1; i<=n; ++i) f[i]=md(f[i]+mod-g[i]);
}
void solve() {
for (int i=n; i; --i) if (1ll*i*(i+1)/2<=n) {
for (int j=n; j>=i; --j) f[j]=f[j-i];
md(f[i], 1);
for (int j=i; j<=n; ++j) md(f[j], f[j-i]);
}
f[0]=1; solve(n);
// cout<<"f: "; for (int i=1; i<=n; ++i) cout<<f[i]<<' '; cout<<endl;
for (int i=0; i<n; ++i) ans=(ans+f[i]*qpow(2, n-i-1))%mod;
printf("%lld\n", ((qpow(2, n)-ans)%mod+mod)%mod);
}
}
signed main()
{
freopen("rebuild.in", "r", stdin);
freopen("rebuild.out", "w", stdout);
n=read(); mod=read();
// force::solve();
task::solve();
return 0;
}