1
0
mirror of https://github.com/aidygus/LinVAM.git synced 2024-11-27 02:58:06 +11:00
LinVAM/set_kws_threshold.py

302 lines
11 KiB
Python
Raw Normal View History

"""Script for auto tuning keyword spotting thresholds in pocketsphinx"""
from __future__ import print_function
import sys
import select
import os
import termios
import contextlib
import time
import re
import numpy as np
from pocketsphinx.pocketsphinx import *
from sphinxbase.sphinxbase import *
# keyphrases found in kwlist
WORDS = []
# test case containing multiple occurances
# of words to be used as training audio
TEST_CASE = []
# Threshold values
FREQUENCY = []
# End frame of each word in input speech
NO_OF_FRAMES = []
# Recorded speech input
OUTPUT_FILENAME = 'testing_audio.wav'
def preprocess_files(dic_path, kwlist_path):
"""
Function to generate required lists and call tuning functions
"""
global WORDS, TEST_CASE, FREQUENCY
# words found in dictinary
_content = []
with open(dic_path) as _f:
_content = _f.readlines()
_content = [x.strip() for x in _content]
with open(kwlist_path) as _f:
WORDS = _f.readlines()
WORDS = [x.strip()[:x.strip().rfind(' ')] for x in WORDS]
print (WORDS)
# Loop to find out initial thresholds based on phonetics provided in dictionary
for i, _ in enumerate(WORDS):
# starting position of first phone for a word
init_pos = 0
# Count number of phones based on frequency of spaces
spaces = 0
# In case there is more than one word in a keyphrase, add phones for all words
for _m in re.finditer(' ', WORDS[i]):
indices = [j for j, s in enumerate(_content) if WORDS[i][init_pos:_m.start()]+'\t' in s]
spaces = _content[indices[0]].count(' ') + spaces + 1
init_pos = _m.start()+1
indices = [j for j, s in enumerate(_content) if WORDS[i][init_pos:]+'\t' in s]
spaces += _content[indices[0]].count(' ') + 1
# Normalizing
if spaces <= 3:
FREQUENCY.append(spaces)
else:
FREQUENCY.append(spaces * 2)
# Adding random noise in test case for better tuning
TEST_CASE = ['[RANDOM]', '[RANDOM]']
TEST_CASE.extend(WORDS)
TEST_CASE.extend(['[RANDOM]', '[RANDOM]'])
TEST_CASE.extend(WORDS)
np.random.shuffle(TEST_CASE)
print ("HERE IS YOUR TRAINING SET")
print (TEST_CASE)
# record audio
record()
write_frequency_to_file(kwlist_path)
# Analysis begins
actual_tuning(dic_path, kwlist_path, 1)
print ("Removed many false alarms. New frequency: ")
print (FREQUENCY)
print ('Moving on to missed detections')
actual_tuning(dic_path, kwlist_path, 0)
print ("Frequency tuned to the best of the script's ability. New frequency: ")
print (FREQUENCY)
_missed, _fa = process_threshold(kws_analysis(dic_path, kwlist_path))
def write_frequency_to_file(kwlist_path):
"""
update modified frequencies in kwlist file
"""
_f = open(kwlist_path, 'w')
for i, val in enumerate(FREQUENCY):
_f.write(WORDS[i] + ' /1e-' + str(val) + '/\n')
_f.close()
@contextlib.contextmanager
def raw_mode(_file):
"""
Function handle the button press on successful utterance of word by user
"""
old_attrs = termios.tcgetattr(_file.fileno())
new_attrs = old_attrs[:]
new_attrs[3] = new_attrs[3] & ~(termios.ECHO | termios.ICANON)
try:
termios.tcsetattr(_file.fileno(), termios.TCSADRAIN, new_attrs)
yield
finally:
termios.tcsetattr(_file.fileno(), termios.TCSADRAIN, old_attrs)
def record():
"""
Records user's speech with timestamp for each spoken word
"""
global NO_OF_FRAMES
# rec -c 1 -r 16000 -b 16 recording.wav
print ("-----SAY THE FOLLOWING OUT LOUD AND PRESS ENTER-----")
print (TEST_CASE[0])
os.system('rec -q -c 1 -r 16000 -b 16 ' + OUTPUT_FILENAME + ' &')
NO_OF_FRAMES.append(0)
previous = time.time()
i = 0
with raw_mode(sys.stdin):
while True:
if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
_a = sys.stdin.read(1)
if _a == '\n':
if i == len(TEST_CASE)-1:
current = time.time()
NO_OF_FRAMES.append(NO_OF_FRAMES[i] + (current - previous)*100)
previous = current
print ("STOPPING RECORDING")
time.sleep(2)
# stop Recording
os.system('pkill rec')
print (NO_OF_FRAMES)
break
else:
current = time.time()
NO_OF_FRAMES.append(NO_OF_FRAMES[i] + (current - previous)*100)
previous = current
i = i+1
print ("-----SAY THE FOLLOWING OUT LOUD AND PRESS ENTER-----")
print (TEST_CASE[i])
def actual_tuning(dic_path, kwlist_path, _z):
"""
process false alarms and missed detections to tune thresholds
_z in the paramter is 1 for FA analysis and 0 for missed detection analysis
"""
# to store thresholds with minimum mismatches
minimum_inflection = [FREQUENCY[i] for i, _ in enumerate(WORDS)]
# to check whether a word's assessment has been finished or not
processed = [0 for i, _ in enumerate(WORDS)]
# get frequency of missed detections and false alarms
_missed, _fa = process_threshold(kws_analysis(dic_path, kwlist_path))
_least_negative_threshold = 1
_most_negative_threshold = 49
# Loop until there is at least one word whose assessment has not finished
while 0 in processed:
if _z == 1:
# If there is a False alarm, increase threshold
for i, val in enumerate(_fa):
if FREQUENCY[i] > _least_negative_threshold and processed[i] == 0:
if val[1] > 0:
FREQUENCY[i] -= 2
else:
processed[i] = 1
else:
processed[i] = 1
else:
# If there is a missed detection, decrease threshold
for i, val in enumerate(_missed):
if FREQUENCY[i] < _most_negative_threshold and processed[i] == 0:
if val[1] > 0:
FREQUENCY[i] += 1
else:
processed[i] = 1
else:
processed[i] = 1
write_frequency_to_file(kwlist_path)
print ('UPDATED FREQUENCY:')
print (FREQUENCY)
_previous_missed = []
_previous_missed.extend(_missed)
_previous_fa = []
_previous_fa.extend(_fa)
_missed, _fa = process_threshold(kws_analysis(dic_path, kwlist_path))
if _z == 1:
# If current readings show increase in missed detections,
# go to previous state and stop
for i, val in enumerate(_missed):
if val[1] > _previous_missed[i][1] and processed[i] == 0:
processed[i] = 1
FREQUENCY[i] += 2
else:
# If current readings show increase in false alarms,
# go to previous state and stop
for i, val in enumerate(_fa):
if val[1] > _previous_fa[i][1] and processed[i] == 0:
processed[i] = 1
FREQUENCY[i] -= 1
# If updated thresholds caused better accuracy, save them
for i, val in enumerate([_fa, _missed][_z == 0]):
if val[1] < [_previous_fa[i][1], _previous_missed][_z == 0]:
minimum_inflection[i] = FREQUENCY[i]
for i, val in enumerate([_fa, _missed][_z == 0]):
FREQUENCY[i] = minimum_inflection[i]
write_frequency_to_file(kwlist_path)
def kws_analysis(dic, kwlist):
"""
kws analysis on user speech and updated threshold values
"""
analysis_result = []
modeldir = "/usr/local/share/pocketsphinx/model/"
# Create a decoder with certain model
config = Decoder.default_config()
config.set_string('-hmm', os.path.join(modeldir, 'en-us/en-us'))
config.set_string('-dict', dic)
config.set_string('-kws', kwlist)
config.set_string('-dither', "no")
config.set_string('-logfn', '/dev/null')
config.set_string('-featparams', os.path.join(os.path.join(modeldir,
'en-us/en-us'), "feat.params"))
stream = open(OUTPUT_FILENAME, "rb")
# Process audio chunk by chunk. On keyphrase detected perform action and restart search
decoder = Decoder(config)
decoder.start_utt()
timer = 0
while True:
buf = stream.read(1024)
if buf:
decoder.process_raw(buf, False, False)
else:
break
if decoder.hyp() != None:
for seg in decoder.seg():
pass
analysis_result.append([seg.word.rstrip(), timer/320])
decoder.end_utt()
decoder.start_utt()
timer += 1024
return analysis_result
def process_threshold(analysis_result):
"""
calculate missed detections and false alarms
Argument: analysis result = kws result
"""
# stores timestamp of words which matche in both speech and kws result
_indices = []
missed = [[WORDS[i], 0] for i in range(len(WORDS))]
false_alarms = [[WORDS[i], 0] for i in range(len(WORDS))]
i = 0
for i, val in enumerate(analysis_result):
# Calculate the timestamp in speech closest to timestamp of word found by kws result
_index = min(range(len(NO_OF_FRAMES)), key=lambda l: abs(NO_OF_FRAMES[l] - val[1]))
_indices.append(_index)
if TEST_CASE[_index-1] == '[RANDOM]':
position_observer = WORDS.index(val[0])
false_alarms[position_observer][1] += 1
print ('FA Found', val[0], ' in place of RANDOM TEXT')
elif TEST_CASE[_index-1] == val[0]:
print ('DETECTED CORRECTLY', val[0])
else:
print ('FA Found', val[0], ' in place of ', TEST_CASE[_index-1])
position_original = WORDS.index(TEST_CASE[_index-1])
position_observer = WORDS.index(val[0])
missed[position_original][1] += 1
false_alarms[position_observer][1] += 1
# If speech had timestamp not mentioned in kws result, then its detection was missed
for i, val in enumerate(TEST_CASE):
if i+1 not in _indices and val != '[RANDOM]':
position_original = WORDS.index(val)
missed[position_original][1] += 1
print ('Missed ', val)
return missed, false_alarms
if __name__ == '__main__':
DIC_FILE = "/home/pankaj/catkin_ws/src/pocketsphinx/demo/voice_cmd.dic"
KWLIST_FILE = "/home/pankaj/catkin_ws/src/pocketsphinx/demo/automated.kwlist"
if len(sys.argv) == 3:
DIC_FILE = sys.argv[1]
KWLIST_FILE = sys.argv[2]
preprocess_files(DIC_FILE, KWLIST_FILE)