class KeyWordSpotter(torch.nn.Module):
def __init__(
self,
ckpt_path,
config_path,
token_path,
lexicon_path,
threshold,
min_frames=5,
max_frames=250,
interval_frames=50,
score_beam=3,
path_beam=20,
gpu=-1,
is_jit_model=False,
):
super().__init__()
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
with open(config_path, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
dataset_conf = configs['dataset_conf']
# feature related
self.sample_rate = 16000
self.wave_remained = np.array([])
self.num_mel_bins = dataset_conf['feature_extraction_conf'][
'num_mel_bins']
self.frame_length = dataset_conf['feature_extraction_conf'][
'frame_length'] # in ms
self.frame_shift = dataset_conf['feature_extraction_conf'][
'frame_shift'] # in ms
self.downsampling = dataset_conf.get('frame_skip', 1)
self.resolution = self.frame_shift / 1000 # in second
# fsmn splice operation
self.context_expansion = dataset_conf.get('context_expansion', False)
self.left_context = 0
self.right_context = 0
if self.context_expansion:
self.left_context = dataset_conf['context_expansion_conf']['left']
self.right_context = dataset_conf['context_expansion_conf'][
'right']
self.feature_remained = None
self.feats_ctx_offset = 0 # after downsample, offset exist.
# model related
if is_jit_model:
model = torch.jit.load(ckpt_path)
# For script model, only cpu is supported.
device = torch.device('cpu')
else:
# Init model from configs
model = init_model(configs['model'])
load_checkpoint(model, ckpt_path)
use_cuda = gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
self.device = device
self.model = model.to(device)
self.model.eval()
logging.info(f'model {ckpt_path} loaded.')
self.token_table = read_token(token_path)
logging.info(f'tokens {token_path} with '
f'{len(self.token_table)} units loaded.')
self.lexicon_table = read_lexicon(lexicon_path)
logging.info(f'lexicons {lexicon_path} with '
f'{len(self.lexicon_table)} units loaded.')
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
# decoding and detection related
self.score_beam = score_beam
self.path_beam = path_beam
self.threshold = threshold
self.min_frames = min_frames
self.max_frames = max_frames
self.interval_frames = interval_frames
self.cur_hyps = [(tuple(), (1.0, 0.0, []))]
self.hit_score = 1.0
self.hit_keyword = None
self.activated = False
self.total_frames = 0 # frame offset, for absolute time
self.last_active_pos = -1 # the last frame of being activated
self.result = {}
def set_keywords(self, keywords):
# 4. parse keywords tokens
assert keywords is not None, \
'at least one keyword is needed, ' \
'multiple keywords should be splitted with comma(,)'
keywords_str = keywords
keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {}
keywords_idxset = {0}
keywords_strset = {'<blk>'}
keywords_tokenmap = {'<blk>': 0}
for keyword in keywords_list:
strs, indexes = query_token_set(keyword, self.token_table,
self.lexicon_table)
keywords_token[keyword] = {}
keywords_token[keyword]['token_id'] = indexes
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
for i in indexes)
[keywords_strset.add(i) for i in strs]
[keywords_idxset.add(i) for i in indexes]
for txt, idx in zip(strs, indexes):
if keywords_tokenmap.get(txt, None) is None:
keywords_tokenmap[txt] = idx
token_print = ''
for txt, idx in keywords_tokenmap.items():
token_print += f'{txt}({idx}) '
logging.info(f'Token set is: {token_print}')
self.keywords_idxset = keywords_idxset
self.keywords_token = keywords_token
def accept_wave(self, wave):
assert isinstance(wave, bytes), \
"please make sure the input format is bytes(raw PCM)"
# convert bytes into float32
data = []
for i in range(0, len(wave), 2):
value = struct.unpack('<h', wave[i:i + 2])[0]
data.append(value)
# here we don't divide 32768.0,
# because kaldi.fbank accept original input
wave = np.array(data)
wave = np.append(self.wave_remained, wave)
if wave.size < (self.frame_length * self.sample_rate / 1000) \
* self.right_context :
self.wave_remained = wave
return None
wave_tensor = torch.from_numpy(wave).float().to(self.device)
wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension
feats = kaldi.fbank(wave_tensor,
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=0,
energy_floor=0.0,
sample_frequency=self.sample_rate)
# update wave remained
feat_len = len(feats)
frame_shift = int(self.frame_shift / 1000 * self.sample_rate)
self.wave_remained = wave[feat_len * frame_shift:]
if self.context_expansion:
assert feat_len > self.right_context, \
"make sure each chunk feat length is large than right context."
# pad feats with remained feature from last chunk
if self.feature_remained is None: # first chunk
# pad first frame at the beginning,
# replicate just support last dimension, so we do transpose.
feats_pad = F.pad(feats.T, (self.left_context, 0),
mode='replicate').T
else:
feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats_pad.shape[0] - (self.right_context +
self.right_context)
ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
for i in range(ctx_frm):
feats_ctx[i] = torch.cat(tuple(
feats_pad[i:i + ctx_win])).unsqueeze(0)
# update feature remained, and feats
self.feature_remained = \
feats[-(self.left_context + self.right_context):]
feats = feats_ctx.to(self.device)
if self.downsampling > 1:
last_remainder = 0 if self.feats_ctx_offset == 0 \
else self.downsampling - self.feats_ctx_offset
remainder = (feats.size(0) + last_remainder) % self.downsampling
feats = feats[self.feats_ctx_offset::self.downsampling, :]
self.feats_ctx_offset = remainder \
if remainder == 0 else self.downsampling - remainder
return feats
def decode_keywords(self, t, probs):
absolute_time = t + self.total_frames
# search next_hyps depend on current probs and hyps.
next_hyps = ctc_prefix_beam_search(absolute_time, probs, self.cur_hyps,
self.keywords_idxset,
self.score_beam)
# update cur_hyps. note: the hyps is sort by path score(pnb+pb),
# not the keywords' probabilities.
cur_hyps = next_hyps[:self.path_beam]
self.cur_hyps = cur_hyps
def execute_detection(self, t):
absolute_time = t + self.total_frames
hit_keyword = None
start = 0
end = 0
# hyps for detection
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
# detect keywords in decoding paths.
for one_hyp in hyps:
prefix_ids = one_hyp[0]
# path_score = one_hyp[1]
prefix_nodes = one_hyp[2]
assert len(prefix_ids) == len(prefix_nodes)
for word in self.keywords_token.keys():
lab = self.keywords_token[word]['token_id']
offset = is_sublist(prefix_ids, lab)
if offset != -1:
hit_keyword = word
start = prefix_nodes[offset]['frame']
end = prefix_nodes[offset + len(lab) - 1]['frame']
for idx in range(offset, offset + len(lab)):
self.hit_score *= prefix_nodes[idx]['prob']
break
if hit_keyword is not None:
self.hit_score = math.sqrt(self.hit_score)
break
duration = end - start
if hit_keyword is not None:
if self.hit_score >= self.threshold and \
self.min_frames <= duration <= self.max_frames \
and (self.last_active_pos == -1 or
end - self.last_active_pos >= self.interval_frames):
self.activated = True
self.last_active_pos = end
logging.info(
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"duration {duration}, score {self.hit_score}, Activated.")
elif self.last_active_pos > 0 and \
end - self.last_active_pos < self.interval_frames:
logging.info(
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"but interval {end-self.last_active_pos} "
f"is lower than {self.interval_frames}, Deactivated. ")
elif self.hit_score < self.threshold:
logging.info(f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"but {self.hit_score} "
f"is lower than {self.threshold}, Deactivated. ")
elif self.min_frames > duration or duration > self.max_frames:
logging.info(
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"but {duration} beyond range"
f"({self.min_frames}~{self.max_frames}), Deactivated. ")
self.result = {
"state": 1 if self.activated else 0,
"keyword": hit_keyword if self.activated else None,
"start": start * self.resolution if self.activated else None,
"end": end * self.resolution if self.activated else None,
"score": self.hit_score if self.activated else None
}
def forward(self, wave_chunk):
feature = self.accept_wave(wave_chunk)
if feature is None or feature.size(0) < 1:
return {} # # the feature is not enough to get result.
feature = feature.unsqueeze(0) # add a batch dimension
logits, self.in_cache = self.model(feature, self.in_cache)
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
probs = probs[0].cpu() # remove batch dimension
for (t, prob) in enumerate(probs):
t *= self.downsampling
self.decode_keywords(t, prob)
self.execute_detection(t)
if self.activated:
self.reset()
# since a chunk include about 30 frames,
# once activated, we can jump the latter frames.
# TODO: there should give another method to update result,
# avoiding self.result being cleared.
break
# update frame offset
self.total_frames += len(probs) * self.downsampling
# For streaming kws, the cur_hyps should be reset if the time of
# a possible keyword last over the max_frames value you set.
# see this issue:https://github.com/duj12/kws_demo/issues/2
if len(self.cur_hyps) > 0 and len(self.cur_hyps[0][0]) > 0:
keyword_may_start = int(self.cur_hyps[0][1][2][0]['frame'])
if (self.total_frames - keyword_may_start) > self.max_frames:
self.reset()
return self.result
def reset(self):
self.cur_hyps = [(tuple(), (1.0, 0.0, []))]
self.activated = False
self.hit_score = 1.0
def reset_all(self):
self.reset()
self.wave_remained = np.array([])
self.feature_remained = None
self.feats_ctx_offset = 0 # after downsample, offset exist.
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
self.total_frames = 0 # frame offset, for absolute time
self.last_active_pos = -1 # the last frame of being activated
self.result = {}请帮我缕清整个脉络