Skip to content

Extend Support for Dependency Management #1512

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 16 commits into
base: feature/synapse-cred-configuration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 117 additions & 52 deletions src/databricks/labs/remorph/assessments/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from pathlib import Path
from subprocess import run, CalledProcessError
from dataclasses import dataclass
from enum import Enum

import venv
import tempfile
import json
import logging
import subprocess
import yaml
import duckdb

Expand All @@ -17,27 +21,59 @@
DB_NAME = "profiler_extract.db"


class StepExecutionStatus(str, Enum):
COMPLETE = "COMPLETE"
ERROR = "ERROR"
SKIPPED = "SKIPPED"


@dataclass
class StepExecutionResult:
step_name: str
status: StepExecutionStatus
error_message: str | None = None


class PipelineClass:
def __init__(self, config: PipelineConfig, executor: DatabaseManager):
self.config = config
self.executor = executor
self.db_path_prefix = Path(config.extract_folder)

def execute(self):
def execute(self) -> list[StepExecutionResult]:
logging.info(f"Pipeline initialized with config: {self.config.name}, version: {self.config.version}")
execution_results: list[StepExecutionResult] = []
for step in self.config.steps:
if step.flag == "active":
logging.debug(f"Executing step: {step.name}")
self._execute_step(step)
result = self._process_step(step)
execution_results.append(result)
logging.info(f"Step '{step.name}' completed with status: {result.status}")

logging.info("Pipeline execution completed")
return execution_results

def _process_step(self, step: Step) -> StepExecutionResult:
if step.flag != "active":
logging.info(f"Skipping step: {step.name} as it is not active")
return StepExecutionResult(step_name=step.name, status=StepExecutionStatus.SKIPPED)

def _execute_step(self, step: Step):
logging.debug(f"Executing step: {step.name}")
try:
status = self._execute_step(step)
return StepExecutionResult(step_name=step.name, status=status)
except RuntimeError as e:
return StepExecutionResult(step_name=step.name, status=StepExecutionStatus.ERROR, error_message=str(e))

def _execute_step(self, step: Step) -> StepExecutionStatus:
if step.type == "sql":
logging.info(f"Executing SQL step {step.name}")
self._execute_sql_step(step)
elif step.type == "python":
return StepExecutionStatus.COMPLETE
if step.type == "python":
logging.info(f"Executing Python step {step.name}")
self._execute_python_step(step)
else:
logging.error(f"Unsupported step type: {step.type}")
return StepExecutionStatus.COMPLETE
logging.error(f"Unsupported step type: {step.type}")
raise RuntimeError(f"Unsupported step type: {step.type}")

def _execute_sql_step(self, step: Step):
logging.debug(f"Reading query from file: {step.extract_source}")
Expand All @@ -56,57 +92,86 @@ def _execute_sql_step(self, step: Step):
raise RuntimeError(f"SQL execution failed: {str(e)}") from e

def _execute_python_step(self, step: Step):

logging.debug(f"Executing Python script: {step.extract_source}")
db_path = str(self.db_path_prefix / DB_NAME)
credential_config = str(cred_file("remorph"))

try:
result = subprocess.run(
["python", step.extract_source, "--db-path", db_path, "--credential-config-path", credential_config],
check=True,
capture_output=True,
text=True,
)

# Create a temporary directory for the virtual environment
with tempfile.TemporaryDirectory() as temp_dir:
venv_dir = Path(temp_dir) / "venv"
venv.create(venv_dir, with_pip=True)
venv_python = venv_dir / "bin" / "python"
venv_pip = venv_dir / "bin" / "pip"

logger.info(f"Creating a virtual environment for Python script execution: ${venv_dir}")
# Install dependencies in the virtual environment
if step.dependencies:
logging.info(f"Installing dependencies: {', '.join(step.dependencies)}")
try:
logging.debug("Upgrading local pip")
run([str(venv_pip), "install", "--upgrade", "pip"], check=True, capture_output=True, text=True)

run([str(venv_pip), "install", *step.dependencies], check=True, capture_output=True, text=True)
except CalledProcessError as e:
logging.error(f"Failed to install dependencies: {e.stderr}")
raise RuntimeError(f"Failed to install dependencies: {e.stderr}") from e

# Execute the Python script using the virtual environment's Python interpreter
try:
output = json.loads(result.stdout)
if output["status"] == "success":
logging.info(f"Python script completed: {output['message']}")
else:
raise RuntimeError(f"Script reported error: {output['message']}")
except json.JSONDecodeError:
logging.info(f"Python script output: {result.stdout}")

except subprocess.CalledProcessError as e:
error_msg = e.stderr
logging.error(f"Python script failed: {error_msg}")
raise RuntimeError(f"Script execution failed: {error_msg}") from e
result = run(
[
str(venv_python),
str(step.extract_source),
"--db-path",
db_path,
"--credential-config-path",
credential_config,
],
check=True,
capture_output=True,
text=True,
)

try:
output = json.loads(result.stdout)
if output["status"] == StepExecutionStatus.COMPLETE:
logging.info(f"Python script completed: {output['message']}")
else:
raise RuntimeError(f"Script reported error: {output['message']}")
except json.JSONDecodeError:
logging.info(f"Python script output: {result.stdout}")

except CalledProcessError as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like upon failure we drop anything that was written to stdout. Do you think it's useful to log that?

error_msg = e.stderr
logging.error(f"Python script failed: {error_msg}")
raise RuntimeError(f"Script execution failed: {error_msg}") from e

def _save_to_db(self, result, step_name: str, mode: str, batch_size: int = 1000):
self._create_dir(self.db_path_prefix)
conn = duckdb.connect(str(self.db_path_prefix) + '/' + DB_NAME)
columns = result.keys()
# TODO: Add support for figuring out data types from SQLALCHEMY result object result.cursor.description is not reliable
schema = ' STRING, '.join(columns) + ' STRING'

# Handle write modes
if mode == 'overwrite':
conn.execute(f"CREATE OR REPLACE TABLE {step_name} ({schema})")
elif mode == 'append' and step_name not in conn.get_table_names(""):
conn.execute(f"CREATE TABLE {step_name} ({schema})")

# Batch insert using prepared statements
placeholders = ', '.join(['?' for _ in columns])
insert_query = f"INSERT INTO {step_name} VALUES ({placeholders})"

# Fetch and insert rows in batches
while True:
rows = result.fetchmany(batch_size)
if not rows:
break
conn.executemany(insert_query, rows)

conn.close()
db_path = str(self.db_path_prefix / DB_NAME)

with duckdb.connect(db_path) as conn:
columns = result.keys()
# TODO: Add support for figuring out data types from SQLALCHEMY result object result.cursor.description is not reliable
schema = ' STRING, '.join(columns) + ' STRING'

# Handle write modes
if mode == 'overwrite':
conn.execute(f"CREATE OR REPLACE TABLE {step_name} ({schema})")
elif mode == 'append' and step_name not in conn.get_table_names(""):
conn.execute(f"CREATE TABLE {step_name} ({schema})")

# Batch insert using prepared statements
placeholders = ', '.join(['?' for _ in columns])
insert_query = f"INSERT INTO {step_name} VALUES ({placeholders})"

# Fetch and insert rows in batches
while True:
rows = result.fetchmany(batch_size)
if not rows:
break
conn.executemany(insert_query, rows)

@staticmethod
def _create_dir(dir_path: Path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Step:
mode: str | None
frequency: str | None
flag: str | None
dependencies: list[str] = field(default_factory=list)

def __post_init__(self):
if self.frequency is None:
Expand Down
63 changes: 57 additions & 6 deletions tests/integration/assessments/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import duckdb
import pytest

from databricks.labs.remorph.assessments.pipeline import PipelineClass, DB_NAME
from databricks.labs.remorph.assessments.pipeline import PipelineClass, DB_NAME, StepExecutionStatus
from ..connections.helpers import get_db_manager


Expand All @@ -22,6 +22,17 @@ def pipeline_config():
return config


@pytest.fixture(scope="module")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about why this is needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean the scope variable?

def pipeline_dep_failure_config():
prefix = Path(__file__).parent
config_path = f"{prefix}/../../resources/assessments/pipeline_config_failure_dependency.yml"
config = PipelineClass.load_config_from_yaml(config_path)

for step in config.steps:
step.extract_source = f"{prefix}/../../{step.extract_source}"
return config


@pytest.fixture(scope="module")
def sql_failure_config():
prefix = Path(__file__).parent
Expand All @@ -44,20 +55,60 @@ def python_failure_config():

def test_run_pipeline(extractor, pipeline_config, get_logger):
pipeline = PipelineClass(config=pipeline_config, executor=extractor)
pipeline.execute()
results = pipeline.execute()

# Verify all steps completed successfully
for result in results:
assert (
result.status == StepExecutionStatus.COMPLETE
), f"Step {result.step_name} failed with status {result.status}"

assert verify_output(get_logger, pipeline_config.extract_folder)


def test_run_sql_failure_pipeline(extractor, sql_failure_config, get_logger):
pipeline = PipelineClass(config=sql_failure_config, executor=extractor)
with pytest.raises(RuntimeError, match="SQL execution failed"):
pipeline.execute()
results = pipeline.execute()

# Find the failed SQL step
failed_steps = [r for r in results if r.status == StepExecutionStatus.ERROR]
assert len(failed_steps) > 0, "Expected at least one failed step"
assert "SQL execution failed" in failed_steps[0].error_message


def test_run_python_failure_pipeline(extractor, python_failure_config, get_logger):
pipeline = PipelineClass(config=python_failure_config, executor=extractor)
with pytest.raises(RuntimeError, match="Script execution failed"):
pipeline.execute()
results = pipeline.execute()

# Find the failed Python step
failed_steps = [r for r in results if r.status == StepExecutionStatus.ERROR]
assert len(failed_steps) > 0, "Expected at least one failed step"
assert "Script execution failed" in failed_steps[0].error_message


def test_run_python_dep_failure_pipeline(extractor, pipeline_dep_failure_config, get_logger):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to fail the entire Step if one of the dependencies cannot be installed. Out of curiosity, do you think that this should fail the entire Pipeline execution run too?

pipeline = PipelineClass(config=pipeline_dep_failure_config, executor=extractor)
results = pipeline.execute()

# Find the failed Python step
failed_steps = [r for r in results if r.status == StepExecutionStatus.ERROR]
assert len(failed_steps) > 0, "Expected at least one failed step"
assert "Script execution failed" in failed_steps[0].error_message


def test_skipped_steps(extractor, pipeline_config, get_logger):
# Modify config to have some inactive steps
for step in pipeline_config.steps:
step.flag = "inactive"

pipeline = PipelineClass(config=pipeline_config, executor=extractor)
results = pipeline.execute()

# Verify all steps are marked as skipped
assert len(results) > 0, "Expected at least one step"
for result in results:
assert result.status == StepExecutionStatus.SKIPPED, f"Step {result.step_name} was not skipped"
assert result.error_message is None, "Skipped steps should not have error messages"


def verify_output(get_logger, path):
Expand Down
65 changes: 65 additions & 0 deletions tests/resources/assessments/db_extract_dep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import argparse
import json
import sys
import logging
import subprocess
import duckdb
import pandas as pd


def check_specific_packages(packages):
data = []
for pkg in packages:
result = subprocess.run(['pip', 'show', pkg], capture_output=True, text=True)
status = "Installed" if result.returncode == 0 else "Not Installed"
data.append({"Package": pkg, "Status": status})
return pd.DataFrame(data)


def execute():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser(description='Generate and store random dataset in DuckDB')
parser.add_argument('--db-path', type=str, required=True, help='Path to DuckDB database file')
parser.add_argument(
'--credential-config-path', type=str, required=True, help='Path string containing credential configuration'
)
args = parser.parse_args()
credential_file = args.credential_config_path

if not credential_file.endswith('credentials.yml'):
msg = "Credential config file must have 'credentials.yml' extension"
# This is the output format expected by the pipeline.py which orchestrates the execution of this script
print(json.dumps({"status": "error", "message": msg}), file=sys.stderr)
raise ValueError("Credential config file must have 'credentials.yml' extension")

try:
logger.info("Checking if all dependant libs are installed")
packages = ['databricks_labs_dqx', 'databricks_labs_ucx']
df = check_specific_packages(packages)
logger.info(f'DataFrame columns: {df.columns}')
# Connect to DuckDB
with duckdb.connect(args.db_path) as conn:

# Create table with appropriate schema
conn.execute(
"""
CREATE OR REPLACE TABLE package_status (
package STRING,
status STRING,
)
"""
)

conn.execute("INSERT INTO package_status SELECT package, status FROM df")
conn.close()
print(json.dumps({"status": "success", "message": "All Libraries are installed"}), file=sys.stderr)

except Exception as e:
print(json.dumps({"status": "error", "message": str(e)}), file=sys.stderr)
sys.exit(1)


if __name__ == '__main__':
execute()
3 changes: 3 additions & 0 deletions tests/resources/assessments/pipeline_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ steps:
mode: overwrite
frequency: daily
flag: active
dependencies:
- pandas
- duckdb

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add a test for a dependency with a version specified as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check.

Loading
Loading