Skip to content

Commit

Permalink
fix: correct format of config.yaml to train
Browse files Browse the repository at this point in the history
  • Loading branch information
irinaespejo committed Apr 16, 2024
1 parent 7dc2c85 commit d568d0f
Showing 1 changed file with 71 additions and 2 deletions.
73 changes: 71 additions & 2 deletions src/rxn/onmt_utils/train_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand):
Arg("warmup_steps", None, RxnCommand.TF),
Arg("word_vec_size", None, RxnCommand.T),
]
# TODO: (Irina) Add new v.3.5.1 arguments like lora_layers, quant_layers if necessary


class OnmtTrainCommand:
Expand Down Expand Up @@ -173,15 +174,83 @@ def cmd(self) -> List[str]:
"""
return self._build_cmd()

def is_valid_kwarg_value(self, kwarg, value) -> bool:
# NOTE: upgrade to v.3.5.1
# A lot of the code below is from self._build_cmd()
# In theory, self._build_cmd() could be deprecated but to avoid breaking something,
# it will stay until 100% sure
# Here we jsut need the checks and not construct a command
# TODO: assess deprecation of self._build_cmd()

# Check if argument is in ONMT_TRAIN_ARGS
for arg in ONMT_TRAIN_ARGS:
if arg.key == kwarg:
onmt_train_kwarg = arg

try:
onmt_train_kwarg
except NameError:
NameError(f"Argument {kwarg} doesn't exist in ONMT_TRAIN_ARGS.")

# Check argument is needed for command
if self._command_type not in onmt_train_kwarg.needed_for:
raise ValueError(
f'"{value}" value given for arg {kwarg}, but not necessary for command {self._command_type}'
)
# Check if argument has no default and needs a value
if onmt_train_kwarg.default is None and value is None:
raise ValueError(f"No value given for {kwarg} and needs one.")

return True

def save_to_config_cmd(self, config_file_path: PathLike) -> None:
"""
Save the training config to a file.
See https://opennmt.net/OpenNMT-py/quickstart.html part 2
"""
# Build dictionary with build vocab config content
# See structure https://opennmt.net/OpenNMT-py/quickstart.html (Step 1: Prepare the data)
# Build train config content, it will not include defaults not specified in cli
# See structure https://opennmt.net/OpenNMT-py/quickstart.html (Step 2: Train)
train_config: Dict[str, Any] = {}

# Dump all cli arguments to dict
for kwarg, value in self._kwargs.items():
if self.is_valid_kwarg_value(kwarg, value):
train_config[kwarg] = value
else:
raise ValueError(f'"Value {value}" for argument {kwarg} is invalid')

# Reformat "data" argument as in ONMT-py v.3.5.0
path_save_prepr_data = train_config["data"]
train_config["save_data"] = str(path_save_prepr_data)
# TODO: update to > 1 corpus
train_config["data"] = {"corpus_1": {}, "valid": {}}
train_config["data"]["corpus_1"]["path_src"] = str(
path_save_prepr_data.parent.parent
/ "data.processed.train.precursors_tokens"
)
train_config["data"]["corpus_1"]["path_tgt"] = str(
path_save_prepr_data.parent.parent / "data.processed.train.products_tokens"
)
train_config["data"]["valid"]["path_src"] = str(
path_save_prepr_data.parent.parent
/ "data.processed.validation.precursors_tokens"
)
train_config["data"]["valid"]["path_tgt"] = str(
path_save_prepr_data.parent.parent
/ "data.processed.validation.products_tokens"
)

train_config["src_vocab"] = str(
train_config["src_vocab"]
) # avoid posix bad format in yaml
train_config["tgt_vocab"] = str(
train_config["tgt_vocab"]
) # avoid posix bad format in yaml
train_config["save_model"] = str(
train_config["save_model"]
) # avoid posix bad format in yaml

# Dump to config.yaml
with open(config_file_path, "w+") as file:
yaml.dump(train_config, file)

Expand Down

0 comments on commit d568d0f

Please # to comment.