Implemented strategies for sos and number recognition
Deleted message sender and promt service Implemented fast whisper, but it is not working WavStack refactored into QueueStack, which can use different strategies for proccessing
This commit is contained in:
parent
dbbf845e56
commit
37adf74745
68
src/app.py
68
src/app.py
@ -1,31 +1,35 @@
|
||||
from flask import Flask, abort, request
|
||||
from tempfile import NamedTemporaryFile
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import whisper
|
||||
import torch
|
||||
|
||||
import sys
|
||||
import re
|
||||
|
||||
load_dotenv()
|
||||
import config
|
||||
|
||||
HARPYIA_PROMPT = os.getenv('HARPYIA_PROMPT') or 'спасите помогите на помощь пожар'
|
||||
HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'medium'
|
||||
HARPYIA_LANGUAGE = os.getenv('HARPYIA_LANGUAGE') or 'ru'
|
||||
from queue_stack import QueueStack
|
||||
from queue_stack.strategies import RecognizeAndSendStrategy
|
||||
|
||||
# Check if NVIDIA GPU is available
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
from recognizer import Recognizer
|
||||
from recognizer.strategies import WhisperStrategy, FastWhisperStrategy
|
||||
|
||||
# Load the Whisper model:
|
||||
model = whisper.load_model(HARPYIA_MODEL, device=DEVICE)
|
||||
from message import MessageService
|
||||
from message.strategies import SosMessageStrategy, NumberMessageStrategy
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
whisper_recognizer = Recognizer(WhisperStrategy())
|
||||
fast_whisper_recognizer = Recognizer(FastWhisperStrategy())
|
||||
|
||||
sos_message_service = MessageService(SosMessageStrategy())
|
||||
number_message_service = MessageService(NumberMessageStrategy())
|
||||
|
||||
queue_stack = QueueStack(RecognizeAndSendStrategy())
|
||||
queue_stack.start_loop_in_thread()
|
||||
|
||||
@app.route("/")
|
||||
def hello():
|
||||
return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize_number' route."
|
||||
return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize-number' route."
|
||||
|
||||
def recognize_files(handler_fn):
|
||||
def recognize_files(message_service: MessageService):
|
||||
if not request.files:
|
||||
abort(400)
|
||||
|
||||
@ -34,34 +38,22 @@ def recognize_files(handler_fn):
|
||||
for filename, handle in request.files.items():
|
||||
temp = NamedTemporaryFile()
|
||||
handle.save(temp)
|
||||
result = model.transcribe(temp.name, language=HARPYIA_LANGUAGE, initial_prompt=HARPYIA_PROMPT)
|
||||
results.append({
|
||||
'filename': filename,
|
||||
'transcript': handler_fn(result['text']),
|
||||
})
|
||||
|
||||
results.append(queue_stack.append_and_await((
|
||||
temp,
|
||||
whisper_recognizer,
|
||||
message_service,
|
||||
config.HARPYIA_LANGUAGE,
|
||||
message_service.get_prompt()
|
||||
)))
|
||||
|
||||
print(results, file=sys.stderr)
|
||||
return {'results': results}
|
||||
|
||||
@app.route('/recognize', methods=['POST'])
|
||||
def recognize():
|
||||
return recognize_files(lambda text: text)
|
||||
return recognize_files(sos_message_service)
|
||||
|
||||
@app.route('/recognize_number', methods=['POST'])
|
||||
@app.route('/recognize-number', methods=['POST'])
|
||||
def recognize_number():
|
||||
return recognize_files(transfer_and_clean)
|
||||
|
||||
def transfer_and_clean(input_string):
|
||||
number_mapping = {
|
||||
"один": "1",
|
||||
"два": "2",
|
||||
"три": "3"
|
||||
}
|
||||
|
||||
for word, number in number_mapping.items():
|
||||
input_string = input_string.replace(word, number)
|
||||
|
||||
input_string = re.sub(r'[^\d]+', '', input_string)
|
||||
|
||||
return input_string
|
||||
|
||||
return recognize_files(number_message_service)
|
@ -4,9 +4,18 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HARPYIA_PROMPT = os.getenv('HARPYIA_PROMPT') or 'спасите помогите на помощь пожар'
|
||||
HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'medium'
|
||||
HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'small'
|
||||
HARPYIA_LANGUAGE = os.getenv('HARPYIA_LANGUAGE') or 'ru'
|
||||
HARPYIA_SAMPLE_RATE = os.getenv('HARPYIA_SAMPLE_RATE') or 160000
|
||||
|
||||
WHISPER_NUM_WORKERS = os.getenv('WHISPER_NUM_WORKERS') or 6
|
||||
WHISPER_CPU_THREADS = os.getenv('WHISPER_CPU_THREADS') or 10
|
||||
WHISPER_BEAM_SIZE = os.getenv('WHISPER_BEAM_SIZE') or 5
|
||||
|
||||
SOS_PROMPT = os.getenv('SOS_PROMPT') or 'спасите помогите помощь пожар караул кирилл'
|
||||
NUMBER_PROMPT = os.getenv('NUMBER_PROMPT') or 'один два три четыре пять шесть семь восемь девять десять одинадцать двенадцать тринадцать сто сот'
|
||||
|
||||
RAT_URL = os.getenv('RAT_URL') or 'localhost:8081'
|
||||
|
||||
# Check if NVIDIA GPU is available
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
@ -1,32 +0,0 @@
|
||||
from threading import Thread
|
||||
|
||||
from recognizer import Recognizer
|
||||
from message import MessageComposer
|
||||
|
||||
class WavStack:
|
||||
def __init__(self, recognizer: Recognizer, message_composer: MessageComposer):
|
||||
self._stack = []
|
||||
self._recognizer = recognizer
|
||||
self._message_composer = message_composer
|
||||
self._running = False
|
||||
|
||||
def append(self, file):
|
||||
self._stack.append(file)
|
||||
|
||||
def loop(self):
|
||||
self._running = True
|
||||
while self._running:
|
||||
if self._stack:
|
||||
file = self._stack.pop(0)
|
||||
recognized_text = self._recognizer.recognize(file)
|
||||
message = self._message_composer.compose(recognized_text)
|
||||
|
||||
if message.has_prompt():
|
||||
message.send()
|
||||
|
||||
def start_loop_in_thread(self):
|
||||
thread = Thread(target=self.loop)
|
||||
thread.start()
|
||||
|
||||
def stop_loop(self):
|
||||
self._running = False
|
@ -1,6 +1 @@
|
||||
from message.prompt_service import PromptService
|
||||
from message.message_sender.message_sender import MessageSender
|
||||
from message.message_composer import MessageComposer
|
||||
from message.message import Message
|
||||
|
||||
import message.message_sender as message_sender
|
||||
from message.message_service import MessageService
|
@ -1,17 +0,0 @@
|
||||
from message import PromptService, MessageSender
|
||||
|
||||
class Message:
|
||||
def __init__(self, prompt_service: PromptService, message_sender: MessageSender, \
|
||||
recognized_text: str):
|
||||
self._prompt_service = prompt_service
|
||||
self._message_sender = message_sender
|
||||
self._recognized_text = recognized_text
|
||||
|
||||
def has_prompt(self) -> bool:
|
||||
return self._prompt_service.has_prompt(self._recognized_text)
|
||||
|
||||
def send(self) -> None:
|
||||
self._message_sender.send(self._generate_response())
|
||||
|
||||
def _generate_response(self) -> str:
|
||||
return self._prompt_service.filter_words_with_prompt(self._recognized_text)
|
@ -1,9 +0,0 @@
|
||||
from message import Message, PromptService, MessageSender
|
||||
|
||||
class MessageComposer:
|
||||
def __init__(self, prompt_service: PromptService, message_sender: MessageSender):
|
||||
self._prompt_service = prompt_service
|
||||
self._message_sender = message_sender
|
||||
|
||||
def compose(self, recognized_text) -> Message:
|
||||
return Message(self._prompt_service, self._message_sender, recognized_text)
|
@ -1,3 +0,0 @@
|
||||
from message.message_sender.message_sender import MessageSender
|
||||
from message.message_sender.message_sender_strategy import MessageSenderStrategy
|
||||
from message.message_sender.rat_strategy import RatStrategy
|
@ -1,8 +0,0 @@
|
||||
from message.message_sender import MessageSenderStrategy
|
||||
|
||||
class MessageSender:
|
||||
def __init__(self, strategy: MessageSenderStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
|
||||
def send(self, message) -> None:
|
||||
self._strategy.send(message)
|
@ -1,6 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class MessageSenderStrategy(ABC):
|
||||
@abstractmethod
|
||||
def send(self, message) -> None:
|
||||
pass
|
@ -1,12 +0,0 @@
|
||||
import requests
|
||||
|
||||
from message.message_sender import MessageSenderStrategy
|
||||
|
||||
MESSAGE_ENDPOINT = '/message'
|
||||
|
||||
class RatStrategy(MessageSenderStrategy):
|
||||
def __init__(self, url):
|
||||
self._url = url
|
||||
|
||||
def send(self, message):
|
||||
requests.post(self._url + MESSAGE_ENDPOINT, json={'message': message})
|
24
src/message/message_service.py
Normal file
24
src/message/message_service.py
Normal file
@ -0,0 +1,24 @@
|
||||
import sys
|
||||
from message.strategies import BaseMessageStrategy
|
||||
|
||||
class MessageService:
|
||||
def __init__(self, strategy: BaseMessageStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
self._strategy.get_prompt()
|
||||
|
||||
def transfer(self, text: str) -> any:
|
||||
return self._strategy.transfer(text)
|
||||
|
||||
def send(self, message: str) -> any:
|
||||
self._strategy.send(message)
|
||||
|
||||
def transfer_and_send(self, recognized_result: any) -> any:
|
||||
message = self.transfer(recognized_result)
|
||||
|
||||
if message:
|
||||
self.send(message)
|
||||
|
||||
print('Sending message:', recognized_result, file=sys.stderr)
|
||||
return message
|
@ -1,16 +0,0 @@
|
||||
class PromptService:
|
||||
def __init__(self, prompt):
|
||||
self._prompt = prompt
|
||||
|
||||
def has_prompt(self, text: str) -> bool:
|
||||
for part in text.split(' '):
|
||||
return part in self._prompt.split(' ')
|
||||
|
||||
def filter_words_with_prompt(self, text: str) -> str:
|
||||
words = []
|
||||
|
||||
for part in text.split(' '):
|
||||
if part in self._prompt.split(' '):
|
||||
words.append(part)
|
||||
|
||||
return words
|
3
src/message/strategies/__init__.py
Normal file
3
src/message/strategies/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from message.strategies.base_message_strategy import BaseMessageStrategy
|
||||
from message.strategies.sos_message_strategy import SosMessageStrategy
|
||||
from message.strategies.number_message_strategy import NumberMessageStrategy
|
14
src/message/strategies/base_message_strategy.py
Normal file
14
src/message/strategies/base_message_strategy.py
Normal file
@ -0,0 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseMessageStrategy(ABC):
|
||||
@abstractmethod
|
||||
def get_prompt() -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transfer(self, text: str) -> any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send(self, message: str) -> any:
|
||||
pass
|
31
src/message/strategies/number_message_strategy.py
Normal file
31
src/message/strategies/number_message_strategy.py
Normal file
@ -0,0 +1,31 @@
|
||||
import re
|
||||
|
||||
import config
|
||||
from message.strategies import BaseMessageStrategy
|
||||
|
||||
class NumberMessageStrategy(BaseMessageStrategy):
|
||||
def __init__(self, prompt=config.NUMBER_PROMPT) -> None:
|
||||
self._prompt = prompt
|
||||
|
||||
def get_prompt(self):
|
||||
return self._prompt
|
||||
|
||||
def transfer(self, recognized_result: any) -> str:
|
||||
return self._transfer_and_clean(recognized_result['text'])
|
||||
|
||||
def _transfer_and_clean(self, text: str) -> str:
|
||||
number_mapping = {
|
||||
"один": "1",
|
||||
"два": "2",
|
||||
"три": "3"
|
||||
}
|
||||
|
||||
for word, number in number_mapping.items():
|
||||
transfered_text = text.replace(word, number)
|
||||
|
||||
transfered_text = re.sub(r'[^\d]+', '', transfered_text)
|
||||
|
||||
return {'recognized': transfered_text}
|
||||
|
||||
def send(self, message: str) -> None:
|
||||
pass
|
35
src/message/strategies/sos_message_strategy.py
Normal file
35
src/message/strategies/sos_message_strategy.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import List
|
||||
import requests
|
||||
|
||||
import config
|
||||
from message.strategies import BaseMessageStrategy
|
||||
|
||||
MESSAGE_ENDPOINT = '/message'
|
||||
|
||||
class SosMessageStrategy(BaseMessageStrategy):
|
||||
def __init__(self, prompt=config.SOS_PROMPT, url=config.RAT_URL) -> None:
|
||||
self._prompt = prompt
|
||||
self._url = url
|
||||
|
||||
def get_prompt(self):
|
||||
return self._prompt
|
||||
|
||||
def transfer(self, recognized_result: any) -> str:
|
||||
return {
|
||||
'transcript': recognized_result['text'],
|
||||
'results': self._filter_words_with_prompt(recognized_result['text']),
|
||||
'segments': recognized_result['segments']
|
||||
}
|
||||
|
||||
def _filter_words_with_prompt(self, text: str) -> str:
|
||||
words = []
|
||||
|
||||
for prompt in self._prompt.split(' '):
|
||||
if prompt in text.lower():
|
||||
words.append(prompt)
|
||||
|
||||
return words
|
||||
|
||||
def send(self, message) -> any:
|
||||
pass
|
||||
#return requests.post(self._url + MESSAGE_ENDPOINT, json={'message': message})
|
1
src/queue_stack/__init__.py
Normal file
1
src/queue_stack/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from queue_stack.queue_stack import QueueStack
|
51
src/queue_stack/queue_stack.py
Normal file
51
src/queue_stack/queue_stack.py
Normal file
@ -0,0 +1,51 @@
|
||||
import sys
|
||||
|
||||
from threading import Thread, Event, Lock
|
||||
|
||||
from queue_stack.strategies import BaseProcessStrategy
|
||||
|
||||
class QueueStack:
|
||||
def __init__(self, strategy: BaseProcessStrategy) -> None:
|
||||
self._stack = []
|
||||
self._strategy = strategy
|
||||
|
||||
self._lock = Lock()
|
||||
self._running = False
|
||||
|
||||
self._last_response = None
|
||||
|
||||
def append(self, args, event=None) -> None:
|
||||
with self._lock:
|
||||
self._stack.append((args, event))
|
||||
|
||||
def append_and_await(self, args) -> any:
|
||||
event = Event()
|
||||
self.append(args, event=event)
|
||||
|
||||
event.wait()
|
||||
event.clear()
|
||||
|
||||
return self._last_response
|
||||
|
||||
def loop(self) -> None:
|
||||
self._running = True
|
||||
|
||||
while self._running:
|
||||
with self._lock:
|
||||
if self._stack:
|
||||
print('Stack length:', len(self._stack), file=sys.stderr)
|
||||
(args, event) = self._stack.pop(0)
|
||||
self._last_response = self._process(*args)
|
||||
|
||||
if event:
|
||||
event.set()
|
||||
|
||||
def _process(self, *args, **kwargs) -> any:
|
||||
return self._strategy.process(*args, **kwargs)
|
||||
|
||||
def start_loop_in_thread(self) -> None:
|
||||
thread = Thread(target=self.loop)
|
||||
thread.start()
|
||||
|
||||
def stop_loop(self) -> None:
|
||||
self._running = False
|
2
src/queue_stack/strategies/__init__.py
Normal file
2
src/queue_stack/strategies/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from queue_stack.strategies.base_process_strategy import BaseProcessStrategy
|
||||
from queue_stack.strategies.recognize_and_send_strategy import RecognizeAndSendStrategy
|
6
src/queue_stack/strategies/base_process_strategy.py
Normal file
6
src/queue_stack/strategies/base_process_strategy.py
Normal file
@ -0,0 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseProcessStrategy(ABC):
|
||||
@abstractmethod
|
||||
def process(self, *args, **kwargs) -> any:
|
||||
pass
|
14
src/queue_stack/strategies/recognize_and_send_strategy.py
Normal file
14
src/queue_stack/strategies/recognize_and_send_strategy.py
Normal file
@ -0,0 +1,14 @@
|
||||
import sys
|
||||
|
||||
from queue_stack.strategies import BaseProcessStrategy
|
||||
from message import MessageService
|
||||
from recognizer import Recognizer
|
||||
|
||||
class RecognizeAndSendStrategy(BaseProcessStrategy):
|
||||
def process(self, file, recognizer: Recognizer, message_service: MessageService, language, prompt) -> any:
|
||||
|
||||
result = recognizer.recognize(file, language=language, prompt=prompt)
|
||||
message = message_service.transfer_and_send(result)
|
||||
print(message, file=sys.stderr)
|
||||
|
||||
return message
|
@ -1,4 +1 @@
|
||||
from recognizer.recognizer import Recognizer
|
||||
from recognizer.recognizer_strategy import RecognizerStrategy
|
||||
from recognizer.whisper_strategy import WhisperStrategy
|
||||
from recognizer.fast_whisper_strategy import FastWhisperStrategy
|
||||
|
@ -1,37 +0,0 @@
|
||||
import whisper
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
import config
|
||||
from recognizer import RecognizerStrategy
|
||||
|
||||
class FastWhisperStrategy(RecognizerStrategy):
|
||||
def __init__(self) -> None:
|
||||
self._model = WhisperModel(
|
||||
model_size=config.HARPYIA_MODEL,
|
||||
device=config.DEVICE,
|
||||
num_workers=6,
|
||||
cpu_threads=10,
|
||||
# in_memory=True,
|
||||
)
|
||||
|
||||
def recognize(self, file) -> str:
|
||||
audio = self._prepare_file(file.name)
|
||||
return self._transcribe(audio)
|
||||
|
||||
def _prepare_file(self, filename: str):
|
||||
audio = whisper.load_audio(filename, sr=16000)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
return audio
|
||||
|
||||
|
||||
def _transcribe(self, audio):
|
||||
segments, _ = self._model.transcribe(
|
||||
audio,
|
||||
language=config.HARPYIA_LANGUAGE,
|
||||
initial_prompt=config.HARPYIA_PROMPT,
|
||||
condition_on_previous_text=False,
|
||||
vad_filter=True,
|
||||
beam_size=5,
|
||||
)
|
||||
|
||||
return segments
|
@ -1,8 +1,14 @@
|
||||
from recognizer import RecognizerStrategy
|
||||
import sys
|
||||
|
||||
import config
|
||||
from recognizer.strategies import BaseRecognizerStrategy
|
||||
|
||||
class Recognizer:
|
||||
def __init__(self, strategy: RecognizerStrategy):
|
||||
def __init__(self, strategy: BaseRecognizerStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
|
||||
def recognize(self, file) -> str:
|
||||
self._strategy.recognize(file)
|
||||
def recognize(self, file, language, prompt) -> str:
|
||||
result = self._strategy.recognize(file, language=language, prompt=prompt)
|
||||
|
||||
print(f'Result: {result}', file=sys.stderr)
|
||||
return result
|
@ -1,6 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class RecognizerStrategy(ABC):
|
||||
@abstractmethod
|
||||
def recognize(self, file) -> str:
|
||||
pass
|
3
src/recognizer/strategies/__init__.py
Normal file
3
src/recognizer/strategies/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from recognizer.strategies.base_recognizer_strategy import BaseRecognizerStrategy
|
||||
from recognizer.strategies.whisper_strategy import WhisperStrategy
|
||||
from recognizer.strategies.fast_whisper_strategy import FastWhisperStrategy
|
6
src/recognizer/strategies/base_recognizer_strategy.py
Normal file
6
src/recognizer/strategies/base_recognizer_strategy.py
Normal file
@ -0,0 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseRecognizerStrategy(ABC):
|
||||
@abstractmethod
|
||||
def recognize(self, file, language, prompt) -> any:
|
||||
pass
|
59
src/recognizer/strategies/fast_whisper_strategy.py
Normal file
59
src/recognizer/strategies/fast_whisper_strategy.py
Normal file
@ -0,0 +1,59 @@
|
||||
import sys
|
||||
|
||||
import whisper
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
import config
|
||||
from recognizer.strategies import BaseRecognizerStrategy
|
||||
|
||||
class FastWhisperStrategy(BaseRecognizerStrategy):
|
||||
def __init__(self) -> None:
|
||||
self._model = WhisperModel(
|
||||
model_size_or_path=config.HARPYIA_MODEL,
|
||||
device=config.DEVICE,
|
||||
num_workers=config.WHISPER_NUM_WORKERS,
|
||||
cpu_threads=config.WHISPER_CPU_THREADS
|
||||
)
|
||||
|
||||
def recognize(self, file, language, prompt) -> any:
|
||||
audio = self._prepare_file(file.name)
|
||||
return self._transcribe(audio, language, prompt)
|
||||
|
||||
def _prepare_file(self, filename: str):
|
||||
audio = whisper.load_audio(filename, sr=config.HARPYIA_SAMPLE_RATE)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
return audio
|
||||
|
||||
def _transcribe(self, audio, language, prompt):
|
||||
segments, _ = self._model.transcribe(
|
||||
audio,
|
||||
language=language,
|
||||
initial_prompt=prompt,
|
||||
condition_on_previous_text=False,
|
||||
vad_filter=True,
|
||||
beam_size=config.WHISPER_BEAM_SIZE,
|
||||
)
|
||||
|
||||
print('Segments:', file=sys.stderr)
|
||||
for i in segments:
|
||||
print(i, file=sys.stderr)
|
||||
|
||||
words = []
|
||||
for segment in list(segments):
|
||||
words.append(segment.text)
|
||||
|
||||
return {
|
||||
'text': ' '.join(words),
|
||||
'segments': {
|
||||
'id': None,
|
||||
'seek': None,
|
||||
'start': None,
|
||||
'end': None,
|
||||
'text': None,
|
||||
'tokens': None,
|
||||
'temperature': None,
|
||||
'avg_logprob': None,
|
||||
'compression_ratio': None,
|
||||
'no_speech_prob': None,
|
||||
}
|
||||
}
|
12
src/recognizer/strategies/whisper_strategy.py
Normal file
12
src/recognizer/strategies/whisper_strategy.py
Normal file
@ -0,0 +1,12 @@
|
||||
import whisper
|
||||
|
||||
import config
|
||||
from recognizer.strategies import BaseRecognizerStrategy
|
||||
|
||||
class WhisperStrategy(BaseRecognizerStrategy):
|
||||
def __init__(self) -> None:
|
||||
self._model = whisper.load_model(config.HARPYIA_MODEL, device=config.DEVICE)
|
||||
|
||||
def recognize(self, file, language, prompt) -> any:
|
||||
return self._model.transcribe(file.name, \
|
||||
language=language, initial_prompt=prompt)
|
@ -1,12 +0,0 @@
|
||||
import whisper
|
||||
|
||||
import config
|
||||
from recognizer import RecognizerStrategy
|
||||
|
||||
class WhisperStrategy(RecognizerStrategy):
|
||||
def __init__(self) -> None:
|
||||
self._model = whisper.load_model(config.HARPYIA_MODEL, device=config.DEVICE)
|
||||
|
||||
def recognize(self, file) -> str:
|
||||
return self._model.transcribe(file.name, \
|
||||
language=config.HARPYIA_LANGUAGE, initial_prompt=config.HARPYIA_PROMPT)
|
Loading…
x
Reference in New Issue
Block a user