fix(postgres): add missing type resolution for arrays by name

This commit is contained in:
Austin Bonander 2024-07-04 17:17:20 -07:00
parent efbf57265c
commit 16e3f1025a
19 changed files with 333 additions and 84 deletions

28
Cargo.lock generated
View file

@ -35,7 +35,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
dependencies = [
"cfg-if",
"getrandom",
"once_cell",
"version_check",
"zerocopy",
@ -574,9 +573,9 @@ dependencies = [
[[package]]
name = "borsh"
version = "1.3.1"
version = "1.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f58b559fd6448c6e2fd0adb5720cd98a2506594cafa4737ff98c396f3e82f667"
checksum = "a6362ed55def622cddc70a4746a68554d7b687713770de539e59a739b249f8ed"
dependencies = [
"borsh-derive",
"cfg_aliases",
@ -584,9 +583,9 @@ dependencies = [
[[package]]
name = "borsh-derive"
version = "1.3.1"
version = "1.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7aadb5b6ccbd078890f6d7003694e33816e6b784358f18e15e7e6d9f065a57cd"
checksum = "c3ef8005764f53cd4dca619f5bf64cafd4664dada50ece25e4d81de54c80cc0b"
dependencies = [
"once_cell",
"proc-macro-crate",
@ -704,9 +703,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.1.1"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chrono"
@ -1561,9 +1560,9 @@ dependencies = [
[[package]]
name = "hashbrown"
version = "0.14.3"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash 0.8.11",
"allocator-api2",
@ -1575,7 +1574,7 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "692eaaf7f7607518dd3cef090f1474b61edc5301d8012f09579920df68b725ee"
dependencies = [
"hashbrown 0.14.3",
"hashbrown 0.14.5",
]
[[package]]
@ -1789,7 +1788,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4"
dependencies = [
"equivalent",
"hashbrown 0.14.3",
"hashbrown 0.14.5",
]
[[package]]
@ -3227,7 +3226,6 @@ dependencies = [
name = "sqlx-core"
version = "0.8.0-alpha.0"
dependencies = [
"ahash 0.8.11",
"async-io 1.13.0",
"async-std",
"atoi",
@ -3248,6 +3246,7 @@ dependencies = [
"futures-intrusive",
"futures-io",
"futures-util",
"hashbrown 0.14.5",
"hashlink",
"hex",
"indexmap 2.2.5",
@ -3524,6 +3523,7 @@ dependencies = [
"serde_json",
"sha2",
"smallvec",
"sqlx",
"sqlx-core",
"stringprep",
"thiserror",
@ -3837,9 +3837,9 @@ dependencies = [
[[package]]
name = "toml_datetime"
version = "0.6.5"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1"
checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf"
[[package]]
name = "toml_edit"

View file

@ -135,7 +135,7 @@ bit-vec = "0.6.3"
chrono = { version = "0.4.22", default-features = false }
ipnetwork = "0.20.0"
mac_address = "1.1.5"
rust_decimal = "1.26.1"
rust_decimal = { version = "1.26.1", default-features = false, features = ["std"] }
time = { version = "0.3.36", features = ["formatting", "parsing", "macros"] }
uuid = "1.1.2"

View file

@ -51,7 +51,6 @@ uuid = { workspace = true, optional = true }
async-io = { version = "1.9.0", optional = true }
paste = "1.0.6"
ahash = "0.8.7"
atoi = "2.0"
bytes = "1.1.0"
@ -88,6 +87,7 @@ bstr = { version = "1.0", default-features = false, features = ["std"], optional
hashlink = "0.9.0"
indexmap = "2.0"
event-listener = "5.2.0"
hashbrown = "0.14.5"
[dev-dependencies]
sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] }

View file

@ -17,6 +17,14 @@ impl UStr {
pub fn new(s: &str) -> Self {
UStr::Shared(Arc::from(s.to_owned()))
}
/// Apply [str::strip_prefix], without copying if possible.
pub fn strip_prefix(this: &Self, prefix: &str) -> Option<Self> {
match this {
UStr::Static(s) => s.strip_prefix(prefix).map(Self::Static),
UStr::Shared(s) => s.strip_prefix(prefix).map(|s| Self::Shared(s.into())),
}
}
}
impl Deref for UStr {
@ -60,6 +68,12 @@ impl From<&'static str> for UStr {
}
}
impl<'a> From<&'a UStr> for UStr {
fn from(value: &'a UStr) -> Self {
value.clone()
}
}
impl From<String> for UStr {
#[inline]
fn from(s: String) -> Self {

View file

@ -95,9 +95,8 @@ pub mod testing;
pub use error::{Error, Result};
/// sqlx uses ahash for increased performance, at the cost of reduced DoS resistance.
pub use ahash::AHashMap as HashMap;
pub use either::Either;
pub use hashbrown::{hash_map, HashMap};
pub use indexmap::IndexMap;
pub use percent_encoding;
pub use smallvec::SmallVec;
@ -105,8 +104,6 @@ pub use url::{self, Url};
pub use bytes;
//type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
/// Helper module to get drivers compiling again that used to be in this crate,
/// to avoid having to replace tons of `use crate::<...>` imports.
///
@ -119,6 +116,6 @@ pub mod driver_prelude {
};
pub use crate::error::{Error, Result};
pub use crate::HashMap;
pub use crate::{hash_map, HashMap};
pub use either::Either;
}

View file

@ -9,6 +9,16 @@ pub trait TypeInfo: Debug + Display + Clone + PartialEq<Self> + Send + Sync {
/// should be a rough approximation of how they are written in SQL in the given database.
fn name(&self) -> &str;
/// Return `true` if `self` and `other` represent mutually compatible types.
///
/// Defaults to `self == other`.
fn type_compatible(&self, other: &Self) -> bool
where
Self: Sized,
{
self == other
}
#[doc(hidden)]
fn is_void(&self) -> bool {
false

View file

@ -210,8 +210,10 @@ pub trait Type<DB: Database> {
///
/// When binding arguments with `query!` or `query_as!`, this method is consulted to determine
/// if the Rust type is acceptable.
///
/// Defaults to checking [`TypeInfo::type_compatible()`].
fn compatible(ty: &DB::TypeInfo) -> bool {
*ty == Self::type_info()
Self::type_info().type_compatible(ty)
}
}

View file

@ -14,28 +14,27 @@ use syn::{
pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
let attrs = parse_container_attributes(&input.attrs)?;
match &input.data {
// Newtype structs:
// struct Foo(i32);
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
..
}) if unnamed.len() == 1 => {
expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap())
}) => {
if unnamed.len() == 1 {
expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap())
} else {
Err(syn::Error::new_spanned(
input,
"structs with zero or more than one unnamed field are not supported",
))
}
}
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
Some(_) => expand_derive_has_sql_type_weak_enum(input, variants),
None => expand_derive_has_sql_type_strong_enum(input, variants),
},
// Record types
// struct Foo { foo: i32, bar: String }
Data::Struct(DataStruct {
fields: Fields::Named(FieldsNamed { named, .. }),
..
}) => expand_derive_has_sql_type_struct(input, named),
Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")),
Data::Struct(DataStruct {
fields: Fields::Unnamed(..),
..
}) => Err(syn::Error::new_spanned(
input,
"structs with zero or more than one unnamed field are not supported",
)),
Data::Struct(DataStruct {
fields: Fields::Unit,
..
@ -43,6 +42,14 @@ pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
input,
"unit structs are not supported",
)),
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
// Enums that encode to/from integers (weak enums)
Some(_) => expand_derive_has_sql_type_weak_enum(input, variants),
// Enums that decode to/from strings (strong enums)
None => expand_derive_has_sql_type_strong_enum(input, variants),
},
Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")),
}
}
@ -148,9 +155,10 @@ fn expand_derive_has_sql_type_weak_enum(
if cfg!(feature = "postgres") && !attrs.no_pg_array {
ts.extend(quote!(
#[automatically_derived]
impl ::sqlx::postgres::PgHasArrayType for #ident {
fn array_type_info() -> ::sqlx::postgres::PgTypeInfo {
<#ident as ::sqlx::postgres::PgHasArrayType>::array_type_info()
<#repr as ::sqlx::postgres::PgHasArrayType>::array_type_info()
}
}
));
@ -197,9 +205,10 @@ fn expand_derive_has_sql_type_strong_enum(
if !attributes.no_pg_array {
tts.extend(quote!(
#[automatically_derived]
impl ::sqlx::postgres::PgHasArrayType for #ident {
fn array_type_info() -> ::sqlx::postgres::PgTypeInfo {
<#ident as ::sqlx::postgres::PgHasArrayType>::array_type_info()
::sqlx::postgres::PgTypeInfo::array_of(#ty_name)
}
}
));
@ -244,6 +253,17 @@ fn expand_derive_has_sql_type_struct(
}
}
));
if !attributes.no_pg_array {
tts.extend(quote!(
#[automatically_derived]
impl ::sqlx::postgres::PgHasArrayType for #ident {
fn array_type_info() -> ::sqlx::postgres::PgTypeInfo {
::sqlx::postgres::PgTypeInfo::array_of(#ty_name)
}
}
));
}
}
Ok(tts)

View file

@ -71,5 +71,8 @@ workspace = true
# We use JSON in the driver implementation itself so there's no reason not to enable it here.
features = ["json"]
[dev-dependencies]
sqlx.workspace = true
[target.'cfg(target_os = "windows")'.dependencies]
etcetera = "0.8.0"

View file

@ -1,5 +1,6 @@
use std::fmt::{self, Write};
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use crate::encode::{Encode, IsNull};
use crate::error::Error;
@ -7,6 +8,7 @@ use crate::ext::ustr::UStr;
use crate::types::Type;
use crate::{PgConnection, PgTypeInfo, Postgres};
use crate::type_info::PgArrayOf;
pub(crate) use sqlx_core::arguments::Arguments;
use sqlx_core::error::BoxDynError;
@ -41,7 +43,12 @@ pub struct PgArgumentBuffer {
// This is done for Records and Arrays as the OID is needed well before we are in an async
// function and can just ask postgres.
//
type_holes: Vec<(usize, UStr)>, // Vec<{ offset, type_name }>
type_holes: Vec<(usize, HoleKind)>, // Vec<{ offset, type_name }>
}
enum HoleKind {
Type { name: UStr },
Array(Arc<PgArrayOf>),
}
struct Patch {
@ -106,8 +113,11 @@ impl PgArguments {
(patch.callback)(buf, ty);
}
for (offset, name) in type_holes {
let oid = conn.fetch_type_id_by_name(name).await?;
for (offset, kind) in type_holes {
let oid = match kind {
HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?,
HoleKind::Array(array) => conn.fetch_array_type_id(array).await?,
};
buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
}
@ -186,7 +196,19 @@ impl PgArgumentBuffer {
let offset = self.len();
self.extend_from_slice(&0_u32.to_be_bytes());
self.type_holes.push((offset, type_name.clone()));
self.type_holes.push((
offset,
HoleKind::Type {
name: type_name.clone(),
},
));
}
pub(crate) fn patch_array_type(&mut self, array: Arc<PgArrayOf>) {
let offset = self.len();
self.extend_from_slice(&0_u32.to_be_bytes());
self.type_holes.push((offset, HoleKind::Array(array)));
}
fn snapshot(&self) -> PgArgumentBufferSnapshot {

View file

@ -4,7 +4,7 @@ use crate::message::{ParameterDescription, RowDescription};
use crate::query_as::query_as;
use crate::query_scalar::{query_scalar, query_scalar_with};
use crate::statement::PgStatementMetadata;
use crate::type_info::{PgCustomType, PgType, PgTypeKind};
use crate::type_info::{PgArrayOf, PgCustomType, PgType, PgTypeKind};
use crate::types::Json;
use crate::types::Oid;
use crate::HashMap;
@ -355,6 +355,19 @@ WHERE rngtypid = $1
})
}
pub(crate) async fn resolve_type_id(&mut self, ty: &PgType) -> Result<Oid, Error> {
if let Some(oid) = ty.try_oid() {
return Ok(oid);
}
match ty {
PgType::DeclareWithName(name) => self.fetch_type_id_by_name(name).await,
PgType::DeclareArrayOf(array) => self.fetch_array_type_id(array).await,
// `.try_oid()` should return `Some()` or it should be covered here
_ => unreachable!("(bug) OID should be resolvable for type {ty:?}"),
}
}
pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<Oid, Error> {
if let Some(oid) = self.cache_type_oid.get(name) {
return Ok(*oid);
@ -366,13 +379,41 @@ WHERE rngtypid = $1
.fetch_optional(&mut *self)
.await?
.ok_or_else(|| Error::TypeNotFound {
type_name: String::from(name),
type_name: name.into(),
})?;
self.cache_type_oid.insert(name.to_string().into(), oid);
Ok(oid)
}
pub(crate) async fn fetch_array_type_id(&mut self, array: &PgArrayOf) -> Result<Oid, Error> {
if let Some(oid) = self
.cache_type_oid
.get(&array.elem_name)
.and_then(|elem_oid| self.cache_elem_type_to_array.get(elem_oid))
{
return Ok(*oid);
}
// language=SQL
let (elem_oid, array_oid): (Oid, Oid) =
query_as("SELECT oid, typarray FROM pg_catalog.pg_type WHERE oid = $1::regtype::oid")
.bind(&*array.elem_name)
.fetch_optional(&mut *self)
.await?
.ok_or_else(|| Error::TypeNotFound {
type_name: array.name.to_string(),
})?;
// Avoids copying `elem_name` until necessary
self.cache_type_oid
.entry_ref(&array.elem_name)
.insert(elem_oid);
self.cache_elem_type_to_array.insert(elem_oid, array_oid);
Ok(array_oid)
}
pub(crate) async fn get_nullable_for_columns(
&mut self,
stmt_id: Oid,

View file

@ -146,6 +146,7 @@ impl PgConnection {
cache_statement: StatementCache::new(options.statement_cache_capacity),
cache_type_oid: HashMap::new(),
cache_type_info: HashMap::new(),
cache_elem_type_to_array: HashMap::new(),
log_settings: options.log_settings.clone(),
})
}

View file

@ -7,7 +7,6 @@ use crate::message::{
RowDescription,
};
use crate::statement::PgStatementMetadata;
use crate::type_info::PgType;
use crate::types::Oid;
use crate::{
statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
@ -36,11 +35,7 @@ async fn prepare(
let mut param_types = Vec::with_capacity(parameters.len());
for ty in parameters {
param_types.push(if let PgType::DeclareWithName(name) = &ty.0 {
conn.fetch_type_id_by_name(name).await?
} else {
ty.0.oid()
});
param_types.push(conn.resolve_type_id(&ty.0).await?);
}
// flush and wait until we are re-ready

View file

@ -55,6 +55,7 @@ pub struct PgConnection {
// cache user-defined types by id <-> info
cache_type_info: HashMap<Oid, PgTypeInfo>,
cache_type_oid: HashMap<UStr, Oid>,
cache_elem_type_to_array: HashMap<Oid, Oid>,
// number of ReadyForQuery messages that we are currently expecting
pub(crate) pending_ready_for_query_count: usize,

View file

@ -11,6 +11,34 @@ use crate::types::Oid;
pub(crate) use sqlx_core::type_info::TypeInfo;
/// Type information for a PostgreSQL type.
///
/// ### Note: Implementation of `==` ([`PartialEq::eq()`])
/// Because `==` on [`TypeInfo`]s has been used throughout the SQLx API as a synonym for type compatibility,
/// e.g. in the default impl of [`Type::compatible()`][sqlx_core::types::Type::compatible],
/// some concessions have been made in the implementation.
///
/// When comparing two `PgTypeInfo`s using the `==` operator ([`PartialEq::eq()`]),
/// if one was constructed with [`Self::with_oid()`] and the other with [`Self::with_name()`] or
/// [`Self::array_of()`], `==` will return `true`:
///
/// ```
/// # use sqlx::postgres::{types::Oid, PgTypeInfo};
/// // Potentially surprising result, this assert will pass:
/// assert_eq!(PgTypeInfo::with_oid(Oid(1)), PgTypeInfo::with_name("definitely_not_real"));
/// ```
///
/// Since it is not possible in this case to prove the types are _not_ compatible (because
/// both `PgTypeInfo`s need to be resolved by an active connection to know for sure)
/// and type compatibility is mainly done as a sanity check anyway,
/// it was deemed acceptable to fudge equality in this very specific case.
///
/// This also applies when querying with the text protocol (not using prepared statements,
/// e.g. [`sqlx::raw_sql()`][sqlx_core::raw_sql::raw_sql]), as the connection will be unable
/// to look up the type info like it normally does when preparing a statement: it won't know
/// what the OIDs of the output columns will be until it's in the middle of reading the result,
/// and by that time it's too late.
///
/// To compare types for exact equality, use [`Self::type_eq()`] instead.
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct PgTypeInfo(pub(crate) PgType);
@ -132,6 +160,8 @@ pub enum PgType {
// NOTE: Do we want to bring back type declaration by ID? It's notoriously fragile but
// someone may have a user for it
DeclareWithOid(Oid),
DeclareArrayOf(Arc<PgArrayOf>),
}
#[derive(Debug, Clone)]
@ -155,6 +185,13 @@ pub enum PgTypeKind {
Range(PgTypeInfo),
}
#[derive(Debug)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct PgArrayOf {
pub(crate) elem_name: UStr,
pub(crate) name: Box<str>,
}
impl PgTypeInfo {
/// Returns the corresponding `PgTypeInfo` if the OID is a built-in type and recognized by SQLx.
pub(crate) fn try_from_oid(oid: Oid) -> Option<Self> {
@ -233,18 +270,79 @@ impl PgTypeInfo {
///
/// The OID for the type will be fetched from Postgres on use of
/// a value of this type. The fetched OID will be cached per-connection.
///
/// ### Note: Type Names Prefixed with `_`
/// In `pg_catalog.pg_type`, Postgres prefixes a type name with `_` to denote an array of that
/// type, e.g. `int4[]` actually exists in `pg_type` as `_int4`.
///
/// Previously, it was necessary in manual [`PgHasArrayType`][crate::PgHasArrayType] impls
/// to return [`PgTypeInfo::with_name()`] with the type name prefixed with `_` to denote
/// an array type, but this would not work with schema-qualified names.
///
/// As of 0.8, [`PgTypeInfo::array_of()`] is used to declare an array type,
/// and the Postgres driver is now able to properly resolve arrays of custom types,
/// even in other schemas, which was not previously supported.
///
/// It is highly recommended to migrate existing usages to [`PgTypeInfo::array_of()`] where
/// applicable.
///
/// However, to maintain compatibility, the driver now infers any type name prefixed with `_`
/// to be an array of that type. This may introduce some breakages for types which use
/// a `_` prefix but which are not arrays.
///
/// As a workaround, type names with `_` as a prefix but which are not arrays should be wrapped
/// in quotes, e.g.:
/// ```
/// use sqlx::postgres::PgTypeInfo;
/// use sqlx::Type;
///
/// /// `CREATE TYPE "_foo" AS ENUM ('Bar', 'Baz');`
/// #[derive(sqlx::Type)]
/// // Will prevent SQLx from inferring `_foo` as an array type.
/// #[sqlx(type_name = r#""_foo""#)]
/// enum Foo {
/// Bar,
/// Baz
/// }
///
/// assert_eq!(Foo::type_info().name(), r#""_foo""#);
/// ```
pub const fn with_name(name: &'static str) -> Self {
Self(PgType::DeclareWithName(UStr::Static(name)))
}
/// Create a `PgTypeInfo` of an array from the name of its element type.
///
/// The array type OID will be fetched from Postgres on use of a value of this type.
/// The fetched OID will be cached per-connection.
pub fn array_of(elem_name: &'static str) -> Self {
// to satisfy `name()` and `display_name()`, we need to construct strings to return
Self(PgType::DeclareArrayOf(Arc::new(PgArrayOf {
elem_name: elem_name.into(),
name: format!("{elem_name}[]").into(),
})))
}
/// Create a `PgTypeInfo` from an OID.
///
/// Note that the OID for a type is very dependent on the environment. If you only ever use
/// one database or if this is an unhandled built-in type, you should be fine. Otherwise,
/// you will be better served using [`with_name`](Self::with_name).
/// you will be better served using [`Self::with_name()`].
///
/// ### Note: Interaction with `==`
/// This constructor may give surprising results with `==`.
///
/// See [the type-level docs][Self] for details.
pub const fn with_oid(oid: Oid) -> Self {
Self(PgType::DeclareWithOid(oid))
}
/// Returns `true` if `self` can be compared exactly to `other`.
///
/// Unlike `==`, this will return false if
pub fn type_eq(&self, other: &Self) -> bool {
self.eq_impl(other, false)
}
}
// DEVELOPER PRO TIP: find builtin type OIDs easily by grepping this file
@ -464,6 +562,9 @@ impl PgType {
PgType::DeclareWithName(_) => {
return None;
}
PgType::DeclareArrayOf(_) => {
return None;
}
})
}
@ -564,6 +665,7 @@ impl PgType {
PgType::Custom(ty) => &ty.name,
PgType::DeclareWithOid(_) => "?",
PgType::DeclareWithName(name) => name,
PgType::DeclareArrayOf(array) => &array.name,
}
}
@ -664,6 +766,7 @@ impl PgType {
PgType::Custom(ty) => &ty.name,
PgType::DeclareWithOid(_) => "?",
PgType::DeclareWithName(name) => name,
PgType::DeclareArrayOf(array) => &array.name,
}
}
@ -771,13 +874,16 @@ impl PgType {
PgType::DeclareWithName(name) => {
unreachable!("(bug) use of unresolved type declaration [name={name}]");
}
PgType::DeclareArrayOf(array) => {
unreachable!(
"(bug) use of unresolved type declaration [array of={}]",
array.elem_name
);
}
}
}
/// If `self` is an array type, return the type info for its element.
///
/// This method should only be called on resolved types: calling it on
/// a type that is merely declared (DeclareWithOid/Name) is a bug.
pub(crate) fn try_array_element(&self) -> Option<Cow<'_, PgTypeInfo>> {
// We explicitly match on all the `None` cases to ensure an exhaustive match.
match self {
@ -885,14 +991,50 @@ impl PgType {
PgTypeKind::Enum(_) => None,
PgTypeKind::Range(_) => None,
},
PgType::DeclareWithOid(oid) => {
unreachable!("(bug) use of unresolved type declaration [oid={}]", oid.0);
}
PgType::DeclareWithOid(_) => None,
PgType::DeclareWithName(name) => {
unreachable!("(bug) use of unresolved type declaration [name={name}]");
// LEGACY: infer the array element name from a `_` prefix
UStr::strip_prefix(name, "_")
.map(|elem| Cow::Owned(PgTypeInfo(PgType::DeclareWithName(elem))))
}
PgType::DeclareArrayOf(array) => Some(Cow::Owned(PgTypeInfo(PgType::DeclareWithName(
array.elem_name.clone(),
)))),
}
}
/// Returns `true` if this type cannot be matched by name.
fn is_declare_with_oid(&self) -> bool {
matches!(self, Self::DeclareWithOid(_))
}
/// Compare two `PgType`s, first by OID, then by array element, then by name.
///
/// If `soft_eq` is true and `self` or `other` is `DeclareWithOid` but not both, return `true`
/// before checking names.
fn eq_impl(&self, other: &Self, soft_eq: bool) -> bool {
if let (Some(a), Some(b)) = (self.try_oid(), other.try_oid()) {
// If there are OIDs available, use OIDs to perform a direct match
return a == b;
}
if soft_eq && (self.is_declare_with_oid() || other.is_declare_with_oid()) {
// If we get to this point, one instance is `DeclareWithOid()` and the other is
// `DeclareArrayOf()` or `DeclareWithName()`, which means we can't compare the two.
//
// Since this is only likely to occur when using the text protocol where we can't
// resolve type names before executing a query, we can just opt out of typechecking.
return true;
}
if let (Some(elem_a), Some(elem_b)) = (self.try_array_element(), other.try_array_element())
{
return elem_a == elem_b;
}
// Otherwise, perform a match on the name
name_eq(self.name(), other.name())
}
}
impl TypeInfo for PgTypeInfo {
@ -907,6 +1049,13 @@ impl TypeInfo for PgTypeInfo {
fn is_void(&self) -> bool {
matches!(self.0, PgType::Void)
}
fn type_compatible(&self, other: &Self) -> bool
where
Self: Sized,
{
self == other
}
}
impl PartialEq<PgCustomType> for PgCustomType {
@ -1140,22 +1289,7 @@ impl Display for PgTypeInfo {
impl PartialEq<PgType> for PgType {
fn eq(&self, other: &PgType) -> bool {
if let (Some(a), Some(b)) = (self.try_oid(), other.try_oid()) {
// If there are OIDs available, use OIDs to perform a direct match
a == b
} else if matches!(
(self, other),
(PgType::DeclareWithName(_), PgType::DeclareWithOid(_))
| (PgType::DeclareWithOid(_), PgType::DeclareWithName(_))
) {
// One is a declare-with-name and the other is a declare-with-id
// This only occurs in the TEXT protocol with custom types
// Just opt-out of type checking here
true
} else {
// Otherwise, perform a match on the name
name_eq(self.name(), other.name())
}
self.eq_impl(other, true)
}
}

View file

@ -156,11 +156,10 @@ where
T: Encode<'q, Postgres> + Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
let type_info = if self.is_empty() {
T::type_info()
} else {
self[0].produces().unwrap_or_else(T::type_info)
};
let type_info = self
.first()
.and_then(Encode::produces)
.unwrap_or_else(T::type_info);
buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags
@ -168,6 +167,7 @@ where
// element type
match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),
ty => {
buf.extend(&ty.oid().0.to_be_bytes());

View file

@ -5,7 +5,7 @@ use crate::from_row::FromRow;
use crate::logger::{BranchParent, BranchResult, DebugDiff};
use crate::type_info::DataType;
use crate::SqliteTypeInfo;
use sqlx_core::HashMap;
use sqlx_core::{hash_map, HashMap};
use std::fmt::Debug;
use std::str::from_utf8;
@ -535,13 +535,13 @@ impl BranchList {
) {
logger.add_branch(&state, &state.branch_parent.unwrap());
match self.visited_branch_state.entry(state.mem) {
std::collections::hash_map::Entry::Vacant(entry) => {
hash_map::Entry::Vacant(entry) => {
//this state is not identical to another state, so it will need to be processed
state.mem = entry.key().clone(); //replace state.mem since .entry() moved it
entry.insert(state.get_reference());
self.states.push(state);
}
std::collections::hash_map::Entry::Occupied(entry) => {
hash_map::Entry::Occupied(entry) => {
//already saw a state identical to this one, so no point in processing it
state.mem = entry.key().clone(); //replace state.mem since .entry() moved it
logger.add_result(state, BranchResult::Dedup(*entry.get()));

View file

@ -1,3 +1,4 @@
#![cfg(unix)]
use sqlx::migrate::Migrator;
use std::path::Path;

View file

@ -155,6 +155,9 @@ test_type!(weak_enum<Weak>(Postgres,
"0::int4" == Weak::One,
"2::int4" == Weak::Two,
"4::int4" == Weak::Three,
));
test_type!(weak_enum_array<Vec<Weak>>(Postgres,
"'{0, 2, 4}'::int4[]" == vec![Weak::One, Weak::Two, Weak::Three],
));
@ -162,7 +165,10 @@ test_type!(strong_enum<Strong>(Postgres,
"'one'::text" == Strong::One,
"'two'::text" == Strong::Two,
"'four'::text" == Strong::Three,
"'{'one', 'two', 'four'}'::text[]" == vec![Strong::One, Strong::Two, Strong::Three],
));
test_type!(strong_enum_array<Vec<Strong>>(Postgres,
"ARRAY['one', 'two', 'four']" == vec![Strong::One, Strong::Two, Strong::Three],
));
test_type!(floatrange<FloatRange>(Postgres,
@ -753,11 +759,13 @@ async fn test_enum_with_schema() -> anyhow::Result<()> {
assert_eq!(foo, Foo::Baz);
let foos: Vec<Foo> = sqlx::query_scalar!("SELECT ARRAY($1::foo.\"Foo\", $2::foo.\"Foo\")")
let foos: Vec<Foo> = sqlx::query_scalar("SELECT ARRAY[$1::foo.\"Foo\", $2::foo.\"Foo\"]")
.bind(Foo::Bar)
.bind(Foo::Baz)
.fetch_one(&mut conn)
.await?;
assert_eq!(foos, [Foo::Bar, Foo::Baz]);
Ok(())
}