Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/psqlpy/_internal/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ _CustomClass = TypeVar(
class QueryResult:
"""Result."""

def result(self: Self) -> list[dict[Any, Any]]:
def result(
self: Self,
custom_decoders: dict[str, Callable[[bytes], Any]] | None = None,
) -> list[dict[Any, Any]]:
"""Return result from database as a list of dicts."""
def as_class(
self: Self,
Expand Down
3 changes: 3 additions & 0 deletions python/psqlpy/_internal/extra_types.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,6 @@ class PyMacAddr8:
### Parameters:
- `value`: value for MACADDR8 field.
"""

class PyCustomType:
def __init__(self, value: bytes) -> None: ...
2 changes: 2 additions & 0 deletions python/psqlpy/extra_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._internal.extra_types import (
BigInt,
Integer,
PyCustomType,
PyJSON,
PyJSONB,
PyMacAddr6,
Expand All @@ -22,4 +23,5 @@
"PyMacAddr8",
"PyVarChar",
"PyText",
"PyCustomType",
]
50 changes: 50 additions & 0 deletions python/tests/test_value_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tests.conftest import DefaultPydanticModel, DefaultPythonModelClass

from psqlpy import ConnectionPool
from psqlpy._internal.extra_types import PyCustomType
from psqlpy.extra_types import (
BigInt,
Integer,
Expand Down Expand Up @@ -425,3 +426,52 @@ class TopLevelModel(BaseModel):
)

assert isinstance(model_result[0], TopLevelModel)


async def test_custom_type_as_parameter(
psql_pool: ConnectionPool,
) -> None:
"""Tests that we can use `PyCustomType`."""
await psql_pool.execute("DROP TABLE IF EXISTS for_test")
await psql_pool.execute(
"CREATE TABLE for_test (nickname VARCHAR)",
)

await psql_pool.execute(
querystring="INSERT INTO for_test VALUES ($1)",
parameters=[PyCustomType(b"Some Real Nickname")],
)

qs_result = await psql_pool.execute(
"SELECT * FROM for_test",
)

result = qs_result.result()
assert result[0]["nickname"] == "Some Real Nickname"


async def test_custom_decoder(
psql_pool: ConnectionPool,
) -> None:
await psql_pool.execute("DROP TABLE IF EXISTS for_test")
await psql_pool.execute(
"CREATE TABLE for_test (geo_point POINT)",
)

await psql_pool.execute(
"INSERT INTO for_test VALUES ('(1, 1)')",
)

def point_encoder(point_bytes: bytes) -> str:
return "Just An Example"

qs_result = await psql_pool.execute(
"SELECT * FROM for_test",
)
result = qs_result.result(
custom_decoders={
"geo_point": point_encoder,
},
)

assert result[0]["geo_point"] == "Just An Example"
2 changes: 1 addition & 1 deletion src/driver/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl Connection {
};

Python::with_gil(|gil| match result.columns().first() {
Some(first_column) => postgres_to_py(gil, &result, first_column, 0),
Some(first_column) => postgres_to_py(gil, &result, first_column, 0, &None),
None => Ok(gil.None()),
})
}
Expand Down
2 changes: 1 addition & 1 deletion src/driver/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ impl Transaction {
};

Python::with_gil(|gil| match result.columns().first() {
Some(first_column) => postgres_to_py(gil, &result, first_column, 0),
Some(first_column) => postgres_to_py(gil, &result, first_column, 0, &None),
None => Ok(gil.None()),
})
}
Expand Down
22 changes: 22 additions & 0 deletions src/extra_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,27 @@ macro_rules! build_macaddr_type {
build_macaddr_type!(PyMacAddr6, MacAddr6);
build_macaddr_type!(PyMacAddr8, MacAddr8);

#[pyclass]
#[derive(Clone, Debug)]
pub struct PyCustomType {
inner: Vec<u8>,
}

impl PyCustomType {
#[must_use]
pub fn inner(&self) -> Vec<u8> {
self.inner.clone()
}
}

#[pymethods]
impl PyCustomType {
#[new]
fn new_class(type_bytes: Vec<u8>) -> Self {
PyCustomType { inner: type_bytes }
}
}

#[allow(clippy::module_name_repetitions)]
#[allow(clippy::missing_errors_doc)]
pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> {
Expand All @@ -203,5 +224,6 @@ pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyRes
pymod.add_class::<PyJSON>()?;
pymod.add_class::<PyMacAddr6>()?;
pymod.add_class::<PyMacAddr8>()?;
pymod.add_class::<PyCustomType>()?;
Ok(())
}
17 changes: 11 additions & 6 deletions src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ use crate::{exceptions::rust_errors::RustPSQLDriverPyResult, value_converter::po
fn row_to_dict<'a>(
py: Python<'a>,
postgres_row: &'a Row,
custom_decoders: &Option<Py<PyDict>>,
) -> RustPSQLDriverPyResult<pyo3::Bound<'a, PyDict>> {
let python_dict = PyDict::new_bound(py);
for (column_idx, column) in postgres_row.columns().iter().enumerate() {
let python_type = postgres_to_py(py, postgres_row, column, column_idx)?;
let python_type = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?;
python_dict.set_item(column.name().to_object(py), python_type)?;
}
Ok(python_dict)
Expand Down Expand Up @@ -55,10 +56,14 @@ impl PSQLDriverPyQueryResult {
/// postgres type to python or set new key-value pair
/// in python dict.
#[allow(clippy::needless_pass_by_value)]
pub fn result(&self, py: Python<'_>) -> RustPSQLDriverPyResult<Py<PyAny>> {
pub fn result(
&self,
py: Python<'_>,
custom_decoders: Option<Py<PyDict>>,
) -> RustPSQLDriverPyResult<Py<PyAny>> {
let mut result: Vec<pyo3::Bound<'_, PyDict>> = vec![];
for row in &self.inner {
result.push(row_to_dict(py, row)?);
result.push(row_to_dict(py, row, &custom_decoders)?);
}
Ok(result.to_object(py))
}
Expand All @@ -77,7 +82,7 @@ impl PSQLDriverPyQueryResult {
) -> RustPSQLDriverPyResult<Py<PyAny>> {
let mut res: Vec<Py<PyAny>> = vec![];
for row in &self.inner {
let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row)?;
let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row, &None)?;
let convert_class_inst = as_class.call_bound(py, (), Some(&pydict))?;
res.push(convert_class_inst);
}
Expand Down Expand Up @@ -117,7 +122,7 @@ impl PSQLDriverSinglePyQueryResult {
/// postgres type to python, can not set new key-value pair
/// in python dict or there are no result.
pub fn result(&self, py: Python<'_>) -> RustPSQLDriverPyResult<Py<PyAny>> {
Ok(row_to_dict(py, &self.inner)?.to_object(py))
Ok(row_to_dict(py, &self.inner, &None)?.to_object(py))
}

/// Convert result from database to any class passed from Python.
Expand All @@ -133,7 +138,7 @@ impl PSQLDriverSinglePyQueryResult {
py: Python<'a>,
as_class: Py<PyAny>,
) -> RustPSQLDriverPyResult<Py<PyAny>> {
let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, &self.inner)?;
let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, &self.inner, &None)?;
Ok(as_class.call_bound(py, (), Some(&pydict))?)
}
}
35 changes: 28 additions & 7 deletions src/value_converter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use crate::{
additional_types::{RustMacAddr6, RustMacAddr8},
exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult},
extra_types::{
BigInt, Integer, PyJSON, PyJSONB, PyMacAddr6, PyMacAddr8, PyText, PyUUID, PyVarChar,
SmallInt,
BigInt, Integer, PyCustomType, PyJSON, PyJSONB, PyMacAddr6, PyMacAddr8, PyText, PyUUID,
PyVarChar, SmallInt,
},
};

Expand Down Expand Up @@ -62,6 +62,7 @@ pub enum PythonDTO {
PyJson(Value),
PyMacAddr6(MacAddr6),
PyMacAddr8(MacAddr8),
PyCustomType(Vec<u8>),
}

impl PythonDTO {
Expand Down Expand Up @@ -174,6 +175,9 @@ impl ToSql for PythonDTO {

match self {
PythonDTO::PyNone => {}
PythonDTO::PyCustomType(some_bytes) => {
<&[u8] as ToSql>::to_sql(&some_bytes.as_slice(), ty, out)?;
}
PythonDTO::PyBytes(pybytes) => {
<Vec<u8> as ToSql>::to_sql(pybytes, ty, out)?;
}
Expand Down Expand Up @@ -284,6 +288,12 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
return Ok(PythonDTO::PyNone);
}

if parameter.is_instance_of::<PyCustomType>() {
return Ok(PythonDTO::PyCustomType(
parameter.extract::<PyCustomType>()?.inner(),
));
}

if parameter.is_instance_of::<PyBool>() {
return Ok(PythonDTO::PyBool(parameter.extract::<bool>()?));
}
Expand Down Expand Up @@ -652,10 +662,9 @@ fn postgres_bytes_to_py(
None => Ok(py.None().to_object(py)),
}
}
_ => Ok(
_composite_field_postgres_to_py::<Option<Vec<u8>>>(type_, buf, is_simple)?
.to_object(py),
),
_ => Err(RustPSQLDriverError::RustToPyValueConversionError(
format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.")
)),
}
}

Expand Down Expand Up @@ -720,9 +729,21 @@ pub fn postgres_to_py(
row: &Row,
column: &Column,
column_i: usize,
custom_decoders: &Option<Py<PyDict>>,
) -> RustPSQLDriverPyResult<Py<PyAny>> {
let column_type = column.type_();
let raw_bytes_data = row.col_buffer(column_i);

if let Some(custom_decoders) = custom_decoders {
let py_encoder_func = custom_decoders
.bind(py)
.get_item(column.name().to_lowercase());

if let Ok(Some(py_encoder_func)) = py_encoder_func {
return Ok(py_encoder_func.call((raw_bytes_data,), None)?.unbind());
}
}

let column_type = column.type_();
match raw_bytes_data {
Some(mut raw_bytes_data) => match column_type.kind() {
Kind::Simple | Kind::Array(_) => {
Expand Down