Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add GenerationConfig and SafetySettings parameters to Google Cloud Multimodal Model Operators #40126

Merged
merged 4 commits into from
Jun 14, 2024
Merged

Add GenerationConfig and SafetySettings parameters to Google Cloud Multimodal Model Operators #40126

merged 4 commits into from
Jun 14, 2024

Conversation

CYarros10
Copy link
Contributor

@CYarros10 CYarros10 commented Jun 7, 2024

Gemini / Multimodal model Airflow Hooks/Operators should mimic typical usage as defined in the sample code block below. This provides fine-tuning model output via the additional configurations GenerationConfig and SafetySettings

Updated Operator Usage

    prompt_multimodal_model_task = PromptMultimodalModelOperator(
        task_id="prompt_multimodal_model_task",
        project_id=PROJECT_ID,
        location=REGION,
        prompt=PROMPT,
        generation_config=GENERATION_CONFIG,
        safety_settings=SAFETY_SETTINGS,
        pretrained_model=MULTIMODAL_MODEL,
    )

Sample Code

vertexai.init(project=project_id, location="us-central1")

model = generative_models.GenerativeModel(model_name="gemini-1.0-pro-vision-001")

# Generation config
generation_config = generative_models.GenerationConfig(
    max_output_tokens=2048, temperature=0.4, top_p=1, top_k=32
)

# Safety config
safety_config = [
    generative_models.SafetySetting(
        category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
        threshold=generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
    ),
    generative_models.SafetySetting(
        category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,
        threshold=generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
    ),
]

image_file = Part.from_uri(
    "gs://cloud-samples-data/generative-ai/image/scones.jpg", "image/jpeg"
)

# Generate content
responses = model.generate_content(
    [image_file, "What is in this image?"],
    generation_config=generation_config,
    safety_settings=safety_config,
    stream=True,
)

^ Add meaningful description above
Read the Pull Request Guidelines for more information.
In case of fundamental code changes, an Airflow Improvement Proposal (AIP) is needed.
In case of a new dependency, check compliance with the ASF 3rd Party License Policy.
In case of backwards incompatible changes please leave a note in a newsfragment file, named {pr_number}.significant.rst or {issue_number}.significant.rst, in newsfragments.

@boring-cyborg boring-cyborg bot added area:providers area:system-tests provider:google Google (including GCP) related issues labels Jun 7, 2024
Copy link
Contributor

@MaksYermak MaksYermak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@amoghrajesh amoghrajesh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@potiuk potiuk merged commit e2b8f68 into apache:main Jun 14, 2024
108 checks passed
romsharon98 pushed a commit to romsharon98/airflow that referenced this pull request Jul 26, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
area:providers area:system-tests provider:google Google (including GCP) related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants