Codeforces Round #448 (Div. 2) C. Square Subsets [数论II][dp II][bitmasks]

题目:http://codeforces.com/contest/895/problem/C

题意:在1e5个数字中选一些数字,使他们的乘积为平方数。

题解:最大数字只有70,如果把每个数字分解为质因子,最多也只有19个可能的数字。乘积为平方数则此数字分解为的质因数个数都为偶数,可用异或转移表示为是否都为0。例如 2 为第一个素数,它的状态表示为1 如果再加入一个2可以转移到1^1=0。3为第二个素数,表示为10,以此类推。为了节约内存,使用两个数组互相滚动dp。为了简化转移,初始化dp[0][0]=1代表了空集,但是要记住在最后答案里减去这个情况。

令num[i]为第i个数字所表示的二进制质数码,转移时对每个要加入的数字i,可以手动模拟一次多个数字转移,会发现第一次会把两个状态都转移为他们的加和,后面的操作因为两个数字情况数相同故只是在做乘积,所以一次处理多个相同的数字。

对i数字个数大于1的每个状态j可以转移到 dp[j][i+1]=dp[j^num[i]][i+1]=pow(2,cnt[i]-1)*(dp[j][i]+dp[num[i]^j][i]),但是要注意例如4已经是平方数,0^j=j,他的转移条件为dp[j][i+1]=dp[j^num[i]][i+1]=pow(2,cnt[i]-1)*(dp[num[i]^j][i])。因为1无法分解,把1作为系数在最后乘答案再加上1的所有贡献即可,ans=((ans[0]-1ll)*xi+xi-1LL+mod)%mod。

官方题解有更加通用简洁的dp方法,预处理出了每种数字个数的转移状态,貌似也可以用线性筛的方法解决。

#include<bits/stdc++.h>
#define pii pair<int, int>
#define mod 1000000007
#define mp make_pair
#define pi acos(-1)
#define eps 0.00000001
#define mst(a,i) memset(a,i,sizeof(a))
#define all(n) n.begin(),n.end()
#define lson(x) ((x<<1))
#define rson(x) ((x<<1)|1)
#define inf 0x3f3f3f3f
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
const int maxn = 1e5 + 5;
vector<int>prime;
int isprime[maxn];
void getprimelist(int t)
{
    mst(isprime, 1);
    isprime[1] = 0;
    for (int i = 2; i <= t; ++i)
    {
        if (isprime[i])prime.push_back(i);
        for (int j = 0; j < prime.size() && prime[j] * i <= t; ++j)
        {
            isprime[prime[j] * i] = 0;
            if (i%prime[j] == 0)break;
        }
    }
}
int num[80];
ll change[524288];
int pos[80];
int cnt[80];
ll ans[524288];
ll initpow[100005];
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    int i, j, k, m, n, T;
    cin >> n;
    getprimelist(70);
    for (int i = 0; i < prime.size(); ++i)
        pos[prime[i]] = i;
    initpow[0]=1;
    for(int i = 1;i<=100000;++i)
        initpow[i]=initpow[i-1]*2ll%mod;
    for (int i = 2; i <= 70; ++i)
    {
        int tp = i;
        while (tp > 1)
            for (auto it : prime)
                if (tp%it == 0)
                {
                    tp /= it;
                    num[i] ^= (1 << pos[it]);
                    break;
                }
    }
    for (int i = 1; i <= n; ++i)
    {
        cin >> k;
        cnt[k]++;
    }
    ll xi=initpow[cnt[1]];
    ans[0]=1;
    for (int i = 2; i <= 70; ++i)
    {
        mst(change, 0);
        if (!cnt[i])continue;
        for (int j = 0; j < 524288; ++j)
        {
            if((j^num[i])<j)continue;
            ll ta = ans[j]+ans[j^num[i]];
            if((j^num[i])==j)ta/=2ll;
            ll tb = initpow[cnt[i]-1];
            tb=tb*ta%mod;
            change[j]=(change[j]+tb)%mod,change[num[i]^j]=(tb+change[num[i]^j])%mod;
        }
        memcpy(ans,change,sizeof(ans));
    }
    cout<<((ans[0]-1ll)*xi+xi-1LL+mod)%mod<<endl;
    return 0;
}

官方代码:

#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <fstream>
#include <string>
#include <iomanip>
#include <iterator>
#include <bitset>
#include <vector>
#include <math.h>
#include <queue>
#include <map>
#include <set>
#include <list>
#include <time.h>
#include <algorithm>
#define mkp make_pair
#define inf 1000000000
#define MOD 1000000007
#define eps 1e-7
 
using namespace std;
typedef long long ll;
 
int n;
int mask[72];
ll f[2][72];
ll dp[2][1 << 20];
 
bool prime(int x)
{
    for (int i = 2; i*i <= x; i++)
        if (x%i == 0)
            return 0;
    return 1;
}
 
void init()
{
    for (int i = 0; i < 72; i++)
        f[0][i] = 1;
    int cnt = 0;
    for (int i = 2; i < 72; i++)
    {
        if (!prime(i))
            continue;
        for (int j = 1; j < 72; j++)
        {
            int x = j;
            while (x%i == 0)
            {
                x /= i;
                mask[j] ^= (1 << cnt);
            }
        }
        cnt++;
    }
}
 
int main()
{
    ios_base::sync_with_stdio(0);
    init();
    cin >> n;
    for (int i = 0; i < n; i++)
    {
        int x;
        cin >> x;
        f[0][x] = f[1][x] = (f[0][x] + f[1][x]) % MOD;
    }
    dp[0][0] = 1;
    for (int i = 0; i <= 70; i++)
    {
        int nxt = (i + 1) % 2;
        int cur = i % 2;
        for (int msk = 0; msk < (1<<20); msk++)
        {
            dp[nxt][msk^mask[i]] = dp[nxt][msk^mask[i]] + dp[cur][msk] * f[1][i];
            dp[nxt][msk] = dp[nxt][msk] + dp[cur][msk] * f[0][i];
            if (dp[nxt][msk^mask[i]] >= MOD)
                dp[nxt][msk^mask[i]] %= MOD;
            if (dp[nxt][msk] >= MOD)
                dp[nxt][msk] %= MOD;
        }
        for (int msk = 0; msk < (1<<20); msk++)
            dp[cur][msk] = 0;
    }
    cout << (dp[1][0] - 1 + MOD)%MOD << endl;
}

 

posted @ 2017-11-28 17:09  Meternal  阅读(348)  评论(0编辑  收藏  举报