diff --git a/Cargo.lock b/Cargo.lock index 93a022a6..a26a52d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2451,6 +2451,7 @@ dependencies = [ "zokrates_common", "zokrates_embed", "zokrates_field", + "zokrates_fs_resolver", "zokrates_pest_ast", ] diff --git a/changelogs/unreleased/955-schaeff b/changelogs/unreleased/955-schaeff new file mode 100644 index 00000000..fff1c6e2 --- /dev/null +++ b/changelogs/unreleased/955-schaeff @@ -0,0 +1 @@ +Make the stdlib `unpack` function safe against overflows of bit decompositions for any size of output, introduce `unpack_unchecked` for cases that do not require determinism \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/variable_constant_lt.zok b/zokrates_cli/examples/compile_errors/variable_constant_lt.zok new file mode 100644 index 00000000..062d8295 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/variable_constant_lt.zok @@ -0,0 +1,5 @@ +from "EMBED" import bit_array_le + +// Calling the `bit_array_le` embed on a non-constant second argument should fail at compile-time +def main(bool[1] a, bool[1] b) -> bool: + return bit_array_le::<1>(a, b) \ No newline at end of file diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index 204da484..3c1b886f 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -60,6 +60,7 @@ sha2 = { version = "0.9.3", optional = true } [dev-dependencies] wasm-bindgen-test = "^0.3.0" pretty_assertions = "0.6.1" +zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"} [build-dependencies] cc = { version = "1.0", features = ["parallel"], optional = true } diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 9f87b00e..3ee07660 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -28,6 +28,7 @@ cfg_if::cfg_if! { /// the flattening step when it can be inlined. #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)] pub enum FlatEmbed { + BitArrayLe, U32ToField, Unpack, U8ToBits, @@ -47,6 +48,30 @@ pub enum FlatEmbed { impl FlatEmbed { pub fn signature(&self) -> DeclarationSignature<'static> { match self { + FlatEmbed::BitArrayLe => DeclarationSignature::new() + .generics(vec![Some(DeclarationConstant::Generic( + GenericIdentifier { + name: "N", + index: 0, + }, + ))]) + .inputs(vec![ + DeclarationType::array(( + DeclarationType::Boolean, + GenericIdentifier { + name: "N", + index: 0, + }, + )), + DeclarationType::array(( + DeclarationType::Boolean, + GenericIdentifier { + name: "N", + index: 0, + }, + )), + ]) + .outputs(vec![DeclarationType::Boolean]), FlatEmbed::U32ToField => DeclarationSignature::new() .inputs(vec![DeclarationType::uint(32)]) .outputs(vec![DeclarationType::FieldElement]), @@ -172,6 +197,7 @@ impl FlatEmbed { pub fn id(&self) -> &'static str { match self { + FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT", FlatEmbed::U32ToField => "_U32_TO_FIELD", FlatEmbed::Unpack => "_UNPACK", FlatEmbed::U8ToBits => "_U8_TO_BITS", @@ -449,14 +475,9 @@ fn use_variable( /// * bit_width the number of bits we want to decompose to /// /// # Remarks -/// * the return value of the `FlatFunction` is not deterministic if `bit_width == T::get_required_bits()` -/// as we decompose over `log_2(p) + 1 bits, some -/// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)` +/// * the return value of the `FlatFunction` is not deterministic if `bit_width >= T::get_required_bits()` +/// as some elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)` pub fn unpack_to_bitwidth(bit_width: usize) -> FlatFunction { - let nbits = T::get_required_bits(); - - assert!(bit_width <= nbits); - let mut counter = 0; let mut layout = HashMap::new(); diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 126190b8..901355eb 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -223,7 +223,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { b: &[bool], ) -> Vec> { let len = b.len(); - assert_eq!(a.len(), T::get_required_bits()); assert_eq!(a.len(), b.len()); let mut is_not_smaller_run = vec![]; @@ -1164,6 +1163,58 @@ impl<'ast, T: Field> Flattener<'ast, T> { crate::embed::FlatEmbed::U8FromBits => { vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 8.into())] } + crate::embed::FlatEmbed::BitArrayLe => { + // get the length of the bit arrays + let len = generics[0]; + + // split the arguments into the two bit arrays of size `len` + let (expressions, constants) = ( + param_expressions[..len as usize].to_vec(), + param_expressions[len as usize..].to_vec(), + ); + + // define variables for the variable bits + let variables: Vec<_> = expressions + .into_iter() + .map(|e| { + let e = self + .flatten_expression(statements_flattened, e) + .get_field_unchecked(); + self.define(e, statements_flattened) + }) + .collect(); + + // get constants for the constant bits + let constants: Vec<_> = constants + .into_iter() + .map(|e| { + self.flatten_expression(statements_flattened, e) + .get_field_unchecked() + }) + .map(|e| match e { + FlatExpression::Number(n) if n == T::one() => true, + FlatExpression::Number(n) if n == T::zero() => false, + _ => unreachable!(), + }) + .collect(); + + // get the list of conditions which must hold iff the `<=` relation holds + let conditions = + self.constant_le_check(statements_flattened, &variables, &constants); + + // return `len(conditions) == sum(conditions)` + vec![FlatUExpression::with_field( + self.eq_check( + statements_flattened, + T::from(conditions.len()).into(), + conditions + .into_iter() + .fold(FlatExpression::Number(T::zero()), |acc, e| { + FlatExpression::Add(box acc, box e) + }), + ), + )] + } funct => { let funct = funct.synthetize(&generics); @@ -1924,8 +1975,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // constants do not require directives if let Some(FlatExpression::Number(ref x)) = e.field { - let bits: Vec<_> = Interpreter::default() - .execute_solver(&Solver::bits(to), &[x.clone()]) + let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()]) .unwrap() .into_iter() .map(FlatExpression::Number) diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index ad0e475a..bd01796e 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -148,6 +148,10 @@ impl Importer { id: symbol.get_alias(), symbol: Symbol::Flat(FlatEmbed::Unpack), }, + "bit_array_le" => SymbolDeclaration { + id: symbol.get_alias(), + symbol: Symbol::Flat(FlatEmbed::BitArrayLe), + }, "u64_to_bits" => SymbolDeclaration { id: symbol.get_alias(), symbol: Symbol::Flat(FlatEmbed::U64ToBits), diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index 04f4108e..580e5ee8 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -1,5 +1,4 @@ use crate::flat_absy::flat_variable::FlatVariable; -use crate::ir::Directive; use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness}; use crate::solvers::Solver; use serde::{Deserialize, Serialize}; @@ -66,30 +65,25 @@ impl Interpreter { } }, Statement::Directive(ref d) => { - match (&d.solver, &d.inputs, self.should_try_out_of_range) { - (Solver::Bits(bitwidth), inputs, true) - if inputs[0].left.0.len() > 1 - || inputs[0].right.0.len() > 1 - && *bitwidth == T::get_required_bits() => - { - Self::try_solve_out_of_range(&d, &mut witness) - } - _ => { - let inputs: Vec<_> = d - .inputs - .iter() - .map(|i| i.evaluate(&witness).unwrap()) - .collect(); - match self.execute_solver(&d.solver, &inputs) { - Ok(res) => { - for (i, o) in d.outputs.iter().enumerate() { - witness.insert(*o, res[i].clone()); - } - continue; - } - Err(_) => return Err(Error::Solver), - }; + let mut inputs: Vec<_> = d + .inputs + .iter() + .map(|i| i.evaluate(&witness).unwrap()) + .collect(); + + let res = match (&d.solver, self.should_try_out_of_range) { + (Solver::Bits(bitwidth), true) if *bitwidth >= T::get_required_bits() => { + Ok(Self::try_solve_with_out_of_range_bits( + *bitwidth, + inputs.pop().unwrap(), + )) } + _ => Self::execute_solver(&d.solver, &inputs), + } + .map_err(|_| Error::Solver)?; + + for (i, o) in d.outputs.iter().enumerate() { + witness.insert(*o, res[i].clone()); } } } @@ -98,34 +92,31 @@ impl Interpreter { Ok(Witness(witness)) } - fn try_solve_out_of_range(d: &Directive, witness: &mut BTreeMap) { + fn try_solve_with_out_of_range_bits(bit_width: usize, input: T) -> Vec { use num::traits::Pow; + use num_bigint::BigUint; + + let candidate = input.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint(); - // we target the `2a - 2b` part of the `<` check by only returning out-of-range results - // when the input is not a single summand - let value = d.inputs[0].evaluate(&witness).unwrap(); - let candidate = value.to_biguint() + T::max_value().to_biguint() + T::from(1).to_biguint(); let input = if candidate < T::from(2).to_biguint().pow(T::get_required_bits()) { candidate } else { - value.to_biguint() + input.to_biguint() }; - let mut num = input; - let mut res = vec![]; - let bits = T::get_required_bits(); - for i in (0..bits).rev() { - if T::from(2).to_biguint().pow(i as usize) <= num { - num -= T::from(2).to_biguint().pow(i as usize); - res.push(T::one()); - } else { - res.push(T::zero()); - } - } - assert_eq!(num, T::zero().to_biguint()); - for (i, o) in d.outputs.iter().enumerate() { - witness.insert(*o, res[i].clone()); - } + let padding = bit_width - T::get_required_bits(); + + (0..padding) + .map(|_| T::zero()) + .chain((0..T::get_required_bits()).rev().scan(input, |state, i| { + if BigUint::from(2usize).pow(i) <= *state { + *state = (*state).clone() - BigUint::from(2usize).pow(i); + Some(T::one()) + } else { + Some(T::zero()) + } + })) + .collect() } fn check_inputs(&self, program: &Prog, inputs: &[U]) -> Result<(), Error> { @@ -139,11 +130,7 @@ impl Interpreter { } } - pub fn execute_solver( - &self, - solver: &Solver, - inputs: &[T], - ) -> Result, String> { + pub fn execute_solver(solver: &Solver, inputs: &[T]) -> Result, String> { let (expected_input_count, expected_output_count) = solver.get_signature(); assert_eq!(inputs.len(), expected_input_count); @@ -156,18 +143,23 @@ impl Interpreter { ], }, Solver::Bits(bit_width) => { - let mut num = inputs[0].clone(); - let mut res = vec![]; + let padding = bit_width.saturating_sub(T::get_required_bits()); - for i in (0..*bit_width).rev() { - if T::from(2).pow(i) <= num { - num = num - T::from(2).pow(i); - res.push(T::one()); - } else { - res.push(T::zero()); - } - } - res + let bit_width = bit_width - padding; + + let num = inputs[0].clone(); + + (0..padding) + .map(|_| T::zero()) + .chain((0..bit_width).rev().scan(num, |state, i| { + if T::from(2).pow(i) <= *state { + *state = (*state).clone() - T::from(2).pow(i); + Some(T::one()) + } else { + Some(T::zero()) + } + })) + .collect() } Solver::Xor => { let x = inputs[0].clone(); @@ -346,16 +338,14 @@ mod tests { fn execute() { let cond_eq = Solver::ConditionEq; let inputs = vec![0]; - let interpreter = Interpreter::default(); - let r = interpreter - .execute_solver( - &cond_eq, - &inputs - .iter() - .map(|&i| Bn128Field::from(i)) - .collect::>(), - ) - .unwrap(); + let r = Interpreter::execute_solver( + &cond_eq, + &inputs + .iter() + .map(|&i| Bn128Field::from(i)) + .collect::>(), + ) + .unwrap(); let res: Vec = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect(); assert_eq!(r, &res[..]); } @@ -364,16 +354,14 @@ mod tests { fn execute_non_eq() { let cond_eq = Solver::ConditionEq; let inputs = vec![1]; - let interpreter = Interpreter::default(); - let r = interpreter - .execute_solver( - &cond_eq, - &inputs - .iter() - .map(|&i| Bn128Field::from(i)) - .collect::>(), - ) - .unwrap(); + let r = Interpreter::execute_solver( + &cond_eq, + &inputs + .iter() + .map(|&i| Bn128Field::from(i)) + .collect::>(), + ) + .unwrap(); let res: Vec = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect(); assert_eq!(r, &res[..]); } @@ -382,10 +370,9 @@ mod tests { #[test] fn bits_of_one() { let inputs = vec![Bn128Field::from(1)]; - let interpreter = Interpreter::default(); - let res = interpreter - .execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = + Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) + .unwrap(); assert_eq!(res[253], Bn128Field::from(1)); for r in &res[0..253] { assert_eq!(*r, Bn128Field::from(0)); @@ -395,10 +382,9 @@ mod tests { #[test] fn bits_of_42() { let inputs = vec![Bn128Field::from(42)]; - let interpreter = Interpreter::default(); - let res = interpreter - .execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = + Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) + .unwrap(); assert_eq!(res[253], Bn128Field::from(0)); assert_eq!(res[252], Bn128Field::from(1)); assert_eq!(res[251], Bn128Field::from(0)); @@ -407,4 +393,15 @@ mod tests { assert_eq!(res[248], Bn128Field::from(1)); assert_eq!(res[247], Bn128Field::from(0)); } + + #[test] + fn five_hundred_bits_of_1() { + let inputs = vec![Bn128Field::from(1)]; + let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs).unwrap(); + + let mut expected = vec![Bn128Field::from(0); 500]; + expected[499] = Bn128Field::from(1); + + assert_eq!(res, expected); + } } diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index 6b197b1e..4222281c 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -140,9 +140,7 @@ impl Folder for RedefinitionOptimizer { // unwrap inputs to their constant value let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect(); // run the solver - let outputs = Interpreter::default() - .execute_solver(&d.solver, &inputs) - .unwrap(); + let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap(); assert_eq!(outputs.len(), d.outputs.len()); // insert the results in the substitution diff --git a/zokrates_core/src/static_analysis/shift_checker.rs b/zokrates_core/src/static_analysis/constant_argument_checker.rs similarity index 53% rename from zokrates_core/src/static_analysis/shift_checker.rs rename to zokrates_core/src/static_analysis/constant_argument_checker.rs index 7e44ea52..91ec166e 100644 --- a/zokrates_core/src/static_analysis/shift_checker.rs +++ b/zokrates_core/src/static_analysis/constant_argument_checker.rs @@ -1,20 +1,22 @@ +use crate::embed::FlatEmbed; use crate::typed_absy::TypedProgram; use crate::typed_absy::{ - result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth, - UExpressionInner, + result_folder::ResultFolder, + result_folder::{fold_expression_list_inner, fold_uint_expression_inner}, + Constant, TypedExpressionListInner, Types, UBitwidth, UExpressionInner, }; use zokrates_field::Field; -pub struct ShiftChecker; +pub struct ConstantArgumentChecker; -impl ShiftChecker { +impl ConstantArgumentChecker { pub fn check(p: TypedProgram) -> Result, Error> { - ShiftChecker.fold_program(p) + ConstantArgumentChecker.fold_program(p) } } pub type Error = String; -impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker { +impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker { type Error = Error; fn fold_uint_expression_inner( @@ -52,4 +54,33 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker { e => fold_uint_expression_inner(self, bitwidth, e), } } + + fn fold_expression_list_inner( + &mut self, + tys: &Types<'ast, T>, + l: TypedExpressionListInner<'ast, T>, + ) -> Result, Error> { + match l { + TypedExpressionListInner::EmbedCall(FlatEmbed::BitArrayLe, generics, arguments) => { + let arguments = arguments + .into_iter() + .map(|a| self.fold_expression(a)) + .collect::, _>>()?; + + if arguments[1].is_constant() { + Ok(TypedExpressionListInner::EmbedCall( + FlatEmbed::BitArrayLe, + generics, + arguments, + )) + } else { + Err(format!( + "Cannot compare to a variable value, found `{}`", + arguments[1] + )) + } + } + l => fold_expression_list_inner(self, tys, l), + } + } } diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index 36ae9af3..4296c9f9 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -570,7 +570,7 @@ fn fold_select_expression<'ast, T: Field, E>( statements_buffer: &mut Vec>, select: typed_absy::SelectExpression<'ast, T, E>, ) -> Vec> { - let size = typed_absy::types::ConcreteType::try_from(*select.array.ty().ty) + let size = typed_absy::types::ConcreteType::try_from(*select.array.ty().clone().ty) .unwrap() .get_primitive_count(); diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 82b9d6fe..3b8751f1 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -5,21 +5,21 @@ //! @date 2018 mod branch_isolator; +mod constant_argument_checker; mod constant_inliner; mod flat_propagation; mod flatten_complex_types; mod propagation; mod reducer; -mod shift_checker; mod uint_optimizer; mod unconstrained_vars; mod variable_write_remover; use self::branch_isolator::Isolator; +use self::constant_argument_checker::ConstantArgumentChecker; use self::flatten_complex_types::Flattener; use self::propagation::Propagator; use self::reducer::reduce_program; -use self::shift_checker::ShiftChecker; use self::uint_optimizer::UintOptimizer; use self::unconstrained_vars::UnconstrainedVariableDetector; use self::variable_write_remover::VariableWriteRemover; @@ -39,7 +39,7 @@ pub trait Analyse { pub enum Error { Reducer(self::reducer::Error), Propagation(self::propagation::Error), - NonConstantShift(self::shift_checker::Error), + NonConstantArgument(self::constant_argument_checker::Error), } impl From for Error { @@ -54,9 +54,9 @@ impl From for Error { } } -impl From for Error { - fn from(e: shift_checker::Error) -> Self { - Error::NonConstantShift(e) +impl From for Error { + fn from(e: constant_argument_checker::Error) -> Self { + Error::NonConstantArgument(e) } } @@ -65,7 +65,7 @@ impl fmt::Display for Error { match self { Error::Reducer(e) => write!(f, "{}", e), Error::Propagation(e) => write!(f, "{}", e), - Error::NonConstantShift(e) => write!(f, "{}", e), + Error::NonConstantArgument(e) => write!(f, "{}", e), } } } @@ -107,9 +107,9 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { let r = VariableWriteRemover::apply(r); log::trace!("\n{}", r); - // detect non constant shifts - log::debug!("Static analyser: Detect non constant shifts"); - let r = ShiftChecker::check(r).map_err(Error::from)?; + // detect non constant shifts and constant lt bounds + log::debug!("Static analyser: Detect non constant arguments"); + let r = ConstantArgumentChecker::check(r).map_err(Error::from)?; log::trace!("\n{}", r); // convert to zir, removing complex types diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 6f3f6c89..c66dc60b 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -124,126 +124,6 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> { } } -fn is_constant(e: &TypedExpression) -> bool { - match e { - TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true, - TypedExpression::Boolean(BooleanExpression::Value(..)) => true, - TypedExpression::Array(a) => match a.as_inner() { - ArrayExpressionInner::Value(v) => v.0.iter().all(|e| match e { - TypedExpressionOrSpread::Expression(e) => is_constant(e), - _ => false, - }), - ArrayExpressionInner::Slice(box a, box from, box to) => { - is_constant(&from.clone().into()) - && is_constant(&to.clone().into()) - && is_constant(&a.clone().into()) - } - ArrayExpressionInner::Repeat(box e, box count) => { - is_constant(&count.clone().into()) && is_constant(&e) - } - _ => false, - }, - TypedExpression::Struct(a) => match a.as_inner() { - StructExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)), - _ => false, - }, - TypedExpression::Uint(a) => matches!(a.as_inner(), UExpressionInner::Value(..)), - _ => false, - } -} - -// in the constant map, we only want canonical constants: [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc -fn to_canonical_constant(e: TypedExpression) -> TypedExpression { - fn to_canonical_constant_aux( - e: TypedExpressionOrSpread, - ) -> Vec> { - match e { - TypedExpressionOrSpread::Expression(e) => vec![e], - TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() { - ArrayExpressionInner::Value(v) => { - v.into_iter().flat_map(to_canonical_constant_aux).collect() - } - _ => unimplemented!(), - }, - } - } - - match e { - TypedExpression::Array(a) => { - let array_ty = a.ty(); - - match a.into_inner() { - ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value( - v.into_iter() - .flat_map(to_canonical_constant_aux) - .map(|e| e.into()) - .collect::>() - .into(), - ) - .annotate(*array_ty.ty, array_ty.size) - .into(), - ArrayExpressionInner::Slice(box a, box from, box to) => { - let from = match from.into_inner() { - UExpressionInner::Value(from) => from as usize, - _ => unreachable!("should be a uint value"), - }; - - let to = match to.into_inner() { - UExpressionInner::Value(to) => to as usize, - _ => unreachable!("should be a uint value"), - }; - - let v = match a.into_inner() { - ArrayExpressionInner::Value(v) => v, - _ => unreachable!("should be an array value"), - }; - - ArrayExpressionInner::Value( - v.into_iter() - .flat_map(to_canonical_constant_aux) - .map(|e| e.into()) - .enumerate() - .filter(|(index, _)| index >= &from && index < &to) - .map(|(_, e)| e) - .collect::>() - .into(), - ) - .annotate(*array_ty.ty, array_ty.size) - .into() - } - ArrayExpressionInner::Repeat(box e, box count) => { - let count = match count.into_inner() { - UExpressionInner::Value(from) => from as usize, - _ => unreachable!("should be a uint value"), - }; - - let e = to_canonical_constant(e); - - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression(e); count].into(), - ) - .annotate(*array_ty.ty, array_ty.size) - .into() - } - _ => unreachable!(), - } - } - TypedExpression::Struct(s) => { - let struct_ty = s.ty().clone(); - - match s.into_inner() { - StructExpressionInner::Value(expressions) => StructExpressionInner::Value( - expressions.into_iter().map(to_canonical_constant).collect(), - ) - .annotate(struct_ty) - .into(), - _ => unreachable!(), - } - } - e => e, - } -} - impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { type Error = Error; @@ -341,10 +221,10 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } }; - if is_constant(&expr) { + if expr.is_constant() { match assignee { TypedAssignee::Identifier(var) => { - let expr = to_canonical_constant(expr); + let expr = expr.into_canonical_constant(); assert!(self.constants.insert(var.id, expr).is_none()); @@ -352,7 +232,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } assignee => match self.try_get_constant_mut(&assignee) { Ok((_, c)) => { - *c = to_canonical_constant(expr); + *c = expr.into_canonical_constant(); Ok(vec![]) } Err(v) => match self.constants.remove(&v.id) { @@ -423,7 +303,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { let argument = arguments.pop().unwrap(); - let argument = to_canonical_constant(argument); + let argument = argument.into_canonical_constant(); match ArrayExpression::try_from(argument) .unwrap() @@ -498,10 +378,11 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } } - match arguments.iter().all(|a| is_constant(a)) { + match arguments.iter().all(|a| a.is_constant()) { true => { let r: Option> = match embed { FlatEmbed::U32ToField => None, // todo + FlatEmbed::BitArrayLe => None, // todo FlatEmbed::U64FromBits => Some(process_u_from_bits( assignees.clone(), arguments.clone(), @@ -1173,6 +1054,34 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { }, None => Ok(ArrayExpressionInner::Identifier(id)), }, + ArrayExpressionInner::Value(exprs) => { + Ok(ArrayExpressionInner::Value( + exprs + .into_iter() + .map(|e| self.fold_expression_or_spread(e)) + .collect::, _>>()? + .into_iter() + .flat_map(|e| { + match e { + // simplify `...[a, b]` to `a, b` + TypedExpressionOrSpread::Spread(TypedSpread { + array: + ArrayExpression { + inner: ArrayExpressionInner::Value(v), + .. + }, + }) => v.0, + e => vec![e], + } + }) + // ignore spreads over empty arrays + .filter_map(|e| match e { + TypedExpressionOrSpread::Spread(s) if s.array.size() == 0 => None, + e => Some(e), + }) + .collect(), + )) + } e => fold_array_expression_inner(self, ty, e), } } diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index ae26d6cd..84790f29 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -1273,13 +1273,8 @@ mod tests { // def main(): // # PUSH CALL to foo::<1> // # PUSH CALL to bar::<2> - // field[2] a_1 = [...[1]], 0] - // field[2] #RET_0_1 = a_1 // # POP CALL - // field[1] ret := #RET_0_1[0..1] - // field[1] #RET_0 = ret // # POP CALL - // field[1] b_0 := #RET_0 // return let foo_signature = DeclarationSignature::new() @@ -1452,71 +1447,8 @@ mod tests { .collect(), ), ), - TypedStatement::Definition( - Variable::array(Identifier::from("a").version(1), Type::FieldElement, 2u32) - .into(), - ArrayExpressionInner::Value( - vec![ - TypedExpressionOrSpread::Spread( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(1)).into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - FieldElementExpression::Number(Bn128Field::from(0)).into(), - ] - .into(), - ) - .annotate(Type::FieldElement, 2u32) - .into(), - ), - TypedStatement::Definition( - Variable::array( - Identifier::from(CoreIdentifier::Call(0)).version(1), - Type::FieldElement, - 2u32, - ) - .into(), - ArrayExpressionInner::Identifier(Identifier::from("a").version(1)) - .annotate(Type::FieldElement, 2u32) - .into(), - ), TypedStatement::PopCallLog, - TypedStatement::Definition( - Variable::array("ret", Type::FieldElement, 1u32).into(), - ArrayExpressionInner::Slice( - box ArrayExpressionInner::Identifier( - Identifier::from(CoreIdentifier::Call(0)).version(1), - ) - .annotate(Type::FieldElement, 2u32), - box 0u32.into(), - box 1u32.into(), - ) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - TypedStatement::Definition( - Variable::array( - Identifier::from(CoreIdentifier::Call(0)), - Type::FieldElement, - 1u32, - ) - .into(), - ArrayExpressionInner::Identifier("ret".into()) - .annotate(Type::FieldElement, 1u32) - .into(), - ), TypedStatement::PopCallLog, - TypedStatement::Definition( - Variable::array("b", Type::FieldElement, 1u32).into(), - ArrayExpressionInner::Identifier(Identifier::from(CoreIdentifier::Call(0))) - .annotate(Type::FieldElement, 1u32) - .into(), - ), TypedStatement::Return(vec![]), ], signature: DeclarationSignature::new(), diff --git a/zokrates_core/src/static_analysis/uint_optimizer.rs b/zokrates_core/src/static_analysis/uint_optimizer.rs index ecca499a..232c4424 100644 --- a/zokrates_core/src/static_analysis/uint_optimizer.rs +++ b/zokrates_core/src/static_analysis/uint_optimizer.rs @@ -738,7 +738,7 @@ mod tests { assert_eq!( UintOptimizer::new() - .fold_uint_expression(UExpression::right_shift(left.clone(), right.clone())), + .fold_uint_expression(UExpression::right_shift(left.clone(), right)), UExpression::right_shift(left_expected, right_expected).with_max(output_max) ); } @@ -761,7 +761,7 @@ mod tests { assert_eq!( UintOptimizer::new() - .fold_uint_expression(UExpression::left_shift(left.clone(), right.clone())), + .fold_uint_expression(UExpression::left_shift(left.clone(), right)), UExpression::left_shift(left_expected, right_expected).with_max(output_max) ); } diff --git a/zokrates_core/src/typed_absy/integer.rs b/zokrates_core/src/typed_absy/integer.rs index 59589df2..62eba07c 100644 --- a/zokrates_core/src/typed_absy/integer.rs +++ b/zokrates_core/src/typed_absy/integer.rs @@ -5,7 +5,7 @@ use crate::typed_absy::types::{ }; use crate::typed_absy::UBitwidth; use crate::typed_absy::{ - ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse, + ArrayExpression, ArrayExpressionInner, BooleanExpression, Expr, FieldElementExpression, IfElse, IfElseExpression, Select, SelectExpression, StructExpression, StructExpressionInner, Typed, TypedExpression, TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner, }; @@ -585,7 +585,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { array: Self, target_array_ty: &GArrayType, ) -> Result> { - let array_ty = array.ty(); + let array_ty = array.ty().clone(); // elements must fit in the target type match array.into_inner() { diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 89f93976..b31dfeee 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -1029,8 +1029,8 @@ impl<'ast, T> From for BooleanExpression<'ast, T> { /// type checking #[derive(Clone, PartialEq, Debug, Hash, Eq)] pub struct ArrayExpression<'ast, T> { - ty: Box>, - inner: ArrayExpressionInner<'ast, T>, + pub ty: Box>, + pub inner: ArrayExpressionInner<'ast, T>, } #[derive(Debug, PartialEq, Eq, Hash, Clone)] @@ -1149,25 +1149,6 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> { pub fn size(&self) -> UExpression<'ast, T> { self.ty.size.clone() } - - pub fn as_inner(&self) -> &ArrayExpressionInner<'ast, T> { - &self.inner - } - - pub fn as_inner_mut(&mut self) -> &mut ArrayExpressionInner<'ast, T> { - &mut self.inner - } - - pub fn into_inner(self) -> ArrayExpressionInner<'ast, T> { - self.inner - } - - pub fn ty(&self) -> ArrayType<'ast, T> { - ArrayType { - size: self.size(), - ty: box self.inner_type().clone(), - } - } } #[derive(Clone, PartialEq, Debug, Hash, Eq)] @@ -1495,15 +1476,23 @@ pub trait Expr<'ast, T>: From> { type Inner; type Ty: Clone + IntoTypes<'ast, T>; + fn ty(&self) -> &Self::Ty; + fn into_inner(self) -> Self::Inner; fn as_inner(&self) -> &Self::Inner; + + fn as_inner_mut(&mut self) -> &mut Self::Inner; } impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; + fn ty(&self) -> &Self::Ty { + &Type::FieldElement + } + fn into_inner(self) -> Self::Inner { self } @@ -1511,12 +1500,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + self + } } impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; + fn ty(&self) -> &Self::Ty { + &Type::Boolean + } + fn into_inner(self) -> Self::Inner { self } @@ -1524,12 +1521,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + self + } } impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> { type Inner = UExpressionInner<'ast, T>; type Ty = UBitwidth; + fn ty(&self) -> &Self::Ty { + &self.bitwidth + } + fn into_inner(self) -> Self::Inner { self.inner } @@ -1537,12 +1542,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self.inner } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + &mut self.inner + } } impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> { type Inner = StructExpressionInner<'ast, T>; type Ty = StructType<'ast, T>; + fn ty(&self) -> &Self::Ty { + &self.ty + } + fn into_inner(self) -> Self::Inner { self.inner } @@ -1550,12 +1563,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self.inner } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + &mut self.inner + } } impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> { type Inner = ArrayExpressionInner<'ast, T>; type Ty = ArrayType<'ast, T>; + fn ty(&self) -> &Self::Ty { + &self.ty + } + fn into_inner(self) -> Self::Inner { self.inner } @@ -1563,12 +1584,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self.inner } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + &mut self.inner + } } impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; + fn ty(&self) -> &Self::Ty { + &Type::Int + } + fn into_inner(self) -> Self::Inner { self } @@ -1576,12 +1605,20 @@ impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + self + } } impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> { type Inner = TypedExpressionListInner<'ast, T>; type Ty = Types<'ast, T>; + fn ty(&self) -> &Self::Ty { + &self.types + } + fn into_inner(self) -> Self::Inner { self.inner } @@ -1589,6 +1626,10 @@ impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> { fn as_inner(&self) -> &Self::Inner { &self.inner } + + fn as_inner_mut(&mut self) -> &mut Self::Inner { + &mut self.inner + } } // Enums types to enable returning e.g a member expression OR another type of expression of this type @@ -1769,7 +1810,7 @@ impl<'ast, T> Member<'ast, T> for BooleanExpression<'ast, T> { } } -impl<'ast, T> Member<'ast, T> for UExpression<'ast, T> { +impl<'ast, T: Clone> Member<'ast, T> for UExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { let ty = s.ty().members.iter().find(|member| id == member.id); let bitwidth = match ty { @@ -1949,7 +1990,7 @@ impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> { impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> { fn block(statements: Vec>, value: Self) -> Self { - let array_ty = value.ty(); + let array_ty = value.ty().clone(); ArrayExpressionInner::Block(BlockExpression::new(statements, value)) .annotate(*array_ty.ty, array_ty.size) } @@ -1962,3 +2003,199 @@ impl<'ast, T: Field> Block<'ast, T> for StructExpression<'ast, T> { StructExpressionInner::Block(BlockExpression::new(statements, value)).annotate(struct_ty) } } + +pub trait Constant: Sized { + // return whether this is constant + fn is_constant(&self) -> bool; + + // canonicalize an expression *that we know to be constant* + // for example for [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc + fn into_canonical_constant(self) -> Self { + self + } +} + +impl<'ast, T: Field> Constant for FieldElementExpression<'ast, T> { + fn is_constant(&self) -> bool { + matches!(self, FieldElementExpression::Number(..)) + } +} + +impl<'ast, T: Field> Constant for BooleanExpression<'ast, T> { + fn is_constant(&self) -> bool { + matches!(self, BooleanExpression::Value(..)) + } +} + +impl<'ast, T: Field> Constant for UExpression<'ast, T> { + fn is_constant(&self) -> bool { + matches!(self.as_inner(), UExpressionInner::Value(..)) + } +} + +impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { + fn is_constant(&self) -> bool { + match self.as_inner() { + ArrayExpressionInner::Value(v) => v.0.iter().all(|e| match e { + TypedExpressionOrSpread::Expression(e) => e.is_constant(), + TypedExpressionOrSpread::Spread(s) => s.array.is_constant(), + }), + ArrayExpressionInner::Slice(box a, box from, box to) => { + from.is_constant() && to.is_constant() && a.is_constant() + } + ArrayExpressionInner::Repeat(box e, box count) => { + count.is_constant() && e.is_constant() + } + _ => false, + } + } + + fn into_canonical_constant(self) -> Self { + fn into_canonical_constant_aux( + e: TypedExpressionOrSpread, + ) -> Vec> { + match e { + TypedExpressionOrSpread::Expression(e) => vec![e], + TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() { + ArrayExpressionInner::Value(v) => v + .into_iter() + .flat_map(into_canonical_constant_aux) + .collect(), + ArrayExpressionInner::Slice(box v, box from, box to) => { + let from = match from.into_inner() { + UExpressionInner::Value(v) => v, + _ => unreachable!(), + }; + + let to = match to.into_inner() { + UExpressionInner::Value(v) => v, + _ => unreachable!(), + }; + + let v = match v.into_inner() { + ArrayExpressionInner::Value(v) => v, + _ => unreachable!(), + }; + + v.into_iter() + .flat_map(into_canonical_constant_aux) + .skip(from as usize) + .take(to as usize - from as usize) + .collect() + } + ArrayExpressionInner::Repeat(box e, box count) => { + let count = match count.into_inner() { + UExpressionInner::Value(count) => count, + _ => unreachable!(), + }; + + vec![e; count as usize] + } + a => unreachable!("{}", a), + }, + } + } + + let array_ty = self.ty().clone(); + + match self.into_inner() { + ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value( + v.into_iter() + .flat_map(into_canonical_constant_aux) + .map(|e| e.into()) + .collect::>() + .into(), + ) + .annotate(*array_ty.ty, array_ty.size), + ArrayExpressionInner::Slice(box a, box from, box to) => { + let from = match from.into_inner() { + UExpressionInner::Value(from) => from as usize, + _ => unreachable!("should be a uint value"), + }; + + let to = match to.into_inner() { + UExpressionInner::Value(to) => to as usize, + _ => unreachable!("should be a uint value"), + }; + + let v = match a.into_inner() { + ArrayExpressionInner::Value(v) => v, + _ => unreachable!("should be an array value"), + }; + + ArrayExpressionInner::Value( + v.into_iter() + .flat_map(into_canonical_constant_aux) + .map(|e| e.into()) + .skip(from) + .take(to - from) + .collect::>() + .into(), + ) + .annotate(*array_ty.ty, array_ty.size) + } + ArrayExpressionInner::Repeat(box e, box count) => { + let count = match count.into_inner() { + UExpressionInner::Value(from) => from as usize, + _ => unreachable!("should be a uint value"), + }; + + let e = e.into_canonical_constant(); + + ArrayExpressionInner::Value( + vec![TypedExpressionOrSpread::Expression(e); count].into(), + ) + .annotate(*array_ty.ty, array_ty.size) + } + _ => unreachable!(), + } + } +} + +impl<'ast, T: Field> Constant for StructExpression<'ast, T> { + fn is_constant(&self) -> bool { + match self.as_inner() { + StructExpressionInner::Value(v) => v.iter().all(|e| e.is_constant()), + _ => false, + } + } + + fn into_canonical_constant(self) -> Self { + let struct_ty = self.ty().clone(); + + match self.into_inner() { + StructExpressionInner::Value(expressions) => StructExpressionInner::Value( + expressions + .into_iter() + .map(|e| e.into_canonical_constant()) + .collect(), + ) + .annotate(struct_ty), + _ => unreachable!(), + } + } +} + +impl<'ast, T: Field> Constant for TypedExpression<'ast, T> { + fn is_constant(&self) -> bool { + match self { + TypedExpression::FieldElement(e) => e.is_constant(), + TypedExpression::Boolean(e) => e.is_constant(), + TypedExpression::Array(e) => e.is_constant(), + TypedExpression::Struct(e) => e.is_constant(), + TypedExpression::Uint(e) => e.is_constant(), + _ => unreachable!(), + } + } + + fn into_canonical_constant(self) -> Self { + match self { + TypedExpression::FieldElement(e) => e.into_canonical_constant().into(), + TypedExpression::Boolean(e) => e.into_canonical_constant().into(), + TypedExpression::Array(e) => e.into_canonical_constant().into(), + TypedExpression::Struct(e) => e.into_canonical_constant().into(), + TypedExpression::Uint(e) => e.into_canonical_constant().into(), + _ => unreachable!(), + } + } +} diff --git a/zokrates_core/tests/out_of_range.rs b/zokrates_core/tests/out_of_range.rs index 2152e45c..c27e2361 100644 --- a/zokrates_core/tests/out_of_range.rs +++ b/zokrates_core/tests/out_of_range.rs @@ -10,6 +10,7 @@ use zokrates_core::{ ir::Interpreter, }; use zokrates_field::Bn128Field; +use zokrates_fs_resolver::FileSystemResolver; #[test] fn lt_field() { @@ -74,3 +75,83 @@ fn lt_uint() { ) .is_err()); } + +#[test] +fn unpack256() { + let source = r#" + import "utils/pack/bool/unpack256" + + def main(private field a): + bool[256] bits = unpack256(a) + assert(bits[255]) + return + "# + .to_string(); + + // let's try to prove that the least significant bit of 0 is 1 + // we exploit the fact that the bits of 0 are the bits of p, and p is even + // we want this to still fail + + let stdlib_path = std::fs::canonicalize( + std::env::current_dir() + .unwrap() + .join("../zokrates_stdlib/stdlib"), + ) + .unwrap(); + + let res: CompilationArtifacts = compile( + source, + "./path/to/file".into(), + Some(&FileSystemResolver::with_stdlib_root( + stdlib_path.to_str().unwrap(), + )), + &CompileConfig::default(), + ) + .unwrap(); + + let interpreter = Interpreter::try_out_of_range(); + + assert!(interpreter + .execute(&res.prog(), &[Bn128Field::from(0)]) + .is_err()); +} + +#[test] +fn unpack256_unchecked() { + let source = r#" + import "utils/pack/bool/nonStrictUnpack256" + + def main(private field a): + bool[256] bits = nonStrictUnpack256(a) + assert(bits[255]) + return + "# + .to_string(); + + // let's try to prove that the least significant bit of 0 is 1 + // we exploit the fact that the bits of 0 are the bits of p, and p is odd + // we want this to succeed as the non strict version does not enforce the bits to be in range + + let stdlib_path = std::fs::canonicalize( + std::env::current_dir() + .unwrap() + .join("../zokrates_stdlib/stdlib"), + ) + .unwrap(); + + let res: CompilationArtifacts = compile( + source, + "./path/to/file".into(), + Some(&FileSystemResolver::with_stdlib_root( + stdlib_path.to_str().unwrap(), + )), + &CompileConfig::default(), + ) + .unwrap(); + + let interpreter = Interpreter::try_out_of_range(); + + assert!(interpreter + .execute(&res.prog(), &[Bn128Field::from(0)]) + .is_ok()); +} diff --git a/zokrates_fs_resolver/src/lib.rs b/zokrates_fs_resolver/src/lib.rs index b38a0b17..1289bbe3 100644 --- a/zokrates_fs_resolver/src/lib.rs +++ b/zokrates_fs_resolver/src/lib.rs @@ -26,17 +26,16 @@ impl<'a> Resolver for FileSystemResolver<'a> { ) -> Result<(String, PathBuf), io::Error> { let source = Path::new(&import_location); - if !current_location.is_file() { - return Err(io::Error::new( - io::ErrorKind::Other, - format!("{} was expected to be a file", current_location.display()), - )); - } - // paths starting with `./` or `../` are interpreted relative to the current file // other paths `abc/def` are interpreted relative to the standard library root path let base = match source.components().next() { Some(Component::CurDir) | Some(Component::ParentDir) => { + if !current_location.is_file() { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("{} was expected to be a file", current_location.display()), + )); + } current_location.parent().unwrap().into() } _ => PathBuf::from(self.stdlib_root_path.unwrap_or("")), diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok b/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok index 4e48909f..e31dece4 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/nonStrictUnpack256.zok @@ -1,12 +1,12 @@ #pragma curve bn128 -import "./unpack" as unpack +import "./unpack_unchecked" // Unpack a field element as 256 big-endian bits // Note: uniqueness of the output is not guaranteed // For example, `0` can map to `[0, 0, ..., 0]` or to `bits(p)` def main(field i) -> bool[256]: - bool[254] b = unpack::<254>(i) + bool[254] b = unpack_unchecked::<254>(i) return [false, false, ...b] \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok index d5b7a5cd..bc6d22d1 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok @@ -1,12 +1,12 @@ -#pragma curve bn128 - -from "EMBED" import unpack +import "./unpack_unchecked.zok" +from "field" import FIELD_SIZE_IN_BITS +from "EMBED" import bit_array_le // Unpack a field element as N big endian bits def main(field i) -> bool[N]: - - assert(N <= 254) - bool[N] res = unpack(i) + bool[N] res = unpack_unchecked(i) + + assert(if N >= FIELD_SIZE_IN_BITS then bit_array_le(res, [...[false; N - FIELD_SIZE_IN_BITS], ...unpack_unchecked::(-1)]) else true fi) return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok index a24a244b..8f0b1203 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok @@ -1,9 +1,7 @@ -#pragma curve bn128 - import "./unpack" as unpack // Unpack a field element as 128 big-endian bits -// Precondition: the input is smaller or equal to `2**128 - 1` +// If the input is larger than `2**128 - 1`, the output is truncated. def main(field i) -> bool[128]: bool[128] res = unpack::<128>(i) return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok new file mode 100644 index 00000000..4c3e3e56 --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack256.zok @@ -0,0 +1,7 @@ +import "./unpack" as unpack + +// Unpack a field element as 256 big-endian bits +// If the input is larger than `2**256 - 1`, the output is truncated. +def main(field i) -> bool[256]: + bool[256] res = unpack::<256>(i) + return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack_unchecked.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack_unchecked.zok new file mode 100644 index 00000000..2b0babbe --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack_unchecked.zok @@ -0,0 +1,9 @@ +from "EMBED" import unpack + +// Unpack a field element as N big endian bits without checking for overflows +// This does *not* guarantee a single output: for example, 0 can be decomposed as 0 or as P and this function does not enforce either +def main(field i) -> bool[N]: + + bool[N] res = unpack(i) + + return res \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json new file mode 100644 index 00000000..5739811a --- /dev/null +++ b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.json @@ -0,0 +1,16 @@ +{ + "entry_point": "./tests/tests/utils/pack/bool/unpack256.zok", + "curves": ["Bn128"], + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": [] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok new file mode 100644 index 00000000..921ccb02 --- /dev/null +++ b/zokrates_stdlib/tests/tests/utils/pack/bool/unpack256.zok @@ -0,0 +1,24 @@ +import "utils/pack/bool/unpack256" as unpack256 + +def testFive() -> bool: + + bool[256] b = unpack256(5) + + assert(b == [...[false; 253], true, false, true]) + + return true + +def testZero() -> bool: + + bool[256] b = unpack256(0) + + assert(b == [false; 256]) + + return true + + def main(): + + assert(testFive()) + assert(testZero()) + + return