[AGC043-D]Merge Triplets
题目
点这里看题目。
分析
我们不妨来考虑一下生成的序列有什么性质。
为了方便表示,我们将序列\(S\)的第\(i\)项写为\(S[i]\)。
首先考虑如果所有的\(A\)序列都是递增的,那么我们得到的序列肯定是递增的。如果存在递减的情况,例如其中某个序列\(B\in\{A_1,A_2,\dots,A_n\}\),存在\(B[1]>B[2]\)。那么按照取数规则,我们一旦取出了\(B[1]\),我们就一定会取出\(B[2]\)。这个比较显然。这是因为\(B[1]\)被取出的时候,其他所有序列的第一个元素肯定都大于\(B[1]\),因此也肯定大于\(B[2]\)。
我们发现只要序列中存在\(B[1]>B[2]\)或者\(B[2]>B[3]\),这两个数就会被相邻地取出;如果存在\(B[1]>B[2]\)且\(B[1]>B[3]\),这三个数也会被相邻取出。我们将这种必然相邻取出的情况分进一个组里面。
注意到组只会有长度为 1 ,长度为 2 ,长度为 3 三种,而且由于一个长度为 2 的组一定和一个长度为 1 的组成对出现,因此长度为 2 的组的数量一定不超过长度为 1 的组的数量。
我们可以发现,这样的组在构造的过程中,一定会按照组的第一个元素的大小进行排序构造出一个排列来。因此,一些组如果合法,就可以唯一确定一个排列。因此,我们可以通过计算组的合法构造方案来计算可生成的排列方案数。
我们有两种方法来解决这个问题:
1.DP
我们需要将\([1,3n]\)划分成若干组,限制如下:
1. 每一组的长度不超过3。
2. 每一组的第一个数一定是这一组中最大的。
3. 长度为 2 的组的数量不超过长度为 1 的组的数量。
因此我们可以设计如下的 DP 方案:
\(f(i,j)\):前\(i\)个数分组,满足长度为 1 的组的数量减去长度为 2 的组的数量为\(j\)的方案数。
转移实际上是考虑最后一个数会怎样分组。转移如下:
答案是\(\sum_{i=0}^{3n}f(3n,i)\)。DP 的时间是\(O(n^2)\)。
2.枚举
这其实是我自己口胡的。
建议写第一种方法。
由于\(n\)很小,我们可以直接枚举长度为 2 的组的数量和长度为 3 的组的数量(需要满足长度为 2 的组的数量不超过长度为 1 的数量这一前提)。设\(f(n)\)为\(2n\)个数全部分为长度为 2 的组的方案数。转移大概如下:
\(g(n)\)为\(3n\)个数全部分为长度为 3 的组的方案数,转移类似。这样预处理完之后就可以枚举组的数量了。答案:
时间是\(O(n^2)\)。
如果有问题请轻喷,我也没有试过这个方法。
代码
#include <cstdio>
const int MAXN = 6005;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
int f[MAXN][MAXN << 1];
int N, M;
void add( int &x, const int v ) { x += v; if( x >= M ) x -= M; }
int main()
{
read( N ), read( M );
int t = N * 3;
f[0][t] = 1;
for( int i = 1 ; i <= t ; i ++ )
for( int j = - t ; j <= t ; j ++ )
{
if( j > -t ) add( f[i][j + t], f[i - 1][j + t - 1] ); //长度为1
if( j < t && i >= 2 ) add( f[i][j + t], 1ll * f[i - 2][j + t + 1] * ( i - 1 ) % M ); //长度为2
if( i >= 3 ) add( f[i][j + t], 1ll * f[i - 3][j + t] * ( i - 1 ) % M * ( i - 2 ) % M ); //长度为3
}
int ans = 0;
for( int i = 0 ; i <= t ; i ++ ) add( ans, f[t][i + t] );
write( ans ), putchar( '\n' );
return 0;
}