题目链接

https://atcoder.jp/contests/agc035/tasks/agc035_e

题解

没想出来最后一步DP宛如智障……
考虑一个数\(x\notin S\)的条件是\(x\)被删除了且在\(x\)最后一次被删除之后不能再对\(x+2\)\(x-K\)进行删除操作。也就是说\(x+2\)\(x-K\)的最晚删除时间要比\(x\)晚。那么我们从\(x\)\(x+2\)\(x-K\)连边,形成的图如果有环那么这个方案就不合法,否则合法。
如果\(K\)是偶数,显然很好算;如果\(K\)是奇数,有环的充要条件是存在一个环从某点\(a\)出发先后经过\(1\)\(+K\)边、若干条\(-2\)边、\(1\)\(+K\)边、若干条\(-2\)边,且经过的\(-2\)边的总数是\(K\). 这也就意味着我们需要让选出的点中不存在这样的一个环。
如果我们把图换一种方式看,把图排成一个每行\(2\)列的形状,让左边的奇数\(x\)和右边的偶数\(x+K\)在一行,从上往下每一层的两个(或一个)数是它上面一层对应的数\(+2\), 那么刚才提到的环就相当于若干条向上边、\(1\)条向右边、若干条向上边,向上边的总数是\(K\),最终从一个向左下的边下来。于是我们的目标就是要让“若干条向上边、\(1\)条向右边、若干条向上边”构成的最长的链长度不超过\((K+1)\).
这个可以从上往下DP: 设\(f[i][j][k]\)表示前\(i\)层,“从该层左边开始经过若干条向上边、\(1\)条向右边、若干条向上边的最长链”长度为\(j\),“从该层右边开始经过若干条向上边的最长链”长度为\(k\). 转移讨论一下两边都不选、选左、选右、左右都选即可。
时间复杂度\(O(n^3)\).

代码

#include<bits/stdc++.h>
#define llong long long
#define pii pair<int,int>
#define riterator reverse_iterator
using namespace std;

inline int read()
{
	int x = 0,f = 1; char ch = getchar();
	for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
	for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
	return x*f;
}

const int N = 150;
llong P;
int n,m;

void updsum(llong &x,llong y) {x = x+y>=P?x+y-P:x+y;}

namespace Solve1
{
	llong f[N+3];
	void solve()
	{
		m>>=1; f[0] = 1ll;
		for(int i=1; i<=n; i++)
		{
			for(int j=max(0,i-m-1); j<i; j++)
			{
				updsum(f[i],f[j]);
			}
		}
		llong ans1 = 0ll,ans2 = 0ll;
		for(int i=max(0,(n>>1)-m); i<=(n>>1); i++) updsum(ans1,f[i]);
		for(int i=max(0,(n+1>>1)-m); i<=(n+1>>1); i++) updsum(ans2,f[i]);
//		printf("ans1=%lld ans2=%lld\n",ans1,ans2);
		printf("%lld\n",ans1*ans2%P);
	}
}

namespace Solve2
{
	llong f[N+3][N+3][N+3];
	void solve()
	{
		f[0][0][0] = 1ll;
		for(int i=1; i+i-m<=n; i++)
		{
			for(int j=0; j<=m+1; j++) for(int k=0; k<=n; k++)
			{
				updsum(f[i][0][0],f[i-1][j][k]);
			}
			if(i+i<=n)
			{
				for(int j=0; j<=m+1; j++) for(int k=0; k<=n; k++)
				{
					updsum(f[i][0][k+1],f[i-1][j][k]);
				}
			}
			if(i+i-m>=1)
			{
				for(int j=0; j<=m; j++) for(int k=0; k<=n; k++)
				{
					updsum(f[i][j+(j>0)][0],f[i-1][j][k]);
				}
			}
			if(i+i<=n&&i+i-m>=1)
			{
				for(int j=0; j<=m; j++) for(int k=0; k<=n; k++)
				{
					int jj = max(j+1,k+2);
					if(jj<=m+1)
					{
						updsum(f[i][jj][k+1],f[i-1][j][k]);
					}
				}
			}
		}
		llong ans = 0ll;
		for(int j=0; j<=m+1; j++) for(int k=0; k<=n; k++)
		{
			updsum(ans,f[(n+m)>>1][j][k]);
		}
		printf("%lld\n",ans);
	}
}

int main()
{
	scanf("%d%d%lld",&n,&m,&P);
	if(!(m&1)) {Solve1::solve();}
	else {Solve2::solve();}
	return 0;
}