CF1463F Max Correct Set
一、题目
二、解法
首先我们考虑值域序列上决策,每个位置放 \(0/1\),要求任意两个 \(1\) 之间的距离不能是 \(x/y\),由于 \(n\) 很大但是 \(x,y\) 很小,可以猜测 \(x+y\) 是原序列的一段循环节,也就是这一段的最优解可以通过复制得到 \(n\) 的最优解:
证明:若 \((x+y)|n\),因为在 \(x+y\) 这一段中已经是最优解了,而相邻两段之间有没有更多的限制,所以把它复制到 \(n\) 显然达到了答案上界,故得到了最优解。
若 \((x+y)\not| \ \ n\),设 \(r=n\%(x+y)\),那么我们把原序列再次分段,奇数段的长度是 \(r\),偶数段的长度是 \(x+y-r\),因为我们要证明存在最优解满足奇数段和偶数段都可以全一样。那么等价于证明对于任何一种数列,可以调整成上述形式之后使得答案不降。
设 \(d_i\) 为把和 \(i\) 奇偶性相同的段全部调整成 \(i\) 段的代价,因为代价会抵消所以 \(\sum d_{2i}=\sum d_{2i+1}=\sum d_i=0\),然后考虑反证法,考虑对于所有 \(i\) 有 \(d_i+d_{i+1}<0\),那么:
- 因为 \(d_2+d_3<0,d_4+d_5<0...\),所以 \(\sum_{i\not=1}d_i<0\),那么 \(d_1>0\)
- 因为 \(d_1+d_2<0,d_4+d_5<0...\),所以 \(\sum_{i\not=3} d_i<0\),那么 \(d_3>0\)
- 类似可得 \(d_{2k+1}>0\),因为至少存在一个 \(d_{2k}\geq0\),所以存在 \(d_{2k}+d_{2k+1}>0\),与原命题矛盾。
现在找 \(x+y\) 段内的最优解即可,无脑状压可以做到 \(O(2^{\max(x,y)}(x+y))\),足以通过此题
由于只有考虑 \(i\) 和 \(i+x\) 或者 \(i+y\) 的限制,那么我们把 \(i\) 向 \(p=i\pm x\bmod(x+y)\) 连边,根据数论知识如果 \(\gcd(x,y)=1\) 就会构成一个环,环上相邻两个点才会有限制。这种情况可以用 \(O(x+y)\) 的简单 \(dp\) 解决。
如果 \(\gcd(x,y)\not=1\) 呢?考虑除去它们的最大公因数 \(g\),因为如果我们按模 \(g\) 的剩余系分类,每个剩余系中是互不干扰的,那么就可以分开求最优解最后直接相加,这就转化成了 \(\gcd(x,y)=1\) 的情况。
三、总结
\(n\) 很大的时候可以考虑:找规律,矩阵乘法,循环节。
证明循环节的时候用到的调整法,就考虑调整成相同的时候代价怎么变化。
没想到数论知识还可以用来优化 \(dp\)
#include <cstdio>
#include <iostream>
using namespace std;
const int M = 50;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,x,y,cnt[M],f[M][2][2];
int gcd(int a,int b)
{
return !b?a:gcd(b,a%b);
}
int solve(int n,int x,int y)
{
int g=gcd(x,y);
if(g>1) return (n%g)*solve(n/g+1,x/g,y/g)+(g-n%g)*solve(n/g,x/g,y/g);
for(int i=0;i<(x+y);i++)
cnt[i]=(n-1)/(x+y)+((n-1)%(x+y)>=i);
f[0][1][1]=cnt[0];
for(int i=1;i<(x+y);i++)
{
int c=cnt[(i*x)%(x+y)];
if(i!=x+y-1) f[i][1][1]=f[i-1][0][1]+c;
f[i][0][1]=max(f[i-1][0][1],f[i-1][1][1]);
f[i][1][0]=f[i-1][0][0]+c;
f[i][0][0]=max(f[i-1][0][0],f[i-1][1][0]);
}
int t=x+y-1;
return max(max(f[t][0][0],f[t][1][0]),max(f[t][0][1],f[t][1][1]));
}
signed main()
{
n=read();x=read();y=read();
printf("%d\n",solve(n,x,y));
}