From 91ee0108fba8442e84900e8379c0abeaab483539 Mon Sep 17 00:00:00 2001
From: Googler <nobody@google.com>
Date: Tue, 26 Mar 2024 18:55:06 -0700
Subject: [PATCH] chore(components): Add test machine spec support to
 `preview.llm` pipelines

PiperOrigin-RevId: 619378459
---
 .../_implementation/llm/function_based.py     | 33 ++++++++++++++-----
 .../_implementation/llm/validate_pipeline.py  |  9 ++++-
 2 files changed, 32 insertions(+), 10 deletions(-)

diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/function_based.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/function_based.py
index 099d7b4f96c..f0e82152dc7 100644
--- a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/function_based.py
+++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/function_based.py
@@ -22,7 +22,7 @@
 
 @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
 def resolve_machine_spec(
-    accelerator_type: str = '',
+    accelerator_type: str = 'GPU',
     use_test_spec: bool = False,
 ) -> NamedTuple(
     'MachineSpec',
@@ -37,7 +37,8 @@ def resolve_machine_spec(
     accelerator_type: One of 'TPU' or 'GPU'. If 'TPU' is specified, tuning
       components run in europe-west4. Otherwise tuning components run in
       us-central1 on GPUs. Default is 'GPU'.
-    use_test_spec: Whether to use a lower resource machine for testing.
+    use_test_spec: Whether to use a lower resource machine for testing. If True,
+      a machine with the specified `accelerator_type` is provisioned.
 
   Returns:
     Machine spec.
@@ -61,14 +62,27 @@ def resolve_machine_spec(
           accelerator_count=32,
           tuning_location='europe-west4',
       )
-    else:
+    elif accelerator_type == 'GPU':
       return outputs(
           machine_type='a2-highgpu-1g',
           accelerator_type='NVIDIA_TESLA_A100',
           accelerator_count=1,
           tuning_location='us-central1',
       )
-  elif accelerator_type == 'TPU':
+    elif accelerator_type == 'CPU':
+      return outputs(
+          machine_type='e2-standard-16',
+          accelerator_type='ACCELERATOR_TYPE_UNSPECIFIED',
+          accelerator_count=0,
+          tuning_location='us-central1',
+      )
+    else:
+      raise ValueError(
+          f'Unsupported test accelerator_type {accelerator_type}. Must be one '
+          'of TPU, GPU or CPU.'
+      )
+
+  if accelerator_type == 'TPU':
     return outputs(
         machine_type='cloud-tpu',
         accelerator_type='TPU_V3',
@@ -82,10 +96,11 @@ def resolve_machine_spec(
         accelerator_count=8,
         tuning_location='us-central1',
     )
-  raise ValueError(
-      f'Unsupported accelerator type {accelerator_type}. Must be one of'
-      'TPU or GPU.'
-  )
+  else:
+    raise ValueError(
+        f'Unsupported accelerator_type {accelerator_type}. Must be one of'
+        'TPU or GPU.'
+    )
 
 
 @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
@@ -114,7 +129,7 @@ def resolve_refined_image_uri(
   Raises:
     ValueError: if an unsupported accelerator type is provided.
   """
-  if not accelerator_type:
+  if not accelerator_type or accelerator_type == 'ACCELERATOR_TYPE_UNSPECIFIED':
     accelerator_postfix = 'cpu'
   elif 'TPU' in accelerator_type:
     accelerator_postfix = 'tpu'
diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/validate_pipeline.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/validate_pipeline.py
index 232b20af52f..44623fb2c2d 100644
--- a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/validate_pipeline.py
+++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/validate_pipeline.py
@@ -88,7 +88,14 @@ def validate_pipeline(
           f' {supported_pipeline_regions}.'
       )
 
-    valid_cmek_config = location == 'us-central1' and accelerator_type == 'GPU'
+    valid_cmek_accelerator_types = {
+        'GPU',
+        'CPU',  # Only used for testing.
+    }
+    valid_cmek_config = (
+        location == 'us-central1'
+        and accelerator_type in valid_cmek_accelerator_types
+    )
     if encryption_spec_key_name and not valid_cmek_config:
       raise ValueError(
           'encryption_spec_key_name (CMEK) is only supported for GPU training'