2021杭电多校2 1006/HDU 6966 I love sequences
题目链接:https://acm.hdu.edu.cn/showproblem.php?pid=6966
题目大意:计算\(\sum_{p=1}^n\sum_{k=1}^{+\infty }d_{p,k}*c_{p}^k\),其中\(d_{p,k}=\sum_{k=i\oplus j}a_i*b_j(1\leq i\leq \frac{n}{p},1\leq j\leq \frac{n}{p})\),\(\oplus\)运算为三进制下按位取\(gcd\),即\(k_{t}=gcd(i_{t},j_{t})\)
题目思路:对于每个\(\frac{n}{p}\)暴力FWT计算\(d_{p,k}\)的时间复杂度为\(\sum_{p=1}^nO({n\over p}\log{n\over p})=O(n\log^2 n)\)
难点在于构造三维矩阵
\[\begin{pmatrix}
c(0,0) &c(0,1) &c(0,2) \\
c(1,0) &c(1,1) &c(2,2) \\
c(2,0) &c(2,1) &c(2,2)
\end{pmatrix}
\]
使其中\(c(x,y)c(x,z) = c(x,y\oplus z) = c (x,gcd(y,z))\)
我们假设\(x = 0\),因为\(c(0,0)c(0,0)=c(0,0)\),所以\(c(0,0)=1\)或\(0\),除了自己乘自己,其他可能的情况有
\[\left\{\begin{matrix}
c(0,0)c(0,1)=c(0,1)\\
c(0,0)c(0,2)=c(0,2)\\
c(0,1)c(0,2)=c(0,1)
\end{matrix}\right.
\]
直接暴力枚举每个值发现满足条件的有四组\((0 ,0 ,0)(1, 0 ,0)(1 ,0 ,1)(1 ,1 ,1) \),因为要保证矩阵有逆,所以取矩阵
\[\begin{pmatrix}
1&0&0\\
1&1&1\\
1&0&1
\end{pmatrix}
\]
及其逆
\[\begin{pmatrix}
1&0&0\\
0&1&-1\\
-1&0&1
\end{pmatrix}
\]
AC代码:
#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 4e6 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 1e9 + 7;
const double eps = 1e-12;
const double PI = acos(-1.0);
int a[N], b[N], c[N];
int aa[N], bb[N], cc[N];
void FWT(int a[], int n, int flag)
{
for (int len = 1; len < n; len *= 3)
for (int i = 0; i < n; i += len * 3)
for (int j = 0; j < len; ++j)
{
int x = a[i + j], y = a[i + j + len], z = a[i + j + 2 * len];
if (flag == 1)
{
a[i + j] = x;
a[i + j + len] = ((x + y) % mod + z) % mod;
a[i + j + 2 * len] = (x + z) % mod;
}
else
{
a[i + j] = x;
a[i + j + len] = (y - z + mod) % mod;
a[i + j + 2 * len] = (z - x + mod) % mod;
}
}
}
int cal(int n, int c)
{
int tot = 1;
while (tot <= n)
tot *= 3;
for (int i = 0; i <= n; ++i)
{
aa[i] = a[i];
bb[i] = b[i];
}
for (int i = n + 1; i < tot; ++i)
{
aa[i] = 0;
bb[i] = 0;
}
FWT(aa, tot, 1), FWT(bb, tot, 1);
for (int i = 0; i < tot; ++i)
cc[i] = (ll)aa[i] * bb[i] % mod;
FWT(cc, tot, -1);
int res = 0, ck = c;
for (int k = 1; k < tot; ++k, ck = (ll)ck * c % mod)
res = (res + (ll)cc[k] * ck) % mod;
return res;
}
int main()
{
int n;
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", &a[i]);
for (int i = 1; i <= n; ++i)
scanf("%d", &b[i]);
for (int i = 1; i <= n; ++i)
scanf("%d", &c[i]);
int ans = 0;
for (int p = 1; p <= n; ++p)
ans = (ans + cal(n / p, c[p])) % mod;
printf("%d\n", ans);
return 0;
}