Skip to content

Commit

Permalink
bigtable and legacy support
Browse files Browse the repository at this point in the history
  • Loading branch information
riteshghorse committed Feb 22, 2024
1 parent 4a13130 commit 2d0376c
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging

__all__ = [
Expand All @@ -23,7 +22,9 @@

from typing import List

from google.cloud.aiplatform_v1 import FetchFeatureValuesRequest, FeatureOnlineStoreServiceClient
import proto
from google.api_core.exceptions import NotFound
from google.cloud import aiplatform

import apache_beam as beam
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
Expand All @@ -43,8 +44,7 @@ def __init__(
api_endpoint: str,
feature_store_name: str,
feature_view_name: str,
entity_type_name: str,
feature_ids: List[str]):
entity_type_name: str):
"""Initializes an instance of `VertexAIFeatureStoreEnrichmentHandler`.
Args:
Expand All @@ -56,39 +56,106 @@ def __init__(
feature store.
entity_type_name (str): The name of the entity type within the
feature store.
feature_ids (List[str]): A list of feature IDs to fetch
from the feature store.
"""
self.project = project
self.location = location
self.api_endpoint = api_endpoint
self.feature_store_name = feature_store_name
self.feature_view_name = feature_view_name
self.entity_type_name = entity_type_name
self.feature_ids = feature_ids

def __enter__(self):
self.client = FeatureOnlineStoreServiceClient(
self.client = aiplatform.gapic.FeatureOnlineStoreServiceClient(
client_options={"api_endpoint": self.api_endpoint})

def __call__(self, request, *args, **kwargs):
entity_id = request._asdict()[self.entity_type_name]
response = self.client.fetch_feature_values(
FetchFeatureValuesRequest(
feature_view=(
"projects/%s/locations/%s/featureOnlineStores/%s/feature"
"Views/%s" % (
self.project,
self.location,
self.feature_store_name,
self.feature_view_name)),
data_key=entity_id,
))
response_dict = json.loads(response.key_values)
return request, response_dict
def __call__(self, request: beam.Row, *args, **kwargs):
try:
entity_id = request._asdict()[self.entity_type_name]
except KeyError:
raise ValueError(
"no entry found for entity_type_name %s in input row" %
self.entity_type_name)
try:
response = self.client.fetch_feature_values(
request=aiplatform.gapic.FetchFeatureValuesRequest(
data_key=aiplatform.gapic.FeatureViewDataKey(key=entity_id),
feature_view=(
"projects/%s/locations/%s/featureOnlineStores/%s/"
"featureViews/%s" % (
self.project,
self.location,
self.feature_store_name,
self.feature_view_name)),
data_format=aiplatform.gapic.FeatureViewDataFormat.PROTO_STRUCT,
))
except NotFound:
raise ValueError("entity_id: %s not found" % entity_id)
response_dict = dict(response.proto_struct)
return request, beam.Row(**response_dict)

def __exit__(self, exc_type, exc_val, exc_tb):
self.client.__exit__()
self.client = None

def get_cache_key(self, request):
return 'entity_id: %s'


class VertexAIFeatureStoreLegacyEnrichmentHandler(EnrichmentSourceHandler):
def __init__(
self,
project: str,
location: str,
api_endpoint: str,
feature_store_id: str,
entity_type_id: str,
feature_ids: List[str],
entity_id: str):
"""Initializes an instance of `VertexAIFeatureStoreEnrichmentHandler`.
Args:
project (str): The GCP project for the Vertex AI feature store.
location (str): The region for the Vertex AI feature store.
api_endpoint (str): The API endpoint for the Vertex AI feature store.
feature_store_id (str): The name of the Vertex AI feature store.
entity_type_id (str): The name of the feature view within the
feature store.
feature_ids (List[str]): A list of feature IDs to fetch
from the feature store.
"""
self.project = project
self.location = location
self.api_endpoint = api_endpoint
self.feature_store_id = feature_store_id
self.entity_type_id = entity_type_id
self.feature_ids = feature_ids
self.entity_id = entity_id

def __enter__(self):
self.client = aiplatform.gapic.FeaturestoreOnlineServingServiceClient(
client_options={'api_endpoint': self.api_endpoint})

def __call__(self, request: beam.Row, *args, **kwargs):
entity_id = request._asdict()[self.entity_id]
selector = aiplatform.gapic.FeatureSelector(
id_matcher=aiplatform.gapic.IdMatcher(ids=self.feature_ids))
response = self.client.read_feature_values(
request=aiplatform.gapic.ReadFeatureValuesRequest(
entity_type=(
"projects/%s/locations/%s/featurestores/%s/entityTypes/%s" % (
self.project,
self.location,
self.feature_store_id,
self.entity_type_id)),
entity_id=entity_id,
feature_selector=selector))
response_dict = {}
proto_to_dict = proto.Message.to_dict(response.entity_view)
for key, msg in zip(response.header.feature_descriptors,
proto_to_dict['data']):
for _, value in msg['value'].items():
response_dict[key.id] = value
break # skip fetching the metadata
return request, beam.Row(**response_dict)

def __exit__(self, exc_type, exc_val, exc_tb):
self.client = None
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,78 @@
#
import unittest

import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import VertexAIFeatureStoreEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import VertexAIFeatureStoreLegacyEnrichmentHandler


class TestVertexAIFeatureStoreHandler(unittest.TestCase):
pass
def setUp(self) -> None:
self.project = 'google.com:clouddfe'
self.location = 'us-central1'
self.feature_view_name = "registry_product"
self.entity_type_name = "entity_id"

def test_vertex_ai_feature_store_bigtable_serving_enrichment(self):
requests = [
beam.Row(
entity_id="847", name="fred perry men\'s sharp stripe t-shirt")
]
feature_store_name = "the_look_demo_unique"
api_endpoint = "us-central1-aiplatform.googleapis.com"
handler = VertexAIFeatureStoreEnrichmentHandler(
project=self.project,
location=self.location,
api_endpoint=api_endpoint,
feature_store_name=feature_store_name,
feature_view_name=self.feature_view_name,
entity_type_name=self.entity_type_name)

with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (test_pipeline | beam.Create(requests) | Enrichment(handler))

def test_vertex_ai_feature_store_optimized_serving_enrichment(self):
requests = [
beam.Row(
entity_id="16050", name="fred perry men\'s sharp stripe t-shirt")
]
feature_store_name = "the_look_demo_optimized_public_unique"
api_endpoint = (
"2484122222587805696.us-central1-927334603519.featurestore."
"vertexai.goog")
handler = VertexAIFeatureStoreEnrichmentHandler(
project=self.project,
location=self.location,
api_endpoint=api_endpoint,
feature_store_name=feature_store_name,
feature_view_name=self.feature_view_name,
entity_type_name=self.entity_type_name)

with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (test_pipeline | beam.Create(requests) | Enrichment(handler))

def test_vertex_ai_legacy_feature_store_enrichment(self):
requests = [
beam.Row(
entity_id="movie_02", name="fred perry men\'s sharp stripe t-shirt")
]
feature_store_id = "movie_prediction_unique"
api_endpoint = "us-central1-aiplatform.googleapis.com"
entity_type_id = "movies"
feature_ids = ['title', 'genres']
handler = VertexAIFeatureStoreLegacyEnrichmentHandler(
project=self.project,
location=self.location,
api_endpoint=api_endpoint,
feature_store_id=feature_store_id,
entity_type_id=entity_type_id,
feature_ids=feature_ids,
entity_id=self.entity_type_name)

with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (test_pipeline | beam.Create(requests) | Enrichment(handler))


if __name__ == '__main__':
Expand Down

0 comments on commit 2d0376c

Please sign in to comment.