题解 UOJ #62. 【UR #5】怎样跑得更快
首先先把 \(\text{lcm}\) 拆出来:
\[\sum_{j=1}^n\gcd(i,j)^{c-d}j^dx_j=\frac{b_i}{i^d}
\]
此后我们令 \(x_i\) 为 \(i^dx_i\),\(b_i\) 为 \(\frac{b_i}{i^d}\),于是有:
\[\sum_{j=1}^n\gcd(i,j)^{c-d}x_j=b_i
\]
现在先考虑约尔当函数:
\[J_k(n)=id_k(n)\otimes \mu
\]
它满足性质:
\[J_k(n)\otimes 1=id_k
\]
然后拿着它拆一下:
\[\begin{aligned}
b_i&=\sum_{j=1}^n\gcd(i,j)^{c-d}x_j\\
&=\sum_{j=1}^nx_j\sum_{k\mid\gcd(i,j)}J_{c-d}(k)\\
&=\sum_{k\mid i}J_{c-d}(k)\sum_{j=1}^{n/k}x_{jk}
\end{aligned}
\]
现在令
\[G(k)=\sum_{j=1}^{n/k}x_{jk}
\]
于是我们有:
\[\sum_{k\mid i}J_{c-d}(k)G(k)=b_i\Leftrightarrow J_{c-d}\cdot G=b\otimes\mu
\]
于是我们现在可以得到 \(G\) 的各项值了,然后做一个枚举超集的莫比乌斯反演:
\[f(d)=\sum_{d\mid n}g(n)\Rightarrow g(d)=\sum_{d\mid n}\mu(\frac nd)f(n)
\]
就可以得到最后的答案了。
无解的情况就是非零数除以了零,判掉就行了。
#include<cstdio>
typedef long long ll;
const ll mod = 998244353;
const int maxn = 1E+5 + 5;
int n, c, d, q;
ll mu[maxn], J[maxn], idk[maxn];
ll b[maxn], g[maxn];
int cnt, prime[maxn];
bool nprime[maxn];
inline ll fsp(ll a, ll b, const ll &mod = mod, ll res = 1) {
for(a %= mod, b %= mod - 1; b; a = a * a % mod, b >>= 1)
if(b & 1) res = res * a % mod; return res;
}
inline void pre(int N) {
mu[1] = 1, idk[1] = 1;
for(int i = 2; i <= N; ++i) {
if(!nprime[i]) prime[++cnt] = i, mu[i] = -1, idk[i] = fsp(i, c - d + mod - 1);
for(int j = 1; j <= cnt && i * prime[j] <= N; ++j) {
nprime[i * prime[j]] = 1, idk[i * prime[j]] = idk[i] * idk[prime[j]] % mod;
if(i % prime[j]) mu[i * prime[j]] = -mu[i];
else { mu[i * prime[j]] = 0; break; }
}
}
for(int d = 1; d <= N; ++d)
for(int i = 1; i * d <= N; ++i)
(J[i * d] += idk[d] * mu[i]) %= mod;
}
int main() {
scanf("%d%d%d%d", &n, &c, &d, &q);
c %= mod - 1, d %= mod - 1, pre(n);
while(q --> 0) {
for(int i = 1; i <= n; ++i) {
scanf("%lld", &b[i]), g[i] = 0;
b[i] = b[i] * fsp(fsp(i, mod - 2), d) % mod;
}
for(int d = 1; d <= n; ++d)
for(int i = 1; i * d <= n; ++i)
(g[i * d] += b[d] * mu[i]) %= mod;
bool flag = 1;
for(int i = 1; i <= n; ++i) {
if(g[i] && !J[i]) { flag = 0; break; }
g[i] = g[i] * fsp(J[i], mod - 2) % mod;
}
if(!flag) printf("-1");
else for(int i = 1; i <= n; ++i) {
ll res = 0;
for(int k = 1; k * i <= n; ++k) (res += g[k * i] * mu[k]) %= mod;
printf("%lld ", (res * fsp(fsp(i, mod - 2), d) % mod + mod) % mod);
} putchar('\n');
}
}