From e0b154306bb3d219f48a57385664b57344043a38 Mon Sep 17 00:00:00 2001 From: Himanshu Neema Date: Fri, 16 Jun 2023 12:31:55 -0700 Subject: [PATCH] feat: Support for Microsoft azure endpoints (#67) * initial implementation to support Azure Endpoints * update existing tests * update lib.rs example * use doc suggestion * documentation * test fix * test update: remove deprecated codex example; run all tests using 'cargo make all' * add example for azure openai service * update example * update example * back to Client::new * updates * test fix * update doc --- async-openai/README.md | 7 +- async-openai/src/audio.rs | 9 +- async-openai/src/chat.rs | 9 +- async-openai/src/client.rs | 209 +++++++++------------- async-openai/src/completion.rs | 9 +- async-openai/src/config.rs | 176 ++++++++++++++++++ async-openai/src/edit.rs | 9 +- async-openai/src/embedding.rs | 9 +- async-openai/src/file.rs | 9 +- async-openai/src/fine_tune.rs | 9 +- async-openai/src/image.rs | 9 +- async-openai/src/lib.rs | 38 +++- async-openai/src/model.rs | 9 +- async-openai/src/moderation.rs | 9 +- examples/Cargo.toml | 1 + examples/Makefile.toml | 20 +++ examples/azure-openai-service/Cargo.toml | 9 + examples/azure-openai-service/README.md | 3 + examples/azure-openai-service/src/main.rs | 106 +++++++++++ examples/completions/Cargo.toml | 2 +- examples/create-image/Cargo.toml | 2 +- 21 files changed, 484 insertions(+), 179 deletions(-) create mode 100644 async-openai/src/config.rs create mode 100644 examples/Makefile.toml create mode 100644 examples/azure-openai-service/Cargo.toml create mode 100644 examples/azure-openai-service/README.md create mode 100644 examples/azure-openai-service/src/main.rs diff --git a/async-openai/README.md b/async-openai/README.md index 1862db70..5927b9e2 100644 --- a/async-openai/README.md +++ b/async-openai/README.md @@ -31,14 +31,13 @@ - [x] Files - [x] Fine-Tuning (including SSE streaming) - [x] Images - - [ ] Microsoft Azure Endpoints / AD Authentication (see [issue](https://github.com/64bit/async-openai/issues/32)) + - [x] Microsoft Azure Endpoints - [x] Models - [x] Moderations - Non-streaming requests are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits) by the API server. - Ergonomic Rust library with builder pattern for all request objects. -_Being a young project there could be rough edges._ - +**Note on Azure OpenAI Service**: `async-openai` primarily implements OpenAI APIs, and exposes same library for Azure OpenAI Service too. In reality Azure OpenAI Service provides only subset of OpenAI APIs. ## Usage The library reads [API key](https://platform.openai.com/account/api-keys) from the environment variable `OPENAI_API_KEY`. @@ -104,7 +103,7 @@ async fn main() -> Result<(), Box> { Thank you for your time to contribute and improve the project, I'd be happy to have you! -A good starting point would be an [open issue](https://github.com/64bit/async-openai/issues). +A good starting point would be existing [open issues](https://github.com/64bit/async-openai/issues). ## License diff --git a/async-openai/src/audio.rs b/async-openai/src/audio.rs index 9fcb13a6..ccacd148 100644 --- a/async-openai/src/audio.rs +++ b/async-openai/src/audio.rs @@ -1,4 +1,5 @@ use crate::{ + config::Config, error::OpenAIError, types::{ CreateTranscriptionRequest, CreateTranscriptionResponse, CreateTranslationRequest, @@ -10,12 +11,12 @@ use crate::{ /// Turn audio into text /// Related guide: [Speech to text](https://platform.openai.com/docs/guides/speech-to-text) -pub struct Audio<'c> { - client: &'c Client, +pub struct Audio<'c, C: Config> { + client: &'c Client, } -impl<'c> Audio<'c> { - pub fn new(client: &'c Client) -> Self { +impl<'c, C: Config> Audio<'c, C> { + pub fn new(client: &'c Client) -> Self { Self { client } } diff --git a/async-openai/src/chat.rs b/async-openai/src/chat.rs index 4f7b5f5e..7efa83b3 100644 --- a/async-openai/src/chat.rs +++ b/async-openai/src/chat.rs @@ -1,4 +1,5 @@ use crate::{ + config::Config, error::OpenAIError, types::{ ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse, @@ -7,12 +8,12 @@ use crate::{ }; /// Given a chat conversation, the model will return a chat completion response. -pub struct Chat<'c> { - client: &'c Client, +pub struct Chat<'c, C: Config> { + client: &'c Client, } -impl<'c> Chat<'c> { - pub fn new(client: &'c Client) -> Self { +impl<'c, C: Config> Chat<'c, C> { + pub fn new(client: &'c Client) -> Self { Self { client } } diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index 3006a416..e4061b76 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -1,11 +1,11 @@ use std::pin::Pin; use futures::{stream::StreamExt, Stream}; -use reqwest::header::HeaderMap; use reqwest_eventsource::{Event, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Serialize}; use crate::{ + config::{Config, OpenAIConfig}, edit::Edits, error::{map_deserialization_error, OpenAIError, WrappedError}, file::Files, @@ -15,38 +15,33 @@ use crate::{ }; #[derive(Debug, Clone)] -/// Client is a container for api key, base url, organization id, and backoff -/// configuration used to make API calls. -pub struct Client { +/// Client is a container for config, backoff and http_client +/// used to make API calls. +pub struct Client { http_client: reqwest::Client, - api_key: String, - api_base: String, - org_id: String, + config: C, backoff: backoff::ExponentialBackoff, } -/// Default v1 API base url -pub const API_BASE: &str = "https://api.openai.com/v1"; -/// Name for organization header -pub const ORGANIZATION_HEADER: &str = "OpenAI-Organization"; - -impl Default for Client { - /// Create client with default [API_BASE] url and default API key from OPENAI_API_KEY env var - fn default() -> Self { +impl Client { + /// Client with default [OpenAIConfig] + pub fn new() -> Self { Self { http_client: reqwest::Client::new(), - api_base: API_BASE.to_string(), - api_key: std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "".to_string()), - org_id: Default::default(), + config: OpenAIConfig::default(), backoff: Default::default(), } } } -impl Client { - /// Create client with default [API_BASE] url and default API key from OPENAI_API_KEY env var - pub fn new() -> Self { - Default::default() +impl Client { + /// Create client with [OpenAIConfig] or [crate::config::AzureConfig] + pub fn with_config(config: C) -> Self { + Self { + http_client: reqwest::Client::new(), + config, + backoff: Default::default(), + } } /// Provide your own [client] to make HTTP requests with. @@ -57,24 +52,6 @@ impl Client { self } - /// To use a different API key different from default OPENAI_API_KEY env var - pub fn with_api_key>(mut self, api_key: S) -> Self { - self.api_key = api_key.into(); - self - } - - /// To use a different organization id other than default - pub fn with_org_id>(mut self, org_id: S) -> Self { - self.org_id = org_id.into(); - self - } - - /// To use a API base url different from default [API_BASE] - pub fn with_api_base>(mut self, api_base: S) -> Self { - self.api_base = api_base.into(); - self - } - /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests. /// Form submissions are not retried. pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self { @@ -82,74 +59,58 @@ impl Client { self } - pub fn api_base(&self) -> &str { - &self.api_base - } - - pub fn api_key(&self) -> &str { - &self.api_key - } - // API groups /// To call [Models] group related APIs using this client. - pub fn models(&self) -> Models { + pub fn models(&self) -> Models { Models::new(self) } /// To call [Completions] group related APIs using this client. - pub fn completions(&self) -> Completions { + pub fn completions(&self) -> Completions { Completions::new(self) } /// To call [Chat] group related APIs using this client. - pub fn chat(&self) -> Chat { + pub fn chat(&self) -> Chat { Chat::new(self) } /// To call [Edits] group related APIs using this client. - pub fn edits(&self) -> Edits { + pub fn edits(&self) -> Edits { Edits::new(self) } /// To call [Images] group related APIs using this client. - pub fn images(&self) -> Images { + pub fn images(&self) -> Images { Images::new(self) } /// To call [Moderations] group related APIs using this client. - pub fn moderations(&self) -> Moderations { + pub fn moderations(&self) -> Moderations { Moderations::new(self) } /// To call [Files] group related APIs using this client. - pub fn files(&self) -> Files { + pub fn files(&self) -> Files { Files::new(self) } /// To call [FineTunes] group related APIs using this client. - pub fn fine_tunes(&self) -> FineTunes { + pub fn fine_tunes(&self) -> FineTunes { FineTunes::new(self) } /// To call [Embeddings] group related APIs using this client. - pub fn embeddings(&self) -> Embeddings { + pub fn embeddings(&self) -> Embeddings { Embeddings::new(self) } /// To call [Audio] group related APIs using this client. - pub fn audio(&self) -> Audio { + pub fn audio(&self) -> Audio { Audio::new(self) } - fn headers(&self) -> HeaderMap { - let mut headers = HeaderMap::new(); - if !self.org_id.is_empty() { - headers.insert(ORGANIZATION_HEADER, self.org_id.as_str().parse().unwrap()); - } - headers - } - /// Make a GET request to {path} and deserialize the response body pub(crate) async fn get(&self, path: &str) -> Result where @@ -157,9 +118,9 @@ impl Client { { let request = self .http_client - .get(format!("{}{path}", self.api_base())) - .bearer_auth(self.api_key()) - .headers(self.headers()) + .get(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) .build()?; self.execute(request).await @@ -172,9 +133,9 @@ impl Client { { let request = self .http_client - .delete(format!("{}{path}", self.api_base())) - .bearer_auth(self.api_key()) - .headers(self.headers()) + .delete(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) .build()?; self.execute(request).await @@ -188,9 +149,9 @@ impl Client { { let request = self .http_client - .post(format!("{}{path}", self.api_base())) - .bearer_auth(self.api_key()) - .headers(self.headers()) + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) .json(&request) .build()?; @@ -208,9 +169,9 @@ impl Client { { let request = self .http_client - .post(format!("{}{path}", self.api_base())) - .bearer_auth(self.api_key()) - .headers(self.headers()) + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) .multipart(form) .build()?; @@ -311,14 +272,14 @@ impl Client { { let event_source = self .http_client - .post(format!("{}{path}", self.api_base())) - .headers(self.headers()) - .bearer_auth(self.api_key()) + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) .json(&request) .eventsource() .unwrap(); - Client::stream(event_source).await + stream(event_source).await } /// Make HTTP GET request to receive SSE @@ -333,61 +294,59 @@ impl Client { { let event_source = self .http_client - .get(format!("{}{path}", self.api_base())) + .get(self.config.url(path)) .query(query) - .headers(self.headers()) - .bearer_auth(self.api_key()) + .query(&self.config.query()) + .headers(self.config.headers()) .eventsource() .unwrap(); - Client::stream(event_source).await + stream(event_source).await } +} - /// Request which responds with SSE. - /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format) - pub(crate) async fn stream( - mut event_source: EventSource, - ) -> Pin> + Send>> - where - O: DeserializeOwned + std::marker::Send + 'static, - { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); +/// Request which responds with SSE. +/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format) +pub(crate) async fn stream( + mut event_source: EventSource, +) -> Pin> + Send>> +where + O: DeserializeOwned + std::marker::Send + 'static, +{ + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + tokio::spawn(async move { + while let Some(ev) = event_source.next().await { + match ev { + Err(e) => { + if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) { + // rx dropped + break; + } + } + Ok(event) => match event { + Event::Message(message) => { + if message.data == "[DONE]" { + break; + } + + let response = match serde_json::from_str::(&message.data) { + Err(e) => Err(map_deserialization_error(e, &message.data.as_bytes())), + Ok(output) => Ok(output), + }; - tokio::spawn(async move { - while let Some(ev) = event_source.next().await { - match ev { - Err(e) => { - if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) { + if let Err(_e) = tx.send(response) { // rx dropped break; } } - Ok(event) => match event { - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - - let response = match serde_json::from_str::(&message.data) { - Err(e) => { - Err(map_deserialization_error(e, &message.data.as_bytes())) - } - Ok(output) => Ok(output), - }; - - if let Err(_e) = tx.send(response) { - // rx dropped - break; - } - } - Event::Open => continue, - }, - } + Event::Open => continue, + }, } + } - event_source.close(); - }); + event_source.close(); + }); - Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)) - } + Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)) } diff --git a/async-openai/src/completion.rs b/async-openai/src/completion.rs index 05b2223a..d0775292 100644 --- a/async-openai/src/completion.rs +++ b/async-openai/src/completion.rs @@ -1,5 +1,6 @@ use crate::{ client::Client, + config::Config, error::OpenAIError, types::{CompletionResponseStream, CreateCompletionRequest, CreateCompletionResponse}, }; @@ -7,12 +8,12 @@ use crate::{ /// Given a prompt, the model will return one or more predicted /// completions, and can also return the probabilities of alternative /// tokens at each position. -pub struct Completions<'c> { - client: &'c Client, +pub struct Completions<'c, C: Config> { + client: &'c Client, } -impl<'c> Completions<'c> { - pub fn new(client: &'c Client) -> Self { +impl<'c, C: Config> Completions<'c, C> { + pub fn new(client: &'c Client) -> Self { Self { client } } diff --git a/async-openai/src/config.rs b/async-openai/src/config.rs new file mode 100644 index 00000000..21c05ff4 --- /dev/null +++ b/async-openai/src/config.rs @@ -0,0 +1,176 @@ +//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service. +use reqwest::header::{HeaderMap, AUTHORIZATION}; + +/// Default v1 API base url +pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1"; +/// Name for organization header +pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization"; + +/// [crate::Client] relies on this for every API call on OpenAI +/// or Azure OpenAI service +pub trait Config { + fn headers(&self) -> HeaderMap; + fn url(&self, path: &str) -> String; + fn query(&self) -> Vec<(&str, &str)>; + + fn api_base(&self) -> &str; + + fn api_key(&self) -> &str; +} + +/// Configuration for OpenAI API +pub struct OpenAIConfig { + api_base: String, + api_key: String, + org_id: String, +} + +impl Default for OpenAIConfig { + fn default() -> Self { + Self { + api_base: OPENAI_API_BASE.to_string(), + api_key: std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "".to_string()), + org_id: Default::default(), + } + } +} + +impl OpenAIConfig { + /// Create client with default [OPENAI_API_BASE] url and default API key from OPENAI_API_KEY env var + pub fn new() -> Self { + Default::default() + } + + /// To use a different organization id other than default + pub fn with_org_id>(mut self, org_id: S) -> Self { + self.org_id = org_id.into(); + self + } + + /// To use a different API key different from default OPENAI_API_KEY env var + pub fn with_api_key>(mut self, api_key: S) -> Self { + self.api_key = api_key.into(); + self + } + + /// To use a API base url different from default [OPENAI_API_BASE] + pub fn with_api_base>(mut self, api_base: S) -> Self { + self.api_base = api_base.into(); + self + } + + pub fn org_id(&self) -> &str { + &self.org_id + } +} + +impl Config for OpenAIConfig { + fn headers(&self) -> HeaderMap { + let mut headers = HeaderMap::new(); + if !self.org_id.is_empty() { + headers.insert( + OPENAI_ORGANIZATION_HEADER, + self.org_id.as_str().parse().unwrap(), + ); + } + + headers.insert( + AUTHORIZATION, + format!("Bearer {}", self.api_key).as_str().parse().unwrap(), + ); + + headers + } + + fn url(&self, path: &str) -> String { + format!("{}{}", OPENAI_API_BASE, path) + } + + fn api_base(&self) -> &str { + &self.api_base + } + + fn api_key(&self) -> &str { + &self.api_key + } + + fn query(&self) -> Vec<(&str, &str)> { + vec![] + } +} + +/// Configuration for Azure OpenAI Service +pub struct AzureConfig { + api_version: String, + deployment_id: String, + api_base: String, + api_key: String, +} + +impl Default for AzureConfig { + fn default() -> Self { + Self { + api_base: Default::default(), + api_key: std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "".to_string()), + deployment_id: Default::default(), + api_version: Default::default(), + } + } +} + +impl AzureConfig { + pub fn new() -> Self { + Default::default() + } + + pub fn with_api_version>(mut self, api_version: S) -> Self { + self.api_version = api_version.into(); + self + } + + pub fn with_deployment_id>(mut self, deployment_id: S) -> Self { + self.deployment_id = deployment_id.into(); + self + } + + /// To use a different API key different from default OPENAI_API_KEY env var + pub fn with_api_key>(mut self, api_key: S) -> Self { + self.api_key = api_key.into(); + self + } + + /// API base url in form of + pub fn with_api_base>(mut self, api_base: S) -> Self { + self.api_base = api_base.into(); + self + } +} + +impl Config for AzureConfig { + fn headers(&self) -> HeaderMap { + let mut headers = HeaderMap::new(); + + headers.insert("api-key", self.api_key.as_str().parse().unwrap()); + + headers + } + + fn url(&self, path: &str) -> String { + format!( + "{}/openai/deployments/{}{}", + self.api_base, self.deployment_id, path + ) + } + + fn api_base(&self) -> &str { + &self.api_base + } + + fn api_key(&self) -> &str { + &self.api_key + } + + fn query(&self) -> Vec<(&str, &str)> { + vec![("api-version", &self.api_version)] + } +} diff --git a/async-openai/src/edit.rs b/async-openai/src/edit.rs index 78e2349f..daca6b73 100644 --- a/async-openai/src/edit.rs +++ b/async-openai/src/edit.rs @@ -1,4 +1,5 @@ use crate::{ + config::Config, error::OpenAIError, types::{CreateEditRequest, CreateEditResponse}, Client, @@ -6,12 +7,12 @@ use crate::{ /// Given a prompt and an instruction, the model will return /// an edited version of the prompt. -pub struct Edits<'c> { - client: &'c Client, +pub struct Edits<'c, C: Config> { + client: &'c Client, } -impl<'c> Edits<'c> { - pub fn new(client: &'c Client) -> Self { +impl<'c, C: Config> Edits<'c, C> { + pub fn new(client: &'c Client) -> Self { Self { client } } diff --git a/async-openai/src/embedding.rs b/async-openai/src/embedding.rs index 3ef3b805..b6a9ab86 100644 --- a/async-openai/src/embedding.rs +++ b/async-openai/src/embedding.rs @@ -1,4 +1,5 @@ use crate::{ + config::Config, error::OpenAIError, types::{CreateEmbeddingRequest, CreateEmbeddingResponse}, Client, @@ -8,12 +9,12 @@ use crate::{ /// consumed by machine learning models and algorithms. /// /// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings) -pub struct Embeddings<'c> { - client: &'c Client, +pub struct Embeddings<'c, C: Config> { + client: &'c Client, } -impl<'c> Embeddings<'c> { - pub fn new(client: &'c Client) -> Self { +impl<'c, C: Config> Embeddings<'c, C> { + pub fn new(client: &'c Client) -> Self { Self { client } } diff --git a/async-openai/src/file.rs b/async-openai/src/file.rs index a2beba53..3ebfa5ac 100644 --- a/async-openai/src/file.rs +++ b/async-openai/src/file.rs @@ -1,4 +1,5 @@ use crate::{ + config::Config, error::OpenAIError, types::{CreateFileRequest, DeleteFileResponse, ListFilesResponse, OpenAIFile}, util::create_file_part, @@ -6,12 +7,12 @@ use crate::{ }; /// Files are used to upload documents that can be used with features like [Fine-tuning](https://platform.openai.com/docs/api-reference/fine-tunes). -pub struct Files<'c> { - client: &'c Client, +pub struct Files<'c, C: Config> { + client: &'c Client, } -impl<'c> Files<'c> { - pub fn new(client: &'c Client) -> Self { +impl<'c, C: Config> Files<'c, C> { + pub fn new(client: &'c Client) -> Self { Self { client } } diff --git a/async-openai/src/fine_tune.rs b/async-openai/src/fine_tune.rs index a4a771b9..04ce042b 100644 --- a/async-openai/src/fine_tune.rs +++ b/async-openai/src/fine_tune.rs @@ -1,4 +1,5 @@ use crate::{ + config::Config, error::OpenAIError, types::{ CreateFineTuneRequest, FineTune, FineTuneEventsResponseStream, ListFineTuneEventsResponse, @@ -10,12 +11,12 @@ use crate::{ /// Manage fine-tuning jobs to tailor a model to your specific training data. /// /// Related guide: [Fine-tune models](https://platform.openai.com/docs/guides/fine-tuning) -pub struct FineTunes<'c> { - client: &'c Client, +pub struct FineTunes<'c, C: Config> { + client: &'c Client, } -impl<'c> FineTunes<'c> { - pub fn new(client: &'c Client) -> Self { +impl<'c, C: Config> FineTunes<'c, C> { + pub fn new(client: &'c Client) -> Self { Self { client } } diff --git a/async-openai/src/image.rs b/async-openai/src/image.rs index 2d322dec..65e91502 100644 --- a/async-openai/src/image.rs +++ b/async-openai/src/image.rs @@ -1,4 +1,5 @@ use crate::{ + config::Config, error::OpenAIError, types::{ CreateImageEditRequest, CreateImageRequest, CreateImageVariationRequest, ImageResponse, @@ -10,12 +11,12 @@ use crate::{ /// Given a prompt and/or an input image, the model will generate a new image. /// /// Related guide: [Image generation](https://platform.openai.com/docs/guides/images/introduction) -pub struct Images<'c> { - client: &'c Client, +pub struct Images<'c, C: Config> { + client: &'c Client, } -impl<'c> Images<'c> { - pub fn new(client: &'c Client) -> Self { +impl<'c, C: Config> Images<'c, C> { + pub fn new(client: &'c Client) -> Self { Self { client } } diff --git a/async-openai/src/lib.rs b/async-openai/src/lib.rs index 553acb84..ef4500ef 100644 --- a/async-openai/src/lib.rs +++ b/async-openai/src/lib.rs @@ -3,23 +3,46 @@ //! ## Creating client //! //! ``` -//! use async_openai::Client; +//! use async_openai::{Client, config::OpenAIConfig}; //! -//! // Create a client with api key from env var OPENAI_API_KEY and default base url. +//! // Create a OpenAI client with api key from env var OPENAI_API_KEY and default base url. //! let client = Client::new(); //! -//! // OR use API key from different source +//! // Above is shortcut for +//! let config = OpenAIConfig::default(); +//! let client = Client::with_config(config); +//! +//! // OR use API key from different source and a non default organization //! let api_key = "sk-..."; // This secret could be from a file, or environment variable. -//! let client = Client::new().with_api_key(api_key); +//! let config = OpenAIConfig::new() +//! .with_api_key(api_key) +//! .with_org_id("the-continental"); //! -//! // Use organization other than default when making requests -//! let client = Client::new().with_org_id("the-org"); +//! let client = Client::with_config(config); //! //! // Use custom reqwest client //! let http_client = reqwest::ClientBuilder::new().user_agent("async-openai").build().unwrap(); //! let client = Client::new().with_http_client(http_client); //! ``` //! +//! ## Microsoft Azure Endpoints +//! +//! ``` +//! use async_openai::{Client, config::AzureConfig}; +//! +//! let config = AzureConfig::new() +//! .with_api_base("https://my-resource-name.openai.azure.com") +//! .with_api_version("2023-03-15-preview") +//! .with_deployment_id("deployment-id") +//! .with_api_key("..."); +//! +//! let client = Client::with_config(config); +//! +//! // Note that Azure OpenAI service does not support all APIs and `async-openai` +//! // doesn't restrict and still allows calls to all of the APIs as OpenAI. +//! +//! ``` +//! //! ## Making requests //! //!``` @@ -57,6 +80,7 @@ mod audio; mod chat; mod client; mod completion; +pub mod config; mod download; mod edit; mod embedding; @@ -72,8 +96,6 @@ mod util; pub use audio::Audio; pub use chat::Chat; pub use client::Client; -pub use client::API_BASE; -pub use client::ORGANIZATION_HEADER; pub use completion::Completions; pub use edit::Edits; pub use embedding::Embeddings; diff --git a/async-openai/src/model.rs b/async-openai/src/model.rs index f040b30f..df1a443f 100644 --- a/async-openai/src/model.rs +++ b/async-openai/src/model.rs @@ -1,4 +1,5 @@ use crate::{ + config::Config, error::OpenAIError, types::{DeleteModelResponse, ListModelResponse, Model}, Client, @@ -7,12 +8,12 @@ use crate::{ /// List and describe the various models available in the API. /// You can refer to the [Models](https://platform.openai.com/docs/models) documentation to understand what /// models are available and the differences between them. -pub struct Models<'c> { - client: &'c Client, +pub struct Models<'c, C: Config> { + client: &'c Client, } -impl<'c> Models<'c> { - pub fn new(client: &'c Client) -> Self { +impl<'c, C: Config> Models<'c, C> { + pub fn new(client: &'c Client) -> Self { Self { client } } diff --git a/async-openai/src/moderation.rs b/async-openai/src/moderation.rs index 039ce882..b1cae0df 100644 --- a/async-openai/src/moderation.rs +++ b/async-openai/src/moderation.rs @@ -1,4 +1,5 @@ use crate::{ + config::Config, error::OpenAIError, types::{CreateModerationRequest, CreateModerationResponse}, Client, @@ -7,12 +8,12 @@ use crate::{ /// Given a input text, outputs if the model classifies it as violating OpenAI's content policy. /// /// Related guide: [Moderations](https://platform.openai.com/docs/guides/moderation/overview) -pub struct Moderations<'c> { - client: &'c Client, +pub struct Moderations<'c, C: Config> { + client: &'c Client, } -impl<'c> Moderations<'c> { - pub fn new(client: &'c Client) -> Self { +impl<'c, C: Config> Moderations<'c, C> { + pub fn new(client: &'c Client) -> Self { Self { client } } diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 40edec34..9cfd9c47 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -2,6 +2,7 @@ members = [ "audio-transcribe", "audio-translate", + "azure-openai-service", "chat", "chat-stream", "create-edit", diff --git a/examples/Makefile.toml b/examples/Makefile.toml new file mode 100644 index 00000000..229ddb53 --- /dev/null +++ b/examples/Makefile.toml @@ -0,0 +1,20 @@ + +[tasks.all] +workspace = false +script = ''' +cd audio-transcribe && cargo run && cd - +cd audio-translate && cargo run && cd - +cd chat && cargo run && cd - +cd chat-stream && cargo run && cd - +cd completions && cargo run && cd - +cd completions-stream && cargo run && cd - +cd create-edit && cargo run && cd - +cd create-image && cargo run && cd - +cd create-image-b64-json && cargo run && cd - +cd create-image-edit && cargo run && cd - +cd create-image-variation && cargo run && cd - +cd embeddings && cargo run && cd - +cd models && cargo run && cd - +cd moderations && cargo run && cd - +#cd rate-limit-completions && cargo run && cd - +''' diff --git a/examples/azure-openai-service/Cargo.toml b/examples/azure-openai-service/Cargo.toml new file mode 100644 index 00000000..6cae8952 --- /dev/null +++ b/examples/azure-openai-service/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "azure-openai-service" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-openai = {path = "../../async-openai"} +tokio = { version = "1.27.0", features = ["full"] } +futures = "0.3.26" diff --git a/examples/azure-openai-service/README.md b/examples/azure-openai-service/README.md new file mode 100644 index 00000000..229cc5ab --- /dev/null +++ b/examples/azure-openai-service/README.md @@ -0,0 +1,3 @@ +## Overview + +Please note: before running this example configure api_base, api_key, deploy_id and api_version in main.rs. diff --git a/examples/azure-openai-service/src/main.rs b/examples/azure-openai-service/src/main.rs new file mode 100644 index 00000000..c230d0de --- /dev/null +++ b/examples/azure-openai-service/src/main.rs @@ -0,0 +1,106 @@ +use std::error::Error; + +use async_openai::{ + config::AzureConfig, + types::{ + ChatCompletionRequestMessageArgs, CreateChatCompletionRequestArgs, + CreateEmbeddingRequestArgs, Role, + }, + Client, +}; + + +async fn chat_completion_example(client: &Client) -> Result<(), Box> { + let request = CreateChatCompletionRequestArgs::default() + .max_tokens(512u16) + .model("gpt-3.5-turbo") + .messages([ + ChatCompletionRequestMessageArgs::default() + .role(Role::System) + .content("You are a helpful assistant.") + .build()?, + ChatCompletionRequestMessageArgs::default() + .role(Role::User) + .content("How does large language model work?") + .build()?, + ]) + .build()?; + + let response = client.chat().create(request).await?; + + println!("\nResponse:\n"); + for choice in response.choices { + println!( + "{}: Role: {} Content: {}", + choice.index, choice.message.role, choice.message.content + ); + } + Ok(()) +} + +// Bug (help wanted): https://github.com/64bit/async-openai/pull/67#issuecomment-1555165805 +// async fn completions_stream_example(client: &Client) -> Result<(), Box> { +// let request = CreateCompletionRequestArgs::default() +// .model("text-davinci-003") +// .n(1) +// .prompt("Tell me a short bedtime story about Optimus Prime and Bumblebee in Sir David Attenborough voice") +// .stream(true) +// .max_tokens(512_u16) +// .build()?; + +// let mut stream = client.completions().create_stream(request).await?; + +// while let Some(response) = stream.next().await { +// match response { +// Ok(ccr) => ccr.choices.iter().for_each(|c| { +// print!("{}", c.text); +// }), +// Err(e) => eprintln!("{}", e), +// } +// } +// Ok(()) +// } + +async fn embedding_example(client: &Client) -> Result<(), Box> { + let request = CreateEmbeddingRequestArgs::default() + .model("text-embedding-ada-002") + .input( + "Why do programmers hate nature? It has too many bugs.", + ) + .build()?; + + let response = client.embeddings().create(request).await?; + + for data in response.data { + println!( + "[{}]: has embedding of length {}", + data.index, + data.embedding.len() + ) + } + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let config = AzureConfig::new() + .with_api_base("https://your-resource-name.openai.azure.com") + .with_api_key("...") + .with_deployment_id("deployment-id") + .with_api_version("2023-03-15-preview"); + + let client = Client::with_config(config); + + // Run embedding Example + embedding_example(&client).await?; + + // Run completions stream Example + // Bug (help wanted): https://github.com/64bit/async-openai/pull/67#issuecomment-1555165805 + //completions_stream_example(&client).await?; + + // Run chat completion example + chat_completion_example(&client).await?; + + Ok(()) +} diff --git a/examples/completions/Cargo.toml b/examples/completions/Cargo.toml index 529d5749..9db08f8b 100644 --- a/examples/completions/Cargo.toml +++ b/examples/completions/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "example-completions" +name = "completions" version = "0.1.0" edition = "2021" publish = false diff --git a/examples/create-image/Cargo.toml b/examples/create-image/Cargo.toml index d20518cc..6c09b7bb 100644 --- a/examples/create-image/Cargo.toml +++ b/examples/create-image/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "example-create-image" +name = "create-image" version = "0.1.0" edition = "2021" publish = false