https://www.hackerrank.com/challenges/the-grid-search/forum
今天碰见这题,看见难度是Moderate,觉得应该能半小时内搞定。
读完题目发现是纯粹的一道子矩阵匹配问题,想想自己以前没做过,肯定能学到新算法,于是就开搞了。
于是上网搜到了Rabin-Karp算法,一种基于hashing的模式匹配算法。尽管连一维的我也没写过,但看了思想以后觉得推广到二维应该也不会很难。
于是有了以下代码,原理就是计算子矩阵的hash key。以hash key的比较代替了子矩阵的比较,这样可以首先排除掉hash key不相等的子矩阵。
对于hash key相等的,再用朴素方法判断子矩阵是否相等。
为什么最后还是要判断子矩阵是否相等呢?因为hash key可能存在碰撞,即使概率不大,为了保证正确性也需要进行检查。
学习Rabin-Karp算法的资料在此:
http://blog.sina.com.cn/s/blog_6a09b5a70100nhnr.html
思路虽简单,代码写起来却各种bug,最终我花了不下两个钟头才搞定。Hackerrank果然是给hacker玩的,我这水平在上面真是举步维艰。
不过,这么搞下来,倒是有了实实在在的收获,如果学习算法能一直保持这种节奏就好了。
下面是AC的代码,时空复杂度均为O(N ^ 2):
1 # 2D Rabin-Karp Algorithm 2 import re 3 4 MOD = 10 ** 9 + 7 5 6 def get2DMatrix(n, m): 7 a = [[0 for j in xrange(m)] for i in xrange(n)] 8 return a 9 10 def calcHash(a, nn, mm): 11 n = len(a) 12 m = len(a[0]) 13 14 b = 1 15 for i in xrange(mm): 16 b = b * 10 % MOD 17 b2 = 1 18 for i in xrange(nn): 19 b2 = b2 * b % MOD 20 21 h = get2DMatrix(n, m) 22 for i in xrange(n): 23 val = 0 24 for j in xrange(m): 25 val = (val * 10 + a[i][j]) % MOD 26 if j >= mm: 27 val = (val + a[i][j - mm] * (MOD - b)) % MOD 28 h[i][j] = val 29 30 h2 = get2DMatrix(n, m) 31 h2[0] = h[0][:] 32 for i in xrange(1, n): 33 for j in xrange(m): 34 h2[i][j] = (h2[i - 1][j] * b + h[i][j]) % MOD 35 if i >= nn: 36 h2[i][j] = (h2[i][j] + h[i - nn][j] * (MOD - b2)) % MOD 37 return h, h2 38 39 def equal(a, p, ai, aj): 40 np = len(p) 41 mp = len(p[0]) 42 for i in xrange(np): 43 for j in xrange(mp): 44 if a[ai + i][aj + j] != p[i][j]: 45 return False 46 return True 47 48 def solve(): 49 na, ma = map(int, re.split('\s+', raw_input().strip())) 50 a = [] 51 for i in xrange(na): 52 a.append(map(int, list(raw_input().strip()))) 53 np, mp = map(int, re.split('\s+', raw_input().strip())) 54 p = [] 55 for i in xrange(np): 56 p.append(map(int, list(raw_input().strip()))) 57 ha, h2a = calcHash(a, np, mp) 58 hp, h2p = calcHash(p, np, mp) 59 60 for i in xrange(np - 1, na): 61 for j in xrange(mp - 1, ma): 62 if h2a[i][j] != h2p[np - 1][mp - 1]: 63 continue 64 if equal(a, p, i - np + 1, j - mp + 1): 65 print('YES') 66 return 67 print('NO') 68 69 if __name__ == '__main__': 70 t = int(raw_input()) 71 for ti in xrange(t): 72 solve() 73