Switch `TTS` to use `Arc<RwLock<Box<dyn Backend>>>` to address soundness issues.

This commit is contained in:
Nolan Darilek 2022-03-07 17:54:26 -06:00
parent 90f7dae6a1
commit 00bd5e62ff
1 changed files with 37 additions and 31 deletions

View File

@ -12,12 +12,12 @@
* * WebAssembly * * WebAssembly
*/ */
use std::boxed::Box;
use std::collections::HashMap; use std::collections::HashMap;
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
use std::ffi::CStr; use std::ffi::CStr;
use std::fmt; use std::fmt;
use std::sync::Mutex; use std::sync::{Arc, Mutex};
use std::{boxed::Box, sync::RwLock};
#[cfg(any(target_os = "macos", target_os = "ios"))] #[cfg(any(target_os = "macos", target_os = "ios"))]
use cocoa_foundation::base::id; use cocoa_foundation::base::id;
@ -262,7 +262,7 @@ lazy_static! {
} }
#[derive(Clone)] #[derive(Clone)]
pub struct Tts(Box<dyn Backend>); pub struct Tts(Arc<RwLock<Box<dyn Backend>>>);
unsafe impl Send for Tts {} unsafe impl Send for Tts {}
@ -296,7 +296,7 @@ impl Tts {
#[cfg(windows)] #[cfg(windows)]
Backends::WinRt => { Backends::WinRt => {
let tts = backends::WinRt::new()?; let tts = backends::WinRt::new()?;
Ok(Tts(Box::new(tts))) Ok(Tts(Arc::new(RwLock::new(Box::new(tts)))))
} }
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
Backends::AppKit => Ok(Tts(Box::new(backends::AppKit::new()))), Backends::AppKit => Ok(Tts(Box::new(backends::AppKit::new()))),
@ -309,7 +309,7 @@ impl Tts {
} }
}; };
if let Ok(backend) = backend { if let Ok(backend) = backend {
if let Some(id) = backend.0.id() { if let Some(id) = backend.0.read().unwrap().id() {
let mut callbacks = CALLBACKS.lock().unwrap(); let mut callbacks = CALLBACKS.lock().unwrap();
callbacks.insert(id, Callbacks::default()); callbacks.insert(id, Callbacks::default());
} }
@ -362,7 +362,7 @@ impl Tts {
* Returns the features supported by this TTS engine * Returns the features supported by this TTS engine
*/ */
pub fn supported_features(&self) -> Features { pub fn supported_features(&self) -> Features {
self.0.supported_features() self.0.read().unwrap().supported_features()
} }
/** /**
@ -373,7 +373,10 @@ impl Tts {
text: S, text: S,
interrupt: bool, interrupt: bool,
) -> Result<Option<UtteranceId>, Error> { ) -> Result<Option<UtteranceId>, Error> {
self.0.speak(text.into().as_str(), interrupt) self.0
.write()
.unwrap()
.speak(text.into().as_str(), interrupt)
} }
/** /**
@ -382,7 +385,7 @@ impl Tts {
pub fn stop(&mut self) -> Result<&Self, Error> { pub fn stop(&mut self) -> Result<&Self, Error> {
let Features { stop, .. } = self.supported_features(); let Features { stop, .. } = self.supported_features();
if stop { if stop {
self.0.stop()?; self.0.write().unwrap().stop()?;
Ok(self) Ok(self)
} else { } else {
Err(Error::UnsupportedFeature) Err(Error::UnsupportedFeature)
@ -393,21 +396,21 @@ impl Tts {
* Returns the minimum rate for this speech synthesizer. * Returns the minimum rate for this speech synthesizer.
*/ */
pub fn min_rate(&self) -> f32 { pub fn min_rate(&self) -> f32 {
self.0.min_rate() self.0.read().unwrap().min_rate()
} }
/** /**
* Returns the maximum rate for this speech synthesizer. * Returns the maximum rate for this speech synthesizer.
*/ */
pub fn max_rate(&self) -> f32 { pub fn max_rate(&self) -> f32 {
self.0.max_rate() self.0.read().unwrap().max_rate()
} }
/** /**
* Returns the normal rate for this speech synthesizer. * Returns the normal rate for this speech synthesizer.
*/ */
pub fn normal_rate(&self) -> f32 { pub fn normal_rate(&self) -> f32 {
self.0.normal_rate() self.0.read().unwrap().normal_rate()
} }
/** /**
@ -416,7 +419,7 @@ impl Tts {
pub fn get_rate(&self) -> Result<f32, Error> { pub fn get_rate(&self) -> Result<f32, Error> {
let Features { rate, .. } = self.supported_features(); let Features { rate, .. } = self.supported_features();
if rate { if rate {
self.0.get_rate() self.0.read().unwrap().get_rate()
} else { } else {
Err(Error::UnsupportedFeature) Err(Error::UnsupportedFeature)
} }
@ -430,10 +433,11 @@ impl Tts {
rate: rate_feature, .. rate: rate_feature, ..
} = self.supported_features(); } = self.supported_features();
if rate_feature { if rate_feature {
if rate < self.0.min_rate() || rate > self.0.max_rate() { let mut backend = self.0.write().unwrap();
if rate < backend.min_rate() || rate > backend.max_rate() {
Err(Error::OutOfRange) Err(Error::OutOfRange)
} else { } else {
self.0.set_rate(rate)?; backend.set_rate(rate)?;
Ok(self) Ok(self)
} }
} else { } else {
@ -445,21 +449,21 @@ impl Tts {
* Returns the minimum pitch for this speech synthesizer. * Returns the minimum pitch for this speech synthesizer.
*/ */
pub fn min_pitch(&self) -> f32 { pub fn min_pitch(&self) -> f32 {
self.0.min_pitch() self.0.read().unwrap().min_pitch()
} }
/** /**
* Returns the maximum pitch for this speech synthesizer. * Returns the maximum pitch for this speech synthesizer.
*/ */
pub fn max_pitch(&self) -> f32 { pub fn max_pitch(&self) -> f32 {
self.0.max_pitch() self.0.read().unwrap().max_pitch()
} }
/** /**
* Returns the normal pitch for this speech synthesizer. * Returns the normal pitch for this speech synthesizer.
*/ */
pub fn normal_pitch(&self) -> f32 { pub fn normal_pitch(&self) -> f32 {
self.0.normal_pitch() self.0.read().unwrap().normal_pitch()
} }
/** /**
@ -468,7 +472,7 @@ impl Tts {
pub fn get_pitch(&self) -> Result<f32, Error> { pub fn get_pitch(&self) -> Result<f32, Error> {
let Features { pitch, .. } = self.supported_features(); let Features { pitch, .. } = self.supported_features();
if pitch { if pitch {
self.0.get_pitch() self.0.read().unwrap().get_pitch()
} else { } else {
Err(Error::UnsupportedFeature) Err(Error::UnsupportedFeature)
} }
@ -483,10 +487,11 @@ impl Tts {
.. ..
} = self.supported_features(); } = self.supported_features();
if pitch_feature { if pitch_feature {
if pitch < self.0.min_pitch() || pitch > self.0.max_pitch() { let mut backend = self.0.write().unwrap();
if pitch < backend.min_pitch() || pitch > backend.max_pitch() {
Err(Error::OutOfRange) Err(Error::OutOfRange)
} else { } else {
self.0.set_pitch(pitch)?; backend.set_pitch(pitch)?;
Ok(self) Ok(self)
} }
} else { } else {
@ -498,21 +503,21 @@ impl Tts {
* Returns the minimum volume for this speech synthesizer. * Returns the minimum volume for this speech synthesizer.
*/ */
pub fn min_volume(&self) -> f32 { pub fn min_volume(&self) -> f32 {
self.0.min_volume() self.0.read().unwrap().min_volume()
} }
/** /**
* Returns the maximum volume for this speech synthesizer. * Returns the maximum volume for this speech synthesizer.
*/ */
pub fn max_volume(&self) -> f32 { pub fn max_volume(&self) -> f32 {
self.0.max_volume() self.0.read().unwrap().max_volume()
} }
/** /**
* Returns the normal volume for this speech synthesizer. * Returns the normal volume for this speech synthesizer.
*/ */
pub fn normal_volume(&self) -> f32 { pub fn normal_volume(&self) -> f32 {
self.0.normal_volume() self.0.read().unwrap().normal_volume()
} }
/** /**
@ -521,7 +526,7 @@ impl Tts {
pub fn get_volume(&self) -> Result<f32, Error> { pub fn get_volume(&self) -> Result<f32, Error> {
let Features { volume, .. } = self.supported_features(); let Features { volume, .. } = self.supported_features();
if volume { if volume {
self.0.get_volume() self.0.read().unwrap().get_volume()
} else { } else {
Err(Error::UnsupportedFeature) Err(Error::UnsupportedFeature)
} }
@ -536,10 +541,11 @@ impl Tts {
.. ..
} = self.supported_features(); } = self.supported_features();
if volume_feature { if volume_feature {
if volume < self.0.min_volume() || volume > self.0.max_volume() { let mut backend = self.0.write().unwrap();
if volume < backend.min_volume() || volume > backend.max_volume() {
Err(Error::OutOfRange) Err(Error::OutOfRange)
} else { } else {
self.0.set_volume(volume)?; backend.set_volume(volume)?;
Ok(self) Ok(self)
} }
} else { } else {
@ -553,7 +559,7 @@ impl Tts {
pub fn is_speaking(&self) -> Result<bool, Error> { pub fn is_speaking(&self) -> Result<bool, Error> {
let Features { is_speaking, .. } = self.supported_features(); let Features { is_speaking, .. } = self.supported_features();
if is_speaking { if is_speaking {
self.0.is_speaking() self.0.read().unwrap().is_speaking()
} else { } else {
Err(Error::UnsupportedFeature) Err(Error::UnsupportedFeature)
} }
@ -572,7 +578,7 @@ impl Tts {
} = self.supported_features(); } = self.supported_features();
if utterance_callbacks { if utterance_callbacks {
let mut callbacks = CALLBACKS.lock().unwrap(); let mut callbacks = CALLBACKS.lock().unwrap();
let id = self.0.id().unwrap(); let id = self.0.read().unwrap().id().unwrap();
let mut callbacks = callbacks.get_mut(&id).unwrap(); let mut callbacks = callbacks.get_mut(&id).unwrap();
callbacks.utterance_begin = callback; callbacks.utterance_begin = callback;
Ok(()) Ok(())
@ -594,7 +600,7 @@ impl Tts {
} = self.supported_features(); } = self.supported_features();
if utterance_callbacks { if utterance_callbacks {
let mut callbacks = CALLBACKS.lock().unwrap(); let mut callbacks = CALLBACKS.lock().unwrap();
let id = self.0.id().unwrap(); let id = self.0.read().unwrap().id().unwrap();
let mut callbacks = callbacks.get_mut(&id).unwrap(); let mut callbacks = callbacks.get_mut(&id).unwrap();
callbacks.utterance_end = callback; callbacks.utterance_end = callback;
Ok(()) Ok(())
@ -616,7 +622,7 @@ impl Tts {
} = self.supported_features(); } = self.supported_features();
if utterance_callbacks { if utterance_callbacks {
let mut callbacks = CALLBACKS.lock().unwrap(); let mut callbacks = CALLBACKS.lock().unwrap();
let id = self.0.id().unwrap(); let id = self.0.read().unwrap().id().unwrap();
let mut callbacks = callbacks.get_mut(&id).unwrap(); let mut callbacks = callbacks.get_mut(&id).unwrap();
callbacks.utterance_stop = callback; callbacks.utterance_stop = callback;
Ok(()) Ok(())
@ -646,7 +652,7 @@ impl Tts {
impl Drop for Tts { impl Drop for Tts {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(id) = self.0.id() { if let Some(id) = self.0.read().unwrap().id() {
let mut callbacks = CALLBACKS.lock().unwrap(); let mut callbacks = CALLBACKS.lock().unwrap();
callbacks.remove(&id); callbacks.remove(&id);
} }