From 37adf7474588ebd6041e4e995c754a0747ef2d55 Mon Sep 17 00:00:00 2001 From: Sviatoslav Tsariov Date: Fri, 22 Mar 2024 18:59:42 +0300 Subject: [PATCH] Implemented strategies for sos and number recognition Deleted message sender and promt service Implemented fast whisper, but it is not working WavStack refactored into QueueStack, which can use different strategies for proccessing --- src/app.py | 84 +++++++++---------- src/config.py | 13 ++- src/file_stack.py | 32 ------- src/message/__init__.py | 7 +- src/message/message.py | 17 ---- src/message/message_composer.py | 9 -- src/message/message_sender/__init__.py | 3 - src/message/message_sender/message_sender.py | 8 -- .../message_sender/message_sender_strategy.py | 6 -- src/message/message_sender/rat_strategy.py | 12 --- src/message/message_service.py | 24 ++++++ src/message/prompt_service.py | 16 ---- src/message/strategies/__init__.py | 3 + .../strategies/base_message_strategy.py | 14 ++++ .../strategies/number_message_strategy.py | 31 +++++++ .../strategies/sos_message_strategy.py | 35 ++++++++ src/queue_stack/__init__.py | 1 + src/queue_stack/queue_stack.py | 51 +++++++++++ src/queue_stack/strategies/__init__.py | 2 + .../strategies/base_process_strategy.py | 6 ++ .../strategies/recognize_and_send_strategy.py | 14 ++++ src/recognizer/__init__.py | 5 +- src/recognizer/fast_whisper_strategy.py | 37 -------- src/recognizer/recognizer.py | 14 +++- src/recognizer/recognizer_strategy.py | 6 -- src/recognizer/strategies/__init__.py | 3 + .../strategies/base_recognizer_strategy.py | 6 ++ .../strategies/fast_whisper_strategy.py | 59 +++++++++++++ src/recognizer/strategies/whisper_strategy.py | 12 +++ src/recognizer/whisper_strategy.py | 12 --- 30 files changed, 322 insertions(+), 220 deletions(-) delete mode 100644 src/file_stack.py delete mode 100644 src/message/message.py delete mode 100644 src/message/message_composer.py delete mode 100644 src/message/message_sender/__init__.py delete mode 100644 src/message/message_sender/message_sender.py delete mode 100644 src/message/message_sender/message_sender_strategy.py delete mode 100644 src/message/message_sender/rat_strategy.py create mode 100644 src/message/message_service.py delete mode 100644 src/message/prompt_service.py create mode 100644 src/message/strategies/__init__.py create mode 100644 src/message/strategies/base_message_strategy.py create mode 100644 src/message/strategies/number_message_strategy.py create mode 100644 src/message/strategies/sos_message_strategy.py create mode 100644 src/queue_stack/__init__.py create mode 100644 src/queue_stack/queue_stack.py create mode 100644 src/queue_stack/strategies/__init__.py create mode 100644 src/queue_stack/strategies/base_process_strategy.py create mode 100644 src/queue_stack/strategies/recognize_and_send_strategy.py delete mode 100644 src/recognizer/fast_whisper_strategy.py delete mode 100644 src/recognizer/recognizer_strategy.py create mode 100644 src/recognizer/strategies/__init__.py create mode 100644 src/recognizer/strategies/base_recognizer_strategy.py create mode 100644 src/recognizer/strategies/fast_whisper_strategy.py create mode 100644 src/recognizer/strategies/whisper_strategy.py delete mode 100644 src/recognizer/whisper_strategy.py diff --git a/src/app.py b/src/app.py index 25ac26a..fc612fe 100644 --- a/src/app.py +++ b/src/app.py @@ -1,67 +1,59 @@ from flask import Flask, abort, request from tempfile import NamedTemporaryFile -from dotenv import load_dotenv -import os -import whisper -import torch + import sys -import re -load_dotenv() +import config -HARPYIA_PROMPT = os.getenv('HARPYIA_PROMPT') or 'спасите помогите на помощь пожар' -HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'medium' -HARPYIA_LANGUAGE = os.getenv('HARPYIA_LANGUAGE') or 'ru' +from queue_stack import QueueStack +from queue_stack.strategies import RecognizeAndSendStrategy -# Check if NVIDIA GPU is available -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +from recognizer import Recognizer +from recognizer.strategies import WhisperStrategy, FastWhisperStrategy -# Load the Whisper model: -model = whisper.load_model(HARPYIA_MODEL, device=DEVICE) +from message import MessageService +from message.strategies import SosMessageStrategy, NumberMessageStrategy app = Flask(__name__) +whisper_recognizer = Recognizer(WhisperStrategy()) +fast_whisper_recognizer = Recognizer(FastWhisperStrategy()) + +sos_message_service = MessageService(SosMessageStrategy()) +number_message_service = MessageService(NumberMessageStrategy()) + +queue_stack = QueueStack(RecognizeAndSendStrategy()) +queue_stack.start_loop_in_thread() + @app.route("/") def hello(): - return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize_number' route." + return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize-number' route." -def recognize_files(handler_fn): - if not request.files: - abort(400) +def recognize_files(message_service: MessageService): + if not request.files: + abort(400) - results = [] + results = [] - for filename, handle in request.files.items(): - temp = NamedTemporaryFile() - handle.save(temp) - result = model.transcribe(temp.name, language=HARPYIA_LANGUAGE, initial_prompt=HARPYIA_PROMPT) - results.append({ - 'filename': filename, - 'transcript': handler_fn(result['text']), - }) + for filename, handle in request.files.items(): + temp = NamedTemporaryFile() + handle.save(temp) - print(results, file=sys.stderr) - return {'results': results} + results.append(queue_stack.append_and_await(( + temp, + whisper_recognizer, + message_service, + config.HARPYIA_LANGUAGE, + message_service.get_prompt() + ))) + + print(results, file=sys.stderr) + return {'results': results} @app.route('/recognize', methods=['POST']) def recognize(): - return recognize_files(lambda text: text) + return recognize_files(sos_message_service) -@app.route('/recognize_number', methods=['POST']) +@app.route('/recognize-number', methods=['POST']) def recognize_number(): - return recognize_files(transfer_and_clean) - -def transfer_and_clean(input_string): - number_mapping = { - "один": "1", - "два": "2", - "три": "3" - } - - for word, number in number_mapping.items(): - input_string = input_string.replace(word, number) - - input_string = re.sub(r'[^\d]+', '', input_string) - - return input_string - + return recognize_files(number_message_service) \ No newline at end of file diff --git a/src/config.py b/src/config.py index e016740..9f9cba9 100644 --- a/src/config.py +++ b/src/config.py @@ -4,9 +4,18 @@ from dotenv import load_dotenv load_dotenv() -HARPYIA_PROMPT = os.getenv('HARPYIA_PROMPT') or 'спасите помогите на помощь пожар' -HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'medium' +HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'small' HARPYIA_LANGUAGE = os.getenv('HARPYIA_LANGUAGE') or 'ru' +HARPYIA_SAMPLE_RATE = os.getenv('HARPYIA_SAMPLE_RATE') or 160000 + +WHISPER_NUM_WORKERS = os.getenv('WHISPER_NUM_WORKERS') or 6 +WHISPER_CPU_THREADS = os.getenv('WHISPER_CPU_THREADS') or 10 +WHISPER_BEAM_SIZE = os.getenv('WHISPER_BEAM_SIZE') or 5 + +SOS_PROMPT = os.getenv('SOS_PROMPT') or 'спасите помогите помощь пожар караул кирилл' +NUMBER_PROMPT = os.getenv('NUMBER_PROMPT') or 'один два три четыре пять шесть семь восемь девять десять одинадцать двенадцать тринадцать сто сот' + +RAT_URL = os.getenv('RAT_URL') or 'localhost:8081' # Check if NVIDIA GPU is available DEVICE = "cuda" if torch.cuda.is_available() else "cpu" \ No newline at end of file diff --git a/src/file_stack.py b/src/file_stack.py deleted file mode 100644 index 28b223f..0000000 --- a/src/file_stack.py +++ /dev/null @@ -1,32 +0,0 @@ -from threading import Thread - -from recognizer import Recognizer -from message import MessageComposer - -class WavStack: - def __init__(self, recognizer: Recognizer, message_composer: MessageComposer): - self._stack = [] - self._recognizer = recognizer - self._message_composer = message_composer - self._running = False - - def append(self, file): - self._stack.append(file) - - def loop(self): - self._running = True - while self._running: - if self._stack: - file = self._stack.pop(0) - recognized_text = self._recognizer.recognize(file) - message = self._message_composer.compose(recognized_text) - - if message.has_prompt(): - message.send() - - def start_loop_in_thread(self): - thread = Thread(target=self.loop) - thread.start() - - def stop_loop(self): - self._running = False \ No newline at end of file diff --git a/src/message/__init__.py b/src/message/__init__.py index 52ede6f..92b98bd 100644 --- a/src/message/__init__.py +++ b/src/message/__init__.py @@ -1,6 +1 @@ -from message.prompt_service import PromptService -from message.message_sender.message_sender import MessageSender -from message.message_composer import MessageComposer -from message.message import Message - -import message.message_sender as message_sender +from message.message_service import MessageService \ No newline at end of file diff --git a/src/message/message.py b/src/message/message.py deleted file mode 100644 index 6eece7e..0000000 --- a/src/message/message.py +++ /dev/null @@ -1,17 +0,0 @@ -from message import PromptService, MessageSender - -class Message: - def __init__(self, prompt_service: PromptService, message_sender: MessageSender, \ - recognized_text: str): - self._prompt_service = prompt_service - self._message_sender = message_sender - self._recognized_text = recognized_text - - def has_prompt(self) -> bool: - return self._prompt_service.has_prompt(self._recognized_text) - - def send(self) -> None: - self._message_sender.send(self._generate_response()) - - def _generate_response(self) -> str: - return self._prompt_service.filter_words_with_prompt(self._recognized_text) \ No newline at end of file diff --git a/src/message/message_composer.py b/src/message/message_composer.py deleted file mode 100644 index 8dca533..0000000 --- a/src/message/message_composer.py +++ /dev/null @@ -1,9 +0,0 @@ -from message import Message, PromptService, MessageSender - -class MessageComposer: - def __init__(self, prompt_service: PromptService, message_sender: MessageSender): - self._prompt_service = prompt_service - self._message_sender = message_sender - - def compose(self, recognized_text) -> Message: - return Message(self._prompt_service, self._message_sender, recognized_text) diff --git a/src/message/message_sender/__init__.py b/src/message/message_sender/__init__.py deleted file mode 100644 index 34dfeb2..0000000 --- a/src/message/message_sender/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from message.message_sender.message_sender import MessageSender -from message.message_sender.message_sender_strategy import MessageSenderStrategy -from message.message_sender.rat_strategy import RatStrategy diff --git a/src/message/message_sender/message_sender.py b/src/message/message_sender/message_sender.py deleted file mode 100644 index 4705e60..0000000 --- a/src/message/message_sender/message_sender.py +++ /dev/null @@ -1,8 +0,0 @@ -from message.message_sender import MessageSenderStrategy - -class MessageSender: - def __init__(self, strategy: MessageSenderStrategy) -> None: - self._strategy = strategy - - def send(self, message) -> None: - self._strategy.send(message) diff --git a/src/message/message_sender/message_sender_strategy.py b/src/message/message_sender/message_sender_strategy.py deleted file mode 100644 index 7855a58..0000000 --- a/src/message/message_sender/message_sender_strategy.py +++ /dev/null @@ -1,6 +0,0 @@ -from abc import ABC, abstractmethod - -class MessageSenderStrategy(ABC): - @abstractmethod - def send(self, message) -> None: - pass diff --git a/src/message/message_sender/rat_strategy.py b/src/message/message_sender/rat_strategy.py deleted file mode 100644 index 3954229..0000000 --- a/src/message/message_sender/rat_strategy.py +++ /dev/null @@ -1,12 +0,0 @@ -import requests - -from message.message_sender import MessageSenderStrategy - -MESSAGE_ENDPOINT = '/message' - -class RatStrategy(MessageSenderStrategy): - def __init__(self, url): - self._url = url - - def send(self, message): - requests.post(self._url + MESSAGE_ENDPOINT, json={'message': message}) diff --git a/src/message/message_service.py b/src/message/message_service.py new file mode 100644 index 0000000..c78e8ea --- /dev/null +++ b/src/message/message_service.py @@ -0,0 +1,24 @@ +import sys +from message.strategies import BaseMessageStrategy + +class MessageService: + def __init__(self, strategy: BaseMessageStrategy) -> None: + self._strategy = strategy + + def get_prompt(self) -> str: + self._strategy.get_prompt() + + def transfer(self, text: str) -> any: + return self._strategy.transfer(text) + + def send(self, message: str) -> any: + self._strategy.send(message) + + def transfer_and_send(self, recognized_result: any) -> any: + message = self.transfer(recognized_result) + + if message: + self.send(message) + + print('Sending message:', recognized_result, file=sys.stderr) + return message \ No newline at end of file diff --git a/src/message/prompt_service.py b/src/message/prompt_service.py deleted file mode 100644 index b443e5b..0000000 --- a/src/message/prompt_service.py +++ /dev/null @@ -1,16 +0,0 @@ -class PromptService: - def __init__(self, prompt): - self._prompt = prompt - - def has_prompt(self, text: str) -> bool: - for part in text.split(' '): - return part in self._prompt.split(' ') - - def filter_words_with_prompt(self, text: str) -> str: - words = [] - - for part in text.split(' '): - if part in self._prompt.split(' '): - words.append(part) - - return words diff --git a/src/message/strategies/__init__.py b/src/message/strategies/__init__.py new file mode 100644 index 0000000..8991a4f --- /dev/null +++ b/src/message/strategies/__init__.py @@ -0,0 +1,3 @@ +from message.strategies.base_message_strategy import BaseMessageStrategy +from message.strategies.sos_message_strategy import SosMessageStrategy +from message.strategies.number_message_strategy import NumberMessageStrategy \ No newline at end of file diff --git a/src/message/strategies/base_message_strategy.py b/src/message/strategies/base_message_strategy.py new file mode 100644 index 0000000..fd45ed0 --- /dev/null +++ b/src/message/strategies/base_message_strategy.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod + +class BaseMessageStrategy(ABC): + @abstractmethod + def get_prompt() -> str: + pass + + @abstractmethod + def transfer(self, text: str) -> any: + pass + + @abstractmethod + def send(self, message: str) -> any: + pass \ No newline at end of file diff --git a/src/message/strategies/number_message_strategy.py b/src/message/strategies/number_message_strategy.py new file mode 100644 index 0000000..054bd2d --- /dev/null +++ b/src/message/strategies/number_message_strategy.py @@ -0,0 +1,31 @@ +import re + +import config +from message.strategies import BaseMessageStrategy + +class NumberMessageStrategy(BaseMessageStrategy): + def __init__(self, prompt=config.NUMBER_PROMPT) -> None: + self._prompt = prompt + + def get_prompt(self): + return self._prompt + + def transfer(self, recognized_result: any) -> str: + return self._transfer_and_clean(recognized_result['text']) + + def _transfer_and_clean(self, text: str) -> str: + number_mapping = { + "один": "1", + "два": "2", + "три": "3" + } + + for word, number in number_mapping.items(): + transfered_text = text.replace(word, number) + + transfered_text = re.sub(r'[^\d]+', '', transfered_text) + + return {'recognized': transfered_text} + + def send(self, message: str) -> None: + pass \ No newline at end of file diff --git a/src/message/strategies/sos_message_strategy.py b/src/message/strategies/sos_message_strategy.py new file mode 100644 index 0000000..39dee28 --- /dev/null +++ b/src/message/strategies/sos_message_strategy.py @@ -0,0 +1,35 @@ +from typing import List +import requests + +import config +from message.strategies import BaseMessageStrategy + +MESSAGE_ENDPOINT = '/message' + +class SosMessageStrategy(BaseMessageStrategy): + def __init__(self, prompt=config.SOS_PROMPT, url=config.RAT_URL) -> None: + self._prompt = prompt + self._url = url + + def get_prompt(self): + return self._prompt + + def transfer(self, recognized_result: any) -> str: + return { + 'transcript': recognized_result['text'], + 'results': self._filter_words_with_prompt(recognized_result['text']), + 'segments': recognized_result['segments'] + } + + def _filter_words_with_prompt(self, text: str) -> str: + words = [] + + for prompt in self._prompt.split(' '): + if prompt in text.lower(): + words.append(prompt) + + return words + + def send(self, message) -> any: + pass + #return requests.post(self._url + MESSAGE_ENDPOINT, json={'message': message}) diff --git a/src/queue_stack/__init__.py b/src/queue_stack/__init__.py new file mode 100644 index 0000000..b065d1b --- /dev/null +++ b/src/queue_stack/__init__.py @@ -0,0 +1 @@ +from queue_stack.queue_stack import QueueStack diff --git a/src/queue_stack/queue_stack.py b/src/queue_stack/queue_stack.py new file mode 100644 index 0000000..6c38e20 --- /dev/null +++ b/src/queue_stack/queue_stack.py @@ -0,0 +1,51 @@ +import sys + +from threading import Thread, Event, Lock + +from queue_stack.strategies import BaseProcessStrategy + +class QueueStack: + def __init__(self, strategy: BaseProcessStrategy) -> None: + self._stack = [] + self._strategy = strategy + + self._lock = Lock() + self._running = False + + self._last_response = None + + def append(self, args, event=None) -> None: + with self._lock: + self._stack.append((args, event)) + + def append_and_await(self, args) -> any: + event = Event() + self.append(args, event=event) + + event.wait() + event.clear() + + return self._last_response + + def loop(self) -> None: + self._running = True + + while self._running: + with self._lock: + if self._stack: + print('Stack length:', len(self._stack), file=sys.stderr) + (args, event) = self._stack.pop(0) + self._last_response = self._process(*args) + + if event: + event.set() + + def _process(self, *args, **kwargs) -> any: + return self._strategy.process(*args, **kwargs) + + def start_loop_in_thread(self) -> None: + thread = Thread(target=self.loop) + thread.start() + + def stop_loop(self) -> None: + self._running = False \ No newline at end of file diff --git a/src/queue_stack/strategies/__init__.py b/src/queue_stack/strategies/__init__.py new file mode 100644 index 0000000..e274e9b --- /dev/null +++ b/src/queue_stack/strategies/__init__.py @@ -0,0 +1,2 @@ +from queue_stack.strategies.base_process_strategy import BaseProcessStrategy +from queue_stack.strategies.recognize_and_send_strategy import RecognizeAndSendStrategy \ No newline at end of file diff --git a/src/queue_stack/strategies/base_process_strategy.py b/src/queue_stack/strategies/base_process_strategy.py new file mode 100644 index 0000000..00d7792 --- /dev/null +++ b/src/queue_stack/strategies/base_process_strategy.py @@ -0,0 +1,6 @@ +from abc import ABC, abstractmethod + +class BaseProcessStrategy(ABC): + @abstractmethod + def process(self, *args, **kwargs) -> any: + pass diff --git a/src/queue_stack/strategies/recognize_and_send_strategy.py b/src/queue_stack/strategies/recognize_and_send_strategy.py new file mode 100644 index 0000000..1269ff5 --- /dev/null +++ b/src/queue_stack/strategies/recognize_and_send_strategy.py @@ -0,0 +1,14 @@ +import sys + +from queue_stack.strategies import BaseProcessStrategy +from message import MessageService +from recognizer import Recognizer + +class RecognizeAndSendStrategy(BaseProcessStrategy): + def process(self, file, recognizer: Recognizer, message_service: MessageService, language, prompt) -> any: + + result = recognizer.recognize(file, language=language, prompt=prompt) + message = message_service.transfer_and_send(result) + print(message, file=sys.stderr) + + return message \ No newline at end of file diff --git a/src/recognizer/__init__.py b/src/recognizer/__init__.py index f892ae5..402b39d 100644 --- a/src/recognizer/__init__.py +++ b/src/recognizer/__init__.py @@ -1,4 +1 @@ -from recognizer.recognizer import Recognizer -from recognizer.recognizer_strategy import RecognizerStrategy -from recognizer.whisper_strategy import WhisperStrategy -from recognizer.fast_whisper_strategy import FastWhisperStrategy +from recognizer.recognizer import Recognizer \ No newline at end of file diff --git a/src/recognizer/fast_whisper_strategy.py b/src/recognizer/fast_whisper_strategy.py deleted file mode 100644 index 458f932..0000000 --- a/src/recognizer/fast_whisper_strategy.py +++ /dev/null @@ -1,37 +0,0 @@ -import whisper -from faster_whisper import WhisperModel - -import config -from recognizer import RecognizerStrategy - -class FastWhisperStrategy(RecognizerStrategy): - def __init__(self) -> None: - self._model = WhisperModel( - model_size=config.HARPYIA_MODEL, - device=config.DEVICE, - num_workers=6, - cpu_threads=10, - # in_memory=True, - ) - - def recognize(self, file) -> str: - audio = self._prepare_file(file.name) - return self._transcribe(audio) - - def _prepare_file(self, filename: str): - audio = whisper.load_audio(filename, sr=16000) - audio = whisper.pad_or_trim(audio) - return audio - - - def _transcribe(self, audio): - segments, _ = self._model.transcribe( - audio, - language=config.HARPYIA_LANGUAGE, - initial_prompt=config.HARPYIA_PROMPT, - condition_on_previous_text=False, - vad_filter=True, - beam_size=5, - ) - - return segments \ No newline at end of file diff --git a/src/recognizer/recognizer.py b/src/recognizer/recognizer.py index cb7d704..d78135d 100644 --- a/src/recognizer/recognizer.py +++ b/src/recognizer/recognizer.py @@ -1,8 +1,14 @@ -from recognizer import RecognizerStrategy +import sys + +import config +from recognizer.strategies import BaseRecognizerStrategy class Recognizer: - def __init__(self, strategy: RecognizerStrategy): + def __init__(self, strategy: BaseRecognizerStrategy) -> None: self._strategy = strategy - def recognize(self, file) -> str: - self._strategy.recognize(file) + def recognize(self, file, language, prompt) -> str: + result = self._strategy.recognize(file, language=language, prompt=prompt) + + print(f'Result: {result}', file=sys.stderr) + return result \ No newline at end of file diff --git a/src/recognizer/recognizer_strategy.py b/src/recognizer/recognizer_strategy.py deleted file mode 100644 index 16b718a..0000000 --- a/src/recognizer/recognizer_strategy.py +++ /dev/null @@ -1,6 +0,0 @@ -from abc import ABC, abstractmethod - -class RecognizerStrategy(ABC): - @abstractmethod - def recognize(self, file) -> str: - pass diff --git a/src/recognizer/strategies/__init__.py b/src/recognizer/strategies/__init__.py new file mode 100644 index 0000000..d73c755 --- /dev/null +++ b/src/recognizer/strategies/__init__.py @@ -0,0 +1,3 @@ +from recognizer.strategies.base_recognizer_strategy import BaseRecognizerStrategy +from recognizer.strategies.whisper_strategy import WhisperStrategy +from recognizer.strategies.fast_whisper_strategy import FastWhisperStrategy diff --git a/src/recognizer/strategies/base_recognizer_strategy.py b/src/recognizer/strategies/base_recognizer_strategy.py new file mode 100644 index 0000000..0bdd03a --- /dev/null +++ b/src/recognizer/strategies/base_recognizer_strategy.py @@ -0,0 +1,6 @@ +from abc import ABC, abstractmethod + +class BaseRecognizerStrategy(ABC): + @abstractmethod + def recognize(self, file, language, prompt) -> any: + pass diff --git a/src/recognizer/strategies/fast_whisper_strategy.py b/src/recognizer/strategies/fast_whisper_strategy.py new file mode 100644 index 0000000..5d0cdac --- /dev/null +++ b/src/recognizer/strategies/fast_whisper_strategy.py @@ -0,0 +1,59 @@ +import sys + +import whisper +from faster_whisper import WhisperModel + +import config +from recognizer.strategies import BaseRecognizerStrategy + +class FastWhisperStrategy(BaseRecognizerStrategy): + def __init__(self) -> None: + self._model = WhisperModel( + model_size_or_path=config.HARPYIA_MODEL, + device=config.DEVICE, + num_workers=config.WHISPER_NUM_WORKERS, + cpu_threads=config.WHISPER_CPU_THREADS + ) + + def recognize(self, file, language, prompt) -> any: + audio = self._prepare_file(file.name) + return self._transcribe(audio, language, prompt) + + def _prepare_file(self, filename: str): + audio = whisper.load_audio(filename, sr=config.HARPYIA_SAMPLE_RATE) + audio = whisper.pad_or_trim(audio) + return audio + + def _transcribe(self, audio, language, prompt): + segments, _ = self._model.transcribe( + audio, + language=language, + initial_prompt=prompt, + condition_on_previous_text=False, + vad_filter=True, + beam_size=config.WHISPER_BEAM_SIZE, + ) + + print('Segments:', file=sys.stderr) + for i in segments: + print(i, file=sys.stderr) + + words = [] + for segment in list(segments): + words.append(segment.text) + + return { + 'text': ' '.join(words), + 'segments': { + 'id': None, + 'seek': None, + 'start': None, + 'end': None, + 'text': None, + 'tokens': None, + 'temperature': None, + 'avg_logprob': None, + 'compression_ratio': None, + 'no_speech_prob': None, + } + } \ No newline at end of file diff --git a/src/recognizer/strategies/whisper_strategy.py b/src/recognizer/strategies/whisper_strategy.py new file mode 100644 index 0000000..05cdf27 --- /dev/null +++ b/src/recognizer/strategies/whisper_strategy.py @@ -0,0 +1,12 @@ +import whisper + +import config +from recognizer.strategies import BaseRecognizerStrategy + +class WhisperStrategy(BaseRecognizerStrategy): + def __init__(self) -> None: + self._model = whisper.load_model(config.HARPYIA_MODEL, device=config.DEVICE) + + def recognize(self, file, language, prompt) -> any: + return self._model.transcribe(file.name, \ + language=language, initial_prompt=prompt) \ No newline at end of file diff --git a/src/recognizer/whisper_strategy.py b/src/recognizer/whisper_strategy.py deleted file mode 100644 index 2472085..0000000 --- a/src/recognizer/whisper_strategy.py +++ /dev/null @@ -1,12 +0,0 @@ -import whisper - -import config -from recognizer import RecognizerStrategy - -class WhisperStrategy(RecognizerStrategy): - def __init__(self) -> None: - self._model = whisper.load_model(config.HARPYIA_MODEL, device=config.DEVICE) - - def recognize(self, file) -> str: - return self._model.transcribe(file.name, \ - language=config.HARPYIA_LANGUAGE, initial_prompt=config.HARPYIA_PROMPT) \ No newline at end of file