diff --git a/ibis/backends/sql/__init__.py b/ibis/backends/sql/__init__.py index c3a03823cb7e..6687bdf69337 100644 --- a/ibis/backends/sql/__init__.py +++ b/ibis/backends/sql/__init__.py @@ -423,7 +423,7 @@ def insert( Parameters ---------- name - The name of the table to which data needs will be inserted + The name of the table to which data will be inserted obj The source data or expression to insert database @@ -453,22 +453,30 @@ def insert( with self._safe_raw_sql(query): pass - def _build_insert_from_table( + def _get_columns_to_insert( self, *, target: str, source, db: str | None = None, catalog: str | None = None ): - compiler = self.compiler - quoted = compiler.quoted # Compare the columns between the target table and the object to be inserted # If source is a subset of target, use source columns for insert list # Otherwise, assume auto-generated column names and use positional ordering. target_cols = self.get_schema(target, catalog=catalog, database=db).keys() - columns = ( + return ( source_cols if (source_cols := source.schema().keys()) <= target_cols else target_cols ) + def _build_insert_from_table( + self, *, target: str, source, db: str | None = None, catalog: str | None = None + ): + compiler = self.compiler + quoted = compiler.quoted + + columns = self._get_columns_to_insert( + target=target, source=source, db=db, catalog=catalog + ) + query = sge.insert( expression=self.compile(source), into=sg.table(target, db=db, catalog=catalog, quoted=quoted), @@ -526,6 +534,116 @@ def _build_insert_template( ), ).sql(self.dialect) + def upsert( + self, + name: str, + /, + obj: pd.DataFrame | ir.Table | list | dict, + on: str, + *, + database: str | None = None, + ) -> None: + """Upsert data into a table. + + ::: {.callout-note} + ## Ibis does not use the word `schema` to refer to database hierarchy. + + A collection of `table` is referred to as a `database`. + A collection of `database` is referred to as a `catalog`. + + These terms are mapped onto the corresponding features in each + backend (where available), regardless of whether the backend itself + uses the same terminology. + ::: + + Parameters + ---------- + name + The name of the table to which data will be upserted + obj + The source data or expression to upsert + on + Column name to join on + database + Name of the attached database that the table is located in. + + For backends that support multi-level table hierarchies, you can + pass in a dotted string path like `"catalog.database"` or a tuple of + strings like `("catalog", "database")`. + """ + table_loc = self._to_sqlglot_table(database) + catalog, db = self._to_catalog_db_tuple(table_loc) + + if not isinstance(obj, ir.Table): + obj = ibis.memtable(obj) + + self._run_pre_execute_hooks(obj) + + query = self._build_upsert_from_table( + target=name, source=obj, on=on, db=db, catalog=catalog + ) + + with self._safe_raw_sql(query): + pass + + def _build_upsert_from_table( + self, + *, + target: str, + source, + on: str, + db: str | None = None, + catalog: str | None = None, + ): + compiler = self.compiler + quoted = compiler.quoted + + columns = self._get_columns_to_insert( + target=target, source=source, db=db, catalog=catalog + ) + + source_alias = util.gen_name("source") + target_alias = util.gen_name("target") + query = sge.merge( + sge.When( + matched=True, + then=sge.Update( + expressions=[ + sg.column(col, quoted=quoted).eq( + sg.column(col, table=source_alias, quoted=quoted) + ) + for col in columns + if col != on + ] + ), + ), + sge.When( + matched=False, + then=sge.Insert( + this=sge.Tuple( + expressions=[sg.column(col, quoted=quoted) for col in columns] + ), + expression=sge.Tuple( + expressions=[ + sg.column(col, table=source_alias, quoted=quoted) + for col in columns + ] + ), + ), + ), + into=sg.table(target, db=db, catalog=catalog, quoted=quoted).as_( + sg.to_identifier(target_alias, quoted=quoted), table=True + ), + using=f"({self.compile(source)}) AS {sg.to_identifier(source_alias, quoted=quoted)}", + on=sge.Paren( + this=sg.column(on, table=target_alias, quoted=quoted).eq( + sg.column(on, table=source_alias, quoted=quoted) + ) + ), + dialect=compiler.dialect, + ) + return query + def truncate_table(self, name: str, /, *, database: str | None = None) -> None: """Delete all rows from a table. diff --git a/ibis/backends/tests/conftest.py b/ibis/backends/tests/conftest.py index a9c69b64b57b..6a5e16a94b4b 100644 --- a/ibis/backends/tests/conftest.py +++ b/ibis/backends/tests/conftest.py @@ -1,9 +1,21 @@ from __future__ import annotations +import sqlite3 + import pytest +from packaging.version import parse as vparse import ibis.common.exceptions as com -from ibis.backends.tests.errors import MySQLOperationalError +from ibis.backends.tests.errors import ( + ClickHouseDatabaseError, + ImpalaHiveServer2Error, + MySQLOperationalError, + MySQLProgrammingError, + PsycoPg2InternalError, + Py4JJavaError, + PySparkUnsupportedOperationException, + TrinoUserError, +) def combine_marks(marks: list) -> callable: @@ -50,7 +62,6 @@ def decorator(func): ] NO_ARRAY_SUPPORT = combine_marks(NO_ARRAY_SUPPORT_MARKS) - NO_STRUCT_SUPPORT_MARKS = [ pytest.mark.never(["mysql", "sqlite", "mssql"], reason="No struct support"), pytest.mark.notyet(["impala"]), @@ -78,3 +89,53 @@ def decorator(func): pytest.mark.notimpl(["datafusion", "exasol", "mssql", "druid", "oracle"]), ] NO_JSON_SUPPORT = combine_marks(NO_JSON_SUPPORT_MARKS) + +try: + import pyspark + + pyspark_merge_exception = ( + PySparkUnsupportedOperationException + if vparse(pyspark.__version__) >= vparse("3.5") + else Py4JJavaError + ) +except ImportError: + pyspark_merge_exception = None + +NO_MERGE_SUPPORT_MARKS = [ + pytest.mark.notyet( + ["clickhouse"], + raises=ClickHouseDatabaseError, + reason="MERGE INTO is not supported", + ), + pytest.mark.notyet(["datafusion"], reason="MERGE INTO is not supported"), + pytest.mark.notyet( + ["impala"], + raises=ImpalaHiveServer2Error, + reason="target table must be an Iceberg table", + ), + pytest.mark.notyet( + ["mysql"], raises=MySQLProgrammingError, reason="MERGE INTO is not supported" + ), + pytest.mark.notimpl(["polars"], reason="`upsert` method not implemented"), + pytest.mark.notyet( + ["pyspark"], + raises=pyspark_merge_exception, + reason="MERGE INTO TABLE is not supported temporarily", + ), + pytest.mark.notyet( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="MERGE INTO is not supported", + ), + pytest.mark.notyet( + ["sqlite"], + raises=sqlite3.OperationalError, + reason="MERGE INTO is not supported", + ), + pytest.mark.notyet( + ["trino"], + raises=TrinoUserError, + reason="connector does not support modifying table rows", + ), +] +NO_MERGE_SUPPORT = combine_marks(NO_MERGE_SUPPORT_MARKS) diff --git a/ibis/backends/tests/errors.py b/ibis/backends/tests/errors.py index 0c06dcbf8d12..a961166c98b3 100644 --- a/ibis/backends/tests/errors.py +++ b/ibis/backends/tests/errors.py @@ -55,13 +55,18 @@ from pyspark.errors.exceptions.base import ParseException as PySparkParseException from pyspark.errors.exceptions.base import PySparkValueError from pyspark.errors.exceptions.base import PythonException as PySparkPythonException + from pyspark.errors.exceptions.base import ( + UnsupportedOperationException as PySparkUnsupportedOperationException, + ) from pyspark.errors.exceptions.connect import ( SparkConnectGrpcException as PySparkConnectGrpcException, ) except ImportError: PySparkParseException = PySparkAnalysisException = PySparkArithmeticException = ( PySparkPythonException - ) = PySparkConnectGrpcException = PySparkValueError = None + ) = PySparkUnsupportedOperationException = PySparkConnectGrpcException = ( + PySparkValueError + ) = None try: from google.api_core.exceptions import BadRequest as GoogleBadRequest diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 45a8d1af74b5..f3d0d015fab6 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -25,6 +25,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.conftest import ALL_BACKENDS +from ibis.backends.tests.conftest import NO_MERGE_SUPPORT from ibis.backends.tests.errors import ( DatabricksServerOperationError, ExaQueryError, @@ -493,8 +494,6 @@ def employee_data_1_temp_table(backend, con, test_employee_schema): @pytest.fixture def test_employee_data_2(): - import pandas as pd - df2 = pd.DataFrame( { "first_name": ["X", "Y", "Z"], @@ -519,6 +518,32 @@ def employee_data_2_temp_table( con.drop_table(temp_table_name, force=True) +@pytest.fixture +def test_employee_data_3(): + df3 = pd.DataFrame( + { + "first_name": ["B", "Y", "Z"], + "last_name": ["A", "B", "C"], + "department_name": ["XX", "YY", "ZZ"], + "salary": [400.0, 500.0, 600.0], + } + ) + + return df3 + + +@pytest.fixture +def employee_data_3_temp_table( + backend, con, test_employee_schema, test_employee_data_3 +): + temp_table_name = gen_name("temp_employee_data_3") + _create_temp_table_with_schema( + backend, con, temp_table_name, test_employee_schema, data=test_employee_data_3 + ) + yield temp_table_name + con.drop_table(temp_table_name, force=True) + + @pytest.mark.notimpl(["polars"], reason="`insert` method not implemented") def test_insert_no_overwrite_from_dataframe( backend, con, test_employee_data_2, employee_empty_temp_table @@ -626,6 +651,105 @@ def _emp(a, b, c, d): assert len(con.table(employee_data_1_temp_table).execute()) == 3 +@NO_MERGE_SUPPORT +def test_upsert_from_dataframe( + backend, con, employee_data_1_temp_table, test_employee_data_3 +): + temporary = con.table(employee_data_1_temp_table) + df1 = temporary.execute().set_index("first_name") + + con.upsert(employee_data_1_temp_table, obj=test_employee_data_3, on="first_name") + result = temporary.execute() + df2 = test_employee_data_3.set_index("first_name") + expected = pd.concat([df1[~df1.index.isin(df2.index)], df2]).reset_index() + assert len(result) == len(expected) + backend.assert_frame_equal( + result.sort_values("first_name").reset_index(drop=True), + expected.sort_values("first_name").reset_index(drop=True), + ) + + +@NO_MERGE_SUPPORT +@pytest.mark.parametrize( + "with_order_by", + [ + pytest.param( + True, + marks=pytest.mark.notyet( + ["mssql"], + "MSSQL doesn't allow ORDER BY in subqueries, unless " + "TOP, OFFSET or FOR XML is also specified", + ), + ), + False, + ], +) +def test_upsert_from_expr( + backend, con, employee_data_1_temp_table, employee_data_3_temp_table, with_order_by +): + temporary = con.table(employee_data_1_temp_table) + from_table = con.table(employee_data_3_temp_table) + if with_order_by: + from_table = from_table.filter(ibis._.salary > 0).order_by("first_name") + + df1 = temporary.execute().set_index("first_name") + + con.upsert(employee_data_1_temp_table, obj=from_table, on="first_name") + result = temporary.execute() + df2 = from_table.execute().set_index("first_name") + expected = pd.concat([df1[~df1.index.isin(df2.index)], df2]).reset_index() + assert len(result) == len(expected) + backend.assert_frame_equal( + result.sort_values("first_name").reset_index(drop=True), + expected.sort_values("first_name").reset_index(drop=True), + ) + + +@NO_MERGE_SUPPORT +@pytest.mark.notyet(["druid"], raises=NotImplementedError) +@pytest.mark.notimpl( + ["flink"], + raises=com.IbisError, + reason="`tbl_properties` is required when creating table with schema", +) +@pytest.mark.parametrize( + ("sch", "expectation"), + [ + ({"x": "int64", "y": "float64", "z": "string"}, contextlib.nullcontext()), + ({"z": "!string", "y": "float32", "x": "int8"}, contextlib.nullcontext()), + ({"x": "int64"}, pytest.raises(Exception)), # No cols to insert + ({"x": "int64", "z": "string"}, contextlib.nullcontext()), + ({"z": "string"}, pytest.raises(Exception)), # Missing `on` col + ], +) +def test_upsert_from_memtable(backend, con, temp_table, sch, expectation): + t1 = ibis.memtable({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0], "z": ["a", "b", "c"]}) + table_name = temp_table + + data = { + k: v + for k, v in {"x": [3, 2, 6], "y": [7.0, 8.0, 9.0], "z": ["d", "e", "f"]}.items() + if k in sch + } + t2 = ibis.memtable(data, schema=sch) + + con.create_table(table_name, schema=t1.schema()) + con.upsert(table_name, t1, on="x") + temporary = con.table(table_name) + df1 = temporary.execute().set_index("x") + + with expectation: + con.upsert(table_name, t2, on="x") + + result = temporary.execute() + expected = pd.DataFrame(data).set_index("x").combine_first(df1).reset_index() + assert len(result) == len(expected) + backend.assert_frame_equal( + result.sort_values("x").reset_index(drop=True), + expected.sort_values("x").reset_index(drop=True), + ) + + @pytest.mark.notimpl( ["polars"], raises=AttributeError, reason="`insert` method not implemented" )