From b1c9a171f8666bb7994b64895d9b9c46f61eb8cb Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 4 Aug 2021 14:55:17 +0200 Subject: [PATCH] add bit lt embed, fail on non constant bound, implement safe unpack --- .../compile_errors/variable_constant_lt.zok | 5 + zokrates_core/src/embed.rs | 30 +++++- zokrates_core/src/flatten/mod.rs | 49 ++++++++- zokrates_core/src/imports.rs | 4 + zokrates_core/src/ir/interpreter.rs | 24 ++++- .../constant_argument_checker.rs | 99 +++++++++++++++++++ zokrates_core/src/static_analysis/mod.rs | 14 +-- .../src/static_analysis/propagation.rs | 1 + .../src/static_analysis/shift_checker.rs | 55 ----------- .../utils/pack/bool/nonStrictUnpack256.zok | 4 +- .../stdlib/utils/pack/bool/unpack.zok | 12 +-- .../stdlib/utils/pack/bool/unpack128.zok | 4 +- .../stdlib/utils/pack/bool/unpack256.zok | 7 ++ .../utils/pack/bool/unpack_unchecked.zok | 9 ++ .../tests/utils/pack/bool/unpack256.json | 16 +++ .../tests/tests/utils/pack/bool/unpack256.zok | 24 +++++ 16 files changed, 277 insertions(+), 80 deletions(-) create mode 100644 zokrates_cli/examples/compile_errors/variable_constant_lt.zok create mode 100644 zokrates_core/src/static_analysis/constant_argument_checker.rs delete mode 100644 zokrates_core/src/static_analysis/shift_checker.rs create mode 100644 zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok create mode 100644 zokrates_stdlib/stdlib/utils/pack/bool/unpack_unchecked.zok create mode 100644 zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json create mode 100644 zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok diff --git a/zokrates_cli/examples/compile_errors/variable_constant_lt.zok b/zokrates_cli/examples/compile_errors/variable_constant_lt.zok new file mode 100644 index 00000000..4527bffc --- /dev/null +++ b/zokrates_cli/examples/compile_errors/variable_constant_lt.zok @@ -0,0 +1,5 @@ +from "EMBED" import bit_array_le + +// Unpack a field element as N big endian bits +def main(bool[1] a, bool[1] b) -> bool: + return bit_array_le::<1>(a, b) \ No newline at end of file diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 9f87b00e..d0c34bab 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -28,6 +28,7 @@ cfg_if::cfg_if! { /// the flattening step when it can be inlined. #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)] pub enum FlatEmbed { + BitArrayLe, U32ToField, Unpack, U8ToBits, @@ -47,6 +48,30 @@ pub enum FlatEmbed { impl FlatEmbed { pub fn signature(&self) -> DeclarationSignature<'static> { match self { + FlatEmbed::BitArrayLe => DeclarationSignature::new() + .generics(vec![Some(DeclarationConstant::Generic( + GenericIdentifier { + name: "N", + index: 0, + }, + ))]) + .inputs(vec![ + DeclarationType::array(( + DeclarationType::Boolean, + GenericIdentifier { + name: "N", + index: 0, + }, + )), + DeclarationType::array(( + DeclarationType::Boolean, + GenericIdentifier { + name: "N", + index: 0, + }, + )), + ]) + .outputs(vec![DeclarationType::Boolean]), FlatEmbed::U32ToField => DeclarationSignature::new() .inputs(vec![DeclarationType::uint(32)]) .outputs(vec![DeclarationType::FieldElement]), @@ -172,6 +197,7 @@ impl FlatEmbed { pub fn id(&self) -> &'static str { match self { + &FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT", FlatEmbed::U32ToField => "_U32_TO_FIELD", FlatEmbed::Unpack => "_UNPACK", FlatEmbed::U8ToBits => "_U8_TO_BITS", @@ -453,10 +479,6 @@ fn use_variable( /// as we decompose over `log_2(p) + 1 bits, some /// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)` pub fn unpack_to_bitwidth(bit_width: usize) -> FlatFunction { - let nbits = T::get_required_bits(); - - assert!(bit_width <= nbits); - let mut counter = 0; let mut layout = HashMap::new(); diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index dd05f3f3..1d230eae 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -222,7 +222,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { b: &[bool], ) -> Vec> { let len = b.len(); - assert_eq!(a.len(), T::get_required_bits()); assert_eq!(a.len(), b.len()); let mut is_not_smaller_run = vec![]; @@ -984,7 +983,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { } // check that the decomposition is in the field with a strict `< p` checks - self.constant_le_check( + self.enforce_constant_le_check( statements_flattened, &sub_bits_be, &T::max_value().bit_vector_be(), @@ -1161,6 +1160,52 @@ impl<'ast, T: Field> Flattener<'ast, T> { crate::embed::FlatEmbed::U8FromBits => { vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 8.into())] } + crate::embed::FlatEmbed::BitArrayLe => { + let len = generics[0]; + + let (expressions, constants) = ( + param_expressions[..len as usize].to_vec(), + param_expressions[len as usize..].to_vec(), + ); + + let variables: Vec<_> = expressions + .into_iter() + .map(|e| { + let e = self + .flatten_expression(statements_flattened, e) + .get_field_unchecked(); + self.define(e, statements_flattened) + }) + .collect(); + + let constants: Vec<_> = constants + .into_iter() + .map(|e| { + self.flatten_expression(statements_flattened, e) + .get_field_unchecked() + }) + .map(|e| match e { + FlatExpression::Number(n) => n == T::one(), + _ => unreachable!(), + }) + .collect(); + + let conditions = + self.constant_le_check(statements_flattened, &variables, &constants); + + // return `len(conditions) == sum(conditions)` + vec![FlatUExpression::with_field( + self.eq_check( + statements_flattened, + T::from(conditions.len()).into(), + conditions + .into_iter() + .fold(FlatExpression::Number(T::zero()), |acc, e| { + FlatExpression::Add(box acc, box e) + }), + ), + )] + } funct => { let funct = funct.synthetize(&generics); diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index 4578b89d..78d0f180 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -148,6 +148,10 @@ impl Importer { id: symbol.get_alias(), symbol: Symbol::Flat(FlatEmbed::Unpack), }, + "bit_array_le" => SymbolDeclaration { + id: symbol.get_alias(), + symbol: Symbol::Flat(FlatEmbed::BitArrayLe), + }, "u64_to_bits" => SymbolDeclaration { id: symbol.get_alias(), symbol: Symbol::Flat(FlatEmbed::U64ToBits), diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index 04f4108e..c6b75dae 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -156,10 +156,18 @@ impl Interpreter { ], }, Solver::Bits(bit_width) => { + let padding = bit_width.saturating_sub(T::get_required_bits()); + + let bit_width = bit_width - padding; + let mut num = inputs[0].clone(); let mut res = vec![]; - for i in (0..*bit_width).rev() { + for _ in 0..padding { + res.push(T::zero()); + } + + for i in (0..bit_width).rev() { if T::from(2).pow(i) <= num { num = num - T::from(2).pow(i); res.push(T::one()); @@ -407,4 +415,18 @@ mod tests { assert_eq!(res[248], Bn128Field::from(1)); assert_eq!(res[247], Bn128Field::from(0)); } + + #[test] + fn five_hundred_bits_of_1() { + let inputs = vec![Bn128Field::from(1)]; + let interpreter = Interpreter::default(); + let res = interpreter + .execute_solver(&Solver::Bits(500), &inputs) + .unwrap(); + + let mut expected = vec![Bn128Field::from(0); 500]; + expected[499] = Bn128Field::from(1); + + assert_eq!(res, expected); + } } diff --git a/zokrates_core/src/static_analysis/constant_argument_checker.rs b/zokrates_core/src/static_analysis/constant_argument_checker.rs new file mode 100644 index 00000000..9e9ddb39 --- /dev/null +++ b/zokrates_core/src/static_analysis/constant_argument_checker.rs @@ -0,0 +1,99 @@ +use crate::embed::FlatEmbed; +use crate::typed_absy::TypedProgram; +use crate::typed_absy::{ + result_folder::ResultFolder, + result_folder::{fold_expression_list_inner, fold_uint_expression_inner}, + ArrayExpressionInner, BooleanExpression, TypedExpression, TypedExpressionListInner, + TypedExpressionOrSpread, Types, UBitwidth, UExpressionInner, +}; +use zokrates_field::Field; +pub struct ConstantArgumentChecker; + +impl ConstantArgumentChecker { + pub fn check(p: TypedProgram) -> Result, Error> { + ConstantArgumentChecker.fold_program(p) + } +} + +pub type Error = String; + +impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker { + type Error = Error; + + fn fold_uint_expression_inner( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Error> { + match e { + UExpressionInner::LeftShift(box e, box by) => { + let e = self.fold_uint_expression(e)?; + let by = self.fold_uint_expression(by)?; + + match by.as_inner() { + UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)), + by => Err(format!( + "Cannot shift by a variable value, found `{} << {}`", + e, + by.clone().annotate(UBitwidth::B32) + )), + } + } + UExpressionInner::RightShift(box e, box by) => { + let e = self.fold_uint_expression(e)?; + let by = self.fold_uint_expression(by)?; + + match by.as_inner() { + UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)), + by => Err(format!( + "Cannot shift by a variable value, found `{} >> {}`", + e, + by.clone().annotate(UBitwidth::B32) + )), + } + } + e => fold_uint_expression_inner(self, bitwidth, e), + } + } + + fn fold_expression_list_inner( + &mut self, + tys: &Types<'ast, T>, + l: TypedExpressionListInner<'ast, T>, + ) -> Result, Error> { + match l { + TypedExpressionListInner::EmbedCall(FlatEmbed::BitArrayLe, generics, arguments) => { + let arguments = arguments + .into_iter() + .map(|a| self.fold_expression(a)) + .collect::, _>>()?; + + match arguments[1] { + TypedExpression::Array(ref a) => match a.as_inner() { + ArrayExpressionInner::Value(v) => { + if v.0.iter().all(|v| { + matches!( + v, + TypedExpressionOrSpread::Expression(TypedExpression::Boolean( + BooleanExpression::Value(_) + )) + ) + }) { + Ok(TypedExpressionListInner::EmbedCall( + FlatEmbed::BitArrayLe, + generics, + arguments, + )) + } else { + Err(format!("Cannot compare to a variable value, found `{}`", a)) + } + } + v => Err(format!("Cannot compare to a variable value, found `{}`", v)), + }, + _ => unreachable!(), + } + } + l => fold_expression_list_inner(self, tys, l), + } + } +} diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index e66f1505..bf658118 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -5,21 +5,21 @@ //! @date 2018 mod branch_isolator; +mod constant_argument_checker; mod constant_inliner; mod flat_propagation; mod flatten_complex_types; mod propagation; mod reducer; -mod shift_checker; mod uint_optimizer; mod unconstrained_vars; mod variable_write_remover; use self::branch_isolator::Isolator; +use self::constant_argument_checker::ConstantArgumentChecker; use self::flatten_complex_types::Flattener; use self::propagation::Propagator; use self::reducer::reduce_program; -use self::shift_checker::ShiftChecker; use self::uint_optimizer::UintOptimizer; use self::unconstrained_vars::UnconstrainedVariableDetector; use self::variable_write_remover::VariableWriteRemover; @@ -39,7 +39,7 @@ pub trait Analyse { pub enum Error { Reducer(self::reducer::Error), Propagation(self::propagation::Error), - NonConstantShift(self::shift_checker::Error), + NonConstantShift(self::constant_argument_checker::Error), } impl From for Error { @@ -54,8 +54,8 @@ impl From for Error { } } -impl From for Error { - fn from(e: shift_checker::Error) -> Self { +impl From for Error { + fn from(e: constant_argument_checker::Error) -> Self { Error::NonConstantShift(e) } } @@ -90,8 +90,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { let r = Propagator::propagate(r).map_err(Error::from)?; // remove assignment to variable index let r = VariableWriteRemover::apply(r); - // detect non constant shifts - let r = ShiftChecker::check(r).map_err(Error::from)?; + // detect non constant shifts and constant lt bounds + let r = ConstantArgumentChecker::check(r).map_err(Error::from)?; // convert to zir, removing complex types let zir = Flattener::flatten(r); // optimize uint expressions diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 3799755a..34a9c3f3 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -502,6 +502,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { true => { let r: Option> = match embed { FlatEmbed::U32ToField => None, // todo + FlatEmbed::BitArrayLe => None, // todo FlatEmbed::U64FromBits => Some(process_u_from_bits( assignees.clone(), arguments.clone(), diff --git a/zokrates_core/src/static_analysis/shift_checker.rs b/zokrates_core/src/static_analysis/shift_checker.rs deleted file mode 100644 index 7e44ea52..00000000 --- a/zokrates_core/src/static_analysis/shift_checker.rs +++ /dev/null @@ -1,55 +0,0 @@ -use crate::typed_absy::TypedProgram; -use crate::typed_absy::{ - result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth, - UExpressionInner, -}; -use zokrates_field::Field; -pub struct ShiftChecker; - -impl ShiftChecker { - pub fn check(p: TypedProgram) -> Result, Error> { - ShiftChecker.fold_program(p) - } -} - -pub type Error = String; - -impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker { - type Error = Error; - - fn fold_uint_expression_inner( - &mut self, - bitwidth: UBitwidth, - e: UExpressionInner<'ast, T>, - ) -> Result, Error> { - match e { - UExpressionInner::LeftShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; - - match by.as_inner() { - UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)), - by => Err(format!( - "Cannot shift by a variable value, found `{} << {}`", - e, - by.clone().annotate(UBitwidth::B32) - )), - } - } - UExpressionInner::RightShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; - - match by.as_inner() { - UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)), - by => Err(format!( - "Cannot shift by a variable value, found `{} >> {}`", - e, - by.clone().annotate(UBitwidth::B32) - )), - } - } - e => fold_uint_expression_inner(self, bitwidth, e), - } - } -} diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok b/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok index 4e48909f..e31dece4 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok @@ -1,12 +1,12 @@ #pragma curve bn128 -import "./unpack" as unpack +import "./unpack_unchecked" // Unpack a field element as 256 big-endian bits // Note: uniqueness of the output is not guaranteed // For example, `0` can map to `[0, 0, ..., 0]` or to `bits(p)` def main(field i) -> bool[256]: - bool[254] b = unpack::<254>(i) + bool[254] b = unpack_unchecked::<254>(i) return [false, false, ...b] \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok index d5b7a5cd..bc6d22d1 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok @@ -1,12 +1,12 @@ -#pragma curve bn128 - -from "EMBED" import unpack +import "./unpack_unchecked.zok" +from "field" import FIELD_SIZE_IN_BITS +from "EMBED" import bit_array_le // Unpack a field element as N big endian bits def main(field i) -> bool[N]: - - assert(N <= 254) - bool[N] res = unpack(i) + bool[N] res = unpack_unchecked(i) + + assert(if N >= FIELD_SIZE_IN_BITS then bit_array_le(res, [...[false; N - FIELD_SIZE_IN_BITS], ...unpack_unchecked::(-1)]) else true fi) return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok index a24a244b..8f0b1203 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok @@ -1,9 +1,7 @@ -#pragma curve bn128 - import "./unpack" as unpack // Unpack a field element as 128 big-endian bits -// Precondition: the input is smaller or equal to `2**128 - 1` +// If the input is larger than `2**128 - 1`, the output is truncated. def main(field i) -> bool[128]: bool[128] res = unpack::<128>(i) return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok new file mode 100644 index 00000000..4c3e3e56 --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok @@ -0,0 +1,7 @@ +import "./unpack" as unpack + +// Unpack a field element as 256 big-endian bits +// If the input is larger than `2**256 - 1`, the output is truncated. +def main(field i) -> bool[256]: + bool[256] res = unpack::<256>(i) + return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack_unchecked.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack_unchecked.zok new file mode 100644 index 00000000..2b0babbe --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack_unchecked.zok @@ -0,0 +1,9 @@ +from "EMBED" import unpack + +// Unpack a field element as N big endian bits without checking for overflows +// This does *not* guarantee a single output: for example, 0 can be decomposed as 0 or as P and this function does not enforce either +def main(field i) -> bool[N]: + + bool[N] res = unpack(i) + + return res \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json new file mode 100644 index 00000000..5739811a --- /dev/null +++ b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/utils/pack/bool/unpack256.zok", + "curves": ["Bn128"], + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": [] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok new file mode 100644 index 00000000..921ccb02 --- /dev/null +++ b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok @@ -0,0 +1,24 @@ +import "utils/pack/bool/unpack256" as unpack256 + +def testFive() -> bool: + + bool[256] b = unpack256(5) + + assert(b == [...[false; 253], true, false, true]) + + return true + +def testZero() -> bool: + + bool[256] b = unpack256(0) + + assert(b == [false; 256]) + + return true + + def main(): + + assert(testFive()) + assert(testZero()) + + return