CF1042E Vasya and Magic Matrix 题解
思路分析
看到题目中 \(n,m \leq 1000\) ,故直接考虑 \(O(n^2)\) 级别做法。
我们先把所有的点按照 \(val\) 值从小到大排序,这样的话二维问题变成序列问题。
设 \(f_i\) 表示走到第 \(i\) 个点的价值的期望。
先列出裸的 \(dp\) 方程:(\(Num\) 表示符合条件的点的个数)
\[f_i =\frac{1}{Num} \sum_{a_i > a_j}(x_i-x_j)^2+(y_i-y_j)^2+f_j
\]
但是这个好像是 \(O(n^2m^2)\) 的优秀算法……
不要担心,我们把式子化简一下:
\[f_i=\frac{1}{Num}\sum_{a_i>a_j}x_i^2+2x_ix_j+x_j^2+y_i^2+2y_iy_j+y_j^2+f_j \\
f_i=\frac{1}{Num}\sum_{a_i>a_j}x_i^2+x_j^2+y_i^2+y_j^2+2(x_ix_j+y_iy_j)+f_j
\]
我们惊喜地发现,可以用前缀和优化。
设:
- \(suma_i=\sum_{j=1}^ia_j\quad sumb_i=\sum_{j=1}^ib_j\)
- \(sumpa_i=\sum_{j=1}^ia_j^2 \quad sumpb_i=\sum_{j=1}^ib_j^2\)
- \(sumf_i=\sum_{j=1}^if_j\)
直接按照上述前缀和替换即可,在这里不写了。
Code
#include <bits/stdc++.h>
#define file(a) freopen(a".in", "r", stdin), freopen(a".out", "w", stdout)
#define Enter putchar('\n')
#define quad putchar(' ')
#define int long long
#define N 1005
#define mod 998244353
int n, m, a[N][N], tot, x, y, f[N * N], sumf[N * N];
int suma[N * N], sumb[N * N], sumpa[N * N], sumpb[N * N];
struct Node {
int x, y, num;
friend bool operator <(const Node &p, const Node &q) {
return p.num < q.num;
}
}node[N * N];
inline int ksm (int a, int n) {
int ret = 1;
while (n) {
if (n % 2 == 1) ret = (ret * a) % mod;
a = (a * a) % mod;
n /= 2;
}
return ret;
}
inline void init() {
for (int i = 1; i <= tot; i++) {
sumpa[i] = (sumpa[i - 1] + node[i].x * node[i].x) % mod;
sumpb[i] = (sumpb[i - 1] + node[i].y * node[i].y) % mod;
suma[i] = (suma[i - 1] + node[i].x) % mod;
sumb[i] = (sumb[i - 1] + node[i].y) % mod;
}
return ;
}
inline void solve(int pos, int id) {
// std::cout << pos << " " << id << std::endl;
f[pos] = (f[pos] + sumpa[id] + sumpb[id]) % mod;
f[pos] = f[pos] - 2 * node[pos].x * suma[id] - 2 * node[pos].y * sumb[id];
f[pos] = (f[pos] % mod + mod) % mod;
int px = node[pos].x * node[pos].x;
int py = node[pos].y * node[pos].y;
f[pos] = (f[pos] + id * px + id * py) % mod;
f[pos] = (f[pos] + sumf[id]) % mod;
f[pos] *= ksm(id, mod - 2); f[pos] %= mod;
sumf[pos] = (sumf[pos - 1] + f[pos]) % mod;
return ;
}
signed main(void) {
std::cin >> n >> m;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++)
scanf("%d", &a[i][j]);
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
node[++tot].x = i;
node[tot].y = j;
node[tot].num = a[i][j];
}
}
std::cin >> x >> y;
std::sort(node + 1, node + 1 + tot);
init();
int last;
node[0].num = -114514;
for (int i = 1; i <= tot; i++) {
if (node[i].num != node[i - 1].num) last = i - 1;
solve(i, last);
if (node[i].x == x && node[i].y == y) {
std::cout << f[i] << std::endl;
return 0;
}
}
}