diff --git a/spotify_tensorflow/luigi/tensorflow_task.py b/spotify_tensorflow/luigi/tensorflow_task.py index f94f2a0..db62959 100644 --- a/spotify_tensorflow/luigi/tensorflow_task.py +++ b/spotify_tensorflow/luigi/tensorflow_task.py @@ -44,6 +44,9 @@ class TensorFlowTask(luigi.Task): to the latest stable version. See https://cloud.google.com/ml/docs/concepts/runtime-version-list for a list of accepted versions. + python_version = None The Google Cloud AI Platform python version for this job. See + https://cloud.google.com/ml-engine/docs/versioning#set-python-version-training + for more information. scale_tier = None Specifies the machine types, the number of replicas for workers and parameter servers. SCALE_TIER must be one of: basic, basic-gpu, basic-tpu, custom, premium-1, standard-1. @@ -74,6 +77,12 @@ class TensorFlowTask(luigi.Task): runtime_version = luigi.Parameter(default=None, description="The Google Cloud AI Platform runtime version " "for this job.") + python_version = luigi.Parameter( + default=None, + description="The Google Cloud AI Platform python version for this job. See " + "https://cloud.google.com/ml-engine/docs/versioning#set-python-version-training" + "for more information." + ) scale_tier = luigi.Parameter(default=None, description="Specifies the machine types, the number of replicas " "for workers and parameter servers.") @@ -126,6 +135,8 @@ def _mk_cloud_params(self): params.append("--stream-logs") # makes the execution "blocking" if self.runtime_version: params.append("--runtime-version=%s" % self.runtime_version) + if self.python_version: + params.append("--python-version=%s" % self.python_version) if self.scale_tier: params.append("--scale-tier=%s" % self.scale_tier) return params