use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::io::StatementId;
use crate::message::{ParameterDescription, RowDescription};
use crate::query_as::query_as;
use crate::query_scalar::query_scalar;
use crate::statement::PgStatementMetadata;
use crate::type_info::{PgArrayOf, PgCustomType, PgType, PgTypeKind};
use crate::types::Json;
use crate::types::Oid;
use crate::HashMap;
use crate::{PgColumn, PgConnection, PgTypeInfo};
use futures_core::future::BoxFuture;
use smallvec::SmallVec;
use sqlx_core::query_builder::QueryBuilder;
use std::sync::Arc;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum TypType {
Base,
Composite,
Domain,
Enum,
Pseudo,
Range,
}
impl TryFrom<i8> for TypType {
type Error = ();
fn try_from(t: i8) -> Result<Self, Self::Error> {
let t = u8::try_from(t).or(Err(()))?;
let t = match t {
b'b' => Self::Base,
b'c' => Self::Composite,
b'd' => Self::Domain,
b'e' => Self::Enum,
b'p' => Self::Pseudo,
b'r' => Self::Range,
_ => return Err(()),
};
Ok(t)
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum TypCategory {
Array,
Boolean,
Composite,
DateTime,
Enum,
Geometric,
Network,
Numeric,
Pseudo,
Range,
String,
Timespan,
User,
BitString,
Unknown,
}
impl TryFrom<i8> for TypCategory {
type Error = ();
fn try_from(c: i8) -> Result<Self, Self::Error> {
let c = u8::try_from(c).or(Err(()))?;
let c = match c {
b'A' => Self::Array,
b'B' => Self::Boolean,
b'C' => Self::Composite,
b'D' => Self::DateTime,
b'E' => Self::Enum,
b'G' => Self::Geometric,
b'I' => Self::Network,
b'N' => Self::Numeric,
b'P' => Self::Pseudo,
b'R' => Self::Range,
b'S' => Self::String,
b'T' => Self::Timespan,
b'U' => Self::User,
b'V' => Self::BitString,
b'X' => Self::Unknown,
_ => return Err(()),
};
Ok(c)
}
}
impl PgConnection {
pub(super) async fn handle_row_description(
&mut self,
desc: Option<RowDescription>,
should_fetch: bool,
) -> Result<(Vec<PgColumn>, HashMap<UStr, usize>), Error> {
let mut columns = Vec::new();
let mut column_names = HashMap::new();
let desc = if let Some(desc) = desc {
desc
} else {
return Ok((columns, column_names));
};
columns.reserve(desc.fields.len());
column_names.reserve(desc.fields.len());
for (index, field) in desc.fields.into_iter().enumerate() {
let name = UStr::from(field.name);
let type_info = self
.maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch)
.await?;
let column = PgColumn {
ordinal: index,
name: name.clone(),
type_info,
relation_id: field.relation_id,
relation_attribute_no: field.relation_attribute_no,
};
columns.push(column);
column_names.insert(name, index);
}
Ok((columns, column_names))
}
pub(super) async fn handle_parameter_description(
&mut self,
desc: ParameterDescription,
) -> Result<Vec<PgTypeInfo>, Error> {
let mut params = Vec::with_capacity(desc.types.len());
for ty in desc.types {
params.push(self.maybe_fetch_type_info_by_oid(ty, true).await?);
}
Ok(params)
}
async fn maybe_fetch_type_info_by_oid(
&mut self,
oid: Oid,
should_fetch: bool,
) -> Result<PgTypeInfo, Error> {
if let Some(info) = PgTypeInfo::try_from_oid(oid) {
return Ok(info);
}
if let Some(info) = self.cache_type_info.get(&oid) {
return Ok(info.clone());
}
if should_fetch {
let info = self.fetch_type_by_oid(oid).await?;
self.cache_type_info.insert(oid, info.clone());
self.cache_type_oid
.insert(info.0.name().to_string().into(), oid);
Ok(info)
} else {
Ok(PgTypeInfo(PgType::DeclareWithOid(oid)))
}
}
fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let (name, typ_type, category, relation_id, element, base_type): (
String,
i8,
i8,
Oid,
Oid,
Oid,
) = query_as(
"SELECT oid::regtype::text, \
typtype, \
typcategory, \
typrelid, \
typelem, \
typbasetype \
FROM pg_catalog.pg_type \
WHERE oid = $1",
)
.bind(oid)
.fetch_one(&mut *self)
.await?;
let typ_type = TypType::try_from(typ_type);
let category = TypCategory::try_from(category);
match (typ_type, category) {
(Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
(Ok(TypType::Base), Ok(TypCategory::Array)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Array(
self.maybe_fetch_type_info_by_oid(element, true).await?,
),
name: name.into(),
oid,
}))))
}
(Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Pseudo,
name: name.into(),
oid,
}))))
}
(Ok(TypType::Range), Ok(TypCategory::Range)) => {
self.fetch_range_by_oid(oid, name).await
}
(Ok(TypType::Enum), Ok(TypCategory::Enum)) => {
self.fetch_enum_by_oid(oid, name).await
}
(Ok(TypType::Composite), Ok(TypCategory::Composite)) => {
self.fetch_composite_by_oid(oid, relation_id, name).await
}
_ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Simple,
name: name.into(),
oid,
})))),
}
})
}
async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> {
let variants: Vec<String> = query_scalar(
r#"
SELECT enumlabel
FROM pg_catalog.pg_enum
WHERE enumtypid = $1
ORDER BY enumsortorder
"#,
)
.bind(oid)
.fetch_all(self)
.await?;
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
oid,
name: name.into(),
kind: PgTypeKind::Enum(Arc::from(variants)),
}))))
}
fn fetch_composite_by_oid(
&mut self,
oid: Oid,
relation_id: Oid,
name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let raw_fields: Vec<(String, Oid)> = query_as(
r#"
SELECT attname, atttypid
FROM pg_catalog.pg_attribute
WHERE attrelid = $1
AND NOT attisdropped
AND attnum > 0
ORDER BY attnum
"#,
)
.bind(relation_id)
.fetch_all(&mut *self)
.await?;
let mut fields = Vec::new();
for (field_name, field_oid) in raw_fields.into_iter() {
let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?;
fields.push((field_name, field_type));
}
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
oid,
name: name.into(),
kind: PgTypeKind::Composite(Arc::from(fields)),
}))))
})
}
fn fetch_domain_by_oid(
&mut self,
oid: Oid,
base_type: Oid,
name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
oid,
name: name.into(),
kind: PgTypeKind::Domain(base_type),
}))))
})
}
fn fetch_range_by_oid(
&mut self,
oid: Oid,
name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let element_oid: Oid = query_scalar(
r#"
SELECT rngsubtype
FROM pg_catalog.pg_range
WHERE rngtypid = $1
"#,
)
.bind(oid)
.fetch_one(&mut *self)
.await?;
let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?;
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Range(element),
name: name.into(),
oid,
}))))
})
}
pub(crate) async fn resolve_type_id(&mut self, ty: &PgType) -> Result<Oid, Error> {
if let Some(oid) = ty.try_oid() {
return Ok(oid);
}
match ty {
PgType::DeclareWithName(name) => self.fetch_type_id_by_name(name).await,
PgType::DeclareArrayOf(array) => self.fetch_array_type_id(array).await,
_ => unreachable!("(bug) OID should be resolvable for type {ty:?}"),
}
}
pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<Oid, Error> {
if let Some(oid) = self.cache_type_oid.get(name) {
return Ok(*oid);
}
let (oid,): (Oid,) = query_as("SELECT $1::regtype::oid")
.bind(name)
.fetch_optional(&mut *self)
.await?
.ok_or_else(|| Error::TypeNotFound {
type_name: name.into(),
})?;
self.cache_type_oid.insert(name.to_string().into(), oid);
Ok(oid)
}
pub(crate) async fn fetch_array_type_id(&mut self, array: &PgArrayOf) -> Result<Oid, Error> {
if let Some(oid) = self
.cache_type_oid
.get(&array.elem_name)
.and_then(|elem_oid| self.cache_elem_type_to_array.get(elem_oid))
{
return Ok(*oid);
}
let (elem_oid, array_oid): (Oid, Oid) =
query_as("SELECT oid, typarray FROM pg_catalog.pg_type WHERE oid = $1::regtype::oid")
.bind(&*array.elem_name)
.fetch_optional(&mut *self)
.await?
.ok_or_else(|| Error::TypeNotFound {
type_name: array.name.to_string(),
})?;
self.cache_type_oid
.entry_ref(&array.elem_name)
.insert(elem_oid);
self.cache_elem_type_to_array.insert(elem_oid, array_oid);
Ok(array_oid)
}
pub(crate) async fn get_nullable_for_columns(
&mut self,
stmt_id: StatementId,
meta: &PgStatementMetadata,
) -> Result<Vec<Option<bool>>, Error> {
if meta.columns.is_empty() {
return Ok(vec![]);
}
if meta.columns.len() * 3 > 65535 {
tracing::debug!(
?stmt_id,
num_columns = meta.columns.len(),
"number of columns in query is too large to pull nullability for"
);
}
let mut nullable_query = QueryBuilder::new("SELECT NOT pg_attribute.attnotnull FROM ( ");
nullable_query.push_values(meta.columns.iter().zip(0i32..), |mut tuple, (column, i)| {
tuple.push_bind(i).push_unseparated("::int4");
tuple
.push_bind(column.relation_id)
.push_unseparated("::int4");
tuple
.push_bind(column.relation_attribute_no)
.push_unseparated("::int2");
});
nullable_query.push(
") as col(idx, table_id, col_idx) \
LEFT JOIN pg_catalog.pg_attribute \
ON table_id IS NOT NULL \
AND attrelid = table_id \
AND attnum = col_idx \
ORDER BY col.idx",
);
let mut nullables: Vec<Option<bool>> = nullable_query
.build_query_scalar()
.fetch_all(&mut *self)
.await
.map_err(|e| {
err_protocol!(
"error from nullables query: {e}; query: {:?}",
nullable_query.sql()
)
})?;
if !self.stream.parameter_statuses.contains_key("crdb_version")
&& !self.stream.parameter_statuses.contains_key("mz_version")
{
let nullable_patch = self
.nullables_from_explain(stmt_id, meta.parameters.len())
.await?;
for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) {
*nullable = patch.or(*nullable);
}
}
Ok(nullables)
}
async fn nullables_from_explain(
&mut self,
stmt_id: StatementId,
params_len: usize,
) -> Result<Vec<Option<bool>>, Error> {
let stmt_id_display = stmt_id
.display()
.ok_or_else(|| err_protocol!("cannot EXPLAIN unnamed statement: {stmt_id:?}"))?;
let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id_display}");
let mut comma = false;
if params_len > 0 {
explain += "(";
for _ in 0..params_len {
if comma {
explain += ", ";
}
explain += "NULL";
comma = true;
}
explain += ")";
}
let (Json(explains),): (Json<SmallVec<[Explain; 1]>>,) =
query_as(&explain).fetch_one(self).await?;
let mut nullables = Vec::new();
if let Some(Explain::Plan {
plan:
plan @ Plan {
output: Some(ref outputs),
..
},
}) = explains.first()
{
nullables.resize(outputs.len(), None);
visit_plan(plan, outputs, &mut nullables);
}
Ok(nullables)
}
}
fn visit_plan(plan: &Plan, outputs: &[String], nullables: &mut Vec<Option<bool>>) {
if let Some(plan_outputs) = &plan.output {
if plan.join_type.as_deref() == Some("Full")
|| plan.parent_relation.as_deref() == Some("Inner")
{
for output in plan_outputs {
if let Some(i) = outputs.iter().position(|o| o == output) {
nullables[i] = Some(true);
}
}
}
}
if let Some(plans) = &plan.plans {
if let Some("Left") | Some("Right") = plan.join_type.as_deref() {
for plan in plans {
visit_plan(plan, outputs, nullables);
}
}
}
}
#[derive(serde::Deserialize, Debug)]
#[serde(untagged)]
enum Explain {
Plan {
#[serde(rename = "Plan")]
plan: Plan,
},
Other(serde::de::IgnoredAny),
}
#[derive(serde::Deserialize, Debug)]
struct Plan {
#[serde(rename = "Join Type")]
join_type: Option<String>,
#[serde(rename = "Parent Relationship")]
parent_relation: Option<String>,
#[serde(rename = "Output")]
output: Option<Vec<String>>,
#[serde(rename = "Plans")]
plans: Option<Vec<Plan>>,
}
#[test]
fn explain_parsing() {
let normal_plan = r#"[
{
"Plan": {
"Node Type": "Result",
"Parallel Aware": false,
"Async Capable": false,
"Startup Cost": 0.00,
"Total Cost": 0.01,
"Plan Rows": 1,
"Plan Width": 4,
"Output": ["1"]
}
}
]"#;
let extra_field = r#"[
{
"Plan": {
"Node Type": "Result",
"Parallel Aware": false,
"Async Capable": false,
"Startup Cost": 0.00,
"Total Cost": 0.01,
"Plan Rows": 1,
"Plan Width": 4,
"Output": ["1"]
},
"Query Identifier": 1147616880456321454
}
]"#;
let utility_statement = r#"["Utility Statement"]"#;
let normal_plan_parsed = serde_json::from_str::<[Explain; 1]>(normal_plan).unwrap();
let extra_field_parsed = serde_json::from_str::<[Explain; 1]>(extra_field).unwrap();
let utility_statement_parsed = serde_json::from_str::<[Explain; 1]>(utility_statement).unwrap();
assert!(
matches!(normal_plan_parsed, [Explain::Plan { plan: Plan { .. } }]),
"unexpected parse from {normal_plan:?}: {normal_plan_parsed:?}"
);
assert!(
matches!(extra_field_parsed, [Explain::Plan { plan: Plan { .. } }]),
"unexpected parse from {extra_field:?}: {extra_field_parsed:?}"
);
assert!(
matches!(utility_statement_parsed, [Explain::Other(_)]),
"unexpected parse from {utility_statement:?}: {utility_statement_parsed:?}"
)
}