Loading

CF1042E Vasya and Magic Matrix 题解

题目链接

思路分析

看到题目中 \(n,m \leq 1000\) ,故直接考虑 \(O(n^2)\) 级别做法。
我们先把所有的点按照 \(val\) 值从小到大排序,这样的话二维问题变成序列问题。
\(f_i\) 表示走到第 \(i\) 个点的价值的期望。
先列出裸的 \(dp\) 方程:(\(Num\) 表示符合条件的点的个数)

\[f_i =\frac{1}{Num} \sum_{a_i > a_j}(x_i-x_j)^2+(y_i-y_j)^2+f_j \]

但是这个好像是 \(O(n^2m^2)\) 的优秀算法……
不要担心,我们把式子化简一下:

\[f_i=\frac{1}{Num}\sum_{a_i>a_j}x_i^2+2x_ix_j+x_j^2+y_i^2+2y_iy_j+y_j^2+f_j \\ f_i=\frac{1}{Num}\sum_{a_i>a_j}x_i^2+x_j^2+y_i^2+y_j^2+2(x_ix_j+y_iy_j)+f_j \]

我们惊喜地发现,可以用前缀和优化。
设:

  • \(suma_i=\sum_{j=1}^ia_j\quad sumb_i=\sum_{j=1}^ib_j\)
  • \(sumpa_i=\sum_{j=1}^ia_j^2 \quad sumpb_i=\sum_{j=1}^ib_j^2\)
  • \(sumf_i=\sum_{j=1}^if_j\)

直接按照上述前缀和替换即可,在这里不写了。

Code

#include <bits/stdc++.h>

#define file(a) freopen(a".in", "r", stdin), freopen(a".out", "w", stdout)

#define Enter putchar('\n')
#define quad putchar(' ')

#define int long long 

#define N 1005
#define mod 998244353

int n, m, a[N][N], tot, x, y, f[N * N], sumf[N * N];
int suma[N * N], sumb[N * N], sumpa[N * N], sumpb[N * N];

struct Node {
  int x, y, num;
  friend bool operator <(const Node &p, const Node &q) {
    return p.num < q.num;
  } 
}node[N * N];

inline int ksm (int a, int n) {
  int ret = 1;
  while (n) {
    if (n % 2 == 1) ret = (ret * a) % mod;
    a = (a * a) % mod;
    n /= 2;
  }
  return ret;
}

inline void init() {
  for (int i = 1; i <= tot; i++) {
    sumpa[i] = (sumpa[i - 1] + node[i].x * node[i].x) % mod;
    sumpb[i] = (sumpb[i - 1] + node[i].y * node[i].y) % mod;
    suma[i] = (suma[i - 1] + node[i].x) % mod;
    sumb[i] = (sumb[i - 1] + node[i].y) % mod;
  }
  return ;
}

inline void solve(int pos, int id) {
  // std::cout << pos <<  " " << id << std::endl;
  f[pos] = (f[pos] + sumpa[id] + sumpb[id]) % mod;
  f[pos] = f[pos] - 2 * node[pos].x * suma[id] - 2 * node[pos].y * sumb[id];
  f[pos] = (f[pos] % mod + mod) % mod;
  int px = node[pos].x * node[pos].x;
  int py = node[pos].y * node[pos].y;
  f[pos] = (f[pos] + id * px + id * py) % mod;
  f[pos] = (f[pos] + sumf[id]) % mod;
  f[pos] *= ksm(id, mod - 2); f[pos] %= mod;
  sumf[pos] = (sumf[pos - 1] + f[pos]) % mod;
  return ;
}

signed main(void) {
  std::cin >> n >> m;
  for (int i = 1; i <= n; i++)
    for (int j = 1; j <= m; j++)
      scanf("%d", &a[i][j]);
  for (int i = 1; i <= n; i++) {
    for (int j = 1; j <= m; j++) {
      node[++tot].x = i;
      node[tot].y = j;
      node[tot].num = a[i][j];
    }
  }
  std::cin >> x >> y;
  std::sort(node + 1, node + 1 + tot);
  init();
  int last;
  node[0].num = -114514;
  for (int i = 1; i <= tot; i++) {
    if (node[i].num != node[i - 1].num) last = i - 1;
    solve(i, last);
    if (node[i].x == x && node[i].y == y) {
      std::cout << f[i] << std::endl;
      return 0;
    }
  }
}
posted @ 2022-03-26 17:25  Aonynation  阅读(52)  评论(0编辑  收藏  举报