Skip to content
This repository has been archived by the owner on Feb 27, 2024. It is now read-only.

Add llm support for 0.2 modules along with tests #19

Merged
merged 1 commit into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions abi-conformance/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use std::{future::Future, str};
use test_config::Config;
use test_http::Http;
use test_key_value::KeyValue;
use test_llm::Llm;
use test_mysql::Mysql;
use test_postgres::Postgres;
use test_redis::Redis;
Expand All @@ -38,6 +39,7 @@ use wasmtime::{
use wasmtime_wasi::preview2::{pipe::WritePipe, Table, WasiCtx, WasiCtxBuilder, WasiView};

pub use test_key_value::KeyValueReport;
pub use test_llm::LlmReport;
pub use test_mysql::MysqlReport;
pub use test_postgres::PostgresReport;
pub use test_redis::RedisReport;
Expand All @@ -48,6 +50,7 @@ mod test_http;
mod test_inbound_http;
mod test_inbound_redis;
mod test_key_value;
mod test_llm;
mod test_mysql;
mod test_postgres;
mod test_redis;
Expand Down Expand Up @@ -137,6 +140,11 @@ pub struct Report {
/// See [`KeyValueReport`] for details.
pub key_value: KeyValueReport,

/// Results of the Spin llm tests
///
/// See [`LlmReport`] for details.
pub llm: LlmReport,

/// Results of the WASI tests
///
/// See [`WasiReport`] for details.
Expand All @@ -159,6 +167,7 @@ pub async fn test(
postgres::add_to_linker(&mut linker, |context| &mut context.postgres)?;
mysql::add_to_linker(&mut linker, |context| &mut context.mysql)?;
key_value::add_to_linker(&mut linker, |context| &mut context.key_value)?;
llm::add_to_linker(&mut linker, |context| &mut context.llm)?;
config::add_to_linker(&mut linker, |context| &mut context.config)?;

let pre = linker.instantiate_pre(component)?;
Expand All @@ -172,6 +181,7 @@ pub async fn test(
postgres: test_postgres::test(engine, test_config.clone(), &pre).await?,
mysql: test_mysql::test(engine, test_config.clone(), &pre).await?,
key_value: test_key_value::test(engine, test_config.clone(), &pre).await?,
llm: test_llm::test(engine, test_config.clone(), &pre).await?,
wasi: test_wasi::test(engine, test_config, &pre).await?,
})
}
Expand Down Expand Up @@ -220,6 +230,7 @@ struct Context {
postgres: Postgres,
mysql: Mysql,
key_value: KeyValue,
llm: Llm,
config: Config,
}

Expand All @@ -241,6 +252,7 @@ impl Context {
postgres: Default::default(),
mysql: Default::default(),
key_value: Default::default(),
llm: Default::default(),
config: Default::default(),
}
}
Expand Down
103 changes: 103 additions & 0 deletions abi-conformance/src/test_llm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use std::collections::HashMap;

use anyhow::{ensure, Result};
use async_trait::async_trait;
use serde::Serialize;

use crate::llm;

/// Report of which key-value functions a module successfully used, if any
#[derive(Serialize, PartialEq, Eq, Debug)]
pub struct LlmReport {
pub infer: Result<(), String>,
}

#[derive(Default)]
pub struct Llm {
inferences: HashMap<(String, String), String>,
embeddings: HashMap<(String, Vec<String>), Vec<Vec<f32>>>,
}

#[async_trait]
impl llm::Host for Llm {
async fn infer(
&mut self,
model: llm::InferencingModel,
prompt: String,
_params: Option<llm::InferencingParams>,
) -> wasmtime::Result<Result<llm::InferencingResult, llm::Error>> {
Ok(self
.inferences
.remove(&(model, prompt.clone()))
.map(|r| llm::InferencingResult {
text: r,
usage: llm::InferencingUsage {
prompt_token_count: 0,
generated_token_count: 0,
},
})
.ok_or_else(|| {
llm::Error::RuntimeError(format!(
"expected {:?}, got {:?}",
self.inferences.keys(),
prompt
))
}))
}

async fn generate_embeddings(
&mut self,
model: llm::EmbeddingModel,
text: Vec<String>,
) -> wasmtime::Result<Result<llm::EmbeddingsResult, llm::Error>> {
Ok(self
.embeddings
.remove(&(model, text.clone()))
.map(|r| llm::EmbeddingsResult {
embeddings: r,
usage: llm::EmbeddingsUsage {
prompt_token_count: 0,
},
})
.ok_or_else(|| {
llm::Error::RuntimeError(format!(
"expected {:?}, got {:?}",
self.embeddings.keys(),
text
))
}))
}
}

pub(crate) async fn test(
engine: &wasmtime::Engine,
test_config: crate::TestConfig,
pre: &wasmtime::component::InstancePre<crate::Context>,
) -> Result<LlmReport> {
Ok(LlmReport {
infer: {
let mut store =
crate::create_store_with_context(engine, test_config.clone(), |context| {
context
.llm
.inferences
.insert(("model".into(), "Say hello".into()), "hello".into());
});

crate::run_command(
&mut store,
pre,
&["llm-infer", "model", "Say hello"],
|store| {
ensure!(
store.data().llm.inferences.is_empty(),
"expected module to call `llm::infer` exactly once"
);

Ok(())
},
)
.await
},
})
}
2 changes: 1 addition & 1 deletion adapters/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ The componentize process uses adapters to adapt plain wasm modules to wasi previ
* The upstream wasi preview1 adapters for both commands and reactors for use with newer versions of wit-bindgen (v0.5 and above).
* These are currently the [v10.0.1 release](https://github.com/bytecodealliance/wasmtime/releases/tag/v10.0.1).
* A modified adapter that has knowledge of Spin APIs for use with v0.2 of wit-bindgen which has a different ABI than newer wit-bindgen based modules.
* This is currently built using commit [9555713](https://github.com/rylev/wasmtime/commit/955571392155e428a5f8be585c9a569f1f0b94c7) on the github.com/rylev/wasmtime fork of wasmtime.
* This is currently built using commit [b5d484](https://github.com/rylev/wasmtime/commit/b5d484c6abe040355add59ef3eb8ca1b4d9991e6) on the github.com/rylev/wasmtime fork of wasmtime.
* Changes are usually done on the `spin-adapter` branch of that fork.

Binary file modified adapters/wasi_snapshot_preview1.spin.wasm
Binary file not shown.
15 changes: 11 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,14 @@ fn add_custom_section(name: &str, data: &[u8], module: &[u8]) -> Result<Vec<u8>>

#[cfg(test)]
mod tests {
use anyhow::Context;
use wasmtime_wasi::preview2::{wasi::command::Command, Table, WasiView};

use {
anyhow::{anyhow, Result},
spin_abi_conformance::{
InvocationStyle, KeyValueReport, MysqlReport, PostgresReport, RedisReport, Report,
TestConfig, WasiReport,
InvocationStyle, KeyValueReport, LlmReport, MysqlReport, PostgresReport, RedisReport,
Report, TestConfig, WasiReport,
},
std::io::Cursor,
tokio::fs,
Expand All @@ -251,7 +252,11 @@ mod tests {

let engine = Engine::new(&config)?;

let component = Component::new(&engine, crate::componentize(module)?)?;
let component = Component::new(
&engine,
crate::componentize(module).context("could not componentize")?,
)
.context("failed to instantiate componentized bytes")?;

let report = spin_abi_conformance::test(
&component,
Expand All @@ -260,7 +265,8 @@ mod tests {
invocation_style: InvocationStyle::InboundHttp,
},
)
.await?;
.await
.context("abi conformance test failed")?;

let expected = Report {
inbound_http: Ok(()),
Expand Down Expand Up @@ -295,6 +301,7 @@ mod tests {
get_keys: Ok(()),
close: Ok(()),
},
llm: LlmReport { infer: Ok(()) },
wasi: WasiReport {
env: Ok(()),
epoch: Ok(()),
Expand Down
4 changes: 4 additions & 0 deletions tests/case-helper/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ pub enum Command {
KeyValueClose {
store: u32,
},
LlmInfer {
model: String,
prompt: String,
},
WasiEnv {
key: String,
},
Expand Down
5 changes: 5 additions & 0 deletions tests/rust-case-0.2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ impl fmt::Display for key_value::Error {

impl error::Error for key_value::Error {}

wit_bindgen_rust::import!("../wit/llm.wit");

fn dispatch(body: Option<Vec<u8>>) -> Response {
match execute(body) {
Ok(()) => {
Expand Down Expand Up @@ -352,6 +354,9 @@ fn execute(body: Option<Vec<u8>>) -> Result<()> {
Command::KeyValueClose { store } => {
key_value::close(*store);
}
Command::LlmInfer { model, prompt } => {
llm::infer(model, prompt, None);
}

Command::WasiEnv { key } => Command::env(key.clone())?,
Command::WasiEpoch => Command::epoch()?,
Expand Down
3 changes: 3 additions & 0 deletions tests/rust-case-0.8/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ fn execute(body: Option<Vec<u8>>) -> Result<()> {
Command::KeyValueClose { store } => {
spin::key_value::close(store);
}
Command::LlmInfer { model, prompt } => {
spin::llm::infer(&model, &prompt, None);
}

Command::WasiEnv { key } => Command::env(key)?,
Command::WasiEpoch => Command::epoch()?,
Expand Down
70 changes: 70 additions & 0 deletions tests/wit-0.8/llm.wit
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// A WASI interface dedicated to performing inferencing for Large Language Models.
interface llm {
/// A Large Language Model.
type inferencing-model = string

/// Inference request parameters
record inferencing-params {
/// The maximum tokens that should be inferred.
///
/// Note: the backing implementation may return less tokens.
max-tokens: u32,
/// The amount the model should avoid repeating tokens.
repeat-penalty: float32,
/// The number of tokens the model should apply the repeat penalty to.
repeat-penalty-last-n-token-count: u32,
/// The randomness with which the next token is selected.
temperature: float32,
/// The number of possible next tokens the model will choose from.
top-k: u32,
/// The probability total of next tokens the model will choose from.
top-p: float32
}

/// The set of errors which may be raised by functions in this interface
variant error {
model-not-supported,
runtime-error(string),
invalid-input(string)
}

/// An inferencing result
record inferencing-result {
/// The text generated by the model
// TODO: this should be a stream
text: string,
/// Usage information about the inferencing request
usage: inferencing-usage
}

/// Usage information related to the inferencing result
record inferencing-usage {
/// Number of tokens in the prompt
prompt-token-count: u32,
/// Number of tokens generated by the inferencing operation
generated-token-count: u32
}

/// Perform inferencing using the provided model and prompt with the given optional params
infer: func(model: inferencing-model, prompt: string, params: option<inferencing-params>) -> result<inferencing-result, error>

/// The model used for generating embeddings
type embedding-model = string

/// Generate embeddings for the supplied list of text
generate-embeddings: func(model: embedding-model, text: list<string>) -> result<embeddings-result, error>

/// Result of generating embeddings
record embeddings-result {
/// The embeddings generated by the request
embeddings: list<list<float32>>,
/// Usage related to the embeddings generation request
usage: embeddings-usage
}

/// Usage related to an embeddings generation request
record embeddings-usage {
/// Number of tokens in the prompt
prompt-token-count: u32,
}
}
1 change: 1 addition & 0 deletions tests/wit-0.8/reactor.wit
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ world reactor {
import redis
import key-value
import http
import llm
export inbound-http
export inbound-redis
}
Loading