Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions automl/google/cloud/automl_v1beta1/tables/gcs_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
class GcsClient(object):
"""Uploads Pandas DataFrame to a bucket in Google Cloud Storage."""

def __init__(self, bucket_name=None, client=None, credentials=None, project=None):
def __init__(
self, bucket_name=None, client=None, credentials=None, project=None
):
"""Constructor.

Args:
Expand All @@ -65,7 +67,9 @@ def __init__(self, bucket_name=None, client=None, credentials=None, project=None
if client is not None:
self.client = client
elif credentials is not None:
self.client = storage.Client(credentials=credentials, project=project)
self.client = storage.Client(
credentials=credentials, project=project
)
else:
self.client = storage.Client()

Expand Down Expand Up @@ -97,7 +101,9 @@ def ensure_bucket_exists(self, project, region):
except (exceptions.Forbidden, exceptions.NotFound) as e:
if isinstance(e, exceptions.Forbidden):
used_bucket_name = self.bucket_name
self.bucket_name = used_bucket_name + "-{}".format(int(time.time()))
self.bucket_name = used_bucket_name + "-{}".format(
int(time.time())
)
_LOGGER.warning(
"Created a bucket named {} because a bucket named {} already exists in a different project.".format(
self.bucket_name, used_bucket_name
Expand All @@ -123,10 +129,14 @@ def upload_pandas_dataframe(self, dataframe, uploaded_csv_name=None):
raise ImportError(_PANDAS_REQUIRED)

if not isinstance(dataframe, pandas.DataFrame):
raise ValueError("'dataframe' must be a pandas.DataFrame instance.")
raise ValueError(
"'dataframe' must be a pandas.DataFrame instance."
)

if self.bucket_name is None:
raise ValueError("Must ensure a bucket exists before uploading data.")
raise ValueError(
"Must ensure a bucket exists before uploading data."
)

if uploaded_csv_name is None:
uploaded_csv_name = "automl-tables-dataframe-{}.csv".format(
Expand Down
114 changes: 88 additions & 26 deletions automl/google/cloud/automl_v1beta1/tables/tables_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from google.cloud.automl_v1beta1.proto import data_types_pb2
from google.cloud.automl_v1beta1.tables import gcs_client

_GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution("google-cloud-automl").version
_GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution(
"google-cloud-automl"
).version
_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -187,7 +189,11 @@ def __dataset_from_args(
region=None,
**kwargs
):
if dataset is None and dataset_display_name is None and dataset_name is None:
if (
dataset is None
and dataset_display_name is None
and dataset_name is None
):
raise ValueError(
"One of 'dataset', 'dataset_name' or "
"'dataset_display_name' must be set."
Expand Down Expand Up @@ -216,7 +222,8 @@ def __model_from_args(
):
if model is None and model_display_name is None and model_name is None:
raise ValueError(
"One of 'model', 'model_name' or " "'model_display_name' must be set."
"One of 'model', 'model_name' or "
"'model_display_name' must be set."
)
# we prefer to make a live call here in the case that the
# model object is out-of-date
Expand All @@ -240,7 +247,11 @@ def __dataset_name_from_args(
region=None,
**kwargs
):
if dataset is None and dataset_display_name is None and dataset_name is None:
if (
dataset is None
and dataset_display_name is None
and dataset_name is None
):
raise ValueError(
"One of 'dataset', 'dataset_name' or "
"'dataset_display_name' must be set."
Expand All @@ -259,7 +270,10 @@ def __dataset_name_from_args(
else:
# we do this to force a NotFound error when needed
self.get_dataset(
dataset_name=dataset_name, project=project, region=region, **kwargs
dataset_name=dataset_name,
project=project,
region=region,
**kwargs
)
return dataset_name

Expand All @@ -283,7 +297,8 @@ def __table_spec_name_from_args(
)

table_specs = [
t for t in self.list_table_specs(dataset_name=dataset_name, **kwargs)
t
for t in self.list_table_specs(dataset_name=dataset_name, **kwargs)
]

table_spec_full_id = table_specs[table_spec_index].name
Expand All @@ -300,7 +315,8 @@ def __model_name_from_args(
):
if model is None and model_display_name is None and model_name is None:
raise ValueError(
"One of 'model', 'model_name' or " "'model_display_name' must be set."
"One of 'model', 'model_name' or "
"'model_display_name' must be set."
)

if model_name is None:
Expand Down Expand Up @@ -527,7 +543,8 @@ def get_dataset(
"""
if dataset_name is None and dataset_display_name is None:
raise ValueError(
"One of 'dataset_name' or " "'dataset_display_name' must be set."
"One of 'dataset_name' or "
"'dataset_display_name' must be set."
)

if dataset_name is not None:
Expand All @@ -540,7 +557,12 @@ def get_dataset(
)

def create_dataset(
self, dataset_display_name, metadata={}, project=None, region=None, **kwargs
self,
dataset_display_name,
metadata={},
project=None,
region=None,
**kwargs
):
"""Create a dataset. Keep in mind, importing data is a separate step.

Expand Down Expand Up @@ -580,7 +602,10 @@ def create_dataset(
"""
return self.auto_ml_client.create_dataset(
self.__location_path(project, region),
{"display_name": dataset_display_name, "tables_dataset_metadata": metadata},
{
"display_name": dataset_display_name,
"tables_dataset_metadata": metadata,
},
**kwargs
)

Expand Down Expand Up @@ -767,7 +792,9 @@ def import_data(
credentials = credentials or self.credentials
self.__ensure_gcs_client_is_initialized(credentials, project)
self.gcs_client.ensure_bucket_exists(project, region)
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(pandas_dataframe)
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(
pandas_dataframe
)
request = {"gcs_source": {"input_uris": [gcs_input_uri]}}
elif gcs_input_uris is not None:
if type(gcs_input_uris) != list:
Expand Down Expand Up @@ -868,9 +895,13 @@ def export_data(

request = {}
if gcs_output_uri_prefix is not None:
request = {"gcs_destination": {"output_uri_prefix": gcs_output_uri_prefix}}
request = {
"gcs_destination": {"output_uri_prefix": gcs_output_uri_prefix}
}
elif bigquery_output_uri is not None:
request = {"bigquery_destination": {"output_uri": bigquery_output_uri}}
request = {
"bigquery_destination": {"output_uri": bigquery_output_uri}
}
else:
raise ValueError(
"One of 'gcs_output_uri_prefix', or 'bigquery_output_uri' must be set."
Expand All @@ -880,7 +911,9 @@ def export_data(
self.__log_operation_info("Export data", op)
return op

def get_table_spec(self, table_spec_name, project=None, region=None, **kwargs):
def get_table_spec(
self, table_spec_name, project=None, region=None, **kwargs
):
"""Gets a single table spec in a particular project and region.

Example:
Expand Down Expand Up @@ -992,7 +1025,9 @@ def list_table_specs(

return self.auto_ml_client.list_table_specs(dataset_name, **kwargs)

def get_column_spec(self, column_spec_name, project=None, region=None, **kwargs):
def get_column_spec(
self, column_spec_name, project=None, region=None, **kwargs
):
"""Gets a single column spec in a particular project and region.

Example:
Expand Down Expand Up @@ -1572,7 +1607,10 @@ def clear_time_column(
dataset_name=dataset_name, **kwargs
)

my_table_spec = {"name": table_spec_full_id, "time_column_spec_id": None}
my_table_spec = {
"name": table_spec_full_id,
"time_column_spec_id": None,
}

return self.auto_ml_client.update_table_spec(my_table_spec, **kwargs)

Expand Down Expand Up @@ -1766,7 +1804,9 @@ def clear_weight_column(
**kwargs
)
metadata = dataset.tables_dataset_metadata
metadata = self.__update_metadata(metadata, "weight_column_spec_id", None)
metadata = self.__update_metadata(
metadata, "weight_column_spec_id", None
)

request = {"name": dataset.name, "tables_dataset_metadata": metadata}

Expand Down Expand Up @@ -1964,7 +2004,9 @@ def clear_test_train_column(
**kwargs
)
metadata = dataset.tables_dataset_metadata
metadata = self.__update_metadata(metadata, "ml_use_column_spec_id", None)
metadata = self.__update_metadata(
metadata, "ml_use_column_spec_id", None
)

request = {"name": dataset.name, "tables_dataset_metadata": metadata}

Expand Down Expand Up @@ -2217,7 +2259,9 @@ def create_model(
**kwargs
)

model_metadata["train_budget_milli_node_hours"] = train_budget_milli_node_hours
model_metadata[
"train_budget_milli_node_hours"
] = train_budget_milli_node_hours
if optimization_objective is not None:
model_metadata["optimization_objective"] = optimization_objective
if disable_early_stopping:
Expand Down Expand Up @@ -2255,7 +2299,9 @@ def create_model(
}

op = self.auto_ml_client.create_model(
self.__location_path(project=project, region=region), request, **kwargs
self.__location_path(project=project, region=region),
request,
**kwargs
)
self.__log_operation_info("Model creation", op)
return op
Expand Down Expand Up @@ -2377,7 +2423,9 @@ def get_model_evaluation(
to a retryable error and retry attempts failed.
ValueError: If required parameters are missing.
"""
return self.auto_ml_client.get_model_evaluation(model_evaluation_name, **kwargs)
return self.auto_ml_client.get_model_evaluation(
model_evaluation_name, **kwargs
)

def get_model(
self,
Expand Down Expand Up @@ -2440,7 +2488,9 @@ def get_model(
return self.auto_ml_client.get_model(model_name, **kwargs)

return self.__lookup_by_display_name(
"model", self.list_models(project, region, **kwargs), model_display_name
"model",
self.list_models(project, region, **kwargs),
model_display_name,
)

# TODO(jonathanskim): allow deployment from just model ID
Expand Down Expand Up @@ -2596,6 +2646,7 @@ def predict(
model=None,
model_name=None,
model_display_name=None,
params=None,
project=None,
region=None,
**kwargs
Expand Down Expand Up @@ -2642,6 +2693,9 @@ def predict(
The `model` instance you want to predict with . This must be
supplied if `model_display_name` or `model_name` are not
supplied.
params (dict[str, str]):
`feature_importance` can be set as True to enable local
explainability. The default is false.

Returns:
A :class:`~google.cloud.automl_v1beta1.types.PredictResponse`
Expand Down Expand Up @@ -2678,12 +2732,16 @@ def predict(

values = []
for i, c in zip(inputs, column_specs):
value_type = self.__type_code_to_value_type(c.data_type.type_code, i)
value_type = self.__type_code_to_value_type(
c.data_type.type_code, i
)
values.append(value_type)

request = {"row": {"values": values}}

return self.prediction_client.predict(model.name, request, **kwargs)
return self.prediction_client.predict(
model.name, request, params, **kwargs
)

def batch_predict(
self,
Expand Down Expand Up @@ -2795,14 +2853,18 @@ def batch_predict(
credentials = credentials or self.credentials
self.__ensure_gcs_client_is_initialized(credentials, project)
self.gcs_client.ensure_bucket_exists(project, region)
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(pandas_dataframe)
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(
pandas_dataframe
)
input_request = {"gcs_source": {"input_uris": [gcs_input_uri]}}
elif gcs_input_uris is not None:
if type(gcs_input_uris) != list:
gcs_input_uris = [gcs_input_uris]
input_request = {"gcs_source": {"input_uris": gcs_input_uris}}
elif bigquery_input_uri is not None:
input_request = {"bigquery_source": {"input_uri": bigquery_input_uri}}
input_request = {
"bigquery_source": {"input_uri": bigquery_input_uri}
}
else:
raise ValueError(
"One of 'gcs_input_uris'/'bigquery_input_uris' must" "be set"
Expand Down