#![allow(clippy::unusual_byte_groupings)]
#[cfg_attr(target_pointer_width = "32", path = "scalar/p384_scalar_32.rs")]
#[cfg_attr(target_pointer_width = "64", path = "scalar/p384_scalar_64.rs")]
#[allow(
clippy::identity_op,
clippy::too_many_arguments,
clippy::unnecessary_cast
)]
mod scalar_impl;
use self::scalar_impl::*;
use crate::{FieldBytes, NistP384, SecretKey, ORDER_HEX, U384};
use core::{
iter::{Product, Sum},
ops::{AddAssign, MulAssign, Neg, Shr, ShrAssign, SubAssign},
};
use elliptic_curve::{
bigint::{self, ArrayEncoding, Limb},
ff::PrimeField,
ops::{Invert, Reduce},
scalar::{FromUintUnchecked, IsHigh},
subtle::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, CtOption},
Curve as _, Error, Result, ScalarPrimitive,
};
#[cfg(feature = "bits")]
use {crate::ScalarBits, elliptic_curve::group::ff::PrimeFieldBits};
#[cfg(feature = "serde")]
use serdect::serde::{de, ser, Deserialize, Serialize};
#[cfg(doc)]
use core::ops::{Add, Mul, Sub};
#[derive(Clone, Copy, Debug, PartialOrd, Ord)]
pub struct Scalar(U384);
primeorder::impl_mont_field_element!(
NistP384,
Scalar,
FieldBytes,
U384,
NistP384::ORDER,
fiat_p384_scalar_montgomery_domain_field_element,
fiat_p384_scalar_from_montgomery,
fiat_p384_scalar_to_montgomery,
fiat_p384_scalar_add,
fiat_p384_scalar_sub,
fiat_p384_scalar_mul,
fiat_p384_scalar_opp,
fiat_p384_scalar_square
);
impl Scalar {
pub fn invert(&self) -> CtOption<Self> {
CtOption::new(self.invert_unchecked(), !self.is_zero())
}
const fn invert_unchecked(&self) -> Self {
let words = impl_field_invert!(
self.to_canonical().as_words(),
Self::ONE.0.to_words(),
Limb::BITS,
bigint::nlimbs!(U384::BITS),
fiat_p384_scalar_mul,
fiat_p384_scalar_opp,
fiat_p384_scalar_divstep_precomp,
fiat_p384_scalar_divstep,
fiat_p384_scalar_msat,
fiat_p384_scalar_selectznz,
);
Self(U384::from_words(words))
}
pub fn sqrt(&self) -> CtOption<Self> {
let t1 = *self;
let t10 = t1.square();
let t11 = *self * t10;
let t101 = t10 * t11;
let t111 = t10 * t101;
let t1001 = t10 * t111;
let t1011 = t10 * t1001;
let t1101 = t10 * t1011;
let t1111 = t10 * t1101;
let t11110 = t1111.square();
let t11111 = t1 * t11110;
let t1111100 = t11111.sqn(2);
let t11111000 = t1111100.square();
let i14 = t11111000.square();
let i20 = i14.sqn(5) * i14;
let i31 = i20.sqn(10) * i20;
let i58 = (i31.sqn(4) * t11111000).sqn(21) * i31;
let i110 = (i58.sqn(3) * t1111100).sqn(47) * i58;
let x194 = i110.sqn(95) * i110 * t1111;
let i225 = ((x194.sqn(6) * t111).sqn(3) * t11).sqn(7);
let i235 = ((t1101 * i225).sqn(6) * t1101).square() * t1;
let i258 = ((i235.sqn(11) * t11111).sqn(2) * t1).sqn(8);
let i269 = ((t1101 * i258).sqn(2) * t11).sqn(6) * t1011;
let i286 = ((i269.sqn(4) * t111).sqn(6) * t11111).sqn(5);
let i308 = ((t1011 * i286).sqn(10) * t1101).sqn(9) * t1101;
let i323 = ((i308.sqn(4) * t1011).sqn(6) * t1001).sqn(3);
let i340 = ((t1 * i323).sqn(7) * t1011).sqn(7) * t101;
let i357 = ((i340.sqn(5) * t111).sqn(5) * t1111).sqn(5);
let i369 = ((t1011 * i357).sqn(4) * t1011).sqn(5) * t111;
let i387 = ((i369.sqn(3) * t11).sqn(7) * t11).sqn(6);
let i397 = ((t1011 * i387).sqn(4) * t101).sqn(3) * t11;
let i413 = ((i397.sqn(4) * t11).sqn(4) * t11).sqn(6);
let i427 = ((t101 * i413).sqn(5) * t101).sqn(6) * t1011;
let x = i427.sqn(3) * t101;
CtOption::new(x, x.square().ct_eq(&t1))
}
fn sqn(&self, n: usize) -> Self {
let mut x = *self;
for _ in 0..n {
x = x.square();
}
x
}
pub const fn shr_vartime(&self, shift: usize) -> Scalar {
Self(self.0.shr_vartime(shift))
}
}
impl AsRef<Scalar> for Scalar {
fn as_ref(&self) -> &Scalar {
self
}
}
impl FromUintUnchecked for Scalar {
type Uint = U384;
fn from_uint_unchecked(uint: Self::Uint) -> Self {
Self::from_uint_unchecked(uint)
}
}
impl Invert for Scalar {
type Output = CtOption<Self>;
fn invert(&self) -> CtOption<Self> {
self.invert()
}
}
impl IsHigh for Scalar {
fn is_high(&self) -> Choice {
const MODULUS_SHR1: U384 = NistP384::ORDER.shr_vartime(1);
self.to_canonical().ct_gt(&MODULUS_SHR1)
}
}
impl Shr<usize> for Scalar {
type Output = Self;
fn shr(self, rhs: usize) -> Self::Output {
self.shr_vartime(rhs)
}
}
impl Shr<usize> for &Scalar {
type Output = Scalar;
fn shr(self, rhs: usize) -> Self::Output {
self.shr_vartime(rhs)
}
}
impl ShrAssign<usize> for Scalar {
fn shr_assign(&mut self, rhs: usize) {
*self = *self >> rhs;
}
}
impl PrimeField for Scalar {
type Repr = FieldBytes;
const MODULUS: &'static str = ORDER_HEX;
const CAPACITY: u32 = 383;
const NUM_BITS: u32 = 384;
const TWO_INV: Self = Self::from_u64(2).invert_unchecked();
const MULTIPLICATIVE_GENERATOR: Self = Self::from_u64(2);
const S: u32 = 1;
const ROOT_OF_UNITY: Self = Self::from_hex("ffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52972");
const ROOT_OF_UNITY_INV: Self = Self::ROOT_OF_UNITY.invert_unchecked();
const DELTA: Self = Self::from_u64(4);
#[inline]
fn from_repr(bytes: FieldBytes) -> CtOption<Self> {
Self::from_bytes(&bytes)
}
#[inline]
fn to_repr(&self) -> FieldBytes {
self.to_bytes()
}
#[inline]
fn is_odd(&self) -> Choice {
self.is_odd()
}
}
#[cfg(feature = "bits")]
impl PrimeFieldBits for Scalar {
type ReprBits = fiat_p384_scalar_montgomery_domain_field_element;
fn to_le_bits(&self) -> ScalarBits {
self.to_canonical().to_words().into()
}
fn char_le_bits() -> ScalarBits {
NistP384::ORDER.to_words().into()
}
}
impl Reduce<U384> for Scalar {
type Bytes = FieldBytes;
fn reduce(w: U384) -> Self {
let (r, underflow) = w.sbb(&NistP384::ORDER, Limb::ZERO);
let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
Self::from_uint_unchecked(U384::conditional_select(&w, &r, !underflow))
}
#[inline]
fn reduce_bytes(bytes: &FieldBytes) -> Self {
Self::reduce(U384::from_be_byte_array(*bytes))
}
}
impl From<ScalarPrimitive<NistP384>> for Scalar {
fn from(w: ScalarPrimitive<NistP384>) -> Self {
Scalar::from(&w)
}
}
impl From<&ScalarPrimitive<NistP384>> for Scalar {
fn from(w: &ScalarPrimitive<NistP384>) -> Scalar {
Scalar::from_uint_unchecked(*w.as_uint())
}
}
impl From<Scalar> for ScalarPrimitive<NistP384> {
fn from(scalar: Scalar) -> ScalarPrimitive<NistP384> {
ScalarPrimitive::from(&scalar)
}
}
impl From<&Scalar> for ScalarPrimitive<NistP384> {
fn from(scalar: &Scalar) -> ScalarPrimitive<NistP384> {
ScalarPrimitive::new(scalar.into()).unwrap()
}
}
impl From<Scalar> for FieldBytes {
fn from(scalar: Scalar) -> Self {
scalar.to_repr()
}
}
impl From<&Scalar> for FieldBytes {
fn from(scalar: &Scalar) -> Self {
scalar.to_repr()
}
}
impl From<Scalar> for U384 {
fn from(scalar: Scalar) -> U384 {
U384::from(&scalar)
}
}
impl From<&Scalar> for U384 {
fn from(scalar: &Scalar) -> U384 {
scalar.to_canonical()
}
}
impl From<&SecretKey> for Scalar {
fn from(secret_key: &SecretKey) -> Scalar {
*secret_key.to_nonzero_scalar()
}
}
impl TryFrom<U384> for Scalar {
type Error = Error;
fn try_from(w: U384) -> Result<Self> {
Option::from(Self::from_uint(w)).ok_or(Error)
}
}
#[cfg(feature = "serde")]
impl Serialize for Scalar {
fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
where
S: ser::Serializer,
{
ScalarPrimitive::from(self).serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for Scalar {
fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
Ok(ScalarPrimitive::deserialize(deserializer)?.into())
}
}
#[cfg(test)]
mod tests {
use super::Scalar;
use crate::FieldBytes;
use elliptic_curve::ff::PrimeField;
use primeorder::impl_primefield_tests;
const T: [u64; 6] = [
0x76760cb5666294b9,
0xac0d06d9245853bd,
0xe3b1a6c0fa1b96ef,
0xffffffffffffffff,
0xffffffffffffffff,
0x7fffffffffffffff,
];
impl_primefield_tests!(Scalar, T);
#[test]
fn from_to_bytes_roundtrip() {
let k: u64 = 42;
let mut bytes = FieldBytes::default();
bytes[40..].copy_from_slice(k.to_be_bytes().as_ref());
let scalar = Scalar::from_repr(bytes).unwrap();
assert_eq!(bytes, scalar.to_bytes());
}
#[test]
fn multiply() {
let one = Scalar::ONE;
let two = one + one;
let three = two + one;
let six = three + three;
assert_eq!(six, two * three);
let minus_two = -two;
let minus_three = -three;
assert_eq!(two, -minus_two);
assert_eq!(minus_three * minus_two, minus_two * minus_three);
assert_eq!(six, minus_two * minus_three);
}
#[test]
fn invert() {
let one = Scalar::ONE;
let three = one + one + one;
let inv_three = three.invert().unwrap();
assert_eq!(three * inv_three, one);
let minus_three = -three;
let inv_minus_three = minus_three.invert().unwrap();
assert_eq!(inv_minus_three, -inv_three);
assert_eq!(three * inv_minus_three, -one);
}
#[test]
fn sqrt() {
for &n in &[1u64, 4, 9, 16, 25, 36, 49, 64] {
let scalar = Scalar::from(n);
let sqrt = scalar.sqrt().unwrap();
assert_eq!(sqrt.square(), scalar);
}
}
}