题解 CF840C On the Bench
如果\(a\times b\)是完全平方数,\(a\times c\)是完全平方数,那么\(a^2\times b\times c\)也是完全平方数,又因为\(a^2\)是完全平方数,所以\(b\times c\)也是完全平方数。也就是说,完全平方数具有传递性。据此,我们可以把给定的序列划分为若干个集合,使得同一集合内的数两两的乘积都是完全平方数,不同集合内的任意两个数乘积都不是完全平方数。
于是,问题转化为,\(n\)个数分为\(m\)个集合。求有多少种\(n\)个数的排列方式,使得任意相邻两个数不来自同一集合。
设集合\(i\)的大小为\(sz[i]\),前\(i\)个集合的大小之和为\(sum[i]\)。
DP。逐个集合考虑。设\(dp[i][j]\),表示考虑了前\(i\)个集合,有\(j\)个不合法的位置的方案数。这里“不合法的位置”,是指某个位置上的数,和它前一个位置上的数,来自同一个集合。显然,有\(j\)个不合法的位置,也意味着序列被划分成了\(sum[i]-j\)个块,每个块内的元素来自相同的集合,相邻两块的元素来自不同的集合。
注意,在这个DP状态的定义下,我们认为,这\(sum[i]\)个数是紧贴着放置的,但是在转移时,可能会在上一状态下两个紧贴着的位置之间插入一些数。
转移时,我们先考虑把当前集合,分成多少个块。设这个块数为\(k\)(\(1\leq k\leq sz[i]\)),枚举\(k\)。显然,把\(x\)个元素划分为\(k\)个块并排在一起的方案数为\(x!\cdot {x-1\choose k-1}\)(插板法)。这样,当前集合就划分好了。考虑如何把这划分出的\(k\)个块插入前面已经排列好的序列中。
前面的序列里已经有了\(sum[i-1]-j\)个块。如果在不同块之间插入,或者在整个序列的最前面/最后面插入,都不会使不合法位置的数量(即DP数组的第二维)减少;否则,在原先某一块的内部插入时,会使一个本来不合法的数,和它前面的位置隔开,于是这个数就变得合法了。我们枚举新增的\(k\)个块里,有\(l\)块是第一种情况,剩下的块是第二种情况。显然,\(l\leq k\)且\(l\leq sum[i-1]-j+1\)。
这样,我们会转移到\(dp[i][j+(sz[i]-k)-(k-l)]\)。其中\((sz[i]-k)\)是新增的不合法的位置,\((k-l)\)是因为新增的数把原来的数隔开,从而使得不合法位置的数量减少了。
于是,可以写出转移式:
答案就是\(dp[m][0]\)。
下面来分析一下时间复杂度。我们一共枚举了\(i\),\(j\),\(k\),\(l\)四个东西,看起来是\(O(n^4)\)。但是注意到,\(i\)遍历的是集合数量,\(k\)遍历的是当前集合的大小,\(i\),\(k\)套起来是\(O(\sum_{t=1}^{m}sz[t])=O(n)\)的。所以实际的时间复杂度为\(O(n^3)\)。
参考代码:
//problem:CF840C
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
if(S==T){
T=(S=buf)+fread(buf,1,MAXN,stdin);
if(S==T)return EOF;
}
return *S++;
}
}//namespace Fread
#ifdef ONLINE_JUDGE
#define getchar Fread::getchar
#endif
template<typename T>inline void read(T& x){
x=0;int f=1;
char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+(ch-'0'),ch=getchar();
x*=f;
}
/* ------ by:duyi ------ */ // myt天下第一
const int MAXN=300,MOD=1e9+7;
inline int mod1(int x){return x<MOD?x:x-MOD;}
inline int mod2(int x){return x<0?x+MOD:x;}
inline void add(int &x,int y){x=mod1(x+y);}
inline void sub(int &x,int y){x=mod2(x-y);}
inline int pow_mod(int x,int i){int y=1;while(i){if(i&1)y=(ll)y*x%MOD;x=(ll)x*x%MOD;i>>=1;}return y;}
int fac[MAXN+5],ifac[MAXN+5],n,a[MAXN+5],fa[MAXN+5],sz[MAXN+5],s[MAXN+5],cnt,dp[MAXN+5][MAXN+5];
inline int comb(int n,int k){
if(n<k)return 0;
return (ll)fac[n]*ifac[k]%MOD*ifac[n-k]%MOD;
}
int get_fa(int x){return (x==fa[x])?x:(fa[x]=get_fa(fa[x]));}
void union_s(int x,int y){
x=get_fa(x);y=get_fa(y);
if(x!=y){
if(sz[x]>sz[y])swap(x,y);
fa[x]=y;sz[y]+=sz[x];
}
}
bool is_square(ll x){ll t=sqrt(x);return t*t==x;}
int g(int n,int k){return (ll)fac[n]*comb(n-1,k-1)%MOD;}
int main() {
fac[0]=1;
for(int i=1;i<=MAXN;++i)fac[i]=(ll)fac[i-1]*i%MOD;
ifac[MAXN]=pow_mod(fac[MAXN],MOD-2);
for(int i=MAXN-1;i>=0;--i)ifac[i]=(ll)ifac[i+1]*(i+1)%MOD;
read(n);
for(int i=1;i<=n;++i)fa[i]=i,sz[i]=1;
for(int i=1;i<=n;++i){
read(a[i]);
for(int j=1;j<i;++j)if(is_square((ll)a[i]*a[j]))union_s(i,j);
}
for(int i=1;i<=n;++i)if(get_fa(i)==i)s[++cnt]=sz[i];
//for(int i=1;i<=cnt;++i)cout<<s[i]<<" ";cout<<endl;
dp[1][s[1]-1]=fac[s[1]];
for(int i=2,sum=s[1];i<=cnt;++i){
for(int j=0;j<=sum-1;++j)if(dp[i-1][j]){
//有j个不合法的元素: 被分成了sum-j个连续段
for(int k=1;k<=s[i]&&k<=sum+1;++k){
for(int l=max(0,k-j);l<=sum-j+1&&l<=k;++l){
assert(j+s[i]-k-(k-l)>=0);
add(dp[i][j+(s[i]-k)-(k-l)],(ll)dp[i-1][j]*g(s[i],k)%MOD*comb(sum-j+1,l)%MOD*comb(j,k-l)%MOD);
}
}
}
sum+=s[i];
}
cout<<dp[cnt][0]<<endl;
//int i,j;while(1){read(i);read(j);cout<<dp[i][j]<<endl;}
return 0;
}