mirror of
https://github.com/aidygus/LinVAM.git
synced 2024-11-30 12:28:06 +11:00
302 lines
11 KiB
Python
302 lines
11 KiB
Python
|
"""Script for auto tuning keyword spotting thresholds in pocketsphinx"""
|
||
|
from __future__ import print_function
|
||
|
import sys
|
||
|
import select
|
||
|
import os
|
||
|
import termios
|
||
|
import contextlib
|
||
|
import time
|
||
|
import re
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from pocketsphinx.pocketsphinx import *
|
||
|
from sphinxbase.sphinxbase import *
|
||
|
|
||
|
# keyphrases found in kwlist
|
||
|
WORDS = []
|
||
|
# test case containing multiple occurances
|
||
|
# of words to be used as training audio
|
||
|
TEST_CASE = []
|
||
|
# Threshold values
|
||
|
FREQUENCY = []
|
||
|
# End frame of each word in input speech
|
||
|
NO_OF_FRAMES = []
|
||
|
# Recorded speech input
|
||
|
OUTPUT_FILENAME = 'testing_audio.wav'
|
||
|
|
||
|
def preprocess_files(dic_path, kwlist_path):
|
||
|
"""
|
||
|
Function to generate required lists and call tuning functions
|
||
|
"""
|
||
|
global WORDS, TEST_CASE, FREQUENCY
|
||
|
|
||
|
# words found in dictinary
|
||
|
_content = []
|
||
|
with open(dic_path) as _f:
|
||
|
_content = _f.readlines()
|
||
|
_content = [x.strip() for x in _content]
|
||
|
|
||
|
with open(kwlist_path) as _f:
|
||
|
WORDS = _f.readlines()
|
||
|
WORDS = [x.strip()[:x.strip().rfind(' ')] for x in WORDS]
|
||
|
print (WORDS)
|
||
|
|
||
|
# Loop to find out initial thresholds based on phonetics provided in dictionary
|
||
|
for i, _ in enumerate(WORDS):
|
||
|
# starting position of first phone for a word
|
||
|
init_pos = 0
|
||
|
# Count number of phones based on frequency of spaces
|
||
|
spaces = 0
|
||
|
# In case there is more than one word in a keyphrase, add phones for all words
|
||
|
for _m in re.finditer(' ', WORDS[i]):
|
||
|
indices = [j for j, s in enumerate(_content) if WORDS[i][init_pos:_m.start()]+'\t' in s]
|
||
|
spaces = _content[indices[0]].count(' ') + spaces + 1
|
||
|
init_pos = _m.start()+1
|
||
|
indices = [j for j, s in enumerate(_content) if WORDS[i][init_pos:]+'\t' in s]
|
||
|
spaces += _content[indices[0]].count(' ') + 1
|
||
|
# Normalizing
|
||
|
if spaces <= 3:
|
||
|
FREQUENCY.append(spaces)
|
||
|
else:
|
||
|
FREQUENCY.append(spaces * 2)
|
||
|
# Adding random noise in test case for better tuning
|
||
|
TEST_CASE = ['[RANDOM]', '[RANDOM]']
|
||
|
TEST_CASE.extend(WORDS)
|
||
|
TEST_CASE.extend(['[RANDOM]', '[RANDOM]'])
|
||
|
TEST_CASE.extend(WORDS)
|
||
|
np.random.shuffle(TEST_CASE)
|
||
|
print ("HERE IS YOUR TRAINING SET")
|
||
|
print (TEST_CASE)
|
||
|
|
||
|
# record audio
|
||
|
record()
|
||
|
write_frequency_to_file(kwlist_path)
|
||
|
|
||
|
# Analysis begins
|
||
|
actual_tuning(dic_path, kwlist_path, 1)
|
||
|
print ("Removed many false alarms. New frequency: ")
|
||
|
print (FREQUENCY)
|
||
|
print ('Moving on to missed detections')
|
||
|
actual_tuning(dic_path, kwlist_path, 0)
|
||
|
print ("Frequency tuned to the best of the script's ability. New frequency: ")
|
||
|
print (FREQUENCY)
|
||
|
_missed, _fa = process_threshold(kws_analysis(dic_path, kwlist_path))
|
||
|
|
||
|
def write_frequency_to_file(kwlist_path):
|
||
|
"""
|
||
|
update modified frequencies in kwlist file
|
||
|
"""
|
||
|
_f = open(kwlist_path, 'w')
|
||
|
for i, val in enumerate(FREQUENCY):
|
||
|
_f.write(WORDS[i] + ' /1e-' + str(val) + '/\n')
|
||
|
_f.close()
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def raw_mode(_file):
|
||
|
"""
|
||
|
Function handle the button press on successful utterance of word by user
|
||
|
"""
|
||
|
old_attrs = termios.tcgetattr(_file.fileno())
|
||
|
new_attrs = old_attrs[:]
|
||
|
new_attrs[3] = new_attrs[3] & ~(termios.ECHO | termios.ICANON)
|
||
|
try:
|
||
|
termios.tcsetattr(_file.fileno(), termios.TCSADRAIN, new_attrs)
|
||
|
yield
|
||
|
finally:
|
||
|
termios.tcsetattr(_file.fileno(), termios.TCSADRAIN, old_attrs)
|
||
|
|
||
|
def record():
|
||
|
"""
|
||
|
Records user's speech with timestamp for each spoken word
|
||
|
"""
|
||
|
global NO_OF_FRAMES
|
||
|
# rec -c 1 -r 16000 -b 16 recording.wav
|
||
|
print ("-----SAY THE FOLLOWING OUT LOUD AND PRESS ENTER-----")
|
||
|
print (TEST_CASE[0])
|
||
|
os.system('rec -q -c 1 -r 16000 -b 16 ' + OUTPUT_FILENAME + ' &')
|
||
|
NO_OF_FRAMES.append(0)
|
||
|
previous = time.time()
|
||
|
i = 0
|
||
|
with raw_mode(sys.stdin):
|
||
|
while True:
|
||
|
if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
|
||
|
_a = sys.stdin.read(1)
|
||
|
if _a == '\n':
|
||
|
if i == len(TEST_CASE)-1:
|
||
|
current = time.time()
|
||
|
NO_OF_FRAMES.append(NO_OF_FRAMES[i] + (current - previous)*100)
|
||
|
previous = current
|
||
|
print ("STOPPING RECORDING")
|
||
|
time.sleep(2)
|
||
|
# stop Recording
|
||
|
os.system('pkill rec')
|
||
|
print (NO_OF_FRAMES)
|
||
|
break
|
||
|
else:
|
||
|
current = time.time()
|
||
|
NO_OF_FRAMES.append(NO_OF_FRAMES[i] + (current - previous)*100)
|
||
|
previous = current
|
||
|
i = i+1
|
||
|
print ("-----SAY THE FOLLOWING OUT LOUD AND PRESS ENTER-----")
|
||
|
print (TEST_CASE[i])
|
||
|
|
||
|
def actual_tuning(dic_path, kwlist_path, _z):
|
||
|
"""
|
||
|
process false alarms and missed detections to tune thresholds
|
||
|
_z in the paramter is 1 for FA analysis and 0 for missed detection analysis
|
||
|
"""
|
||
|
# to store thresholds with minimum mismatches
|
||
|
minimum_inflection = [FREQUENCY[i] for i, _ in enumerate(WORDS)]
|
||
|
# to check whether a word's assessment has been finished or not
|
||
|
processed = [0 for i, _ in enumerate(WORDS)]
|
||
|
# get frequency of missed detections and false alarms
|
||
|
_missed, _fa = process_threshold(kws_analysis(dic_path, kwlist_path))
|
||
|
|
||
|
_least_negative_threshold = 1
|
||
|
_most_negative_threshold = 49
|
||
|
|
||
|
# Loop until there is at least one word whose assessment has not finished
|
||
|
while 0 in processed:
|
||
|
if _z == 1:
|
||
|
# If there is a False alarm, increase threshold
|
||
|
for i, val in enumerate(_fa):
|
||
|
if FREQUENCY[i] > _least_negative_threshold and processed[i] == 0:
|
||
|
if val[1] > 0:
|
||
|
FREQUENCY[i] -= 2
|
||
|
else:
|
||
|
processed[i] = 1
|
||
|
else:
|
||
|
processed[i] = 1
|
||
|
else:
|
||
|
# If there is a missed detection, decrease threshold
|
||
|
for i, val in enumerate(_missed):
|
||
|
if FREQUENCY[i] < _most_negative_threshold and processed[i] == 0:
|
||
|
if val[1] > 0:
|
||
|
FREQUENCY[i] += 1
|
||
|
else:
|
||
|
processed[i] = 1
|
||
|
else:
|
||
|
processed[i] = 1
|
||
|
|
||
|
write_frequency_to_file(kwlist_path)
|
||
|
|
||
|
print ('UPDATED FREQUENCY:')
|
||
|
print (FREQUENCY)
|
||
|
|
||
|
_previous_missed = []
|
||
|
_previous_missed.extend(_missed)
|
||
|
_previous_fa = []
|
||
|
_previous_fa.extend(_fa)
|
||
|
|
||
|
_missed, _fa = process_threshold(kws_analysis(dic_path, kwlist_path))
|
||
|
|
||
|
if _z == 1:
|
||
|
# If current readings show increase in missed detections,
|
||
|
# go to previous state and stop
|
||
|
for i, val in enumerate(_missed):
|
||
|
if val[1] > _previous_missed[i][1] and processed[i] == 0:
|
||
|
processed[i] = 1
|
||
|
FREQUENCY[i] += 2
|
||
|
else:
|
||
|
# If current readings show increase in false alarms,
|
||
|
# go to previous state and stop
|
||
|
for i, val in enumerate(_fa):
|
||
|
if val[1] > _previous_fa[i][1] and processed[i] == 0:
|
||
|
processed[i] = 1
|
||
|
FREQUENCY[i] -= 1
|
||
|
|
||
|
# If updated thresholds caused better accuracy, save them
|
||
|
for i, val in enumerate([_fa, _missed][_z == 0]):
|
||
|
if val[1] < [_previous_fa[i][1], _previous_missed][_z == 0]:
|
||
|
minimum_inflection[i] = FREQUENCY[i]
|
||
|
|
||
|
for i, val in enumerate([_fa, _missed][_z == 0]):
|
||
|
FREQUENCY[i] = minimum_inflection[i]
|
||
|
write_frequency_to_file(kwlist_path)
|
||
|
|
||
|
def kws_analysis(dic, kwlist):
|
||
|
"""
|
||
|
kws analysis on user speech and updated threshold values
|
||
|
"""
|
||
|
analysis_result = []
|
||
|
|
||
|
modeldir = "/usr/local/share/pocketsphinx/model/"
|
||
|
|
||
|
# Create a decoder with certain model
|
||
|
config = Decoder.default_config()
|
||
|
config.set_string('-hmm', os.path.join(modeldir, 'en-us/en-us'))
|
||
|
config.set_string('-dict', dic)
|
||
|
config.set_string('-kws', kwlist)
|
||
|
config.set_string('-dither', "no")
|
||
|
config.set_string('-logfn', '/dev/null')
|
||
|
config.set_string('-featparams', os.path.join(os.path.join(modeldir,
|
||
|
'en-us/en-us'), "feat.params"))
|
||
|
|
||
|
stream = open(OUTPUT_FILENAME, "rb")
|
||
|
|
||
|
# Process audio chunk by chunk. On keyphrase detected perform action and restart search
|
||
|
decoder = Decoder(config)
|
||
|
decoder.start_utt()
|
||
|
timer = 0
|
||
|
while True:
|
||
|
buf = stream.read(1024)
|
||
|
if buf:
|
||
|
decoder.process_raw(buf, False, False)
|
||
|
else:
|
||
|
break
|
||
|
if decoder.hyp() != None:
|
||
|
for seg in decoder.seg():
|
||
|
pass
|
||
|
analysis_result.append([seg.word.rstrip(), timer/320])
|
||
|
|
||
|
decoder.end_utt()
|
||
|
decoder.start_utt()
|
||
|
timer += 1024
|
||
|
return analysis_result
|
||
|
|
||
|
def process_threshold(analysis_result):
|
||
|
"""
|
||
|
calculate missed detections and false alarms
|
||
|
Argument: analysis result = kws result
|
||
|
"""
|
||
|
# stores timestamp of words which matche in both speech and kws result
|
||
|
_indices = []
|
||
|
|
||
|
missed = [[WORDS[i], 0] for i in range(len(WORDS))]
|
||
|
false_alarms = [[WORDS[i], 0] for i in range(len(WORDS))]
|
||
|
i = 0
|
||
|
|
||
|
for i, val in enumerate(analysis_result):
|
||
|
# Calculate the timestamp in speech closest to timestamp of word found by kws result
|
||
|
_index = min(range(len(NO_OF_FRAMES)), key=lambda l: abs(NO_OF_FRAMES[l] - val[1]))
|
||
|
_indices.append(_index)
|
||
|
|
||
|
if TEST_CASE[_index-1] == '[RANDOM]':
|
||
|
position_observer = WORDS.index(val[0])
|
||
|
false_alarms[position_observer][1] += 1
|
||
|
print ('FA Found', val[0], ' in place of RANDOM TEXT')
|
||
|
elif TEST_CASE[_index-1] == val[0]:
|
||
|
print ('DETECTED CORRECTLY', val[0])
|
||
|
else:
|
||
|
print ('FA Found', val[0], ' in place of ', TEST_CASE[_index-1])
|
||
|
position_original = WORDS.index(TEST_CASE[_index-1])
|
||
|
position_observer = WORDS.index(val[0])
|
||
|
missed[position_original][1] += 1
|
||
|
false_alarms[position_observer][1] += 1
|
||
|
# If speech had timestamp not mentioned in kws result, then its detection was missed
|
||
|
for i, val in enumerate(TEST_CASE):
|
||
|
if i+1 not in _indices and val != '[RANDOM]':
|
||
|
position_original = WORDS.index(val)
|
||
|
missed[position_original][1] += 1
|
||
|
print ('Missed ', val)
|
||
|
return missed, false_alarms
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
DIC_FILE = "/home/pankaj/catkin_ws/src/pocketsphinx/demo/voice_cmd.dic"
|
||
|
KWLIST_FILE = "/home/pankaj/catkin_ws/src/pocketsphinx/demo/automated.kwlist"
|
||
|
if len(sys.argv) == 3:
|
||
|
DIC_FILE = sys.argv[1]
|
||
|
KWLIST_FILE = sys.argv[2]
|
||
|
preprocess_files(DIC_FILE, KWLIST_FILE)
|