Skip to content
This repository has been archived by the owner on Feb 27, 2024. It is now read-only.

Commit

Permalink
Add llm support for 0.2 modules along with tests
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>
  • Loading branch information
rylev committed Sep 7, 2023
1 parent fea6fc1 commit 5858dc2
Show file tree
Hide file tree
Showing 13 changed files with 348 additions and 5 deletions.
12 changes: 12 additions & 0 deletions abi-conformance/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use std::{future::Future, str};
use test_config::Config;
use test_http::Http;
use test_key_value::KeyValue;
use test_llm::Llm;
use test_mysql::Mysql;
use test_postgres::Postgres;
use test_redis::Redis;
Expand All @@ -38,6 +39,7 @@ use wasmtime::{
use wasmtime_wasi::preview2::{pipe::WritePipe, Table, WasiCtx, WasiCtxBuilder, WasiView};

pub use test_key_value::KeyValueReport;
pub use test_llm::LlmReport;
pub use test_mysql::MysqlReport;
pub use test_postgres::PostgresReport;
pub use test_redis::RedisReport;
Expand All @@ -48,6 +50,7 @@ mod test_http;
mod test_inbound_http;
mod test_inbound_redis;
mod test_key_value;
mod test_llm;
mod test_mysql;
mod test_postgres;
mod test_redis;
Expand Down Expand Up @@ -137,6 +140,11 @@ pub struct Report {
/// See [`KeyValueReport`] for details.
pub key_value: KeyValueReport,

/// Results of the Spin llm tests
///
/// See [`LlmReport`] for details.
pub llm: LlmReport,

/// Results of the WASI tests
///
/// See [`WasiReport`] for details.
Expand All @@ -159,6 +167,7 @@ pub async fn test(
postgres::add_to_linker(&mut linker, |context| &mut context.postgres)?;
mysql::add_to_linker(&mut linker, |context| &mut context.mysql)?;
key_value::add_to_linker(&mut linker, |context| &mut context.key_value)?;
llm::add_to_linker(&mut linker, |context| &mut context.llm)?;
config::add_to_linker(&mut linker, |context| &mut context.config)?;

let pre = linker.instantiate_pre(component)?;
Expand All @@ -172,6 +181,7 @@ pub async fn test(
postgres: test_postgres::test(engine, test_config.clone(), &pre).await?,
mysql: test_mysql::test(engine, test_config.clone(), &pre).await?,
key_value: test_key_value::test(engine, test_config.clone(), &pre).await?,
llm: test_llm::test(engine, test_config.clone(), &pre).await?,
wasi: test_wasi::test(engine, test_config, &pre).await?,
})
}
Expand Down Expand Up @@ -220,6 +230,7 @@ struct Context {
postgres: Postgres,
mysql: Mysql,
key_value: KeyValue,
llm: Llm,
config: Config,
}

Expand All @@ -241,6 +252,7 @@ impl Context {
postgres: Default::default(),
mysql: Default::default(),
key_value: Default::default(),
llm: Default::default(),
config: Default::default(),
}
}
Expand Down
103 changes: 103 additions & 0 deletions abi-conformance/src/test_llm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use std::collections::HashMap;

use anyhow::{ensure, Result};
use async_trait::async_trait;
use serde::Serialize;

use crate::llm;

/// Report of which key-value functions a module successfully used, if any
#[derive(Serialize, PartialEq, Eq, Debug)]
pub struct LlmReport {
pub infer: Result<(), String>,
}

#[derive(Default)]
pub struct Llm {
inferences: HashMap<(String, String), String>,
embeddings: HashMap<(String, Vec<String>), Vec<Vec<f32>>>,
}

#[async_trait]
impl llm::Host for Llm {
async fn infer(
&mut self,
model: llm::InferencingModel,
prompt: String,
_params: Option<llm::InferencingParams>,
) -> wasmtime::Result<Result<llm::InferencingResult, llm::Error>> {
Ok(self
.inferences
.remove(&(model, prompt.clone()))
.map(|r| llm::InferencingResult {
text: r,
usage: llm::InferencingUsage {
prompt_token_count: 0,
generated_token_count: 0,
},
})
.ok_or_else(|| {
llm::Error::RuntimeError(format!(
"expected {:?}, got {:?}",
self.inferences.keys(),
prompt
))
}))
}

async fn generate_embeddings(
&mut self,
model: llm::EmbeddingModel,
text: Vec<String>,
) -> wasmtime::Result<Result<llm::EmbeddingsResult, llm::Error>> {
Ok(self
.embeddings
.remove(&(model, text.clone()))
.map(|r| llm::EmbeddingsResult {
embeddings: r,
usage: llm::EmbeddingsUsage {
prompt_token_count: 0,
},
})
.ok_or_else(|| {
llm::Error::RuntimeError(format!(
"expected {:?}, got {:?}",
self.embeddings.keys(),
text
))
}))
}
}

pub(crate) async fn test(
engine: &wasmtime::Engine,
test_config: crate::TestConfig,
pre: &wasmtime::component::InstancePre<crate::Context>,
) -> Result<LlmReport> {
Ok(LlmReport {
infer: {
let mut store =
crate::create_store_with_context(engine, test_config.clone(), |context| {
context
.llm
.inferences
.insert(("model".into(), "Say hello".into()), "hello".into());
});

crate::run_command(
&mut store,
pre,
&["llm-infer", "model", "Say hello"],
|store| {
ensure!(
store.data().llm.inferences.is_empty(),
"expected module to call `llm::infer` exactly once"
);

Ok(())
},
)
.await
},
})
}
2 changes: 1 addition & 1 deletion adapters/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ The componentize process uses adapters to adapt plain wasm modules to wasi previ
* The upstream wasi preview1 adapters for both commands and reactors for use with newer versions of wit-bindgen (v0.5 and above).
* These are currently the [v10.0.1 release](https://github.com/bytecodealliance/wasmtime/releases/tag/v10.0.1).
* A modified adapter that has knowledge of Spin APIs for use with v0.2 of wit-bindgen which has a different ABI than newer wit-bindgen based modules.
* This is currently built using commit [8e261ac4](https://github.com/rylev/wasmtime/commit/8e261ac452ff54031efe2fde804cdf63fded3e55) on the github.com/rylev/wasmtime fork of wasmtime.
* This is currently built using commit [4536277](https://github.com/rylev/wasmtime/commit/4536277317c02443936253a029c3db765147102f) on the github.com/rylev/wasmtime fork of wasmtime.
Binary file modified adapters/wasi_snapshot_preview1.spin.wasm
Binary file not shown.
15 changes: 11 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,14 @@ fn add_custom_section(name: &str, data: &[u8], module: &[u8]) -> Result<Vec<u8>>

#[cfg(test)]
mod tests {
use anyhow::Context;
use wasmtime_wasi::preview2::{wasi::command::Command, Table, WasiView};

use {
anyhow::{anyhow, Result},
spin_abi_conformance::{
InvocationStyle, KeyValueReport, MysqlReport, PostgresReport, RedisReport, Report,
TestConfig, WasiReport,
InvocationStyle, KeyValueReport, LlmReport, MysqlReport, PostgresReport, RedisReport,
Report, TestConfig, WasiReport,
},
std::io::Cursor,
tokio::fs,
Expand All @@ -251,7 +252,11 @@ mod tests {

let engine = Engine::new(&config)?;

let component = Component::new(&engine, crate::componentize(module)?)?;
let component = Component::new(
&engine,
crate::componentize(module).context("could not componentize")?,
)
.context("failed to instantiate componentized bytes")?;

let report = spin_abi_conformance::test(
&component,
Expand All @@ -260,7 +265,8 @@ mod tests {
invocation_style: InvocationStyle::InboundHttp,
},
)
.await?;
.await
.context("abi conformance test failed")?;

let expected = Report {
inbound_http: Ok(()),
Expand Down Expand Up @@ -295,6 +301,7 @@ mod tests {
get_keys: Ok(()),
close: Ok(()),
},
llm: LlmReport { infer: Ok(()) },
wasi: WasiReport {
env: Ok(()),
epoch: Ok(()),
Expand Down
4 changes: 4 additions & 0 deletions tests/case-helper/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ pub enum Command {
KeyValueClose {
store: u32,
},
LlmInfer {
model: String,
prompt: String,
},
WasiEnv {
key: String,
},
Expand Down
5 changes: 5 additions & 0 deletions tests/rust-case-0.2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ impl fmt::Display for key_value::Error {

impl error::Error for key_value::Error {}

wit_bindgen_rust::import!("../wit/llm.wit");

fn dispatch(body: Option<Vec<u8>>) -> Response {
match execute(body) {
Ok(()) => {
Expand Down Expand Up @@ -352,6 +354,9 @@ fn execute(body: Option<Vec<u8>>) -> Result<()> {
Command::KeyValueClose { store } => {
key_value::close(*store);
}
Command::LlmInfer { model, prompt } => {
llm::infer(model, prompt, None);
}

Command::WasiEnv { key } => Command::env(key.clone())?,
Command::WasiEpoch => Command::epoch()?,
Expand Down
3 changes: 3 additions & 0 deletions tests/rust-case-0.8/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ fn execute(body: Option<Vec<u8>>) -> Result<()> {
Command::KeyValueClose { store } => {
spin::key_value::close(store);
}
Command::LlmInfer { model, prompt } => {
spin::llm::infer(&model, &prompt, None);
}

Command::WasiEnv { key } => Command::env(key)?,
Command::WasiEpoch => Command::epoch()?,
Expand Down
70 changes: 70 additions & 0 deletions tests/wit-0.8/llm.wit
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// A WASI interface dedicated to performing inferencing for Large Language Models.
interface llm {
/// A Large Language Model.
type inferencing-model = string

/// Inference request parameters
record inferencing-params {
/// The maximum tokens that should be inferred.
///
/// Note: the backing implementation may return less tokens.
max-tokens: u32,
/// The amount the model should avoid repeating tokens.
repeat-penalty: float32,
/// The number of tokens the model should apply the repeat penalty to.
repeat-penalty-last-n-token-count: u32,
/// The randomness with which the next token is selected.
temperature: float32,
/// The number of possible next tokens the model will choose from.
top-k: u32,
/// The probability total of next tokens the model will choose from.
top-p: float32
}

/// The set of errors which may be raised by functions in this interface
variant error {
model-not-supported,
runtime-error(string),
invalid-input(string)
}

/// An inferencing result
record inferencing-result {
/// The text generated by the model
// TODO: this should be a stream
text: string,
/// Usage information about the inferencing request
usage: inferencing-usage
}

/// Usage information related to the inferencing result
record inferencing-usage {
/// Number of tokens in the prompt
prompt-token-count: u32,
/// Number of tokens generated by the inferencing operation
generated-token-count: u32
}

/// Perform inferencing using the provided model and prompt with the given optional params
infer: func(model: inferencing-model, prompt: string, params: option<inferencing-params>) -> result<inferencing-result, error>

/// The model used for generating embeddings
type embedding-model = string

/// Generate embeddings for the supplied list of text
generate-embeddings: func(model: embedding-model, text: list<string>) -> result<embeddings-result, error>

/// Result of generating embeddings
record embeddings-result {
/// The embeddings generated by the request
embeddings: list<list<float32>>,
/// Usage related to the embeddings generation request
usage: embeddings-usage
}

/// Usage related to an embeddings generation request
record embeddings-usage {
/// Number of tokens in the prompt
prompt-token-count: u32,
}
}
1 change: 1 addition & 0 deletions tests/wit-0.8/reactor.wit
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ world reactor {
import redis
import key-value
import http
import llm
export inbound-http
export inbound-redis
}
Loading

0 comments on commit 5858dc2

Please # to comment.