Compare commits

..

1 Commits

Author SHA1 Message Date
e4fee65c38 Works in 700ms 2024-03-01 13:57:11 +03:00
24 changed files with 158 additions and 340 deletions

View File

@ -140,3 +140,8 @@ ENV/
# mypy
.mypy_cache/
#
**/*.wav
**/*.mp3

View File

@ -1,59 +1,131 @@
from ctranslate2.extensions import asyncio
from flask import Flask, abort, request
from tempfile import NamedTemporaryFile
from dotenv import load_dotenv
import os
from torch.functional import Tensor
import whisper
import torch
import sys
import re
from faster_whisper import WhisperModel
from test_utils import elapsed_time
from whisper_timestamped import transcribe_timestamped
from multiprocessing import Process
import config
load_dotenv()
model_size = "small"
# tiny.en, tiny, base.en, base, small.en, small, medium.en, medium, large-v1, large-v2, large-v3, large, distil-large-v2, distil-medium.en, distil-small.en
HARPYIA_PROMPT = os.getenv("HARPYIA_PROMPT") or "спасите помогите на помощь пожар"
HARPYIA_MODEL = os.getenv("HARPYIA_MODEL") or "medium"
HARPYIA_LANGUAGE = os.getenv("HARPYIA_LANGUAGE") or "ru"
from process_stack import ProcessQueue
from process_stack.strategies import RecognizeAndSendStrategy
# Check if NVIDIA GPU is available
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = "cpu"
from recognizer import Recognizer
from recognizer.strategies import WhisperStrategy, FastWhisperStrategy
from message import MessageService
from message.strategies import SosMessageStrategy, NumberMessageStrategy
# Load the Whisper model:
model = WhisperModel(
model_size,
device=DEVICE,
num_workers=6,
cpu_threads=10,
# in_memory=True,
)
app = Flask(__name__)
whisper_recognizer = Recognizer(WhisperStrategy())
fast_whisper_recognizer = Recognizer(FastWhisperStrategy())
sos_message_service = MessageService(SosMessageStrategy())
number_message_service = MessageService(NumberMessageStrategy())
process_stack = ProcessQueue(RecognizeAndSendStrategy())
process_stack.start_loop_in_thread()
@app.route("/")
def hello():
return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize-number' route."
return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize_number' route."
def recognize_files(message_service: MessageService):
if not request.files:
abort(400)
results = []
def recognize_files(handler_fn):
if not request.files:
abort(400)
for filename, handle in request.files.items():
temp = NamedTemporaryFile()
handle.save(temp)
results = []
asyncio.get_running_loop()
for filename, handle in request.files.items():
temp = NamedTemporaryFile()
handle.save(temp)
results.append(process_stack.append_and_await((
temp,
whisper_recognizer,
message_service,
config.HARPYIA_LANGUAGE,
message_service.get_prompt()
)))
audio = prepare_file(temp.name)
res = trans(audio)
print(results, file=sys.stderr)
return {'results': results}
results.append(
{
"filename": filename,
"transcript": res,
}
)
@app.route('/recognize', methods=['POST'])
def recognize():
return recognize_files(sos_message_service)
print(results, file=sys.stderr)
return {"results": results}
@app.route('/recognize-number', methods=['POST'])
initprompt = [
"один",
"два",
"три",
"четыре",
"пять",
"шесть",
"семь",
"восемь",
"девять",
"десять",
"одинадцать",
"двенадцать",
"тренадцать",
"сто",
"сот",
]
@elapsed_time
def trans(audio):
segments, _ = model.transcribe(
audio,
language=HARPYIA_LANGUAGE,
initial_prompt="семь сот сто",
condition_on_previous_text=False,
vad_filter=True,
beam_size=5,
)
words = []
for e in list(segments):
words.append(e.text)
return " ".join(words)
@elapsed_time
def prepare_file(filename: str):
audio = whisper.load_audio(filename, sr=16000)
audio = whisper.pad_or_trim(audio)
return audio
@app.route("/recognize", methods=["POST"])
async def recognize():
return recognize_files(lambda text: text)
@app.route("/recognize_number", methods=["POST"])
def recognize_number():
return recognize_files(number_message_service)
return recognize_files(transfer_and_clean)
def transfer_and_clean(input_string):
number_mapping = {"один": "1", "два": "2", "три": "3"}
for word, number in number_mapping.items():
input_string = input_string.replace(word, number)
input_string = re.sub(r"[^\d]+", "", input_string)
return input_string

View File

@ -1,21 +0,0 @@
import os
import torch
from dotenv import load_dotenv
load_dotenv()
HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'medium'
HARPYIA_LANGUAGE = os.getenv('HARPYIA_LANGUAGE') or 'ru'
HARPYIA_SAMPLE_RATE = os.getenv('HARPYIA_SAMPLE_RATE') or 160000
WHISPER_NUM_WORKERS = os.getenv('WHISPER_NUM_WORKERS') or 6
WHISPER_CPU_THREADS = os.getenv('WHISPER_CPU_THREADS') or 10
WHISPER_BEAM_SIZE = os.getenv('WHISPER_BEAM_SIZE') or 5
SOS_PROMPT = os.getenv('SOS_PROMPT') or 'спасите помогите помощь пожар караул гит'
NUMBER_PROMPT = os.getenv('NUMBER_PROMPT') or 'один два три четыре пять шесть семь восемь девять десять одинадцать двенадцать тринадцать сто сот'
RAT_URL = os.getenv('RAT_URL') or 'localhost:8081'
# Check if NVIDIA GPU is available
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

View File

@ -1 +0,0 @@
from message.message_service import MessageService

View File

@ -1,24 +0,0 @@
import sys
from message.strategies import BaseMessageStrategy
class MessageService:
def __init__(self, strategy: BaseMessageStrategy) -> None:
self._strategy = strategy
def get_prompt(self) -> str:
self._strategy.get_prompt()
def transfer(self, text: str) -> any:
return self._strategy.transfer(text)
def send(self, message: str) -> any:
self._strategy.send(message)
def transfer_and_send(self, recognized_result: any) -> any:
message = self.transfer(recognized_result)
if message:
self.send(message)
print('Sending message:', recognized_result, file=sys.stderr)
return message

View File

@ -1,3 +0,0 @@
from message.strategies.base_message_strategy import BaseMessageStrategy
from message.strategies.sos_message_strategy import SosMessageStrategy
from message.strategies.number_message_strategy import NumberMessageStrategy

View File

@ -1,14 +0,0 @@
from abc import ABC, abstractmethod
class BaseMessageStrategy(ABC):
@abstractmethod
def get_prompt() -> str:
pass
@abstractmethod
def transfer(self, text: str) -> any:
pass
@abstractmethod
def send(self, message: str) -> any:
pass

View File

@ -1,31 +0,0 @@
import re
import config
from message.strategies import BaseMessageStrategy
class NumberMessageStrategy(BaseMessageStrategy):
def __init__(self, prompt=config.NUMBER_PROMPT) -> None:
self._prompt = prompt
def get_prompt(self):
return self._prompt
def transfer(self, recognized_result: any) -> str:
return self._transfer_and_clean(recognized_result['text'])
def _transfer_and_clean(self, text: str) -> str:
number_mapping = {
"один": "1",
"два": "2",
"три": "3"
}
for word, number in number_mapping.items():
transfered_text = text.replace(word, number)
transfered_text = re.sub(r'[^\d]+', '', transfered_text)
return {'recognized': transfered_text}
def send(self, message: str) -> None:
pass

View File

@ -1,35 +0,0 @@
from typing import List
import requests
import config
from message.strategies import BaseMessageStrategy
MESSAGE_ENDPOINT = '/message'
class SosMessageStrategy(BaseMessageStrategy):
def __init__(self, prompt=config.SOS_PROMPT, url=config.RAT_URL) -> None:
self._prompt = prompt
self._url = url
def get_prompt(self):
return self._prompt
def transfer(self, recognized_result: any) -> str:
return {
'transcript': recognized_result['text'],
'results': self._filter_words_with_prompt(recognized_result['text']),
'segments': recognized_result['segments']
}
def _filter_words_with_prompt(self, text: str) -> str:
words = []
for prompt in self._prompt.split(' '):
if prompt in text.lower():
words.append(prompt)
return words
def send(self, message) -> any:
pass
#return requests.post(self._url + MESSAGE_ENDPOINT, json={'message': message})

View File

@ -1 +0,0 @@
from process_stack.process_stack import ProcessQueue

View File

@ -1,51 +0,0 @@
import sys
from threading import Thread, Event, Lock
from process_stack.strategies import BaseProcessStrategy
class ProcessQueue:
def __init__(self, strategy: BaseProcessStrategy) -> None:
self._stack = []
self._strategy = strategy
self._lock = Lock()
self._running = False
self._last_response = None
def append(self, args, event=None) -> None:
with self._lock:
self._stack.append((args, event))
def append_and_await(self, args) -> any:
event = Event()
self.append(args, event=event)
event.wait()
event.clear()
return self._last_response
def loop(self) -> None:
self._running = True
while self._running:
with self._lock:
if self._stack:
print('Stack length:', len(self._stack), file=sys.stderr)
(args, event) = self._stack.pop(0)
self._last_response = self._process(*args)
if event:
event.set()
def _process(self, *args, **kwargs) -> any:
return self._strategy.process(*args, **kwargs)
def start_loop_in_thread(self) -> None:
thread = Thread(target=self.loop)
thread.start()
def stop_loop(self) -> None:
self._running = False

View File

@ -1,2 +0,0 @@
from process_stack.strategies.base_process_strategy import BaseProcessStrategy
from process_stack.strategies.recognize_and_send_strategy import RecognizeAndSendStrategy

View File

@ -1,6 +0,0 @@
from abc import ABC, abstractmethod
class BaseProcessStrategy(ABC):
@abstractmethod
def process(self, *args, **kwargs) -> any:
pass

View File

@ -1,14 +0,0 @@
import sys
from process_stack.strategies import BaseProcessStrategy
from message import MessageService
from recognizer import Recognizer
class RecognizeAndSendStrategy(BaseProcessStrategy):
def process(self, file, recognizer: Recognizer, message_service: MessageService, language, prompt) -> any:
result = recognizer.recognize(file, language=language, prompt=prompt)
message = message_service.transfer_and_send(result)
print(message, file=sys.stderr)
return message

View File

@ -1 +0,0 @@
from recognizer.recognizer import Recognizer

View File

@ -1,14 +0,0 @@
import sys
import config
from recognizer.strategies import BaseRecognizerStrategy
class Recognizer:
def __init__(self, strategy: BaseRecognizerStrategy) -> None:
self._strategy = strategy
def recognize(self, file, language, prompt) -> str:
result = self._strategy.recognize(file, language=language, prompt=prompt)
print(f'Result: {result}', file=sys.stderr)
return result

View File

@ -1,3 +0,0 @@
from recognizer.strategies.base_recognizer_strategy import BaseRecognizerStrategy
from recognizer.strategies.whisper_strategy import WhisperStrategy
from recognizer.strategies.fast_whisper_strategy import FastWhisperStrategy

View File

@ -1,6 +0,0 @@
from abc import ABC, abstractmethod
class BaseRecognizerStrategy(ABC):
@abstractmethod
def recognize(self, file, language, prompt) -> any:
pass

View File

@ -1,59 +0,0 @@
import sys
import whisper
from faster_whisper import WhisperModel
import config
from recognizer.strategies import BaseRecognizerStrategy
class FastWhisperStrategy(BaseRecognizerStrategy):
def __init__(self) -> None:
self._model = WhisperModel(
model_size_or_path=config.HARPYIA_MODEL,
device=config.DEVICE,
num_workers=config.WHISPER_NUM_WORKERS,
cpu_threads=config.WHISPER_CPU_THREADS
)
def recognize(self, file, language, prompt) -> any:
audio = self._prepare_file(file.name)
return self._transcribe(audio, language, prompt)
def _prepare_file(self, filename: str):
audio = whisper.load_audio(filename, sr=config.HARPYIA_SAMPLE_RATE)
audio = whisper.pad_or_trim(audio)
return audio
def _transcribe(self, audio, language, prompt):
segments, _ = self._model.transcribe(
audio,
language=language,
initial_prompt=prompt,
condition_on_previous_text=False,
vad_filter=True,
beam_size=config.WHISPER_BEAM_SIZE,
)
print('Segments:', file=sys.stderr)
for i in segments:
print(i, file=sys.stderr)
words = []
for segment in list(segments):
words.append(segment.text)
return {
'text': ' '.join(words),
'segments': {
'id': None,
'seek': None,
'start': None,
'end': None,
'text': None,
'tokens': None,
'temperature': None,
'avg_logprob': None,
'compression_ratio': None,
'no_speech_prob': None,
}
}

View File

@ -1,12 +0,0 @@
import whisper
import config
from recognizer.strategies import BaseRecognizerStrategy
class WhisperStrategy(BaseRecognizerStrategy):
def __init__(self) -> None:
self._model = whisper.load_model(config.HARPYIA_MODEL, device=config.DEVICE)
def recognize(self, file, language, prompt) -> any:
return self._model.transcribe(file.name, \
language=language, initial_prompt=prompt)

39
src/test_utils.py Normal file
View File

@ -0,0 +1,39 @@
import time
import sys
def elapsed_time_wrapper(unique_id: str = ""):
def decorator(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
if not unique_id == "":
print(
f"[{unique_id}] Executed in {execution_time} seconds",
file=sys.stderr,
)
else:
print(f"Executed in {execution_time} seconds", file=sys.stderr)
return result
return wrapper
return decorator
def elapsed_time(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
print(
f"Executed in {execution_time} seconds",
sep="\n",
)
return result
return wrapper