P9838
数据点分治。
考虑这个式子事实上是什么,其实它是 \(\sum i(n+1-i) f(p_i)\)。
感性看,有相当多种排列的值相同。\(f(x)=i\) 有 \(\frac{n}{2^i}\) 组解,所以,本质不同但值相同的 \(p\) 非常多,至少是 \(\prod (\frac{n}{2^i})!\)。
题解告诉我们 \(n=29\) 时这个数值就大于 \(10^{18}\) 了,因此 \(n \geq 29\) 时等价于 \(k=1\)。
先考虑这个 \(k=1\)。
\(f\) 序列是定值,\(i(n+1-i)\) 序列也是定值,根据排序不等式,我们要把大的 \(f\) 放到边角。
一次放下一段同样的数字,那么我们需要快速求 \(\sum_{i=1}^r i(n+1-i)\)。这是一个三次函数。
接着考虑 \(n < 29\)。
注意到 \(f\) 只有 \(5\) 种值,往后填时我们不关心前面怎么填,只关心每种值用了几个。因此 dp:\(dp_{a, b, c, d, e, x}\) 表示前 \(a+b+c+d+e\) 位填了 \(a\) 个 \(1\),以此类推,原式的值为 \(x\),有几种方法。这是可以贪心地倒推的。
做完了,感觉挺神秘。
#include <bits/stdc++.h>
const int mod = 998244353, iv2 = (mod + 1) / 2, iv6 = (mod + 1) / 6;
using LL = long long;
LL cnt[63];
inline int sum(LL r, LL n) {
r %= mod, n %= mod;
return (r * (r + 1) % mod * iv2 % mod * (n + 1) % mod
- r * (r + 1) % mod * (2 * r + 1) % mod * iv6 + mod) % mod;
}
inline int sum(LL l, LL r, LL n) {
return (sum(r, n) - sum(l - 1, n) + mod) % mod;
}
inline int f(LL n, LL m) {
return m % mod * ((n - m + 1) % mod) % mod % mod;
}
int calc(LL n) {
for (int i = 0; i < 61; i++)
cnt[i + 1] = n >> i;
for (int i = 1; i < 61; i++)
cnt[i] -= cnt[i + 1];
int ans = 0;
LL m = n >> 1; // 已经使用 [m+1, n-m-1]
if (n & 1) {
ans += f(n, (n + 1) / 2), --cnt[1], --n;
}
for (int i = 1; i <= 60; i++) {
int t = cnt[i] / 2;
(ans += 2ll * sum(m - t + 1, m, n) * i % mod) %= mod;
m -= t;
if (cnt[i] & 1) {
(ans += 1ll * f(n, m) * (i + i + 1) % mod) %= mod;
--cnt[i + 1], --m;
}
}
return ans;
}
LL fac[17];
LL n;
int dp[15][9][5][3][2][5983];
LL dfs(int a, int b, int c, int d, int e, int x) {
if (a < 0 || b < 0 || c < 0 || d < 0 || e < 0 || x < 0) return 0;
if (!a && !b && !c && !d && !e) return x == 0;
if (dp[a][b][c][d][e][x] != -1) return dp[a][b][c][d][e][x];
LL ans = 0;
int p = f(n, a + b + c + d + e + 1);
ans += dfs(a - 1, b, c, d, e, x - p);
ans += dfs(a, b - 1, c, d, e, x - 2 * p);
ans += dfs(a, b, c - 1, d, e, x - 3 * p);
ans += dfs(a, b, c, d - 1, e, x - 4 * p);
ans += dfs(a, b, c, d, e - 1, x - 5 * p);
return dp[a][b][c][d][e][x] = ans;
}
int main() {
fac[0] = 1;
for (int i = 1; i < 17; i++)
fac[i] = 1ll * fac[i - 1] * i % mod;
int T; scanf("%d", &T); while (T--) {
LL k; scanf("%lld %lld", &n, &k);
int t = calc(n);
if (n >= 29) {
printf("%d\n", t); continue;
}
}
}
本文来自博客园,作者:purplevine,转载请注明原文链接:https://www.cnblogs.com/purplevine/p/solution-p9838.html