HMM(hidden markov model)可以用于模式识别,李开复老师就是采用了HMM完成了语音识别。
一下的例子来自于《统计学习方法》
一个HMM由初始概率分布,状态转移概率分布,观测概率分布确定。并且基于两个假设:
1 假设任意时刻t的状态只依赖于前一个时刻的状态,与其他时刻的状态和观测序列无关
2 假设任意时刻的观测只依赖与该市可的马尔科夫的状态,与其他观测,状态无关。
基于此,HMM有三个基本问题:
1 概率计算问题,给定模型和观测序列,计算在模型下的观测序列出现的概率
2 预测问题,已知模型和观测序列,求最有可能的状态序列
3 学习问题,给出若干个观测序列,估计模型的参数,使得该模型下观测序列概率最大
由于第三个问题涉及到EM算法,而今天还没有看,所以这里只解决了两个,明天写第三个。
给出上面两个问题的实际例子
有三个盒子,每个盒子有红球白球
盒子 红球 白球
1 5 5
2 4 6
3 7 3
第一次从这三个盒子中随机取一个盒子的概率为0.2,0.4,0.4
并且如果上一次抽取的是盒子1那么下一次抽取盒子1的概率为0.5,抽取盒子2的概率为0.2,盒子3的概率为0.3,我们通过一个状态转移矩阵来描述
0.5 0.2 0.3
0.3 0.5 0.2
0.2 0.3 0.5 Aij表示从状态i转移到状态j的概率
通过以上描述,我们能得到该HMM的模型参数
状态转移矩阵:
0.5 0.2 0.3
0.3 0.5 0.2
0.2 0.3 0.5
观测概率分布:
0.5 0.5
0.4 0.6
0.7 0.3 Bij表示第i个状态下观测值为j的概率,这里就是抽到红球和白球的概率
初始概率:
0.2,0.4,0.4表示一开始到各个状态的概率
对于问题1:
现在我们抽取三次,结果为:红白红,求其出现的概率。
解决方法:
采用前向算法
就是我们从时刻1开始,先计算所有状态下观测为红的概率,接下来再求t2时刻会转移到某个状态的概率和,以此类推
具体的可以看《统计学习方法》,http://www.cnblogs.com/tornadomeet/archive/2012/03/24/2415583.html这个说的也比较详细
对于问题2:
抽三次后,结果为红白红,求被抽到最有可能的盒子的序列
解决方法:
这里采用了维特比算法,其实就是很常见的动态规划的算法,和求最短路径一样。如果说t+1时刻的状态序列概率最大,那么t时刻的状态序列也应该是最大的。
具体可以看《统计学习方法》
1 import java.io.BufferedReader; 2 import java.io.FileInputStream; 3 import java.io.IOException; 4 import java.io.InputStreamReader; 5 import java.util.ArrayList; 6 import java.util.HashMap; 7 import java.util.Map; 8 9 class Alaph{//Alaph和delta两个一样。。。一开始的时候delta思路错了,后来就不改了 10 double pro;//用于存放概率 11 int state;//存放状态值 12 public String toString(){ 13 return "pro:"+pro+" state:"+state; 14 } 15 } 16 17 class Delta{ 18 public double pro; 19 public int pos; 20 public String toString(){ 21 return "pro is "+pro+" pos is "+pos; 22 } 23 } 24 25 class Utils{ 26 public static ArrayList<ArrayList<Double>> loadMatrix(String filename) throws IOException{//读取数据 27 ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>(); 28 FileInputStream fis=new FileInputStream(filename); 29 InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); 30 BufferedReader br=new BufferedReader(isr); 31 String line=""; 32 33 while((line=br.readLine())!=null){ 34 ArrayList<Double> data=new ArrayList<Double>(); 35 String[] s=line.split(" "); 36 37 for(int i=0;i<s.length;i++){ 38 data.add(Double.parseDouble(s[i])); 39 } 40 dataSet.add(data); 41 } 42 return dataSet; 43 } 44 45 public static ArrayList<Double> loadState(String filename)throws IOException{//读取数据,这个和上面那个很像, 46 FileInputStream fis=new FileInputStream(filename); 47 InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); 48 BufferedReader br=new BufferedReader(isr); 49 String line=""; 50 ArrayList<Double> data=new ArrayList<Double>(); 51 while((line=br.readLine())!=null){ 52 53 String[] s=line.split(" "); 54 55 for(int i=0;i<s.length;i++){ 56 data.add(Double.parseDouble(s[i])); 57 } 58 59 } 60 return data; 61 } 62 63 64 public static ArrayList<Double> getColData(ArrayList<ArrayList<Double>> A,int index){//根据index值,获取相应的列的数据,后来好像没什么用到。。。囧 65 ArrayList<Double> col=new ArrayList<Double>(); 66 for(int i=0;i<A.size();i++){ 67 col.add(A.get(i).get(index)); 68 } 69 return col; 70 } 71 72 73 public static void showData(ArrayList<ArrayList<Double>> data){//debug的时候用的,打印 74 for(ArrayList<Double> a:data){ 75 System.out.println(a); 76 } 77 } 78 79 public static void showAlaph(ArrayList<Alaph> list){ 80 for(Alaph a:list){ 81 System.out.println(a); 82 } 83 } 84 85 public static ArrayList<Alaph> copy(ArrayList<Alaph> list){//复制 86 ArrayList<Alaph> temp=new ArrayList<Alaph>(); 87 for(Alaph a:list){ 88 Alaph b=new Alaph(); 89 b.pro=a.pro; 90 b.state=a.state; 91 temp.add(b); 92 } 93 return temp; 94 } 95 96 public static Delta copyDelta(Delta src){//和上面一样,没有什么用 97 Delta d=new Delta(); 98 d.pro=src.pro; 99 d.pos=src.pos; 100 return d; 101 } 102 103 public static ArrayList<Delta> copyDeltaList(Delta[] list){//复制 104 ArrayList<Delta> deltaList=new ArrayList<Delta>(); 105 for(Delta delta:list){ 106 Delta temp=copyDelta(delta); 107 deltaList.add(temp); 108 } 109 return deltaList; 110 } 111 112 public static void showDeltaList(ArrayList<Delta> list){//debug 113 for(Delta d:list){ 114 System.out.println(d); 115 } 116 } 117 118 public static int getMaxIndex(ArrayList<Delta> list){//求list中值最大的下标 119 double max=-1.0; 120 int index=-1; 121 for(int i=0;i<list.size();i++){ 122 if(list.get(i).pro>max){ 123 max=list.get(i).pro; 124 index=i; 125 } 126 } 127 return index; 128 } 129 130 } 131 132 133 134 public class HMM { 135 public static ArrayList<Alaph> getInitAlaph(ArrayList<Double> initState,ArrayList<ArrayList<Double>> B,int index){//第一步的时候,用于求各个状态下的初始情况 136 ArrayList<Double> col=Utils.getColData(B,index); 137 ArrayList<Alaph> alaphSet=new ArrayList<Alaph>(); 138 for(int i=0;i<col.size();i++){ 139 Alaph a=new Alaph(); 140 a.pro=col.get(i)*initState.get(i);//初始情况为初始状态*对应的观测概率矩阵的值 141 a.state=i; 142 alaphSet.add(a); 143 } 144 return alaphSet; 145 } 146 public static ArrayList<Delta> getInitDelta(ArrayList<Double> initState,ArrayList<ArrayList<Double>> B,int index){//和上面一样 147 ArrayList<Double> col=Utils.getColData(B,index); 148 ArrayList<Delta> alaphSet=new ArrayList<Delta>(); 149 for(int i=0;i<col.size();i++){ 150 Delta d=new Delta(); 151 d.pro=col.get(i)*initState.get(i); 152 d.pos=i; 153 alaphSet.add(d); 154 } 155 return alaphSet; 156 } 157 158 //用于求给定模型和观测序列下求,该模型下的观测序列出现的概率 159 public static double calProb(ArrayList<ArrayList<Double>> A,ArrayList<ArrayList<Double>> B,ArrayList<Double> initState,String[] observe,Map<String,Integer> map){ 160 int index=map.get(observe[0]); 161 ArrayList<Alaph> alaphList=getInitAlaph(initState,B,index);//先求第一步的状态概率 162 for(int i=1;i<observe.length;i++){//对各个观测值进行求解 163 String s=observe[i]; 164 int tag=map.get(s); 165 ArrayList<Alaph> temp=Utils.copy(alaphList); 166 for(Alaph alaph:alaphList){ 167 int destState=alaph.state; 168 double pro=0; 169 for(Alaph a:temp){ 170 int srcState=a.state; 171 pro+=a.pro*A.get(srcState).get(destState); 172 } 173 pro=pro*B.get(destState).get(tag); 174 alaph.pro=pro; 175 } 176 } 177 double result=0; 178 for(Alaph alaph:alaphList){ 179 result+=alaph.pro; 180 } 181 return result; 182 } 183 184 //用于求给定模型和观测序列下,求其最大可能性的状态序列 185 public static void decoding(ArrayList<ArrayList<Double>> A,ArrayList<ArrayList<Double>> B,ArrayList<Double> initState,String[] observe,Map<String,Integer> map){ 186 int index=map.get(observe[0]); 187 188 ArrayList<Delta> deltaList=getInitDelta(initState,B,index); 189 int length=B.size(); 190 Delta maxDeltaList[]=new Delta[B.size()];//用于存放各个状态下的最大概率对应的delta值 191 ArrayList<ArrayList<Integer>> posList=new ArrayList<ArrayList<Integer>>();//用于存放各个状态下的最佳状态值 192 193 for(int i=0;i<B.size();i++){ 194 ArrayList<Integer> a=new ArrayList<Integer>(); 195 a.add(i); 196 posList.add(a); 197 } 198 199 for(int j=1;j<3;j++){ 200 ArrayList<Delta> maxList=new ArrayList<Delta>(); 201 String s=observe[j]; 202 int tag=map.get(s); 203 for(int i=0;i<B.size();i++){ 204 Delta max=new Delta(); 205 double maxPro=-1.0; 206 int maxPos=-1; 207 int maxIndex=-1; 208 for(int k=0;k<deltaList.size();k++){ 209 Delta delta=deltaList.get(k); 210 double pro=delta.pro*A.get(delta.pos).get(i)*B.get(i).get(tag); 211 if(pro>maxPro){ 212 maxPro=pro; 213 maxPos=i; 214 maxIndex=k; 215 } 216 } 217 max.pro=maxPro; 218 max.pos=maxPos; 219 maxDeltaList[i]=max; 220 posList.get(i).add(maxIndex); 221 } 222 223 deltaList=Utils.copyDeltaList(maxDeltaList); 224 System.out.println(" "); 225 } 226 227 System.out.println(posList.get(Utils.getMaxIndex(deltaList))); 228 229 } 230 231 /** 232 * @param args 233 * @throws IOException 234 */ 235 public static void main(String[] args) throws IOException { 236 String dataA="C:/Users/Administrator/Desktop/upload/HMM/A.txt"; 237 String dataB="C:/Users/Administrator/Desktop/upload/HMM/B.txt"; 238 String state="C:/Users/Administrator/Desktop/upload/HMM/init.txt"; 239 ArrayList<ArrayList<Double>> A=Utils.loadMatrix(dataA); 240 ArrayList<ArrayList<Double>> B=Utils.loadMatrix(dataB); 241 ArrayList<Double> initState=Utils.loadState(state); 242 String[] s={"Red","White","Red"}; 243 Map<String,Integer> map=new HashMap(); 244 map.put("Red",0); 245 map.put("White",1); 246 double pro=calProb(A,B,initState,s,map); 247 // System.out.println("pro is "+pro); 248 decoding(A,B,initState,s,map); 249 } 250 251 }