Rank Transform of a Matrix

Rank Transform of a Matrix

Given an m x n matrix , return a new matrix answer where answer[row][col] is the rank of matrix[row][col] .

The rank is an integer that represents how large an element is compared to other elements. It is calculated using the following rules:

  • The rank is an integer starting from 1.
  • If two elements p and q are in the same row or column, then:
    • If p < q then rank(p) < rank(q)
    • If p == q then rank(p) == rank(q)
    • If p > q then rank(p) > rank(q)
  • The rank should be as small as possible.

The test cases are generated so that answer is unique under the given rules.

Example 1:

Input: matrix = [[1,2],[3,4]]
Output: [[1,2],[2,3]]
Explanation:
The rank of matrix[0][0] is 1 because it is the smallest integer in its row and column.
The rank of matrix[0][1] is 2 because matrix[0][1] > matrix[0][0] and matrix[0][0] is rank 1.
The rank of matrix[1][0] is 2 because matrix[1][0] > matrix[0][0] and matrix[0][0] is rank 1.
The rank of matrix[1][1] is 3 because matrix[1][1] > matrix[0][1], matrix[1][1] > matrix[1][0], and both matrix[0][1] and matrix[1][0] are rank 2.

Example 2:

Input: matrix = [[7,7],[7,7]]
Output: [[1,1],[1,1]]

Input: matrix = [[20,-21,14],[-19,4,19],[22,-47,24],[-19,4,19]]
Output: [[4,2,3],[1,3,4],[5,1,6],[1,3,4]]

 

解题思路

  这题难就难在有相同的值出现,还要对建图进行优化。如果矩阵中所有的值都不同,那么我们可以根据每个格子的大小关系建一张图,格子是图中的节点,边的关系则是由值小的格子指向值大的格子。可以发现这个图是没有环的,因为是按照严格的大小关系来连边的,否则就会出现矛盾。最后对这个图跑一遍拓扑排序就可以了。

  如果我们直接暴力建图那么时间复杂度是$O(n \cdot m \cdot (n + m))$,会超时。优化的方法是分别考虑行和列。先遍历每一行,对每一行的值进行排序,然后仅在相邻两个元素之间连一条边。然后再枚举每一列,用同样的方法连边。这样时间复杂度就变成了$O(n m\log{nm})$。这样子建图是对的是因为我们最终是跑拓扑排序,一个元素与第一个比它大的元素连一条边就够了,再与其他比它大的元素连一条边是多余的,因为第一个比它大的元素又会与更大的元素连一条边,图中自然就存在传递的关系。

  然后现在考虑有相同的值的情况。因为同一行或列值相同的格子的秩相同,因此想一下能不能把相同的格子都压缩成同一个格子来表示。考虑用并查集,枚举每一行和每一列把相同的值的格子进行合并,最后选择一个代表元素来表示这个值的格子。暴力枚举也是会超时,因此每一行或列分别开个哈希表记录出现过的值,遍历的过程中相同的值的格子就合并。

  AC代码如下:

 1 class Solution {
 2 public:
 3     vector<vector<int>> matrixRankTransform(vector<vector<int>>& matrix) {
 4         int n = matrix.size(), m = matrix[0].size();
 5         vector<int> fa(n * m);
 6         for (int i = 0; i < n * m; i++) {
 7             fa[i] = i;
 8         }
 9         function<int(int)> find = [&](int x) {
10             return fa[x] == x ? fa[x] : fa[x] = find(fa[x]);
11         };
12         for (int i = 0; i < n; i++) {   // 把每一行相同的值进行合并
13             unordered_map<int, int> mp;
14             for (int j = 0; j < m; j++) {
15                 if (mp.count(matrix[i][j])) fa[find(i * m + j)] = mp[matrix[i][j]]; // 同一行出现相同的值
16                 mp[matrix[i][j]] = find(i * m + j);
17             }
18         }
19         for (int j = 0; j < m; j++) {   // 把每一列相同的值进行合并
20             unordered_map<int, int> mp;
21             for (int i = 0; i < n; i++) {
22                 if (mp.count(matrix[i][j])) fa[find(i * m + j)] = mp[matrix[i][j]]; // 同一列出现相同的值
23                 mp[matrix[i][j]] = find(i * m + j);
24             }
25         }
26         vector<vector<int>> g(n * m);
27         for (int i = 0; i < n; i++) {   // 枚举每一行建图
28             vector<int> p;
29             for (int j = 0; j < m; j++) {
30                 p.push_back(j);
31             }
32             sort(p.begin(), p.end(), [&](int a, int b) {    // 每一行的元素从小到大排序
33                 return matrix[i][a] < matrix[i][b];
34             });
35             for (int j = 0, k = 0; j < m; j++) {    // k记录上一次枚举到的位置
36                 if (matrix[i][p[j]] != matrix[i][p[k]]) {   // 相同的值跳过
37                     g[find(i * m + p[k])].push_back(find(i * m + p[j]));
38                     k = j;
39                 }
40             }
41         }
42         for (int j = 0; j < m; j++) {   // 枚举每一列建图
43             vector<int> p;
44             for (int i = 0; i < n; i++) {
45                 p.push_back(i);
46             }
47             sort(p.begin(), p.end(), [&](int a, int b) {    // 每一列的元素从小到大排序
48                 return matrix[a][j] < matrix[b][j];
49             });
50             for (int i = 0, k = 0; i < n; i++) {
51                 if (matrix[p[i]][j] != matrix[p[k]][j]) {
52                     g[find(p[k] * m + j)].push_back(find(p[i] * m + j));
53                     k = i;
54                 }
55             }
56         }
57         vector<int> deg(n * m);
58         for (int i = 0; i < n * m; i++) {
59             if (fa[i] == i) {
60                 for (auto &j : g[i]) {
61                     deg[j]++;
62                 }
63             }
64         }
65         vector<vector<int>> ans(n, vector<int>(m, 1));  // 入度为0的节点的秩为1
66         queue<int> q;
67         for (int i = 0; i < n * m; i++) {
68             if (fa[i] == i && deg[i] == 0) q.push(i);   // 把入度为0,且是代表元素的节点加入队列
69         }
70         while (!q.empty()) {
71             int t = q.front();
72             q.pop();
73             for (auto &i : g[t]) {
74                 if (--deg[i] == 0) {
75                     q.push(i);
76                     ans[i / m][i % m] = ans[t / m][t % m] + 1;
77                 }
78             }
79         }
80         for (int i = 0; i < n * m; i++) {   // 把值相同的格子的秩改成其代表元素的秩
81             if (fa[i] != i) {   // 表示该节点不是代表元素
82                 int t = find(i);
83                 ans[i / m][i % m] = ans[t / m][t % m];
84             }
85         }
86         return ans;
87     }
88 };

 

参考资料

  Codeforces Round #345: editorial:https://codeforces.com/blog/entry/43677

posted @ 2023-01-25 16:47  onlyblues  阅读(25)  评论(0编辑  收藏  举报
Web Analytics