From 0012e9f111ce0c196c13bf10e12ca07db726867d Mon Sep 17 00:00:00 2001 From: dragonliu Date: Tue, 5 Nov 2024 10:29:54 +0800 Subject: [PATCH 1/5] support redis dic batch --- Cargo.lock | 20 +++++- Cargo.toml | 1 + src/query/expression/Cargo.toml | 1 + src/query/expression/src/values.rs | 24 +++++++ src/query/service/Cargo.toml | 1 + .../transforms/transform_dictionary.rs | 71 +++++++++++++------ 6 files changed, 95 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 80aeebcb0389..3dccf38c4d1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3514,6 +3514,7 @@ dependencies = [ "pretty_assertions", "rand 0.8.5", "recursive", + "redis 0.27.5 (git+https://github.com/redis-rs/redis-rs.git)", "rmp-serde", "roaring", "rust_decimal", @@ -5294,6 +5295,7 @@ dependencies = [ "prost 0.12.6", "rand 0.8.5", "recursive", + "redis 0.27.5 (git+https://github.com/redis-rs/redis-rs.git)", "regex", "reqwest", "rmp-serde", @@ -10865,7 +10867,7 @@ dependencies = [ "prometheus-client", "prost 0.13.1", "quick-xml 0.36.1", - "redis 0.27.5", + "redis 0.27.5 (registry+https://github.com/rust-lang/crates.io-index)", "reqsign", "reqwest", "serde", @@ -12741,6 +12743,22 @@ dependencies = [ "url", ] +[[package]] +name = "redis" +version = "0.27.5" +source = "git+https://github.com/redis-rs/redis-rs.git#3642621d3a73330076b8ab9af28fa491716a38dd" +dependencies = [ + "arc-swap", + "combine", + "itoa", + "num-bigint", + "percent-encoding", + "ryu", + "sha1_smol", + "socket2 0.5.7", + "url", +] + [[package]] name = "redox_syscall" version = "0.2.16" diff --git a/Cargo.toml b/Cargo.toml index 679998b3c7e2..283d6e6e5524 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -293,6 +293,7 @@ prometheus-client = "0.22" prost = { version = "0.12.1" } prost-build = { version = "0.12.1" } rand = { version = "0.8.5", features = ["small_rng"] } +redis = { version = "0.27.5", git = "https://github.com/redis-rs/redis-rs.git" } regex = "1.8.1" reqwest = { version = "0.12", default-features = false, features = [ "json", diff --git a/src/query/expression/Cargo.toml b/src/query/expression/Cargo.toml index 08d4232286c4..5fde9d8551ad 100644 --- a/src/query/expression/Cargo.toml +++ b/src/query/expression/Cargo.toml @@ -49,6 +49,7 @@ num-bigint = "0.4.6" num-traits = "0.2.15" rand = { workspace = true } recursive = "0.1.1" +redis = { workspace = true } roaring = { version = "0.10.1", features = ["serde"] } rust_decimal = "1.26" serde = { workspace = true } diff --git a/src/query/expression/src/values.rs b/src/query/expression/src/values.rs index 0196bab4a955..c9a6c5662e1d 100755 --- a/src/query/expression/src/values.rs +++ b/src/query/expression/src/values.rs @@ -37,6 +37,8 @@ use geo::Point; use geozero::CoordDimensions; use geozero::ToWkb; use itertools::Itertools; +use redis::FromRedisValue; +use redis::RedisResult; use roaring::RoaringTreemap; use serde::de::Visitor; use serde::Deserialize; @@ -764,6 +766,28 @@ impl PartialEq for Scalar { } } +impl FromRedisValue for Scalar { + fn from_redis_value(v: &redis::Value) -> RedisResult { + match v { + redis::Value::BulkString(bs) => { + let str = unsafe { String::from_utf8_unchecked(bs.to_vec()) }; + Ok(Scalar::String(str)) + } + redis::Value::Array(arr) => { + // 如何设置 builder 配置?DataType写死? + let mut builder = ColumnBuilder::with_capacity(&DataType::String, 1); + for item in arr { + let scalar = Scalar::from_redis_value(item)?; + builder.push(scalar.as_ref()); + } + Ok(Scalar::Array(builder.build())) + } + redis::Value::Nil => Ok(Scalar::default_value(&DataType::String)), + _ => unreachable!(), + } + } +} + impl<'a, 'b> PartialOrd> for ScalarRef<'a> { fn partial_cmp(&self, other: &ScalarRef<'b>) -> Option { match (self, other) { diff --git a/src/query/service/Cargo.toml b/src/query/service/Cargo.toml index 4ce4711dff6f..678ec2b6b0cf 100644 --- a/src/query/service/Cargo.toml +++ b/src/query/service/Cargo.toml @@ -153,6 +153,7 @@ poem = { workspace = true } prost = { workspace = true } rand = { workspace = true } recursive = "0.1.1" +redis = { workspace = true } regex = { workspace = true } reqwest = { workspace = true } rustls = "0.22" diff --git a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs index 7082a4429be6..193ecf441b93 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs @@ -24,8 +24,10 @@ use databend_common_expression::types::DataType; use databend_common_expression::types::Number; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberScalar; +use databend_common_expression::types::StringColumn; use databend_common_expression::with_integer_mapped_type; use databend_common_expression::BlockEntry; +use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; use databend_common_expression::Scalar; @@ -34,6 +36,7 @@ use databend_common_expression::Value; use databend_common_storage::build_operator; use opendal::services::Redis; use opendal::Operator; +use redis::Commands; use sqlx::MySqlPool; use crate::pipelines::processors::transforms::TransformAsyncFunction; @@ -45,6 +48,7 @@ use crate::sql::IndexType; pub(crate) enum DictionaryOperator { Operator(Operator), + Redis(redis::Connection), Mysql((MySqlPool, String)), } @@ -130,6 +134,31 @@ impl DictionaryOperator { "unsupported value type {data_type}" ))), }, + _ => unreachable!(), + } + } + + async fn dict_get_batch( + & self, + column: &Column, + data_type: &DataType, + ) -> Result> { + match self { + DictionaryOperator::Redis(con) => match column { + Column::String(str_col) => { + let mut keys: Vec<&str> = vec![]; + for i in 0..str_col.len() { + keys.push(unsafe { str_col.index_unchecked(i) }); + } + let value = con.get(keys); + match value { + Ok(val) => Ok(val), + _ => unreachable!(), + } + } + _ => Ok(None), + }, + _ => unreachable!(), } } } @@ -143,18 +172,18 @@ impl TransformAsyncFunction { if let AsyncFunctionArgument::DictGetFunction(dict_arg) = &async_func_desc.func_arg { match &dict_arg.dict_source { DictionarySource::Redis(redis_source) => { - let mut builder = Redis::default().endpoint(&redis_source.connection_url); - if let Some(ref username) = redis_source.username { - builder = builder.username(username); - } - if let Some(ref password) = redis_source.password { - builder = builder.password(password); - } - if let Some(db_index) = redis_source.db_index { - builder = builder.db(db_index); - } - let op = build_operator(builder)?; - operators.insert(i, Arc::new(DictionaryOperator::Operator(op))); + let client = redis::Client::open(redis_source.connection_url.clone()); // TODO: "redis://127.0.0.1/" + let con = match client { + Ok(cli) => { + let connection = cli.get_connection(); + match connection { + Ok(con) => con, + _ => unreachable!(), + } + } + _ => unreachable!(), + }; + operators.insert(i, Arc::new(DictionaryOperator::Redis(con))); } DictionarySource::Mysql(sql_source) => { let mysql_pool = databend_common_base::runtime::block_on( @@ -187,22 +216,20 @@ impl TransformAsyncFunction { let entry = data_block.get_by_offset(arg_index); let value = match &entry.value { Value::Scalar(scalar) => { + let mut builder = ColumnBuilder::with_capacity(data_type, 1); + builder.push(scalar.as_ref()); let value = op - .dict_get(scalar.as_ref(), data_type) + .dict_get_batch(&builder.build(), data_type) .await? .unwrap_or(dict_arg.default_value.clone()); Value::Scalar(value) } Value::Column(column) => { - let mut builder = ColumnBuilder::with_capacity(data_type, column.len()); - for scalar_ref in column.iter() { - let value = op - .dict_get(scalar_ref, data_type) - .await? - .unwrap_or(dict_arg.default_value.clone()); - builder.push(value.as_ref()); - } - Value::Column(builder.build()) + let value = op + .dict_get_batch(column, data_type) + .await? + .unwrap_or(dict_arg.default_value.clone()); + Value::Scalar(value) } }; let entry = BlockEntry { From 7f04302a5f91fe0e79bca964ee0e0d75591ea397 Mon Sep 17 00:00:00 2001 From: dragonliu Date: Tue, 5 Nov 2024 16:41:00 +0800 Subject: [PATCH 2/5] tmp code --- Cargo.lock | 23 ++------ Cargo.toml | 4 +- src/common/exception/Cargo.toml | 1 + src/common/exception/src/exception_into.rs | 6 ++ src/meta/app/src/schema/dictionary.rs | 2 +- .../transforms/transform_dictionary.rs | 59 ++++++++++--------- .../sql/src/planner/plans/scalar_expr.rs | 2 + .../sql/src/planner/semantic/type_check.rs | 1 + 8 files changed, 49 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3dccf38c4d1a..e039ac49d1cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3459,6 +3459,7 @@ dependencies = [ "parquet", "paste", "prost 0.12.6", + "redis 0.27.5", "reqwest", "serde", "serde_json", @@ -3514,7 +3515,7 @@ dependencies = [ "pretty_assertions", "rand 0.8.5", "recursive", - "redis 0.27.5 (git+https://github.com/redis-rs/redis-rs.git)", + "redis 0.27.5", "rmp-serde", "roaring", "rust_decimal", @@ -5295,7 +5296,7 @@ dependencies = [ "prost 0.12.6", "rand 0.8.5", "recursive", - "redis 0.27.5 (git+https://github.com/redis-rs/redis-rs.git)", + "redis 0.27.5", "regex", "reqwest", "rmp-serde", @@ -10867,7 +10868,7 @@ dependencies = [ "prometheus-client", "prost 0.13.1", "quick-xml 0.36.1", - "redis 0.27.5 (registry+https://github.com/rust-lang/crates.io-index)", + "redis 0.27.5", "reqsign", "reqwest", "serde", @@ -12743,22 +12744,6 @@ dependencies = [ "url", ] -[[package]] -name = "redis" -version = "0.27.5" -source = "git+https://github.com/redis-rs/redis-rs.git#3642621d3a73330076b8ab9af28fa491716a38dd" -dependencies = [ - "arc-swap", - "combine", - "itoa", - "num-bigint", - "percent-encoding", - "ryu", - "sha1_smol", - "socket2 0.5.7", - "url", -] - [[package]] name = "redox_syscall" version = "0.2.16" diff --git a/Cargo.toml b/Cargo.toml index 283d6e6e5524..abf254ce5efc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -293,7 +293,9 @@ prometheus-client = "0.22" prost = { version = "0.12.1" } prost-build = { version = "0.12.1" } rand = { version = "0.8.5", features = ["small_rng"] } -redis = { version = "0.27.5", git = "https://github.com/redis-rs/redis-rs.git" } +redis = { version = "0.27.5", features = [ + "connection-manager", +]} regex = "1.8.1" reqwest = { version = "0.12", default-features = false, features = [ "json", diff --git a/src/common/exception/Cargo.toml b/src/common/exception/Cargo.toml index ac0872ba5d65..cea1d8301794 100644 --- a/src/common/exception/Cargo.toml +++ b/src/common/exception/Cargo.toml @@ -26,6 +26,7 @@ opendal = { workspace = true } parquet = { workspace = true } paste = { workspace = true } prost = { workspace = true } +redis = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/src/common/exception/src/exception_into.rs b/src/common/exception/src/exception_into.rs index cb18b0e451cc..9839abf9dd90 100644 --- a/src/common/exception/src/exception_into.rs +++ b/src/common/exception/src/exception_into.rs @@ -447,3 +447,9 @@ impl From for ErrorCode { ErrorCode::DictionarySourceError(format!("Dictionary Sqlx Error, cause: {}", error)) } } + +impl From for ErrorCode { + fn from(error: redis::RedisError) -> Self { + ErrorCode::DictionarySourceError(format!("Dictionary Redis Error, cause: {}", error)) + } +} diff --git a/src/meta/app/src/schema/dictionary.rs b/src/meta/app/src/schema/dictionary.rs index 99679b2a8c99..984a40b5eccf 100644 --- a/src/meta/app/src/schema/dictionary.rs +++ b/src/meta/app/src/schema/dictionary.rs @@ -116,7 +116,7 @@ impl DictionaryMeta { .options .get("port") .ok_or_else(|| ErrorCode::BadArguments("Miss option `port`"))?; - Ok(format!("tcp://{}:{}", host, port)) + Ok(format!("redis://{}:{}", host, port)) } } diff --git a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs index 193ecf441b93..cc7fa87f3db4 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs @@ -36,7 +36,9 @@ use databend_common_expression::Value; use databend_common_storage::build_operator; use opendal::services::Redis; use opendal::Operator; -use redis::Commands; +use redis::aio::ConnectionManager; +use redis::AsyncCommands; +use redis::Client; use sqlx::MySqlPool; use crate::pipelines::processors::transforms::TransformAsyncFunction; @@ -48,7 +50,7 @@ use crate::sql::IndexType; pub(crate) enum DictionaryOperator { Operator(Operator), - Redis(redis::Connection), + Redis(ConnectionManager), Mysql((MySqlPool, String)), } @@ -139,22 +141,30 @@ impl DictionaryOperator { } async fn dict_get_batch( - & self, + &self, column: &Column, data_type: &DataType, - ) -> Result> { + ) -> Result> { match self { - DictionaryOperator::Redis(con) => match column { + DictionaryOperator::Redis(conn) => match column { Column::String(str_col) => { - let mut keys: Vec<&str> = vec![]; - for i in 0..str_col.len() { - keys.push(unsafe { str_col.index_unchecked(i) }); + let key_cnt = str_col.len(); + let mut key_map = BTreeMap::new(); + for i in 0..key_cnt { + let key = unsafe { str_col.index_unchecked(i) }; + let index = key_map.len(); + key_map.insert(key, index); } - let value = con.get(keys); - match value { - Ok(val) => Ok(val), - _ => unreachable!(), + let keys = Vec::from_iter(key_map.iter().map(|(k, _)| k)); + let mut conn = conn.clone(); + let values: Vec = conn.get(keys).await?; + let mut builder = ColumnBuilder::with_capacity(data_type, values.len()); + for i in 0..key_cnt { + let key = unsafe { str_col.index_unchecked(i) }; + let index = key_map[key]; + builder.push(ScalarRef::String(values[index].as_str())); } + Ok(Some(builder.build())) } _ => Ok(None), }, @@ -172,18 +182,11 @@ impl TransformAsyncFunction { if let AsyncFunctionArgument::DictGetFunction(dict_arg) = &async_func_desc.func_arg { match &dict_arg.dict_source { DictionarySource::Redis(redis_source) => { - let client = redis::Client::open(redis_source.connection_url.clone()); // TODO: "redis://127.0.0.1/" - let con = match client { - Ok(cli) => { - let connection = cli.get_connection(); - match connection { - Ok(con) => con, - _ => unreachable!(), - } - } - _ => unreachable!(), - }; - operators.insert(i, Arc::new(DictionaryOperator::Redis(con))); + let client = Client::open(redis_source.connection_url.clone())?; + let conn = databend_common_base::runtime::block_on( + ConnectionManager::new(client), + )?; + operators.insert(i, Arc::new(DictionaryOperator::Redis(conn))); } DictionarySource::Mysql(sql_source) => { let mysql_pool = databend_common_base::runtime::block_on( @@ -221,15 +224,15 @@ impl TransformAsyncFunction { let value = op .dict_get_batch(&builder.build(), data_type) .await? - .unwrap_or(dict_arg.default_value.clone()); - Value::Scalar(value) + .ok_or(()).unwrap(); // TODO: .unwrap_or(dict_arg.default_value.clone()); + Value::Column(value) } Value::Column(column) => { let value = op .dict_get_batch(column, data_type) .await? - .unwrap_or(dict_arg.default_value.clone()); - Value::Scalar(value) + .ok_or(()).unwrap(); // TODO: .unwrap_or(dict_arg.default_value.clone()); + Value::Column(value) } }; let entry = BlockEntry { diff --git a/src/query/sql/src/planner/plans/scalar_expr.rs b/src/query/sql/src/planner/plans/scalar_expr.rs index ea4f0201cd61..2424caeb03c0 100644 --- a/src/query/sql/src/planner/plans/scalar_expr.rs +++ b/src/query/sql/src/planner/plans/scalar_expr.rs @@ -25,6 +25,7 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberScalar; +use databend_common_expression::Column; use databend_common_expression::RemoteExpr; use databend_common_expression::Scalar; use databend_common_meta_app::schema::GetSequenceNextValueReq; @@ -830,6 +831,7 @@ pub enum DictionarySource { pub struct DictGetFunctionArgument { pub dict_source: DictionarySource, pub default_value: Scalar, + // pub default_value: Column, } // Asynchronous functions are functions that need to call remote interfaces. diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index 43538c30792f..f6fced80aa35 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -70,6 +70,7 @@ use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberScalar; use databend_common_expression::types::F32; +use databend_common_expression::Column; use databend_common_expression::ColumnIndex; use databend_common_expression::ConstantFolder; use databend_common_expression::DataField; From 71e0436c7af1d369e746558e2f6ceef3b521ed20 Mon Sep 17 00:00:00 2001 From: dragonliu Date: Wed, 6 Nov 2024 11:29:52 +0800 Subject: [PATCH 3/5] support column nullable and etc. --- src/query/expression/src/values.rs | 24 ----- .../transforms/transform_dictionary.rs | 94 ++++++++++++++----- .../sql/src/planner/plans/scalar_expr.rs | 1 - .../sql/src/planner/semantic/type_check.rs | 3 +- 4 files changed, 73 insertions(+), 49 deletions(-) diff --git a/src/query/expression/src/values.rs b/src/query/expression/src/values.rs index c9a6c5662e1d..0196bab4a955 100755 --- a/src/query/expression/src/values.rs +++ b/src/query/expression/src/values.rs @@ -37,8 +37,6 @@ use geo::Point; use geozero::CoordDimensions; use geozero::ToWkb; use itertools::Itertools; -use redis::FromRedisValue; -use redis::RedisResult; use roaring::RoaringTreemap; use serde::de::Visitor; use serde::Deserialize; @@ -766,28 +764,6 @@ impl PartialEq for Scalar { } } -impl FromRedisValue for Scalar { - fn from_redis_value(v: &redis::Value) -> RedisResult { - match v { - redis::Value::BulkString(bs) => { - let str = unsafe { String::from_utf8_unchecked(bs.to_vec()) }; - Ok(Scalar::String(str)) - } - redis::Value::Array(arr) => { - // 如何设置 builder 配置?DataType写死? - let mut builder = ColumnBuilder::with_capacity(&DataType::String, 1); - for item in arr { - let scalar = Scalar::from_redis_value(item)?; - builder.push(scalar.as_ref()); - } - Ok(Scalar::Array(builder.build())) - } - redis::Value::Nil => Ok(Scalar::default_value(&DataType::String)), - _ => unreachable!(), - } - } -} - impl<'a, 'b> PartialOrd> for ScalarRef<'a> { fn partial_cmp(&self, other: &ScalarRef<'b>) -> Option { match (self, other) { diff --git a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs index cc7fa87f3db4..a2023da29104 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs @@ -33,8 +33,6 @@ use databend_common_expression::DataBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; use databend_common_expression::Value; -use databend_common_storage::build_operator; -use opendal::services::Redis; use opendal::Operator; use redis::aio::ConnectionManager; use redis::AsyncCommands; @@ -146,31 +144,81 @@ impl DictionaryOperator { data_type: &DataType, ) -> Result> { match self { - DictionaryOperator::Redis(conn) => match column { - Column::String(str_col) => { - let key_cnt = str_col.len(); - let mut key_map = BTreeMap::new(); - for i in 0..key_cnt { - let key = unsafe { str_col.index_unchecked(i) }; - let index = key_map.len(); - key_map.insert(key, index); - } - let keys = Vec::from_iter(key_map.iter().map(|(k, _)| k)); - let mut conn = conn.clone(); - let values: Vec = conn.get(keys).await?; - let mut builder = ColumnBuilder::with_capacity(data_type, values.len()); - for i in 0..key_cnt { - let key = unsafe { str_col.index_unchecked(i) }; - let index = key_map[key]; - builder.push(ScalarRef::String(values[index].as_str())); + DictionaryOperator::Redis(connection) => match column { + Column::Nullable(box nullable_col) => match &nullable_col.column { + Column::String(str_col) => { + self.get_values_from_redis(str_col, connection, data_type) + .await } - Ok(Some(builder.build())) + _ => unreachable!(), + }, + Column::String(str_col) => { + self.get_values_from_redis(str_col, connection, data_type) + .await } _ => Ok(None), }, _ => unreachable!(), } } + + async fn get_values_from_redis( + &self, + str_col: &StringColumn, + connection: &ConnectionManager, + data_type: &DataType, + ) -> Result> { + let key_cnt = str_col.len(); + let mut keys: Vec<&str> = vec![]; + let mut key_map = BTreeMap::new(); + for i in 0..key_cnt { + let key = unsafe { str_col.index_unchecked(i) }; + if !key_map.contains_key(key) { + let index = key_map.len(); + keys.push(key); + key_map.insert(key, index); + } + } + let mut conn = connection.clone(); + let redis_val: redis::Value = conn.get(keys).await.unwrap(); + let res = self.from_redis_value_to_scalar(&redis_val)?; + let mut builder = ColumnBuilder::with_capacity(data_type, key_cnt); + match res { + Scalar::Array(arr) => { + for i in 0..key_cnt { + let key = unsafe { str_col.index_unchecked(i) }; + let index = key_map[key]; + builder.push(unsafe { arr.index_unchecked(index) }); + } + } + Scalar::String(str) => { + for _ in 0..key_cnt { + builder.push(ScalarRef::String(str.as_str())); + } + } + _ => unreachable!(), + } + Ok(Some(builder.build())) + } + + fn from_redis_value_to_scalar(&self, rv: &redis::Value) -> Result { + match rv { + redis::Value::BulkString(bs) => { + let str = unsafe { String::from_utf8_unchecked(bs.to_vec()) }; + Ok(Scalar::String(str)) + } + redis::Value::Array(arr) => { + let mut builder = ColumnBuilder::with_capacity(&DataType::String, 1); + for item in arr { + let scalar = self.from_redis_value_to_scalar(item)?; + builder.push(scalar.as_ref()); + } + Ok(Scalar::Array(builder.build())) + } + redis::Value::Nil => Ok(Scalar::default_value(&DataType::String)), + _ => unreachable!(), + } + } } impl TransformAsyncFunction { @@ -224,14 +272,16 @@ impl TransformAsyncFunction { let value = op .dict_get_batch(&builder.build(), data_type) .await? - .ok_or(()).unwrap(); // TODO: .unwrap_or(dict_arg.default_value.clone()); + .ok_or(()) + .unwrap(); // TODO: .unwrap_or(dict_arg.default_value.clone()); Value::Column(value) } Value::Column(column) => { let value = op .dict_get_batch(column, data_type) .await? - .ok_or(()).unwrap(); // TODO: .unwrap_or(dict_arg.default_value.clone()); + .ok_or(()) + .unwrap(); // TODO: .unwrap_or(dict_arg.default_value.clone()); Value::Column(value) } }; diff --git a/src/query/sql/src/planner/plans/scalar_expr.rs b/src/query/sql/src/planner/plans/scalar_expr.rs index 2424caeb03c0..4f805ab1d9bb 100644 --- a/src/query/sql/src/planner/plans/scalar_expr.rs +++ b/src/query/sql/src/planner/plans/scalar_expr.rs @@ -25,7 +25,6 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberScalar; -use databend_common_expression::Column; use databend_common_expression::RemoteExpr; use databend_common_expression::Scalar; use databend_common_meta_app::schema::GetSequenceNextValueReq; diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index f6fced80aa35..99647e9e95b2 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -70,7 +70,6 @@ use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberScalar; use databend_common_expression::types::F32; -use databend_common_expression::Column; use databend_common_expression::ColumnIndex; use databend_common_expression::ConstantFolder; use databend_common_expression::DataField; @@ -4052,7 +4051,7 @@ impl<'a> TypeChecker<'a> { let mut args = Vec::with_capacity(1); let box (key_scalar, key_type) = self.resolve(key_arg)?; - if primary_type != key_type { + if primary_type != key_type.remove_nullable() { args.push(wrap_cast(&key_scalar, &primary_type)); } else { args.push(key_scalar); From 681c78195a881f8d3378be9d9b0744a74ec785fc Mon Sep 17 00:00:00 2001 From: dragonliu Date: Wed, 6 Nov 2024 22:03:20 +0800 Subject: [PATCH 4/5] code format --- .../transforms/transform_async_function.rs | 3 +- .../transforms/transform_dictionary.rs | 205 ++++++++---------- .../sql/src/planner/plans/scalar_expr.rs | 2 - .../sql/src/planner/semantic/type_check.rs | 7 +- 4 files changed, 94 insertions(+), 123 deletions(-) diff --git a/src/query/service/src/pipelines/processors/transforms/transform_async_function.rs b/src/query/service/src/pipelines/processors/transforms/transform_async_function.rs index 385cd1a40ce5..be2eeb93aaaf 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_async_function.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_async_function.rs @@ -99,11 +99,10 @@ impl AsyncTransform for TransformAsyncFunction { ) .await?; } - AsyncFunctionArgument::DictGetFunction(dict_arg) => { + AsyncFunctionArgument::DictGetFunction(_) => { self.transform_dict_get( i, &mut data_block, - dict_arg, &async_func_desc.arg_indices, &async_func_desc.data_type, ) diff --git a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs index a2023da29104..eaf8e5aff7be 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs @@ -33,7 +33,6 @@ use databend_common_expression::DataBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; use databend_common_expression::Value; -use opendal::Operator; use redis::aio::ConnectionManager; use redis::AsyncCommands; use redis::Client; @@ -42,132 +41,50 @@ use sqlx::MySqlPool; use crate::pipelines::processors::transforms::TransformAsyncFunction; use crate::sql::executor::physical_plans::AsyncFunctionDesc; use crate::sql::plans::AsyncFunctionArgument; -use crate::sql::plans::DictGetFunctionArgument; use crate::sql::plans::DictionarySource; use crate::sql::IndexType; pub(crate) enum DictionaryOperator { - Operator(Operator), Redis(ConnectionManager), Mysql((MySqlPool, String)), } impl DictionaryOperator { - fn format_key(&self, key: ScalarRef<'_>) -> String { - match key { - ScalarRef::String(s) => s.to_string(), - ScalarRef::Date(d) => format!("{}", date_to_string(d as i64, Tz::UTC)), - ScalarRef::Timestamp(t) => format!("{}", timestamp_to_string(t, Tz::UTC)), - _ => format!("{}", key), - } - } - - async fn dict_get(&self, key: ScalarRef<'_>, data_type: &DataType) -> Result> { - if key == ScalarRef::Null { - return Ok(None); - } - match self { - DictionaryOperator::Operator(op) => { - if let ScalarRef::String(key) = key { - let buffer = op.read(key).await; - match buffer { - Ok(res) => { - let value = - unsafe { String::from_utf8_unchecked(res.current().to_vec()) }; - Ok(Some(Scalar::String(value))) - } - Err(e) => { - if e.kind() == opendal::ErrorKind::NotFound { - Ok(None) - } else { - Err(ErrorCode::DictionarySourceError(format!( - "dictionary source error: {e}" - ))) - } - } - } - } else { - Ok(None) - } - } - DictionaryOperator::Mysql((pool, sql)) => match data_type.remove_nullable() { - DataType::Boolean => { - let value: Option = sqlx::query_scalar(sql) - .bind(self.format_key(key)) - .fetch_optional(pool) - .await?; - Ok(value.map(Scalar::Boolean)) - } - DataType::String => { - let value: Option = sqlx::query_scalar(sql) - .bind(self.format_key(key)) - .fetch_optional(pool) - .await?; - Ok(value.map(Scalar::String)) - } - DataType::Number(num_ty) => { - with_integer_mapped_type!(|NUM_TYPE| match num_ty { - NumberDataType::NUM_TYPE => { - let value: Option = sqlx::query_scalar(&sql) - .bind(self.format_key(key)) - .fetch_optional(pool) - .await?; - Ok(value.map(|v| Scalar::Number(NUM_TYPE::upcast_scalar(v)))) - } - NumberDataType::Float32 => { - let value: Option = sqlx::query_scalar(sql) - .bind(self.format_key(key)) - .fetch_optional(pool) - .await?; - Ok(value.map(|v| Scalar::Number(NumberScalar::Float32(v.into())))) - } - NumberDataType::Float64 => { - let value: Option = sqlx::query_scalar(sql) - .bind(self.format_key(key)) - .fetch_optional(pool) - .await?; - Ok(value.map(|v| Scalar::Number(NumberScalar::Float64(v.into())))) - } - }) - } - _ => Err(ErrorCode::DictionarySourceError(format!( - "unsupported value type {data_type}" - ))), - }, - _ => unreachable!(), - } - } - - async fn dict_get_batch( - &self, - column: &Column, - data_type: &DataType, - ) -> Result> { + async fn dict_get(&self, column: &Column, data_type: &DataType) -> Result { match self { DictionaryOperator::Redis(connection) => match column { Column::Nullable(box nullable_col) => match &nullable_col.column { Column::String(str_col) => { - self.get_values_from_redis(str_col, connection, data_type) + self.get_values_from_redis(str_col, data_type, connection) .await } - _ => unreachable!(), + _ => Err(ErrorCode::DictionarySourceError(format!( + "Redis dictionary operator currently does not support value type {}", + column.data_type() + ))), }, Column::String(str_col) => { - self.get_values_from_redis(str_col, connection, data_type) + self.get_values_from_redis(str_col, data_type, connection) .await } - _ => Ok(None), + _ => Err(ErrorCode::DictionarySourceError(format!( + "Redis dictionary operator currently does not support value type {}", + column.data_type() + ))), + }, + DictionaryOperator::Mysql((pool, sql)) => { + self.get_data_from_mysql(&column, &data_type, &pool, &sql) + .await }, - _ => unreachable!(), } } async fn get_values_from_redis( &self, str_col: &StringColumn, - connection: &ConnectionManager, data_type: &DataType, - ) -> Result> { + connection: &ConnectionManager, + ) -> Result { let key_cnt = str_col.len(); let mut keys: Vec<&str> = vec![]; let mut key_map = BTreeMap::new(); @@ -198,7 +115,7 @@ impl DictionaryOperator { } _ => unreachable!(), } - Ok(Some(builder.build())) + Ok(builder.build()) } fn from_redis_value_to_scalar(&self, rv: &redis::Value) -> Result { @@ -219,6 +136,76 @@ impl DictionaryOperator { _ => unreachable!(), } } + + async fn get_data_from_mysql( + &self, + column: &Column, + data_type: &DataType, + pool: &MySqlPool, + sql: &String, + ) -> Result { + let key = unsafe { column.index_unchecked(0) }; + let res = match data_type.remove_nullable() { + DataType::Boolean => { + let value: Option = sqlx::query_scalar(sql) + .bind(self.format_key(key)) + .fetch_optional(pool) + .await?; + Ok(value.map(Scalar::Boolean)) + } + DataType::String => { + let value: Option = sqlx::query_scalar(sql) + .bind(self.format_key(key)) + .fetch_optional(pool) + .await?; + Ok(value.map(Scalar::String)) + } + DataType::Number(num_ty) => { + with_integer_mapped_type!(|NUM_TYPE| match num_ty { + NumberDataType::NUM_TYPE => { + let value: Option = sqlx::query_scalar(&sql) + .bind(self.format_key(key)) + .fetch_optional(pool) + .await?; + Ok(value.map(|v| Scalar::Number(NUM_TYPE::upcast_scalar(v)))) + } + NumberDataType::Float32 => { + let value: Option = sqlx::query_scalar(sql) + .bind(self.format_key(key)) + .fetch_optional(pool) + .await?; + Ok(value.map(|v| Scalar::Number(NumberScalar::Float32(v.into())))) + } + NumberDataType::Float64 => { + let value: Option = sqlx::query_scalar(sql) + .bind(self.format_key(key)) + .fetch_optional(pool) + .await?; + Ok(value.map(|v| Scalar::Number(NumberScalar::Float64(v.into())))) + } + }) + } + _ => Err(ErrorCode::DictionarySourceError(format!( + "MySQL dictionary operator currently does not support value type {data_type}" + ))), + }?; + let res_data_type = match res { + Some(_) => data_type.remove_nullable(), + None => DataType::Null, + }; + let mut builder = ColumnBuilder::with_capacity(&res_data_type, 1); + builder.push(res.unwrap_or(Scalar::Null).as_ref()); + Ok(builder.build()) + } + + fn format_key(&self, key: ScalarRef<'_>) -> String { + match key { + ScalarRef::String(s) => s.to_string(), + ScalarRef::Date(d) => format!("{}", date_to_string(d as i64, Tz::UTC)), + ScalarRef::Timestamp(t) => format!("{}", timestamp_to_string(t, Tz::UTC)), + _ => format!("{}", key), + } + } } impl TransformAsyncFunction { @@ -257,7 +244,6 @@ impl TransformAsyncFunction { &self, i: usize, data_block: &mut DataBlock, - dict_arg: &DictGetFunctionArgument, arg_indices: &[IndexType], data_type: &DataType, ) -> Result<()> { @@ -267,21 +253,14 @@ impl TransformAsyncFunction { let entry = data_block.get_by_offset(arg_index); let value = match &entry.value { Value::Scalar(scalar) => { - let mut builder = ColumnBuilder::with_capacity(data_type, 1); + let scalar_data_type = scalar.as_ref().infer_data_type(); + let mut builder = ColumnBuilder::with_capacity(&scalar_data_type, 1); builder.push(scalar.as_ref()); - let value = op - .dict_get_batch(&builder.build(), data_type) - .await? - .ok_or(()) - .unwrap(); // TODO: .unwrap_or(dict_arg.default_value.clone()); + let value = op.dict_get(&builder.build(), data_type).await?; Value::Column(value) } Value::Column(column) => { - let value = op - .dict_get_batch(column, data_type) - .await? - .ok_or(()) - .unwrap(); // TODO: .unwrap_or(dict_arg.default_value.clone()); + let value = op.dict_get(column, data_type).await?; Value::Column(value) } }; diff --git a/src/query/sql/src/planner/plans/scalar_expr.rs b/src/query/sql/src/planner/plans/scalar_expr.rs index 4f805ab1d9bb..5a58e48f863a 100644 --- a/src/query/sql/src/planner/plans/scalar_expr.rs +++ b/src/query/sql/src/planner/plans/scalar_expr.rs @@ -829,8 +829,6 @@ pub enum DictionarySource { #[educe(PartialEq, Eq, Hash)] pub struct DictGetFunctionArgument { pub dict_source: DictionarySource, - pub default_value: Scalar, - // pub default_value: Column, } // Asynchronous functions are functions that need to call remote interfaces. diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index 99647e9e95b2..eccdcfb3b64a 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -117,7 +117,6 @@ use crate::binder::CteInfo; use crate::binder::ExprContext; use crate::binder::InternalColumnBinding; use crate::binder::NameResolutionResult; -use crate::field_default_value; use crate::optimizer::RelExpr; use crate::optimizer::SExpr; use crate::parse_lambda_expr; @@ -4041,7 +4040,6 @@ impl<'a> TypeChecker<'a> { }; let attr_field = dictionary.schema.field_with_name(attr_name)?; let attr_type: DataType = (&attr_field.data_type).into(); - let default_value = field_default_value(self.ctx.clone(), attr_field)?; // Get primary_key_value and check type. let primary_column_id = dictionary.primary_column_ids[0]; @@ -4093,10 +4091,7 @@ impl<'a> TypeChecker<'a> { } }; - let dict_get_func_arg = DictGetFunctionArgument { - dict_source, - default_value, - }; + let dict_get_func_arg = DictGetFunctionArgument { dict_source }; let display_name = format!( "{}({}.{}, {}, {})", func_name, db_name, dict_name, field_arg, key_arg, From a6adf5e5871f847b33a37d16fff94c29ccb49fa4 Mon Sep 17 00:00:00 2001 From: dragonliu Date: Thu, 7 Nov 2024 10:57:02 +0800 Subject: [PATCH 5/5] add redis cache --- .../transforms/transform_dictionary.rs | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs index eaf8e5aff7be..2a3760d5b869 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::BTreeMap; +use std::string::String; use std::sync::Arc; use chrono_tz::Tz; @@ -33,6 +34,8 @@ use databend_common_expression::DataBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; use databend_common_expression::Value; +use databend_storages_common_cache::CacheAccessor; +use databend_storages_common_cache::InMemoryLruCache; use redis::aio::ConnectionManager; use redis::AsyncCommands; use redis::Client; @@ -44,18 +47,20 @@ use crate::sql::plans::AsyncFunctionArgument; use crate::sql::plans::DictionarySource; use crate::sql::IndexType; +pub type RedisCache = InMemoryLruCache; + pub(crate) enum DictionaryOperator { - Redis(ConnectionManager), + Redis((ConnectionManager, RedisCache)), Mysql((MySqlPool, String)), } impl DictionaryOperator { async fn dict_get(&self, column: &Column, data_type: &DataType) -> Result { match self { - DictionaryOperator::Redis(connection) => match column { + DictionaryOperator::Redis((connection, redis_cache)) => match column { Column::Nullable(box nullable_col) => match &nullable_col.column { Column::String(str_col) => { - self.get_values_from_redis(str_col, data_type, connection) + self.get_values_from_redis(str_col, data_type, connection, redis_cache) .await } _ => Err(ErrorCode::DictionarySourceError(format!( @@ -64,7 +69,7 @@ impl DictionaryOperator { ))), }, Column::String(str_col) => { - self.get_values_from_redis(str_col, data_type, connection) + self.get_values_from_redis(str_col, data_type, connection, redis_cache) .await } _ => Err(ErrorCode::DictionarySourceError(format!( @@ -75,7 +80,7 @@ impl DictionaryOperator { DictionaryOperator::Mysql((pool, sql)) => { self.get_data_from_mysql(&column, &data_type, &pool, &sql) .await - }, + } } } @@ -84,13 +89,14 @@ impl DictionaryOperator { str_col: &StringColumn, data_type: &DataType, connection: &ConnectionManager, + redis_cache: &RedisCache, ) -> Result { let key_cnt = str_col.len(); let mut keys: Vec<&str> = vec![]; let mut key_map = BTreeMap::new(); for i in 0..key_cnt { let key = unsafe { str_col.index_unchecked(i) }; - if !key_map.contains_key(key) { + if !redis_cache.contains_key(key) && !key_map.contains_key(key) { let index = key_map.len(); keys.push(key); key_map.insert(key, index); @@ -104,11 +110,20 @@ impl DictionaryOperator { Scalar::Array(arr) => { for i in 0..key_cnt { let key = unsafe { str_col.index_unchecked(i) }; - let index = key_map[key]; - builder.push(unsafe { arr.index_unchecked(index) }); + if redis_cache.contains_key(key) { + let val = redis_cache.get(key).unwrap().clone(); + builder.push(ScalarRef::String(val.as_str())); + } else { + let index = key_map[key]; + let val = unsafe { arr.index_unchecked(index) }; + builder.push(val.clone()); + redis_cache.insert(key.to_string(), val.to_string()); + } } } Scalar::String(str) => { + let key = unsafe { str_col.index_unchecked(0) }; + redis_cache.insert(key.to_string(), str.clone()); for _ in 0..key_cnt { builder.push(ScalarRef::String(str.as_str())); } @@ -221,7 +236,12 @@ impl TransformAsyncFunction { let conn = databend_common_base::runtime::block_on( ConnectionManager::new(client), )?; - operators.insert(i, Arc::new(DictionaryOperator::Redis(conn))); + let redis_cache = RedisCache::with_items_capacity( + String::from("memory_cache_redis"), + 1000, // TODO: what capacity to set + ); + operators + .insert(i, Arc::new(DictionaryOperator::Redis((conn, redis_cache)))); } DictionarySource::Mysql(sql_source) => { let mysql_pool = databend_common_base::runtime::block_on(