sqlx_postgres/connection/
sasl.rsuse crate::connection::stream::PgStream;
use crate::error::Error;
use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse};
use crate::PgConnectOptions;
use hmac::{Hmac, Mac};
use rand::Rng;
use sha2::{Digest, Sha256};
use stringprep::saslprep;
use base64::prelude::{Engine as _, BASE64_STANDARD};
const GS2_HEADER: &str = "n,,";
const CHANNEL_ATTR: &str = "c";
const USERNAME_ATTR: &str = "n";
const CLIENT_PROOF_ATTR: &str = "p";
const NONCE_ATTR: &str = "r";
pub(crate) async fn authenticate(
stream: &mut PgStream,
options: &PgConnectOptions,
data: AuthenticationSasl,
) -> Result<(), Error> {
let mut has_sasl = false;
let mut has_sasl_plus = false;
let mut unknown = Vec::new();
for mechanism in data.mechanisms() {
match mechanism {
"SCRAM-SHA-256" => {
has_sasl = true;
}
"SCRAM-SHA-256-PLUS" => {
has_sasl_plus = true;
}
_ => {
unknown.push(mechanism.to_owned());
}
}
}
if !has_sasl_plus && !has_sasl {
return Err(err_protocol!(
"unsupported SASL authentication mechanisms: {}",
unknown.join(", ")
));
}
let mut channel_binding = format!("{CHANNEL_ATTR}=");
BASE64_STANDARD.encode_string(GS2_HEADER, &mut channel_binding);
let username = format!("{}={}", USERNAME_ATTR, options.username);
let username = match saslprep(&username) {
Ok(v) => v,
Err(_) => panic!("Failed to saslprep username"),
};
let nonce = gen_nonce();
let client_first_message_bare = format!("{username},{nonce}");
let client_first_message = format!("{GS2_HEADER}{client_first_message_bare}");
stream
.send(SaslInitialResponse {
response: &client_first_message,
plus: false,
})
.await?;
let cont = match stream.recv_expect().await? {
Authentication::SaslContinue(data) => data,
auth => {
return Err(err_protocol!(
"expected SASLContinue but received {:?}",
auth
));
}
};
let salted_password = hi(
options.password.as_deref().unwrap_or_default(),
&cont.salt,
cont.iterations,
)?;
let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
mac.update(b"Client Key");
let client_key = mac.finalize().into_bytes();
let stored_key = Sha256::digest(client_key);
let client_final_message_wo_proof = format!(
"{channel_binding},r={nonce}",
channel_binding = channel_binding,
nonce = &cont.nonce
);
let auth_message = format!(
"{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
client_first_message_bare = client_first_message_bare,
server_first_message = cont.message,
client_final_message_wo_proof = client_final_message_wo_proof
);
let mut mac = Hmac::<Sha256>::new_from_slice(&stored_key).map_err(Error::protocol)?;
mac.update(auth_message.as_bytes());
let client_signature = mac.finalize().into_bytes();
let client_proof: Vec<u8> = client_key
.iter()
.zip(client_signature.iter())
.map(|(&a, &b)| a ^ b)
.collect();
let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
mac.update(b"Server Key");
let server_key = mac.finalize().into_bytes();
let mut mac = Hmac::<Sha256>::new_from_slice(&server_key).map_err(Error::protocol)?;
mac.update(auth_message.as_bytes());
let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}=");
BASE64_STANDARD.encode_string(client_proof, &mut client_final_message);
stream.send(SaslResponse(&client_final_message)).await?;
let data = match stream.recv_expect().await? {
Authentication::SaslFinal(data) => data,
auth => {
return Err(err_protocol!("expected SASLFinal but received {:?}", auth));
}
};
mac.verify_slice(&data.verifier).map_err(Error::protocol)?;
Ok(())
}
fn gen_nonce() -> String {
let mut rng = rand::thread_rng();
let count = rng.gen_range(64..128);
let nonce: String = std::iter::repeat(())
.map(|()| {
let mut c = rng.gen_range(0x21u8..0x7F);
while c == 0x2C {
c = rng.gen_range(0x21u8..0x7F);
}
c
})
.take(count)
.map(|c| c as char)
.collect();
rng.gen_range(32..128);
format!("{NONCE_ATTR}={nonce}")
}
fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> {
let mut mac = Hmac::<Sha256>::new_from_slice(s.as_bytes()).map_err(Error::protocol)?;
mac.update(salt);
mac.update(&1u32.to_be_bytes());
let mut u = mac.finalize_reset().into_bytes();
let mut hi = u;
for _ in 1..iter_count {
mac.update(u.as_slice());
u = mac.finalize_reset().into_bytes();
hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect();
}
Ok(hi.into())
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_sasl_hi(b: &mut test::Bencher) {
use test::black_box;
let mut rng = rand::thread_rng();
let nonce: Vec<u8> = std::iter::repeat(())
.map(|()| rng.sample(rand::distributions::Alphanumeric))
.take(64)
.collect();
b.iter(|| {
let _ = hi(
test::black_box("secret_password"),
test::black_box(&nonce),
test::black_box(4096),
);
});
}