From 20247b5a57e7c96090d806aaaf6eb24b87198e4a Mon Sep 17 00:00:00 2001 From: bymoye Date: Tue, 1 Jul 2025 12:57:02 +0800 Subject: [PATCH 1/7] change result --- docs/components/results.md | 14 ++++++++++---- python/tests/test_query_results.py | 4 ++-- python/tests/test_value_converter.py | 2 +- src/query_result.rs | 13 +++++++------ src/value_converter/to_python.rs | 2 +- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/docs/components/results.md b/docs/components/results.md index 7cea19a5..765571fa 100644 --- a/docs/components/results.md +++ b/docs/components/results.md @@ -14,8 +14,9 @@ Currently there are two results: ### Result #### Parameters + - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) -- `as_tuple`: return result as a tuple instead of dict. +- `as_tuple`: Headless tuple output Get the result as a list of dicts @@ -32,7 +33,7 @@ async def main() -> None: list_dict_result: List[Dict[str, Any]] = query_result.result() # Result as tuple - list_tuple_result: List[Tuple[Tuple[str, typing.Any], ...]] = query_result.result( + list_tuple_result: List[Tuple[str, typing.Any], ...] = query_result.result( as_tuple=True, ) ``` @@ -40,6 +41,7 @@ async def main() -> None: ### As class #### Parameters + - `as_class`: Custom class from Python. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) @@ -68,6 +70,7 @@ async def main() -> None: ### Row Factory #### Parameters + - `row_factory`: custom callable object. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) @@ -78,8 +81,9 @@ async def main() -> None: ### Result #### Parameters + - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) -- `as_tuple`: return result as a tuple instead of dict. +- `as_tuple`: Headless tuple output Get the result as a dict @@ -96,7 +100,7 @@ async def main() -> None: dict_result: Dict[str, Any] = query_result.result() # Result as tuple - tuple_result: Tuple[Tuple[str, typing.Any], ...] = query_result.result( + tuple_result: Tuple[str, typing.Any] = query_result.result( as_tuple=True, ) ``` @@ -104,6 +108,7 @@ async def main() -> None: ### As class #### Parameters + - `as_class`: Custom class from Python. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) @@ -131,6 +136,7 @@ async def main() -> None: ### Row Factory #### Parameters + - `row_factory`: custom callable object. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) diff --git a/python/tests/test_query_results.py b/python/tests/test_query_results.py index 95de93c7..ff136fb4 100644 --- a/python/tests/test_query_results.py +++ b/python/tests/test_query_results.py @@ -39,7 +39,7 @@ async def test_result_as_tuple( assert isinstance(conn_result, QueryResult) assert isinstance(single_tuple_row, tuple) - assert single_tuple_row[0][0] == "id" + assert single_tuple_row[0] == 1 async def test_single_result_as_dict( @@ -73,4 +73,4 @@ async def test_single_result_as_tuple( assert isinstance(conn_result, SingleQueryResult) assert isinstance(result_tuple, tuple) - assert result_tuple[0][0] == "id" + assert result_tuple[0] == 1 diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index 07833848..122201ef 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -672,7 +672,7 @@ def point_encoder(point_bytes: bytes) -> str: # noqa: ARG001 as_tuple=True, ) - assert result[0][0][1] == "Just An Example" + assert result[0][0] == "Just An Example" async def test_row_factory_query_result( diff --git a/src/query_result.rs b/src/query_result.rs index a5af132d..46047848 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -42,14 +42,15 @@ fn row_to_tuple<'a>( postgres_row: &'a Row, custom_decoders: &Option>, ) -> PSQLPyResult> { - let mut rows: Vec> = vec![]; + let columns = postgres_row.columns(); + let mut tuple_items = Vec::with_capacity(columns.len()); - for (column_idx, column) in postgres_row.columns().iter().enumerate() { - let python_type = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?; - let timed_tuple = PyTuple::new(py, vec![column.name().into_py_any(py)?, python_type])?; - rows.push(timed_tuple); + for (column_idx, column) in columns.iter().enumerate() { + let python_value = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?; + tuple_items.push(python_value); } - Ok(PyTuple::new(py, rows)?) + + Ok(PyTuple::new(py, tuple_items)?) } #[pyclass(name = "QueryResult")] diff --git a/src/value_converter/to_python.rs b/src/value_converter/to_python.rs index abc734c8..e742af1a 100644 --- a/src/value_converter/to_python.rs +++ b/src/value_converter/to_python.rs @@ -604,7 +604,7 @@ pub fn raw_bytes_data_process( if let Ok(Some(py_encoder_func)) = py_encoder_func { return Ok(py_encoder_func - .call((raw_bytes_data.to_vec(),), None)? + .call1((PyBytes::new(py, raw_bytes_data),))? .unbind()); } } From 8cc74de28b24864c1ebcd40fbe416560d66598cb Mon Sep 17 00:00:00 2001 From: bymoye Date: Tue, 1 Jul 2025 14:53:10 +0800 Subject: [PATCH 2/7] more optimizations --- src/driver/common.rs | 20 +++----- src/value_converter/to_python.rs | 86 ++++++++++++++++++-------------- 2 files changed, 56 insertions(+), 50 deletions(-) diff --git a/src/driver/common.rs b/src/driver/common.rs index 3c22517a..3789c1f6 100644 --- a/src/driver/common.rs +++ b/src/driver/common.rs @@ -77,19 +77,13 @@ macro_rules! impl_config_py_methods { #[cfg(not(unix))] #[getter] fn hosts(&self) -> Vec { - let mut hosts_vec = vec![]; - - let hosts = self.pg_config.get_hosts(); - for host in hosts { - match host { - Host::Tcp(host) => { - hosts_vec.push(host.to_string()); - } - _ => unreachable!(), - } - } - - hosts_vec + self.pg_config + .get_hosts() + .iter() + .map(|host| match host { + Host::Tcp(host) => host.to_string(), + }) + .collect() } #[getter] diff --git a/src/value_converter/to_python.rs b/src/value_converter/to_python.rs index e742af1a..cf0f6d35 100644 --- a/src/value_converter/to_python.rs +++ b/src/value_converter/to_python.rs @@ -95,13 +95,9 @@ fn postgres_array_to_py<'py, T: IntoPyObject<'py> + Clone>( array: Option>, ) -> Option> { array.map(|array| { - inner_postgres_array_to_py( - py, - array.dimensions(), - array.iter().cloned().collect::>(), - 0, - 0, - ) + // Collect data once instead of creating copies in recursion + let data: Vec = array.iter().cloned().collect(); + inner_postgres_array_to_py(py, array.dimensions(), &data, 0, 0) }) } @@ -110,44 +106,60 @@ fn postgres_array_to_py<'py, T: IntoPyObject<'py> + Clone>( fn inner_postgres_array_to_py<'py, T>( py: Python<'py>, dimensions: &[Dimension], - data: Vec, + data: &[T], dimension_index: usize, - mut lower_bound: usize, + data_offset: usize, ) -> Py where T: IntoPyObject<'py> + Clone, { - let current_dimension = dimensions.get(dimension_index); - - if let Some(current_dimension) = current_dimension { - let possible_next_dimension = dimensions.get(dimension_index + 1); - match possible_next_dimension { - Some(next_dimension) => { - let final_list = PyList::empty(py); - - for _ in 0..current_dimension.len as usize { - if dimensions.get(dimension_index + 1).is_some() { - let inner_pylist = inner_postgres_array_to_py( - py, - dimensions, - data[lower_bound..next_dimension.len as usize + lower_bound].to_vec(), - dimension_index + 1, - 0, - ); - final_list.append(inner_pylist).unwrap(); - lower_bound += next_dimension.len as usize; - } - } - - return final_list.unbind(); - } - None => { - return PyList::new(py, data).unwrap().unbind(); // TODO unwrap is unsafe - } + // Check bounds early + if dimension_index >= dimensions.len() || data_offset >= data.len() { + return PyList::empty(py).unbind(); + } + + let current_dimension = &dimensions[dimension_index]; + let current_len = current_dimension.len as usize; + + // If this is the last dimension, create a list with the actual data + if dimension_index + 1 >= dimensions.len() { + let end_offset = (data_offset + current_len).min(data.len()); + let slice = &data[data_offset..end_offset]; + + // Create Python list more efficiently + return match PyList::new(py, slice.iter().cloned()) { + Ok(list) => list.unbind(), + Err(_) => PyList::empty(py).unbind(), + }; + } + + // For multi-dimensional arrays, recursively create nested lists + let final_list = PyList::empty(py); + + // Calculate the size of each sub-array + let sub_array_size = dimensions[dimension_index + 1..] + .iter() + .map(|d| d.len as usize) + .product::(); + + let mut current_offset = data_offset; + + for _ in 0..current_len { + if current_offset >= data.len() { + break; } + + let inner_list = + inner_postgres_array_to_py(py, dimensions, data, dimension_index + 1, current_offset); + + if final_list.append(inner_list).is_err() { + break; + } + + current_offset += sub_array_size; } - PyList::empty(py).unbind() + final_list.unbind() } #[allow(clippy::too_many_lines)] From e7995268db4a9948a3b38b394df9e63844e0dad1 Mon Sep 17 00:00:00 2001 From: bymoye Date: Tue, 1 Jul 2025 15:12:38 +0800 Subject: [PATCH 3/7] fix typing hint --- python/psqlpy/_internal/__init__.pyi | 64 +++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 2665678b..17a2d482 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -36,7 +36,7 @@ class QueryResult: self: Self, as_tuple: typing.Literal[True], custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, - ) -> list[tuple[tuple[str, typing.Any], ...]]: ... + ) -> list[tuple[typing.Any, ...]]: ... @typing.overload def result( self: Self, @@ -50,6 +50,7 @@ class QueryResult: `custom_decoders` must be used when you use PostgreSQL Type which isn't supported, read more in our docs. """ + def as_class( self: Self, as_class: Callable[..., _CustomClass], @@ -83,6 +84,7 @@ class QueryResult: ) ``` """ + def row_factory( self, row_factory: Callable[[dict[str, Any]], _RowFactoryRV], @@ -124,7 +126,7 @@ class SingleQueryResult: self: Self, as_tuple: typing.Literal[True], custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, - ) -> tuple[tuple[str, typing.Any]]: ... + ) -> tuple[typing.Any, ...]: ... @typing.overload def result( self: Self, @@ -138,6 +140,7 @@ class SingleQueryResult: `custom_decoders` must be used when you use PostgreSQL Type which isn't supported, read more in our docs. """ + def as_class( self: Self, as_class: Callable[..., _CustomClass], @@ -174,6 +177,7 @@ class SingleQueryResult: ) ``` """ + def row_factory( self, row_factory: Callable[[dict[str, Any]], _RowFactoryRV], @@ -328,11 +332,13 @@ class Cursor: Execute DECLARE command for the cursor. """ + def close(self: Self) -> None: """Close the cursor. Execute CLOSE command for the cursor. """ + async def execute( self: Self, querystring: str, @@ -343,10 +349,13 @@ class Cursor: Method should be used instead of context manager and `start` method. """ + async def fetchone(self: Self) -> QueryResult: """Return next one row from the cursor.""" + async def fetchmany(self: Self, size: int | None = None) -> QueryResult: """Return rows from the cursor.""" + async def fetchall(self: Self, size: int | None = None) -> QueryResult: """Return all remaining rows from the cursor.""" @@ -379,6 +388,7 @@ class Transaction: `begin()` can be called only once per transaction. """ + async def commit(self: Self) -> None: """Commit the transaction. @@ -386,6 +396,7 @@ class Transaction: `commit()` can be called only once per transaction. """ + async def rollback(self: Self) -> None: """Rollback all queries in the transaction. @@ -406,6 +417,7 @@ class Transaction: await transaction.rollback() ``` """ + async def execute( self: Self, querystring: str, @@ -443,6 +455,7 @@ class Transaction: await transaction.commit() ``` """ + async def execute_batch( self: Self, querystring: str, @@ -458,6 +471,7 @@ class Transaction: ### Parameters: - `querystring`: querystrings separated by semicolons. """ + async def execute_many( self: Self, querystring: str, @@ -516,6 +530,7 @@ class Transaction: - `prepared`: should the querystring be prepared before the request. By default any querystring will be prepared. """ + async def fetch_row( self: Self, querystring: str, @@ -555,6 +570,7 @@ class Transaction: await transaction.commit() ``` """ + async def fetch_val( self: Self, querystring: str, @@ -595,6 +611,7 @@ class Transaction: ) ``` """ + async def pipeline( self, queries: list[tuple[str, list[Any] | None]], @@ -659,6 +676,7 @@ class Transaction: ) ``` """ + async def create_savepoint(self: Self, savepoint_name: str) -> None: """Create new savepoint. @@ -687,6 +705,7 @@ class Transaction: await transaction.rollback_savepoint("my_savepoint") ``` """ + async def rollback_savepoint(self: Self, savepoint_name: str) -> None: """ROLLBACK to the specified `savepoint_name`. @@ -712,6 +731,7 @@ class Transaction: await transaction.rollback_savepoint("my_savepoint") ``` """ + async def release_savepoint(self: Self, savepoint_name: str) -> None: """Execute ROLLBACK TO SAVEPOINT. @@ -736,6 +756,7 @@ class Transaction: await transaction.release_savepoint ``` """ + def cursor( self: Self, querystring: str, @@ -779,6 +800,7 @@ class Transaction: await cursor.close() ``` """ + async def binary_copy_to_table( self: Self, source: bytes | bytearray | Buffer | BytesIO, @@ -860,6 +882,7 @@ class Connection: Return representation of prepared statement. """ + async def execute( self: Self, querystring: str, @@ -896,6 +919,7 @@ class Connection: dict_result: List[Dict[Any, Any]] = query_result.result() ``` """ + async def execute_batch( self: Self, querystring: str, @@ -911,6 +935,7 @@ class Connection: ### Parameters: - `querystring`: querystrings separated by semicolons. """ + async def execute_many( self: Self, querystring: str, @@ -964,6 +989,7 @@ class Connection: - `prepared`: should the querystring be prepared before the request. By default any querystring will be prepared. """ + async def fetch_row( self: Self, querystring: str, @@ -1000,6 +1026,7 @@ class Connection: dict_result: Dict[Any, Any] = query_result.result() ``` """ + async def fetch_val( self: Self, querystring: str, @@ -1039,6 +1066,7 @@ class Connection: ) ``` """ + def transaction( self, isolation_level: IsolationLevel | None = None, @@ -1052,6 +1080,7 @@ class Connection: - `read_variant`: configure read variant of the transaction. - `deferrable`: configure deferrable of the transaction. """ + def cursor( self: Self, querystring: str, @@ -1090,6 +1119,7 @@ class Connection: ... # do something with this result. ``` """ + def close(self: Self) -> None: """Return connection back to the pool. @@ -1234,6 +1264,7 @@ class ConnectionPool: - `ca_file`: Loads trusted root certificates from a file. The file should contain a sequence of PEM-formatted CA certificates. """ + def __iter__(self: Self) -> Self: ... def __enter__(self: Self) -> Self: ... def __exit__( @@ -1248,6 +1279,7 @@ class ConnectionPool: ### Returns `ConnectionPoolStatus` """ + def resize(self: Self, new_max_size: int) -> None: """Resize the connection pool. @@ -1257,11 +1289,13 @@ class ConnectionPool: ### Parameters: - `new_max_size`: new size for the connection pool. """ + async def connection(self: Self) -> Connection: """Create new connection. It acquires new connection from the database pool. """ + def acquire(self: Self) -> Connection: """Create new connection for async context manager. @@ -1279,6 +1313,7 @@ class ConnectionPool: res = await connection.execute(...) ``` """ + def listener(self: Self) -> Listener: """Create new listener.""" @@ -1390,6 +1425,7 @@ class ConnectionPoolBuilder: def __init__(self: Self) -> None: """Initialize new instance of `ConnectionPoolBuilder`.""" + def build(self: Self) -> ConnectionPool: """ Build `ConnectionPool`. @@ -1397,6 +1433,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPool` """ + def max_pool_size(self: Self, pool_size: int) -> Self: """ Set maximum connection pool size. @@ -1407,6 +1444,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def conn_recycling_method( self: Self, conn_recycling_method: ConnRecyclingMethod, @@ -1422,6 +1460,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def user(self: Self, user: str) -> Self: """ Set username to `PostgreSQL`. @@ -1432,6 +1471,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def password(self: Self, password: str) -> Self: """ Set password for `PostgreSQL`. @@ -1442,6 +1482,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def dbname(self: Self, dbname: str) -> Self: """ Set database name for the `PostgreSQL`. @@ -1452,6 +1493,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def options(self: Self, options: str) -> Self: """ Set command line options used to configure the server. @@ -1462,6 +1504,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def application_name(self: Self, application_name: str) -> Self: """ Set the value of the `application_name` runtime parameter. @@ -1472,6 +1515,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def ssl_mode(self: Self, ssl_mode: SslMode) -> Self: """ Set the SSL configuration. @@ -1482,6 +1526,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def ca_file(self: Self, ca_file: str) -> Self: """ Set ca_file for SSL. @@ -1492,6 +1537,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def host(self: Self, host: str) -> Self: """ Add a host to the configuration. @@ -1509,6 +1555,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def hostaddr(self: Self, hostaddr: IPv4Address | IPv6Address) -> Self: """ Add a hostaddr to the configuration. @@ -1524,6 +1571,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def port(self: Self, port: int) -> Self: """ Add a port to the configuration. @@ -1540,6 +1588,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def connect_timeout(self: Self, connect_timeout: int) -> Self: """ Set the timeout applied to socket-level connection attempts. @@ -1554,6 +1603,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def tcp_user_timeout(self: Self, tcp_user_timeout: int) -> Self: """ Set the TCP user timeout. @@ -1569,6 +1619,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def target_session_attrs( self: Self, target_session_attrs: TargetSessionAttrs, @@ -1586,6 +1637,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def load_balance_hosts( self: Self, load_balance_hosts: LoadBalanceHosts, @@ -1601,6 +1653,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives( self: Self, keepalives: bool, @@ -1618,6 +1671,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives_idle( self: Self, keepalives_idle: int, @@ -1636,6 +1690,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives_interval( self: Self, keepalives_interval: int, @@ -1655,6 +1710,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives_retries( self: Self, keepalives_retries: int, @@ -1747,11 +1803,13 @@ class Listener: Each listener MUST be started up. """ + async def shutdown(self: Self) -> None: """Shutdown the listener. Abort listen and release underlying connection. """ + async def add_callback( self: Self, channel: str, @@ -1814,7 +1872,9 @@ class Column: class PreparedStatement: async def execute(self: Self) -> QueryResult: """Execute prepared statement.""" + def cursor(self: Self) -> Cursor: """Create new server-side cursor based on prepared statement.""" + def columns(self: Self) -> list[Column]: """Return information about statement columns.""" From 8ed07240c2740ec3e8da5d89ffb798e29e994ffb Mon Sep 17 00:00:00 2001 From: bymoye Date: Wed, 2 Jul 2025 11:43:24 +0800 Subject: [PATCH 4/7] more optimizations --- src/value_converter/dto/funcs.rs | 2 +- src/value_converter/models/serde_value.rs | 128 ++++++++++++---------- 2 files changed, 70 insertions(+), 60 deletions(-) diff --git a/src/value_converter/dto/funcs.rs b/src/value_converter/dto/funcs.rs index eec045e0..e869966c 100644 --- a/src/value_converter/dto/funcs.rs +++ b/src/value_converter/dto/funcs.rs @@ -4,7 +4,7 @@ use postgres_types::Type; pub fn array_type_to_single_type(array_type: &Type) -> Type { match *array_type { Type::BOOL_ARRAY => Type::BOOL, - Type::UUID_ARRAY => Type::UUID_ARRAY, + Type::UUID_ARRAY => Type::UUID, Type::VARCHAR_ARRAY => Type::VARCHAR, Type::TEXT_ARRAY => Type::TEXT, Type::INT2_ARRAY => Type::INT2, diff --git a/src/value_converter/models/serde_value.rs b/src/value_converter/models/serde_value.rs index 222ffe56..8c499c0a 100644 --- a/src/value_converter/models/serde_value.rs +++ b/src/value_converter/models/serde_value.rs @@ -3,8 +3,8 @@ use postgres_types::FromSql; use serde_json::{json, Map, Value}; use pyo3::{ - types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyTuple}, - Bound, FromPyObject, IntoPyObject, Py, PyAny, PyResult, Python, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods}, + Bound, FromPyObject, IntoPyObject, PyAny, PyResult, Python, }; use tokio_postgres::types::Type; @@ -37,7 +37,7 @@ impl<'py> IntoPyObject<'py> for InternalSerdeValue { type Error = RustPSQLDriverError; fn into_pyobject(self, py: Python<'py>) -> Result { - match build_python_from_serde_value(py, self.0.clone()) { + match build_python_from_serde_value(py, self.0) { Ok(ok_value) => Ok(ok_value.bind(py).clone()), Err(err) => Err(err), } @@ -57,25 +57,30 @@ impl<'a> FromSql<'a> for InternalSerdeValue { } } -fn serde_value_from_list(gil: Python<'_>, bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { - let mut result_vec: Vec = vec![]; +fn serde_value_from_list(_gil: Python<'_>, bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { + let py_list = bind_value.downcast::().map_err(|e| { + RustPSQLDriverError::PyToRustValueConversionError(format!( + "Parameter must be a list, but it's not: {}", + e + )) + })?; - let params = bind_value.extract::>>()?; + let mut result_vec: Vec = Vec::with_capacity(py_list.len()); - for inner in params { - let inner_bind = inner.bind(gil); - if inner_bind.is_instance_of::() { - let python_dto = from_python_untyped(inner_bind)?; + for item in py_list.iter() { + if item.is_instance_of::() { + let python_dto = from_python_untyped(&item)?; result_vec.push(python_dto.to_serde_value()?); - } else if inner_bind.is_instance_of::() { - let serde_value = build_serde_value(inner.bind(gil))?; + } else if item.is_instance_of::() { + let serde_value = build_serde_value(&item)?; result_vec.push(serde_value); } else { return Err(RustPSQLDriverError::PyToRustValueConversionError( - "PyJSON must have dicts.".to_string(), + "Items in JSON array must be dicts or lists.".to_string(), )); } } + Ok(json!(result_vec)) } @@ -86,19 +91,18 @@ fn serde_value_from_dict(bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { )) })?; - let mut serde_map: Map = Map::new(); + let dict_len = dict.len(); + let mut serde_map: Map = Map::with_capacity(dict_len); - for dict_item in dict.items() { - let py_list = dict_item.downcast::().map_err(|error| { + for (key, value) in dict.iter() { + let key_str = key.extract::().map_err(|error| { RustPSQLDriverError::PyToRustValueConversionError(format!( - "Cannot cast to list: {error}" + "Cannot extract dict key as string: {error}" )) })?; - let key = py_list.get_item(0)?.extract::()?; - let value = from_python_untyped(&py_list.get_item(1)?)?; - - serde_map.insert(key, value.to_serde_value()?); + let value_dto = from_python_untyped(&value)?; + serde_map.insert(key_str, value_dto.to_serde_value()?); } Ok(Value::Object(serde_map)) @@ -131,12 +135,10 @@ pub fn build_serde_value(value: &Bound<'_, PyAny>) -> PSQLPyResult { /// May return error if cannot create serde value. pub fn pythondto_array_to_serde(array: Option>) -> PSQLPyResult { match array { - Some(array) => inner_pythondto_array_to_serde( - array.dimensions(), - array.iter().collect::>().as_slice(), - 0, - 0, - ), + Some(array) => { + let data: Vec = array.iter().cloned().collect(); + inner_pythondto_array_to_serde(array.dimensions(), &data, 0, 0) + } None => Ok(Value::Null), } } @@ -145,41 +147,49 @@ pub fn pythondto_array_to_serde(array: Option>) -> PSQLPyResult #[allow(clippy::cast_sign_loss)] fn inner_pythondto_array_to_serde( dimensions: &[Dimension], - data: &[&PythonDTO], + data: &[PythonDTO], dimension_index: usize, - mut lower_bound: usize, + data_offset: usize, ) -> PSQLPyResult { - let current_dimension = dimensions.get(dimension_index); - - if let Some(current_dimension) = current_dimension { - let possible_next_dimension = dimensions.get(dimension_index + 1); - match possible_next_dimension { - Some(next_dimension) => { - let mut final_list: Value = Value::Array(vec![]); - - for _ in 0..current_dimension.len as usize { - if dimensions.get(dimension_index + 1).is_some() { - let inner_pylist = inner_pythondto_array_to_serde( - dimensions, - &data[lower_bound..next_dimension.len as usize + lower_bound], - dimension_index + 1, - 0, - )?; - match final_list { - Value::Array(ref mut array) => array.push(inner_pylist), - _ => unreachable!(), - } - lower_bound += next_dimension.len as usize; - } - } - - return Ok(final_list); - } - None => { - return data.iter().map(|x| x.to_serde_value()).collect(); - } + if dimension_index >= dimensions.len() || data_offset >= data.len() { + return Ok(Value::Array(vec![])); + } + + let current_dimension = &dimensions[dimension_index]; + let current_len = current_dimension.len as usize; + + if dimension_index + 1 >= dimensions.len() { + let end_offset = (data_offset + current_len).min(data.len()); + let slice = &data[data_offset..end_offset]; + + let mut result_values = Vec::with_capacity(slice.len()); + for item in slice { + result_values.push(item.to_serde_value()?); } + + return Ok(Value::Array(result_values)); + } + + let mut final_array = Vec::with_capacity(current_len); + + let sub_array_size = dimensions[dimension_index + 1..] + .iter() + .map(|d| d.len as usize) + .product::(); + + let mut current_offset = data_offset; + + for _ in 0..current_len { + if current_offset >= data.len() { + break; + } + + let inner_value = + inner_pythondto_array_to_serde(dimensions, data, dimension_index + 1, current_offset)?; + + final_array.push(inner_value); + current_offset += sub_array_size; } - Ok(Value::Array(vec![])) + Ok(Value::Array(final_array)) } From ba63201d6eba88042ee2f37b8856888c40b1eede Mon Sep 17 00:00:00 2001 From: bymoye Date: Wed, 2 Jul 2025 11:59:08 +0800 Subject: [PATCH 5/7] refactor execute_many for better performance --- src/connection/impls.rs | 84 +++++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 32 deletions(-) diff --git a/src/connection/impls.rs b/src/connection/impls.rs index 62d9a830..ebeb9bb0 100644 --- a/src/connection/impls.rs +++ b/src/connection/impls.rs @@ -409,42 +409,62 @@ impl PSQLPyConnection { parameters: Option>>, prepared: Option, ) -> PSQLPyResult<()> { - let mut statements: Vec = vec![]; - if let Some(parameters) = parameters { - for vec_of_py_any in parameters { - // TODO: Fix multiple qs creation - let statement = - StatementBuilder::new(&querystring, &Some(vec_of_py_any), self, prepared) - .build() - .await?; - - statements.push(statement); - } - } + let Some(parameters) = parameters else { + return Ok(()); + }; let prepared = prepared.unwrap_or(true); - for statement in statements { - let querystring_result = if prepared { - let prepared_stmt = &self.prepare(statement.raw_query(), true).await; - if let Err(error) = prepared_stmt { - return Err(RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement in execute_many, operation rolled back {error}", - ))); - } - self.query( - &self.prepare(statement.raw_query(), true).await?, - &statement.params(), - ) - .await - } else { - self.query(statement.raw_query(), &statement.params()).await - }; + let mut statements: Vec = Vec::with_capacity(parameters.len()); + + for param_set in parameters { + let statement = + StatementBuilder::new(&querystring, &Some(param_set), self, Some(prepared)) + .build() + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot build statement in execute_many: {err}" + )) + })?; + statements.push(statement); + } + + if statements.is_empty() { + return Ok(()); + } - if let Err(error) = querystring_result { - return Err(RustPSQLDriverError::ConnectionExecuteError(format!( - "Error occured in `execute_many` statement: {error}" - ))); + if prepared { + let first_statement = &statements[0]; + let prepared_stmt = self + .prepare(first_statement.raw_query(), true) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot prepare statement in execute_many: {err}" + )) + })?; + + // Execute all statements using the same prepared statement + for statement in statements { + self.query(&prepared_stmt, &statement.params()) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Error occurred in `execute_many` statement: {err}" + )) + })?; + } + } else { + // Execute each statement without preparation + for statement in statements { + self.query(statement.raw_query(), &statement.params()) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Error occurred in `execute_many` statement: {err}" + )) + })?; } } From d48163adf52d642e71f6d2105aefc494be5e76d3 Mon Sep 17 00:00:00 2001 From: bymoye Date: Wed, 2 Jul 2025 15:15:40 +0800 Subject: [PATCH 6/7] more optimizations --- src/driver/common.rs | 98 ++++++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 39 deletions(-) diff --git a/src/driver/common.rs b/src/driver/common.rs index 3789c1f6..b2ff6a52 100644 --- a/src/driver/common.rs +++ b/src/driver/common.rs @@ -16,7 +16,7 @@ use crate::{ use bytes::BytesMut; use futures_util::pin_mut; -use pyo3::{buffer::PyBuffer, PyErr, Python}; +use pyo3::{buffer::PyBuffer, Python}; use tokio_postgres::binary_copy::BinaryCopyInWriter; use crate::format_helpers::quote_ident; @@ -249,50 +249,70 @@ macro_rules! impl_binary_copy_method { columns: Option>, schema_name: Option, ) -> PSQLPyResult { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone()); - let mut table_name = quote_ident(&table_name); - if let Some(schema_name) = schema_name { - table_name = format!("{}.{}", quote_ident(&schema_name), table_name); - } - - let mut formated_columns = String::default(); - if let Some(columns) = columns { - formated_columns = format!("({})", columns.join(", ")); - } + let (db_client, mut bytes_mut) = + Python::with_gil(|gil| -> PSQLPyResult<(Option<_>, BytesMut)> { + let db_client = self_.borrow(gil).conn.clone(); + + let Some(db_client) = db_client else { + return Ok((None, BytesMut::new())); + }; + + let data_bytes_mut = + if let Ok(py_buffer) = source.extract::>(gil) { + let buffer_len = py_buffer.len_bytes(); + let mut bytes_mut = BytesMut::zeroed(buffer_len); + + py_buffer.copy_to_slice(gil, &mut bytes_mut[..])?; + bytes_mut + } else if let Ok(py_bytes) = source.call_method0(gil, "getvalue") { + if let Ok(bytes_vec) = py_bytes.extract::>(gil) { + let bytes_mut = BytesMut::from(&bytes_vec[..]); + bytes_mut + } else { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "source must be bytes or support Buffer protocol".into(), + )); + } + } else { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "source must be bytes or support Buffer protocol".into(), + )); + }; + + Ok((Some(db_client), data_bytes_mut)) + })?; - let copy_qs = - format!("COPY {table_name}{formated_columns} FROM STDIN (FORMAT binary)"); + let Some(db_client) = db_client else { + return Ok(0); + }; - if let Some(db_client) = db_client { - let mut psql_bytes: BytesMut = Python::with_gil(|gil| { - let possible_py_buffer: Result, PyErr> = - source.extract::>(gil); - if let Ok(py_buffer) = possible_py_buffer { - let vec_buf = py_buffer.to_vec(gil)?; - return Ok(BytesMut::from(vec_buf.as_slice())); - } + let full_table_name = match schema_name { + Some(schema) => { + format!("{}.{}", quote_ident(&schema), quote_ident(&table_name)) + } + None => quote_ident(&table_name), + }; - if let Ok(py_bytes) = source.call_method0(gil, "getvalue") { - if let Ok(bytes) = py_bytes.extract::>(gil) { - return Ok(BytesMut::from(bytes.as_slice())); - } - } + let copy_qs = match columns { + Some(ref cols) if !cols.is_empty() => { + format!( + "COPY {}({}) FROM STDIN (FORMAT binary)", + full_table_name, + cols.join(", ") + ) + } + _ => format!("COPY {} FROM STDIN (FORMAT binary)", full_table_name), + }; - Err(RustPSQLDriverError::PyToRustValueConversionError( - "source must be bytes or support Buffer protocol".into(), - )) - })?; + let read_conn_g = db_client.read().await; + let sink = read_conn_g.copy_in(©_qs).await?; + let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]); + pin_mut!(writer); - let read_conn_g = db_client.read().await; - let sink = read_conn_g.copy_in(©_qs).await?; - let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]); - pin_mut!(writer); - writer.as_mut().write_raw_bytes(&mut psql_bytes).await?; - let rows_created = writer.as_mut().finish_empty().await?; - return Ok(rows_created); - } + writer.as_mut().write_raw_bytes(&mut bytes_mut).await?; + let rows_created = writer.as_mut().finish_empty().await?; - Ok(0) + Ok(rows_created) } } }; From 5bf59e2aa2384390ca49a1669a9f473aa1af20c6 Mon Sep 17 00:00:00 2001 From: bymoye Date: Thu, 3 Jul 2025 23:22:56 +0800 Subject: [PATCH 7/7] fix clippy warning --- src/value_converter/models/serde_value.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/value_converter/models/serde_value.rs b/src/value_converter/models/serde_value.rs index 8c499c0a..ebe95f58 100644 --- a/src/value_converter/models/serde_value.rs +++ b/src/value_converter/models/serde_value.rs @@ -60,8 +60,7 @@ impl<'a> FromSql<'a> for InternalSerdeValue { fn serde_value_from_list(_gil: Python<'_>, bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { let py_list = bind_value.downcast::().map_err(|e| { RustPSQLDriverError::PyToRustValueConversionError(format!( - "Parameter must be a list, but it's not: {}", - e + "Parameter must be a list, but it's not: {e}" )) })?;