From b1c9a171f8666bb7994b64895d9b9c46f61eb8cb Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 4 Aug 2021 14:55:17 +0200 Subject: [PATCH 01/15] add bit lt embed, fail on non constant bound, implement safe unpack --- .../compile_errors/variable_constant_lt.zok | 5 + zokrates_core/src/embed.rs | 30 +++++- zokrates_core/src/flatten/mod.rs | 49 ++++++++- zokrates_core/src/imports.rs | 4 + zokrates_core/src/ir/interpreter.rs | 24 ++++- .../constant_argument_checker.rs | 99 +++++++++++++++++++ zokrates_core/src/static_analysis/mod.rs | 14 +-- .../src/static_analysis/propagation.rs | 1 + .../src/static_analysis/shift_checker.rs | 55 ----------- .../utils/pack/bool/nonStrictUnpack256.zok | 4 +- .../stdlib/utils/pack/bool/unpack.zok | 12 +-- .../stdlib/utils/pack/bool/unpack128.zok | 4 +- .../stdlib/utils/pack/bool/unpack256.zok | 7 ++ .../utils/pack/bool/unpack_unchecked.zok | 9 ++ .../tests/utils/pack/bool/unpack256.json | 16 +++ .../tests/tests/utils/pack/bool/unpack256.zok | 24 +++++ 16 files changed, 277 insertions(+), 80 deletions(-) create mode 100644 zokrates_cli/examples/compile_errors/variable_constant_lt.zok create mode 100644 zokrates_core/src/static_analysis/constant_argument_checker.rs delete mode 100644 zokrates_core/src/static_analysis/shift_checker.rs create mode 100644 zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok create mode 100644 zokrates_stdlib/stdlib/utils/pack/bool/unpack_unchecked.zok create mode 100644 zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json create mode 100644 zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok diff --git a/zokrates_cli/examples/compile_errors/variable_constant_lt.zok b/zokrates_cli/examples/compile_errors/variable_constant_lt.zok new file mode 100644 index 00000000..4527bffc --- /dev/null +++ b/zokrates_cli/examples/compile_errors/variable_constant_lt.zok @@ -0,0 +1,5 @@ +from "EMBED" import bit_array_le + +// Unpack a field element as N big endian bits +def main(bool[1] a, bool[1] b) -> bool: + return bit_array_le::<1>(a, b) \ No newline at end of file diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 9f87b00e..d0c34bab 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -28,6 +28,7 @@ cfg_if::cfg_if! { /// the flattening step when it can be inlined. #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)] pub enum FlatEmbed { + BitArrayLe, U32ToField, Unpack, U8ToBits, @@ -47,6 +48,30 @@ pub enum FlatEmbed { impl FlatEmbed { pub fn signature(&self) -> DeclarationSignature<'static> { match self { + FlatEmbed::BitArrayLe => DeclarationSignature::new() + .generics(vec![Some(DeclarationConstant::Generic( + GenericIdentifier { + name: "N", + index: 0, + }, + ))]) + .inputs(vec![ + DeclarationType::array(( + DeclarationType::Boolean, + GenericIdentifier { + name: "N", + index: 0, + }, + )), + DeclarationType::array(( + DeclarationType::Boolean, + GenericIdentifier { + name: "N", + index: 0, + }, + )), + ]) + .outputs(vec![DeclarationType::Boolean]), FlatEmbed::U32ToField => DeclarationSignature::new() .inputs(vec![DeclarationType::uint(32)]) .outputs(vec![DeclarationType::FieldElement]), @@ -172,6 +197,7 @@ impl FlatEmbed { pub fn id(&self) -> &'static str { match self { + &FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT", FlatEmbed::U32ToField => "_U32_TO_FIELD", FlatEmbed::Unpack => "_UNPACK", FlatEmbed::U8ToBits => "_U8_TO_BITS", @@ -453,10 +479,6 @@ fn use_variable( /// as we decompose over `log_2(p) + 1 bits, some /// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)` pub fn unpack_to_bitwidth(bit_width: usize) -> FlatFunction { - let nbits = T::get_required_bits(); - - assert!(bit_width <= nbits); - let mut counter = 0; let mut layout = HashMap::new(); diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index dd05f3f3..1d230eae 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -222,7 +222,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { b: &[bool], ) -> Vec> { let len = b.len(); - assert_eq!(a.len(), T::get_required_bits()); assert_eq!(a.len(), b.len()); let mut is_not_smaller_run = vec![]; @@ -984,7 +983,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { } // check that the decomposition is in the field with a strict `< p` checks - self.constant_le_check( + self.enforce_constant_le_check( statements_flattened, &sub_bits_be, &T::max_value().bit_vector_be(), @@ -1161,6 +1160,52 @@ impl<'ast, T: Field> Flattener<'ast, T> { crate::embed::FlatEmbed::U8FromBits => { vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 8.into())] } + crate::embed::FlatEmbed::BitArrayLe => { + let len = generics[0]; + + let (expressions, constants) = ( + param_expressions[..len as usize].to_vec(), + param_expressions[len as usize..].to_vec(), + ); + + let variables: Vec<_> = expressions + .into_iter() + .map(|e| { + let e = self + .flatten_expression(statements_flattened, e) + .get_field_unchecked(); + self.define(e, statements_flattened) + }) + .collect(); + + let constants: Vec<_> = constants + .into_iter() + .map(|e| { + self.flatten_expression(statements_flattened, e) + .get_field_unchecked() + }) + .map(|e| match e { + FlatExpression::Number(n) => n == T::one(), + _ => unreachable!(), + }) + .collect(); + + let conditions = + self.constant_le_check(statements_flattened, &variables, &constants); + + // return `len(conditions) == sum(conditions)` + vec![FlatUExpression::with_field( + self.eq_check( + statements_flattened, + T::from(conditions.len()).into(), + conditions + .into_iter() + .fold(FlatExpression::Number(T::zero()), |acc, e| { + FlatExpression::Add(box acc, box e) + }), + ), + )] + } funct => { let funct = funct.synthetize(&generics); diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index 4578b89d..78d0f180 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -148,6 +148,10 @@ impl Importer { id: symbol.get_alias(), symbol: Symbol::Flat(FlatEmbed::Unpack), }, + "bit_array_le" => SymbolDeclaration { + id: symbol.get_alias(), + symbol: Symbol::Flat(FlatEmbed::BitArrayLe), + }, "u64_to_bits" => SymbolDeclaration { id: symbol.get_alias(), symbol: Symbol::Flat(FlatEmbed::U64ToBits), diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index 04f4108e..c6b75dae 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -156,10 +156,18 @@ impl Interpreter { ], }, Solver::Bits(bit_width) => { + let padding = bit_width.saturating_sub(T::get_required_bits()); + + let bit_width = bit_width - padding; + let mut num = inputs[0].clone(); let mut res = vec![]; - for i in (0..*bit_width).rev() { + for _ in 0..padding { + res.push(T::zero()); + } + + for i in (0..bit_width).rev() { if T::from(2).pow(i) <= num { num = num - T::from(2).pow(i); res.push(T::one()); @@ -407,4 +415,18 @@ mod tests { assert_eq!(res[248], Bn128Field::from(1)); assert_eq!(res[247], Bn128Field::from(0)); } + + #[test] + fn five_hundred_bits_of_1() { + let inputs = vec![Bn128Field::from(1)]; + let interpreter = Interpreter::default(); + let res = interpreter + .execute_solver(&Solver::Bits(500), &inputs) + .unwrap(); + + let mut expected = vec![Bn128Field::from(0); 500]; + expected[499] = Bn128Field::from(1); + + assert_eq!(res, expected); + } } diff --git a/zokrates_core/src/static_analysis/constant_argument_checker.rs b/zokrates_core/src/static_analysis/constant_argument_checker.rs new file mode 100644 index 00000000..9e9ddb39 --- /dev/null +++ b/zokrates_core/src/static_analysis/constant_argument_checker.rs @@ -0,0 +1,99 @@ +use crate::embed::FlatEmbed; +use crate::typed_absy::TypedProgram; +use crate::typed_absy::{ + result_folder::ResultFolder, + result_folder::{fold_expression_list_inner, fold_uint_expression_inner}, + ArrayExpressionInner, BooleanExpression, TypedExpression, TypedExpressionListInner, + TypedExpressionOrSpread, Types, UBitwidth, UExpressionInner, +}; +use zokrates_field::Field; +pub struct ConstantArgumentChecker; + +impl ConstantArgumentChecker { + pub fn check(p: TypedProgram) -> Result, Error> { + ConstantArgumentChecker.fold_program(p) + } +} + +pub type Error = String; + +impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker { + type Error = Error; + + fn fold_uint_expression_inner( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Error> { + match e { + UExpressionInner::LeftShift(box e, box by) => { + let e = self.fold_uint_expression(e)?; + let by = self.fold_uint_expression(by)?; + + match by.as_inner() { + UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)), + by => Err(format!( + "Cannot shift by a variable value, found `{} << {}`", + e, + by.clone().annotate(UBitwidth::B32) + )), + } + } + UExpressionInner::RightShift(box e, box by) => { + let e = self.fold_uint_expression(e)?; + let by = self.fold_uint_expression(by)?; + + match by.as_inner() { + UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)), + by => Err(format!( + "Cannot shift by a variable value, found `{} >> {}`", + e, + by.clone().annotate(UBitwidth::B32) + )), + } + } + e => fold_uint_expression_inner(self, bitwidth, e), + } + } + + fn fold_expression_list_inner( + &mut self, + tys: &Types<'ast, T>, + l: TypedExpressionListInner<'ast, T>, + ) -> Result, Error> { + match l { + TypedExpressionListInner::EmbedCall(FlatEmbed::BitArrayLe, generics, arguments) => { + let arguments = arguments + .into_iter() + .map(|a| self.fold_expression(a)) + .collect::, _>>()?; + + match arguments[1] { + TypedExpression::Array(ref a) => match a.as_inner() { + ArrayExpressionInner::Value(v) => { + if v.0.iter().all(|v| { + matches!( + v, + TypedExpressionOrSpread::Expression(TypedExpression::Boolean( + BooleanExpression::Value(_) + )) + ) + }) { + Ok(TypedExpressionListInner::EmbedCall( + FlatEmbed::BitArrayLe, + generics, + arguments, + )) + } else { + Err(format!("Cannot compare to a variable value, found `{}`", a)) + } + } + v => Err(format!("Cannot compare to a variable value, found `{}`", v)), + }, + _ => unreachable!(), + } + } + l => fold_expression_list_inner(self, tys, l), + } + } +} diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index e66f1505..bf658118 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -5,21 +5,21 @@ //! @date 2018 mod branch_isolator; +mod constant_argument_checker; mod constant_inliner; mod flat_propagation; mod flatten_complex_types; mod propagation; mod reducer; -mod shift_checker; mod uint_optimizer; mod unconstrained_vars; mod variable_write_remover; use self::branch_isolator::Isolator; +use self::constant_argument_checker::ConstantArgumentChecker; use self::flatten_complex_types::Flattener; use self::propagation::Propagator; use self::reducer::reduce_program; -use self::shift_checker::ShiftChecker; use self::uint_optimizer::UintOptimizer; use self::unconstrained_vars::UnconstrainedVariableDetector; use self::variable_write_remover::VariableWriteRemover; @@ -39,7 +39,7 @@ pub trait Analyse { pub enum Error { Reducer(self::reducer::Error), Propagation(self::propagation::Error), - NonConstantShift(self::shift_checker::Error), + NonConstantShift(self::constant_argument_checker::Error), } impl From for Error { @@ -54,8 +54,8 @@ impl From for Error { } } -impl From for Error { - fn from(e: shift_checker::Error) -> Self { +impl From for Error { + fn from(e: constant_argument_checker::Error) -> Self { Error::NonConstantShift(e) } } @@ -90,8 +90,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { let r = Propagator::propagate(r).map_err(Error::from)?; // remove assignment to variable index let r = VariableWriteRemover::apply(r); - // detect non constant shifts - let r = ShiftChecker::check(r).map_err(Error::from)?; + // detect non constant shifts and constant lt bounds + let r = ConstantArgumentChecker::check(r).map_err(Error::from)?; // convert to zir, removing complex types let zir = Flattener::flatten(r); // optimize uint expressions diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 3799755a..34a9c3f3 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -502,6 +502,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { true => { let r: Option> = match embed { FlatEmbed::U32ToField => None, // todo + FlatEmbed::BitArrayLe => None, // todo FlatEmbed::U64FromBits => Some(process_u_from_bits( assignees.clone(), arguments.clone(), diff --git a/zokrates_core/src/static_analysis/shift_checker.rs b/zokrates_core/src/static_analysis/shift_checker.rs deleted file mode 100644 index 7e44ea52..00000000 --- a/zokrates_core/src/static_analysis/shift_checker.rs +++ /dev/null @@ -1,55 +0,0 @@ -use crate::typed_absy::TypedProgram; -use crate::typed_absy::{ - result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth, - UExpressionInner, -}; -use zokrates_field::Field; -pub struct ShiftChecker; - -impl ShiftChecker { - pub fn check(p: TypedProgram) -> Result, Error> { - ShiftChecker.fold_program(p) - } -} - -pub type Error = String; - -impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker { - type Error = Error; - - fn fold_uint_expression_inner( - &mut self, - bitwidth: UBitwidth, - e: UExpressionInner<'ast, T>, - ) -> Result, Error> { - match e { - UExpressionInner::LeftShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; - - match by.as_inner() { - UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)), - by => Err(format!( - "Cannot shift by a variable value, found `{} << {}`", - e, - by.clone().annotate(UBitwidth::B32) - )), - } - } - UExpressionInner::RightShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; - - match by.as_inner() { - UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)), - by => Err(format!( - "Cannot shift by a variable value, found `{} >> {}`", - e, - by.clone().annotate(UBitwidth::B32) - )), - } - } - e => fold_uint_expression_inner(self, bitwidth, e), - } - } -} diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok b/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok index 4e48909f..e31dece4 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok @@ -1,12 +1,12 @@ #pragma curve bn128 -import "./unpack" as unpack +import "./unpack_unchecked" // Unpack a field element as 256 big-endian bits // Note: uniqueness of the output is not guaranteed // For example, `0` can map to `[0, 0, ..., 0]` or to `bits(p)` def main(field i) -> bool[256]: - bool[254] b = unpack::<254>(i) + bool[254] b = unpack_unchecked::<254>(i) return [false, false, ...b] \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok index d5b7a5cd..bc6d22d1 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok @@ -1,12 +1,12 @@ -#pragma curve bn128 - -from "EMBED" import unpack +import "./unpack_unchecked.zok" +from "field" import FIELD_SIZE_IN_BITS +from "EMBED" import bit_array_le // Unpack a field element as N big endian bits def main(field i) -> bool[N]: - - assert(N <= 254) - bool[N] res = unpack(i) + bool[N] res = unpack_unchecked(i) + + assert(if N >= FIELD_SIZE_IN_BITS then bit_array_le(res, [...[false; N - FIELD_SIZE_IN_BITS], ...unpack_unchecked::(-1)]) else true fi) return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok index a24a244b..8f0b1203 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok @@ -1,9 +1,7 @@ -#pragma curve bn128 - import "./unpack" as unpack // Unpack a field element as 128 big-endian bits -// Precondition: the input is smaller or equal to `2**128 - 1` +// If the input is larger than `2**128 - 1`, the output is truncated. def main(field i) -> bool[128]: bool[128] res = unpack::<128>(i) return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok new file mode 100644 index 00000000..4c3e3e56 --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok @@ -0,0 +1,7 @@ +import "./unpack" as unpack + +// Unpack a field element as 256 big-endian bits +// If the input is larger than `2**256 - 1`, the output is truncated. +def main(field i) -> bool[256]: + bool[256] res = unpack::<256>(i) + return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack_unchecked.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack_unchecked.zok new file mode 100644 index 00000000..2b0babbe --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack_unchecked.zok @@ -0,0 +1,9 @@ +from "EMBED" import unpack + +// Unpack a field element as N big endian bits without checking for overflows +// This does *not* guarantee a single output: for example, 0 can be decomposed as 0 or as P and this function does not enforce either +def main(field i) -> bool[N]: + + bool[N] res = unpack(i) + + return res \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json new file mode 100644 index 00000000..5739811a --- /dev/null +++ b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/utils/pack/bool/unpack256.zok", + "curves": ["Bn128"], + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": [] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok new file mode 100644 index 00000000..921ccb02 --- /dev/null +++ b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok @@ -0,0 +1,24 @@ +import "utils/pack/bool/unpack256" as unpack256 + +def testFive() -> bool: + + bool[256] b = unpack256(5) + + assert(b == [...[false; 253], true, false, true]) + + return true + +def testZero() -> bool: + + bool[256] b = unpack256(0) + + assert(b == [false; 256]) + + return true + + def main(): + + assert(testFive()) + assert(testZero()) + + return From ca6a3f09e8bc338e47657381fd581fec1df87080 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 4 Aug 2021 20:08:23 +0200 Subject: [PATCH 02/15] apply more aggressive propagation to array values --- .../src/static_analysis/propagation.rs | 28 +++++++++++++++++++ zokrates_core/src/typed_absy/mod.rs | 4 +-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 34a9c3f3..384619b4 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -1174,6 +1174,34 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { }, None => Ok(ArrayExpressionInner::Identifier(id)), }, + ArrayExpressionInner::Value(exprs) => { + Ok(ArrayExpressionInner::Value( + exprs + .into_iter() + .map(|e| self.fold_expression_or_spread(e)) + .collect::, _>>()? + .into_iter() + .flat_map(|e| { + match e { + // simplify `...[a, b]` to `a, b` + TypedExpressionOrSpread::Spread(TypedSpread { + array: + ArrayExpression { + ty: _, + inner: ArrayExpressionInner::Value(v), + }, + }) => v.0, + e => vec![e], + } + }) + // ignore spreads over empty arrays + .filter_map(|e| match e { + TypedExpressionOrSpread::Spread(s) if s.array.size() == 0 => None, + e => Some(e), + }) + .collect(), + )) + } e => fold_array_expression_inner(self, ty, e), } } diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index eb4f272e..07ce3c01 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -1019,8 +1019,8 @@ impl<'ast, T> From for BooleanExpression<'ast, T> { /// type checking #[derive(Clone, PartialEq, Debug, Hash, Eq)] pub struct ArrayExpression<'ast, T> { - ty: Box>, - inner: ArrayExpressionInner<'ast, T>, + pub ty: Box>, + pub inner: ArrayExpressionInner<'ast, T>, } #[derive(Debug, PartialEq, Eq, Hash, Clone)] From eb8e55e137751b49dcaed43a3641dc3dd91d4b8e Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 4 Aug 2021 20:33:10 +0200 Subject: [PATCH 03/15] adjust test --- .../src/static_analysis/reducer/mod.rs | 68 ------------------- 1 file changed, 68 deletions(-) diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index ae26d6cd..84790f29 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -1273,13 +1273,8 @@ mod tests { // def main(): // # PUSH CALL to foo::<1> // # PUSH CALL to bar::<2> - // field[2] a_1 = [...[1]], 0] - // field[2] #RET_0_1 = a_1 // # POP CALL - // field[1] ret := #RET_0_1[0..1] - // field[1] #RET_0 = ret // # POP CALL - // field[1] b_0 := #RET_0 // return let foo_signature = DeclarationSignature::new() @@ -1452,71 +1447,8 @@ mod tests { .collect(), ), ), - TypedStatement::Definition( - Variable::array(Identifier::from("a").version(1), Type::FieldElement, 2u32) - .into(), - ArrayExpressionInner::Value( - vec![ - TypedExpressionOrSpread::Spread( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(1)).into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - FieldElementExpression::Number(Bn128Field::from(0)).into(), - ] - .into(), - ) - .annotate(Type::FieldElement, 2u32) - .into(), - ), - TypedStatement::Definition( - Variable::array( - Identifier::from(CoreIdentifier::Call(0)).version(1), - Type::FieldElement, - 2u32, - ) - .into(), - ArrayExpressionInner::Identifier(Identifier::from("a").version(1)) - .annotate(Type::FieldElement, 2u32) - .into(), - ), TypedStatement::PopCallLog, - TypedStatement::Definition( - Variable::array("ret", Type::FieldElement, 1u32).into(), - ArrayExpressionInner::Slice( - box ArrayExpressionInner::Identifier( - Identifier::from(CoreIdentifier::Call(0)).version(1), - ) - .annotate(Type::FieldElement, 2u32), - box 0u32.into(), - box 1u32.into(), - ) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - TypedStatement::Definition( - Variable::array( - Identifier::from(CoreIdentifier::Call(0)), - Type::FieldElement, - 1u32, - ) - .into(), - ArrayExpressionInner::Identifier("ret".into()) - .annotate(Type::FieldElement, 1u32) - .into(), - ), TypedStatement::PopCallLog, - TypedStatement::Definition( - Variable::array("b", Type::FieldElement, 1u32).into(), - ArrayExpressionInner::Identifier(Identifier::from(CoreIdentifier::Call(0))) - .annotate(Type::FieldElement, 1u32) - .into(), - ), TypedStatement::Return(vec![]), ], signature: DeclarationSignature::new(), From 1ca0b3b9e657cee599ca08d9823a0044e6114072 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 4 Aug 2021 20:59:58 +0200 Subject: [PATCH 04/15] changelog, clean --- changelogs/unreleased/955-schaeff | 1 + zokrates_core/src/embed.rs | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 changelogs/unreleased/955-schaeff diff --git a/changelogs/unreleased/955-schaeff b/changelogs/unreleased/955-schaeff new file mode 100644 index 00000000..fff1c6e2 --- /dev/null +++ b/changelogs/unreleased/955-schaeff @@ -0,0 +1 @@ +Make the stdlib `unpack` function safe against overflows of bit decompositions for any size of output, introduce `unpack_unchecked` for cases that do not require determinism \ No newline at end of file diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index d0c34bab..3ee07660 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -197,7 +197,7 @@ impl FlatEmbed { pub fn id(&self) -> &'static str { match self { - &FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT", + FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT", FlatEmbed::U32ToField => "_U32_TO_FIELD", FlatEmbed::Unpack => "_UNPACK", FlatEmbed::U8ToBits => "_U8_TO_BITS", @@ -475,9 +475,8 @@ fn use_variable( /// * bit_width the number of bits we want to decompose to /// /// # Remarks -/// * the return value of the `FlatFunction` is not deterministic if `bit_width == T::get_required_bits()` -/// as we decompose over `log_2(p) + 1 bits, some -/// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)` +/// * the return value of the `FlatFunction` is not deterministic if `bit_width >= T::get_required_bits()` +/// as some elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)` pub fn unpack_to_bitwidth(bit_width: usize) -> FlatFunction { let mut counter = 0; From 7c9e31f40b6f00c2aa5efe9d3cd5c3896cb064e2 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 4 Aug 2021 21:20:45 +0200 Subject: [PATCH 05/15] inline repeats of constants --- .../src/static_analysis/propagation.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 384619b4..a4bc85b0 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -1187,10 +1187,27 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { TypedExpressionOrSpread::Spread(TypedSpread { array: ArrayExpression { - ty: _, inner: ArrayExpressionInner::Value(v), + .. }, }) => v.0, + // simplify ...[a; N] to [a, ..., a] if a is constant + TypedExpressionOrSpread::Spread(TypedSpread { + array: + ArrayExpression { + inner: + ArrayExpressionInner::Repeat( + box v, + box UExpression { + inner: UExpressionInner::Value(count), + .. + }, + ), + .. + }, + }) if is_constant(&v) => { + vec![TypedExpressionOrSpread::Expression(v); count as usize] + } e => vec![e], } }) From dd126b63e0c96705fa21407daa2ad80c4f388aac Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 4 Aug 2021 22:54:43 +0200 Subject: [PATCH 06/15] refactor is_constant to trait, move to typed_absy --- .../constant_argument_checker.rs | 37 ++-- .../src/static_analysis/propagation.rs | 147 +------------- zokrates_core/src/typed_absy/mod.rs | 191 ++++++++++++++++++ 3 files changed, 208 insertions(+), 167 deletions(-) diff --git a/zokrates_core/src/static_analysis/constant_argument_checker.rs b/zokrates_core/src/static_analysis/constant_argument_checker.rs index 9e9ddb39..91ec166e 100644 --- a/zokrates_core/src/static_analysis/constant_argument_checker.rs +++ b/zokrates_core/src/static_analysis/constant_argument_checker.rs @@ -3,8 +3,7 @@ use crate::typed_absy::TypedProgram; use crate::typed_absy::{ result_folder::ResultFolder, result_folder::{fold_expression_list_inner, fold_uint_expression_inner}, - ArrayExpressionInner, BooleanExpression, TypedExpression, TypedExpressionListInner, - TypedExpressionOrSpread, Types, UBitwidth, UExpressionInner, + Constant, TypedExpressionListInner, Types, UBitwidth, UExpressionInner, }; use zokrates_field::Field; pub struct ConstantArgumentChecker; @@ -68,29 +67,17 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker { .map(|a| self.fold_expression(a)) .collect::, _>>()?; - match arguments[1] { - TypedExpression::Array(ref a) => match a.as_inner() { - ArrayExpressionInner::Value(v) => { - if v.0.iter().all(|v| { - matches!( - v, - TypedExpressionOrSpread::Expression(TypedExpression::Boolean( - BooleanExpression::Value(_) - )) - ) - }) { - Ok(TypedExpressionListInner::EmbedCall( - FlatEmbed::BitArrayLe, - generics, - arguments, - )) - } else { - Err(format!("Cannot compare to a variable value, found `{}`", a)) - } - } - v => Err(format!("Cannot compare to a variable value, found `{}`", v)), - }, - _ => unreachable!(), + if arguments[1].is_constant() { + Ok(TypedExpressionListInner::EmbedCall( + FlatEmbed::BitArrayLe, + generics, + arguments, + )) + } else { + Err(format!( + "Cannot compare to a variable value, found `{}`", + arguments[1] + )) } } l => fold_expression_list_inner(self, tys, l), diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index a4bc85b0..cbb2c427 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -124,126 +124,6 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> { } } -fn is_constant(e: &TypedExpression) -> bool { - match e { - TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true, - TypedExpression::Boolean(BooleanExpression::Value(..)) => true, - TypedExpression::Array(a) => match a.as_inner() { - ArrayExpressionInner::Value(v) => v.0.iter().all(|e| match e { - TypedExpressionOrSpread::Expression(e) => is_constant(e), - _ => false, - }), - ArrayExpressionInner::Slice(box a, box from, box to) => { - is_constant(&from.clone().into()) - && is_constant(&to.clone().into()) - && is_constant(&a.clone().into()) - } - ArrayExpressionInner::Repeat(box e, box count) => { - is_constant(&count.clone().into()) && is_constant(&e) - } - _ => false, - }, - TypedExpression::Struct(a) => match a.as_inner() { - StructExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)), - _ => false, - }, - TypedExpression::Uint(a) => matches!(a.as_inner(), UExpressionInner::Value(..)), - _ => false, - } -} - -// in the constant map, we only want canonical constants: [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc -fn to_canonical_constant(e: TypedExpression) -> TypedExpression { - fn to_canonical_constant_aux( - e: TypedExpressionOrSpread, - ) -> Vec> { - match e { - TypedExpressionOrSpread::Expression(e) => vec![e], - TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() { - ArrayExpressionInner::Value(v) => { - v.into_iter().flat_map(to_canonical_constant_aux).collect() - } - _ => unimplemented!(), - }, - } - } - - match e { - TypedExpression::Array(a) => { - let array_ty = a.ty(); - - match a.into_inner() { - ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value( - v.into_iter() - .flat_map(to_canonical_constant_aux) - .map(|e| e.into()) - .collect::>() - .into(), - ) - .annotate(*array_ty.ty, array_ty.size) - .into(), - ArrayExpressionInner::Slice(box a, box from, box to) => { - let from = match from.into_inner() { - UExpressionInner::Value(from) => from as usize, - _ => unreachable!("should be a uint value"), - }; - - let to = match to.into_inner() { - UExpressionInner::Value(to) => to as usize, - _ => unreachable!("should be a uint value"), - }; - - let v = match a.into_inner() { - ArrayExpressionInner::Value(v) => v, - _ => unreachable!("should be an array value"), - }; - - ArrayExpressionInner::Value( - v.into_iter() - .flat_map(to_canonical_constant_aux) - .map(|e| e.into()) - .enumerate() - .filter(|(index, _)| index >= &from && index < &to) - .map(|(_, e)| e) - .collect::>() - .into(), - ) - .annotate(*array_ty.ty, array_ty.size) - .into() - } - ArrayExpressionInner::Repeat(box e, box count) => { - let count = match count.into_inner() { - UExpressionInner::Value(from) => from as usize, - _ => unreachable!("should be a uint value"), - }; - - let e = to_canonical_constant(e); - - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression(e); count].into(), - ) - .annotate(*array_ty.ty, array_ty.size) - .into() - } - _ => unreachable!(), - } - } - TypedExpression::Struct(s) => { - let struct_ty = s.ty().clone(); - - match s.into_inner() { - StructExpressionInner::Value(expressions) => StructExpressionInner::Value( - expressions.into_iter().map(to_canonical_constant).collect(), - ) - .annotate(struct_ty) - .into(), - _ => unreachable!(), - } - } - e => e, - } -} - impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { type Error = Error; @@ -341,10 +221,10 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } }; - if is_constant(&expr) { + if expr.is_constant() { match assignee { TypedAssignee::Identifier(var) => { - let expr = to_canonical_constant(expr); + let expr = expr.into_canonical_constant(); assert!(self.constants.insert(var.id, expr).is_none()); @@ -352,7 +232,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } assignee => match self.try_get_constant_mut(&assignee) { Ok((_, c)) => { - *c = to_canonical_constant(expr); + *c = expr.into_canonical_constant(); Ok(vec![]) } Err(v) => match self.constants.remove(&v.id) { @@ -423,7 +303,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { let argument = arguments.pop().unwrap(); - let argument = to_canonical_constant(argument); + let argument = argument.into_canonical_constant(); match ArrayExpression::try_from(argument) .unwrap() @@ -498,7 +378,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } } - match arguments.iter().all(|a| is_constant(a)) { + match arguments.iter().all(|a| a.is_constant()) { true => { let r: Option> = match embed { FlatEmbed::U32ToField => None, // todo @@ -1191,23 +1071,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { .. }, }) => v.0, - // simplify ...[a; N] to [a, ..., a] if a is constant - TypedExpressionOrSpread::Spread(TypedSpread { - array: - ArrayExpression { - inner: - ArrayExpressionInner::Repeat( - box v, - box UExpression { - inner: UExpressionInner::Value(count), - .. - }, - ), - .. - }, - }) if is_constant(&v) => { - vec![TypedExpressionOrSpread::Expression(v); count as usize] - } e => vec![e], } }) diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 07ce3c01..e4f500b5 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -1970,3 +1970,194 @@ impl<'ast, T: Field> Block<'ast, T> for StructExpression<'ast, T> { StructExpressionInner::Block(BlockExpression::new(statements, value)).annotate(struct_ty) } } + +pub trait Constant: Sized { + fn is_constant(&self) -> bool; + + fn into_canonical_constant(self) -> Self { + self + } +} + +impl<'ast, T: Field> Constant for FieldElementExpression<'ast, T> { + fn is_constant(&self) -> bool { + matches!(self, FieldElementExpression::Number(..)) + } +} + +impl<'ast, T: Field> Constant for BooleanExpression<'ast, T> { + fn is_constant(&self) -> bool { + matches!(self, BooleanExpression::Value(..)) + } +} + +impl<'ast, T: Field> Constant for UExpression<'ast, T> { + fn is_constant(&self) -> bool { + matches!(self.as_inner(), UExpressionInner::Value(..)) + } +} + +impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { + fn is_constant(&self) -> bool { + match self.as_inner() { + ArrayExpressionInner::Value(v) => v.0.iter().all(|e| match e { + TypedExpressionOrSpread::Expression(e) => e.is_constant(), + TypedExpressionOrSpread::Spread(s) => s.array.is_constant(), + }), + ArrayExpressionInner::Slice(box a, box from, box to) => { + from.is_constant() && to.is_constant() && a.is_constant() + } + ArrayExpressionInner::Repeat(box e, box count) => { + count.is_constant() && e.is_constant() + } + _ => false, + } + } + + fn into_canonical_constant(self) -> Self { + fn into_canonical_constant_aux( + e: TypedExpressionOrSpread, + ) -> Vec> { + match e { + TypedExpressionOrSpread::Expression(e) => vec![e], + TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() { + ArrayExpressionInner::Value(v) => v + .into_iter() + .flat_map(into_canonical_constant_aux) + .collect(), + ArrayExpressionInner::Slice(box v, box from, box to) => { + let from = match from.into_inner() { + UExpressionInner::Value(v) => v, + _ => unreachable!(), + }; + + let to = match to.into_inner() { + UExpressionInner::Value(v) => v, + _ => unreachable!(), + }; + + let v = match v.into_inner() { + ArrayExpressionInner::Value(v) => v, + _ => unreachable!(), + }; + + v.into_iter() + .flat_map(into_canonical_constant_aux) + .skip(from as usize) + .take(to as usize - from as usize) + .collect() + } + a => unreachable!("{}", a), + }, + } + } + + let array_ty = self.ty(); + + match self.into_inner() { + ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value( + v.into_iter() + .flat_map(into_canonical_constant_aux) + .map(|e| e.into()) + .collect::>() + .into(), + ) + .annotate(*array_ty.ty, array_ty.size) + .into(), + ArrayExpressionInner::Slice(box a, box from, box to) => { + let from = match from.into_inner() { + UExpressionInner::Value(from) => from as usize, + _ => unreachable!("should be a uint value"), + }; + + let to = match to.into_inner() { + UExpressionInner::Value(to) => to as usize, + _ => unreachable!("should be a uint value"), + }; + + let v = match a.into_inner() { + ArrayExpressionInner::Value(v) => v, + _ => unreachable!("should be an array value"), + }; + + ArrayExpressionInner::Value( + v.into_iter() + .flat_map(into_canonical_constant_aux) + .map(|e| e.into()) + .enumerate() + .filter(|(index, _)| index >= &from && index < &to) + .map(|(_, e)| e) + .collect::>() + .into(), + ) + .annotate(*array_ty.ty, array_ty.size) + .into() + } + ArrayExpressionInner::Repeat(box e, box count) => { + let count = match count.into_inner() { + UExpressionInner::Value(from) => from as usize, + _ => unreachable!("should be a uint value"), + }; + + let e = e.into_canonical_constant(); + + ArrayExpressionInner::Value( + vec![TypedExpressionOrSpread::Expression(e); count].into(), + ) + .annotate(*array_ty.ty, array_ty.size) + .into() + } + _ => unreachable!(), + } + } +} + +impl<'ast, T: Field> Constant for StructExpression<'ast, T> { + fn is_constant(&self) -> bool { + match self.as_inner() { + StructExpressionInner::Value(v) => v.iter().all(|e| e.is_constant()), + _ => false, + } + } + + fn into_canonical_constant(self) -> Self { + let struct_ty = self.ty().clone(); + + match self.into_inner() { + StructExpressionInner::Value(expressions) => StructExpressionInner::Value( + expressions + .into_iter() + .map(|e| e.into_canonical_constant()) + .collect(), + ) + .annotate(struct_ty) + .into(), + _ => unreachable!(), + } + } +} + +impl<'ast, T: Field> Constant for TypedExpression<'ast, T> { + fn is_constant(&self) -> bool { + match self { + TypedExpression::FieldElement(e) => e.is_constant(), + TypedExpression::Boolean(e) => e.is_constant(), + TypedExpression::Array(e) => e.is_constant(), + TypedExpression::Struct(e) => e.is_constant(), + TypedExpression::Uint(e) => e.is_constant(), + _ => unreachable!(), + } + } + + // in the constant map, we only want canonical constants: [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc + fn into_canonical_constant(self) -> Self { + match self { + TypedExpression::FieldElement(e) => e.into_canonical_constant().into(), + TypedExpression::Boolean(e) => e.into_canonical_constant().into(), + TypedExpression::Array(e) => e.into_canonical_constant().into(), + TypedExpression::Struct(e) => e.into_canonical_constant().into(), + TypedExpression::Uint(e) => e.into_canonical_constant().into(), + _ => unreachable!(), + } + } +} From ea0594035a495be61a9f357df2a61b548667cbe4 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 4 Aug 2021 23:00:10 +0200 Subject: [PATCH 07/15] clippy --- zokrates_core/src/static_analysis/uint_optimizer.rs | 4 ++-- zokrates_core/src/typed_absy/mod.rs | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/zokrates_core/src/static_analysis/uint_optimizer.rs b/zokrates_core/src/static_analysis/uint_optimizer.rs index ecca499a..232c4424 100644 --- a/zokrates_core/src/static_analysis/uint_optimizer.rs +++ b/zokrates_core/src/static_analysis/uint_optimizer.rs @@ -738,7 +738,7 @@ mod tests { assert_eq!( UintOptimizer::new() - .fold_uint_expression(UExpression::right_shift(left.clone(), right.clone())), + .fold_uint_expression(UExpression::right_shift(left.clone(), right)), UExpression::right_shift(left_expected, right_expected).with_max(output_max) ); } @@ -761,7 +761,7 @@ mod tests { assert_eq!( UintOptimizer::new() - .fold_uint_expression(UExpression::left_shift(left.clone(), right.clone())), + .fold_uint_expression(UExpression::left_shift(left.clone(), right)), UExpression::left_shift(left_expected, right_expected).with_max(output_max) ); } diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index e4f500b5..07f18a63 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -2062,8 +2062,7 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { .collect::>() .into(), ) - .annotate(*array_ty.ty, array_ty.size) - .into(), + .annotate(*array_ty.ty, array_ty.size), ArrayExpressionInner::Slice(box a, box from, box to) => { let from = match from.into_inner() { UExpressionInner::Value(from) => from as usize, @@ -2091,7 +2090,6 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { .into(), ) .annotate(*array_ty.ty, array_ty.size) - .into() } ArrayExpressionInner::Repeat(box e, box count) => { let count = match count.into_inner() { @@ -2105,7 +2103,6 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { vec![TypedExpressionOrSpread::Expression(e); count].into(), ) .annotate(*array_ty.ty, array_ty.size) - .into() } _ => unreachable!(), } @@ -2130,8 +2127,7 @@ impl<'ast, T: Field> Constant for StructExpression<'ast, T> { .map(|e| e.into_canonical_constant()) .collect(), ) - .annotate(struct_ty) - .into(), + .annotate(struct_ty), _ => unreachable!(), } } From 8663ea2bcaac955006c2d701861fdb895c399e62 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 4 Aug 2021 23:52:02 +0200 Subject: [PATCH 08/15] implement repeat case --- zokrates_core/src/typed_absy/mod.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 07f18a63..d262df05 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -2047,6 +2047,14 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { .take(to as usize - from as usize) .collect() } + ArrayExpressionInner::Repeat(box e, box count) => { + let count = match count.into_inner() { + UExpressionInner::Value(count) => count, + _ => unreachable!(), + }; + + vec![e; count as usize] + } a => unreachable!("{}", a), }, } From 6e19e7754c4b2b95c8b1e17a0cf540a24b1b4f79 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 5 Aug 2021 11:46:07 +0200 Subject: [PATCH 09/15] clean typed_absy, move common impl to expr trait --- .../static_analysis/flatten_complex_types.rs | 2 +- zokrates_core/src/typed_absy/integer.rs | 4 +- zokrates_core/src/typed_absy/mod.rs | 112 +++++++++++------- 3 files changed, 71 insertions(+), 47 deletions(-) diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index 4e2294af..ea024055 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -570,7 +570,7 @@ fn fold_select_expression<'ast, T: Field, E>( statements_buffer: &mut Vec>, select: typed_absy::SelectExpression<'ast, T, E>, ) -> Vec> { - let size = typed_absy::types::ConcreteType::try_from(*select.array.ty().ty) + let size = typed_absy::types::ConcreteType::try_from(*select.array.ty().clone().ty) .unwrap() .get_primitive_count(); diff --git a/zokrates_core/src/typed_absy/integer.rs b/zokrates_core/src/typed_absy/integer.rs index 851acf7a..87ea87d2 100644 --- a/zokrates_core/src/typed_absy/integer.rs +++ b/zokrates_core/src/typed_absy/integer.rs @@ -1,7 +1,7 @@ use crate::typed_absy::types::{ArrayType, Type}; use crate::typed_absy::UBitwidth; use crate::typed_absy::{ - ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse, + ArrayExpression, ArrayExpressionInner, BooleanExpression, Expr, FieldElementExpression, IfElse, IfElseExpression, Select, SelectExpression, StructExpression, Typed, TypedExpression, TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner, }; @@ -477,7 +477,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { array: Self, target_inner_ty: Type<'ast, T>, ) -> Result> { - let array_ty = array.ty(); + let array_ty = array.ty().clone(); // elements must fit in the target type match array.into_inner() { diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index d262df05..07220a22 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -1139,25 +1139,6 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> { pub fn size(&self) -> UExpression<'ast, T> { self.ty.size.clone() } - - pub fn as_inner(&self) -> &ArrayExpressionInner<'ast, T> { - &self.inner - } - - pub fn as_inner_mut(&mut self) -> &mut ArrayExpressionInner<'ast, T> { - &mut self.inner - } - - pub fn into_inner(self) -> ArrayExpressionInner<'ast, T> { - self.inner - } - - pub fn ty(&self) -> ArrayType<'ast, T> { - ArrayType { - size: self.size(), - ty: box self.inner_type().clone(), - } - } } #[derive(Clone, PartialEq, Debug, Hash, Eq)] @@ -1184,24 +1165,6 @@ impl<'ast, T: Field> StructExpression<'ast, T> { } } -impl<'ast, T> StructExpression<'ast, T> { - pub fn ty(&self) -> &StructType<'ast, T> { - &self.ty - } - - pub fn as_inner(&self) -> &StructExpressionInner<'ast, T> { - &self.inner - } - - pub fn as_inner_mut(&mut self) -> &mut StructExpressionInner<'ast, T> { - &mut self.inner - } - - pub fn into_inner(self) -> StructExpressionInner<'ast, T> { - self.inner - } -} - #[derive(Clone, PartialEq, Debug, Hash, Eq)] pub enum StructExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, StructExpression<'ast, T>>), @@ -1503,15 +1466,23 @@ pub trait Expr<'ast, T>: From> { type Inner; type Ty: Clone + IntoTypes<'ast, T>; + fn ty(&self) -> &Self::Ty; + fn into_inner(self) -> Self::Inner; fn as_inner(&self) -> &Self::Inner; + + fn as_inner_mut(&mut self) -> &mut Self::Inner; } impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; + fn ty(&self) -> &Self::Ty { + &Type::FieldElement + } + fn into_inner(self) -> Self::Inner { self } @@ -1519,12 +1490,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + self + } } impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; + fn ty(&self) -> &Self::Ty { + &Type::Boolean + } + fn into_inner(self) -> Self::Inner { self } @@ -1532,12 +1511,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + self + } } impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> { type Inner = UExpressionInner<'ast, T>; type Ty = UBitwidth; + fn ty(&self) -> &Self::Ty { + &self.bitwidth + } + fn into_inner(self) -> Self::Inner { self.inner } @@ -1545,12 +1532,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self.inner } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + &mut self.inner + } } impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> { type Inner = StructExpressionInner<'ast, T>; type Ty = StructType<'ast, T>; + fn ty(&self) -> &Self::Ty { + &self.ty + } + fn into_inner(self) -> Self::Inner { self.inner } @@ -1558,12 +1553,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self.inner } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + &mut self.inner + } } impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> { type Inner = ArrayExpressionInner<'ast, T>; type Ty = ArrayType<'ast, T>; + fn ty(&self) -> &Self::Ty { + &self.ty + } + fn into_inner(self) -> Self::Inner { self.inner } @@ -1571,12 +1574,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self.inner } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + &mut self.inner + } } impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; + fn ty(&self) -> &Self::Ty { + &Type::Int + } + fn into_inner(self) -> Self::Inner { self } @@ -1584,12 +1595,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + self + } } impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> { type Inner = TypedExpressionListInner<'ast, T>; type Ty = Types<'ast, T>; + fn ty(&self) -> &Self::Ty { + &self.types + } + fn into_inner(self) -> Self::Inner { self.inner } @@ -1597,6 +1616,10 @@ impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self.inner } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + &mut self.inner + } } // Enums types to enable returning e.g a member expression OR another type of expression of this type @@ -1777,7 +1800,7 @@ impl<'ast, T> Member<'ast, T> for BooleanExpression<'ast, T> { } } -impl<'ast, T> Member<'ast, T> for UExpression<'ast, T> { +impl<'ast, T: Clone> Member<'ast, T> for UExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { let ty = s.ty().members.iter().find(|member| id == member.id); let bitwidth = match ty { @@ -1957,7 +1980,7 @@ impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> { impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> { fn block(statements: Vec>, value: Self) -> Self { - let array_ty = value.ty(); + let array_ty = value.ty().clone(); ArrayExpressionInner::Block(BlockExpression::new(statements, value)) .annotate(*array_ty.ty, array_ty.size) } @@ -1972,8 +1995,11 @@ impl<'ast, T: Field> Block<'ast, T> for StructExpression<'ast, T> { } pub trait Constant: Sized { + // return whether this is constant fn is_constant(&self) -> bool; + // canonicalize an expression *that we know to be constant* + // for example for [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc fn into_canonical_constant(self) -> Self { self } @@ -2060,7 +2086,7 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { } } - let array_ty = self.ty(); + let array_ty = self.ty().clone(); match self.into_inner() { ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value( @@ -2091,9 +2117,8 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { v.into_iter() .flat_map(into_canonical_constant_aux) .map(|e| e.into()) - .enumerate() - .filter(|(index, _)| index >= &from && index < &to) - .map(|(_, e)| e) + .skip(from) + .take(to - from) .collect::>() .into(), ) @@ -2153,7 +2178,6 @@ impl<'ast, T: Field> Constant for TypedExpression<'ast, T> { } } - // in the constant map, we only want canonical constants: [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc fn into_canonical_constant(self) -> Self { match self { TypedExpression::FieldElement(e) => e.into_canonical_constant().into(), From 33c8fba1e1bce86ef85c1f64b9678f75d0140d60 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 5 Aug 2021 13:36:48 +0200 Subject: [PATCH 10/15] add comments, use iterators --- zokrates_core/src/flatten/mod.rs | 8 +++++++- zokrates_core/src/ir/interpreter.rs | 27 ++++++++++++--------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 1d230eae..e096f9b2 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1161,13 +1161,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 8.into())] } crate::embed::FlatEmbed::BitArrayLe => { + // get the length of the bit arrays let len = generics[0]; + // split the arguments into the two bit arrays of size `len` let (expressions, constants) = ( param_expressions[..len as usize].to_vec(), param_expressions[len as usize..].to_vec(), ); + // define variables for the variable bits let variables: Vec<_> = expressions .into_iter() .map(|e| { @@ -1178,6 +1181,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { }) .collect(); + // get constants for the constant bits let constants: Vec<_> = constants .into_iter() .map(|e| { @@ -1185,11 +1189,13 @@ impl<'ast, T: Field> Flattener<'ast, T> { .get_field_unchecked() }) .map(|e| match e { - FlatExpression::Number(n) => n == T::one(), + FlatExpression::Number(n) if n == T::one() => true, + FlatExpression::Number(n) if n == T::zero() => false, _ => unreachable!(), }) .collect(); + // get the list of conditions which must hold iff the `<=` relation holds let conditions = self.constant_le_check(statements_flattened, &variables, &constants); diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index c6b75dae..b68574f3 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -160,22 +160,19 @@ impl Interpreter { let bit_width = bit_width - padding; - let mut num = inputs[0].clone(); - let mut res = vec![]; + let num = inputs[0].clone(); - for _ in 0..padding { - res.push(T::zero()); - } - - for i in (0..bit_width).rev() { - if T::from(2).pow(i) <= num { - num = num - T::from(2).pow(i); - res.push(T::one()); - } else { - res.push(T::zero()); - } - } - res + (0..padding) + .map(|_| T::zero()) + .chain((0..bit_width).rev().scan(num, |state, i| { + if T::from(2).pow(i) <= *state { + *state = (*state).clone() - T::from(2).pow(i); + Some(T::one()) + } else { + Some(T::zero()) + } + })) + .collect() } Solver::Xor => { let x = inputs[0].clone(); From 63be983d745a8551b28642367dcd71319291223f Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 5 Aug 2021 15:22:46 +0200 Subject: [PATCH 11/15] add tests, tweak out of range interpreter --- Cargo.lock | 1 + zokrates_core/Cargo.toml | 1 + zokrates_core/src/ir/interpreter.rs | 51 +++++----- zokrates_core/src/static_analysis/mod.rs | 6 +- zokrates_core/tests/out_of_range.rs | 117 ++++++++++++++++++++++- zokrates_fs_resolver/src/lib.rs | 13 ++- 6 files changed, 151 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c87c1802..cf5b77b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2394,6 +2394,7 @@ dependencies = [ "zokrates_common", "zokrates_embed", "zokrates_field", + "zokrates_fs_resolver", "zokrates_pest_ast", ] diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index e81f04fb..91b31714 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -59,6 +59,7 @@ sha2 = { version = "0.9.3", optional = true } [dev-dependencies] wasm-bindgen-test = "^0.3.0" pretty_assertions = "0.6.1" +zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"} [build-dependencies] cc = { version = "1.0", features = ["parallel"], optional = true } diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index b68574f3..ef344e91 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -65,33 +65,27 @@ impl Interpreter { } } }, - Statement::Directive(ref d) => { - match (&d.solver, &d.inputs, self.should_try_out_of_range) { - (Solver::Bits(bitwidth), inputs, true) - if inputs[0].left.0.len() > 1 - || inputs[0].right.0.len() > 1 - && *bitwidth == T::get_required_bits() => - { - Self::try_solve_out_of_range(&d, &mut witness) - } - _ => { - let inputs: Vec<_> = d - .inputs - .iter() - .map(|i| i.evaluate(&witness).unwrap()) - .collect(); - match self.execute_solver(&d.solver, &inputs) { - Ok(res) => { - for (i, o) in d.outputs.iter().enumerate() { - witness.insert(*o, res[i].clone()); - } - continue; - } - Err(_) => return Err(Error::Solver), - }; - } + Statement::Directive(ref d) => match (&d.solver, self.should_try_out_of_range) { + (Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => { + Self::try_solve_out_of_range(&d, &mut witness) } - } + _ => { + let inputs: Vec<_> = d + .inputs + .iter() + .map(|i| i.evaluate(&witness).unwrap()) + .collect(); + match self.execute_solver(&d.solver, &inputs) { + Ok(res) => { + for (i, o) in d.outputs.iter().enumerate() { + witness.insert(*o, res[i].clone()); + } + continue; + } + Err(_) => return Err(Error::Solver), + }; + } + }, } } @@ -104,7 +98,9 @@ impl Interpreter { // we target the `2a - 2b` part of the `<` check by only returning out-of-range results // when the input is not a single summand let value = d.inputs[0].evaluate(&witness).unwrap(); + let candidate = value.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint(); + let input = if candidate < T::from(2).to_biguint().pow(T::get_required_bits()) { candidate } else { @@ -123,6 +119,9 @@ impl Interpreter { } } assert_eq!(num, T::zero().to_biguint()); + + println!("RES {:?}", res); + for (i, o) in d.outputs.iter().enumerate() { witness.insert(*o, res[i].clone()); } diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index bf658118..b2dab63e 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -39,7 +39,7 @@ pub trait Analyse { pub enum Error { Reducer(self::reducer::Error), Propagation(self::propagation::Error), - NonConstantShift(self::constant_argument_checker::Error), + NonConstantArgument(self::constant_argument_checker::Error), } impl From for Error { @@ -56,7 +56,7 @@ impl From for Error { impl From for Error { fn from(e: constant_argument_checker::Error) -> Self { - Error::NonConstantShift(e) + Error::NonConstantArgument(e) } } @@ -65,7 +65,7 @@ impl fmt::Display for Error { match self { Error::Reducer(e) => write!(f, "{}", e), Error::Propagation(e) => write!(f, "{}", e), - Error::NonConstantShift(e) => write!(f, "{}", e), + Error::NonConstantArgument(e) => write!(f, "{}", e), } } } diff --git a/zokrates_core/tests/out_of_range.rs b/zokrates_core/tests/out_of_range.rs index 6d2f19ab..95fd5497 100644 --- a/zokrates_core/tests/out_of_range.rs +++ b/zokrates_core/tests/out_of_range.rs @@ -10,9 +10,10 @@ use zokrates_core::{ ir::Interpreter, }; use zokrates_field::Bn128Field; +use zokrates_fs_resolver::FileSystemResolver; #[test] -fn out_of_range() { +fn lt_field() { let source = r#" def main(private field a, private field b) -> field: field x = if a < b then 3333 else 4444 fi @@ -21,7 +22,7 @@ fn out_of_range() { "# .to_string(); - // let's try to prove that "10000 < 5555" is true by exploiting + // let's try to prove that "10000f < 5555f" is true by exploiting // the fact that `2*10000 - 2*5555` has two distinct bit decompositions // we chose the one which is out of range, ie the sum check features an overflow @@ -42,3 +43,115 @@ fn out_of_range() { ) .is_err()); } + +#[test] +fn lt_uint() { + let source = r#" + def main(private u32 a, private u32 b): + field x = if a < b then 3333 else 4444 fi + assert(x == 3333) + return + "# + .to_string(); + + // let's try to prove that "10000u32 < 5555u32" is true by exploiting + // the fact that `2*10000 - 2*5555` has two distinct bit decompositions + // we chose the one which is out of range, ie the sum check features an overflow + + let res: CompilationArtifacts = compile( + source, + "./path/to/file".into(), + None::<&dyn Resolver>, + &CompileConfig::default(), + ) + .unwrap(); + + let interpreter = Interpreter::try_out_of_range(); + + assert!(interpreter + .execute( + &res.prog(), + &[Bn128Field::from(10000), Bn128Field::from(5555)] + ) + .is_err()); +} + +#[test] +fn unpack256() { + let source = r#" + import "utils/pack/bool/unpack256" + + def main(private field a): + bool[256] bits = unpack256(a) + assert(bits[255]) + return + "# + .to_string(); + + // let's try to prove that the least significant bit of 0 is 1 + // we exploit the fact that the bits of 0 are the bits of p, and p is even + // we want this to still fail + + let stdlib_path = std::fs::canonicalize( + std::env::current_dir() + .unwrap() + .join("../zokrates_stdlib/stdlib"), + ) + .unwrap(); + + let res: CompilationArtifacts = compile( + source, + "./path/to/file".into(), + Some(&FileSystemResolver::with_stdlib_root( + stdlib_path.to_str().unwrap(), + )), + &CompileConfig::default(), + ) + .unwrap(); + + let interpreter = Interpreter::try_out_of_range(); + + assert!(interpreter + .execute(&res.prog(), &[Bn128Field::from(0)]) + .is_err()); +} + +#[test] +fn unpack256_unchecked() { + let source = r#" + import "utils/pack/bool/nonStrictUnpack256" + + def main(private field a): + bool[256] bits = nonStrictUnpack256(a) + assert(bits[255]) + return + "# + .to_string(); + + // let's try to prove that the least significant bit of 0 is 1 + // we exploit the fact that the bits of 0 are the bits of p, and p is even + // we want this to succeed as the non strict version does not enforce the bits to be in range + + let stdlib_path = std::fs::canonicalize( + std::env::current_dir() + .unwrap() + .join("../zokrates_stdlib/stdlib"), + ) + .unwrap(); + + let res: CompilationArtifacts = compile( + source, + "./path/to/file".into(), + Some(&FileSystemResolver::with_stdlib_root( + stdlib_path.to_str().unwrap(), + )), + &CompileConfig::default(), + ) + .unwrap(); + + let interpreter = Interpreter::try_out_of_range(); + + assert!(interpreter + .execute(&res.prog(), &[Bn128Field::from(0)]) + .is_ok()); +} diff --git a/zokrates_fs_resolver/src/lib.rs b/zokrates_fs_resolver/src/lib.rs index b38a0b17..1289bbe3 100644 --- a/zokrates_fs_resolver/src/lib.rs +++ b/zokrates_fs_resolver/src/lib.rs @@ -26,17 +26,16 @@ impl<'a> Resolver for FileSystemResolver<'a> { ) -> Result<(String, PathBuf), io::Error> { let source = Path::new(&import_location); - if !current_location.is_file() { - return Err(io::Error::new( - io::ErrorKind::Other, - format!("{} was expected to be a file", current_location.display()), - )); - } - // paths starting with `./` or `../` are interpreted relative to the current file // other paths `abc/def` are interpreted relative to the standard library root path let base = match source.components().next() { Some(Component::CurDir) | Some(Component::ParentDir) => { + if !current_location.is_file() { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("{} was expected to be a file", current_location.display()), + )); + } current_location.parent().unwrap().into() } _ => PathBuf::from(self.stdlib_root_path.unwrap_or("")), From 7f76505a86771ad3f131c7651d11ab8c205376a4 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 5 Aug 2021 16:03:57 +0200 Subject: [PATCH 12/15] refactor out of range interpreter, accept any size of output --- zokrates_core/src/flatten/mod.rs | 3 +- zokrates_core/src/ir/interpreter.rs | 86 +++++++++------------ zokrates_core/src/optimizer/redefinition.rs | 4 +- zokrates_core/tests/out_of_range.rs | 2 +- 4 files changed, 40 insertions(+), 55 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index e096f9b2..f44c5364 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1982,8 +1982,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // constants do not require directives if let Some(FlatExpression::Number(ref x)) = e.field { - let bits: Vec<_> = Interpreter::default() - .execute_solver(&Solver::bits(to), &[x.clone()]) + let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()]) .unwrap() .into_iter() .map(FlatExpression::Number) diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index ef344e91..510c89b2 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -1,5 +1,4 @@ use crate::flat_absy::flat_variable::FlatVariable; -use crate::ir::Directive; use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness}; use crate::solvers::Solver; use serde::{Deserialize, Serialize}; @@ -65,66 +64,59 @@ impl Interpreter { } } }, - Statement::Directive(ref d) => match (&d.solver, self.should_try_out_of_range) { - (Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => { - Self::try_solve_out_of_range(&d, &mut witness) + Statement::Directive(ref d) => { + let mut inputs: Vec<_> = d + .inputs + .iter() + .map(|i| i.evaluate(&witness).unwrap()) + .collect(); + + let res = match (&d.solver, self.should_try_out_of_range) { + (Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => { + Ok(Self::try_solve_with_out_of_range_bits( + *bitwidth, + inputs.pop().unwrap(), + )) + } + _ => Self::execute_solver(&d.solver, &inputs), } - _ => { - let inputs: Vec<_> = d - .inputs - .iter() - .map(|i| i.evaluate(&witness).unwrap()) - .collect(); - match self.execute_solver(&d.solver, &inputs) { - Ok(res) => { - for (i, o) in d.outputs.iter().enumerate() { - witness.insert(*o, res[i].clone()); - } - continue; - } - Err(_) => return Err(Error::Solver), - }; + .map_err(|_| Error::Solver)?; + + for (i, o) in d.outputs.iter().enumerate() { + witness.insert(*o, res[i].clone()); } - }, + } } } Ok(Witness(witness)) } - fn try_solve_out_of_range(d: &Directive, witness: &mut BTreeMap) { + fn try_solve_with_out_of_range_bits(bit_width: usize, input: T) -> Vec { use num::traits::Pow; + use num_bigint::BigUint; - // we target the `2a - 2b` part of the `<` check by only returning out-of-range results - // when the input is not a single summand - let value = d.inputs[0].evaluate(&witness).unwrap(); - - let candidate = value.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint(); + let candidate = input.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint(); let input = if candidate < T::from(2).to_biguint().pow(T::get_required_bits()) { candidate } else { - value.to_biguint() + input.to_biguint() }; - let mut num = input; - let mut res = vec![]; - let bits = T::get_required_bits(); - for i in (0..bits).rev() { - if T::from(2).to_biguint().pow(i as usize) <= num { - num -= T::from(2).to_biguint().pow(i as usize); - res.push(T::one()); - } else { - res.push(T::zero()); - } - } - assert_eq!(num, T::zero().to_biguint()); + let padding = bit_width - T::get_required_bits(); - println!("RES {:?}", res); - - for (i, o) in d.outputs.iter().enumerate() { - witness.insert(*o, res[i].clone()); - } + (0..padding) + .map(|_| T::zero()) + .chain((0..T::get_required_bits()).rev().scan(input, |state, i| { + if BigUint::from(2usize).pow(i) <= *state { + *state = (*state).clone() - BigUint::from(2usize).pow(i); + Some(T::one()) + } else { + Some(T::zero()) + } + })) + .collect() } fn check_inputs(&self, program: &Prog, inputs: &[U]) -> Result<(), Error> { @@ -138,11 +130,7 @@ impl Interpreter { } } - pub fn execute_solver( - &self, - solver: &Solver, - inputs: &[T], - ) -> Result, String> { + pub fn execute_solver(solver: &Solver, inputs: &[T]) -> Result, String> { let (expected_input_count, expected_output_count) = solver.get_signature(); assert_eq!(inputs.len(), expected_input_count); diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index 6b197b1e..4222281c 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -140,9 +140,7 @@ impl Folder for RedefinitionOptimizer { // unwrap inputs to their constant value let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect(); // run the solver - let outputs = Interpreter::default() - .execute_solver(&d.solver, &inputs) - .unwrap(); + let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap(); assert_eq!(outputs.len(), d.outputs.len()); // insert the results in the substitution diff --git a/zokrates_core/tests/out_of_range.rs b/zokrates_core/tests/out_of_range.rs index 95fd5497..c27e2361 100644 --- a/zokrates_core/tests/out_of_range.rs +++ b/zokrates_core/tests/out_of_range.rs @@ -129,7 +129,7 @@ fn unpack256_unchecked() { .to_string(); // let's try to prove that the least significant bit of 0 is 1 - // we exploit the fact that the bits of 0 are the bits of p, and p is even + // we exploit the fact that the bits of 0 are the bits of p, and p is odd // we want this to succeed as the non strict version does not enforce the bits to be in range let stdlib_path = std::fs::canonicalize( From a7cff74ea8a82470c4fc7f5327b6431a573d92f3 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 5 Aug 2021 16:13:15 +0200 Subject: [PATCH 13/15] fix tests --- zokrates_core/src/ir/interpreter.rs | 55 ++++++++++++----------------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index 510c89b2..580e5ee8 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -338,16 +338,14 @@ mod tests { fn execute() { let cond_eq = Solver::ConditionEq; let inputs = vec![0]; - let interpreter = Interpreter::default(); - let r = interpreter - .execute_solver( - &cond_eq, - &inputs - .iter() - .map(|&i| Bn128Field::from(i)) - .collect::>(), - ) - .unwrap(); + let r = Interpreter::execute_solver( + &cond_eq, + &inputs + .iter() + .map(|&i| Bn128Field::from(i)) + .collect::>(), + ) + .unwrap(); let res: Vec = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect(); assert_eq!(r, &res[..]); } @@ -356,16 +354,14 @@ mod tests { fn execute_non_eq() { let cond_eq = Solver::ConditionEq; let inputs = vec![1]; - let interpreter = Interpreter::default(); - let r = interpreter - .execute_solver( - &cond_eq, - &inputs - .iter() - .map(|&i| Bn128Field::from(i)) - .collect::>(), - ) - .unwrap(); + let r = Interpreter::execute_solver( + &cond_eq, + &inputs + .iter() + .map(|&i| Bn128Field::from(i)) + .collect::>(), + ) + .unwrap(); let res: Vec = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect(); assert_eq!(r, &res[..]); } @@ -374,10 +370,9 @@ mod tests { #[test] fn bits_of_one() { let inputs = vec![Bn128Field::from(1)]; - let interpreter = Interpreter::default(); - let res = interpreter - .execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = + Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) + .unwrap(); assert_eq!(res[253], Bn128Field::from(1)); for r in &res[0..253] { assert_eq!(*r, Bn128Field::from(0)); @@ -387,10 +382,9 @@ mod tests { #[test] fn bits_of_42() { let inputs = vec![Bn128Field::from(42)]; - let interpreter = Interpreter::default(); - let res = interpreter - .execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = + Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) + .unwrap(); assert_eq!(res[253], Bn128Field::from(0)); assert_eq!(res[252], Bn128Field::from(1)); assert_eq!(res[251], Bn128Field::from(0)); @@ -403,10 +397,7 @@ mod tests { #[test] fn five_hundred_bits_of_1() { let inputs = vec![Bn128Field::from(1)]; - let interpreter = Interpreter::default(); - let res = interpreter - .execute_solver(&Solver::Bits(500), &inputs) - .unwrap(); + let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs).unwrap(); let mut expected = vec![Bn128Field::from(0); 500]; expected[499] = Bn128Field::from(1); From de50a7a0d1171efa3a194c37e3f1ea436ecba357 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 9 Aug 2021 12:02:50 +0200 Subject: [PATCH 14/15] fix test --- zokrates_core/tests/out_of_range.rs | 32 ----------------------------- 1 file changed, 32 deletions(-) diff --git a/zokrates_core/tests/out_of_range.rs b/zokrates_core/tests/out_of_range.rs index 14208d71..c27e2361 100644 --- a/zokrates_core/tests/out_of_range.rs +++ b/zokrates_core/tests/out_of_range.rs @@ -76,38 +76,6 @@ fn lt_uint() { .is_err()); } -#[test] -fn lt_uint() { - let source = r#" - def main(private u32 a, private u32 b): - field x = if a < b then 3333 else 4444 fi - assert(x == 3333) - return - "# - .to_string(); - - // let's try to prove that "10000u32 < 5555u32" is true by exploiting - // the fact that `2*10000 - 2*5555` has two distinct bit decompositions - // we chose the one which is out of range, ie the sum check features an overflow - - let res: CompilationArtifacts = compile( - source, - "./path/to/file".into(), - None::<&dyn Resolver>, - &CompileConfig::default(), - ) - .unwrap(); - - let interpreter = Interpreter::try_out_of_range(); - - assert!(interpreter - .execute( - &res.prog(), - &[Bn128Field::from(10000), Bn128Field::from(5555)] - ) - .is_err()); -} - #[test] fn unpack256() { let source = r#" From 9ce5827596cfa634d181812860487e25656e9fdc Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 12 Aug 2021 21:14:24 +0200 Subject: [PATCH 15/15] remove left out commented out conflicts --- zokrates_core/src/typed_absy/mod.rs | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index cb9d2537..b31dfeee 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -1157,23 +1157,6 @@ pub struct StructExpression<'ast, T> { inner: StructExpressionInner<'ast, T>, } -// <<<<<<< HEAD -// impl<'ast, T: Field> StructExpression<'ast, T> { -// pub fn try_from_typed( -// e: TypedExpression<'ast, T>, -// target_struct_ty: StructType<'ast, T>, -// ) -> Result> { -// match e { -// TypedExpression::Struct(e) => { -// if e.ty() == &target_struct_ty { -// Ok(e) -// } else { -// Err(TypedExpression::Struct(e)) -// } -// } -// e => Err(e), -// } -// ======= impl<'ast, T> StructExpression<'ast, T> { pub fn ty(&self) -> &StructType<'ast, T> { &self.ty @@ -1189,7 +1172,6 @@ impl<'ast, T> StructExpression<'ast, T> { pub fn into_inner(self) -> StructExpressionInner<'ast, T> { self.inner - // >>>>>>> 5a02186fc1d5c8f438a9663112f444497e752ea6 } }