1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

add tests, refine uint optimizer

This commit is contained in:
schaeff 2020-05-28 15:28:47 +02:00
parent bb2b70593f
commit e4b5820722
8 changed files with 379 additions and 1016 deletions

View file

@ -12,7 +12,7 @@ fi
files=()
for file in $(git diff --name-only --cached); do
for file in $(git diff --diff-filter=d --name-only --cached); do
if [ ${file: -3} == ".rs" ]; then
rustfmt +nightly --check $file &>/dev/null
if [ $? != 0 ]; then

40
Cargo.lock generated
View file

@ -43,7 +43,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72342c21057a3cb5f7c2d849bf7999a83795434dd36d74fa8c24680581bd1930"
dependencies = [
"colored",
"difference",
"difference 1.0.0",
"environment",
"error-chain 0.11.0",
"serde_json",
@ -366,12 +366,28 @@ dependencies = [
"memchr",
]
[[package]]
name = "ctor"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf6b25ee9ac1995c54d7adb2eff8cfffb7260bc774fb63c601ec65467f43cd9d"
dependencies = [
"quote 1.0.5",
"syn 1.0.21",
]
[[package]]
name = "difference"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3304d19798a8e067e48d8e69b2c37f0b5e9b4e462504ad9e27e9f3fce02bba8"
[[package]]
name = "difference"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "524cbf6897b527295dff137cec09ecf3a05f4fddffd7dfcd1585403449e74198"
[[package]]
name = "digest"
version = "0.8.1"
@ -916,6 +932,15 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "output_vt100"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53cdc5b785b7a58c5aad8216b3dfa114df64b0b06ae6e1501cef91df2fbdf8f9"
dependencies = [
"winapi",
]
[[package]]
name = "pairing_ce"
version = "0.21.0"
@ -1027,6 +1052,18 @@ version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
[[package]]
name = "pretty_assertions"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f81e1644e1b54f5a68959a29aa86cde704219254669da328ecfdf6a1f09d427"
dependencies = [
"ansi_term",
"ctor",
"difference 2.0.0",
"output_vt100",
]
[[package]]
name = "proc-macro2"
version = "0.4.30"
@ -1757,6 +1794,7 @@ dependencies = [
"num",
"num-bigint",
"pairing_ce",
"pretty_assertions",
"rand 0.4.6",
"reduce",
"regex",

View file

@ -43,6 +43,7 @@ features = ["serde"]
glob = "0.2.11"
assert_cli = "0.5"
wasm-bindgen-test = "0.3.0"
pretty_assertions = "0.6.1"
[build-dependencies]
cc = { version = "1.0", features = ["parallel"], optional = true }

View file

@ -23,7 +23,7 @@
// ## Optimization rules
// We maintain `s`, a set of substitutions as a mapping of `(variable => linear_combination)`. It starts empty.
// We also maintaint `i`, a set of variables that should be ignored when trying to substitute. It starts empty.
// We also maintain `i`, a set of variables that should be ignored when trying to substitute. It starts empty.
// - input variables are inserted into `i`
// - the `~one` variable is inserted into `i`
@ -57,16 +57,7 @@ impl<T: Field> RedefinitionOptimizer<T> {
}
pub fn optimize(p: Prog<T>) -> Prog<T> {
let mut p = p;
loop {
let size_before = p.main.statements.len();
p = RedefinitionOptimizer::new().fold_module(p);
let size_after = p.main.statements.len();
if size_after == size_before {
return p;
}
}
RedefinitionOptimizer::new().fold_module(p)
}
}
@ -77,54 +68,48 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
let quad = self.fold_quadratic_combination(quad);
let lin = self.fold_linear_combination(lin);
if self.substitution.len() < 150000 {
let (keep_constraint, to_insert, to_ignore) = match lin.try_summand() {
// if the right side is a single variable
Some((variable, coefficient)) => {
match self.ignore.contains(&variable) {
// if the variable isn't tagged as ignored
false => match self.substitution.get(&variable) {
// if the variable is already defined
Some(_) => (true, None, None),
// if the variable is not defined yet
None => match quad.try_linear() {
// if the left side is linear
Some(l) => {
(false, Some((variable, l / &coefficient)), None)
}
// if the left side isn't linear
None => (true, None, Some(variable)),
},
let (keep_constraint, to_insert, to_ignore) = match lin.try_summand() {
// if the right side is a single variable
Some((variable, coefficient)) => {
match self.ignore.contains(&variable) {
// if the variable isn't tagged as ignored
false => match self.substitution.get(&variable) {
// if the variable is already defined
Some(_) => (true, None, None),
// if the variable is not defined yet
None => match quad.try_linear() {
// if the left side is linear
Some(l) => (false, Some((variable, l / &coefficient)), None),
// if the left side isn't linear
None => (true, None, Some(variable)),
},
true => (true, None, None),
}
},
true => (true, None, None),
}
None => (true, None, None),
};
// insert into the ignored set
match to_ignore {
Some(v) => {
self.ignore.insert(v);
}
None => {}
}
None => (true, None, None),
};
// insert into the substitution map
match to_insert {
Some((k, v)) => {
self.substitution.insert(k, v);
}
None => {}
};
// decide whether the constraint should be kept
match keep_constraint {
false => vec![],
true => vec![Statement::Constraint(quad, lin)],
// insert into the ignored set
match to_ignore {
Some(v) => {
self.ignore.insert(v);
}
} else {
vec![Statement::Constraint(quad, lin)]
None => {}
}
// insert into the substitution map
match to_insert {
Some((k, v)) => {
self.substitution.insert(k, v);
}
None => {}
};
// decide whether the constraint should be kept
match keep_constraint {
false => vec![],
true => vec![Statement::Constraint(quad, lin)],
}
}
Statement::Directive(d) => {
@ -140,15 +125,26 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
}
fn fold_linear_combination(&mut self, lc: LinComb<T>) -> LinComb<T> {
// for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is
lc.0.into_iter()
.map(|(variable, coefficient)| {
self.substitution
.get(&variable)
.map(|l| l.clone() * &coefficient)
.unwrap_or(LinComb::summand(coefficient, variable))
})
.fold(LinComb::zero(), |acc, x| acc + x)
match lc
.0
.iter()
.find(|(variable, _)| self.substitution.get(&variable).is_some())
.is_some()
{
true =>
// for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is
{
lc.0.into_iter()
.map(|(variable, coefficient)| {
self.substitution
.get(&variable)
.map(|l| l.clone() * &coefficient)
.unwrap_or(LinComb::summand(coefficient, variable))
})
.fold(LinComb::zero(), |acc, x| acc + x)
}
false => lc,
}
}
fn fold_argument(&mut self, a: FlatVariable) -> FlatVariable {

View file

@ -126,7 +126,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
// `2**n - max_left <= a - b + 2 ** n <= bound where bound = max_left + offset`
// If ´bound < N´, we set we return `bound` as the max of ´left - right`
// If ´bound < N´, we return `bound` as the max of ´left - right`
// Else we start again, reducing `left`. In this case `max_left` becomes `2**target - 1`
// Else we start again, reducing `right`. In this case `offset` becomes `2**target`
// Else we start again reducing both. In this case `bound` becomes `2**(target+1) - 1` which is always
@ -141,25 +141,33 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
let offset =
T::from(2u32).pow(std::cmp::max(right_bitwidth, range as u32) as usize);
let target_offset = T::from(2u32).pow(range);
let (should_reduce_left, should_reduce_right, max) = left_max
.checked_add(&offset)
.map(|max| (false, false, max))
.unwrap_or_else(|| {
range_max
.clone()
let (should_reduce_left, should_reduce_right, max) =
if right_bitwidth as usize == T::get_required_bits() - 1 {
// if and only if `right_bitwidth` is `T::get_required_bits() - 1`, then `offset` is out of the interval
// [0, 2**(max_bitwidth)[, therefore we need to reduce `right`
left_max
.checked_add(&target_offset.clone())
.map(|max| (false, true, max))
.unwrap_or_else(|| (true, true, range_max.clone() + target_offset))
} else {
left_max
.checked_add(&offset)
.map(|max| (true, false, max))
.map(|max| (false, false, max))
.unwrap_or_else(|| {
left_max
.checked_add(&target_offset.clone())
.map(|max| (false, true, max))
.unwrap_or_else(|| {
(true, true, range_max.clone() + target_offset)
})
range_max
.clone()
.checked_add(&offset)
.map(|max| (true, false, max))
.unwrap_or_else(
// this is unreachable because the max value for `range_max + offset` is
// 2**32 + 2**(T::get_required_bits() - 2) < 2**(T::get_required_bits() - 1)
|| unreachable!(),
)
})
});
};
let left = if should_reduce_left {
force_reduce(left)
@ -244,14 +252,39 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
let e = self.fold_uint_expression(e);
let by = self.fold_field_expression(by);
UExpression::left_shift(force_reduce(e), by).with_max(range_max)
let by_u = match by {
FieldElementExpression::Number(ref by) => {
by.to_dec_string().parse::<usize>().unwrap()
}
_ => unreachable!(),
};
let bitwidth = e.metadata.clone().unwrap().bitwidth();
let max =
T::from(2).pow(std::cmp::min(bitwidth as usize + by_u, range)) - T::from(1);
UExpression::left_shift(force_reduce(e), by).with_max(max)
}
RightShift(box e, box by) => {
// reduce the two terms
let e = self.fold_uint_expression(e);
let by = self.fold_field_expression(by);
UExpression::right_shift(force_reduce(e), by).with_max(range_max)
let by_u = match by {
FieldElementExpression::Number(ref by) => {
by.to_dec_string().parse::<usize>().unwrap()
}
_ => unreachable!(),
};
let bitwidth = e.metadata.clone().unwrap().bitwidth();
let max = T::from(2)
.pow(bitwidth as usize - std::cmp::min(by_u, bitwidth as usize))
- T::from(1);
UExpression::right_shift(force_reduce(e), by).with_max(max)
}
IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_uint_expression(consequence);
@ -375,99 +408,227 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
mod tests {
use super::*;
use zokrates_field::Bn128Field;
use zokrates_field::Pow;
// #[should_panic]
// #[test]
// fn existing_metadata() {
// let e = UExpressionInner::Identifier("foo".into())
// .annotate(32)
// .metadata(UMetadata::with_max(2_u32.pow(33_u32) - 1));
extern crate pretty_assertions;
use self::pretty_assertions::assert_eq;
// let mut optimizer: UintOptimizer<Bn128Field> = UintOptimizer::new();
macro_rules! uint_test {
( $left_max:expr, $left_reduce:expr, $right_max:expr, $right_reduce:expr, $method:ident, $res_max:expr ) => {{
let left = e_with_max($left_max);
// let _ = optimizer.fold_uint_expression(e.clone());
// }
let right = e_with_max($right_max);
let left_expected = if $left_reduce {
force_reduce(left.clone())
} else {
force_no_reduce(left.clone())
};
let right_expected = if $right_reduce {
force_reduce(right.clone())
} else {
force_no_reduce(right.clone())
};
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::$method(left.clone(), right.clone())),
UExpression::$method(left_expected, right_expected).with_max($res_max)
);
}};
}
fn e_with_max<'a, U: Into<Bn128Field>>(max: U) -> UExpression<'a, Bn128Field> {
UExpressionInner::Identifier("foo".into())
.annotate(32)
.metadata(UMetadata::with_max(max))
}
#[test]
fn add() {
// max(left + right) = max(left) + max(right)
let left: UExpression<Bn128Field> = UExpressionInner::Identifier("foo".into())
.annotate(32)
.metadata(UMetadata::with_max(42u32));
let right = UExpressionInner::Identifier("foo".into())
.annotate(32)
.metadata(UMetadata::with_max(33u32));
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::add(left, right))
.metadata
.unwrap()
.max,
75u32.into()
// no reduction
uint_test!(42, false, 33, false, add, 75);
// left reduction
uint_test!(
Bn128Field::max_unique_value(),
true,
1,
false,
add,
0x100000000_u128
);
// right reduction
uint_test!(
1,
false,
Bn128Field::max_unique_value(),
true,
add,
0x100000000_u128
);
// right and left reductions
uint_test!(
Bn128Field::max_unique_value(),
true,
Bn128Field::max_unique_value(),
true,
add,
0x1fffffffe_u128
);
}
#[test]
fn sub() {
// `left` and `right` are smaller than the target
let left: UExpression<Bn128Field> = UExpressionInner::Identifier("a".into())
.annotate(32)
.metadata(UMetadata::with_max(42u32));
// no reduction
uint_test!(42, false, 33, false, sub, 0x100000000_u128 + 42);
// left reduction
uint_test!(
Bn128Field::max_unique_value(),
true,
1,
false,
sub,
0x1ffffffff_u128
);
// right reduction
uint_test!(
1,
false,
Bn128Field::max_unique_value(),
true,
sub,
0x100000001_u128
);
// right and left reductions
uint_test!(
Bn128Field::max_unique_value(),
true,
Bn128Field::max_unique_value(),
true,
sub,
0x1ffffffff_u128
);
}
let right = UExpressionInner::Identifier("b".into())
.annotate(32)
.metadata(UMetadata::with_max(33u32));
#[test]
fn mult() {
// no reduction
uint_test!(42, false, 33, false, mult, 1386);
// left reduction
uint_test!(
Bn128Field::max_unique_value(),
true,
2,
false,
mult,
0x1fffffffe_u128
);
// right reduction
uint_test!(
2,
false,
Bn128Field::max_unique_value(),
true,
mult,
0x1fffffffe_u128
);
// right and left reductions
uint_test!(
Bn128Field::max_unique_value(),
true,
Bn128Field::max_unique_value(),
true,
mult,
0xfffffffe00000001_u128
);
}
#[test]
fn bitwise() {
// xor
uint_test!(42, true, 33, true, xor, 0xffffffff_u32);
// or
uint_test!(42, true, 33, true, or, 0xffffffff_u32);
// and
uint_test!(42, true, 33, true, and, 0xffffffff_u32);
// not
let e = e_with_max(255);
let e_expected = force_reduce(e.clone());
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::sub(left, right))
.metadata
.unwrap()
.max,
Bn128Field::from(2u32).pow(32) + Bn128Field::from(42)
UintOptimizer::new().fold_uint_expression(UExpression::not(e)),
UExpression::not(e_expected).with_max(0xffffffff_u32)
);
}
#[test]
fn right_shift() {
let e = e_with_max(255);
let e_expected = force_reduce(e.clone());
assert_eq!(
UintOptimizer::new().fold_uint_expression(UExpression::right_shift(
e,
FieldElementExpression::Number(Bn128Field::from(2))
)),
UExpression::right_shift(
e_expected,
FieldElementExpression::Number(Bn128Field::from(2))
)
.with_max(63)
);
// `left` and `right` are larger than the target but no readjustment is required
let left: UExpression<Bn128Field> = UExpressionInner::Identifier("a".into())
.annotate(32)
.metadata(UMetadata::with_max(u64::MAX as u128));
let e = e_with_max(2);
let right = UExpressionInner::Identifier("b".into())
.annotate(32)
.metadata(UMetadata::with_max(u64::MAX as u128));
let e_expected = force_reduce(e.clone());
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::sub(left, right))
.metadata
.unwrap()
.max,
Bn128Field::from(2).pow(64) + Bn128Field::from(u64::MAX as u128)
UintOptimizer::new().fold_uint_expression(UExpression::right_shift(
e,
FieldElementExpression::Number(Bn128Field::from(2))
)),
UExpression::right_shift(
e_expected,
FieldElementExpression::Number(Bn128Field::from(2))
)
.with_max(0)
);
}
#[test]
fn left_shift() {
let e = e_with_max(255);
let e_expected = force_reduce(e.clone());
assert_eq!(
UintOptimizer::new().fold_uint_expression(UExpression::left_shift(
e,
FieldElementExpression::Number(Bn128Field::from(2))
)),
UExpression::left_shift(
e_expected,
FieldElementExpression::Number(Bn128Field::from(2))
)
.with_max(1023)
);
// `left` and `right` are larger than the target and needs to be readjusted
let left: UExpression<Bn128Field> = UExpressionInner::Identifier("a".into())
.annotate(32)
.metadata(UMetadata::with_max(
Bn128Field::from(2u32).pow(Bn128Field::get_required_bits() - 1)
- Bn128Field::from(1),
));
let e = e_with_max(0xffffffff_u32);
let right = UExpressionInner::Identifier("b".into())
.annotate(32)
.metadata(UMetadata::with_max(42u32));
let e_expected = force_reduce(e.clone());
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::sub(left, right))
.metadata
.unwrap()
.max,
Bn128Field::from(2u32).pow(32) * Bn128Field::from(2) - Bn128Field::from(1)
UintOptimizer::new().fold_uint_expression(UExpression::left_shift(
e,
FieldElementExpression::Number(Bn128Field::from(2))
)),
UExpression::left_shift(
e_expected,
FieldElementExpression::Number(Bn128Field::from(2))
)
.with_max(0xffffffff_u32)
);
}

View file

@ -29,6 +29,11 @@ impl<'ast, T: Field> UExpression<'ast, T> {
UExpressionInner::Xor(box self, box other).annotate(bitwidth)
}
pub fn not(self) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth;
UExpressionInner::Not(box self).annotate(bitwidth)
}
pub fn or(self, other: Self) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth;
assert_eq!(bitwidth, other.bitwidth);

View file

@ -1,854 +0,0 @@
//
// @file field.rs
// @author Dennis Kuhnert <dennis.kuhnert@campus.tu-berlin.de>
// @author Jacob Eberhardt <jacob.eberhardt@tu-berlin.de>
// @date 2017
use lazy_static::lazy_static;
use num_bigint::{BigInt, BigUint, Sign, ToBigInt};
use num_integer::Integer;
use num_traits::{One, Zero};
use pairing::bn256::Bn256;
use pairing::ff::ScalarEngine;
use pairing::Engine;
use serde_derive::{Deserialize, Serialize};
use std::convert::From;
use std::fmt;
use std::fmt::{Debug, Display};
use std::hash::Hash;
use std::ops::{Add, Div, Mul, Sub};
lazy_static! {
static ref P: BigInt = BigInt::parse_bytes(
b"21888242871839275222246405745257275088548364400416034343698204186575808495617",
10
)
.unwrap();
}
pub trait Pow<RHS> {
type Output;
fn pow(self, _: RHS) -> Self::Output;
}
pub trait Field:
From<i32>
+ From<u32>
+ From<usize>
+ From<u128>
+ Zero
+ One
+ Clone
+ PartialEq
+ Eq
+ Hash
+ PartialOrd
+ Display
+ Debug
+ Add<Self, Output = Self>
+ for<'a> Add<&'a Self, Output = Self>
+ Sub<Self, Output = Self>
+ for<'a> Sub<&'a Self, Output = Self>
+ Mul<Self, Output = Self>
+ for<'a> Mul<&'a Self, Output = Self>
+ Div<Self, Output = Self>
+ for<'a> Div<&'a Self, Output = Self>
+ Pow<usize, Output = Self>
+ Pow<Self, Output = Self>
+ for<'a> Pow<&'a Self, Output = Self>
+ num_traits::CheckedAdd
+ num_traits::CheckedMul
{
/// An associated type to be able to operate with Bellman ff traits
type BellmanEngine: Engine;
fn from_bellman(e: <Self::BellmanEngine as ScalarEngine>::Fr) -> Self {
use ff::{PrimeField, PrimeFieldRepr};
let mut res: Vec<u8> = vec![];
e.into_repr().write_le(&mut res).unwrap();
Self::from_byte_vector(res)
}
fn into_bellman(self) -> <Self::BellmanEngine as ScalarEngine>::Fr {
use ff::PrimeField;
let s = self.to_dec_string();
<Self::BellmanEngine as ScalarEngine>::Fr::from_str(&s).unwrap()
}
/// Returns this `Field`'s contents as little-endian byte vector
fn into_byte_vector(&self) -> Vec<u8>;
/// Returns an element of this `Field` from a little-endian byte vector
fn from_byte_vector(_: Vec<u8>) -> Self;
/// Returns this `Field`'s contents as decimal string
fn to_dec_string(&self) -> String;
/// Returns the multiplicative inverse, i.e.: self * self.inverse_mul() = Self::one()
fn inverse_mul(&self) -> Self;
/// Returns the smallest value that can be represented by this field type.
fn min_value() -> Self;
/// Returns the largest value that can be represented by this field type.
fn max_value() -> Self;
/// Returns the number of required bits to represent this field type.
fn get_required_bits() -> usize;
/// Tries to parse a string into this representation
fn try_from_dec_str<'a>(s: &'a str) -> Result<Self, ()>;
/// Returns a decimal string representing the member of the equivalence class of this `Field` in Z/pZ
/// which lies in [-(p-1)/2, (p-1)/2]
fn to_compact_dec_string(&self) -> String;
/// Converts to BigUint
fn into_big_uint(self) -> BigUint;
/// Gets the number of bits
fn bits(&self) -> u32;
}
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Serialize, Deserialize)]
pub struct FieldPrime {
value: BigInt,
}
impl num_traits::CheckedAdd for FieldPrime {
fn checked_add(&self, other: &Self) -> Option<Self> {
use num_traits::Pow;
let res = self.value.clone() + other.value.clone();
let bound = BigInt::from(2u32).pow(Self::get_required_bits() - 1);
// we only go up to 2**(bitwidth - 1) because after that we lose uniqueness of bit decomposition
if res >= bound {
None
} else {
Some(FieldPrime { value: res })
}
}
}
impl num_traits::CheckedMul for FieldPrime {
fn checked_mul(&self, other: &Self) -> Option<Self> {
use num_traits::Pow;
let res = self.value.clone() * other.value.clone();
// we only go up to 2**(bitwidth - 1) because after that we lose uniqueness of bit decomposition
if res >= BigInt::from(2u32).pow(Self::get_required_bits() - 1) {
None
} else {
Some(FieldPrime { value: res })
}
}
}
impl Field for FieldPrime {
type BellmanEngine = Bn256;
fn into_big_uint(self) -> BigUint {
self.value.to_biguint().unwrap()
}
fn into_byte_vector(&self) -> Vec<u8> {
match self.value.to_biguint() {
Option::Some(val) => val.to_bytes_le(),
Option::None => panic!("Should never happen."),
}
}
fn from_byte_vector(bytes: Vec<u8>) -> Self {
let uval = BigUint::from_bytes_le(bytes.as_slice());
FieldPrime {
value: BigInt::from_biguint(Sign::Plus, uval),
}
}
fn to_dec_string(&self) -> String {
self.value.to_str_radix(10)
}
fn bits(&self) -> u32 {
self.value.bits() as u32
}
fn inverse_mul(&self) -> FieldPrime {
let (b, s, _) = extended_euclid(&self.value, &*P);
assert_eq!(b, BigInt::one());
FieldPrime {
value: &s - s.div_floor(&*P) * &*P,
}
}
fn min_value() -> FieldPrime {
FieldPrime {
value: ToBigInt::to_bigint(&0).unwrap(),
}
}
fn max_value() -> FieldPrime {
FieldPrime {
value: &*P - ToBigInt::to_bigint(&1).unwrap(),
}
}
fn get_required_bits() -> usize {
(*P).bits()
}
fn try_from_dec_str<'a>(s: &'a str) -> Result<Self, ()> {
let x = BigInt::parse_bytes(s.as_bytes(), 10).ok_or(())?;
Ok(FieldPrime {
value: &x - x.div_floor(&*P) * &*P,
})
}
fn to_compact_dec_string(&self) -> String {
// values up to (p-1)/2 included are represented as positive, values between (p+1)/2 and p-1 are represented as negative by subtracting p
if self.value <= FieldPrime::max_value().value / 2 {
format!("{}", self.value.to_str_radix(10))
} else {
format!(
"({})",
(&self.value - (FieldPrime::max_value().value + BigInt::one())).to_str_radix(10)
)
}
}
}
impl Default for FieldPrime {
fn default() -> Self {
FieldPrime {
value: BigInt::default(),
}
}
}
impl Display for FieldPrime {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.value.to_str_radix(10))
}
}
impl Debug for FieldPrime {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.value.to_str_radix(10))
}
}
impl From<i32> for FieldPrime {
fn from(num: i32) -> Self {
let x = ToBigInt::to_bigint(&num).unwrap();
FieldPrime {
value: &x - x.div_floor(&*P) * &*P,
}
}
}
impl From<u32> for FieldPrime {
fn from(num: u32) -> Self {
let x = ToBigInt::to_bigint(&num).unwrap();
FieldPrime {
value: &x - x.div_floor(&*P) * &*P,
}
}
}
impl From<usize> for FieldPrime {
fn from(num: usize) -> Self {
let x = ToBigInt::to_bigint(&num).unwrap();
FieldPrime {
value: &x - x.div_floor(&*P) * &*P,
}
}
}
impl From<u128> for FieldPrime {
fn from(num: u128) -> Self {
let x = ToBigInt::to_bigint(&num).unwrap();
FieldPrime {
value: &x - x.div_floor(&*P) * &*P,
}
}
}
impl Zero for FieldPrime {
fn zero() -> FieldPrime {
FieldPrime {
value: ToBigInt::to_bigint(&0).unwrap(),
}
}
fn is_zero(&self) -> bool {
self.value == ToBigInt::to_bigint(&0).unwrap()
}
}
impl One for FieldPrime {
fn one() -> FieldPrime {
FieldPrime {
value: ToBigInt::to_bigint(&1).unwrap(),
}
}
}
impl Add<FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn add(self, other: FieldPrime) -> FieldPrime {
FieldPrime {
value: (self.value + other.value) % &*P,
}
}
}
impl<'a> Add<&'a FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn add(self, other: &FieldPrime) -> FieldPrime {
FieldPrime {
value: (self.value + other.value.clone()) % &*P,
}
}
}
impl Sub<FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn sub(self, other: FieldPrime) -> FieldPrime {
let x = self.value - other.value;
FieldPrime {
value: &x - x.div_floor(&*P) * &*P,
}
}
}
impl<'a> Sub<&'a FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn sub(self, other: &FieldPrime) -> FieldPrime {
let x = self.value - other.value.clone();
FieldPrime {
value: &x - x.div_floor(&*P) * &*P,
}
}
}
impl Mul<FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn mul(self, other: FieldPrime) -> FieldPrime {
FieldPrime {
value: (self.value * other.value) % &*P,
}
}
}
impl<'a> Mul<&'a FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn mul(self, other: &FieldPrime) -> FieldPrime {
FieldPrime {
value: (self.value * other.value.clone()) % &*P,
}
}
}
impl Div<FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn div(self, other: FieldPrime) -> FieldPrime {
self * other.inverse_mul()
}
}
impl<'a> Div<&'a FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn div(self, other: &FieldPrime) -> FieldPrime {
self / other.clone()
}
}
impl Pow<usize> for FieldPrime {
type Output = FieldPrime;
fn pow(self, exp: usize) -> FieldPrime {
let mut res = FieldPrime::from(1);
for _ in 0..exp {
res = res * &self;
}
res
}
}
impl Pow<FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn pow(self, exp: FieldPrime) -> FieldPrime {
let mut res = FieldPrime::one();
let mut current = FieldPrime::zero();
loop {
if current >= exp {
return res;
}
res = res * &self;
current = current + FieldPrime::one();
}
}
}
impl<'a> Pow<&'a FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn pow(self, exp: &'a FieldPrime) -> FieldPrime {
let mut res = FieldPrime::one();
let mut current = FieldPrime::zero();
loop {
if &current >= exp {
return res;
}
res = res * &self;
current = current + FieldPrime::one();
}
}
}
/// Calculates the gcd using an iterative implementation of the extended euclidian algorithm.
/// Returning `(d, s, t)` so that `d = s * a + t * b`
///
/// # Arguments
/// * `a` - First number as `BigInt`
/// * `b` - Second number as `BigInt`
fn extended_euclid(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
let (mut s, mut old_s) = (BigInt::zero(), BigInt::one());
let (mut t, mut old_t) = (BigInt::one(), BigInt::zero());
let (mut r, mut old_r) = (b.clone(), a.clone());
while !&r.is_zero() {
let quotient = &old_r / &r;
let tmp_r = old_r.clone();
old_r = r.clone();
r = &tmp_r - &quotient * &r;
let tmp_s = old_s.clone();
old_s = s.clone();
s = &tmp_s - &quotient * &s;
let tmp_t = old_t.clone();
old_t = t.clone();
t = &tmp_t - &quotient * &t;
}
return (old_r, old_s, old_t);
}
#[cfg(test)]
mod tests {
use super::*;
impl<'a> From<&'a str> for FieldPrime {
fn from(s: &'a str) -> FieldPrime {
FieldPrime::try_from_dec_str(s).unwrap()
}
}
#[cfg(test)]
mod field_prime {
use super::*;
use bincode::{deserialize, serialize, Infinite};
#[test]
fn positive_number() {
assert_eq!(
"1234245612".parse::<BigInt>().unwrap(),
FieldPrime::from("1234245612").value
);
}
#[test]
fn negative_number() {
assert_eq!(
P.checked_sub(&"12".parse::<BigInt>().unwrap()).unwrap(),
FieldPrime::from("-12").value
);
}
#[test]
fn addition() {
assert_eq!(
"65484493".parse::<BigInt>().unwrap(),
(FieldPrime::from("65416358") + FieldPrime::from("68135")).value
);
assert_eq!(
"65484493".parse::<BigInt>().unwrap(),
(FieldPrime::from("65416358") + &FieldPrime::from("68135")).value
);
}
#[test]
fn addition_negative_small() {
assert_eq!(
"3".parse::<BigInt>().unwrap(),
(FieldPrime::from("5") + FieldPrime::from("-2")).value
);
assert_eq!(
"3".parse::<BigInt>().unwrap(),
(FieldPrime::from("5") + &FieldPrime::from("-2")).value
);
}
#[test]
fn addition_negative() {
assert_eq!(
"65348223".parse::<BigInt>().unwrap(),
(FieldPrime::from("65416358") + FieldPrime::from("-68135")).value
);
assert_eq!(
"65348223".parse::<BigInt>().unwrap(),
(FieldPrime::from("65416358") + &FieldPrime::from("-68135")).value
);
}
#[test]
fn subtraction() {
assert_eq!(
"65348223".parse::<BigInt>().unwrap(),
(FieldPrime::from("65416358") - FieldPrime::from("68135")).value
);
assert_eq!(
"65348223".parse::<BigInt>().unwrap(),
(FieldPrime::from("65416358") - &FieldPrime::from("68135")).value
);
}
#[test]
fn subtraction_negative() {
assert_eq!(
"65484493".parse::<BigInt>().unwrap(),
(FieldPrime::from("65416358") - FieldPrime::from("-68135")).value
);
assert_eq!(
"65484493".parse::<BigInt>().unwrap(),
(FieldPrime::from("65416358") - &FieldPrime::from("-68135")).value
);
}
#[test]
fn subtraction_overflow() {
assert_eq!(
"21888242871839275222246405745257275088548364400416034343698204186575743147394"
.parse::<BigInt>()
.unwrap(),
(FieldPrime::from("68135") - FieldPrime::from("65416358")).value
);
assert_eq!(
"21888242871839275222246405745257275088548364400416034343698204186575743147394"
.parse::<BigInt>()
.unwrap(),
(FieldPrime::from("68135") - &FieldPrime::from("65416358")).value
);
}
#[test]
fn multiplication() {
assert_eq!(
"13472".parse::<BigInt>().unwrap(),
(FieldPrime::from("32") * FieldPrime::from("421")).value
);
assert_eq!(
"13472".parse::<BigInt>().unwrap(),
(FieldPrime::from("32") * &FieldPrime::from("421")).value
);
}
#[test]
fn multiplication_negative() {
assert_eq!(
"21888242871839275222246405745257275088548364400416034343698204186575808014369"
.parse::<BigInt>()
.unwrap(),
(FieldPrime::from("54") * FieldPrime::from("-8912")).value
);
assert_eq!(
"21888242871839275222246405745257275088548364400416034343698204186575808014369"
.parse::<BigInt>()
.unwrap(),
(FieldPrime::from("54") * &FieldPrime::from("-8912")).value
);
}
#[test]
fn multiplication_two_negative() {
assert_eq!(
"648".parse::<BigInt>().unwrap(),
(FieldPrime::from("-54") * FieldPrime::from("-12")).value
);
assert_eq!(
"648".parse::<BigInt>().unwrap(),
(FieldPrime::from("-54") * &FieldPrime::from("-12")).value
);
}
#[test]
fn multiplication_overflow() {
assert_eq!(
"6042471409729479866150380306128222617399890671095126975526159292198160466142"
.parse::<BigInt>()
.unwrap(),
(FieldPrime::from(
"21888242871839225222246405785257275088694311157297823662689037894645225727"
) * FieldPrime::from("218882428715392752222464057432572755886923"))
.value
);
assert_eq!(
"6042471409729479866150380306128222617399890671095126975526159292198160466142"
.parse::<BigInt>()
.unwrap(),
(FieldPrime::from(
"21888242871839225222246405785257275088694311157297823662689037894645225727"
) * &FieldPrime::from("218882428715392752222464057432572755886923"))
.value
);
}
#[test]
fn division() {
assert_eq!(
FieldPrime::from(4),
FieldPrime::from(48) / FieldPrime::from(12)
);
assert_eq!(
FieldPrime::from(4),
FieldPrime::from(48) / &FieldPrime::from(12)
);
}
#[test]
fn division_negative() {
let res = FieldPrime::from(-54) / FieldPrime::from(12);
assert_eq!(FieldPrime::from(-54), FieldPrime::from(12) * res);
}
#[test]
fn division_two_negative() {
let res = FieldPrime::from(-12) / FieldPrime::from(-85);
assert_eq!(FieldPrime::from(-12), FieldPrime::from(-85) * res);
}
#[test]
fn pow_small() {
assert_eq!(
"8".parse::<BigInt>().unwrap(),
(FieldPrime::from("2").pow(FieldPrime::from("3"))).value
);
assert_eq!(
"8".parse::<BigInt>().unwrap(),
(FieldPrime::from("2").pow(&FieldPrime::from("3"))).value
);
}
#[test]
fn pow_usize() {
assert_eq!(
"614787626176508399616".parse::<BigInt>().unwrap(),
(FieldPrime::from("54").pow(12)).value
);
}
#[test]
fn pow() {
assert_eq!(
"614787626176508399616".parse::<BigInt>().unwrap(),
(FieldPrime::from("54").pow(FieldPrime::from("12"))).value
);
assert_eq!(
"614787626176508399616".parse::<BigInt>().unwrap(),
(FieldPrime::from("54").pow(&FieldPrime::from("12"))).value
);
}
#[test]
fn pow_negative() {
assert_eq!(
"21888242871839275222246405745257275088548364400416034343686819230535502784513"
.parse::<BigInt>()
.unwrap(),
(FieldPrime::from("-54").pow(FieldPrime::from("11"))).value
);
assert_eq!(
"21888242871839275222246405745257275088548364400416034343686819230535502784513"
.parse::<BigInt>()
.unwrap(),
(FieldPrime::from("-54").pow(&FieldPrime::from("11"))).value
);
}
#[test]
fn serde_ser_deser() {
let serialized = &serialize(&FieldPrime::from("11"), Infinite).unwrap();
let deserialized = deserialize(serialized).unwrap();
assert_eq!(FieldPrime::from("11"), deserialized);
}
#[test]
fn serde_json_ser_deser() {
let serialized = serde_json::to_string(&FieldPrime::from("11")).unwrap();
let deserialized = serde_json::from_str(&serialized).unwrap();
assert_eq!(FieldPrime::from("11"), deserialized);
}
#[test]
fn bytes_ser_deser() {
let fp = FieldPrime::from("101");
let bv = fp.into_byte_vector();
assert_eq!(fp, FieldPrime::from_byte_vector(bv));
}
#[test]
fn dec_string_ser_deser() {
let fp = FieldPrime::from("101");
let bv = fp.to_dec_string();
assert_eq!(fp, FieldPrime::try_from_dec_str(&bv).unwrap());
}
#[test]
fn compact_representation() {
let one = FieldPrime::from(1);
assert_eq!("1", &one.to_compact_dec_string());
let minus_one = FieldPrime::from(0) - one;
assert_eq!("(-1)", &minus_one.to_compact_dec_string());
// (p-1)/2 -> positive notation
let p_minus_one_over_two =
(FieldPrime::from(0) - FieldPrime::from(1)) / FieldPrime::from(2);
assert_eq!(
"10944121435919637611123202872628637544274182200208017171849102093287904247808",
&p_minus_one_over_two.to_compact_dec_string()
);
// (p-1)/2 + 1 -> negative notation (p-1)/2 + 1 - p == (-p+1)/2
let p_minus_one_over_two_plus_one = ((FieldPrime::from(0) - FieldPrime::from(1))
/ FieldPrime::from(2))
+ FieldPrime::from(1);
assert_eq!(
"(-10944121435919637611123202872628637544274182200208017171849102093287904247808)",
&p_minus_one_over_two_plus_one.to_compact_dec_string()
);
}
}
#[test]
fn bigint_assertions() {
let x = BigInt::parse_bytes(b"65", 10).unwrap();
assert_eq!(&x + &x, BigInt::parse_bytes(b"130", 10).unwrap());
assert_eq!(
"1".parse::<BigInt>().unwrap(),
"3".parse::<BigInt>()
.unwrap()
.div_floor(&"2".parse::<BigInt>().unwrap())
);
assert_eq!(
"-2".parse::<BigInt>().unwrap(),
"-3".parse::<BigInt>()
.unwrap()
.div_floor(&"2".parse::<BigInt>().unwrap())
);
}
#[test]
fn test_extended_euclid() {
assert_eq!(
(
ToBigInt::to_bigint(&1).unwrap(),
ToBigInt::to_bigint(&-9).unwrap(),
ToBigInt::to_bigint(&47).unwrap()
),
extended_euclid(
&ToBigInt::to_bigint(&120).unwrap(),
&ToBigInt::to_bigint(&23).unwrap()
)
);
assert_eq!(
(
ToBigInt::to_bigint(&2).unwrap(),
ToBigInt::to_bigint(&2).unwrap(),
ToBigInt::to_bigint(&-11).unwrap()
),
extended_euclid(
&ToBigInt::to_bigint(&122).unwrap(),
&ToBigInt::to_bigint(&22).unwrap()
)
);
assert_eq!(
(
ToBigInt::to_bigint(&2).unwrap(),
ToBigInt::to_bigint(&-9).unwrap(),
ToBigInt::to_bigint(&47).unwrap()
),
extended_euclid(
&ToBigInt::to_bigint(&240).unwrap(),
&ToBigInt::to_bigint(&46).unwrap()
)
);
let (b, s, _) = extended_euclid(&ToBigInt::to_bigint(&253).unwrap(), &*P);
assert_eq!(b, BigInt::one());
let s_field = FieldPrime {
value: &s - s.div_floor(&*P) * &*P,
};
assert_eq!(
FieldPrime::from(
"12717674712096337777352654721552646000065650461901806515903699665717959876900"
),
s_field
);
}
mod bellman {
use super::*;
use ff::Field as FField;
extern crate rand;
use pairing::bn256::Fr;
use rand::{thread_rng, Rng};
use Field;
#[test]
fn fr_to_field_to_fr() {
let rng = &mut thread_rng();
for _ in 0..1000 {
let a: Fr = rng.gen();
assert_eq!(FieldPrime::from_bellman(a).into_bellman(), a);
}
}
#[test]
fn field_to_fr_to_field() {
// use Fr to get a random element
let rng = &mut thread_rng();
for _ in 0..1000 {
let a: Fr = rng.gen();
// now test idempotence
let a = FieldPrime::from_bellman(a);
assert_eq!(FieldPrime::from_bellman(a.clone().into_bellman()), a);
}
}
#[test]
fn one() {
let a = FieldPrime::from(1);
assert_eq!(a.into_bellman(), Fr::one());
}
#[test]
fn zero() {
let a = FieldPrime::from(0);
assert_eq!(a.into_bellman(), Fr::zero());
}
#[test]
fn minus_one() {
let mut a: Fr = Fr::one();
a.negate();
assert_eq!(FieldPrime::from_bellman(a), FieldPrime::from(-1));
}
#[test]
fn add() {
let rng = &mut thread_rng();
let mut a: Fr = rng.gen();
let b: Fr = rng.gen();
let aa = FieldPrime::from_bellman(a);
let bb = FieldPrime::from_bellman(b);
let cc = aa + bb;
a.add_assign(&b);
assert_eq!(FieldPrime::from_bellman(a), cc);
}
}
}

View file

@ -79,7 +79,10 @@ pub trait Field:
fn min_value() -> Self;
/// Returns the largest value that can be represented by this field type.
fn max_value() -> Self;
/// Returns the number of required bits to represent this field type.
/// Returns the largest value `m` such that there exist a number of bits `n` so that any value smaller or equal to
/// m` has a single `n`-bit decomposition
fn max_unique_value() -> Self;
/// Returns the number of bits required to represent any element of this field type.
fn get_required_bits() -> usize;
/// Tries to parse a string into this representation
fn try_from_dec_str<'a>(s: &'a str) -> Result<Self, ()>;
@ -167,6 +170,13 @@ mod prime_field {
value: &*P - ToBigInt::to_bigint(&1).unwrap(),
}
}
fn max_unique_value() -> FieldPrime {
use num_traits::Pow;
FieldPrime {
value: BigInt::from(2u32).pow(Self::get_required_bits() - 1) - 1,
}
}
fn get_required_bits() -> usize {
(*P).bits()
}
@ -414,29 +424,35 @@ mod prime_field {
impl num_traits::CheckedAdd for FieldPrime {
fn checked_add(&self, other: &Self) -> Option<Self> {
use num_traits::Pow;
let res = self.value.clone() + other.value.clone();
let bound = Self::max_unique_value();
let bound = BigInt::from(2u32).pow(Self::get_required_bits() - 1);
assert!(self <= &bound);
assert!(other <= &bound);
// we only go up to 2**(bitwidth - 1) because after that we lose uniqueness of bit decomposition
if res >= bound {
let big_res = self.value.clone() + other.value.clone();
if big_res > bound.value {
None
} else {
Some(FieldPrime { value: res })
Some(FieldPrime { value: big_res })
}
}
}
impl num_traits::CheckedMul for FieldPrime {
fn checked_mul(&self, other: &Self) -> Option<Self> {
use num_traits::Pow;
let res = self.value.clone() * other.value.clone();
let bound = Self::max_unique_value();
assert!(self <= &bound);
assert!(other <= &bound);
let big_res = self.value.clone() * other.value.clone();
// we only go up to 2**(bitwidth - 1) because after that we lose uniqueness of bit decomposition
if res >= BigInt::from(2u32).pow(Self::get_required_bits() - 1) {
if big_res > bound.value {
None
} else {
Some(FieldPrime { value: res })
Some(FieldPrime { value: big_res })
}
}
}