Skip to content

Commit 39749b0

Browse files
rsamborskiriathakkar
authored andcommitted
feat(GenAI): Gemma2 samples for Model Garden deployments to Vertex AI endpoints (GoogleCloudPlatform#12598)
* Initial working samples * Tests ready with mocks * Updated CODEOWNERS * Fix lint and tests * Fix lint and tests * Moving samples to fix testing issues when in subfolder of generative-ai * Add missing pytest * Trailling newline
1 parent c4cc173 commit 39749b0

File tree

7 files changed

+308
-0
lines changed

7 files changed

+308
-0
lines changed

.github/CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
/cdn/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
2222
/compute/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
2323
/dns/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
24+
/gemma2/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
25+
/generative_ai/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
2426
/iam/cloud-client/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
2527
/kms/**/** @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
2628
/media_cdn/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers

gemma2/gemma2_predict_gpu.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHcontent WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import sys
17+
18+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
19+
20+
21+
def gemma2_predict_gpu(ENDPOINT_REGION: str, ENDPOINT_ID: str) -> str:
22+
# [START generativeaionvertexai_gemma2_predict_gpu]
23+
"""
24+
Sample to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accellerators.
25+
"""
26+
27+
from google.cloud import aiplatform
28+
from google.protobuf import json_format
29+
from google.protobuf.struct_pb2 import Value
30+
31+
# TODO(developer): Update & uncomment lines below
32+
# PROJECT_ID = "your-project-id"
33+
# ENDPOINT_REGION = "your-vertex-endpoint-region"
34+
# ENDPOINT_ID = "your-vertex-endpoint-id"
35+
36+
# Default configuration
37+
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}
38+
39+
# Prompt used in the prediction
40+
prompt = "Why is the sky blue?"
41+
42+
# Encapsulate the prompt in a correct format for GPUs
43+
# Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}]
44+
input = {"inputs": prompt, "parameters": config}
45+
46+
# Convert input message to a list of GAPIC instances for model input
47+
instances = [json_format.ParseDict(input, Value())]
48+
49+
# Create a client
50+
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
51+
client = aiplatform.gapic.PredictionServiceClient(
52+
client_options={"api_endpoint": api_endpoint}
53+
)
54+
55+
# Call the Gemma2 endpoint
56+
gemma2_end_point = (
57+
f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
58+
)
59+
response = client.predict(
60+
endpoint=gemma2_end_point,
61+
instances=instances,
62+
)
63+
text_responses = response.predictions
64+
print(text_responses[0])
65+
66+
# [END generativeaionvertexai_gemma2_predict_gpu]
67+
return text_responses[0]
68+
69+
70+
if __name__ == "__main__":
71+
if len(sys.argv) != 3:
72+
print(
73+
"Usage: python gemma2_predict_gpu.py <GEMMA2_ENDPOINT_REGION> <GEMMA2_ENDPOINT_ID>"
74+
)
75+
sys.exit(1)
76+
77+
ENDPOINT_REGION = sys.argv[1]
78+
ENDPOINT_ID = sys.argv[2]
79+
gemma2_predict_gpu(ENDPOINT_REGION, ENDPOINT_ID)

gemma2/gemma2_predict_tpu.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHcontent WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import sys
17+
18+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
19+
20+
21+
def gemma2_predict_tpu(ENDPOINT_REGION: str, ENDPOINT_ID: str) -> str:
22+
# [START generativeaionvertexai_gemma2_predict_tpu]
23+
"""
24+
Sample to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
25+
"""
26+
27+
from google.cloud import aiplatform
28+
from google.protobuf import json_format
29+
from google.protobuf.struct_pb2 import Value
30+
31+
# TODO(developer): Update & uncomment lines below
32+
# PROJECT_ID = "your-project-id"
33+
# ENDPOINT_REGION = "your-vertex-endpoint-region"
34+
# ENDPOINT_ID = "your-vertex-endpoint-id"
35+
36+
# Default configuration
37+
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}
38+
39+
# Prompt used in the prediction
40+
prompt = "Why is the sky blue?"
41+
42+
# Encapsulate the prompt in a correct format for TPUs
43+
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
44+
input = {"prompt": prompt}
45+
input.update(config)
46+
47+
# Convert input message to a list of GAPIC instances for model input
48+
instances = [json_format.ParseDict(input, Value())]
49+
50+
# Create a client
51+
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
52+
client = aiplatform.gapic.PredictionServiceClient(
53+
client_options={"api_endpoint": api_endpoint}
54+
)
55+
56+
# Call the Gemma2 endpoint
57+
gemma2_end_point = (
58+
f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
59+
)
60+
response = client.predict(
61+
endpoint=gemma2_end_point,
62+
instances=instances,
63+
)
64+
text_responses = response.predictions
65+
print(text_responses[0])
66+
67+
# [END generativeaionvertexai_gemma2_predict_tpu]
68+
return text_responses[0]
69+
70+
71+
if __name__ == "__main__":
72+
if len(sys.argv) != 3:
73+
print(
74+
"Usage: python gemma2_predict_tpu.py <GEMMA2_ENDPOINT_REGION> <GEMMA2_ENDPOINT_ID>"
75+
)
76+
sys.exit(1)
77+
78+
ENDPOINT_REGION = sys.argv[1]
79+
ENDPOINT_ID = sys.argv[2]
80+
gemma2_predict_tpu(ENDPOINT_REGION, ENDPOINT_ID)

gemma2/gemma2_test.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHcontent WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from typing import MutableSequence, Optional
17+
from unittest import mock
18+
from unittest.mock import MagicMock
19+
20+
from google.cloud.aiplatform_v1.types import prediction_service
21+
import google.protobuf.struct_pb2 as struct_pb2
22+
from google.protobuf.struct_pb2 import Value
23+
24+
from gemma2_predict_gpu import gemma2_predict_gpu
25+
from gemma2_predict_tpu import gemma2_predict_tpu
26+
27+
# Global variables
28+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
29+
GPU_ENDPOINT_REGION = "us-east1"
30+
GPU_ENDPOINT_ID = "123456789" # Mock ID used to check if GPU was called
31+
32+
TPU_ENDPOINT_REGION = "us-west1"
33+
TPU_ENDPOINT_ID = "987654321" # Mock ID used to check if TPU was called
34+
35+
# MOCKED RESPONSE
36+
MODEL_RESPONSES = """
37+
The sky appears blue due to a phenomenon called **Rayleigh scattering**.
38+
39+
**Here's how it works:**
40+
41+
1. **Sunlight:** Sunlight is composed of all the colors of the rainbow.
42+
43+
2. **Earth's Atmosphere:** When sunlight enters the Earth's atmosphere, it collides with tiny particles like nitrogen and oxygen molecules.
44+
45+
3. **Scattering:** These particles scatter the sunlight in all directions. However, blue light (which has a shorter wavelength) is scattered more effectively than other colors.
46+
47+
4. **Our Perception:** As a result, we see a blue sky because the scattered blue light reaches our eyes from all directions.
48+
49+
**Why not other colors?**
50+
51+
* **Violet light** has an even shorter wavelength than blue and is scattered even more. However, our eyes are less sensitive to violet light, so we perceive the sky as blue.
52+
* **Longer wavelengths** like red, orange, and yellow are scattered less and travel more directly through the atmosphere. This is why we see these colors during sunrise and sunset, when sunlight has to travel through more of the atmosphere.
53+
"""
54+
55+
56+
# Mocked function - we check if proper format was used depending on selected architecture
57+
def mock_predict(
58+
endpoint: Optional[str] = None,
59+
instances: Optional[MutableSequence[struct_pb2.Value]] = None,
60+
) -> prediction_service.PredictResponse:
61+
gpu_endpoint = f"projects/{PROJECT_ID}/locations/{GPU_ENDPOINT_REGION}/endpoints/{GPU_ENDPOINT_ID}"
62+
tpu_endpoint = f"projects/{PROJECT_ID}/locations/{TPU_ENDPOINT_REGION}/endpoints/{TPU_ENDPOINT_ID}"
63+
instance_fields = instances[0].struct_value.fields
64+
65+
if endpoint == gpu_endpoint:
66+
assert "string_value" in instance_fields["inputs"]
67+
assert "struct_value" in instance_fields["parameters"]
68+
parameters = instance_fields["parameters"].struct_value.fields
69+
assert "number_value" in parameters["max_tokens"]
70+
assert "number_value" in parameters["temperature"]
71+
assert "number_value" in parameters["top_p"]
72+
assert "number_value" in parameters["top_k"]
73+
elif endpoint == tpu_endpoint:
74+
assert "string_value" in instance_fields["prompt"]
75+
assert "number_value" in instance_fields["max_tokens"]
76+
assert "number_value" in instance_fields["temperature"]
77+
assert "number_value" in instance_fields["top_p"]
78+
assert "number_value" in instance_fields["top_k"]
79+
else:
80+
assert False
81+
82+
response = prediction_service.PredictResponse()
83+
response.predictions.append(Value(string_value=MODEL_RESPONSES))
84+
return response
85+
86+
87+
@mock.patch("google.cloud.aiplatform.gapic.PredictionServiceClient")
88+
def test_gemma2_predict_gpu(mock_client: MagicMock) -> None:
89+
mock_client_instance = mock_client.return_value
90+
mock_client_instance.predict = mock_predict
91+
92+
response = gemma2_predict_gpu(GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID)
93+
assert "Rayleigh scattering" in response
94+
95+
96+
@mock.patch("google.cloud.aiplatform.gapic.PredictionServiceClient")
97+
def test_gemma2_predict_tpu(mock_client: MagicMock) -> None:
98+
mock_client_instance = mock_client.return_value
99+
mock_client_instance.predict = mock_predict
100+
101+
response = gemma2_predict_tpu(TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID)
102+
assert "Rayleigh scattering" in response

gemma2/noxfile_config.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Default TEST_CONFIG_OVERRIDE for python repos.
16+
17+
# You can copy this file into your directory, then it will be imported from
18+
# the noxfile.py.
19+
20+
# The source of truth:
21+
# https://github.com/GoogleCloudPlatform/python-docs-samples/blob/main/noxfile_config.py
22+
23+
TEST_CONFIG_OVERRIDE = {
24+
# You can opt out from the test for specific Python versions.
25+
"ignored_versions": ["2.7", "3.7", "3.9", "3.10", "3.11"],
26+
# Old samples are opted out of enforcing Python type hints
27+
# All new samples should feature them
28+
"enforce_type_hints": True,
29+
# An envvar key for determining the project id to use. Change it
30+
# to 'BUILD_SPECIFIC_GCLOUD_PROJECT' if you want to opt in using a
31+
# build specific Cloud project. You can also use your own string
32+
# to use your own Cloud project.
33+
"gcloud_project_env": "GOOGLE_CLOUD_PROJECT",
34+
# 'gcloud_project_env': 'BUILD_SPECIFIC_GCLOUD_PROJECT',
35+
# If you need to use a specific version of pip,
36+
# change pip_version_override to the string representation
37+
# of the version number, for example, "20.2.4"
38+
"pip_version_override": None,
39+
# A dictionary you want to inject into your test. Don't put any
40+
# secrets here. These values will override predefined values.
41+
"envs": {},
42+
}

gemma2/requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pytest==8.3.3

gemma2/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
google-cloud-aiplatform[all]==1.64.0
2+
protobuf==5.28.1

0 commit comments

Comments
 (0)