多模式匹配AC算法(kotlin)2

目的

之前写了一篇文章多模式匹配AC算法Java(kotlin)实现,可建模中文,里面通过建模char(unicode)来实现跳转,使用的是map。但是通过私下的实验,其实这样做性能并不高,而且代码复杂难懂。更通用的做法是将unicode字符串转换为bytes,每个byte256种情况,也就是为每个节点维护一个256维的数组表示子节点。通过实验,建模bytes的方法比建模char的性能高不少。本文贴出建模byte的kotlin代码。

要点

  1. 根节点比较特殊,其子节点除了模式中的首byte节点之外,剩下的bytes的节点的跳转表要指向根节点。
  2. 失败跳转表构建思路比较绕,其实也比较简单。假如当前节点在byte_x失配了,那么就查其父节点的失配节点是否有byte_x的子节点,如果有那么恭喜,直接到那个节点,继续在树中跳转;如果没有,则查其父节点的失配节点的失配节点是否有byte_x的子节点,如果有,直接跳转过去。由于最差也会跳转到根节点,所以算法收敛。

代码

package com.davezhao.utils

val BYTE_SIZE = Byte.MAX_VALUE - Byte.MIN_VALUE + 1

data class NodeByte(
	var finish: Boolean = false, // 当前节点是否为某一模式终止节点
	var label: Int = 0, // 当前节点编号,默认为0,即根节点,其余的节点大于0
	var pattern: String = "", // 如果当前节点为某一模式的终止节点,则该字段保存终止的模式字符串
	val transitionTable: MutableList<Int> = MutableList(BYTE_SIZE, { -1 }) // 当前节点的子节点编号,-1表示不是子节点
)


class AcMatchByte {
	private val startNode = NodeByte() // trie的根节点
	private var labelCount = 1 //内部有效的节点个数
	private val nodes = mutableListOf<NodeByte>(startNode) // 保存所有的节点,位置为节点的编号
	private var fail: MutableList<Int> = mutableListOf() // 失败跳转表

	/**
	 * 添加模式,构建trie树,可一次放入多个模式字符串,或者多次调用该函数放入。
	 * @param patterns
	 * @return
	 */
	fun addPatterns(vararg patterns: String) {
		var latestLabel: Int = labelCount
		for (pattern in patterns) {
			var pNode = startNode //从根节点开始构建
			for (b in pattern.toByteArray()) { //将当前模式字符串转换为byte
				val i = b - Byte.MIN_VALUE // 取当前byte的位置
				var nxtNodeLabel = pNode.transitionTable[i] //查看当前节点是否包含i的子节点
				if (nxtNodeLabel == -1) { //如果不包含,则需要为其创建子节点
					val nxtNode = NodeByte()
					nxtNode.label = latestLabel
					nodes.add(nxtNode)
					pNode.transitionTable[i] = latestLabel // 为i创建指向子节点的跳转表
					nxtNodeLabel = latestLabel++ // 全局的节点编号要增加1
				}
				pNode = nodes[nxtNodeLabel] // 令pNode指向i的子节点,继续下一个byte的构建
			}
			pNode.finish = true // 一个模式完成后,为模式的最后一个节点设置finish=true
			pNode.pattern = pattern // 一个模式完成后,为模式的最后一个节点pattern设为当前pattern
		}

		labelCount = latestLabel // 构建完所有的模式后,将最新编号赋值给labelCount,以备下次构建
	}

	fun build() {
		// 在构建失败跳转规则之前,需要有一个保底的设置。根节点有256个子节点,哪些非匹配模式的节点需要跳转到根节点本身,以便自动机跳转
		for (i in (0 until BYTE_SIZE)) {
			if (startNode.transitionTable[i] == -1) {
				startNode.transitionTable[i] = 0
			}
		}

		val q = mutableListOf<Int>() // 创建一个队列,用于存储待创建其子节点失败跳转表的节点
		fail = MutableList(labelCount, { -1 }) // 失败跳转表是每个节点都有跳转,所以size为state_count个。
		startNode.transitionTable.filter { it > 0 }.forEach {
			// 将startNode节点中非指向根节点的节点挑出来,设置它们的失败跳转为根节点,并且加入队列,以便创建其子节点的失败跳转表
			fail[it] = 0
			q.add(it)
		}

		while (!q.isEmpty()) { // 如果队列为空,则说明所有节点失败跳转构建完毕,退出
			val known = q.removeAt(0) // 从队列中取出队头的节点
			(0 until BYTE_SIZE).filter { nodes[known].transitionTable[it] > 0 }.forEach { i ->
				// 取出当前节点known的所有模式子节点
				val nxt = nodes[known].transitionTable[i] // 对于nxt子节点
				var p = fail[known] // 首先先得到其父节点的跳转节点
				while (!(p != -1 && nodes[p].transitionTable[i] != -1)) { // 然后判断,如果nxt节点父节点known跳转节点p存在,且p也有子节点i,则退出跳转,否则继续寻找p的跳转节点赋值给p继续判断。由于我们有根节点的保底设置,所以最差也会到根节点。
					p = fail[p]
				}
				fail[nxt] = nodes[p].transitionTable[i] // 子节点nxt的失败跳转节点即为其广义父节点的N(大于等于1)次跳转节点的同字符i的子节点
				q.add(nxt) // 将nxt放入q,为其子节点设置失败跳转
			}
		}
	}

	fun match(str: String): List<String> {
		val strB = str.toByteArray()
		var pNode = startNode
		var i = 0 // 遍历带搜索字符串str的下标
		val res = mutableListOf<String>() // 保存搜索到的字符串列表
		while (i < strB.size) {
			val trans = strB[i] - Byte.MIN_VALUE
			if (pNode.transitionTable[trans] != -1) { // 如果当前字符跳转成功,则跳转到当前字符的下一个字符
				pNode = nodes[pNode.transitionTable[trans]]
			} else { // 否则回退下标,并且使用失败跳转转到下一个节点继续
				--i
				pNode = nodes[fail[pNode.label]]
			}
			if (pNode.finish) { // 如果当前节点已经是某一模式终点,则保存该模式
				res.add(pNode.pattern)
			}
			++i // 正常的下标增长
		}

		return res
	}
}

fun main(args: Array<String>) {
	val ac = AcMatchByte()
	val patterns = arrayOf("his", "hers", "she", "he", "中国")
	ac.addPatterns(patterns = *patterns)
	ac.addPatterns("国中", "中国中")
	ac.build()

	val str = "hishers中国人民中国中国"

	val res = ac.match(str)
	println(res)
}
posted @ 2021-07-31 22:57  ledao  阅读(61)  评论(0编辑  收藏  举报