use crate::{instance::Instance, WasmEdgeResult};
use wasmedge_sys::{self as sys};
#[cfg(feature = "wasi_nn")]
use wasmedge_types::error::WasmEdgeError;
pub mod ffi {
pub use wasmedge_sys::ffi::{
WasmEdge_ModuleDescriptor, WasmEdge_ModuleInstanceContext, WasmEdge_PluginDescriptor,
};
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
#[derive(Debug)]
pub struct NNPreload {
alias: String,
backend: GraphEncoding,
target: ExecutionTarget,
path: std::path::PathBuf,
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl NNPreload {
pub fn new(
alias: impl AsRef<str>,
backend: GraphEncoding,
target: ExecutionTarget,
path: impl AsRef<std::path::Path>,
) -> Self {
Self {
alias: alias.as_ref().to_owned(),
backend,
target,
path: path.as_ref().to_owned(),
}
}
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl std::fmt::Display for NNPreload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{alias}:{backend}:{target}:{path}",
alias = self.alias,
backend = self.backend,
target = self.target,
path = self.path.to_string_lossy().into_owned()
)
}
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl std::str::FromStr for NNPreload {
type Err = WasmEdgeError;
fn from_str(preload: &str) -> std::result::Result<Self, Self::Err> {
let nn_preload: Vec<&str> = preload.split(':').collect();
if nn_preload.len() != 4 {
return Err(WasmEdgeError::Operation(format!(
"Failed to convert to NNPreload value. Invalid preload string: {}. The correct format is: 'alias:backend:target:path'",
preload
)));
}
let (alias, backend, target, path) = (
nn_preload[0].to_string(),
nn_preload[1]
.parse::<GraphEncoding>()
.map_err(|err| WasmEdgeError::Operation(err.to_string()))?,
nn_preload[2]
.parse::<ExecutionTarget>()
.map_err(|err| WasmEdgeError::Operation(err.to_string()))?,
std::path::PathBuf::from(nn_preload[3]),
);
Ok(Self::new(alias, backend, target, path))
}
}
#[cfg(feature = "wasi_nn")]
#[test]
fn test_generate_nnpreload_from_str() {
use std::str::FromStr;
let preload = "default:GGML:CPU:llama-2-7b-chat.Q5_K_M.gguf";
let result = NNPreload::from_str(preload);
assert!(result.is_ok());
let nnpreload = result.unwrap();
assert_eq!(nnpreload.alias, "default");
assert_eq!(nnpreload.backend, GraphEncoding::GGML);
assert_eq!(nnpreload.target, ExecutionTarget::CPU);
assert_eq!(
nnpreload.path,
std::path::PathBuf::from("llama-2-7b-chat.Q5_K_M.gguf")
);
let preload = "default:CPU:GGML:llama-2-7b-chat.Q5_K_M.gguf";
let result = NNPreload::from_str(preload);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(
WasmEdgeError::Operation(
"Failed to convert to NNBackend value. Unknown NNBackend type: CPU".to_string()
),
err
);
let preload = "default:GGML:NPU:llama-2-7b-chat.Q5_K_M.gguf";
let result = NNPreload::from_str(preload);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(
WasmEdgeError::Operation(
"Failed to convert to ExecutionTarget value. Unknown ExecutionTarget type: NPU"
.to_string()
),
err
);
let preload = "default:GGML:CPU";
let result = NNPreload::from_str(preload);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(
WasmEdgeError::Operation(
"Failed to convert to NNPreload value. Invalid preload string: default:GGML:CPU. The correct format is: 'alias:backend:target:path'"
.to_string()
),
err
);
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[allow(non_camel_case_types)]
pub enum GraphEncoding {
OpenVINO,
ONNX,
TensorFlow,
PyTorch,
TensorFlowLite,
Autodetect,
GGML,
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl std::fmt::Display for GraphEncoding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GraphEncoding::PyTorch => write!(f, "PyTorch"),
GraphEncoding::TensorFlowLite => write!(f, "TensorflowLite"),
GraphEncoding::TensorFlow => write!(f, "Tensorflow"),
GraphEncoding::OpenVINO => write!(f, "OpenVINO"),
GraphEncoding::GGML => write!(f, "GGML"),
GraphEncoding::ONNX => write!(f, "ONNX"),
GraphEncoding::Autodetect => write!(f, "Autodetect"),
}
}
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl std::str::FromStr for GraphEncoding {
type Err = WasmEdgeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openvino" => Ok(GraphEncoding::OpenVINO),
"onnx" => Ok(GraphEncoding::ONNX),
"tensorflow" => Ok(GraphEncoding::TensorFlow),
"pytorch" => Ok(GraphEncoding::PyTorch),
"tensorflowlite" => Ok(GraphEncoding::TensorFlowLite),
"autodetect" => Ok(GraphEncoding::Autodetect),
"ggml" => Ok(GraphEncoding::GGML),
_ => Err(WasmEdgeError::Operation(format!(
"Failed to convert to NNBackend value. Unknown NNBackend type: {}",
s
))),
}
}
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[allow(non_camel_case_types)]
pub enum ExecutionTarget {
CPU,
GPU,
TPU,
AUTO,
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl std::fmt::Display for ExecutionTarget {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExecutionTarget::CPU => write!(f, "CPU"),
ExecutionTarget::GPU => write!(f, "GPU"),
ExecutionTarget::TPU => write!(f, "TPU"),
ExecutionTarget::AUTO => write!(f, "AUTO"),
}
}
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl std::str::FromStr for ExecutionTarget {
type Err = WasmEdgeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_uppercase().as_str() {
"CPU" => Ok(ExecutionTarget::CPU),
"GPU" => Ok(ExecutionTarget::GPU),
"TPU" => Ok(ExecutionTarget::TPU),
"AUTO" => Ok(ExecutionTarget::AUTO),
_ => Err(WasmEdgeError::Operation(format!(
"Failed to convert to ExecutionTarget value. Unknown ExecutionTarget type: {}",
s
))),
}
}
}
#[derive(Debug)]
pub struct PluginManager {}
impl PluginManager {
pub fn load(path: Option<&std::path::Path>) -> WasmEdgeResult<()> {
match path {
Some(p) => sys::plugin::PluginManager::load_plugins(p),
None => {
sys::plugin::PluginManager::load_plugins_from_default_paths();
Ok(())
}
}
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
pub fn nn_preload(preloads: Vec<NNPreload>) {
let mut nn_preloads = Vec::new();
for preload in preloads {
nn_preloads.push(preload.to_string());
}
let nn_preloads_str: Vec<&str> = nn_preloads.iter().map(|s| s.as_str()).collect();
sys::plugin::PluginManager::nn_preload(nn_preloads_str);
}
pub fn count() -> u32 {
sys::plugin::PluginManager::count()
}
pub fn names() -> Vec<String> {
sys::plugin::PluginManager::names()
}
pub fn find(name: impl AsRef<str>) -> WasmEdgeResult<Plugin> {
sys::plugin::PluginManager::find(name.as_ref()).map(|p| Plugin { inner: p })
}
pub fn create_plugin_instance(
pname: impl AsRef<str>,
mname: impl AsRef<str>,
) -> WasmEdgeResult<PluginInstance> {
let plugin = sys::plugin::PluginManager::create_plugin_instance(pname, mname)?;
Ok(plugin)
}
#[cfg(all(
target_os = "linux",
feature = "wasmedge_process",
not(feature = "static")
))]
#[cfg_attr(
docsrs,
doc(cfg(all(
target_os = "linux",
feature = "wasmedge_process",
not(feature = "static")
)))
)]
pub fn init_wasmedge_process(allowed_cmds: Option<Vec<&str>>, allowed: bool) {
sys::plugin::PluginManager::init_wasmedge_process(allowed_cmds, allowed);
}
pub fn auto_detect_plugins() -> WasmEdgeResult<Vec<Instance>> {
let mut plugin_mods = vec![];
for plugin_name in PluginManager::names().iter() {
if let Ok(plugin) = PluginManager::find(plugin_name) {
for mod_name in plugin.mod_names().iter() {
if let Ok(mod_instance) = plugin.mod_instance(mod_name) {
plugin_mods.push(mod_instance)
}
}
}
}
Ok(plugin_mods)
}
}
impl PluginManager {
pub fn load_plugin_wasi_nn() -> WasmEdgeResult<Instance> {
Self::create_plugin_instance("wasi_nn", "wasi_nn")
}
pub fn load_wasi_crypto_common() -> WasmEdgeResult<Instance> {
Self::create_plugin_instance("wasi_crypto", "wasi_crypto_common")
}
pub fn load_wasi_crypto_asymmetric_common() -> WasmEdgeResult<Instance> {
Self::create_plugin_instance("wasi_crypto", "wasi_crypto_asymmetric_common")
}
pub fn load_wasi_crypto_kx() -> WasmEdgeResult<Instance> {
Self::create_plugin_instance("wasi_crypto", "wasi_crypto_kx")
}
pub fn load_wasi_crypto_signatures() -> WasmEdgeResult<Instance> {
Self::create_plugin_instance("wasi_crypto", "wasi_crypto_signatures")
}
pub fn load_wasi_crypto_symmetric() -> WasmEdgeResult<Instance> {
Self::create_plugin_instance("wasi_crypto", "wasi_crypto_symmetric")
}
}
#[derive(Debug)]
pub struct Plugin {
inner: sys::plugin::Plugin,
}
impl Plugin {
pub fn name(&self) -> String {
self.inner.name()
}
pub fn mod_count(&self) -> u32 {
self.inner.mod_count()
}
pub fn mod_names(&self) -> Vec<String> {
self.inner.mod_names()
}
pub fn mod_instance(&self, name: impl AsRef<str>) -> WasmEdgeResult<PluginInstance> {
self.inner.mod_instance(name.as_ref())
}
}
pub type PluginInstance = Instance;