diff --git a/.github/workflows/general.yml b/.github/workflows/general.yml index fe758262..6037a9b5 100644 --- a/.github/workflows/general.yml +++ b/.github/workflows/general.yml @@ -1,4 +1,3 @@ -# copied from Luca Palmieri's gist here: https://gist.github.com/LukeMathWalker/5ae1107432ce283310c3e601fac915f3 name: Rust on: [push, pull_request] @@ -29,7 +28,6 @@ jobs: - uses: dtolnay/rust-toolchain@stable with: components: clippy - # - uses: Swatinem/rust-cache@v2 - name: Linting run: cargo clippy -- -D warnings @@ -51,28 +49,51 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v3 + + - name: Wait for Postgres to be ready + run: | + until pg_isready -h localhost -p 5430; do + echo "Waiting for Postgres..." + sleep 1 + done + + - name: Enable logical WAL + run: | + PGPASSWORD=postgres psql -h localhost -p 5430 -U postgres -c "ALTER SYSTEM SET wal_level = 'logical';" + + - name: Restart Postgres service container + run: | + docker restart ${{ job.services.postgres.id }} + - name: Install sqlx-cli - run: cargo install sqlx-cli - --features native-tls,postgres - --no-default-features - --locked + run: | + cargo install sqlx-cli \ + --features native-tls,postgres \ + --no-default-features \ + --locked + - name: Migrate database run: | sudo apt-get install libpq-dev -y - cd api SKIP_DOCKER=true ./scripts/init_db.sh + - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov + - name: Generate code coverage id: coverage run: | - cargo llvm-cov test --workspace --no-fail-fast --lcov --output-path lcov.info + cargo llvm-cov test \ + --workspace --no-fail-fast \ + --lcov --output-path lcov.info + - name: Coveralls upload uses: coverallsapp/github-action@v2 with: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-lcov: lcov.info debug: true + docker: name: Docker runs-on: ubuntu-latest @@ -96,4 +117,4 @@ jobs: with: file: ./replicator/Dockerfile push: true - tags: ${{ vars.DOCKERHUB_USERNAME }}/replicator:${{ github.head_ref || github.ref_name }}.${{ github.sha }} + tags: ${{ vars.DOCKERHUB_USERNAME }}/replicator:${{ github.head_ref || github.ref_name }}.${{ github.sha }} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index fd72004b..bcf2dd02 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,20 @@ [workspace] - resolver = "2" - -members = ["api", "pg_replicate", "replicator", "telemetry"] +members = [ + "api", + "pg_replicate", + "postgres", + "replicator", + "telemetry" +] [workspace.dependencies] +api = { path = "api" } +pg_replicate = { path = "pg_replicate" } +postgres = { path = "postgres" } +replicator = { path = "replicator" } +telemetry = { path = "telemetry" } + actix-web = { version = "4", default-features = false } actix-web-httpauth = { version = "0.8.2", default-features = false } anyhow = { version = "1.0", default-features = false } diff --git a/api/Cargo.toml b/api/Cargo.toml index 55594831..6b6d294a 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -11,13 +11,15 @@ path = "src/main.rs" name = "api" [dependencies] +postgres = { workspace = true, features = ["sqlx"] } +telemetry = { workspace = true } + actix-web = { workspace = true, features = ["macros", "http2"] } actix-web-httpauth = { workspace = true } anyhow = { workspace = true, features = ["std"] } async-trait = { workspace = true } aws-lc-rs = { workspace = true, features = ["alloc", "aws-lc-sys"] } base64 = { workspace = true, features = ["std"] } -bytes = { workspace = true } config = { workspace = true, features = ["yaml"] } constant_time_eq = { workspace = true } k8s-openapi = { workspace = true, features = ["latest"] } @@ -30,7 +32,6 @@ kube = { workspace = true, features = [ pg_escape = { workspace = true } rand = { workspace = true, features = ["std"] } reqwest = { workspace = true, features = ["json"] } -secrecy = { workspace = true, features = ["serde", "alloc"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["std"] } sqlx = { workspace = true, features = [ @@ -41,10 +42,12 @@ sqlx = { workspace = true, features = [ "migrate", ] } thiserror = { workspace = true } -telemetry = { path = "../telemetry" } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true, default-features = false } tracing-actix-web = { workspace = true, features = ["emit_event_on_error"] } utoipa = { workspace = true, features = ["actix_extras"] } utoipa-swagger-ui = { workspace = true, features = ["actix-web", "reqwest"] } uuid = { version = "1.10.0", features = ["v4"] } + +[dev-dependencies] +postgres = { workspace = true, features = ["test-utils", "sqlx"] } \ No newline at end of file diff --git a/api/README.md b/api/README.md index 4715439e..6489e0f4 100644 --- a/api/README.md +++ b/api/README.md @@ -26,7 +26,7 @@ Before you begin, ensure you have the following installed: ## Database Management ### Initial Setup -To set up and initialize the database, run the following command from the `api` directory: +To set up and initialize the database, run the following command from the main directory: ```bash ./scripts/init_db.sh diff --git a/api/src/configuration.rs b/api/src/configuration.rs index 1a34c929..4bd36095 100644 --- a/api/src/configuration.rs +++ b/api/src/configuration.rs @@ -1,12 +1,11 @@ use std::fmt::{self, Display}; use base64::{prelude::BASE64_STANDARD, Engine}; -use secrecy::{ExposeSecret, Secret}; +use postgres::sqlx::options::PgDatabaseOptions; use serde::{ de::{self, MapAccess, Unexpected, Visitor}, Deserialize, Deserializer, }; -use sqlx::postgres::{PgConnectOptions, PgSslMode}; use thiserror::Error; #[derive(serde::Deserialize, Clone)] @@ -100,59 +99,12 @@ impl<'de> Deserialize<'de> for ApiKey { #[derive(serde::Deserialize, Clone)] pub struct Settings { - pub database: DatabaseSettings, + pub database: PgDatabaseOptions, pub application: ApplicationSettings, pub encryption_key: EncryptionKey, pub api_key: String, } -#[derive(serde::Deserialize, Clone)] -pub struct DatabaseSettings { - /// Host on which Postgres is running - pub host: String, - - /// Port on which Postgres is running - pub port: u16, - - /// Postgres database name - pub name: String, - - /// Postgres database user name - pub username: String, - - /// Postgres database user password - pub password: Option>, - - /// Whether to enable ssl or not - pub require_ssl: bool, -} - -impl DatabaseSettings { - pub fn without_db(&self) -> PgConnectOptions { - let ssl_mode = if self.require_ssl { - PgSslMode::Require - } else { - PgSslMode::Prefer - }; - - let options = PgConnectOptions::new_without_pgpass() - .host(&self.host) - .username(&self.username) - .port(self.port) - .ssl_mode(ssl_mode); - - if let Some(password) = &self.password { - options.password(password.expose_secret()) - } else { - options - } - } - - pub fn with_db(&self) -> PgConnectOptions { - self.without_db().database(&self.name) - } -} - #[derive(serde::Deserialize, Clone)] pub struct ApplicationSettings { /// host the api listens on diff --git a/api/src/main.rs b/api/src/main.rs index 09a809e6..2f76d377 100644 --- a/api/src/main.rs +++ b/api/src/main.rs @@ -2,9 +2,10 @@ use std::env; use anyhow::anyhow; use api::{ - configuration::{get_settings, DatabaseSettings, Settings}, + configuration::{get_settings, Settings}, startup::Application, }; +use postgres::sqlx::options::PgDatabaseOptions; use telemetry::init_tracing; use tracing::{error, info}; @@ -23,7 +24,7 @@ pub async fn main() -> anyhow::Result<()> { // Run the application server 1 => { let configuration = get_settings::<'_, Settings>()?; - log_database_settings(&configuration.database); + log_pg_database_options(&configuration.database); let application = Application::build(configuration.clone()).await?; application.run_until_stopped().await?; } @@ -32,9 +33,9 @@ pub async fn main() -> anyhow::Result<()> { let command = args.nth(1).unwrap(); match command.as_str() { "migrate" => { - let configuration = get_settings::<'_, DatabaseSettings>()?; - log_database_settings(&configuration); - Application::migrate_database(configuration).await?; + let options = get_settings::<'_, PgDatabaseOptions>()?; + log_pg_database_options(&options); + Application::migrate_database(options).await?; info!("database migrated successfully"); } _ => { @@ -54,13 +55,13 @@ pub async fn main() -> anyhow::Result<()> { Ok(()) } -fn log_database_settings(settings: &DatabaseSettings) { +fn log_pg_database_options(options: &PgDatabaseOptions) { info!( - host = settings.host, - port = settings.port, - dbname = settings.name, - username = settings.username, - require_ssl = settings.require_ssl, - "database details", + host = options.host, + port = options.port, + dbname = options.name, + username = options.username, + require_ssl = options.require_ssl, + "pg database options", ); } diff --git a/api/src/startup.rs b/api/src/startup.rs index 9661bf5d..3fa4b4c9 100644 --- a/api/src/startup.rs +++ b/api/src/startup.rs @@ -4,6 +4,7 @@ use actix_web::{dev::Server, web, App, HttpServer}; use actix_web_httpauth::middleware::HttpAuthentication; use aws_lc_rs::aead::{RandomizedNonceKey, AES_256_GCM}; use base64::{prelude::BASE64_STANDARD, Engine}; +use postgres::sqlx::options::PgDatabaseOptions; use sqlx::{postgres::PgPoolOptions, PgPool}; use tracing_actix_web::TracingLogger; use utoipa::OpenApi; @@ -11,7 +12,7 @@ use utoipa_swagger_ui::SwaggerUi; use crate::{ authentication::auth_validator, - configuration::{DatabaseSettings, Settings}, + configuration::Settings, db::publications::Publication, encryption, k8s_client::HttpK8sClient, @@ -90,10 +91,8 @@ impl Application { Ok(Self { port, server }) } - pub async fn migrate_database( - database_settings: DatabaseSettings, - ) -> Result<(), anyhow::Error> { - let connection_pool = get_connection_pool(&database_settings); + pub async fn migrate_database(options: PgDatabaseOptions) -> Result<(), anyhow::Error> { + let connection_pool = get_connection_pool(&options); sqlx::migrate!("./migrations").run(&connection_pool).await?; @@ -109,8 +108,8 @@ impl Application { } } -pub fn get_connection_pool(configuration: &DatabaseSettings) -> PgPool { - PgPoolOptions::new().connect_lazy_with(configuration.with_db()) +pub fn get_connection_pool(options: &PgDatabaseOptions) -> PgPool { + PgPoolOptions::new().connect_lazy_with(options.with_db()) } // HttpK8sClient is wrapped in an option because creating it diff --git a/api/tests/common/database.rs b/api/tests/common/database.rs index f78c3520..c08e3ed0 100644 --- a/api/tests/common/database.rs +++ b/api/tests/common/database.rs @@ -1,20 +1,16 @@ -use api::configuration::DatabaseSettings; -use sqlx::{Connection, Executor, PgConnection, PgPool}; +use postgres::sqlx::options::PgDatabaseOptions; +use postgres::sqlx::test_utils::create_pg_database; +use sqlx::PgPool; -pub async fn create_and_configure_database(settings: &DatabaseSettings) -> PgPool { - // Create the database via a single connection. - let mut connection = PgConnection::connect_with(&settings.without_db()) - .await - .expect("Failed to connect to Postgres"); - connection - .execute(&*format!(r#"CREATE DATABASE "{}";"#, settings.name)) - .await - .expect("Failed to create database"); +/// Creates and configures a new PostgreSQL database for the API. +/// +/// Similar to [`create_pg_database`], but additionally runs all database migrations +/// from the "./migrations" directory after creation. Returns a [`PgPool`] +/// connected to the newly created and migrated database. Panics if database +/// creation or migration fails. +pub async fn create_pg_replicate_api_database(options: &PgDatabaseOptions) -> PgPool { + let connection_pool = create_pg_database(&options).await; - // Create a connection pool to the database and run the migration. - let connection_pool = PgPool::connect_with(settings.with_db()) - .await - .expect("Failed to connect to Postgres"); sqlx::migrate!("./migrations") .run(&connection_pool) .await @@ -22,30 +18,3 @@ pub async fn create_and_configure_database(settings: &DatabaseSettings) -> PgPoo connection_pool } - -pub async fn destroy_database(settings: &DatabaseSettings) { - // Connect to the default database. - let mut connection = PgConnection::connect_with(&settings.without_db()) - .await - .expect("Failed to connect to Postgres"); - - // Forcefully terminate any remaining connections to the database. This code assumes that those - // connections are not used anymore and do not outlive the `TestApp` instance. - connection - .execute(&*format!( - r#" - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{}' - AND pid <> pg_backend_pid();"#, - settings.name - )) - .await - .expect("Failed to terminate database connections"); - - // Drop the database. - connection - .execute(&*format!(r#"DROP DATABASE IF EXISTS "{}";"#, settings.name)) - .await - .expect("Failed to destroy database"); -} diff --git a/api/tests/common/mod.rs b/api/tests/common/mod.rs index 5cc827d5..0f6d8dcf 100644 --- a/api/tests/common/mod.rs +++ b/api/tests/common/mod.rs @@ -14,6 +14,5 @@ //! - Provide connection pools for tests //! //! These utilities help maintain consistency across tests and reduce code duplication. - pub mod database; pub mod test_app; diff --git a/api/tests/common/test_app.rs b/api/tests/common/test_app.rs index c280c61a..c7d0144d 100644 --- a/api/tests/common/test_app.rs +++ b/api/tests/common/test_app.rs @@ -1,40 +1,19 @@ -use std::io; -use std::net::TcpListener; - -use crate::common::database::{create_and_configure_database, destroy_database}; +use crate::common::database::create_pg_replicate_api_database; use api::{ configuration::{get_settings, Settings}, db::{pipelines::PipelineConfig, sinks::SinkConfig, sources::SourceConfig}, encryption::{self, generate_random_key}, startup::run, }; +use postgres::sqlx::options::PgDatabaseOptions; +use postgres::sqlx::test_utils::drop_pg_database; use reqwest::{IntoUrl, RequestBuilder}; use serde::{Deserialize, Serialize}; +use std::io; +use std::net::TcpListener; use tokio::runtime::Handle; use uuid::Uuid; -pub struct TestApp { - pub address: String, - pub api_client: reqwest::Client, - pub api_key: String, - settings: Settings, - server_handle: tokio::task::JoinHandle>, -} - -impl Drop for TestApp { - fn drop(&mut self) { - // First, abort the server task to ensure it's terminated. - self.server_handle.abort(); - - // To use `block_in_place,` we need a multithreaded runtime since when a blocking - // task is issued, the runtime will offload existing tasks to another worker. - tokio::task::block_in_place(move || { - Handle::current() - .block_on(async move { destroy_database(&self.settings.database).await }); - }); - } -} - #[derive(Serialize)] pub struct CreateTenantRequest { pub id: String, @@ -217,6 +196,27 @@ pub struct UpdateImageRequest { pub is_default: bool, } +pub struct TestApp { + pub address: String, + pub api_client: reqwest::Client, + pub api_key: String, + options: PgDatabaseOptions, + server_handle: tokio::task::JoinHandle>, +} + +impl Drop for TestApp { + fn drop(&mut self) { + // First, abort the server task to ensure it's terminated. + self.server_handle.abort(); + + // To use `block_in_place,` we need a multithreaded runtime since when a blocking + // task is issued, the runtime will offload existing tasks to another worker. + tokio::task::block_in_place(move || { + Handle::current().block_on(async move { drop_pg_database(&self.options).await }); + }); + } +} + impl TestApp { fn get_authenticated(&self, url: U) -> RequestBuilder { self.api_client.get(url).bearer_auth(self.api_key.clone()) @@ -533,9 +533,10 @@ pub async fn spawn_test_app() -> TestApp { let port = listener.local_addr().unwrap().port(); let mut settings = get_settings::<'_, Settings>().expect("Failed to read configuration"); + // We use a random database name. settings.database.name = Uuid::new_v4().to_string(); - let connection_pool = create_and_configure_database(&settings.database).await; + let connection_pool = create_pg_replicate_api_database(&settings.database).await; let key = generate_random_key::<32>().expect("failed to generate random key"); let encryption_key = encryption::EncryptionKey { id: 0, key }; @@ -557,7 +558,7 @@ pub async fn spawn_test_app() -> TestApp { address: format!("http://{base_address}:{port}"), api_client: reqwest::Client::new(), api_key, - settings, + options: settings.database, server_handle, } } diff --git a/pg_replicate/Cargo.toml b/pg_replicate/Cargo.toml index 80d56061..93c5350a 100644 --- a/pg_replicate/Cargo.toml +++ b/pg_replicate/Cargo.toml @@ -18,6 +18,8 @@ name = "stdout" required-features = ["stdout"] [dependencies] +postgres = { workspace = true, features = ["tokio"]} + async-trait = { workspace = true } bigdecimal = { workspace = true, features = ["std"] } bytes = { workspace = true } @@ -38,7 +40,7 @@ rustls = { workspace = true, features = ["aws-lc-rs", "logging"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["std"] } thiserror = { workspace = true } -tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +tokio = { workspace = true, features = ["rt-multi-thread", "macros", "sync"] } tokio-postgres = { workspace = true, features = [ "runtime", "with-chrono-0_4", @@ -50,6 +52,8 @@ tracing = { workspace = true, default-features = true } uuid = { workspace = true, features = ["v4"] } [dev-dependencies] +postgres = { workspace = true, features = ["test-utils", "tokio"]} + clap = { workspace = true, default-features = true, features = [ "std", "derive", @@ -58,10 +62,11 @@ tracing-subscriber = { workspace = true, default-features = true, features = [ "env-filter", ] } + [features] bigquery = ["dep:gcp-bigquery-client", "dep:prost"] duckdb = ["dep:duckdb"] stdout = [] # When enabled converts unknown types to bytes unknown_types_to_bytes = [] -default = ["unknown_types_to_bytes"] +default = ["unknown_types_to_bytes"] \ No newline at end of file diff --git a/pg_replicate/examples/bigquery.rs b/pg_replicate/examples/bigquery.rs index 9cfd41de..e8f16429 100644 --- a/pg_replicate/examples/bigquery.rs +++ b/pg_replicate/examples/bigquery.rs @@ -8,9 +8,10 @@ use pg_replicate::{ sources::postgres::{PostgresSource, TableNamesFrom}, PipelineAction, }, - table::TableName, SslMode, }; +use postgres::schema::TableName; +use postgres::tokio::options::PgDatabaseOptions; use tracing::error; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -120,22 +121,22 @@ async fn main_impl() -> Result<(), Box> { let db_args = args.db_args; let bq_args = args.bq_args; + let options = PgDatabaseOptions { + host: db_args.db_host, + port: db_args.db_port, + name: db_args.db_name, + username: db_args.db_username, + password: db_args.db_password, + ssl_mode: SslMode::Disable, + }; + let (postgres_source, action) = match args.command { Command::CopyTable { schema, name } => { let table_names = vec![TableName { schema, name }]; - let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, - vec![], - None, - TableNamesFrom::Vec(table_names), - ) - .await?; + let postgres_source = + PostgresSource::new(options, vec![], None, TableNamesFrom::Vec(table_names)) + .await?; (postgres_source, PipelineAction::TableCopiesOnly) } Command::Cdc { @@ -143,12 +144,7 @@ async fn main_impl() -> Result<(), Box> { slot_name, } => { let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, + options, vec![], Some(slot_name), TableNamesFrom::Publication(publication), diff --git a/pg_replicate/examples/duckdb.rs b/pg_replicate/examples/duckdb.rs index c222203e..a989de64 100644 --- a/pg_replicate/examples/duckdb.rs +++ b/pg_replicate/examples/duckdb.rs @@ -8,9 +8,10 @@ use pg_replicate::{ sources::postgres::{PostgresSource, TableNamesFrom}, PipelineAction, }, - table::TableName, SslMode, }; +use postgres::schema::TableName; +use postgres::tokio::options::PgDatabaseOptions; use tracing::error; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -114,22 +115,22 @@ async fn main_impl() -> Result<(), Box> { let args = AppArgs::parse(); let db_args = args.db_args; + let options = PgDatabaseOptions { + host: db_args.db_host, + port: db_args.db_port, + name: db_args.db_name, + username: db_args.db_username, + password: db_args.db_password, + ssl_mode: SslMode::Disable, + }; + let (postgres_source, action) = match args.command { Command::CopyTable { schema, name } => { let table_names = vec![TableName { schema, name }]; - let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, - vec![], - None, - TableNamesFrom::Vec(table_names), - ) - .await?; + let postgres_source = + PostgresSource::new(options, vec![], None, TableNamesFrom::Vec(table_names)) + .await?; (postgres_source, PipelineAction::TableCopiesOnly) } Command::Cdc { @@ -137,12 +138,7 @@ async fn main_impl() -> Result<(), Box> { slot_name, } => { let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, + options, vec![], Some(slot_name), TableNamesFrom::Publication(publication), diff --git a/pg_replicate/examples/stdout.rs b/pg_replicate/examples/stdout.rs index 330c8fab..5172cbe2 100644 --- a/pg_replicate/examples/stdout.rs +++ b/pg_replicate/examples/stdout.rs @@ -8,9 +8,10 @@ use pg_replicate::{ sources::postgres::{PostgresSource, TableNamesFrom}, PipelineAction, }, - table::TableName, SslMode, }; +use postgres::schema::TableName; +use postgres::tokio::options::PgDatabaseOptions; use tracing::error; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -90,22 +91,22 @@ async fn main_impl() -> Result<(), Box> { let args = AppArgs::parse(); let db_args = args.db_args; + let options = PgDatabaseOptions { + host: db_args.db_host, + port: db_args.db_port, + name: db_args.db_name, + username: db_args.db_username, + password: db_args.db_password, + ssl_mode: SslMode::Disable, + }; + let (postgres_source, action) = match args.command { Command::CopyTable { schema, name } => { let table_names = vec![TableName { schema, name }]; - let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, - vec![], - None, - TableNamesFrom::Vec(table_names), - ) - .await?; + let postgres_source = + PostgresSource::new(options, vec![], None, TableNamesFrom::Vec(table_names)) + .await?; (postgres_source, PipelineAction::TableCopiesOnly) } Command::Cdc { @@ -113,12 +114,7 @@ async fn main_impl() -> Result<(), Box> { slot_name, } => { let postgres_source = PostgresSource::new( - &db_args.db_host, - db_args.db_port, - &db_args.db_name, - &db_args.db_username, - db_args.db_password, - SslMode::Disable, + options, vec![], Some(slot_name), TableNamesFrom::Publication(publication), diff --git a/pg_replicate/src/clients/bigquery.rs b/pg_replicate/src/clients/bigquery.rs index c22a62ee..6b1f213e 100644 --- a/pg_replicate/src/clients/bigquery.rs +++ b/pg_replicate/src/clients/bigquery.rs @@ -15,17 +15,15 @@ use gcp_bigquery_client::{ storage::{ColumnType, FieldDescriptor, StreamName, TableDescriptor}, Client, }; +use postgres::schema::{ColumnSchema, TableId, TableSchema}; use prost::Message; use tokio_postgres::types::{PgLsn, Type}; use tracing::info; use uuid::Uuid; use crate::conversions::numeric::PgNumeric; +use crate::conversions::table_row::TableRow; use crate::conversions::{ArrayCell, Cell}; -use crate::{ - conversions::table_row::TableRow, - table::{ColumnSchema, TableId, TableSchema}, -}; pub struct BigQueryClient { project_id: String, @@ -990,102 +988,105 @@ impl ArrayCell { } } -impl From<&TableSchema> for TableDescriptor { - fn from(table_schema: &TableSchema) -> Self { - let mut field_descriptors = Vec::with_capacity(table_schema.column_schemas.len()); - let mut number = 1; - for column_schema in &table_schema.column_schemas { - let typ = match column_schema.typ { - Type::BOOL => ColumnType::Bool, - Type::CHAR | Type::BPCHAR | Type::VARCHAR | Type::NAME | Type::TEXT => { - ColumnType::String - } - Type::INT2 => ColumnType::Int32, - Type::INT4 => ColumnType::Int32, - Type::INT8 => ColumnType::Int64, - Type::FLOAT4 => ColumnType::Float, - Type::FLOAT8 => ColumnType::Double, - Type::NUMERIC => ColumnType::String, - Type::DATE => ColumnType::String, - Type::TIME => ColumnType::String, - Type::TIMESTAMP => ColumnType::String, - Type::TIMESTAMPTZ => ColumnType::String, - Type::UUID => ColumnType::String, - Type::JSON => ColumnType::String, - Type::JSONB => ColumnType::String, - Type::OID => ColumnType::Int32, - Type::BYTEA => ColumnType::Bytes, - Type::BOOL_ARRAY => ColumnType::Bool, - Type::CHAR_ARRAY - | Type::BPCHAR_ARRAY - | Type::VARCHAR_ARRAY - | Type::NAME_ARRAY - | Type::TEXT_ARRAY => ColumnType::String, - Type::INT2_ARRAY => ColumnType::Int32, - Type::INT4_ARRAY => ColumnType::Int32, - Type::INT8_ARRAY => ColumnType::Int64, - Type::FLOAT4_ARRAY => ColumnType::Float, - Type::FLOAT8_ARRAY => ColumnType::Double, - Type::NUMERIC_ARRAY => ColumnType::String, - Type::DATE_ARRAY => ColumnType::String, - Type::TIME_ARRAY => ColumnType::String, - Type::TIMESTAMP_ARRAY => ColumnType::String, - Type::TIMESTAMPTZ_ARRAY => ColumnType::String, - Type::UUID_ARRAY => ColumnType::String, - Type::JSON_ARRAY => ColumnType::String, - Type::JSONB_ARRAY => ColumnType::String, - Type::OID_ARRAY => ColumnType::Int32, - Type::BYTEA_ARRAY => ColumnType::Bytes, - _ => ColumnType::String, - }; - - let mode = match column_schema.typ { - Type::BOOL_ARRAY - | Type::CHAR_ARRAY - | Type::BPCHAR_ARRAY - | Type::VARCHAR_ARRAY - | Type::NAME_ARRAY - | Type::TEXT_ARRAY - | Type::INT2_ARRAY - | Type::INT4_ARRAY - | Type::INT8_ARRAY - | Type::FLOAT4_ARRAY - | Type::FLOAT8_ARRAY - | Type::NUMERIC_ARRAY - | Type::DATE_ARRAY - | Type::TIME_ARRAY - | Type::TIMESTAMP_ARRAY - | Type::TIMESTAMPTZ_ARRAY - | Type::UUID_ARRAY - | Type::JSON_ARRAY - | Type::JSONB_ARRAY - | Type::OID_ARRAY - | Type::BYTEA_ARRAY => ColumnMode::Repeated, - _ => { - if column_schema.nullable { - ColumnMode::Nullable - } else { - ColumnMode::Required - } +/// Converts a [`TableSchema`] to [`TableDescriptor`]. +/// +/// This function is defined here and doesn't use the [`From`] trait because it's not possible since +/// [`TableSchema`] is in another crate and we don't want to pollute the `postgres` crate with sink +/// specific internals. +pub fn table_schema_to_descriptor(table_schema: &TableSchema) -> TableDescriptor { + let mut field_descriptors = Vec::with_capacity(table_schema.column_schemas.len()); + let mut number = 1; + for column_schema in &table_schema.column_schemas { + let typ = match column_schema.typ { + Type::BOOL => ColumnType::Bool, + Type::CHAR | Type::BPCHAR | Type::VARCHAR | Type::NAME | Type::TEXT => { + ColumnType::String + } + Type::INT2 => ColumnType::Int32, + Type::INT4 => ColumnType::Int32, + Type::INT8 => ColumnType::Int64, + Type::FLOAT4 => ColumnType::Float, + Type::FLOAT8 => ColumnType::Double, + Type::NUMERIC => ColumnType::String, + Type::DATE => ColumnType::String, + Type::TIME => ColumnType::String, + Type::TIMESTAMP => ColumnType::String, + Type::TIMESTAMPTZ => ColumnType::String, + Type::UUID => ColumnType::String, + Type::JSON => ColumnType::String, + Type::JSONB => ColumnType::String, + Type::OID => ColumnType::Int32, + Type::BYTEA => ColumnType::Bytes, + Type::BOOL_ARRAY => ColumnType::Bool, + Type::CHAR_ARRAY + | Type::BPCHAR_ARRAY + | Type::VARCHAR_ARRAY + | Type::NAME_ARRAY + | Type::TEXT_ARRAY => ColumnType::String, + Type::INT2_ARRAY => ColumnType::Int32, + Type::INT4_ARRAY => ColumnType::Int32, + Type::INT8_ARRAY => ColumnType::Int64, + Type::FLOAT4_ARRAY => ColumnType::Float, + Type::FLOAT8_ARRAY => ColumnType::Double, + Type::NUMERIC_ARRAY => ColumnType::String, + Type::DATE_ARRAY => ColumnType::String, + Type::TIME_ARRAY => ColumnType::String, + Type::TIMESTAMP_ARRAY => ColumnType::String, + Type::TIMESTAMPTZ_ARRAY => ColumnType::String, + Type::UUID_ARRAY => ColumnType::String, + Type::JSON_ARRAY => ColumnType::String, + Type::JSONB_ARRAY => ColumnType::String, + Type::OID_ARRAY => ColumnType::Int32, + Type::BYTEA_ARRAY => ColumnType::Bytes, + _ => ColumnType::String, + }; + + let mode = match column_schema.typ { + Type::BOOL_ARRAY + | Type::CHAR_ARRAY + | Type::BPCHAR_ARRAY + | Type::VARCHAR_ARRAY + | Type::NAME_ARRAY + | Type::TEXT_ARRAY + | Type::INT2_ARRAY + | Type::INT4_ARRAY + | Type::INT8_ARRAY + | Type::FLOAT4_ARRAY + | Type::FLOAT8_ARRAY + | Type::NUMERIC_ARRAY + | Type::DATE_ARRAY + | Type::TIME_ARRAY + | Type::TIMESTAMP_ARRAY + | Type::TIMESTAMPTZ_ARRAY + | Type::UUID_ARRAY + | Type::JSON_ARRAY + | Type::JSONB_ARRAY + | Type::OID_ARRAY + | Type::BYTEA_ARRAY => ColumnMode::Repeated, + _ => { + if column_schema.nullable { + ColumnMode::Nullable + } else { + ColumnMode::Required } - }; - - field_descriptors.push(FieldDescriptor { - number, - name: column_schema.name.clone(), - typ, - mode, - }); - number += 1; - } + } + }; field_descriptors.push(FieldDescriptor { number, - name: "_CHANGE_TYPE".to_string(), - typ: ColumnType::String, - mode: ColumnMode::Required, + name: column_schema.name.clone(), + typ, + mode, }); - - TableDescriptor { field_descriptors } + number += 1; } + + field_descriptors.push(FieldDescriptor { + number, + name: "_CHANGE_TYPE".to_string(), + typ: ColumnType::String, + mode: ColumnMode::Required, + }); + + TableDescriptor { field_descriptors } } diff --git a/pg_replicate/src/clients/duckdb.rs b/pg_replicate/src/clients/duckdb.rs index 389a29f6..e1421af4 100644 --- a/pg_replicate/src/clients/duckdb.rs +++ b/pg_replicate/src/clients/duckdb.rs @@ -5,12 +5,10 @@ use duckdb::{ types::{ToSqlOutput, Value}, Config, Connection, ToSql, }; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use tokio_postgres::types::{PgLsn, Type}; -use crate::{ - conversions::{table_row::TableRow, ArrayCell, Cell}, - table::{ColumnSchema, TableId, TableName, TableSchema}, -}; +use crate::conversions::{table_row::TableRow, ArrayCell, Cell}; pub struct DuckDbClient { conn: Connection, diff --git a/pg_replicate/src/clients/postgres.rs b/pg_replicate/src/clients/postgres.rs index 9581c13d..b2c1b03c 100644 --- a/pg_replicate/src/clients/postgres.rs +++ b/pg_replicate/src/clients/postgres.rs @@ -1,19 +1,19 @@ use std::collections::HashMap; use pg_escape::{quote_identifier, quote_literal}; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; +use postgres::tokio::options::PgDatabaseOptions; use postgres_replication::LogicalReplicationStream; use rustls::{pki_types::CertificateDer, ClientConfig}; use thiserror::Error; use tokio_postgres::{ - config::{ReplicationMode, SslMode}, + config::ReplicationMode, types::{Kind, PgLsn, Type}, Client as PostgresClient, Config, CopyOutStream, NoTls, SimpleQueryMessage, }; use tokio_postgres_rustls::MakeRustlsConnect; use tracing::{info, warn}; -use crate::table::{ColumnSchema, TableId, TableName, TableSchema}; - pub struct SlotInfo { pub confirmed_flush_lsn: PgLsn, } @@ -63,25 +63,12 @@ pub enum ReplicationClientError { impl ReplicationClient { /// Connect to a postgres database in logical replication mode without TLS pub async fn connect_no_tls( - host: &str, - port: u16, - database: &str, - username: &str, - password: Option, + options: PgDatabaseOptions, ) -> Result { info!("connecting to postgres without TLS"); - let mut config = Config::new(); - config - .host(host) - .port(port) - .dbname(database) - .user(username) - .replication_mode(ReplicationMode::Logical); - - if let Some(password) = password { - config.password(password); - } + let mut config: Config = options.into(); + config.replication_mode(ReplicationMode::Logical); let (postgres_client, connection) = config.connect(NoTls).await?; @@ -89,7 +76,9 @@ impl ReplicationClient { info!("waiting for connection to terminate"); if let Err(e) = connection.await { warn!("connection error: {}", e); + return; } + info!("connection terminated successfully") }); info!("successfully connected to postgres"); @@ -102,28 +91,13 @@ impl ReplicationClient { /// Connect to a postgres database in logical replication mode with TLS pub async fn connect_tls( - host: &str, - port: u16, - database: &str, - username: &str, - password: Option, - ssl_mode: SslMode, + options: PgDatabaseOptions, trusted_root_certs: Vec>, ) -> Result { info!("connecting to postgres with TLS"); - let mut config = Config::new(); - config - .host(host) - .port(port) - .dbname(database) - .user(username) - .ssl_mode(ssl_mode) - .replication_mode(ReplicationMode::Logical); - - if let Some(password) = password { - config.password(password); - } + let mut config: Config = options.into(); + config.replication_mode(ReplicationMode::Logical); let mut root_store = rustls::RootCertStore::empty(); for trusted_root_cert in trusted_root_certs { @@ -141,7 +115,9 @@ impl ReplicationClient { info!("waiting for connection to terminate"); if let Err(e) = connection.await { warn!("connection error: {}", e); + return; } + info!("connection terminated successfully") }); info!("successfully connected to postgres"); diff --git a/pg_replicate/src/conversions/cdc_event.rs b/pg_replicate/src/conversions/cdc_event.rs index c3f1d935..a6f493f1 100644 --- a/pg_replicate/src/conversions/cdc_event.rs +++ b/pg_replicate/src/conversions/cdc_event.rs @@ -1,16 +1,14 @@ use core::str; use std::{collections::HashMap, str::Utf8Error}; +use postgres::schema::{ColumnSchema, TableId, TableSchema}; use postgres_replication::protocol::{ BeginBody, CommitBody, DeleteBody, InsertBody, LogicalReplicationMessage, OriginBody, RelationBody, ReplicationMessage, TruncateBody, TupleData, TypeBody, UpdateBody, }; use thiserror::Error; -use crate::{ - pipeline::batching::BatchBoundary, - table::{ColumnSchema, TableId, TableSchema}, -}; +use crate::pipeline::batching::BatchBoundary; use super::{ table_row::TableRow, diff --git a/pg_replicate/src/conversions/table_row.rs b/pg_replicate/src/conversions/table_row.rs index 411177d5..5273b071 100644 --- a/pg_replicate/src/conversions/table_row.rs +++ b/pg_replicate/src/conversions/table_row.rs @@ -1,6 +1,7 @@ use core::str; use std::str::Utf8Error; +use postgres::schema::ColumnSchema; use thiserror::Error; use tokio_postgres::types::Type; use tracing::error; @@ -9,7 +10,7 @@ use crate::{conversions::text::TextFormatConverter, pipeline::batching::BatchBou use super::{text::FromTextError, Cell}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct TableRow { pub values: Vec, } @@ -44,7 +45,7 @@ impl TableRowConverter { // parses text produced by this code in Postgres: https://github.com/postgres/postgres/blob/263a3f5f7f508167dbeafc2aefd5835b41d77481/src/backend/commands/copyto.c#L988-L1134 pub fn try_from( row: &[u8], - column_schemas: &[crate::table::ColumnSchema], + column_schemas: &[ColumnSchema], ) -> Result { let mut values = Vec::with_capacity(column_schemas.len()); diff --git a/pg_replicate/src/lib.rs b/pg_replicate/src/lib.rs index d725db6a..5eab8fde 100644 --- a/pg_replicate/src/lib.rs +++ b/pg_replicate/src/lib.rs @@ -1,5 +1,4 @@ pub mod clients; pub mod conversions; pub mod pipeline; -pub mod table; pub use tokio_postgres::config::SslMode; diff --git a/pg_replicate/src/pipeline/batching/data_pipeline.rs b/pg_replicate/src/pipeline/batching/data_pipeline.rs index e1f371fd..3785a78c 100644 --- a/pg_replicate/src/pipeline/batching/data_pipeline.rs +++ b/pg_replicate/src/pipeline/batching/data_pipeline.rs @@ -1,7 +1,9 @@ -use std::{collections::HashSet, time::Instant}; - use futures::StreamExt; +use postgres::schema::TableId; +use std::sync::Arc; +use std::{collections::HashSet, time::Instant}; use tokio::pin; +use tokio::sync::Notify; use tokio_postgres::types::PgLsn; use tracing::{debug, info}; @@ -13,16 +15,31 @@ use crate::{ sources::{postgres::CdcStreamError, CommonSourceError, Source}, PipelineAction, PipelineError, }, - table::TableId, }; use super::BatchConfig; +#[derive(Debug, Clone)] +pub struct BatchDataPipelineHandle { + stream_stop: Arc, +} + +impl BatchDataPipelineHandle { + pub fn stop(&self) { + // We want to notify all waiters that their streams have to be stopped. + // + // Technically, we should not need to notify multiple waiters since we can't have multiple + // streams active in parallel, but in this way we cover for the future. + self.stream_stop.notify_waiters(); + } +} + pub struct BatchDataPipeline { source: Src, sink: Snk, action: PipelineAction, batch_config: BatchConfig, + stream_stop: Arc, } impl BatchDataPipeline { @@ -32,6 +49,7 @@ impl BatchDataPipeline { sink, action, batch_config, + stream_stop: Arc::new(Notify::new()), } } @@ -77,8 +95,11 @@ impl BatchDataPipeline { .await .map_err(PipelineError::Source)?; - let batch_timeout_stream = - BatchTimeoutStream::new(table_rows, self.batch_config.clone()); + let batch_timeout_stream = BatchTimeoutStream::new( + table_rows, + self.batch_config.clone(), + self.stream_stop.notified(), + ); pin!(batch_timeout_stream); @@ -123,16 +144,19 @@ impl BatchDataPipeline { let mut last_lsn: u64 = last_lsn.into(); last_lsn += 1; + let cdc_events = self .source .get_cdc_stream(last_lsn.into()) .await .map_err(PipelineError::Source)?; - pin!(cdc_events); - let batch_timeout_stream = BatchTimeoutStream::new(cdc_events, self.batch_config.clone()); - + let batch_timeout_stream = BatchTimeoutStream::new( + cdc_events, + self.batch_config.clone(), + self.stream_stop.notified(), + ); pin!(batch_timeout_stream); while let Some(batch) = batch_timeout_stream.next().await { @@ -201,4 +225,18 @@ impl BatchDataPipeline { Ok(()) } + + pub fn handle(&self) -> BatchDataPipelineHandle { + BatchDataPipelineHandle { + stream_stop: self.stream_stop.clone(), + } + } + + pub fn source(&self) -> &Src { + &self.source + } + + pub fn sink(&self) -> &Snk { + &self.sink + } } diff --git a/pg_replicate/src/pipeline/batching/stream.rs b/pg_replicate/src/pipeline/batching/stream.rs index 3495e010..73ab5c3e 100644 --- a/pg_replicate/src/pipeline/batching/stream.rs +++ b/pg_replicate/src/pipeline/batching/stream.rs @@ -2,40 +2,47 @@ use futures::{ready, Future, Stream}; use pin_project_lite::pin_project; use tokio::time::{sleep, Sleep}; +use super::{BatchBoundary, BatchConfig}; use core::pin::Pin; use core::task::{Context, Poll}; - -use super::{BatchBoundary, BatchConfig}; +use tokio::sync::futures::Notified; +use tracing::info; // Implementation adapted from https://github.com/tokio-rs/tokio/blob/master/tokio-stream/src/stream_ext/chunks_timeout.rs pin_project! { /// Adapter stream which batches the items of the underlying stream when it /// reaches max_size or when a timeout expires. The underlying streams items /// must implement [`BatchBoundary`]. A batch is guaranteed to end on an - /// item which returns true from [`BatchBoundary::is_last_in_batch`] + /// item which returns true from [`BatchBoundary::is_last_in_batch`] unless the + /// stream is forcefully stopped. #[must_use = "streams do nothing unless polled"] #[derive(Debug)] - pub struct BatchTimeoutStream> { + pub struct BatchTimeoutStream<'a, B: BatchBoundary, S: Stream> { #[pin] stream: S, #[pin] deadline: Option, + #[pin] + stream_stop: Notified<'a>, items: Vec, batch_config: BatchConfig, reset_timer: bool, inner_stream_ended: bool, + stream_stopped: bool } } -impl> BatchTimeoutStream { - pub fn new(stream: S, batch_config: BatchConfig) -> Self { +impl<'a, B: BatchBoundary, S: Stream> BatchTimeoutStream<'a, B, S> { + pub fn new(stream: S, batch_config: BatchConfig, stream_stop: Notified<'a>) -> Self { BatchTimeoutStream { stream, deadline: None, + stream_stop, items: Vec::with_capacity(batch_config.max_batch_size), batch_config, reset_timer: true, inner_stream_ended: false, + stream_stopped: false, } } @@ -44,15 +51,33 @@ impl> BatchTimeoutStream { } } -impl> Stream for BatchTimeoutStream { +impl<'a, B: BatchBoundary, S: Stream> Stream for BatchTimeoutStream<'a, B, S> { type Item = Vec; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.as_mut().project(); + if *this.inner_stream_ended { return Poll::Ready(None); } + loop { + if *this.stream_stopped { + return Poll::Ready(None); + } + + // If the stream has been asked to stop, we mark the stream as stopped and return the + // remaining elements, irrespectively of boundaries. + if this.stream_stop.as_mut().poll(cx).is_ready() { + info!("the stream has been forcefully stopped"); + *this.stream_stopped = true; + return if !this.items.is_empty() { + Poll::Ready(Some(std::mem::take(this.items))) + } else { + Poll::Ready(None) + }; + } + if *this.reset_timer { this.deadline .set(Some(sleep(this.batch_config.max_batch_fill_time))); diff --git a/pg_replicate/src/pipeline/mod.rs b/pg_replicate/src/pipeline/mod.rs index b64d1637..8af5e194 100644 --- a/pg_replicate/src/pipeline/mod.rs +++ b/pg_replicate/src/pipeline/mod.rs @@ -1,12 +1,11 @@ use std::collections::HashSet; +use postgres::schema::TableId; use sinks::SinkError; use sources::SourceError; use thiserror::Error; use tokio_postgres::types::PgLsn; -use crate::table::TableId; - pub mod batching; pub mod sinks; pub mod sources; @@ -18,6 +17,7 @@ pub enum PipelineAction { Both, } +#[derive(Debug)] pub struct PipelineResumptionState { pub copied_tables: HashSet, pub last_lsn: PgLsn, diff --git a/pg_replicate/src/pipeline/sinks/bigquery.rs b/pg_replicate/src/pipeline/sinks/bigquery.rs index 57048f59..20330a28 100644 --- a/pg_replicate/src/pipeline/sinks/bigquery.rs +++ b/pg_replicate/src/pipeline/sinks/bigquery.rs @@ -2,19 +2,19 @@ use std::collections::HashMap; use async_trait::async_trait; use gcp_bigquery_client::error::BQError; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use thiserror::Error; use tokio_postgres::types::{PgLsn, Type}; use tracing::info; +use super::{BatchSink, SinkError}; +use crate::clients::bigquery::table_schema_to_descriptor; use crate::{ clients::bigquery::BigQueryClient, conversions::{cdc_event::CdcEvent, table_row::TableRow, Cell}, pipeline::PipelineResumptionState, - table::{ColumnSchema, TableId, TableName, TableSchema}, }; -use super::{BatchSink, SinkError}; - #[derive(Debug, Error)] pub enum BigQuerySinkError { #[error("big query error: {0}")] @@ -183,7 +183,7 @@ impl BatchSink for BigQueryBatchSink { ) -> Result<(), Self::Error> { let table_schema = self.get_table_schema(table_id)?; let table_name = Self::table_name_in_bq(&table_schema.table_name); - let table_descriptor = table_schema.into(); + let table_descriptor = table_schema_to_descriptor(table_schema); for table_row in &mut table_rows { table_row.values.push(Cell::String("UPSERT".to_string())); @@ -250,7 +250,7 @@ impl BatchSink for BigQueryBatchSink { for (table_id, table_rows) in table_name_to_table_rows { let table_schema = self.get_table_schema(table_id)?; let table_name = Self::table_name_in_bq(&table_schema.table_name); - let table_descriptor = table_schema.into(); + let table_descriptor = table_schema_to_descriptor(table_schema); self.client .stream_rows(&self.dataset_id, table_name, &table_descriptor, &table_rows) .await?; diff --git a/pg_replicate/src/pipeline/sinks/duckdb/executor.rs b/pg_replicate/src/pipeline/sinks/duckdb/executor.rs index 905d56e6..b6cdd64a 100644 --- a/pg_replicate/src/pipeline/sinks/duckdb/executor.rs +++ b/pg_replicate/src/pipeline/sinks/duckdb/executor.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use thiserror::Error; use tokio::sync::mpsc::{error::SendError, Receiver, Sender}; use tokio_postgres::types::{PgLsn, Type}; @@ -9,7 +10,6 @@ use crate::{ clients::duckdb::DuckDbClient, conversions::{cdc_event::CdcEvent, table_row::TableRow}, pipeline::{sinks::SinkError, PipelineResumptionState}, - table::{ColumnSchema, TableId, TableName, TableSchema}, }; pub enum DuckDbRequest { diff --git a/pg_replicate/src/pipeline/sinks/duckdb/sink.rs b/pg_replicate/src/pipeline/sinks/duckdb/sink.rs index 48393255..11f1171d 100644 --- a/pg_replicate/src/pipeline/sinks/duckdb/sink.rs +++ b/pg_replicate/src/pipeline/sinks/duckdb/sink.rs @@ -3,13 +3,13 @@ use std::{collections::HashMap, path::Path}; use tokio::sync::mpsc::{channel, Receiver, Sender}; use async_trait::async_trait; +use postgres::schema::{TableId, TableSchema}; use tokio_postgres::types::PgLsn; use crate::{ clients::duckdb::DuckDbClient, conversions::{cdc_event::CdcEvent, table_row::TableRow}, pipeline::{sinks::BatchSink, PipelineResumptionState}, - table::{TableId, TableSchema}, }; use super::{ diff --git a/pg_replicate/src/pipeline/sinks/mod.rs b/pg_replicate/src/pipeline/sinks/mod.rs index 94e9c101..79b6d2cb 100644 --- a/pg_replicate/src/pipeline/sinks/mod.rs +++ b/pg_replicate/src/pipeline/sinks/mod.rs @@ -1,13 +1,11 @@ use std::collections::HashMap; use async_trait::async_trait; +use postgres::schema::{TableId, TableSchema}; use thiserror::Error; use tokio_postgres::types::PgLsn; -use crate::{ - conversions::{cdc_event::CdcEvent, table_row::TableRow}, - table::{TableId, TableSchema}, -}; +use crate::conversions::{cdc_event::CdcEvent, table_row::TableRow}; use super::PipelineResumptionState; diff --git a/pg_replicate/src/pipeline/sinks/stdout.rs b/pg_replicate/src/pipeline/sinks/stdout.rs index 6d79dd1d..9e98525b 100644 --- a/pg_replicate/src/pipeline/sinks/stdout.rs +++ b/pg_replicate/src/pipeline/sinks/stdout.rs @@ -1,13 +1,13 @@ use std::collections::{HashMap, HashSet}; use async_trait::async_trait; +use postgres::schema::{TableId, TableSchema}; use tokio_postgres::types::PgLsn; use tracing::info; use crate::{ conversions::{cdc_event::CdcEvent, table_row::TableRow}, pipeline::PipelineResumptionState, - table::{TableId, TableSchema}, }; use super::{BatchSink, InfallibleSinkError}; diff --git a/pg_replicate/src/pipeline/sources/mod.rs b/pg_replicate/src/pipeline/sources/mod.rs index d35d1643..51f2feaf 100644 --- a/pg_replicate/src/pipeline/sources/mod.rs +++ b/pg_replicate/src/pipeline/sources/mod.rs @@ -1,12 +1,11 @@ use std::collections::HashMap; +use ::postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use async_trait::async_trait; use thiserror::Error; use tokio_postgres::types::PgLsn; -use crate::table::{ColumnSchema, TableId, TableName, TableSchema}; - -use self::postgres::{ +use postgres::{ CdcStream, CdcStreamError, PostgresSourceError, StatusUpdateError, TableCopyStream, TableCopyStreamError, }; diff --git a/pg_replicate/src/pipeline/sources/postgres.rs b/pg_replicate/src/pipeline/sources/postgres.rs index 9c100131..6f7a8bf9 100644 --- a/pg_replicate/src/pipeline/sources/postgres.rs +++ b/pg_replicate/src/pipeline/sources/postgres.rs @@ -8,6 +8,8 @@ use std::{ use async_trait::async_trait; use futures::{ready, Stream}; use pin_project_lite::pin_project; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; +use postgres::tokio::options::PgDatabaseOptions; use postgres_replication::LogicalReplicationStream; use rustls::pki_types::CertificateDer; use thiserror::Error; @@ -20,7 +22,6 @@ use crate::{ cdc_event::{CdcEvent, CdcEventConversionError, CdcEventConverter}, table_row::{TableRow, TableRowConversionError, TableRowConverter}, }, - table::{ColumnSchema, TableId, TableName, TableSchema}, }; use super::{Source, SourceError}; @@ -52,41 +53,31 @@ pub struct PostgresSource { } impl PostgresSource { - #[allow(clippy::too_many_arguments)] pub async fn new( - host: &str, - port: u16, - database: &str, - username: &str, - password: Option, - ssl_mode: SslMode, + options: PgDatabaseOptions, trusted_root_certs: Vec>, slot_name: Option, table_names_from: TableNamesFrom, ) -> Result { - let mut replication_client = if ssl_mode == SslMode::Disable { - ReplicationClient::connect_no_tls(host, port, database, username, password).await? - } else { - ReplicationClient::connect_tls( - host, - port, - database, - username, - password, - ssl_mode, - trusted_root_certs, - ) - .await? + let mut replication_client = match options.ssl_mode { + SslMode::Disable => ReplicationClient::connect_no_tls(options).await?, + _ => ReplicationClient::connect_tls(options, trusted_root_certs).await?, }; + + // TODO: we have to fix this whole block which starts the transaction and loads the data. + // We will not do this here in the future but rather let the pipeline drive this based on + // its internal state. replication_client.begin_readonly_transaction().await?; if let Some(ref slot_name) = slot_name { replication_client.get_or_create_slot(slot_name).await?; } + let (table_names, publication) = Self::get_table_names_and_publication(&replication_client, table_names_from).await?; let table_schemas = replication_client .get_table_schemas(&table_names, publication.as_deref()) .await?; + Ok(PostgresSource { replication_client, table_schemas, @@ -163,12 +154,14 @@ impl Source for PostgresSource { async fn get_cdc_stream(&self, start_lsn: PgLsn) -> Result { info!("starting cdc stream at lsn {start_lsn}"); + let publication = self .publication() .ok_or(PostgresSourceError::MissingPublication)?; let slot_name = self .slot_name() .ok_or(PostgresSourceError::MissingSlotName)?; + let stream = self .replication_client .get_logical_replication_stream(publication, slot_name, start_lsn) diff --git a/pg_replicate/src/table.rs b/pg_replicate/src/table.rs deleted file mode 100644 index 73fabfd1..00000000 --- a/pg_replicate/src/table.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::fmt::Display; - -use pg_escape::quote_identifier; -use tokio_postgres::types::Type; - -#[derive(Debug, Clone)] -pub struct TableName { - pub schema: String, - pub name: String, -} - -impl TableName { - pub fn as_quoted_identifier(&self) -> String { - let quoted_schema = quote_identifier(&self.schema); - let quoted_name = quote_identifier(&self.name); - format!("{quoted_schema}.{quoted_name}") - } -} - -impl Display for TableName { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!("{0}.{1}", self.schema, self.name)) - } -} - -type TypeModifier = i32; - -#[derive(Debug, Clone)] -pub struct ColumnSchema { - pub name: String, - pub typ: Type, - pub modifier: TypeModifier, - pub nullable: bool, - pub primary: bool, -} - -pub type TableId = u32; - -#[derive(Debug, Clone)] -pub struct TableSchema { - pub table_name: TableName, - pub table_id: TableId, - pub column_schemas: Vec, -} - -impl TableSchema { - pub fn has_primary_keys(&self) -> bool { - self.column_schemas.iter().any(|cs| cs.primary) - } -} diff --git a/pg_replicate/tests/common/database.rs b/pg_replicate/tests/common/database.rs new file mode 100644 index 00000000..0dda445a --- /dev/null +++ b/pg_replicate/tests/common/database.rs @@ -0,0 +1,55 @@ +use postgres::schema::TableName; +use postgres::tokio::options::PgDatabaseOptions; +use postgres::tokio::test_utils::PgDatabase; +use tokio_postgres::config::SslMode; +use uuid::Uuid; + +/// The schema name used for organizing test tables. +/// +/// This constant defines the default schema where test tables are created, +/// providing isolation from other database objects. +const TEST_DATABASE_SCHEMA: &str = "test"; + +/// Creates a [`TableName`] in the test schema. +/// +/// This helper function constructs a [`TableName`] with the schema set to [`TEST_DATABASE_SCHEMA`] +/// and the provided name as the table name. It's used to ensure consistent table naming +/// across test scenarios. +pub fn test_table_name(name: &str) -> TableName { + TableName { + schema: TEST_DATABASE_SCHEMA.to_owned(), + name: name.to_owned(), + } +} + +/// Creates a new test database instance with a unique name. +/// +/// This function spawns a new PostgreSQL database with a random UUID as its name, +/// using default credentials and disabled SSL. It automatically creates the test schema +/// for organizing test tables. +/// +/// # Panics +/// +/// Panics if the test schema cannot be created. +pub async fn spawn_database() -> PgDatabase { + let options = PgDatabaseOptions { + host: "localhost".to_owned(), + port: 5430, + // We create a random database name to avoid conflicts with existing databases. + name: Uuid::new_v4().to_string(), + username: "postgres".to_owned(), + password: Some("postgres".to_owned()), + ssl_mode: SslMode::Disable, + }; + + let database = PgDatabase::new(options).await; + + // Create the test schema. + database + .client + .execute(&format!("CREATE SCHEMA {}", TEST_DATABASE_SCHEMA), &[]) + .await + .expect("Failed to create test schema"); + + database +} diff --git a/pg_replicate/tests/common/mod.rs b/pg_replicate/tests/common/mod.rs new file mode 100644 index 00000000..ada4cb7e --- /dev/null +++ b/pg_replicate/tests/common/mod.rs @@ -0,0 +1,49 @@ +/// Common utilities and helpers for testing PostgreSQL replication functionality. +/// +/// This module provides shared testing infrastructure including database management, +/// pipeline testing utilities, sink testing helpers, and table manipulation utilities. +/// It also includes common testing patterns like waiting for conditions to be met. +use std::time::{Duration, Instant}; +use tokio::time::sleep; + +pub mod database; +pub mod pipeline; +pub mod sink; +pub mod table; + +/// The maximum duration to wait for test conditions to be met. +/// +/// This constant defines the timeout period for asynchronous test assertions, +/// ensuring tests don't hang indefinitely while waiting for expected states. +const MAX_ASSERTION_DURATION: Duration = Duration::from_secs(20); + +/// The interval between condition checks during test assertions. +/// +/// This constant defines how frequently we poll for condition changes while +/// waiting for test assertions to complete. +const ASSERTION_FREQUENCY_DURATION: Duration = Duration::from_millis(10); + +/// Waits asynchronously for a condition to be met within the maximum timeout period. +/// +/// This function repeatedly evaluates the provided condition until it returns true +/// or the maximum duration is exceeded. It's useful for testing asynchronous +/// operations where the exact completion time is not known. +/// +/// # Panics +/// +/// Panics if the condition is not met within [`MAX_ASSERTION_DURATION`]. +pub async fn wait_for_condition(condition: F) +where + F: Fn() -> bool, +{ + let start = Instant::now(); + while start.elapsed() < MAX_ASSERTION_DURATION { + if condition() { + return; + } + + sleep(ASSERTION_FREQUENCY_DURATION).await; + } + + panic!("Failed to process all events within timeout") +} diff --git a/pg_replicate/tests/common/pipeline.rs b/pg_replicate/tests/common/pipeline.rs new file mode 100644 index 00000000..3ba16a94 --- /dev/null +++ b/pg_replicate/tests/common/pipeline.rs @@ -0,0 +1,156 @@ +use pg_replicate::pipeline::batching::data_pipeline::{BatchDataPipeline, BatchDataPipelineHandle}; +use pg_replicate::pipeline::batching::BatchConfig; +use pg_replicate::pipeline::sinks::BatchSink; +use pg_replicate::pipeline::sources::postgres::{PostgresSource, TableNamesFrom}; +use pg_replicate::pipeline::PipelineAction; +use postgres::schema::TableName; +use postgres::tokio::options::PgDatabaseOptions; +use std::time::Duration; +use tokio::task::JoinHandle; + +/// Defines the operational mode for a PostgreSQL replication pipeline. +pub enum PipelineMode { + /// Initializes a pipeline to copy specified tables. + CopyTable { table_names: Vec }, + /// Initializes a pipeline to consume changes from a publication and replication slot. + /// + /// If no slot name is provided, a new slot will be created on the specified publication. + Cdc { + publication: String, + slot_name: String, + }, +} + +/// Generates a test-specific replication slot name. +/// +/// This function prefixes the provided slot name with "test_" to avoid conflicts +/// with other replication slots. +pub fn test_slot_name(slot_name: &str) -> String { + format!("test_{}", slot_name) +} + +/// Creates a new PostgreSQL replication pipeline. +/// +/// This function initializes a pipeline with a batch size of 1000 records and +/// a maximum batch duration of 10 seconds. +/// +/// # Panics +/// +/// Panics if the PostgreSQL source cannot be created. +pub async fn spawn_pg_pipeline( + options: &PgDatabaseOptions, + mode: PipelineMode, + sink: Snk, +) -> BatchDataPipeline { + let batch_config = BatchConfig::new(1000, Duration::from_secs(10)); + + let pipeline = match mode { + PipelineMode::CopyTable { table_names } => { + let source = PostgresSource::new( + options.clone(), + vec![], + None, + TableNamesFrom::Vec(table_names), + ) + .await + .expect("Failure when creating the Postgres source for copying tables"); + let action = PipelineAction::TableCopiesOnly; + BatchDataPipeline::new(source, sink, action, batch_config) + } + PipelineMode::Cdc { + publication, + slot_name, + } => { + let source = PostgresSource::new( + options.clone(), + vec![], + Some(test_slot_name(&slot_name)), + TableNamesFrom::Publication(publication), + ) + .await + .expect("Failure when creating the Postgres source for cdc"); + let action = PipelineAction::CdcOnly; + BatchDataPipeline::new(source, sink, action, batch_config) + } + }; + + pipeline +} + +/// Creates and spawns a new asynchronous PostgreSQL replication pipeline. +/// +/// This function creates a pipeline and wraps it in a [`PipelineRunner`] for +/// easier management of the pipeline lifecycle. +pub async fn spawn_async_pg_pipeline( + options: &PgDatabaseOptions, + mode: PipelineMode, + sink: Snk, +) -> PipelineRunner { + let pipeline = spawn_pg_pipeline(options, mode, sink).await; + PipelineRunner::new(pipeline) +} + +/// Manages the lifecycle of a PostgreSQL replication pipeline. +/// +/// This struct provides methods to run and stop a pipeline, handling the +/// pipeline's state and ensuring proper cleanup. +pub struct PipelineRunner { + pipeline: Option>, + pipeline_handle: BatchDataPipelineHandle, +} + +impl PipelineRunner { + /// Creates a new pipeline runner with the specified pipeline. + pub fn new(pipeline: BatchDataPipeline) -> Self { + let pipeline_handle = pipeline.handle(); + Self { + pipeline: Some(pipeline), + pipeline_handle, + } + } + + /// Starts the pipeline asynchronously. + /// + /// # Panics + /// + /// Panics if the pipeline has already been run. + pub async fn run(&mut self) -> JoinHandle> { + if let Some(mut pipeline) = self.pipeline.take() { + return tokio::spawn(async move { + pipeline + .start() + .await + .expect("The pipeline experienced an error"); + + pipeline + }); + } + + panic!("The pipeline has already been run"); + } + + /// Stops the pipeline and waits for it to complete. + /// + /// This method signals the pipeline to stop and waits for it to finish + /// before returning. The pipeline is then restored to its initial state + /// for potential reuse. + /// + /// # Panics + /// + /// Panics if the pipeline task fails. + pub async fn stop_and_wait( + &mut self, + pipeline_task_handle: JoinHandle>, + ) { + // We signal the existing pipeline to stop. + self.pipeline_handle.stop(); + + // We wait for the pipeline to finish, and we put it back for the next run. + let pipeline = pipeline_task_handle + .await + .expect("The pipeline task has failed"); + // We recreate the handle just to make sure the pipeline handle and pipelines connected. + self.pipeline_handle = pipeline.handle(); + self.pipeline = Some(pipeline); + } +} diff --git a/pg_replicate/tests/common/sink.rs b/pg_replicate/tests/common/sink.rs new file mode 100644 index 00000000..cfd6151d --- /dev/null +++ b/pg_replicate/tests/common/sink.rs @@ -0,0 +1,169 @@ +use async_trait::async_trait; +use pg_replicate::conversions::cdc_event::CdcEvent; +use pg_replicate::conversions::table_row::TableRow; +use pg_replicate::pipeline::sinks::{BatchSink, InfallibleSinkError}; +use pg_replicate::pipeline::PipelineResumptionState; +use postgres::schema::{TableId, TableSchema}; +use std::cmp::max; +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, Mutex}; +use tokio_postgres::types::PgLsn; + +/// A test sink that captures replication events and data for verification. +/// +/// This sink is designed to be shared across multiple pipelines, simulating +/// persistent storage while maintaining thread safety through interior mutability. +#[derive(Debug, Clone)] +pub struct TestSink { + inner: Arc>, +} + +/// Internal state of the test sink. +/// +/// This struct maintains the sink's state including table schemas, rows, +/// CDC events, and tracking information for copied and truncated tables. +#[derive(Debug)] +struct TestSinkInner { + // We have a Vec to store all the changes of the schema that we receive over time. + tables_schemas: Vec>, + tables_rows: HashMap>, + events: Vec>, + copied_tables: HashSet, + truncated_tables: HashSet, + last_lsn: u64, +} + +impl TestSink { + /// Creates a new test sink with an empty state. + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(TestSinkInner { + tables_schemas: Vec::new(), + tables_rows: HashMap::new(), + events: Vec::new(), + copied_tables: HashSet::new(), + truncated_tables: HashSet::new(), + last_lsn: 0, + })), + } + } + + /// Updates the last LSN based on received events. + /// + /// This method ensures that the last LSN is monotonically increasing, + /// taking the maximum between the current LSN and the maximum LSN from + /// the received events. + fn receive_events(&mut self, events: &[CdcEvent]) { + let mut max_lsn = 0; + for event in events { + if let CdcEvent::Commit(commit_body) = event { + max_lsn = max(max_lsn, commit_body.commit_lsn()); + } + } + + // We update the last lsn taking the maximum between the maximum of the event stream and + // the current lsn, since we assume that lsns are guaranteed to be monotonically increasing, + // so if we see a max lsn, we can be sure that all events before that point have been received. + let mut inner = self.inner.lock().unwrap(); + inner.last_lsn = max(inner.last_lsn, max_lsn); + } + + /// Returns a copy of all table schemas received by the sink. + pub fn get_tables_schemas(&self) -> Vec> { + self.inner.lock().unwrap().tables_schemas.clone() + } + + /// Returns a copy of all table rows received by the sink. + pub fn get_tables_rows(&self) -> HashMap> { + self.inner.lock().unwrap().tables_rows.clone() + } + + /// Returns a copy of all CDC events received by the sink. + pub fn get_events(&self) -> Vec> { + self.inner.lock().unwrap().events.clone() + } + + /// Returns a copy of the set of tables that have been copied. + pub fn get_copied_tables(&self) -> HashSet { + self.inner.lock().unwrap().copied_tables.clone() + } + + /// Returns the number of tables that have been copied. + pub fn get_tables_copied(&self) -> u8 { + self.inner.lock().unwrap().copied_tables.len() as u8 + } + + /// Returns the number of tables that have been truncated. + pub fn get_tables_truncated(&self) -> u8 { + self.inner.lock().unwrap().truncated_tables.len() as u8 + } + + /// Returns the last LSN processed by the sink. + pub fn get_last_lsn(&self) -> u64 { + self.inner.lock().unwrap().last_lsn + } +} + +#[async_trait] +impl BatchSink for TestSink { + type Error = InfallibleSinkError; + + async fn get_resumption_state(&mut self) -> Result { + Ok(PipelineResumptionState { + copied_tables: self.get_copied_tables(), + last_lsn: PgLsn::from(self.get_last_lsn()), + }) + } + + async fn write_table_schemas( + &mut self, + table_schemas: HashMap, + ) -> Result<(), Self::Error> { + self.inner + .lock() + .unwrap() + .tables_schemas + .push(table_schemas); + + Ok(()) + } + + async fn write_table_rows( + &mut self, + rows: Vec, + table_id: TableId, + ) -> Result<(), Self::Error> { + self.inner + .lock() + .unwrap() + .tables_rows + .entry(table_id) + .or_default() + .extend(rows); + + Ok(()) + } + + async fn write_cdc_events(&mut self, events: Vec) -> Result { + self.receive_events(&events); + + // Since CdcEvent is not Clone, we have to wrap it in an Arc, and we are fine with this + // since it's not mutable, so we don't even have to use mutexes. + let arc_events = events.into_iter().map(Arc::new).collect::>(); + self.inner.lock().unwrap().events.extend(arc_events); + + Ok(PgLsn::from(self.inner.lock().unwrap().last_lsn)) + } + + async fn table_copied(&mut self, table_id: TableId) -> Result<(), Self::Error> { + self.inner.lock().unwrap().copied_tables.insert(table_id); + + Ok(()) + } + + async fn truncate_table(&mut self, table_id: TableId) -> Result<(), Self::Error> { + self.inner.lock().unwrap().truncated_tables.insert(table_id); + + Ok(()) + } +} diff --git a/pg_replicate/tests/common/table.rs b/pg_replicate/tests/common/table.rs new file mode 100644 index 00000000..9f5e7868 --- /dev/null +++ b/pg_replicate/tests/common/table.rs @@ -0,0 +1,39 @@ +use crate::common::sink::TestSink; +use postgres::schema::{ColumnSchema, TableId, TableName}; + +/// Verifies that a table's schema matches the expected configuration. +/// +/// This function compares a table's actual schema against the expected schema, +/// checking the table name, ID, and all column properties including name, type, +/// modifiers, nullability, and primary key status. +/// +/// # Panics +/// +/// Panics if: +/// - The table ID is not found in the sink's schema +/// - The schema index is out of bounds +/// - Any column property doesn't match the expected configuration +pub fn assert_table_schema( + sink: &TestSink, + table_id: TableId, + schema_index: usize, + expected_table_name: TableName, + expected_columns: &[ColumnSchema], +) { + let tables_schemas = &sink.get_tables_schemas()[schema_index]; + let table_schema = tables_schemas.get(&table_id).unwrap(); + + assert_eq!(table_schema.table_id, table_id); + assert_eq!(table_schema.table_name, expected_table_name); + + let columns = &table_schema.column_schemas; + assert_eq!(columns.len(), expected_columns.len()); + + for (actual, expected) in columns.iter().zip(expected_columns.iter()) { + assert_eq!(actual.name, expected.name); + assert_eq!(actual.typ, expected.typ); + assert_eq!(actual.modifier, expected.modifier); + assert_eq!(actual.nullable, expected.nullable); + assert_eq!(actual.primary, expected.primary); + } +} diff --git a/pg_replicate/tests/integration/mod.rs b/pg_replicate/tests/integration/mod.rs new file mode 100644 index 00000000..8f0125e2 --- /dev/null +++ b/pg_replicate/tests/integration/mod.rs @@ -0,0 +1 @@ +mod pipeline_test; diff --git a/pg_replicate/tests/integration/pipeline_test.rs b/pg_replicate/tests/integration/pipeline_test.rs new file mode 100644 index 00000000..5f8622f5 --- /dev/null +++ b/pg_replicate/tests/integration/pipeline_test.rs @@ -0,0 +1,219 @@ +use crate::common::database::{spawn_database, test_table_name}; +use crate::common::pipeline::{spawn_async_pg_pipeline, spawn_pg_pipeline, PipelineMode}; +use crate::common::sink::TestSink; +use crate::common::table::assert_table_schema; +use crate::common::wait_for_condition; +use pg_replicate::conversions::cdc_event::CdcEvent; +use pg_replicate::conversions::Cell; +use postgres::schema::{ColumnSchema, TableId}; +use postgres::tokio::test_utils::PgDatabase; +use std::ops::Range; +use tokio_postgres::types::Type; + +fn get_expected_ages_sum(num_users: usize) -> i32 { + ((num_users * (num_users + 1)) / 2) as i32 +} + +async fn create_users_table(database: &PgDatabase) -> TableId { + let table_id = database + .create_table(test_table_name("users"), &[("age", "integer")]) + .await + .unwrap(); + + table_id +} + +async fn create_users_table_with_publication( + database: &PgDatabase, + publication_name: &str, +) -> TableId { + let table_id = create_users_table(database).await; + + database + .create_publication(publication_name, &[test_table_name("users")]) + .await + .unwrap(); + + table_id +} + +async fn fill_users(database: &PgDatabase, num_users: usize) { + for i in 0..num_users { + let age = i as i32 + 1; + database + .insert_values(test_table_name("users"), &["age"], &[&age]) + .await + .unwrap(); + } +} + +async fn double_users_ages(database: &PgDatabase) { + database + .update_values(test_table_name("users"), &["age"], &["age * 2"]) + .await + .unwrap(); +} + +fn assert_users_table_schema(sink: &TestSink, users_table_id: TableId, schema_index: usize) { + let expected_columns = vec![ + ColumnSchema { + name: "id".to_string(), + typ: Type::INT8, + modifier: -1, + nullable: false, + primary: true, + }, + ColumnSchema { + name: "age".to_string(), + typ: Type::INT4, + modifier: -1, + nullable: true, + primary: false, + }, + ]; + + assert_table_schema( + sink, + users_table_id, + schema_index, + test_table_name("users"), + &expected_columns, + ); +} + +fn get_users_age_sum_from_rows(sink: &TestSink, users_table_id: TableId) -> i32 { + let mut actual_sum = 0; + + let tables_rows = sink.get_tables_rows(); + let table_rows = tables_rows.get(&users_table_id).unwrap(); + for table_row in table_rows { + if let Cell::I32(age) = &table_row.values[1] { + actual_sum += age; + } + } + + actual_sum +} + +fn get_users_age_sum_from_events( + sink: &TestSink, + users_table_id: TableId, + // We use a range since events are not indexed by table id but just an ordered sequence which + // we want to slice through. + range: Range, +) -> i32 { + let mut actual_sum = 0; + + let mut i = 0; + for event in sink.get_events() { + match event.as_ref() { + CdcEvent::Insert((table_id, table_row)) | CdcEvent::Update((table_id, table_row)) + if table_id == &users_table_id && range.contains(&i) => + { + if let Cell::I32(age) = &table_row.values[1] { + actual_sum += age; + } + i += 1; + } + _ => {} + } + } + + actual_sum +} + +/* +Tests to write: +- Insert -> cdc -> Update -> cdc +- Insert -> cdc -> add table -> recreate pipeline and source -> check schema + */ + +#[tokio::test(flavor = "multi_thread")] +async fn test_table_copy_with_insert_and_update() { + let database = spawn_database().await; + + // We insert 100 rows. + let users_table_id = create_users_table(&database).await; + fill_users(&database, 100).await; + + // We create a pipeline that copies the users table. + let mut pipeline = spawn_pg_pipeline( + &database.options, + PipelineMode::CopyTable { + table_names: vec![test_table_name("users")], + }, + TestSink::new(), + ) + .await; + pipeline.start().await.unwrap(); + + assert_users_table_schema(pipeline.sink(), users_table_id, 0); + let expected_sum = get_expected_ages_sum(100); + let actual_sum = get_users_age_sum_from_rows(pipeline.sink(), users_table_id); + assert_eq!(actual_sum, expected_sum); + assert_eq!(pipeline.sink().get_tables_copied(), 1); + assert_eq!(pipeline.sink().get_tables_truncated(), 1); + + // We double the user ages. + double_users_ages(&database).await; + + // We recreate the pipeline to copy again and see if we have the new data. + let mut pipeline = spawn_pg_pipeline( + &database.options, + PipelineMode::CopyTable { + table_names: vec![test_table_name("users")], + }, + TestSink::new(), + ) + .await; + pipeline.start().await.unwrap(); + + assert_users_table_schema(pipeline.sink(), users_table_id, 0); + let expected_sum = expected_sum * 2; + let actual_sum = get_users_age_sum_from_rows(pipeline.sink(), users_table_id); + assert_eq!(actual_sum, expected_sum); + assert_eq!(pipeline.sink().get_tables_copied(), 1); + assert_eq!(pipeline.sink().get_tables_truncated(), 1); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_cdc_with_multiple_inserts() { + let database = spawn_database().await; + + // We create the table and publication. + let users_table_id = create_users_table_with_publication(&database, "users_publication").await; + + // We create a pipeline that subscribes to the changes of the users table. + let sink = TestSink::new(); + let mut pipeline = spawn_async_pg_pipeline( + &database.options, + PipelineMode::Cdc { + publication: "users_publication".to_owned(), + slot_name: "users_slot".to_string(), + }, + sink.clone(), + ) + .await; + + // We insert 100 rows. + fill_users(&database, 100).await; + + // We run the pipeline in the background which should correctly pick up entries from the start + // even though the pipeline was started after insertions. + let pipeline_task_handle = pipeline.run().await; + + // Wait for all events to be processed. + let expected_sum = get_expected_ages_sum(100); + wait_for_condition(|| { + let actual_sum = get_users_age_sum_from_events(&sink, users_table_id, 0..100); + actual_sum == expected_sum + }) + .await; + + // We stop the pipeline and wait for it to finish. + pipeline.stop_and_wait(pipeline_task_handle).await; + + assert_users_table_schema(&sink, users_table_id, 0); + assert_eq!(sink.get_tables_copied(), 0); + assert_eq!(sink.get_tables_truncated(), 0); +} diff --git a/pg_replicate/tests/mod.rs b/pg_replicate/tests/mod.rs new file mode 100644 index 00000000..e78302ae --- /dev/null +++ b/pg_replicate/tests/mod.rs @@ -0,0 +1,2 @@ +mod common; +mod integration; diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml new file mode 100644 index 00000000..fcab8197 --- /dev/null +++ b/postgres/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "postgres" +version = "0.1.0" +edition = "2024" + +[dependencies] +pg_escape = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true, features = ["std"] } +secrecy = { workspace = true, features = ["serde", "alloc"] } +sqlx = { workspace = true, features = [ + "runtime-tokio-rustls", + "macros", + "postgres", + "json", + "migrate", +] } +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +tokio-postgres = { workspace = true, features = [ + "runtime", + "with-chrono-0_4", + "with-uuid-1", + "with-serde_json-1", +] } +tokio-postgres-rustls = { workspace = true } + + +[features] +test-utils = [] +tokio = [] +sqlx = [] \ No newline at end of file diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs new file mode 100644 index 00000000..7ae1123c --- /dev/null +++ b/postgres/src/lib.rs @@ -0,0 +1,16 @@ +//! PostgreSQL database connection utilities for all crates. +//! +//! This crate provides database connection options and utilities for working with PostgreSQL. +//! It supports both the [`sqlx`] and [`tokio-postgres`] crates through feature flags. +//! +//! # Features +//! +//! - `sqlx`: Enables SQLx-specific database connection options and utilities +//! - `tokio`: Enables tokio-postgres-specific database connection options and utilities +//! - `test-utils`: Enables test utilities for both SQLx and tokio-postgres implementations + +pub mod schema; +#[cfg(feature = "sqlx")] +pub mod sqlx; +#[cfg(feature = "tokio")] +pub mod tokio; diff --git a/postgres/src/schema.rs b/postgres/src/schema.rs new file mode 100644 index 00000000..f8b1cacb --- /dev/null +++ b/postgres/src/schema.rs @@ -0,0 +1,86 @@ +use std::fmt; + +use pg_escape::quote_identifier; +use tokio_postgres::types::Type; + +/// A fully qualified PostgreSQL table name consisting of a schema and table name. +/// +/// This type represents a table identifier in PostgreSQL, which requires both a schema name +/// and a table name. It provides methods for formatting the name in different contexts. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct TableName { + /// The schema name containing the table + pub schema: String, + /// The name of the table within the schema + pub name: String, +} + +impl TableName { + /// Returns the table name as a properly quoted PostgreSQL identifier. + /// + /// This method ensures the schema and table names are properly escaped according to + /// PostgreSQL identifier quoting rules. + pub fn as_quoted_identifier(&self) -> String { + let quoted_schema = quote_identifier(&self.schema); + let quoted_name = quote_identifier(&self.name); + format!("{quoted_schema}.{quoted_name}") + } +} + +impl fmt::Display for TableName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_fmt(format_args!("{0}.{1}", self.schema, self.name)) + } +} + +/// A type alias for PostgreSQL type modifiers. +/// +/// Type modifiers in PostgreSQL are used to specify additional type-specific attributes, +/// such as length for varchar or precision for numeric types. +type TypeModifier = i32; + +/// Represents the schema of a single column in a PostgreSQL table. +/// +/// This type contains all metadata about a column including its name, data type, +/// type modifier, nullability, and whether it's part of the primary key. +#[derive(Debug, Clone)] +pub struct ColumnSchema { + /// The name of the column + pub name: String, + /// The PostgreSQL data type of the column + pub typ: Type, + /// Type-specific modifier value (e.g., length for varchar) + pub modifier: TypeModifier, + /// Whether the column can contain NULL values + pub nullable: bool, + /// Whether the column is part of the table's primary key + pub primary: bool, +} + +/// A type alias for PostgreSQL table OIDs. +/// +/// Table OIDs are unique identifiers assigned to tables in PostgreSQL. +pub type TableId = u32; + +/// Represents the complete schema of a PostgreSQL table. +/// +/// This type contains all metadata about a table including its name, OID, +/// and the schemas of all its columns. +#[derive(Debug, Clone)] +pub struct TableSchema { + /// The fully qualified name of the table + pub table_name: TableName, + /// The PostgreSQL OID of the table + pub table_id: TableId, + /// The schemas of all columns in the table + pub column_schemas: Vec, +} + +impl TableSchema { + /// Returns whether the table has any primary key columns. + /// + /// This method checks if any column in the table is marked as part of the primary key. + pub fn has_primary_keys(&self) -> bool { + self.column_schemas.iter().any(|cs| cs.primary) + } +} diff --git a/postgres/src/sqlx/mod.rs b/postgres/src/sqlx/mod.rs new file mode 100644 index 00000000..090183a4 --- /dev/null +++ b/postgres/src/sqlx/mod.rs @@ -0,0 +1,3 @@ +pub mod options; +#[cfg(feature = "test-utils")] +pub mod test_utils; diff --git a/postgres/src/sqlx/options.rs b/postgres/src/sqlx/options.rs new file mode 100644 index 00000000..f78363fc --- /dev/null +++ b/postgres/src/sqlx/options.rs @@ -0,0 +1,61 @@ +use secrecy::{ExposeSecret, Secret}; +use serde::Deserialize; +use sqlx::postgres::{PgConnectOptions, PgSslMode}; + +/// Connection options for a PostgreSQL database. +/// +/// Contains the connection parameters needed to establish a connection to a PostgreSQL +/// database server, including network location, authentication credentials, and security +/// settings. +#[derive(Debug, Clone, Deserialize)] +pub struct PgDatabaseOptions { + /// Host name or IP address of the PostgreSQL server + pub host: String, + /// Port number that the PostgreSQL server listens on + pub port: u16, + /// Name of the target database + pub name: String, + /// Username for authentication + pub username: String, + /// Optional password for authentication, wrapped in [`Secret`] for secure handling + pub password: Option>, + /// If true, requires SSL/TLS encryption for the connection + pub require_ssl: bool, +} + +impl PgDatabaseOptions { + /// Creates connection options for connecting to the PostgreSQL server without + /// specifying a database. + /// + /// Returns [`PgConnectOptions`] configured with the host, port, username, SSL mode + /// and optional password from this instance. Useful for administrative operations + /// that must be performed before connecting to a specific database, like database + /// creation. + pub fn without_db(&self) -> PgConnectOptions { + let ssl_mode = if self.require_ssl { + PgSslMode::Require + } else { + PgSslMode::Prefer + }; + + let options = PgConnectOptions::new_without_pgpass() + .host(&self.host) + .username(&self.username) + .port(self.port) + .ssl_mode(ssl_mode); + + if let Some(password) = &self.password { + options.password(password.expose_secret()) + } else { + options + } + } + + /// Creates connection options for connecting to a specific database. + /// + /// Returns [`PgConnectOptions`] configured with all connection parameters including + /// the database name from this instance. + pub fn with_db(&self) -> PgConnectOptions { + self.without_db().database(&self.name) + } +} diff --git a/postgres/src/sqlx/test_utils.rs b/postgres/src/sqlx/test_utils.rs new file mode 100644 index 00000000..d860996c --- /dev/null +++ b/postgres/src/sqlx/test_utils.rs @@ -0,0 +1,55 @@ +use crate::sqlx::options::PgDatabaseOptions; +use sqlx::{Connection, Executor, PgConnection, PgPool}; + +/// Creates a new PostgreSQL database and returns a connection pool to it. +/// +/// Establishes a connection to the PostgreSQL server using the provided options, +/// creates a new database, and returns a [`PgPool`] connected to the new database. +/// Panics if the connection fails or if database creation fails. +pub async fn create_pg_database(options: &PgDatabaseOptions) -> PgPool { + // Create the database via a single connection. + let mut connection = PgConnection::connect_with(&options.without_db()) + .await + .expect("Failed to connect to Postgres"); + connection + .execute(&*format!(r#"create database "{}";"#, options.name)) + .await + .expect("Failed to create database"); + + // Create a connection pool to the database. + PgPool::connect_with(options.with_db()) + .await + .expect("Failed to connect to Postgres") +} + +/// Drops a PostgreSQL database and cleans up all connections. +/// +/// Connects to the PostgreSQL server, forcefully terminates all active connections +/// to the target database, and drops the database if it exists. Useful for cleaning +/// up test databases. Takes a reference to [`PgDatabaseOptions`] specifying the database +/// to drop. Panics if any operation fails. +pub async fn drop_pg_database(options: &PgDatabaseOptions) { + // Connect to the default database. + let mut connection = PgConnection::connect_with(&options.without_db()) + .await + .expect("Failed to connect to Postgres"); + + // Forcefully terminate any remaining connections to the database. + connection + .execute(&*format!( + r#" + select pg_terminate_backend(pg_stat_activity.pid) + from pg_stat_activity + where pg_stat_activity.datname = '{}' + and pid <> pg_backend_pid();"#, + options.name + )) + .await + .expect("Failed to terminate database connections"); + + // Drop the database. + connection + .execute(&*format!(r#"drop database if exists "{}";"#, options.name)) + .await + .expect("Failed to destroy database"); +} diff --git a/postgres/src/tokio/mod.rs b/postgres/src/tokio/mod.rs new file mode 100644 index 00000000..090183a4 --- /dev/null +++ b/postgres/src/tokio/mod.rs @@ -0,0 +1,3 @@ +pub mod options; +#[cfg(feature = "test-utils")] +pub mod test_utils; diff --git a/postgres/src/tokio/options.rs b/postgres/src/tokio/options.rs new file mode 100644 index 00000000..4bb298f2 --- /dev/null +++ b/postgres/src/tokio/options.rs @@ -0,0 +1,71 @@ +use tokio_postgres::Config; +use tokio_postgres::config::SslMode; + +/// Connection options for a PostgreSQL database. +/// +/// Contains the connection parameters needed to establish a connection to a PostgreSQL +/// database server, including network location, authentication credentials, and security +/// settings. +#[derive(Debug, Clone)] +pub struct PgDatabaseOptions { + /// Host name or IP address of the PostgreSQL server + pub host: String, + /// Port number that the PostgreSQL server listens on + pub port: u16, + /// Name of the target database + pub name: String, + /// Username for authentication + pub username: String, + /// Optional password for authentication + pub password: Option, + /// SSL mode for the connection + pub ssl_mode: SslMode, +} + +impl PgDatabaseOptions { + /// Creates connection options for connecting to the PostgreSQL server without + /// specifying a database. + /// + /// Returns [`Config`] configured with the host, port, username, SSL mode and optional + /// password from this instance. The database name is set to the username as per + /// PostgreSQL convention. Useful for administrative operations that must be performed + /// before connecting to a specific database, like database creation. + pub fn without_db(&self) -> Config { + let mut this = self.clone(); + // Postgres requires a database, so we default to the database which is equal to the username + // since this seems to be the standard. + this.name = this.username.clone(); + + this.into() + } + + /// Creates connection options for connecting to a specific database. + /// + /// Returns [`Config`] configured with all connection parameters including the database + /// name from this instance. + pub fn with_db(&self) -> Config { + self.clone().into() + } +} + +impl From for Config { + /// Converts [`PgDatabaseOptions`] into a [`Config`] instance. + /// + /// Sets all connection parameters including host, port, database name, username, + /// SSL mode, and optional password. + fn from(value: PgDatabaseOptions) -> Self { + let mut config = Config::new(); + config + .host(value.host) + .port(value.port) + .dbname(value.name) + .user(value.username) + .ssl_mode(value.ssl_mode); + + if let Some(password) = value.password { + config.password(password); + } + + config + } +} diff --git a/postgres/src/tokio/test_utils.rs b/postgres/src/tokio/test_utils.rs new file mode 100644 index 00000000..a49eb9a4 --- /dev/null +++ b/postgres/src/tokio/test_utils.rs @@ -0,0 +1,254 @@ +use crate::schema::{TableId, TableName}; +use crate::tokio::options::PgDatabaseOptions; +use tokio::runtime::Handle; +use tokio_postgres::{Client, NoTls}; + +pub struct PgDatabase { + pub options: PgDatabaseOptions, + pub client: Client, +} + +impl PgDatabase { + pub async fn new(options: PgDatabaseOptions) -> Self { + let client = create_pg_database(&options).await; + + Self { options, client } + } + + /// Creates a new publication for the specified tables. + pub async fn create_publication( + &self, + publication_name: &str, + table_names: &[TableName], + ) -> Result<(), tokio_postgres::Error> { + let table_names = table_names + .iter() + .map(TableName::as_quoted_identifier) + .collect::>(); + + let create_publication_query = format!( + "create publication {} for table {}", + publication_name, + table_names.join(", ") + ); + self.client.execute(&create_publication_query, &[]).await?; + + Ok(()) + } + + /// Creates a new table with the specified name and columns. + pub async fn create_table( + &self, + table_name: TableName, + columns: &[(&str, &str)], // (column_name, column_type) + ) -> Result { + let columns_str = columns + .iter() + .map(|(name, typ)| format!("{} {}", name, typ)) + .collect::>() + .join(", "); + + let create_table_query = format!( + "create table {} (id bigserial primary key, {})", + table_name.as_quoted_identifier(), + columns_str + ); + self.client.execute(&create_table_query, &[]).await?; + + // Get the OID of the newly created table + let row = self + .client + .query_one( + "select c.oid from pg_class c join pg_namespace n on n.oid = c.relnamespace \ + where n.nspname = $1 and c.relname = $2", + &[&table_name.schema, &table_name.name], + ) + .await?; + + let table_id: TableId = row.get(0); + + Ok(table_id) + } + + /// Inserts values into the specified table. + pub async fn insert_values( + &self, + table_name: TableName, + columns: &[&str], + values: &[&(dyn tokio_postgres::types::ToSql + Sync)], + ) -> Result { + let columns_str = columns.join(", "); + let placeholders: Vec = (1..=values.len()).map(|i| format!("${}", i)).collect(); + let placeholders_str = placeholders.join(", "); + + let insert_query = format!( + "insert into {} ({}) values ({})", + table_name.as_quoted_identifier(), + columns_str, + placeholders_str + ); + + self.client.execute(&insert_query, values).await + } + + /// Updates all rows in the specified table with the given values. + pub async fn update_values( + &self, + table_name: TableName, + columns: &[&str], + values: &[&str], + ) -> Result { + let set_clauses: Vec = columns + .iter() + .zip(values.iter()) + .map(|(col, val)| format!("{} = {}", col, val)) + .collect(); + let set_clause = set_clauses.join(", "); + + let update_query = format!( + "update {} set {}", + table_name.as_quoted_identifier(), + set_clause + ); + + self.client.execute(&update_query, &[]).await + } + + /// Queries rows from a single column of a table. + pub async fn query_table( + &self, + table_name: &TableName, + column: &str, + where_clause: Option<&str>, + ) -> Result, tokio_postgres::Error> + where + T: for<'a> tokio_postgres::types::FromSql<'a>, + { + let where_str = where_clause.map_or(String::new(), |w| format!(" where {}", w)); + let query = format!( + "select {} from {}{}", + column, + table_name.as_quoted_identifier(), + where_str + ); + + let rows = self.client.query(&query, &[]).await?; + Ok(rows.iter().map(|row| row.get(0)).collect()) + } +} + +impl Drop for PgDatabase { + fn drop(&mut self) { + // To use `block_in_place,` we need a multithreaded runtime since when a blocking + // task is issued, the runtime will offload existing tasks to another worker. + tokio::task::block_in_place(move || { + Handle::current().block_on(async move { drop_pg_database(&self.options).await }); + }); + } +} + +/// Creates a new PostgreSQL database and returns a client connected to it. +/// +/// Establishes a connection to the PostgreSQL server using the provided options, +/// creates a new database, and returns a [`Client`] connected to the new database. +/// Panics if the connection fails or if database creation fails. +pub async fn create_pg_database(options: &PgDatabaseOptions) -> Client { + // Create the database via a single connection + let (client, connection) = options + .without_db() + .connect(NoTls) + .await + .expect("Failed to connect to Postgres"); + + // Spawn the connection on a new task + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + // Create the database + client + .execute(&*format!(r#"create database "{}";"#, options.name), &[]) + .await + .expect("Failed to create database"); + + // Create a new client connected to the created database + let (client, connection) = options + .with_db() + .connect(NoTls) + .await + .expect("Failed to connect to Postgres"); + + // Spawn the connection on a new task + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + client +} + +/// Drops a PostgreSQL database and cleans up all connections. +/// +/// Connects to the PostgreSQL server, forcefully terminates all active connections +/// to the target database, and drops the database if it exists. Useful for cleaning +/// up test databases. Takes a reference to [`PgDatabaseOptions`] specifying the database +/// to drop. Panics if any operation fails. +pub async fn drop_pg_database(options: &PgDatabaseOptions) { + // Connect to the default database + let (client, connection) = options + .without_db() + .connect(NoTls) + .await + .expect("Failed to connect to Postgres"); + + // Spawn the connection on a new task + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + // Forcefully terminate any remaining connections to the database + client + .execute( + &format!( + r#" + select pg_terminate_backend(pg_stat_activity.pid) + from pg_stat_activity + where pg_stat_activity.datname = '{}' + and pid <> pg_backend_pid();"#, + options.name + ), + &[], + ) + .await + .expect("Failed to terminate database connections"); + + // Drop any test replication slots + client + .execute( + &format!( + r#" + select pg_drop_replication_slot(slot_name) + from pg_replication_slots + where slot_name like 'test_%' + and database = '{}';"#, + options.name + ), + &[], + ) + .await + .expect("Failed to drop test replication slots"); + + // Drop the database + client + .execute( + &format!(r#"drop database if exists "{}";"#, options.name), + &[], + ) + .await + .expect("Failed to destroy database"); +} diff --git a/replicator/Cargo.toml b/replicator/Cargo.toml index bed4b67f..e5870055 100644 --- a/replicator/Cargo.toml +++ b/replicator/Cargo.toml @@ -4,15 +4,17 @@ version = "0.1.0" edition = "2021" [dependencies] +pg_replicate = { workspace = true, features = ["bigquery"] } +postgres = { workspace = true, features = ["tokio"] } +telemetry = { workspace = true } + anyhow = { workspace = true, features = ["std"] } config = { workspace = true, features = ["yaml"] } -pg_replicate = { path = "../pg_replicate", features = ["bigquery"] } rustls = { workspace = true, features = ["aws-lc-rs", "logging"] } secrecy = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["std"] } thiserror = { workspace = true } rustls-pemfile = { workspace = true, features = ["std"] } -telemetry = { path = "../telemetry" } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true, default-features = true } diff --git a/replicator/src/main.rs b/replicator/src/main.rs index fef6b948..7df9290f 100644 --- a/replicator/src/main.rs +++ b/replicator/src/main.rs @@ -12,6 +12,7 @@ use pg_replicate::{ }, SslMode, }; +use postgres::tokio::options::PgDatabaseOptions; use telemetry::init_tracing; use tracing::{info, instrument}; @@ -106,13 +107,17 @@ async fn start_replication(settings: Settings) -> anyhow::Result<()> { SslMode::Disable }; - let postgres_source = PostgresSource::new( - &host, + let options = PgDatabaseOptions { + host, port, - &name, - &username, + name, + username, password, ssl_mode, + }; + + let postgres_source = PostgresSource::new( + options, trusted_root_certs_vec, Some(slot_name), TableNamesFrom::Publication(publication), diff --git a/api/scripts/init_db.sh b/scripts/init_db.sh similarity index 90% rename from api/scripts/init_db.sh rename to scripts/init_db.sh index 7b1fa1a6..6866f222 100755 --- a/api/scripts/init_db.sh +++ b/scripts/init_db.sh @@ -1,9 +1,9 @@ #!/usr/bin/env bash set -eo pipefail -if [ ! -d "migrations" ]; then - echo >&2 "❌ Error: '/migrations' folder not found." - echo >&2 "Please run this script from the 'pg_replicate/api' directory." +if [ ! -d "api/migrations" ]; then + echo >&2 "❌ Error: 'api/migrations' folder not found." + echo >&2 "Please run this script from the 'pg_replicate' directory." exit 1 fi @@ -59,7 +59,8 @@ then # Complete the docker run command DOCKER_RUN_CMD="${DOCKER_RUN_CMD} \ --name "postgres_$(date '+%s')" \ - postgres -N 1000" + postgres:15 -N 1000 \ + -c wal_level=logical" # Increased maximum number of connections for testing purposes # Start the container @@ -81,6 +82,6 @@ echo "✅ PostgreSQL is up and running on port ${DB_PORT}" echo "🔄 Setting up the database..." export DATABASE_URL=postgres://${DB_USER}:${DB_PASSWORD}@${DB_HOST}:${DB_PORT}/${DB_NAME} sqlx database create -sqlx migrate run +sqlx migrate run --source api/migrations echo "✨ Database setup complete! Ready to go!"