HMM介绍
Hidden Markov Models是一种统计信号处理方法,模型中包含2个序列和3个矩阵:状态序列S、观察序列O、初始状态矩阵P、状态转移矩阵A、混淆矩阵B。举个例子来说明。
你一个异地的朋友只做三种活动:散步、看书、做清洁。每天只做一种活动。假设天气只有两种状态:晴和兩。每天只有一种天气。你的朋友每天告诉你他做了什么,但是不告诉你他那里的天气。
某一周从周一到周五每天的活动分别是{读书,做清洁,散步,做清洁,散步}----这就是观察序列O,因为你可以观察得到。
从周一到周五的天气依次是{晴,兩,晴,晴,晴}----这就是状态序列S,状态序列是隐藏的,你不知道。
根据长期统计,某天晴的概率是0.6,兩的概率是0.4。则$\pi=[0.6,0.4]$。
从晴转晴的概率是0.7,从晴转兩的概率是0.3。从兩转晴的概率是0.4,从兩转兩的概率是0.6。则$A=\left[\begin{array}{cc}0.7&0.3\\0.4&0.6\end{array}\right]$。
天气晴时,散步的概率是0.4,看书的概率是0.3,做清洁的概率是0.3。天气兩时,散步的概率是0.1,看书的概率是0.4,做清洁的概率是0.5。则$B=\left[\begin{array}{cc}0.4&0.3&0.3\\0.4&0.1&0.5\end{array}\right]$。
该模型和实际情况有明显不符的地方:用一简单的状态转移矩阵A来表示状态的转移概率的前提是t时刻的状态只跟t-1时刻的状态有关,而实际上今天的天气跟过去几天的天气都有关系,而且跟过去几天的晴朗程度、兩量大小都有关系;混淆矩阵B认为今天的活动只跟今天的天气有关系,实际上今天的活动跟过去几天的活动也有关系,比如过去一周都没有打扫房间,那今天做清洁的概率就大大增加。
模型介绍完了。
评估问题
隐马尔可夫模型中包含一个评估问题:已知模型参数,计算某一特定输出序列的概率。通常使用forward算法解决。
比如计算活动序列{读书,做清洁,散步,做清洁,散步}出现的概率,就属于评估问题。
如果穷举的话,观察序列会有2^5种,需要分别计算它们出现的概率,然后找出概率最大的。
穷举法中有很多重复计算,向前算法就是利用已有的结果,减少重复计算。
算法借助于一个矩阵Q[LEN][M],其中M是所有状态的种数,Q[i][j]表示从第0天到第i天,满足观察序列,且第i天隐藏状态为Sj的所有可能的隐藏序列的概率之和。最终所求结果为Q[LEN-1][0]+...+Q[LEN-1][M-1],即最后一天,所有隐藏状态下,分别满足观察序列的概率值之和。
比如Q[0][0]=p(第一天做卫生 且 第一天晴)=p(天晴)*p(做卫生|天晴)=P[0]*B[0][2]=0.6*0.3=0.18
Q[0][1]=p(第一天做卫生 且 第一天下雨)=p(下雨)*p(做卫生|下雨)=P[1]*B[1][2]=0.4*0.5=0.2
Q[1][0]=p(第一天做卫生 且 第二天晴 且 第二天做卫生)
=p(第一天做卫生 且 第二天晴)*p(天晴的情况下做卫生)
=p{ p(第一天做卫生 且 第一天晴)*p(从天晴转天晴)+p(第一天做卫生 且 第一天下雨)*p(从下雨转天晴) }*p(天晴的情况下做卫生)
={ Q[0][0]*A[0][0] + Q[0][1]*A[1][0] } * B[0][2]
Q[1][1]= ……
…… ……
可以看到计算Q矩阵的每i行时都用到了第i-1行的结果。
解码问题
解码问题是:已知模型参数,寻找最可能的能产生某一特定输出序列O(LEN)的隐含状态的序列。通常使用Viterbi算法解决。
观察序列长度为LEN,则隐藏状态序列长度也是LEN,如果采用穷举法,就有M^LEN种可能的隐藏状态序列,我们要计算每一种隐藏状态到指定观察序列的概率,最终选择概率最大的。
穷举法中有很多重复计算,Viterbi算法就是利用已有的结果,减少重复计算。
跟评估问题非常相似,不同点在于评估算的是和,解码算的是最大值。
Viterbi算法主要就是在计算一个矩阵Q[LEN][M],其中Q[i][j]表示从第0天到第i天,满足观察序列,且第i天隐藏状态为Sj的所有可能的隐藏序列的概率的最大值。另外还要建立一个矩阵Path[LEN][M],用来记录状态序列中某一状态之前最可能的状态。
举个例子,假如指定观察序列是{读书,做卫生,散步,做卫生,散步},求出现此观察序列最可能的状态序列是什么。
Q[0][0]=p(第一天读书 且 第一天晴)=p(天晴)*p(读书|天晴)
Path[0][0]=-1;
Q[0][1]=p(第一天读书 且 第一天下雨)=p(下雨)*p(读书|下雨)
Path[0][1]=-1;
关键是从第二天开始,Q[1][0]表示:满足“第一天读书 且 第二做卫生 且 第二天晴”的所有可能的隐藏序列的概率的最大值。那么满足“第一天读书 且 第二做清洁 且 第二天晴”的所有可能的隐藏序列有哪些呢?第二天是必须满足晴天的,第二天之前的状态可以任意变。则所有可能的隐藏序列就是“晴 晴”和“雨 晴”。实际上考虑第三天(及第三天以后)时,并不需要考虑“所有”可能的隐藏序列,而只需要考虑第二天的不同状态取值,这是因为马氏过程有无后效性--tm时刻所处状态的概率只和tm-1时刻的状态有关,而与tm-1时刻之前的状态无关。
Q[1][0]=
max{ p(第一天晴 且 第一天读书 且 第二天晴 且第二天做卫生) ,p(第一天下雨 且 第一天读书 且 第二天晴 且 第二天做卫生) }
=max{ p(第一天读书 且 第一天晴)*p(天晴转天晴),p(第一天读书 且 第一天下雨)*p(下雨转天晴) } * p(做卫生|天晴)
=max{ Q[0][0]*A[0][0],Q[0][1]*A[1][0] } * B[0][2]
假如Q[0][0]*A[0][0] < Q[0][1]*A[1][0],则Path[1][0]=1;假如Q[0][0]*A[0][0] > Q[0][1]*A[1][0],则Path[1][0]=0。
Q[1][1]= ……
…… ……
可以看到计算Q矩阵的每i行时都用到了第i-1行的结果。
下面给出两种算法的Java代码:
1 import java.io.BufferedReader; 2 import java.io.File; 3 import java.io.FileNotFoundException; 4 import java.io.FileReader; 5 import java.io.IOException; 6 import java.util.Collections; 7 import java.util.HashMap; 8 import java.util.LinkedList; 9 import java.util.List; 10 import java.util.Map; 11 import java.util.Map.Entry; 12 13 import xxzl.dm.utility.FileUtil; 14 import xxzl.dm.utility.Pair; 15 import xxzl.dm.utility.math.Smooth; 16 17 /** 18 * 隐马尔可夫推断,包括评估问题和解码问题 19 * 20 * @Author:zhangchaoyang 21 * @Since:2015年3月29日 22 * @Version:1.0 23 */ 24 public class HmmInference { 25 26 /** 27 * HMM的模型参数 28 */ 29 private List<String> stateSet = new LinkedList<String>();// 状态值集合 30 private List<String> observeSet = new LinkedList<String>();// 观察值集合 31 private double[] stateProb;// 初始状态概率矩阵 32 private double[][] stateTrans;// 状态转移矩阵 33 private double[][] emission;// 发射矩阵 34 private double[] minEmission;// 发射矩阵每一行的极小值(用于零值平滑) 35 36 /** 37 * 使用已标记好的训练样本,经过简单统计使用极大似然确定HMM的参数 38 * 39 * @param tagFile 40 * 文件格式:每行2列,第1列是观察值,第2列是状态值,不用序列之间用空行隔开。<br> 41 * /test/resources/corpus/wordcut.train是一个示例文件。 42 */ 43 public void initParam(String tagFile) { 44 Map<String, Integer> stateIndexMap = new HashMap<String, Integer>();// 状态值及其编号 45 Map<String, Integer> observeIndexMap = new HashMap<String, Integer>();// 观察值及其编号 46 int[] stateCount;// 状态值及其计数 47 int[][] stateTransCount;// 状态转移计数矩阵 48 int[][] confusionCount;// 混淆计数矩阵 49 50 try { 51 BufferedReader br = new BufferedReader(new FileReader(new File( 52 tagFile))); 53 if (br.markSupported()) { 54 br.mark(1024 * 1024 * 100); 55 } 56 String line = null; 57 int stateTotal = 0; 58 int observeTotal = 0; 59 // 第一趟扫描文件,给stateIndexMap、observeIndexMap赋值 60 while ((line = br.readLine()) != null) { 61 String[] arr = line.split("\\s+");// 每行存储一个观察值、一个状态值,用空白符隔开。用空行隔开不同的序列。 62 if (arr.length >= 2) { 63 String observe = arr[0]; 64 String state = arr[1]; 65 if (!observeIndexMap.containsKey(observe)) { 66 observeIndexMap.put(observe, observeTotal++); 67 } 68 if (!stateIndexMap.containsKey(state)) { 69 stateIndexMap.put(state, stateTotal++); 70 } 71 } 72 } 73 if (br.markSupported()) { 74 br.reset(); 75 } else { 76 br.close(); 77 br = new BufferedReader(new FileReader(new File(tagFile))); 78 } 79 // System.out.println("state set:"); 80 // for (Entry<String, Integer> entry : stateIndexMap.entrySet()) { 81 // System.out.println(entry.getValue() + ":" + entry.getKey()); 82 // } 83 // 第二趟扫描文件,给stateTransCount、confusionCount、stateCount赋值 84 for (int i = 0; i < stateTotal; i++) { 85 stateSet.add(""); 86 } 87 for (int i = 0; i < observeTotal; i++) { 88 observeSet.add(""); 89 } 90 stateTransCount = new int[stateTotal][]; 91 for (int i = 0; i < stateIndexMap.size(); i++) { 92 stateTransCount[i] = new int[stateTotal]; 93 } 94 confusionCount = new int[stateTotal][]; 95 for (int i = 0; i < stateTotal; i++) { 96 confusionCount[i] = new int[observeTotal]; 97 } 98 stateCount = new int[stateTotal]; 99 String preState = null; 100 while ((line = br.readLine()) != null) { 101 String[] arr = line.split("\\s+"); 102 if (arr.length >= 2) { 103 String observe = arr[0]; 104 String state = arr[1]; 105 int row = stateIndexMap.get(state); 106 int col = 0; 107 int oldCount = 0; 108 // if (observeIndexMap.containsKey(observe)) { 109 col = observeIndexMap.get(observe); 110 oldCount = confusionCount[row][col]; 111 confusionCount[row][col] = oldCount + 1; 112 // } 113 stateCount[row] = stateCount[row] + 1; 114 if (preState == null) { 115 preState = state; 116 } else { 117 row = stateIndexMap.get(preState); 118 col = stateIndexMap.get(state); 119 oldCount = stateTransCount[row][col]; 120 stateTransCount[row][col] = oldCount + 1; 121 preState = state; 122 } 123 } else { 124 preState = null; 125 } 126 } 127 br.close(); 128 // 给HMM基本参数赋值 129 for (Entry<String, Integer> entry : stateIndexMap.entrySet()) { 130 String state = entry.getKey(); 131 int index = entry.getValue(); 132 stateSet.set(index, state); 133 } 134 for (Entry<String, Integer> entry : observeIndexMap.entrySet()) { 135 String observe = entry.getKey(); 136 int index = entry.getValue(); 137 observeSet.set(index, observe); 138 } 139 stateProb = calProbByCount(Smooth.GoodTuring(stateCount)); 140 // System.out.println("state initial prob:"); 141 // System.out.println(Arrays.toString(stateProb)); 142 stateTrans = new double[stateTransCount.length][]; 143 for (int i = 0; i < stateTransCount.length; i++) { 144 stateTrans[i] = calProbByCount(stateTransCount[i]);// 计算状态转移概率时不作平滑,因为有些状态之间转移的概率就应该是0,如果平滑就会变成非0 145 } 146 // System.out.println("state transiation prob:"); 147 // for (int i = 0; i < stateTransProb.length; i++) { 148 // System.out.println(Arrays.toString(stateTransProb[i])); 149 // } 150 emission = new double[confusionCount.length][]; 151 minEmission = new double[confusionCount.length]; 152 for (int i = 0; i < confusionCount.length; i++) { 153 emission[i] = calProbByCount(Smooth 154 .GoodTuring(confusionCount[i])); 155 double min = Double.MAX_VALUE; 156 for (double ele : emission[i]) { 157 if (ele < min) { 158 min = ele; 159 } 160 } 161 minEmission[i] = min; 162 } 163 } catch (FileNotFoundException e) { 164 e.printStackTrace(); 165 } catch (IOException e) { 166 e.printStackTrace(); 167 } 168 169 } 170 171 /** 172 * 采用向前算法解决评估问题:给定HMM的所有参数,评估一个观察序列出现的概率。 173 * 174 * @param obs_seq 175 * @return 176 */ 177 public double estimate(List<String> obs_seq) { 178 double rect = 0.0; 179 int LEN = obs_seq.size(); 180 double[][] Q = new double[LEN][]; 181 // 状态的初始概率,乘上隐藏状态到观察状态的条件概率。 182 Q[0] = new double[stateSet.size()]; 183 for (int j = 0; j < stateSet.size(); j++) { 184 if (observeSet.contains(obs_seq.get(0))) { 185 Q[0][j] = stateProb[j] 186 * emission[j][observeSet.indexOf(obs_seq.get(0))]; 187 } else { 188 Q[0][j] = stateProb[j] * minEmission[j]; 189 System.err.println("观察值'" + obs_seq.get(0) + "'在已标记样本中未出现过"); 190 } 191 } 192 // 首先从前一时刻的每个状态,转移到当前状态的概率求和,然后乘上隐藏状态到观察状态的条件概率。 193 for (int i = 1; i < LEN; i++) { 194 Q[i] = new double[stateSet.size()]; 195 for (int j = 0; j < stateSet.size(); j++) { 196 double sum = 0.0; 197 for (int k = 0; k < stateSet.size(); k++) { 198 sum += Q[i - 1][k] * stateTrans[k][j]; 199 } 200 if (observeSet.contains(obs_seq.get(i))) { 201 Q[i][j] = sum 202 * emission[j][observeSet.indexOf(obs_seq.get(i))]; 203 } else { 204 Q[i][j] = sum * minEmission[j]; 205 System.err 206 .println("观察值'" + obs_seq.get(0) + "'在已标记样本中未出现过"); 207 } 208 } 209 } 210 for (int i = 0; i < stateSet.size(); i++) 211 rect += Q[LEN - 1][i]; 212 return rect; 213 } 214 215 /** 216 * 采用viterbi进行解码:给定HMM的所有参数,给一个观察序列,评估最可能的状态序列是什么。 217 * 218 * @param observe 219 * @return 220 */ 221 public Pair<Double, LinkedList<String>> viterbi(List<String> observe) { 222 LinkedList<String> sta = new LinkedList<String>(); 223 int LEN = observe.size(); 224 int M = stateSet.size(); 225 double[][] Q = new double[LEN][]; 226 int[][] Path = new int[LEN][]; 227 Q[0] = new double[M]; 228 Path[0] = new int[M]; 229 for (int j = 0; j < M; j++) { 230 if (observeSet.contains(observe.get(0))) {// 观察值在训练样本中未出现过,则概率设为0 231 Q[0][j] = stateProb[j] 232 * emission[j][observeSet.indexOf(observe.get(0))]; 233 } else { 234 Q[0][j] = stateProb[j] * minEmission[j] / 2; 235 System.err.println("观察值'" + observe.get(0) + "'在已标记样本中未出现过"); 236 } 237 Path[0][j] = -1; 238 } 239 for (int i = 1; i < LEN; i++) { 240 Q[i] = new double[M]; 241 Path[i] = new int[M]; 242 for (int j = 0; j < M; j++) { 243 double max = 0.0; 244 int index = 0; 245 for (int k = 0; k < M; k++) { 246 if (Q[i - 1][k] * stateTrans[k][j] > max) { 247 max = Q[i - 1][k] * stateTrans[k][j]; 248 index = k; 249 } 250 } 251 if (observeSet.contains(observe.get(i))) { 252 Q[i][j] = max 253 * emission[j][observeSet.indexOf(observe.get(i))]; 254 } else { 255 Q[i][j] = max * minEmission[j] / 2; 256 System.err 257 .println("观察值'" + observe.get(0) + "'在已标记样本中未出现过"); 258 } 259 Path[i][j] = index; 260 } 261 } 262 // 找到最后一个时刻呈现哪种状态的概率最大 263 double max = 0; 264 int index = 0; 265 for (int i = 0; i < M; i++) { 266 if (Q[LEN - 1][i] > max) { 267 max = Q[LEN - 1][i]; 268 index = i; 269 } 270 } 271 sta.add(stateSet.get(index)); 272 // 动态规划,逆推回去各个时刻出现什么状态概率最大 273 for (int i = LEN - 1; i > 0; i--) { 274 index = Path[i][index]; 275 sta.add(stateSet.get(index)); 276 } 277 // 把状态序列再顺过来 278 Collections.reverse(sta); 279 return Pair.of(max, sta); 280 } 281 282 public void baumWelch() { 283 284 } 285 286 public void initStateSet(String infile) { 287 FileUtil.readLines(infile, stateSet); 288 } 289 290 public void initObserveSet(String infile) { 291 FileUtil.readLines(infile, observeSet); 292 } 293 294 public void initStateProb(String infile) { 295 assert stateSet != null; 296 int stateCount = stateSet.size(); 297 assert stateCount > 0; 298 stateProb = new double[stateCount]; 299 List<String> lines = new LinkedList<String>(); 300 FileUtil.readLines(infile, lines); 301 assert lines.size() >= stateCount; 302 for (int i = 0; i < stateCount; i++) { 303 stateProb[i] = Double.parseDouble(lines.get(i)); 304 } 305 } 306 307 public void initStateTrans(String infile) { 308 List<String> lines = new LinkedList<String>(); 309 FileUtil.readLines(infile, lines); 310 stateTrans = new double[lines.size()][]; 311 for (int i = 0; i < lines.size(); i++) { 312 String[] conts = lines.get(i).split("\\s+"); 313 stateTrans[i] = new double[conts.length]; 314 for (int j = 0; j < stateTrans[i].length; j++) { 315 stateTrans[i][j] = Double.parseDouble(conts[j]); 316 } 317 } 318 } 319 320 public void initConfusion(String infile) { 321 List<String> lines = new LinkedList<String>(); 322 FileUtil.readLines(infile, lines); 323 emission = new double[lines.size()][]; 324 minEmission = new double[lines.size()]; 325 for (int i = 0; i < lines.size(); i++) { 326 String[] conts = lines.get(i).split("\\s+"); 327 emission[i] = new double[conts.length]; 328 double min = Double.MAX_VALUE; 329 for (int j = 0; j < emission[i].length; j++) { 330 double ele = Double.parseDouble(conts[j]); 331 emission[i][j] = ele; 332 if (ele < min) { 333 min = ele; 334 } 335 } 336 minEmission[i] = min; 337 } 338 } 339 340 public List<String> getStateSet() { 341 return stateSet; 342 } 343 344 public void setStateSet(List<String> stateSet) { 345 this.stateSet = stateSet; 346 } 347 348 public List<String> getObserveSet() { 349 return observeSet; 350 } 351 352 public void setObserveSet(List<String> observeSet) { 353 this.observeSet = observeSet; 354 } 355 356 public double[] getStateProb() { 357 return stateProb; 358 } 359 360 public void setStateProb(double[] stateProb) { 361 this.stateProb = stateProb; 362 } 363 364 public double[][] getStateTrans() { 365 return stateTrans; 366 } 367 368 public void setStateTrans(double[][] stateTrans) { 369 this.stateTrans = stateTrans; 370 } 371 372 public double[][] getEmission() { 373 return emission; 374 } 375 376 public void setEmission(double[][] emission) { 377 this.emission = emission; 378 minEmission = new double[emission.length]; 379 for (int i = 0; i < emission.length; i++) { 380 double min = Double.MAX_VALUE; 381 for (double ele : emission[i]) { 382 if (ele < min) { 383 min = ele; 384 } 385 } 386 minEmission[i] = min; 387 } 388 } 389 390 /** 391 * 通过一组计数计算概率 392 * 393 * @param countArr 394 * @return 395 */ 396 private double[] calProbByCount(double[] countArr) { 397 double sum = 0.0; 398 for (double count : countArr) { 399 sum += count; 400 } 401 double[] prob = new double[countArr.length]; 402 for (int i = 0; i < countArr.length; i++) { 403 prob[i] = countArr[i] / sum; 404 } 405 return prob; 406 } 407 408 /** 409 * 通过一组计数计算概率 410 * 411 * @param countArr 412 * @return 413 */ 414 private double[] calProbByCount(int[] countArr) { 415 double sum = 0.0; 416 for (double count : countArr) { 417 sum += count; 418 } 419 double[] prob = new double[countArr.length]; 420 for (int i = 0; i < countArr.length; i++) { 421 prob[i] = countArr[i] / sum; 422 } 423 return prob; 424 } 425 }
测试代码:
1 import java.util.ArrayList; 2 import java.util.LinkedList; 3 import java.util.List; 4 5 import org.junit.BeforeClass; 6 import org.junit.Test; 7 8 import xxzl.dm.core.sequence.HmmInference; 9 import xxzl.dm.utility.Pair; 10 11 public class TestHmmInference { 12 13 private static HmmInference hmm = new HmmInference(); 14 15 @BeforeClass 16 public static void setup() { 17 hmm.initParam("/Users/zhangchaoyang/OneDrive/msr_train.corp"); 18 } 19 20 @Test 21 public void testForward() { 22 String sentence = "公布这些事实将对日记主人及其他被涉及者带来何种影响"; 23 List<String> obs_seq = str2List(sentence); 24 double ratio = hmm.estimate(obs_seq); 25 System.out.println("观察序列出现的概率:" + ratio); 26 } 27 28 @Test 29 public void testViterbi() { 30 String sentence = "公布这些事实将对日记主人及其他被涉及者带来何种影响"; 31 List<String> obs_seq = str2List(sentence); 32 Pair<Double, LinkedList<String>> states = hmm.viterbi(obs_seq); 33 System.out.println("最可能的状态序列:" + states.second + ",其概率为:" 34 + states.first); 35 System.out.println("分词结果为:" + wordSeg(sentence, states.second)); 36 } 37 38 private List<String> str2List(String sentence) { 39 List<String> obs_seq = new ArrayList<String>(); 40 int i = 0; 41 while (i < sentence.length()) { 42 obs_seq.add(sentence.substring(i, i + 1)); 43 i++; 44 } 45 return obs_seq; 46 } 47 48 private String wordSeg(String sentence, List<String> tag) { 49 StringBuilder sb = new StringBuilder(); 50 int len = sentence.length() <= tag.size() ? sentence.length() : tag 51 .size(); 52 for (int i = 0; i < len; i++) { 53 String word = sentence.substring(i, i + 1); 54 sb.append(word); 55 if (tag.get(i).equals("E") || tag.get(i).equals("S")) { 56 sb.append("\t"); 57 } 58 } 59 return sb.toString(); 60 } 61 }
本文来自博客园,作者:高性能golang,转载请注明原文链接:https://www.cnblogs.com/zhangchaoyang/articles/2219571.html