1
0
Fork 0
mirror of synced 2025-09-23 20:28:36 +00:00

implement better check, add test

This commit is contained in:
schaeff 2022-09-19 18:37:25 +02:00
parent 9d04dca66e
commit 4ab9cdb1b0
8 changed files with 598 additions and 9 deletions

9
Cargo.lock generated
View file

@ -2109,12 +2109,6 @@ dependencies = [
"thiserror",
]
[[package]]
name = "reduce"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16d2dc47b68ac15ea328cd7ebe01d7d512ed29787f7d534ad2a3c341328b35d7"
[[package]]
name = "regex"
version = "0.2.11"
@ -3127,7 +3121,6 @@ dependencies = [
"num 0.1.42",
"num-bigint 0.2.6",
"pretty_assertions 0.6.1",
"reduce",
"serde",
"serde_json",
"typed-arena",
@ -3215,7 +3208,7 @@ dependencies = [
[[package]]
name = "zokrates_js"
version = "1.1.2"
version = "1.1.3"
dependencies = [
"console_error_panic_hook",
"indexmap",

View file

@ -18,7 +18,6 @@ num = { version = "0.1.36", default-features = false }
num-bigint = { version = "0.2", default-features = false }
lazy_static = "1.4"
typed-arena = "1.4.1"
reduce = "0.1.1"
# serialization and deserialization
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0", features = ["preserve_order"] }

View file

@ -0,0 +1,320 @@
use zokrates_ast::typed::{
folder::*, ArrayExpressionInner, ArrayValue, BooleanExpression, ConditionalExpression,
ConditionalKind, EqExpression, FieldElementExpression, SelectExpression, Type, TypedExpression,
TypedProgram, UExpressionInner,
};
use zokrates_field::Field;
#[derive(Default)]
pub struct BooleanArrayComparator;
impl BooleanArrayComparator {
pub fn simplify<T: Field>(p: TypedProgram<T>) -> TypedProgram<T> {
Self::default().fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for BooleanArrayComparator {
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::ArrayEq(e) => match e.left.inner_type() {
Type::Boolean => {
let len = e.left.size();
let len = match len.as_inner() {
UExpressionInner::Value(v) => *v as usize,
_ => unreachable!("array size should be known"),
};
let chunk_size = T::get_required_bits() as usize - 1;
let left_elements: Vec<_> = (0..len)
.map(|i| {
BooleanExpression::Select(SelectExpression::new(
*e.left.clone(),
(i as u32).into(),
))
})
.collect();
let right_elements: Vec<_> = (0..len)
.map(|i| {
BooleanExpression::Select(SelectExpression::new(
*e.right.clone(),
(i as u32).into(),
))
})
.collect();
let process = |elements: &[BooleanExpression<'ast, T>]| {
elements
.chunks(chunk_size)
.map(|chunk| {
TypedExpression::from(
chunk
.iter()
.rev()
.enumerate()
.rev()
.map(|(index, c)| {
FieldElementExpression::Conditional(
ConditionalExpression::new(
c.clone(),
FieldElementExpression::Pow(
box FieldElementExpression::Number(
T::from(2),
),
box (index as u32).into(),
),
T::zero().into(),
ConditionalKind::Ternary,
),
)
})
.fold(None, |acc, e| match acc {
Some(acc) => {
Some(FieldElementExpression::Add(box acc, box e))
}
None => Some(e),
})
.unwrap_or_else(|| {
FieldElementExpression::Number(T::zero())
}),
)
.into()
})
.collect()
};
let left: Vec<_> = process(&left_elements);
let right: Vec<_> = process(&right_elements);
let chunk_count = left.len();
BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Value(ArrayValue(left))
.annotate(Type::FieldElement, chunk_count as u32),
ArrayExpressionInner::Value(ArrayValue(right))
.annotate(Type::FieldElement, chunk_count as u32),
))
}
_ => fold_boolean_expression(self, BooleanExpression::ArrayEq(e)),
},
e => fold_boolean_expression(self, e),
}
}
}
#[cfg(test)]
mod tests {
use num::Zero;
use zokrates_ast::typed::{
ArrayExpressionInner, ArrayValue, BooleanExpression, ConditionalExpression,
ConditionalKind, EqExpression, FieldElementExpression, SelectExpression, Type,
TypedExpression, UBitwidth, UExpressionInner,
};
use zokrates_field::DummyCurveField;
use super::*;
#[test]
fn simplify_short_array_eq() {
// x == y // type bool[2]
// should become
// [x[0] ? 2**1 : 0 + x[1] ? 2**0 : 0] == [y[0] ? 2**1 : 0 + y[1] ? 2**0 : 0]
// a single field is sufficient, as the prime we're working with is 3 bits long, so we can pack up to 2 bits
let e: BooleanExpression<DummyCurveField> = BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Identifier("x".into()).annotate(Type::Boolean, 2u32),
ArrayExpressionInner::Identifier("y".into()).annotate(Type::Boolean, 2u32),
));
let expected = BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Value(ArrayValue(vec![TypedExpression::from(
FieldElementExpression::Add(
box FieldElementExpression::Conditional(ConditionalExpression::new(
BooleanExpression::Select(SelectExpression::new(
ArrayExpressionInner::Identifier("x".into())
.annotate(Type::Boolean, 2u32),
UExpressionInner::Value(0).annotate(UBitwidth::B32),
)),
FieldElementExpression::Pow(
box FieldElementExpression::Number(DummyCurveField::from(2)),
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
),
FieldElementExpression::Number(DummyCurveField::zero()),
ConditionalKind::Ternary,
)),
box FieldElementExpression::Conditional(ConditionalExpression::new(
BooleanExpression::Select(SelectExpression::new(
ArrayExpressionInner::Identifier("x".into())
.annotate(Type::Boolean, 2u32),
UExpressionInner::Value(1).annotate(UBitwidth::B32),
)),
FieldElementExpression::Pow(
box FieldElementExpression::Number(DummyCurveField::from(2)),
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
),
FieldElementExpression::Number(DummyCurveField::zero()),
ConditionalKind::Ternary,
)),
),
)
.into()]))
.annotate(Type::FieldElement, 1u32),
ArrayExpressionInner::Value(ArrayValue(vec![TypedExpression::from(
FieldElementExpression::Add(
box FieldElementExpression::Conditional(ConditionalExpression::new(
BooleanExpression::Select(SelectExpression::new(
ArrayExpressionInner::Identifier("y".into())
.annotate(Type::Boolean, 2u32),
UExpressionInner::Value(0).annotate(UBitwidth::B32),
)),
FieldElementExpression::Pow(
box FieldElementExpression::Number(DummyCurveField::from(2)),
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
),
FieldElementExpression::Number(DummyCurveField::zero()),
ConditionalKind::Ternary,
)),
box FieldElementExpression::Conditional(ConditionalExpression::new(
BooleanExpression::Select(SelectExpression::new(
ArrayExpressionInner::Identifier("y".into())
.annotate(Type::Boolean, 2u32),
UExpressionInner::Value(1).annotate(UBitwidth::B32),
)),
FieldElementExpression::Pow(
box FieldElementExpression::Number(DummyCurveField::from(2)),
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
),
FieldElementExpression::Number(DummyCurveField::zero()),
ConditionalKind::Ternary,
)),
),
)
.into()]))
.annotate(Type::FieldElement, 1u32),
));
let res = BooleanArrayComparator::default().fold_boolean_expression(e);
assert_eq!(res, expected);
}
#[test]
fn simplify_long_array_eq() {
// x == y // type bool[3]
// should become
// [x[0] ? 2**2 : 0 + x[1] ? 2**1 : 0, x[2] ? 2**0 : 0] == [y[0] ? 2**2 : 0 + y[1] ? 2**1 : 0 y[2] ? 2**0 : 0]
let e: BooleanExpression<DummyCurveField> = BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Identifier("x".into()).annotate(Type::Boolean, 3u32),
ArrayExpressionInner::Identifier("y".into()).annotate(Type::Boolean, 3u32),
));
let expected = BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Value(ArrayValue(vec![
TypedExpression::from(FieldElementExpression::Add(
box FieldElementExpression::Conditional(ConditionalExpression::new(
BooleanExpression::Select(SelectExpression::new(
ArrayExpressionInner::Identifier("x".into())
.annotate(Type::Boolean, 3u32),
UExpressionInner::Value(0).annotate(UBitwidth::B32),
)),
FieldElementExpression::Pow(
box FieldElementExpression::Number(DummyCurveField::from(2)),
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
),
FieldElementExpression::Number(DummyCurveField::zero()),
ConditionalKind::Ternary,
)),
box FieldElementExpression::Conditional(ConditionalExpression::new(
BooleanExpression::Select(SelectExpression::new(
ArrayExpressionInner::Identifier("x".into())
.annotate(Type::Boolean, 3u32),
UExpressionInner::Value(1).annotate(UBitwidth::B32),
)),
FieldElementExpression::Pow(
box FieldElementExpression::Number(DummyCurveField::from(2)),
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
),
FieldElementExpression::Number(DummyCurveField::zero()),
ConditionalKind::Ternary,
)),
))
.into(),
TypedExpression::from(FieldElementExpression::Conditional(
ConditionalExpression::new(
BooleanExpression::Select(SelectExpression::new(
ArrayExpressionInner::Identifier("x".into())
.annotate(Type::Boolean, 3u32),
UExpressionInner::Value(2).annotate(UBitwidth::B32),
)),
FieldElementExpression::Pow(
box FieldElementExpression::Number(DummyCurveField::from(2)),
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
),
FieldElementExpression::Number(DummyCurveField::zero()),
ConditionalKind::Ternary,
),
))
.into(),
]))
.annotate(Type::FieldElement, 2u32),
ArrayExpressionInner::Value(ArrayValue(vec![
TypedExpression::from(FieldElementExpression::Add(
box FieldElementExpression::Conditional(ConditionalExpression::new(
BooleanExpression::Select(SelectExpression::new(
ArrayExpressionInner::Identifier("y".into())
.annotate(Type::Boolean, 3u32),
UExpressionInner::Value(0).annotate(UBitwidth::B32),
)),
FieldElementExpression::Pow(
box FieldElementExpression::Number(DummyCurveField::from(2)),
box UExpressionInner::Value(1).annotate(UBitwidth::B32),
),
FieldElementExpression::Number(DummyCurveField::zero()),
ConditionalKind::Ternary,
)),
box FieldElementExpression::Conditional(ConditionalExpression::new(
BooleanExpression::Select(SelectExpression::new(
ArrayExpressionInner::Identifier("y".into())
.annotate(Type::Boolean, 3u32),
UExpressionInner::Value(1).annotate(UBitwidth::B32),
)),
FieldElementExpression::Pow(
box FieldElementExpression::Number(DummyCurveField::from(2)),
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
),
FieldElementExpression::Number(DummyCurveField::zero()),
ConditionalKind::Ternary,
)),
))
.into(),
TypedExpression::from(FieldElementExpression::Conditional(
ConditionalExpression::new(
BooleanExpression::Select(SelectExpression::new(
ArrayExpressionInner::Identifier("y".into())
.annotate(Type::Boolean, 3u32),
UExpressionInner::Value(2).annotate(UBitwidth::B32),
)),
FieldElementExpression::Pow(
box FieldElementExpression::Number(DummyCurveField::from(2)),
box UExpressionInner::Value(0).annotate(UBitwidth::B32),
),
FieldElementExpression::Number(DummyCurveField::zero()),
ConditionalKind::Ternary,
),
))
.into(),
]))
.annotate(Type::FieldElement, 2u32),
));
let res = BooleanArrayComparator::default().fold_boolean_expression(e);
assert_eq!(res, expected);
}
}

View file

@ -4,6 +4,7 @@
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
mod boolean_array_comparator;
mod branch_isolator;
mod condition_redefiner;
mod constant_argument_checker;
@ -21,6 +22,7 @@ mod uint_optimizer;
mod variable_write_remover;
mod zir_propagation;
use self::boolean_array_comparator::BooleanArrayComparator;
use self::branch_isolator::Isolator;
use self::condition_redefiner::ConditionRedefiner;
use self::constant_argument_checker::ConstantArgumentChecker;
@ -146,6 +148,11 @@ pub fn analyse<'ast, T: Field>(
let r = Propagator::propagate(r).map_err(Error::from)?;
log::trace!("\n{}", r);
// simplify boolean array comparisons
log::debug!("Static analyser: Simplify boolean array comparisons");
let r = BooleanArrayComparator::simplify(r);
log::trace!("\n{}", r);
// remove assignment to variable index
log::debug!("Static analyser: Remove variable index");
let r = VariableWriteRemover::apply(r);

View file

@ -0,0 +1,7 @@
{
"entry_point": "./tests/tests/arrays/boolean_array_equality.zok",
"curves": ["Bn128"],
"max_constraint_count": 1005,
"tests": []
}

View file

@ -0,0 +1,8 @@
// this should cost 1000 for input constraint, and then 1 per chunk of 253 booleans (for bn128), so here 5
// total 1005
const u32 SIZE = 1000;
def main(bool[SIZE] a) {
assert(a == [true; SIZE]);
}

View file

@ -0,0 +1,253 @@
use crate::{Field, Pow};
use num_bigint::BigUint;
use num_traits::{CheckedDiv, One, Zero};
use serde_derive::{Deserialize, Serialize};
use std::convert::{From, TryFrom};
use std::fmt;
use std::fmt::Debug;
use std::hash::Hash;
use std::ops::{Add, Div, Mul, Sub};
const _PRIME: u8 = 7;
#[derive(Default, Debug, Hash, Clone, PartialOrd, Ord, Serialize, Deserialize, PartialEq, Eq)]
pub struct FieldPrime {
v: u8,
}
impl fmt::Display for FieldPrime {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.v)
}
}
impl From<u128> for FieldPrime {
fn from(_: u128) -> Self {
unimplemented!()
}
}
impl From<u64> for FieldPrime {
fn from(_: u64) -> Self {
unimplemented!()
}
}
impl From<u32> for FieldPrime {
fn from(_: u32) -> Self {
unimplemented!()
}
}
impl From<u16> for FieldPrime {
fn from(_: u16) -> Self {
unimplemented!()
}
}
impl From<u8> for FieldPrime {
fn from(num: u8) -> Self {
FieldPrime { v: num }
}
}
impl From<usize> for FieldPrime {
fn from(_: usize) -> Self {
unimplemented!()
}
}
impl From<bool> for FieldPrime {
fn from(_: bool) -> Self {
unimplemented!()
}
}
impl From<i32> for FieldPrime {
fn from(num: i32) -> Self {
assert!(num < _PRIME as i32);
assert!(num >= 0);
Self::from(num as u8)
}
}
impl TryFrom<BigUint> for FieldPrime {
type Error = ();
fn try_from(_: BigUint) -> Result<Self, ()> {
unimplemented!()
}
}
impl Zero for FieldPrime {
fn zero() -> FieldPrime {
FieldPrime { v: 0 }
}
fn is_zero(&self) -> bool {
self.v.is_zero()
}
}
impl One for FieldPrime {
fn one() -> FieldPrime {
FieldPrime { v: 1 }
}
}
impl Add<FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn add(self, _: FieldPrime) -> FieldPrime {
unimplemented!()
}
}
impl<'a> Add<&'a FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn add(self, _: &FieldPrime) -> FieldPrime {
unimplemented!()
}
}
impl Sub<FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn sub(self, _: FieldPrime) -> FieldPrime {
unimplemented!()
}
}
impl<'a> Sub<&'a FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn sub(self, _: &FieldPrime) -> FieldPrime {
unimplemented!()
}
}
impl Mul<FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn mul(self, _: FieldPrime) -> FieldPrime {
unimplemented!()
}
}
impl<'a> Mul<&'a FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn mul(self, _: &FieldPrime) -> FieldPrime {
unimplemented!()
}
}
impl CheckedDiv for FieldPrime {
fn checked_div(&self, _: &FieldPrime) -> Option<FieldPrime> {
unimplemented!()
}
}
impl Div<FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn div(self, _: FieldPrime) -> FieldPrime {
unimplemented!()
}
}
impl<'a> Div<&'a FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn div(self, _: &FieldPrime) -> FieldPrime {
unimplemented!()
}
}
impl Pow<usize> for FieldPrime {
type Output = FieldPrime;
fn pow(self, _: usize) -> FieldPrime {
unimplemented!()
}
}
impl num_traits::CheckedAdd for FieldPrime {
fn checked_add(&self, _: &Self) -> Option<Self> {
unimplemented!()
}
}
impl num_traits::CheckedMul for FieldPrime {
fn checked_mul(&self, _: &Self) -> Option<Self> {
unimplemented!()
}
}
impl Field for FieldPrime {
const G2_TYPE: crate::G2Type = crate::G2Type::Fq2;
fn to_byte_vector(&self) -> Vec<u8> {
unimplemented!()
}
fn from_byte_vector(_: Vec<u8>) -> Self {
unimplemented!()
}
fn to_dec_string(&self) -> String {
unimplemented!()
}
fn inverse_mul(&self) -> Option<Self> {
unimplemented!()
}
fn min_value() -> Self {
unimplemented!()
}
fn max_value() -> Self {
unimplemented!()
}
fn max_unique_value() -> Self {
unimplemented!()
}
fn to_bits_be(&self) -> Vec<bool> {
unimplemented!()
}
fn get_required_bits() -> usize {
3 // ceil(log2(7))
}
fn try_from_dec_str(_: &str) -> Result<Self, crate::FieldParseError> {
unimplemented!()
}
fn try_from_str(_: &str, _: u32) -> Result<Self, crate::FieldParseError> {
unimplemented!()
}
fn to_compact_dec_string(&self) -> String {
unimplemented!()
}
fn id() -> [u8; 4] {
unimplemented!()
}
fn name() -> &'static str {
unimplemented!()
}
fn bits(&self) -> u32 {
unimplemented!()
}
fn to_biguint(&self) -> num_bigint::BigUint {
unimplemented!()
}
}

View file

@ -632,8 +632,10 @@ pub mod bls12_377;
pub mod bls12_381;
pub mod bn128;
pub mod bw6_761;
pub mod dummy_curve;
pub use bls12_377::FieldPrime as Bls12_377Field;
pub use bls12_381::FieldPrime as Bls12_381Field;
pub use bn128::FieldPrime as Bn128Field;
pub use bw6_761::FieldPrime as Bw6_761Field;
pub use dummy_curve::FieldPrime as DummyCurveField;