图论-堆-并查集-2503. 矩阵查询可获得的最大分数

2503. 矩阵查询可获得的最大分数

Description

Difficulty: 困难

Related Topics:

给你一个大小为 m x n 的整数矩阵 grid 和一个大小为 k 的数组 queries

找出一个大小为 k 的数组 answer ,且满足对于每个整数 queres[i] ,你从矩阵 左上角 单元格开始,重复以下过程:

  • 如果 queries[i] 严格 大于你当前所处位置单元格,如果该单元格是第一次访问,则获得 1 分,并且你可以移动到所有 4 个方向(上、下、左、右)上任一 相邻 单元格。
  • 否则,你不能获得任何分,并且结束这一过程。

在过程结束后,answer[i] 是你可以获得的最大分数。注意,对于每个查询,你可以访问同一个单元格 多次

返回结果数组 answer

示例 1:

输入:grid = [[1,2,3],[2,5,7],[3,5,1]], queries = [5,6,2]
输出:[5,8,1]
解释:上图展示了每个查询中访问并获得分数的单元格。

示例 2:

输入:grid = [[5,2,1],[1,1,2]], queries = [3]
输出:[0]
解释:无法获得分数,因为左上角单元格的值大于等于 3 。

提示:

  • m == grid.length
  • n == grid[i].length
  • 2 <= m, n <= 1000
  • 4 <= m * n <= 105
  • k == queries.length
  • 1 <= k <= 104
  • 1 <= grid[i][j], queries[i] <= 106

Solution

  • 解法一:堆
    用堆维护一下即可。
    Language: Python3
class Solution:
    def maxPoints(self, grid: List[List[int]], queries: List[int]) -> List[int]:
        m, n = len(grid), len(grid[0])
        res = [0] * len(queries)
        h = [(grid[0][0], 0, 0)]
        grid[0][0] = 0
        cnt = 0 
        for idx, q in sorted(enumerate(queries), key=lambda x: x[1]):
            while h and h[0][0] < q:
                cnt += 1
                _, i, j = heappop(h)
                for x, y in (i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1):
                    if 0 <= x < m and 0 <= y < n and grid[x][y]:
                        heappush(h, (grid[x][y], x, y))
                        grid[x][y] = 0
            res[idx] = cnt
        return res


  • 并查集
    将原问题转为图论问题,需要定义好节点和边。
    节点:每个位置就是一个节点,节点编号\(i * n + j\)
    边:两个节点的max
    查询的过程就是不断merge节点的过程,用并查集维护一下连通图的大小即可。
class Solution:
    def maxPoints(self, grid: List[List[int]], queries: List[int]) -> List[int]:
        m, n = len(grid), len(grid[0])
        mn = m * n

        edges = []
        for i in range(m):
            for j in range(n):
                if i: edges.append((max(grid[i][j], grid[i - 1][j]), i * n + j, (i - 1) * n + j))
                if j: edges.append((max(grid[i][j], grid[i][j - 1]), i * n + j, i * n + j - 1))
        edges.sort()

        fa = list(range(mn))
        size = [1] * mn

        def find(x):
            if fa[x] != x: fa[x] = find(fa[x])
            return fa[x]
        
        def union(x, y):
            fx = find(x)
            fy = find(y)
            if fx != fy:
                fa[fx] = fy
                size[fy] += size[fx]
        
        res = [0] * len(queries)
        j = 0
        for i, q in sorted(enumerate(queries), key=lambda x: x[1]):
            while j < len(edges) and edges[j][0] < q:
                union(edges[j][1], edges[j][2])
                j += 1
            if grid[0][0] < q:
                res[i] = size[find(0)]
        return res
posted @ 2022-12-12 23:52  hyserendipity  阅读(39)  评论(1编辑  收藏  举报