diff --git a/example.py b/example.py index e151a72..3d1ccd5 100644 --- a/example.py +++ b/example.py @@ -1,8 +1,9 @@ # This is an example feature definition file -from google.protobuf.duration_pb2 import Duration +import datetime -from feast import Entity, Feature, FeatureView, ValueType +from feast import Entity, FeatureView, Field, ValueType +from feast.types import PrimitiveFeastType from feast_hive import HiveSource # Read data from Hive table @@ -10,30 +11,31 @@ # but you can replace to your own Table or Query. driver_hourly_stats = HiveSource( # table='driver_stats', - query = """ + query=""" SELECT from_unixtime(cast(event_timestamp / 1000000 as bigint)) AS event_timestamp, driver_id, conv_rate, acc_rate, avg_daily_trips, from_unixtime(cast(created / 1000000 as bigint)) AS created FROM driver_stats """, - event_timestamp_column="event_timestamp", + name="driver_stats", + timestamp_field="event_timestamp", created_timestamp_column="created", ) # Define an entity for the driver. -driver = Entity(name="driver_id", value_type=ValueType.INT64, description="driver id", ) +driver = Entity(name="driver_id", join_keys=["driver_id"], value_type=ValueType.INT64, description="driver id", ) # Define FeatureView driver_hourly_stats_view = FeatureView( name="driver_hourly_stats", - entities=["driver_id"], - ttl=Duration(seconds=86400 * 1), - features=[ - Feature(name="conv_rate", dtype=ValueType.FLOAT), - Feature(name="acc_rate", dtype=ValueType.FLOAT), - Feature(name="avg_daily_trips", dtype=ValueType.INT64), + entities=[driver], + ttl=datetime.timedelta(seconds=86400 * 1), + schema=[ + Field(name="conv_rate", dtype=PrimitiveFeastType.FLOAT32), + Field(name="acc_rate", dtype=PrimitiveFeastType.FLOAT32), + Field(name="avg_daily_trips", dtype=PrimitiveFeastType.INT32), ], online=True, - input=driver_hourly_stats, + source=driver_hourly_stats, tags={}, -) \ No newline at end of file +) diff --git a/feast_hive/hive.py b/feast_hive/hive.py index d4b57fe..7fa95aa 100644 --- a/feast_hive/hive.py +++ b/feast_hive/hive.py @@ -1,11 +1,11 @@ import contextlib -from datetime import datetime +from datetime import date, datetime from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union +import uuid import numpy as np import pandas as pd import pyarrow as pa -from dateutil import parser from pydantic import StrictBool, StrictInt, StrictStr from pytz import utc from six import reraise @@ -15,12 +15,14 @@ from feast.errors import InvalidEntityType from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL from feast.infra.offline_stores import offline_utils -from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob +from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob, RetrievalMetadata from feast.on_demand_feature_view import OnDemandFeatureView -from feast.registry import Registry +from feast.infra.registry.registry import Registry from feast.repo_config import FeastConfigBaseModel, RepoConfig +from feast.saved_dataset import SavedDatasetStorage from feast_hive.hive_source import HiveSource from feast_hive.hive_type_map import hive_to_pa_value_type, pa_to_hive_value_type +from feast_hive.hive_source import SavedDatasetHiveStorage try: from impala.dbapi import connect as impala_connect @@ -140,13 +142,47 @@ def __exit__(self, exc_type, exc_val, exc_tb): class HiveOfflineStore(OfflineStore): + + @staticmethod + def pull_all_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + start_date: datetime, + end_date: datetime) -> RetrievalJob: + assert isinstance(config.offline_store, HiveOfflineStoreConfig) + assert isinstance(data_source, HiveSource) + + from_expression = data_source.get_table_query_string() + + field_string = ", ".join( + join_key_columns + feature_name_columns + [timestamp_field] + ) + + start_date = _format_datetime(start_date) + end_date = _format_datetime(end_date) + + queries = [ + "SET hive.resultset.use.unique.column.names=false", + f""" + SELECT {field_string} + FROM {from_expression} + WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}') + """, + ] + + conn = HiveConnection(config.offline_store) + return HiveRetrievalJob(conn, queries) + @staticmethod def pull_latest_from_table_or_query( config: RepoConfig, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], - event_timestamp_column: str, + timestamp_field: str, created_timestamp_column: Optional[str], start_date: datetime, end_date: datetime, @@ -161,7 +197,7 @@ def pull_latest_from_table_or_query( partition_by_join_key_string = ( "PARTITION BY " + partition_by_join_key_string ) - timestamps = [event_timestamp_column] + timestamps = [timestamp_field] if created_timestamp_column and created_timestamp_column not in timestamps: timestamps.append(created_timestamp_column) timestamp_desc_string = " DESC, ".join(timestamps) + " DESC" @@ -180,7 +216,7 @@ def pull_latest_from_table_or_query( SELECT {field_string}, ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS feast_row_ FROM {from_expression} t1 - WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}') + WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}') ) t2 WHERE feast_row_ = 1 """, @@ -200,21 +236,32 @@ def get_historical_features( full_feature_names: bool = False, ) -> RetrievalJob: assert isinstance(config.offline_store, HiveOfflineStoreConfig) + for fv in feature_views: + assert isinstance(fv.batch_source, HiveSource) conn = HiveConnection(config.offline_store) + table_name = offline_utils.get_temp_entity_table_name() + + entity_schema = _get_entity_schema( + conn, + entity_df + ) + + entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema + ) + + entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( + entity_df, entity_df_event_timestamp_col, conn, + ) + @contextlib.contextmanager def query_generator() -> ContextManager[List[str]]: - table_name = offline_utils.get_temp_entity_table_name() - try: - entity_schema = _upload_entity_df_and_get_entity_schema( + _upload_entity_df_and_get_entity_schema( config, conn, table_name, entity_df ) - entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( - entity_schema - ) - expected_join_keys = offline_utils.get_expected_join_keys( project, feature_views, registry ) @@ -223,10 +270,6 @@ def query_generator() -> ContextManager[List[str]]: entity_schema, expected_join_keys, entity_df_event_timestamp_col ) - entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, entity_df_event_timestamp_col, conn, table_name, - ) - query_contexts = offline_utils.get_feature_view_query_context( feature_refs, feature_views, @@ -266,6 +309,12 @@ def query_generator() -> ContextManager[List[str]]: on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs( feature_refs, project, registry ), + metadata=RetrievalMetadata( + features=feature_refs, + keys=list(entity_schema.keys() - {entity_df_event_timestamp_col}), + min_event_timestamp=entity_df_event_timestamp_range[0], + max_event_timestamp=entity_df_event_timestamp_range[1], + ), ) @@ -276,6 +325,7 @@ def __init__( queries: Union[str, List[str], Callable[[], ContextManager[List[str]]]], full_feature_names: bool = False, on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, + metadata: Optional[RetrievalMetadata] = None, ): assert ( isinstance(queries, str) or isinstance(queries, list) or callable(queries) @@ -301,6 +351,7 @@ def query_generator() -> ContextManager[List[str]]: self._conn = conn self._full_feature_names = full_feature_names self._on_demand_feature_views = on_demand_feature_views + self._metadata = metadata @property def full_feature_names(self) -> bool: @@ -331,6 +382,43 @@ def _to_arrow_internal(self) -> pa.Table: ] return pa.Table.from_batches(pa_batches, schema) + def to_hive( + self, + destination_table: Optional[str] = None, + ) -> Optional[str]: + """ + Triggers the execution of a historical feature retrieval query and exports the results to a Hive table. + Args: + destination_table: the destination table name. + Returns: + Returns the destination table name. + """ + if not destination_table: + today = date.today().strftime("%Y%m%d") + rand_id = str(uuid.uuid4())[:7] + destination_table = f"historical_{today}_{rand_id}" + + with self._queries_generator() as queries: + with self._conn.cursor() as cursor: + queries[-1] = f""" + CREATE TABLE {destination_table} STORED AS PARQUET AS + {queries[-1]} + """ + for query in queries: + cursor.execute(query) + return destination_table + + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): + if not isinstance(storage, SavedDatasetHiveStorage): + raise ValueError( + f"The storage object is not a `SavedDatasetHiveStorage` but is instead a {type(storage)}" + ) + storage.hive_options.table = self.to_hive(destination_table=storage.hive_options.table) + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return self._metadata + @staticmethod def _convert_hive_batch_to_arrow_batch( hive_batch: ImpalaCBatch, schema: pa.Schema @@ -360,6 +448,26 @@ def _format_datetime(t: datetime): return t +def _get_entity_schema( + conn: HiveConnection, entity_df: Union[pd.DataFrame, str] +) -> Dict[str, np.dtype]: + if isinstance(entity_df, str): + entity_df_sample = HiveRetrievalJob( + conn, + [ + "SET hive.resultset.use.unique.column.names=false", + f"SELECT * FROM ({entity_df}) AS t LIMIT 1", + ], + ).to_df() + entity_schema = dict(zip(entity_df_sample.columns, entity_df_sample.dtypes)) + elif isinstance(entity_df, pd.DataFrame): + entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) + else: + raise InvalidEntityType(type(entity_df)) + + return entity_schema + + def _upload_entity_df_and_get_entity_schema( config: RepoConfig, conn: HiveConnection, @@ -464,44 +572,37 @@ def _get_entity_df_event_timestamp_range( entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, conn: HiveConnection, - table_name: str, ) -> Tuple[datetime, datetime]: - # TODO haven't got time to test this yet, - # just use fake min and max datetime for now, since they will be detected inside the sql - return datetime.now(), datetime.now() - - # if isinstance(entity_df, pd.DataFrame): - # entity_df_event_timestamp = entity_df.loc[ - # :, entity_df_event_timestamp_col - # ].infer_objects() - # if pd.api.types.is_string_dtype(entity_df_event_timestamp): - # entity_df_event_timestamp = pd.to_datetime( - # entity_df_event_timestamp, utc=True - # ) - # entity_df_event_timestamp_range = ( - # entity_df_event_timestamp.min(), - # entity_df_event_timestamp.max(), - # ) - # elif isinstance(entity_df, str): - # # If the entity_df is a string (SQL query), determine range - # # from table - # with conn.cursor() as cursor: - # cursor.execute( - # f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM {table_name}", - # ) - # result = cursor.fetchone() - # assert ( - # result is not None - # ), "Fetching the EntityDataframe's timestamp range failed." - # # TODO haven't tested this yet - # entity_df_event_timestamp_range = ( - # parser.parse(result[0]), - # parser.parse(result[1]), - # ) - # else: - # raise InvalidEntityType(type(entity_df)) - # - # return entity_df_event_timestamp_range + if isinstance(entity_df, pd.DataFrame): + entity_df_event_timestamp = entity_df.loc[ + :, entity_df_event_timestamp_col + ].infer_objects() + if pd.api.types.is_string_dtype(entity_df_event_timestamp): + entity_df_event_timestamp = pd.to_datetime( + entity_df_event_timestamp, utc=True + ) + entity_df_event_timestamp_range = ( + entity_df_event_timestamp.min().to_pydatetime(), + entity_df_event_timestamp.max().to_pydatetime(), + ) + elif isinstance(entity_df, str): + with conn.cursor() as cursor: + cursor.execute( + f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max " + f"FROM ({entity_df}) AS t", + ) + result = cursor.fetchone() + assert ( + result is not None + ), "Fetching the EntityDataframe's timestamp range failed." + entity_df_event_timestamp_range = ( + pd.to_datetime(result[0]).to_pydatetime(), + pd.to_datetime(result[1]).to_pydatetime(), + ) + else: + raise InvalidEntityType(type(entity_df)) + + return entity_df_event_timestamp_range # This query is based on sdk/python/feast/infra/offline_stores/bigquery.py:MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN @@ -583,11 +684,11 @@ def _get_entity_df_event_timestamp_range( {{ featureview.name }}__subquery AS ( SELECT - {{ featureview.event_timestamp_column }} as event_timestamp, + {{ featureview.timestamp_field }} as event_timestamp, {{ featureview.created_timestamp_column ~ ' as created_timestamp,' if featureview.created_timestamp_column else '' }} {{ featureview.entity_selections | join(', ')}}{% if featureview.entity_selections %},{% else %}{% endif %} {% for feature in featureview.features %} - {{ feature }} as {% if full_feature_names %}{{ featureview.name }}__{{feature}}{% else %}{{ feature }}{% endif %}{% if loop.last %}{% else %}, {% endif %} + {{ feature }} as {% if full_feature_names %}{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %}{% if loop.last %}{% else %}, {% endif %} {% endfor %} FROM {{ featureview.table_subquery }} AS subquery INNER JOIN ( @@ -598,9 +699,9 @@ def _get_entity_df_event_timestamp_range( FROM entity_dataframe ) AS temp ON ( - {{ featureview.event_timestamp_column }} <= max_entity_timestamp_ + {{ featureview.timestamp_field }} <= max_entity_timestamp_ {% if featureview.ttl == 0 %}{% else %} - AND {{ featureview.event_timestamp_column }} >= min_entity_timestamp_ + AND {{ featureview.timestamp_field }} >= min_entity_timestamp_ {% endif %} ) ) @@ -710,7 +811,7 @@ def _get_entity_df_event_timestamp_range( SELECT {{featureview.name}}__entity_row_unique_id {% for feature in featureview.features %} - ,{% if full_feature_names %}{{ featureview.name }}__{{feature}}{% else %}{{ feature }}{% endif %} + ,{% if full_feature_names %}{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %} {% endfor %} FROM {{ featureview.name }}__cleaned ) AS {{ featureview.name }}__joined diff --git a/feast_hive/hive_source.py b/feast_hive/hive_source.py index 6e5d25d..74c0f52 100644 --- a/feast_hive/hive_source.py +++ b/feast_hive/hive_source.py @@ -6,6 +6,10 @@ from feast.errors import DataSourceNoNameException from feast.errors import DataSourceNotFoundException from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.protos.feast.core.SavedDataset_pb2 import ( + SavedDatasetStorage as SavedDatasetStorageProto, +) +from feast.saved_dataset import SavedDatasetStorage from feast_hive.hive_type_map import hive_to_feast_value_type @@ -83,9 +87,9 @@ def __init__( self, *, name: Optional[str] = None, - table: Optional[str] = None, + timestamp_field: Optional[str] = None, query: Optional[str] = None, - event_timestamp_column: Optional[str] = "", + table: Optional[str] = None, created_timestamp_column: Optional[str] = "", field_mapping: Optional[Dict[str, str]] = None, date_partition_column: Optional[str] = "", @@ -109,10 +113,10 @@ def __init__( super().__init__( name=_name if _name else "", - event_timestamp_column, - created_timestamp_column, - field_mapping, - date_partition_column, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + field_mapping=field_mapping, + date_partition_column=date_partition_column, description=description, tags=tags, owner=owner, @@ -130,7 +134,7 @@ def __eq__(self, other): self.name == other.name and self.hive_options.table == other.hive_options.table and self.hive_options.query == other.hive_options.query - and self.event_timestamp_column == other.event_timestamp_column + and self.timestamp_field == other.timestamp_field and self.created_timestamp_column == other.created_timestamp_column and self.field_mapping == other.field_mapping and self.date_partition_column == other.date_partition_column @@ -172,7 +176,7 @@ def to_proto(self) -> DataSourceProto: owner=self.owner, ) - data_source_proto.event_timestamp_column = self.event_timestamp_column + data_source_proto.timestamp_field = self.timestamp_field data_source_proto.created_timestamp_column = self.created_timestamp_column data_source_proto.date_partition_column = self.date_partition_column return data_source_proto @@ -189,7 +193,7 @@ def from_proto(data_source: DataSourceProto): field_mapping=dict(data_source.field_mapping), table=hive_options.table, query=hive_options.query, - event_timestamp_column=data_source.event_timestamp_column, + timestamp_field=data_source.timestamp_field, created_timestamp_column=data_source.created_timestamp_column, date_partition_column=data_source.date_partition_column, description=data_source.description, @@ -245,3 +249,26 @@ def get_table_column_names_and_types( return [(field[0], field[1]) for field in cursor.description] except HiveServer2Error: raise DataSourceNotFoundException(self.query) + + +class SavedDatasetHiveStorage(SavedDatasetStorage): + _proto_attr_name = "custom_storage" + + hive_options: HiveOptions + + def __init__(self, table: Optional[str] = None): + self.hive_options = HiveOptions(table=table, query=None) + + @staticmethod + def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage: + return SavedDatasetHiveStorage( + table=HiveOptions.from_proto(storage_proto.custom_storage).table, + ) + + def to_proto(self) -> SavedDatasetStorageProto: + return SavedDatasetStorageProto( + custom_storage=self.hive_options.to_proto() + ) + + def to_data_source(self) -> DataSource: + return HiveSource(table=self.hive_options.table) diff --git a/setup.py b/setup.py index 20e39f3..6a90614 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ readme = f.read() INSTALL_REQUIRE = [ - "feast>=0.17.0", + "feast>=0.26.0", "impyla[kerberos]>=0.15.0", ] @@ -21,7 +21,7 @@ setup( name="feast-hive", - version="0.17.0", + version="0.26.0", author="Benn Ma", author_email="bennmsg@gmail.com", description="Hive support for Feast offline store", diff --git a/tests/feast_tests_funcs.py b/tests/feast_tests_funcs.py index 0cb4dec..8c5142f 100644 --- a/tests/feast_tests_funcs.py +++ b/tests/feast_tests_funcs.py @@ -5,8 +5,9 @@ import numpy as np import pandas as pd -from feast import FeatureView, Feature, ValueType, FeatureStore +from feast import FeatureView, FeatureStore, Field, Entity from feast.data_source import DataSource +from feast.types import PrimitiveFeastType from pytz import FixedOffset, timezone, utc DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL = "event_timestamp" @@ -112,7 +113,7 @@ def create_driver_hourly_stats_df(drivers, start_date, end_date) -> pd.DataFrame "event_timestamp": [ pd.Timestamp(dt, unit="ms", tz="UTC").round("ms") for dt in pd.date_range( - start=start_date, end=end_date, freq="1H", closed="left" + start=start_date, end=end_date, freq="1H", inclusive="left" ) ] # include a fixed timestamp for get_historical_features in the quickstart @@ -173,7 +174,7 @@ def create_customer_daily_profile_df(customers, start_date, end_date) -> pd.Data "event_timestamp": [ pd.Timestamp(dt, unit="ms", tz="UTC").round("ms") for dt in pd.date_range( - start=start_date, end=end_date, freq="1D", closed="left" + start=start_date, end=end_date, freq="1D", inclusive="left" ) ] } @@ -220,32 +221,32 @@ def generate_entities(date, infer_event_timestamp_col, order_count: int = 1000): return customer_entities, driver_entities, end_date, orders_df, start_date -def create_driver_hourly_stats_feature_view(source): +def create_driver_hourly_stats_feature_view(source: DataSource, entity: Entity): driver_stats_feature_view = FeatureView( name="driver_stats", - entities=["driver"], - features=[ - Feature(name="conv_rate", dtype=ValueType.FLOAT), - Feature(name="acc_rate", dtype=ValueType.FLOAT), - Feature(name="avg_daily_trips", dtype=ValueType.INT32), + entities=[entity], + schema=[ + Field(name="conv_rate", dtype=PrimitiveFeastType.FLOAT32), + Field(name="acc_rate", dtype=PrimitiveFeastType.FLOAT32), + Field(name="avg_daily_trips", dtype=PrimitiveFeastType.INT32), ], - batch_source=source, + source=source, ttl=timedelta(hours=2), ) return driver_stats_feature_view -def create_customer_daily_profile_feature_view(source): +def create_customer_daily_profile_feature_view(source: DataSource, entity: Entity): customer_profile_feature_view = FeatureView( name="customer_profile", - entities=["customer_id"], - features=[ - Feature(name="current_balance", dtype=ValueType.FLOAT), - Feature(name="avg_passenger_count", dtype=ValueType.FLOAT), - Feature(name="lifetime_trip_count", dtype=ValueType.INT32), - Feature(name="avg_daily_trips", dtype=ValueType.INT32), + entities=[entity], + schema=[ + Field(name="current_balance", dtype=PrimitiveFeastType.FLOAT32), + Field(name="avg_passenger_count", dtype=PrimitiveFeastType.FLOAT32), + Field(name="lifetime_trip_count", dtype=PrimitiveFeastType.INT32), + Field(name="avg_daily_trips", dtype=PrimitiveFeastType.INT32), ], - batch_source=source, + source=source, ttl=timedelta(days=2), ) return customer_profile_feature_view @@ -282,35 +283,35 @@ def get_expected_training_df( driver_df: pd.DataFrame, driver_fv: FeatureView, orders_df: pd.DataFrame, - event_timestamp: str, + timestamp_field: str, full_feature_names: bool = False, ): # Convert all pandas dataframes into records with UTC timestamps order_records = convert_timestamp_records_to_utc( - orders_df.to_dict("records"), event_timestamp + orders_df.to_dict("records"), timestamp_field ) driver_records = convert_timestamp_records_to_utc( - driver_df.to_dict("records"), driver_fv.batch_source.event_timestamp_column + driver_df.to_dict("records"), driver_fv.batch_source.timestamp_field ) customer_records = convert_timestamp_records_to_utc( - customer_df.to_dict("records"), customer_fv.batch_source.event_timestamp_column + customer_df.to_dict("records"), customer_fv.batch_source.timestamp_field ) # Manually do point-in-time join of orders to drivers and customers records for order_record in order_records: driver_record = find_asof_record( driver_records, - ts_key=driver_fv.batch_source.event_timestamp_column, - ts_start=order_record[event_timestamp] - driver_fv.ttl, - ts_end=order_record[event_timestamp], + ts_key=driver_fv.batch_source.timestamp_field, + ts_start=order_record[timestamp_field] - driver_fv.ttl, + ts_end=order_record[timestamp_field], filter_key="driver_id", filter_value=order_record["driver_id"], ) customer_record = find_asof_record( customer_records, - ts_key=customer_fv.batch_source.event_timestamp_column, - ts_start=order_record[event_timestamp] - customer_fv.ttl, - ts_end=order_record[event_timestamp], + ts_key=customer_fv.batch_source.timestamp_field, + ts_start=order_record[timestamp_field] - customer_fv.ttl, + ts_end=order_record[timestamp_field], filter_key="customer_id", filter_value=order_record["customer_id"], ) @@ -342,8 +343,8 @@ def get_expected_training_df( # Move "event_timestamp" column to front current_cols = expected_df.columns.tolist() - current_cols.remove(event_timestamp) - expected_df = expected_df[[event_timestamp] + current_cols] + current_cols.remove(timestamp_field) + expected_df = expected_df[[timestamp_field] + current_cols] # Cast some columns to expected types, since we lose information when converting pandas DFs into Python objects. if full_feature_names: @@ -371,7 +372,7 @@ def create_dataset() -> pd.DataFrame: now = datetime.utcnow() ts = pd.Timestamp(now).round("ms") data = { - "id": [1, 2, 1, 3, 3], + "driver_id": [1, 2, 1, 3, 3], "value": [0.1, None, 0.3, 4, 5], "ts_1": [ ts - timedelta(hours=4), @@ -390,13 +391,13 @@ def create_dataset() -> pd.DataFrame: return pd.DataFrame.from_dict(data) -def correctness_feature_view(data_source: DataSource) -> FeatureView: +def correctness_feature_view(data_source: DataSource, entity: Entity) -> FeatureView: return FeatureView( name="test_correctness", - entities=["driver"], - features=[Feature("value", ValueType.FLOAT)], + entities=[entity], + schema=[Field(name="value", dtype=PrimitiveFeastType.FLOAT32)], ttl=timedelta(days=5), - input=data_source, + source=data_source, ) @@ -412,7 +413,7 @@ def check_offline_and_online_features( # Check online store response_dict = fs.get_online_features( [f"{fv.name}:value"], - [{"driver": driver_id}], + [{"driver_id": driver_id}], full_feature_names=full_feature_names, ).to_dict() diff --git a/tests/test_all.py b/tests/test_all.py index 260e0e4..40559d9 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -82,21 +82,21 @@ def prep_hive_fs_and_fv( ), TemporaryDirectory() as repo_dir_name, TemporaryDirectory() as data_dir_name: hive_source = HiveSource( + name=table_name, table=table_name if source_type == "table" else None, query=f"SELECT * FROM {table_name}" if source_type == "query" else None, - event_timestamp_column="ts", + timestamp_field="ts", created_timestamp_column="created_ts", date_partition_column="", - field_mapping={"ts_1": "ts", "id": "driver_id"}, + field_mapping={"ts_1": "ts"}, ) - - fv = feast_tests_funcs.correctness_feature_view(hive_source) e = Entity( name="driver", description="id for driver", - join_key="driver_id", + join_keys=["driver_id"], value_type=ValueType.INT32, ) + fv = feast_tests_funcs.correctness_feature_view(hive_source, e) config = RepoConfig( registry=str(Path(repo_dir_name) / "registry.db"), project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}", @@ -105,6 +105,7 @@ def prep_hive_fs_and_fv( path=str(Path(data_dir_name) / "online_store.db") ), offline_store=offline_store, + entity_key_serialization_version=2, ) fs = FeatureStore(config=config) fs.apply([fv, e]) @@ -159,28 +160,29 @@ def test_hive_source(hive_conn_info): path=os.path.join(temp_dir, "online_store.db"), ), offline_store=offline_store, + entity_key_serialization_version=2, ) non_existed_table = f"{table_name}_non_existed" # Test table doesn't exist - hive_source_table = HiveSource(table=non_existed_table) + hive_source_table = HiveSource(name=non_existed_table, table=non_existed_table) assertpy.assert_that(hive_source_table.validate).raises( errors.DataSourceNotFoundException ).when_called_with(config) - hive_source_table = HiveSource(query=f"SELECT * FROM {non_existed_table}") + hive_source_table = HiveSource(name=non_existed_table, query=f"SELECT * FROM {non_existed_table}") assertpy.assert_that(hive_source_table.validate).raises( errors.DataSourceNotFoundException ).when_called_with(config) # Test table - hive_source_table = HiveSource(table=table_name) + hive_source_table = HiveSource(name=table_name, table=table_name) schema1 = hive_source_table.get_table_column_names_and_types(config) assert expected_schema == schema1 # Test query - hive_source_table = HiveSource(query=f"SELECT * FROM {table_name} LIMIT 100") + hive_source_table = HiveSource(name=table_name, query=f"SELECT * FROM {table_name} LIMIT 100") schema2 = hive_source_table.get_table_column_names_and_types(config) assert expected_schema == schema2 @@ -290,26 +292,29 @@ def test_historical_features_from_hive_sources( with orders_context, driver_context, customer_context, TemporaryDirectory() as temp_dir: driver_source = HiveSource( + name=driver_table_name, table=driver_table_name, - event_timestamp_column="event_timestamp", + timestamp_field="event_timestamp", created_timestamp_column="created", ) + driver = Entity(name="driver", join_keys=["driver_id"], value_type=ValueType.INT64) driver_fv = feast_tests_funcs.create_driver_hourly_stats_feature_view( - driver_source + driver_source, + driver, ) customer_source = HiveSource( + name=customer_table_name, table=customer_table_name, - event_timestamp_column="event_timestamp", + timestamp_field="event_timestamp", created_timestamp_column="created", ) + customer = Entity(name="customer", join_keys=["customer_id"], value_type=ValueType.INT64) customer_fv = feast_tests_funcs.create_customer_daily_profile_feature_view( - customer_source + customer_source, + customer, ) - driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64) - customer = Entity(name="customer_id", value_type=ValueType.INT64) - if provider_type == "local": store = FeatureStore( config=RepoConfig( @@ -320,6 +325,7 @@ def test_historical_features_from_hive_sources( path=os.path.join(temp_dir, "online_store.db"), ), offline_store=offline_store, + entity_key_serialization_version=2, ) ) else: @@ -328,7 +334,7 @@ def test_historical_features_from_hive_sources( store.apply([driver, customer, driver_fv, customer_fv]) try: - event_timestamp = ( + timestamp_field = ( feast_tests_funcs.DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL if feast_tests_funcs.DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns @@ -340,7 +346,7 @@ def test_historical_features_from_hive_sources( driver_df, driver_fv, orders_df, - event_timestamp, + timestamp_field, full_feature_names, ) @@ -371,11 +377,11 @@ def test_historical_features_from_hive_sources( ) assert_frame_equal( expected_df.sort_values( - by=[event_timestamp, "order_id", "driver_id", "customer_id"] + by=[timestamp_field, "order_id", "driver_id", "customer_id"] ).reset_index(drop=True), actual_df_from_sql_entities[expected_df.columns] .sort_values( - by=[event_timestamp, "order_id", "driver_id", "customer_id"] + by=[timestamp_field, "order_id", "driver_id", "customer_id"] ) .reset_index(drop=True), check_dtype=False, @@ -384,11 +390,11 @@ def test_historical_features_from_hive_sources( table_from_sql_entities = job_from_sql.to_arrow() assert_frame_equal( actual_df_from_sql_entities.sort_values( - by=[event_timestamp, "order_id", "driver_id", "customer_id"] + by=[timestamp_field, "order_id", "driver_id", "customer_id"] ).reset_index(drop=True), table_from_sql_entities.to_pandas() .sort_values( - by=[event_timestamp, "order_id", "driver_id", "customer_id"] + by=[timestamp_field, "order_id", "driver_id", "customer_id"] ) .reset_index(drop=True), ) @@ -461,11 +467,11 @@ def test_historical_features_from_hive_sources( ) assert_frame_equal( expected_df.sort_values( - by=[event_timestamp, "order_id", "driver_id", "customer_id"] + by=[timestamp_field, "order_id", "driver_id", "customer_id"] ).reset_index(drop=True), actual_df_from_df_entities[expected_df.columns] .sort_values( - by=[event_timestamp, "order_id", "driver_id", "customer_id"] + by=[timestamp_field, "order_id", "driver_id", "customer_id"] ) .reset_index(drop=True), check_dtype=False, @@ -474,11 +480,11 @@ def test_historical_features_from_hive_sources( table_from_df_entities = job_from_df.to_arrow() assert_frame_equal( actual_df_from_df_entities.sort_values( - by=[event_timestamp, "order_id", "driver_id", "customer_id"] + by=[timestamp_field, "order_id", "driver_id", "customer_id"] ).reset_index(drop=True), table_from_df_entities.to_pandas() .sort_values( - by=[event_timestamp, "order_id", "driver_id", "customer_id"] + by=[timestamp_field, "order_id", "driver_id", "customer_id"] ) .reset_index(drop=True), )