From fc8b8b027435223bae673c6d40f607a120f4983a Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Mon, 19 May 2025 12:12:31 +0200 Subject: [PATCH 01/34] feat(tests): Add integration tests for pg_replicate --- Cargo.toml | 16 ++++++-- api/Cargo.toml | 8 ++-- api/README.md | 2 +- api/src/configuration.rs | 52 +----------------------- api/src/main.rs | 25 ++++++------ api/src/startup.rs | 30 +++++++------- api/tests/common/database.rs | 53 ++++++------------------- api/tests/common/test_app.rs | 57 ++++++++++++++------------- pg_replicate/tests/mod.rs | 1 + postgres/Cargo.toml | 19 +++++++++ postgres/src/lib.rs | 3 ++ postgres/src/options.rs | 61 +++++++++++++++++++++++++++++ postgres/src/test_utils.rs | 57 +++++++++++++++++++++++++++ replicator/Cargo.toml | 5 ++- {api/scripts => scripts}/init_db.sh | 8 ++-- 15 files changed, 235 insertions(+), 162 deletions(-) create mode 100644 pg_replicate/tests/mod.rs create mode 100644 postgres/Cargo.toml create mode 100644 postgres/src/lib.rs create mode 100644 postgres/src/options.rs create mode 100644 postgres/src/test_utils.rs rename {api/scripts => scripts}/init_db.sh (92%) diff --git a/Cargo.toml b/Cargo.toml index fd72004b..bcf2dd02 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,20 @@ [workspace] - resolver = "2" - -members = ["api", "pg_replicate", "replicator", "telemetry"] +members = [ + "api", + "pg_replicate", + "postgres", + "replicator", + "telemetry" +] [workspace.dependencies] +api = { path = "api" } +pg_replicate = { path = "pg_replicate" } +postgres = { path = "postgres" } +replicator = { path = "replicator" } +telemetry = { path = "telemetry" } + actix-web = { version = "4", default-features = false } actix-web-httpauth = { version = "0.8.2", default-features = false } anyhow = { version = "1.0", default-features = false } diff --git a/api/Cargo.toml b/api/Cargo.toml index 55594831..a6ff06ea 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -11,13 +11,15 @@ path = "src/main.rs" name = "api" [dependencies] +postgres = { workspace = true, features = ["test_utils"] } +telemetry = { workspace = true } + actix-web = { workspace = true, features = ["macros", "http2"] } actix-web-httpauth = { workspace = true } anyhow = { workspace = true, features = ["std"] } async-trait = { workspace = true } aws-lc-rs = { workspace = true, features = ["alloc", "aws-lc-sys"] } base64 = { workspace = true, features = ["std"] } -bytes = { workspace = true } config = { workspace = true, features = ["yaml"] } constant_time_eq = { workspace = true } k8s-openapi = { workspace = true, features = ["latest"] } @@ -30,7 +32,6 @@ kube = { workspace = true, features = [ pg_escape = { workspace = true } rand = { workspace = true, features = ["std"] } reqwest = { workspace = true, features = ["json"] } -secrecy = { workspace = true, features = ["serde", "alloc"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["std"] } sqlx = { workspace = true, features = [ @@ -41,10 +42,9 @@ sqlx = { workspace = true, features = [ "migrate", ] } thiserror = { workspace = true } -telemetry = { path = "../telemetry" } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true, default-features = false } tracing-actix-web = { workspace = true, features = ["emit_event_on_error"] } utoipa = { workspace = true, features = ["actix_extras"] } utoipa-swagger-ui = { workspace = true, features = ["actix-web", "reqwest"] } -uuid = { version = "1.10.0", features = ["v4"] } +uuid = { version = "1.10.0", features = ["v4"] } \ No newline at end of file diff --git a/api/README.md b/api/README.md index 4715439e..6489e0f4 100644 --- a/api/README.md +++ b/api/README.md @@ -26,7 +26,7 @@ Before you begin, ensure you have the following installed: ## Database Management ### Initial Setup -To set up and initialize the database, run the following command from the `api` directory: +To set up and initialize the database, run the following command from the main directory: ```bash ./scripts/init_db.sh diff --git a/api/src/configuration.rs b/api/src/configuration.rs index 1a34c929..0c957260 100644 --- a/api/src/configuration.rs +++ b/api/src/configuration.rs @@ -1,12 +1,11 @@ use std::fmt::{self, Display}; use base64::{prelude::BASE64_STANDARD, Engine}; -use secrecy::{ExposeSecret, Secret}; +use postgres::options::PgDatabaseOptions; use serde::{ de::{self, MapAccess, Unexpected, Visitor}, Deserialize, Deserializer, }; -use sqlx::postgres::{PgConnectOptions, PgSslMode}; use thiserror::Error; #[derive(serde::Deserialize, Clone)] @@ -100,59 +99,12 @@ impl<'de> Deserialize<'de> for ApiKey { #[derive(serde::Deserialize, Clone)] pub struct Settings { - pub database: DatabaseSettings, + pub database: PgDatabaseOptions, pub application: ApplicationSettings, pub encryption_key: EncryptionKey, pub api_key: String, } -#[derive(serde::Deserialize, Clone)] -pub struct DatabaseSettings { - /// Host on which Postgres is running - pub host: String, - - /// Port on which Postgres is running - pub port: u16, - - /// Postgres database name - pub name: String, - - /// Postgres database user name - pub username: String, - - /// Postgres database user password - pub password: Option>, - - /// Whether to enable ssl or not - pub require_ssl: bool, -} - -impl DatabaseSettings { - pub fn without_db(&self) -> PgConnectOptions { - let ssl_mode = if self.require_ssl { - PgSslMode::Require - } else { - PgSslMode::Prefer - }; - - let options = PgConnectOptions::new_without_pgpass() - .host(&self.host) - .username(&self.username) - .port(self.port) - .ssl_mode(ssl_mode); - - if let Some(password) = &self.password { - options.password(password.expose_secret()) - } else { - options - } - } - - pub fn with_db(&self) -> PgConnectOptions { - self.without_db().database(&self.name) - } -} - #[derive(serde::Deserialize, Clone)] pub struct ApplicationSettings { /// host the api listens on diff --git a/api/src/main.rs b/api/src/main.rs index 09a809e6..a85ed26c 100644 --- a/api/src/main.rs +++ b/api/src/main.rs @@ -2,9 +2,10 @@ use std::env; use anyhow::anyhow; use api::{ - configuration::{get_settings, DatabaseSettings, Settings}, + configuration::{get_settings, Settings}, startup::Application, }; +use postgres::options::PgDatabaseOptions; use telemetry::init_tracing; use tracing::{error, info}; @@ -23,7 +24,7 @@ pub async fn main() -> anyhow::Result<()> { // Run the application server 1 => { let configuration = get_settings::<'_, Settings>()?; - log_database_settings(&configuration.database); + log_pg_database_options(&configuration.database); let application = Application::build(configuration.clone()).await?; application.run_until_stopped().await?; } @@ -32,9 +33,9 @@ pub async fn main() -> anyhow::Result<()> { let command = args.nth(1).unwrap(); match command.as_str() { "migrate" => { - let configuration = get_settings::<'_, DatabaseSettings>()?; - log_database_settings(&configuration); - Application::migrate_database(configuration).await?; + let options = get_settings::<'_, PgDatabaseOptions>()?; + log_pg_database_options(&options); + Application::migrate_database(options).await?; info!("database migrated successfully"); } _ => { @@ -54,13 +55,13 @@ pub async fn main() -> anyhow::Result<()> { Ok(()) } -fn log_database_settings(settings: &DatabaseSettings) { +fn log_pg_database_options(options: &PgDatabaseOptions) { info!( - host = settings.host, - port = settings.port, - dbname = settings.name, - username = settings.username, - require_ssl = settings.require_ssl, - "database details", + host = options.host, + port = options.port, + dbname = options.name, + username = options.username, + require_ssl = options.require_ssl, + "pg database options", ); } diff --git a/api/src/startup.rs b/api/src/startup.rs index 9661bf5d..e9d63446 100644 --- a/api/src/startup.rs +++ b/api/src/startup.rs @@ -1,17 +1,8 @@ use std::{net::TcpListener, sync::Arc}; -use actix_web::{dev::Server, web, App, HttpServer}; -use actix_web_httpauth::middleware::HttpAuthentication; -use aws_lc_rs::aead::{RandomizedNonceKey, AES_256_GCM}; -use base64::{prelude::BASE64_STANDARD, Engine}; -use sqlx::{postgres::PgPoolOptions, PgPool}; -use tracing_actix_web::TracingLogger; -use utoipa::OpenApi; -use utoipa_swagger_ui::SwaggerUi; - use crate::{ authentication::auth_validator, - configuration::{DatabaseSettings, Settings}, + configuration::Settings, db::publications::Publication, encryption, k8s_client::HttpK8sClient, @@ -54,6 +45,15 @@ use crate::{ }, span_builder::ApiRootSpanBuilder, }; +use actix_web::{dev::Server, web, App, HttpServer}; +use actix_web_httpauth::middleware::HttpAuthentication; +use aws_lc_rs::aead::{RandomizedNonceKey, AES_256_GCM}; +use base64::{prelude::BASE64_STANDARD, Engine}; +use postgres::options::PgDatabaseOptions; +use sqlx::{postgres::PgPoolOptions, PgPool}; +use tracing_actix_web::TracingLogger; +use utoipa::OpenApi; +use utoipa_swagger_ui::SwaggerUi; pub struct Application { port: u16, @@ -90,10 +90,8 @@ impl Application { Ok(Self { port, server }) } - pub async fn migrate_database( - database_settings: DatabaseSettings, - ) -> Result<(), anyhow::Error> { - let connection_pool = get_connection_pool(&database_settings); + pub async fn migrate_database(options: PgDatabaseOptions) -> Result<(), anyhow::Error> { + let connection_pool = get_connection_pool(&options); sqlx::migrate!("./migrations").run(&connection_pool).await?; @@ -109,8 +107,8 @@ impl Application { } } -pub fn get_connection_pool(configuration: &DatabaseSettings) -> PgPool { - PgPoolOptions::new().connect_lazy_with(configuration.with_db()) +pub fn get_connection_pool(options: &PgDatabaseOptions) -> PgPool { + PgPoolOptions::new().connect_lazy_with(options.with_db()) } // HttpK8sClient is wrapped in an option because creating it diff --git a/api/tests/common/database.rs b/api/tests/common/database.rs index f78c3520..b5900cec 100644 --- a/api/tests/common/database.rs +++ b/api/tests/common/database.rs @@ -1,20 +1,16 @@ -use api::configuration::DatabaseSettings; -use sqlx::{Connection, Executor, PgConnection, PgPool}; +use postgres::options::PgDatabaseOptions; +use postgres::test_utils::create_pg_database; +use sqlx::PgPool; -pub async fn create_and_configure_database(settings: &DatabaseSettings) -> PgPool { - // Create the database via a single connection. - let mut connection = PgConnection::connect_with(&settings.without_db()) - .await - .expect("Failed to connect to Postgres"); - connection - .execute(&*format!(r#"CREATE DATABASE "{}";"#, settings.name)) - .await - .expect("Failed to create database"); +/// Creates and configures a new PostgreSQL database for the API. +/// +/// Similar to [`create_pg_database`], but additionally runs all database migrations +/// from the "./migrations" directory after creation. Returns a [`PgPool`] +/// connected to the newly created and migrated database. Panics if database +/// creation or migration fails. +pub async fn create_pg_replicate_api_database(options: &PgDatabaseOptions) -> PgPool { + let connection_pool = create_pg_database(&options).await; - // Create a connection pool to the database and run the migration. - let connection_pool = PgPool::connect_with(settings.with_db()) - .await - .expect("Failed to connect to Postgres"); sqlx::migrate!("./migrations") .run(&connection_pool) .await @@ -22,30 +18,3 @@ pub async fn create_and_configure_database(settings: &DatabaseSettings) -> PgPoo connection_pool } - -pub async fn destroy_database(settings: &DatabaseSettings) { - // Connect to the default database. - let mut connection = PgConnection::connect_with(&settings.without_db()) - .await - .expect("Failed to connect to Postgres"); - - // Forcefully terminate any remaining connections to the database. This code assumes that those - // connections are not used anymore and do not outlive the `TestApp` instance. - connection - .execute(&*format!( - r#" - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{}' - AND pid <> pg_backend_pid();"#, - settings.name - )) - .await - .expect("Failed to terminate database connections"); - - // Drop the database. - connection - .execute(&*format!(r#"DROP DATABASE IF EXISTS "{}";"#, settings.name)) - .await - .expect("Failed to destroy database"); -} diff --git a/api/tests/common/test_app.rs b/api/tests/common/test_app.rs index c280c61a..4b041b17 100644 --- a/api/tests/common/test_app.rs +++ b/api/tests/common/test_app.rs @@ -1,40 +1,19 @@ -use std::io; -use std::net::TcpListener; - -use crate::common::database::{create_and_configure_database, destroy_database}; +use crate::common::database::create_pg_replicate_api_database; use api::{ configuration::{get_settings, Settings}, db::{pipelines::PipelineConfig, sinks::SinkConfig, sources::SourceConfig}, encryption::{self, generate_random_key}, startup::run, }; +use postgres::options::PgDatabaseOptions; +use postgres::test_utils::drop_pg_database; use reqwest::{IntoUrl, RequestBuilder}; use serde::{Deserialize, Serialize}; +use std::io; +use std::net::TcpListener; use tokio::runtime::Handle; use uuid::Uuid; -pub struct TestApp { - pub address: String, - pub api_client: reqwest::Client, - pub api_key: String, - settings: Settings, - server_handle: tokio::task::JoinHandle>, -} - -impl Drop for TestApp { - fn drop(&mut self) { - // First, abort the server task to ensure it's terminated. - self.server_handle.abort(); - - // To use `block_in_place,` we need a multithreaded runtime since when a blocking - // task is issued, the runtime will offload existing tasks to another worker. - tokio::task::block_in_place(move || { - Handle::current() - .block_on(async move { destroy_database(&self.settings.database).await }); - }); - } -} - #[derive(Serialize)] pub struct CreateTenantRequest { pub id: String, @@ -217,6 +196,27 @@ pub struct UpdateImageRequest { pub is_default: bool, } +pub struct TestApp { + pub address: String, + pub api_client: reqwest::Client, + pub api_key: String, + options: PgDatabaseOptions, + server_handle: tokio::task::JoinHandle>, +} + +impl Drop for TestApp { + fn drop(&mut self) { + // First, abort the server task to ensure it's terminated. + self.server_handle.abort(); + + // To use `block_in_place,` we need a multithreaded runtime since when a blocking + // task is issued, the runtime will offload existing tasks to another worker. + tokio::task::block_in_place(move || { + Handle::current().block_on(async move { drop_pg_database(&self.options).await }); + }); + } +} + impl TestApp { fn get_authenticated(&self, url: U) -> RequestBuilder { self.api_client.get(url).bearer_auth(self.api_key.clone()) @@ -533,9 +533,10 @@ pub async fn spawn_test_app() -> TestApp { let port = listener.local_addr().unwrap().port(); let mut settings = get_settings::<'_, Settings>().expect("Failed to read configuration"); + // We use a random database name. settings.database.name = Uuid::new_v4().to_string(); - let connection_pool = create_and_configure_database(&settings.database).await; + let connection_pool = create_pg_replicate_api_database(&settings.database).await; let key = generate_random_key::<32>().expect("failed to generate random key"); let encryption_key = encryption::EncryptionKey { id: 0, key }; @@ -557,7 +558,7 @@ pub async fn spawn_test_app() -> TestApp { address: format!("http://{base_address}:{port}"), api_client: reqwest::Client::new(), api_key, - settings, + options: settings.database, server_handle, } } diff --git a/pg_replicate/tests/mod.rs b/pg_replicate/tests/mod.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/pg_replicate/tests/mod.rs @@ -0,0 +1 @@ + diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml new file mode 100644 index 00000000..1b4dae4b --- /dev/null +++ b/postgres/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "postgres" +version = "0.1.0" +edition = "2024" + +[dependencies] +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true, features = ["std"] } +secrecy = { workspace = true, features = ["serde", "alloc"] } +sqlx = { workspace = true, features = [ + "runtime-tokio-rustls", + "macros", + "postgres", + "json", + "migrate", +] } + +[features] +test_utils = [] \ No newline at end of file diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs new file mode 100644 index 00000000..37f84807 --- /dev/null +++ b/postgres/src/lib.rs @@ -0,0 +1,3 @@ +pub mod options; +#[cfg(feature = "test_utils")] +pub mod test_utils; diff --git a/postgres/src/options.rs b/postgres/src/options.rs new file mode 100644 index 00000000..f78363fc --- /dev/null +++ b/postgres/src/options.rs @@ -0,0 +1,61 @@ +use secrecy::{ExposeSecret, Secret}; +use serde::Deserialize; +use sqlx::postgres::{PgConnectOptions, PgSslMode}; + +/// Connection options for a PostgreSQL database. +/// +/// Contains the connection parameters needed to establish a connection to a PostgreSQL +/// database server, including network location, authentication credentials, and security +/// settings. +#[derive(Debug, Clone, Deserialize)] +pub struct PgDatabaseOptions { + /// Host name or IP address of the PostgreSQL server + pub host: String, + /// Port number that the PostgreSQL server listens on + pub port: u16, + /// Name of the target database + pub name: String, + /// Username for authentication + pub username: String, + /// Optional password for authentication, wrapped in [`Secret`] for secure handling + pub password: Option>, + /// If true, requires SSL/TLS encryption for the connection + pub require_ssl: bool, +} + +impl PgDatabaseOptions { + /// Creates connection options for connecting to the PostgreSQL server without + /// specifying a database. + /// + /// Returns [`PgConnectOptions`] configured with the host, port, username, SSL mode + /// and optional password from this instance. Useful for administrative operations + /// that must be performed before connecting to a specific database, like database + /// creation. + pub fn without_db(&self) -> PgConnectOptions { + let ssl_mode = if self.require_ssl { + PgSslMode::Require + } else { + PgSslMode::Prefer + }; + + let options = PgConnectOptions::new_without_pgpass() + .host(&self.host) + .username(&self.username) + .port(self.port) + .ssl_mode(ssl_mode); + + if let Some(password) = &self.password { + options.password(password.expose_secret()) + } else { + options + } + } + + /// Creates connection options for connecting to a specific database. + /// + /// Returns [`PgConnectOptions`] configured with all connection parameters including + /// the database name from this instance. + pub fn with_db(&self) -> PgConnectOptions { + self.without_db().database(&self.name) + } +} diff --git a/postgres/src/test_utils.rs b/postgres/src/test_utils.rs new file mode 100644 index 00000000..6b5612b4 --- /dev/null +++ b/postgres/src/test_utils.rs @@ -0,0 +1,57 @@ +use crate::options::PgDatabaseOptions; +use sqlx::{Connection, Executor, PgConnection, PgPool}; + +/// Creates a new PostgreSQL database and returns a connection pool to it. +/// +/// Establishes a connection to the PostgreSQL server using the provided options, +/// creates a new database, and returns a [`PgPool`] connected to the new database. +/// Panics if the connection fails or if database creation fails. +pub async fn create_pg_database(options: &PgDatabaseOptions) -> PgPool { + // Create the database via a single connection. + let mut connection = PgConnection::connect_with(&options.without_db()) + .await + .expect("Failed to connect to Postgres"); + connection + .execute(&*format!(r#"CREATE DATABASE "{}";"#, options.name)) + .await + .expect("Failed to create database"); + + // Create a connection pool to the database and run the migration. + let connection_pool = PgPool::connect_with(options.with_db()) + .await + .expect("Failed to connect to Postgres"); + + connection_pool +} + +/// Drops a PostgreSQL database and cleans up all connections. +/// +/// Connects to the PostgreSQL server, forcefully terminates all active connections +/// to the target database, and drops the database if it exists. Useful for cleaning +/// up test databases. Takes a reference to [`PgDatabaseOptions`] specifying the database +/// to drop. Panics if any operation fails. +pub async fn drop_pg_database(options: &PgDatabaseOptions) { + // Connect to the default database. + let mut connection = PgConnection::connect_with(&options.without_db()) + .await + .expect("Failed to connect to Postgres"); + + // Forcefully terminate any remaining connections to the database. + connection + .execute(&*format!( + r#" + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{}' + AND pid <> pg_backend_pid();"#, + options.name + )) + .await + .expect("Failed to terminate database connections"); + + // Drop the database. + connection + .execute(&*format!(r#"DROP DATABASE IF EXISTS "{}";"#, options.name)) + .await + .expect("Failed to destroy database"); +} diff --git a/replicator/Cargo.toml b/replicator/Cargo.toml index bed4b67f..b24179cc 100644 --- a/replicator/Cargo.toml +++ b/replicator/Cargo.toml @@ -4,15 +4,16 @@ version = "0.1.0" edition = "2021" [dependencies] +pg_replicate = { workspace = true, features = ["bigquery"] } +telemetry = { workspace = true } + anyhow = { workspace = true, features = ["std"] } config = { workspace = true, features = ["yaml"] } -pg_replicate = { path = "../pg_replicate", features = ["bigquery"] } rustls = { workspace = true, features = ["aws-lc-rs", "logging"] } secrecy = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["std"] } thiserror = { workspace = true } rustls-pemfile = { workspace = true, features = ["std"] } -telemetry = { path = "../telemetry" } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true, default-features = true } diff --git a/api/scripts/init_db.sh b/scripts/init_db.sh similarity index 92% rename from api/scripts/init_db.sh rename to scripts/init_db.sh index 7b1fa1a6..15ff51a4 100755 --- a/api/scripts/init_db.sh +++ b/scripts/init_db.sh @@ -1,9 +1,9 @@ #!/usr/bin/env bash set -eo pipefail -if [ ! -d "migrations" ]; then - echo >&2 "❌ Error: '/migrations' folder not found." - echo >&2 "Please run this script from the 'pg_replicate/api' directory." +if [ ! -d "api/migrations" ]; then + echo >&2 "❌ Error: 'api/migrations' folder not found." + echo >&2 "Please run this script from the 'pg_replicate' directory." exit 1 fi @@ -81,6 +81,6 @@ echo "✅ PostgreSQL is up and running on port ${DB_PORT}" echo "🔄 Setting up the database..." export DATABASE_URL=postgres://${DB_USER}:${DB_PASSWORD}@${DB_HOST}:${DB_PORT}/${DB_NAME} sqlx database create -sqlx migrate run +sqlx migrate run --source api/migrations echo "✨ Database setup complete! Ready to go!" From 52b3694864e90814104e007925fe721f8fc2dbc3 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Mon, 19 May 2025 13:40:24 +0200 Subject: [PATCH 02/34] Implement tokio facilities --- api/Cargo.toml | 7 +- api/src/configuration.rs | 2 +- api/src/main.rs | 2 +- api/src/startup.rs | 2 +- api/tests/common/database.rs | 4 +- api/tests/common/test_app.rs | 4 +- postgres/Cargo.toml | 13 ++- postgres/src/lib.rs | 7 +- postgres/src/sqlx/mod.rs | 3 + postgres/src/{ => sqlx}/options.rs | 0 postgres/src/{ => sqlx}/test_utils.rs | 26 +++++- postgres/src/tokio/mod.rs | 3 + postgres/src/tokio/options.rs | 49 +++++++++++ postgres/src/tokio/test_utils.rs | 116 ++++++++++++++++++++++++++ 14 files changed, 224 insertions(+), 14 deletions(-) create mode 100644 postgres/src/sqlx/mod.rs rename postgres/src/{ => sqlx}/options.rs (100%) rename postgres/src/{ => sqlx}/test_utils.rs (75%) create mode 100644 postgres/src/tokio/mod.rs create mode 100644 postgres/src/tokio/options.rs create mode 100644 postgres/src/tokio/test_utils.rs diff --git a/api/Cargo.toml b/api/Cargo.toml index a6ff06ea..d25af3a7 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -11,7 +11,7 @@ path = "src/main.rs" name = "api" [dependencies] -postgres = { workspace = true, features = ["test_utils"] } +postgres = { workspace = true, features = ["sqlx"] } telemetry = { workspace = true } actix-web = { workspace = true, features = ["macros", "http2"] } @@ -47,4 +47,7 @@ tracing = { workspace = true, default-features = false } tracing-actix-web = { workspace = true, features = ["emit_event_on_error"] } utoipa = { workspace = true, features = ["actix_extras"] } utoipa-swagger-ui = { workspace = true, features = ["actix-web", "reqwest"] } -uuid = { version = "1.10.0", features = ["v4"] } \ No newline at end of file +uuid = { version = "1.10.0", features = ["v4"] } + +[dev-dependencies] +postgres = { workspace = true, features = ["test_utils", "sqlx"] } \ No newline at end of file diff --git a/api/src/configuration.rs b/api/src/configuration.rs index 0c957260..4bd36095 100644 --- a/api/src/configuration.rs +++ b/api/src/configuration.rs @@ -1,7 +1,7 @@ use std::fmt::{self, Display}; use base64::{prelude::BASE64_STANDARD, Engine}; -use postgres::options::PgDatabaseOptions; +use postgres::sqlx::options::PgDatabaseOptions; use serde::{ de::{self, MapAccess, Unexpected, Visitor}, Deserialize, Deserializer, diff --git a/api/src/main.rs b/api/src/main.rs index a85ed26c..2f76d377 100644 --- a/api/src/main.rs +++ b/api/src/main.rs @@ -5,7 +5,7 @@ use api::{ configuration::{get_settings, Settings}, startup::Application, }; -use postgres::options::PgDatabaseOptions; +use postgres::sqlx::options::PgDatabaseOptions; use telemetry::init_tracing; use tracing::{error, info}; diff --git a/api/src/startup.rs b/api/src/startup.rs index e9d63446..302dc758 100644 --- a/api/src/startup.rs +++ b/api/src/startup.rs @@ -49,7 +49,7 @@ use actix_web::{dev::Server, web, App, HttpServer}; use actix_web_httpauth::middleware::HttpAuthentication; use aws_lc_rs::aead::{RandomizedNonceKey, AES_256_GCM}; use base64::{prelude::BASE64_STANDARD, Engine}; -use postgres::options::PgDatabaseOptions; +use postgres::sqlx::options::PgDatabaseOptions; use sqlx::{postgres::PgPoolOptions, PgPool}; use tracing_actix_web::TracingLogger; use utoipa::OpenApi; diff --git a/api/tests/common/database.rs b/api/tests/common/database.rs index b5900cec..c08e3ed0 100644 --- a/api/tests/common/database.rs +++ b/api/tests/common/database.rs @@ -1,5 +1,5 @@ -use postgres::options::PgDatabaseOptions; -use postgres::test_utils::create_pg_database; +use postgres::sqlx::options::PgDatabaseOptions; +use postgres::sqlx::test_utils::create_pg_database; use sqlx::PgPool; /// Creates and configures a new PostgreSQL database for the API. diff --git a/api/tests/common/test_app.rs b/api/tests/common/test_app.rs index 4b041b17..c7d0144d 100644 --- a/api/tests/common/test_app.rs +++ b/api/tests/common/test_app.rs @@ -5,8 +5,8 @@ use api::{ encryption::{self, generate_random_key}, startup::run, }; -use postgres::options::PgDatabaseOptions; -use postgres::test_utils::drop_pg_database; +use postgres::sqlx::options::PgDatabaseOptions; +use postgres::sqlx::test_utils::drop_pg_database; use reqwest::{IntoUrl, RequestBuilder}; use serde::{Deserialize, Serialize}; use std::io; diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index 1b4dae4b..f3305d3c 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -14,6 +14,17 @@ sqlx = { workspace = true, features = [ "json", "migrate", ] } +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +tokio-postgres = { workspace = true, features = [ + "runtime", + "with-chrono-0_4", + "with-uuid-1", + "with-serde_json-1", +] } +tokio-postgres-rustls = { workspace = true } + [features] -test_utils = [] \ No newline at end of file +test_utils = [] +tokio = [] +sqlx = [] \ No newline at end of file diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index 37f84807..0d348139 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -1,3 +1,4 @@ -pub mod options; -#[cfg(feature = "test_utils")] -pub mod test_utils; +#[cfg(feature = "sqlx")] +pub mod sqlx; +#[cfg(feature = "tokio")] +pub mod tokio; diff --git a/postgres/src/sqlx/mod.rs b/postgres/src/sqlx/mod.rs new file mode 100644 index 00000000..37f84807 --- /dev/null +++ b/postgres/src/sqlx/mod.rs @@ -0,0 +1,3 @@ +pub mod options; +#[cfg(feature = "test_utils")] +pub mod test_utils; diff --git a/postgres/src/options.rs b/postgres/src/sqlx/options.rs similarity index 100% rename from postgres/src/options.rs rename to postgres/src/sqlx/options.rs diff --git a/postgres/src/test_utils.rs b/postgres/src/sqlx/test_utils.rs similarity index 75% rename from postgres/src/test_utils.rs rename to postgres/src/sqlx/test_utils.rs index 6b5612b4..f400f647 100644 --- a/postgres/src/test_utils.rs +++ b/postgres/src/sqlx/test_utils.rs @@ -1,5 +1,29 @@ -use crate::options::PgDatabaseOptions; +use crate::sqlx::options::PgDatabaseOptions; use sqlx::{Connection, Executor, PgConnection, PgPool}; +use tokio::runtime::Handle; + +struct PgDatabase { + options: PgDatabaseOptions, + pool: PgPool, +} + +impl PgDatabase { + pub async fn new(options: PgDatabaseOptions) -> Self { + let pool = create_pg_database(&options).await; + + Self { options, pool } + } +} + +impl Drop for PgDatabase { + fn drop(&mut self) { + // To use `block_in_place,` we need a multithreaded runtime since when a blocking + // task is issued, the runtime will offload existing tasks to another worker. + tokio::task::block_in_place(move || { + Handle::current().block_on(async move { drop_pg_database(&self.options).await }); + }); + } +} /// Creates a new PostgreSQL database and returns a connection pool to it. /// diff --git a/postgres/src/tokio/mod.rs b/postgres/src/tokio/mod.rs new file mode 100644 index 00000000..37f84807 --- /dev/null +++ b/postgres/src/tokio/mod.rs @@ -0,0 +1,3 @@ +pub mod options; +#[cfg(feature = "test_utils")] +pub mod test_utils; diff --git a/postgres/src/tokio/options.rs b/postgres/src/tokio/options.rs new file mode 100644 index 00000000..01e40eb6 --- /dev/null +++ b/postgres/src/tokio/options.rs @@ -0,0 +1,49 @@ +use tokio_postgres::Config; +use tokio_postgres::config::SslMode; + +#[derive(Debug, Clone)] +pub struct PgDatabaseOptions { + pub host: String, + pub port: u16, + pub database: String, + pub name: String, + pub username: String, + pub password: Option, + pub ssl_mode: Option, +} + +impl PgDatabaseOptions { + pub fn without_db(&self) -> Config { + let mut this = self.clone(); + // Postgres requires a database, so we default to the database which is equal to the username + // since this seems to be the standard. + this.database = this.username.clone(); + + this.into() + } + + pub fn with_db(&self) -> Config { + self.clone().into() + } +} + +impl From for Config { + fn from(value: PgDatabaseOptions) -> Self { + let mut config = Config::new(); + config + .host(value.host) + .port(value.port) + .dbname(value.database) + .user(value.username); + + if let Some(password) = value.password { + config.password(password); + } + + if let Some(ssl_mode) = value.ssl_mode { + config.ssl_mode(ssl_mode); + } + + config + } +} diff --git a/postgres/src/tokio/test_utils.rs b/postgres/src/tokio/test_utils.rs new file mode 100644 index 00000000..2672d824 --- /dev/null +++ b/postgres/src/tokio/test_utils.rs @@ -0,0 +1,116 @@ +use crate::tokio::options::PgDatabaseOptions; +use tokio::runtime::Handle; +use tokio_postgres::{Client, NoTls}; + +struct PgDatabase { + options: PgDatabaseOptions, + client: Client, +} + +impl PgDatabase { + pub async fn new(options: PgDatabaseOptions) -> Self { + let client = create_pg_database(&options).await; + + Self { options, client } + } +} + +impl Drop for PgDatabase { + fn drop(&mut self) { + // To use `block_in_place,` we need a multithreaded runtime since when a blocking + // task is issued, the runtime will offload existing tasks to another worker. + tokio::task::block_in_place(move || { + Handle::current().block_on(async move { drop_pg_database(&self.options).await }); + }); + } +} + +/// Creates a new PostgreSQL database and returns a client connected to it. +/// +/// Establishes a connection to the PostgreSQL server using the provided options, +/// creates a new database, and returns a [`Client`] connected to the new database. +/// Panics if the connection fails or if database creation fails. +pub async fn create_pg_database(options: &PgDatabaseOptions) -> Client { + // Create the database via a single connection + let (client, connection) = options + .without_db() + .connect(NoTls) + .await + .expect("Failed to connect to Postgres"); + + // Spawn the connection on a new task + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + // Create the database + client + .execute(&*format!(r#"CREATE DATABASE "{}";"#, options.name), &[]) + .await + .expect("Failed to create database"); + + // Create a new client connected to the created database + let (client, connection) = options + .with_db() + .connect(NoTls) + .await + .expect("Failed to connect to Postgres"); + + // Spawn the connection on a new task + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + client +} + +/// Drops a PostgreSQL database and cleans up all connections. +/// +/// Connects to the PostgreSQL server, forcefully terminates all active connections +/// to the target database, and drops the database if it exists. Useful for cleaning +/// up test databases. Takes a reference to [`PgDatabaseOptions`] specifying the database +/// to drop. Panics if any operation fails. +pub async fn drop_pg_database(options: &PgDatabaseOptions) { + // Connect to the default database + let (client, connection) = options + .without_db() + .connect(NoTls) + .await + .expect("Failed to connect to Postgres"); + + // Spawn the connection on a new task + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + // Forcefully terminate any remaining connections to the database + client + .execute( + &*format!( + r#" + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{}' + AND pid <> pg_backend_pid();"#, + options.name + ), + &[], + ) + .await + .expect("Failed to terminate database connections"); + + // Drop the database + client + .execute( + &*format!(r#"DROP DATABASE IF EXISTS "{}";"#, options.name), + &[], + ) + .await + .expect("Failed to destroy database"); +} From 0e32a1b2a75418900f94a212904970171c3d21f7 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Mon, 19 May 2025 14:08:25 +0200 Subject: [PATCH 03/34] Use options in pg_replicate --- api/Cargo.toml | 2 +- pg_replicate/Cargo.toml | 7 ++- pg_replicate/examples/bigquery.rs | 32 ++++++------- pg_replicate/examples/duckdb.rs | 32 ++++++------- pg_replicate/examples/stdout.rs | 32 ++++++------- pg_replicate/src/clients/postgres.rs | 46 ++++-------------- pg_replicate/src/pipeline/sources/postgres.rs | 47 ++++++++----------- postgres/Cargo.toml | 2 +- postgres/src/sqlx/mod.rs | 2 +- postgres/src/tokio/mod.rs | 2 +- postgres/src/tokio/options.rs | 14 ++---- replicator/Cargo.toml | 1 + replicator/src/main.rs | 13 +++-- 13 files changed, 95 insertions(+), 137 deletions(-) diff --git a/api/Cargo.toml b/api/Cargo.toml index d25af3a7..6b6d294a 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -50,4 +50,4 @@ utoipa-swagger-ui = { workspace = true, features = ["actix-web", "reqwest"] } uuid = { version = "1.10.0", features = ["v4"] } [dev-dependencies] -postgres = { workspace = true, features = ["test_utils", "sqlx"] } \ No newline at end of file +postgres = { workspace = true, features = ["test-utils", "sqlx"] } \ No newline at end of file diff --git a/pg_replicate/Cargo.toml b/pg_replicate/Cargo.toml index 80d56061..7c16d677 100644 --- a/pg_replicate/Cargo.toml +++ b/pg_replicate/Cargo.toml @@ -18,6 +18,8 @@ name = "stdout" required-features = ["stdout"] [dependencies] +postgres = { workspace = true, features = ["tokio"]} + async-trait = { workspace = true } bigdecimal = { workspace = true, features = ["std"] } bytes = { workspace = true } @@ -50,6 +52,8 @@ tracing = { workspace = true, default-features = true } uuid = { workspace = true, features = ["v4"] } [dev-dependencies] +postgres = { workspace = true, features = ["test-utils", "tokio"]} + clap = { workspace = true, default-features = true, features = [ "std", "derive", @@ -58,10 +62,11 @@ tracing-subscriber = { workspace = true, default-features = true, features = [ "env-filter", ] } + [features] bigquery = ["dep:gcp-bigquery-client", "dep:prost"] duckdb = ["dep:duckdb"] stdout = [] # When enabled converts unknown types to bytes unknown_types_to_bytes = [] -default = ["unknown_types_to_bytes"] +default = ["unknown_types_to_bytes"] \ No newline at end of file diff --git a/pg_replicate/examples/bigquery.rs b/pg_replicate/examples/bigquery.rs index 9cfd41de..7ae23225 100644 --- a/pg_replicate/examples/bigquery.rs +++ b/pg_replicate/examples/bigquery.rs @@ -11,6 +11,7 @@ use pg_replicate::{ table::TableName, SslMode, }; +use postgres::tokio::options::PgDatabaseOptions; use tracing::error; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -120,22 +121,22 @@ async fn main_impl() -> Result<(), Box> { let db_args = args.db_args; let bq_args = args.bq_args; + let options = PgDatabaseOptions { + host: db_args.db_host, + port: db_args.db_port, + name: db_args.db_name, + username: db_args.db_username, + password: db_args.db_password, + ssl_mode: SslMode::Disable, + }; + let (postgres_source, action) = match args.command { Command::CopyTable { schema, name } => { let table_names = vec![TableName { schema, name }]; - let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, - vec![], - None, - TableNamesFrom::Vec(table_names), - ) - .await?; + let postgres_source = + PostgresSource::new(options, vec![], None, TableNamesFrom::Vec(table_names)) + .await?; (postgres_source, PipelineAction::TableCopiesOnly) } Command::Cdc { @@ -143,12 +144,7 @@ async fn main_impl() -> Result<(), Box> { slot_name, } => { let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, + options, vec![], Some(slot_name), TableNamesFrom::Publication(publication), diff --git a/pg_replicate/examples/duckdb.rs b/pg_replicate/examples/duckdb.rs index c222203e..90bb860d 100644 --- a/pg_replicate/examples/duckdb.rs +++ b/pg_replicate/examples/duckdb.rs @@ -11,6 +11,7 @@ use pg_replicate::{ table::TableName, SslMode, }; +use postgres::tokio::options::PgDatabaseOptions; use tracing::error; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -114,22 +115,22 @@ async fn main_impl() -> Result<(), Box> { let args = AppArgs::parse(); let db_args = args.db_args; + let options = PgDatabaseOptions { + host: db_args.db_host, + port: db_args.db_port, + name: db_args.db_name, + username: db_args.db_username, + password: db_args.db_password, + ssl_mode: SslMode::Disable, + }; + let (postgres_source, action) = match args.command { Command::CopyTable { schema, name } => { let table_names = vec![TableName { schema, name }]; - let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, - vec![], - None, - TableNamesFrom::Vec(table_names), - ) - .await?; + let postgres_source = + PostgresSource::new(options, vec![], None, TableNamesFrom::Vec(table_names)) + .await?; (postgres_source, PipelineAction::TableCopiesOnly) } Command::Cdc { @@ -137,12 +138,7 @@ async fn main_impl() -> Result<(), Box> { slot_name, } => { let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, + options, vec![], Some(slot_name), TableNamesFrom::Publication(publication), diff --git a/pg_replicate/examples/stdout.rs b/pg_replicate/examples/stdout.rs index 330c8fab..8121e8df 100644 --- a/pg_replicate/examples/stdout.rs +++ b/pg_replicate/examples/stdout.rs @@ -11,6 +11,7 @@ use pg_replicate::{ table::TableName, SslMode, }; +use postgres::tokio::options::PgDatabaseOptions; use tracing::error; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -90,22 +91,22 @@ async fn main_impl() -> Result<(), Box> { let args = AppArgs::parse(); let db_args = args.db_args; + let options = PgDatabaseOptions { + host: db_args.db_host, + port: db_args.db_port, + name: db_args.db_name, + username: db_args.db_username, + password: db_args.db_password, + ssl_mode: SslMode::Disable, + }; + let (postgres_source, action) = match args.command { Command::CopyTable { schema, name } => { let table_names = vec![TableName { schema, name }]; - let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, - vec![], - None, - TableNamesFrom::Vec(table_names), - ) - .await?; + let postgres_source = + PostgresSource::new(options, vec![], None, TableNamesFrom::Vec(table_names)) + .await?; (postgres_source, PipelineAction::TableCopiesOnly) } Command::Cdc { @@ -113,12 +114,7 @@ async fn main_impl() -> Result<(), Box> { slot_name, } => { let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, + options, vec![], Some(slot_name), TableNamesFrom::Publication(publication), diff --git a/pg_replicate/src/clients/postgres.rs b/pg_replicate/src/clients/postgres.rs index 9581c13d..e351d156 100644 --- a/pg_replicate/src/clients/postgres.rs +++ b/pg_replicate/src/clients/postgres.rs @@ -1,19 +1,19 @@ use std::collections::HashMap; +use crate::table::{ColumnSchema, TableId, TableName, TableSchema}; use pg_escape::{quote_identifier, quote_literal}; +use postgres::tokio::options::PgDatabaseOptions; use postgres_replication::LogicalReplicationStream; use rustls::{pki_types::CertificateDer, ClientConfig}; use thiserror::Error; use tokio_postgres::{ - config::{ReplicationMode, SslMode}, + config::ReplicationMode, types::{Kind, PgLsn, Type}, Client as PostgresClient, Config, CopyOutStream, NoTls, SimpleQueryMessage, }; use tokio_postgres_rustls::MakeRustlsConnect; use tracing::{info, warn}; -use crate::table::{ColumnSchema, TableId, TableName, TableSchema}; - pub struct SlotInfo { pub confirmed_flush_lsn: PgLsn, } @@ -63,25 +63,12 @@ pub enum ReplicationClientError { impl ReplicationClient { /// Connect to a postgres database in logical replication mode without TLS pub async fn connect_no_tls( - host: &str, - port: u16, - database: &str, - username: &str, - password: Option, + options: PgDatabaseOptions, ) -> Result { info!("connecting to postgres without TLS"); - let mut config = Config::new(); - config - .host(host) - .port(port) - .dbname(database) - .user(username) - .replication_mode(ReplicationMode::Logical); - - if let Some(password) = password { - config.password(password); - } + let mut config: Config = options.into(); + config.replication_mode(ReplicationMode::Logical); let (postgres_client, connection) = config.connect(NoTls).await?; @@ -102,28 +89,13 @@ impl ReplicationClient { /// Connect to a postgres database in logical replication mode with TLS pub async fn connect_tls( - host: &str, - port: u16, - database: &str, - username: &str, - password: Option, - ssl_mode: SslMode, + options: PgDatabaseOptions, trusted_root_certs: Vec>, ) -> Result { info!("connecting to postgres with TLS"); - let mut config = Config::new(); - config - .host(host) - .port(port) - .dbname(database) - .user(username) - .ssl_mode(ssl_mode) - .replication_mode(ReplicationMode::Logical); - - if let Some(password) = password { - config.password(password); - } + let mut config: Config = options.into(); + config.replication_mode(ReplicationMode::Logical); let mut root_store = rustls::RootCertStore::empty(); for trusted_root_cert in trusted_root_certs { diff --git a/pg_replicate/src/pipeline/sources/postgres.rs b/pg_replicate/src/pipeline/sources/postgres.rs index 9c100131..92fc5e0e 100644 --- a/pg_replicate/src/pipeline/sources/postgres.rs +++ b/pg_replicate/src/pipeline/sources/postgres.rs @@ -5,15 +5,6 @@ use std::{ time::{Duration, SystemTime, SystemTimeError, UNIX_EPOCH}, }; -use async_trait::async_trait; -use futures::{ready, Stream}; -use pin_project_lite::pin_project; -use postgres_replication::LogicalReplicationStream; -use rustls::pki_types::CertificateDer; -use thiserror::Error; -use tokio_postgres::{config::SslMode, types::PgLsn, CopyOutStream}; -use tracing::info; - use crate::{ clients::postgres::{ReplicationClient, ReplicationClientError}, conversions::{ @@ -22,6 +13,15 @@ use crate::{ }, table::{ColumnSchema, TableId, TableName, TableSchema}, }; +use async_trait::async_trait; +use futures::{ready, Stream}; +use pin_project_lite::pin_project; +use postgres::tokio::options::PgDatabaseOptions; +use postgres_replication::LogicalReplicationStream; +use rustls::pki_types::CertificateDer; +use thiserror::Error; +use tokio_postgres::{config::SslMode, types::PgLsn, CopyOutStream}; +use tracing::info; use super::{Source, SourceError}; @@ -54,39 +54,30 @@ pub struct PostgresSource { impl PostgresSource { #[allow(clippy::too_many_arguments)] pub async fn new( - host: &str, - port: u16, - database: &str, - username: &str, - password: Option, - ssl_mode: SslMode, + options: PgDatabaseOptions, trusted_root_certs: Vec>, slot_name: Option, table_names_from: TableNamesFrom, ) -> Result { - let mut replication_client = if ssl_mode == SslMode::Disable { - ReplicationClient::connect_no_tls(host, port, database, username, password).await? - } else { - ReplicationClient::connect_tls( - host, - port, - database, - username, - password, - ssl_mode, - trusted_root_certs, - ) - .await? + let mut replication_client = match options.ssl_mode { + SslMode::Disable => ReplicationClient::connect_no_tls(options).await?, + _ => ReplicationClient::connect_tls(options, trusted_root_certs).await?, }; + + // TODO: we have to fix this whole block which starts the transaction and loads the data. + // We will not do this here in the future but rather let the pipeline drive this based on + // its internal state. replication_client.begin_readonly_transaction().await?; if let Some(ref slot_name) = slot_name { replication_client.get_or_create_slot(slot_name).await?; } + let (table_names, publication) = Self::get_table_names_and_publication(&replication_client, table_names_from).await?; let table_schemas = replication_client .get_table_schemas(&table_names, publication.as_deref()) .await?; + Ok(PostgresSource { replication_client, table_schemas, diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index f3305d3c..5bed63ad 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -25,6 +25,6 @@ tokio-postgres-rustls = { workspace = true } [features] -test_utils = [] +test-utils = [] tokio = [] sqlx = [] \ No newline at end of file diff --git a/postgres/src/sqlx/mod.rs b/postgres/src/sqlx/mod.rs index 37f84807..090183a4 100644 --- a/postgres/src/sqlx/mod.rs +++ b/postgres/src/sqlx/mod.rs @@ -1,3 +1,3 @@ pub mod options; -#[cfg(feature = "test_utils")] +#[cfg(feature = "test-utils")] pub mod test_utils; diff --git a/postgres/src/tokio/mod.rs b/postgres/src/tokio/mod.rs index 37f84807..090183a4 100644 --- a/postgres/src/tokio/mod.rs +++ b/postgres/src/tokio/mod.rs @@ -1,3 +1,3 @@ pub mod options; -#[cfg(feature = "test_utils")] +#[cfg(feature = "test-utils")] pub mod test_utils; diff --git a/postgres/src/tokio/options.rs b/postgres/src/tokio/options.rs index 01e40eb6..0a98ced9 100644 --- a/postgres/src/tokio/options.rs +++ b/postgres/src/tokio/options.rs @@ -5,11 +5,10 @@ use tokio_postgres::config::SslMode; pub struct PgDatabaseOptions { pub host: String, pub port: u16, - pub database: String, pub name: String, pub username: String, pub password: Option, - pub ssl_mode: Option, + pub ssl_mode: SslMode, } impl PgDatabaseOptions { @@ -17,7 +16,7 @@ impl PgDatabaseOptions { let mut this = self.clone(); // Postgres requires a database, so we default to the database which is equal to the username // since this seems to be the standard. - this.database = this.username.clone(); + this.name = this.username.clone(); this.into() } @@ -33,17 +32,14 @@ impl From for Config { config .host(value.host) .port(value.port) - .dbname(value.database) - .user(value.username); + .dbname(value.name) + .user(value.username) + .ssl_mode(value.ssl_mode); if let Some(password) = value.password { config.password(password); } - if let Some(ssl_mode) = value.ssl_mode { - config.ssl_mode(ssl_mode); - } - config } } diff --git a/replicator/Cargo.toml b/replicator/Cargo.toml index b24179cc..e5870055 100644 --- a/replicator/Cargo.toml +++ b/replicator/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] pg_replicate = { workspace = true, features = ["bigquery"] } +postgres = { workspace = true, features = ["tokio"] } telemetry = { workspace = true } anyhow = { workspace = true, features = ["std"] } diff --git a/replicator/src/main.rs b/replicator/src/main.rs index fef6b948..7df9290f 100644 --- a/replicator/src/main.rs +++ b/replicator/src/main.rs @@ -12,6 +12,7 @@ use pg_replicate::{ }, SslMode, }; +use postgres::tokio::options::PgDatabaseOptions; use telemetry::init_tracing; use tracing::{info, instrument}; @@ -106,13 +107,17 @@ async fn start_replication(settings: Settings) -> anyhow::Result<()> { SslMode::Disable }; - let postgres_source = PostgresSource::new( - &host, + let options = PgDatabaseOptions { + host, port, - &name, - &username, + name, + username, password, ssl_mode, + }; + + let postgres_source = PostgresSource::new( + options, trusted_root_certs_vec, Some(slot_name), TableNamesFrom::Publication(publication), From 97c066fbbba4721dc2c8d40a7c3905dc329a6932 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Mon, 19 May 2025 16:04:57 +0200 Subject: [PATCH 04/34] Add test facilities --- api/tests/common/mod.rs | 17 -- pg_replicate/examples/bigquery.rs | 2 +- pg_replicate/examples/duckdb.rs | 2 +- pg_replicate/examples/stdout.rs | 2 +- pg_replicate/src/clients/bigquery.rs | 193 +++++++++--------- pg_replicate/src/clients/duckdb.rs | 6 +- pg_replicate/src/clients/postgres.rs | 2 +- pg_replicate/src/conversions/cdc_event.rs | 6 +- pg_replicate/src/conversions/table_row.rs | 3 +- pg_replicate/src/lib.rs | 1 - .../src/pipeline/batching/data_pipeline.rs | 11 +- pg_replicate/src/pipeline/mod.rs | 3 +- pg_replicate/src/pipeline/sinks/bigquery.rs | 10 +- .../src/pipeline/sinks/duckdb/executor.rs | 2 +- .../src/pipeline/sinks/duckdb/sink.rs | 2 +- pg_replicate/src/pipeline/sinks/mod.rs | 6 +- pg_replicate/src/pipeline/sinks/stdout.rs | 2 +- pg_replicate/src/pipeline/sources/mod.rs | 5 +- pg_replicate/src/pipeline/sources/postgres.rs | 4 +- pg_replicate/tests/common/mod.rs | 1 + pg_replicate/tests/common/pipeline.rs | 88 ++++++++ pg_replicate/tests/integration/base.rs | 2 + pg_replicate/tests/integration/mod.rs | 1 + pg_replicate/tests/mod.rs | 3 +- postgres/Cargo.toml | 1 + postgres/src/lib.rs | 12 ++ .../src/table.rs => postgres/src/schema.rs | 6 +- postgres/src/tokio/options.rs | 26 +++ postgres/src/tokio/test_utils.rs | 32 ++- 29 files changed, 292 insertions(+), 159 deletions(-) create mode 100644 pg_replicate/tests/common/mod.rs create mode 100644 pg_replicate/tests/common/pipeline.rs create mode 100644 pg_replicate/tests/integration/base.rs create mode 100644 pg_replicate/tests/integration/mod.rs rename pg_replicate/src/table.rs => postgres/src/schema.rs (88%) diff --git a/api/tests/common/mod.rs b/api/tests/common/mod.rs index 5cc827d5..3f8ea5a3 100644 --- a/api/tests/common/mod.rs +++ b/api/tests/common/mod.rs @@ -1,19 +1,2 @@ -//! Common test utilities for pg_replicate API tests. -//! -//! This module provides shared functionality used across integration tests: -//! -//! - `test_app`: A test application wrapper that provides: -//! - A running instance of the API server for testing -//! - Helper methods for making authenticated HTTP requests -//! - Request/response type definitions for all API endpoints -//! - Methods to create, read, update, and delete resources -//! -//! - `database`: Database configuration utilities that: -//! - Set up test databases with proper configuration -//! - Handle database migrations -//! - Provide connection pools for tests -//! -//! These utilities help maintain consistency across tests and reduce code duplication. - pub mod database; pub mod test_app; diff --git a/pg_replicate/examples/bigquery.rs b/pg_replicate/examples/bigquery.rs index 7ae23225..e8f16429 100644 --- a/pg_replicate/examples/bigquery.rs +++ b/pg_replicate/examples/bigquery.rs @@ -8,9 +8,9 @@ use pg_replicate::{ sources::postgres::{PostgresSource, TableNamesFrom}, PipelineAction, }, - table::TableName, SslMode, }; +use postgres::schema::TableName; use postgres::tokio::options::PgDatabaseOptions; use tracing::error; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; diff --git a/pg_replicate/examples/duckdb.rs b/pg_replicate/examples/duckdb.rs index 90bb860d..a989de64 100644 --- a/pg_replicate/examples/duckdb.rs +++ b/pg_replicate/examples/duckdb.rs @@ -8,9 +8,9 @@ use pg_replicate::{ sources::postgres::{PostgresSource, TableNamesFrom}, PipelineAction, }, - table::TableName, SslMode, }; +use postgres::schema::TableName; use postgres::tokio::options::PgDatabaseOptions; use tracing::error; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; diff --git a/pg_replicate/examples/stdout.rs b/pg_replicate/examples/stdout.rs index 8121e8df..5172cbe2 100644 --- a/pg_replicate/examples/stdout.rs +++ b/pg_replicate/examples/stdout.rs @@ -8,9 +8,9 @@ use pg_replicate::{ sources::postgres::{PostgresSource, TableNamesFrom}, PipelineAction, }, - table::TableName, SslMode, }; +use postgres::schema::TableName; use postgres::tokio::options::PgDatabaseOptions; use tracing::error; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; diff --git a/pg_replicate/src/clients/bigquery.rs b/pg_replicate/src/clients/bigquery.rs index c22a62ee..6b1f213e 100644 --- a/pg_replicate/src/clients/bigquery.rs +++ b/pg_replicate/src/clients/bigquery.rs @@ -15,17 +15,15 @@ use gcp_bigquery_client::{ storage::{ColumnType, FieldDescriptor, StreamName, TableDescriptor}, Client, }; +use postgres::schema::{ColumnSchema, TableId, TableSchema}; use prost::Message; use tokio_postgres::types::{PgLsn, Type}; use tracing::info; use uuid::Uuid; use crate::conversions::numeric::PgNumeric; +use crate::conversions::table_row::TableRow; use crate::conversions::{ArrayCell, Cell}; -use crate::{ - conversions::table_row::TableRow, - table::{ColumnSchema, TableId, TableSchema}, -}; pub struct BigQueryClient { project_id: String, @@ -990,102 +988,105 @@ impl ArrayCell { } } -impl From<&TableSchema> for TableDescriptor { - fn from(table_schema: &TableSchema) -> Self { - let mut field_descriptors = Vec::with_capacity(table_schema.column_schemas.len()); - let mut number = 1; - for column_schema in &table_schema.column_schemas { - let typ = match column_schema.typ { - Type::BOOL => ColumnType::Bool, - Type::CHAR | Type::BPCHAR | Type::VARCHAR | Type::NAME | Type::TEXT => { - ColumnType::String - } - Type::INT2 => ColumnType::Int32, - Type::INT4 => ColumnType::Int32, - Type::INT8 => ColumnType::Int64, - Type::FLOAT4 => ColumnType::Float, - Type::FLOAT8 => ColumnType::Double, - Type::NUMERIC => ColumnType::String, - Type::DATE => ColumnType::String, - Type::TIME => ColumnType::String, - Type::TIMESTAMP => ColumnType::String, - Type::TIMESTAMPTZ => ColumnType::String, - Type::UUID => ColumnType::String, - Type::JSON => ColumnType::String, - Type::JSONB => ColumnType::String, - Type::OID => ColumnType::Int32, - Type::BYTEA => ColumnType::Bytes, - Type::BOOL_ARRAY => ColumnType::Bool, - Type::CHAR_ARRAY - | Type::BPCHAR_ARRAY - | Type::VARCHAR_ARRAY - | Type::NAME_ARRAY - | Type::TEXT_ARRAY => ColumnType::String, - Type::INT2_ARRAY => ColumnType::Int32, - Type::INT4_ARRAY => ColumnType::Int32, - Type::INT8_ARRAY => ColumnType::Int64, - Type::FLOAT4_ARRAY => ColumnType::Float, - Type::FLOAT8_ARRAY => ColumnType::Double, - Type::NUMERIC_ARRAY => ColumnType::String, - Type::DATE_ARRAY => ColumnType::String, - Type::TIME_ARRAY => ColumnType::String, - Type::TIMESTAMP_ARRAY => ColumnType::String, - Type::TIMESTAMPTZ_ARRAY => ColumnType::String, - Type::UUID_ARRAY => ColumnType::String, - Type::JSON_ARRAY => ColumnType::String, - Type::JSONB_ARRAY => ColumnType::String, - Type::OID_ARRAY => ColumnType::Int32, - Type::BYTEA_ARRAY => ColumnType::Bytes, - _ => ColumnType::String, - }; - - let mode = match column_schema.typ { - Type::BOOL_ARRAY - | Type::CHAR_ARRAY - | Type::BPCHAR_ARRAY - | Type::VARCHAR_ARRAY - | Type::NAME_ARRAY - | Type::TEXT_ARRAY - | Type::INT2_ARRAY - | Type::INT4_ARRAY - | Type::INT8_ARRAY - | Type::FLOAT4_ARRAY - | Type::FLOAT8_ARRAY - | Type::NUMERIC_ARRAY - | Type::DATE_ARRAY - | Type::TIME_ARRAY - | Type::TIMESTAMP_ARRAY - | Type::TIMESTAMPTZ_ARRAY - | Type::UUID_ARRAY - | Type::JSON_ARRAY - | Type::JSONB_ARRAY - | Type::OID_ARRAY - | Type::BYTEA_ARRAY => ColumnMode::Repeated, - _ => { - if column_schema.nullable { - ColumnMode::Nullable - } else { - ColumnMode::Required - } +/// Converts a [`TableSchema`] to [`TableDescriptor`]. +/// +/// This function is defined here and doesn't use the [`From`] trait because it's not possible since +/// [`TableSchema`] is in another crate and we don't want to pollute the `postgres` crate with sink +/// specific internals. +pub fn table_schema_to_descriptor(table_schema: &TableSchema) -> TableDescriptor { + let mut field_descriptors = Vec::with_capacity(table_schema.column_schemas.len()); + let mut number = 1; + for column_schema in &table_schema.column_schemas { + let typ = match column_schema.typ { + Type::BOOL => ColumnType::Bool, + Type::CHAR | Type::BPCHAR | Type::VARCHAR | Type::NAME | Type::TEXT => { + ColumnType::String + } + Type::INT2 => ColumnType::Int32, + Type::INT4 => ColumnType::Int32, + Type::INT8 => ColumnType::Int64, + Type::FLOAT4 => ColumnType::Float, + Type::FLOAT8 => ColumnType::Double, + Type::NUMERIC => ColumnType::String, + Type::DATE => ColumnType::String, + Type::TIME => ColumnType::String, + Type::TIMESTAMP => ColumnType::String, + Type::TIMESTAMPTZ => ColumnType::String, + Type::UUID => ColumnType::String, + Type::JSON => ColumnType::String, + Type::JSONB => ColumnType::String, + Type::OID => ColumnType::Int32, + Type::BYTEA => ColumnType::Bytes, + Type::BOOL_ARRAY => ColumnType::Bool, + Type::CHAR_ARRAY + | Type::BPCHAR_ARRAY + | Type::VARCHAR_ARRAY + | Type::NAME_ARRAY + | Type::TEXT_ARRAY => ColumnType::String, + Type::INT2_ARRAY => ColumnType::Int32, + Type::INT4_ARRAY => ColumnType::Int32, + Type::INT8_ARRAY => ColumnType::Int64, + Type::FLOAT4_ARRAY => ColumnType::Float, + Type::FLOAT8_ARRAY => ColumnType::Double, + Type::NUMERIC_ARRAY => ColumnType::String, + Type::DATE_ARRAY => ColumnType::String, + Type::TIME_ARRAY => ColumnType::String, + Type::TIMESTAMP_ARRAY => ColumnType::String, + Type::TIMESTAMPTZ_ARRAY => ColumnType::String, + Type::UUID_ARRAY => ColumnType::String, + Type::JSON_ARRAY => ColumnType::String, + Type::JSONB_ARRAY => ColumnType::String, + Type::OID_ARRAY => ColumnType::Int32, + Type::BYTEA_ARRAY => ColumnType::Bytes, + _ => ColumnType::String, + }; + + let mode = match column_schema.typ { + Type::BOOL_ARRAY + | Type::CHAR_ARRAY + | Type::BPCHAR_ARRAY + | Type::VARCHAR_ARRAY + | Type::NAME_ARRAY + | Type::TEXT_ARRAY + | Type::INT2_ARRAY + | Type::INT4_ARRAY + | Type::INT8_ARRAY + | Type::FLOAT4_ARRAY + | Type::FLOAT8_ARRAY + | Type::NUMERIC_ARRAY + | Type::DATE_ARRAY + | Type::TIME_ARRAY + | Type::TIMESTAMP_ARRAY + | Type::TIMESTAMPTZ_ARRAY + | Type::UUID_ARRAY + | Type::JSON_ARRAY + | Type::JSONB_ARRAY + | Type::OID_ARRAY + | Type::BYTEA_ARRAY => ColumnMode::Repeated, + _ => { + if column_schema.nullable { + ColumnMode::Nullable + } else { + ColumnMode::Required } - }; - - field_descriptors.push(FieldDescriptor { - number, - name: column_schema.name.clone(), - typ, - mode, - }); - number += 1; - } + } + }; field_descriptors.push(FieldDescriptor { number, - name: "_CHANGE_TYPE".to_string(), - typ: ColumnType::String, - mode: ColumnMode::Required, + name: column_schema.name.clone(), + typ, + mode, }); - - TableDescriptor { field_descriptors } + number += 1; } + + field_descriptors.push(FieldDescriptor { + number, + name: "_CHANGE_TYPE".to_string(), + typ: ColumnType::String, + mode: ColumnMode::Required, + }); + + TableDescriptor { field_descriptors } } diff --git a/pg_replicate/src/clients/duckdb.rs b/pg_replicate/src/clients/duckdb.rs index 389a29f6..e1421af4 100644 --- a/pg_replicate/src/clients/duckdb.rs +++ b/pg_replicate/src/clients/duckdb.rs @@ -5,12 +5,10 @@ use duckdb::{ types::{ToSqlOutput, Value}, Config, Connection, ToSql, }; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use tokio_postgres::types::{PgLsn, Type}; -use crate::{ - conversions::{table_row::TableRow, ArrayCell, Cell}, - table::{ColumnSchema, TableId, TableName, TableSchema}, -}; +use crate::conversions::{table_row::TableRow, ArrayCell, Cell}; pub struct DuckDbClient { conn: Connection, diff --git a/pg_replicate/src/clients/postgres.rs b/pg_replicate/src/clients/postgres.rs index e351d156..166da7ca 100644 --- a/pg_replicate/src/clients/postgres.rs +++ b/pg_replicate/src/clients/postgres.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; -use crate::table::{ColumnSchema, TableId, TableName, TableSchema}; use pg_escape::{quote_identifier, quote_literal}; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use postgres::tokio::options::PgDatabaseOptions; use postgres_replication::LogicalReplicationStream; use rustls::{pki_types::CertificateDer, ClientConfig}; diff --git a/pg_replicate/src/conversions/cdc_event.rs b/pg_replicate/src/conversions/cdc_event.rs index c3f1d935..a6f493f1 100644 --- a/pg_replicate/src/conversions/cdc_event.rs +++ b/pg_replicate/src/conversions/cdc_event.rs @@ -1,16 +1,14 @@ use core::str; use std::{collections::HashMap, str::Utf8Error}; +use postgres::schema::{ColumnSchema, TableId, TableSchema}; use postgres_replication::protocol::{ BeginBody, CommitBody, DeleteBody, InsertBody, LogicalReplicationMessage, OriginBody, RelationBody, ReplicationMessage, TruncateBody, TupleData, TypeBody, UpdateBody, }; use thiserror::Error; -use crate::{ - pipeline::batching::BatchBoundary, - table::{ColumnSchema, TableId, TableSchema}, -}; +use crate::pipeline::batching::BatchBoundary; use super::{ table_row::TableRow, diff --git a/pg_replicate/src/conversions/table_row.rs b/pg_replicate/src/conversions/table_row.rs index 411177d5..12a7dd96 100644 --- a/pg_replicate/src/conversions/table_row.rs +++ b/pg_replicate/src/conversions/table_row.rs @@ -1,6 +1,7 @@ use core::str; use std::str::Utf8Error; +use postgres::schema::ColumnSchema; use thiserror::Error; use tokio_postgres::types::Type; use tracing::error; @@ -44,7 +45,7 @@ impl TableRowConverter { // parses text produced by this code in Postgres: https://github.com/postgres/postgres/blob/263a3f5f7f508167dbeafc2aefd5835b41d77481/src/backend/commands/copyto.c#L988-L1134 pub fn try_from( row: &[u8], - column_schemas: &[crate::table::ColumnSchema], + column_schemas: &[ColumnSchema], ) -> Result { let mut values = Vec::with_capacity(column_schemas.len()); diff --git a/pg_replicate/src/lib.rs b/pg_replicate/src/lib.rs index d725db6a..5eab8fde 100644 --- a/pg_replicate/src/lib.rs +++ b/pg_replicate/src/lib.rs @@ -1,5 +1,4 @@ pub mod clients; pub mod conversions; pub mod pipeline; -pub mod table; pub use tokio_postgres::config::SslMode; diff --git a/pg_replicate/src/pipeline/batching/data_pipeline.rs b/pg_replicate/src/pipeline/batching/data_pipeline.rs index e1f371fd..8c506fa9 100644 --- a/pg_replicate/src/pipeline/batching/data_pipeline.rs +++ b/pg_replicate/src/pipeline/batching/data_pipeline.rs @@ -1,10 +1,5 @@ use std::{collections::HashSet, time::Instant}; -use futures::StreamExt; -use tokio::pin; -use tokio_postgres::types::PgLsn; -use tracing::{debug, info}; - use crate::{ conversions::cdc_event::{CdcEvent, CdcEventConversionError}, pipeline::{ @@ -13,8 +8,12 @@ use crate::{ sources::{postgres::CdcStreamError, CommonSourceError, Source}, PipelineAction, PipelineError, }, - table::TableId, }; +use futures::StreamExt; +use postgres::schema::TableId; +use tokio::pin; +use tokio_postgres::types::PgLsn; +use tracing::{debug, info}; use super::BatchConfig; diff --git a/pg_replicate/src/pipeline/mod.rs b/pg_replicate/src/pipeline/mod.rs index b64d1637..ff734e8f 100644 --- a/pg_replicate/src/pipeline/mod.rs +++ b/pg_replicate/src/pipeline/mod.rs @@ -1,12 +1,11 @@ use std::collections::HashSet; +use postgres::schema::TableId; use sinks::SinkError; use sources::SourceError; use thiserror::Error; use tokio_postgres::types::PgLsn; -use crate::table::TableId; - pub mod batching; pub mod sinks; pub mod sources; diff --git a/pg_replicate/src/pipeline/sinks/bigquery.rs b/pg_replicate/src/pipeline/sinks/bigquery.rs index 57048f59..20330a28 100644 --- a/pg_replicate/src/pipeline/sinks/bigquery.rs +++ b/pg_replicate/src/pipeline/sinks/bigquery.rs @@ -2,19 +2,19 @@ use std::collections::HashMap; use async_trait::async_trait; use gcp_bigquery_client::error::BQError; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use thiserror::Error; use tokio_postgres::types::{PgLsn, Type}; use tracing::info; +use super::{BatchSink, SinkError}; +use crate::clients::bigquery::table_schema_to_descriptor; use crate::{ clients::bigquery::BigQueryClient, conversions::{cdc_event::CdcEvent, table_row::TableRow, Cell}, pipeline::PipelineResumptionState, - table::{ColumnSchema, TableId, TableName, TableSchema}, }; -use super::{BatchSink, SinkError}; - #[derive(Debug, Error)] pub enum BigQuerySinkError { #[error("big query error: {0}")] @@ -183,7 +183,7 @@ impl BatchSink for BigQueryBatchSink { ) -> Result<(), Self::Error> { let table_schema = self.get_table_schema(table_id)?; let table_name = Self::table_name_in_bq(&table_schema.table_name); - let table_descriptor = table_schema.into(); + let table_descriptor = table_schema_to_descriptor(table_schema); for table_row in &mut table_rows { table_row.values.push(Cell::String("UPSERT".to_string())); @@ -250,7 +250,7 @@ impl BatchSink for BigQueryBatchSink { for (table_id, table_rows) in table_name_to_table_rows { let table_schema = self.get_table_schema(table_id)?; let table_name = Self::table_name_in_bq(&table_schema.table_name); - let table_descriptor = table_schema.into(); + let table_descriptor = table_schema_to_descriptor(table_schema); self.client .stream_rows(&self.dataset_id, table_name, &table_descriptor, &table_rows) .await?; diff --git a/pg_replicate/src/pipeline/sinks/duckdb/executor.rs b/pg_replicate/src/pipeline/sinks/duckdb/executor.rs index 905d56e6..b6cdd64a 100644 --- a/pg_replicate/src/pipeline/sinks/duckdb/executor.rs +++ b/pg_replicate/src/pipeline/sinks/duckdb/executor.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use thiserror::Error; use tokio::sync::mpsc::{error::SendError, Receiver, Sender}; use tokio_postgres::types::{PgLsn, Type}; @@ -9,7 +10,6 @@ use crate::{ clients::duckdb::DuckDbClient, conversions::{cdc_event::CdcEvent, table_row::TableRow}, pipeline::{sinks::SinkError, PipelineResumptionState}, - table::{ColumnSchema, TableId, TableName, TableSchema}, }; pub enum DuckDbRequest { diff --git a/pg_replicate/src/pipeline/sinks/duckdb/sink.rs b/pg_replicate/src/pipeline/sinks/duckdb/sink.rs index 48393255..11f1171d 100644 --- a/pg_replicate/src/pipeline/sinks/duckdb/sink.rs +++ b/pg_replicate/src/pipeline/sinks/duckdb/sink.rs @@ -3,13 +3,13 @@ use std::{collections::HashMap, path::Path}; use tokio::sync::mpsc::{channel, Receiver, Sender}; use async_trait::async_trait; +use postgres::schema::{TableId, TableSchema}; use tokio_postgres::types::PgLsn; use crate::{ clients::duckdb::DuckDbClient, conversions::{cdc_event::CdcEvent, table_row::TableRow}, pipeline::{sinks::BatchSink, PipelineResumptionState}, - table::{TableId, TableSchema}, }; use super::{ diff --git a/pg_replicate/src/pipeline/sinks/mod.rs b/pg_replicate/src/pipeline/sinks/mod.rs index 94e9c101..79b6d2cb 100644 --- a/pg_replicate/src/pipeline/sinks/mod.rs +++ b/pg_replicate/src/pipeline/sinks/mod.rs @@ -1,13 +1,11 @@ use std::collections::HashMap; use async_trait::async_trait; +use postgres::schema::{TableId, TableSchema}; use thiserror::Error; use tokio_postgres::types::PgLsn; -use crate::{ - conversions::{cdc_event::CdcEvent, table_row::TableRow}, - table::{TableId, TableSchema}, -}; +use crate::conversions::{cdc_event::CdcEvent, table_row::TableRow}; use super::PipelineResumptionState; diff --git a/pg_replicate/src/pipeline/sinks/stdout.rs b/pg_replicate/src/pipeline/sinks/stdout.rs index 6d79dd1d..9e98525b 100644 --- a/pg_replicate/src/pipeline/sinks/stdout.rs +++ b/pg_replicate/src/pipeline/sinks/stdout.rs @@ -1,13 +1,13 @@ use std::collections::{HashMap, HashSet}; use async_trait::async_trait; +use postgres::schema::{TableId, TableSchema}; use tokio_postgres::types::PgLsn; use tracing::info; use crate::{ conversions::{cdc_event::CdcEvent, table_row::TableRow}, pipeline::PipelineResumptionState, - table::{TableId, TableSchema}, }; use super::{BatchSink, InfallibleSinkError}; diff --git a/pg_replicate/src/pipeline/sources/mod.rs b/pg_replicate/src/pipeline/sources/mod.rs index d35d1643..51f2feaf 100644 --- a/pg_replicate/src/pipeline/sources/mod.rs +++ b/pg_replicate/src/pipeline/sources/mod.rs @@ -1,12 +1,11 @@ use std::collections::HashMap; +use ::postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use async_trait::async_trait; use thiserror::Error; use tokio_postgres::types::PgLsn; -use crate::table::{ColumnSchema, TableId, TableName, TableSchema}; - -use self::postgres::{ +use postgres::{ CdcStream, CdcStreamError, PostgresSourceError, StatusUpdateError, TableCopyStream, TableCopyStreamError, }; diff --git a/pg_replicate/src/pipeline/sources/postgres.rs b/pg_replicate/src/pipeline/sources/postgres.rs index 92fc5e0e..b6b92e36 100644 --- a/pg_replicate/src/pipeline/sources/postgres.rs +++ b/pg_replicate/src/pipeline/sources/postgres.rs @@ -11,11 +11,12 @@ use crate::{ cdc_event::{CdcEvent, CdcEventConversionError, CdcEventConverter}, table_row::{TableRow, TableRowConversionError, TableRowConverter}, }, - table::{ColumnSchema, TableId, TableName, TableSchema}, }; + use async_trait::async_trait; use futures::{ready, Stream}; use pin_project_lite::pin_project; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use postgres::tokio::options::PgDatabaseOptions; use postgres_replication::LogicalReplicationStream; use rustls::pki_types::CertificateDer; @@ -52,7 +53,6 @@ pub struct PostgresSource { } impl PostgresSource { - #[allow(clippy::too_many_arguments)] pub async fn new( options: PgDatabaseOptions, trusted_root_certs: Vec>, diff --git a/pg_replicate/tests/common/mod.rs b/pg_replicate/tests/common/mod.rs new file mode 100644 index 00000000..626c2e4c --- /dev/null +++ b/pg_replicate/tests/common/mod.rs @@ -0,0 +1 @@ +pub mod pipeline; diff --git a/pg_replicate/tests/common/pipeline.rs b/pg_replicate/tests/common/pipeline.rs new file mode 100644 index 00000000..4715b783 --- /dev/null +++ b/pg_replicate/tests/common/pipeline.rs @@ -0,0 +1,88 @@ +use pg_replicate::pipeline::batching::data_pipeline::BatchDataPipeline; +use pg_replicate::pipeline::batching::BatchConfig; +use pg_replicate::pipeline::sinks::BatchSink; +use pg_replicate::pipeline::sources::postgres::{PostgresSource, TableNamesFrom}; +use pg_replicate::pipeline::PipelineAction; +use postgres::schema::TableName; +use postgres::tokio::options::PgDatabaseOptions; +use postgres::tokio::test_utils::PgDatabase; +use std::time::Duration; +use tokio_postgres::config::SslMode; +use uuid::Uuid; + +pub enum PipelineMode { + /// In this mode the supplied tables will be copied. + CopyTable { table_names: Vec }, + /// In this mode the changes will be consumed from the given publication and slot. + /// + /// If the slot is not supplied, a new one will be created on the supplied publication. + Cdc { + publication: String, + slot_name: Option, + }, +} + +pub async fn spawn_database_with_publication( + table_names: Vec, + publication_name: Option, +) -> PgDatabase { + let options = PgDatabaseOptions { + host: "localhost".to_owned(), + port: 540, + name: Uuid::new_v4().to_string(), + username: "postgres".to_owned(), + password: Some("postgres".to_owned()), + ssl_mode: SslMode::Disable, + }; + + let database = PgDatabase::new(options).await; + + if let Some(publication_name) = publication_name { + database + .create_publication(&publication_name, &table_names) + .await + .expect("Error while creating a publication"); + } + + database +} + +pub async fn spawn_pg_pipeline( + options: &PgDatabaseOptions, + mode: PipelineMode, + sink: Snk, +) -> BatchDataPipeline { + let batch_config = BatchConfig::new(1000, Duration::from_secs(10)); + + let pipeline = match mode { + PipelineMode::CopyTable { table_names } => { + let source = PostgresSource::new( + options.clone(), + vec![], + None, + TableNamesFrom::Vec(table_names), + ) + .await + .expect("Failure when creating the Postgres source for copying tables"); + let action = PipelineAction::TableCopiesOnly; + BatchDataPipeline::new(source, sink, action, batch_config) + } + PipelineMode::Cdc { + publication, + slot_name, + } => { + let source = PostgresSource::new( + options.clone(), + vec![], + slot_name, + TableNamesFrom::Publication(publication), + ) + .await + .expect("Failure when creating the Postgres source for cdc"); + let action = PipelineAction::CdcOnly; + BatchDataPipeline::new(source, sink, action, batch_config) + } + }; + + pipeline +} diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs new file mode 100644 index 00000000..da58d16c --- /dev/null +++ b/pg_replicate/tests/integration/base.rs @@ -0,0 +1,2 @@ +#[tokio::test(flavor = "multi_thread")] +async fn test_simple() {} diff --git a/pg_replicate/tests/integration/mod.rs b/pg_replicate/tests/integration/mod.rs new file mode 100644 index 00000000..77ed8456 --- /dev/null +++ b/pg_replicate/tests/integration/mod.rs @@ -0,0 +1 @@ +mod base; diff --git a/pg_replicate/tests/mod.rs b/pg_replicate/tests/mod.rs index 8b137891..e78302ae 100644 --- a/pg_replicate/tests/mod.rs +++ b/pg_replicate/tests/mod.rs @@ -1 +1,2 @@ - +mod common; +mod integration; diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index 5bed63ad..fcab8197 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] +pg_escape = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["std"] } secrecy = { workspace = true, features = ["serde", "alloc"] } diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index 0d348139..7ae1123c 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -1,3 +1,15 @@ +//! PostgreSQL database connection utilities for all crates. +//! +//! This crate provides database connection options and utilities for working with PostgreSQL. +//! It supports both the [`sqlx`] and [`tokio-postgres`] crates through feature flags. +//! +//! # Features +//! +//! - `sqlx`: Enables SQLx-specific database connection options and utilities +//! - `tokio`: Enables tokio-postgres-specific database connection options and utilities +//! - `test-utils`: Enables test utilities for both SQLx and tokio-postgres implementations + +pub mod schema; #[cfg(feature = "sqlx")] pub mod sqlx; #[cfg(feature = "tokio")] diff --git a/pg_replicate/src/table.rs b/postgres/src/schema.rs similarity index 88% rename from pg_replicate/src/table.rs rename to postgres/src/schema.rs index 73fabfd1..bcafd36f 100644 --- a/pg_replicate/src/table.rs +++ b/postgres/src/schema.rs @@ -1,4 +1,4 @@ -use std::fmt::Display; +use std::fmt; use pg_escape::quote_identifier; use tokio_postgres::types::Type; @@ -17,8 +17,8 @@ impl TableName { } } -impl Display for TableName { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for TableName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_fmt(format_args!("{0}.{1}", self.schema, self.name)) } } diff --git a/postgres/src/tokio/options.rs b/postgres/src/tokio/options.rs index 0a98ced9..4bb298f2 100644 --- a/postgres/src/tokio/options.rs +++ b/postgres/src/tokio/options.rs @@ -1,17 +1,35 @@ use tokio_postgres::Config; use tokio_postgres::config::SslMode; +/// Connection options for a PostgreSQL database. +/// +/// Contains the connection parameters needed to establish a connection to a PostgreSQL +/// database server, including network location, authentication credentials, and security +/// settings. #[derive(Debug, Clone)] pub struct PgDatabaseOptions { + /// Host name or IP address of the PostgreSQL server pub host: String, + /// Port number that the PostgreSQL server listens on pub port: u16, + /// Name of the target database pub name: String, + /// Username for authentication pub username: String, + /// Optional password for authentication pub password: Option, + /// SSL mode for the connection pub ssl_mode: SslMode, } impl PgDatabaseOptions { + /// Creates connection options for connecting to the PostgreSQL server without + /// specifying a database. + /// + /// Returns [`Config`] configured with the host, port, username, SSL mode and optional + /// password from this instance. The database name is set to the username as per + /// PostgreSQL convention. Useful for administrative operations that must be performed + /// before connecting to a specific database, like database creation. pub fn without_db(&self) -> Config { let mut this = self.clone(); // Postgres requires a database, so we default to the database which is equal to the username @@ -21,12 +39,20 @@ impl PgDatabaseOptions { this.into() } + /// Creates connection options for connecting to a specific database. + /// + /// Returns [`Config`] configured with all connection parameters including the database + /// name from this instance. pub fn with_db(&self) -> Config { self.clone().into() } } impl From for Config { + /// Converts [`PgDatabaseOptions`] into a [`Config`] instance. + /// + /// Sets all connection parameters including host, port, database name, username, + /// SSL mode, and optional password. fn from(value: PgDatabaseOptions) -> Self { let mut config = Config::new(); config diff --git a/postgres/src/tokio/test_utils.rs b/postgres/src/tokio/test_utils.rs index 2672d824..e7e10488 100644 --- a/postgres/src/tokio/test_utils.rs +++ b/postgres/src/tokio/test_utils.rs @@ -1,10 +1,11 @@ +use crate::schema::TableName; use crate::tokio::options::PgDatabaseOptions; use tokio::runtime::Handle; use tokio_postgres::{Client, NoTls}; -struct PgDatabase { - options: PgDatabaseOptions, - client: Client, +pub struct PgDatabase { + pub options: PgDatabaseOptions, + pub client: Client, } impl PgDatabase { @@ -13,6 +14,31 @@ impl PgDatabase { Self { options, client } } + + /// Creates a new publication for the specified tables. + pub async fn create_publication( + &self, + publication_name: &str, + table_names: &[TableName], + ) -> Result { + let table_names = table_names + .iter() + .map(|t| t.to_string()) + .collect::>(); + + let create_publication_query = format!( + "CREATE PUBLICATION {} FOR TABLE {}", + publication_name, + table_names + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(", ") + ); + self.client.execute(&create_publication_query, &[]).await?; + + Ok(publication_name.to_string()) + } } impl Drop for PgDatabase { From 467aab1937f79d9f59ccb4f8d75aafc97d977201 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Mon, 19 May 2025 16:55:49 +0200 Subject: [PATCH 05/34] Improve --- pg_replicate/tests/common/database.rs | 34 ++++++++++++ pg_replicate/tests/common/mod.rs | 11 ++++ pg_replicate/tests/common/pipeline.rs | 28 ---------- pg_replicate/tests/integration/base.rs | 75 +++++++++++++++++++++++++- postgres/src/schema.rs | 36 +++++++++++++ postgres/src/tokio/test_utils.rs | 64 ++++++++++++++++++++++ 6 files changed, 219 insertions(+), 29 deletions(-) create mode 100644 pg_replicate/tests/common/database.rs diff --git a/pg_replicate/tests/common/database.rs b/pg_replicate/tests/common/database.rs new file mode 100644 index 00000000..ffba7cdd --- /dev/null +++ b/pg_replicate/tests/common/database.rs @@ -0,0 +1,34 @@ +use tokio_postgres::config::SslMode; +use uuid::Uuid; +use postgres::schema::TableName; +use postgres::tokio::options::PgDatabaseOptions; +use postgres::tokio::test_utils::PgDatabase; + +pub async fn spawn_database() -> PgDatabase { + let options = PgDatabaseOptions { + host: "localhost".to_owned(), + port: 5430, + name: Uuid::new_v4().to_string(), + username: "postgres".to_owned(), + password: Some("postgres".to_owned()), + ssl_mode: SslMode::Disable, + }; + + PgDatabase::new(options).await +} + +pub async fn spawn_database_with_publication( + table_names: Vec, + publication_name: Option, +) -> PgDatabase { + let database = spawn_database().await; + + if let Some(publication_name) = publication_name { + database + .create_publication(&publication_name, &table_names) + .await + .expect("Error while creating a publication"); + } + + database +} \ No newline at end of file diff --git a/pg_replicate/tests/common/mod.rs b/pg_replicate/tests/common/mod.rs index 626c2e4c..6305ea2c 100644 --- a/pg_replicate/tests/common/mod.rs +++ b/pg_replicate/tests/common/mod.rs @@ -1 +1,12 @@ +use postgres::schema::TableName; + +pub mod database; pub mod pipeline; + +/// Creates a [`TableName`] on the `test` schema. +pub fn test_table_name(name: &str) -> TableName { + TableName { + schema: "test".to_owned(), + name: name.to_owned(), + } +} diff --git a/pg_replicate/tests/common/pipeline.rs b/pg_replicate/tests/common/pipeline.rs index 4715b783..65efbf2b 100644 --- a/pg_replicate/tests/common/pipeline.rs +++ b/pg_replicate/tests/common/pipeline.rs @@ -5,10 +5,7 @@ use pg_replicate::pipeline::sources::postgres::{PostgresSource, TableNamesFrom}; use pg_replicate::pipeline::PipelineAction; use postgres::schema::TableName; use postgres::tokio::options::PgDatabaseOptions; -use postgres::tokio::test_utils::PgDatabase; use std::time::Duration; -use tokio_postgres::config::SslMode; -use uuid::Uuid; pub enum PipelineMode { /// In this mode the supplied tables will be copied. @@ -22,31 +19,6 @@ pub enum PipelineMode { }, } -pub async fn spawn_database_with_publication( - table_names: Vec, - publication_name: Option, -) -> PgDatabase { - let options = PgDatabaseOptions { - host: "localhost".to_owned(), - port: 540, - name: Uuid::new_v4().to_string(), - username: "postgres".to_owned(), - password: Some("postgres".to_owned()), - ssl_mode: SslMode::Disable, - }; - - let database = PgDatabase::new(options).await; - - if let Some(publication_name) = publication_name { - database - .create_publication(&publication_name, &table_names) - .await - .expect("Error while creating a publication"); - } - - database -} - pub async fn spawn_pg_pipeline( options: &PgDatabaseOptions, mode: PipelineMode, diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs index da58d16c..fda6c308 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/base.rs @@ -1,2 +1,75 @@ +use std::collections::HashMap; +use async_trait::async_trait; +use tokio_postgres::types::PgLsn; +use pg_replicate::conversions::cdc_event::CdcEvent; +use pg_replicate::conversions::table_row::TableRow; +use pg_replicate::pipeline::PipelineResumptionState; +use pg_replicate::pipeline::sinks::{BatchSink, InfallibleSinkError}; +use postgres::schema::{TableId, TableSchema}; +use postgres::tokio::test_utils::PgDatabase; +use crate::common::database::spawn_database; +use crate::common::pipeline::{spawn_pg_pipeline, PipelineMode}; +use crate::common::test_table_name; + +struct SumSink { + +} + +#[async_trait] +impl BatchSink for SumSink { + type Error = InfallibleSinkError; + + async fn get_resumption_state(&mut self) -> Result { + todo!() + } + + async fn write_table_schemas(&mut self, table_schemas: HashMap) -> Result<(), Self::Error> { + todo!() + } + + async fn write_table_rows(&mut self, rows: Vec, table_id: TableId) -> Result<(), Self::Error> { + todo!() + } + + async fn write_cdc_events(&mut self, events: Vec) -> Result { + todo!() + } + + async fn table_copied(&mut self, table_id: TableId) -> Result<(), Self::Error> { + todo!() + } + + async fn truncate_table(&mut self, table_id: TableId) -> Result<(), Self::Error> { + todo!() + } +} + +async fn create_and_fill_users_table(database: &PgDatabase, num_rows: u32) { + database.create_table(test_table_name("users"), &vec![("age", "integer")]).await.unwrap(); + + for i in 0..num_rows { + let age = i + 1; + database.insert_values( + test_table_name("users"), + &["age"], + &[&age], + ).await.unwrap(); + } +} + #[tokio::test(flavor = "multi_thread")] -async fn test_simple() {} +async fn test_simple_table_copy() { + let database = spawn_database().await; + create_and_fill_users_table(&database, 100).await; + + let sink = SumSink {}; + let pipeline = spawn_pg_pipeline( + &database.options, + PipelineMode::CopyTable { + table_names: vec![test_table_name("users")] + }, + sink + ); + + +} diff --git a/postgres/src/schema.rs b/postgres/src/schema.rs index bcafd36f..6fcf565e 100644 --- a/postgres/src/schema.rs +++ b/postgres/src/schema.rs @@ -3,13 +3,23 @@ use std::fmt; use pg_escape::quote_identifier; use tokio_postgres::types::Type; +/// A fully qualified PostgreSQL table name consisting of a schema and table name. +/// +/// This type represents a table identifier in PostgreSQL, which requires both a schema name +/// and a table name. It provides methods for formatting the name in different contexts. #[derive(Debug, Clone)] pub struct TableName { + /// The schema name containing the table pub schema: String, + /// The name of the table within the schema pub name: String, } impl TableName { + /// Returns the table name as a properly quoted PostgreSQL identifier. + /// + /// This method ensures the schema and table names are properly escaped according to + /// PostgreSQL identifier quoting rules. pub fn as_quoted_identifier(&self) -> String { let quoted_schema = quote_identifier(&self.schema); let quoted_name = quote_identifier(&self.name); @@ -23,27 +33,53 @@ impl fmt::Display for TableName { } } +/// A type alias for PostgreSQL type modifiers. +/// +/// Type modifiers in PostgreSQL are used to specify additional type-specific attributes, +/// such as length for varchar or precision for numeric types. type TypeModifier = i32; +/// Represents the schema of a single column in a PostgreSQL table. +/// +/// This type contains all metadata about a column including its name, data type, +/// type modifier, nullability, and whether it's part of the primary key. #[derive(Debug, Clone)] pub struct ColumnSchema { + /// The name of the column pub name: String, + /// The PostgreSQL data type of the column pub typ: Type, + /// Type-specific modifier value (e.g., length for varchar) pub modifier: TypeModifier, + /// Whether the column can contain NULL values pub nullable: bool, + /// Whether the column is part of the table's primary key pub primary: bool, } +/// A type alias for PostgreSQL table OIDs. +/// +/// Table OIDs are unique identifiers assigned to tables in PostgreSQL. pub type TableId = u32; +/// Represents the complete schema of a PostgreSQL table. +/// +/// This type contains all metadata about a table including its name, OID, +/// and the schemas of all its columns. #[derive(Debug, Clone)] pub struct TableSchema { + /// The fully qualified name of the table pub table_name: TableName, + /// The PostgreSQL OID of the table pub table_id: TableId, + /// The schemas of all columns in the table pub column_schemas: Vec, } impl TableSchema { + /// Returns whether the table has any primary key columns. + /// + /// This method checks if any column in the table is marked as part of the primary key. pub fn has_primary_keys(&self) -> bool { self.column_schemas.iter().any(|cs| cs.primary) } diff --git a/postgres/src/tokio/test_utils.rs b/postgres/src/tokio/test_utils.rs index e7e10488..9f72fe14 100644 --- a/postgres/src/tokio/test_utils.rs +++ b/postgres/src/tokio/test_utils.rs @@ -39,6 +39,70 @@ impl PgDatabase { Ok(publication_name.to_string()) } + + /// Creates a new table with the specified name and columns. + pub async fn create_table( + &self, + table_name: TableName, + columns: &[(&str, &str)], // (column_name, column_type) + ) -> Result<(), tokio_postgres::Error> { + let columns_str = columns + .iter() + .map(|(name, typ)| format!("{} {}", name, typ.to_string())) + .collect::>() + .join(", "); + + let create_table_query = format!( + "CREATE TABLE {} (id BIGSERIAL PRIMARY KEY, {})", + table_name.as_quoted_identifier(), + columns_str + ); + self.client.execute(&create_table_query, &[]).await?; + Ok(()) + } + + /// Inserts values into the specified table. + pub async fn insert_values( + &self, + table_name: TableName, + columns: &[&str], + values: &[&(dyn tokio_postgres::types::ToSql + Sync)], + ) -> Result { + let columns_str = columns.join(", "); + let placeholders: Vec = (1..=values.len()).map(|i| format!("${}", i)).collect(); + let placeholders_str = placeholders.join(", "); + + let insert_query = format!( + "INSERT INTO {} ({}) VALUES ({})", + table_name.as_quoted_identifier(), + columns_str, + placeholders_str + ); + + self.client.execute(&insert_query, values).await + } + + /// Queries rows from a single column of a table. + pub async fn query_table( + &self, + table_name: &TableName, + column: &str, + where_clause: Option<&str>, + ) -> Result, tokio_postgres::Error> + where + T: for<'a> tokio_postgres::types::FromSql<'a>, + { + let where_str = where_clause.map_or(String::new(), |w| format!(" WHERE {}", w)); + let query = format!( + "SELECT {} FROM {}{}", + column, + table_name.as_quoted_identifier(), + where_str + ); + + let rows = self.client.query(&query, &[]).await?; + Ok(rows.iter().map(|row| row.get(0)).collect()) + } } impl Drop for PgDatabase { From bbcfdbd1c87e84ca4005926dc5d63cc6099ebe84 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Tue, 20 May 2025 10:41:57 +0200 Subject: [PATCH 06/34] Improve --- .../src/pipeline/batching/data_pipeline.rs | 8 + pg_replicate/tests/common/database.rs | 50 ++++- pg_replicate/tests/common/mod.rs | 10 - pg_replicate/tests/integration/base.rs | 188 ++++++++++++++---- postgres/src/schema.rs | 2 +- postgres/src/tokio/test_utils.rs | 27 ++- 6 files changed, 226 insertions(+), 59 deletions(-) diff --git a/pg_replicate/src/pipeline/batching/data_pipeline.rs b/pg_replicate/src/pipeline/batching/data_pipeline.rs index 8c506fa9..2ba08a55 100644 --- a/pg_replicate/src/pipeline/batching/data_pipeline.rs +++ b/pg_replicate/src/pipeline/batching/data_pipeline.rs @@ -200,4 +200,12 @@ impl BatchDataPipeline { Ok(()) } + + pub fn source(&self) -> &Src { + &self.source + } + + pub fn sink(&self) -> &Snk { + &self.sink + } } diff --git a/pg_replicate/tests/common/database.rs b/pg_replicate/tests/common/database.rs index ffba7cdd..c48a7ae5 100644 --- a/pg_replicate/tests/common/database.rs +++ b/pg_replicate/tests/common/database.rs @@ -1,22 +1,64 @@ -use tokio_postgres::config::SslMode; -use uuid::Uuid; use postgres::schema::TableName; use postgres::tokio::options::PgDatabaseOptions; use postgres::tokio::test_utils::PgDatabase; +use tokio_postgres::config::SslMode; +use uuid::Uuid; +/// The default schema name used for test tables. +const TEST_DATABASE_SCHEMA: &str = "test"; + +/// Creates a [`TableName`] in the test schema. +/// +/// This helper function constructs a [`TableName`] with the schema set to the test schema +/// and the provided name as the table name. +pub fn test_table_name(name: &str) -> TableName { + TableName { + schema: TEST_DATABASE_SCHEMA.to_owned(), + name: name.to_owned(), + } +} + +/// Creates a new test database instance. +/// +/// This function spawns a new PostgreSQL database with a random UUID as its name, +/// using default credentials and disabled SSL. It also creates a test schema +/// for organizing test tables. +/// +/// # Panics +/// +/// Panics if the test schema cannot be created. pub async fn spawn_database() -> PgDatabase { let options = PgDatabaseOptions { host: "localhost".to_owned(), port: 5430, + // We create a random database name to avoid conflicts with existing databases. name: Uuid::new_v4().to_string(), username: "postgres".to_owned(), password: Some("postgres".to_owned()), ssl_mode: SslMode::Disable, }; - PgDatabase::new(options).await + let database = PgDatabase::new(options).await; + + // Create the test schema. + database + .client + .execute(&format!("CREATE SCHEMA {}", TEST_DATABASE_SCHEMA), &[]) + .await + .expect("Failed to create test schema"); + + database } +/// Creates a new test database instance with an optional publication. +/// +/// This function creates a test database and optionally sets up a publication +/// for the specified tables. The publication can be used for testing replication +/// scenarios. +/// +/// # Panics +/// +/// Panics if the publication cannot be created. pub async fn spawn_database_with_publication( table_names: Vec, publication_name: Option, @@ -31,4 +73,4 @@ pub async fn spawn_database_with_publication( } database -} \ No newline at end of file +} diff --git a/pg_replicate/tests/common/mod.rs b/pg_replicate/tests/common/mod.rs index 6305ea2c..50308081 100644 --- a/pg_replicate/tests/common/mod.rs +++ b/pg_replicate/tests/common/mod.rs @@ -1,12 +1,2 @@ -use postgres::schema::TableName; - pub mod database; pub mod pipeline; - -/// Creates a [`TableName`] on the `test` schema. -pub fn test_table_name(name: &str) -> TableName { - TableName { - schema: "test".to_owned(), - name: name.to_owned(), - } -} diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs index fda6c308..8ad922ae 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/base.rs @@ -1,75 +1,189 @@ -use std::collections::HashMap; +use crate::common::database::{spawn_database, test_table_name}; +use crate::common::pipeline::{spawn_pg_pipeline, PipelineMode}; use async_trait::async_trait; -use tokio_postgres::types::PgLsn; use pg_replicate::conversions::cdc_event::CdcEvent; use pg_replicate::conversions::table_row::TableRow; -use pg_replicate::pipeline::PipelineResumptionState; +use pg_replicate::conversions::Cell; use pg_replicate::pipeline::sinks::{BatchSink, InfallibleSinkError}; -use postgres::schema::{TableId, TableSchema}; +use pg_replicate::pipeline::PipelineResumptionState; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use postgres::tokio::test_utils::PgDatabase; -use crate::common::database::spawn_database; -use crate::common::pipeline::{spawn_pg_pipeline, PipelineMode}; -use crate::common::test_table_name; +use std::collections::{HashMap, HashSet}; +use tokio_postgres::types::{PgLsn, Type}; -struct SumSink { - +struct TestSink { + table_schemas: HashMap, + table_rows: HashMap>, + events: Vec, + tables_copied: u8, + tables_truncated: u8, +} + +impl TestSink { + fn new() -> Self { + Self { + table_schemas: HashMap::new(), + table_rows: HashMap::new(), + events: Vec::new(), + tables_copied: 0, + tables_truncated: 0, + } + } } #[async_trait] -impl BatchSink for SumSink { +impl BatchSink for TestSink { type Error = InfallibleSinkError; async fn get_resumption_state(&mut self) -> Result { - todo!() + Ok(PipelineResumptionState { + copied_tables: HashSet::new(), + last_lsn: PgLsn::from(0), + }) } - async fn write_table_schemas(&mut self, table_schemas: HashMap) -> Result<(), Self::Error> { - todo!() + async fn write_table_schemas( + &mut self, + table_schemas: HashMap, + ) -> Result<(), Self::Error> { + self.table_schemas = table_schemas; + Ok(()) } - async fn write_table_rows(&mut self, rows: Vec, table_id: TableId) -> Result<(), Self::Error> { - todo!() + async fn write_table_rows( + &mut self, + rows: Vec, + table_id: TableId, + ) -> Result<(), Self::Error> { + self.table_rows.entry(table_id).or_default().extend(rows); + Ok(()) } async fn write_cdc_events(&mut self, events: Vec) -> Result { - todo!() + self.events.extend(events); + Ok(PgLsn::from(0)) } - async fn table_copied(&mut self, table_id: TableId) -> Result<(), Self::Error> { - todo!() + async fn table_copied(&mut self, _table_id: TableId) -> Result<(), Self::Error> { + self.tables_copied += 1; + Ok(()) } - async fn truncate_table(&mut self, table_id: TableId) -> Result<(), Self::Error> { - todo!() + async fn truncate_table(&mut self, _table_id: TableId) -> Result<(), Self::Error> { + self.tables_truncated += 1; + Ok(()) + } +} + +async fn create_and_fill_users_table(database: &PgDatabase, num_users: usize) -> TableId { + let table_id = database + .create_table(test_table_name("users"), &vec![("age", "integer")]) + .await + .unwrap(); + + for i in 0..num_users { + let age = i as i32 + 1; + database + .insert_values(test_table_name("users"), &["age"], &[&age]) + .await + .unwrap(); + } + + table_id +} + +fn assert_table_schema( + sink: &TestSink, + table_id: TableId, + expected_table_name: TableName, + expected_columns: &[ColumnSchema], +) { + let table_schema = sink.table_schemas.get(&table_id).unwrap(); + + assert_eq!(table_schema.table_id, table_id); + assert_eq!(table_schema.table_name, expected_table_name); + + let columns = &table_schema.column_schemas; + assert_eq!(columns.len(), expected_columns.len()); + + for (actual, expected) in columns.iter().zip(expected_columns.iter()) { + assert_eq!(actual.name, expected.name); + assert_eq!(actual.typ, expected.typ); + assert_eq!(actual.modifier, expected.modifier); + assert_eq!(actual.nullable, expected.nullable); + assert_eq!(actual.primary, expected.primary); } } -async fn create_and_fill_users_table(database: &PgDatabase, num_rows: u32) { - database.create_table(test_table_name("users"), &vec![("age", "integer")]).await.unwrap(); +fn assert_users_table_schema(sink: &TestSink, users_table_id: TableId) { + let expected_columns = vec![ + ColumnSchema { + name: "id".to_string(), + typ: Type::INT8, + modifier: -1, + nullable: false, + primary: true, + }, + ColumnSchema { + name: "age".to_string(), + typ: Type::INT4, + modifier: -1, + nullable: true, + primary: false, + }, + ]; + + assert_table_schema( + sink, + users_table_id, + test_table_name("users"), + &expected_columns, + ); +} - for i in 0..num_rows { - let age = i + 1; - database.insert_values( - test_table_name("users"), - &["age"], - &[&age], - ).await.unwrap(); +fn assert_users_age_sum(sink: &TestSink, users_table_id: TableId, num_users: usize) { + let mut actual_sum = 0; + let expected_sum = ((num_users * (num_users + 1)) / 2) as i32; + let rows = sink.table_rows.get(&users_table_id).unwrap(); + for row in rows { + if let Cell::I32(age) = &row.values[1] { + actual_sum += age; + } } + + assert_eq!(actual_sum, expected_sum); } +/* +Tests to write: +- Insert -> table copy +- Insert -> Update -> table copy +- Insert -> cdc +- Insert -> Update -> cdc +- Insert -> cdc -> Update -> cdc +- Insert -> table copy -> crash while copying -> add new table -> check if new table is in the snapshot + */ + #[tokio::test(flavor = "multi_thread")] async fn test_simple_table_copy() { let database = spawn_database().await; - create_and_fill_users_table(&database, 100).await; - let sink = SumSink {}; - let pipeline = spawn_pg_pipeline( + // We insert 100 rows. + let users_table_id = create_and_fill_users_table(&database, 100).await; + + // We create a pipeline that copies the users table. + let mut pipeline = spawn_pg_pipeline( &database.options, PipelineMode::CopyTable { - table_names: vec![test_table_name("users")] + table_names: vec![test_table_name("users")], }, - sink - ); - - + TestSink::new(), + ) + .await; + pipeline.start().await.unwrap(); + + assert_users_table_schema(pipeline.sink(), users_table_id); + assert_users_age_sum(pipeline.sink(), users_table_id, 100); + assert_eq!(pipeline.sink().tables_copied, 1); + assert_eq!(pipeline.sink().tables_truncated, 1); } diff --git a/postgres/src/schema.rs b/postgres/src/schema.rs index 6fcf565e..f8b1cacb 100644 --- a/postgres/src/schema.rs +++ b/postgres/src/schema.rs @@ -7,7 +7,7 @@ use tokio_postgres::types::Type; /// /// This type represents a table identifier in PostgreSQL, which requires both a schema name /// and a table name. It provides methods for formatting the name in different contexts. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct TableName { /// The schema name containing the table pub schema: String, diff --git a/postgres/src/tokio/test_utils.rs b/postgres/src/tokio/test_utils.rs index 9f72fe14..6cdbb483 100644 --- a/postgres/src/tokio/test_utils.rs +++ b/postgres/src/tokio/test_utils.rs @@ -1,4 +1,4 @@ -use crate::schema::TableName; +use crate::schema::{TableId, TableName}; use crate::tokio::options::PgDatabaseOptions; use tokio::runtime::Handle; use tokio_postgres::{Client, NoTls}; @@ -45,7 +45,7 @@ impl PgDatabase { &self, table_name: TableName, columns: &[(&str, &str)], // (column_name, column_type) - ) -> Result<(), tokio_postgres::Error> { + ) -> Result { let columns_str = columns .iter() .map(|(name, typ)| format!("{} {}", name, typ.to_string())) @@ -58,7 +58,20 @@ impl PgDatabase { columns_str ); self.client.execute(&create_table_query, &[]).await?; - Ok(()) + + // Get the OID of the newly created table + let row = self + .client + .query_one( + "SELECT c.oid FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace \ + WHERE n.nspname = $1 AND c.relname = $2", + &[&table_name.schema, &table_name.name], + ) + .await?; + + let table_id: TableId = row.get(0); + + Ok(table_id) } /// Inserts values into the specified table. @@ -71,14 +84,14 @@ impl PgDatabase { let columns_str = columns.join(", "); let placeholders: Vec = (1..=values.len()).map(|i| format!("${}", i)).collect(); let placeholders_str = placeholders.join(", "); - + let insert_query = format!( "INSERT INTO {} ({}) VALUES ({})", table_name.as_quoted_identifier(), columns_str, placeholders_str ); - + self.client.execute(&insert_query, values).await } @@ -88,7 +101,7 @@ impl PgDatabase { table_name: &TableName, column: &str, where_clause: Option<&str>, - ) -> Result, tokio_postgres::Error> + ) -> Result, tokio_postgres::Error> where T: for<'a> tokio_postgres::types::FromSql<'a>, { @@ -99,7 +112,7 @@ impl PgDatabase { table_name.as_quoted_identifier(), where_str ); - + let rows = self.client.query(&query, &[]).await?; Ok(rows.iter().map(|row| row.get(0)).collect()) } From 4dde13659994ee0d9fd22339c4acc9d959fd74ff Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Tue, 20 May 2025 12:38:57 +0200 Subject: [PATCH 07/34] Improve --- pg_replicate/src/conversions/table_row.rs | 2 +- pg_replicate/src/pipeline/sources/postgres.rs | 1 + pg_replicate/tests/common/database.rs | 25 --- pg_replicate/tests/common/mod.rs | 2 + pg_replicate/tests/common/pipeline.rs | 4 +- pg_replicate/tests/common/sink.rs | 117 ++++++++++ pg_replicate/tests/common/table.rs | 27 +++ pg_replicate/tests/integration/base.rs | 202 +++++++++--------- postgres/src/tokio/test_utils.rs | 23 ++ scripts/init_db.sh | 3 +- 10 files changed, 272 insertions(+), 134 deletions(-) create mode 100644 pg_replicate/tests/common/sink.rs create mode 100644 pg_replicate/tests/common/table.rs diff --git a/pg_replicate/src/conversions/table_row.rs b/pg_replicate/src/conversions/table_row.rs index 12a7dd96..5273b071 100644 --- a/pg_replicate/src/conversions/table_row.rs +++ b/pg_replicate/src/conversions/table_row.rs @@ -10,7 +10,7 @@ use crate::{conversions::text::TextFormatConverter, pipeline::batching::BatchBou use super::{text::FromTextError, Cell}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct TableRow { pub values: Vec, } diff --git a/pg_replicate/src/pipeline/sources/postgres.rs b/pg_replicate/src/pipeline/sources/postgres.rs index b6b92e36..2e75ce9b 100644 --- a/pg_replicate/src/pipeline/sources/postgres.rs +++ b/pg_replicate/src/pipeline/sources/postgres.rs @@ -160,6 +160,7 @@ impl Source for PostgresSource { let slot_name = self .slot_name() .ok_or(PostgresSourceError::MissingSlotName)?; + let stream = self .replication_client .get_logical_replication_stream(publication, slot_name, start_lsn) diff --git a/pg_replicate/tests/common/database.rs b/pg_replicate/tests/common/database.rs index c48a7ae5..8284b49b 100644 --- a/pg_replicate/tests/common/database.rs +++ b/pg_replicate/tests/common/database.rs @@ -49,28 +49,3 @@ pub async fn spawn_database() -> PgDatabase { database } - -/// Creates a new test database instance with an optional publication. -/// -/// This function creates a test database and optionally sets up a publication -/// for the specified tables. The publication can be used for testing replication -/// scenarios. -/// -/// # Panics -/// -/// Panics if the publication cannot be created. -pub async fn spawn_database_with_publication( - table_names: Vec, - publication_name: Option, -) -> PgDatabase { - let database = spawn_database().await; - - if let Some(publication_name) = publication_name { - database - .create_publication(&publication_name, &table_names) - .await - .expect("Error while creating a publication"); - } - - database -} diff --git a/pg_replicate/tests/common/mod.rs b/pg_replicate/tests/common/mod.rs index 50308081..eb72c70f 100644 --- a/pg_replicate/tests/common/mod.rs +++ b/pg_replicate/tests/common/mod.rs @@ -1,2 +1,4 @@ pub mod database; pub mod pipeline; +pub mod sink; +pub mod table; diff --git a/pg_replicate/tests/common/pipeline.rs b/pg_replicate/tests/common/pipeline.rs index 65efbf2b..a0ef2078 100644 --- a/pg_replicate/tests/common/pipeline.rs +++ b/pg_replicate/tests/common/pipeline.rs @@ -15,7 +15,7 @@ pub enum PipelineMode { /// If the slot is not supplied, a new one will be created on the supplied publication. Cdc { publication: String, - slot_name: Option, + slot_name: String, }, } @@ -46,7 +46,7 @@ pub async fn spawn_pg_pipeline( let source = PostgresSource::new( options.clone(), vec![], - slot_name, + Some(slot_name), TableNamesFrom::Publication(publication), ) .await diff --git a/pg_replicate/tests/common/sink.rs b/pg_replicate/tests/common/sink.rs new file mode 100644 index 00000000..6676fc5f --- /dev/null +++ b/pg_replicate/tests/common/sink.rs @@ -0,0 +1,117 @@ +use async_trait::async_trait; +use pg_replicate::conversions::cdc_event::CdcEvent; +use pg_replicate::conversions::table_row::TableRow; +use pg_replicate::pipeline::sinks::{BatchSink, InfallibleSinkError}; +use pg_replicate::pipeline::PipelineResumptionState; +use postgres::schema::{TableId, TableSchema}; +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, Mutex}; +use tokio_postgres::types::PgLsn; + +#[derive(Debug, Clone)] +pub struct TestSink { + // We use Arc to allow the sink to be shared by multiple pipelines, effectively + // simulating recreating pipelines with a sink that "persists" data. + inner: Arc>, +} + +#[derive(Debug)] +struct TestSinkInner { + // We have a Vec to store all the changes of the schema that we receive over time. + tables_schemas: Vec>, + tables_rows: HashMap>, + events: Vec>, + tables_copied: u8, + tables_truncated: u8, +} + +impl TestSink { + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(TestSinkInner { + tables_schemas: Vec::new(), + tables_rows: HashMap::new(), + events: Vec::new(), + tables_copied: 0, + tables_truncated: 0, + })), + } + } + + pub fn get_tables_schemas(&self) -> Vec> { + self.inner.lock().unwrap().tables_schemas.clone() + } + + pub fn get_tables_rows(&self) -> HashMap> { + self.inner.lock().unwrap().tables_rows.clone() + } + + pub fn get_events(&self) -> Vec> { + self.inner.lock().unwrap().events.clone() + } + + pub fn get_tables_copied(&self) -> u8 { + self.inner.lock().unwrap().tables_copied + } + + pub fn get_tables_truncated(&self) -> u8 { + self.inner.lock().unwrap().tables_truncated + } +} + +#[async_trait] +impl BatchSink for TestSink { + type Error = InfallibleSinkError; + + async fn get_resumption_state(&mut self) -> Result { + Ok(PipelineResumptionState { + copied_tables: HashSet::new(), + last_lsn: PgLsn::from(0), + }) + } + + async fn write_table_schemas( + &mut self, + table_schemas: HashMap, + ) -> Result<(), Self::Error> { + self.inner + .lock() + .unwrap() + .tables_schemas + .push(table_schemas); + Ok(()) + } + + async fn write_table_rows( + &mut self, + rows: Vec, + table_id: TableId, + ) -> Result<(), Self::Error> { + self.inner + .lock() + .unwrap() + .tables_rows + .entry(table_id) + .or_default() + .extend(rows); + Ok(()) + } + + async fn write_cdc_events(&mut self, events: Vec) -> Result { + // Since CdcEvent is not Clone, we have to wrap it in an Arc, and we are fine with this + // since it's not mutable, so we don't even have to use mutexes. + let arc_events = events.into_iter().map(Arc::new).collect::>(); + self.inner.lock().unwrap().events.extend(arc_events); + Ok(PgLsn::from(0)) + } + + async fn table_copied(&mut self, _table_id: TableId) -> Result<(), Self::Error> { + self.inner.lock().unwrap().tables_copied += 1; + Ok(()) + } + + async fn truncate_table(&mut self, _table_id: TableId) -> Result<(), Self::Error> { + self.inner.lock().unwrap().tables_truncated += 1; + Ok(()) + } +} diff --git a/pg_replicate/tests/common/table.rs b/pg_replicate/tests/common/table.rs new file mode 100644 index 00000000..3392da22 --- /dev/null +++ b/pg_replicate/tests/common/table.rs @@ -0,0 +1,27 @@ +use crate::common::sink::TestSink; +use postgres::schema::{ColumnSchema, TableId, TableName}; + +pub fn assert_table_schema( + sink: &TestSink, + table_id: TableId, + schema_index: usize, + expected_table_name: TableName, + expected_columns: &[ColumnSchema], +) { + let tables_schemas = &sink.get_tables_schemas()[schema_index]; + let table_schema = tables_schemas.get(&table_id).unwrap(); + + assert_eq!(table_schema.table_id, table_id); + assert_eq!(table_schema.table_name, expected_table_name); + + let columns = &table_schema.column_schemas; + assert_eq!(columns.len(), expected_columns.len()); + + for (actual, expected) in columns.iter().zip(expected_columns.iter()) { + assert_eq!(actual.name, expected.name); + assert_eq!(actual.typ, expected.typ); + assert_eq!(actual.modifier, expected.modifier); + assert_eq!(actual.nullable, expected.nullable); + assert_eq!(actual.primary, expected.primary); + } +} diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs index 8ad922ae..d560fe4b 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/base.rs @@ -1,86 +1,40 @@ use crate::common::database::{spawn_database, test_table_name}; use crate::common::pipeline::{spawn_pg_pipeline, PipelineMode}; -use async_trait::async_trait; -use pg_replicate::conversions::cdc_event::CdcEvent; -use pg_replicate::conversions::table_row::TableRow; +use crate::common::sink::TestSink; +use crate::common::table::assert_table_schema; use pg_replicate::conversions::Cell; -use pg_replicate::pipeline::sinks::{BatchSink, InfallibleSinkError}; -use pg_replicate::pipeline::PipelineResumptionState; -use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; +use postgres::schema::{ColumnSchema, TableId}; use postgres::tokio::test_utils::PgDatabase; -use std::collections::{HashMap, HashSet}; -use tokio_postgres::types::{PgLsn, Type}; - -struct TestSink { - table_schemas: HashMap, - table_rows: HashMap>, - events: Vec, - tables_copied: u8, - tables_truncated: u8, -} +use tokio_postgres::types::Type; -impl TestSink { - fn new() -> Self { - Self { - table_schemas: HashMap::new(), - table_rows: HashMap::new(), - events: Vec::new(), - tables_copied: 0, - tables_truncated: 0, - } - } +fn get_expected_ages_sum(num_users: usize) -> i32 { + ((num_users * (num_users + 1)) / 2) as i32 } -#[async_trait] -impl BatchSink for TestSink { - type Error = InfallibleSinkError; - - async fn get_resumption_state(&mut self) -> Result { - Ok(PipelineResumptionState { - copied_tables: HashSet::new(), - last_lsn: PgLsn::from(0), - }) - } - - async fn write_table_schemas( - &mut self, - table_schemas: HashMap, - ) -> Result<(), Self::Error> { - self.table_schemas = table_schemas; - Ok(()) - } - - async fn write_table_rows( - &mut self, - rows: Vec, - table_id: TableId, - ) -> Result<(), Self::Error> { - self.table_rows.entry(table_id).or_default().extend(rows); - Ok(()) - } - - async fn write_cdc_events(&mut self, events: Vec) -> Result { - self.events.extend(events); - Ok(PgLsn::from(0)) - } - - async fn table_copied(&mut self, _table_id: TableId) -> Result<(), Self::Error> { - self.tables_copied += 1; - Ok(()) - } +async fn create_users_table(database: &PgDatabase) -> TableId { + let table_id = database + .create_table(test_table_name("users"), &vec![("age", "integer")]) + .await + .unwrap(); - async fn truncate_table(&mut self, _table_id: TableId) -> Result<(), Self::Error> { - self.tables_truncated += 1; - Ok(()) - } + table_id } -async fn create_and_fill_users_table(database: &PgDatabase, num_users: usize) -> TableId { - let table_id = database - .create_table(test_table_name("users"), &vec![("age", "integer")]) +async fn create_users_table_with_publication( + database: &PgDatabase, + publication_name: &str, +) -> TableId { + let table_id = create_users_table(database).await; + + database + .create_publication(publication_name, &vec![test_table_name("users")]) .await .unwrap(); + table_id +} + +async fn fill_users(database: &PgDatabase, num_users: usize) { for i in 0..num_users { let age = i as i32 + 1; database @@ -88,34 +42,16 @@ async fn create_and_fill_users_table(database: &PgDatabase, num_users: usize) -> .await .unwrap(); } - - table_id } -fn assert_table_schema( - sink: &TestSink, - table_id: TableId, - expected_table_name: TableName, - expected_columns: &[ColumnSchema], -) { - let table_schema = sink.table_schemas.get(&table_id).unwrap(); - - assert_eq!(table_schema.table_id, table_id); - assert_eq!(table_schema.table_name, expected_table_name); - - let columns = &table_schema.column_schemas; - assert_eq!(columns.len(), expected_columns.len()); - - for (actual, expected) in columns.iter().zip(expected_columns.iter()) { - assert_eq!(actual.name, expected.name); - assert_eq!(actual.typ, expected.typ); - assert_eq!(actual.modifier, expected.modifier); - assert_eq!(actual.nullable, expected.nullable); - assert_eq!(actual.primary, expected.primary); - } +async fn double_users_ages(database: &PgDatabase) { + database + .update_values(test_table_name("users"), &["age"], &[&"age * 2"]) + .await + .unwrap(); } -fn assert_users_table_schema(sink: &TestSink, users_table_id: TableId) { +fn assert_users_table_schema(sink: &TestSink, users_table_id: TableId, schema_index: usize) { let expected_columns = vec![ ColumnSchema { name: "id".to_string(), @@ -136,17 +72,19 @@ fn assert_users_table_schema(sink: &TestSink, users_table_id: TableId) { assert_table_schema( sink, users_table_id, + schema_index, test_table_name("users"), &expected_columns, ); } -fn assert_users_age_sum(sink: &TestSink, users_table_id: TableId, num_users: usize) { +fn assert_users_age_sum(sink: &TestSink, users_table_id: TableId, expected_sum: i32) { let mut actual_sum = 0; - let expected_sum = ((num_users * (num_users + 1)) / 2) as i32; - let rows = sink.table_rows.get(&users_table_id).unwrap(); - for row in rows { - if let Cell::I32(age) = &row.values[1] { + + let tables_rows = sink.get_tables_rows(); + let table_rows = tables_rows.get(&users_table_id).unwrap(); + for table_row in table_rows { + if let Cell::I32(age) = &table_row.values[1] { actual_sum += age; } } @@ -162,14 +100,17 @@ Tests to write: - Insert -> Update -> cdc - Insert -> cdc -> Update -> cdc - Insert -> table copy -> crash while copying -> add new table -> check if new table is in the snapshot + +The main test we want to do is to check if resuming after a new table has been added causes problems */ #[tokio::test(flavor = "multi_thread")] -async fn test_simple_table_copy() { +async fn test_table_copy_with_insert_and_update() { let database = spawn_database().await; // We insert 100 rows. - let users_table_id = create_and_fill_users_table(&database, 100).await; + let users_table_id = create_users_table(&database).await; + fill_users(&database, 100).await; // We create a pipeline that copies the users table. let mut pipeline = spawn_pg_pipeline( @@ -182,8 +123,59 @@ async fn test_simple_table_copy() { .await; pipeline.start().await.unwrap(); - assert_users_table_schema(pipeline.sink(), users_table_id); - assert_users_age_sum(pipeline.sink(), users_table_id, 100); - assert_eq!(pipeline.sink().tables_copied, 1); - assert_eq!(pipeline.sink().tables_truncated, 1); + assert_users_table_schema(pipeline.sink(), users_table_id, 0); + let expected_sum = get_expected_ages_sum(100); + assert_users_age_sum(pipeline.sink(), users_table_id, expected_sum); + assert_eq!(pipeline.sink().get_tables_copied(), 1); + assert_eq!(pipeline.sink().get_tables_truncated(), 1); + + // We double the user ages. + double_users_ages(&database).await; + + // We recreate the pipeline to copy again and see if we have the new data. + let mut pipeline = spawn_pg_pipeline( + &database.options, + PipelineMode::CopyTable { + table_names: vec![test_table_name("users")], + }, + TestSink::new(), + ) + .await; + pipeline.start().await.unwrap(); + + assert_users_table_schema(pipeline.sink(), users_table_id, 0); + let expected_sum = expected_sum * 2; + assert_users_age_sum(pipeline.sink(), users_table_id, expected_sum); + assert_eq!(pipeline.sink().get_tables_copied(), 1); + assert_eq!(pipeline.sink().get_tables_truncated(), 1); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_cdc_with_insert_and_update() { + let database = spawn_database().await; + + // We create the table and publication. + let users_table_id = create_users_table_with_publication(&database, "users_publication").await; + + // We create a pipeline that subscribes to the changes of the users table. + let mut pipeline = spawn_pg_pipeline( + &database.options, + PipelineMode::Cdc { + publication: "users_publication".to_owned(), + slot_name: "users_slot".to_string(), + }, + TestSink::new(), + ) + .await; + pipeline.start().await.unwrap(); + + // We insert 100 rows. + fill_users(&database, 100).await; + + // assert_users_table_schema(pipeline.sink(), users_table_id, 0); + // let expected_sum = expected_sum * 2; + // assert_users_age_sum(pipeline.sink(), users_table_id, expected_sum); + println!("CDC {:?}", pipeline.sink().get_events()); + assert_eq!(pipeline.sink().get_tables_copied(), 0); + assert_eq!(pipeline.sink().get_tables_truncated(), 0); } diff --git a/postgres/src/tokio/test_utils.rs b/postgres/src/tokio/test_utils.rs index 6cdbb483..2876c666 100644 --- a/postgres/src/tokio/test_utils.rs +++ b/postgres/src/tokio/test_utils.rs @@ -95,6 +95,29 @@ impl PgDatabase { self.client.execute(&insert_query, values).await } + /// Updates all rows in the specified table with the given values. + pub async fn update_values( + &self, + table_name: TableName, + columns: &[&str], + values: &[&str], + ) -> Result { + let set_clauses: Vec = columns + .iter() + .zip(values.iter()) + .map(|(col, val)| format!("{} = {}", col, val)) + .collect(); + let set_clause = set_clauses.join(", "); + + let update_query = format!( + "UPDATE {} SET {}", + table_name.as_quoted_identifier(), + set_clause + ); + + self.client.execute(&update_query, &[]).await + } + /// Queries rows from a single column of a table. pub async fn query_table( &self, diff --git a/scripts/init_db.sh b/scripts/init_db.sh index 15ff51a4..f50450ff 100755 --- a/scripts/init_db.sh +++ b/scripts/init_db.sh @@ -59,7 +59,8 @@ then # Complete the docker run command DOCKER_RUN_CMD="${DOCKER_RUN_CMD} \ --name "postgres_$(date '+%s')" \ - postgres -N 1000" + postgres -N 1000 \ + -c wal_level=logical" # Increased maximum number of connections for testing purposes # Start the container From 2fe00aa66a8f06b545fca390a231f29e1ad6fe42 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Tue, 20 May 2025 12:51:04 +0200 Subject: [PATCH 08/34] Improve --- pg_replicate/tests/integration/base.rs | 44 ++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs index d560fe4b..0ce5460b 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/base.rs @@ -2,9 +2,11 @@ use crate::common::database::{spawn_database, test_table_name}; use crate::common::pipeline::{spawn_pg_pipeline, PipelineMode}; use crate::common::sink::TestSink; use crate::common::table::assert_table_schema; +use pg_replicate::conversions::cdc_event::CdcEvent; use pg_replicate::conversions::Cell; use postgres::schema::{ColumnSchema, TableId}; use postgres::tokio::test_utils::PgDatabase; +use std::ops::Range; use tokio_postgres::types::Type; fn get_expected_ages_sum(num_users: usize) -> i32 { @@ -78,7 +80,7 @@ fn assert_users_table_schema(sink: &TestSink, users_table_id: TableId, schema_in ); } -fn assert_users_age_sum(sink: &TestSink, users_table_id: TableId, expected_sum: i32) { +fn assert_users_age_sum_from_rows(sink: &TestSink, users_table_id: TableId, expected_sum: i32) { let mut actual_sum = 0; let tables_rows = sink.get_tables_rows(); @@ -92,6 +94,33 @@ fn assert_users_age_sum(sink: &TestSink, users_table_id: TableId, expected_sum: assert_eq!(actual_sum, expected_sum); } +fn assert_users_age_sum_from_events( + sink: &TestSink, + users_table_id: TableId, + // We use a range since events are not indexed by table id but just an ordered sequence which + // we want to slice through. + range: Range, + expected_sum: i32, +) { + let mut actual_sum = 0; + + let events = &sink.get_events()[range]; + for event in events { + match event.as_ref() { + CdcEvent::Insert((table_id, table_row)) | CdcEvent::Update((table_id, table_row)) + if table_id == &users_table_id => + { + if let Cell::I32(age) = &table_row.values[1] { + actual_sum += age; + } + } + _ => {} + } + } + + assert_eq!(actual_sum, expected_sum); +} + /* Tests to write: - Insert -> table copy @@ -102,6 +131,8 @@ Tests to write: - Insert -> table copy -> crash while copying -> add new table -> check if new table is in the snapshot The main test we want to do is to check if resuming after a new table has been added causes problems + +insert -> cdc -> add table -> recreate pipeline and source -> check schema */ #[tokio::test(flavor = "multi_thread")] @@ -125,7 +156,7 @@ async fn test_table_copy_with_insert_and_update() { assert_users_table_schema(pipeline.sink(), users_table_id, 0); let expected_sum = get_expected_ages_sum(100); - assert_users_age_sum(pipeline.sink(), users_table_id, expected_sum); + assert_users_age_sum_from_rows(pipeline.sink(), users_table_id, expected_sum); assert_eq!(pipeline.sink().get_tables_copied(), 1); assert_eq!(pipeline.sink().get_tables_truncated(), 1); @@ -145,7 +176,7 @@ async fn test_table_copy_with_insert_and_update() { assert_users_table_schema(pipeline.sink(), users_table_id, 0); let expected_sum = expected_sum * 2; - assert_users_age_sum(pipeline.sink(), users_table_id, expected_sum); + assert_users_age_sum_from_rows(pipeline.sink(), users_table_id, expected_sum); assert_eq!(pipeline.sink().get_tables_copied(), 1); assert_eq!(pipeline.sink().get_tables_truncated(), 1); } @@ -172,10 +203,9 @@ async fn test_cdc_with_insert_and_update() { // We insert 100 rows. fill_users(&database, 100).await; - // assert_users_table_schema(pipeline.sink(), users_table_id, 0); - // let expected_sum = expected_sum * 2; - // assert_users_age_sum(pipeline.sink(), users_table_id, expected_sum); - println!("CDC {:?}", pipeline.sink().get_events()); + assert_users_table_schema(pipeline.sink(), users_table_id, 0); + let expected_sum = get_expected_ages_sum(100); + assert_users_age_sum_from_events(pipeline.sink(), users_table_id, 0..100, expected_sum); assert_eq!(pipeline.sink().get_tables_copied(), 0); assert_eq!(pipeline.sink().get_tables_truncated(), 0); } From 1e7c9c76aab11bb038841a64d42874c2c435d1c6 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Tue, 20 May 2025 15:28:09 +0200 Subject: [PATCH 09/34] Fix --- .../src/pipeline/batching/data_pipeline.rs | 3 +-- pg_replicate/tests/common/pipeline.rs | 6 +++++- postgres/src/tokio/test_utils.rs | 19 +++++++++++++------ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/pg_replicate/src/pipeline/batching/data_pipeline.rs b/pg_replicate/src/pipeline/batching/data_pipeline.rs index 2ba08a55..7346a407 100644 --- a/pg_replicate/src/pipeline/batching/data_pipeline.rs +++ b/pg_replicate/src/pipeline/batching/data_pipeline.rs @@ -122,16 +122,15 @@ impl BatchDataPipeline { let mut last_lsn: u64 = last_lsn.into(); last_lsn += 1; + let cdc_events = self .source .get_cdc_stream(last_lsn.into()) .await .map_err(PipelineError::Source)?; - pin!(cdc_events); let batch_timeout_stream = BatchTimeoutStream::new(cdc_events, self.batch_config.clone()); - pin!(batch_timeout_stream); while let Some(batch) = batch_timeout_stream.next().await { diff --git a/pg_replicate/tests/common/pipeline.rs b/pg_replicate/tests/common/pipeline.rs index a0ef2078..968fef24 100644 --- a/pg_replicate/tests/common/pipeline.rs +++ b/pg_replicate/tests/common/pipeline.rs @@ -19,6 +19,10 @@ pub enum PipelineMode { }, } +pub fn test_slot_name(slot_name: &str) -> String { + format!("test_{}", slot_name) +} + pub async fn spawn_pg_pipeline( options: &PgDatabaseOptions, mode: PipelineMode, @@ -46,7 +50,7 @@ pub async fn spawn_pg_pipeline( let source = PostgresSource::new( options.clone(), vec![], - Some(slot_name), + Some(test_slot_name(&slot_name)), TableNamesFrom::Publication(publication), ) .await diff --git a/postgres/src/tokio/test_utils.rs b/postgres/src/tokio/test_utils.rs index 2876c666..6e131210 100644 --- a/postgres/src/tokio/test_utils.rs +++ b/postgres/src/tokio/test_utils.rs @@ -23,17 +23,13 @@ impl PgDatabase { ) -> Result { let table_names = table_names .iter() - .map(|t| t.to_string()) + .map(TableName::as_quoted_identifier) .collect::>(); let create_publication_query = format!( "CREATE PUBLICATION {} FOR TABLE {}", publication_name, - table_names - .iter() - .map(|t| t.to_string()) - .collect::>() - .join(", ") + table_names.join(", ") ); self.client.execute(&create_publication_query, &[]).await?; @@ -231,6 +227,17 @@ pub async fn drop_pg_database(options: &PgDatabaseOptions) { .await .expect("Failed to terminate database connections"); + // Drop any test replication slots + client + .execute( + "SELECT pg_drop_replication_slot(slot_name) + FROM pg_replication_slots + WHERE slot_name LIKE 'test_%';", + &[], + ) + .await + .expect("Failed to drop test replication slots"); + // Drop the database client .execute( From 63cd678efdf4e40d7f54bab90a6a131a8d045d01 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Tue, 20 May 2025 16:37:49 +0200 Subject: [PATCH 10/34] Add facilities --- pg_replicate/Cargo.toml | 2 +- .../src/pipeline/batching/data_pipeline.rs | 40 ++++++++++++++++--- pg_replicate/src/pipeline/batching/stream.rs | 21 ++++++++-- pg_replicate/tests/common/pipeline.rs | 4 ++ pg_replicate/tests/integration/base.rs | 26 ++++++++---- 5 files changed, 76 insertions(+), 17 deletions(-) diff --git a/pg_replicate/Cargo.toml b/pg_replicate/Cargo.toml index 7c16d677..93c5350a 100644 --- a/pg_replicate/Cargo.toml +++ b/pg_replicate/Cargo.toml @@ -40,7 +40,7 @@ rustls = { workspace = true, features = ["aws-lc-rs", "logging"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["std"] } thiserror = { workspace = true } -tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +tokio = { workspace = true, features = ["rt-multi-thread", "macros", "sync"] } tokio-postgres = { workspace = true, features = [ "runtime", "with-chrono-0_4", diff --git a/pg_replicate/src/pipeline/batching/data_pipeline.rs b/pg_replicate/src/pipeline/batching/data_pipeline.rs index 7346a407..6d941636 100644 --- a/pg_replicate/src/pipeline/batching/data_pipeline.rs +++ b/pg_replicate/src/pipeline/batching/data_pipeline.rs @@ -1,5 +1,3 @@ -use std::{collections::HashSet, time::Instant}; - use crate::{ conversions::cdc_event::{CdcEvent, CdcEventConversionError}, pipeline::{ @@ -11,17 +9,33 @@ use crate::{ }; use futures::StreamExt; use postgres::schema::TableId; +use std::sync::Arc; +use std::{collections::HashSet, time::Instant}; use tokio::pin; +use tokio::sync::Notify; use tokio_postgres::types::PgLsn; use tracing::{debug, info}; use super::BatchConfig; +#[derive(Debug, Clone)] +pub struct BatchDataPipelineHandle { + stream_stop: Arc, +} + +impl BatchDataPipelineHandle { + pub fn stop(&self) { + // We want to notify all waiters that their streams have to be stopped. + self.stream_stop.notify_waiters(); + } +} + pub struct BatchDataPipeline { source: Src, sink: Snk, action: PipelineAction, batch_config: BatchConfig, + stream_stop: Arc, } impl BatchDataPipeline { @@ -31,6 +45,7 @@ impl BatchDataPipeline { sink, action, batch_config, + stream_stop: Arc::new(Notify::new()), } } @@ -76,8 +91,11 @@ impl BatchDataPipeline { .await .map_err(PipelineError::Source)?; - let batch_timeout_stream = - BatchTimeoutStream::new(table_rows, self.batch_config.clone()); + let batch_timeout_stream = BatchTimeoutStream::new( + table_rows, + self.batch_config.clone(), + self.stream_stop.clone(), + ); pin!(batch_timeout_stream); @@ -122,7 +140,7 @@ impl BatchDataPipeline { let mut last_lsn: u64 = last_lsn.into(); last_lsn += 1; - + let cdc_events = self .source .get_cdc_stream(last_lsn.into()) @@ -130,7 +148,11 @@ impl BatchDataPipeline { .map_err(PipelineError::Source)?; pin!(cdc_events); - let batch_timeout_stream = BatchTimeoutStream::new(cdc_events, self.batch_config.clone()); + let batch_timeout_stream = BatchTimeoutStream::new( + cdc_events, + self.batch_config.clone(), + self.stream_stop.clone(), + ); pin!(batch_timeout_stream); while let Some(batch) = batch_timeout_stream.next().await { @@ -200,6 +222,12 @@ impl BatchDataPipeline { Ok(()) } + pub fn handle(&self) -> BatchDataPipelineHandle { + BatchDataPipelineHandle { + stream_stop: self.stream_stop.clone(), + } + } + pub fn source(&self) -> &Src { &self.source } diff --git a/pg_replicate/src/pipeline/batching/stream.rs b/pg_replicate/src/pipeline/batching/stream.rs index 3495e010..0fa82d8a 100644 --- a/pg_replicate/src/pipeline/batching/stream.rs +++ b/pg_replicate/src/pipeline/batching/stream.rs @@ -2,10 +2,13 @@ use futures::{ready, Future, Stream}; use pin_project_lite::pin_project; use tokio::time::{sleep, Sleep}; +use super::{BatchBoundary, BatchConfig}; use core::pin::Pin; use core::task::{Context, Poll}; - -use super::{BatchBoundary, BatchConfig}; +use std::sync::Arc; +use tokio::pin; +use tokio::sync::Notify; +use tracing::info; // Implementation adapted from https://github.com/tokio-rs/tokio/blob/master/tokio-stream/src/stream_ext/chunks_timeout.rs pin_project! { @@ -24,11 +27,12 @@ pin_project! { batch_config: BatchConfig, reset_timer: bool, inner_stream_ended: bool, + stop: Arc } } impl> BatchTimeoutStream { - pub fn new(stream: S, batch_config: BatchConfig) -> Self { + pub fn new(stream: S, batch_config: BatchConfig, stop: Arc) -> Self { BatchTimeoutStream { stream, deadline: None, @@ -36,6 +40,7 @@ impl> BatchTimeoutStream { batch_config, reset_timer: true, inner_stream_ended: false, + stop, } } @@ -49,10 +54,20 @@ impl> Stream for BatchTimeoutStream fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.as_mut().project(); + if *this.inner_stream_ended { return Poll::Ready(None); } + loop { + let notified = this.stop.notified(); + pin!(notified); + + if notified.poll(cx).is_ready() { + info!("the stream has been forcefully stopped"); + return Poll::Ready(None); + } + if *this.reset_timer { this.deadline .set(Some(sleep(this.batch_config.max_batch_fill_time))); diff --git a/pg_replicate/tests/common/pipeline.rs b/pg_replicate/tests/common/pipeline.rs index 968fef24..35cdcbbb 100644 --- a/pg_replicate/tests/common/pipeline.rs +++ b/pg_replicate/tests/common/pipeline.rs @@ -2,10 +2,14 @@ use pg_replicate::pipeline::batching::data_pipeline::BatchDataPipeline; use pg_replicate::pipeline::batching::BatchConfig; use pg_replicate::pipeline::sinks::BatchSink; use pg_replicate::pipeline::sources::postgres::{PostgresSource, TableNamesFrom}; +use pg_replicate::pipeline::sources::Source; use pg_replicate::pipeline::PipelineAction; use postgres::schema::TableName; use postgres::tokio::options::PgDatabaseOptions; +use std::sync::Arc; use std::time::Duration; +use tokio::sync::{mpsc, Mutex}; +use tokio::task::JoinHandle; pub enum PipelineMode { /// In this mode the supplied tables will be copied. diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs index 0ce5460b..bb70e97a 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/base.rs @@ -1,12 +1,14 @@ use crate::common::database::{spawn_database, test_table_name}; -use crate::common::pipeline::{spawn_pg_pipeline, PipelineMode}; +use crate::common::pipeline::{spawn_pg_pipeline, PipelineMode, PipelineRunner}; use crate::common::sink::TestSink; use crate::common::table::assert_table_schema; use pg_replicate::conversions::cdc_event::CdcEvent; use pg_replicate::conversions::Cell; +use pg_replicate::pipeline::sources::postgres::PostgresSource; use postgres::schema::{ColumnSchema, TableId}; use postgres::tokio::test_utils::PgDatabase; use std::ops::Range; +use tokio::net::unix::pipe::pipe; use tokio_postgres::types::Type; fn get_expected_ages_sum(num_users: usize) -> i32 { @@ -189,23 +191,33 @@ async fn test_cdc_with_insert_and_update() { let users_table_id = create_users_table_with_publication(&database, "users_publication").await; // We create a pipeline that subscribes to the changes of the users table. + let sink = TestSink::new(); let mut pipeline = spawn_pg_pipeline( &database.options, PipelineMode::Cdc { publication: "users_publication".to_owned(), slot_name: "users_slot".to_string(), }, - TestSink::new(), + sink.clone(), ) .await; - pipeline.start().await.unwrap(); + let pipeline_handle = pipeline.handle(); + + // We start the pipeline in another task to not block. + let pipeline_task_handle = tokio::spawn(async move { + pipeline.start().await.unwrap(); + }); // We insert 100 rows. fill_users(&database, 100).await; - assert_users_table_schema(pipeline.sink(), users_table_id, 0); + // We stop the pipeline and wait for it to finish. + pipeline_handle.stop(); + pipeline_task_handle.await.unwrap(); + + assert_users_table_schema(&sink, users_table_id, 0); let expected_sum = get_expected_ages_sum(100); - assert_users_age_sum_from_events(pipeline.sink(), users_table_id, 0..100, expected_sum); - assert_eq!(pipeline.sink().get_tables_copied(), 0); - assert_eq!(pipeline.sink().get_tables_truncated(), 0); + assert_users_age_sum_from_events(&sink, users_table_id, 0..100, expected_sum); + assert_eq!(sink.get_tables_copied(), 0); + assert_eq!(sink.get_tables_truncated(), 0); } From c5261039e65b4a81a505e40fb900e2e21026ccc6 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Tue, 20 May 2025 17:03:35 +0200 Subject: [PATCH 11/34] Improve --- .../src/pipeline/batching/data_pipeline.rs | 7 ++++-- pg_replicate/src/pipeline/batching/stream.rs | 22 ++++++++----------- pg_replicate/tests/common/sink.rs | 3 +++ pg_replicate/tests/integration/base.rs | 17 +++++++++----- 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/pg_replicate/src/pipeline/batching/data_pipeline.rs b/pg_replicate/src/pipeline/batching/data_pipeline.rs index 6d941636..2fc0bca8 100644 --- a/pg_replicate/src/pipeline/batching/data_pipeline.rs +++ b/pg_replicate/src/pipeline/batching/data_pipeline.rs @@ -26,6 +26,9 @@ pub struct BatchDataPipelineHandle { impl BatchDataPipelineHandle { pub fn stop(&self) { // We want to notify all waiters that their streams have to be stopped. + // + // Technically, we should not need to notify multiple waiters since we can't have multiple + // streams active in parallel, but in this way we cover for the future. self.stream_stop.notify_waiters(); } } @@ -94,7 +97,7 @@ impl BatchDataPipeline { let batch_timeout_stream = BatchTimeoutStream::new( table_rows, self.batch_config.clone(), - self.stream_stop.clone(), + self.stream_stop.notified(), ); pin!(batch_timeout_stream); @@ -151,7 +154,7 @@ impl BatchDataPipeline { let batch_timeout_stream = BatchTimeoutStream::new( cdc_events, self.batch_config.clone(), - self.stream_stop.clone(), + self.stream_stop.notified(), ); pin!(batch_timeout_stream); diff --git a/pg_replicate/src/pipeline/batching/stream.rs b/pg_replicate/src/pipeline/batching/stream.rs index 0fa82d8a..a9ebf692 100644 --- a/pg_replicate/src/pipeline/batching/stream.rs +++ b/pg_replicate/src/pipeline/batching/stream.rs @@ -5,9 +5,7 @@ use tokio::time::{sleep, Sleep}; use super::{BatchBoundary, BatchConfig}; use core::pin::Pin; use core::task::{Context, Poll}; -use std::sync::Arc; -use tokio::pin; -use tokio::sync::Notify; +use tokio::sync::futures::Notified; use tracing::info; // Implementation adapted from https://github.com/tokio-rs/tokio/blob/master/tokio-stream/src/stream_ext/chunks_timeout.rs @@ -18,21 +16,22 @@ pin_project! { /// item which returns true from [`BatchBoundary::is_last_in_batch`] #[must_use = "streams do nothing unless polled"] #[derive(Debug)] - pub struct BatchTimeoutStream> { + pub struct BatchTimeoutStream<'a, B: BatchBoundary, S: Stream> { #[pin] stream: S, #[pin] deadline: Option, + #[pin] + stream_stop: Notified<'a>, items: Vec, batch_config: BatchConfig, reset_timer: bool, inner_stream_ended: bool, - stop: Arc } } -impl> BatchTimeoutStream { - pub fn new(stream: S, batch_config: BatchConfig, stop: Arc) -> Self { +impl<'a, B: BatchBoundary, S: Stream> BatchTimeoutStream<'a, B, S> { + pub fn new(stream: S, batch_config: BatchConfig, stream_stop: Notified<'a>) -> Self { BatchTimeoutStream { stream, deadline: None, @@ -40,7 +39,7 @@ impl> BatchTimeoutStream { batch_config, reset_timer: true, inner_stream_ended: false, - stop, + stream_stop, } } @@ -49,7 +48,7 @@ impl> BatchTimeoutStream { } } -impl> Stream for BatchTimeoutStream { +impl<'a, B: BatchBoundary, S: Stream> Stream for BatchTimeoutStream<'a, B, S> { type Item = Vec; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -60,10 +59,7 @@ impl> Stream for BatchTimeoutStream } loop { - let notified = this.stop.notified(); - pin!(notified); - - if notified.poll(cx).is_ready() { + if this.stream_stop.as_mut().poll(cx).is_ready() { info!("the stream has been forcefully stopped"); return Poll::Ready(None); } diff --git a/pg_replicate/tests/common/sink.rs b/pg_replicate/tests/common/sink.rs index 6676fc5f..757d2a81 100644 --- a/pg_replicate/tests/common/sink.rs +++ b/pg_replicate/tests/common/sink.rs @@ -98,6 +98,9 @@ impl BatchSink for TestSink { } async fn write_cdc_events(&mut self, events: Vec) -> Result { + for event in events.iter() { + println!("EVENT {:?}", event); + } // Since CdcEvent is not Clone, we have to wrap it in an Arc, and we are fine with this // since it's not mutable, so we don't even have to use mutexes. let arc_events = events.into_iter().map(Arc::new).collect::>(); diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs index bb70e97a..23fc2b27 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/base.rs @@ -1,14 +1,14 @@ use crate::common::database::{spawn_database, test_table_name}; -use crate::common::pipeline::{spawn_pg_pipeline, PipelineMode, PipelineRunner}; +use crate::common::pipeline::{spawn_pg_pipeline, PipelineMode}; use crate::common::sink::TestSink; use crate::common::table::assert_table_schema; use pg_replicate::conversions::cdc_event::CdcEvent; use pg_replicate::conversions::Cell; -use pg_replicate::pipeline::sources::postgres::PostgresSource; use postgres::schema::{ColumnSchema, TableId}; use postgres::tokio::test_utils::PgDatabase; use std::ops::Range; -use tokio::net::unix::pipe::pipe; +use std::time::Duration; +use tokio::time::sleep; use tokio_postgres::types::Type; fn get_expected_ages_sum(num_users: usize) -> i32 { @@ -106,15 +106,16 @@ fn assert_users_age_sum_from_events( ) { let mut actual_sum = 0; - let events = &sink.get_events()[range]; - for event in events { + let mut i = 0; + for event in sink.get_events() { match event.as_ref() { CdcEvent::Insert((table_id, table_row)) | CdcEvent::Update((table_id, table_row)) - if table_id == &users_table_id => + if table_id == &users_table_id && range.contains(&i) => { if let Cell::I32(age) = &table_row.values[1] { actual_sum += age; } + i += 1; } _ => {} } @@ -208,9 +209,13 @@ async fn test_cdc_with_insert_and_update() { pipeline.start().await.unwrap(); }); + sleep(Duration::from_secs(5)).await; + // We insert 100 rows. fill_users(&database, 100).await; + sleep(Duration::from_secs(5)).await; + // We stop the pipeline and wait for it to finish. pipeline_handle.stop(); pipeline_task_handle.await.unwrap(); From 339ea67e836d55e4b33b97e4b31fd50a8de15ed6 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Tue, 20 May 2025 17:26:13 +0200 Subject: [PATCH 12/34] Improve --- pg_replicate/tests/common/pipeline.rs | 17 ++++++++++++++++- pg_replicate/tests/integration/base.rs | 10 ++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/pg_replicate/tests/common/pipeline.rs b/pg_replicate/tests/common/pipeline.rs index 35cdcbbb..ca0f55d4 100644 --- a/pg_replicate/tests/common/pipeline.rs +++ b/pg_replicate/tests/common/pipeline.rs @@ -1,4 +1,4 @@ -use pg_replicate::pipeline::batching::data_pipeline::BatchDataPipeline; +use pg_replicate::pipeline::batching::data_pipeline::{BatchDataPipeline, BatchDataPipelineHandle}; use pg_replicate::pipeline::batching::BatchConfig; use pg_replicate::pipeline::sinks::BatchSink; use pg_replicate::pipeline::sources::postgres::{PostgresSource, TableNamesFrom}; @@ -66,3 +66,18 @@ pub async fn spawn_pg_pipeline( pipeline } + +pub async fn spawn_async_pg_pipeline( + options: &PgDatabaseOptions, + mode: PipelineMode, + sink: Snk, +) -> (BatchDataPipelineHandle, JoinHandle<()>) { + let mut pipeline = spawn_pg_pipeline(options, mode, sink).await; + + let pipeline_handle = pipeline.handle(); + let pipeline_task_handle = tokio::spawn(async move { + pipeline.start().await.expect("The pipeline failed to start within a task"); + }); + + (pipeline_handle, pipeline_task_handle) +} \ No newline at end of file diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs index 23fc2b27..e7d5ebcf 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/base.rs @@ -1,5 +1,5 @@ use crate::common::database::{spawn_database, test_table_name}; -use crate::common::pipeline::{spawn_pg_pipeline, PipelineMode}; +use crate::common::pipeline::{spawn_async_pg_pipeline, spawn_pg_pipeline, PipelineMode}; use crate::common::sink::TestSink; use crate::common::table::assert_table_schema; use pg_replicate::conversions::cdc_event::CdcEvent; @@ -193,7 +193,7 @@ async fn test_cdc_with_insert_and_update() { // We create a pipeline that subscribes to the changes of the users table. let sink = TestSink::new(); - let mut pipeline = spawn_pg_pipeline( + let (pipeline_handle, pipeline_task_handle) = spawn_async_pg_pipeline( &database.options, PipelineMode::Cdc { publication: "users_publication".to_owned(), @@ -202,12 +202,6 @@ async fn test_cdc_with_insert_and_update() { sink.clone(), ) .await; - let pipeline_handle = pipeline.handle(); - - // We start the pipeline in another task to not block. - let pipeline_task_handle = tokio::spawn(async move { - pipeline.start().await.unwrap(); - }); sleep(Duration::from_secs(5)).await; From b6840344861c5e13ab9c94bf1591dd8bd2e63b36 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Tue, 20 May 2025 17:40:07 +0200 Subject: [PATCH 13/34] Fix --- pg_replicate/tests/common/mod.rs | 26 ++++++++++++++++++++ pg_replicate/tests/common/pipeline.rs | 12 ++++----- pg_replicate/tests/integration/base.rs | 34 ++++++++++++++------------ 3 files changed, 51 insertions(+), 21 deletions(-) diff --git a/pg_replicate/tests/common/mod.rs b/pg_replicate/tests/common/mod.rs index eb72c70f..3fcd1067 100644 --- a/pg_replicate/tests/common/mod.rs +++ b/pg_replicate/tests/common/mod.rs @@ -1,4 +1,30 @@ +use std::time::{Duration, Instant}; +use tokio::time::sleep; + pub mod database; pub mod pipeline; pub mod sink; pub mod table; + +/// The maximum time in seconds for which we should wait for a condition to be met +/// in tests. +const MAX_ASSERTION_DURATION: Duration = Duration::from_secs(20); + +/// The frequency at which we should check for a condition to be met in tests. +const ASSERTION_FREQUENCY_DURATION: Duration = Duration::from_millis(10); + +/// Wait for a condition to be met within the maximum timeout. +pub async fn wait_for_condition(condition: F) +where + F: Fn() -> bool, +{ + let start = Instant::now(); + while start.elapsed() < MAX_ASSERTION_DURATION { + if condition() { + break; + } + sleep(ASSERTION_FREQUENCY_DURATION).await; + } + + assert!(false, "Failed to process all events within timeout") +} diff --git a/pg_replicate/tests/common/pipeline.rs b/pg_replicate/tests/common/pipeline.rs index ca0f55d4..22d8a59d 100644 --- a/pg_replicate/tests/common/pipeline.rs +++ b/pg_replicate/tests/common/pipeline.rs @@ -2,13 +2,10 @@ use pg_replicate::pipeline::batching::data_pipeline::{BatchDataPipeline, BatchDa use pg_replicate::pipeline::batching::BatchConfig; use pg_replicate::pipeline::sinks::BatchSink; use pg_replicate::pipeline::sources::postgres::{PostgresSource, TableNamesFrom}; -use pg_replicate::pipeline::sources::Source; use pg_replicate::pipeline::PipelineAction; use postgres::schema::TableName; use postgres::tokio::options::PgDatabaseOptions; -use std::sync::Arc; use std::time::Duration; -use tokio::sync::{mpsc, Mutex}; use tokio::task::JoinHandle; pub enum PipelineMode { @@ -73,11 +70,14 @@ pub async fn spawn_async_pg_pipeline( sink: Snk, ) -> (BatchDataPipelineHandle, JoinHandle<()>) { let mut pipeline = spawn_pg_pipeline(options, mode, sink).await; - + let pipeline_handle = pipeline.handle(); let pipeline_task_handle = tokio::spawn(async move { - pipeline.start().await.expect("The pipeline failed to start within a task"); + pipeline + .start() + .await + .expect("The pipeline experienced an error"); }); (pipeline_handle, pipeline_task_handle) -} \ No newline at end of file +} diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs index e7d5ebcf..6bf4ca4c 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/base.rs @@ -2,13 +2,12 @@ use crate::common::database::{spawn_database, test_table_name}; use crate::common::pipeline::{spawn_async_pg_pipeline, spawn_pg_pipeline, PipelineMode}; use crate::common::sink::TestSink; use crate::common::table::assert_table_schema; +use crate::common::wait_for_condition; use pg_replicate::conversions::cdc_event::CdcEvent; use pg_replicate::conversions::Cell; use postgres::schema::{ColumnSchema, TableId}; use postgres::tokio::test_utils::PgDatabase; use std::ops::Range; -use std::time::Duration; -use tokio::time::sleep; use tokio_postgres::types::Type; fn get_expected_ages_sum(num_users: usize) -> i32 { @@ -82,7 +81,7 @@ fn assert_users_table_schema(sink: &TestSink, users_table_id: TableId, schema_in ); } -fn assert_users_age_sum_from_rows(sink: &TestSink, users_table_id: TableId, expected_sum: i32) { +fn get_users_age_sum_from_rows(sink: &TestSink, users_table_id: TableId) -> i32 { let mut actual_sum = 0; let tables_rows = sink.get_tables_rows(); @@ -93,17 +92,16 @@ fn assert_users_age_sum_from_rows(sink: &TestSink, users_table_id: TableId, expe } } - assert_eq!(actual_sum, expected_sum); + actual_sum } -fn assert_users_age_sum_from_events( +fn get_users_age_sum_from_events( sink: &TestSink, users_table_id: TableId, // We use a range since events are not indexed by table id but just an ordered sequence which // we want to slice through. range: Range, - expected_sum: i32, -) { +) -> i32 { let mut actual_sum = 0; let mut i = 0; @@ -121,7 +119,7 @@ fn assert_users_age_sum_from_events( } } - assert_eq!(actual_sum, expected_sum); + actual_sum } /* @@ -159,7 +157,8 @@ async fn test_table_copy_with_insert_and_update() { assert_users_table_schema(pipeline.sink(), users_table_id, 0); let expected_sum = get_expected_ages_sum(100); - assert_users_age_sum_from_rows(pipeline.sink(), users_table_id, expected_sum); + let actual_sum = get_users_age_sum_from_rows(pipeline.sink(), users_table_id); + assert_eq!(actual_sum, expected_sum); assert_eq!(pipeline.sink().get_tables_copied(), 1); assert_eq!(pipeline.sink().get_tables_truncated(), 1); @@ -179,7 +178,8 @@ async fn test_table_copy_with_insert_and_update() { assert_users_table_schema(pipeline.sink(), users_table_id, 0); let expected_sum = expected_sum * 2; - assert_users_age_sum_from_rows(pipeline.sink(), users_table_id, expected_sum); + let actual_sum = get_users_age_sum_from_rows(pipeline.sink(), users_table_id); + assert_eq!(actual_sum, expected_sum); assert_eq!(pipeline.sink().get_tables_copied(), 1); assert_eq!(pipeline.sink().get_tables_truncated(), 1); } @@ -203,20 +203,24 @@ async fn test_cdc_with_insert_and_update() { ) .await; - sleep(Duration::from_secs(5)).await; - // We insert 100 rows. fill_users(&database, 100).await; - sleep(Duration::from_secs(5)).await; + // Wait for all events to be processed + let expected_sum = get_expected_ages_sum(100); + wait_for_condition( + || { + let actual_sum = get_users_age_sum_from_events(&sink, users_table_id, 0..100); + actual_sum == expected_sum + }, + ) + .await; // We stop the pipeline and wait for it to finish. pipeline_handle.stop(); pipeline_task_handle.await.unwrap(); assert_users_table_schema(&sink, users_table_id, 0); - let expected_sum = get_expected_ages_sum(100); - assert_users_age_sum_from_events(&sink, users_table_id, 0..100, expected_sum); assert_eq!(sink.get_tables_copied(), 0); assert_eq!(sink.get_tables_truncated(), 0); } From 2689ee030f710c6e204e1529b3f7364e5e03bfa6 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Tue, 20 May 2025 17:46:32 +0200 Subject: [PATCH 14/34] Implement wait mechanism --- pg_replicate/tests/integration/base.rs | 3 +++ postgres/src/sqlx/test_utils.rs | 24 ------------------------ 2 files changed, 3 insertions(+), 24 deletions(-) diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs index 6bf4ca4c..50770137 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/base.rs @@ -219,6 +219,9 @@ async fn test_cdc_with_insert_and_update() { // We stop the pipeline and wait for it to finish. pipeline_handle.stop(); pipeline_task_handle.await.unwrap(); + + // TODO: figure out why the stopping causes the dropping of the db to start before the dropping + // of the pipeline. assert_users_table_schema(&sink, users_table_id, 0); assert_eq!(sink.get_tables_copied(), 0); diff --git a/postgres/src/sqlx/test_utils.rs b/postgres/src/sqlx/test_utils.rs index f400f647..024ee434 100644 --- a/postgres/src/sqlx/test_utils.rs +++ b/postgres/src/sqlx/test_utils.rs @@ -1,29 +1,5 @@ use crate::sqlx::options::PgDatabaseOptions; use sqlx::{Connection, Executor, PgConnection, PgPool}; -use tokio::runtime::Handle; - -struct PgDatabase { - options: PgDatabaseOptions, - pool: PgPool, -} - -impl PgDatabase { - pub async fn new(options: PgDatabaseOptions) -> Self { - let pool = create_pg_database(&options).await; - - Self { options, pool } - } -} - -impl Drop for PgDatabase { - fn drop(&mut self) { - // To use `block_in_place,` we need a multithreaded runtime since when a blocking - // task is issued, the runtime will offload existing tasks to another worker. - tokio::task::block_in_place(move || { - Handle::current().block_on(async move { drop_pg_database(&self.options).await }); - }); - } -} /// Creates a new PostgreSQL database and returns a connection pool to it. /// From cd130f11a5c7b8e998f4dcb621ae8b509fe97b4b Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 09:50:09 +0200 Subject: [PATCH 15/34] Improve --- pg_replicate/src/pipeline/batching/stream.rs | 20 +++++++++++++++++--- pg_replicate/tests/common/mod.rs | 3 ++- pg_replicate/tests/common/sink.rs | 3 --- pg_replicate/tests/integration/base.rs | 13 ++++--------- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/pg_replicate/src/pipeline/batching/stream.rs b/pg_replicate/src/pipeline/batching/stream.rs index a9ebf692..73ab5c3e 100644 --- a/pg_replicate/src/pipeline/batching/stream.rs +++ b/pg_replicate/src/pipeline/batching/stream.rs @@ -13,7 +13,8 @@ pin_project! { /// Adapter stream which batches the items of the underlying stream when it /// reaches max_size or when a timeout expires. The underlying streams items /// must implement [`BatchBoundary`]. A batch is guaranteed to end on an - /// item which returns true from [`BatchBoundary::is_last_in_batch`] + /// item which returns true from [`BatchBoundary::is_last_in_batch`] unless the + /// stream is forcefully stopped. #[must_use = "streams do nothing unless polled"] #[derive(Debug)] pub struct BatchTimeoutStream<'a, B: BatchBoundary, S: Stream> { @@ -27,6 +28,7 @@ pin_project! { batch_config: BatchConfig, reset_timer: bool, inner_stream_ended: bool, + stream_stopped: bool } } @@ -35,11 +37,12 @@ impl<'a, B: BatchBoundary, S: Stream> BatchTimeoutStream<'a, B, S> { BatchTimeoutStream { stream, deadline: None, + stream_stop, items: Vec::with_capacity(batch_config.max_batch_size), batch_config, reset_timer: true, inner_stream_ended: false, - stream_stop, + stream_stopped: false, } } @@ -59,9 +62,20 @@ impl<'a, B: BatchBoundary, S: Stream> Stream for BatchTimeoutStream<'a } loop { + if *this.stream_stopped { + return Poll::Ready(None); + } + + // If the stream has been asked to stop, we mark the stream as stopped and return the + // remaining elements, irrespectively of boundaries. if this.stream_stop.as_mut().poll(cx).is_ready() { info!("the stream has been forcefully stopped"); - return Poll::Ready(None); + *this.stream_stopped = true; + return if !this.items.is_empty() { + Poll::Ready(Some(std::mem::take(this.items))) + } else { + Poll::Ready(None) + }; } if *this.reset_timer { diff --git a/pg_replicate/tests/common/mod.rs b/pg_replicate/tests/common/mod.rs index 3fcd1067..e4300c34 100644 --- a/pg_replicate/tests/common/mod.rs +++ b/pg_replicate/tests/common/mod.rs @@ -21,8 +21,9 @@ where let start = Instant::now(); while start.elapsed() < MAX_ASSERTION_DURATION { if condition() { - break; + return; } + sleep(ASSERTION_FREQUENCY_DURATION).await; } diff --git a/pg_replicate/tests/common/sink.rs b/pg_replicate/tests/common/sink.rs index 757d2a81..6676fc5f 100644 --- a/pg_replicate/tests/common/sink.rs +++ b/pg_replicate/tests/common/sink.rs @@ -98,9 +98,6 @@ impl BatchSink for TestSink { } async fn write_cdc_events(&mut self, events: Vec) -> Result { - for event in events.iter() { - println!("EVENT {:?}", event); - } // Since CdcEvent is not Clone, we have to wrap it in an Arc, and we are fine with this // since it's not mutable, so we don't even have to use mutexes. let arc_events = events.into_iter().map(Arc::new).collect::>(); diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs index 50770137..6f36c6e8 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/base.rs @@ -208,20 +208,15 @@ async fn test_cdc_with_insert_and_update() { // Wait for all events to be processed let expected_sum = get_expected_ages_sum(100); - wait_for_condition( - || { - let actual_sum = get_users_age_sum_from_events(&sink, users_table_id, 0..100); - actual_sum == expected_sum - }, - ) + wait_for_condition(|| { + let actual_sum = get_users_age_sum_from_events(&sink, users_table_id, 0..100); + actual_sum == expected_sum + }) .await; // We stop the pipeline and wait for it to finish. pipeline_handle.stop(); pipeline_task_handle.await.unwrap(); - - // TODO: figure out why the stopping causes the dropping of the db to start before the dropping - // of the pipeline. assert_users_table_schema(&sink, users_table_id, 0); assert_eq!(sink.get_tables_copied(), 0); From cfdc826135f53e13177be838586ed6e579492a45 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 10:54:01 +0200 Subject: [PATCH 16/34] Improve --- pg_replicate/src/clients/postgres.rs | 3 +- .../src/pipeline/batching/data_pipeline.rs | 2 + pg_replicate/src/pipeline/mod.rs | 1 + pg_replicate/src/pipeline/sources/postgres.rs | 2 + pg_replicate/tests/common/pipeline.rs | 49 +++++++++++---- pg_replicate/tests/common/sink.rs | 59 +++++++++++++++---- pg_replicate/tests/integration/base.rs | 32 ++++++++-- postgres/src/tokio/test_utils.rs | 15 +++-- 8 files changed, 128 insertions(+), 35 deletions(-) diff --git a/pg_replicate/src/clients/postgres.rs b/pg_replicate/src/clients/postgres.rs index 166da7ca..a2dafd0e 100644 --- a/pg_replicate/src/clients/postgres.rs +++ b/pg_replicate/src/clients/postgres.rs @@ -71,12 +71,13 @@ impl ReplicationClient { config.replication_mode(ReplicationMode::Logical); let (postgres_client, connection) = config.connect(NoTls).await?; - + tokio::spawn(async move { info!("waiting for connection to terminate"); if let Err(e) = connection.await { warn!("connection error: {}", e); } + info!("connection terminated successfully") }); info!("successfully connected to postgres"); diff --git a/pg_replicate/src/pipeline/batching/data_pipeline.rs b/pg_replicate/src/pipeline/batching/data_pipeline.rs index 2fc0bca8..70a459bf 100644 --- a/pg_replicate/src/pipeline/batching/data_pipeline.rs +++ b/pg_replicate/src/pipeline/batching/data_pipeline.rs @@ -205,6 +205,8 @@ impl BatchDataPipeline { .get_resumption_state() .await .map_err(PipelineError::Sink)?; + + println!("RESUMPTION STATE {:?}", resumption_state); match self.action { PipelineAction::TableCopiesOnly => { diff --git a/pg_replicate/src/pipeline/mod.rs b/pg_replicate/src/pipeline/mod.rs index ff734e8f..8af5e194 100644 --- a/pg_replicate/src/pipeline/mod.rs +++ b/pg_replicate/src/pipeline/mod.rs @@ -17,6 +17,7 @@ pub enum PipelineAction { Both, } +#[derive(Debug)] pub struct PipelineResumptionState { pub copied_tables: HashSet, pub last_lsn: PgLsn, diff --git a/pg_replicate/src/pipeline/sources/postgres.rs b/pg_replicate/src/pipeline/sources/postgres.rs index 2e75ce9b..e6dfd58f 100644 --- a/pg_replicate/src/pipeline/sources/postgres.rs +++ b/pg_replicate/src/pipeline/sources/postgres.rs @@ -154,6 +154,7 @@ impl Source for PostgresSource { async fn get_cdc_stream(&self, start_lsn: PgLsn) -> Result { info!("starting cdc stream at lsn {start_lsn}"); + println!("STARTING STREAM AT {:?}", start_lsn); let publication = self .publication() .ok_or(PostgresSourceError::MissingPublication)?; @@ -166,6 +167,7 @@ impl Source for PostgresSource { .get_logical_replication_stream(publication, slot_name, start_lsn) .await .map_err(PostgresSourceError::ReplicationClient)?; + println!("STREAM STARTED AT {:?}", start_lsn); const TIME_SEC_CONVERSION: u64 = 946_684_800; let postgres_epoch = UNIX_EPOCH + Duration::from_secs(TIME_SEC_CONVERSION); diff --git a/pg_replicate/tests/common/pipeline.rs b/pg_replicate/tests/common/pipeline.rs index 22d8a59d..4db5f7fa 100644 --- a/pg_replicate/tests/common/pipeline.rs +++ b/pg_replicate/tests/common/pipeline.rs @@ -68,16 +68,45 @@ pub async fn spawn_async_pg_pipeline( options: &PgDatabaseOptions, mode: PipelineMode, sink: Snk, -) -> (BatchDataPipelineHandle, JoinHandle<()>) { - let mut pipeline = spawn_pg_pipeline(options, mode, sink).await; +) -> PipelineRunner { + let pipeline = spawn_pg_pipeline(options, mode, sink).await; + PipelineRunner::new(pipeline) +} - let pipeline_handle = pipeline.handle(); - let pipeline_task_handle = tokio::spawn(async move { - pipeline - .start() - .await - .expect("The pipeline experienced an error"); - }); +pub struct PipelineRunner { + pipeline: Option>, + pipeline_handle: BatchDataPipelineHandle, +} + +impl PipelineRunner { + pub fn new(pipeline: BatchDataPipeline) -> Self { + let pipeline_handle = pipeline.handle(); + Self { pipeline: Some(pipeline), pipeline_handle } + } + + pub async fn run(&mut self) -> JoinHandle> { + if let Some(mut pipeline) = self.pipeline.take() { + return tokio::spawn(async move { + pipeline + .start() + .await + .expect("The pipeline experienced an error"); + + pipeline + }) + } - (pipeline_handle, pipeline_task_handle) + panic!("The pipeline has already been run"); + } + + pub async fn stop_and_wait(&mut self, pipeline_task_handle: JoinHandle>) { + // We signal the existing pipeline to stop. + self.pipeline_handle.stop(); + + // We wait for the pipeline to finish, and we put it back for the next run. + let pipeline = pipeline_task_handle.await.expect("The pipeline task has failed"); + // We recreate the handle just to make sure the pipeline handle and pipelines connected. + self.pipeline_handle = pipeline.handle(); + self.pipeline = Some(pipeline); + } } diff --git a/pg_replicate/tests/common/sink.rs b/pg_replicate/tests/common/sink.rs index 6676fc5f..2be5c6bb 100644 --- a/pg_replicate/tests/common/sink.rs +++ b/pg_replicate/tests/common/sink.rs @@ -1,3 +1,4 @@ +use std::cmp::max; use async_trait::async_trait; use pg_replicate::conversions::cdc_event::CdcEvent; use pg_replicate::conversions::table_row::TableRow; @@ -21,8 +22,9 @@ struct TestSinkInner { tables_schemas: Vec>, tables_rows: HashMap>, events: Vec>, - tables_copied: u8, - tables_truncated: u8, + copied_tables: HashSet, + truncated_tables: HashSet, + last_lsn: u64 } impl TestSink { @@ -32,11 +34,27 @@ impl TestSink { tables_schemas: Vec::new(), tables_rows: HashMap::new(), events: Vec::new(), - tables_copied: 0, - tables_truncated: 0, + copied_tables: HashSet::new(), + truncated_tables: HashSet::new(), + last_lsn: 0 })), } } + + fn receive_events(&mut self, events: &[CdcEvent]) { + let mut max_lsn = 0; + for event in events { + if let CdcEvent::Commit(commit_body) = event { + max_lsn = max(max_lsn, commit_body.commit_lsn()); + } + } + + // We update the last lsn taking the maximum between the maximum of the event stream and + // the current lsn, since we assume that lsns are guaranteed to be monotonically increasing, + // so if we see a max lsn, we can be sure that all events before that point have been received. + let last_lsn = self.inner.lock().unwrap().last_lsn; + self.inner.lock().unwrap().last_lsn = max(last_lsn, max_lsn); + } pub fn get_tables_schemas(&self) -> Vec> { self.inner.lock().unwrap().tables_schemas.clone() @@ -49,13 +67,21 @@ impl TestSink { pub fn get_events(&self) -> Vec> { self.inner.lock().unwrap().events.clone() } + + pub fn get_copied_tables(&self) -> HashSet { + self.inner.lock().unwrap().copied_tables.clone() + } pub fn get_tables_copied(&self) -> u8 { - self.inner.lock().unwrap().tables_copied + self.inner.lock().unwrap().copied_tables.len() as u8 } pub fn get_tables_truncated(&self) -> u8 { - self.inner.lock().unwrap().tables_truncated + self.inner.lock().unwrap().truncated_tables.len() as u8 + } + + pub fn get_last_lsn(&self) -> u64 { + self.inner.lock().unwrap().last_lsn } } @@ -65,8 +91,8 @@ impl BatchSink for TestSink { async fn get_resumption_state(&mut self) -> Result { Ok(PipelineResumptionState { - copied_tables: HashSet::new(), - last_lsn: PgLsn::from(0), + copied_tables: self.get_copied_tables(), + last_lsn: PgLsn::from(self.get_last_lsn()), }) } @@ -79,6 +105,7 @@ impl BatchSink for TestSink { .unwrap() .tables_schemas .push(table_schemas); + Ok(()) } @@ -94,24 +121,30 @@ impl BatchSink for TestSink { .entry(table_id) .or_default() .extend(rows); + Ok(()) } async fn write_cdc_events(&mut self, events: Vec) -> Result { + self.receive_events(&events); + // Since CdcEvent is not Clone, we have to wrap it in an Arc, and we are fine with this // since it's not mutable, so we don't even have to use mutexes. let arc_events = events.into_iter().map(Arc::new).collect::>(); self.inner.lock().unwrap().events.extend(arc_events); - Ok(PgLsn::from(0)) + + Ok(PgLsn::from(self.inner.lock().unwrap().last_lsn)) } - async fn table_copied(&mut self, _table_id: TableId) -> Result<(), Self::Error> { - self.inner.lock().unwrap().tables_copied += 1; + async fn table_copied(&mut self, table_id: TableId) -> Result<(), Self::Error> { + self.inner.lock().unwrap().copied_tables.insert(table_id); + Ok(()) } - async fn truncate_table(&mut self, _table_id: TableId) -> Result<(), Self::Error> { - self.inner.lock().unwrap().tables_truncated += 1; + async fn truncate_table(&mut self, table_id: TableId) -> Result<(), Self::Error> { + self.inner.lock().unwrap().truncated_tables.insert(table_id); + Ok(()) } } diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/base.rs index 6f36c6e8..418052c5 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/base.rs @@ -185,7 +185,7 @@ async fn test_table_copy_with_insert_and_update() { } #[tokio::test(flavor = "multi_thread")] -async fn test_cdc_with_insert_and_update() { +async fn test_cdc_with_multiple_inserts() { let database = spawn_database().await; // We create the table and publication. @@ -193,7 +193,7 @@ async fn test_cdc_with_insert_and_update() { // We create a pipeline that subscribes to the changes of the users table. let sink = TestSink::new(); - let (pipeline_handle, pipeline_task_handle) = spawn_async_pg_pipeline( + let mut pipeline = spawn_async_pg_pipeline( &database.options, PipelineMode::Cdc { publication: "users_publication".to_owned(), @@ -202,11 +202,15 @@ async fn test_cdc_with_insert_and_update() { sink.clone(), ) .await; - + // We insert 100 rows. fill_users(&database, 100).await; - // Wait for all events to be processed + // We run the pipeline in the background which should correctly pick up entries from the start + // even though the pipeline was started after insertions. + let pipeline_task_handle = pipeline.run().await; + + // Wait for all events to be processed. let expected_sum = get_expected_ages_sum(100); wait_for_condition(|| { let actual_sum = get_users_age_sum_from_events(&sink, users_table_id, 0..100); @@ -215,10 +219,26 @@ async fn test_cdc_with_insert_and_update() { .await; // We stop the pipeline and wait for it to finish. - pipeline_handle.stop(); - pipeline_task_handle.await.unwrap(); + pipeline.stop_and_wait(pipeline_task_handle).await; assert_users_table_schema(&sink, users_table_id, 0); assert_eq!(sink.get_tables_copied(), 0); assert_eq!(sink.get_tables_truncated(), 0); + + // We insert an additional 100 rows. + // fill_users(&database, 100).await; + + // We run the pipeline in the background. + let pipeline_task_handle = pipeline.run().await; + + // Wait for all events to be processed. + let expected_sum = get_expected_ages_sum(100); + wait_for_condition(|| { + let actual_sum = get_users_age_sum_from_events(&sink, users_table_id, 100..200); + actual_sum == expected_sum + }) + .await; + + // We stop the pipeline and wait for it to finish. + pipeline.stop_and_wait(pipeline_task_handle).await; } diff --git a/postgres/src/tokio/test_utils.rs b/postgres/src/tokio/test_utils.rs index 6e131210..5441076f 100644 --- a/postgres/src/tokio/test_utils.rs +++ b/postgres/src/tokio/test_utils.rs @@ -214,7 +214,7 @@ pub async fn drop_pg_database(options: &PgDatabaseOptions) { // Forcefully terminate any remaining connections to the database client .execute( - &*format!( + &format!( r#" SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity @@ -230,9 +230,14 @@ pub async fn drop_pg_database(options: &PgDatabaseOptions) { // Drop any test replication slots client .execute( - "SELECT pg_drop_replication_slot(slot_name) - FROM pg_replication_slots - WHERE slot_name LIKE 'test_%';", + &format!( + r#" + SELECT pg_drop_replication_slot(slot_name) + FROM pg_replication_slots + WHERE slot_name LIKE 'test_%' + AND database = '{}';"#, + options.name + ), &[], ) .await @@ -241,7 +246,7 @@ pub async fn drop_pg_database(options: &PgDatabaseOptions) { // Drop the database client .execute( - &*format!(r#"DROP DATABASE IF EXISTS "{}";"#, options.name), + &format!(r#"DROP DATABASE IF EXISTS "{}";"#, options.name), &[], ) .await From a74bdc7299e9af91820402976cd6f3940fd1f930 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:02:47 +0200 Subject: [PATCH 17/34] Remove code --- .../src/pipeline/batching/data_pipeline.rs | 2 -- pg_replicate/src/pipeline/sources/postgres.rs | 3 +-- pg_replicate/tests/integration/mod.rs | 2 +- .../integration/{base.rs => pipeline_test.rs} | 17 ----------------- 4 files changed, 2 insertions(+), 22 deletions(-) rename pg_replicate/tests/integration/{base.rs => pipeline_test.rs} (92%) diff --git a/pg_replicate/src/pipeline/batching/data_pipeline.rs b/pg_replicate/src/pipeline/batching/data_pipeline.rs index 70a459bf..2fc0bca8 100644 --- a/pg_replicate/src/pipeline/batching/data_pipeline.rs +++ b/pg_replicate/src/pipeline/batching/data_pipeline.rs @@ -205,8 +205,6 @@ impl BatchDataPipeline { .get_resumption_state() .await .map_err(PipelineError::Sink)?; - - println!("RESUMPTION STATE {:?}", resumption_state); match self.action { PipelineAction::TableCopiesOnly => { diff --git a/pg_replicate/src/pipeline/sources/postgres.rs b/pg_replicate/src/pipeline/sources/postgres.rs index e6dfd58f..121e3c1d 100644 --- a/pg_replicate/src/pipeline/sources/postgres.rs +++ b/pg_replicate/src/pipeline/sources/postgres.rs @@ -154,7 +154,7 @@ impl Source for PostgresSource { async fn get_cdc_stream(&self, start_lsn: PgLsn) -> Result { info!("starting cdc stream at lsn {start_lsn}"); - println!("STARTING STREAM AT {:?}", start_lsn); + let publication = self .publication() .ok_or(PostgresSourceError::MissingPublication)?; @@ -167,7 +167,6 @@ impl Source for PostgresSource { .get_logical_replication_stream(publication, slot_name, start_lsn) .await .map_err(PostgresSourceError::ReplicationClient)?; - println!("STREAM STARTED AT {:?}", start_lsn); const TIME_SEC_CONVERSION: u64 = 946_684_800; let postgres_epoch = UNIX_EPOCH + Duration::from_secs(TIME_SEC_CONVERSION); diff --git a/pg_replicate/tests/integration/mod.rs b/pg_replicate/tests/integration/mod.rs index 77ed8456..8f0125e2 100644 --- a/pg_replicate/tests/integration/mod.rs +++ b/pg_replicate/tests/integration/mod.rs @@ -1 +1 @@ -mod base; +mod pipeline_test; diff --git a/pg_replicate/tests/integration/base.rs b/pg_replicate/tests/integration/pipeline_test.rs similarity index 92% rename from pg_replicate/tests/integration/base.rs rename to pg_replicate/tests/integration/pipeline_test.rs index 418052c5..27ad0c47 100644 --- a/pg_replicate/tests/integration/base.rs +++ b/pg_replicate/tests/integration/pipeline_test.rs @@ -224,21 +224,4 @@ async fn test_cdc_with_multiple_inserts() { assert_users_table_schema(&sink, users_table_id, 0); assert_eq!(sink.get_tables_copied(), 0); assert_eq!(sink.get_tables_truncated(), 0); - - // We insert an additional 100 rows. - // fill_users(&database, 100).await; - - // We run the pipeline in the background. - let pipeline_task_handle = pipeline.run().await; - - // Wait for all events to be processed. - let expected_sum = get_expected_ages_sum(100); - wait_for_condition(|| { - let actual_sum = get_users_age_sum_from_events(&sink, users_table_id, 100..200); - actual_sum == expected_sum - }) - .await; - - // We stop the pipeline and wait for it to finish. - pipeline.stop_and_wait(pipeline_task_handle).await; } From e399c3ce47665b230cff6a89b4a277cb3c467b48 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:03:01 +0200 Subject: [PATCH 18/34] Reformat --- pg_replicate/src/clients/postgres.rs | 2 +- pg_replicate/src/pipeline/sources/postgres.rs | 2 +- pg_replicate/tests/common/pipeline.rs | 24 ++++++++++++------- pg_replicate/tests/common/sink.rs | 22 ++++++++--------- .../tests/integration/pipeline_test.rs | 2 +- 5 files changed, 30 insertions(+), 22 deletions(-) diff --git a/pg_replicate/src/clients/postgres.rs b/pg_replicate/src/clients/postgres.rs index a2dafd0e..0db023a5 100644 --- a/pg_replicate/src/clients/postgres.rs +++ b/pg_replicate/src/clients/postgres.rs @@ -71,7 +71,7 @@ impl ReplicationClient { config.replication_mode(ReplicationMode::Logical); let (postgres_client, connection) = config.connect(NoTls).await?; - + tokio::spawn(async move { info!("waiting for connection to terminate"); if let Err(e) = connection.await { diff --git a/pg_replicate/src/pipeline/sources/postgres.rs b/pg_replicate/src/pipeline/sources/postgres.rs index 121e3c1d..a7b1cd7b 100644 --- a/pg_replicate/src/pipeline/sources/postgres.rs +++ b/pg_replicate/src/pipeline/sources/postgres.rs @@ -154,7 +154,7 @@ impl Source for PostgresSource { async fn get_cdc_stream(&self, start_lsn: PgLsn) -> Result { info!("starting cdc stream at lsn {start_lsn}"); - + let publication = self .publication() .ok_or(PostgresSourceError::MissingPublication)?; diff --git a/pg_replicate/tests/common/pipeline.rs b/pg_replicate/tests/common/pipeline.rs index 4db5f7fa..0db0da03 100644 --- a/pg_replicate/tests/common/pipeline.rs +++ b/pg_replicate/tests/common/pipeline.rs @@ -81,9 +81,12 @@ pub struct PipelineRunner { impl PipelineRunner { pub fn new(pipeline: BatchDataPipeline) -> Self { let pipeline_handle = pipeline.handle(); - Self { pipeline: Some(pipeline), pipeline_handle } + Self { + pipeline: Some(pipeline), + pipeline_handle, + } } - + pub async fn run(&mut self) -> JoinHandle> { if let Some(mut pipeline) = self.pipeline.take() { return tokio::spawn(async move { @@ -91,20 +94,25 @@ impl PipelineRunner { .start() .await .expect("The pipeline experienced an error"); - + pipeline - }) + }); } panic!("The pipeline has already been run"); } - - pub async fn stop_and_wait(&mut self, pipeline_task_handle: JoinHandle>) { + + pub async fn stop_and_wait( + &mut self, + pipeline_task_handle: JoinHandle>, + ) { // We signal the existing pipeline to stop. self.pipeline_handle.stop(); - + // We wait for the pipeline to finish, and we put it back for the next run. - let pipeline = pipeline_task_handle.await.expect("The pipeline task has failed"); + let pipeline = pipeline_task_handle + .await + .expect("The pipeline task has failed"); // We recreate the handle just to make sure the pipeline handle and pipelines connected. self.pipeline_handle = pipeline.handle(); self.pipeline = Some(pipeline); diff --git a/pg_replicate/tests/common/sink.rs b/pg_replicate/tests/common/sink.rs index 2be5c6bb..54053641 100644 --- a/pg_replicate/tests/common/sink.rs +++ b/pg_replicate/tests/common/sink.rs @@ -1,10 +1,10 @@ -use std::cmp::max; use async_trait::async_trait; use pg_replicate::conversions::cdc_event::CdcEvent; use pg_replicate::conversions::table_row::TableRow; use pg_replicate::pipeline::sinks::{BatchSink, InfallibleSinkError}; use pg_replicate::pipeline::PipelineResumptionState; use postgres::schema::{TableId, TableSchema}; +use std::cmp::max; use std::collections::{HashMap, HashSet}; use std::sync::{Arc, Mutex}; use tokio_postgres::types::PgLsn; @@ -24,7 +24,7 @@ struct TestSinkInner { events: Vec>, copied_tables: HashSet, truncated_tables: HashSet, - last_lsn: u64 + last_lsn: u64, } impl TestSink { @@ -36,11 +36,11 @@ impl TestSink { events: Vec::new(), copied_tables: HashSet::new(), truncated_tables: HashSet::new(), - last_lsn: 0 + last_lsn: 0, })), } } - + fn receive_events(&mut self, events: &[CdcEvent]) { let mut max_lsn = 0; for event in events { @@ -48,13 +48,13 @@ impl TestSink { max_lsn = max(max_lsn, commit_body.commit_lsn()); } } - + // We update the last lsn taking the maximum between the maximum of the event stream and - // the current lsn, since we assume that lsns are guaranteed to be monotonically increasing, + // the current lsn, since we assume that lsns are guaranteed to be monotonically increasing, // so if we see a max lsn, we can be sure that all events before that point have been received. let last_lsn = self.inner.lock().unwrap().last_lsn; - self.inner.lock().unwrap().last_lsn = max(last_lsn, max_lsn); - } + self.inner.lock().unwrap().last_lsn = max(last_lsn, max_lsn); + } pub fn get_tables_schemas(&self) -> Vec> { self.inner.lock().unwrap().tables_schemas.clone() @@ -67,7 +67,7 @@ impl TestSink { pub fn get_events(&self) -> Vec> { self.inner.lock().unwrap().events.clone() } - + pub fn get_copied_tables(&self) -> HashSet { self.inner.lock().unwrap().copied_tables.clone() } @@ -79,7 +79,7 @@ impl TestSink { pub fn get_tables_truncated(&self) -> u8 { self.inner.lock().unwrap().truncated_tables.len() as u8 } - + pub fn get_last_lsn(&self) -> u64 { self.inner.lock().unwrap().last_lsn } @@ -127,7 +127,7 @@ impl BatchSink for TestSink { async fn write_cdc_events(&mut self, events: Vec) -> Result { self.receive_events(&events); - + // Since CdcEvent is not Clone, we have to wrap it in an Arc, and we are fine with this // since it's not mutable, so we don't even have to use mutexes. let arc_events = events.into_iter().map(Arc::new).collect::>(); diff --git a/pg_replicate/tests/integration/pipeline_test.rs b/pg_replicate/tests/integration/pipeline_test.rs index 27ad0c47..11348352 100644 --- a/pg_replicate/tests/integration/pipeline_test.rs +++ b/pg_replicate/tests/integration/pipeline_test.rs @@ -202,7 +202,7 @@ async fn test_cdc_with_multiple_inserts() { sink.clone(), ) .await; - + // We insert 100 rows. fill_users(&database, 100).await; From 6ff33169f89e3442a174a5c7fb04b6dc6b6145f4 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:09:49 +0200 Subject: [PATCH 19/34] Improve --- pg_replicate/tests/common/database.rs | 14 ++++--- pg_replicate/tests/common/mod.rs | 26 ++++++++++-- pg_replicate/tests/common/pipeline.rs | 42 +++++++++++++++++-- pg_replicate/tests/common/sink.rs | 23 +++++++++- pg_replicate/tests/common/table.rs | 12 ++++++ .../tests/integration/pipeline_test.rs | 10 +---- 6 files changed, 104 insertions(+), 23 deletions(-) diff --git a/pg_replicate/tests/common/database.rs b/pg_replicate/tests/common/database.rs index 8284b49b..0dda445a 100644 --- a/pg_replicate/tests/common/database.rs +++ b/pg_replicate/tests/common/database.rs @@ -4,13 +4,17 @@ use postgres::tokio::test_utils::PgDatabase; use tokio_postgres::config::SslMode; use uuid::Uuid; -/// The default schema name used for test tables. +/// The schema name used for organizing test tables. +/// +/// This constant defines the default schema where test tables are created, +/// providing isolation from other database objects. const TEST_DATABASE_SCHEMA: &str = "test"; /// Creates a [`TableName`] in the test schema. /// -/// This helper function constructs a [`TableName`] with the schema set to the test schema -/// and the provided name as the table name. +/// This helper function constructs a [`TableName`] with the schema set to [`TEST_DATABASE_SCHEMA`] +/// and the provided name as the table name. It's used to ensure consistent table naming +/// across test scenarios. pub fn test_table_name(name: &str) -> TableName { TableName { schema: TEST_DATABASE_SCHEMA.to_owned(), @@ -18,10 +22,10 @@ pub fn test_table_name(name: &str) -> TableName { } } -/// Creates a new test database instance. +/// Creates a new test database instance with a unique name. /// /// This function spawns a new PostgreSQL database with a random UUID as its name, -/// using default credentials and disabled SSL. It also creates a test schema +/// using default credentials and disabled SSL. It automatically creates the test schema /// for organizing test tables. /// /// # Panics diff --git a/pg_replicate/tests/common/mod.rs b/pg_replicate/tests/common/mod.rs index e4300c34..40e96caa 100644 --- a/pg_replicate/tests/common/mod.rs +++ b/pg_replicate/tests/common/mod.rs @@ -1,3 +1,8 @@ +/// Common utilities and helpers for testing PostgreSQL replication functionality. +/// +/// This module provides shared testing infrastructure including database management, +/// pipeline testing utilities, sink testing helpers, and table manipulation utilities. +/// It also includes common testing patterns like waiting for conditions to be met. use std::time::{Duration, Instant}; use tokio::time::sleep; @@ -6,14 +11,27 @@ pub mod pipeline; pub mod sink; pub mod table; -/// The maximum time in seconds for which we should wait for a condition to be met -/// in tests. +/// The maximum duration to wait for test conditions to be met. +/// +/// This constant defines the timeout period for asynchronous test assertions, +/// ensuring tests don't hang indefinitely while waiting for expected states. const MAX_ASSERTION_DURATION: Duration = Duration::from_secs(20); -/// The frequency at which we should check for a condition to be met in tests. +/// The interval between condition checks during test assertions. +/// +/// This constant defines how frequently we poll for condition changes while +/// waiting for test assertions to complete. const ASSERTION_FREQUENCY_DURATION: Duration = Duration::from_millis(10); -/// Wait for a condition to be met within the maximum timeout. +/// Waits asynchronously for a condition to be met within the maximum timeout period. +/// +/// This function repeatedly evaluates the provided condition until it returns true +/// or the maximum duration is exceeded. It's useful for testing asynchronous +/// operations where the exact completion time is not known. +/// +/// # Panics +/// +/// Panics if the condition is not met within [`MAX_ASSERTION_DURATION`]. pub async fn wait_for_condition(condition: F) where F: Fn() -> bool, diff --git a/pg_replicate/tests/common/pipeline.rs b/pg_replicate/tests/common/pipeline.rs index 0db0da03..3ba16a94 100644 --- a/pg_replicate/tests/common/pipeline.rs +++ b/pg_replicate/tests/common/pipeline.rs @@ -8,22 +8,35 @@ use postgres::tokio::options::PgDatabaseOptions; use std::time::Duration; use tokio::task::JoinHandle; +/// Defines the operational mode for a PostgreSQL replication pipeline. pub enum PipelineMode { - /// In this mode the supplied tables will be copied. + /// Initializes a pipeline to copy specified tables. CopyTable { table_names: Vec }, - /// In this mode the changes will be consumed from the given publication and slot. + /// Initializes a pipeline to consume changes from a publication and replication slot. /// - /// If the slot is not supplied, a new one will be created on the supplied publication. + /// If no slot name is provided, a new slot will be created on the specified publication. Cdc { publication: String, slot_name: String, }, } +/// Generates a test-specific replication slot name. +/// +/// This function prefixes the provided slot name with "test_" to avoid conflicts +/// with other replication slots. pub fn test_slot_name(slot_name: &str) -> String { format!("test_{}", slot_name) } +/// Creates a new PostgreSQL replication pipeline. +/// +/// This function initializes a pipeline with a batch size of 1000 records and +/// a maximum batch duration of 10 seconds. +/// +/// # Panics +/// +/// Panics if the PostgreSQL source cannot be created. pub async fn spawn_pg_pipeline( options: &PgDatabaseOptions, mode: PipelineMode, @@ -64,6 +77,10 @@ pub async fn spawn_pg_pipeline( pipeline } +/// Creates and spawns a new asynchronous PostgreSQL replication pipeline. +/// +/// This function creates a pipeline and wraps it in a [`PipelineRunner`] for +/// easier management of the pipeline lifecycle. pub async fn spawn_async_pg_pipeline( options: &PgDatabaseOptions, mode: PipelineMode, @@ -73,12 +90,17 @@ pub async fn spawn_async_pg_pipeline( PipelineRunner::new(pipeline) } +/// Manages the lifecycle of a PostgreSQL replication pipeline. +/// +/// This struct provides methods to run and stop a pipeline, handling the +/// pipeline's state and ensuring proper cleanup. pub struct PipelineRunner { pipeline: Option>, pipeline_handle: BatchDataPipelineHandle, } impl PipelineRunner { + /// Creates a new pipeline runner with the specified pipeline. pub fn new(pipeline: BatchDataPipeline) -> Self { let pipeline_handle = pipeline.handle(); Self { @@ -87,6 +109,11 @@ impl PipelineRunner { } } + /// Starts the pipeline asynchronously. + /// + /// # Panics + /// + /// Panics if the pipeline has already been run. pub async fn run(&mut self) -> JoinHandle> { if let Some(mut pipeline) = self.pipeline.take() { return tokio::spawn(async move { @@ -102,6 +129,15 @@ impl PipelineRunner { panic!("The pipeline has already been run"); } + /// Stops the pipeline and waits for it to complete. + /// + /// This method signals the pipeline to stop and waits for it to finish + /// before returning. The pipeline is then restored to its initial state + /// for potential reuse. + /// + /// # Panics + /// + /// Panics if the pipeline task fails. pub async fn stop_and_wait( &mut self, pipeline_task_handle: JoinHandle>, diff --git a/pg_replicate/tests/common/sink.rs b/pg_replicate/tests/common/sink.rs index 54053641..c94dfd11 100644 --- a/pg_replicate/tests/common/sink.rs +++ b/pg_replicate/tests/common/sink.rs @@ -9,13 +9,19 @@ use std::collections::{HashMap, HashSet}; use std::sync::{Arc, Mutex}; use tokio_postgres::types::PgLsn; +/// A test sink that captures replication events and data for verification. +/// +/// This sink is designed to be shared across multiple pipelines, simulating +/// persistent storage while maintaining thread safety through interior mutability. #[derive(Debug, Clone)] pub struct TestSink { - // We use Arc to allow the sink to be shared by multiple pipelines, effectively - // simulating recreating pipelines with a sink that "persists" data. inner: Arc>, } +/// Internal state of the test sink. +/// +/// This struct maintains the sink's state including table schemas, rows, +/// CDC events, and tracking information for copied and truncated tables. #[derive(Debug)] struct TestSinkInner { // We have a Vec to store all the changes of the schema that we receive over time. @@ -28,6 +34,7 @@ struct TestSinkInner { } impl TestSink { + /// Creates a new test sink with an empty state. pub fn new() -> Self { Self { inner: Arc::new(Mutex::new(TestSinkInner { @@ -41,6 +48,11 @@ impl TestSink { } } + /// Updates the last LSN based on received events. + /// + /// This method ensures that the last LSN is monotonically increasing, + /// taking the maximum between the current LSN and the maximum LSN from + /// the received events. fn receive_events(&mut self, events: &[CdcEvent]) { let mut max_lsn = 0; for event in events { @@ -56,30 +68,37 @@ impl TestSink { self.inner.lock().unwrap().last_lsn = max(last_lsn, max_lsn); } + /// Returns a copy of all table schemas received by the sink. pub fn get_tables_schemas(&self) -> Vec> { self.inner.lock().unwrap().tables_schemas.clone() } + /// Returns a copy of all table rows received by the sink. pub fn get_tables_rows(&self) -> HashMap> { self.inner.lock().unwrap().tables_rows.clone() } + /// Returns a copy of all CDC events received by the sink. pub fn get_events(&self) -> Vec> { self.inner.lock().unwrap().events.clone() } + /// Returns a copy of the set of tables that have been copied. pub fn get_copied_tables(&self) -> HashSet { self.inner.lock().unwrap().copied_tables.clone() } + /// Returns the number of tables that have been copied. pub fn get_tables_copied(&self) -> u8 { self.inner.lock().unwrap().copied_tables.len() as u8 } + /// Returns the number of tables that have been truncated. pub fn get_tables_truncated(&self) -> u8 { self.inner.lock().unwrap().truncated_tables.len() as u8 } + /// Returns the last LSN processed by the sink. pub fn get_last_lsn(&self) -> u64 { self.inner.lock().unwrap().last_lsn } diff --git a/pg_replicate/tests/common/table.rs b/pg_replicate/tests/common/table.rs index 3392da22..9f5e7868 100644 --- a/pg_replicate/tests/common/table.rs +++ b/pg_replicate/tests/common/table.rs @@ -1,6 +1,18 @@ use crate::common::sink::TestSink; use postgres::schema::{ColumnSchema, TableId, TableName}; +/// Verifies that a table's schema matches the expected configuration. +/// +/// This function compares a table's actual schema against the expected schema, +/// checking the table name, ID, and all column properties including name, type, +/// modifiers, nullability, and primary key status. +/// +/// # Panics +/// +/// Panics if: +/// - The table ID is not found in the sink's schema +/// - The schema index is out of bounds +/// - Any column property doesn't match the expected configuration pub fn assert_table_schema( sink: &TestSink, table_id: TableId, diff --git a/pg_replicate/tests/integration/pipeline_test.rs b/pg_replicate/tests/integration/pipeline_test.rs index 11348352..5f20b25c 100644 --- a/pg_replicate/tests/integration/pipeline_test.rs +++ b/pg_replicate/tests/integration/pipeline_test.rs @@ -124,16 +124,8 @@ fn get_users_age_sum_from_events( /* Tests to write: -- Insert -> table copy -- Insert -> Update -> table copy -- Insert -> cdc -- Insert -> Update -> cdc - Insert -> cdc -> Update -> cdc -- Insert -> table copy -> crash while copying -> add new table -> check if new table is in the snapshot - -The main test we want to do is to check if resuming after a new table has been added causes problems - -insert -> cdc -> add table -> recreate pipeline and source -> check schema +- Insert -> cdc -> add table -> recreate pipeline and source -> check schema */ #[tokio::test(flavor = "multi_thread")] From ee7ce40a0c5f74531d108e67bebe9a7d65e137df Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:10:40 +0200 Subject: [PATCH 20/34] Update action --- .github/workflows/general.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/general.yml b/.github/workflows/general.yml index fe758262..9fc77094 100644 --- a/.github/workflows/general.yml +++ b/.github/workflows/general.yml @@ -59,7 +59,6 @@ jobs: - name: Migrate database run: | sudo apt-get install libpq-dev -y - cd api SKIP_DOCKER=true ./scripts/init_db.sh - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov From 16a328a28b59e8fc3ce3cca17dfd2029369249a8 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:16:26 +0200 Subject: [PATCH 21/34] Improve --- api/src/startup.rs | 19 ++++++++++--------- api/tests/common/mod.rs | 16 ++++++++++++++++ .../src/pipeline/batching/data_pipeline.rs | 17 +++++++++-------- pg_replicate/src/pipeline/sources/postgres.rs | 16 ++++++++-------- 4 files changed, 43 insertions(+), 25 deletions(-) diff --git a/api/src/startup.rs b/api/src/startup.rs index 302dc758..3fa4b4c9 100644 --- a/api/src/startup.rs +++ b/api/src/startup.rs @@ -1,5 +1,15 @@ use std::{net::TcpListener, sync::Arc}; +use actix_web::{dev::Server, web, App, HttpServer}; +use actix_web_httpauth::middleware::HttpAuthentication; +use aws_lc_rs::aead::{RandomizedNonceKey, AES_256_GCM}; +use base64::{prelude::BASE64_STANDARD, Engine}; +use postgres::sqlx::options::PgDatabaseOptions; +use sqlx::{postgres::PgPoolOptions, PgPool}; +use tracing_actix_web::TracingLogger; +use utoipa::OpenApi; +use utoipa_swagger_ui::SwaggerUi; + use crate::{ authentication::auth_validator, configuration::Settings, @@ -45,15 +55,6 @@ use crate::{ }, span_builder::ApiRootSpanBuilder, }; -use actix_web::{dev::Server, web, App, HttpServer}; -use actix_web_httpauth::middleware::HttpAuthentication; -use aws_lc_rs::aead::{RandomizedNonceKey, AES_256_GCM}; -use base64::{prelude::BASE64_STANDARD, Engine}; -use postgres::sqlx::options::PgDatabaseOptions; -use sqlx::{postgres::PgPoolOptions, PgPool}; -use tracing_actix_web::TracingLogger; -use utoipa::OpenApi; -use utoipa_swagger_ui::SwaggerUi; pub struct Application { port: u16, diff --git a/api/tests/common/mod.rs b/api/tests/common/mod.rs index 3f8ea5a3..0f6d8dcf 100644 --- a/api/tests/common/mod.rs +++ b/api/tests/common/mod.rs @@ -1,2 +1,18 @@ +//! Common test utilities for pg_replicate API tests. +//! +//! This module provides shared functionality used across integration tests: +//! +//! - `test_app`: A test application wrapper that provides: +//! - A running instance of the API server for testing +//! - Helper methods for making authenticated HTTP requests +//! - Request/response type definitions for all API endpoints +//! - Methods to create, read, update, and delete resources +//! +//! - `database`: Database configuration utilities that: +//! - Set up test databases with proper configuration +//! - Handle database migrations +//! - Provide connection pools for tests +//! +//! These utilities help maintain consistency across tests and reduce code duplication. pub mod database; pub mod test_app; diff --git a/pg_replicate/src/pipeline/batching/data_pipeline.rs b/pg_replicate/src/pipeline/batching/data_pipeline.rs index 2fc0bca8..3785a78c 100644 --- a/pg_replicate/src/pipeline/batching/data_pipeline.rs +++ b/pg_replicate/src/pipeline/batching/data_pipeline.rs @@ -1,3 +1,12 @@ +use futures::StreamExt; +use postgres::schema::TableId; +use std::sync::Arc; +use std::{collections::HashSet, time::Instant}; +use tokio::pin; +use tokio::sync::Notify; +use tokio_postgres::types::PgLsn; +use tracing::{debug, info}; + use crate::{ conversions::cdc_event::{CdcEvent, CdcEventConversionError}, pipeline::{ @@ -7,14 +16,6 @@ use crate::{ PipelineAction, PipelineError, }, }; -use futures::StreamExt; -use postgres::schema::TableId; -use std::sync::Arc; -use std::{collections::HashSet, time::Instant}; -use tokio::pin; -use tokio::sync::Notify; -use tokio_postgres::types::PgLsn; -use tracing::{debug, info}; use super::BatchConfig; diff --git a/pg_replicate/src/pipeline/sources/postgres.rs b/pg_replicate/src/pipeline/sources/postgres.rs index a7b1cd7b..6f7a8bf9 100644 --- a/pg_replicate/src/pipeline/sources/postgres.rs +++ b/pg_replicate/src/pipeline/sources/postgres.rs @@ -5,14 +5,6 @@ use std::{ time::{Duration, SystemTime, SystemTimeError, UNIX_EPOCH}, }; -use crate::{ - clients::postgres::{ReplicationClient, ReplicationClientError}, - conversions::{ - cdc_event::{CdcEvent, CdcEventConversionError, CdcEventConverter}, - table_row::{TableRow, TableRowConversionError, TableRowConverter}, - }, -}; - use async_trait::async_trait; use futures::{ready, Stream}; use pin_project_lite::pin_project; @@ -24,6 +16,14 @@ use thiserror::Error; use tokio_postgres::{config::SslMode, types::PgLsn, CopyOutStream}; use tracing::info; +use crate::{ + clients::postgres::{ReplicationClient, ReplicationClientError}, + conversions::{ + cdc_event::{CdcEvent, CdcEventConversionError, CdcEventConverter}, + table_row::{TableRow, TableRowConversionError, TableRowConverter}, + }, +}; + use super::{Source, SourceError}; pub enum TableNamesFrom { From ca0ac1b1459803111c39f56fecdf48e31594af98 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:19:06 +0200 Subject: [PATCH 22/34] Setup wal level in ci --- .github/workflows/general.yml | 2 ++ scripts/init_db.sh | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/general.yml b/.github/workflows/general.yml index 9fc77094..752e1115 100644 --- a/.github/workflows/general.yml +++ b/.github/workflows/general.yml @@ -48,6 +48,8 @@ jobs: POSTGRES_DB: postgres ports: - 5430:5432 + options: >- + --command "postgres -c wal_level=logical" steps: - name: Checkout repository uses: actions/checkout@v3 diff --git a/scripts/init_db.sh b/scripts/init_db.sh index f50450ff..6866f222 100755 --- a/scripts/init_db.sh +++ b/scripts/init_db.sh @@ -59,7 +59,7 @@ then # Complete the docker run command DOCKER_RUN_CMD="${DOCKER_RUN_CMD} \ --name "postgres_$(date '+%s')" \ - postgres -N 1000 \ + postgres:15 -N 1000 \ -c wal_level=logical" # Increased maximum number of connections for testing purposes From b358b4ecf5dc59ce4d666d9812e273cb525a8ad3 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:24:40 +0200 Subject: [PATCH 23/34] Fix --- .github/workflows/general.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/general.yml b/.github/workflows/general.yml index 752e1115..3554fe23 100644 --- a/.github/workflows/general.yml +++ b/.github/workflows/general.yml @@ -49,7 +49,7 @@ jobs: ports: - 5430:5432 options: >- - --command "postgres -c wal_level=logical" + -c wal_level=logical steps: - name: Checkout repository uses: actions/checkout@v3 From f6a7d3a5d7ab2cd5c7b18ca53bf90723b6531750 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:33:13 +0200 Subject: [PATCH 24/34] Fix --- .github/workflows/general.yml | 41 ++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/.github/workflows/general.yml b/.github/workflows/general.yml index 3554fe23..bf17833d 100644 --- a/.github/workflows/general.yml +++ b/.github/workflows/general.yml @@ -1,4 +1,3 @@ -# copied from Luca Palmieri's gist here: https://gist.github.com/LukeMathWalker/5ae1107432ce283310c3e601fac915f3 name: Rust on: [push, pull_request] @@ -29,7 +28,6 @@ jobs: - uses: dtolnay/rust-toolchain@stable with: components: clippy - # - uses: Swatinem/rust-cache@v2 - name: Linting run: cargo clippy -- -D warnings @@ -48,32 +46,55 @@ jobs: POSTGRES_DB: postgres ports: - 5430:5432 - options: >- - -c wal_level=logical steps: - name: Checkout repository uses: actions/checkout@v3 + + - name: Wait for Postgres to be ready + run: | + until pg_isready -h localhost -p 5430; do + echo "Waiting for Postgres..." + sleep 1 + done + + - name: Enable logical WAL + run: | + # Append the setting to the Postgres config + docker exec "${{ job.services.postgres.id }}" \ + bash -c "echo 'wal_level = logical' >> /var/lib/postgresql/data/postgresql.conf" + # Reload Postgres configuration + docker exec "${{ job.services.postgres.id }}" \ + bash -c "kill -HUP 1" + - name: Install sqlx-cli - run: cargo install sqlx-cli - --features native-tls,postgres - --no-default-features - --locked + run: | + cargo install sqlx-cli \ + --features native-tls,postgres \ + --no-default-features \ + --locked + - name: Migrate database run: | sudo apt-get install libpq-dev -y SKIP_DOCKER=true ./scripts/init_db.sh + - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov + - name: Generate code coverage id: coverage run: | - cargo llvm-cov test --workspace --no-fail-fast --lcov --output-path lcov.info + cargo llvm-cov test \ + --workspace --no-fail-fast \ + --lcov --output-path lcov.info + - name: Coveralls upload uses: coverallsapp/github-action@v2 with: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-lcov: lcov.info debug: true + docker: name: Docker runs-on: ubuntu-latest @@ -97,4 +118,4 @@ jobs: with: file: ./replicator/Dockerfile push: true - tags: ${{ vars.DOCKERHUB_USERNAME }}/replicator:${{ github.head_ref || github.ref_name }}.${{ github.sha }} + tags: ${{ vars.DOCKERHUB_USERNAME }}/replicator:${{ github.head_ref || github.ref_name }}.${{ github.sha }} \ No newline at end of file From c99c451a74e085ceda73023854436d91848610e1 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:45:03 +0200 Subject: [PATCH 25/34] Trying to fix --- .github/workflows/general.yml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/general.yml b/.github/workflows/general.yml index bf17833d..9a3a852a 100644 --- a/.github/workflows/general.yml +++ b/.github/workflows/general.yml @@ -59,12 +59,11 @@ jobs: - name: Enable logical WAL run: | - # Append the setting to the Postgres config - docker exec "${{ job.services.postgres.id }}" \ - bash -c "echo 'wal_level = logical' >> /var/lib/postgresql/data/postgresql.conf" - # Reload Postgres configuration - docker exec "${{ job.services.postgres.id }}" \ - bash -c "kill -HUP 1" + PGPASSWORD=postgres psql -h localhost -U postgres -c "ALTER SYSTEM SET wal_level = 'logical';" + + - name: Restart Postgres service container + run: | + docker restart ${{ job.services.postgres.id }} - name: Install sqlx-cli run: | From b3977aa919850bc9c120efad922155aff25b144e Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:46:47 +0200 Subject: [PATCH 26/34] Improve --- .github/workflows/general.yml | 2 +- pg_replicate/src/clients/postgres.rs | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/general.yml b/.github/workflows/general.yml index 9a3a852a..6037a9b5 100644 --- a/.github/workflows/general.yml +++ b/.github/workflows/general.yml @@ -59,7 +59,7 @@ jobs: - name: Enable logical WAL run: | - PGPASSWORD=postgres psql -h localhost -U postgres -c "ALTER SYSTEM SET wal_level = 'logical';" + PGPASSWORD=postgres psql -h localhost -p 5430 -U postgres -c "ALTER SYSTEM SET wal_level = 'logical';" - name: Restart Postgres service container run: | diff --git a/pg_replicate/src/clients/postgres.rs b/pg_replicate/src/clients/postgres.rs index 0db023a5..b2c1b03c 100644 --- a/pg_replicate/src/clients/postgres.rs +++ b/pg_replicate/src/clients/postgres.rs @@ -76,6 +76,7 @@ impl ReplicationClient { info!("waiting for connection to terminate"); if let Err(e) = connection.await { warn!("connection error: {}", e); + return; } info!("connection terminated successfully") }); @@ -114,7 +115,9 @@ impl ReplicationClient { info!("waiting for connection to terminate"); if let Err(e) = connection.await { warn!("connection error: {}", e); + return; } + info!("connection terminated successfully") }); info!("successfully connected to postgres"); From 372e6e7f6b7d263f986e51fb187cbdbe1db6b1ba Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:48:02 +0200 Subject: [PATCH 27/34] Improve --- .github/workflows/general.yml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/general.yml b/.github/workflows/general.yml index 6037a9b5..15cb7d20 100644 --- a/.github/workflows/general.yml +++ b/.github/workflows/general.yml @@ -57,13 +57,14 @@ jobs: sleep 1 done - - name: Enable logical WAL + - name: Enable logical WAL without restart run: | - PGPASSWORD=postgres psql -h localhost -p 5430 -U postgres -c "ALTER SYSTEM SET wal_level = 'logical';" - - - name: Restart Postgres service container - run: | - docker restart ${{ job.services.postgres.id }} + # Persist the change into postgresql.auto.conf + PGPASSWORD=postgres psql -h localhost -p 5430 -U postgres \ + -c "ALTER SYSTEM SET wal_level = 'logical';" + # Tell Postgres to re-read its config (no restart) + PGPASSWORD=postgres psql -h localhost -p 5430 -U postgres \ + -c "SELECT pg_reload_conf();" - name: Install sqlx-cli run: | From 98857d5aa916817da0bd90711a69c8375a8d0fa1 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:48:38 +0200 Subject: [PATCH 28/34] Improve --- .github/workflows/general.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/general.yml b/.github/workflows/general.yml index 15cb7d20..65e39e0b 100644 --- a/.github/workflows/general.yml +++ b/.github/workflows/general.yml @@ -57,7 +57,7 @@ jobs: sleep 1 done - - name: Enable logical WAL without restart + - name: Enable logical WAL run: | # Persist the change into postgresql.auto.conf PGPASSWORD=postgres psql -h localhost -p 5430 -U postgres \ From e997309f4c7ab0597fe16499a625c423910a7157 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:58:29 +0200 Subject: [PATCH 29/34] Improve --- .github/workflows/general.yml | 11 +++++------ scripts/init_db.sh | 3 +-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/.github/workflows/general.yml b/.github/workflows/general.yml index 65e39e0b..6037a9b5 100644 --- a/.github/workflows/general.yml +++ b/.github/workflows/general.yml @@ -59,12 +59,11 @@ jobs: - name: Enable logical WAL run: | - # Persist the change into postgresql.auto.conf - PGPASSWORD=postgres psql -h localhost -p 5430 -U postgres \ - -c "ALTER SYSTEM SET wal_level = 'logical';" - # Tell Postgres to re-read its config (no restart) - PGPASSWORD=postgres psql -h localhost -p 5430 -U postgres \ - -c "SELECT pg_reload_conf();" + PGPASSWORD=postgres psql -h localhost -p 5430 -U postgres -c "ALTER SYSTEM SET wal_level = 'logical';" + + - name: Restart Postgres service container + run: | + docker restart ${{ job.services.postgres.id }} - name: Install sqlx-cli run: | diff --git a/scripts/init_db.sh b/scripts/init_db.sh index 6866f222..8597cb18 100755 --- a/scripts/init_db.sh +++ b/scripts/init_db.sh @@ -59,8 +59,7 @@ then # Complete the docker run command DOCKER_RUN_CMD="${DOCKER_RUN_CMD} \ --name "postgres_$(date '+%s')" \ - postgres:15 -N 1000 \ - -c wal_level=logical" + postgres:15 -N 1000" # Increased maximum number of connections for testing purposes # Start the container From 0d3543008fabca5f8a5fcadf353b08a87c63d5ef Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 12:58:54 +0200 Subject: [PATCH 30/34] Improve --- scripts/init_db.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/init_db.sh b/scripts/init_db.sh index 8597cb18..6866f222 100755 --- a/scripts/init_db.sh +++ b/scripts/init_db.sh @@ -59,7 +59,8 @@ then # Complete the docker run command DOCKER_RUN_CMD="${DOCKER_RUN_CMD} \ --name "postgres_$(date '+%s')" \ - postgres:15 -N 1000" + postgres:15 -N 1000 \ + -c wal_level=logical" # Increased maximum number of connections for testing purposes # Start the container From 12faaeca6b03d03f52ea95f534e4a71740582d7e Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Wed, 21 May 2025 17:28:26 +0530 Subject: [PATCH 31/34] fix: clippy warnings --- pg_replicate/tests/common/mod.rs | 2 +- pg_replicate/tests/integration/pipeline_test.rs | 6 +++--- postgres/src/sqlx/test_utils.rs | 6 ++---- postgres/src/tokio/test_utils.rs | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pg_replicate/tests/common/mod.rs b/pg_replicate/tests/common/mod.rs index 40e96caa..ada4cb7e 100644 --- a/pg_replicate/tests/common/mod.rs +++ b/pg_replicate/tests/common/mod.rs @@ -45,5 +45,5 @@ where sleep(ASSERTION_FREQUENCY_DURATION).await; } - assert!(false, "Failed to process all events within timeout") + panic!("Failed to process all events within timeout") } diff --git a/pg_replicate/tests/integration/pipeline_test.rs b/pg_replicate/tests/integration/pipeline_test.rs index 5f20b25c..5f8622f5 100644 --- a/pg_replicate/tests/integration/pipeline_test.rs +++ b/pg_replicate/tests/integration/pipeline_test.rs @@ -16,7 +16,7 @@ fn get_expected_ages_sum(num_users: usize) -> i32 { async fn create_users_table(database: &PgDatabase) -> TableId { let table_id = database - .create_table(test_table_name("users"), &vec![("age", "integer")]) + .create_table(test_table_name("users"), &[("age", "integer")]) .await .unwrap(); @@ -30,7 +30,7 @@ async fn create_users_table_with_publication( let table_id = create_users_table(database).await; database - .create_publication(publication_name, &vec![test_table_name("users")]) + .create_publication(publication_name, &[test_table_name("users")]) .await .unwrap(); @@ -49,7 +49,7 @@ async fn fill_users(database: &PgDatabase, num_users: usize) { async fn double_users_ages(database: &PgDatabase) { database - .update_values(test_table_name("users"), &["age"], &[&"age * 2"]) + .update_values(test_table_name("users"), &["age"], &["age * 2"]) .await .unwrap(); } diff --git a/postgres/src/sqlx/test_utils.rs b/postgres/src/sqlx/test_utils.rs index 024ee434..1b73b3b8 100644 --- a/postgres/src/sqlx/test_utils.rs +++ b/postgres/src/sqlx/test_utils.rs @@ -17,11 +17,9 @@ pub async fn create_pg_database(options: &PgDatabaseOptions) -> PgPool { .expect("Failed to create database"); // Create a connection pool to the database and run the migration. - let connection_pool = PgPool::connect_with(options.with_db()) + PgPool::connect_with(options.with_db()) .await - .expect("Failed to connect to Postgres"); - - connection_pool + .expect("Failed to connect to Postgres") } /// Drops a PostgreSQL database and cleans up all connections. diff --git a/postgres/src/tokio/test_utils.rs b/postgres/src/tokio/test_utils.rs index 5441076f..ecb1c538 100644 --- a/postgres/src/tokio/test_utils.rs +++ b/postgres/src/tokio/test_utils.rs @@ -44,7 +44,7 @@ impl PgDatabase { ) -> Result { let columns_str = columns .iter() - .map(|(name, typ)| format!("{} {}", name, typ.to_string())) + .map(|(name, typ)| format!("{} {}", name, typ)) .collect::>() .join(", "); From c7cb5959ceb85b995c5f24eff5cf0e7718bcb278 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 14:00:41 +0200 Subject: [PATCH 32/34] Fix PR comments --- pg_replicate/tests/common/sink.rs | 4 ++-- postgres/src/sqlx/test_utils.rs | 4 ++-- postgres/src/tokio/test_utils.rs | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pg_replicate/tests/common/sink.rs b/pg_replicate/tests/common/sink.rs index c94dfd11..cfd6151d 100644 --- a/pg_replicate/tests/common/sink.rs +++ b/pg_replicate/tests/common/sink.rs @@ -64,8 +64,8 @@ impl TestSink { // We update the last lsn taking the maximum between the maximum of the event stream and // the current lsn, since we assume that lsns are guaranteed to be monotonically increasing, // so if we see a max lsn, we can be sure that all events before that point have been received. - let last_lsn = self.inner.lock().unwrap().last_lsn; - self.inner.lock().unwrap().last_lsn = max(last_lsn, max_lsn); + let mut inner = self.inner.lock().unwrap(); + inner.last_lsn = max(inner.last_lsn, max_lsn); } /// Returns a copy of all table schemas received by the sink. diff --git a/postgres/src/sqlx/test_utils.rs b/postgres/src/sqlx/test_utils.rs index 1b73b3b8..9a525817 100644 --- a/postgres/src/sqlx/test_utils.rs +++ b/postgres/src/sqlx/test_utils.rs @@ -15,8 +15,8 @@ pub async fn create_pg_database(options: &PgDatabaseOptions) -> PgPool { .execute(&*format!(r#"CREATE DATABASE "{}";"#, options.name)) .await .expect("Failed to create database"); - - // Create a connection pool to the database and run the migration. + + // Create a connection pool to the database. PgPool::connect_with(options.with_db()) .await .expect("Failed to connect to Postgres") diff --git a/postgres/src/tokio/test_utils.rs b/postgres/src/tokio/test_utils.rs index ecb1c538..081ac091 100644 --- a/postgres/src/tokio/test_utils.rs +++ b/postgres/src/tokio/test_utils.rs @@ -20,7 +20,7 @@ impl PgDatabase { &self, publication_name: &str, table_names: &[TableName], - ) -> Result { + ) -> Result<(), tokio_postgres::Error> { let table_names = table_names .iter() .map(TableName::as_quoted_identifier) @@ -33,7 +33,7 @@ impl PgDatabase { ); self.client.execute(&create_publication_query, &[]).await?; - Ok(publication_name.to_string()) + Ok(()) } /// Creates a new table with the specified name and columns. From 6bd16bdf0069c15a412a4ddd76f8b5418df2a91d Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 14:03:26 +0200 Subject: [PATCH 33/34] Make queries lowercase --- postgres/src/sqlx/test_utils.rs | 12 +++++------ postgres/src/tokio/test_utils.rs | 36 ++++++++++++++++---------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/postgres/src/sqlx/test_utils.rs b/postgres/src/sqlx/test_utils.rs index 9a525817..a21bb0ee 100644 --- a/postgres/src/sqlx/test_utils.rs +++ b/postgres/src/sqlx/test_utils.rs @@ -12,7 +12,7 @@ pub async fn create_pg_database(options: &PgDatabaseOptions) -> PgPool { .await .expect("Failed to connect to Postgres"); connection - .execute(&*format!(r#"CREATE DATABASE "{}";"#, options.name)) + .execute(&*format!(r#"create database "{}";"#, options.name)) .await .expect("Failed to create database"); @@ -38,10 +38,10 @@ pub async fn drop_pg_database(options: &PgDatabaseOptions) { connection .execute(&*format!( r#" - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{}' - AND pid <> pg_backend_pid();"#, + select pg_terminate_backend(pg_stat_activity.pid) + from pg_stat_activity + where pg_stat_activity.datname = '{}' + and pid <> pg_backend_pid();"#, options.name )) .await @@ -49,7 +49,7 @@ pub async fn drop_pg_database(options: &PgDatabaseOptions) { // Drop the database. connection - .execute(&*format!(r#"DROP DATABASE IF EXISTS "{}";"#, options.name)) + .execute(&*format!(r#"drop database if exists "{}";"#, options.name)) .await .expect("Failed to destroy database"); } diff --git a/postgres/src/tokio/test_utils.rs b/postgres/src/tokio/test_utils.rs index 081ac091..a49eb9a4 100644 --- a/postgres/src/tokio/test_utils.rs +++ b/postgres/src/tokio/test_utils.rs @@ -27,7 +27,7 @@ impl PgDatabase { .collect::>(); let create_publication_query = format!( - "CREATE PUBLICATION {} FOR TABLE {}", + "create publication {} for table {}", publication_name, table_names.join(", ") ); @@ -49,7 +49,7 @@ impl PgDatabase { .join(", "); let create_table_query = format!( - "CREATE TABLE {} (id BIGSERIAL PRIMARY KEY, {})", + "create table {} (id bigserial primary key, {})", table_name.as_quoted_identifier(), columns_str ); @@ -59,8 +59,8 @@ impl PgDatabase { let row = self .client .query_one( - "SELECT c.oid FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace \ - WHERE n.nspname = $1 AND c.relname = $2", + "select c.oid from pg_class c join pg_namespace n on n.oid = c.relnamespace \ + where n.nspname = $1 and c.relname = $2", &[&table_name.schema, &table_name.name], ) .await?; @@ -82,7 +82,7 @@ impl PgDatabase { let placeholders_str = placeholders.join(", "); let insert_query = format!( - "INSERT INTO {} ({}) VALUES ({})", + "insert into {} ({}) values ({})", table_name.as_quoted_identifier(), columns_str, placeholders_str @@ -106,7 +106,7 @@ impl PgDatabase { let set_clause = set_clauses.join(", "); let update_query = format!( - "UPDATE {} SET {}", + "update {} set {}", table_name.as_quoted_identifier(), set_clause ); @@ -124,9 +124,9 @@ impl PgDatabase { where T: for<'a> tokio_postgres::types::FromSql<'a>, { - let where_str = where_clause.map_or(String::new(), |w| format!(" WHERE {}", w)); + let where_str = where_clause.map_or(String::new(), |w| format!(" where {}", w)); let query = format!( - "SELECT {} FROM {}{}", + "select {} from {}{}", column, table_name.as_quoted_identifier(), where_str @@ -169,7 +169,7 @@ pub async fn create_pg_database(options: &PgDatabaseOptions) -> Client { // Create the database client - .execute(&*format!(r#"CREATE DATABASE "{}";"#, options.name), &[]) + .execute(&*format!(r#"create database "{}";"#, options.name), &[]) .await .expect("Failed to create database"); @@ -216,10 +216,10 @@ pub async fn drop_pg_database(options: &PgDatabaseOptions) { .execute( &format!( r#" - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{}' - AND pid <> pg_backend_pid();"#, + select pg_terminate_backend(pg_stat_activity.pid) + from pg_stat_activity + where pg_stat_activity.datname = '{}' + and pid <> pg_backend_pid();"#, options.name ), &[], @@ -232,10 +232,10 @@ pub async fn drop_pg_database(options: &PgDatabaseOptions) { .execute( &format!( r#" - SELECT pg_drop_replication_slot(slot_name) - FROM pg_replication_slots - WHERE slot_name LIKE 'test_%' - AND database = '{}';"#, + select pg_drop_replication_slot(slot_name) + from pg_replication_slots + where slot_name like 'test_%' + and database = '{}';"#, options.name ), &[], @@ -246,7 +246,7 @@ pub async fn drop_pg_database(options: &PgDatabaseOptions) { // Drop the database client .execute( - &format!(r#"DROP DATABASE IF EXISTS "{}";"#, options.name), + &format!(r#"drop database if exists "{}";"#, options.name), &[], ) .await From 7068ec0a28f09156dfe50090b2d414282c382e54 Mon Sep 17 00:00:00 2001 From: Riccardo Busetti Date: Wed, 21 May 2025 14:15:28 +0200 Subject: [PATCH 34/34] Improve --- postgres/src/sqlx/test_utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres/src/sqlx/test_utils.rs b/postgres/src/sqlx/test_utils.rs index a21bb0ee..d860996c 100644 --- a/postgres/src/sqlx/test_utils.rs +++ b/postgres/src/sqlx/test_utils.rs @@ -15,7 +15,7 @@ pub async fn create_pg_database(options: &PgDatabaseOptions) -> PgPool { .execute(&*format!(r#"create database "{}";"#, options.name)) .await .expect("Failed to create database"); - + // Create a connection pool to the database. PgPool::connect_with(options.with_db()) .await