|
20 | 20 | Description:
|
21 | 21 | Model Class
|
22 | 22 | """
|
23 |
| - |
24 | 23 | import time
|
25 | 24 | import json
|
26 | 25 | import logging
|
@@ -251,23 +250,65 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param
|
251 | 250 | response["error"] = msg
|
252 | 251 | return response
|
253 | 252 |
|
254 |
| - def check_finetune_status(self): |
| 253 | + def check_finetune_status(self, after_epoch: Optional[int] = None): |
255 | 254 | """Check the status of the FineTune model.
|
256 | 255 |
|
| 256 | + Args: |
| 257 | + after_epoch (Optional[int], optional): status after a given epoch. Defaults to None. |
| 258 | +
|
257 | 259 | Raises:
|
258 | 260 | Exception: If the 'TEAM_API_KEY' is not provided.
|
259 | 261 |
|
260 | 262 | Returns:
|
261 |
| - str: The status of the FineTune model. |
| 263 | + FinetuneStatus: The status of the FineTune model. |
262 | 264 | """
|
| 265 | + from aixplain.enums.asset_status import AssetStatus |
| 266 | + from aixplain.modules.finetune.status import FinetuneStatus |
263 | 267 | headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
|
| 268 | + resp = None |
264 | 269 | try:
|
265 |
| - url = urljoin(self.backend_url, f"sdk/models/{self.id}") |
| 270 | + url = urljoin(self.backend_url, f"sdk/finetune/{self.id}/ml-logs") |
266 | 271 | logging.info(f"Start service for GET Check FineTune status Model - {url} - {headers}")
|
267 | 272 | r = _request_with_retry("get", url, headers=headers)
|
268 | 273 | resp = r.json()
|
269 |
| - status = resp["status"] |
270 |
| - logging.info(f"Response for GET Check FineTune status Model - Id {self.id} / Status {status}.") |
| 274 | + finetune_status = AssetStatus(resp["finetuneStatus"]) |
| 275 | + model_status = AssetStatus(resp["modelStatus"]) |
| 276 | + logs = sorted(resp["logs"], key=lambda x: float(x["epoch"])) |
| 277 | + |
| 278 | + target_epoch = None |
| 279 | + if after_epoch is not None: |
| 280 | + logs = [log for log in logs if float(log["epoch"]) > after_epoch] |
| 281 | + if len(logs) > 0: |
| 282 | + target_epoch = float(logs[0]["epoch"]) |
| 283 | + elif len(logs) > 0: |
| 284 | + target_epoch = float(logs[-1]["epoch"]) |
| 285 | + |
| 286 | + if target_epoch is not None: |
| 287 | + log = None |
| 288 | + for log_ in logs: |
| 289 | + if int(log_["epoch"]) == target_epoch: |
| 290 | + if log is None: |
| 291 | + log = log_ |
| 292 | + else: |
| 293 | + if log_["trainLoss"] is not None: |
| 294 | + log["trainLoss"] = log_["trainLoss"] |
| 295 | + if log_["evalLoss"] is not None: |
| 296 | + log["evalLoss"] = log_["evalLoss"] |
| 297 | + |
| 298 | + status = FinetuneStatus( |
| 299 | + status=finetune_status, |
| 300 | + model_status=model_status, |
| 301 | + epoch=float(log["epoch"]) if "epoch" in log and log["epoch"] is not None else None, |
| 302 | + training_loss=float(log["trainLoss"]) if "trainLoss" in log and log["trainLoss"] is not None else None, |
| 303 | + validation_loss=float(log["evalLoss"]) if "evalLoss" in log and log["evalLoss"] is not None else None, |
| 304 | + ) |
| 305 | + else: |
| 306 | + status = FinetuneStatus( |
| 307 | + status=finetune_status, |
| 308 | + model_status=model_status, |
| 309 | + ) |
| 310 | + |
| 311 | + logging.info(f"Response for GET Check FineTune status Model - Id {self.id} / Status {status.status.value}.") |
271 | 312 | return status
|
272 | 313 | except Exception as e:
|
273 | 314 | message = ""
|
|
0 commit comments