HMM(hidden markov model)可以用于模式识别,李开复老师就是采用了HMM完成了语音识别。

一下的例子来自于《统计学习方法》

一个HMM由初始概率分布,状态转移概率分布,观测概率分布确定。并且基于两个假设:

1 假设任意时刻t的状态只依赖于前一个时刻的状态,与其他时刻的状态和观测序列无关

2 假设任意时刻的观测只依赖与该市可的马尔科夫的状态,与其他观测,状态无关。

基于此,HMM有三个基本问题:

1 概率计算问题,给定模型和观测序列,计算在模型下的观测序列出现的概率

2 预测问题,已知模型和观测序列,求最有可能的状态序列

3 学习问题,给出若干个观测序列,估计模型的参数,使得该模型下观测序列概率最大

由于第三个问题涉及到EM算法,而今天还没有看,所以这里只解决了两个,明天写第三个。

给出上面两个问题的实际例子

有三个盒子,每个盒子有红球白球

盒子   红球  白球

1   5  5

2  4  6

3  7  3

第一次从这三个盒子中随机取一个盒子的概率为0.2,0.4,0.4

并且如果上一次抽取的是盒子1那么下一次抽取盒子1的概率为0.5,抽取盒子2的概率为0.2,盒子3的概率为0.3,我们通过一个状态转移矩阵来描述

0.5  0.2  0.3

0.3  0.5  0.2

0.2  0.3  0.5    Aij表示从状态i转移到状态j的概率

通过以上描述,我们能得到该HMM的模型参数

状态转移矩阵:

0.5  0.2  0.3

0.3  0.5  0.2

0.2  0.3  0.5

观测概率分布:

0.5  0.5

0.4  0.6

0.7  0.3 Bij表示第i个状态下观测值为j的概率,这里就是抽到红球和白球的概率

初始概率:

0.2,0.4,0.4表示一开始到各个状态的概率

对于问题1:

现在我们抽取三次,结果为:红白红,求其出现的概率。

解决方法:

采用前向算法

就是我们从时刻1开始,先计算所有状态下观测为红的概率,接下来再求t2时刻会转移到某个状态的概率和,以此类推

具体的可以看《统计学习方法》,http://www.cnblogs.com/tornadomeet/archive/2012/03/24/2415583.html这个说的也比较详细

对于问题2:

抽三次后,结果为红白红,求被抽到最有可能的盒子的序列

解决方法:

这里采用了维特比算法,其实就是很常见的动态规划的算法,和求最短路径一样。如果说t+1时刻的状态序列概率最大,那么t时刻的状态序列也应该是最大的。

具体可以看《统计学习方法》

  1 import java.io.BufferedReader;
  2 import java.io.FileInputStream;
  3 import java.io.IOException;
  4 import java.io.InputStreamReader;
  5 import java.util.ArrayList;
  6 import java.util.HashMap;
  7 import java.util.Map;
  8 
  9 class Alaph{//Alaph和delta两个一样。。。一开始的时候delta思路错了,后来就不改了
 10     double pro;//用于存放概率
 11     int state;//存放状态值
 12     public String toString(){
 13         return "pro:"+pro+" state:"+state;
 14     }
 15 }
 16 
 17 class Delta{
 18     public double pro;
 19     public int pos;
 20     public String toString(){
 21         return "pro is "+pro+" pos is "+pos;
 22     }
 23 }
 24 
 25 class Utils{
 26     public static ArrayList<ArrayList<Double>> loadMatrix(String filename) throws IOException{//读取数据
 27         ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>();
 28         FileInputStream fis=new FileInputStream(filename);
 29         InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
 30         BufferedReader br=new BufferedReader(isr);
 31         String line="";
 32         
 33         while((line=br.readLine())!=null){
 34             ArrayList<Double> data=new ArrayList<Double>();
 35             String[] s=line.split(" ");
 36             
 37             for(int i=0;i<s.length;i++){
 38                 data.add(Double.parseDouble(s[i]));
 39             }
 40             dataSet.add(data);
 41         }
 42         return  dataSet;
 43     }
 44     
 45     public static ArrayList<Double> loadState(String filename)throws IOException{//读取数据,这个和上面那个很像,
 46         FileInputStream fis=new FileInputStream(filename);
 47         InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
 48         BufferedReader br=new BufferedReader(isr);
 49         String line="";
 50         ArrayList<Double> data=new ArrayList<Double>();
 51         while((line=br.readLine())!=null){
 52             
 53             String[] s=line.split(" ");
 54             
 55             for(int i=0;i<s.length;i++){
 56                 data.add(Double.parseDouble(s[i]));
 57             }
 58             
 59         }
 60         return  data;
 61     }
 62     
 63     
 64     public static ArrayList<Double> getColData(ArrayList<ArrayList<Double>> A,int index){//根据index值,获取相应的列的数据,后来好像没什么用到。。。囧
 65         ArrayList<Double> col=new ArrayList<Double>();
 66         for(int i=0;i<A.size();i++){
 67             col.add(A.get(i).get(index));
 68         }
 69         return col;
 70     }
 71     
 72     
 73     public static void showData(ArrayList<ArrayList<Double>> data){//debug的时候用的,打印
 74         for(ArrayList<Double> a:data){
 75             System.out.println(a);
 76         }
 77     }
 78     
 79     public static void showAlaph(ArrayList<Alaph> list){
 80         for(Alaph a:list){
 81             System.out.println(a);
 82         }
 83     }
 84     
 85     public static ArrayList<Alaph> copy(ArrayList<Alaph> list){//复制
 86         ArrayList<Alaph> temp=new ArrayList<Alaph>();
 87         for(Alaph a:list){
 88             Alaph b=new Alaph();
 89             b.pro=a.pro;
 90             b.state=a.state;
 91             temp.add(b);
 92         }
 93         return temp;
 94     }
 95     
 96     public static Delta copyDelta(Delta src){//和上面一样,没有什么用
 97         Delta d=new Delta();
 98         d.pro=src.pro;
 99         d.pos=src.pos;
100         return d;
101     }
102     
103     public static ArrayList<Delta> copyDeltaList(Delta[] list){//复制
104         ArrayList<Delta> deltaList=new ArrayList<Delta>();
105         for(Delta delta:list){
106             Delta temp=copyDelta(delta);
107             deltaList.add(temp);
108         }
109         return deltaList;
110     }
111     
112     public static void showDeltaList(ArrayList<Delta> list){//debug
113         for(Delta d:list){
114             System.out.println(d);
115         }
116     }
117     
118     public static int getMaxIndex(ArrayList<Delta> list){//求list中值最大的下标
119         double max=-1.0;
120         int index=-1;
121         for(int i=0;i<list.size();i++){
122             if(list.get(i).pro>max){
123                 max=list.get(i).pro;
124                 index=i;
125             }
126         }
127         return index;
128     }
129     
130 }
131 
132 
133 
134 public class HMM {
135     public static ArrayList<Alaph> getInitAlaph(ArrayList<Double> initState,ArrayList<ArrayList<Double>> B,int index){//第一步的时候,用于求各个状态下的初始情况
136         ArrayList<Double> col=Utils.getColData(B,index);
137         ArrayList<Alaph> alaphSet=new ArrayList<Alaph>();
138         for(int i=0;i<col.size();i++){
139             Alaph a=new Alaph();
140             a.pro=col.get(i)*initState.get(i);//初始情况为初始状态*对应的观测概率矩阵的值
141             a.state=i;
142             alaphSet.add(a);
143         }
144         return alaphSet;
145     }
146     public static ArrayList<Delta> getInitDelta(ArrayList<Double> initState,ArrayList<ArrayList<Double>> B,int index){//和上面一样
147         ArrayList<Double> col=Utils.getColData(B,index);
148         ArrayList<Delta> alaphSet=new ArrayList<Delta>();
149         for(int i=0;i<col.size();i++){
150             Delta d=new Delta();
151             d.pro=col.get(i)*initState.get(i);
152             d.pos=i;
153             alaphSet.add(d);
154         }
155         return alaphSet;
156     }
157     
158     //用于求给定模型和观测序列下求,该模型下的观测序列出现的概率
159     public static double calProb(ArrayList<ArrayList<Double>> A,ArrayList<ArrayList<Double>> B,ArrayList<Double> initState,String[] observe,Map<String,Integer> map){
160         int index=map.get(observe[0]);
161         ArrayList<Alaph> alaphList=getInitAlaph(initState,B,index);//先求第一步的状态概率
162         for(int i=1;i<observe.length;i++){//对各个观测值进行求解
163             String s=observe[i];
164             int tag=map.get(s);
165             ArrayList<Alaph> temp=Utils.copy(alaphList);
166             for(Alaph alaph:alaphList){
167                 int destState=alaph.state;
168                 double pro=0;
169                 for(Alaph a:temp){
170                     int srcState=a.state;
171                     pro+=a.pro*A.get(srcState).get(destState);
172                 }
173                 pro=pro*B.get(destState).get(tag);
174                 alaph.pro=pro;
175             }
176         }
177         double result=0;
178         for(Alaph alaph:alaphList){
179             result+=alaph.pro;
180         }
181         return result;
182     }
183     
184     //用于求给定模型和观测序列下,求其最大可能性的状态序列
185     public static void  decoding(ArrayList<ArrayList<Double>> A,ArrayList<ArrayList<Double>> B,ArrayList<Double> initState,String[] observe,Map<String,Integer> map){
186         int index=map.get(observe[0]);
187         
188         ArrayList<Delta> deltaList=getInitDelta(initState,B,index);
189         int length=B.size();
190         Delta maxDeltaList[]=new Delta[B.size()];//用于存放各个状态下的最大概率对应的delta值
191         ArrayList<ArrayList<Integer>> posList=new ArrayList<ArrayList<Integer>>();//用于存放各个状态下的最佳状态值
192         
193         for(int i=0;i<B.size();i++){
194             ArrayList<Integer> a=new ArrayList<Integer>();
195             a.add(i);
196             posList.add(a);
197         }
198         
199         for(int j=1;j<3;j++){
200             ArrayList<Delta> maxList=new ArrayList<Delta>();
201             String s=observe[j];
202             int tag=map.get(s);
203             for(int i=0;i<B.size();i++){
204                 Delta max=new Delta();
205                 double maxPro=-1.0;
206                 int maxPos=-1;
207                 int maxIndex=-1;
208                 for(int k=0;k<deltaList.size();k++){
209                     Delta delta=deltaList.get(k);
210                     double pro=delta.pro*A.get(delta.pos).get(i)*B.get(i).get(tag);
211                     if(pro>maxPro){
212                         maxPro=pro;
213                         maxPos=i;
214                         maxIndex=k;
215                     }
216                 }
217                 max.pro=maxPro;
218                 max.pos=maxPos;
219                 maxDeltaList[i]=max;
220                 posList.get(i).add(maxIndex);
221             }
222             
223             deltaList=Utils.copyDeltaList(maxDeltaList);
224             System.out.println("  ");
225         }
226         
227         System.out.println(posList.get(Utils.getMaxIndex(deltaList)));
228         
229     }
230     
231     /**
232      * @param args
233      * @throws IOException 
234      */
235     public static void main(String[] args) throws IOException {
236         String dataA="C:/Users/Administrator/Desktop/upload/HMM/A.txt";
237         String dataB="C:/Users/Administrator/Desktop/upload/HMM/B.txt";
238         String state="C:/Users/Administrator/Desktop/upload/HMM/init.txt";
239         ArrayList<ArrayList<Double>> A=Utils.loadMatrix(dataA);
240         ArrayList<ArrayList<Double>> B=Utils.loadMatrix(dataB);
241         ArrayList<Double> initState=Utils.loadState(state);
242         String[] s={"Red","White","Red"};
243         Map<String,Integer> map=new HashMap();
244         map.put("Red",0);
245         map.put("White",1);
246         double pro=calProb(A,B,initState,s,map);
247 //        System.out.println("pro is "+pro);
248         decoding(A,B,initState,s,map);
249     }
250 
251 }

 

posted on 2015-06-15 21:22  sunrye  阅读(4537)  评论(0编辑  收藏  举报