2021牛客多校4 H/nowcoder 11255 H Convolution
题目链接:https://ac.nowcoder.com/acm/contest/11255/H
题目大意:定义运算符⊗:\(x=\prod p_i^{ a_i }\),\(y=\prod p_i^{ b_i }\),\(x\otimes y=\prod p_i^{\mid a_i-b_i\mid }\),定义\(b_i=\sum_{1\leq j,k\leq n,j\otimes k==i}^{}a_jk^c\),计算\((b_1 \ xor \ b_2 \ xor \ ... \ xor \ b_n) \ mod \ 998244353\)
题目思路:首先化简运算符\(\otimes\)
\[x\otimes y=\prod p_i^{\mid a_i-b_i\mid } = \prod p_i^{max(a_i,b_i)-min(a_i,b_i) } = \frac{lcm(x,y)}{gcd(x,y)}=\frac{xy}{gcd^2(x,y)}
\]
这个\(b_i\)不好看,我们化成
\[b_i=\sum_{j=1}^{n}\sum_{k=1}^{n}\sum_{\frac{jk}{gcd^2(j,k)}=i}^{}a_jk^c
\]
枚举\(g=gcd(j,k)\)
\[b_i=\sum_{j=1}^{n}\sum_{k=1}^{n}\sum_{g=1}\sum_{\frac{jk}{g^2}=i}a_jk^c
\]
设\(x=\frac{j}{g},y=\frac{k}{g}\)
\[b_i=\sum_{x=1}^{n}\sum_{y=1}^{n}\sum_{g=1}^{min(\frac{n}{x},\frac{n}{y})}\sum_{xy=i}a_{jg}(yg)^c
\]
后面和\(g\)没什么关系了,改变枚举顺序
\[b_i=\sum_{xy=i}y^c\sum_{g=1}^{min(\frac{n}{x},\frac{n}{y})}a_{jg}g^c
\]
通过枚举\(x,y,g\)即可得出答案
AC代码:
#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 1e6 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-9;
const double PI = acos(-1.0);
ll a[N], b[N];
ll dp[N], p[N];
int gcd(int a, int b)
{
return b == 0 ? a : gcd(b, a % b);
}
ll ksm(ll a, ll b)
{
ll res = 1 % mod;
while (b)
{
if (b & 1)
res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
int main()
{
int n, c;
scanf("%d%d", &n, &c);
for (int i = 1; i <= n; ++i)
scanf("%lld", &a[i]);
for (int i = 1; i <= n; ++i)
p[i] = ksm(i, c);
for (int x = 1; x <= n; ++x)
{
for (int g = 1; g * x <= n; ++g)
dp[g] = (dp[g - 1] + p[g] * a[x * g]) % mod;
for (int y = 1; y * x <= n; ++y)
if (gcd(x, y) == 1)
b[x * y] = (b[x * y] + dp[min(n / x, n / y)] * p[y]) % mod;
}
ll ans = 0;
for (int i = 1; i <= n; ++i)
ans ^= b[i];
printf("%lld\n", ans);
return 0;
}