diff --git a/Cargo.lock b/Cargo.lock index 0ee3647..eaf42c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -365,14 +365,37 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "darling" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +dependencies = [ + "darling_core 0.14.4", + "darling_macro 0.14.4", +] + [[package]] name = "darling" version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0209d94da627ab5605dcccf08bb18afa5009cfbef48d8a8b7d7bdbc79be25c5e" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.20.3", + "darling_macro 0.20.3", +] + +[[package]] +name = "darling_core" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "syn 1.0.109", ] [[package]] @@ -389,13 +412,24 @@ dependencies = [ "syn 2.0.33", ] +[[package]] +name = "darling_macro" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +dependencies = [ + "darling_core 0.14.4", + "quote", + "syn 1.0.109", +] + [[package]] name = "darling_macro" version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" dependencies = [ - "darling_core", + "darling_core 0.20.3", "quote", "syn 2.0.33", ] @@ -430,6 +464,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.3" @@ -559,6 +599,19 @@ dependencies = [ "slab", ] +[[package]] +name = "gdc_rust_types" +version = "1.0.2" +source = "git+https://github.com/hasura/gdc_rust_types?rev=bc57c40#bc57c406e530ffff2345c6a02f7b784959a600b5" +dependencies = [ + "indexmap 1.9.3", + "openapiv3", + "serde", + "serde-enum-str", + "serde_json", + "serde_with 3.3.0", +] + [[package]] name = "getrandom" version = "0.2.10" @@ -588,7 +641,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 1.9.3", "slab", "tokio", "tokio-util", @@ -601,6 +654,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" + [[package]] name = "heck" version = "0.4.1" @@ -754,7 +813,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", + "serde", +] + +[[package]] +name = "indexmap" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", "serde", ] @@ -920,14 +990,14 @@ version = "0.1.0" source = "git+http://github.com/hasura/ndc-spec.git?tag=v0.1.0-rc.5#0842f61fc5e29d19994ff3439abb8d7eabc28449" dependencies = [ "async-trait", - "indexmap", + "indexmap 1.9.3", "opentelemetry", "reqwest", "schemars", "serde", "serde_derive", "serde_json", - "serde_with", + "serde_with 2.3.3", "url", ] @@ -939,6 +1009,8 @@ dependencies = [ "axum", "axum-macros", "clap", + "gdc_rust_types", + "indexmap 1.9.3", "ndc-client", "ndc-test", "opentelemetry", @@ -969,7 +1041,7 @@ dependencies = [ "async-trait", "clap", "colored", - "indexmap", + "indexmap 1.9.3", "ndc-client", "proptest", "reqwest", @@ -1024,6 +1096,17 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +[[package]] +name = "openapiv3" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e56d5c441965b6425165b7e3223cc933ca469834f4a8b4786817a1f9dc4f13" +dependencies = [ + "indexmap 1.9.3", + "serde", + "serde_json", +] + [[package]] name = "openssl" version = "0.10.57" @@ -1141,7 +1224,7 @@ checksum = "8a81f725323db1b1206ca3da8bb19874bbd3f57c3bcd59471bfb04525b265b9b" dependencies = [ "futures-channel", "futures-util", - "indexmap", + "indexmap 1.9.3", "js-sys", "once_cell", "pin-project-lite", @@ -1537,7 +1620,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "763f8cd0d4c71ed8389c90cb8100cba87e763bd01a8e614d4f0af97bcd50a161" dependencies = [ "dyn-clone", - "indexmap", + "indexmap 1.9.3", "schemars_derive", "serde", "serde_json", @@ -1600,6 +1683,36 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-attributes" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eb8ec7724e4e524b2492b510e66957fe1a2c76c26a6975ec80823f2439da685" +dependencies = [ + "darling_core 0.14.4", + "serde-rename-rule", + "syn 1.0.109", +] + +[[package]] +name = "serde-enum-str" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26416dc95fcd46b0e4b12a3758043a229a6914050aaec2e8191949753ed4e9aa" +dependencies = [ + "darling 0.14.4", + "proc-macro2", + "quote", + "serde-attributes", + "syn 1.0.109", +] + +[[package]] +name = "serde-rename-rule" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "794e44574226fc701e3be5c651feb7939038fc67fb73f6f4dd5c4ba90fd3be70" + [[package]] name = "serde_derive" version = "1.0.188" @@ -1628,6 +1741,7 @@ version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ + "indexmap 2.0.0", "itoa", "ryu", "serde", @@ -1664,10 +1778,27 @@ dependencies = [ "base64 0.13.1", "chrono", "hex", - "indexmap", + "indexmap 1.9.3", + "serde", + "serde_json", + "serde_with_macros 2.3.3", + "time", +] + +[[package]] +name = "serde_with" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ca3b16a3d82c4088f343b7480a93550b3eabe1a358569c2dfe38bbcead07237" +dependencies = [ + "base64 0.21.4", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.0.0", "serde", "serde_json", - "serde_with_macros", + "serde_with_macros 3.3.0", "time", ] @@ -1677,7 +1808,19 @@ version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "881b6f881b17d13214e5d494c939ebab463d01264ce1811e9d4ac3a882e7695f" dependencies = [ - "darling", + "darling 0.20.3", + "proc-macro2", + "quote", + "syn 2.0.33", +] + +[[package]] +name = "serde_with_macros" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e6be15c453eb305019bfa438b1593c731f36a289a7853f7707ee29e870b3b3c" +dependencies = [ + "darling 0.20.3", "proc-macro2", "quote", "syn 2.0.33", @@ -1976,7 +2119,7 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", - "indexmap", + "indexmap 1.9.3", "pin-project", "pin-project-lite", "rand", diff --git a/rust-connector-sdk/Cargo.toml b/rust-connector-sdk/Cargo.toml index 9058670..e4eeae4 100644 --- a/rust-connector-sdk/Cargo.toml +++ b/rust-connector-sdk/Cargo.toml @@ -18,10 +18,13 @@ axum-macros = "^0.3.7" clap = { version = "^4.3.9", features = ["derive", "env"] } ndc-client = { git = "http://github.com/hasura/ndc-spec.git", tag = "v0.1.0-rc.5" } ndc-test = { git = "http://github.com/hasura/ndc-spec.git", tag = "v0.1.0-rc.5" } -opentelemetry = { version = "^0.20", features = [ "rt-tokio", "trace" ], default-features = false } +opentelemetry = { version = "^0.20", features = [ + "rt-tokio", + "trace", +], default-features = false } opentelemetry_api = "^0.20.0" opentelemetry_sdk = "^0.20.0" -opentelemetry-otlp = { version = "^0.13.0", features = [ "reqwest-client" ] } +opentelemetry-otlp = { version = "^0.13.0", features = ["reqwest-client"] } opentelemetry-semantic-conventions = "^0.12.0" prometheus = "^0.13.3" reqwest = "^0.11.20" @@ -29,10 +32,21 @@ schemars = { version = "^0.8.12", features = ["smol_str"] } serde = { version = "^1.0.164", features = ["derive"] } serde_json = { version = "^1.0.97", features = ["raw_value"] } thiserror = "^1.0" -tokio = { version = "^1.28.2", features = ["fs", "signal"] } -tower-http = { version = "^0.4.1", features = ["cors", "trace", "validate-request"] } +tokio = { version = "^1.28.2", features = ["fs", "signal"] } +tower-http = { version = "^0.4.1", features = [ + "cors", + "trace", + "validate-request", +] } tracing = "^0.1.37" tracing-opentelemetry = "^0.20.0" -tracing-subscriber = { version = "^0.3", default-features = false, features = ["ansi", "env-filter", "fmt", "json"] } +tracing-subscriber = { version = "^0.3", default-features = false, features = [ + "ansi", + "env-filter", + "fmt", + "json", +] } url = "2.4.1" uuid = "^1.3.4" +gdc_rust_types = { git = "https://github.com/hasura/gdc_rust_types", rev = "bc57c40" } +indexmap = "^1" diff --git a/rust-connector-sdk/src/default_main.rs b/rust-connector-sdk/src/default_main.rs index 9e31d8c..b2cb536 100644 --- a/rust-connector-sdk/src/default_main.rs +++ b/rust-connector-sdk/src/default_main.rs @@ -1,3 +1,5 @@ +mod v2_compat; + use crate::{ check_health, connector::{Connector, InvalidRange, SchemaError, UpdateConfigurationError}, @@ -38,6 +40,8 @@ use tower_http::{ use tracing::Level; use tracing_subscriber::{prelude::*, EnvFilter}; +use self::v2_compat::SourceConfig; + #[derive(Parser)] struct CliArgs { #[command(subcommand)] @@ -72,6 +76,8 @@ struct ServeCommand { service_token_secret: Option, #[arg(long, value_name = "OTEL_SERVICE_NAME", env = "OTEL_SERVICE_NAME")] service_name: Option, + #[arg(long, env = "ENABLE_V2_COMPATIBILITY")] + enable_v2_compatibility: bool, } #[derive(Clone, Parser)] @@ -219,41 +225,17 @@ where let server_state = init_server_state::(serve_command.configuration).await; - let expected_auth_header: Option = - serve_command - .service_token_secret - .and_then(|service_token_secret| { - let expected_bearer = format!("Bearer {}", service_token_secret); - HeaderValue::from_str(&expected_bearer).ok() - }); - - let router = create_router::(server_state) - .layer( - TraceLayer::new_for_http() - .make_span_with(DefaultMakeSpan::default().level(Level::INFO)), - ) - .layer(ValidateRequestHeaderLayer::custom( - move |request: &mut Request| { - // Validate the request - let auth_header = request.headers().get("Authorization").cloned(); + let router = create_router::( + server_state.clone(), + serve_command.service_token_secret.clone(), + ); - // NOTE: The comparison should probably be more permissive to allow for whitespace, etc. - if auth_header == expected_auth_header { - return Ok(()); - } - Err(( - StatusCode::UNAUTHORIZED, - Json(ErrorResponse { - message: "Internal error".into(), - details: serde_json::Value::Object(serde_json::Map::from_iter([( - "cause".into(), - serde_json::Value::String("Bearer token does not match.".to_string()), - )])), - }), - ) - .into_response()) - }, - )); + let router = if serve_command.enable_v2_compatibility { + let v2_router = create_v2_router(server_state, serve_command.service_token_secret.clone()); + Router::new().merge(router).nest("/v2", v2_router) + } else { + router + }; let port = serve_command.port; let address = net::SocketAddr::new(net::IpAddr::V4(net::Ipv4Addr::UNSPECIFIED), port); @@ -270,6 +252,7 @@ where .expect("unable to install signal handler") }; // wait for a SIGTERM, i.e. a normal `kill` command + #[cfg(unix)] let sigterm = async { tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") @@ -277,10 +260,15 @@ where .await }; // block until either of the above happens + #[cfg(unix)] tokio::select! { _ = sigint => (), _ = sigterm => (), } + #[cfg(windows)] + tokio::select! { + _ = sigint => (), + } opentelemetry::global::shutdown_tracer_provider(); }) @@ -317,13 +305,16 @@ where } } -pub fn create_router(state: ServerState) -> Router +pub fn create_router( + state: ServerState, + service_token_secret: Option, +) -> Router where C::RawConfiguration: DeserializeOwned + Sync + Send, C::Configuration: Serialize + Clone + Sync + Send, C::State: Sync + Send + Clone, { - Router::new() + let router = Router::new() .route("/capabilities", get(get_capabilities::)) .route("/health", get(get_health::)) .route("/metrics", get(get_metrics::)) @@ -331,6 +322,98 @@ where .route("/query", post(post_query::)) .route("/explain", post(post_explain::)) .route("/mutation", post(post_mutation::)) + .with_state(state); + + let expected_auth_header: Option = + service_token_secret.and_then(|service_token_secret| { + let expected_bearer = format!("Bearer {}", service_token_secret); + HeaderValue::from_str(&expected_bearer).ok() + }); + + router + .layer( + TraceLayer::new_for_http() + .make_span_with(DefaultMakeSpan::default().level(Level::INFO)), + ) + .layer(ValidateRequestHeaderLayer::custom( + move |request: &mut Request| { + // Validate the request + let auth_header = request.headers().get("Authorization").cloned(); + + // NOTE: The comparison should probably be more permissive to allow for whitespace, etc. + if auth_header == expected_auth_header { + return Ok(()); + } + Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + message: "Internal error".into(), + details: serde_json::Value::Object(serde_json::Map::from_iter([( + "cause".into(), + serde_json::Value::String("Bearer token does not match.".to_string()), + )])), + }), + ) + .into_response()) + }, + )) +} + +pub fn create_v2_router( + state: ServerState, + service_token_secret: Option, +) -> Router +where + C::RawConfiguration: DeserializeOwned + Sync + Send, + C::Configuration: Serialize + Clone + Sync + Send, + C::State: Sync + Send + Clone, +{ + Router::new() + .route("/schema", post(v2_compat::post_schema::)) + .route("/query", post(v2_compat::post_query::)) + // .route("/mutation", post(v2_compat::post_mutation::)) + // .route("/raw", post(v2_compat::post_raw::)) + .route("/explain", post(v2_compat::post_explain::)) + .layer( + TraceLayer::new_for_http() + .make_span_with(DefaultMakeSpan::default().level(Level::INFO)), + ) + .layer(ValidateRequestHeaderLayer::custom( + move |request: &mut Request| { + let provided_service_token_secret = request + .headers() + .get("x-hasura-dataconnector-config") + .and_then(|config_header| { + serde_json::from_slice::(config_header.as_bytes()).ok() + }) + .and_then(|config| config.service_token_secret); + + if service_token_secret == provided_service_token_secret { + // if token set & config header present & values match + // or token not set & config header not set/does not have value for token key + // allow request + Ok(()) + } else { + // all other cases, block request + Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + message: "Internal error".into(), + details: serde_json::Value::Object(serde_json::Map::from_iter([( + "cause".into(), + serde_json::Value::String( + "Service Token Secret does not match.".to_string(), + ), + )])), + }), + ) + .into_response()) + } + }, + )) + // capabilities and health endpoints are exempt from auth requirements + .route("/capabilities", get(v2_compat::get_capabilities::)) + .route("/health", get(v2_compat::get_health)) .with_state(state) } diff --git a/rust-connector-sdk/src/default_main/v2_compat.rs b/rust-connector-sdk/src/default_main/v2_compat.rs new file mode 100644 index 0000000..c02a029 --- /dev/null +++ b/rust-connector-sdk/src/default_main/v2_compat.rs @@ -0,0 +1,1146 @@ +use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; +use gdc_rust_types::{ + Aggregate, BinaryArrayComparisonOperator, BinaryComparisonOperator, Capabilities, + CapabilitiesResponse, ColumnInfo, ColumnSelector, ColumnType, ComparisonCapabilities, + ComparisonColumn, ComparisonValue, ConfigSchemaResponse, DetailLevel, ErrorResponse, + ErrorResponseType, ExistsInTable, ExplainResponse, Expression, Field, ForEachRow, FunctionInfo, + MutationCapabilities, ObjectTypeDefinition, OrderBy, OrderByElement, OrderByRelation, + OrderByTarget, OrderDirection, Query, QueryCapabilities, QueryRequest, QueryResponse, + Relationship, RelationshipType, ResponseFieldValue, ResponseRow, ScalarTypeCapabilities, + SchemaRequest, SchemaResponse, SubqueryComparisonCapabilities, TableInfo, TableRelationships, + Target, UnaryComparisonOperator, UpdateColumnOperatorDefinition, +}; +use indexmap::IndexMap; +use ndc_client::models; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::collections::BTreeMap; + +use crate::{ + connector::{Connector, ExplainError, QueryError}, + default_main::ServerState, +}; + +pub async fn get_health() -> impl IntoResponse { + // todo: if source_name and config provided, check if that specific source is healthy + StatusCode::NO_CONTENT +} + +pub async fn get_capabilities( + State(state): State>, +) -> Result, (StatusCode, Json)> { + let v3_capabilities = C::get_capabilities().await; + let v3_schema = C::get_schema(&state.configuration).await.map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + details: None, + message: err.to_string(), + r#type: None, + }), + ) + })?; + + let scalar_types = IndexMap::from_iter(v3_schema.scalar_types.into_iter().map( + |(name, scalar_type)| { + ( + name, + ScalarTypeCapabilities { + aggregate_functions: Some(IndexMap::from_iter( + scalar_type.aggregate_functions.into_iter().filter_map( + |(function_name, aggregate_function)| match aggregate_function + .result_type + { + models::Type::Named { name } => Some((function_name, name)), + models::Type::Nullable { .. } => None, + models::Type::Array { .. } => None, + }, + ), + )), + comparison_operators: Some(IndexMap::from_iter( + scalar_type.comparison_operators.into_iter().filter_map( + |(operator_name, comparison_operator)| match comparison_operator + .argument_type + { + models::Type::Named { name } => Some((operator_name, name)), + models::Type::Nullable { .. } => None, + models::Type::Array { .. } => None, + }, + ), + )), + update_column_operators: Some(IndexMap::from_iter( + scalar_type.update_operators.into_iter().filter_map( + |(operator_name, update_operator)| match update_operator.argument_type { + models::Type::Named { name } => Some(( + operator_name, + UpdateColumnOperatorDefinition { + argument_type: name, + }, + )), + models::Type::Nullable { .. } => None, + models::Type::Array { .. } => None, + }, + ), + )), + graphql_type: None, + }, + ) + }, + )); + + let response = CapabilitiesResponse { + capabilities: Capabilities { + comparisons: Some(ComparisonCapabilities { + subquery: Some(SubqueryComparisonCapabilities { + supports_relations: v3_capabilities + .capabilities + .query + .as_ref() + .map(|capabilities| capabilities.relation_comparisons.is_some()), + }), + }), + data_schema: None, + datasets: None, + explain: v3_capabilities.capabilities.explain.to_owned(), + interpolated_queries: None, + licensing: None, + metrics: None, + mutations: v3_capabilities + .capabilities + .mutations + .as_ref() + .map(|capabilities| MutationCapabilities { + atomicity_support_level: None, + delete: None, + insert: None, + returning: capabilities.returning.to_owned(), + update: None, + }), + queries: v3_capabilities + .capabilities + .query + .as_ref() + .map(|capabilities| QueryCapabilities { + foreach: capabilities.foreach.to_owned(), + }), + raw: None, + relationships: v3_capabilities.capabilities.relationships.to_owned(), + scalar_types: Some(scalar_types), + subscriptions: None, + user_defined_functions: None, + post_schema: Some(json!({})), + }, + config_schemas: get_openapi_config_schema_response(), + display_name: None, + release_name: Some(v3_capabilities.versions.to_owned()), + }; + + Ok(Json(response)) +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SourceConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub service_token_secret: Option, +} + +fn get_openapi_config_schema_response() -> ConfigSchemaResponse { + // note: we should probably have some config for auth, will do later + let config_schema_json = json!({ + "type": "object", + "nullable": false, + "properties": { + "service_token_secret": { + "title": "Service Token Secret", + "description": "Service Token Secret, required if your connector is configured with a secret.", + "nullable": true, + "type": "string" + } + }, + "required": ["service_token_secret"] + }); + + ConfigSchemaResponse { + config_schema: serde_json::from_value(config_schema_json) + .expect("json value should be valid OpenAPI schema"), + other_schemas: serde_json::from_str("{}").expect("static string should be valid json"), + } +} + +pub async fn post_schema( + State(state): State>, + request: Option>, +) -> Result, (StatusCode, Json)> { + let v3_schema = C::get_schema(&state.configuration).await.map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + details: None, + message: err.to_string(), + r#type: None, + }), + ) + })?; + let schema = map_schema(v3_schema).map_err(|err| (StatusCode::BAD_REQUEST, Json(err)))?; + + let schema = if let Some(request) = request { + let SchemaResponse { + object_types, + tables, + functions, + } = schema; + + let tables = if let Some(requested_tables) = request + .filters + .as_ref() + .and_then(|filters| filters.only_tables.as_ref()) + { + tables + .into_iter() + .filter(|table| { + requested_tables + .iter() + .any(|requested_table| requested_table == &table.name) + }) + .collect() + } else { + tables + }; + + let tables = match request.detail_level { + Some(DetailLevel::BasicInfo) => tables + .into_iter() + .map(|table| TableInfo { + columns: None, + deletable: None, + description: None, + foreign_keys: None, + insertable: None, + name: table.name, + primary_key: None, + r#type: table.r#type, + updatable: None, + }) + .collect(), + _ => tables, + }; + + let functions = if let Some(requested_functions) = request + .filters + .as_ref() + .and_then(|filters| filters.only_functions.as_ref()) + { + functions.map(|functions| { + functions + .into_iter() + .filter(|function| { + requested_functions + .iter() + .any(|requested_function| requested_function == &function.name) + }) + .collect() + }) + } else { + functions + }; + + let functions = match request.detail_level { + Some(DetailLevel::BasicInfo) => functions.map(|functions| { + functions + .into_iter() + .map(|function| FunctionInfo { + args: None, + description: None, + name: function.name, + response_cardinality: None, + returns: None, + r#type: function.r#type, + }) + .collect() + }), + _ => functions, + }; + + SchemaResponse { + object_types, + tables, + functions, + } + } else { + schema + }; + + Ok(Json(schema)) +} + +fn map_schema(schema: models::SchemaResponse) -> Result { + let tables = schema + .collections + .iter() + .map(|collection| { + let table_type = schema + .object_types + .get(&collection.collection_type) + .ok_or_else(|| ErrorResponse { + details: None, + message: format!( + "Could not find type {} for table {}", + collection.collection_type, collection.name + ), + r#type: Some(ErrorResponseType::UncaughtError), + })?; + let columns = table_type + .fields + .iter() + .map(|(field_name, field_info)| { + Ok(ColumnInfo { + name: field_name.to_owned(), + r#type: get_field_type(&field_info.r#type, &schema), + nullable: matches!(field_info.r#type, models::Type::Nullable { .. }), + description: field_info.description.to_owned(), + insertable: collection + .insertable_columns + .as_ref() + .map(|insertable_columns| insertable_columns.contains(field_name)), + updatable: collection + .updatable_columns + .as_ref() + .map(|updatable_columns| updatable_columns.contains(field_name)), + value_generated: None, + }) + }) + .collect::, _>>()?; + Ok(TableInfo { + name: vec![collection.name.to_owned()], + description: collection.description.to_owned(), + insertable: collection + .insertable_columns + .as_ref() + .map(|insertable_columns| !insertable_columns.is_empty()), + updatable: collection + .updatable_columns + .as_ref() + .map(|updatable_columns| !updatable_columns.is_empty()), + deletable: Some(collection.deletable), + primary_key: None, + foreign_keys: None, + r#type: None, + columns: Some(columns), + }) + }) + .collect::, _>>()?; + + let object_types = schema + .object_types + .iter() + .map(|(object_name, object_definition)| { + Ok(ObjectTypeDefinition { + name: object_name.to_owned(), + description: object_definition.description.to_owned(), + columns: object_definition + .fields + .iter() + .map(|(field_name, field_definition)| ColumnInfo { + description: field_definition.description.to_owned(), + insertable: None, + name: field_name.to_owned(), + nullable: matches!(field_definition.r#type, models::Type::Nullable { .. }), + r#type: get_field_type(&field_definition.r#type, &schema), + updatable: None, + value_generated: None, + }) + .collect(), + }) + }) + .collect::, _>>()?; + + Ok(SchemaResponse { + tables, + object_types: Some(object_types), + functions: None, + }) +} + +fn get_field_type(column_type: &models::Type, schema: &models::SchemaResponse) -> ColumnType { + match column_type { + models::Type::Named { name } => { + if schema.object_types.contains_key(name) { + ColumnType::ColumnTypeNonScalar(gdc_rust_types::ColumnTypeNonScalar::Object { + name: name.to_owned(), + }) + } else { + // silently assuming scalar if not object type + ColumnType::Scalar(name.to_owned()) + } + } + models::Type::Nullable { underlying_type } => get_field_type(underlying_type, schema), + models::Type::Array { element_type } => { + ColumnType::ColumnTypeNonScalar(gdc_rust_types::ColumnTypeNonScalar::Array { + element_type: Box::new(get_field_type(element_type, schema)), + nullable: matches!(**element_type, models::Type::Nullable { .. }), + }) + } + } +} + +pub async fn post_query( + State(state): State>, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + let request = map_query_request(request).map_err(|err| (StatusCode::BAD_REQUEST, Json(err)))?; + let response = C::query(&state.configuration, &state.state, request) + .await + .map_err(|err| match err { + QueryError::InvalidRequest(message) | QueryError::UnsupportedOperation(message) => ( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + details: None, + message, + r#type: None, + }), + ), + QueryError::Other(err) => ( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + details: None, + message: err.to_string(), + r#type: None, + }), + ), + })?; + let response = map_query_response(response); + Ok(Json(response)) +} + +pub async fn post_explain( + State(state): State>, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + let v2_ir_json = serde_json::to_string(&request).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + details: None, + message: format!("Error serializing v2 IR to JSON: {}", err), + r#type: None, + }), + ) + })?; + let request = map_query_request(request).map_err(|err| (StatusCode::BAD_REQUEST, Json(err)))?; + + let v3_ir_json = serde_json::to_string(&request).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + details: None, + message: format!("Error serializing v3 IR to JSON: {}", err), + r#type: None, + }), + ) + })?; + let response = C::explain(&state.configuration, &state.state, request.clone()) + .await + .map_err(|err| match err { + ExplainError::InvalidRequest(message) | ExplainError::UnsupportedOperation(message) => { + ( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + details: None, + message, + r#type: None, + }), + ) + } + ExplainError::Other(err) => ( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + details: None, + message: err.to_string(), + r#type: None, + }), + ), + })?; + + let response = ExplainResponse { + lines: vec![ + "v2 IR".to_string(), + v2_ir_json, + "v3 IR".to_string(), + v3_ir_json, + ] + .into_iter() + .chain( + response + .details + .into_iter() + .map(|(key, value)| format!("{key}: {value}")), + ) + .collect(), + query: "".to_string(), + }; + Ok(Json(response)) +} + +fn map_query_request(request: QueryRequest) -> Result { + let QueryRequest { + foreach, + target, + relationships, + query, + interpolated_queries: _, + } = request; + + let foreach_expr = foreach + .as_ref() + .and_then(|foreach| foreach.first()) + .and_then(|first_row| { + let mut expressions: Vec<_> = first_row + .keys() + .map(|key| models::Expression::BinaryComparisonOperator { + column: models::ComparisonTarget::Column { + name: key.to_owned(), + path: vec![], + }, + operator: models::BinaryComparisonOperator::Equal, + value: models::ComparisonValue::Variable { + name: key.to_owned(), + }, + }) + .collect(); + + if expressions.len() > 1 { + Some(models::Expression::And { expressions }) + } else { + expressions.pop() + } + }); + + let variables = foreach.map(|foreach| { + foreach + .into_iter() + .map(|map| BTreeMap::from_iter(map.into_iter().map(|(key, value)| (key, value.value)))) + .collect() + }); + + let (collection, arguments) = get_collection_and_arguments(&target)?; + + let collection_relationships = BTreeMap::from_iter( + relationships + .iter() + .map(|source_table| { + let collection = get_name(&source_table.source_table)?; + source_table + .relationships + .iter() + .map(move |(relationship_name, relationship_info)| { + let Relationship { + column_mapping, + relationship_type, + target, + } = relationship_info; + let (target_collection, arguments) = + get_collection_and_relationship_arguments(target)?; + Ok(( + format!("{}.{}", collection, relationship_name), + models::Relationship { + column_mapping: BTreeMap::from_iter( + column_mapping.clone().into_iter(), + ), + relationship_type: match relationship_type { + RelationshipType::Object => models::RelationshipType::Object, + RelationshipType::Array => models::RelationshipType::Array, + }, + source_collection_or_type: get_name(&source_table.source_table)?, + target_collection, + arguments, + }, + )) + }) + .collect::, _>>() + }) + .collect::, _>>()? + .into_iter() + .flatten(), + ); + + Ok(models::QueryRequest { + collection: collection.clone(), + arguments, + variables, + query: map_query(query, &collection, &relationships, foreach_expr)?, + collection_relationships, + }) +} + +fn map_query( + query: Query, + collection: &String, + relationships: &Vec, + foreach_expr: Option, +) -> Result { + let Query { + aggregates, + aggregates_limit, + fields, + limit, + offset, + order_by, + r#where, + } = query; + + let order_by = order_by + .map(|order_by| { + let OrderBy { + elements, + relations, + } = order_by; + Ok(models::OrderBy { + elements: elements + .into_iter() + .map(|element| { + let OrderByElement { + order_direction, + target, + target_path, + } = element; + + let element = models::OrderByElement { + order_direction: match order_direction { + OrderDirection::Asc => models::OrderDirection::Asc, + OrderDirection::Desc => models::OrderDirection::Desc, + }, + target: match target { + OrderByTarget::StarCountAggregate {} => { + models::OrderByTarget::StarCountAggregate { + path: map_order_by_path( + target_path, + relations.to_owned(), + collection, + relationships, + )?, + } + } + OrderByTarget::SingleColumnAggregate { + column, + function, + result_type: _, + } => models::OrderByTarget::SingleColumnAggregate { + column, + function, + path: map_order_by_path( + target_path, + relations.to_owned(), + collection, + relationships, + )?, + }, + OrderByTarget::Column { column } => models::OrderByTarget::Column { + name: get_col_name(&column)?, + path: map_order_by_path( + target_path, + relations.to_owned(), + collection, + relationships, + )?, + }, + }, + }; + Ok(element) + }) + .collect::, _>>()?, + }) + }) + .transpose()?; + + let aggregates = aggregates.map(|aggregates| { + IndexMap::from_iter(aggregates.into_iter().map(|(key, aggregate)| { + ( + key, + match aggregate { + Aggregate::ColumnCount { column, distinct } => { + models::Aggregate::ColumnCount { column, distinct } + } + Aggregate::SingleColumn { + column, + function, + result_type: _, + } => models::Aggregate::SingleColumn { column, function }, + Aggregate::StarCount {} => models::Aggregate::StarCount {}, + }, + ) + })) + }); + let fields = fields + .map(|fields| { + let fields = fields + .into_iter() + .map(|(key, field)| { + Ok(( + key, + match field { + Field::Column { + column, + column_type: _, + } => models::Field::Column { column }, + Field::Relationship { + query, + relationship, + } => { + let (target_collection, arguments) = + get_relationship_collection_arguments( + collection, + &relationship, + relationships, + )?; + + models::Field::Relationship { + query: Box::new(map_query( + query, + &target_collection, + relationships, + None, + )?), + relationship: format!("{}.{}", collection, relationship), + arguments, + } + } + Field::Object { .. } => { + return Err(ErrorResponse { + details: None, + message: "Object fields not supported".to_string(), + r#type: None, + }) + } + Field::Array { .. } => { + return Err(ErrorResponse { + details: None, + message: "Array fields not supported".to_string(), + r#type: None, + }) + } + }, + )) + }) + .collect::, _>>()? + .into_iter(); + Ok(IndexMap::from_iter(fields)) + }) + .transpose()?; + + let applicable_limit = match (limit, aggregates_limit) { + (None, None) => None, + (None, Some(aggregates_limit)) => { + if fields.is_none() { + Some(aggregates_limit) + } else { + return Err(ErrorResponse { + details: None, + message: + "Setting limit for aggregates when fields also requested is not supported" + .to_string(), + r#type: None, + }); + } + } + (Some(limit), None) => { + if aggregates.is_none() { + Some(limit) + } else { + return Err(ErrorResponse { + details: None, + message: + "Setting limit for fields when aggregates also requested is not supported" + .to_string(), + r#type: None, + }); + } + } + (Some(_), Some(_)) => { + return Err(ErrorResponse { + details: None, + message: "Different limits for aggregates and fields not supported".to_string(), + r#type: None, + }) + } + }; + + let limit = applicable_limit + .map(|limit| { + limit.try_into().map_err(|_| ErrorResponse { + details: None, + message: "Limit must be valid u32".to_string(), + r#type: None, + }) + }) + .transpose()?; + + let offset = offset + .map(|offset| { + offset.try_into().map_err(|_| ErrorResponse { + details: None, + message: "Offset must be valid u32".to_string(), + r#type: None, + }) + }) + .transpose()?; + + let predicate = r#where + .map(|r#where| map_expression(&r#where, collection, relationships)) + .transpose()?; + + let predicate = match (predicate, foreach_expr) { + (None, None) => None, + (None, Some(foreach_expr)) => Some(foreach_expr), + (Some(predicate), None) => Some(predicate), + (Some(predicate), Some(foreach_expr)) => Some(models::Expression::And { + expressions: vec![predicate, foreach_expr], + }), + }; + + Ok(models::Query { + aggregates, + fields, + limit, + offset, + order_by, + predicate, + }) +} + +fn map_order_by_path( + path: Vec, + relations: IndexMap, + collection: &String, + relationships: &Vec, +) -> Result, ErrorResponse> { + let mut mapped_path: Vec = vec![]; + + let mut relations = relations; + let mut source_table = collection.to_owned(); + for segment in path { + let relation = relations.get(&segment).ok_or_else(|| ErrorResponse { + details: None, + message: format!("could not find order by relationship for path segment {segment}"), + r#type: None, + })?; + + let (target_table, arguments) = + get_relationship_collection_arguments(&source_table, &segment, relationships)?; + + mapped_path.push(models::PathElement { + relationship: format!("{}.{}", source_table, segment), + arguments, + predicate: if let Some(predicate) = &relation.r#where { + Box::new(map_expression(predicate, &target_table, relationships)?) + } else { + // hack: predicate is not optional, so default to empty "And" expression, which evaluates to true. + Box::new(models::Expression::And { + expressions: vec![], + }) + }, + }); + + source_table = target_table; + relations = relation.subrelations.to_owned(); + } + + Ok(mapped_path) +} + +fn get_relationship_collection_arguments( + source_table_name: &str, + relationship: &str, + table_relationships: &[TableRelationships], +) -> Result<(String, BTreeMap), ErrorResponse> { + let source_table = table_relationships + .iter() + .find( + |table_relationships| matches!(table_relationships.source_table.as_slice(), [name] if source_table_name == name), + ) + .ok_or_else(|| ErrorResponse { + details: None, + message: format!("Could not find table {source_table_name} in relationships"), + r#type: None, + })?; + + let relationship = source_table + .relationships + .get(relationship) + .ok_or_else(|| ErrorResponse { + details: None, + message: format!( + "Could not find relationship {relationship} in table {source_table_name}" + ), + r#type: None, + })?; + + get_collection_and_relationship_arguments(&relationship.target) +} + +fn map_expression( + expression: &Expression, + collection: &str, + relationships: &Vec, +) -> Result { + Ok(match expression { + Expression::And { expressions } => models::Expression::And { + expressions: expressions + .iter() + .map(|expression| map_expression(expression, collection, relationships)) + .collect::, _>>()?, + }, + Expression::Or { expressions } => models::Expression::Or { + expressions: expressions + .iter() + .map(|expression| map_expression(expression, collection, relationships)) + .collect::, _>>()?, + }, + Expression::Not { expression } => models::Expression::Not { + expression: Box::new(map_expression(expression, collection, relationships)?), + }, + Expression::ApplyUnaryComparison { column, operator } => { + models::Expression::UnaryComparisonOperator { + column: map_comparison_column(column)?, + operator: match operator { + UnaryComparisonOperator::IsNull => models::UnaryComparisonOperator::IsNull, + UnaryComparisonOperator::Other(operator) => { + return Err(ErrorResponse { + details: None, + message: format!("Unknown unary comparison operator {operator}"), + r#type: None, + }) + } + }, + } + } + Expression::ApplyBinaryComparison { + column, + operator, + value, + } => models::Expression::BinaryComparisonOperator { + column: map_comparison_column(column)?, + operator: match operator { + BinaryComparisonOperator::LessThan => models::BinaryComparisonOperator::Other { + name: "less_than".to_string(), + }, + BinaryComparisonOperator::LessThanOrEqual => { + models::BinaryComparisonOperator::Other { + name: "less_than_or_equal".to_string(), + } + } + BinaryComparisonOperator::Equal => models::BinaryComparisonOperator::Equal, + BinaryComparisonOperator::GreaterThan => models::BinaryComparisonOperator::Other { + name: "greater_than".to_string(), + }, + BinaryComparisonOperator::GreaterThanOrEqual => { + models::BinaryComparisonOperator::Other { + name: "greater_than_or_equal".to_string(), + } + } + BinaryComparisonOperator::Other(operator) => { + models::BinaryComparisonOperator::Other { + name: operator.to_owned(), + } + } + }, + value: match value { + ComparisonValue::Scalar { + value, + value_type: _, + } => models::ComparisonValue::Scalar { + value: value.clone(), + }, + ComparisonValue::Column { column } => models::ComparisonValue::Column { + column: map_comparison_column(column)?, + }, + }, + }, + Expression::ApplyBinaryArrayComparison { + column, + operator, + value_type: _, + values, + } => models::Expression::BinaryArrayComparisonOperator { + column: map_comparison_column(column)?, + operator: match operator { + BinaryArrayComparisonOperator::In => models::BinaryArrayComparisonOperator::In, + BinaryArrayComparisonOperator::Other(operator) => { + return Err(ErrorResponse { + details: None, + message: format!("Unknown binary array comparison operator {operator}"), + r#type: None, + }) + } + }, + values: values + .iter() + .map(|value| models::ComparisonValue::Scalar { + value: value.clone(), + }) + .collect(), + }, + Expression::Exists { in_table, r#where } => match in_table { + ExistsInTable::Unrelated { table } => models::Expression::Exists { + in_collection: models::ExistsInCollection::Unrelated { + collection: get_name(table)?, + arguments: BTreeMap::new(), + }, + predicate: Box::new(map_expression(r#where, &get_name(table)?, relationships)?), + }, + ExistsInTable::Related { relationship } => { + let (target_table, arguments) = + get_relationship_collection_arguments(collection, relationship, relationships)?; + + models::Expression::Exists { + in_collection: models::ExistsInCollection::Related { + relationship: format!("{}.{}", collection, relationship), + arguments, + }, + predicate: Box::new(map_expression(r#where, &target_table, relationships)?), + } + } + }, + }) +} + +fn map_comparison_column( + column: &ComparisonColumn, +) -> Result { + match &column.path.as_deref() { + Some([]) | None => Ok(models::ComparisonTarget::Column { + name: get_col_name(&column.name)?, + path: vec![], + }), + Some([p]) if p == "$" => Ok(models::ComparisonTarget::RootCollectionColumn { + name: get_col_name(&column.name)?, + }), + Some(path) => Err(ErrorResponse { + details: None, + message: format!("Valid values for path are empty array, or array with $ reference to closest query target. Got {}", path.join(".")), + r#type: None, + }), + } +} + +fn map_query_response(models::QueryResponse(response): models::QueryResponse) -> QueryResponse { + if response.len() == 1 { + QueryResponse::Single(get_reponse_row( + response + .into_iter() + .next() + .expect("we just checked there is exactly least one element"), + )) + } else { + QueryResponse::ForEach { + rows: response + .into_iter() + .map(|row| ForEachRow { + query: get_reponse_row(row), + }) + .collect(), + } + } +} + +fn get_reponse_row(row: models::RowSet) -> ResponseRow { + ResponseRow { + aggregates: row.aggregates, + rows: row.rows.map(|rows| { + rows.into_iter() + .map(|row| { + IndexMap::from_iter(row.into_iter().map( + |(alias, models::RowFieldValue(value))| { + (alias, ResponseFieldValue::Column(value)) + }, + )) + }) + .collect() + }), + } +} + +fn get_collection_and_arguments( + target: &Target, +) -> Result<(String, BTreeMap), ErrorResponse> { + match target { + Target::Table { name } => Ok((get_name(name)?, BTreeMap::new())), + Target::Interpolated { .. } => Err(ErrorResponse { + details: None, + message: "Interpolated queries not supported".to_string(), + r#type: None, + }), + Target::Function { name, arguments } => Ok(( + get_name(name)?, + BTreeMap::from_iter(arguments.iter().map(|argument| match argument { + gdc_rust_types::FunctionRequestArgument::Named { name, value } => ( + name.to_owned(), + models::Argument::Literal { + value: match value { + gdc_rust_types::ArgumentValue::Scalar { + value, + value_type: _, + } => value.to_owned(), + }, + }, + ), + })), + )), + } +} + +fn get_collection_and_relationship_arguments( + target: &Target, +) -> Result<(String, BTreeMap), ErrorResponse> { + match target { + Target::Table { name } => Ok((get_name(name)?, BTreeMap::new())), + Target::Interpolated { .. } => Err(ErrorResponse { + details: None, + message: "Interpolated queries not supported".to_string(), + r#type: None, + }), + Target::Function { name, arguments } => Ok(( + get_name(name)?, + BTreeMap::from_iter(arguments.iter().map(|argument| match argument { + gdc_rust_types::FunctionRequestArgument::Named { name, value } => ( + name.to_owned(), + models::RelationshipArgument::Literal { + value: match value { + gdc_rust_types::ArgumentValue::Scalar { + value, + value_type: _, + } => value.to_owned(), + }, + }, + ), + })), + )), + } +} + +fn get_name(target: &Vec) -> Result { + match target.as_slice() { + [name] => Ok(name.to_owned()), + _ => Err(ErrorResponse { + details: None, + message: format!( + "Expected function name to be array with exacly one string member, got {}", + target.join(".") + ), + r#type: None, + }), + } +} + +fn get_col_name(column: &ColumnSelector) -> Result { + match column { + ColumnSelector::Compound(name) => Err(ErrorResponse { + details: None, + message: format!( + "Compound column selectors not supported, got {}", + name.join(".") + ), + r#type: None, + }), + ColumnSelector::Name(name) => Ok(name.to_owned()), + } +}