Skip to content

Commit af88c42

Browse files
committed
Adding capability to invoke python catalog and schema providers as well
1 parent ed951fd commit af88c42

File tree

7 files changed

+267
-35
lines changed

7 files changed

+267
-35
lines changed

examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,19 @@
1919

2020
import pyarrow as pa
2121

22-
from datafusion import SessionContext
22+
from datafusion import SessionContext, Table
2323
from datafusion_ffi_example import MyCatalogProvider
2424

25+
from datafusion.context import PyCatalogProvider, PySchemaProvider
26+
27+
2528
def test_catalog_provider():
2629
ctx = SessionContext()
2730

2831
my_catalog_name = "my_catalog"
2932
expected_schema_name = "my_schema"
3033
expected_table_name = "my_table"
31-
expected_table_columns = ['units', 'price']
34+
expected_table_columns = ["units", "price"]
3235

3336
catalog_provider = MyCatalogProvider()
3437
ctx.register_catalog_provider(my_catalog_name, catalog_provider)
@@ -41,12 +44,9 @@ def test_catalog_provider():
4144
my_table = my_database.table(expected_table_name)
4245
assert expected_table_columns == my_table.schema.names
4346

44-
ctx.register_table(expected_table_name, my_table)
45-
expected_df = ctx.sql(f"SELECT * FROM {expected_table_name}").to_pandas()
46-
assert len(expected_df) == 5
47-
assert expected_table_columns == expected_df.columns.tolist()
48-
49-
result = ctx.table(f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}").collect()
47+
result = ctx.table(
48+
f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}"
49+
).collect()
5050
assert len(result) == 2
5151

5252
col0_result = [r.column(0) for r in result]
@@ -60,4 +60,58 @@ def test_catalog_provider():
6060
pa.array([1.5, 2.5], type=pa.float64()),
6161
]
6262
assert col0_result == expected_col0
63-
assert col1_result == expected_col1
63+
assert col1_result == expected_col1
64+
65+
66+
class MyPyCatalogProvider(PyCatalogProvider):
67+
my_schemas = ['my_schema']
68+
69+
def schema_names(self) -> list[str]:
70+
return self.my_schemas
71+
72+
def schema(self, name: str) -> PySchemaProvider:
73+
return MyPySchemaProvider()
74+
75+
76+
class MyPySchemaProvider(PySchemaProvider):
77+
my_tables = ['table1', 'table2', 'table3']
78+
79+
def table_names(self) -> list[str]:
80+
return self.my_tables
81+
82+
def table_exist(self, table_name: str) -> bool:
83+
return table_name in self.my_tables
84+
85+
def table(self, table_name: str) -> Table:
86+
raise RuntimeError(f"Can not get table: {table_name}")
87+
88+
def register_table(self, table: Table) -> None:
89+
raise RuntimeError(f"Can not register {table} as table")
90+
91+
def deregister_table(self, table_name: str) -> None:
92+
raise RuntimeError(f"Can not deregister table: {table_name}")
93+
94+
95+
def test_python_catalog_provider():
96+
ctx = SessionContext()
97+
98+
my_catalog_name = "my_py_catalog"
99+
expected_schema_name = "my_schema"
100+
my_py_catalog_provider = MyPyCatalogProvider()
101+
ctx.register_catalog_provider(my_catalog_name, my_py_catalog_provider)
102+
my_py_catalog = ctx.catalog(my_catalog_name)
103+
assert MyPyCatalogProvider.my_schemas == my_py_catalog.names()
104+
105+
my_database = my_py_catalog.database(expected_schema_name)
106+
assert set(MyPySchemaProvider.my_tables) == my_database.names()
107+
108+
# asserting a non-compliant provider fails at the python level as expected
109+
try:
110+
ctx.register_catalog_provider(my_catalog_name, "non_compliant_provider")
111+
except TypeError:
112+
# expect a TypeError because we can not register a str as a catalog provider
113+
pass
114+
115+
116+
if __name__ == "__main__":
117+
test_python_catalog_provider()

examples/datafusion-ffi-example/src/catalog_provider.rs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::{any::Any, fmt::Debug, sync::Arc};
1918
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
19+
use std::{any::Any, fmt::Debug, sync::Arc};
2020

2121
use arrow::datatypes::Schema;
2222
use async_trait::async_trait;
2323
use datafusion::{
2424
catalog::{
25-
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
26-
TableProvider,
25+
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, TableProvider,
2726
},
2827
common::exec_err,
2928
datasource::MemTable,
@@ -46,12 +45,12 @@ pub fn my_table() -> Arc<dyn TableProvider + 'static> {
4645
("units", Int32, vec![10, 20, 30]),
4746
("price", Float64, vec![1.0, 2.0, 5.0])
4847
)
49-
.unwrap(),
48+
.unwrap(),
5049
record_batch!(
5150
("units", Int32, vec![5, 7]),
5251
("price", Float64, vec![1.5, 2.5])
5352
)
54-
.unwrap(),
53+
.unwrap(),
5554
];
5655

5756
Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap())
@@ -68,9 +67,7 @@ impl Default for FixedSchemaProvider {
6867

6968
let table = my_table();
7069

71-
let _ = inner
72-
.register_table("my_table".to_string(), table)
73-
.unwrap();
70+
let _ = inner.register_table("my_table".to_string(), table).unwrap();
7471

7572
Self { inner }
7673
}
@@ -86,10 +83,7 @@ impl SchemaProvider for FixedSchemaProvider {
8683
self.inner.table_names()
8784
}
8885

89-
async fn table(
90-
&self,
91-
name: &str,
92-
) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
86+
async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
9387
self.inner.table(name).await
9488
}
9589

@@ -110,7 +104,6 @@ impl SchemaProvider for FixedSchemaProvider {
110104
}
111105
}
112106

113-
114107
/// This catalog provider is intended only for unit tests. It prepopulates with one
115108
/// schema and only allows for schemas named after four types of fruit.
116109
#[pyclass(name = "MyCatalogProvider", module = "datafusion_ffi_example", subclass)]
@@ -178,4 +171,4 @@ impl MyCatalogProvider {
178171

179172
PyCapsule::new(py, catalog_provider, Some(name))
180173
}
181-
}
174+
}

examples/datafusion-ffi-example/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::catalog_provider::MyCatalogProvider;
1819
use crate::table_function::MyTableFunction;
1920
use crate::table_provider::MyTableProvider;
20-
use crate::catalog_provider::MyCatalogProvider;
2121
use pyo3::prelude::*;
2222

23+
pub(crate) mod catalog_provider;
2324
pub(crate) mod table_function;
2425
pub(crate) mod table_provider;
25-
pub(crate) mod catalog_provider;
2626

2727
#[pymodule]
2828
fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {

python/datafusion/context.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from __future__ import annotations
2121

2222
import warnings
23-
from typing import TYPE_CHECKING, Any, Protocol
23+
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
2424

2525
import pyarrow as pa
2626

@@ -79,13 +79,40 @@ class TableProviderExportable(Protocol):
7979

8080
def __datafusion_table_provider__(self) -> object: ... # noqa: D105
8181

82-
82+
@runtime_checkable
8383
class CatalogProviderExportable(Protocol):
8484
"""Type hint for object that has __datafusion_catalog_provider__ PyCapsule.
8585
8686
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
8787
"""
88-
def __datafusion_catalog_provider__(self) -> object: ...
88+
89+
def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105
90+
91+
@runtime_checkable
92+
class PySchemaProvider(Protocol):
93+
def table_names(self) -> list[str]:
94+
...
95+
96+
def register_table(self, table: Table) -> None:
97+
...
98+
99+
def deregister_table(self, table_name: str) -> None:
100+
...
101+
102+
def table_exist(self, table_name: str) -> bool:
103+
...
104+
105+
def table(self, table_name: str) -> Table:
106+
...
107+
108+
109+
@runtime_checkable
110+
class PyCatalogProvider(Protocol):
111+
def schema_names(self) -> list[str]:
112+
...
113+
114+
def schema(self, name: str) -> PySchemaProvider:
115+
...
89116

90117

91118
class SessionConfig:
@@ -758,9 +785,14 @@ def deregister_table(self, name: str) -> None:
758785
self.ctx.deregister_table(name)
759786

760787
def register_catalog_provider(
761-
self, name: str, provider: CatalogProviderExportable
788+
self, name: str, provider: PyCatalogProvider | CatalogProviderExportable
762789
) -> None:
763790
"""Register a catalog provider."""
791+
if not isinstance(provider, (PyCatalogProvider, CatalogProviderExportable)):
792+
raise TypeError(
793+
f"Expected provider to be CatalogProviderProtocol or rust version exposed through python, but got {type(provider)} instead."
794+
)
795+
764796
self.ctx.register_catalog_provider(name, provider)
765797

766798
def register_table_provider(

0 commit comments

Comments
 (0)