diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index cc6fc052..df1b05d5 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -11,7 +11,10 @@ _CustomClass = TypeVar( class QueryResult: """Result.""" - def result(self: Self) -> list[dict[Any, Any]]: + def result( + self: Self, + custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, + ) -> list[dict[Any, Any]]: """Return result from database as a list of dicts.""" def as_class( self: Self, diff --git a/python/psqlpy/_internal/extra_types.pyi b/python/psqlpy/_internal/extra_types.pyi index 4346412c..525b314d 100644 --- a/python/psqlpy/_internal/extra_types.pyi +++ b/python/psqlpy/_internal/extra_types.pyi @@ -123,3 +123,6 @@ class PyMacAddr8: ### Parameters: - `value`: value for MACADDR8 field. """ + +class PyCustomType: + def __init__(self, value: bytes) -> None: ... diff --git a/python/psqlpy/extra_types.py b/python/psqlpy/extra_types.py index a65ff2ec..3fc13125 100644 --- a/python/psqlpy/extra_types.py +++ b/python/psqlpy/extra_types.py @@ -1,6 +1,7 @@ from ._internal.extra_types import ( BigInt, Integer, + PyCustomType, PyJSON, PyJSONB, PyMacAddr6, @@ -22,4 +23,5 @@ "PyMacAddr8", "PyVarChar", "PyText", + "PyCustomType", ] diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index 24aa8706..c11070a9 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -8,6 +8,7 @@ from tests.conftest import DefaultPydanticModel, DefaultPythonModelClass from psqlpy import ConnectionPool +from psqlpy._internal.extra_types import PyCustomType from psqlpy.extra_types import ( BigInt, Integer, @@ -425,3 +426,52 @@ class TopLevelModel(BaseModel): ) assert isinstance(model_result[0], TopLevelModel) + + +async def test_custom_type_as_parameter( + psql_pool: ConnectionPool, +) -> None: + """Tests that we can use `PyCustomType`.""" + await psql_pool.execute("DROP TABLE IF EXISTS for_test") + await psql_pool.execute( + "CREATE TABLE for_test (nickname VARCHAR)", + ) + + await psql_pool.execute( + querystring="INSERT INTO for_test VALUES ($1)", + parameters=[PyCustomType(b"Some Real Nickname")], + ) + + qs_result = await psql_pool.execute( + "SELECT * FROM for_test", + ) + + result = qs_result.result() + assert result[0]["nickname"] == "Some Real Nickname" + + +async def test_custom_decoder( + psql_pool: ConnectionPool, +) -> None: + await psql_pool.execute("DROP TABLE IF EXISTS for_test") + await psql_pool.execute( + "CREATE TABLE for_test (geo_point POINT)", + ) + + await psql_pool.execute( + "INSERT INTO for_test VALUES ('(1, 1)')", + ) + + def point_encoder(point_bytes: bytes) -> str: + return "Just An Example" + + qs_result = await psql_pool.execute( + "SELECT * FROM for_test", + ) + result = qs_result.result( + custom_decoders={ + "geo_point": point_encoder, + }, + ) + + assert result[0]["geo_point"] == "Just An Example" diff --git a/src/driver/connection.rs b/src/driver/connection.rs index c14580bd..159e5da6 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -257,7 +257,7 @@ impl Connection { }; Python::with_gil(|gil| match result.columns().first() { - Some(first_column) => postgres_to_py(gil, &result, first_column, 0), + Some(first_column) => postgres_to_py(gil, &result, first_column, 0, &None), None => Ok(gil.None()), }) } diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index 1d0cabfa..8ed00cd7 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -331,7 +331,7 @@ impl Transaction { }; Python::with_gil(|gil| match result.columns().first() { - Some(first_column) => postgres_to_py(gil, &result, first_column, 0), + Some(first_column) => postgres_to_py(gil, &result, first_column, 0, &None), None => Ok(gil.None()), }) } diff --git a/src/extra_types.rs b/src/extra_types.rs index 76d6af74..d86b9a62 100644 --- a/src/extra_types.rs +++ b/src/extra_types.rs @@ -190,6 +190,27 @@ macro_rules! build_macaddr_type { build_macaddr_type!(PyMacAddr6, MacAddr6); build_macaddr_type!(PyMacAddr8, MacAddr8); +#[pyclass] +#[derive(Clone, Debug)] +pub struct PyCustomType { + inner: Vec, +} + +impl PyCustomType { + #[must_use] + pub fn inner(&self) -> Vec { + self.inner.clone() + } +} + +#[pymethods] +impl PyCustomType { + #[new] + fn new_class(type_bytes: Vec) -> Self { + PyCustomType { inner: type_bytes } + } +} + #[allow(clippy::module_name_repetitions)] #[allow(clippy::missing_errors_doc)] pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { @@ -203,5 +224,6 @@ pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyRes pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; + pymod.add_class::()?; Ok(()) } diff --git a/src/query_result.rs b/src/query_result.rs index 3b50d637..a5c5a774 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -13,10 +13,11 @@ use crate::{exceptions::rust_errors::RustPSQLDriverPyResult, value_converter::po fn row_to_dict<'a>( py: Python<'a>, postgres_row: &'a Row, + custom_decoders: &Option>, ) -> RustPSQLDriverPyResult> { let python_dict = PyDict::new_bound(py); for (column_idx, column) in postgres_row.columns().iter().enumerate() { - let python_type = postgres_to_py(py, postgres_row, column, column_idx)?; + let python_type = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?; python_dict.set_item(column.name().to_object(py), python_type)?; } Ok(python_dict) @@ -55,10 +56,14 @@ impl PSQLDriverPyQueryResult { /// postgres type to python or set new key-value pair /// in python dict. #[allow(clippy::needless_pass_by_value)] - pub fn result(&self, py: Python<'_>) -> RustPSQLDriverPyResult> { + pub fn result( + &self, + py: Python<'_>, + custom_decoders: Option>, + ) -> RustPSQLDriverPyResult> { let mut result: Vec> = vec![]; for row in &self.inner { - result.push(row_to_dict(py, row)?); + result.push(row_to_dict(py, row, &custom_decoders)?); } Ok(result.to_object(py)) } @@ -77,7 +82,7 @@ impl PSQLDriverPyQueryResult { ) -> RustPSQLDriverPyResult> { let mut res: Vec> = vec![]; for row in &self.inner { - let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row)?; + let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row, &None)?; let convert_class_inst = as_class.call_bound(py, (), Some(&pydict))?; res.push(convert_class_inst); } @@ -117,7 +122,7 @@ impl PSQLDriverSinglePyQueryResult { /// postgres type to python, can not set new key-value pair /// in python dict or there are no result. pub fn result(&self, py: Python<'_>) -> RustPSQLDriverPyResult> { - Ok(row_to_dict(py, &self.inner)?.to_object(py)) + Ok(row_to_dict(py, &self.inner, &None)?.to_object(py)) } /// Convert result from database to any class passed from Python. @@ -133,7 +138,7 @@ impl PSQLDriverSinglePyQueryResult { py: Python<'a>, as_class: Py, ) -> RustPSQLDriverPyResult> { - let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, &self.inner)?; + let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, &self.inner, &None)?; Ok(as_class.call_bound(py, (), Some(&pydict))?) } } diff --git a/src/value_converter.rs b/src/value_converter.rs index 3f051354..59a3735f 100644 --- a/src/value_converter.rs +++ b/src/value_converter.rs @@ -23,8 +23,8 @@ use crate::{ additional_types::{RustMacAddr6, RustMacAddr8}, exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, extra_types::{ - BigInt, Integer, PyJSON, PyJSONB, PyMacAddr6, PyMacAddr8, PyText, PyUUID, PyVarChar, - SmallInt, + BigInt, Integer, PyCustomType, PyJSON, PyJSONB, PyMacAddr6, PyMacAddr8, PyText, PyUUID, + PyVarChar, SmallInt, }, }; @@ -62,6 +62,7 @@ pub enum PythonDTO { PyJson(Value), PyMacAddr6(MacAddr6), PyMacAddr8(MacAddr8), + PyCustomType(Vec), } impl PythonDTO { @@ -174,6 +175,9 @@ impl ToSql for PythonDTO { match self { PythonDTO::PyNone => {} + PythonDTO::PyCustomType(some_bytes) => { + <&[u8] as ToSql>::to_sql(&some_bytes.as_slice(), ty, out)?; + } PythonDTO::PyBytes(pybytes) => { as ToSql>::to_sql(pybytes, ty, out)?; } @@ -284,6 +288,12 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult< return Ok(PythonDTO::PyNone); } + if parameter.is_instance_of::() { + return Ok(PythonDTO::PyCustomType( + parameter.extract::()?.inner(), + )); + } + if parameter.is_instance_of::() { return Ok(PythonDTO::PyBool(parameter.extract::()?)); } @@ -652,10 +662,9 @@ fn postgres_bytes_to_py( None => Ok(py.None().to_object(py)), } } - _ => Ok( - _composite_field_postgres_to_py::>>(type_, buf, is_simple)? - .to_object(py), - ), + _ => Err(RustPSQLDriverError::RustToPyValueConversionError( + format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.") + )), } } @@ -720,9 +729,21 @@ pub fn postgres_to_py( row: &Row, column: &Column, column_i: usize, + custom_decoders: &Option>, ) -> RustPSQLDriverPyResult> { - let column_type = column.type_(); let raw_bytes_data = row.col_buffer(column_i); + + if let Some(custom_decoders) = custom_decoders { + let py_encoder_func = custom_decoders + .bind(py) + .get_item(column.name().to_lowercase()); + + if let Ok(Some(py_encoder_func)) = py_encoder_func { + return Ok(py_encoder_func.call((raw_bytes_data,), None)?.unbind()); + } + } + + let column_type = column.type_(); match raw_bytes_data { Some(mut raw_bytes_data) => match column_type.kind() { Kind::Simple | Kind::Array(_) => {