Codeforces Round #448 (Div. 2) C. Square Subsets [数论II][dp II][bitmasks]
题目:http://codeforces.com/contest/895/problem/C
题意:在1e5个数字中选一些数字,使他们的乘积为平方数。
题解:最大数字只有70,如果把每个数字分解为质因子,最多也只有19个可能的数字。乘积为平方数则此数字分解为的质因数个数都为偶数,可用异或转移表示为是否都为0。例如 2 为第一个素数,它的状态表示为1 如果再加入一个2可以转移到1^1=0。3为第二个素数,表示为10,以此类推。为了节约内存,使用两个数组互相滚动dp。为了简化转移,初始化dp[0][0]=1代表了空集,但是要记住在最后答案里减去这个情况。
令num[i]为第i个数字所表示的二进制质数码,转移时对每个要加入的数字i,可以手动模拟一次多个数字转移,会发现第一次会把两个状态都转移为他们的加和,后面的操作因为两个数字情况数相同故只是在做乘积,所以一次处理多个相同的数字。
对i数字个数大于1的每个状态j可以转移到 dp[j][i+1]=dp[j^num[i]][i+1]=pow(2,cnt[i]-1)*(dp[j][i]+dp[num[i]^j][i]),但是要注意例如4已经是平方数,0^j=j,他的转移条件为dp[j][i+1]=dp[j^num[i]][i+1]=pow(2,cnt[i]-1)*(dp[num[i]^j][i])。因为1无法分解,把1作为系数在最后乘答案再加上1的所有贡献即可,ans=((ans[0]-1ll)*xi+xi-1LL+mod)%mod。
官方题解有更加通用简洁的dp方法,预处理出了每种数字个数的转移状态,貌似也可以用线性筛的方法解决。
#include<bits/stdc++.h> #define pii pair<int, int> #define mod 1000000007 #define mp make_pair #define pi acos(-1) #define eps 0.00000001 #define mst(a,i) memset(a,i,sizeof(a)) #define all(n) n.begin(),n.end() #define lson(x) ((x<<1)) #define rson(x) ((x<<1)|1) #define inf 0x3f3f3f3f typedef long long ll; typedef unsigned long long ull; using namespace std; const int maxn = 1e5 + 5; vector<int>prime; int isprime[maxn]; void getprimelist(int t) { mst(isprime, 1); isprime[1] = 0; for (int i = 2; i <= t; ++i) { if (isprime[i])prime.push_back(i); for (int j = 0; j < prime.size() && prime[j] * i <= t; ++j) { isprime[prime[j] * i] = 0; if (i%prime[j] == 0)break; } } } int num[80]; ll change[524288]; int pos[80]; int cnt[80]; ll ans[524288]; ll initpow[100005]; int main() { ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); int i, j, k, m, n, T; cin >> n; getprimelist(70); for (int i = 0; i < prime.size(); ++i) pos[prime[i]] = i; initpow[0]=1; for(int i = 1;i<=100000;++i) initpow[i]=initpow[i-1]*2ll%mod; for (int i = 2; i <= 70; ++i) { int tp = i; while (tp > 1) for (auto it : prime) if (tp%it == 0) { tp /= it; num[i] ^= (1 << pos[it]); break; } } for (int i = 1; i <= n; ++i) { cin >> k; cnt[k]++; } ll xi=initpow[cnt[1]]; ans[0]=1; for (int i = 2; i <= 70; ++i) { mst(change, 0); if (!cnt[i])continue; for (int j = 0; j < 524288; ++j) { if((j^num[i])<j)continue; ll ta = ans[j]+ans[j^num[i]]; if((j^num[i])==j)ta/=2ll; ll tb = initpow[cnt[i]-1]; tb=tb*ta%mod; change[j]=(change[j]+tb)%mod,change[num[i]^j]=(tb+change[num[i]^j])%mod; } memcpy(ans,change,sizeof(ans)); } cout<<((ans[0]-1ll)*xi+xi-1LL+mod)%mod<<endl; return 0; }
官方代码:
#define _CRT_SECURE_NO_WARNINGS #include <iostream> #include <fstream> #include <string> #include <iomanip> #include <iterator> #include <bitset> #include <vector> #include <math.h> #include <queue> #include <map> #include <set> #include <list> #include <time.h> #include <algorithm> #define mkp make_pair #define inf 1000000000 #define MOD 1000000007 #define eps 1e-7 using namespace std; typedef long long ll; int n; int mask[72]; ll f[2][72]; ll dp[2][1 << 20]; bool prime(int x) { for (int i = 2; i*i <= x; i++) if (x%i == 0) return 0; return 1; } void init() { for (int i = 0; i < 72; i++) f[0][i] = 1; int cnt = 0; for (int i = 2; i < 72; i++) { if (!prime(i)) continue; for (int j = 1; j < 72; j++) { int x = j; while (x%i == 0) { x /= i; mask[j] ^= (1 << cnt); } } cnt++; } } int main() { ios_base::sync_with_stdio(0); init(); cin >> n; for (int i = 0; i < n; i++) { int x; cin >> x; f[0][x] = f[1][x] = (f[0][x] + f[1][x]) % MOD; } dp[0][0] = 1; for (int i = 0; i <= 70; i++) { int nxt = (i + 1) % 2; int cur = i % 2; for (int msk = 0; msk < (1<<20); msk++) { dp[nxt][msk^mask[i]] = dp[nxt][msk^mask[i]] + dp[cur][msk] * f[1][i]; dp[nxt][msk] = dp[nxt][msk] + dp[cur][msk] * f[0][i]; if (dp[nxt][msk^mask[i]] >= MOD) dp[nxt][msk^mask[i]] %= MOD; if (dp[nxt][msk] >= MOD) dp[nxt][msk] %= MOD; } for (int msk = 0; msk < (1<<20); msk++) dp[cur][msk] = 0; } cout << (dp[1][0] - 1 + MOD)%MOD << endl; }