Skip to content

Commit

Permalink
fix: config file paths for src, tgt also for retro
Browse files Browse the repository at this point in the history
  • Loading branch information
irinaespejo committed May 24, 2024
1 parent f1a0b97 commit 6725e6d
Showing 1 changed file with 41 additions and 14 deletions.
55 changes: 41 additions & 14 deletions src/rxn/onmt_utils/train_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,28 @@
logger.addHandler(logging.NullHandler())


def get_paths_src_tgt(dir: PathLike, model_task: str) -> Tuple[str, str, str, str]:
# reaction is A -> B irregardless of retro or forward
if model_task == "forward":
A = "precursors"
B = "products"
elif model_task == "retro":
A = "products"
B = "precursors"
pass
else:
raise ValueError(
f"Argument model_task can only be 'forward' or 'retro' but received {model_task}"
)

corpus_path_src = f"{dir}/data.processed.train.{A}_tokens"
corpus_path_tgt = f"{dir}/data.processed.train.{B}_tokens"
valid_path_src = f"{dir}/data.processed.validation.{A}_tokens"
valid_path_tgt = f"{dir}/data.processed.validation.{B}_tokens"

return corpus_path_src, corpus_path_tgt, valid_path_src, valid_path_tgt


class RxnCommand(Flag):
"""
Flag indicating which command(s) the parameters relate to.
Expand Down Expand Up @@ -106,12 +128,14 @@ def __init__(
command_type: RxnCommand,
no_gpu: bool,
data_weights: Tuple[int, ...],
model_task: str,
**kwargs: Any,
):
self._command_type = command_type
self._no_gpu = no_gpu
self._data_weights = data_weights
self._kwargs = kwargs
self.model_task = model_task

def _build_cmd(self) -> List[str]:
"""
Expand Down Expand Up @@ -237,22 +261,19 @@ def save_to_config_cmd(self, config_file_path: PathLike) -> None:
train_config["save_data"] = str(path_save_prepr_data)
# TODO: update to > 1 corpus
train_config["data"] = {"corpus_1": {}, "valid": {}}
train_config["data"]["corpus_1"]["path_src"] = str(
path_save_prepr_data.parent.parent
/ "data.processed.train.precursors_tokens"
)
train_config["data"]["corpus_1"]["path_tgt"] = str(
path_save_prepr_data.parent.parent / "data.processed.train.products_tokens"
)
train_config["data"]["valid"]["path_src"] = str(
path_save_prepr_data.parent.parent
/ "data.processed.validation.precursors_tokens"
)
train_config["data"]["valid"]["path_tgt"] = str(
path_save_prepr_data.parent.parent
/ "data.processed.validation.products_tokens"

# get data files path, caution depends on task because ONMT preprocessed files in v.3.5.1 aren't fully processed as with earlier versions
corpus_path_src, corpus_path_tgt, valid_path_src, valid_path_tgt = (
get_paths_src_tgt(
dir=path_save_prepr_data.parent.parent, model_task=self.model_task
)
)

train_config["data"]["corpus_1"]["path_src"] = corpus_path_src
train_config["data"]["corpus_1"]["path_tgt"] = corpus_path_tgt
train_config["data"]["valid"]["path_src"] = valid_path_src
train_config["data"]["valid"]["path_tgt"] = valid_path_tgt

train_config["src_vocab"] = str(
train_config["src_vocab"]
) # avoid posix bad format in yaml
Expand Down Expand Up @@ -297,6 +318,7 @@ def train(
word_vec_size: int,
no_gpu: bool,
data_weights: Tuple[int, ...],
model_task: str,
keep_checkpoint: int = -1,
) -> "OnmtTrainCommand":
return cls(
Expand All @@ -319,6 +341,7 @@ def train(
transformer_ff=transformer_ff,
warmup_steps=warmup_steps,
word_vec_size=word_vec_size,
model_task=model_task,
)

@classmethod
Expand All @@ -333,6 +356,7 @@ def continue_training(
train_steps: int,
no_gpu: bool,
data_weights: Tuple[int, ...],
model_task: str,
keep_checkpoint: int = -1,
) -> "OnmtTrainCommand":
return cls(
Expand All @@ -348,6 +372,7 @@ def continue_training(
seed=seed,
train_from=train_from,
train_steps=train_steps,
model_task=model_task,
)

@classmethod
Expand All @@ -366,6 +391,7 @@ def finetune(
data_weights: Tuple[int, ...],
report_every: int,
save_checkpoint_steps: int,
model_task: str,
keep_checkpoint: int = -1,
hidden_size: Optional[int] = None,
) -> "OnmtTrainCommand":
Expand Down Expand Up @@ -396,6 +422,7 @@ def finetune(
warmup_steps=warmup_steps,
report_every=report_every,
save_checkpoint_steps=save_checkpoint_steps,
model_task=model_task,
)


Expand Down

0 comments on commit 6725e6d

Please # to comment.