Compare commits

..

5 Commits

Author SHA1 Message Date
3801925b72 Fixed typo
QueueStack -> ProcessStack
2024-03-29 10:39:10 +03:00
70e6e6ca90 Merge pull request 'feature/file-stack' (#1) from feature/file-stack into main
Reviewed-on: #1
2024-03-22 16:00:50 +00:00
37adf74745 Implemented strategies for sos and number recognition
Deleted message sender and promt service

Implemented fast whisper, but it is not working

WavStack refactored into QueueStack, which can use different strategies for proccessing
2024-03-22 18:59:42 +03:00
dbbf845e56 Added whisper and fast whisper implementation 2024-03-20 12:51:14 +03:00
e89122cb76 Implemented new architecture
Created message service responsible for searching the prompts inside the recognized text and sending it to the client.

Created recognizer with two strategies: whisper and Dany's fast whisper.

Implemented file stack which works in the separated thread, sends the file to the recognizer and after that sends the message to the client (Rat, for example).
2024-03-19 19:01:36 +03:00
24 changed files with 340 additions and 158 deletions

View File

@ -140,8 +140,3 @@ ENV/
# mypy # mypy
.mypy_cache/ .mypy_cache/
#
**/*.wav
**/*.mp3

View File

@ -1,131 +1,59 @@
from ctranslate2.extensions import asyncio
from flask import Flask, abort, request from flask import Flask, abort, request
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from dotenv import load_dotenv
import os
from torch.functional import Tensor
import whisper
import torch
import sys import sys
import re
from faster_whisper import WhisperModel
from test_utils import elapsed_time
from whisper_timestamped import transcribe_timestamped
from multiprocessing import Process
load_dotenv() import config
model_size = "small"
# tiny.en, tiny, base.en, base, small.en, small, medium.en, medium, large-v1, large-v2, large-v3, large, distil-large-v2, distil-medium.en, distil-small.en
HARPYIA_PROMPT = os.getenv("HARPYIA_PROMPT") or "спасите помогите на помощь пожар"
HARPYIA_MODEL = os.getenv("HARPYIA_MODEL") or "medium"
HARPYIA_LANGUAGE = os.getenv("HARPYIA_LANGUAGE") or "ru"
# Check if NVIDIA GPU is available from process_stack import ProcessQueue
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" from process_stack.strategies import RecognizeAndSendStrategy
DEVICE = "cpu"
# Load the Whisper model: from recognizer import Recognizer
model = WhisperModel( from recognizer.strategies import WhisperStrategy, FastWhisperStrategy
model_size,
device=DEVICE, from message import MessageService
num_workers=6, from message.strategies import SosMessageStrategy, NumberMessageStrategy
cpu_threads=10,
# in_memory=True,
)
app = Flask(__name__) app = Flask(__name__)
whisper_recognizer = Recognizer(WhisperStrategy())
fast_whisper_recognizer = Recognizer(FastWhisperStrategy())
sos_message_service = MessageService(SosMessageStrategy())
number_message_service = MessageService(NumberMessageStrategy())
process_stack = ProcessQueue(RecognizeAndSendStrategy())
process_stack.start_loop_in_thread()
@app.route("/") @app.route("/")
def hello(): def hello():
return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize_number' route." return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize-number' route."
def recognize_files(message_service: MessageService):
def recognize_files(handler_fn):
if not request.files: if not request.files:
abort(400) abort(400)
results = [] results = []
asyncio.get_running_loop()
for filename, handle in request.files.items(): for filename, handle in request.files.items():
temp = NamedTemporaryFile() temp = NamedTemporaryFile()
handle.save(temp) handle.save(temp)
audio = prepare_file(temp.name) results.append(process_stack.append_and_await((
res = trans(audio) temp,
whisper_recognizer,
results.append( message_service,
{ config.HARPYIA_LANGUAGE,
"filename": filename, message_service.get_prompt()
"transcript": res, )))
}
)
print(results, file=sys.stderr) print(results, file=sys.stderr)
return {"results": results} return {'results': results}
@app.route('/recognize', methods=['POST'])
def recognize():
return recognize_files(sos_message_service)
initprompt = [ @app.route('/recognize-number', methods=['POST'])
"один",
"два",
"три",
"четыре",
"пять",
"шесть",
"семь",
"восемь",
"девять",
"десять",
"одинадцать",
"двенадцать",
"тренадцать",
"сто",
"сот",
]
@elapsed_time
def trans(audio):
segments, _ = model.transcribe(
audio,
language=HARPYIA_LANGUAGE,
initial_prompt="семь сот сто",
condition_on_previous_text=False,
vad_filter=True,
beam_size=5,
)
words = []
for e in list(segments):
words.append(e.text)
return " ".join(words)
@elapsed_time
def prepare_file(filename: str):
audio = whisper.load_audio(filename, sr=16000)
audio = whisper.pad_or_trim(audio)
return audio
@app.route("/recognize", methods=["POST"])
async def recognize():
return recognize_files(lambda text: text)
@app.route("/recognize_number", methods=["POST"])
def recognize_number(): def recognize_number():
return recognize_files(transfer_and_clean) return recognize_files(number_message_service)
def transfer_and_clean(input_string):
number_mapping = {"один": "1", "два": "2", "три": "3"}
for word, number in number_mapping.items():
input_string = input_string.replace(word, number)
input_string = re.sub(r"[^\d]+", "", input_string)
return input_string

21
src/config.py Normal file
View File

@ -0,0 +1,21 @@
import os
import torch
from dotenv import load_dotenv
load_dotenv()
HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'medium'
HARPYIA_LANGUAGE = os.getenv('HARPYIA_LANGUAGE') or 'ru'
HARPYIA_SAMPLE_RATE = os.getenv('HARPYIA_SAMPLE_RATE') or 160000
WHISPER_NUM_WORKERS = os.getenv('WHISPER_NUM_WORKERS') or 6
WHISPER_CPU_THREADS = os.getenv('WHISPER_CPU_THREADS') or 10
WHISPER_BEAM_SIZE = os.getenv('WHISPER_BEAM_SIZE') or 5
SOS_PROMPT = os.getenv('SOS_PROMPT') or 'спасите помогите помощь пожар караул гит'
NUMBER_PROMPT = os.getenv('NUMBER_PROMPT') or 'один два три четыре пять шесть семь восемь девять десять одинадцать двенадцать тринадцать сто сот'
RAT_URL = os.getenv('RAT_URL') or 'localhost:8081'
# Check if NVIDIA GPU is available
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

1
src/message/__init__.py Normal file
View File

@ -0,0 +1 @@
from message.message_service import MessageService

View File

@ -0,0 +1,24 @@
import sys
from message.strategies import BaseMessageStrategy
class MessageService:
def __init__(self, strategy: BaseMessageStrategy) -> None:
self._strategy = strategy
def get_prompt(self) -> str:
self._strategy.get_prompt()
def transfer(self, text: str) -> any:
return self._strategy.transfer(text)
def send(self, message: str) -> any:
self._strategy.send(message)
def transfer_and_send(self, recognized_result: any) -> any:
message = self.transfer(recognized_result)
if message:
self.send(message)
print('Sending message:', recognized_result, file=sys.stderr)
return message

View File

@ -0,0 +1,3 @@
from message.strategies.base_message_strategy import BaseMessageStrategy
from message.strategies.sos_message_strategy import SosMessageStrategy
from message.strategies.number_message_strategy import NumberMessageStrategy

View File

@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
class BaseMessageStrategy(ABC):
@abstractmethod
def get_prompt() -> str:
pass
@abstractmethod
def transfer(self, text: str) -> any:
pass
@abstractmethod
def send(self, message: str) -> any:
pass

View File

@ -0,0 +1,31 @@
import re
import config
from message.strategies import BaseMessageStrategy
class NumberMessageStrategy(BaseMessageStrategy):
def __init__(self, prompt=config.NUMBER_PROMPT) -> None:
self._prompt = prompt
def get_prompt(self):
return self._prompt
def transfer(self, recognized_result: any) -> str:
return self._transfer_and_clean(recognized_result['text'])
def _transfer_and_clean(self, text: str) -> str:
number_mapping = {
"один": "1",
"два": "2",
"три": "3"
}
for word, number in number_mapping.items():
transfered_text = text.replace(word, number)
transfered_text = re.sub(r'[^\d]+', '', transfered_text)
return {'recognized': transfered_text}
def send(self, message: str) -> None:
pass

View File

@ -0,0 +1,35 @@
from typing import List
import requests
import config
from message.strategies import BaseMessageStrategy
MESSAGE_ENDPOINT = '/message'
class SosMessageStrategy(BaseMessageStrategy):
def __init__(self, prompt=config.SOS_PROMPT, url=config.RAT_URL) -> None:
self._prompt = prompt
self._url = url
def get_prompt(self):
return self._prompt
def transfer(self, recognized_result: any) -> str:
return {
'transcript': recognized_result['text'],
'results': self._filter_words_with_prompt(recognized_result['text']),
'segments': recognized_result['segments']
}
def _filter_words_with_prompt(self, text: str) -> str:
words = []
for prompt in self._prompt.split(' '):
if prompt in text.lower():
words.append(prompt)
return words
def send(self, message) -> any:
pass
#return requests.post(self._url + MESSAGE_ENDPOINT, json={'message': message})

View File

@ -0,0 +1 @@
from process_stack.process_stack import ProcessQueue

View File

@ -0,0 +1,51 @@
import sys
from threading import Thread, Event, Lock
from process_stack.strategies import BaseProcessStrategy
class ProcessQueue:
def __init__(self, strategy: BaseProcessStrategy) -> None:
self._stack = []
self._strategy = strategy
self._lock = Lock()
self._running = False
self._last_response = None
def append(self, args, event=None) -> None:
with self._lock:
self._stack.append((args, event))
def append_and_await(self, args) -> any:
event = Event()
self.append(args, event=event)
event.wait()
event.clear()
return self._last_response
def loop(self) -> None:
self._running = True
while self._running:
with self._lock:
if self._stack:
print('Stack length:', len(self._stack), file=sys.stderr)
(args, event) = self._stack.pop(0)
self._last_response = self._process(*args)
if event:
event.set()
def _process(self, *args, **kwargs) -> any:
return self._strategy.process(*args, **kwargs)
def start_loop_in_thread(self) -> None:
thread = Thread(target=self.loop)
thread.start()
def stop_loop(self) -> None:
self._running = False

View File

@ -0,0 +1,2 @@
from process_stack.strategies.base_process_strategy import BaseProcessStrategy
from process_stack.strategies.recognize_and_send_strategy import RecognizeAndSendStrategy

View File

@ -0,0 +1,6 @@
from abc import ABC, abstractmethod
class BaseProcessStrategy(ABC):
@abstractmethod
def process(self, *args, **kwargs) -> any:
pass

View File

@ -0,0 +1,14 @@
import sys
from process_stack.strategies import BaseProcessStrategy
from message import MessageService
from recognizer import Recognizer
class RecognizeAndSendStrategy(BaseProcessStrategy):
def process(self, file, recognizer: Recognizer, message_service: MessageService, language, prompt) -> any:
result = recognizer.recognize(file, language=language, prompt=prompt)
message = message_service.transfer_and_send(result)
print(message, file=sys.stderr)
return message

View File

@ -0,0 +1 @@
from recognizer.recognizer import Recognizer

View File

@ -0,0 +1,14 @@
import sys
import config
from recognizer.strategies import BaseRecognizerStrategy
class Recognizer:
def __init__(self, strategy: BaseRecognizerStrategy) -> None:
self._strategy = strategy
def recognize(self, file, language, prompt) -> str:
result = self._strategy.recognize(file, language=language, prompt=prompt)
print(f'Result: {result}', file=sys.stderr)
return result

View File

@ -0,0 +1,3 @@
from recognizer.strategies.base_recognizer_strategy import BaseRecognizerStrategy
from recognizer.strategies.whisper_strategy import WhisperStrategy
from recognizer.strategies.fast_whisper_strategy import FastWhisperStrategy

View File

@ -0,0 +1,6 @@
from abc import ABC, abstractmethod
class BaseRecognizerStrategy(ABC):
@abstractmethod
def recognize(self, file, language, prompt) -> any:
pass

View File

@ -0,0 +1,59 @@
import sys
import whisper
from faster_whisper import WhisperModel
import config
from recognizer.strategies import BaseRecognizerStrategy
class FastWhisperStrategy(BaseRecognizerStrategy):
def __init__(self) -> None:
self._model = WhisperModel(
model_size_or_path=config.HARPYIA_MODEL,
device=config.DEVICE,
num_workers=config.WHISPER_NUM_WORKERS,
cpu_threads=config.WHISPER_CPU_THREADS
)
def recognize(self, file, language, prompt) -> any:
audio = self._prepare_file(file.name)
return self._transcribe(audio, language, prompt)
def _prepare_file(self, filename: str):
audio = whisper.load_audio(filename, sr=config.HARPYIA_SAMPLE_RATE)
audio = whisper.pad_or_trim(audio)
return audio
def _transcribe(self, audio, language, prompt):
segments, _ = self._model.transcribe(
audio,
language=language,
initial_prompt=prompt,
condition_on_previous_text=False,
vad_filter=True,
beam_size=config.WHISPER_BEAM_SIZE,
)
print('Segments:', file=sys.stderr)
for i in segments:
print(i, file=sys.stderr)
words = []
for segment in list(segments):
words.append(segment.text)
return {
'text': ' '.join(words),
'segments': {
'id': None,
'seek': None,
'start': None,
'end': None,
'text': None,
'tokens': None,
'temperature': None,
'avg_logprob': None,
'compression_ratio': None,
'no_speech_prob': None,
}
}

View File

@ -0,0 +1,12 @@
import whisper
import config
from recognizer.strategies import BaseRecognizerStrategy
class WhisperStrategy(BaseRecognizerStrategy):
def __init__(self) -> None:
self._model = whisper.load_model(config.HARPYIA_MODEL, device=config.DEVICE)
def recognize(self, file, language, prompt) -> any:
return self._model.transcribe(file.name, \
language=language, initial_prompt=prompt)

View File

@ -1,39 +0,0 @@
import time
import sys
def elapsed_time_wrapper(unique_id: str = ""):
def decorator(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
if not unique_id == "":
print(
f"[{unique_id}] Executed in {execution_time} seconds",
file=sys.stderr,
)
else:
print(f"Executed in {execution_time} seconds", file=sys.stderr)
return result
return wrapper
return decorator
def elapsed_time(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
print(
f"Executed in {execution_time} seconds",
sep="\n",
)
return result
return wrapper