convert_to_records.py路径修改

本文介绍了一种将MNIST数据集转换为TensorFlow TFRecords格式的方法,通过Python脚本实现自动化处理,便于后续的数据加载及训练使用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Converts MNIST data to TFRecords file format with Example protos."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys

import tensorflow as tf

from tensorflow.contrib.learn.python.learn.datasets import mnist

FLAGS = None


def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def convert_to(data_set, name):
  """Converts a dataset to tfrecords."""
  images = data_set.images
  labels = data_set.labels
  num_examples = data_set.num_examples

  if images.shape[0] != num_examples:
    raise ValueError('Images size %d does not match label size %d.' %
                     (images.shape[0], num_examples))
  rows = images.shape[1]
  cols = images.shape[2]
  depth = images.shape[3]

  filename = os.path.join(FLAGS.directory, name + '.tfrecords')
  print('Writing', filename)
  writer = tf.python_io.TFRecordWriter(filename)
  for index in range(num_examples):
    image_raw = images[index].tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(rows),
        'width': _int64_feature(cols),
        'depth': _int64_feature(depth),
        'label': _int64_feature(int(labels[index])),
        'image_raw': _bytes_feature(image_raw)}))
    writer.write(example.SerializeToString())
  writer.close()


def main(unused_argv):
  # Get the data.
  data_sets = mnist.read_data_sets(FLAGS.directory,
                                   dtype=tf.uint8,
                                   reshape=False,
                                   validation_size=FLAGS.validation_size)

  # Convert to Examples and write the result to TFRecords.
  convert_to(data_sets.train, 'train')
  convert_to(data_sets.validation, 'validation')
  convert_to(data_sets.test, 'test')


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--directory',
      type=str,
      default='/tmp/data',
      help='Directory to download data files and write the converted result'
  )
  parser.add_argument(
      '--validation_size',
      type=int,
      default=5000,
      help="""\
      Number of examples to separate from the training data for the validation
      set.\
      """
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)















这段代码是用来转换MNIST到标准的tensorflow格式——TFRecords的,默认路径写在了/tmp/data下。我这里直接改成了MNIST_data/ ,然后里面放上MNIST的那四个gz文件就可以了。这是修改路径的方法。运行结果是


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Writing MNIST_data/train.tfrecords
Writing MNIST_data/validation.tfrecords
Writing MNIST_data/test.tfrecords
然后在MNIST_data/ 下可以看到生成的三个TFRecords文件。


最终目标是训练自己的图片,初步怀疑gz解压缩以后就是原始图片,明天可以尝试把自己的图片也弄成gz文件然后生成TFRecords。另外,接着尝试从队列中读取数据。对比教程和官方教程。




                
# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Benchmark offline inference throughput.""" import argparse import dataclasses import json import os import random import time import warnings from typing import Any, Optional, Union import torch import uvloop from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase from benchmark_dataset import ( AIMODataset, BurstGPTDataset, ConversationDataset, InstructCoderDataset, RandomDataset, SampleRequest, ShareGPTDataset, SonnetDataset, VisionArenaDataset, ) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args, ) from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.utils import FlexibleArgumentParser, merge_async_iterators def run_vllm( requests: list[SampleRequest], n: int, engine_args: EngineArgs, disable_detokenize: bool = False, ) -> tuple[float, Optional[list[RequestOutput]]]: from vllm import LLM, SamplingParams llm = LLM(**dataclasses.asdict(engine_args)) assert all( llm.llm_engine.model_config.max_model_len >= (request.prompt_len + request.expected_output_len) for request in requests ), ( "Please ensure that max_model_len is greater than the sum of" " prompt_len and expected_output_len for all requests." ) # Add the requests to the engine. prompts: list[Union[TextPrompt, TokensPrompt]] = [] sampling_params: list[SamplingParams] = [] for request in requests: prompts.append( TokensPrompt( prompt_token_ids=request.prompt["prompt_token_ids"], multi_modal_data=request.multi_modal_data, ) if "prompt_token_ids" in request.prompt else TextPrompt( prompt=request.prompt, multi_modal_data=request.multi_modal_data ) ) sampling_params.append( SamplingParams( n=n, temperature=1.0, top_p=1.0, ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, ) ) lora_requests: Optional[list[LoRARequest]] = None if engine_args.enable_lora: lora_requests = [request.lora_request for request in requests] use_beam_search = False outputs = None if not use_beam_search: start = time.perf_counter() outputs = llm.generate( prompts, sampling_params, lora_request=lora_requests, use_tqdm=True ) end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" prompts = [request.prompt for request in requests] # output_len should be the same for all requests. output_len = requests[0].expected_output_len for request in requests: assert request.expected_output_len == output_len start = time.perf_counter() llm.beam_search( prompts, BeamSearchParams( beam_width=n, max_tokens=output_len, ignore_eos=True, ), ) end = time.perf_counter() return end - start, outputs def run_vllm_chat( requests: list[SampleRequest], n: int, engine_args: EngineArgs, disable_detokenize: bool = False, ) -> tuple[float, list[RequestOutput]]: """ Run vLLM chat benchmark. This function is recommended ONLY for benchmarking multimodal models as it properly handles multimodal inputs and chat formatting. For non-multimodal models, use run_vllm() instead. """ from vllm import LLM, SamplingParams llm = LLM(**dataclasses.asdict(engine_args)) assert all( llm.llm_engine.model_config.max_model_len >= (request.prompt_len + request.expected_output_len) for request in requests ), ( "Please ensure that max_model_len is greater than the sum of " "prompt_len and expected_output_len for all requests." ) prompts = [] sampling_params: list[SamplingParams] = [] for request in requests: prompts.append(request.prompt) sampling_params.append( SamplingParams( n=n, temperature=1.0, top_p=1.0, ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, ) ) start = time.perf_counter() outputs = llm.chat(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() return end - start, outputs async def run_vllm_async( requests: list[SampleRequest], n: int, engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, disable_detokenize: bool = False, ) -> float: from vllm import SamplingParams async with build_async_engine_client_from_engine_args( engine_args, disable_frontend_multiprocessing ) as llm: model_config = await llm.get_model_config() assert all( model_config.max_model_len >= (request.prompt_len + request.expected_output_len) for request in requests ), ( "Please ensure that max_model_len is greater than the sum of" " prompt_len and expected_output_len for all requests." ) # Add the requests to the engine. prompts: list[Union[TextPrompt, TokensPrompt]] = [] sampling_params: list[SamplingParams] = [] lora_requests: list[Optional[LoRARequest]] = [] for request in requests: prompts.append( TokensPrompt( prompt_token_ids=request.prompt["prompt_token_ids"], multi_modal_data=request.multi_modal_data, ) if "prompt_token_ids" in request.prompt else TextPrompt( prompt=request.prompt, multi_modal_data=request.multi_modal_data ) ) sampling_params.append( SamplingParams( n=n, temperature=1.0, top_p=1.0, ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, ) ) lora_requests.append(request.lora_request) generators = [] start = time.perf_counter() for i, (prompt, sp, lr) in enumerate( zip(prompts, sampling_params, lora_requests) ): generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: pass end = time.perf_counter() return end - start def run_hf( requests: list[SampleRequest], model: str, tokenizer: PreTrainedTokenizerBase, n: int, max_batch_size: int, trust_remote_code: bool, disable_detokenize: bool = False, ) -> float: llm = AutoModelForCausalLM.from_pretrained( model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code ) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token llm = llm.cuda() pbar = tqdm(total=len(requests)) start = time.perf_counter() batch: list[str] = [] max_prompt_len = 0 max_output_len = 0 for i in range(len(requests)): prompt = requests[i].prompt prompt_len = requests[i].prompt_len output_len = requests[i].expected_output_len # Add the prompt to the batch. batch.append(prompt) max_prompt_len = max(max_prompt_len, prompt_len) max_output_len = max(max_output_len, output_len) if len(batch) < max_batch_size and i != len(requests) - 1: # Check if we can add more requests to the batch. next_prompt_len = requests[i + 1].prompt_len next_output_len = requests[i + 1].expected_output_len if ( max(max_prompt_len, next_prompt_len) + max(max_output_len, next_output_len) ) <= 2048: # We can add more requests to the batch. continue # Generate the sequences. input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids llm_outputs = llm.generate( input_ids=input_ids.cuda(), do_sample=True, num_return_sequences=n, temperature=1.0, top_p=1.0, use_cache=True, max_new_tokens=max_output_len, ) if not disable_detokenize: # Include the decoding time. tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) pbar.update(len(batch)) # Clear the batch. batch = [] max_prompt_len = 0 max_output_len = 0 end = time.perf_counter() return end - start def run_mii( requests: list[SampleRequest], model: str, tensor_parallel_size: int, output_len: int, ) -> float: from mii import client, serve llm = serve(model, tensor_parallel=tensor_parallel_size) prompts = [request.prompt for request in requests] start = time.perf_counter() llm.generate(prompts, max_new_tokens=output_len) end = time.perf_counter() client = client(model) client.terminate_server() return end - start def save_to_pytorch_benchmark_format( args: argparse.Namespace, results: dict[str, Any] ) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={ "requests_per_second": [results["requests_per_second"]], "tokens_per_second": [results["tokens_per_second"]], }, extra_info={ k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"] }, ) if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" write_to_json(pt_file, pt_records) def get_requests(args, tokenizer): # Common parameters for all dataset types. common_kwargs = { "dataset_path": args.dataset_path, "random_seed": args.seed, } sample_kwargs = { "tokenizer": tokenizer, "lora_path": args.lora_path, "max_loras": args.max_loras, "num_requests": args.num_prompts, "input_len": args.input_len, "output_len": args.output_len, } if args.dataset_path is None or args.dataset_name == "random": sample_kwargs["range_ratio"] = args.random_range_ratio sample_kwargs["prefix_len"] = args.prefix_len dataset_cls = RandomDataset elif args.dataset_name == "sharegpt": dataset_cls = ShareGPTDataset if args.backend == "vllm-chat": sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_name == "sonnet": assert tokenizer.chat_template or tokenizer.default_chat_template, ( "Tokenizer/model must have chat template for sonnet dataset." ) dataset_cls = SonnetDataset sample_kwargs["prefix_len"] = args.prefix_len sample_kwargs["return_prompt_formatted"] = True elif args.dataset_name == "burstgpt": dataset_cls = BurstGPTDataset elif args.dataset_name == "hf": common_kwargs["no_stream"] = args.no_stream if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: dataset_cls = VisionArenaDataset common_kwargs["dataset_subset"] = None common_kwargs["dataset_split"] = "train" sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: dataset_cls = InstructCoderDataset common_kwargs["dataset_split"] = "train" elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: dataset_cls = ConversationDataset common_kwargs["dataset_subset"] = args.hf_subset common_kwargs["dataset_split"] = args.hf_split sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_cls = AIMODataset common_kwargs["dataset_subset"] = None common_kwargs["dataset_split"] = "train" else: raise ValueError(f"Unknown dataset name: {args.dataset_name}") # Remove None values sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} return dataset_cls(**common_kwargs).sample(**sample_kwargs) def main(args: argparse.Namespace): if args.seed is None: args.seed = 0 print(args) random.seed(args.seed) # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( args.tokenizer, trust_remote_code=args.trust_remote_code ) requests = get_requests(args, tokenizer) is_multi_modal = any(request.multi_modal_data is not None for request in requests) request_outputs: Optional[list[RequestOutput]] = None if args.backend == "vllm": if args.async_engine: elapsed_time = uvloop.run( run_vllm_async( requests, args.n, AsyncEngineArgs.from_cli_args(args), args.disable_frontend_multiprocessing, args.disable_detokenize, ) ) else: elapsed_time, request_outputs = run_vllm( requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize, ) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf( requests, args.model, tokenizer, args.n, args.hf_max_batch_size, args.trust_remote_code, args.disable_detokenize, ) elif args.backend == "mii": elapsed_time = run_mii( requests, args.model, args.tensor_parallel_size, args.output_len ) elif args.backend == "vllm-chat": elapsed_time, request_outputs = run_vllm_chat( requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize ) else: raise ValueError(f"Unknown backend: {args.backend}") if request_outputs: # Note: with the vllm and vllm-chat backends, # we have request_outputs, which we use to count tokens. total_prompt_tokens = 0 total_output_tokens = 0 for ro in request_outputs: if not isinstance(ro, RequestOutput): continue total_prompt_tokens += ( len(ro.prompt_token_ids) if ro.prompt_token_ids else 0 ) total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o) total_num_tokens = total_prompt_tokens + total_output_tokens else: total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests) total_output_tokens = sum(r.expected_output_len for r in requests) total_prompt_tokens = total_num_tokens - total_output_tokens if is_multi_modal and args.backend != "vllm-chat": print( "\033[91mWARNING\033[0m: Multi-modal request with " f"{args.backend} backend detected. The " "following metrics are not accurate because image tokens are not" " counted. See vllm-project/vllm/issues/9778 for details." ) # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. # vllm-chat backend counts the image tokens now print( f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " f"{total_output_tokens / elapsed_time:.2f} output tokens/s" ) print(f"Total num prompt tokens: {total_prompt_tokens}") print(f"Total num output tokens: {total_output_tokens}") # Output JSON results if specified if args.output_json: results = { "elapsed_time": elapsed_time, "num_requests": len(requests), "total_num_tokens": total_num_tokens, "requests_per_second": len(requests) / elapsed_time, "tokens_per_second": total_num_tokens / elapsed_time, } with open(args.output_json, "w") as f: json.dump(results, f, indent=4) save_to_pytorch_benchmark_format(args, results) def validate_args(args): """ Validate command-line arguments. """ # === Deprecation and Defaulting === if args.dataset is not None: warnings.warn( "The '--dataset' argument will be deprecated in the next release. " "Please use '--dataset-name' and '--dataset-path' instead.", stacklevel=2, ) args.dataset_path = args.dataset if not getattr(args, "tokenizer", None): args.tokenizer = args.model # === Backend Validation === valid_backends = {"vllm", "hf", "mii", "vllm-chat"} if args.backend not in valid_backends: raise ValueError(f"Unsupported backend: {args.backend}") # === Dataset Configuration === if not args.dataset and not args.dataset_path: print("When dataset path is not set, it will default to random dataset") args.dataset_name = "random" if args.input_len is None: raise ValueError("input_len must be provided for a random dataset") # === Dataset Name Specific Checks === # --hf-subset and --hf-split: only used # when dataset_name is 'hf' if args.dataset_name != "hf" and ( getattr(args, "hf_subset", None) is not None or getattr(args, "hf_split", None) is not None ): warnings.warn( "--hf-subset and --hf-split will be ignored \ since --dataset-name is not 'hf'.", stacklevel=2, ) elif args.dataset_name == "hf": if args.dataset_path in ( VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() | ConversationDataset.SUPPORTED_DATASET_PATHS ): assert args.backend == "vllm-chat", ( f"{args.dataset_path} needs to use vllm-chat as the backend." ) # noqa: E501 elif args.dataset_path in ( InstructCoderDataset.SUPPORTED_DATASET_PATHS | AIMODataset.SUPPORTED_DATASET_PATHS ): assert args.backend == "vllm", ( f"{args.dataset_path} needs to use vllm as the backend." ) # noqa: E501 else: raise ValueError(f"{args.dataset_path} is not supported by hf dataset.") # --random-range-ratio: only used when dataset_name is 'random' if args.dataset_name != "random" and args.random_range_ratio is not None: warnings.warn( "--random-range-ratio will be ignored since \ --dataset-name is not 'random'.", stacklevel=2, ) # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not # set. if ( args.dataset_name not in {"random", "sonnet", None} and args.prefix_len is not None ): warnings.warn( "--prefix-len will be ignored since --dataset-name\ is not 'random', 'sonnet', or not set.", stacklevel=2, ) # === LoRA Settings === if getattr(args, "enable_lora", False) and args.backend != "vllm": raise ValueError("LoRA benchmarking is only supported for vLLM backend") if getattr(args, "enable_lora", False) and args.lora_path is None: raise ValueError("LoRA path must be provided when enable_lora is True") # === Backend-specific Validations === if args.backend == "hf" and args.hf_max_batch_size is None: raise ValueError("HF max batch size is required for HF backend") if args.backend != "hf" and args.hf_max_batch_size is not None: raise ValueError("HF max batch size is only for HF backend.") if ( args.backend in {"hf", "mii"} and getattr(args, "quantization", None) is not None ): raise ValueError("Quantization is only for vLLM backend.") if args.backend == "mii" and args.dtype != "auto": raise ValueError("dtype must be auto for MII backend.") if args.backend == "mii" and args.n != 1: raise ValueError("n must be 1 for MII backend.") if args.backend == "mii" and args.tokenizer != args.model: raise ValueError("Tokenizer must be the same as the model for MII backend.") # --data-parallel is not supported currently. # https://github.com/vllm-project/vllm/issues/16222 if args.data_parallel_size > 1: raise ValueError( "Data parallel is not supported in offline benchmark, \ please use benchmark serving instead" ) def create_argument_parser(): parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser.add_argument( "--backend", type=str, choices=["vllm", "hf", "mii", "vllm-chat"], default="vllm", ) parser.add_argument( "--dataset-name", type=str, choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], help="Name of the dataset to benchmark on.", default="sharegpt", ) parser.add_argument( "--no-stream", action="store_true", help="Do not load the dataset in streaming mode.", ) parser.add_argument( "--dataset", type=str, default=None, help="Path to the ShareGPT dataset, will be deprecated in\ the next release. The dataset is expected to " "be a json in form of list[dict[..., conversations: " "list[dict[..., value: <prompt_or_response>]]]]", ) parser.add_argument( "--dataset-path", type=str, default=None, help="Path to the dataset" ) parser.add_argument( "--input-len", type=int, default=None, help="Input prompt length for each request", ) parser.add_argument( "--output-len", type=int, default=None, help="Output length for each request. Overrides the " "output length from the dataset.", ) parser.add_argument( "--n", type=int, default=1, help="Number of generated sequences per prompt." ) parser.add_argument( "--num-prompts", type=int, default=1000, help="Number of prompts to process." ) parser.add_argument( "--hf-max-batch-size", type=int, default=None, help="Maximum batch size for HF backend.", ) parser.add_argument( "--output-json", type=str, default=None, help="Path to save the throughput results in JSON format.", ) parser.add_argument( "--async-engine", action="store_true", default=False, help="Use vLLM async engine rather than LLM class.", ) parser.add_argument( "--disable-frontend-multiprocessing", action="store_true", default=False, help="Disable decoupled async engine frontend.", ) parser.add_argument( "--disable-detokenize", action="store_true", help=( "Do not detokenize the response (i.e. do not include " "detokenization time in the measurement)" ), ) # LoRA parser.add_argument( "--lora-path", type=str, default=None, help="Path to the LoRA adapters to use. This can be an absolute path, " "a relative path, or a Hugging Face model identifier.", ) parser.add_argument( "--prefix-len", type=int, default=None, help=f"Number of prefix tokens to be used in RandomDataset " "and SonnetDataset. For RandomDataset, the total input " "length is the sum of prefix-len (default: " f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length " "sampled from [input_len * (1 - range_ratio), " "input_len * (1 + range_ratio)]. For SonnetDataset, " f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) " "controls how much of the input is fixed lines versus " "random lines, but the total input length remains approximately " "input_len tokens.", ) # random dataset parser.add_argument( "--random-range-ratio", type=float, default=None, help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) " "for sampling input/output length, " "used only for RandomDataset. Must be in the range [0, 1) to " "define a symmetric sampling range " "[length * (1 - range_ratio), length * (1 + range_ratio)].", ) # hf dtaset parser.add_argument( "--hf-subset", type=str, default=None, help="Subset of the HF dataset." ) parser.add_argument( "--hf-split", type=str, default=None, help="Split of the HF dataset." ) parser = AsyncEngineArgs.add_cli_args(parser) return parser if __name__ == "__main__": parser = create_argument_parser() args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model validate_args(args) main(args)
最新发布
07-29
【SLP·Python】基于 DTW GMM HMM 三种方法实现的语音分类识别系统 Sherry-XLL 于 2021-08-28 17:27:16 发布 阅读量8.9k 收藏 133 点赞数 22 文章标签: python 语音识别 hmm 机器学习 版权 优快云学习社区 文章已被社区收录 加入社区 本文探讨了在孤立词语音识别中,使用动态时间规整(DTW)、高斯混合模型(GMM)和隐马尔可夫模型(HMM)的方法。通过Python实现,针对10个数字的语音识别任务,对数据集进行预处理、特征提取,训练模型并进行预测。实验结果显示,自定义GMM模型达到了100%的准确率,而DTW和HMM模型则分别达到78%和87.5%。此外,文章提供了相关代码和数据集链接,供读者参考和复现实验结果。 摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 > 展开 前言 Github 项目地址:https://github.com/Sherry-XLL/Digital-Recognition-DTW_HMM_GMM 在孤立词语音识别(Isolated Word Speech Recognition) 中,DTW,GMM 和 HMM 是三种典型的方法: 动态时间规整(DTW, Dyanmic Time Warping) 高斯混合模型(GMM, Gaussian Mixed Model) 隐马尔可夫模型(HMM, Hidden Markov Model) 本文并不介绍这三种方法的基本原理,而是侧重于 Python 版代码的实现,针对一个具体的语音识别任务——10 digits recognition system,分别使用 DTW、GMM 和 HMM 建立对 0~9 十个数字的孤立词语音分类识别模型。 前言 一、音频数据处理 1. 数据集 2. 音频预处理 3. MFCC 特征提取 二、Isolated Speech Recognition - DTW 1. DTW 算法步骤 2. DTW 函数编写 3. 模型训练与预测 三、Isolated Speech Recognition - GMM 1. GMM 模型训练流程 2. EM 算法步骤 3. GMM 实现代码 四、Isolated Speech Recognition - HMM 1. 隐马尔可夫模型 2. 代码实现 五、实验结果 1. 实验环境 2. 文件简介 3. 结果对比 六、参考链接 一、音频数据处理 1. 数据集 本文中的语音识别任务是对 0 ~ 9 这十个建立孤立词语言识别模型,所使用的数据集类似于语音版的 MNIST 手写数字数据库。原始数据集是 20 个人录制的数字 0 ~ 9 的音频文件,共 200 条 .wav 格式的录音,同样上传到了 Github 项目 中,分支情况如下: records digit_0 digit_1 digit_2 digit_3 digit_4 digit_5 digit_6 digit_7 digit_8 digit_9 1_0.wav ...... 20_0.wav records 文件夹下包括 digit_0 ~ digit_9 十个子文件夹。每个子文件夹代表一个数字,里面有 20 个人对该数字的录音音频,1_0.wav 就代表第 1 个人录数字 0 的音频文件。 为了比较不同方法的性能,我们将数据集按照 4:1 的比例分为训练集和测试集。我们更想了解的是模型在未知数据上的表现,于是前16个人的数据都被划分为了训练集,共 160 条音频文件;17-20 这四个人的音频为测试集,共 40 条 .wav 格式的录音。 2. 音频预处理 preprocess.py referencing https://github.com/rocketeerli/Computer-VisionandAudio-Lab/tree/master/lab1 当录数字的音频时,录音前后会有微小时间的空隙,并且每段音频的空白片段长度不一。如果忽略这个差异,直接对原音频进行建模运算,结果难免会受到影响,因此本文选择基于双阈值的语音端点检测对音频进行预处理。 对于每段 .wav 音频,首先用 wave 读取,并通过 np.frombuffer 和 readframes 得到二进制数组。然后,编写计算能量 (energy) 的函数 calEnergy 和计算过零率 (zero-crossing rate,ZCR) 的函数 calZeroCrossingRate,帧长 framesize 设置为常用的 256: def calEnergy(wave_data): """ :param wave_data: binary data of audio file :return: energy """ energy = [] sum = 0 for i in range(len(wave_data)): sum = sum + (int(wave_data[i]) * int(wave_data[i])) if (i + 1) % 256 == 0: energy.append(sum) sum = 0 elif i == len(wave_data) - 1: energy.append(sum) return energy def calZeroCrossingRate(wave_data): """ :param wave_data: binary data of audio file :return: ZeroCrossingRate """ zeroCrossingRate = [] sum = 0 for i in range(len(wave_data)): sum = sum + np.abs(int(wave_data[i] >= 0) - int(wave_data[i - 1] >= 0)) if (i + 1) % 256 == 0: zeroCrossingRate.append(float(sum) / 255) sum = 0 elif i == len(wave_data) - 1: zeroCrossingRate.append(float(sum) / 255) return zeroCrossingRate 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 代入音频的二进制数据得到能量和过零率数组之后,通过语音端点检测得到处理之后的音频数据,端点检测函数 endPointDetect 主要为三个步骤: 先根据 energy 和 zeroCrossingRate 计算高能阈值 (high energy threshold, MH) 第一步是计算较高的短时能量作为高能阈值 (high energy threshold, MH),可用于识别语音的浊音部分,得到初步检测数据 A。 第二步是计算较低的能量阈值 (low energy threshold, ML),使用该阈值搜索两端,第二次检测从 A 得到 B。 第三步是计算过零率阈值 (zero crossing rate threshold, Zs),在考虑过零率的基础上进一步处理第二步的结果 B,返回处理好的数据 C。 最后,将编号 1-16 被试者的音频文件存入 processed_train_records ,编号 17-20 存入 processed_test_records。 3. MFCC 特征提取 音频文件无法直接进行语音识别,需要对其进行特征提取。在语音识别(Speech Recognition)领域,最常用到的语音特征就是梅尔倒谱系数(Mel-scale Frequency Cepstral Coefficients,MFCC)。 在 Python 中,可以用封装好的工具包 librosa 或 python_speech_features 实现对 MFCC 特征的提取。本文使用 librosa 提取给定音频的 MFCC 特征,每帧提取 39 维 MFCC 特征: import librosa def mfcc(wav_path, delta = 2): """ Read .wav files and calculate MFCC :param wav_path: path of audio file :param delta: derivative order, default order is 2 :return: mfcc """ y, sr = librosa.load(wav_path) # MEL frequency cepstrum coefficient mfcc_feat = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13) ans = [mfcc_feat] # Calculate the 1st derivative if delta >= 1: mfcc_delta1 = librosa.feature.delta(mfcc_feat, order=1, mode ='nearest') ans.append(mfcc_delta1) # Calculate the 2nd derivative if delta >= 2: mfcc_delta2 = librosa.feature.delta(mfcc_feat, order=2, mode ='nearest') ans.append(mfcc_delta2) return np.transpose(np.concatenate(ans, axis=0), [1,0]) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 可以参考官网 librosa 和 python-speech-features 了解更多音频特征提取工具。 二、Isolated Speech Recognition - DTW 1. DTW 算法步骤 动态时间规整(DTW, Dyanmic Time Warping) 本质上是一种简单的动态规划算法,它可以分为两个步骤,一个是计算两种模式的帧之间的距离,即得到帧匹配距离矩阵;另一个是在帧匹配距离矩阵中找到最优路径。算法步骤如下: 首先,使用欧几里德距离 (Euclidean distance) 初始化成本矩阵 (cost matrix)。 其次,使用动态规划计算成本矩阵,将每个数据的向量与模板向量进行比较,动态转移方程为 c o s t [ i , j ] + = m i n ( [ c o s t [ i , j ] , c o s t [ i + 1 , j ] , c o s t [ i , j + 1 ] ] ) . cost[i, j] += min([cost[i, j], cost[i+1, j], cost[i, j+1]]). cost[i,j]+=min([cost[i,j],cost[i+1,j],cost[i,j+1]]). 第三,计算规整路径 (warp path),最小距离为标准化距离。将第一个训练序列设置为标准 (template),合并所有训练序列并求其平均值,返回 0-9 位数字的模型。 测试过程中,使用 DTW 衡量每条音频到 0-9 这十个模板的”距离“,并选择模板距离最小的数字作为音频的预测值。 2. DTW 函数编写 编写 dtw 函数的关键在于两段 MFCC 序列的动态规划矩阵和规整路径的计算。令输入的两段 MFCC 序列分别为 M1 和 M2,M1_len 和 M2_len 表示各自的长度,则 cost matrix 的大小为 M1_len * M2_len,先用欧式距离对其进行初始化,再根据转移式计算 cost matrix: def dtw(M1, M2): """ Compute Dynamic Time Warping(DTW) of two mfcc sequences. :param M1, M2: two mfcc sequences :return: the minimum distance and the wrap path """ # length of two sequences M1_len = len(M1) M2_len = len(M2) cost_0 = np.zeros((M1_len + 1, M2_len + 1)) cost_0[0, 1:] = np.inf cost_0[1:, 0] = np.inf # Initialize the array size to M1_len * M2_len cost = cost_0[1:, 1:] for i in range(M1_len): for j in range(M2_len): cost[i, j] = calEuclidDist(M1[i], M2[j]) # dynamic programming to calculate cost matrix for i in range(M1_len): for j in range(M2_len): cost[i, j] += min([cost_0[i, j], \ cost_0[min(i + 1, M1_len - 1), j], \ cost_0[i, min(j + 1, M2_len - 1)]]) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 欧式距离的计算函数 calEuclidDist 可以用一行代码完成: def calEuclidDist(A, B): """ :param A, B: two vectors :return: the Euclidean distance of A and B """ return sqrt(sum([(a - b) ** 2 for (a, b) in zip(A, B)])) 1 2 3 4 5 6 cost matrix 计算完成后还需计算 warp path,MFCC 序列长度为 1 时的路径可以单独分情况讨论,最小代价的 path_1 和 path_2 即为所求,最后返回数组 path: # calculate the warp path if len(M1) == 1: path = np.zeros(len(M2)), range(len(M2)) elif len(M2) == 1: path = range(len(M1)), np.zeros(len(M1)) else: i, j = np.array(cost_0.shape) - 2 path_1, path_2 = [i], [j] # path_1, path_2 with the minimum cost is what we want while (i > 0) or (j > 0): arg_min = np.argmin((cost_0[i, j], cost_0[i, j + 1], cost_0[i + 1, j])) if arg_min == 0: i -= 1 j -= 1 elif arg_min == 1: i -= 1 else: j -= 1 path_1.insert(0, i) path_2.insert(0, j) # convert to array path = np.array(path_1), np.array(path_2) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 3. 模型训练与预测 在模型训练阶段,对于每个数字的 16 条音频文件,现计算 MFCC 序列到 mfcc_list 列表中,将第一个音频文件的 MFCC 序列设置为标准,计算标准模板和每条模板之间的动态时间规整路径,再对其求平均进行归一化,将结果 append 到 model 列表中。返回的 model 包含 0-9 这 10 个数字的 DTW 模型: # set the first sequence as standard, merge all training sequences mfcc_count = np.zeros(len(mfcc_list[0])) mfcc_all = np.zeros(mfcc_list[0].shape) for i in range(len(mfcc_list)): # calculate the wrap path between standard and each template _, path = dtw(mfcc_list[0], mfcc_list[i]) for j in range(len(path[0])): mfcc_count[int(path[0][j])] += 1 mfcc_all[int(path[0][j])] += mfcc_list[i][path[1][j]] # Generalization by averaging the templates model_digit = np.zeros(mfcc_all.shape) for i in range(len(mfcc_count)): for j in range(len(mfcc_all[i])): model_digit[i][j] = mfcc_all[i][j] / mfcc_count[i] # return model with models of 0-9 digits model.append(model_digit) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 测试阶段比较简单,将训练得到的 model 和需要预测音频的路径输入到函数 predict_dtw 中,分别计算 model 中每位数字到音频 MFCC 序列的最短距离 min_dist,min_dist 所对应的数字即为该音频的预测标签。 def predict_dtw(model, file_path): """ :param model: trained model :param file_path: path of .wav file :return: digit """ # Iterate, find the digit with the minimum distance digit = 0 min_dist, _ = dtw(model[0], mfcc_feat) for i in range(1, len(model)): dist, _ = dtw(model[i], mfcc_feat) if dist < min_dist: digit = i min_dist = dist return str(digit) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 三、Isolated Speech Recognition - GMM 1. GMM 模型训练流程 参考 sklearn 高斯混合模型中文文档:https://www.scikitlearn.com.cn/0.21.3/20/#211 高斯混合模型(GMM, Gaussian Mixed Model) 是一个假设所有的数据点都是生成于有限个带有未知参数的高斯分布所混合的概率模型,包含了数据的协方差结构以及隐高斯模型的中心信息,同样可以用于解决本文的数字识别任务。 scikit-learn 是基于 Python 语言的机器学习工具,sklearn.mixture 是其中应用高斯混合模型进行非监督学习的包,GaussianMixture 对象实现了用来拟合高斯混合模型的期望最大化 (EM, Expectation-Maximum) 算法。 如果想要直接调用封装好的库,只需在代码中加入 from sklearn.mixture import GaussianMixture,GaussianMixture.fit 便可以从训练数据中拟合出一个高斯混合模型,并用 predict 进行预测,详情参见官方文档。 GMM 模型的训练和预测过程与 DTW 类似,此处不再赘述,流程图如下所示: 2. EM 算法步骤 参考文档:https://blog.csdn.net/nsh119/article/details/79584629?spm=1001.2014.3001.5501 GMM 模型的核心在于期望最大化 (EM, Expectation-Maximum) 算法,模型参数的训练步骤主要为 Expectation 的 E 步和 Maximum 的 M 步: (1) 取参数的初始值开始迭代。 (2) E 步:依据当前模型参数,计算分模型 k kk 对观测数据 y j y_jy j ​ 的相应度 γ ^ j k = α j ϕ ( y j ∣ θ k ) Σ k = 1 K α k ϕ ( y j ∣ θ k ) , j = 1 , 2 , . . . , N ; k = 1 , 2 , . . . , K \hat{\gamma}_{jk}=\dfrac{\alpha_j\phi(y_j|\theta_k)}{\Sigma^K_{k=1}\alpha_k\phi(y_j|\theta_k)},j=1,2,...,N; k=1,2,...,K γ ^ ​ jk ​ = Σ k=1 K ​ α k ​ ϕ(y j ​ ∣θ k ​ ) α j ​ ϕ(y j ​ ∣θ k ​ ) ​ ,j=1,2,...,N;k=1,2,...,K (3) M 步:计算新一轮迭代的模型参数 μ ^ k = Σ j = 1 N γ ^ j k y j Σ j = 1 N γ ^ j k , k = 1 , 2 , . . . , K σ ^ k 2 = Σ j = 1 N γ ^ j k ( y j − μ k ) 2 Σ j = 1 N γ ^ j k , k = 1 , 2 , . . . , K α ^ k = Σ j = 1 N γ ^ j k N , k = 1 , 2 , . . . , K μ^kσ^2kα^k=ΣNj=1γ^jkyjΣNj=1γ^jk,k=1,2,...,K=ΣNj=1γ^jk(yj−μk)2ΣNj=1γ^jk,k=1,2,...,K=ΣNj=1γ^jkN,k=1,2,...,K 𝜇 ^ 𝑘 = Σ 𝑗 = 1 𝑁 𝛾 ^ 𝑗 𝑘 𝑦 𝑗 Σ 𝑗 = 1 𝑁 𝛾 ^ 𝑗 𝑘 , 𝑘 = 1 , 2 , . . . , 𝐾 𝜎 ^ 𝑘 2 = Σ 𝑗 = 1 𝑁 𝛾 ^ 𝑗 𝑘 ( 𝑦 𝑗 − 𝜇 𝑘 ) 2 Σ 𝑗 = 1 𝑁 𝛾 ^ 𝑗 𝑘 , 𝑘 = 1 , 2 , . . . , 𝐾 𝛼 ^ 𝑘 = Σ 𝑗 = 1 𝑁 𝛾 ^ 𝑗 𝑘 𝑁 , 𝑘 = 1 , 2 , . . . , 𝐾 μ ^ ​ k ​ σ ^ k 2 ​ α ^ k ​ ​ = Σ j=1 N ​ γ ^ ​ jk ​ Σ j=1 N ​ γ ^ ​ jk ​ y j ​ ​ ,k=1,2,...,K = Σ j=1 N ​ γ ^ ​ jk ​ Σ j=1 N ​ γ ^ ​ jk ​ (y j ​ −μ k ​ ) 2 ​ ,k=1,2,...,K = N Σ j=1 N ​ γ ^ ​ jk ​ ​ ,k=1,2,...,K ​ 3. GMM 实现代码 sklearn.mixture/_gaussian_mixture.py:https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/mixture/_gaussian_mixture.py GMM 模型实现可以参考 GaussianMixture 的源码,本文对 sklearn 的实现方式进行了简化,关键在于 GMM 这个类的函数实现。 我们实现的 GMM 类包含三个参数 (parameter) 和三个属性 (attribute):n_components 表示模型中状态的数目;mfcc_data 为音频提取的 MFCC 数据;random_state 是控制随机数的参数;means 表示在每个状态中 mixture component 的平均值;var 为方差;weights 是每个 component 的参数矩阵。 class GMM: """Gaussian Mixture Model. Parameters ---------- n_components : int Number of states in the model. mfcc_data : array, shape (, 39) Mfcc data of training wavs. random_state: int RandomState instance or None, optional (default=0) Attributes ---------- means : array, shape (n_components, 1) Mean parameters for each mixture component in each state. var : array, shape (n_components, 1) Variance parameters for each mixture component in each state. weights : array, shape (1, n_components) Weights matrix of each component. """ 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 根据均值和方差可以计算对数高斯概率: def log_gaussian_prob(x, means, var): """ Estimate the log Gaussian probability :param x: input array :param means: The mean of each mixture component. :param var: The variance of each mixture component. :return: the log Gaussian probability """ return (-0.5 * np.log(var) - np.divide(np.square(x - means), 2 * var) - 0.5 * np.log(2 * np.pi)).sum() 1 2 3 4 5 6 7 8 9 GMM 类中包含五个部分:__init__ 用于初始化;e_step 是 Expectation 步骤;m_step 是 Maximization 步骤;train 用于调用 E 步和 M 步进行参数训练;log_prob 计算每个高斯模型的对数高斯概率。 def __init__(self, mfcc_data, n_components, random_state=0): # Initialization self.mfcc_data = mfcc_data self.means = np.tile(np.mean(self.mfcc_data, axis=0), (n_components, 1)) # randomization np.random.seed(random_state) for k in range(n_components): randn_k = np.random.randn() self.means[k] += 0.01 * randn_k * np.sqrt(np.var(self.mfcc_data, axis=0)) self.var = np.tile(np.var(self.mfcc_data, axis=0), (n_components, 1)) self.weights = np.ones(n_components) / n_components self.n_components = n_components def e_step(self, x): """ Expectation-step. :param x: input array-like data :return: Logarithm of the posterior probabilities (or responsibilities) of the point of each sample in x. """ log_resp = np.zeros((x.shape[0], self.n_components)) for i in range(x.shape[0]): log_resp_i = np.log(self.weights) for j in range(self.n_components): log_resp_i[j] += log_gaussian_prob(x[i], self.means[j], self.var[j]) y = np.exp(log_resp_i - log_resp_i.max()) log_resp[i] = y / y.sum() return log_resp def m_step(self, x, log_resp): """ Maximization step. :param x: input array-like data :param log_resp: Logarithm of the posterior probabilities (or responsibilities) of the point of each sample in data """ self.weights = np.sum(log_resp, axis=0) / np.sum(log_resp) denominator = np.sum(log_resp, axis=0, keepdims=True).T means_num = np.zeros_like(self.means) for k in range(self.n_components): means_num[k] = np.sum(np.multiply(x, np.expand_dims(log_resp[:, k], axis=1)), axis=0) self.means = np.divide(means_num, denominator) var_num = np.zeros_like(self.var) for k in range(self.n_components): var_num[k] = np.sum(np.multiply(np.square(np.subtract(x, self.means[k])), np.expand_dims(log_resp[:, k], axis=1)), axis=0) self.var = np.divide(var_num, denominator) def train(self, x): """ :param x: input array-like data """ log_resp = self.e_step(x) self.m_step(x, log_resp) def log_prob(self, x): """ :param x: input array-like data :return: calculated log gaussian probability of single modal Gaussian """ sum_prob = 0 for i in range(x.shape[0]): prob_i = np.array([np.log(self.weights[j]) + log_gaussian_prob(x[i], self.means[j], self.var[j]) for j in range(self.n_components)]) sum_prob += np.max(prob_i) return sum_prob 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 四、Isolated Speech Recognition - HMM 1. 隐马尔可夫模型 参考 hmmlearn 文档:https://hmmlearn.readthedocs.io/en/stable 隐马尔可夫模型(HMM, Hidden Markov Model) 是一种生成概率模型,其中可观测的 X XX 变量序列由内部隐状态序列 Z ZZ 生成。隐状态 (hidden states) 之间的转换假定为(一阶)马尔可夫链 (Markov chain) 的形式,可以由起始概率向量 π ππ 和转移概率矩阵 A AA 指定。可观测变量的发射概率 (emission probability) 可以是以当前隐藏状态为条件的任何分布,参数为 θ \thetaθ 。HMM 完全由 π ππ、A AA 和 θ \thetaθ 决定。 HMM 有三个基本问题: 给定模型参数和观测数据,估计最优隐状态序列 (estimate the optimal sequence of hidden states) 给定模型参数和观测数据,计算模型似然 (calculate the model likelihood) 仅给出观测数据,估计模型参数 (estimate the model parameters) 第一个和第二个问题可以通过动态规划中的维特比算法 (Viterbi algorithm) 和前向-后向算法 (Forward-Backward algorithm) 来解决,最后一个问题可以通过迭代期望最大化 (EM, Expectation-Maximization) 算法求解,也称为 Baum-Welch 算法。 如果想要直接调用封装好的库,只需在代码中加入 from hmmlearn import hmm,详情参见官方文档 和 源码。 2. 代码实现 对于本文的数字语音识别任务,HMM 模型的训练预测过程与 DTW 和 GMM 类似,不再重复说明。HMM 模型实现可以参考 hmmlearn 的源码,本文的实现关键在于 HMM 这个类中的函数。与 GMM 的 EM 算法不同的是,在 HMM 的实现中我们使用的是维特比算法 (Viterbi algorithm) 和前向-后向算法 (Forward-Backward algorithm)。 HMM 类包含两个参数 (parameter) 和四个属性 (attribute):n_components 表示模型中状态的数目;mfcc_data 为音频提取的 MFCC 数据;means 表示在每个状态中 mixture component 的平均值;var 为方差;trans_mat 是状态之间的转移矩阵。 class HMM(): """Hidden Markov Model. Parameters ---------- n_components : int Number of states in the model. mfcc_data : array, shape (, 39) Mfcc data of training wavs. Attributes ---------- init_prob : array, shape (n_components, ) Initial probability over states. means : array, shape (n_components, 1) Mean parameters for each mixture component in each state. var : array, shape (n_components, 1) Variance parameters for each mixture component in each state. trans_mat : array, shape (n_components, n_components) Matrix of transition probabilities between states. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 HMM 类中包含五个部分: __init__ 对参数进行初始化 (Initialization); viterbi Viterbi 算法用于解决 HMM 的解码问题,即找到观测序列中最可能的标记序列,本质上是一种动态规划算法 (dynamic programming); forward 计算所有时间步中所有状态的对数概率 (log probabilities); forward_backward 是参数估计的前向-后向算法 (Forward-Backward algorithm); train 用于调用 forward_backward 和 viterbi 进行参数训练; log_prob 计算每个模型的对数概率 (log likelihood)。 def viterbi(self, data): # Viterbi algorithm is used to solve decoding problem of HMM # That is to find the most possible labeled sequence of the observation sequence for index, single_wav in enumerate(data): n = single_wav.shape[0] labeled_seq = np.zeros(n, dtype=int) log_delta = np.zeros((n, self.n_components)) psi = np.zeros((n, self.n_components)) log_delta[0] = np.log(self.init_prob) for i in range(self.n_components): log_delta[0, i] += log_gaussian_prob(single_wav[0], self.means[i], self.var[i]) log_trans_mat = np.log(self.trans_mat) for t in range(1, n): for j in range(self.n_components): temp = np.zeros(self.n_components) for i in range(self.n_components): temp[i] = log_delta[t - 1, i] + log_trans_mat[i, j] + log_gaussian_prob(single_wav[t], self.means[j], self.var[j]) log_delta[t, j] = np.max(temp) psi[t, j] = np.argmax(log_delta[t - 1] + log_trans_mat[:, j]) labeled_seq[n - 1] = np.argmax(log_delta[n - 1]) for i in reversed(range(n - 1)): labeled_seq[i] = psi[i + 1, labeled_seq[i + 1]] self.states[index] = labeled_seq def forward(self, data): # Computes forward log-probabilities of all states at all time steps. n = data.shape[0] m = self.means.shape[0] alpha = np.zeros((n, m)) alpha[0] = np.log(self.init_prob) + np.array([log_gaussian_prob(data[0], self.means[j], self.var[j]) for j in range(m)]) for i in range(1, n): for j in range(m): alpha[i, j] = log_gaussian_prob(data[i], self.means[j], self.var[j]) + np.max( np.log(self.trans_mat[:, j].T) + alpha[i - 1]) return alpha def forward_backward(self, data): # forward backward algorithm to estimate parameters gamma_0 = np.zeros(self.n_components) gamma_1 = np.zeros((self.n_components, data[0].shape[1])) gamma_2 = np.zeros((self.n_components, data[0].shape[1])) for index, single_wav in enumerate(data): n = single_wav.shape[0] labeled_seq = self.states[index] gamma = np.zeros((n, self.n_components)) for t, j in enumerate(labeled_seq[:-1]): self.trans_mat[j, labeled_seq[t + 1]] += 1 gamma[t, j] = 1 gamma[n - 1, self.n_components - 1] = 1 gamma_0 += np.sum(gamma, axis=0) for t in range(n): gamma_1[labeled_seq[t]] += single_wav[t] gamma_2[labeled_seq[t]] += np.square(single_wav[t]) gamma_0 = np.expand_dims(gamma_0, axis=1) self.means = gamma_1 / gamma_0 self.var = (gamma_2 - np.multiply(gamma_0, self.means ** 2)) / gamma_0 for j in range(self.n_components): self.trans_mat[j] /= np.sum(self.trans_mat[j]) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 五、实验结果 Github 项目地址:https://github.com/Sherry-XLL/Digital-Recognition-DTW_HMM_GMM 1. 实验环境 macOS Catalina Version 10.15.6, Python 3.8, PyCharm 2020.2 (Professional Edition). 1 2. 文件简介 dtw.py: Implementation of Dynamic Time Warping (DTW) gmm.py: Implementation of Gaussian Mixture Model (GMM) hmm.py: Implementation of Hidden Markov Model (HMM) gmm_from_sklearn.py: Train gmm model with GaussianMixture from sklearn hmm_from_hmmlearn.py: Train hmm model with hmm from hmmlearn preprocess.py: preprocess audios and split data processed_test_records: records with test audios processed_train_records: records with train audios records: original audios utils.py: utils function 1 2 3 4 5 6 7 8 9 10 11 12 若想运行代码,只需在命令行输入如下指令,以 DTW 为例: eg: python preprocess.py (mkdir processed records) python dtw.py 1 2 3 3. 结果对比 各个方法的数据集和预处理部分完全相同,下面是运行不同文件的结果: python dtw.py ----------Dynamic Time Warping (DTW)---------- Train num: 160, Test num: 40, Predict true num: 31 Accuracy: 0.78 1 2 3 4 python gmm_from_sklearn.py: ---------- GMM (GaussianMixture) ---------- Train num: 160, Test num: 40, Predict true num: 34 Accuracy: 0.85 1 2 3 4 python gmm.py ---------- Gaussian Mixture Model (GMM) ---------- confusion_matrix: [[4 0 0 0 0 0 0 0 0 0] [0 4 0 0 0 0 0 0 0 0] [0 0 4 0 0 0 0 0 0 0] [0 0 0 4 0 0 0 0 0 0] [0 0 0 0 4 0 0 0 0 0] [0 0 0 0 0 4 0 0 0 0] [0 0 0 0 0 0 4 0 0 0] [0 0 0 0 0 0 0 4 0 0] [0 0 0 0 0 0 0 0 4 0] [0 0 0 0 0 0 0 0 0 4]] Train num: 160, Test num: 40, Predict true num: 40 Accuracy: 1.00 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 python hmm_from_hmmlearn.py ---------- HMM (GaussianHMM) ---------- Train num: 160, Test num: 40, Predict true num: 36 Accuracy: 0.90 1 2 3 4 python hmm.py ---------- HMM (Hidden Markov Model) ---------- confusion_matrix: [[4 0 0 0 0 0 0 0 0 0] [0 4 0 0 0 0 0 0 0 0] [0 0 3 0 1 0 0 0 0 0] [0 0 0 4 0 0 0 0 0 0] [0 0 0 0 4 0 0 0 0 0] [0 0 0 0 1 3 0 0 0 0] [0 0 0 0 0 0 3 0 1 0] [0 0 0 0 0 0 0 4 0 0] [0 0 0 1 0 0 1 0 2 0] [0 0 0 0 0 0 0 0 0 4]] Train num: 160, Test num: 40, Predict true num: 35 Accuracy: 0.875 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 Method DTW GMM from sklearn Our GMM HMM from hmmlearn Our HMM Accuracy 0.78 0.85 1.00 0.90 0.875 值得注意的是,上面所得正确率仅供参考。我们阅读源码就会发现,不同文件中 n_components 的数目并不相同,最大迭代次数 max_iter 也会影响结果。设置不同的超参数 (hyper parameter) 可以得到不同的正确率,上表并不是各种方法的客观对比。事实上 scikit-learn 的实现更完整详细,效果也更好,文中三种方法的实现仅为基础版本。 完整代码详见笔者的 Github 项目,如有问题欢迎在评论区留言指出。 ———————————————— 版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 原文链接:https://blog.csdn.net/Sherry_ling/article/details/118713802
06-05
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值