从头开始编写基于隐含马尔可夫模型HMM的中文分词器之二 - 模型训练与使用

这篇博客介绍了如何从头开始编写基于隐含马尔可夫模型(HMM)的中文分词器,重点在于模型训练和Viterbi算法的实现。使用四类标签(B,E,M,S)对字符进行标注,通过统计方法构建1阶HMM,求解状态转移矩阵A和混合矩阵B。文章还讨论了数据平滑技术以及在Java中处理精度问题的方法。" 41410755,958136,理解CI框架的自动加载机制,"['PHP', '框架', 'CodeIgniter', '自动加载', 'CI框架']

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

我们使用/icwb2-data.rar/training/msr_training.utf8  用以训练模型,这个词库里包含已分词汇约2000000个。

使用经典的字符标注模型,首先需要确定标注集,在前面的介绍中,我们使用的是{B,E}的二元集合,研究表明基于四类标签的字符标注模型明显优于两类标签,原因是两类标签过于简单而损失了部分信息。四类标签的集合是 {B,E,M,S},其含义如下:
B:一个词的开始
E:一个词的结束
M:一个词的中间
S:单字成词
举例:你S现B在E应B该E去S幼B儿M园E了S

用四类标签为msr_training.utf8做好标记后,就可以开始用统计的方法构建一个HMM。我们打算构建一个2-gram(bigram)语言模型,也即一个1阶HMM,每个字符的标签分类只受前一个字符分类的影响。现在,我们需要求得HMM的状态转移矩阵 A 以及混合矩阵 B。其中:
                     Aij = P(Cj|Ci)  =  P(Ci,Cj) / P(Ci) = Count(Ci,Cj) / Count(Ci)
                     Bij = P(Oj|Ci)  =  P(Oj,Ci) / P(Ci) = Count(Oj,Ci) / Count(Ci)
公式中C = {B,E,M,S},O = {字符集合},Count代表频率。在计算Bij时,由于数据的稀疏性,很多字符未出现在训练集中,这导致概率为0的结果出现在B中,为了修补这个问题,我们采用加1的数据平滑技术,即:
                 Bij = P(Oj|Ci)  =  (Count(Oj,Ci) + 1)/ Count(Ci)
         这并不是一种最好的处理技术,因为这有可能低估或高估真实概率,更加科学的方法是使用复杂一点的Good—Turing技术,这项技术的的原始版本是图灵当年和他的助手Good在破解德国密码机时发明的。

我完成的分词器暂时没有用这个技术,只是简单的为没出现的词赋一个很小的值来实现简单的情况模拟,后续会尝试使用Good-Turing技术,试下效果怎么样。

求得的PI跟矩阵A如下:

    private static  double[] PI = new double[] {0.529918835192331, 0.0, 0.0, 0.470081164807669}; //B, M, E, S
    private static double[][] A = new double[][] {
        {0.0, 0.17142595344427136, 0.8285740465557286, 0.0}, //B
        {0.0, 0.49616531193870117, 0.5038346880612988, 0.0}, //M
        {0.4741680643477776, 0.0, 0.0, 0.5258319356522224}, //E
        {0.5927662448712697, 0.0, 0.0, 0.4072328569272888} //S
    };
相关训练代码:

public static void buildPiAndMatrixA() {
        /**
         * count matrix:
         *   ALL B M E S
         * B *   * * * *
         * M *   * * * *
         * E *   * * * *
         * S *   * * * *
         * 
         * NOTE:
         *  count[2][0] is the total number of complex words
         *  count[3][0] is the total number of single words
         */
        long[][] count = new long[4][5];
        
        try {
            BufferedReader br=new BufferedReader(new InputStreamReader(new FileInputStream("icwb2-data/training/msr_training.utf8"),"UTF-8"));
            String line = null;
            String last = null;
            while ((line = br.readLine()) != null) {
                String[] words = line.split(" ");
                for (int i=0; i<words.length; i++) {
                    String word = words[i].trim();
                    int length = word.length();
                    if (length < 1)
                        continue;
                    if (length == 1) {
                        count[3][0]++;
                        if (last != null) {
                            if (last.length() == 1)
                                count[3][4]++;
                            else
                                count[2][4]++;
                        }
                    } else {
                        count[2][0]++;
                        count[0][0]++;
                        if (length > 2) {
                            count[1][0] += length-2;
                            count[0][2]++;
                            if (length-2 > 1) {
                                count[1][2] += length-3;
                            }
                            count[1][3]++;
                        } else {
                            count[0][3]++;
                        }
                        
                        if (last != null) {
                            if (last.length() == 1) {
                                count[3][1]++;
                            } else {
                                count[2][1]++;
                            }
                        }
                    }
                    last = word;
                }
                //System.out.println("Finish " + words.length + " words ...");
            }
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
        
        for (int i=0; i<count.length; i++)
            System.out.println(Arrays.toString(count[i]));
        System.out.println(" ===== So Pi array is: ===== ");
        double[] pi = new double[4];
        long allWordCount = count[2][0] + count[3][0];
        pi[0] = (double)count[2][0] / allWordCount;
        pi[3] = (double)count[3][0] / allWordCount;
        System.out.println(Arrays.toString(pi));
        System.out.println(" ===== And A matrix is: ===== ");
        double[][] A = new double[4][4];
        for (int i=0; i<A.length; i++)
            for (int j=0; j<A[i].length; j++)
                A[i][j] = (double)count[i][j+1]/ count[i][0];
        for (int i=0; i<A.length; i++)
            System.out.println(Arrays.toString(A[i]));
    }

矩阵中出现的概率为0的元素表明B-B, B-S, M-B, M-S, E-M, E-E, S-M, S-E这8种组合是不可能出现的。这是合乎逻辑的。

矩阵B内容比较大,格式如下:

1.0 3.174041419653506E-6 0.0016687522763828306 1.0633038755839245E-4 4.7610621294802586E-6 
1.0 2.313791818894887E-6 5.738203710859319E-4 4.8589628196792625E-5 2.313791818894887E-6 
1.0 7.935103549133765E-7 0.005299062150111528 6.348082839307012E-6 7.935103549133765E-7 
1.0 1.7881026800082968E-6 0.005061224635763484 4.470256700020742E-6 8.940513400041484E-7 

矩阵B训练代码:

public static void buildMatrixB(String charMapFile, String charMapCharset, String matrixBFileName) {
        /**
         * Chinese Character count => 5167
         * 
         * count matrix:
         *   ALL C1 C2 C3 CN C5168
         * B  *  *  *  *  *  1/ALL+5168
         * M  *  *  *  *  *  1/ALL+5168
         * E  *  *  *  *  *  1/ALL+5168
         * S  *  *  *  *  *  1/ALL+5168
         * 
         * NOTE:
         *  count[0][0] is the total number of begin count
         *  count[0][0] is the total number of middle count
         *  count[2][0] is the total number of end count
         *  count[3][0] is the total number of single cound
         *  
         *  B row -> 4
         *  B col -> 5169
         */
        long[][] matrixBCount = new long[4][5169];
        for (int row = 0; row < matrixBCount.length; row++) {
            Arrays.fill(matrixBCount[row], 1);
            matrixBCount[row][0] = 5168;
        }
        
        Map<Character, Integer> dict = new HashMap<Character, Integer>();
        Utils.readDict(charMapFile, charMapCharset, dict, null);
        
        try {
            BufferedReader br=new BufferedReader(new InputStreamReader(new FileInputStream("icwb2-data/training/msr_training.utf8"),"UTF-8"));
            String line = null;
            while ((line = br.readLine()) != null) {
                String[] words = line.split(" ");
                for (int i=0; i<words.length; i++) {
                    String word = words[i].trim();
                    if (word.length() < 1)
                        continue;
                    
                    if (word.length() == 1) {
                        int index = dict.get(word.charAt(0));
                        matrixBCount[3][0]++;
                        matrixBCount[3][index]++;
                    } else {
                        for (int j=0; j<word.length(); j++) {
                            int index = dict.get(word.charAt(j));
                            if (j == 0) {
                                matrixBCount[0][0]++;
                                matrixBCount[0][index]++;
                            } else if (j == word.length()-1) {
                                matrixBCount[2][0]++;
                                matrixBCount[2][index]++;
                            } else {
                                matrixBCount[1][0]++;
                                matrixBCount[1][index]++;
                            }
                        }
                    }
                    
                }
            }
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
        
        System.out.println(" ===== matrixBCount ===== ");
        for (int i=0; i<matrixBCount.length; i++)
            System.out.println(Arrays.toString(matrixBCount[i]));
        
        System.out.println(" ========= B matrix =========");
        double[][] B = new double[matrixBCount.length][matrixBCount[0].length];
        for (int row = 0; row < B.length; row++) {
            for (int col = 0; col < B[row].length; col++) {
                B[row][col] = (double) matrixBCount[row][col] / matrixBCount[row][0];
                if (col < 50) {
                    System.out.print(B[row][col] + " ");
                }
            }
            System.out.println("");
        }
        
        try {
            PrintWriter bOut = new PrintWriter(new File(matrixBFileName));
            for (int row = 0; row < B.length; row++) {
                for (int col = 0; col < B[row].length; col++) {
                    bOut.print(B[row][col] + " ");
                }
                bOut.println("");
                bOut.flush();
            }
            bOut.close();
            System.out.println("Finish write B to file " + matrixBFileName);
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
    }

有了矩阵PI,A跟B,我们就可以写入一个观察序列,用隐含马尔可夫模型跟Viterbi算法获得一个隐藏序列(分词结果)了。如果你对隐含马尔可夫模型还有什么疑问,请参考52nlp的博文:

http://www.52nlp.cn/hmm-learn-best-practices-one-introduction

以下是我用Java自己实现的Viterbi算法,有一点需要注意的是,Java的double能表示最小的值约为1E-350,如果一个文本串很长,例如大于200个字符,算得的结果值很有可能会小于double的最小值,而由于精度问题变为0,这样最终的计算结果就失去意义了。当然,如果能保证输入串的长度比较短,可以不care这个,但为了程序的健壮性,我这里在计算某一列的结果值小于1E-250时,将停止使用double而改用Java提供的高精度类BigDecimal,虽然计算速度会比double慢很多(尤其随着串越来越长),但总比变为0,结果失去意义要强一些。但即便是这样,这个函数也期望输入串不长于200,否则就有可能在1S内得不到最终计算结果。

public static String Viterbi(double[] PI, double[][] A, double[][] B, int[] sentences) {
        StringBuilder ret = new StringBuilder();
        double[][] matrix = new double[PI.length][sentences.length];
        int[][] past = new int[PI.length][sentences.length];
        
        int supplementStartColumn = -1;
        BigDecimal[][] supplement = null; //new BigDecimal[][];
        
        for (int row=0; row<matrix.length; row++)
            matrix[row][0] = PI[row] * B[row][sentences[0]];
        
        for (int col=1; col<sentences.length; col++) {
            if (supplementStartColumn > -1) { //Use supplement BigDecimal matrix
                for (int row=0; row<matrix.length; row++) {
                    BigDecimal max = new BigDecimal(0d);
                    int last = -1;
                    for (int r=0; r<matrix.length; r++) {
                        BigDecimal value = supplement[r][col-1-supplementStartColumn].multiply(new BigDecimal(A[r][row])).multiply(new BigDecimal(B[row][sentences[col]]));
                        if (value.compareTo(max) > 0) {
                            max = value;
                            last = r;
                        }
                    }
                    supplement[row][col-supplementStartColumn] = max;
                    past[row][col] = last;
                }
            } else {
                boolean switchSupplement = false;
                for (int row=0; row<matrix.length; row++) {
                    double max = 0;
                    int last = -1;
                    for (int r=0; r<matrix.length; r++) {
                        double value = matrix[r][col-1] * A[r][row] * B[row][sentences[col]];
                        if (value > max) {
                            max = value;
                            last = r;
                        }
                    }
                    matrix[row][col] = max;
                    past[row][col] = last;
                    if (max < 1E-250)
                        switchSupplement = true;
                }
                
                //Really small data, should switch to supplement BigDecimal matrix now, or we will loose accuracy soon
                if (switchSupplement) {
                    supplementStartColumn = col;
                    supplement = new BigDecimal[PI.length][sentences.length - supplementStartColumn];
                    for (int row=0; row<matrix.length; row++) {
                        supplement[row][col - supplementStartColumn] = new BigDecimal(matrix[row][col]);
                    }
                }
            }
        }
        
        int index = -1;
        if (supplementStartColumn > -1) {
            BigDecimal max = new BigDecimal(0d);
            int column = supplement[0].length-1;
            for (int row=0; row<supplement.length; row++) {
                if (supplement[row][column].compareTo(max) > 0) {
                    max = supplement[row][column];
                    index = row;
                }
            }
        } else {
            double max = 0;
            for (int row=0; row<matrix.length; row++)
                if (matrix[row][sentences.length-1] > max) {
                    max = matrix[row][sentences.length-1];
                    index = row;
                }
        }
        
        /*for (int i=0; i<matrix.length; i++)
        System.out.println(Arrays.toString(matrix[i]));*/
        
        ret.append(getDesc(index));
        for (int col=sentences.length-1; col>=1; col--) {
            index = past[index][col];
            ret.append(getDesc(index));
        }
        
        return ret.reverse().toString();
    }

测试一下,对如下串进行分词:

String str2 = "这并不是一种最好的处理技术,因为这有可能低估或高估真实概率,更加科学的方法是使用复杂一点的Good—Turing技术,这项技术的原始版本是图灵当年和他的助手Good在破解德国密码机时发明的。";
        
结果为:

Switch to BigDecimal at length = 79
words.length=95   timecost:63
这/并/不是/一种/最好/的/处理/技术/,/因为/这有/可能/低估/或/高估/真实/概率/,/更加/科学/的/方法/是/使用/复杂/一点/的/Good—Turing技术/,/这项/技术/的/原始/版本/是/图灵/当年/和/他/的/助手/Good/在/破解/德国/密码/机时/发明/的/。/
timeCost: 63

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值