CodeForces 840C - On the Bench | Codeforces Round #429 (Div. 1)

思路来自FXXL中的某个链接

/*
CodeForces 840C - On the Bench [ DP ]  |  Codeforces Round #429 (Div. 1)
题意:
	给出一个数组,问有多少种下标排列,使得任意两个相邻元素的乘积不是完全平方数
分析:
	将数组分组,使得每组中的任意两个数之积为完全平方数
	由唯一分解定理可知,每个质因子的幂次的奇偶性相同的两个数之积为完全平方数
		即按每个质因子的幂次的奇偶性分组,故这样的分组唯一
	然后问题归结于每组中的数不能相邻的排列有几种
	设 dp[i][j]表示 前i组相邻的同组的数有j对
	考虑把第i+1组分段后插入前i组的空隙中
	枚举将下一组分成k段,每段相邻
	枚举k段中有l段插在前面j对同组的空隙中
	设前i组总个数为sum, 第i+1组个数为num
	则得到转移方程
		dp[i+1][j-l+num-k] += C(num-1, k-1) * C(j, l) * C(sum+1-j, k-l) * dp[i][j]
	组合数什么的仔细推导下,再最后乘上每组的排列数
*/
#include <bits/stdc++.h>
using namespace std;
#define LL long long
const int MOD = 1e9+7;
const int N = 305;
LL C[N][N], F[N];
void init() {
    C[0][0] = 1;
    for (int i = 1; i < N; i++) {
        C[i][0] = C[i][i] = 1;
        for (int j = 1; j < i; j++)
            C[i][j] = (C[i-1][j] + C[i-1][j-1]) % MOD;
    }
    F[0] = 1;
    for (int i = 1; i < N; i++) F[i] = i * F[i-1] % MOD;
}
bool check(LL a, LL b)
{
    LL l = 1, r = 1e10, mid;
    while (l <= r)
    {
        mid = (l+r) >> 1;
        if (mid*mid <= a*b) l = mid+1;
        else r = mid-1;
    }
    return r*r == a*b;
}
LL dp[N][N], ans;
int n, a[N], id[N], num[N], cnt;
void solve()
{
    dp[0][0] = 1;
    int sum = 0;
    for (int i = 1; i <= cnt; i++)//第i组
    {
        for (int j = 0; j <= sum; j++)//j处平方
            for (int k = 1; k <= num[i]; k++)//num[i]分成k段
                for (int l = 0; l <= j && l <= k; l++)//j 中 l 段
                {
                    LL tmp = dp[i-1][j];
                    tmp = tmp * C[num[i]-1][k-1] % MOD;
                    tmp = tmp * C[j][l] % MOD;
                    tmp = tmp * C[sum+1-j][k-l] % MOD;
                    dp[i][j-l+num[i]-k] += tmp;
                    dp[i][j-l+num[i]-k] %= MOD;
                }
        sum += num[i];
    }
    ans = dp[cnt][0];
    for (int i = 1; i <= cnt; i++) ans = ans * F[num[i]] % MOD;
}
int main()
{
    init();
    cnt = 0;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
        bool flag = 0;
        for (int j = 1; j < i; j++)
        {
            if (check(a[i], a[j]))
            {
                num[id[j]]++;
                id[i] = id[j];
                flag = 1; break;
            }
        }
        if (!flag)
        {
            id[i] = ++cnt;
            num[cnt] = 1;
        }
    }
    solve();
    printf("%lld\n", ans);
}

  

posted @ 2017-08-23 17:33  nicetomeetu  阅读(278)  评论(0编辑  收藏  举报