From b31d15d7e4c7e022aa56f5717baf119101c51c72 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 20 Jan 2020 16:35:22 +0100 Subject: [PATCH] fix error message --- zokrates_core/src/absy/mod.rs | 7 +- zokrates_core/src/semantics.rs | 22 +- zokrates_core/src/static_analysis/mod.rs | 8 +- .../src/static_analysis/propagate_unroll.rs | 198 ++++++++++++++++++ .../src/static_analysis/propagation.rs | 43 +++- zokrates_core/src/static_analysis/unroll.rs | 132 ++++++++---- zokrates_core/src/typed_absy/mod.rs | 42 +++- .../tests/tests/nested_loop.json | 14 +- .../tests/tests/nested_loop.zok | 17 +- 9 files changed, 415 insertions(+), 68 deletions(-) create mode 100644 zokrates_core/src/static_analysis/propagate_unroll.rs diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 439fc5a8..3ea31136 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -282,7 +282,12 @@ pub enum Statement<'ast, T: Field> { Declaration(VariableNode<'ast>), Definition(AssigneeNode<'ast, T>, ExpressionNode<'ast, T>), Condition(ExpressionNode<'ast, T>, ExpressionNode<'ast, T>), - For(VariableNode<'ast>, ExpressionNode<'ast, T>, ExpressionNode<'ast, T>, Vec>), + For( + VariableNode<'ast>, + ExpressionNode<'ast, T>, + ExpressionNode<'ast, T>, + Vec>, + ), MultipleDefinition(Vec>, ExpressionNode<'ast, T>), } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index dcb0a8b6..67dd6062 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -845,8 +845,12 @@ impl<'ast> Checker<'ast> { let var = self.check_variable(var, module_id, types).unwrap(); - let from = self.check_expression(from, module_id, &types).map_err(|e| vec![e])?; - let to = self.check_expression(to, module_id, &types).map_err(|e| vec![e])?; + let from = self + .check_expression(from, module_id, &types) + .map_err(|e| vec![e])?; + let to = self + .check_expression(to, module_id, &types) + .map_err(|e| vec![e])?; let from = match from { TypedExpression::FieldElement(e) => Ok(e), @@ -864,7 +868,7 @@ impl<'ast> Checker<'ast> { e => Err(Error { pos: Some(pos), message: format!( - "Expected lower loop bound to be of type field, found {}", + "Expected higher loop bound to be of type field, found {}", e.get_type() ), }) @@ -2668,8 +2672,8 @@ mod tests { let foo_statements = vec![ Statement::For( absy::Variable::new("i", UnresolvedType::FieldElement.mock()).mock(), - FieldPrime::from(0), - FieldPrime::from(10), + Expression::FieldConstant(FieldPrime::from(0)).mock(), + Expression::FieldConstant(FieldPrime::from(10)).mock(), vec![], ) .mock(), @@ -2726,8 +2730,8 @@ mod tests { let foo_statements = vec![Statement::For( absy::Variable::new("i", UnresolvedType::FieldElement.mock()).mock(), - FieldPrime::from(0), - FieldPrime::from(10), + Expression::FieldConstant(FieldPrime::from(0)).mock(), + Expression::FieldConstant(FieldPrime::from(10)).mock(), for_statements, ) .mock()]; @@ -2742,8 +2746,8 @@ mod tests { let foo_statements_checked = vec![TypedStatement::For( typed_absy::Variable::field_element("i".into()), - FieldPrime::from(0), - FieldPrime::from(10), + FieldElementExpression::Number(FieldPrime::from(0)), + FieldElementExpression::Number(FieldPrime::from(10)), for_statements_checked, )]; diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index e2c2e10d..c54c9ce5 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -7,13 +7,14 @@ mod constrain_inputs; mod flat_propagation; mod inline; +mod propagate_unroll; mod propagation; mod unroll; use self::constrain_inputs::InputConstrainer; use self::inline::Inliner; +use self::propagate_unroll::PropagatedUnroller; use self::propagation::Propagator; -use self::unroll::Unroller; use crate::flat_absy::FlatProg; use crate::typed_absy::TypedProgram; use zokrates_field::field::Field; @@ -24,14 +25,15 @@ pub trait Analyse { impl<'ast, T: Field> Analyse for TypedProgram<'ast, T> { fn analyse(self) -> Self { - // unroll - let r = Unroller::unroll(self); + // propagated unrolling + let r = PropagatedUnroller::unroll(self); // inline let r = Inliner::inline(r); // propagate let r = Propagator::propagate(r); // constrain inputs let r = InputConstrainer::constrain(r); + r } } diff --git a/zokrates_core/src/static_analysis/propagate_unroll.rs b/zokrates_core/src/static_analysis/propagate_unroll.rs new file mode 100644 index 00000000..2231bb90 --- /dev/null +++ b/zokrates_core/src/static_analysis/propagate_unroll.rs @@ -0,0 +1,198 @@ +//! Module containing iterative unrolling in order to unroll nested loops with variable bounds +//! +//! For example: +//! ```zokrates +//! for field i in 0..5 do +//! for field j in i..5 +//! // +//! endfor +//! enfor +//! ``` +//! +//! We can unroll the outer loop, but to unroll the inner one we need to propagate the value of `i` to the lower bound of the loop +//! +//! This module does exactly that: +//! - unroll the outter loop, detecting that it cannot unroll the inner one as the lower `i` bound isn't constant +//! - apply constant propagation to the program, *not visiting statements of loops whose bounds are not constant yet* +//! - unroll again, this time the 5 inner loops all have constant bounds +//! +//! In the case that a loop bound cannot be reduced to a constant, we rely on a maximum number of passes after which we conclude that a bound is not constant +//! This sets a hard limit on the number of loops with variable bounds in the program +//! +//! @file propagation.rs +//! @author Thibaut Schaeffer +//! @date 2018 + +use static_analysis::propagation::Propagator; +use static_analysis::unroll::Unroller; +use typed_absy::TypedProgram; +use zokrates_field::field::Field; + +pub struct PropagatedUnroller; + +const MAX_DEPTH: usize = 100; + +impl PropagatedUnroller { + pub fn unroll<'ast, T: Field>(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + let mut p = p; + let mut count = 0; + + // unroll a first time, retrieving whether the unroll is complete + let unrolled = Unroller::unroll(p); + let mut complete = unrolled.1; + p = unrolled.0; + + loop { + + // conditions to exit the loop + if complete { + break; + } + if count > MAX_DEPTH { + panic!("Loop unrolling failed. Most likely this happened because a loop bound is not constant") + } + + // propagate + p = Propagator::propagate_verbose(p); + + // unroll + let unrolled = Unroller::unroll(p); + complete = unrolled.1; + p = unrolled.0; + + count = count + 1; + } + + p + } +} + +#[cfg(test)] +mod tests { + use super::*; + use typed_absy::types::{FunctionKey, Signature}; + use typed_absy::*; + use zokrates_field::field::FieldPrime; + + #[test] + fn for_loop() { + // for field i in 0..2 + // for field j in i..2 + // field foo = i + j + + // should be unrolled to + // i_0 = 0 + // j_0 = 0 + // foo_0 = i_0 + j_0 + // j_1 = 1 + // foo_1 = i_0 + j_1 + // i_1 = 1 + // j_2 = 1 + // foo_2 = i_1 + j_2 + + let s = TypedStatement::For( + Variable::field_element("i".into()), + FieldElementExpression::Number(FieldPrime::from(0)), + FieldElementExpression::Number(FieldPrime::from(2)), + vec![TypedStatement::For( + Variable::field_element("j".into()), + FieldElementExpression::Identifier("i".into()), + FieldElementExpression::Number(FieldPrime::from(2)), + vec![ + TypedStatement::Declaration(Variable::field_element("foo".into())), + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element("foo".into())), + FieldElementExpression::Add( + box FieldElementExpression::Identifier("i".into()), + box FieldElementExpression::Identifier("j".into()), + ) + .into(), + ), + ], + )], + ); + + let expected_statements = vec![ + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("i").version(0), + )), + FieldElementExpression::Number(FieldPrime::from(0)).into(), + ), + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("j").version(0), + )), + FieldElementExpression::Number(FieldPrime::from(0)).into(), + ), + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("foo").version(0), + )), + FieldElementExpression::Number(FieldPrime::from(0)).into(), + ), + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("j").version(1), + )), + FieldElementExpression::Number(FieldPrime::from(1)).into(), + ), + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("foo").version(1), + )), + FieldElementExpression::Number(FieldPrime::from(1)).into(), + ), + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("i").version(1), + )), + FieldElementExpression::Number(FieldPrime::from(1)).into(), + ), + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("j").version(2), + )), + FieldElementExpression::Number(FieldPrime::from(1)).into(), + ), + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("foo").version(2), + )), + FieldElementExpression::Number(FieldPrime::from(2)).into(), + ), + ]; + + let p = TypedProgram { + modules: vec![( + "main".to_string(), + TypedModule { + functions: vec![( + FunctionKey::with_id("main"), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + signature: Signature::new(), + statements: vec![s], + }), + )] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + main: "main".to_string(), + }; + + let statements = match PropagatedUnroller::unroll(p).modules["main"].functions + [&FunctionKey::with_id("main")] + .clone() + { + TypedFunctionSymbol::Here(f) => f.statements, + _ => unreachable!(), + }; + + assert_eq!(statements, expected_statements); + } + +} diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 20092c67..62c3ea55 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -12,19 +12,36 @@ use typed_absy::types::{StructMember, Type}; use zokrates_field::field::Field; pub struct Propagator<'ast, T: Field> { + // constants keeps track of constant expressions + // we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is constants: HashMap, TypedExpression<'ast, T>>, + // the verbose mode doesn't remove statements which assign constants to variables + // it required when using propagation in combination with unrolling + verbose: bool, } impl<'ast, T: Field> Propagator<'ast, T> { + fn verbose() -> Self { + Propagator { + constants: HashMap::new(), + verbose: true, + } + } + fn new() -> Self { Propagator { constants: HashMap::new(), + verbose: false, } } pub fn propagate(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { Propagator::new().fold_program(p) } + + pub fn propagate_verbose(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + Propagator::verbose().fold_program(p) + } } fn is_constant<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> bool { @@ -50,6 +67,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { + let res = match s { TypedStatement::Declaration(v) => Some(TypedStatement::Declaration(v)), TypedStatement::Return(expressions) => Some(TypedStatement::Return( @@ -63,8 +81,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { let expr = self.fold_expression(expr); if is_constant(&expr) { - self.constants.insert(TypedAssignee::Identifier(var), expr); - None + self.constants + .insert(TypedAssignee::Identifier(var.clone()), expr.clone()); + match self.verbose { + true => Some(TypedStatement::Definition( + TypedAssignee::Identifier(var), + expr, + )), + false => None, + } } else { Some(TypedStatement::Definition( TypedAssignee::Identifier(var), @@ -86,9 +111,17 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { self.fold_expression(e2), )) } - // we unrolled for loops in the previous step - TypedStatement::For(..) => { - unreachable!("for loop is unexpected, it should have been unrolled") + // only loops with variable bounds are expected here + // we stop propagation here as constants maybe be modified inside the loop body + // which we do not visit + TypedStatement::For(v, from, to, statements) => { + let from = self.fold_field_expression(from); + let to = self.fold_field_expression(to); + + // invalidate the constants map as any constant could be modified inside the loop body, which we don't visit + self.constants.clear(); + + Some(TypedStatement::For(v, from, to, statements)) } TypedStatement::MultipleDefinition(variables, expression_list) => { let expression_list = self.fold_expression_list(expression_list); diff --git a/zokrates_core/src/static_analysis/unroll.rs b/zokrates_core/src/static_analysis/unroll.rs index a9bb9440..a1397ab0 100644 --- a/zokrates_core/src/static_analysis/unroll.rs +++ b/zokrates_core/src/static_analysis/unroll.rs @@ -12,18 +12,23 @@ use std::collections::HashSet; use zokrates_field::field::Field; pub struct Unroller<'ast> { - substitution: HashMap, usize>, + // version index for any variable name + substitution: HashMap<&'ast str, usize>, + // whether all statements could be unrolled so far. Loops with variable bounds cannot. + complete: bool, } impl<'ast> Unroller<'ast> { fn new() -> Self { Unroller { substitution: HashMap::new(), + complete: true, } } fn issue_next_ssa_variable(&mut self, v: Variable<'ast>) -> Variable<'ast> { - let res = match self.substitution.get(&v.id) { + + let res = match self.substitution.get(&v.id.id) { Some(i) => Variable { id: Identifier { id: v.id.id, @@ -34,15 +39,18 @@ impl<'ast> Unroller<'ast> { }, None => Variable { ..v.clone() }, }; + self.substitution - .entry(v.id) + .entry(v.id.id) .and_modify(|e| *e += 1) .or_insert(0); res } - pub fn unroll(p: TypedProgram) -> TypedProgram { - Unroller::new().fold_program(p) + pub fn unroll(p: TypedProgram) -> (TypedProgram, bool) { + let mut unroller = Unroller::new(); + let p = unroller.fold_program(p); + (p, unroller.complete) } fn choose_many( @@ -349,6 +357,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> { }; let base = self.fold_expression(base); + let indices = indices .into_iter() .map(|a| match a { @@ -378,34 +387,45 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> { vec![TypedStatement::MultipleDefinition(variables, exprs)] } TypedStatement::For(v, from, to, stats) => { - let mut values: Vec = vec![]; - let mut current = from; - while current < to { - values.push(current.clone()); - current = T::one() + ¤t; + let from = self.fold_field_expression(from); + let to = self.fold_field_expression(to); + + match (from, to) { + (FieldElementExpression::Number(from), FieldElementExpression::Number(to)) => { + let mut values: Vec = vec![]; + let mut current = from; + while current < to { + values.push(current.clone()); + current = T::one() + ¤t; + } + + let res = values + .into_iter() + .map(|index| { + vec![ + vec![ + TypedStatement::Declaration(v.clone()), + TypedStatement::Definition( + TypedAssignee::Identifier(v.clone()), + FieldElementExpression::Number(index).into(), + ), + ], + stats.clone(), + ] + .into_iter() + .flat_map(|x| x) + }) + .flat_map(|x| x) + .flat_map(|x| self.fold_statement(x)) + .collect(); + + res + } + (from, to) => { + self.complete = false; + vec![TypedStatement::For(v, from, to, stats)] + } } - - let res = values - .into_iter() - .map(|index| { - vec![ - vec![ - TypedStatement::Declaration(v.clone()), - TypedStatement::Definition( - TypedAssignee::Identifier(v.clone()), - FieldElementExpression::Number(index).into(), - ), - ], - stats.clone(), - ] - .into_iter() - .flat_map(|x| x) - }) - .flat_map(|x| x) - .flat_map(|x| self.fold_statement(x)) - .collect(); - - res } s => fold_statement(self, s), } @@ -414,7 +434,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> { fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> { self.substitution = HashMap::new(); for arg in &f.arguments { - self.substitution.insert(arg.id.id.clone(), 0); + self.substitution.insert(arg.id.id.id.clone(), 0); } fold_function(self, f) @@ -422,7 +442,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> { fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { Identifier { - version: self.substitution.get(&n).unwrap_or(&0).clone(), + version: self.substitution.get(&n.id).unwrap_or(&0).clone(), ..n } } @@ -693,8 +713,8 @@ mod tests { let s = TypedStatement::For( Variable::field_element("i".into()), - FieldPrime::from(2), - FieldPrime::from(5), + FieldElementExpression::Number(FieldPrime::from(2)), + FieldElementExpression::Number(FieldPrime::from(5)), vec![ TypedStatement::Declaration(Variable::field_element("foo".into())), TypedStatement::Definition( @@ -748,6 +768,46 @@ mod tests { assert_eq!(u.fold_statement(s), expected); } + #[test] + fn idempotence() { + // an already unrolled program should not be modified by unrolling again + + // a = 5 + // a_1 = 6 + // a_2 = 7 + + // should be turned into + // a = 5 + // a_1 = 6 + // a_2 = 7 + + let mut u = Unroller::new(); + + let s = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("a").version(0), + )), + FieldElementExpression::Number(FieldPrime::from(5)).into(), + ); + assert_eq!(u.fold_statement(s.clone()), vec![s]); + + let s = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("a").version(1), + )), + FieldElementExpression::Number(FieldPrime::from(6)).into(), + ); + assert_eq!(u.fold_statement(s.clone()), vec![s]); + + let s = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("a").version(2), + )), + FieldElementExpression::Number(FieldPrime::from(7)).into(), + ); + assert_eq!(u.fold_statement(s.clone()), vec![s]); + } + #[test] fn definition() { // field a diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index add8f74f..dec15823 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -229,7 +229,7 @@ impl<'ast, T: Field> fmt::Display for TypedFunction<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "({}) -> ({}):\n{}", + "({}) -> ({}):", self.arguments .iter() .map(|x| format!("{}", x)) @@ -241,12 +241,16 @@ impl<'ast, T: Field> fmt::Display for TypedFunction<'ast, T> { .map(|x| format!("{}", x)) .collect::>() .join(", "), - self.statements - .iter() - .map(|x| format!("\t{}", x)) - .collect::>() - .join("\n") - ) + )?; + + writeln!(f, "")?; + + for s in &self.statements { + s.fmt_indented(f, 1)?; + writeln!(f, "")?; + } + + Ok(()) } } @@ -326,7 +330,12 @@ pub enum TypedStatement<'ast, T: Field> { Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>), Declaration(Variable<'ast>), Condition(TypedExpression<'ast, T>, TypedExpression<'ast, T>), - For(Variable<'ast>, FieldElementExpression<'ast, T>, FieldElementExpression<'ast, T>, Vec>), + For( + Variable<'ast>, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Vec>, + ), MultipleDefinition(Vec>, TypedExpressionList<'ast, T>), } @@ -364,6 +373,23 @@ impl<'ast, T: Field> fmt::Debug for TypedStatement<'ast, T> { } } +impl<'ast, T: Field> TypedStatement<'ast, T> { + fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result { + match self { + TypedStatement::For(variable, from, to, statements) => { + write!(f, "{}", "\t".repeat(depth))?; + writeln!(f, "for {} in {}..{} do", variable, from, to)?; + for s in statements { + s.fmt_indented(f, depth + 1)?; + writeln!(f, "")?; + } + writeln!(f, "{}endfor", "\t".repeat(depth)) + } + s => write!(f, "{}{}", "\t".repeat(depth), s), + } + } +} + impl<'ast, T: Field> fmt::Display for TypedStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { diff --git a/zokrates_core_test/tests/tests/nested_loop.json b/zokrates_core_test/tests/tests/nested_loop.json index c40a0575..82b19745 100644 --- a/zokrates_core_test/tests/tests/nested_loop.json +++ b/zokrates_core_test/tests/tests/nested_loop.json @@ -3,11 +3,21 @@ "tests": [ { "input": { - "values": ["1", "2", "0", "1"] + "values": ["1", "2", "3", "4"] }, "output": { "Ok": { - "values": [] + "values": ["4838400", "10"] + } + } + }, + { + "input": { + "values": ["0", "1", "2", "3"] + }, + "output": { + "Ok": { + "values": ["0", "10"] } } } diff --git a/zokrates_core_test/tests/tests/nested_loop.zok b/zokrates_core_test/tests/tests/nested_loop.zok index e94b3bad..581571b3 100644 --- a/zokrates_core_test/tests/tests/nested_loop.zok +++ b/zokrates_core_test/tests/tests/nested_loop.zok @@ -1,9 +1,18 @@ -def main(field[4] values) -> (): - field acc = 1 +def main(field[4] values) -> (field, field): + field res0 = 1 + field res1 = 0 + + field counter = 0 + for field i in 0..4 do for field j in i..4 do - acc = acc * (values[j] - values[i]) + counter = counter + 1 + res0 = res0 * (values[i] + values[j]) endfor endfor - return \ No newline at end of file + for field i in 0..counter do + res1 = res1 + 1 + endfor + + return res0, res1 \ No newline at end of file