Wake word cleanup (#98652)

* Make arguments for async_pipeline_from_audio_stream keyword-only after hass

* Use a bytearray ring buffer

* Move generator outside

* Move stt stream generator outside

* Clean up execute

* Refactor VAD to use bytearray

* More tests

* Refactor chunk_samples to be more correct and robust

* Change AudioBuffer to use append instead of setitem

* Cleanup

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen
2023-08-25 12:28:48 -05:00
committed by GitHub
parent 49897341ba
commit 8768c39021
9 changed files with 458 additions and 163 deletions

View File

@@ -52,6 +52,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
*,
context: Context,
event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata,

View File

@@ -49,6 +49,7 @@ from .error import (
WakeWordDetectionError,
WakeWordTimeoutError,
)
from .ring_buffer import RingBuffer
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
_LOGGER = logging.getLogger(__name__)
@@ -425,7 +426,6 @@ class PipelineRun:
async def prepare_wake_word_detection(self) -> None:
"""Prepare wake-word-detection."""
# Need to add to pipeline store
engine = wake_word.async_default_engine(self.hass)
if engine is None:
raise WakeWordDetectionError(
@@ -448,7 +448,7 @@ class PipelineRun:
async def wake_word_detection(
self,
stream: AsyncIterable[bytes],
audio_buffer: list[bytes],
audio_chunks_for_stt: list[bytes],
) -> wake_word.DetectionResult | None:
"""Run wake-word-detection portion of pipeline. Returns detection result."""
metadata_dict = asdict(
@@ -484,46 +484,29 @@ class PipelineRun:
# Use VAD to determine timeout
wake_word_vad = VoiceActivityTimeout(wake_word_settings.timeout)
# Audio chunk buffer.
audio_bytes_to_buffer = int(
wake_word_settings.audio_seconds_to_buffer * 16000 * 2
# Audio chunk buffer. This audio will be forwarded to speech-to-text
# after wake-word-detection.
num_audio_bytes_to_buffer = int(
wake_word_settings.audio_seconds_to_buffer * 16000 * 2 # 16-bit @ 16Khz
)
audio_ring_buffer = b""
async def timestamped_stream() -> AsyncIterable[tuple[bytes, int]]:
"""Yield audio with timestamps (milliseconds since start of stream)."""
nonlocal audio_ring_buffer
timestamp_ms = 0
async for chunk in stream:
yield chunk, timestamp_ms
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
# Keeping audio right before wake word detection allows the
# voice command to be spoken immediately after the wake word.
if audio_bytes_to_buffer > 0:
audio_ring_buffer += chunk
if len(audio_ring_buffer) > audio_bytes_to_buffer:
# A proper ring buffer would be far more efficient
audio_ring_buffer = audio_ring_buffer[
len(audio_ring_buffer) - audio_bytes_to_buffer :
]
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
stt_audio_buffer: RingBuffer | None = None
if num_audio_bytes_to_buffer > 0:
stt_audio_buffer = RingBuffer(num_audio_bytes_to_buffer)
try:
# Detect wake word(s)
result = await self.wake_word_provider.async_process_audio_stream(
timestamped_stream()
_wake_word_audio_stream(
audio_stream=stream,
stt_audio_buffer=stt_audio_buffer,
wake_word_vad=wake_word_vad,
)
)
if audio_ring_buffer:
if stt_audio_buffer is not None:
# All audio kept from right before the wake word was detected as
# a single chunk.
audio_buffer.append(audio_ring_buffer)
audio_chunks_for_stt.append(stt_audio_buffer.getvalue())
except WakeWordTimeoutError:
_LOGGER.debug("Timeout during wake word detection")
raise
@@ -540,9 +523,14 @@ class PipelineRun:
wake_word_output: dict[str, Any] = {}
else:
if result.queued_audio:
# Add audio that was pending at detection
# Add audio that was pending at detection.
#
# Because detection occurs *after* the wake word was actually
# spoken, we need to make sure pending audio is forwarded to
# speech-to-text so the user does not have to pause before
# speaking the voice command.
for chunk_ts in result.queued_audio:
audio_buffer.append(chunk_ts[0])
audio_chunks_for_stt.append(chunk_ts[0])
wake_word_output = asdict(result)
@@ -608,41 +596,12 @@ class PipelineRun:
)
try:
segmenter = VoiceCommandSegmenter()
async def segment_stream(
stream: AsyncIterable[bytes],
) -> AsyncGenerator[bytes, None]:
"""Stop stream when voice command is finished."""
sent_vad_start = False
timestamp_ms = 0
async for chunk in stream:
if not segmenter.process(chunk):
# Silence detected at the end of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_END,
{"timestamp": timestamp_ms},
)
)
break
if segmenter.in_command and (not sent_vad_start):
# Speech detected at start of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_START,
{"timestamp": timestamp_ms},
)
)
sent_vad_start = True
yield chunk
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
# Transcribe audio stream
result = await self.stt_provider.async_process_audio_stream(
metadata, segment_stream(stream)
metadata,
self._speech_to_text_stream(
audio_stream=stream, stt_vad=VoiceCommandSegmenter()
),
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text")
@@ -677,6 +636,42 @@ class PipelineRun:
return result.text
async def _speech_to_text_stream(
self,
audio_stream: AsyncIterable[bytes],
stt_vad: VoiceCommandSegmenter | None,
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncGenerator[bytes, None]:
"""Yield audio chunks until VAD detects silence or speech-to-text completes."""
ms_per_sample = sample_rate // 1000
sent_vad_start = False
timestamp_ms = 0
async for chunk in audio_stream:
if stt_vad is not None:
if not stt_vad.process(chunk):
# Silence detected at the end of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_END,
{"timestamp": timestamp_ms},
)
)
break
if stt_vad.in_command and (not sent_vad_start):
# Speech detected at start of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_START,
{"timestamp": timestamp_ms},
)
)
sent_vad_start = True
yield chunk
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
async def prepare_recognize_intent(self) -> None:
"""Prepare recognizing an intent."""
agent_info = conversation.async_get_agent_info(
@@ -861,13 +856,14 @@ class PipelineInput:
"""Run pipeline."""
self.run.start()
current_stage: PipelineStage | None = self.run.start_stage
audio_buffer: list[bytes] = []
stt_audio_buffer: list[bytes] = []
try:
if current_stage == PipelineStage.WAKE_WORD:
# wake-word-detection
assert self.stt_stream is not None
detect_result = await self.run.wake_word_detection(
self.stt_stream, audio_buffer
self.stt_stream, stt_audio_buffer
)
if detect_result is None:
# No wake word. Abort the rest of the pipeline.
@@ -882,19 +878,22 @@ class PipelineInput:
assert self.stt_metadata is not None
assert self.stt_stream is not None
if audio_buffer:
stt_stream = self.stt_stream
async def buffered_stream() -> AsyncGenerator[bytes, None]:
for chunk in audio_buffer:
if stt_audio_buffer:
# Send audio in the buffer first to speech-to-text, then move on to stt_stream.
# This is basically an async itertools.chain.
async def buffer_then_audio_stream() -> AsyncGenerator[bytes, None]:
# Buffered audio
for chunk in stt_audio_buffer:
yield chunk
# Streamed audio
assert self.stt_stream is not None
async for chunk in self.stt_stream:
yield chunk
stt_stream = cast(AsyncIterable[bytes], buffered_stream())
else:
stt_stream = self.stt_stream
stt_stream = buffer_then_audio_stream()
intent_input = await self.run.speech_to_text(
self.stt_metadata,
@@ -906,6 +905,7 @@ class PipelineInput:
tts_input = self.tts_input
if current_stage == PipelineStage.INTENT:
# intent-recognition
assert intent_input is not None
tts_input = await self.run.recognize_intent(
intent_input,
@@ -915,6 +915,7 @@ class PipelineInput:
current_stage = PipelineStage.TTS
if self.run.end_stage != PipelineStage.INTENT:
# text-to-speech
if current_stage == PipelineStage.TTS:
assert tts_input is not None
await self.run.text_to_speech(tts_input)
@@ -999,6 +1000,36 @@ class PipelineInput:
await asyncio.gather(*prepare_tasks)
async def _wake_word_audio_stream(
audio_stream: AsyncIterable[bytes],
stt_audio_buffer: RingBuffer | None,
wake_word_vad: VoiceActivityTimeout | None,
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncIterable[tuple[bytes, int]]:
"""Yield audio chunks with timestamps (milliseconds since start of stream).
Adds audio to a ring buffer that will be forwarded to speech-to-text after
detection. Times out if VAD detects enough silence.
"""
ms_per_sample = sample_rate // 1000
timestamp_ms = 0
async for chunk in audio_stream:
yield chunk, timestamp_ms
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
# Wake-word-detection occurs *after* the wake word was actually
# spoken. Keeping audio right before detection allows the voice
# command to be spoken immediately after the wake word.
if stt_audio_buffer is not None:
stt_audio_buffer.put(chunk)
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
class PipelinePreferred(CollectionError):
"""Raised when attempting to delete the preferred pipelen."""

View File

@@ -0,0 +1,57 @@
"""Implementation of a ring buffer using bytearray."""
class RingBuffer:
"""Basic ring buffer using a bytearray.
Not threadsafe.
"""
def __init__(self, maxlen: int) -> None:
"""Initialize empty buffer."""
self._buffer = bytearray(maxlen)
self._pos = 0
self._length = 0
self._maxlen = maxlen
@property
def maxlen(self) -> int:
"""Return the maximum size of the buffer."""
return self._maxlen
@property
def pos(self) -> int:
"""Return the current put position."""
return self._pos
def __len__(self) -> int:
"""Return the length of data stored in the buffer."""
return self._length
def put(self, data: bytes) -> None:
"""Put a chunk of data into the buffer, possibly wrapping around."""
data_len = len(data)
new_pos = self._pos + data_len
if new_pos >= self._maxlen:
# Split into two chunks
num_bytes_1 = self._maxlen - self._pos
num_bytes_2 = new_pos - self._maxlen
self._buffer[self._pos : self._maxlen] = data[:num_bytes_1]
self._buffer[:num_bytes_2] = data[num_bytes_1:]
new_pos = new_pos - self._maxlen
else:
# Entire chunk fits at current position
self._buffer[self._pos : self._pos + data_len] = data
self._pos = new_pos
self._length = min(self._maxlen, self._length + data_len)
def getvalue(self) -> bytes:
"""Get bytes written to the buffer."""
if (self._pos + self._length) <= self._maxlen:
# Single chunk
return bytes(self._buffer[: self._length])
# Two chunks
return bytes(self._buffer[self._pos :] + self._buffer[: self._pos])

View File

@@ -1,12 +1,15 @@
"""Voice activity detection."""
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Final
import webrtcvad
_SAMPLE_RATE = 16000
_SAMPLE_RATE: Final = 16000 # Hz
_SAMPLE_WIDTH: Final = 2 # bytes
class VadSensitivity(StrEnum):
@@ -29,6 +32,45 @@ class VadSensitivity(StrEnum):
return 1.0
class AudioBuffer:
"""Fixed-sized audio buffer with variable internal length."""
def __init__(self, maxlen: int) -> None:
"""Initialize buffer."""
self._buffer = bytearray(maxlen)
self._length = 0
@property
def length(self) -> int:
"""Get number of bytes currently in the buffer."""
return self._length
def clear(self) -> None:
"""Clear the buffer."""
self._length = 0
def append(self, data: bytes) -> None:
"""Append bytes to the buffer, increasing the internal length."""
data_len = len(data)
if (self._length + data_len) > len(self._buffer):
raise ValueError("Length cannot be greater than buffer size")
self._buffer[self._length : self._length + data_len] = data
self._length += data_len
def bytes(self) -> bytes:
"""Convert written portion of buffer to bytes."""
return bytes(self._buffer[: self._length])
def __len__(self) -> int:
"""Get the number of bytes currently in the buffer."""
return self._length
def __bool__(self) -> bool:
"""Return True if there are bytes in the buffer."""
return self._length > 0
@dataclass
class VoiceCommandSegmenter:
"""Segments an audio stream into voice commands using webrtcvad."""
@@ -36,7 +78,7 @@ class VoiceCommandSegmenter:
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_frames: int = 480 # 30 ms
vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
speech_seconds: float = 0.3
@@ -67,20 +109,23 @@ class VoiceCommandSegmenter:
"""Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_audio_buffer: bytes = field(default_factory=bytes)
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
_seconds_per_chunk: float = 0.03 # 30 ms
_leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_frames * 2
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
self.reset()
def reset(self) -> None:
"""Reset all counters and state."""
self._audio_buffer = b""
self._leftover_chunk_buffer.clear()
self._speech_seconds_left = self.speech_seconds
self._silence_seconds_left = self.silence_seconds
self._timeout_seconds_left = self.timeout_seconds
@@ -92,27 +137,20 @@ class VoiceCommandSegmenter:
Returns False when command is done.
"""
self._audio_buffer += samples
# Process in 10, 20, or 30 ms chunks.
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
for chunk_idx in range(num_chunks):
chunk_offset = chunk_idx * self._bytes_per_chunk
chunk = self._audio_buffer[
chunk_offset : chunk_offset + self._bytes_per_chunk
]
for chunk in chunk_samples(
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
):
if not self._process_chunk(chunk):
self.reset()
return False
if num_chunks > 0:
# Remove from buffer
self._audio_buffer = self._audio_buffer[
num_chunks * self._bytes_per_chunk :
]
return True
@property
def audio_buffer(self) -> bytes:
"""Get partial chunk in the audio buffer."""
return self._leftover_chunk_buffer.bytes()
def _process_chunk(self, chunk: bytes) -> bool:
"""Process a single chunk of 16-bit 16Khz mono audio.
@@ -163,7 +201,7 @@ class VoiceActivityTimeout:
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_frames: int = 480 # 30 ms
vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
_silence_seconds_left: float = 0.0
@@ -173,20 +211,23 @@ class VoiceActivityTimeout:
"""Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_audio_buffer: bytes = field(default_factory=bytes)
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
_seconds_per_chunk: float = 0.03 # 30 ms
_leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_frames * 2
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
self.reset()
def reset(self) -> None:
"""Reset all counters and state."""
self._audio_buffer = b""
self._leftover_chunk_buffer.clear()
self._silence_seconds_left = self.silence_seconds
self._reset_seconds_left = self.reset_seconds
@@ -195,24 +236,12 @@ class VoiceActivityTimeout:
Returns False when timeout is reached.
"""
self._audio_buffer += samples
# Process in 10, 20, or 30 ms chunks.
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
for chunk_idx in range(num_chunks):
chunk_offset = chunk_idx * self._bytes_per_chunk
chunk = self._audio_buffer[
chunk_offset : chunk_offset + self._bytes_per_chunk
]
for chunk in chunk_samples(
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
):
if not self._process_chunk(chunk):
return False
if num_chunks > 0:
# Remove from buffer
self._audio_buffer = self._audio_buffer[
num_chunks * self._bytes_per_chunk :
]
return True
def _process_chunk(self, chunk: bytes) -> bool:
@@ -239,3 +268,37 @@ class VoiceActivityTimeout:
)
return True
def chunk_samples(
samples: bytes,
bytes_per_chunk: int,
leftover_chunk_buffer: AudioBuffer,
) -> Iterable[bytes]:
"""Yield fixed-sized chunks from samples, keeping leftover bytes from previous call(s)."""
if (len(leftover_chunk_buffer) + len(samples)) < bytes_per_chunk:
# Extend leftover chunk, but not enough samples to complete it
leftover_chunk_buffer.append(samples)
return
next_chunk_idx = 0
if leftover_chunk_buffer:
# Add to leftover chunk from previous call(s).
bytes_to_copy = bytes_per_chunk - len(leftover_chunk_buffer)
leftover_chunk_buffer.append(samples[:bytes_to_copy])
next_chunk_idx = bytes_to_copy
# Process full chunk in buffer
yield leftover_chunk_buffer.bytes()
leftover_chunk_buffer.clear()
while next_chunk_idx < len(samples) - bytes_per_chunk + 1:
# Process full chunk
yield samples[next_chunk_idx : next_chunk_idx + bytes_per_chunk]
next_chunk_idx += bytes_per_chunk
# Capture leftover chunks
if rest_samples := samples[next_chunk_idx:]:
leftover_chunk_buffer.append(rest_samples)

View File

@@ -79,8 +79,6 @@ class WakeWordDetectionEntity(RestoreEntity):
@final
def state(self) -> str | None:
"""Return the state of the entity."""
if self.__last_detected is None:
return None
return self.__last_detected
@property