diff --git a/Cargo.lock b/Cargo.lock index 7e67221a..1a74be51 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2109,12 +2109,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "reduce" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16d2dc47b68ac15ea328cd7ebe01d7d512ed29787f7d534ad2a3c341328b35d7" - [[package]] name = "regex" version = "0.2.11" @@ -3127,7 +3121,6 @@ dependencies = [ "num 0.1.42", "num-bigint 0.2.6", "pretty_assertions 0.6.1", - "reduce", "serde", "serde_json", "typed-arena", @@ -3215,7 +3208,7 @@ dependencies = [ [[package]] name = "zokrates_js" -version = "1.1.2" +version = "1.1.3" dependencies = [ "console_error_panic_hook", "indexmap", diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index c1f316f2..744b685f 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -18,7 +18,6 @@ num = { version = "0.1.36", default-features = false } num-bigint = { version = "0.2", default-features = false } lazy_static = "1.4" typed-arena = "1.4.1" -reduce = "0.1.1" # serialization and deserialization serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0", features = ["preserve_order"] } diff --git a/zokrates_core/src/static_analysis/boolean_array_comparator.rs b/zokrates_core/src/static_analysis/boolean_array_comparator.rs new file mode 100644 index 00000000..90b83f28 --- /dev/null +++ b/zokrates_core/src/static_analysis/boolean_array_comparator.rs @@ -0,0 +1,320 @@ +use zokrates_ast::typed::{ + folder::*, ArrayExpressionInner, ArrayValue, BooleanExpression, ConditionalExpression, + ConditionalKind, EqExpression, FieldElementExpression, SelectExpression, Type, TypedExpression, + TypedProgram, UExpressionInner, +}; +use zokrates_field::Field; + +#[derive(Default)] +pub struct BooleanArrayComparator; + +impl BooleanArrayComparator { + pub fn simplify(p: TypedProgram) -> TypedProgram { + Self::default().fold_program(p) + } +} + +impl<'ast, T: Field> Folder<'ast, T> for BooleanArrayComparator { + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> BooleanExpression<'ast, T> { + match e { + BooleanExpression::ArrayEq(e) => match e.left.inner_type() { + Type::Boolean => { + let len = e.left.size(); + let len = match len.as_inner() { + UExpressionInner::Value(v) => *v as usize, + _ => unreachable!("array size should be known"), + }; + + let chunk_size = T::get_required_bits() as usize - 1; + + let left_elements: Vec<_> = (0..len) + .map(|i| { + BooleanExpression::Select(SelectExpression::new( + *e.left.clone(), + (i as u32).into(), + )) + }) + .collect(); + let right_elements: Vec<_> = (0..len) + .map(|i| { + BooleanExpression::Select(SelectExpression::new( + *e.right.clone(), + (i as u32).into(), + )) + }) + .collect(); + + let process = |elements: &[BooleanExpression<'ast, T>]| { + elements + .chunks(chunk_size) + .map(|chunk| { + TypedExpression::from( + chunk + .iter() + .rev() + .enumerate() + .rev() + .map(|(index, c)| { + FieldElementExpression::Conditional( + ConditionalExpression::new( + c.clone(), + FieldElementExpression::Pow( + box FieldElementExpression::Number( + T::from(2), + ), + box (index as u32).into(), + ), + T::zero().into(), + ConditionalKind::Ternary, + ), + ) + }) + .fold(None, |acc, e| match acc { + Some(acc) => { + Some(FieldElementExpression::Add(box acc, box e)) + } + None => Some(e), + }) + .unwrap_or_else(|| { + FieldElementExpression::Number(T::zero()) + }), + ) + .into() + }) + .collect() + }; + + let left: Vec<_> = process(&left_elements); + + let right: Vec<_> = process(&right_elements); + + let chunk_count = left.len(); + + BooleanExpression::ArrayEq(EqExpression::new( + ArrayExpressionInner::Value(ArrayValue(left)) + .annotate(Type::FieldElement, chunk_count as u32), + ArrayExpressionInner::Value(ArrayValue(right)) + .annotate(Type::FieldElement, chunk_count as u32), + )) + } + _ => fold_boolean_expression(self, BooleanExpression::ArrayEq(e)), + }, + e => fold_boolean_expression(self, e), + } + } +} + +#[cfg(test)] +mod tests { + use num::Zero; + use zokrates_ast::typed::{ + ArrayExpressionInner, ArrayValue, BooleanExpression, ConditionalExpression, + ConditionalKind, EqExpression, FieldElementExpression, SelectExpression, Type, + TypedExpression, UBitwidth, UExpressionInner, + }; + use zokrates_field::DummyCurveField; + + use super::*; + + #[test] + fn simplify_short_array_eq() { + // x == y // type bool[2] + // should become + // [x[0] ? 2**1 : 0 + x[1] ? 2**0 : 0] == [y[0] ? 2**1 : 0 + y[1] ? 2**0 : 0] + // a single field is sufficient, as the prime we're working with is 3 bits long, so we can pack up to 2 bits + + let e: BooleanExpression = BooleanExpression::ArrayEq(EqExpression::new( + ArrayExpressionInner::Identifier("x".into()).annotate(Type::Boolean, 2u32), + ArrayExpressionInner::Identifier("y".into()).annotate(Type::Boolean, 2u32), + )); + + let expected = BooleanExpression::ArrayEq(EqExpression::new( + ArrayExpressionInner::Value(ArrayValue(vec![TypedExpression::from( + FieldElementExpression::Add( + box FieldElementExpression::Conditional(ConditionalExpression::new( + BooleanExpression::Select(SelectExpression::new( + ArrayExpressionInner::Identifier("x".into()) + .annotate(Type::Boolean, 2u32), + UExpressionInner::Value(0).annotate(UBitwidth::B32), + )), + FieldElementExpression::Pow( + box FieldElementExpression::Number(DummyCurveField::from(2)), + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + ), + FieldElementExpression::Number(DummyCurveField::zero()), + ConditionalKind::Ternary, + )), + box FieldElementExpression::Conditional(ConditionalExpression::new( + BooleanExpression::Select(SelectExpression::new( + ArrayExpressionInner::Identifier("x".into()) + .annotate(Type::Boolean, 2u32), + UExpressionInner::Value(1).annotate(UBitwidth::B32), + )), + FieldElementExpression::Pow( + box FieldElementExpression::Number(DummyCurveField::from(2)), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ), + FieldElementExpression::Number(DummyCurveField::zero()), + ConditionalKind::Ternary, + )), + ), + ) + .into()])) + .annotate(Type::FieldElement, 1u32), + ArrayExpressionInner::Value(ArrayValue(vec![TypedExpression::from( + FieldElementExpression::Add( + box FieldElementExpression::Conditional(ConditionalExpression::new( + BooleanExpression::Select(SelectExpression::new( + ArrayExpressionInner::Identifier("y".into()) + .annotate(Type::Boolean, 2u32), + UExpressionInner::Value(0).annotate(UBitwidth::B32), + )), + FieldElementExpression::Pow( + box FieldElementExpression::Number(DummyCurveField::from(2)), + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + ), + FieldElementExpression::Number(DummyCurveField::zero()), + ConditionalKind::Ternary, + )), + box FieldElementExpression::Conditional(ConditionalExpression::new( + BooleanExpression::Select(SelectExpression::new( + ArrayExpressionInner::Identifier("y".into()) + .annotate(Type::Boolean, 2u32), + UExpressionInner::Value(1).annotate(UBitwidth::B32), + )), + FieldElementExpression::Pow( + box FieldElementExpression::Number(DummyCurveField::from(2)), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ), + FieldElementExpression::Number(DummyCurveField::zero()), + ConditionalKind::Ternary, + )), + ), + ) + .into()])) + .annotate(Type::FieldElement, 1u32), + )); + + let res = BooleanArrayComparator::default().fold_boolean_expression(e); + + assert_eq!(res, expected); + } + + #[test] + fn simplify_long_array_eq() { + // x == y // type bool[3] + // should become + // [x[0] ? 2**2 : 0 + x[1] ? 2**1 : 0, x[2] ? 2**0 : 0] == [y[0] ? 2**2 : 0 + y[1] ? 2**1 : 0 y[2] ? 2**0 : 0] + + let e: BooleanExpression = BooleanExpression::ArrayEq(EqExpression::new( + ArrayExpressionInner::Identifier("x".into()).annotate(Type::Boolean, 3u32), + ArrayExpressionInner::Identifier("y".into()).annotate(Type::Boolean, 3u32), + )); + + let expected = BooleanExpression::ArrayEq(EqExpression::new( + ArrayExpressionInner::Value(ArrayValue(vec![ + TypedExpression::from(FieldElementExpression::Add( + box FieldElementExpression::Conditional(ConditionalExpression::new( + BooleanExpression::Select(SelectExpression::new( + ArrayExpressionInner::Identifier("x".into()) + .annotate(Type::Boolean, 3u32), + UExpressionInner::Value(0).annotate(UBitwidth::B32), + )), + FieldElementExpression::Pow( + box FieldElementExpression::Number(DummyCurveField::from(2)), + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + ), + FieldElementExpression::Number(DummyCurveField::zero()), + ConditionalKind::Ternary, + )), + box FieldElementExpression::Conditional(ConditionalExpression::new( + BooleanExpression::Select(SelectExpression::new( + ArrayExpressionInner::Identifier("x".into()) + .annotate(Type::Boolean, 3u32), + UExpressionInner::Value(1).annotate(UBitwidth::B32), + )), + FieldElementExpression::Pow( + box FieldElementExpression::Number(DummyCurveField::from(2)), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ), + FieldElementExpression::Number(DummyCurveField::zero()), + ConditionalKind::Ternary, + )), + )) + .into(), + TypedExpression::from(FieldElementExpression::Conditional( + ConditionalExpression::new( + BooleanExpression::Select(SelectExpression::new( + ArrayExpressionInner::Identifier("x".into()) + .annotate(Type::Boolean, 3u32), + UExpressionInner::Value(2).annotate(UBitwidth::B32), + )), + FieldElementExpression::Pow( + box FieldElementExpression::Number(DummyCurveField::from(2)), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ), + FieldElementExpression::Number(DummyCurveField::zero()), + ConditionalKind::Ternary, + ), + )) + .into(), + ])) + .annotate(Type::FieldElement, 2u32), + ArrayExpressionInner::Value(ArrayValue(vec![ + TypedExpression::from(FieldElementExpression::Add( + box FieldElementExpression::Conditional(ConditionalExpression::new( + BooleanExpression::Select(SelectExpression::new( + ArrayExpressionInner::Identifier("y".into()) + .annotate(Type::Boolean, 3u32), + UExpressionInner::Value(0).annotate(UBitwidth::B32), + )), + FieldElementExpression::Pow( + box FieldElementExpression::Number(DummyCurveField::from(2)), + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + ), + FieldElementExpression::Number(DummyCurveField::zero()), + ConditionalKind::Ternary, + )), + box FieldElementExpression::Conditional(ConditionalExpression::new( + BooleanExpression::Select(SelectExpression::new( + ArrayExpressionInner::Identifier("y".into()) + .annotate(Type::Boolean, 3u32), + UExpressionInner::Value(1).annotate(UBitwidth::B32), + )), + FieldElementExpression::Pow( + box FieldElementExpression::Number(DummyCurveField::from(2)), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ), + FieldElementExpression::Number(DummyCurveField::zero()), + ConditionalKind::Ternary, + )), + )) + .into(), + TypedExpression::from(FieldElementExpression::Conditional( + ConditionalExpression::new( + BooleanExpression::Select(SelectExpression::new( + ArrayExpressionInner::Identifier("y".into()) + .annotate(Type::Boolean, 3u32), + UExpressionInner::Value(2).annotate(UBitwidth::B32), + )), + FieldElementExpression::Pow( + box FieldElementExpression::Number(DummyCurveField::from(2)), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ), + FieldElementExpression::Number(DummyCurveField::zero()), + ConditionalKind::Ternary, + ), + )) + .into(), + ])) + .annotate(Type::FieldElement, 2u32), + )); + + let res = BooleanArrayComparator::default().fold_boolean_expression(e); + + assert_eq!(res, expected); + } +} diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index e6ed76b2..bf394dc2 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -4,6 +4,7 @@ //! @author Thibaut Schaeffer //! @date 2018 +mod boolean_array_comparator; mod branch_isolator; mod condition_redefiner; mod constant_argument_checker; @@ -21,6 +22,7 @@ mod uint_optimizer; mod variable_write_remover; mod zir_propagation; +use self::boolean_array_comparator::BooleanArrayComparator; use self::branch_isolator::Isolator; use self::condition_redefiner::ConditionRedefiner; use self::constant_argument_checker::ConstantArgumentChecker; @@ -146,6 +148,11 @@ pub fn analyse<'ast, T: Field>( let r = Propagator::propagate(r).map_err(Error::from)?; log::trace!("\n{}", r); + // simplify boolean array comparisons + log::debug!("Static analyser: Simplify boolean array comparisons"); + let r = BooleanArrayComparator::simplify(r); + log::trace!("\n{}", r); + // remove assignment to variable index log::debug!("Static analyser: Remove variable index"); let r = VariableWriteRemover::apply(r); diff --git a/zokrates_core_test/tests/tests/arrays/boolean_array_equality.json b/zokrates_core_test/tests/tests/arrays/boolean_array_equality.json new file mode 100644 index 00000000..64bcaaab --- /dev/null +++ b/zokrates_core_test/tests/tests/arrays/boolean_array_equality.json @@ -0,0 +1,7 @@ +{ + "entry_point": "./tests/tests/arrays/boolean_array_equality.zok", + "curves": ["Bn128"], + "max_constraint_count": 1005, + "tests": [] + } + \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/arrays/boolean_array_equality.zok b/zokrates_core_test/tests/tests/arrays/boolean_array_equality.zok new file mode 100644 index 00000000..a0a864d3 --- /dev/null +++ b/zokrates_core_test/tests/tests/arrays/boolean_array_equality.zok @@ -0,0 +1,8 @@ +// this should cost 1000 for input constraint, and then 1 per chunk of 253 booleans (for bn128), so here 5 +// total 1005 + +const u32 SIZE = 1000; + +def main(bool[SIZE] a) { + assert(a == [true; SIZE]); +} \ No newline at end of file diff --git a/zokrates_field/src/dummy_curve.rs b/zokrates_field/src/dummy_curve.rs new file mode 100644 index 00000000..5d3aed4a --- /dev/null +++ b/zokrates_field/src/dummy_curve.rs @@ -0,0 +1,253 @@ +use crate::{Field, Pow}; +use num_bigint::BigUint; +use num_traits::{CheckedDiv, One, Zero}; +use serde_derive::{Deserialize, Serialize}; +use std::convert::{From, TryFrom}; +use std::fmt; +use std::fmt::Debug; +use std::hash::Hash; +use std::ops::{Add, Div, Mul, Sub}; + +const _PRIME: u8 = 7; + +#[derive(Default, Debug, Hash, Clone, PartialOrd, Ord, Serialize, Deserialize, PartialEq, Eq)] +pub struct FieldPrime { + v: u8, +} + +impl fmt::Display for FieldPrime { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.v) + } +} + +impl From for FieldPrime { + fn from(_: u128) -> Self { + unimplemented!() + } +} + +impl From for FieldPrime { + fn from(_: u64) -> Self { + unimplemented!() + } +} + +impl From for FieldPrime { + fn from(_: u32) -> Self { + unimplemented!() + } +} + +impl From for FieldPrime { + fn from(_: u16) -> Self { + unimplemented!() + } +} + +impl From for FieldPrime { + fn from(num: u8) -> Self { + FieldPrime { v: num } + } +} + +impl From for FieldPrime { + fn from(_: usize) -> Self { + unimplemented!() + } +} + +impl From for FieldPrime { + fn from(_: bool) -> Self { + unimplemented!() + } +} + +impl From for FieldPrime { + fn from(num: i32) -> Self { + assert!(num < _PRIME as i32); + assert!(num >= 0); + Self::from(num as u8) + } +} + +impl TryFrom for FieldPrime { + type Error = (); + + fn try_from(_: BigUint) -> Result { + unimplemented!() + } +} + +impl Zero for FieldPrime { + fn zero() -> FieldPrime { + FieldPrime { v: 0 } + } + fn is_zero(&self) -> bool { + self.v.is_zero() + } +} + +impl One for FieldPrime { + fn one() -> FieldPrime { + FieldPrime { v: 1 } + } +} + +impl Add for FieldPrime { + type Output = FieldPrime; + + fn add(self, _: FieldPrime) -> FieldPrime { + unimplemented!() + } +} + +impl<'a> Add<&'a FieldPrime> for FieldPrime { + type Output = FieldPrime; + + fn add(self, _: &FieldPrime) -> FieldPrime { + unimplemented!() + } +} + +impl Sub for FieldPrime { + type Output = FieldPrime; + + fn sub(self, _: FieldPrime) -> FieldPrime { + unimplemented!() + } +} + +impl<'a> Sub<&'a FieldPrime> for FieldPrime { + type Output = FieldPrime; + + fn sub(self, _: &FieldPrime) -> FieldPrime { + unimplemented!() + } +} + +impl Mul for FieldPrime { + type Output = FieldPrime; + + fn mul(self, _: FieldPrime) -> FieldPrime { + unimplemented!() + } +} + +impl<'a> Mul<&'a FieldPrime> for FieldPrime { + type Output = FieldPrime; + + fn mul(self, _: &FieldPrime) -> FieldPrime { + unimplemented!() + } +} + +impl CheckedDiv for FieldPrime { + fn checked_div(&self, _: &FieldPrime) -> Option { + unimplemented!() + } +} + +impl Div for FieldPrime { + type Output = FieldPrime; + + fn div(self, _: FieldPrime) -> FieldPrime { + unimplemented!() + } +} + +impl<'a> Div<&'a FieldPrime> for FieldPrime { + type Output = FieldPrime; + + fn div(self, _: &FieldPrime) -> FieldPrime { + unimplemented!() + } +} + +impl Pow for FieldPrime { + type Output = FieldPrime; + + fn pow(self, _: usize) -> FieldPrime { + unimplemented!() + } +} + +impl num_traits::CheckedAdd for FieldPrime { + fn checked_add(&self, _: &Self) -> Option { + unimplemented!() + } +} + +impl num_traits::CheckedMul for FieldPrime { + fn checked_mul(&self, _: &Self) -> Option { + unimplemented!() + } +} + +impl Field for FieldPrime { + const G2_TYPE: crate::G2Type = crate::G2Type::Fq2; + + fn to_byte_vector(&self) -> Vec { + unimplemented!() + } + + fn from_byte_vector(_: Vec) -> Self { + unimplemented!() + } + + fn to_dec_string(&self) -> String { + unimplemented!() + } + + fn inverse_mul(&self) -> Option { + unimplemented!() + } + + fn min_value() -> Self { + unimplemented!() + } + + fn max_value() -> Self { + unimplemented!() + } + + fn max_unique_value() -> Self { + unimplemented!() + } + + fn to_bits_be(&self) -> Vec { + unimplemented!() + } + + fn get_required_bits() -> usize { + 3 // ceil(log2(7)) + } + + fn try_from_dec_str(_: &str) -> Result { + unimplemented!() + } + + fn try_from_str(_: &str, _: u32) -> Result { + unimplemented!() + } + + fn to_compact_dec_string(&self) -> String { + unimplemented!() + } + + fn id() -> [u8; 4] { + unimplemented!() + } + + fn name() -> &'static str { + unimplemented!() + } + + fn bits(&self) -> u32 { + unimplemented!() + } + + fn to_biguint(&self) -> num_bigint::BigUint { + unimplemented!() + } +} diff --git a/zokrates_field/src/lib.rs b/zokrates_field/src/lib.rs index dc1e6b90..38f76905 100644 --- a/zokrates_field/src/lib.rs +++ b/zokrates_field/src/lib.rs @@ -632,8 +632,10 @@ pub mod bls12_377; pub mod bls12_381; pub mod bn128; pub mod bw6_761; +pub mod dummy_curve; pub use bls12_377::FieldPrime as Bls12_377Field; pub use bls12_381::FieldPrime as Bls12_381Field; pub use bn128::FieldPrime as Bn128Field; pub use bw6_761::FieldPrime as Bw6_761Field; +pub use dummy_curve::FieldPrime as DummyCurveField;