From e2b8f68ddf84bf55f350461aaf1c801e9a152ac1 Mon Sep 17 00:00:00 2001 From: Christian Yarros Date: Fri, 14 Jun 2024 10:13:50 -0400 Subject: [PATCH] add generation_config and safety_settings to google cloud multimodal model operators (#40126) --- .../cloud/hooks/vertex_ai/generative_model.py | 18 ++++++++++--- .../operators/vertex_ai/generative_model.py | 16 +++++++++++ airflow/providers/google/provider.yaml | 2 +- generated/provider_dependencies.json | 2 +- .../hooks/vertex_ai/test_generative_model.py | 27 +++++++++++++++++-- .../vertex_ai/test_generative_model.py | 24 +++++++++++++++++ .../example_vertex_ai_generative_model.py | 13 +++++++++ 7 files changed, 95 insertions(+), 7 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py index c71759c890901..eb3db0f3c6c79 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py @@ -141,6 +141,8 @@ def prompt_multimodal_model( self, prompt: str, location: str, + generation_config: dict | None = None, + safety_settings: dict | None = None, pretrained_model: str = "gemini-pro", project_id: str = PROVIDE_PROJECT_ID, ) -> str: @@ -149,17 +151,21 @@ def prompt_multimodal_model( :param prompt: Required. Inputs or queries that a user or a program gives to the Multi-modal model, in order to elicit a specific response. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param generation_config: Optional. Generation configuration settings. + :param safety_settings: Optional. Per request settings for blocking unsafe content. :param pretrained_model: By default uses the pre-trained model `gemini-pro`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code. - :param location: Required. The ID of the Google Cloud location that the service belongs to. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. """ vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) model = self.get_generative_model(pretrained_model) - response = model.generate_content(prompt) + response = model.generate_content( + contents=[prompt], generation_config=generation_config, safety_settings=safety_settings + ) return response.text @@ -170,6 +176,8 @@ def prompt_multimodal_model_with_media( location: str, media_gcs_path: str, mime_type: str, + generation_config: dict | None = None, + safety_settings: dict | None = None, pretrained_model: str = "gemini-pro-vision", project_id: str = PROVIDE_PROJECT_ID, ) -> str: @@ -178,6 +186,8 @@ def prompt_multimodal_model_with_media( :param prompt: Required. Inputs or queries that a user or a program gives to the Multi-modal model, in order to elicit a specific response. + :param generation_config: Optional. Generation configuration settings. + :param safety_settings: Optional. Per request settings for blocking unsafe content. :param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can @@ -192,6 +202,8 @@ def prompt_multimodal_model_with_media( model = self.get_generative_model(pretrained_model) part = self.get_generative_model_part(media_gcs_path, mime_type) - response = model.generate_content([prompt, part]) + response = model.generate_content( + contents=[prompt, part], generation_config=generation_config, safety_settings=safety_settings + ) return response.text diff --git a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py index da1436a6ab3cc..a42b00c6777da 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py @@ -187,6 +187,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator): service belongs to (templated). :param prompt: Required. Inputs or queries that a user or a program gives to the Multi-modal model, in order to elicit a specific response (templated). + :param generation_config: Optional. Generation configuration settings. + :param safety_settings: Optional. Per request settings for blocking unsafe content. :param pretrained_model: By default uses the pre-trained model `gemini-pro`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can @@ -210,6 +212,8 @@ def __init__( project_id: str, location: str, prompt: str, + generation_config: dict | None = None, + safety_settings: dict | None = None, pretrained_model: str = "gemini-pro", gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -219,6 +223,8 @@ def __init__( self.project_id = project_id self.location = location self.prompt = prompt + self.generation_config = generation_config + self.safety_settings = safety_settings self.pretrained_model = pretrained_model self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -232,6 +238,8 @@ def execute(self, context: Context): project_id=self.project_id, location=self.location, prompt=self.prompt, + generation_config=self.generation_config, + safety_settings=self.safety_settings, pretrained_model=self.pretrained_model, ) @@ -251,6 +259,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator): service belongs to (templated). :param prompt: Required. Inputs or queries that a user or a program gives to the Multi-modal model, in order to elicit a specific response (templated). + :param generation_config: Optional. Generation configuration settings. + :param safety_settings: Optional. Per request settings for blocking unsafe content. :param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can @@ -279,6 +289,8 @@ def __init__( prompt: str, media_gcs_path: str, mime_type: str, + generation_config: dict | None = None, + safety_settings: dict | None = None, pretrained_model: str = "gemini-pro-vision", gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -288,6 +300,8 @@ def __init__( self.project_id = project_id self.location = location self.prompt = prompt + self.generation_config = generation_config + self.safety_settings = safety_settings self.pretrained_model = pretrained_model self.media_gcs_path = media_gcs_path self.mime_type = mime_type @@ -303,6 +317,8 @@ def execute(self, context: Context): project_id=self.project_id, location=self.location, prompt=self.prompt, + generation_config=self.generation_config, + safety_settings=self.safety_settings, pretrained_model=self.pretrained_model, media_gcs_path=self.media_gcs_path, mime_type=self.mime_type, diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 205e9f398edd4..376dcb92f2a8d 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -113,7 +113,7 @@ dependencies: - google-api-python-client>=2.0.2 - google-auth>=2.29.0 - google-auth-httplib2>=0.0.1 - - google-cloud-aiplatform>=1.42.1 + - google-cloud-aiplatform>=1.54.0 - google-cloud-automl>=2.12.0 # google-cloud-bigquery version 3.21.0 introduced a performance enhancement in QueryJob.result(), # which has led to backward compatibility issues diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 6c7ca749f7d86..6fdf5d9c2988d 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -593,7 +593,7 @@ "google-api-python-client>=2.0.2", "google-auth-httplib2>=0.0.1", "google-auth>=2.29.0", - "google-cloud-aiplatform>=1.42.1", + "google-cloud-aiplatform>=1.54.0", "google-cloud-automl>=2.12.0", "google-cloud-batch>=0.13.0", "google-cloud-bigquery-datatransfer>=3.13.0", diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py index 1899222174505..2308903485c19 100644 --- a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py @@ -23,6 +23,8 @@ # For no Pydantic environment, we need to skip the tests pytest.importorskip("google.cloud.aiplatform_v1") +vertexai = pytest.importorskip("vertexai.generative_models") +from vertexai.generative_models import HarmBlockThreshold, HarmCategory from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import ( GenerativeModelHook, @@ -45,6 +47,17 @@ TEST_TEXT_EMBEDDING_MODEL = "" TEST_MULTIMODAL_PRETRAINED_MODEL = "gemini-pro" +TEST_SAFETY_SETTINGS = { + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, +} +TEST_GENERATION_CONFIG = { + "max_output_tokens": TEST_MAX_OUTPUT_TOKENS, + "top_p": TEST_TOP_P, + "temperature": TEST_TEMPERATURE, +} TEST_MULTIMODAL_VISION_MODEL = "gemini-pro-vision" TEST_VISION_PROMPT = "In 10 words or less, describe this content." @@ -104,10 +117,16 @@ def test_prompt_multimodal_model(self, mock_model) -> None: project_id=GCP_PROJECT, location=GCP_LOCATION, prompt=TEST_PROMPT, + generation_config=TEST_GENERATION_CONFIG, + safety_settings=TEST_SAFETY_SETTINGS, pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL, ) mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL) - mock_model.return_value.generate_content.assert_called_once_with(TEST_PROMPT) + mock_model.return_value.generate_content.assert_called_once_with( + contents=[TEST_PROMPT], + generation_config=TEST_GENERATION_CONFIG, + safety_settings=TEST_SAFETY_SETTINGS, + ) @mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model_part")) @mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model")) @@ -116,6 +135,8 @@ def test_prompt_multimodal_model_with_media(self, mock_model, mock_part) -> None project_id=GCP_PROJECT, location=GCP_LOCATION, prompt=TEST_VISION_PROMPT, + generation_config=TEST_GENERATION_CONFIG, + safety_settings=TEST_SAFETY_SETTINGS, pretrained_model=TEST_MULTIMODAL_VISION_MODEL, media_gcs_path=TEST_MEDIA_GCS_PATH, mime_type=TEST_MIME_TYPE, @@ -124,5 +145,7 @@ def test_prompt_multimodal_model_with_media(self, mock_model, mock_part) -> None mock_part.assert_called_once_with(TEST_MEDIA_GCS_PATH, TEST_MIME_TYPE) mock_model.return_value.generate_content.assert_called_once_with( - [TEST_VISION_PROMPT, mock_part.return_value] + contents=[TEST_VISION_PROMPT, mock_part.return_value], + generation_config=TEST_GENERATION_CONFIG, + safety_settings=TEST_SAFETY_SETTINGS, ) diff --git a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py index c9c23019f5aab..a5afd8ac2766f 100644 --- a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py +++ b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py @@ -22,6 +22,8 @@ # For no Pydantic environment, we need to skip the tests pytest.importorskip("google.cloud.aiplatform_v1") +vertexai = pytest.importorskip("vertexai.generative_models") +from vertexai.generative_models import HarmBlockThreshold, HarmCategory from airflow.providers.google.cloud.operators.vertex_ai.generative_model import ( GenerateTextEmbeddingsOperator, @@ -112,12 +114,21 @@ class TestVertexAIPromptMultimodalModelOperator: def test_execute(self, mock_hook): prompt = "In 10 words or less, what is Apache Airflow?" pretrained_model = "gemini-pro" + safety_settings = { + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + } + generation_config = {"max_output_tokens": 256, "top_p": 0.8, "temperature": 0.0} op = PromptMultimodalModelOperator( task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, prompt=prompt, + generation_config=generation_config, + safety_settings=safety_settings, pretrained_model=pretrained_model, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -131,6 +142,8 @@ def test_execute(self, mock_hook): project_id=GCP_PROJECT, location=GCP_LOCATION, prompt=prompt, + generation_config=generation_config, + safety_settings=safety_settings, pretrained_model=pretrained_model, ) @@ -142,12 +155,21 @@ def test_execute(self, mock_hook): vision_prompt = "In 10 words or less, describe this content." media_gcs_path = "gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg" mime_type = "image/jpeg" + safety_settings = { + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + } + generation_config = {"max_output_tokens": 256, "top_p": 0.8, "temperature": 0.0} op = PromptMultimodalModelWithMediaOperator( task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, prompt=vision_prompt, + generation_config=generation_config, + safety_settings=safety_settings, pretrained_model=pretrained_model, media_gcs_path=media_gcs_path, mime_type=mime_type, @@ -163,6 +185,8 @@ def test_execute(self, mock_hook): project_id=GCP_PROJECT, location=GCP_LOCATION, prompt=vision_prompt, + generation_config=generation_config, + safety_settings=safety_settings, pretrained_model=pretrained_model, media_gcs_path=media_gcs_path, mime_type=mime_type, diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py index 101cedaf7e9e5..95c141c7af33d 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py @@ -25,6 +25,8 @@ import os from datetime import datetime +from vertexai.generative_models import HarmBlockThreshold, HarmCategory + from airflow.models.dag import DAG from airflow.providers.google.cloud.operators.vertex_ai.generative_model import ( GenerateTextEmbeddingsOperator, @@ -44,6 +46,13 @@ VISION_PROMPT = "In 10 words or less, describe this content." MEDIA_GCS_PATH = "gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg" MIME_TYPE = "image/jpeg" +GENERATION_CONFIG = {"max_output_tokens": 256, "top_p": 0.95, "temperature": 0.0} +SAFETY_SETTINGS = { + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, +} with DAG( dag_id=DAG_ID, @@ -79,6 +88,8 @@ project_id=PROJECT_ID, location=REGION, prompt=PROMPT, + generation_config=GENERATION_CONFIG, + safety_settings=SAFETY_SETTINGS, pretrained_model=MULTIMODAL_MODEL, ) # [END how_to_cloud_vertex_ai_prompt_multimodal_model_operator] @@ -89,6 +100,8 @@ project_id=PROJECT_ID, location=REGION, prompt=VISION_PROMPT, + generation_config=GENERATION_CONFIG, + safety_settings=SAFETY_SETTINGS, pretrained_model=MULTIMODAL_VISION_MODEL, media_gcs_path=MEDIA_GCS_PATH, mime_type=MIME_TYPE,