diff --git a/snowflake-api/src/connection.rs b/snowflake-api/src/connection.rs index e7087e1..0c146ea 100644 --- a/snowflake-api/src/connection.rs +++ b/snowflake-api/src/connection.rs @@ -29,8 +29,9 @@ pub enum ConnectionError { /// Container for query parameters /// This API has different endpoints and MIME types for different requests struct QueryContext { - path: &'static str, + path: String, accept_mime: &'static str, + method: reqwest::Method } pub enum QueryType { @@ -39,30 +40,40 @@ pub enum QueryType { CloseSession, JsonQuery, ArrowQuery, + ArrowQueryResult(String), } - impl QueryType { - const fn query_context(&self) -> QueryContext { + fn query_context(&self) -> QueryContext { match self { Self::LoginRequest => QueryContext { - path: "session/v1/login-request", + path: "session/v1/login-request".to_string(), accept_mime: "application/json", + method: reqwest::Method::POST, }, Self::TokenRequest => QueryContext { - path: "/session/token-request", + path: "/session/token-request".to_string(), accept_mime: "application/snowflake", + method: reqwest::Method::POST, }, Self::CloseSession => QueryContext { - path: "session", + path: "session".to_string(), accept_mime: "application/snowflake", + method: reqwest::Method::POST, }, Self::JsonQuery => QueryContext { - path: "queries/v1/query-request", + path: "queries/v1/query-request".to_string(), accept_mime: "application/json", + method: reqwest::Method::POST, }, Self::ArrowQuery => QueryContext { - path: "queries/v1/query-request", + path: "queries/v1/query-request".to_string(), + accept_mime: "application/snowflake", + method: reqwest::Method::POST, + }, + Self::ArrowQueryResult(query_result_url) => QueryContext { + path: query_result_url.to_string(), accept_mime: "application/snowflake", + method: reqwest::Method::GET, }, } } @@ -163,14 +174,22 @@ impl Connection { } // todo: persist client to use connection polling - let resp = self - .client - .post(url) - .headers(headers) - .json(&body) - .send() - .await?; - + let resp = match context.method { + reqwest::Method::POST => self + .client + .post(url) + .headers(headers) + .json(&body) + .send() + .await?, + reqwest::Method::GET => self + .client + .get(url) + .headers(headers) + .send() + .await?, + _ => panic!("Unsupported method"), + }; Ok(resp.json::().await?) } diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 1fa7b36..3d8e78c 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -407,6 +407,7 @@ impl SnowflakeApi { match resp { ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse), + ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse), ExecResponse::PutGet(pg) => put::put(pg).await, ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError( e.data.error_code, @@ -430,14 +431,21 @@ impl SnowflakeApi { } async fn exec_arrow_raw(&self, sql: &str) -> Result { - let resp = self + let mut resp = self .run_sql::(sql, QueryType::ArrowQuery) .await?; log::debug!("Got query response: {:?}", resp); + if let ExecResponse::QueryAsync(data) = &resp { + log::debug!("Got async exec response"); + resp = self.get_async_exec_result(&data.data.get_result_url).await?; + log::debug!("Got result for async exec: {:?}", resp); + } + let resp = match resp { // processable response ExecResponse::Query(qr) => Ok(qr), + ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse), ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse), ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError( e.data.error_code, @@ -504,10 +512,38 @@ impl SnowflakeApi { &self.account_identifier, &[], Some(&parts.session_token_auth_header), - body, + Some(body), ) .await?; Ok(resp) } + + pub async fn get_async_exec_result(&self, query_result_url: &String) -> Result{ + log::debug!("Getting async exec result: {}", query_result_url); + + let mut delay = 1; // Initial delay of 1 second + + loop { + let parts = self.session.get_token().await?; + let resp = self + .connection + .request::( + QueryType::ArrowQueryResult(query_result_url.to_string()), + &self.account_identifier, + &[], + Some(&parts.session_token_auth_header), + serde_json::Value::default() + ) + .await?; + + if let ExecResponse::QueryAsync(_) = &resp { + // simple exponential retry with a maximum wait time of 5 seconds + tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; + delay = (delay * 2).min(5); // cap delay to 5 seconds + } else { + return Ok(resp); + } + }; + } } diff --git a/snowflake-api/src/responses.rs b/snowflake-api/src/responses.rs index b8a3e68..11034ce 100644 --- a/snowflake-api/src/responses.rs +++ b/snowflake-api/src/responses.rs @@ -7,6 +7,7 @@ use serde::Deserialize; #[serde(untagged)] pub enum ExecResponse { Query(QueryExecResponse), + QueryAsync(QueryAsyncExecResponse), PutGet(PutGetExecResponse), Error(ExecErrorResponse), } @@ -34,6 +35,7 @@ pub struct BaseRestResponse { pub type PutGetExecResponse = BaseRestResponse; pub type QueryExecResponse = BaseRestResponse; +pub type QueryAsyncExecResponse = BaseRestResponse; pub type ExecErrorResponse = BaseRestResponse; pub type AuthErrorResponse = BaseRestResponse; pub type AuthenticatorResponse = BaseRestResponse; @@ -54,7 +56,7 @@ pub struct ExecErrorResponseData { pub pos: Option, // fixme: only valid for exec query response error? present in any exec query response? - pub query_id: String, + pub query_id: Option, pub sql_state: String, } @@ -151,6 +153,13 @@ pub struct QueryExecResponseData { // `sendResultTime`, `queryResultFormat`, `queryContext` also exist } +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct QueryAsyncExecResponseData { + pub query_id: String, + pub get_result_url: String, +} + #[derive(Deserialize, Debug)] pub struct ExecResponseRowType { pub name: String,