feature/file-stack #1
84
src/app.py
84
src/app.py
@ -1,67 +1,59 @@
|
||||
from flask import Flask, abort, request
|
||||
from tempfile import NamedTemporaryFile
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import whisper
|
||||
import torch
|
||||
|
||||
import sys
|
||||
import re
|
||||
|
||||
load_dotenv()
|
||||
import config
|
||||
|
||||
HARPYIA_PROMPT = os.getenv('HARPYIA_PROMPT') or 'спасите помогите на помощь пожар'
|
||||
HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'medium'
|
||||
HARPYIA_LANGUAGE = os.getenv('HARPYIA_LANGUAGE') or 'ru'
|
||||
from queue_stack import QueueStack
|
||||
from queue_stack.strategies import RecognizeAndSendStrategy
|
||||
|
||||
# Check if NVIDIA GPU is available
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
from recognizer import Recognizer
|
||||
from recognizer.strategies import WhisperStrategy, FastWhisperStrategy
|
||||
|
||||
# Load the Whisper model:
|
||||
model = whisper.load_model(HARPYIA_MODEL, device=DEVICE)
|
||||
from message import MessageService
|
||||
from message.strategies import SosMessageStrategy, NumberMessageStrategy
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
whisper_recognizer = Recognizer(WhisperStrategy())
|
||||
fast_whisper_recognizer = Recognizer(FastWhisperStrategy())
|
||||
|
||||
sos_message_service = MessageService(SosMessageStrategy())
|
||||
number_message_service = MessageService(NumberMessageStrategy())
|
||||
|
||||
queue_stack = QueueStack(RecognizeAndSendStrategy())
|
||||
queue_stack.start_loop_in_thread()
|
||||
|
||||
@app.route("/")
|
||||
def hello():
|
||||
return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize_number' route."
|
||||
return "To recognize an audio file, upload it using a POST request with '/recognize' or '/recognize-number' route."
|
||||
|
||||
def recognize_files(handler_fn):
|
||||
if not request.files:
|
||||
abort(400)
|
||||
def recognize_files(message_service: MessageService):
|
||||
if not request.files:
|
||||
abort(400)
|
||||
|
||||
results = []
|
||||
results = []
|
||||
|
||||
for filename, handle in request.files.items():
|
||||
temp = NamedTemporaryFile()
|
||||
handle.save(temp)
|
||||
result = model.transcribe(temp.name, language=HARPYIA_LANGUAGE, initial_prompt=HARPYIA_PROMPT)
|
||||
results.append({
|
||||
'filename': filename,
|
||||
'transcript': handler_fn(result['text']),
|
||||
})
|
||||
for filename, handle in request.files.items():
|
||||
temp = NamedTemporaryFile()
|
||||
handle.save(temp)
|
||||
|
||||
print(results, file=sys.stderr)
|
||||
return {'results': results}
|
||||
results.append(queue_stack.append_and_await((
|
||||
temp,
|
||||
whisper_recognizer,
|
||||
message_service,
|
||||
config.HARPYIA_LANGUAGE,
|
||||
message_service.get_prompt()
|
||||
)))
|
||||
|
||||
print(results, file=sys.stderr)
|
||||
return {'results': results}
|
||||
|
||||
@app.route('/recognize', methods=['POST'])
|
||||
def recognize():
|
||||
return recognize_files(lambda text: text)
|
||||
return recognize_files(sos_message_service)
|
||||
|
||||
@app.route('/recognize_number', methods=['POST'])
|
||||
@app.route('/recognize-number', methods=['POST'])
|
||||
def recognize_number():
|
||||
return recognize_files(transfer_and_clean)
|
||||
|
||||
def transfer_and_clean(input_string):
|
||||
number_mapping = {
|
||||
"один": "1",
|
||||
"два": "2",
|
||||
"три": "3"
|
||||
}
|
||||
|
||||
for word, number in number_mapping.items():
|
||||
input_string = input_string.replace(word, number)
|
||||
|
||||
input_string = re.sub(r'[^\d]+', '', input_string)
|
||||
|
||||
return input_string
|
||||
|
||||
return recognize_files(number_message_service)
|
@ -4,9 +4,18 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HARPYIA_PROMPT = os.getenv('HARPYIA_PROMPT') or 'спасите помогите на помощь пожар'
|
||||
HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'medium'
|
||||
HARPYIA_MODEL = os.getenv('HARPYIA_MODEL') or 'small'
|
||||
HARPYIA_LANGUAGE = os.getenv('HARPYIA_LANGUAGE') or 'ru'
|
||||
HARPYIA_SAMPLE_RATE = os.getenv('HARPYIA_SAMPLE_RATE') or 160000
|
||||
|
||||
WHISPER_NUM_WORKERS = os.getenv('WHISPER_NUM_WORKERS') or 6
|
||||
WHISPER_CPU_THREADS = os.getenv('WHISPER_CPU_THREADS') or 10
|
||||
WHISPER_BEAM_SIZE = os.getenv('WHISPER_BEAM_SIZE') or 5
|
||||
|
||||
SOS_PROMPT = os.getenv('SOS_PROMPT') or 'спасите помогите помощь пожар караул кирилл'
|
||||
NUMBER_PROMPT = os.getenv('NUMBER_PROMPT') or 'один два три четыре пять шесть семь восемь девять десять одинадцать двенадцать тринадцать сто сот'
|
||||
|
||||
RAT_URL = os.getenv('RAT_URL') or 'localhost:8081'
|
||||
|
||||
# Check if NVIDIA GPU is available
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
@ -1,32 +0,0 @@
|
||||
from threading import Thread
|
||||
|
||||
from recognizer import Recognizer
|
||||
from message import MessageComposer
|
||||
|
||||
class WavStack:
|
||||
def __init__(self, recognizer: Recognizer, message_composer: MessageComposer):
|
||||
self._stack = []
|
||||
self._recognizer = recognizer
|
||||
self._message_composer = message_composer
|
||||
self._running = False
|
||||
|
||||
def append(self, file):
|
||||
self._stack.append(file)
|
||||
|
||||
def loop(self):
|
||||
self._running = True
|
||||
while self._running:
|
||||
if self._stack:
|
||||
file = self._stack.pop(0)
|
||||
recognized_text = self._recognizer.recognize(file)
|
||||
message = self._message_composer.compose(recognized_text)
|
||||
|
||||
if message.has_prompt():
|
||||
message.send()
|
||||
|
||||
def start_loop_in_thread(self):
|
||||
thread = Thread(target=self.loop)
|
||||
thread.start()
|
||||
|
||||
def stop_loop(self):
|
||||
self._running = False
|
@ -1,6 +1 @@
|
||||
from message.prompt_service import PromptService
|
||||
from message.message_sender.message_sender import MessageSender
|
||||
from message.message_composer import MessageComposer
|
||||
from message.message import Message
|
||||
|
||||
import message.message_sender as message_sender
|
||||
from message.message_service import MessageService
|
@ -1,17 +0,0 @@
|
||||
from message import PromptService, MessageSender
|
||||
|
||||
class Message:
|
||||
def __init__(self, prompt_service: PromptService, message_sender: MessageSender, \
|
||||
recognized_text: str):
|
||||
self._prompt_service = prompt_service
|
||||
self._message_sender = message_sender
|
||||
self._recognized_text = recognized_text
|
||||
|
||||
def has_prompt(self) -> bool:
|
||||
return self._prompt_service.has_prompt(self._recognized_text)
|
||||
|
||||
def send(self) -> None:
|
||||
self._message_sender.send(self._generate_response())
|
||||
|
||||
def _generate_response(self) -> str:
|
||||
return self._prompt_service.filter_words_with_prompt(self._recognized_text)
|
@ -1,9 +0,0 @@
|
||||
from message import Message, PromptService, MessageSender
|
||||
|
||||
class MessageComposer:
|
||||
def __init__(self, prompt_service: PromptService, message_sender: MessageSender):
|
||||
self._prompt_service = prompt_service
|
||||
self._message_sender = message_sender
|
||||
|
||||
def compose(self, recognized_text) -> Message:
|
||||
return Message(self._prompt_service, self._message_sender, recognized_text)
|
@ -1,3 +0,0 @@
|
||||
from message.message_sender.message_sender import MessageSender
|
||||
from message.message_sender.message_sender_strategy import MessageSenderStrategy
|
||||
from message.message_sender.rat_strategy import RatStrategy
|
@ -1,8 +0,0 @@
|
||||
from message.message_sender import MessageSenderStrategy
|
||||
|
||||
class MessageSender:
|
||||
def __init__(self, strategy: MessageSenderStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
|
||||
def send(self, message) -> None:
|
||||
self._strategy.send(message)
|
@ -1,6 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class MessageSenderStrategy(ABC):
|
||||
@abstractmethod
|
||||
def send(self, message) -> None:
|
||||
pass
|
@ -1,12 +0,0 @@
|
||||
import requests
|
||||
|
||||
from message.message_sender import MessageSenderStrategy
|
||||
|
||||
MESSAGE_ENDPOINT = '/message'
|
||||
|
||||
class RatStrategy(MessageSenderStrategy):
|
||||
def __init__(self, url):
|
||||
self._url = url
|
||||
|
||||
def send(self, message):
|
||||
requests.post(self._url + MESSAGE_ENDPOINT, json={'message': message})
|
24
src/message/message_service.py
Normal file
24
src/message/message_service.py
Normal file
@ -0,0 +1,24 @@
|
||||
import sys
|
||||
from message.strategies import BaseMessageStrategy
|
||||
|
||||
class MessageService:
|
||||
def __init__(self, strategy: BaseMessageStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
self._strategy.get_prompt()
|
||||
|
||||
def transfer(self, text: str) -> any:
|
||||
return self._strategy.transfer(text)
|
||||
|
||||
def send(self, message: str) -> any:
|
||||
self._strategy.send(message)
|
||||
|
||||
def transfer_and_send(self, recognized_result: any) -> any:
|
||||
message = self.transfer(recognized_result)
|
||||
|
||||
if message:
|
||||
self.send(message)
|
||||
|
||||
print('Sending message:', recognized_result, file=sys.stderr)
|
||||
return message
|
@ -1,16 +0,0 @@
|
||||
class PromptService:
|
||||
def __init__(self, prompt):
|
||||
self._prompt = prompt
|
||||
|
||||
def has_prompt(self, text: str) -> bool:
|
||||
for part in text.split(' '):
|
||||
return part in self._prompt.split(' ')
|
||||
|
||||
def filter_words_with_prompt(self, text: str) -> str:
|
||||
words = []
|
||||
|
||||
for part in text.split(' '):
|
||||
if part in self._prompt.split(' '):
|
||||
words.append(part)
|
||||
|
||||
return words
|
3
src/message/strategies/__init__.py
Normal file
3
src/message/strategies/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from message.strategies.base_message_strategy import BaseMessageStrategy
|
||||
from message.strategies.sos_message_strategy import SosMessageStrategy
|
||||
from message.strategies.number_message_strategy import NumberMessageStrategy
|
14
src/message/strategies/base_message_strategy.py
Normal file
14
src/message/strategies/base_message_strategy.py
Normal file
@ -0,0 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseMessageStrategy(ABC):
|
||||
@abstractmethod
|
||||
def get_prompt() -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transfer(self, text: str) -> any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send(self, message: str) -> any:
|
||||
pass
|
31
src/message/strategies/number_message_strategy.py
Normal file
31
src/message/strategies/number_message_strategy.py
Normal file
@ -0,0 +1,31 @@
|
||||
import re
|
||||
|
||||
import config
|
||||
from message.strategies import BaseMessageStrategy
|
||||
|
||||
class NumberMessageStrategy(BaseMessageStrategy):
|
||||
def __init__(self, prompt=config.NUMBER_PROMPT) -> None:
|
||||
self._prompt = prompt
|
||||
|
||||
def get_prompt(self):
|
||||
return self._prompt
|
||||
|
||||
def transfer(self, recognized_result: any) -> str:
|
||||
return self._transfer_and_clean(recognized_result['text'])
|
||||
|
||||
def _transfer_and_clean(self, text: str) -> str:
|
||||
number_mapping = {
|
||||
"один": "1",
|
||||
"два": "2",
|
||||
"три": "3"
|
||||
}
|
||||
|
||||
for word, number in number_mapping.items():
|
||||
transfered_text = text.replace(word, number)
|
||||
|
||||
transfered_text = re.sub(r'[^\d]+', '', transfered_text)
|
||||
|
||||
return {'recognized': transfered_text}
|
||||
|
||||
def send(self, message: str) -> None:
|
||||
pass
|
35
src/message/strategies/sos_message_strategy.py
Normal file
35
src/message/strategies/sos_message_strategy.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import List
|
||||
import requests
|
||||
|
||||
import config
|
||||
from message.strategies import BaseMessageStrategy
|
||||
|
||||
MESSAGE_ENDPOINT = '/message'
|
||||
|
||||
class SosMessageStrategy(BaseMessageStrategy):
|
||||
def __init__(self, prompt=config.SOS_PROMPT, url=config.RAT_URL) -> None:
|
||||
self._prompt = prompt
|
||||
self._url = url
|
||||
|
||||
def get_prompt(self):
|
||||
return self._prompt
|
||||
|
||||
def transfer(self, recognized_result: any) -> str:
|
||||
return {
|
||||
'transcript': recognized_result['text'],
|
||||
'results': self._filter_words_with_prompt(recognized_result['text']),
|
||||
'segments': recognized_result['segments']
|
||||
}
|
||||
|
||||
def _filter_words_with_prompt(self, text: str) -> str:
|
||||
words = []
|
||||
|
||||
for prompt in self._prompt.split(' '):
|
||||
if prompt in text.lower():
|
||||
words.append(prompt)
|
||||
|
||||
return words
|
||||
|
||||
def send(self, message) -> any:
|
||||
pass
|
||||
#return requests.post(self._url + MESSAGE_ENDPOINT, json={'message': message})
|
1
src/queue_stack/__init__.py
Normal file
1
src/queue_stack/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from queue_stack.queue_stack import QueueStack
|
51
src/queue_stack/queue_stack.py
Normal file
51
src/queue_stack/queue_stack.py
Normal file
@ -0,0 +1,51 @@
|
||||
import sys
|
||||
|
||||
from threading import Thread, Event, Lock
|
||||
|
||||
from queue_stack.strategies import BaseProcessStrategy
|
||||
|
||||
class QueueStack:
|
||||
def __init__(self, strategy: BaseProcessStrategy) -> None:
|
||||
self._stack = []
|
||||
self._strategy = strategy
|
||||
|
||||
self._lock = Lock()
|
||||
self._running = False
|
||||
|
||||
self._last_response = None
|
||||
|
||||
def append(self, args, event=None) -> None:
|
||||
with self._lock:
|
||||
self._stack.append((args, event))
|
||||
|
||||
def append_and_await(self, args) -> any:
|
||||
event = Event()
|
||||
self.append(args, event=event)
|
||||
|
||||
event.wait()
|
||||
event.clear()
|
||||
|
||||
return self._last_response
|
||||
|
||||
def loop(self) -> None:
|
||||
self._running = True
|
||||
|
||||
while self._running:
|
||||
with self._lock:
|
||||
if self._stack:
|
||||
print('Stack length:', len(self._stack), file=sys.stderr)
|
||||
(args, event) = self._stack.pop(0)
|
||||
self._last_response = self._process(*args)
|
||||
|
||||
if event:
|
||||
event.set()
|
||||
|
||||
def _process(self, *args, **kwargs) -> any:
|
||||
return self._strategy.process(*args, **kwargs)
|
||||
|
||||
def start_loop_in_thread(self) -> None:
|
||||
thread = Thread(target=self.loop)
|
||||
thread.start()
|
||||
|
||||
def stop_loop(self) -> None:
|
||||
self._running = False
|
2
src/queue_stack/strategies/__init__.py
Normal file
2
src/queue_stack/strategies/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from queue_stack.strategies.base_process_strategy import BaseProcessStrategy
|
||||
from queue_stack.strategies.recognize_and_send_strategy import RecognizeAndSendStrategy
|
6
src/queue_stack/strategies/base_process_strategy.py
Normal file
6
src/queue_stack/strategies/base_process_strategy.py
Normal file
@ -0,0 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseProcessStrategy(ABC):
|
||||
@abstractmethod
|
||||
def process(self, *args, **kwargs) -> any:
|
||||
pass
|
14
src/queue_stack/strategies/recognize_and_send_strategy.py
Normal file
14
src/queue_stack/strategies/recognize_and_send_strategy.py
Normal file
@ -0,0 +1,14 @@
|
||||
import sys
|
||||
|
||||
from queue_stack.strategies import BaseProcessStrategy
|
||||
from message import MessageService
|
||||
from recognizer import Recognizer
|
||||
|
||||
class RecognizeAndSendStrategy(BaseProcessStrategy):
|
||||
def process(self, file, recognizer: Recognizer, message_service: MessageService, language, prompt) -> any:
|
||||
|
||||
result = recognizer.recognize(file, language=language, prompt=prompt)
|
||||
message = message_service.transfer_and_send(result)
|
||||
print(message, file=sys.stderr)
|
||||
|
||||
return message
|
@ -1,4 +1 @@
|
||||
from recognizer.recognizer import Recognizer
|
||||
from recognizer.recognizer_strategy import RecognizerStrategy
|
||||
from recognizer.whisper_strategy import WhisperStrategy
|
||||
from recognizer.fast_whisper_strategy import FastWhisperStrategy
|
||||
from recognizer.recognizer import Recognizer
|
@ -1,37 +0,0 @@
|
||||
import whisper
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
import config
|
||||
from recognizer import RecognizerStrategy
|
||||
|
||||
class FastWhisperStrategy(RecognizerStrategy):
|
||||
def __init__(self) -> None:
|
||||
self._model = WhisperModel(
|
||||
model_size=config.HARPYIA_MODEL,
|
||||
device=config.DEVICE,
|
||||
num_workers=6,
|
||||
cpu_threads=10,
|
||||
# in_memory=True,
|
||||
)
|
||||
|
||||
def recognize(self, file) -> str:
|
||||
audio = self._prepare_file(file.name)
|
||||
return self._transcribe(audio)
|
||||
|
||||
def _prepare_file(self, filename: str):
|
||||
audio = whisper.load_audio(filename, sr=16000)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
return audio
|
||||
|
||||
|
||||
def _transcribe(self, audio):
|
||||
segments, _ = self._model.transcribe(
|
||||
audio,
|
||||
language=config.HARPYIA_LANGUAGE,
|
||||
initial_prompt=config.HARPYIA_PROMPT,
|
||||
condition_on_previous_text=False,
|
||||
vad_filter=True,
|
||||
beam_size=5,
|
||||
)
|
||||
|
||||
return segments
|
@ -1,8 +1,14 @@
|
||||
from recognizer import RecognizerStrategy
|
||||
import sys
|
||||
|
||||
import config
|
||||
from recognizer.strategies import BaseRecognizerStrategy
|
||||
|
||||
class Recognizer:
|
||||
def __init__(self, strategy: RecognizerStrategy):
|
||||
def __init__(self, strategy: BaseRecognizerStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
|
||||
def recognize(self, file) -> str:
|
||||
self._strategy.recognize(file)
|
||||
def recognize(self, file, language, prompt) -> str:
|
||||
result = self._strategy.recognize(file, language=language, prompt=prompt)
|
||||
|
||||
print(f'Result: {result}', file=sys.stderr)
|
||||
return result
|
@ -1,6 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class RecognizerStrategy(ABC):
|
||||
@abstractmethod
|
||||
def recognize(self, file) -> str:
|
||||
pass
|
3
src/recognizer/strategies/__init__.py
Normal file
3
src/recognizer/strategies/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from recognizer.strategies.base_recognizer_strategy import BaseRecognizerStrategy
|
||||
from recognizer.strategies.whisper_strategy import WhisperStrategy
|
||||
from recognizer.strategies.fast_whisper_strategy import FastWhisperStrategy
|
6
src/recognizer/strategies/base_recognizer_strategy.py
Normal file
6
src/recognizer/strategies/base_recognizer_strategy.py
Normal file
@ -0,0 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseRecognizerStrategy(ABC):
|
||||
@abstractmethod
|
||||
def recognize(self, file, language, prompt) -> any:
|
||||
pass
|
59
src/recognizer/strategies/fast_whisper_strategy.py
Normal file
59
src/recognizer/strategies/fast_whisper_strategy.py
Normal file
@ -0,0 +1,59 @@
|
||||
import sys
|
||||
|
||||
import whisper
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
import config
|
||||
from recognizer.strategies import BaseRecognizerStrategy
|
||||
|
||||
class FastWhisperStrategy(BaseRecognizerStrategy):
|
||||
def __init__(self) -> None:
|
||||
self._model = WhisperModel(
|
||||
model_size_or_path=config.HARPYIA_MODEL,
|
||||
device=config.DEVICE,
|
||||
num_workers=config.WHISPER_NUM_WORKERS,
|
||||
cpu_threads=config.WHISPER_CPU_THREADS
|
||||
)
|
||||
|
||||
def recognize(self, file, language, prompt) -> any:
|
||||
audio = self._prepare_file(file.name)
|
||||
return self._transcribe(audio, language, prompt)
|
||||
|
||||
def _prepare_file(self, filename: str):
|
||||
audio = whisper.load_audio(filename, sr=config.HARPYIA_SAMPLE_RATE)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
return audio
|
||||
|
||||
def _transcribe(self, audio, language, prompt):
|
||||
segments, _ = self._model.transcribe(
|
||||
audio,
|
||||
language=language,
|
||||
initial_prompt=prompt,
|
||||
condition_on_previous_text=False,
|
||||
vad_filter=True,
|
||||
beam_size=config.WHISPER_BEAM_SIZE,
|
||||
)
|
||||
|
||||
print('Segments:', file=sys.stderr)
|
||||
for i in segments:
|
||||
print(i, file=sys.stderr)
|
||||
|
||||
words = []
|
||||
for segment in list(segments):
|
||||
words.append(segment.text)
|
||||
|
||||
return {
|
||||
'text': ' '.join(words),
|
||||
'segments': {
|
||||
'id': None,
|
||||
'seek': None,
|
||||
'start': None,
|
||||
'end': None,
|
||||
'text': None,
|
||||
'tokens': None,
|
||||
'temperature': None,
|
||||
'avg_logprob': None,
|
||||
'compression_ratio': None,
|
||||
'no_speech_prob': None,
|
||||
}
|
||||
}
|
12
src/recognizer/strategies/whisper_strategy.py
Normal file
12
src/recognizer/strategies/whisper_strategy.py
Normal file
@ -0,0 +1,12 @@
|
||||
import whisper
|
||||
|
||||
import config
|
||||
from recognizer.strategies import BaseRecognizerStrategy
|
||||
|
||||
class WhisperStrategy(BaseRecognizerStrategy):
|
||||
def __init__(self) -> None:
|
||||
self._model = whisper.load_model(config.HARPYIA_MODEL, device=config.DEVICE)
|
||||
|
||||
def recognize(self, file, language, prompt) -> any:
|
||||
return self._model.transcribe(file.name, \
|
||||
language=language, initial_prompt=prompt)
|
@ -1,12 +0,0 @@
|
||||
import whisper
|
||||
|
||||
import config
|
||||
from recognizer import RecognizerStrategy
|
||||
|
||||
class WhisperStrategy(RecognizerStrategy):
|
||||
def __init__(self) -> None:
|
||||
self._model = whisper.load_model(config.HARPYIA_MODEL, device=config.DEVICE)
|
||||
|
||||
def recognize(self, file) -> str:
|
||||
return self._model.transcribe(file.name, \
|
||||
language=config.HARPYIA_LANGUAGE, initial_prompt=config.HARPYIA_PROMPT)
|
Loading…
x
Reference in New Issue
Block a user