隐马尔可夫模型(hidden markov model 简称hmm)广泛应用于语音识别,机器翻译等领域。
隐马尔可夫模型的具体定义,请参考著名论文《A tutorial on Hidden Markov Models and selected applications in speech recognition》,在阅读以下内容之前,建议读者阅读这篇论文的第I II III 节,理论性的东西在此不做赘述。
hmm通常解决以下三类问题:
1.给定一个hmm和观察序列,判断生成这个观察序列的可能性;
2.给定一个hmm和观察序列,给出最可能生成这个观察序列的隐藏序列;
3.给定一个观察序列,训练一个hmm。
第1个问题,通常称为评估问题,可以用前向算法(forward algorithm)来解决,使用了动态规划技术,将该问题的时间复杂度降为O(N*N*T),其中N为隐藏状态的个数,T为给定的观察序列的长度,下面给出java代码:
隐马尔可夫模型的具体定义,请参考著名论文《A tutorial on Hidden Markov Models and selected applications in speech recognition》,在阅读以下内容之前,建议读者阅读这篇论文的第I II III 节,理论性的东西在此不做赘述。
hmm通常解决以下三类问题:
1.给定一个hmm和观察序列,判断生成这个观察序列的可能性;
2.给定一个hmm和观察序列,给出最可能生成这个观察序列的隐藏序列;
3.给定一个观察序列,训练一个hmm。
第1个问题,通常称为评估问题,可以用前向算法(forward algorithm)来解决,使用了动态规划技术,将该问题的时间复杂度降为O(N*N*T),其中N为隐藏状态的个数,T为给定的观察序列的长度,下面给出java代码:
package hmm;
import java.util.HashMap;
import java.util.Map;
/**
* 隐马尔可夫模型
* @author xuguanglv
*
*/
public class Hmm {
//初始概率向量
private static double[] pai = {0.63, 0.17, 0.20};
//状态转移矩阵
private static double[][] A = {{0.500, 0.375, 0.125},
{0.250, 0.125, 0.625},
{0.250, 0.375, 0.375}};
//混淆矩阵
private static double[][] B = {{0.60, 0.20, 0.15, 0.05},
{0.25, 0.25, 0.25, 0.25},
{0.05, 0.10, 0.35, 0.50}};
//隐藏状态索引
private static Map<String, Integer> hiddenStateIndex = new HashMap<String, Integer>();
static{
hiddenStateIndex.put("S(0)", 0);
hiddenStateIndex.put("S(1)", 1);
hiddenStateIndex.put("S(2)", 2);
}
//观察状态索引
private static Map<String, Integer> observableStateIndex = new HashMap<String, Integer>();
static{
observableStateIndex.put("O(0)", 0);
observableStateIndex.put("O(1)", 1);
observableStateIndex.put("O(2)", 2);
observableStateIndex.put("O(3)", 3);
}
//前向算法 根据观察序列和已知的隐马尔可夫模型 返回这个模型生成这个观察序列的概率
//alpha[t][j]表示t时刻由隐藏状态S(j)生成观察状态O(t)的概率
public static double forward(String[] observedSequence){
double[][] alpha = new double[observedSequence.length][A.length];
//利用动态规划计算出alpha数组
//初始化
for(int i = 0; i <= A.length - 1; i++){
int index = observableStateIndex.get(observedSequence[0]);
alpha[0][i] = pai[i] * B[i][index];
}
for(int t = 1; t <= observedSequence.length - 1; t++){
for(int j = 0; j <= A.length - 1; j++){
double sum = 0;
for(int i = 0; i <= A.length - 1; i++){
sum += (alpha[t - 1][i] * A[i][j]);
}
int index = observableStateIndex.get(observedSequence[t]);
alpha[t][j] = sum * B[j][index];
}
}
double prob = 0;
for(int i = 0; i <= A.length - 1; i++){
prob += alpha[observedSequence.length - 1][i];
}
return prob;
}
public static void main(String[] args){
String[] observedSequence = {"O(0)", "O(2)", "O(3)"};
System.out.println(forward(observedSequence));
}
}