Skip to content

Commit

Permalink
fix: correct gpu argument format onmt-py
Browse files Browse the repository at this point in the history
  • Loading branch information
irinaespejo committed Apr 30, 2024
1 parent d568d0f commit 765d97e
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/rxn/onmt_utils/train_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import torch
import yaml
from rxn.utilities.files import PathLike

Expand Down Expand Up @@ -212,6 +213,10 @@ def save_to_config_cmd(self, config_file_path: PathLike) -> None:
# See structure https://opennmt.net/OpenNMT-py/quickstart.html (Step 2: Train)
train_config: Dict[str, Any] = {}

# GPUs
if torch.cuda.is_available() and self._no_gpu is False:
train_config["gpu_ranks"] = [0]

# Dump all cli arguments to dict
for kwarg, value in self._kwargs.items():
if self.is_valid_kwarg_value(kwarg, value):
Expand Down

0 comments on commit 765d97e

Please # to comment.