From 33c8fba1e1bce86ef85c1f64b9678f75d0140d60 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 5 Aug 2021 13:36:48 +0200 Subject: [PATCH] add comments, use iterators --- zokrates_core/src/flatten/mod.rs | 8 +++++++- zokrates_core/src/ir/interpreter.rs | 27 ++++++++++++--------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 1d230eae..e096f9b2 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1161,13 +1161,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 8.into())] } crate::embed::FlatEmbed::BitArrayLe => { + // get the length of the bit arrays let len = generics[0]; + // split the arguments into the two bit arrays of size `len` let (expressions, constants) = ( param_expressions[..len as usize].to_vec(), param_expressions[len as usize..].to_vec(), ); + // define variables for the variable bits let variables: Vec<_> = expressions .into_iter() .map(|e| { @@ -1178,6 +1181,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { }) .collect(); + // get constants for the constant bits let constants: Vec<_> = constants .into_iter() .map(|e| { @@ -1185,11 +1189,13 @@ impl<'ast, T: Field> Flattener<'ast, T> { .get_field_unchecked() }) .map(|e| match e { - FlatExpression::Number(n) => n == T::one(), + FlatExpression::Number(n) if n == T::one() => true, + FlatExpression::Number(n) if n == T::zero() => false, _ => unreachable!(), }) .collect(); + // get the list of conditions which must hold iff the `<=` relation holds let conditions = self.constant_le_check(statements_flattened, &variables, &constants); diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index c6b75dae..b68574f3 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -160,22 +160,19 @@ impl Interpreter { let bit_width = bit_width - padding; - let mut num = inputs[0].clone(); - let mut res = vec![]; + let num = inputs[0].clone(); - for _ in 0..padding { - res.push(T::zero()); - } - - for i in (0..bit_width).rev() { - if T::from(2).pow(i) <= num { - num = num - T::from(2).pow(i); - res.push(T::one()); - } else { - res.push(T::zero()); - } - } - res + (0..padding) + .map(|_| T::zero()) + .chain((0..bit_width).rev().scan(num, |state, i| { + if T::from(2).pow(i) <= *state { + *state = (*state).clone() - T::from(2).pow(i); + Some(T::one()) + } else { + Some(T::zero()) + } + })) + .collect() } Solver::Xor => { let x = inputs[0].clone();