From 810d7de08f6af419c323be781a627eff05debcc6 Mon Sep 17 00:00:00 2001 From: drachenbach Date: Tue, 27 Aug 2019 16:52:03 +0200 Subject: [PATCH] Add python-version argument to tensorflow-task (#210) --- spotify_tensorflow/luigi/tensorflow_task.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/spotify_tensorflow/luigi/tensorflow_task.py b/spotify_tensorflow/luigi/tensorflow_task.py index f94f2a0..db62959 100644 --- a/spotify_tensorflow/luigi/tensorflow_task.py +++ b/spotify_tensorflow/luigi/tensorflow_task.py @@ -44,6 +44,9 @@ class TensorFlowTask(luigi.Task): to the latest stable version. See https://cloud.google.com/ml/docs/concepts/runtime-version-list for a list of accepted versions. + python_version = None The Google Cloud AI Platform python version for this job. See + https://cloud.google.com/ml-engine/docs/versioning#set-python-version-training + for more information. scale_tier = None Specifies the machine types, the number of replicas for workers and parameter servers. SCALE_TIER must be one of: basic, basic-gpu, basic-tpu, custom, premium-1, standard-1. @@ -74,6 +77,12 @@ class TensorFlowTask(luigi.Task): runtime_version = luigi.Parameter(default=None, description="The Google Cloud AI Platform runtime version " "for this job.") + python_version = luigi.Parameter( + default=None, + description="The Google Cloud AI Platform python version for this job. See " + "https://cloud.google.com/ml-engine/docs/versioning#set-python-version-training" + "for more information." + ) scale_tier = luigi.Parameter(default=None, description="Specifies the machine types, the number of replicas " "for workers and parameter servers.") @@ -126,6 +135,8 @@ def _mk_cloud_params(self): params.append("--stream-logs") # makes the execution "blocking" if self.runtime_version: params.append("--runtime-version=%s" % self.runtime_version) + if self.python_version: + params.append("--python-version=%s" % self.python_version) if self.scale_tier: params.append("--scale-tier=%s" % self.scale_tier) return params