Skip to content

Commit dcfcb28

Browse files
committed
Introduce cloud_asr brick
1 parent 73aafb2 commit dcfcb28

File tree

9 files changed

+748
-0
lines changed

9 files changed

+748
-0
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ cloud_llm = [
7777
"langchain-openai >=0.3.0, <0.4.0",
7878
"langchain-google-genai >=2.1.0, <2.2.0",
7979
]
80+
cloud_asr = [
81+
"websocket-client",
82+
"google-cloud-speech>=2.27.0",
83+
]
8084

8185
all = [
8286
"arduino_app_bricks[dbstorage_influx]",
@@ -90,6 +94,7 @@ all = [
9094
"arduino_app_bricks[stream]",
9195
"arduino_app_bricks[arduino_cloud]",
9296
"arduino_app_bricks[cloud_llm]",
97+
"arduino_app_bricks[cloud_asr]",
9398
]
9499

95100
[project.urls]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc)
2+
#
3+
# SPDX-License-Identifier: MPL-2.0
4+
5+
from .cloud_asr import CloudASR
6+
from .providers import ASREvent, CloudProvider
7+
8+
__all__ = ["CloudASR", "ASREvent", "CloudProvider"]
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
id: arduino:cloud_asr
2+
name: Cloud ASR
3+
description: |
4+
Cloud ASR Brick provides a unified and flexible way to connect cloud-based Automatic Speech Recognition (ASR) services and transform spoken audio into text.
5+
It enables real-time, streaming transcription from a connected microphone, leveraging leading cloud providers to deliver low-latency speech-to-text processing.
6+
category: audio
7+
mount_devices_into_container: true
8+
required_devices:
9+
- microphone
10+
requires_container: false
11+
requires_model: false
12+
variables:
13+
- name: API_KEY
14+
description: API Key for the cloud-based Speech to Text service
15+
- name: LANGUAGE
16+
description: Language code for transcription (e.g., en, it). Default: en
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc)
2+
#
3+
# SPDX-License-Identifier: MPL-2.0
4+
5+
from __future__ import annotations
6+
7+
import os
8+
import queue
9+
import threading
10+
from typing import Iterator, Callable, Optional
11+
12+
import numpy as np
13+
14+
from arduino.app_peripherals.microphone import Microphone
15+
from arduino.app_utils import Logger, brick
16+
17+
from .providers import ASRProvider, CloudProvider, DEFAULT_PROVIDER, provider_factory
18+
19+
logger = Logger(__name__)
20+
21+
DEFAULT_LANGUAGE = "en"
22+
23+
24+
@brick
25+
class CloudASR:
26+
"""
27+
Cloud-based speech-to-text with pluggable cloud providers.
28+
It captures audio from a microphone and streams it to the selected cloud ASR provider for transcription.
29+
The recognized text is yielded as events in real-time.
30+
"""
31+
32+
def __init__(
33+
self,
34+
api_key: str = os.getenv("API_KEY", ""),
35+
provider: CloudProvider = DEFAULT_PROVIDER,
36+
mic: Optional[Microphone] = None,
37+
language: str = os.getenv("LANGUAGE", ""),
38+
stream_partial: bool = True,
39+
):
40+
if mic:
41+
logger.info(f"[{self.__class__.__name__}] Using provided microphone: {mic}")
42+
self._mic = mic
43+
else:
44+
self._mic = Microphone()
45+
46+
self._language = language
47+
self._stream_partial = stream_partial
48+
self._mic_lock = threading.Lock()
49+
self._provider: ASRProvider = provider_factory(
50+
api_key=api_key,
51+
name=provider,
52+
language=self._language,
53+
sample_rate=self._mic.sample_rate,
54+
stream_partial=self._stream_partial,
55+
)
56+
57+
self.handlers: list[Callable[[dict], None]] = []
58+
self.handlers_lock = threading.Lock()
59+
60+
def start(self):
61+
with self._mic_lock:
62+
if not self._mic.is_recording.is_set():
63+
self._mic.start()
64+
logger.info(f"[{self.__class__.__name__}] Microphone started.")
65+
66+
def stop(self):
67+
with self._mic_lock:
68+
if self._mic.is_recording.is_set():
69+
self._mic.stop()
70+
logger.info(f"[{self.__class__.__name__}] Microphone stopped.")
71+
72+
def on_detect(self, handler):
73+
"""Register a callback to be invoked when speech is detected."""
74+
with self.handlers_lock:
75+
self.handlers.append(handler)
76+
77+
@brick.loop
78+
def _detect_loop(self):
79+
"""Continuously listen for speech and invoke handlers."""
80+
for resp in self.transcribe():
81+
with self.handlers_lock:
82+
for handler in self.handlers:
83+
try:
84+
handler(resp)
85+
except Exception as exc:
86+
logger.error(f"Error in speech detected handler: {exc}")
87+
88+
def transcribe(self) -> Iterator[dict]:
89+
"""Perform speech-to-text recognition.
90+
91+
Returns:
92+
Iterator[dict]: Generator yielding {"event": ("partial_text"|"text"|"error"), "data": "<payload>"} messages.
93+
"""
94+
95+
provider = self._provider
96+
messages: queue.Queue[dict] = queue.Queue()
97+
stop_event = threading.Event()
98+
99+
def _send():
100+
try:
101+
for chunk in self._mic.stream():
102+
if stop_event.is_set():
103+
break
104+
if chunk is None:
105+
continue
106+
pcm_chunk_np = np.asarray(chunk, dtype=np.int16)
107+
provider.send_audio(pcm_chunk_np.tobytes())
108+
except KeyboardInterrupt:
109+
logger.info("Recognition interrupted by user. Exiting...")
110+
except Exception as exc:
111+
logger.error("Error while streaming microphone audio: %s", exc)
112+
messages.put({"event": "error", "data": str(exc)})
113+
finally:
114+
stop_event.set()
115+
116+
partial_buffer = ""
117+
118+
def _recv():
119+
nonlocal partial_buffer
120+
try:
121+
while not stop_event.is_set():
122+
result = provider.recv()
123+
if result is None:
124+
continue
125+
126+
data = result.data
127+
if result.event == "partial_text":
128+
if self._provider.partial_mode == "replace":
129+
partial_buffer = str(data)
130+
else:
131+
partial_buffer += str(data)
132+
elif result.event == "text":
133+
data = data or partial_buffer
134+
partial_buffer = ""
135+
messages.put({"event": result.event, "data": data})
136+
137+
except Exception as exc:
138+
logger.error("Error receiving transcription events: %s", exc)
139+
messages.put({"event": "error", "data": str(exc)})
140+
stop_event.set()
141+
142+
send_thread = threading.Thread(target=_send, daemon=True)
143+
recv_thread = threading.Thread(target=_recv, daemon=True)
144+
send_thread.start()
145+
recv_thread.start()
146+
147+
try:
148+
while recv_thread.is_alive() or send_thread.is_alive() or not messages.empty():
149+
try:
150+
msg = messages.get(timeout=0.1)
151+
yield msg
152+
except queue.Empty:
153+
continue
154+
finally:
155+
stop_event.set()
156+
send_thread.join(timeout=1)
157+
recv_thread.join(timeout=1)
158+
provider.stop()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc)
2+
#
3+
# SPDX-License-Identifier: MPL-2.0
4+
5+
from enum import Enum
6+
7+
from .openai import OpenAITranscribe
8+
from .google import GoogleSpeech
9+
from .types import ASREvent, ASRProvider
10+
11+
12+
class CloudProvider(str, Enum):
13+
OPENAI_TRANSCRIBE = "openai-transcribe"
14+
GOOGLE_SPEECH = "google-speech"
15+
16+
17+
DEFAULT_PROVIDER = CloudProvider.OPENAI_TRANSCRIBE
18+
19+
20+
def provider_factory(
21+
api_key: str,
22+
language: str,
23+
sample_rate: int,
24+
stream_partial: bool,
25+
name: CloudProvider = DEFAULT_PROVIDER,
26+
) -> ASRProvider:
27+
"""Return the ASR cloud provider implementation."""
28+
if name == CloudProvider.OPENAI_TRANSCRIBE:
29+
return OpenAITranscribe(
30+
api_key=api_key,
31+
language=language,
32+
sample_rate=sample_rate,
33+
stream_partial=stream_partial,
34+
)
35+
if name == CloudProvider.GOOGLE_SPEECH:
36+
return GoogleSpeech(
37+
api_key=api_key,
38+
language=language,
39+
sample_rate=sample_rate,
40+
stream_partial=stream_partial,
41+
)
42+
raise ValueError(f"Unsupported ASR cloud provider: {name}")
43+
44+
45+
__all__ = [
46+
"ASREvent",
47+
"ASRProvider",
48+
"CloudProvider",
49+
"DEFAULT_PROVIDER",
50+
"GoogleSpeech",
51+
"OpenAITranscribe",
52+
"provider_factory",
53+
]

0 commit comments

Comments
 (0)