diff --git a/abi-conformance/src/lib.rs b/abi-conformance/src/lib.rs index ce19c81..67207e2 100644 --- a/abi-conformance/src/lib.rs +++ b/abi-conformance/src/lib.rs @@ -28,6 +28,7 @@ use std::{future::Future, str}; use test_config::Config; use test_http::Http; use test_key_value::KeyValue; +use test_llm::Llm; use test_mysql::Mysql; use test_postgres::Postgres; use test_redis::Redis; @@ -38,6 +39,7 @@ use wasmtime::{ use wasmtime_wasi::preview2::{pipe::WritePipe, Table, WasiCtx, WasiCtxBuilder, WasiView}; pub use test_key_value::KeyValueReport; +pub use test_llm::LlmReport; pub use test_mysql::MysqlReport; pub use test_postgres::PostgresReport; pub use test_redis::RedisReport; @@ -48,6 +50,7 @@ mod test_http; mod test_inbound_http; mod test_inbound_redis; mod test_key_value; +mod test_llm; mod test_mysql; mod test_postgres; mod test_redis; @@ -137,6 +140,11 @@ pub struct Report { /// See [`KeyValueReport`] for details. pub key_value: KeyValueReport, + /// Results of the Spin llm tests + /// + /// See [`LlmReport`] for details. + pub llm: LlmReport, + /// Results of the WASI tests /// /// See [`WasiReport`] for details. @@ -159,6 +167,7 @@ pub async fn test( postgres::add_to_linker(&mut linker, |context| &mut context.postgres)?; mysql::add_to_linker(&mut linker, |context| &mut context.mysql)?; key_value::add_to_linker(&mut linker, |context| &mut context.key_value)?; + llm::add_to_linker(&mut linker, |context| &mut context.llm)?; config::add_to_linker(&mut linker, |context| &mut context.config)?; let pre = linker.instantiate_pre(component)?; @@ -172,6 +181,7 @@ pub async fn test( postgres: test_postgres::test(engine, test_config.clone(), &pre).await?, mysql: test_mysql::test(engine, test_config.clone(), &pre).await?, key_value: test_key_value::test(engine, test_config.clone(), &pre).await?, + llm: test_llm::test(engine, test_config.clone(), &pre).await?, wasi: test_wasi::test(engine, test_config, &pre).await?, }) } @@ -220,6 +230,7 @@ struct Context { postgres: Postgres, mysql: Mysql, key_value: KeyValue, + llm: Llm, config: Config, } @@ -241,6 +252,7 @@ impl Context { postgres: Default::default(), mysql: Default::default(), key_value: Default::default(), + llm: Default::default(), config: Default::default(), } } diff --git a/abi-conformance/src/test_llm.rs b/abi-conformance/src/test_llm.rs new file mode 100644 index 0000000..0c38df5 --- /dev/null +++ b/abi-conformance/src/test_llm.rs @@ -0,0 +1,103 @@ +use std::collections::HashMap; + +use anyhow::{ensure, Result}; +use async_trait::async_trait; +use serde::Serialize; + +use crate::llm; + +/// Report of which key-value functions a module successfully used, if any +#[derive(Serialize, PartialEq, Eq, Debug)] +pub struct LlmReport { + pub infer: Result<(), String>, +} + +#[derive(Default)] +pub struct Llm { + inferences: HashMap<(String, String), String>, + embeddings: HashMap<(String, Vec), Vec>>, +} + +#[async_trait] +impl llm::Host for Llm { + async fn infer( + &mut self, + model: llm::InferencingModel, + prompt: String, + _params: Option, + ) -> wasmtime::Result> { + Ok(self + .inferences + .remove(&(model, prompt.clone())) + .map(|r| llm::InferencingResult { + text: r, + usage: llm::InferencingUsage { + prompt_token_count: 0, + generated_token_count: 0, + }, + }) + .ok_or_else(|| { + llm::Error::RuntimeError(format!( + "expected {:?}, got {:?}", + self.inferences.keys(), + prompt + )) + })) + } + + async fn generate_embeddings( + &mut self, + model: llm::EmbeddingModel, + text: Vec, + ) -> wasmtime::Result> { + Ok(self + .embeddings + .remove(&(model, text.clone())) + .map(|r| llm::EmbeddingsResult { + embeddings: r, + usage: llm::EmbeddingsUsage { + prompt_token_count: 0, + }, + }) + .ok_or_else(|| { + llm::Error::RuntimeError(format!( + "expected {:?}, got {:?}", + self.embeddings.keys(), + text + )) + })) + } +} + +pub(crate) async fn test( + engine: &wasmtime::Engine, + test_config: crate::TestConfig, + pre: &wasmtime::component::InstancePre, +) -> Result { + Ok(LlmReport { + infer: { + let mut store = + crate::create_store_with_context(engine, test_config.clone(), |context| { + context + .llm + .inferences + .insert(("model".into(), "Say hello".into()), "hello".into()); + }); + + crate::run_command( + &mut store, + pre, + &["llm-infer", "model", "Say hello"], + |store| { + ensure!( + store.data().llm.inferences.is_empty(), + "expected module to call `llm::infer` exactly once" + ); + + Ok(()) + }, + ) + .await + }, + }) +} diff --git a/adapters/README.md b/adapters/README.md index 777c0aa..f78dc4d 100644 --- a/adapters/README.md +++ b/adapters/README.md @@ -5,4 +5,4 @@ The componentize process uses adapters to adapt plain wasm modules to wasi previ * The upstream wasi preview1 adapters for both commands and reactors for use with newer versions of wit-bindgen (v0.5 and above). * These are currently the [v10.0.1 release](https://github.com/bytecodealliance/wasmtime/releases/tag/v10.0.1). * A modified adapter that has knowledge of Spin APIs for use with v0.2 of wit-bindgen which has a different ABI than newer wit-bindgen based modules. - * This is currently built using commit [8e261ac4](https://github.com/rylev/wasmtime/commit/8e261ac452ff54031efe2fde804cdf63fded3e55) on the github.com/rylev/wasmtime fork of wasmtime. + * This is currently built using commit [4536277](https://github.com/rylev/wasmtime/commit/4536277317c02443936253a029c3db765147102f) on the github.com/rylev/wasmtime fork of wasmtime. diff --git a/adapters/wasi_snapshot_preview1.spin.wasm b/adapters/wasi_snapshot_preview1.spin.wasm index e2a6c30..87b0ab3 100755 Binary files a/adapters/wasi_snapshot_preview1.spin.wasm and b/adapters/wasi_snapshot_preview1.spin.wasm differ diff --git a/src/lib.rs b/src/lib.rs index f7ea0d6..d0126fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -226,13 +226,14 @@ fn add_custom_section(name: &str, data: &[u8], module: &[u8]) -> Result> #[cfg(test)] mod tests { + use anyhow::Context; use wasmtime_wasi::preview2::{wasi::command::Command, Table, WasiView}; use { anyhow::{anyhow, Result}, spin_abi_conformance::{ - InvocationStyle, KeyValueReport, MysqlReport, PostgresReport, RedisReport, Report, - TestConfig, WasiReport, + InvocationStyle, KeyValueReport, LlmReport, MysqlReport, PostgresReport, RedisReport, + Report, TestConfig, WasiReport, }, std::io::Cursor, tokio::fs, @@ -251,7 +252,11 @@ mod tests { let engine = Engine::new(&config)?; - let component = Component::new(&engine, crate::componentize(module)?)?; + let component = Component::new( + &engine, + crate::componentize(module).context("could not componentize")?, + ) + .context("failed to instantiate componentized bytes")?; let report = spin_abi_conformance::test( &component, @@ -260,7 +265,8 @@ mod tests { invocation_style: InvocationStyle::InboundHttp, }, ) - .await?; + .await + .context("abi conformance test failed")?; let expected = Report { inbound_http: Ok(()), @@ -295,6 +301,7 @@ mod tests { get_keys: Ok(()), close: Ok(()), }, + llm: LlmReport { infer: Ok(()) }, wasi: WasiReport { env: Ok(()), epoch: Ok(()), diff --git a/tests/case-helper/src/lib.rs b/tests/case-helper/src/lib.rs index ed9cbdb..b5f7100 100644 --- a/tests/case-helper/src/lib.rs +++ b/tests/case-helper/src/lib.rs @@ -109,6 +109,10 @@ pub enum Command { KeyValueClose { store: u32, }, + LlmInfer { + model: String, + prompt: String, + }, WasiEnv { key: String, }, diff --git a/tests/rust-case-0.2/src/lib.rs b/tests/rust-case-0.2/src/lib.rs index 9ec8a5c..944f3ab 100644 --- a/tests/rust-case-0.2/src/lib.rs +++ b/tests/rust-case-0.2/src/lib.rs @@ -152,6 +152,8 @@ impl fmt::Display for key_value::Error { impl error::Error for key_value::Error {} +wit_bindgen_rust::import!("../wit/llm.wit"); + fn dispatch(body: Option>) -> Response { match execute(body) { Ok(()) => { @@ -352,6 +354,9 @@ fn execute(body: Option>) -> Result<()> { Command::KeyValueClose { store } => { key_value::close(*store); } + Command::LlmInfer { model, prompt } => { + llm::infer(model, prompt, None); + } Command::WasiEnv { key } => Command::env(key.clone())?, Command::WasiEpoch => Command::epoch()?, diff --git a/tests/rust-case-0.8/src/lib.rs b/tests/rust-case-0.8/src/lib.rs index 65cc457..26aff63 100644 --- a/tests/rust-case-0.8/src/lib.rs +++ b/tests/rust-case-0.8/src/lib.rs @@ -260,6 +260,9 @@ fn execute(body: Option>) -> Result<()> { Command::KeyValueClose { store } => { spin::key_value::close(store); } + Command::LlmInfer { model, prompt } => { + spin::llm::infer(&model, &prompt, None); + } Command::WasiEnv { key } => Command::env(key)?, Command::WasiEpoch => Command::epoch()?, diff --git a/tests/wit-0.8/llm.wit b/tests/wit-0.8/llm.wit new file mode 100644 index 0000000..2b1533c --- /dev/null +++ b/tests/wit-0.8/llm.wit @@ -0,0 +1,70 @@ +// A WASI interface dedicated to performing inferencing for Large Language Models. +interface llm { + /// A Large Language Model. + type inferencing-model = string + + /// Inference request parameters + record inferencing-params { + /// The maximum tokens that should be inferred. + /// + /// Note: the backing implementation may return less tokens. + max-tokens: u32, + /// The amount the model should avoid repeating tokens. + repeat-penalty: float32, + /// The number of tokens the model should apply the repeat penalty to. + repeat-penalty-last-n-token-count: u32, + /// The randomness with which the next token is selected. + temperature: float32, + /// The number of possible next tokens the model will choose from. + top-k: u32, + /// The probability total of next tokens the model will choose from. + top-p: float32 + } + + /// The set of errors which may be raised by functions in this interface + variant error { + model-not-supported, + runtime-error(string), + invalid-input(string) + } + + /// An inferencing result + record inferencing-result { + /// The text generated by the model + // TODO: this should be a stream + text: string, + /// Usage information about the inferencing request + usage: inferencing-usage + } + + /// Usage information related to the inferencing result + record inferencing-usage { + /// Number of tokens in the prompt + prompt-token-count: u32, + /// Number of tokens generated by the inferencing operation + generated-token-count: u32 + } + + /// Perform inferencing using the provided model and prompt with the given optional params + infer: func(model: inferencing-model, prompt: string, params: option) -> result + + /// The model used for generating embeddings + type embedding-model = string + + /// Generate embeddings for the supplied list of text + generate-embeddings: func(model: embedding-model, text: list) -> result + + /// Result of generating embeddings + record embeddings-result { + /// The embeddings generated by the request + embeddings: list>, + /// Usage related to the embeddings generation request + usage: embeddings-usage + } + + /// Usage related to an embeddings generation request + record embeddings-usage { + /// Number of tokens in the prompt + prompt-token-count: u32, + } +} diff --git a/tests/wit-0.8/reactor.wit b/tests/wit-0.8/reactor.wit index c06a4bd..c5898ac 100644 --- a/tests/wit-0.8/reactor.wit +++ b/tests/wit-0.8/reactor.wit @@ -8,6 +8,7 @@ world reactor { import redis import key-value import http + import llm export inbound-http export inbound-redis } diff --git a/tests/wit/llm.wit b/tests/wit/llm.wit new file mode 100644 index 0000000..66ca408 --- /dev/null +++ b/tests/wit/llm.wit @@ -0,0 +1,67 @@ +/// A Large Language Model. +type inferencing-model = string + +/// Inference request parameters +record inferencing-params { + /// The maximum tokens that should be inferred. + /// + /// Note: the backing implementation may return less tokens. + max-tokens: u32, + /// The amount the model should avoid repeating tokens. + repeat-penalty: float32, + /// The number of tokens the model should apply the repeat penalty to. + repeat-penalty-last-n-token-count: u32, + /// The randomness with which the next token is selected. + temperature: float32, + /// The number of possible next tokens the model will choose from. + top-k: u32, + /// The probability total of next tokens the model will choose from. + top-p: float32 +} + +/// The set of errors which may be raised by functions in this interface +variant error { + model-not-supported, + runtime-error(string), + invalid-input(string) +} + +/// An inferencing result +record inferencing-result { + /// The text generated by the model + // TODO: this should be a stream + text: string, + /// Usage information about the inferencing request + usage: inferencing-usage +} + +/// Usage information related to the inferencing result +record inferencing-usage { + /// Number of tokens in the prompt + prompt-token-count: u32, + /// Number of tokens generated by the inferencing operation + generated-token-count: u32 +} + +/// Perform inferencing using the provided model and prompt with the given optional params +infer: func(model: inferencing-model, prompt: string, params: option) -> expected + +/// The model used for generating embeddings +type embedding-model = string + +/// Generate embeddings for the supplied list of text +generate-embeddings: func(model: embedding-model, text: list) -> expected + +/// Result of generating embeddings +record embeddings-result { + /// The embeddings generated by the request + embeddings: list>, + /// Usage related to the embeddings generation request + usage: embeddings-usage +} + +/// Usage related to an embeddings generation request +record embeddings-usage { + /// Number of tokens in the prompt + prompt-token-count: u32, +} \ No newline at end of file diff --git a/wit/llm.wit b/wit/llm.wit new file mode 100644 index 0000000..2b1533c --- /dev/null +++ b/wit/llm.wit @@ -0,0 +1,70 @@ +// A WASI interface dedicated to performing inferencing for Large Language Models. +interface llm { + /// A Large Language Model. + type inferencing-model = string + + /// Inference request parameters + record inferencing-params { + /// The maximum tokens that should be inferred. + /// + /// Note: the backing implementation may return less tokens. + max-tokens: u32, + /// The amount the model should avoid repeating tokens. + repeat-penalty: float32, + /// The number of tokens the model should apply the repeat penalty to. + repeat-penalty-last-n-token-count: u32, + /// The randomness with which the next token is selected. + temperature: float32, + /// The number of possible next tokens the model will choose from. + top-k: u32, + /// The probability total of next tokens the model will choose from. + top-p: float32 + } + + /// The set of errors which may be raised by functions in this interface + variant error { + model-not-supported, + runtime-error(string), + invalid-input(string) + } + + /// An inferencing result + record inferencing-result { + /// The text generated by the model + // TODO: this should be a stream + text: string, + /// Usage information about the inferencing request + usage: inferencing-usage + } + + /// Usage information related to the inferencing result + record inferencing-usage { + /// Number of tokens in the prompt + prompt-token-count: u32, + /// Number of tokens generated by the inferencing operation + generated-token-count: u32 + } + + /// Perform inferencing using the provided model and prompt with the given optional params + infer: func(model: inferencing-model, prompt: string, params: option) -> result + + /// The model used for generating embeddings + type embedding-model = string + + /// Generate embeddings for the supplied list of text + generate-embeddings: func(model: embedding-model, text: list) -> result + + /// Result of generating embeddings + record embeddings-result { + /// The embeddings generated by the request + embeddings: list>, + /// Usage related to the embeddings generation request + usage: embeddings-usage + } + + /// Usage related to an embeddings generation request + record embeddings-usage { + /// Number of tokens in the prompt + prompt-token-count: u32, + } +} diff --git a/wit/spin.wit b/wit/spin.wit index 0b2a509..d48de89 100644 --- a/wit/spin.wit +++ b/wit/spin.wit @@ -31,6 +31,7 @@ world reactor { import redis import key-value import http + import llm export inbound-http export inbound-redis