Implemented strategies for sos and number recognition

Deleted message sender and promt service

Implemented fast whisper, but it is not working

WavStack refactored into QueueStack, which can use different strategies for proccessing
This commit is contained in:
Sviatoslav Tsariov Yurievich 2024-03-22 18:59:42 +03:00
parent dbbf845e56
commit 37adf74745
30 changed files with 322 additions and 220 deletions

View File

@ -1,67 +1,59 @@
from flask import Flask, abort, request from flask import Flask, abort, request
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from dotenv import load_dotenv
import os
import whisper
import torch
import sys import sys
import re
load_dotenv() import config
HARPYIA_PROMPT = os.getenv('HARPYIA_PROMPT') or 'спасите помогите на помощь пожар' from queue_stack import QueueStack
HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'medium' from queue_stack.strategies import RecognizeAndSendStrategy
HARPYIA_LANGUAGE = os.getenv('HARPYIA_LANGUAGE') or 'ru'
# Check if NVIDIA GPU is available from recognizer import Recognizer
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" from recognizer.strategies import WhisperStrategy, FastWhisperStrategy
# Load the Whisper model: from message import MessageService
model = whisper.load_model(HARPYIA_MODEL, device=DEVICE) from message.strategies import SosMessageStrategy, NumberMessageStrategy
app = Flask(__name__) app = Flask(__name__)
whisper_recognizer = Recognizer(WhisperStrategy())
fast_whisper_recognizer = Recognizer(FastWhisperStrategy())
sos_message_service = MessageService(SosMessageStrategy())
number_message_service = MessageService(NumberMessageStrategy())
queue_stack = QueueStack(RecognizeAndSendStrategy())
queue_stack.start_loop_in_thread()
@app.route("/") @app.route("/")
def hello(): def hello():
return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize_number' route." return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize-number' route."
def recognize_files(handler_fn): def recognize_files(message_service: MessageService):
if not request.files: if not request.files:
abort(400) abort(400)
results = [] results = []
for filename, handle in request.files.items(): for filename, handle in request.files.items():
temp = NamedTemporaryFile() temp = NamedTemporaryFile()
handle.save(temp) handle.save(temp)
result = model.transcribe(temp.name, language=HARPYIA_LANGUAGE, initial_prompt=HARPYIA_PROMPT)
results.append({
'filename': filename,
'transcript': handler_fn(result['text']),
})
print(results, file=sys.stderr) results.append(queue_stack.append_and_await((
return {'results': results} temp,
whisper_recognizer,
message_service,
config.HARPYIA_LANGUAGE,
message_service.get_prompt()
)))
print(results, file=sys.stderr)
return {'results': results}
@app.route('/recognize', methods=['POST']) @app.route('/recognize', methods=['POST'])
def recognize(): def recognize():
return recognize_files(lambda text: text) return recognize_files(sos_message_service)
@app.route('/recognize_number', methods=['POST']) @app.route('/recognize-number', methods=['POST'])
def recognize_number(): def recognize_number():
return recognize_files(transfer_and_clean) return recognize_files(number_message_service)
def transfer_and_clean(input_string):
number_mapping = {
"один": "1",
"два": "2",
"три": "3"
}
for word, number in number_mapping.items():
input_string = input_string.replace(word, number)
input_string = re.sub(r'[^\d]+', '', input_string)
return input_string

View File

@ -4,9 +4,18 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
HARPYIA_PROMPT = os.getenv('HARPYIA_PROMPT') or 'спасите помогите на помощь пожар' HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'small'
HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'medium'
HARPYIA_LANGUAGE = os.getenv('HARPYIA_LANGUAGE') or 'ru' HARPYIA_LANGUAGE = os.getenv('HARPYIA_LANGUAGE') or 'ru'
HARPYIA_SAMPLE_RATE = os.getenv('HARPYIA_SAMPLE_RATE') or 160000
WHISPER_NUM_WORKERS = os.getenv('WHISPER_NUM_WORKERS') or 6
WHISPER_CPU_THREADS = os.getenv('WHISPER_CPU_THREADS') or 10
WHISPER_BEAM_SIZE = os.getenv('WHISPER_BEAM_SIZE') or 5
SOS_PROMPT = os.getenv('SOS_PROMPT') or 'спасите помогите помощь пожар караул кирилл'
NUMBER_PROMPT = os.getenv('NUMBER_PROMPT') or 'один два три четыре пять шесть семь восемь девять десять одинадцать двенадцать тринадцать сто сот'
RAT_URL = os.getenv('RAT_URL') or 'localhost:8081'
# Check if NVIDIA GPU is available # Check if NVIDIA GPU is available
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

View File

@ -1,32 +0,0 @@
from threading import Thread
from recognizer import Recognizer
from message import MessageComposer
class WavStack:
def __init__(self, recognizer: Recognizer, message_composer: MessageComposer):
self._stack = []
self._recognizer = recognizer
self._message_composer = message_composer
self._running = False
def append(self, file):
self._stack.append(file)
def loop(self):
self._running = True
while self._running:
if self._stack:
file = self._stack.pop(0)
recognized_text = self._recognizer.recognize(file)
message = self._message_composer.compose(recognized_text)
if message.has_prompt():
message.send()
def start_loop_in_thread(self):
thread = Thread(target=self.loop)
thread.start()
def stop_loop(self):
self._running = False

View File

@ -1,6 +1 @@
from message.prompt_service import PromptService from message.message_service import MessageService
from message.message_sender.message_sender import MessageSender
from message.message_composer import MessageComposer
from message.message import Message
import message.message_sender as message_sender

View File

@ -1,17 +0,0 @@
from message import PromptService, MessageSender
class Message:
def __init__(self, prompt_service: PromptService, message_sender: MessageSender, \
recognized_text: str):
self._prompt_service = prompt_service
self._message_sender = message_sender
self._recognized_text = recognized_text
def has_prompt(self) -> bool:
return self._prompt_service.has_prompt(self._recognized_text)
def send(self) -> None:
self._message_sender.send(self._generate_response())
def _generate_response(self) -> str:
return self._prompt_service.filter_words_with_prompt(self._recognized_text)

View File

@ -1,9 +0,0 @@
from message import Message, PromptService, MessageSender
class MessageComposer:
def __init__(self, prompt_service: PromptService, message_sender: MessageSender):
self._prompt_service = prompt_service
self._message_sender = message_sender
def compose(self, recognized_text) -> Message:
return Message(self._prompt_service, self._message_sender, recognized_text)

View File

@ -1,3 +0,0 @@
from message.message_sender.message_sender import MessageSender
from message.message_sender.message_sender_strategy import MessageSenderStrategy
from message.message_sender.rat_strategy import RatStrategy

View File

@ -1,8 +0,0 @@
from message.message_sender import MessageSenderStrategy
class MessageSender:
def __init__(self, strategy: MessageSenderStrategy) -> None:
self._strategy = strategy
def send(self, message) -> None:
self._strategy.send(message)

View File

@ -1,6 +0,0 @@
from abc import ABC, abstractmethod
class MessageSenderStrategy(ABC):
@abstractmethod
def send(self, message) -> None:
pass

View File

@ -1,12 +0,0 @@
import requests
from message.message_sender import MessageSenderStrategy
MESSAGE_ENDPOINT = '/message'
class RatStrategy(MessageSenderStrategy):
def __init__(self, url):
self._url = url
def send(self, message):
requests.post(self._url + MESSAGE_ENDPOINT, json={'message': message})

View File

@ -0,0 +1,24 @@
import sys
from message.strategies import BaseMessageStrategy
class MessageService:
def __init__(self, strategy: BaseMessageStrategy) -> None:
self._strategy = strategy
def get_prompt(self) -> str:
self._strategy.get_prompt()
def transfer(self, text: str) -> any:
return self._strategy.transfer(text)
def send(self, message: str) -> any:
self._strategy.send(message)
def transfer_and_send(self, recognized_result: any) -> any:
message = self.transfer(recognized_result)
if message:
self.send(message)
print('Sending message:', recognized_result, file=sys.stderr)
return message

View File

@ -1,16 +0,0 @@
class PromptService:
def __init__(self, prompt):
self._prompt = prompt
def has_prompt(self, text: str) -> bool:
for part in text.split(' '):
return part in self._prompt.split(' ')
def filter_words_with_prompt(self, text: str) -> str:
words = []
for part in text.split(' '):
if part in self._prompt.split(' '):
words.append(part)
return words

View File

@ -0,0 +1,3 @@
from message.strategies.base_message_strategy import BaseMessageStrategy
from message.strategies.sos_message_strategy import SosMessageStrategy
from message.strategies.number_message_strategy import NumberMessageStrategy

View File

@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
class BaseMessageStrategy(ABC):
@abstractmethod
def get_prompt() -> str:
pass
@abstractmethod
def transfer(self, text: str) -> any:
pass
@abstractmethod
def send(self, message: str) -> any:
pass

View File

@ -0,0 +1,31 @@
import re
import config
from message.strategies import BaseMessageStrategy
class NumberMessageStrategy(BaseMessageStrategy):
def __init__(self, prompt=config.NUMBER_PROMPT) -> None:
self._prompt = prompt
def get_prompt(self):
return self._prompt
def transfer(self, recognized_result: any) -> str:
return self._transfer_and_clean(recognized_result['text'])
def _transfer_and_clean(self, text: str) -> str:
number_mapping = {
"один": "1",
"два": "2",
"три": "3"
}
for word, number in number_mapping.items():
transfered_text = text.replace(word, number)
transfered_text = re.sub(r'[^\d]+', '', transfered_text)
return {'recognized': transfered_text}
def send(self, message: str) -> None:
pass

View File

@ -0,0 +1,35 @@
from typing import List
import requests
import config
from message.strategies import BaseMessageStrategy
MESSAGE_ENDPOINT = '/message'
class SosMessageStrategy(BaseMessageStrategy):
def __init__(self, prompt=config.SOS_PROMPT, url=config.RAT_URL) -> None:
self._prompt = prompt
self._url = url
def get_prompt(self):
return self._prompt
def transfer(self, recognized_result: any) -> str:
return {
'transcript': recognized_result['text'],
'results': self._filter_words_with_prompt(recognized_result['text']),
'segments': recognized_result['segments']
}
def _filter_words_with_prompt(self, text: str) -> str:
words = []
for prompt in self._prompt.split(' '):
if prompt in text.lower():
words.append(prompt)
return words
def send(self, message) -> any:
pass
#return requests.post(self._url + MESSAGE_ENDPOINT, json={'message': message})

View File

@ -0,0 +1 @@
from queue_stack.queue_stack import QueueStack

View File

@ -0,0 +1,51 @@
import sys
from threading import Thread, Event, Lock
from queue_stack.strategies import BaseProcessStrategy
class QueueStack:
def __init__(self, strategy: BaseProcessStrategy) -> None:
self._stack = []
self._strategy = strategy
self._lock = Lock()
self._running = False
self._last_response = None
def append(self, args, event=None) -> None:
with self._lock:
self._stack.append((args, event))
def append_and_await(self, args) -> any:
event = Event()
self.append(args, event=event)
event.wait()
event.clear()
return self._last_response
def loop(self) -> None:
self._running = True
while self._running:
with self._lock:
if self._stack:
print('Stack length:', len(self._stack), file=sys.stderr)
(args, event) = self._stack.pop(0)
self._last_response = self._process(*args)
if event:
event.set()
def _process(self, *args, **kwargs) -> any:
return self._strategy.process(*args, **kwargs)
def start_loop_in_thread(self) -> None:
thread = Thread(target=self.loop)
thread.start()
def stop_loop(self) -> None:
self._running = False

View File

@ -0,0 +1,2 @@
from queue_stack.strategies.base_process_strategy import BaseProcessStrategy
from queue_stack.strategies.recognize_and_send_strategy import RecognizeAndSendStrategy

View File

@ -0,0 +1,6 @@
from abc import ABC, abstractmethod
class BaseProcessStrategy(ABC):
@abstractmethod
def process(self, *args, **kwargs) -> any:
pass

View File

@ -0,0 +1,14 @@
import sys
from queue_stack.strategies import BaseProcessStrategy
from message import MessageService
from recognizer import Recognizer
class RecognizeAndSendStrategy(BaseProcessStrategy):
def process(self, file, recognizer: Recognizer, message_service: MessageService, language, prompt) -> any:
result = recognizer.recognize(file, language=language, prompt=prompt)
message = message_service.transfer_and_send(result)
print(message, file=sys.stderr)
return message

View File

@ -1,4 +1 @@
from recognizer.recognizer import Recognizer from recognizer.recognizer import Recognizer
from recognizer.recognizer_strategy import RecognizerStrategy
from recognizer.whisper_strategy import WhisperStrategy
from recognizer.fast_whisper_strategy import FastWhisperStrategy

View File

@ -1,37 +0,0 @@
import whisper
from faster_whisper import WhisperModel
import config
from recognizer import RecognizerStrategy
class FastWhisperStrategy(RecognizerStrategy):
def __init__(self) -> None:
self._model = WhisperModel(
model_size=config.HARPYIA_MODEL,
device=config.DEVICE,
num_workers=6,
cpu_threads=10,
# in_memory=True,
)
def recognize(self, file) -> str:
audio = self._prepare_file(file.name)
return self._transcribe(audio)
def _prepare_file(self, filename: str):
audio = whisper.load_audio(filename, sr=16000)
audio = whisper.pad_or_trim(audio)
return audio
def _transcribe(self, audio):
segments, _ = self._model.transcribe(
audio,
language=config.HARPYIA_LANGUAGE,
initial_prompt=config.HARPYIA_PROMPT,
condition_on_previous_text=False,
vad_filter=True,
beam_size=5,
)
return segments

View File

@ -1,8 +1,14 @@
from recognizer import RecognizerStrategy import sys
import config
from recognizer.strategies import BaseRecognizerStrategy
class Recognizer: class Recognizer:
def __init__(self, strategy: RecognizerStrategy): def __init__(self, strategy: BaseRecognizerStrategy) -> None:
self._strategy = strategy self._strategy = strategy
def recognize(self, file) -> str: def recognize(self, file, language, prompt) -> str:
self._strategy.recognize(file) result = self._strategy.recognize(file, language=language, prompt=prompt)
print(f'Result: {result}', file=sys.stderr)
return result

View File

@ -1,6 +0,0 @@
from abc import ABC, abstractmethod
class RecognizerStrategy(ABC):
@abstractmethod
def recognize(self, file) -> str:
pass

View File

@ -0,0 +1,3 @@
from recognizer.strategies.base_recognizer_strategy import BaseRecognizerStrategy
from recognizer.strategies.whisper_strategy import WhisperStrategy
from recognizer.strategies.fast_whisper_strategy import FastWhisperStrategy

View File

@ -0,0 +1,6 @@
from abc import ABC, abstractmethod
class BaseRecognizerStrategy(ABC):
@abstractmethod
def recognize(self, file, language, prompt) -> any:
pass

View File

@ -0,0 +1,59 @@
import sys
import whisper
from faster_whisper import WhisperModel
import config
from recognizer.strategies import BaseRecognizerStrategy
class FastWhisperStrategy(BaseRecognizerStrategy):
def __init__(self) -> None:
self._model = WhisperModel(
model_size_or_path=config.HARPYIA_MODEL,
device=config.DEVICE,
num_workers=config.WHISPER_NUM_WORKERS,
cpu_threads=config.WHISPER_CPU_THREADS
)
def recognize(self, file, language, prompt) -> any:
audio = self._prepare_file(file.name)
return self._transcribe(audio, language, prompt)
def _prepare_file(self, filename: str):
audio = whisper.load_audio(filename, sr=config.HARPYIA_SAMPLE_RATE)
audio = whisper.pad_or_trim(audio)
return audio
def _transcribe(self, audio, language, prompt):
segments, _ = self._model.transcribe(
audio,
language=language,
initial_prompt=prompt,
condition_on_previous_text=False,
vad_filter=True,
beam_size=config.WHISPER_BEAM_SIZE,
)
print('Segments:', file=sys.stderr)
for i in segments:
print(i, file=sys.stderr)
words = []
for segment in list(segments):
words.append(segment.text)
return {
'text': ' '.join(words),
'segments': {
'id': None,
'seek': None,
'start': None,
'end': None,
'text': None,
'tokens': None,
'temperature': None,
'avg_logprob': None,
'compression_ratio': None,
'no_speech_prob': None,
}
}

View File

@ -0,0 +1,12 @@
import whisper
import config
from recognizer.strategies import BaseRecognizerStrategy
class WhisperStrategy(BaseRecognizerStrategy):
def __init__(self) -> None:
self._model = whisper.load_model(config.HARPYIA_MODEL, device=config.DEVICE)
def recognize(self, file, language, prompt) -> any:
return self._model.transcribe(file.name, \
language=language, initial_prompt=prompt)

View File

@ -1,12 +0,0 @@
import whisper
import config
from recognizer import RecognizerStrategy
class WhisperStrategy(RecognizerStrategy):
def __init__(self) -> None:
self._model = whisper.load_model(config.HARPYIA_MODEL, device=config.DEVICE)
def recognize(self, file) -> str:
return self._model.transcribe(file.name, \
language=config.HARPYIA_LANGUAGE, initial_prompt=config.HARPYIA_PROMPT)