2021牛客多校4 H/nowcoder 11255 H Convolution

题目链接:https://ac.nowcoder.com/acm/contest/11255/H

题目大意:定义运算符⊗:\(x=\prod p_i^{ a_i }\)\(y=\prod p_i^{ b_i }\)\(x\otimes y=\prod p_i^{\mid a_i-b_i\mid }\),定义\(b_i=\sum_{1\leq j,k\leq n,j\otimes k==i}^{}a_jk^c\),计算\((b_1 \ xor \ b_2 \ xor \ ... \ xor \ b_n) \ mod \ 998244353\)

题目思路:首先化简运算符\(\otimes\)

\[x\otimes y=\prod p_i^{\mid a_i-b_i\mid } = \prod p_i^{max(a_i,b_i)-min(a_i,b_i) } = \frac{lcm(x,y)}{gcd(x,y)}=\frac{xy}{gcd^2(x,y)} \]

这个\(b_i\)不好看,我们化成

\[b_i=\sum_{j=1}^{n}\sum_{k=1}^{n}\sum_{\frac{jk}{gcd^2(j,k)}=i}^{}a_jk^c \]

枚举\(g=gcd(j,k)\)

\[b_i=\sum_{j=1}^{n}\sum_{k=1}^{n}\sum_{g=1}\sum_{\frac{jk}{g^2}=i}a_jk^c \]

\(x=\frac{j}{g},y=\frac{k}{g}\)

\[b_i=\sum_{x=1}^{n}\sum_{y=1}^{n}\sum_{g=1}^{min(\frac{n}{x},\frac{n}{y})}\sum_{xy=i}a_{jg}(yg)^c \]

后面和\(g\)没什么关系了,改变枚举顺序

\[b_i=\sum_{xy=i}y^c\sum_{g=1}^{min(\frac{n}{x},\frac{n}{y})}a_{jg}g^c \]

通过枚举\(x,y,g\)即可得出答案

AC代码:

#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 1e6 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-9;
const double PI = acos(-1.0);
ll a[N], b[N];
ll dp[N], p[N];
int gcd(int a, int b)
{
    return b == 0 ? a : gcd(b, a % b);
}
ll ksm(ll a, ll b)
{
    ll res = 1 % mod;
    while (b)
    {
        if (b & 1)
            res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}
int main()
{
    int n, c;
    scanf("%d%d", &n, &c);
    for (int i = 1; i <= n; ++i)
        scanf("%lld", &a[i]);
    for (int i = 1; i <= n; ++i)
        p[i] = ksm(i, c);
    for (int x = 1; x <= n; ++x)
    {
        for (int g = 1; g * x <= n; ++g)
            dp[g] = (dp[g - 1] + p[g] * a[x * g]) % mod;
        for (int y = 1; y * x <= n; ++y)
            if (gcd(x, y) == 1)
                b[x * y] = (b[x * y] + dp[min(n / x, n / y)] * p[y]) % mod;
    }
    ll ans = 0;
    for (int i = 1; i <= n; ++i)
        ans ^= b[i];
    printf("%lld\n", ans);
    return 0;
}
posted @ 2021-08-01 09:02  xiaopangpang7  阅读(147)  评论(0编辑  收藏  举报