Box Pgconnection fields (#3529)

* Update PgConnection code

* rustfmt
This commit is contained in:
joeydewaal 2024-10-02 20:42:54 +02:00 committed by GitHub
parent 81298b86b3
commit 68da5aefea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 137 additions and 104 deletions

View file

@ -163,7 +163,7 @@ impl PgConnection {
}
// next we check a local cache for user-defined type names <-> object id
if let Some(info) = self.cache_type_info.get(&oid) {
if let Some(info) = self.inner.cache_type_info.get(&oid) {
return Ok(info.clone());
}
@ -173,8 +173,9 @@ impl PgConnection {
// cache the type name <-> oid relationship in a paired hashmap
// so we don't come down this road again
self.cache_type_info.insert(oid, info.clone());
self.cache_type_oid
self.inner.cache_type_info.insert(oid, info.clone());
self.inner
.cache_type_oid
.insert(info.0.name().to_string().into(), oid);
Ok(info)
@ -374,7 +375,7 @@ WHERE rngtypid = $1
}
pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<Oid, Error> {
if let Some(oid) = self.cache_type_oid.get(name) {
if let Some(oid) = self.inner.cache_type_oid.get(name) {
return Ok(*oid);
}
@ -387,15 +388,18 @@ WHERE rngtypid = $1
type_name: name.into(),
})?;
self.cache_type_oid.insert(name.to_string().into(), oid);
self.inner
.cache_type_oid
.insert(name.to_string().into(), oid);
Ok(oid)
}
pub(crate) async fn fetch_array_type_id(&mut self, array: &PgArrayOf) -> Result<Oid, Error> {
if let Some(oid) = self
.inner
.cache_type_oid
.get(&array.elem_name)
.and_then(|elem_oid| self.cache_elem_type_to_array.get(elem_oid))
.and_then(|elem_oid| self.inner.cache_elem_type_to_array.get(elem_oid))
{
return Ok(*oid);
}
@ -411,10 +415,13 @@ WHERE rngtypid = $1
})?;
// Avoids copying `elem_name` until necessary
self.cache_type_oid
self.inner
.cache_type_oid
.entry_ref(&array.elem_name)
.insert(elem_oid);
self.cache_elem_type_to_array.insert(elem_oid, array_oid);
self.inner
.cache_elem_type_to_array
.insert(elem_oid, array_oid);
Ok(array_oid)
}
@ -475,8 +482,16 @@ WHERE rngtypid = $1
})?;
// If the server is CockroachDB or Materialize, skip this step (#1248).
if !self.stream.parameter_statuses.contains_key("crdb_version")
&& !self.stream.parameter_statuses.contains_key("mz_version")
if !self
.inner
.stream
.parameter_statuses
.contains_key("crdb_version")
&& !self
.inner
.stream
.parameter_statuses
.contains_key("mz_version")
{
// patch up our null inference with data from EXPLAIN
let nullable_patch = self

View file

@ -9,6 +9,8 @@ use crate::message::{
};
use crate::{PgConnectOptions, PgConnection};
use super::PgConnectionInner;
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11
@ -134,18 +136,20 @@ impl PgConnection {
}
Ok(PgConnection {
stream,
process_id,
secret_key,
transaction_status,
transaction_depth: 0,
pending_ready_for_query_count: 0,
next_statement_id: StatementId::NAMED_START,
cache_statement: StatementCache::new(options.statement_cache_capacity),
cache_type_oid: HashMap::new(),
cache_type_info: HashMap::new(),
cache_elem_type_to_array: HashMap::new(),
log_settings: options.log_settings.clone(),
inner: Box::new(PgConnectionInner {
stream,
process_id,
secret_key,
transaction_status,
transaction_depth: 0,
pending_ready_for_query_count: 0,
next_statement_id: StatementId::NAMED_START,
cache_statement: StatementCache::new(options.statement_cache_capacity),
cache_type_oid: HashMap::new(),
cache_type_info: HashMap::new(),
cache_elem_type_to_array: HashMap::new(),
log_settings: options.log_settings.clone(),
}),
})
}
}

View file

@ -26,8 +26,8 @@ async fn prepare(
parameters: &[PgTypeInfo],
metadata: Option<Arc<PgStatementMetadata>>,
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
let id = conn.next_statement_id;
conn.next_statement_id = id.next();
let id = conn.inner.next_statement_id;
conn.inner.next_statement_id = id.next();
// build a list of type OIDs to send to the database in the PARSE command
// we have not yet started the query sequence, so we are *safe* to cleanly make
@ -43,7 +43,7 @@ async fn prepare(
conn.wait_until_ready().await?;
// next we send the PARSE command to the server
conn.stream.write_msg(Parse {
conn.inner.stream.write_msg(Parse {
param_types: &param_types,
query: sql,
statement: id,
@ -51,15 +51,17 @@ async fn prepare(
if metadata.is_none() {
// get the statement columns and parameters
conn.stream.write_msg(message::Describe::Statement(id))?;
conn.inner
.stream
.write_msg(message::Describe::Statement(id))?;
}
// we ask for the server to immediately send us the result of the PARSE command
conn.write_sync();
conn.stream.flush().await?;
conn.inner.stream.flush().await?;
// indicates that the SQL query string is now successfully parsed and has semantic validity
conn.stream.recv_expect::<ParseComplete>().await?;
conn.inner.stream.recv_expect::<ParseComplete>().await?;
let metadata = if let Some(metadata) = metadata {
// each SYNC produces one READY FOR QUERY
@ -94,11 +96,11 @@ async fn prepare(
}
async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
conn.stream.recv_expect().await
conn.inner.stream.recv_expect().await
}
async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
let rows: Option<RowDescription> = match conn.stream.recv().await? {
let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
// describes the rows that will be returned when the statement is eventually executed
message if message.format == BackendMessageFormat::RowDescription => {
Some(message.decode()?)
@ -123,7 +125,7 @@ impl PgConnection {
pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
// we need to wait for the [CloseComplete] to be returned from the server
while count > 0 {
match self.stream.recv().await? {
match self.inner.stream.recv().await? {
message if message.format == BackendMessageFormat::PortalSuspended => {
// there was an open portal
// this can happen if the last time a statement was used it was not fully executed
@ -148,12 +150,13 @@ impl PgConnection {
#[inline(always)]
pub(crate) fn write_sync(&mut self) {
self.stream
self.inner
.stream
.write_msg(message::Sync)
.expect("BUG: Sync should not be too big for protocol");
// all SYNC messages will return a ReadyForQuery
self.pending_ready_for_query_count += 1;
self.inner.pending_ready_for_query_count += 1;
}
async fn get_or_prepare<'a>(
@ -166,18 +169,18 @@ impl PgConnection {
// a statement object
metadata: Option<Arc<PgStatementMetadata>>,
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
if let Some(statement) = self.cache_statement.get_mut(sql) {
if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
return Ok((*statement).clone());
}
let statement = prepare(self, sql, parameters, metadata).await?;
if store_to_cache && self.cache_statement.is_enabled() {
if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) {
self.stream.write_msg(Close::Statement(id))?;
if store_to_cache && self.inner.cache_statement.is_enabled() {
if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
self.inner.stream.write_msg(Close::Statement(id))?;
self.write_sync();
self.stream.flush().await?;
self.inner.stream.flush().await?;
self.wait_for_close_complete(1).await?;
self.recv_ready_for_query().await?;
@ -195,7 +198,7 @@ impl PgConnection {
persistent: bool,
metadata_opt: Option<Arc<PgStatementMetadata>>,
) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
let mut logger = QueryLogger::new(query, self.log_settings.clone());
let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
// before we continue, wait until we are "ready" to accept more queries
self.wait_until_ready().await?;
@ -231,7 +234,7 @@ impl PgConnection {
self.wait_until_ready().await?;
// bind to attach the arguments to the statement and create a portal
self.stream.write_msg(Bind {
self.inner.stream.write_msg(Bind {
portal: PortalId::UNNAMED,
statement,
formats: &[PgValueFormat::Binary],
@ -242,7 +245,7 @@ impl PgConnection {
// executes the portal up to the passed limit
// the protocol-level limit acts nearly identically to the `LIMIT` in SQL
self.stream.write_msg(message::Execute {
self.inner.stream.write_msg(message::Execute {
portal: PortalId::UNNAMED,
limit: limit.into(),
})?;
@ -255,7 +258,9 @@ impl PgConnection {
// we ask the database server to close the unnamed portal and free the associated resources
// earlier - after the execution of the current query.
self.stream.write_msg(Close::Portal(PortalId::UNNAMED))?;
self.inner
.stream
.write_msg(Close::Portal(PortalId::UNNAMED))?;
// finally, [Sync] asks postgres to process the messages that we sent and respond with
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send
@ -268,8 +273,8 @@ impl PgConnection {
PgValueFormat::Binary
} else {
// Query will trigger a ReadyForQuery
self.stream.write_msg(Query(query))?;
self.pending_ready_for_query_count += 1;
self.inner.stream.write_msg(Query(query))?;
self.inner.pending_ready_for_query_count += 1;
// metadata starts out as "nothing"
metadata = Arc::new(PgStatementMetadata::default());
@ -278,11 +283,11 @@ impl PgConnection {
PgValueFormat::Text
};
self.stream.flush().await?;
self.inner.stream.flush().await?;
Ok(try_stream! {
loop {
let message = self.stream.recv().await?;
let message = self.inner.stream.recv().await?;
match message.format {
BackendMessageFormat::BindComplete

View file

@ -31,6 +31,10 @@ mod tls;
/// A connection to a PostgreSQL database.
pub struct PgConnection {
pub(crate) inner: Box<PgConnectionInner>,
}
pub struct PgConnectionInner {
// underlying TCP or UDS stream,
// wrapped in a potentially TLS stream,
// wrapped in a buffered stream
@ -71,17 +75,17 @@ pub struct PgConnection {
impl PgConnection {
/// the version number of the server in `libpq` format
pub fn server_version_num(&self) -> Option<u32> {
self.stream.server_version_num
self.inner.stream.server_version_num
}
// will return when the connection is ready for another query
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
if !self.stream.write_buffer_mut().is_empty() {
self.stream.flush().await?;
if !self.inner.stream.write_buffer_mut().is_empty() {
self.inner.stream.flush().await?;
}
while self.pending_ready_for_query_count > 0 {
let message = self.stream.recv().await?;
while self.inner.pending_ready_for_query_count > 0 {
let message = self.inner.stream.recv().await?;
if let BackendMessageFormat::ReadyForQuery = message.format {
self.handle_ready_for_query(message)?;
@ -92,22 +96,23 @@ impl PgConnection {
}
async fn recv_ready_for_query(&mut self) -> Result<(), Error> {
let r: ReadyForQuery = self.stream.recv_expect().await?;
let r: ReadyForQuery = self.inner.stream.recv_expect().await?;
self.pending_ready_for_query_count -= 1;
self.transaction_status = r.transaction_status;
self.inner.pending_ready_for_query_count -= 1;
self.inner.transaction_status = r.transaction_status;
Ok(())
}
#[inline(always)]
fn handle_ready_for_query(&mut self, message: ReceivedMessage) -> Result<(), Error> {
self.pending_ready_for_query_count = self
self.inner.pending_ready_for_query_count = self
.inner
.pending_ready_for_query_count
.checked_sub(1)
.ok_or_else(|| err_protocol!("received more ReadyForQuery messages than expected"))?;
self.transaction_status = message.decode::<ReadyForQuery>()?.transaction_status;
self.inner.transaction_status = message.decode::<ReadyForQuery>()?.transaction_status;
Ok(())
}
@ -117,8 +122,8 @@ impl PgConnection {
/// Used for rolling back transactions and releasing advisory locks.
#[inline(always)]
pub(crate) fn queue_simple_query(&mut self, query: &str) -> Result<(), Error> {
self.stream.write_msg(Query(query))?;
self.pending_ready_for_query_count += 1;
self.inner.stream.write_msg(Query(query))?;
self.inner.pending_ready_for_query_count += 1;
Ok(())
}
@ -143,8 +148,8 @@ impl Connection for PgConnection {
// connection and terminates.
Box::pin(async move {
self.stream.send(Terminate).await?;
self.stream.shutdown().await?;
self.inner.stream.send(Terminate).await?;
self.inner.stream.shutdown().await?;
Ok(())
})
@ -152,7 +157,7 @@ impl Connection for PgConnection {
fn close_hard(mut self) -> BoxFuture<'static, Result<(), Error>> {
Box::pin(async move {
self.stream.shutdown().await?;
self.inner.stream.shutdown().await?;
Ok(())
})
@ -178,25 +183,25 @@ impl Connection for PgConnection {
}
fn cached_statements_size(&self) -> usize {
self.cache_statement.len()
self.inner.cache_statement.len()
}
fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
self.cache_type_oid.clear();
self.inner.cache_type_oid.clear();
let mut cleared = 0_usize;
self.wait_until_ready().await?;
while let Some((id, _)) = self.cache_statement.remove_lru() {
self.stream.write_msg(Close::Statement(id))?;
while let Some((id, _)) = self.inner.cache_statement.remove_lru() {
self.inner.stream.write_msg(Close::Statement(id))?;
cleared += 1;
}
if cleared > 0 {
self.write_sync();
self.stream.flush().await?;
self.inner.stream.flush().await?;
self.wait_for_close_complete(cleared).await?;
self.recv_ready_for_query().await?;
@ -207,7 +212,7 @@ impl Connection for PgConnection {
}
fn shrink_buffers(&mut self) {
self.stream.shrink_buffers();
self.inner.stream.shrink_buffers();
}
#[doc(hidden)]
@ -217,7 +222,7 @@ impl Connection for PgConnection {
#[doc(hidden)]
fn should_flush(&self) -> bool {
!self.stream.write_buffer().is_empty()
!self.inner.stream.write_buffer().is_empty()
}
}

View file

@ -145,12 +145,12 @@ pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
async fn begin(mut conn: C, statement: &str) -> Result<Self> {
conn.wait_until_ready().await?;
conn.stream.send(Query(statement)).await?;
conn.inner.stream.send(Query(statement)).await?;
let response = match conn.stream.recv_expect::<CopyInResponse>().await {
let response = match conn.inner.stream.recv_expect::<CopyInResponse>().await {
Ok(res) => res.0,
Err(e) => {
conn.stream.recv().await?;
conn.inner.stream.recv().await?;
return Err(e);
}
};
@ -191,6 +191,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
self.conn
.as_deref_mut()
.expect("send_data: conn taken")
.inner
.stream
.send(CopyData(data))
.await?;
@ -215,7 +216,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
loop {
let buf = conn.stream.write_buffer_mut();
let buf = conn.inner.stream.write_buffer_mut();
// Write the CopyData format code and reserve space for the length.
// This may end up sending an empty `CopyData` packet if, after this point,
@ -234,7 +235,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
(&mut buf.get_mut()[1..]).put_u32(read32 + 4);
conn.stream.flush().await?;
conn.inner.stream.flush().await?;
}
Ok(self)
@ -251,9 +252,9 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
.take()
.expect("PgCopyIn::fail_with: conn taken illegally");
conn.stream.send(CopyFail::new(msg)).await?;
conn.inner.stream.send(CopyFail::new(msg)).await?;
match conn.stream.recv().await {
match conn.inner.stream.recv().await {
Ok(msg) => Err(err_protocol!(
"fail_with: expected ErrorResponse, got: {:?}",
msg.format
@ -262,7 +263,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
match e.code() {
Some(Cow::Borrowed("57014")) => {
// postgres abort received error code
conn.stream.recv_expect::<ReadyForQuery>().await?;
conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
Ok(())
}
_ => Err(Error::Database(e)),
@ -281,16 +282,16 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
.take()
.expect("CopyWriter::finish: conn taken illegally");
conn.stream.send(CopyDone).await?;
let cc: CommandComplete = match conn.stream.recv_expect().await {
conn.inner.stream.send(CopyDone).await?;
let cc: CommandComplete = match conn.inner.stream.recv_expect().await {
Ok(cc) => cc,
Err(e) => {
conn.stream.recv().await?;
conn.inner.stream.recv().await?;
return Err(e);
}
};
conn.stream.recv_expect::<ReadyForQuery>().await?;
conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
Ok(cc.rows_affected())
}
@ -299,7 +300,8 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
fn drop(&mut self) {
if let Some(mut conn) = self.conn.take() {
conn.stream
conn.inner
.stream
.write_msg(CopyFail::new(
"PgCopyIn dropped without calling finish() or fail()",
))
@ -313,23 +315,23 @@ async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
statement: &str,
) -> Result<BoxStream<'c, Result<Bytes>>> {
conn.wait_until_ready().await?;
conn.stream.send(Query(statement)).await?;
conn.inner.stream.send(Query(statement)).await?;
let _: CopyOutResponse = conn.stream.recv_expect().await?;
let _: CopyOutResponse = conn.inner.stream.recv_expect().await?;
let stream: TryAsyncStream<'c, Bytes> = try_stream! {
loop {
match conn.stream.recv().await {
match conn.inner.stream.recv().await {
Err(e) => {
conn.stream.recv_expect::<ReadyForQuery>().await?;
conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
return Err(e);
},
Ok(msg) => match msg.format {
BackendMessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
BackendMessageFormat::CopyDone => {
let _ = msg.decode::<CopyDone>()?;
conn.stream.recv_expect::<CommandComplete>().await?;
conn.stream.recv_expect::<ReadyForQuery>().await?;
conn.inner.stream.recv_expect::<CommandComplete>().await?;
conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
return Ok(())
},
_ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))

View file

@ -58,7 +58,7 @@ impl PgListener {
// Setup a notification buffer
let (sender, receiver) = mpsc::unbounded();
connection.stream.notifications = Some(sender);
connection.inner.stream.notifications = Some(sender);
Ok(Self {
pool: pool.clone(),
@ -155,7 +155,7 @@ impl PgListener {
async fn connect_if_needed(&mut self) -> Result<(), Error> {
if self.connection.is_none() {
let mut connection = self.pool.acquire().await?;
connection.stream.notifications = self.buffer_tx.take();
connection.inner.stream.notifications = self.buffer_tx.take();
connection
.execute(&*build_listen_all_query(&self.channels))
@ -243,7 +243,7 @@ impl PgListener {
let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
loop {
let next_message = self.connection().await?.stream.recv_unchecked();
let next_message = self.connection().await?.inner.stream.recv_unchecked();
let res = if let Some(ref mut close_event) = close_event {
// cancels the wait and returns `Err(PoolClosed)` if the pool is closed
@ -263,7 +263,7 @@ impl PgListener {
|| err.kind() == io::ErrorKind::UnexpectedEof) =>
{
if let Some(mut conn) = self.connection.take() {
self.buffer_tx = conn.stream.notifications.take();
self.buffer_tx = conn.inner.stream.notifications.take();
// Close the connection in a background task, so we can continue.
conn.close_on_drop();
}
@ -286,7 +286,7 @@ impl PgListener {
// Mark the connection as ready for another query
BackendMessageFormat::ReadyForQuery => {
self.connection().await?.pending_ready_for_query_count -= 1;
self.connection().await?.inner.pending_ready_for_query_count -= 1;
}
// Ignore unexpected messages

View file

@ -16,9 +16,9 @@ impl TransactionManager for PgTransactionManager {
fn begin(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
let rollback = Rollback::new(conn);
let query = begin_ansi_transaction_sql(rollback.conn.transaction_depth);
let query = begin_ansi_transaction_sql(rollback.conn.inner.transaction_depth);
rollback.conn.queue_simple_query(&query)?;
rollback.conn.transaction_depth += 1;
rollback.conn.inner.transaction_depth += 1;
rollback.conn.wait_until_ready().await?;
rollback.defuse();
@ -28,11 +28,11 @@ impl TransactionManager for PgTransactionManager {
fn commit(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
if conn.transaction_depth > 0 {
conn.execute(&*commit_ansi_transaction_sql(conn.transaction_depth))
if conn.inner.transaction_depth > 0 {
conn.execute(&*commit_ansi_transaction_sql(conn.inner.transaction_depth))
.await?;
conn.transaction_depth -= 1;
conn.inner.transaction_depth -= 1;
}
Ok(())
@ -41,11 +41,13 @@ impl TransactionManager for PgTransactionManager {
fn rollback(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
if conn.transaction_depth > 0 {
conn.execute(&*rollback_ansi_transaction_sql(conn.transaction_depth))
.await?;
if conn.inner.transaction_depth > 0 {
conn.execute(&*rollback_ansi_transaction_sql(
conn.inner.transaction_depth,
))
.await?;
conn.transaction_depth -= 1;
conn.inner.transaction_depth -= 1;
}
Ok(())
@ -53,11 +55,11 @@ impl TransactionManager for PgTransactionManager {
}
fn start_rollback(conn: &mut PgConnection) {
if conn.transaction_depth > 0 {
conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth))
if conn.inner.transaction_depth > 0 {
conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.inner.transaction_depth))
.expect("BUG: Rollback query somehow too large for protocol");
conn.transaction_depth -= 1;
conn.inner.transaction_depth -= 1;
}
}
}