sqlx_postgres/types/numeric.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
use sqlx_core::bytes::Buf;
use std::num::Saturating;
use crate::error::BoxDynError;
use crate::PgArgumentBuffer;
/// Represents a `NUMERIC` value in the **Postgres** wire protocol.
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum PgNumeric {
/// Equivalent to the `'NaN'` value in Postgres. The result of, e.g. `1 / 0`.
NotANumber,
/// A populated `NUMERIC` value.
///
/// A description of these fields can be found here (although the type being described is the
/// version for in-memory calculations, the field names are the same):
/// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L224-L269
Number {
/// The sign of the value: positive (also set for 0 and -0), or negative.
sign: PgNumericSign,
/// The digits of the number in base-10000 with the most significant digit first
/// (big-endian).
///
/// The length of this vector must not overflow `i16` for the binary protocol.
///
/// *Note*: the `Encode` implementation will panic if any digit is `>= 10000`.
digits: Vec<i16>,
/// The scaling factor of the number, such that the value will be interpreted as
///
/// ```text
/// digits[0] * 10,000 ^ weight
/// + digits[1] * 10,000 ^ (weight - 1)
/// ...
/// + digits[N] * 10,000 ^ (weight - N) where N = digits.len() - 1
/// ```
/// May be negative.
weight: i16,
/// How many _decimal_ (base-10) digits following the decimal point to consider in
/// arithmetic regardless of how many actually follow the decimal point as determined by
/// `weight`--the comment in the Postgres code linked above recommends using this only for
/// ignoring unnecessary trailing zeroes (as trimming nonzero digits means reducing the
/// precision of the value).
///
/// Must be `>= 0`.
scale: i16,
},
}
// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L167-L170
const SIGN_POS: u16 = 0x0000;
const SIGN_NEG: u16 = 0x4000;
const SIGN_NAN: u16 = 0xC000; // overflows i16 (C equivalent truncates from integer literal)
/// Possible sign values for [PgNumeric].
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[repr(u16)]
pub(crate) enum PgNumericSign {
Positive = SIGN_POS,
Negative = SIGN_NEG,
}
impl PgNumericSign {
fn try_from_u16(val: u16) -> Result<Self, BoxDynError> {
match val {
SIGN_POS => Ok(PgNumericSign::Positive),
SIGN_NEG => Ok(PgNumericSign::Negative),
SIGN_NAN => unreachable!("sign value for NaN passed to PgNumericSign"),
_ => Err(format!("invalid value for PgNumericSign: {val:#04X}").into()),
}
}
}
impl PgNumeric {
/// Equivalent value of `0::numeric`.
pub const ZERO: Self = PgNumeric::Number {
sign: PgNumericSign::Positive,
digits: vec![],
weight: 0,
scale: 0,
};
pub(crate) fn is_valid_digit(digit: i16) -> bool {
(0..10_000).contains(&digit)
}
pub(crate) fn size_hint(decimal_digits: u64) -> usize {
let mut size_hint = Saturating(decimal_digits);
// BigDecimal::digits() gives us base-10 digits, so we divide by 4 to get base-10000 digits
// and since this is just a hint we just always round up
size_hint /= 4;
size_hint += 1;
// Times two bytes for each base-10000 digit
size_hint *= 2;
// Plus `weight` and `scale`
size_hint += 8;
usize::try_from(size_hint.0).unwrap_or(usize::MAX)
}
pub(crate) fn decode(mut buf: &[u8]) -> Result<Self, BoxDynError> {
// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874
let num_digits = buf.get_u16();
let weight = buf.get_i16();
let sign = buf.get_u16();
let scale = buf.get_i16();
if sign == SIGN_NAN {
Ok(PgNumeric::NotANumber)
} else {
let digits: Vec<_> = (0..num_digits).map(|_| buf.get_i16()).collect::<_>();
Ok(PgNumeric::Number {
sign: PgNumericSign::try_from_u16(sign)?,
scale,
weight,
digits,
})
}
}
/// ### Errors
///
/// * If `digits.len()` overflows `i16`
/// * If any element in `digits` is greater than or equal to 10000
pub(crate) fn encode(&self, buf: &mut PgArgumentBuffer) -> Result<(), String> {
match *self {
PgNumeric::Number {
ref digits,
sign,
scale,
weight,
} => {
let digits_len = i16::try_from(digits.len()).map_err(|_| {
format!(
"PgNumeric digits.len() ({}) should not overflow i16",
digits.len()
)
})?;
buf.extend(&digits_len.to_be_bytes());
buf.extend(&weight.to_be_bytes());
buf.extend(&(sign as i16).to_be_bytes());
buf.extend(&scale.to_be_bytes());
for (i, &digit) in digits.iter().enumerate() {
if !Self::is_valid_digit(digit) {
return Err(format!("{i}th PgNumeric digit out of range: {digit}"));
}
buf.extend(&digit.to_be_bytes());
}
}
PgNumeric::NotANumber => {
buf.extend(&0_i16.to_be_bytes());
buf.extend(&0_i16.to_be_bytes());
buf.extend(&SIGN_NAN.to_be_bytes());
buf.extend(&0_i16.to_be_bytes());
}
}
Ok(())
}
}