Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 329 additions & 11 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ crate-type = ["cdylib"]

[dependencies]
deadpool-postgres = { git = "https://github.com/chandr-andr/deadpool.git", branch = "master" }
pyo3 = { version = "*", features = ["chrono", "experimental-async"] }
pyo3 = { version = "*", features = [
"chrono",
"experimental-async",
"rust_decimal",
] }
pyo3-asyncio = { git = "https://github.com/chandr-andr/pyo3-asyncio.git", version = "0.20.0", features = [
"tokio-runtime",
] }
Expand All @@ -34,3 +38,7 @@ postgres-types = { git = "https://github.com/chandr-andr/rust-postgres.git", bra
"derive",
] }
postgres-protocol = { git = "https://github.com/chandr-andr/rust-postgres.git", branch = "master" }
rust_decimal = { git = "https://github.com/chandr-andr/rust-decimal.git", branch = "psqlpy", features = [
"db-postgres",
"db-tokio-postgres",
] }
10 changes: 10 additions & 0 deletions python/psqlpy/_internal/extra_types.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ class BigInt:
- `inner_value`: int object.
"""

class Money:
"""Represent `MONEY` in PostgreSQL and `i64` in Rust."""

def __init__(self: Self, inner_value: int) -> None:
"""Create new instance of class.

### Parameters:
- `inner_value`: int object.
"""

class Float32:
"""Represents `FLOAT4` in `PostgreSQL` and `f32` in Rust."""

Expand Down
2 changes: 2 additions & 0 deletions python/psqlpy/extra_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Float32,
Float64,
Integer,
Money,
PyCustomType,
PyJSON,
PyJSONB,
Expand All @@ -26,4 +27,5 @@
"PyCustomType",
"Float32",
"Float64",
"Money",
]
20 changes: 20 additions & 0 deletions python/tests/test_value_converter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import uuid
from decimal import Decimal
from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, List, Union
Expand All @@ -15,6 +16,7 @@
Float32,
Float64,
Integer,
Money,
PyJSON,
PyJSONB,
PyMacAddr6,
Expand Down Expand Up @@ -72,10 +74,18 @@ async def test_as_class(
("BYTEA", b"Bytes", [66, 121, 116, 101, 115]),
("VARCHAR", "Some String", "Some String"),
("TEXT", "Some String", "Some String"),
(
"XML",
"""<?xml version="1.0"?><book><title>Manual</title><chapter>...</chapter></book>""",
"""<book><title>Manual</title><chapter>...</chapter></book>""",
),
("BOOL", True, True),
("INT2", SmallInt(12), 12),
("INT4", Integer(121231231), 121231231),
("INT8", BigInt(99999999999999999), 99999999999999999),
("MONEY", BigInt(99999999999999999), 99999999999999999),
("MONEY", Money(99999999999999999), 99999999999999999),
("NUMERIC(5, 2)", Decimal("120.12"), Decimal("120.12")),
("FLOAT4", 32.12329864501953, 32.12329864501953),
("FLOAT4", Float32(32.12329864501953), 32.12329864501953),
("FLOAT8", Float64(32.12329864501953), 32.12329864501953),
Expand Down Expand Up @@ -145,6 +155,16 @@ async def test_as_class(
[BigInt(99999999999999999), BigInt(99999999999999999)],
[99999999999999999, 99999999999999999],
),
(
"MONEY ARRAY",
[Money(99999999999999999), Money(99999999999999999)],
[99999999999999999, 99999999999999999],
),
(
"NUMERIC(5, 2) ARRAY",
[Decimal("121.23"), Decimal("188.99")],
[Decimal("121.23"), Decimal("188.99")],
),
(
"FLOAT4 ARRAY",
[32.12329864501953, 32.12329864501953],
Expand Down
5 changes: 4 additions & 1 deletion src/exceptions/rust_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ pub enum RustPSQLDriverError {
RustMacAddrConversionError(#[from] macaddr::ParseError),
#[error("Cannot execute future in Rust: {0}")]
RustRuntimeJoinError(#[from] JoinError),
#[error("Cannot convert python Decimal into rust Decimal")]
DecimalConversionError(#[from] rust_decimal::Error),
}

impl From<RustPSQLDriverError> for pyo3::PyErr {
Expand All @@ -92,7 +94,8 @@ impl From<RustPSQLDriverError> for pyo3::PyErr {
RustPSQLDriverError::RustToPyValueConversionError(_) => {
RustToPyValueMappingError::new_err((error_desc,))
}
RustPSQLDriverError::PyToRustValueConversionError(_) => {
RustPSQLDriverError::PyToRustValueConversionError(_)
| RustPSQLDriverError::DecimalConversionError(_) => {
PyToRustValueMappingError::new_err((error_desc,))
}
RustPSQLDriverError::ConnectionPoolConfigurationError(_) => {
Expand Down
2 changes: 2 additions & 0 deletions src/extra_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ macro_rules! build_python_type {
build_python_type!(SmallInt, i16);
build_python_type!(Integer, i32);
build_python_type!(BigInt, i64);
build_python_type!(Money, i64);
build_python_type!(Float32, f32);
build_python_type!(Float64, f64);

Expand Down Expand Up @@ -189,6 +190,7 @@ pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyRes
pymod.add_class::<SmallInt>()?;
pymod.add_class::<Integer>()?;
pymod.add_class::<BigInt>()?;
pymod.add_class::<Money>()?;
pymod.add_class::<Float32>()?;
pymod.add_class::<Float64>()?;
pymod.add_class::<PyText>()?;
Expand Down
91 changes: 82 additions & 9 deletions src/value_converter.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
use macaddr::{MacAddr6, MacAddr8};
use postgres_types::{Field, FromSql, Kind, ToSql};
use rust_decimal::Decimal;
use serde_json::{json, Map, Value};
use std::{fmt::Debug, net::IpAddr};
use uuid::Uuid;

use bytes::{BufMut, BytesMut};
use postgres_protocol::types;
use pyo3::{
sync::GILOnceCell,
types::{
PyAnyMethods, PyBool, PyBytes, PyDate, PyDateTime, PyDict, PyDictMethods, PyFloat, PyInt,
PyList, PyListMethods, PyString, PyTime, PyTuple, PyTypeMethods,
PyList, PyListMethods, PyString, PyTime, PyTuple, PyType, PyTypeMethods,
},
Bound, Py, PyAny, Python, ToPyObject,
Bound, Py, PyAny, PyObject, PyResult, Python, ToPyObject,
};
use tokio_postgres::{
types::{to_sql_checked, Type},
Expand All @@ -23,13 +25,43 @@ use crate::{
additional_types::{RustMacAddr6, RustMacAddr8},
exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult},
extra_types::{
BigInt, Float32, Float64, Integer, PyCustomType, PyJSON, PyJSONB, PyMacAddr6, PyMacAddr8,
PyText, PyVarChar, SmallInt,
BigInt, Float32, Float64, Integer, Money, PyCustomType, PyJSON, PyJSONB, PyMacAddr6,
PyMacAddr8, PyText, PyVarChar, SmallInt,
},
};

static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();

pub type QueryParameter = (dyn ToSql + Sync);

fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
DECIMAL_CLS
.get_or_try_init(py, || {
let type_object = py
.import_bound("decimal")?
.getattr("Decimal")?
.downcast_into()?;
Ok(type_object.unbind())
})
.map(|ty| ty.bind(py))
}

/// Struct for Decimal.
///
/// It's necessary because we use custom forks and there is
/// no implementation of `ToPyObject` for Decimal.
struct InnerDecimal(Decimal);

impl ToPyObject for InnerDecimal {
fn to_object(&self, py: Python<'_>) -> PyObject {
let dec_cls = get_decimal_cls(py).expect("failed to load decimal.Decimal");
let ret = dec_cls
.call1((self.0.to_string(),))
.expect("failed to call decimal.Decimal(value)");
ret.to_object(py)
}
}

/// Additional type for types come from Python.
///
/// It's necessary because we need to pass this
Expand All @@ -51,6 +83,7 @@ pub enum PythonDTO {
PyIntU64(u64),
PyFloat32(f32),
PyFloat64(f64),
PyMoney(i64),
PyDate(NaiveDate),
PyTime(NaiveTime),
PyDateTime(NaiveDateTime),
Expand All @@ -62,6 +95,7 @@ pub enum PythonDTO {
PyJson(Value),
PyMacAddr6(MacAddr6),
PyMacAddr8(MacAddr8),
PyDecimal(Decimal),
PyCustomType(Vec<u8>),
}

Expand Down Expand Up @@ -89,6 +123,7 @@ impl PythonDTO {
PythonDTO::PyIntI64(_) => Ok(tokio_postgres::types::Type::INT8_ARRAY),
PythonDTO::PyFloat32(_) => Ok(tokio_postgres::types::Type::FLOAT4_ARRAY),
PythonDTO::PyFloat64(_) => Ok(tokio_postgres::types::Type::FLOAT8_ARRAY),
PythonDTO::PyMoney(_) => Ok(tokio_postgres::types::Type::MONEY_ARRAY),
PythonDTO::PyIpAddress(_) => Ok(tokio_postgres::types::Type::INET_ARRAY),
PythonDTO::PyJsonb(_) => Ok(tokio_postgres::types::Type::JSONB_ARRAY),
PythonDTO::PyJson(_) => Ok(tokio_postgres::types::Type::JSON_ARRAY),
Expand All @@ -98,6 +133,7 @@ impl PythonDTO {
PythonDTO::PyDateTimeTz(_) => Ok(tokio_postgres::types::Type::TIMESTAMPTZ_ARRAY),
PythonDTO::PyMacAddr6(_) => Ok(tokio_postgres::types::Type::MACADDR_ARRAY),
PythonDTO::PyMacAddr8(_) => Ok(tokio_postgres::types::Type::MACADDR8_ARRAY),
PythonDTO::PyDecimal(_) => Ok(tokio_postgres::types::Type::NUMERIC_ARRAY),
_ => Err(RustPSQLDriverError::PyToRustValueConversionError(
"Can't process array type, your type doesn't have support yet".into(),
)),
Expand Down Expand Up @@ -197,7 +233,7 @@ impl ToSql for PythonDTO {
}
PythonDTO::PyIntI16(int) => out.put_i16(*int),
PythonDTO::PyIntI32(int) => out.put_i32(*int),
PythonDTO::PyIntI64(int) => out.put_i64(*int),
PythonDTO::PyIntI64(int) | PythonDTO::PyMoney(int) => out.put_i64(*int),
PythonDTO::PyIntU32(int) => out.put_u32(*int),
PythonDTO::PyIntU64(int) => out.put_u64(*int),
PythonDTO::PyFloat32(float) => out.put_f32(*float),
Expand Down Expand Up @@ -237,6 +273,9 @@ impl ToSql for PythonDTO {
PythonDTO::PyJsonb(py_dict) | PythonDTO::PyJson(py_dict) => {
<&Value as ToSql>::to_sql(&py_dict, ty, out)?;
}
PythonDTO::PyDecimal(py_decimal) => {
<Decimal as ToSql>::to_sql(py_decimal, ty, out)?;
}
}
if return_is_null_true {
Ok(tokio_postgres::types::IsNull::Yes)
Expand Down Expand Up @@ -286,6 +325,7 @@ pub fn convert_parameters(parameters: Py<PyAny>) -> RustPSQLDriverPyResult<Vec<P
/// or value of the type is incorrect.
#[allow(clippy::too_many_lines)]
pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<PythonDTO> {
println!("{}", parameter.get_type().name()?);
if parameter.is_none() {
return Ok(PythonDTO::PyNone);
}
Expand Down Expand Up @@ -352,6 +392,12 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
));
}

if parameter.is_instance_of::<Money>() {
return Ok(PythonDTO::PyMoney(
parameter.extract::<Money>()?.retrieve_value(),
));
}

if parameter.is_instance_of::<PyInt>() {
return Ok(PythonDTO::PyIntI32(parameter.extract::<i32>()?));
}
Expand Down Expand Up @@ -443,6 +489,13 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
)?));
}

if parameter.get_type().name()? == "decimal.Decimal" {
println!("{}", parameter.str()?.extract::<&str>()?);
return Ok(PythonDTO::PyDecimal(Decimal::from_str_exact(
parameter.str()?.extract::<&str>()?,
)?));
}

if let Ok(id_address) = parameter.extract::<IpAddr>() {
return Ok(PythonDTO::PyIpAddress(id_address));
}
Expand Down Expand Up @@ -496,7 +549,7 @@ fn postgres_bytes_to_py(
.to_object(py)),
// // ---------- String Types ----------
// // Convert TEXT and VARCHAR type into String, then into str
Type::TEXT | Type::VARCHAR => Ok(_composite_field_postgres_to_py::<Option<String>>(
Type::TEXT | Type::VARCHAR | Type::XML => Ok(_composite_field_postgres_to_py::<Option<String>>(
type_, buf, is_simple,
)?
.to_object(py)),
Expand All @@ -515,7 +568,7 @@ fn postgres_bytes_to_py(
_composite_field_postgres_to_py::<Option<i32>>(type_, buf, is_simple)?.to_object(py),
),
// Convert BigInt into i64, then into int
Type::INT8 => Ok(
Type::INT8 | Type::MONEY => Ok(
_composite_field_postgres_to_py::<Option<i64>>(type_, buf, is_simple)?.to_object(py),
),
// Convert REAL into f32, then into float
Expand Down Expand Up @@ -592,13 +645,21 @@ fn postgres_bytes_to_py(
Ok(py.None().to_object(py))
}
}
Type::NUMERIC => {
if let Some(numeric_) = _composite_field_postgres_to_py::<Option<Decimal>>(
type_, buf, is_simple,
)? {
return Ok(InnerDecimal(numeric_).to_object(py));
}
Ok(py.None().to_object(py))
}
// ---------- Array Text Types ----------
Type::BOOL_ARRAY => Ok(_composite_field_postgres_to_py::<Option<Vec<bool>>>(
type_, buf, is_simple,
)?
.to_object(py)),
// Convert ARRAY of TEXT or VARCHAR into Vec<String>, then into list[str]
Type::TEXT_ARRAY | Type::VARCHAR_ARRAY => Ok(_composite_field_postgres_to_py::<
Type::TEXT_ARRAY | Type::VARCHAR_ARRAY | Type::XML_ARRAY => Ok(_composite_field_postgres_to_py::<
Option<Vec<String>>,
>(type_, buf, is_simple)?
.to_object(py)),
Expand All @@ -614,7 +675,7 @@ fn postgres_bytes_to_py(
)?
.to_object(py)),
// Convert ARRAY of BigInt into Vec<i64>, then into list[int]
Type::INT8_ARRAY => Ok(_composite_field_postgres_to_py::<Option<Vec<i64>>>(
Type::INT8_ARRAY | Type::MONEY_ARRAY => Ok(_composite_field_postgres_to_py::<Option<Vec<i64>>>(
type_, buf, is_simple,
)?
.to_object(py)),
Expand Down Expand Up @@ -686,6 +747,18 @@ fn postgres_bytes_to_py(
None => Ok(py.None().to_object(py)),
}
}
Type::NUMERIC_ARRAY => {
if let Some(numeric_array) = _composite_field_postgres_to_py::<Option<Vec<Decimal>>>(
type_, buf, is_simple,
)? {
let py_list = PyList::empty_bound(py);
for numeric_ in numeric_array {
py_list.append(InnerDecimal(numeric_).to_object(py))?;
}
return Ok(py_list.to_object(py))
};
Ok(py.None().to_object(py))
},
_ => Err(RustPSQLDriverError::RustToPyValueConversionError(
format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.")
)),
Expand Down