fix(mysql): audit for bad casts
This commit is contained in:
parent
e99f0fa5b6
commit
1f669ae996
22 changed files with 228 additions and 120 deletions
|
@ -213,6 +213,8 @@ impl<'a> TryFrom<&'a AnyConnectOptions> for MySqlConnectOptions {
|
|||
fn map_result(result: MySqlQueryResult) -> AnyQueryResult {
|
||||
AnyQueryResult {
|
||||
rows_affected: result.rows_affected,
|
||||
// Don't expect this to be a problem
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
last_insert_id: Some(result.last_insert_id as i64),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ impl AuthPlugin {
|
|||
0x04 => {
|
||||
let payload = encrypt_rsa(stream, 0x02, password, nonce).await?;
|
||||
|
||||
stream.write_packet(&*payload);
|
||||
stream.write_packet(&*payload)?;
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(false)
|
||||
|
@ -143,7 +143,7 @@ async fn encrypt_rsa<'s>(
|
|||
}
|
||||
|
||||
// client sends a public key request
|
||||
stream.write_packet(&[public_key_request_id][..]);
|
||||
stream.write_packet(&[public_key_request_id][..])?;
|
||||
stream.flush().await?;
|
||||
|
||||
// server sends a public key response
|
||||
|
|
|
@ -131,7 +131,7 @@ impl<'a> DoHandshake<'a> {
|
|||
database: options.database.as_deref(),
|
||||
auth_plugin: plugin,
|
||||
auth_response: auth_response.as_deref(),
|
||||
});
|
||||
})?;
|
||||
|
||||
stream.flush().await?;
|
||||
|
||||
|
@ -160,7 +160,7 @@ impl<'a> DoHandshake<'a> {
|
|||
)
|
||||
.await?;
|
||||
|
||||
stream.write_packet(AuthSwitchResponse(response));
|
||||
stream.write_packet(AuthSwitchResponse(response))?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
|
||||
|
|
|
@ -187,7 +187,9 @@ impl MySqlConnection {
|
|||
// otherwise, this first packet is the start of the result-set metadata,
|
||||
*self.inner.stream.waiting.front_mut().unwrap() = Waiting::Row;
|
||||
|
||||
let num_columns = packet.get_uint_lenenc() as usize; // column count
|
||||
let num_columns = packet.get_uint_lenenc(); // column count
|
||||
let num_columns = usize::try_from(num_columns)
|
||||
.map_err(|_| err_protocol!("column count overflows usize: {num_columns}"))?;
|
||||
|
||||
if needs_metadata {
|
||||
column_names = Arc::new(recv_result_metadata(&mut self.inner.stream, num_columns, Arc::make_mut(&mut columns)).await?);
|
||||
|
|
|
@ -113,17 +113,17 @@ impl<S: Socket> MySqlStream<S> {
|
|||
T: ProtocolEncode<'en, Capabilities>,
|
||||
{
|
||||
self.sequence_id = 0;
|
||||
self.write_packet(payload);
|
||||
self.write_packet(payload)?;
|
||||
self.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn write_packet<'en, T>(&mut self, payload: T)
|
||||
pub(crate) fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
|
||||
where
|
||||
T: ProtocolEncode<'en, Capabilities>,
|
||||
{
|
||||
self.socket
|
||||
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
|
||||
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id))
|
||||
}
|
||||
|
||||
async fn recv_packet_part(&mut self) -> Result<Bytes, Error> {
|
||||
|
@ -132,6 +132,8 @@ impl<S: Socket> MySqlStream<S> {
|
|||
|
||||
let mut header: Bytes = self.socket.read(4).await?;
|
||||
|
||||
// cannot overflow
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let packet_size = header.get_uint_le(3) as usize;
|
||||
let sequence_id = header.get_u8();
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ pub(super) async fn maybe_upgrade<S: Socket>(
|
|||
stream.write_packet(SslRequest {
|
||||
max_packet_size: super::MAX_PACKET_SIZE,
|
||||
collation: stream.collation as u8,
|
||||
});
|
||||
})?;
|
||||
|
||||
stream.flush().await?;
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ pub trait MySqlBufExt: Buf {
|
|||
fn get_str_lenenc(&mut self) -> Result<String, Error>;
|
||||
|
||||
// Read a length-encoded byte sequence.
|
||||
fn get_bytes_lenenc(&mut self) -> Bytes;
|
||||
fn get_bytes_lenenc(&mut self) -> Result<Bytes, Error>;
|
||||
}
|
||||
|
||||
impl MySqlBufExt for Bytes {
|
||||
|
@ -31,11 +31,17 @@ impl MySqlBufExt for Bytes {
|
|||
|
||||
fn get_str_lenenc(&mut self) -> Result<String, Error> {
|
||||
let size = self.get_uint_lenenc();
|
||||
self.get_str(size as usize)
|
||||
let size = usize::try_from(size)
|
||||
.map_err(|_| err_protocol!("string length overflows usize: {size}"))?;
|
||||
|
||||
self.get_str(size)
|
||||
}
|
||||
|
||||
fn get_bytes_lenenc(&mut self) -> Bytes {
|
||||
fn get_bytes_lenenc(&mut self) -> Result<Bytes, Error> {
|
||||
let size = self.get_uint_lenenc();
|
||||
self.split_to(size as usize)
|
||||
let size = usize::try_from(size)
|
||||
.map_err(|_| err_protocol!("string length overflows usize: {size}"))?;
|
||||
|
||||
Ok(self.split_to(size))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,17 +13,22 @@ impl MySqlBufMutExt for Vec<u8> {
|
|||
// https://dev.mysql.com/doc/internals/en/integer.html
|
||||
// https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers
|
||||
|
||||
if v < 251 {
|
||||
self.push(v as u8);
|
||||
} else if v < 0x1_00_00 {
|
||||
self.push(0xfc);
|
||||
self.extend(&(v as u16).to_le_bytes());
|
||||
} else if v < 0x1_00_00_00 {
|
||||
self.push(0xfd);
|
||||
self.extend(&(v as u32).to_le_bytes()[..3]);
|
||||
} else {
|
||||
self.push(0xfe);
|
||||
self.extend(&v.to_le_bytes());
|
||||
let encoded_le = v.to_le_bytes();
|
||||
|
||||
match v {
|
||||
..251 => self.push(encoded_le[0]),
|
||||
251..0x1_00_00 => {
|
||||
self.push(0xfc);
|
||||
self.extend_from_slice(&encoded_le[..2]);
|
||||
}
|
||||
0x1_00_00..0x1_00_00_00 => {
|
||||
self.push(0xfd);
|
||||
self.extend_from_slice(&encoded_le[..3]);
|
||||
}
|
||||
_ => {
|
||||
self.push(0xfe);
|
||||
self.extend_from_slice(&encoded_le);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -231,7 +231,9 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations (
|
|||
WHERE version = ?
|
||||
"#,
|
||||
)
|
||||
.bind(elapsed.as_nanos() as i64)
|
||||
// Unlikely unless the execution time exceeds ~292 years,
|
||||
// then we're probably okay losing that information.
|
||||
.bind(i64::try_from(elapsed.as_nanos()).ok())
|
||||
.bind(migration.version)
|
||||
.execute(self)
|
||||
.await?;
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__capabilities__flags.html
|
||||
// https://mariadb.com/kb/en/library/connection/#capabilities
|
||||
//
|
||||
// MySQL defines the capabilities flags as fitting in an `int<4>` but MariaDB
|
||||
// extends this with more bits sent in a separate field.
|
||||
// For simplicity, we've chosen to combine these into one type.
|
||||
bitflags::bitflags! {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct Capabilities: u64 {
|
||||
|
@ -43,45 +47,65 @@ bitflags::bitflags! {
|
|||
const TRANSACTIONS = 8192;
|
||||
|
||||
// 4.1+ authentication
|
||||
const SECURE_CONNECTION = (1 << 15);
|
||||
const SECURE_CONNECTION = 1 << 15;
|
||||
|
||||
// Enable/disable multi-statement support for COM_QUERY *and* COM_STMT_PREPARE
|
||||
const MULTI_STATEMENTS = (1 << 16);
|
||||
const MULTI_STATEMENTS = 1 << 16;
|
||||
|
||||
// Enable/disable multi-results for COM_QUERY
|
||||
const MULTI_RESULTS = (1 << 17);
|
||||
const MULTI_RESULTS = 1 << 17;
|
||||
|
||||
// Enable/disable multi-results for COM_STMT_PREPARE
|
||||
const PS_MULTI_RESULTS = (1 << 18);
|
||||
const PS_MULTI_RESULTS = 1 << 18;
|
||||
|
||||
// Client supports plugin authentication
|
||||
const PLUGIN_AUTH = (1 << 19);
|
||||
const PLUGIN_AUTH = 1 << 19;
|
||||
|
||||
// Client supports connection attributes
|
||||
const CONNECT_ATTRS = (1 << 20);
|
||||
const CONNECT_ATTRS = 1 << 20;
|
||||
|
||||
// Enable authentication response packet to be larger than 255 bytes.
|
||||
const PLUGIN_AUTH_LENENC_DATA = (1 << 21);
|
||||
const PLUGIN_AUTH_LENENC_DATA = 1 << 21;
|
||||
|
||||
// Don't close the connection for a user account with expired password.
|
||||
const CAN_HANDLE_EXPIRED_PASSWORDS = (1 << 22);
|
||||
const CAN_HANDLE_EXPIRED_PASSWORDS = 1 << 22;
|
||||
|
||||
// Capable of handling server state change information.
|
||||
const SESSION_TRACK = (1 << 23);
|
||||
const SESSION_TRACK = 1 << 23;
|
||||
|
||||
// Client no longer needs EOF_Packet and will use OK_Packet instead.
|
||||
const DEPRECATE_EOF = (1 << 24);
|
||||
const DEPRECATE_EOF = 1 << 24;
|
||||
|
||||
// Support ZSTD protocol compression
|
||||
const ZSTD_COMPRESSION_ALGORITHM = (1 << 26);
|
||||
const ZSTD_COMPRESSION_ALGORITHM = 1 << 26;
|
||||
|
||||
// Verify server certificate
|
||||
const SSL_VERIFY_SERVER_CERT = (1 << 30);
|
||||
const SSL_VERIFY_SERVER_CERT = 1 << 30;
|
||||
|
||||
// The client can handle optional metadata information in the resultset
|
||||
const OPTIONAL_RESULTSET_METADATA = (1 << 25);
|
||||
const OPTIONAL_RESULTSET_METADATA = 1 << 25;
|
||||
|
||||
// Don't reset the options after an unsuccessful connect
|
||||
const REMEMBER_OPTIONS = (1 << 31);
|
||||
const REMEMBER_OPTIONS = 1 << 31;
|
||||
|
||||
// Extended capabilities (MariaDB only, as of writing)
|
||||
// Client support progress indicator (since 10.2)
|
||||
const MARIADB_CLIENT_PROGRESS = 1 << 32;
|
||||
|
||||
// Permit COM_MULTI protocol
|
||||
const MARIADB_CLIENT_MULTI = 1 << 33;
|
||||
|
||||
// Permit bulk insert
|
||||
const MARIADB_CLIENT_STMT_BULK_OPERATIONS = 1 << 34;
|
||||
|
||||
// Add extended metadata information
|
||||
const MARIADB_CLIENT_EXTENDED_TYPE_INFO = 1 << 35;
|
||||
|
||||
// Permit skipping metadata
|
||||
const MARIADB_CLIENT_CACHE_METADATA = 1 << 36;
|
||||
|
||||
// when enabled, indicate that Bulk command can use STMT_BULK_FLAG_SEND_UNIT_RESULTS flag
|
||||
// that permit to return a result-set of all affected rows and auto-increment values
|
||||
const MARIADB_CLIENT_BULK_UNIT_RESULTS = 1 << 37;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,8 +62,8 @@ impl ProtocolDecode<'_> for Handshake {
|
|||
}
|
||||
|
||||
let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) {
|
||||
let len = cmp::max((auth_plugin_data_len as isize) - 9, 12) as usize;
|
||||
let v = buf.get_bytes(len);
|
||||
let len = cmp::max(auth_plugin_data_len.saturating_sub(9), 12);
|
||||
let v = buf.get_bytes(len as usize);
|
||||
buf.advance(1); // NUL-terminator
|
||||
|
||||
v
|
||||
|
|
|
@ -52,7 +52,10 @@ impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> {
|
|||
} else if context.contains(Capabilities::SECURE_CONNECTION) {
|
||||
let response = self.auth_response.unwrap_or_default();
|
||||
|
||||
buf.push(response.len() as u8);
|
||||
let response_len = u8::try_from(response.len())
|
||||
.map_err(|_| err_protocol!("auth_response.len() too long: {}", response.len()))?;
|
||||
|
||||
buf.push(response_len);
|
||||
buf.extend(response);
|
||||
} else {
|
||||
buf.push(0);
|
||||
|
|
|
@ -12,6 +12,8 @@ pub struct SslRequest {
|
|||
|
||||
impl ProtocolEncode<'_, Capabilities> for SslRequest {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, context: Capabilities) -> Result<(), crate::Error> {
|
||||
// truncation is intended
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
buf.extend(&(context.bits() as u32).to_le_bytes());
|
||||
buf.extend(&self.max_packet_size.to_le_bytes());
|
||||
buf.push(self.collation);
|
||||
|
|
|
@ -40,6 +40,8 @@ where
|
|||
let len = buf.len() - offset - 4;
|
||||
let header = &mut buf[offset..];
|
||||
|
||||
// // `min(.., 0xFF_FF_FF)` cannot overflow
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
header[..4].copy_from_slice(&next_header(min(len, 0xFF_FF_FF) as u32));
|
||||
|
||||
// add more packets if we need to split the data
|
||||
|
@ -49,6 +51,9 @@ where
|
|||
|
||||
for chunk in chunks.by_ref() {
|
||||
buf.reserve(chunk.len() + 4);
|
||||
|
||||
// `chunk.len() == 0xFF_FF_FF`
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
buf.extend(&next_header(chunk.len() as u32));
|
||||
buf.extend(chunk);
|
||||
}
|
||||
|
@ -56,6 +61,9 @@ where
|
|||
// this will also handle adding a zero sized packet if the data size is a multiple of 0xFF_FF_FF
|
||||
let remainder = chunks.remainder();
|
||||
buf.reserve(remainder.len() + 4);
|
||||
|
||||
// `remainder.len() < 0xFF_FF_FF`
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
buf.extend(&next_header(remainder.len() as u32));
|
||||
buf.extend(remainder);
|
||||
}
|
||||
|
|
|
@ -34,8 +34,11 @@ impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for BinaryRow {
|
|||
for (column_idx, column) in columns.iter().enumerate() {
|
||||
// NOTE: the column index starts at the 3rd bit
|
||||
let column_null_idx = column_idx + 2;
|
||||
let is_null =
|
||||
null_bitmap[column_null_idx / 8] & (1 << (column_null_idx % 8) as u8) != 0;
|
||||
|
||||
let byte_idx = column_null_idx / 8;
|
||||
let bit_idx = column_null_idx % 8;
|
||||
|
||||
let is_null = null_bitmap[byte_idx] & (1u8 << bit_idx) != 0;
|
||||
|
||||
if is_null {
|
||||
values.push(None);
|
||||
|
@ -72,7 +75,11 @@ impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for BinaryRow {
|
|||
| ColumnType::Bit
|
||||
| ColumnType::Decimal
|
||||
| ColumnType::Json
|
||||
| ColumnType::NewDecimal => buf.get_uint_lenenc() as usize,
|
||||
| ColumnType::NewDecimal => {
|
||||
let size = buf.get_uint_lenenc();
|
||||
usize::try_from(size)
|
||||
.map_err(|_| err_protocol!("BLOB length out of range: {size}"))?
|
||||
}
|
||||
|
||||
// Like strings and blobs, these values are variable-length.
|
||||
// Unlike strings and blobs, however, they exclusively use one byte for length.
|
||||
|
|
|
@ -136,12 +136,12 @@ impl ColumnDefinition {
|
|||
|
||||
impl ProtocolDecode<'_, Capabilities> for ColumnDefinition {
|
||||
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
|
||||
let catalog = buf.get_bytes_lenenc();
|
||||
let schema = buf.get_bytes_lenenc();
|
||||
let table_alias = buf.get_bytes_lenenc();
|
||||
let table = buf.get_bytes_lenenc();
|
||||
let alias = buf.get_bytes_lenenc();
|
||||
let name = buf.get_bytes_lenenc();
|
||||
let catalog = buf.get_bytes_lenenc()?;
|
||||
let schema = buf.get_bytes_lenenc()?;
|
||||
let table_alias = buf.get_bytes_lenenc()?;
|
||||
let table = buf.get_bytes_lenenc()?;
|
||||
let alias = buf.get_bytes_lenenc()?;
|
||||
let name = buf.get_bytes_lenenc()?;
|
||||
let _next_len = buf.get_uint_lenenc(); // always 0x0c
|
||||
let collation = buf.get_u16_le();
|
||||
let max_size = buf.get_u32_le();
|
||||
|
|
|
@ -22,7 +22,10 @@ impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for TextRow {
|
|||
values.push(None);
|
||||
buf.advance(1);
|
||||
} else {
|
||||
let size = buf.get_uint_lenenc() as usize;
|
||||
let size = buf.get_uint_lenenc();
|
||||
let size = usize::try_from(size)
|
||||
.map_err(|_| err_protocol!("TextRow length out of range: {size}"))?;
|
||||
|
||||
let offset = offset - buf.len();
|
||||
|
||||
values.push(Some(offset..(offset + size)));
|
||||
|
|
|
@ -59,7 +59,8 @@ impl TransactionManager for MySqlTransactionManager {
|
|||
conn.inner.stream.sequence_id = 0;
|
||||
conn.inner
|
||||
.stream
|
||||
.write_packet(Query(&rollback_ansi_transaction_sql(depth)));
|
||||
.write_packet(Query(&rollback_ansi_transaction_sql(depth)))
|
||||
.expect("BUG: unexpected error queueing ROLLBACK");
|
||||
|
||||
conn.inner.transaction_depth = depth - 1;
|
||||
}
|
||||
|
|
|
@ -70,8 +70,8 @@ impl Type<MySql> for NaiveTime {
|
|||
|
||||
impl Encode<'_, MySql> for NaiveTime {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
|
||||
let len = Encode::<MySql>::size_hint(self) - 1;
|
||||
buf.push(len as u8);
|
||||
let len = naive_time_encoded_len(self);
|
||||
buf.push(len);
|
||||
|
||||
// NaiveTime is not negative
|
||||
buf.push(0);
|
||||
|
@ -80,19 +80,13 @@ impl Encode<'_, MySql> for NaiveTime {
|
|||
// https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding
|
||||
buf.extend_from_slice(&[0_u8; 4]);
|
||||
|
||||
encode_time(self, len > 9, buf);
|
||||
encode_time(self, len > 8, buf);
|
||||
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
if self.nanosecond() == 0 {
|
||||
// if micro_seconds is 0, length is 8 and micro_seconds is not sent
|
||||
9
|
||||
} else {
|
||||
// otherwise length is 12
|
||||
13
|
||||
}
|
||||
naive_time_encoded_len(self) as usize + 1 // plus length byte
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -217,38 +211,20 @@ impl Type<MySql> for NaiveDateTime {
|
|||
|
||||
impl Encode<'_, MySql> for NaiveDateTime {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
|
||||
let len = Encode::<MySql>::size_hint(self) - 1;
|
||||
buf.push(len as u8);
|
||||
let len = naive_dt_encoded_len(self);
|
||||
buf.push(len);
|
||||
|
||||
encode_date(&self.date(), buf)?;
|
||||
|
||||
if len > 4 {
|
||||
encode_time(&self.time(), len > 8, buf);
|
||||
encode_time(&self.time(), len > 7, buf);
|
||||
}
|
||||
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
// to save space the packet can be compressed:
|
||||
match (
|
||||
self.hour(),
|
||||
self.minute(),
|
||||
self.second(),
|
||||
#[allow(deprecated)]
|
||||
self.timestamp_subsec_nanos(),
|
||||
) {
|
||||
// if hour, minutes, seconds and micro_seconds are all 0,
|
||||
// length is 4 and no other field is sent
|
||||
(0, 0, 0, 0) => 5,
|
||||
|
||||
// if micro_seconds is 0, length is 7
|
||||
// and micro_seconds is not sent
|
||||
(_, _, _, 0) => 8,
|
||||
|
||||
// otherwise length is 11
|
||||
(_, _, _, _) => 12,
|
||||
}
|
||||
naive_dt_encoded_len(self) as usize + 1 // plus length byte
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -284,13 +260,18 @@ impl<'r> Decode<'r, MySql> for NaiveDateTime {
|
|||
}
|
||||
|
||||
fn encode_date(date: &NaiveDate, buf: &mut Vec<u8>) -> Result<(), BoxDynError> {
|
||||
// MySQL supports years from 1000 - 9999
|
||||
// MySQL supports years 1000 - 9999
|
||||
let year = u16::try_from(date.year())
|
||||
.map_err(|_| format!("NaiveDateTime out of range for Mysql: {date}"))?;
|
||||
|
||||
buf.extend_from_slice(&year.to_le_bytes());
|
||||
buf.push(date.month() as u8);
|
||||
buf.push(date.day() as u8);
|
||||
|
||||
// `NaiveDate` guarantees the ranges of these values
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
{
|
||||
buf.push(date.month() as u8);
|
||||
buf.push(date.day() as u8);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -314,9 +295,13 @@ fn decode_date(mut buf: &[u8]) -> Result<Option<NaiveDate>, BoxDynError> {
|
|||
}
|
||||
|
||||
fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec<u8>) {
|
||||
buf.push(time.hour() as u8);
|
||||
buf.push(time.minute() as u8);
|
||||
buf.push(time.second() as u8);
|
||||
// `NaiveTime` API guarantees the ranges of these values
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
{
|
||||
buf.push(time.hour() as u8);
|
||||
buf.push(time.minute() as u8);
|
||||
buf.push(time.second() as u8);
|
||||
}
|
||||
|
||||
if include_micros {
|
||||
buf.extend((time.nanosecond() / 1000).to_le_bytes());
|
||||
|
@ -335,6 +320,43 @@ fn decode_time(len: u8, mut buf: &[u8]) -> Result<NaiveTime, BoxDynError> {
|
|||
0
|
||||
};
|
||||
|
||||
NaiveTime::from_hms_micro_opt(hour as u32, minute as u32, seconds as u32, micros as u32)
|
||||
let micros = u32::try_from(micros)
|
||||
.map_err(|_| format!("server returned microseconds out of range: {micros}"))?;
|
||||
|
||||
NaiveTime::from_hms_micro_opt(hour as u32, minute as u32, seconds as u32, micros)
|
||||
.ok_or_else(|| format!("server returned invalid time: {hour:02}:{minute:02}:{seconds:02}; micros: {micros}").into())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn naive_dt_encoded_len(time: &NaiveDateTime) -> u8 {
|
||||
// to save space the packet can be compressed:
|
||||
match (
|
||||
time.hour(),
|
||||
time.minute(),
|
||||
time.second(),
|
||||
#[allow(deprecated)]
|
||||
time.timestamp_subsec_nanos(),
|
||||
) {
|
||||
// if hour, minutes, seconds and micro_seconds are all 0,
|
||||
// length is 4 and no other field is sent
|
||||
(0, 0, 0, 0) => 4,
|
||||
|
||||
// if micro_seconds is 0, length is 7
|
||||
// and micro_seconds is not sent
|
||||
(_, _, _, 0) => 7,
|
||||
|
||||
// otherwise length is 11
|
||||
(_, _, _, _) => 11,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn naive_time_encoded_len(time: &NaiveTime) -> u8 {
|
||||
if time.nanosecond() == 0 {
|
||||
// if micro_seconds is 0, length is 8 and micro_seconds is not sent
|
||||
8
|
||||
} else {
|
||||
// otherwise length is 12
|
||||
12
|
||||
}
|
||||
}
|
||||
|
|
|
@ -59,6 +59,7 @@ impl Decode<'_, MySql> for f32 {
|
|||
4 => LittleEndian::read_f32(buf),
|
||||
// MySQL can return 8-byte DOUBLE values for a FLOAT
|
||||
// We take and truncate to f32 as that's the same behavior as *in* MySQL,
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
8 => LittleEndian::read_f64(buf) as f32,
|
||||
other => {
|
||||
// Users may try to decode a DECIMAL as floating point;
|
||||
|
|
|
@ -617,6 +617,8 @@ fn parse_microseconds(micros: &str) -> Result<u32, BoxDynError> {
|
|||
len @ ..=EXPECTED_DIGITS => {
|
||||
// Fewer than 6 digits, multiply to the correct magnitude
|
||||
let micros: u32 = micros.parse()?;
|
||||
// cast cannot overflow
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
Ok(micros * 10u32.pow((EXPECTED_DIGITS - len) as u32))
|
||||
}
|
||||
// More digits than expected, truncate
|
||||
|
|
|
@ -47,29 +47,23 @@ impl Type<MySql> for Time {
|
|||
|
||||
impl Encode<'_, MySql> for Time {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
|
||||
let len = Encode::<MySql>::size_hint(self) - 1;
|
||||
buf.push(len as u8);
|
||||
let len = time_encoded_len(self);
|
||||
buf.push(len);
|
||||
|
||||
// Time is not negative
|
||||
// sign byte: Time is never negative
|
||||
buf.push(0);
|
||||
|
||||
// Number of days in the interval; always 0 for time-of-day values.
|
||||
// https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding
|
||||
buf.extend_from_slice(&[0_u8; 4]);
|
||||
|
||||
encode_time(self, len > 9, buf);
|
||||
encode_time(self, len > 8, buf);
|
||||
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
if self.nanosecond() == 0 {
|
||||
// if micro_seconds is 0, length is 8 and micro_seconds is not sent
|
||||
9
|
||||
} else {
|
||||
// otherwise length is 12
|
||||
13
|
||||
}
|
||||
time_encoded_len(self) as usize + 1 // plus length byte
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,6 +93,7 @@ impl TryFrom<MySqlTime> for Time {
|
|||
return Err(format!("MySqlTime value out of range for `time::Time`: {time}").into());
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
Ok(Time::from_hms_micro(
|
||||
// `is_valid_time_of_day()` ensures this won't overflow
|
||||
time.hours() as u8,
|
||||
|
@ -111,6 +106,8 @@ impl TryFrom<MySqlTime> for Time {
|
|||
|
||||
impl From<MySqlTime> for time::Duration {
|
||||
fn from(time: MySqlTime) -> Self {
|
||||
// `subsec_nanos()` is guaranteed to be between 0 and 10^9
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
time::Duration::new(time.whole_seconds_signed(), time.subsec_nanos() as i32)
|
||||
}
|
||||
}
|
||||
|
@ -191,32 +188,20 @@ impl Type<MySql> for PrimitiveDateTime {
|
|||
|
||||
impl Encode<'_, MySql> for PrimitiveDateTime {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
|
||||
let len = Encode::<MySql>::size_hint(self) - 1;
|
||||
buf.push(len as u8);
|
||||
let len = primitive_dt_encoded_len(self);
|
||||
buf.push(len);
|
||||
|
||||
encode_date(&self.date(), buf)?;
|
||||
|
||||
if len > 4 {
|
||||
encode_time(&self.time(), len > 8, buf);
|
||||
encode_time(&self.time(), len > 7, buf);
|
||||
}
|
||||
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
// to save space the packet can be compressed:
|
||||
match (self.hour(), self.minute(), self.second(), self.nanosecond()) {
|
||||
// if hour, minutes, seconds and micro_seconds are all 0,
|
||||
// length is 4 and no other field is sent
|
||||
(0, 0, 0, 0) => 5,
|
||||
|
||||
// if micro_seconds is 0, length is 7
|
||||
// and micro_seconds is not sent
|
||||
(_, _, _, 0) => 8,
|
||||
|
||||
// otherwise length is 11
|
||||
(_, _, _, _) => 12,
|
||||
}
|
||||
primitive_dt_encoded_len(self) as usize + 1 // plus length byte
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -316,6 +301,37 @@ fn decode_time(mut buf: &[u8]) -> Result<Time, BoxDynError> {
|
|||
0
|
||||
};
|
||||
|
||||
Time::from_hms_micro(hour, minute, seconds, micros as u32)
|
||||
let micros = u32::try_from(micros)
|
||||
.map_err(|_| format!("MySQL returned microseconds out of range: {micros}"))?;
|
||||
|
||||
Time::from_hms_micro(hour, minute, seconds, micros)
|
||||
.map_err(|e| format!("Time out of range for MySQL: {e}").into())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn primitive_dt_encoded_len(time: &PrimitiveDateTime) -> u8 {
|
||||
// to save space the packet can be compressed:
|
||||
match (time.hour(), time.minute(), time.second(), time.nanosecond()) {
|
||||
// if hour, minutes, seconds and micro_seconds are all 0,
|
||||
// length is 4 and no other field is sent
|
||||
(0, 0, 0, 0) => 4,
|
||||
|
||||
// if micro_seconds is 0, length is 7
|
||||
// and micro_seconds is not sent
|
||||
(_, _, _, 0) => 7,
|
||||
|
||||
// otherwise length is 11
|
||||
(_, _, _, _) => 11,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn time_encoded_len(time: &Time) -> u8 {
|
||||
if time.nanosecond() == 0 {
|
||||
// if micro_seconds is 0, length is 8 and micro_seconds is not sent
|
||||
8
|
||||
} else {
|
||||
// otherwise length is 12
|
||||
12
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue