diff --git a/pyproject.toml b/pyproject.toml index bea402a5..5c48d69b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,10 @@ cloud_llm = [ "langchain-openai >=0.3.0, <0.4.0", "langchain-google-genai >=2.1.0, <2.2.0", ] +cloud_asr = [ + "websocket-client", + "google-cloud-speech>=2.27.0", +] all = [ "arduino_app_bricks[dbstorage_influx]", @@ -90,6 +94,7 @@ all = [ "arduino_app_bricks[stream]", "arduino_app_bricks[arduino_cloud]", "arduino_app_bricks[cloud_llm]", + "arduino_app_bricks[cloud_asr]", ] [project.urls] diff --git a/src/arduino/app_bricks/cloud_asr/__init__.py b/src/arduino/app_bricks/cloud_asr/__init__.py new file mode 100644 index 00000000..1e62c423 --- /dev/null +++ b/src/arduino/app_bricks/cloud_asr/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +from .cloud_asr import CloudASR +from .providers import ASREvent, CloudProvider + +__all__ = ["CloudASR", "ASREvent", "CloudProvider"] diff --git a/src/arduino/app_bricks/cloud_asr/brick_config.yaml b/src/arduino/app_bricks/cloud_asr/brick_config.yaml new file mode 100644 index 00000000..1a07f644 --- /dev/null +++ b/src/arduino/app_bricks/cloud_asr/brick_config.yaml @@ -0,0 +1,16 @@ +id: arduino:cloud_asr +name: Cloud ASR +description: | + Cloud ASR Brick provides a unified and flexible way to connect cloud-based Automatic Speech Recognition (ASR) services and transform spoken audio into text. + It enables real-time, streaming transcription from a connected microphone, leveraging leading cloud providers to deliver low-latency speech-to-text processing. +category: audio +mount_devices_into_container: true +required_devices: + - microphone +requires_container: false +requires_model: false +variables: + - name: API_KEY + description: API Key for the cloud-based Speech to Text service + - name: LANGUAGE + description: Language code for transcription (e.g., en, it). Default: en diff --git a/src/arduino/app_bricks/cloud_asr/cloud_asr.py b/src/arduino/app_bricks/cloud_asr/cloud_asr.py new file mode 100644 index 00000000..970cc708 --- /dev/null +++ b/src/arduino/app_bricks/cloud_asr/cloud_asr.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +from __future__ import annotations + +import os +import queue +import threading +from typing import Iterator, Callable, Optional + +import numpy as np + +from arduino.app_peripherals.microphone import Microphone +from arduino.app_utils import Logger, brick + +from .providers import ASRProvider, CloudProvider, DEFAULT_PROVIDER, provider_factory + +logger = Logger(__name__) + +DEFAULT_LANGUAGE = "en" + + +@brick +class CloudASR: + """ + Cloud-based speech-to-text with pluggable cloud providers. + It captures audio from a microphone and streams it to the selected cloud ASR provider for transcription. + The recognized text is yielded as events in real-time. + """ + + def __init__( + self, + api_key: str = os.getenv("API_KEY", ""), + provider: CloudProvider = DEFAULT_PROVIDER, + mic: Optional[Microphone] = None, + language: str = os.getenv("LANGUAGE", ""), + ): + if mic: + logger.info(f"[{self.__class__.__name__}] Using provided microphone: {mic}") + self._mic = mic + else: + self._mic = Microphone() + + self._language = language + self._mic_lock = threading.Lock() + self._provider: ASRProvider = provider_factory( + api_key=api_key, + name=provider, + language=self._language, + sample_rate=self._mic.sample_rate, + ) + + self.detect_handlers: list[Callable[[dict], None]] = [] + self.detect_handlers_lock = threading.Lock() + self.partial_handlers: list[Callable[[dict], None]] = [] + self.partial_handlers_lock = threading.Lock() + + def start(self): + with self._mic_lock: + if not self._mic.is_recording.is_set(): + self._mic.start() + logger.info(f"[{self.__class__.__name__}] Microphone started.") + + def stop(self): + with self._mic_lock: + if self._mic.is_recording.is_set(): + self._mic.stop() + logger.info(f"[{self.__class__.__name__}] Microphone stopped.") + + def on_detect(self, handler): + """Register a callback to be invoked when speech is detected.""" + with self.detect_handlers_lock: + self.detect_handlers.append(handler) + + @brick.loop + def _detect_loop(self): + """Continuously listen for speech and invoke handlers when final text is detected.""" + for resp in self.transcribe(): + match resp["event"]: + case "error": + logger.error(f"ASR error: {resp['data']}") + case "text": + with self.detect_handlers_lock: + for handler in self.detect_handlers: + try: + handler(resp["data"]) + except Exception as exc: + logger.error(f"Error in speech detected handler: {exc}") + + def on_update(self, handler): + """Register a callback to be invoked for each partial speech update.""" + with self.partial_handlers_lock: + self.partial_handlers.append(handler) + + @brick.loop + def _update_loop(self): + """Continuously listen for partial speech and invoke handlers.""" + for resp in self.transcribe(): + with self.partial_handlers_lock: + for handler in self.partial_handlers: + try: + handler(resp) + except Exception as exc: + logger.error(f"Error in partial speech handler: {exc}") + + def transcribe(self) -> Iterator[dict]: + """Perform speech-to-text recognition. + + Returns: + Iterator[dict]: Generator yielding + {"event": ("speech_start|partial_text|text|error|speech_stop"), "data": ""} + messages. + """ + + provider = self._provider + messages: queue.Queue[dict] = queue.Queue() + stop_event = threading.Event() + + def _send(): + try: + for chunk in self._mic.stream(): + if stop_event.is_set(): + break + if chunk is None: + continue + pcm_chunk_np = np.asarray(chunk, dtype=np.int16) + provider.send_audio(pcm_chunk_np.tobytes()) + except KeyboardInterrupt: + logger.info("Recognition interrupted by user. Exiting...") + except Exception as exc: + logger.error("Error while streaming microphone audio: %s", exc) + messages.put({"event": "error", "data": str(exc)}) + finally: + stop_event.set() + + partial_buffer = "" + + def _recv(): + nonlocal partial_buffer + try: + while not stop_event.is_set(): + result = provider.recv() + if result is None: + continue + + data = result.data + if result.event == "partial_text": + if self._provider.partial_mode == "replace": + partial_buffer = str(data) + else: + partial_buffer += str(data) + elif result.event == "text": + data = data or partial_buffer + partial_buffer = "" + messages.put({"event": result.event, "data": data}) + + except Exception as exc: + logger.error("Error receiving transcription events: %s", exc) + messages.put({"event": "error", "data": str(exc)}) + stop_event.set() + + send_thread = threading.Thread(target=_send, daemon=True) + recv_thread = threading.Thread(target=_recv, daemon=True) + send_thread.start() + recv_thread.start() + + try: + while recv_thread.is_alive() or send_thread.is_alive() or not messages.empty(): + try: + msg = messages.get(timeout=0.1) + yield msg + except queue.Empty: + continue + finally: + stop_event.set() + send_thread.join(timeout=1) + recv_thread.join(timeout=1) + provider.stop() diff --git a/src/arduino/app_bricks/cloud_asr/examples/1_asr.py b/src/arduino/app_bricks/cloud_asr/examples/1_asr.py new file mode 100644 index 00000000..11937a70 --- /dev/null +++ b/src/arduino/app_bricks/cloud_asr/examples/1_asr.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +# EXAMPLE_NAME = "Detect speech from microphone" +# EXAMPLE_REQUIRES = "Requires an USB microphone connected to the Arduino board." +from arduino.app_bricks.cloud_asr import CloudASR +from arduino.app_utils import App + +cloud_asr = CloudASR( + api_key="YOUR_API_KEY", # Replace with your actual API key +) +cloud_asr.on_detect(lambda text: print(f"Detected speech: {text}")) + +App.run() diff --git a/src/arduino/app_bricks/cloud_asr/examples/2_multilingual_asr.py b/src/arduino/app_bricks/cloud_asr/examples/2_multilingual_asr.py new file mode 100644 index 00000000..6254c863 --- /dev/null +++ b/src/arduino/app_bricks/cloud_asr/examples/2_multilingual_asr.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +# EXAMPLE_NAME = "Detect speech from microphone in Italian" +# EXAMPLE_REQUIRES = "Requires an USB microphone connected to the Arduino board." +from arduino.app_bricks.cloud_asr import CloudASR +from arduino.app_utils import App + +cloud_asr = CloudASR( + api_key="YOUR_API_KEY", # Replace with your actual API key + language="it", # Set language to Italian +) +cloud_asr.on_detect(lambda text: print(f"Detected speech: {text}")) + +App.run() diff --git a/src/arduino/app_bricks/cloud_asr/examples/3_event_stream.py b/src/arduino/app_bricks/cloud_asr/examples/3_event_stream.py new file mode 100644 index 00000000..9c2d044c --- /dev/null +++ b/src/arduino/app_bricks/cloud_asr/examples/3_event_stream.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +# EXAMPLE_NAME = "Sends audio from microphone and receives all the streaming events" +# EXAMPLE_REQUIRES = "Requires an USB microphone connected to the Arduino board." +from arduino.app_bricks.cloud_asr import CloudASR +from arduino.app_utils import App + +cloud_asr = CloudASR( + api_key="YOUR_API_KEY", # Replace with your actual API key +) +cloud_asr.on_update(lambda resp: print(f"{resp['event']}: {resp['data']}")) + +App.run() diff --git a/src/arduino/app_bricks/cloud_asr/providers/__init__.py b/src/arduino/app_bricks/cloud_asr/providers/__init__.py new file mode 100644 index 00000000..eff54ca8 --- /dev/null +++ b/src/arduino/app_bricks/cloud_asr/providers/__init__.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +from enum import Enum + +from .openai import OpenAITranscribe +from .google import GoogleSpeech +from .types import ASREvent, ASRProvider + + +class CloudProvider(str, Enum): + OPENAI_TRANSCRIBE = "openai-transcribe" + GOOGLE_SPEECH = "google-speech" + + +DEFAULT_PROVIDER = CloudProvider.OPENAI_TRANSCRIBE + + +def provider_factory( + api_key: str, + language: str, + sample_rate: int, + name: CloudProvider = DEFAULT_PROVIDER, +) -> ASRProvider: + """Return the ASR cloud provider implementation.""" + if name == CloudProvider.OPENAI_TRANSCRIBE: + return OpenAITranscribe( + api_key=api_key, + language=language, + sample_rate=sample_rate, + ) + if name == CloudProvider.GOOGLE_SPEECH: + return GoogleSpeech( + api_key=api_key, + language=language, + sample_rate=sample_rate, + ) + raise ValueError(f"Unsupported ASR cloud provider: {name}") + + +__all__ = [ + "ASREvent", + "ASRProvider", + "CloudProvider", + "DEFAULT_PROVIDER", + "GoogleSpeech", + "OpenAITranscribe", + "provider_factory", +] diff --git a/src/arduino/app_bricks/cloud_asr/providers/google.py b/src/arduino/app_bricks/cloud_asr/providers/google.py new file mode 100644 index 00000000..1fad120f --- /dev/null +++ b/src/arduino/app_bricks/cloud_asr/providers/google.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +from __future__ import annotations + +import queue +import threading + +from arduino.app_utils import Logger +from google.api_core.client_options import ClientOptions +from google.cloud.speech import SpeechClient, StreamingRecognitionConfig, RecognitionConfig, StreamingRecognizeRequest, StreamingRecognizeResponse + +from .types import ASREvent + +logger = Logger(__name__) + + +class GoogleSpeech: + """ + Google ASR cloud provider implementation. + + It uses google cloud speech package to connect to Google Speech-to-Text API + for streaming transcription. + For English locales it uses the default streaming model. For non-English + locales the standard model segments poorly, so `latest_short` is used to + get faster segmentation even though it emits a single utterance; when that + happens the stream is restarted transparently while preserving queued + audio so callers keep a continuous feed of events. + """ + + partial_mode = "replace" + provider_name = "google-speech" + + GOOGLE_LANG_MAP = { + "en": "en-US", + "it": "it-IT", + "es": "es-ES", + "fr": "fr-FR", + "de": "de-DE", + "pt": "pt-PT", + "pt-br": "pt-BR", + } + DEFAULT_LANGUAGE = "en" + + def __init__( + self, + api_key: str, + language: str = DEFAULT_LANGUAGE, + sample_rate: int = 16000, + ): + if not api_key: + raise RuntimeError("Google Speech requires an API key; set API_KEY for this cloud provider.") + self._api_key = api_key + + self._language = self._resolve_google_language(language) + if not self._language: + self._language = self.DEFAULT_LANGUAGE + + self._sample_rate = sample_rate + self._use_short_model = not (self._language.lower().startswith("en")) + + self._stop_event = threading.Event() + self._audio_q: queue.Queue[bytes | None] = queue.Queue() + self._resp_q: queue.Queue[ASREvent | None] = queue.Queue() + + self._client = SpeechClient(client_options=ClientOptions(api_key=self._api_key)) + self._config = self._build_config() + + self._thread = threading.Thread(target=self._asr_worker, daemon=True) + self._thread.start() + + def _resolve_google_language(self, language: str) -> str: + if not language: + return self.GOOGLE_LANG_MAP[self.DEFAULT_LANGUAGE] + key = language.strip().lower() + return self.GOOGLE_LANG_MAP.get(key, language) + + def _build_config(self) -> StreamingRecognitionConfig: + config_kwargs = dict( + encoding=RecognitionConfig.AudioEncoding.LINEAR16, + sample_rate_hertz=self._sample_rate, + language_code=self._language, + enable_automatic_punctuation=True, + ) + if self._use_short_model: + config_kwargs["model"] = "latest_short" + + return StreamingRecognitionConfig( + config=RecognitionConfig(**config_kwargs), + interim_results=True, + enable_voice_activity_events=True, + single_utterance=self._use_short_model, + ) + + def _request_loop(self, session_end: threading.Event): + while not self._stop_event.is_set() and not session_end.is_set(): + try: + chunk = self._audio_q.get(timeout=0.1) + except queue.Empty: + continue + + # When stop() is called, a None value is pushed into the audio queue + # to unblock audio_q.get() and force this generator to exit immediately. + # This allows streaming_recognize to terminate cleanly even if no audio + # chunks are currently being produced. + if chunk is None: + return + + yield StreamingRecognizeRequest(audio_content=chunk) + + def _asr_worker(self): + """ + ASR worker thread that streams audio to Google Speech and emits transcription events. + + The worker runs streaming_recognize in a loop to support both standard and short models. + For short models, Google ends the stream at utterance boundaries; when + END_OF_SINGLE_UTTERANCE is received, `session_end` stops audio consumption in + `_request_loop`. The outer loop then restarts streaming_recognize to continue + processing subsequent audio. + For standard models, this event never occurs, so the stream remains open until + explicitly stopped or an error occurs. + """ + try: + while not self._stop_event.is_set(): + session_end = threading.Event() + try: + for response in self._client.streaming_recognize( + config=self._config, + requests=self._request_loop(session_end), + ): + ev = self._format_event(response) + if ev is None: + continue + if ev.event == "utterance_end": + session_end.set() + continue + + self._resp_q.put(ev) + + except Exception as exc: + if not self._stop_event.is_set(): + self._resp_q.put(ASREvent(event="error", data=str(exc))) + break + + finally: + self._resp_q.put(None) + + def _format_event(self, message: object) -> ASREvent | None: + match getattr(message, "speech_event_type", None): + case StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN: + return ASREvent(event="speech_start") + case StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END: + return ASREvent(event="speech_stop") + case StreamingRecognizeResponse.SpeechEventType.END_OF_SINGLE_UTTERANCE: + return ASREvent(event="utterance_end") + + results = getattr(message, "results", None) + if not results: + return None + + final_text: str | None = None + best_partial_text: str | None = None + best_stability: float = -1.0 + + for result in results: + alternatives = getattr(result, "alternatives", []) + if not alternatives: + continue + + # First alternative is the most probable one + transcript = (alternatives[0].transcript or "").strip() + if not transcript: + continue + + if getattr(result, "is_final", False): + final_text = transcript + else: + stability = float(getattr(result, "stability", 0.0)) + if stability > best_stability: + best_stability = stability + best_partial_text = transcript + + if final_text: + return ASREvent(event="text", data=final_text) + if best_partial_text: + return ASREvent(event="partial_text", data=best_partial_text) + return None + + def send_audio(self, pcm_chunk: bytes) -> None: + self._audio_q.put(pcm_chunk) + + def recv(self) -> ASREvent | None: + try: + return self._resp_q.get(timeout=0.1) + except queue.Empty: + return None + + def stop(self) -> None: + self._stop_event.set() + self._audio_q.put(None) + try: + self._thread.join(timeout=1) + except Exception: + pass + + +__all__ = ["GoogleSpeech"] diff --git a/src/arduino/app_bricks/cloud_asr/providers/openai.py b/src/arduino/app_bricks/cloud_asr/providers/openai.py new file mode 100644 index 00000000..2f401fcf --- /dev/null +++ b/src/arduino/app_bricks/cloud_asr/providers/openai.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +from __future__ import annotations + +import base64 +import json + +import websocket + +from arduino.app_utils import Logger + +from .types import ASREvent + +logger = Logger(__name__) + + +class OpenAITranscribe: + """ + OpenAI ASR cloud provider implementation. + It leverages the Realtime API to enable streaming transcription powered by a GPT-based model. + Audio is transmitted and received over WebSockets, while voice activity detection (VAD) server-side + is used to segment utterances. + If custom VAD behavior is desired, the VoiceActivityDetector class can be used client-side to + trigger commits based on local audio analysis. In that case, track the audio with the vad process_chunk method and + register the vad commit() method to send a `{"type": "input_audio_buffer.commit"}` message to the server. + """ + + provider_name = "openai-transcribe" + partial_mode = "append" + + REALTIME_MODEL = "gpt-realtime" + TRANSCRIPTION_MODEL = "gpt-4o-mini-transcribe" + BASE_URL = "wss://api.openai.com/v1/realtime" + IGNORED_COMMIT_CODES = { + "input_audio_buffer_commit_empty", + "input_audio_buffer_commit_short", + } + VAD_MIN_BUFFER_MS = 120.0 + DEFAULT_LANGUAGE = "en" + + def __init__( + self, + api_key: str, + language: str = DEFAULT_LANGUAGE, + sample_rate: int = 16000, + ): + if not api_key: + raise ValueError("API key is required for OpenAI Realtime client.") + + self._api_key = api_key + self._language = language + if not self._language: + self._language = self.DEFAULT_LANGUAGE + + self._url = f"{self.BASE_URL}?model={self.REALTIME_MODEL}" + self._headers = [ + f"Authorization: Bearer {self._api_key}", + "OpenAI-Beta: realtime=v1", + ] + + self._ws = self._connect() + self._sample_rate = sample_rate + + def _connect(self) -> websocket.WebSocket: + logger.info("Connecting to realtime ASR endpoint: %s", self._url) + ws = websocket.WebSocket() + ws.connect(self._url, header=self._headers, ping_interval=20, ping_timeout=20) + self._send_session_update(ws) + return ws + + def _send_session_update(self, ws: websocket.WebSocket) -> None: + ws.send( + json.dumps({ + "type": "session.update", + "session": { + "modalities": ["text"], + "input_audio_format": "pcm16", + "turn_detection": {"type": "server_vad"}, + "input_audio_transcription": { + "model": self.TRANSCRIPTION_MODEL, + "language": self._language, + }, + "instructions": "You are a transcription engine. Only return verbatim transcripts and do not chat or respond.", + }, + }) + ) + + def _decode_message(self, raw: object) -> object: + if isinstance(raw, (str, bytes, bytearray)): + try: + return json.loads(raw) + except json.JSONDecodeError: + return raw + return raw + + def _extract_error_code(self, message: object) -> str | None: + """Try to find an error code either nested under 'error' or at top-level.""" + if not isinstance(message, dict): + return None + + err = message.get("error") + if isinstance(err, dict): + code = err.get("code") + if isinstance(code, str): + return code + + code = message.get("code") + return code if isinstance(code, str) else None + + def _extract_error_payload(self, message: object) -> object: + """Prefer nested 'error' payload if present, otherwise return message.""" + if isinstance(message, dict) and "error" in message: + return message.get("error") + return message + + def _format_event(self, message: dict) -> ASREvent | None: + match message.get("type"): + case "input_audio_buffer.speech_start": + return ASREvent(event="speech_start", data=None) + case "input_audio_buffer.speech_stop": + return ASREvent(event="speech_stop", data=None) + case "conversation.item.input_audio_transcription.delta": + delta_text = message.get("delta", "") + if delta_text: + return ASREvent(event="partial_text", data=delta_text) + + case "conversation.item.input_audio_transcription.completed": + text = message.get("transcript", "") + if text: + return ASREvent(event="text", data=text) + return ASREvent(event="error", data="Transcription completed with no text.") + + case "error" | "invalid_request_error": + code = self._extract_error_code(message) + if code in self.IGNORED_COMMIT_CODES: + logger.debug("Ignoring empty commit warning from server.") + return None + return ASREvent(event="error", data=self._extract_error_payload(message)) + + return None + + def recv(self) -> ASREvent | None: + try: + raw = self._ws.recv() + except Exception as exc: + return ASREvent(event="error", data=str(exc)) + + message = self._decode_message(raw) + if not isinstance(message, dict): + return None + + try: + return self._format_event(message) + except Exception as exc: # pragma: no cover + return ASREvent(event="error", data=str(exc)) + + def send_audio(self, pcm_chunk: bytes) -> None: + if not pcm_chunk: + return + + audio_payload = base64.b64encode(pcm_chunk).decode("ascii") + self._ws.send(json.dumps({"type": "input_audio_buffer.append", "audio": audio_payload})) + + def stop(self) -> None: + try: + self._ws.close() + except Exception: + pass + + +__all__ = ["OpenAITranscribe"] diff --git a/src/arduino/app_bricks/cloud_asr/providers/types.py b/src/arduino/app_bricks/cloud_asr/providers/types.py new file mode 100644 index 00000000..43df235a --- /dev/null +++ b/src/arduino/app_bricks/cloud_asr/providers/types.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol, runtime_checkable + + +@dataclass(frozen=True) +class ASREvent: + event: str + data: object | None = None + + +@runtime_checkable +class ASRProvider(Protocol): + """Minimal interface for realtime ASR cloud providers.""" + + @property + def provider_name(self) -> str: ... + + @property + def partial_mode(self) -> str: ... + + def send_audio(self, pcm_chunk: bytes) -> None: ... + + def recv(self) -> ASREvent | None: ... + + def stop(self) -> None: ... diff --git a/src/arduino/app_utils/vad.py b/src/arduino/app_utils/vad.py new file mode 100644 index 00000000..bbb026cf --- /dev/null +++ b/src/arduino/app_utils/vad.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +import numpy as np + +ENERGY_THRESHOLD = 80.0 +SILENCE_MS = 1800.0 +MAX_BUFFER_MS = 12000.0 + + +@dataclass +class VADState: + buffered_ms: float = 0.0 + silence_ms: float = 0.0 + speaking: bool = False + + +class VoiceActivityDetector: + """ + This class analyzes incoming PCM16 audio chunks by estimating their signal + energy to determine whether speech is present. Chunks with energy above the + configured threshold are classified as speech, while lower-energy chunks + contribute to silence accumulation. + + Audio duration is buffered while speech is active and a commit callback is + triggered when one of the following conditions is met: + + - A period of silence longer than the configured silence threshold occurs + after speech has started. + - The maximum allowed buffered audio duration is reached. + + The detector is stateful and must be fed sequential audio chunks from a + continuous audio stream. + + Args: + commit_callback (Callable[[], None]): + Function invoked when the buffered audio should be committed. + + min_buffer_ms (float): + Minimum amount of buffered audio (in milliseconds) required to + trigger a commit. Shorter segments are discarded. + + energy_threshold (float, optional): + Energy threshold used to classify a chunk as speech. + Higher values make the detector less sensitive to quiet speech. + Defaults to `ENERGY_THRESHOLD`. + + silence_ms (float, optional): + Amount of consecutive silence (in milliseconds) required to + consider speech ended and trigger a commit. + Defaults to `SILENCE_MS`. + + max_buffer_ms (float, optional): + Maximum amount of audio (in milliseconds) that can be buffered + before forcing a commit, even if speech has not ended. + Defaults to `MAX_BUFFER_MS`. + """ + + def __init__( + self, + commit_callback: Callable[[], None], + min_buffer_ms: float, + energy_threshold: float = ENERGY_THRESHOLD, + silence_ms: float = SILENCE_MS, + max_buffer_ms: float = MAX_BUFFER_MS, + ): + self._commit_callback = commit_callback + self._min_buffer_ms = min_buffer_ms + self._energy_threshold = energy_threshold + self._silence_ms_threshold = silence_ms + self._max_buffer_ms = max_buffer_ms + self._state = VADState() + + def process_chunk(self, pcm_chunk: bytes, sample_rate: int) -> None: + """Update VAD state using raw PCM16 bytes and commit buffered audio when thresholds are met.""" + chunk_ms = chunk_duration_ms(pcm_chunk, sample_rate) + if chunk_ms <= 0: + return + + pcm_chunk_np = np.frombuffer(pcm_chunk, dtype=np.int16) + if self._should_commit(pcm_chunk_np, chunk_ms): + self.commit_buffer() + + def commit_buffer(self) -> None: + if self._state.buffered_ms >= self._min_buffer_ms: + self._commit_callback() + self._state = VADState() + + def flush(self) -> None: + self.commit_buffer() + + def _chunk_energy(self, pcm_chunk_np: np.ndarray) -> float: + return float(np.abs(pcm_chunk_np).mean()) + + def _should_commit(self, pcm_chunk_np: np.ndarray, chunk_ms: float) -> bool: + energy = self._chunk_energy(pcm_chunk_np) + state = self._state + state.buffered_ms += chunk_ms + + if energy > self._energy_threshold: + state.speaking = True + state.silence_ms = 0.0 + elif state.speaking: + state.silence_ms += chunk_ms + if state.silence_ms >= self._silence_ms_threshold: + return True + + if state.buffered_ms >= self._max_buffer_ms: + return True + + return False + + +def chunk_duration_ms(pcm_chunk: bytes, sample_rate: int) -> float: + if sample_rate <= 0: + return 0.0 + samples = len(pcm_chunk) / 2 # 2 bytes per int16 sample + return (samples / sample_rate) * 1000.0 + + +__all__ = [ + "MAX_BUFFER_MS", + "SILENCE_MS", + "ENERGY_THRESHOLD", + "VADState", + "VoiceActivityDetector", + "chunk_duration_ms", +]