Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add python-version argument to tensorflow-task #210

Merged
merged 1 commit into from
Aug 27, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions spotify_tensorflow/luigi/tensorflow_task.py
Original file line number Diff line number Diff line change
@@ -44,6 +44,9 @@ class TensorFlowTask(luigi.Task):
to the latest stable version. See
https://cloud.google.com/ml/docs/concepts/runtime-version-list for a
list of accepted versions.
python_version = None The Google Cloud AI Platform python version for this job. See
https://cloud.google.com/ml-engine/docs/versioning#set-python-version-training
for more information.
scale_tier = None Specifies the machine types, the number of replicas for workers and
parameter servers. SCALE_TIER must be one of:
basic, basic-gpu, basic-tpu, custom, premium-1, standard-1.
@@ -74,6 +77,12 @@ class TensorFlowTask(luigi.Task):
runtime_version = luigi.Parameter(default=None,
description="The Google Cloud AI Platform runtime version "
"for this job.")
python_version = luigi.Parameter(
default=None,
description="The Google Cloud AI Platform python version for this job. See "
"https://cloud.google.com/ml-engine/docs/versioning#set-python-version-training"
"for more information."
)
scale_tier = luigi.Parameter(default=None,
description="Specifies the machine types, the number of replicas "
"for workers and parameter servers.")
@@ -126,6 +135,8 @@ def _mk_cloud_params(self):
params.append("--stream-logs") # makes the execution "blocking"
if self.runtime_version:
params.append("--runtime-version=%s" % self.runtime_version)
if self.python_version:
params.append("--python-version=%s" % self.python_version)
if self.scale_tier:
params.append("--scale-tier=%s" % self.scale_tier)
return params