sqlx_mysql/connection/
stream.rsuse std::collections::VecDeque;
use std::ops::{Deref, DerefMut};
use bytes::{Buf, Bytes, BytesMut};
use crate::collation::{CharSet, Collation};
use crate::error::Error;
use crate::io::MySqlBufExt;
use crate::io::{ProtocolDecode, ProtocolEncode};
use crate::net::{BufferedSocket, Socket};
use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
use crate::protocol::{Capabilities, Packet};
use crate::{MySqlConnectOptions, MySqlDatabaseError};
pub struct MySqlStream<S = Box<dyn Socket>> {
pub(crate) socket: BufferedSocket<S>,
pub(crate) server_version: (u16, u16, u16),
pub(super) capabilities: Capabilities,
pub(crate) sequence_id: u8,
pub(crate) waiting: VecDeque<Waiting>,
pub(crate) charset: CharSet,
pub(crate) collation: Collation,
pub(crate) is_tls: bool,
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum Waiting {
Result,
Row,
}
impl<S: Socket> MySqlStream<S> {
pub(crate) fn with_socket(
charset: CharSet,
collation: Collation,
options: &MySqlConnectOptions,
socket: S,
) -> Self {
let mut capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
| Capabilities::DEPRECATE_EOF
| Capabilities::FOUND_ROWS
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::PS_MULTI_RESULTS
| Capabilities::SSL;
if options.database.is_some() {
capabilities |= Capabilities::CONNECT_WITH_DB;
}
Self {
waiting: VecDeque::new(),
capabilities,
server_version: (0, 0, 0),
sequence_id: 0,
collation,
charset,
socket: BufferedSocket::new(socket),
is_tls: false,
}
}
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
if !self.socket.write_buffer().is_empty() {
self.socket.flush().await?;
}
while !self.waiting.is_empty() {
while self.waiting.front() == Some(&Waiting::Row) {
let packet = self.recv_packet().await?;
if !packet.is_empty() && packet[0] == 0xfe && packet.len() < 9 {
let eof = packet.eof(self.capabilities)?;
if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
*self.waiting.front_mut().unwrap() = Waiting::Result;
} else {
self.waiting.pop_front();
};
}
}
while self.waiting.front() == Some(&Waiting::Result) {
let packet = self.recv_packet().await?;
if !packet.is_empty() && (packet[0] == 0x00 || packet[0] == 0xff) {
let ok = packet.ok()?;
if !ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
self.waiting.pop_front();
}
} else {
*self.waiting.front_mut().unwrap() = Waiting::Row;
self.skip_result_metadata(packet).await?;
}
}
}
Ok(())
}
pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
where
T: ProtocolEncode<'en, Capabilities>,
{
self.sequence_id = 0;
self.write_packet(payload)?;
self.flush().await?;
Ok(())
}
pub(crate) fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
where
T: ProtocolEncode<'en, Capabilities>,
{
self.socket
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id))
}
async fn recv_packet_part(&mut self) -> Result<Bytes, Error> {
let mut header: Bytes = self.socket.read(4).await?;
#[allow(clippy::cast_possible_truncation)]
let packet_size = header.get_uint_le(3) as usize;
let sequence_id = header.get_u8();
self.sequence_id = sequence_id.wrapping_add(1);
let payload: Bytes = self.socket.read(packet_size).await?;
Ok(payload)
}
pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
let payload = self.recv_packet_part().await?;
let payload = if payload.len() < 0xFF_FF_FF {
payload
} else {
let mut final_payload = BytesMut::with_capacity(0xFF_FF_FF * 2);
final_payload.extend_from_slice(&payload);
drop(payload); let mut last_read = 0xFF_FF_FF;
while last_read == 0xFF_FF_FF {
let part = self.recv_packet_part().await?;
last_read = part.len();
final_payload.extend_from_slice(&part);
}
final_payload.into()
};
if payload
.first()
.ok_or(err_protocol!("Packet empty"))?
.eq(&0xff)
{
self.waiting.pop_front();
return Err(
MySqlDatabaseError(ErrPacket::decode_with(payload, self.capabilities)?).into(),
);
}
Ok(Packet(payload))
}
pub(crate) async fn recv<'de, T>(&mut self) -> Result<T, Error>
where
T: ProtocolDecode<'de, Capabilities>,
{
self.recv_packet().await?.decode_with(self.capabilities)
}
pub(crate) async fn recv_ok(&mut self) -> Result<OkPacket, Error> {
self.recv_packet().await?.ok()
}
pub(crate) async fn maybe_recv_eof(&mut self) -> Result<Option<EofPacket>, Error> {
if self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
Ok(None)
} else {
self.recv().await.map(Some)
}
}
async fn skip_result_metadata(&mut self, mut packet: Packet<Bytes>) -> Result<(), Error> {
let num_columns: u64 = packet.get_uint_lenenc(); for _ in 0..num_columns {
let _ = self.recv_packet().await?;
}
self.maybe_recv_eof().await?;
Ok(())
}
pub fn boxed_socket(self) -> MySqlStream {
MySqlStream {
socket: self.socket.boxed(),
server_version: self.server_version,
capabilities: self.capabilities,
sequence_id: self.sequence_id,
waiting: self.waiting,
charset: self.charset,
collation: self.collation,
is_tls: self.is_tls,
}
}
}
impl<S> Deref for MySqlStream<S> {
type Target = BufferedSocket<S>;
fn deref(&self) -> &Self::Target {
&self.socket
}
}
impl<S> DerefMut for MySqlStream<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.socket
}
}