From ef2527ff3ebf2505f7e6e506ccbc7e936721bea8 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Sun, 7 Jun 2020 01:59:59 -0700 Subject: [PATCH] feat(mssql): fix a few bugs and implement Connection::describe --- .gitattributes | 1 + Cargo.lock | 2 + Cargo.toml | 5 + sqlx-core/Cargo.toml | 4 +- sqlx-core/src/encode.rs | 71 ++++++----- sqlx-core/src/lib.rs | 8 +- sqlx-core/src/mssql/arguments.rs | 14 +++ sqlx-core/src/mssql/connection/establish.rs | 3 +- sqlx-core/src/mssql/connection/executor.rs | 126 ++++++++++++++++--- sqlx-core/src/mssql/connection/stream.rs | 5 +- sqlx-core/src/mssql/options.rs | 12 ++ sqlx-core/src/mssql/protocol/done.rs | 2 +- sqlx-core/src/mssql/protocol/message.rs | 4 + sqlx-core/src/mssql/protocol/mod.rs | 1 + sqlx-core/src/mssql/protocol/return_value.rs | 50 ++++++++ sqlx-core/src/mssql/protocol/type_info.rs | 24 ++-- sqlx-core/src/mssql/types/mod.rs | 35 ++++++ sqlx-core/src/mysql/mod.rs | 4 + sqlx-core/src/postgres/mod.rs | 4 + sqlx-core/src/sqlite/mod.rs | 4 + tests/.dockerignore | 2 + tests/docker-compose.yml | 12 +- tests/mssql/Dockerfile | 21 ++++ tests/mssql/configure-db.sh | 7 ++ tests/mssql/describe.rs | 37 ++++++ tests/mssql/entrypoint.sh | 7 ++ tests/mssql/setup.sql | 20 +++ 27 files changed, 424 insertions(+), 61 deletions(-) create mode 100644 .gitattributes create mode 100644 sqlx-core/src/mssql/protocol/return_value.rs create mode 100644 tests/mssql/Dockerfile create mode 100644 tests/mssql/configure-db.sh create mode 100644 tests/mssql/describe.rs create mode 100644 tests/mssql/entrypoint.sh create mode 100644 tests/mssql/setup.sql diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..6313b56c --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto eol=lf diff --git a/Cargo.lock b/Cargo.lock index 82349ed7..a9c21751 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2112,10 +2112,12 @@ dependencies = [ "md-5", "memchr", "num-bigint", + "once_cell", "parking_lot 0.10.2", "percent-encoding 2.1.0", "phf", "rand", + "regex", "serde", "serde_json", "sha-1", diff --git a/Cargo.toml b/Cargo.toml index d9577613..31f82bc7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -159,3 +159,8 @@ required-features = [ "mssql" ] name = "mssql-types" path = "tests/mssql/types.rs" required-features = [ "mssql" ] + +[[test]] +name = "mssql-describe" +path = "tests/mssql/describe.rs" +required-features = [ "mssql" ] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 105a39c8..5c4f3b66 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -19,7 +19,7 @@ default = [ "runtime-async-std" ] postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ] mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ] sqlite = [ "libsqlite3-sys" ] -mssql = [ "uuid", "encoding_rs" ] +mssql = [ "uuid", "encoding_rs", "regex" ] # types all-types = [ "chrono", "time", "bigdecimal", "ipnetwork", "json", "uuid" ] @@ -65,11 +65,13 @@ log = { version = "0.4.8", default-features = false } md-5 = { version = "0.8.0", default-features = false, optional = true } memchr = { version = "2.3.3", default-features = false } num-bigint = { version = "0.2.6", default-features = false, optional = true, features = [ "std" ] } +once_cell = "1.4.0" percent-encoding = "2.1.0" parking_lot = "0.10.2" threadpool = "*" phf = { version = "0.8.0", features = [ "macros" ] } rand = { version = "0.7.3", default-features = false, optional = true, features = [ "std" ] } +regex = { version = "1.3.9", optional = true } serde = { version = "1.0.106", features = [ "derive", "rc" ], optional = true } serde_json = { version = "1.0.51", features = [ "raw_value" ], optional = true } sha-1 = { version = "0.8.2", default-features = false, optional = true } diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index e8fa3dd9..ee75931d 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -68,36 +68,49 @@ where } } -impl<'q, T: 'q + Encode<'q, DB>, DB: Database> Encode<'q, DB> for Option { - #[inline] - fn produces(&self) -> DB::TypeInfo { - if let Some(v) = self { - v.produces() - } else { - T::type_info() - } - } +#[allow(unused_macros)] +macro_rules! impl_encode_for_option { + ($DB:ident) => { + impl<'q, T: 'q + crate::encode::Encode<'q, $DB>> crate::encode::Encode<'q, $DB> + for Option + { + #[inline] + fn produces(&self) -> <$DB as crate::database::Database>::TypeInfo { + if let Some(v) = self { + v.produces() + } else { + T::type_info() + } + } - #[inline] - fn encode(self, buf: &mut >::ArgumentBuffer) -> IsNull { - if let Some(v) = self { - v.encode(buf) - } else { - IsNull::Yes - } - } + #[inline] + fn encode( + self, + buf: &mut <$DB as crate::database::HasArguments<'q>>::ArgumentBuffer, + ) -> crate::encode::IsNull { + if let Some(v) = self { + v.encode(buf) + } else { + crate::encode::IsNull::Yes + } + } - #[inline] - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { - if let Some(v) = self { - v.encode_by_ref(buf) - } else { - IsNull::Yes - } - } + #[inline] + fn encode_by_ref( + &self, + buf: &mut <$DB as crate::database::HasArguments<'q>>::ArgumentBuffer, + ) -> crate::encode::IsNull { + if let Some(v) = self { + v.encode_by_ref(buf) + } else { + crate::encode::IsNull::Yes + } + } - #[inline] - fn size_hint(&self) -> usize { - self.as_ref().map_or(0, Encode::size_hint) - } + #[inline] + fn size_hint(&self) -> usize { + self.as_ref().map_or(0, crate::encode::Encode::size_hint) + } + } + }; } diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 315ab346..b8e1876c 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -31,10 +31,12 @@ pub mod connection; #[macro_use] pub mod transaction; +#[macro_use] +pub mod encode; + pub mod database; pub mod decode; pub mod describe; -pub mod encode; pub mod executor; mod ext; pub mod from_row; @@ -59,3 +61,7 @@ pub mod sqlite; #[cfg(feature = "mysql")] #[cfg_attr(docsrs, doc(cfg(feature = "mysql")))] pub mod mysql; + +#[cfg(feature = "mssql")] +#[cfg_attr(docsrs, doc(cfg(feature = "mssql")))] +pub mod mssql; diff --git a/sqlx-core/src/mssql/arguments.rs b/sqlx-core/src/mssql/arguments.rs index 81506f07..0c92ae9e 100644 --- a/sqlx-core/src/mssql/arguments.rs +++ b/sqlx-core/src/mssql/arguments.rs @@ -2,6 +2,7 @@ use crate::arguments::Arguments; use crate::encode::Encode; use crate::mssql::database::MsSql; use crate::mssql::io::MsSqlBufMutExt; +use crate::mssql::protocol::rpc::StatusFlags; #[derive(Default)] pub struct MsSqlArguments { @@ -31,6 +32,19 @@ impl MsSqlArguments { self.add_named("", value); } + pub(crate) fn declare<'q, T: Encode<'q, MsSql>>(&mut self, name: &str, initial_value: T) { + let ty = initial_value.produces(); + + let mut ty_name = String::new(); + ty.0.fmt(&mut ty_name); + + self.data.put_b_varchar(name); // [ParamName] + self.data.push(StatusFlags::BY_REF_VALUE.bits()); // [StatusFlags] + + ty.0.put(&mut self.data); // [TYPE_INFO] + ty.0.put_value(&mut self.data, initial_value); // [ParamLenData] + } + pub(crate) fn append(&mut self, arguments: &mut MsSqlArguments) { self.ordinal += arguments.ordinal; self.data.append(&mut arguments.data); diff --git a/sqlx-core/src/mssql/connection/establish.rs b/sqlx-core/src/mssql/connection/establish.rs index 1936da79..85235196 100644 --- a/sqlx-core/src/mssql/connection/establish.rs +++ b/sqlx-core/src/mssql/connection/establish.rs @@ -49,8 +49,7 @@ impl MsSqlConnection { server_name: "", client_interface_name: "", language: "", - // FIXME: connect this to options.database - database: "", + database: &*options.database, client_id: [0; 6], }, ); diff --git a/sqlx-core/src/mssql/connection/executor.rs b/sqlx-core/src/mssql/connection/executor.rs index 17542aa2..9a0c0399 100644 --- a/sqlx-core/src/mssql/connection/executor.rs +++ b/sqlx-core/src/mssql/connection/executor.rs @@ -3,16 +3,19 @@ use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::TryStreamExt; +use once_cell::sync::Lazy; +use regex::Regex; -use crate::describe::Describe; +use crate::describe::{Column, Describe}; use crate::error::Error; use crate::executor::{Execute, Executor}; -use crate::mssql::protocol::done::Done; +use crate::mssql::protocol::col_meta_data::Flags; +use crate::mssql::protocol::done::{Done, Status}; use crate::mssql::protocol::message::Message; use crate::mssql::protocol::packet::PacketType; use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest}; use crate::mssql::protocol::sql_batch::SqlBatch; -use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow}; +use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow, MsSqlTypeInfo}; impl MsSqlConnection { pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { @@ -25,8 +28,10 @@ impl MsSqlConnection { let message = self.stream.recv_message().await?; if let Message::DoneProc(done) | Message::Done(done) = message { - // finished RPC procedure *OR* SQL batch - self.handle_done(done); + if !done.status.contains(Status::DONE_MORE) { + // finished RPC procedure *OR* SQL batch + self.handle_done(done); + } } } @@ -106,20 +111,23 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection { yield v; } - Message::DoneProc(done) => { - self.handle_done(done); - break; + Message::Done(done) | Message::DoneProc(done) => { + if done.status.contains(Status::DONE_COUNT) { + let v = Either::Left(done.affected_rows); + yield v; + } + + if !done.status.contains(Status::DONE_MORE) { + self.handle_done(done); + break; + } } Message::DoneInProc(done) => { - // finished SQL query *within* procedure - let v = Either::Left(done.affected_rows); - yield v; - } - - Message::Done(done) => { - self.handle_done(done); - break; + if done.status.contains(Status::DONE_COUNT) { + let v = Either::Left(done.affected_rows); + yield v; + } } _ => {} @@ -157,6 +165,90 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection { 'c: 'e, E: Execute<'q, Self::Database>, { - unimplemented!() + let s = query.query(); + + // [sp_prepare] will emit the column meta data + // small issue is that we need to declare all the used placeholders with a "fallback" type + // we currently use regex to collect them; false positives are *okay* but false + // negatives would break the query + let proc = Either::Right(Procedure::Prepare); + + // NOTE: this does not support unicode identifiers; as we don't even support + // named parameters (yet) this is probably fine, for now + + static PARAMS_RE: Lazy = Lazy::new(|| Regex::new(r"@p[[:alnum:]]+").unwrap()); + + let mut params = String::new(); + let mut num_params = 0; + + for m in PARAMS_RE.captures_iter(s) { + if !params.is_empty() { + params.push_str(","); + } + + params.push_str(&m[0]); + + // NOTE: this means that a query! of `SELECT @p1` will have the macros believe + // it will return nvarchar(1); this is a greater issue with `query!` that we + // we need to circle back to. This doesn't happen much in practice however. + params.push_str(" nvarchar(1)"); + + num_params += 1; + } + + let params = if params.is_empty() { + None + } else { + Some(&*params) + }; + + let mut args = MsSqlArguments::default(); + + args.declare("", 0_i32); + args.add_unnamed(params); + args.add_unnamed(s); + args.add_unnamed(0x0001_i32); // 1 = SEND_METADATA + + self.stream.write_packet( + PacketType::Rpc, + RpcRequest { + transaction_descriptor: self.stream.transaction_descriptor, + arguments: &args, + procedure: proc, + options: OptionFlags::empty(), + }, + ); + + Box::pin(async move { + self.stream.flush().await?; + + loop { + match self.stream.recv_message().await? { + Message::DoneProc(done) | Message::Done(done) => { + if !done.status.contains(Status::DONE_MORE) { + // done with prepare + break; + } + } + + _ => {} + } + } + + let mut columns = Vec::with_capacity(self.stream.columns.len()); + + for col in &self.stream.columns { + columns.push(Column { + name: col.col_name.clone(), + type_info: Some(MsSqlTypeInfo(col.type_info.clone())), + not_null: Some(!col.flags.contains(Flags::NULLABLE)), + }); + } + + Ok(Describe { + params: vec![None; num_params], + columns, + }) + }) } } diff --git a/sqlx-core/src/mssql/connection/stream.rs b/sqlx-core/src/mssql/connection/stream.rs index 710e2d86..2f80aba4 100644 --- a/sqlx-core/src/mssql/connection/stream.rs +++ b/sqlx-core/src/mssql/connection/stream.rs @@ -14,6 +14,7 @@ use crate::mssql::protocol::login_ack::LoginAck; use crate::mssql::protocol::message::{Message, MessageType}; use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status}; use crate::mssql::protocol::return_status::ReturnStatus; +use crate::mssql::protocol::return_value::ReturnValue; use crate::mssql::protocol::row::Row; use crate::mssql::{MsSqlConnectOptions, MsSqlDatabaseError}; use crate::net::MaybeTlsStream; @@ -30,7 +31,7 @@ pub(crate) struct MsSqlStream { // most recent column data from ColMetaData // we need to store this as its needed when decoding - columns: Vec, + pub(crate) columns: Vec, } impl MsSqlStream { @@ -112,6 +113,7 @@ impl MsSqlStream { }; let ty = MessageType::get(buf)?; + let message = match ty { MessageType::EnvChange => { match EnvChange::get(buf)? { @@ -137,6 +139,7 @@ impl MsSqlStream { MessageType::Row => Message::Row(Row::get(buf, &self.columns)?), MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?), MessageType::ReturnStatus => Message::ReturnStatus(ReturnStatus::get(buf)?), + MessageType::ReturnValue => Message::ReturnValue(ReturnValue::get(buf)?), MessageType::Done => Message::Done(Done::get(buf)?), MessageType::DoneInProc => Message::DoneInProc(Done::get(buf)?), MessageType::DoneProc => Message::DoneProc(Done::get(buf)?), diff --git a/sqlx-core/src/mssql/options.rs b/sqlx-core/src/mssql/options.rs index e3848b58..2f3f30b8 100644 --- a/sqlx-core/src/mssql/options.rs +++ b/sqlx-core/src/mssql/options.rs @@ -9,6 +9,7 @@ pub struct MsSqlConnectOptions { pub(crate) host: String, pub(crate) port: u16, pub(crate) username: String, + pub(crate) database: String, pub(crate) password: Option, } @@ -23,6 +24,7 @@ impl MsSqlConnectOptions { Self { port: 1433, host: String::from("localhost"), + database: String::from("master"), username: String::from("sa"), password: None, } @@ -47,6 +49,11 @@ impl MsSqlConnectOptions { self.password = Some(password.to_owned()); self } + + pub fn database(mut self, database: &str) -> Self { + self.database = database.to_owned(); + self + } } impl FromStr for MsSqlConnectOptions { @@ -73,6 +80,11 @@ impl FromStr for MsSqlConnectOptions { options = options.password(password); } + let path = url.path().trim_start_matches('/'); + if !path.is_empty() { + options = options.database(path); + } + Ok(options) } } diff --git a/sqlx-core/src/mssql/protocol/done.rs b/sqlx-core/src/mssql/protocol/done.rs index 5543ab73..a6ac624b 100644 --- a/sqlx-core/src/mssql/protocol/done.rs +++ b/sqlx-core/src/mssql/protocol/done.rs @@ -5,7 +5,7 @@ use crate::error::Error; #[derive(Debug)] pub(crate) struct Done { - status: Status, + pub(crate) status: Status, // The token of the current SQL statement. The token value is provided and controlled by the // application layer, which utilizes TDS. The TDS layer does not evaluate the value. diff --git a/sqlx-core/src/mssql/protocol/message.rs b/sqlx-core/src/mssql/protocol/message.rs index d04c99a5..22727677 100644 --- a/sqlx-core/src/mssql/protocol/message.rs +++ b/sqlx-core/src/mssql/protocol/message.rs @@ -3,6 +3,7 @@ use bytes::{Buf, Bytes}; use crate::mssql::protocol::done::Done; use crate::mssql::protocol::login_ack::LoginAck; use crate::mssql::protocol::return_status::ReturnStatus; +use crate::mssql::protocol::return_value::ReturnValue; use crate::mssql::protocol::row::Row; #[derive(Debug)] @@ -13,6 +14,7 @@ pub(crate) enum Message { DoneProc(Done), Row(Row), ReturnStatus(ReturnStatus), + ReturnValue(ReturnValue), } #[derive(Debug)] @@ -27,6 +29,7 @@ pub(crate) enum MessageType { Error, ColMetaData, ReturnStatus, + ReturnValue, } impl MessageType { @@ -35,6 +38,7 @@ impl MessageType { 0x81 => MessageType::ColMetaData, 0xaa => MessageType::Error, 0xab => MessageType::Info, + 0xac => MessageType::ReturnValue, 0xad => MessageType::LoginAck, 0xd1 => MessageType::Row, 0xe3 => MessageType::EnvChange, diff --git a/sqlx-core/src/mssql/protocol/mod.rs b/sqlx-core/src/mssql/protocol/mod.rs index 31ef0561..03196738 100644 --- a/sqlx-core/src/mssql/protocol/mod.rs +++ b/sqlx-core/src/mssql/protocol/mod.rs @@ -10,6 +10,7 @@ pub(crate) mod message; pub(crate) mod packet; pub(crate) mod pre_login; pub(crate) mod return_status; +pub(crate) mod return_value; pub(crate) mod row; pub(crate) mod rpc; pub(crate) mod sql_batch; diff --git a/sqlx-core/src/mssql/protocol/return_value.rs b/sqlx-core/src/mssql/protocol/return_value.rs new file mode 100644 index 00000000..edec80c5 --- /dev/null +++ b/sqlx-core/src/mssql/protocol/return_value.rs @@ -0,0 +1,50 @@ +use bitflags::bitflags; +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::mssql::io::MsSqlBufExt; +use crate::mssql::protocol::col_meta_data::Flags; +use crate::mssql::protocol::type_info::TypeInfo; + +#[derive(Debug)] +pub(crate) struct ReturnValue { + param_ordinal: u16, + param_name: String, + status: ReturnValueStatus, + user_type: u32, + flags: Flags, + type_info: TypeInfo, + value: Bytes, +} + +bitflags! { + pub(crate) struct ReturnValueStatus: u8 { + // If ReturnValue corresponds to OUTPUT parameter of a stored procedure invocation + const OUTPUT_PARAM = 0x01; + + // If ReturnValue corresponds to return value of User Defined Function. + const USER_DEFINED = 0x02; + } +} + +impl ReturnValue { + pub(crate) fn get(buf: &mut Bytes) -> Result { + let ordinal = buf.get_u16_le(); + let name = buf.get_b_varchar()?; + let status = ReturnValueStatus::from_bits_truncate(buf.get_u8()); + let user_type = buf.get_u32_le(); + let flags = Flags::from_bits_truncate(buf.get_u16_le()); + let type_info = TypeInfo::get(buf)?; + let value = type_info.get_value(buf); + + Ok(Self { + param_ordinal: ordinal, + param_name: name, + status, + user_type, + flags, + type_info, + value, + }) + } +} diff --git a/sqlx-core/src/mssql/protocol/type_info.rs b/sqlx-core/src/mssql/protocol/type_info.rs index 6109b97a..5336fdf9 100644 --- a/sqlx-core/src/mssql/protocol/type_info.rs +++ b/sqlx-core/src/mssql/protocol/type_info.rs @@ -2,7 +2,7 @@ use bitflags::bitflags; use bytes::{Buf, Bytes}; use encoding_rs::Encoding; -use crate::encode::Encode; +use crate::encode::{Encode, IsNull}; use crate::error::Error; use crate::mssql::MsSql; @@ -413,9 +413,13 @@ impl TypeInfo { let offset = buf.len(); buf.push(0); - let _ = value.encode(buf); + let size = if let IsNull::Yes = value.encode(buf) { + 0xFF + } else { + (buf.len() - offset - 1) as u8 + }; - buf[offset] = (buf.len() - offset - 1) as u8; + buf[offset] = size; } pub(crate) fn put_short_len_value<'q, T: Encode<'q, MsSql>>( @@ -426,9 +430,12 @@ impl TypeInfo { let offset = buf.len(); buf.extend(&0_u16.to_le_bytes()); - let _ = value.encode(buf); + let size = if let IsNull::Yes = value.encode(buf) { + 0xFFFF + } else { + (buf.len() - offset - 2) as u16 + }; - let size = (buf.len() - offset - 2) as u16; buf[offset..(offset + 2)].copy_from_slice(&size.to_le_bytes()); } @@ -436,9 +443,12 @@ impl TypeInfo { let offset = buf.len(); buf.extend(&0_u32.to_le_bytes()); - let _ = value.encode(buf); + let size = if let IsNull::Yes = value.encode(buf) { + 0xFFFF_FFFF + } else { + (buf.len() - offset - 4) as u32 + }; - let size = (buf.len() - offset - 4) as u32; buf[offset..(offset + 4)].copy_from_slice(&size.to_le_bytes()); } diff --git a/sqlx-core/src/mssql/types/mod.rs b/sqlx-core/src/mssql/types/mod.rs index 653c149c..6dc4ec52 100644 --- a/sqlx-core/src/mssql/types/mod.rs +++ b/sqlx-core/src/mssql/types/mod.rs @@ -1,3 +1,38 @@ +use crate::encode::{Encode, IsNull}; +use crate::mssql::protocol::type_info::{DataType, TypeInfo}; +use crate::mssql::{MsSql, MsSqlTypeInfo}; + mod float; mod int; mod str; + +impl<'q, T: 'q + Encode<'q, MsSql>> Encode<'q, MsSql> for Option { + fn produces(&self) -> MsSqlTypeInfo { + if let Some(v) = self { + v.produces() + } else { + // MSSQL requires a special NULL type ID + MsSqlTypeInfo(TypeInfo::new(DataType::Null, 0)) + } + } + + fn encode(self, buf: &mut Vec) -> IsNull { + if let Some(v) = self { + v.encode(buf) + } else { + IsNull::Yes + } + } + + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + if let Some(v) = self { + v.encode_by_ref(buf) + } else { + IsNull::Yes + } + } + + fn size_hint(&self) -> usize { + self.as_ref().map_or(0, Encode::size_hint) + } +} diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index 114a0004..64a5bbba 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -30,3 +30,7 @@ pub type MySqlPool = crate::pool::Pool; impl_into_arguments_for_arguments!(MySqlArguments); impl_executor_for_pool_connection!(MySql, MySqlConnection, MySqlRow); impl_executor_for_transaction!(MySql, MySqlRow); + +// required because some databases have a different handling +// of NULL +impl_encode_for_option!(MySql); diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 77b5595a..fde4a38c 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -33,3 +33,7 @@ pub type PgPool = crate::pool::Pool; impl_into_arguments_for_arguments!(PgArguments); impl_executor_for_pool_connection!(Postgres, PgConnection, PgRow); impl_executor_for_transaction!(Postgres, PgRow); + +// required because some databases have a different handling +// of NULL +impl_encode_for_option!(Postgres); diff --git a/sqlx-core/src/sqlite/mod.rs b/sqlx-core/src/sqlite/mod.rs index e56c97fb..b6b547db 100644 --- a/sqlx-core/src/sqlite/mod.rs +++ b/sqlx-core/src/sqlite/mod.rs @@ -34,3 +34,7 @@ pub type SqlitePool = crate::pool::Pool; impl_into_arguments_for_arguments!(SqliteArguments<'q>); impl_executor_for_pool_connection!(Sqlite, SqliteConnection, SqliteRow); impl_executor_for_transaction!(Sqlite, SqliteRow); + +// required because some databases have a different handling +// of NULL +impl_encode_for_option!(Postgres); diff --git a/tests/.dockerignore b/tests/.dockerignore index 1557a10f..6c513a8a 100644 --- a/tests/.dockerignore +++ b/tests/.dockerignore @@ -1,3 +1,5 @@ * !certs/* !keys/* +!mssql/*.sh +!*/*.sql diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 8cc6a603..2ad2a122 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -180,13 +180,21 @@ services: # mssql_2019: - image: mcr.microsoft.com/mssql/server:2019-latest + build: + context: . + dockerfile: mssql/Dockerfile + args: + VERSION: 2019-latest environment: ACCEPT_EULA: Y SA_PASSWORD: Password123! mssql_2017: - image: mcr.microsoft.com/mssql/server:2017-latest + build: + context: . + dockerfile: mssql/Dockerfile + args: + VERSION: 2017-latest environment: ACCEPT_EULA: Y SA_PASSWORD: Password123! diff --git a/tests/mssql/Dockerfile b/tests/mssql/Dockerfile new file mode 100644 index 00000000..6c2389d8 --- /dev/null +++ b/tests/mssql/Dockerfile @@ -0,0 +1,21 @@ +ARG VERSION +FROM mcr.microsoft.com/mssql/server:${VERSION} + +# Create a config directory +RUN mkdir -p /usr/config +WORKDIR /usr/config + +# Bundle config source +COPY mssql/entrypoint.sh /usr/config/entrypoint.sh +COPY mssql/configure-db.sh /usr/config/configure-db.sh +COPY mssql/setup.sql /usr/config/setup.sql + +# Grant permissions for to our scripts to be executable +USER root +RUN chmod +x /usr/config/entrypoint.sh +RUN chmod +x /usr/config/configure-db.sh +RUN chown 10001 /usr/config/entrypoint.sh +RUN chown 10001 /usr/config/configure-db.sh +USER 10001 + +ENTRYPOINT ["/usr/config/entrypoint.sh"] diff --git a/tests/mssql/configure-db.sh b/tests/mssql/configure-db.sh new file mode 100644 index 00000000..654cab45 --- /dev/null +++ b/tests/mssql/configure-db.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash + +# Wait 60 seconds for SQL Server to start up +sleep 60 + +# Run the setup script to create the DB and the schema in the DB +/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P $SA_PASSWORD -d master -i setup.sql diff --git a/tests/mssql/describe.rs b/tests/mssql/describe.rs new file mode 100644 index 00000000..8bb4b3b9 --- /dev/null +++ b/tests/mssql/describe.rs @@ -0,0 +1,37 @@ +use sqlx::mssql::MsSql; +use sqlx::{describe::Column, Executor}; +use sqlx_test::new; + +fn type_names(columns: &[Column]) -> Vec { + columns + .iter() + .filter_map(|col| Some(col.type_info.as_ref()?.to_string())) + .collect() +} + +#[sqlx_macros::test] +async fn it_describes_simple() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let d = conn.describe("SELECT * FROM tweet").await?; + let columns = d.columns; + + assert_eq!(columns[0].name, "id"); + assert_eq!(columns[1].name, "text"); + assert_eq!(columns[2].name, "is_sent"); + assert_eq!(columns[3].name, "owner_id"); + + assert_eq!(columns[0].not_null, Some(true)); + assert_eq!(columns[1].not_null, Some(true)); + assert_eq!(columns[2].not_null, Some(true)); + assert_eq!(columns[3].not_null, Some(false)); + + let column_type_names = type_names(&columns); + + assert_eq!(column_type_names[0], "bigint"); + assert_eq!(column_type_names[1], "nvarchar(max)"); + assert_eq!(column_type_names[2], "tinyint"); + assert_eq!(column_type_names[3], "bigint"); + + Ok(()) +} diff --git a/tests/mssql/entrypoint.sh b/tests/mssql/entrypoint.sh new file mode 100644 index 00000000..c4b5e45c --- /dev/null +++ b/tests/mssql/entrypoint.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash + +# Start the script to create the DB and user +/usr/config/configure-db.sh & + +# Start SQL Server +/opt/mssql/bin/sqlservr diff --git a/tests/mssql/setup.sql b/tests/mssql/setup.sql new file mode 100644 index 00000000..a033227b --- /dev/null +++ b/tests/mssql/setup.sql @@ -0,0 +1,20 @@ +IF DB_ID('sqlx') IS NULL + BEGIN + CREATE DATABASE sqlx; + END; +GO + +USE sqlx; +GO + +IF OBJECT_ID('tweet') IS NULL + BEGIN + CREATE TABLE tweet + ( + id BIGINT NOT NULL PRIMARY KEY, + text NVARCHAR(4000) NOT NULL, + is_sent TINYINT NOT NULL DEFAULT 1, + owner_id BIGINT + ); + END; +GO