-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathqueue.py
162 lines (139 loc) · 5.67 KB
/
queue.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import logging
import shutil
import uuid
from pathlib import Path
from typing import Dict, List, Optional
from redis import Redis
from rq import Queue, get_current_job
from rq.job import Job
from orchestrator.core.calculator import TsunamiCalculator
from orchestrator.core.config import MASTER_PIPELINE, MODEL_DIR
from orchestrator.models.schemas import EarthquakeInput, JobStatus
from orchestrator.utils.file_utils import setup_workspace
from orchestrator.utils.processing import process_step
from orchestrator.utils.system import check_dependencies
logger = logging.getLogger(__name__)
class TSDHNQueue:
def __init__(self, redis_conn: Redis):
self.redis = redis_conn
self.queue = Queue("tsdhn_queue", connection=redis_conn)
def enqueue_job(
self, data: EarthquakeInput, skip_steps: Optional[List[str]] = None
) -> str:
skip_steps = skip_steps or []
self._validate_skip_steps(skip_steps)
try:
job_id = str(uuid.uuid4())
self.queue.enqueue(
execute_pipeline,
data.model_dump(),
skip_steps,
job_id=job_id,
job_timeout="2h",
meta={
"status": JobStatus.QUEUED.value,
"details": "Initializing simulation pipeline",
"data": data.model_dump(),
},
result_ttl=86400,
)
return job_id
except Exception as e:
logger.exception("Job enqueue failed")
raise RuntimeError(f"Failed to enqueue job: {str(e)}") from e
def get_job_status(self, job_id: str) -> Dict:
try:
job = Job.fetch(job_id, connection=self.redis)
status_map = {
"queued": JobStatus.QUEUED.value,
"started": JobStatus.RUNNING.value,
"finished": JobStatus.COMPLETED.value,
"failed": JobStatus.FAILED.value,
}
return {
"status": status_map.get(job.get_status(), JobStatus.QUEUED.value),
"calculation": job.meta.get("calculation"),
"travel_times": job.meta.get("travel_times"),
"details": job.meta.get("details"),
"error": job.meta.get("error"),
"created_at": job.created_at.isoformat() if job.created_at else None,
"started_at": job.started_at.isoformat() if job.started_at else None,
"ended_at": job.ended_at.isoformat() if job.ended_at else None,
"download_url": f"/job-result/{job_id}"
if job.meta.get("status") == JobStatus.COMPLETED.value
else None,
}
except Exception as e:
logger.error(f"Invalid job ID {job_id}: {str(e)}")
raise ValueError("Invalid or expired job ID") from e
def is_redis_connected(self) -> bool:
try:
return self.redis.ping()
except Exception:
return False
@staticmethod
def _validate_skip_steps(skip_steps: List[str]):
valid_steps = {step.name for step in MASTER_PIPELINE}
invalid = set(skip_steps) - valid_steps
if invalid:
raise ValueError(f"Invalid steps to skip: {', '.join(invalid)}")
def execute_pipeline(data_dict: dict, skip_steps: List[str]):
"""Main pipeline executor"""
job = get_current_job()
job_id = job.id
repo_root = Path(__file__).resolve().parent.parent.parent
work_dir = repo_root / "jobs" / job_id
data = EarthquakeInput(**data_dict)
calculator = TsunamiCalculator()
try:
setup_workspace(repo_root / MODEL_DIR, work_dir)
def update_meta(details: str, **kwargs):
job.meta.update({"details": details, **kwargs})
job.save_meta()
# Phase 1: Initial calculations
update_meta("Running earthquake calculations")
calc_result = calculator.calculate_earthquake_parameters(data, work_dir)
update_meta(
"Earthquake calculations complete",
calculation=calc_result.dict(),
status=JobStatus.RUNNING.value,
)
# Phase 2: Tsunami travel times
update_meta("Calculating tsunami travel times")
tsunami_result = calculator.calculate_tsunami_travel_times(data)
update_meta(
"Tsunami calculations complete",
travel_times=tsunami_result.dict(),
status=JobStatus.RUNNING.value,
)
# Phase 3: Main simulation pipeline
check_dependencies()
for step in MASTER_PIPELINE:
if step.name in skip_steps:
logger.info(f"Skipping step: {step.name}")
continue
update_meta(f"Processing {step.name}")
step_dir = work_dir / step.working_dir if step.working_dir else work_dir
step_dir.mkdir(parents=True, exist_ok=True)
process_step(step, step_dir)
update_meta(
"Simulation completed successfully",
status=JobStatus.COMPLETED.value,
download_url=f"/job-result/{job_id}",
)
return {"status": "completed"}
except Exception as e:
logger.exception(f"Pipeline failed for job {job_id}")
if work_dir.exists():
shutil.rmtree(work_dir, ignore_errors=True)
if job:
job.meta.update(
{
"status": JobStatus.FAILED.value,
"error": f"{type(e).__name__}: {str(e)}",
"details": "Pipeline failed - check error logs",
}
)
job.save_meta()
raise
tsdhn_queue = TSDHNQueue(Redis(host="localhost", port=6379, db=0))