diff --git a/zokrates_analysis/src/flatten_complex_types.rs b/zokrates_analysis/src/flatten_complex_types.rs index f4b81d8e..0b834ef8 100644 --- a/zokrates_analysis/src/flatten_complex_types.rs +++ b/zokrates_analysis/src/flatten_complex_types.rs @@ -629,8 +629,6 @@ fn fold_statement<'ast, T: Field>( }) .collect(), )], - typed::TypedStatement::PushCallLog(..) => vec![], - typed::TypedStatement::PopCallLog => vec![], typed::TypedStatement::For(..) => unreachable!(), }; diff --git a/zokrates_analysis/src/lib.rs b/zokrates_analysis/src/lib.rs index c628e728..539fe86c 100644 --- a/zokrates_analysis/src/lib.rs +++ b/zokrates_analysis/src/lib.rs @@ -161,10 +161,6 @@ pub fn analyse<'ast, T: Field>( let r = reduce_program(r).map_err(Error::from)?; log::trace!("\n{}", r); - log::debug!("Static analyser: Propagate"); - let r = Propagator::propagate(r)?; - log::trace!("\n{}", r); - log::debug!("Static analyser: Concretize structs"); let r = StructConcretizer::concretize(r); log::trace!("\n{}", r); diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index 8a2f32e0..7d77e86d 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -308,21 +308,12 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } }; - // particular case of `lhs = rhs` - if TypedExpression::from(assignee.clone()) == expr { - return Ok(vec![]); - } - if expr.is_constant() { match assignee { TypedAssignee::Identifier(var) => { let expr = expr.into_canonical_constant(); - assert!( - self.constants.insert(var.clone().id, expr).is_none(), - "{}", - var - ); + assert!(self.constants.insert(var.id, expr).is_none()); Ok(vec![]) } @@ -629,8 +620,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { _ => Ok(vec![TypedStatement::Assertion(expr, err)]), } } - s @ TypedStatement::PushCallLog(..) => Ok(vec![s]), - s @ TypedStatement::PopCallLog => Ok(vec![s]), s => fold_statement(self, s), } } @@ -1502,7 +1491,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(5))) ); } @@ -1515,7 +1504,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(1))) ); } @@ -1528,7 +1517,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(6))) ); } @@ -1541,7 +1530,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } @@ -1554,15 +1543,14 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(8))) ); } #[test] fn left_shift() { - let mut constants = Constants::new(); - let mut propagator = Propagator::with_constants(&mut constants); + let mut propagator = Propagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::LeftShift( @@ -1607,8 +1595,7 @@ mod tests { #[test] fn right_shift() { - let mut constants = Constants::new(); - let mut propagator = Propagator::with_constants(&mut constants); + let mut propagator = Propagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::RightShift( @@ -1676,7 +1663,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(2))) ); } @@ -1691,7 +1678,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } @@ -1713,7 +1700,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } @@ -1735,18 +1722,15 @@ mod tests { BooleanExpression::Not(box BooleanExpression::identifier("a".into())); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_default.clone()), + Propagator::default().fold_boolean_expression(e_default.clone()), Ok(e_default) ); } @@ -1776,23 +1760,19 @@ mod tests { )); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_true), + Propagator::default().fold_boolean_expression(e_constant_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_false), + Propagator::default().fold_boolean_expression(e_constant_false), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_true), + Propagator::default().fold_boolean_expression(e_identifier_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_unchanged.clone()), + Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()), Ok(e_identifier_unchanged) ); } @@ -1800,38 +1780,42 @@ mod tests { #[test] fn bool_eq() { assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(false), BooleanExpression::Value(false) - ))), + )) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(true), BooleanExpression::Value(true) - ))), + )) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(true), BooleanExpression::Value(false) - ))), + )) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(false), BooleanExpression::Value(true) - ))), + )) + ), Ok(BooleanExpression::Value(false)) ); } @@ -1933,33 +1917,27 @@ mod tests { )); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_true), + Propagator::default().fold_boolean_expression(e_constant_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_false), + Propagator::default().fold_boolean_expression(e_constant_false), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_true), + Propagator::default().fold_boolean_expression(e_identifier_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_unchanged.clone()), + Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()), Ok(e_identifier_unchanged) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_non_canonical_true), + Propagator::default().fold_boolean_expression(e_non_canonical_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_non_canonical_false), + Propagator::default().fold_boolean_expression(e_non_canonical_false), Ok(BooleanExpression::Value(false)) ); } @@ -1977,13 +1955,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2001,13 +1977,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2025,13 +1999,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2049,13 +2021,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2065,67 +2035,75 @@ mod tests { let a_bool: Identifier = "a".into(); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); } @@ -2135,67 +2113,75 @@ mod tests { let a_bool: Identifier = "a".into(); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(true), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(false), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(true), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(false), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(true), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(false), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); } diff --git a/zokrates_analysis/src/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs index 7e2436cd..002228f7 100644 --- a/zokrates_analysis/src/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -135,7 +135,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( } }; - let decl = get_canonical_function(&k, program); + let decl = get_canonical_function(k, program); // get an assignment of generics for this call site let assignment: ConcreteGenericsAssignment<'ast> = k @@ -190,7 +190,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( .into_iter() .zip(inferred_signature.inputs.clone()) .map(|(p, t)| ConcreteVariable::new(p.id.id, t, false)) - .map(|v| Variable::from(v)) + .map(Variable::from) .collect(); let (statements, mut returns): (Vec<_>, Vec<_>) = f diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index 51dcc3aa..b2d5da5c 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -3,13 +3,14 @@ // - free of function calls (except for low level calls) thanks to inlining // - free of for-loops thanks to unrolling -// The process happens in two steps -// 1. Shallow SSA for the `main` function -// We turn the `main` function into SSA form, but ignoring function calls and for loops -// 2. Unroll and inline -// We go through the shallow-SSA program and -// - unroll loops -// - inline function calls. This includes applying shallow-ssa on the target function +// The process happens in a greedy way, starting from the main function +// For each statement: +// * put it in ssa form +// * propagate it +// * inline it (calling this process recursively) +// * propagate again + +// if at any time a generic parameter or loop bound is not constant, error out, because it should have been propagated to a constant by the greedy approach mod constants_reader; mod constants_writer; @@ -21,7 +22,6 @@ use std::collections::HashMap; use zokrates_ast::typed::result_folder::*; use zokrates_ast::typed::DeclarationParameter; use zokrates_ast::typed::Folder; -use zokrates_ast::typed::TypedAssemblyStatement; use zokrates_ast::typed::TypedAssignee; use zokrates_ast::typed::{ ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall, @@ -47,23 +47,6 @@ const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20); pub type ConstantDefinitions<'ast, T> = HashMap, TypedExpression<'ast, T>>; -// An SSA version map, giving access to the latest version number for each identifier -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct Versions<'ast> { - map: HashMap, usize>>, -} - -impl<'ast> Versions<'ast> { - fn insert_in_frame( - &mut self, - id: CoreIdentifier<'ast>, - version: usize, - frame: usize, - ) -> Option { - self.map.entry(frame).or_default().insert(id, version) - } -} - #[derive(Debug, PartialEq, Eq)] pub enum Error { Incompatible(String), @@ -125,10 +108,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { &mut self, p: DeclarationParameter<'ast, T>, ) -> Result, Self::Error> { - // this is only used on the entry point - let id = p.id.id.id.id.clone(); - assert!(self.ssa.versions.insert_in_frame(id, 0, 0).is_none()); - Ok(p) + Ok(self.ssa.fold_parameter(p)) } fn fold_function_call_expression< @@ -138,34 +118,42 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ty: &E::Ty, e: FunctionCallExpression<'ast, T, E>, ) -> Result, Self::Error> { + // generics are already in ssa form + let generics = e .generics .into_iter() .map(|g| { g.map(|g| { - let g = self.ssa.fold_uint_expression(g); let g = self.propagator.fold_uint_expression(g)?; + let g = self.fold_uint_expression(g)?; - self.fold_uint_expression(g) + self.propagator + .fold_uint_expression(g) + .map_err(Self::Error::from) }) .transpose() }) .collect::>()?; + // arguments are already in ssa form + let arguments = e .arguments .into_iter() .map(|e| { - let e = self.ssa.fold_expression(e); let e = self.propagator.fold_expression(e)?; + let e = self.fold_expression(e)?; - self.fold_expression(e) + self.propagator + .fold_expression(e) + .map_err(Self::Error::from) }) .collect::>()?; self.ssa.push_call_frame(); - let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, &self.program); + let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program); let res = match res { Ok((input_variables, arguments, generics_bindings, statements, expression)) => { @@ -183,7 +171,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { self.statement_buffer.extend(generics_bindings); - // the lhs is from the inner call frame, the rhs is from the outer one, so only fld the lhs + // the lhs is from the inner call frame, the rhs is from the outer one, so only fold the lhs let input_bindings: Vec<_> = input_variables .into_iter() .zip(arguments) @@ -274,27 +262,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { }) } - fn fold_assembly_statement( - &mut self, - s: TypedAssemblyStatement<'ast, T>, - ) -> Result>, Self::Error> { - Ok(match s { - TypedAssemblyStatement::Assignment(a, e) => { - vec![TypedAssemblyStatement::Assignment( - self.fold_assignee(a)?, - self.fold_expression(e)?, - )] - } - TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { - vec![TypedAssemblyStatement::Constraint( - self.fold_field_expression(lhs)?, - self.fold_field_expression(rhs)?, - metadata, - )] - } - }) - } - fn fold_canonical_constant_identifier( &mut self, _: CanonicalConstantIdentifier<'ast>, @@ -307,28 +274,16 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { s: TypedStatement<'ast, T>, ) -> Result>, Self::Error> { let res = match s { - TypedStatement::Definition(a, rhs) => { - // usually we transform and then propagate - // for definitions we need special treatment: we transform and propagate the rhs (which can contain function calls) - // then we reduce the rhs to remove the function calls - // only then we transform and propagate the assignee - - let rhs = self.ssa.fold_definition_rhs(rhs); - let rhs = self.propagator.fold_definition_rhs(rhs)?; - let rhs = self.fold_definition_rhs(rhs)?; - - let a = self.ssa.fold_assignee(a); - - self.propagator - .fold_statement(TypedStatement::Definition(a, rhs))? - } TypedStatement::For(v, from, to, statements) => { let from = self.ssa.fold_uint_expression(from); let from = self.propagator.fold_uint_expression(from)?; let from = self.fold_uint_expression(from)?; + let from = self.propagator.fold_uint_expression(from)?; + let to = self.ssa.fold_uint_expression(to); let to = self.propagator.fold_uint_expression(to)?; let to = self.fold_uint_expression(to)?; + let to = self.propagator.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { (UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from @@ -345,40 +300,37 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .collect::, _>>()? .into_iter() .flatten() - .collect()), + .collect::>()), _ => Err(Error::NonConstant(format!( "Expected loop bounds to be constant, found {}..{}", from, to ))), }? } - TypedStatement::Return(e) => { - let e = self.ssa.fold_expression(e); - let e = self.propagator.fold_expression(e)?; - vec![TypedStatement::Return(self.fold_expression(e)?)] - } - TypedStatement::Assertion(e, error) => { - let e = self.ssa.fold_boolean_expression(e); - let e = self.propagator.fold_boolean_expression(e)?; + s => { + let statements = self.ssa.fold_statement(s); - vec![TypedStatement::Assertion( - self.fold_boolean_expression(e)?, - error, - )] + let statements = statements + .into_iter() + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); + + let statements = statements + .map(|s| fold_statement(self, s)) + .collect::, _>>()? + .into_iter() + .flatten(); + + let statements = statements + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); + + statements.collect() } - s => self - .ssa - .fold_statement(s) - .into_iter() - .map(|s| self.propagator.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .map(|s| fold_statement(self, s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), }; Ok(self.statement_buffer.drain(..).chain(res).collect()) @@ -394,12 +346,17 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { let array = self.ssa.fold_array_expression(array); let array = self.propagator.fold_array_expression(array)?; let array = self.fold_array_expression(array)?; + let array = self.propagator.fold_array_expression(array)?; + let from = self.ssa.fold_uint_expression(from); let from = self.propagator.fold_uint_expression(from)?; let from = self.fold_uint_expression(from)?; + let from = self.propagator.fold_uint_expression(from)?; + let to = self.ssa.fold_uint_expression(to); let to = self.propagator.fold_uint_expression(to)?; let to = self.fold_uint_expression(to)?; + let to = self.propagator.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { @@ -503,14 +460,11 @@ mod tests { // } // expected: - // def main(field a_0) -> field { - // a_1 = a_0; - // # PUSH CALL to foo - // a_3 := a_1; // input binding - // #RETURN_AT_INDEX_0_0 := a_3; - // # POP CALL - // a_2 = #RETURN_AT_INDEX_0_0; - // return a_2; + // def main(field a_f0_v0) -> field { + // a_f0_v1 = a_f0_v0; // redef + // a_f1_v0 = a_f0_v1; // input binding + // a_f0_v2 = a_f1_v0; // output binding + // return a_f0_v2; // } let foo: TypedFunction = TypedFunction { @@ -606,30 +560,13 @@ mod tests { Variable::field_element(Identifier::from("a").version(1)).into(), FieldElementExpression::identifier("a".into()).into(), ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo").signature( - DeclarationSignature::new() - .inputs(vec![DeclarationType::FieldElement]) - .output(DeclarationType::FieldElement), - ), - GGenericsAssignment::default(), - ), TypedStatement::definition( - Variable::field_element(Identifier::from("a").version(3)).into(), + Variable::field_element(Identifier::from("a").in_frame(1)).into(), FieldElementExpression::identifier(Identifier::from("a").version(1)).into(), ), - TypedStatement::definition( - Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0)) - .into(), - FieldElementExpression::identifier(Identifier::from("a").version(3)).into(), - ), - TypedStatement::PopCallLog, TypedStatement::definition( Variable::field_element(Identifier::from("a").version(2)).into(), - FieldElementExpression::identifier( - Identifier::from(CoreIdentifier::Call(0)).version(0), - ) - .into(), + FieldElementExpression::identifier(Identifier::from("a").in_frame(1)).into(), ), TypedStatement::Return( FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), @@ -678,14 +615,11 @@ mod tests { // } // expected: - // def main(field a_0) -> field { - // field[1] b_0 = [42]; - // # PUSH CALL to foo::<1> - // a_0 = b_0; - // #RETURN_AT_INDEX_0_0 := a_0; - // # POP CALL - // b_1 = #RETURN_AT_INDEX_0_0; - // return a_2 + b_1[0]; + // def main(field a_f0_v0) -> field { + // field[1] b_f0_v0 = [a_f0_v0]; + // a_f1_v0 = b_f0_v0; + // b_f0_v1 = a_f1_v0; + // return a_f0_v0 + b_f0_v1[0]; // } let foo_signature = DeclarationSignature::new() @@ -812,42 +746,19 @@ mod tests { .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - ), TypedStatement::definition( - Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32) + Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier("b".into()) .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::definition( - Variable::array( - Identifier::from(CoreIdentifier::Call(0)).version(0), - Type::FieldElement, - 1u32, - ) - .into(), - ArrayExpression::identifier(Identifier::from("a").version(1)) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - TypedStatement::PopCallLog, TypedStatement::definition( Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) .into(), - ArrayExpression::identifier( - Identifier::from(CoreIdentifier::Call(0)).version(0), - ) - .annotate(Type::FieldElement, 1u32) - .into(), + ArrayExpression::identifier(Identifier::from("a").in_frame(1)) + .annotate(Type::FieldElement, 1u32) + .into(), ), TypedStatement::Return( (FieldElementExpression::identifier("a".into()) @@ -902,14 +813,11 @@ mod tests { // } // expected: - // def main(field a_0) -> field { - // field[1] b_0 = [42]; - // # PUSH CALL to foo::<1> - // a_0 = b_0; - // #RETURN_AT_INDEX_0_0 := a_0; - // # POP CALL - // b_1 = #RETURN_AT_INDEX_0_0; - // return a_2 + b_1[0]; + // def main(field a) -> field { + // field[1] b = [a]; + // a_f1 = b; + // b_1 = a_f1; + // return a + b_1[0]; // } let foo_signature = DeclarationSignature::new() @@ -1040,47 +948,25 @@ mod tests { TypedStatement::definition( Variable::array("b", Type::FieldElement, 1u32).into(), ArrayExpressionInner::Value( - vec![FieldElementExpression::identifier("a".into()).into()].into(), + vec![FieldElementExpression::identifier(Identifier::from("a")).into()] + .into(), ) .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - ), TypedStatement::definition( - Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32) + Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier("b".into()) .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::definition( - Variable::array( - Identifier::from(CoreIdentifier::Call(0)).version(0), - Type::FieldElement, - 1u32, - ) - .into(), - ArrayExpression::identifier(Identifier::from("a").version(1)) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - TypedStatement::PopCallLog, TypedStatement::definition( Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) .into(), - ArrayExpression::identifier( - Identifier::from(CoreIdentifier::Call(0)).version(0), - ) - .annotate(Type::FieldElement, 1u32) - .into(), + ArrayExpression::identifier(Identifier::from("a").in_frame(1)) + .annotate(Type::FieldElement, 1u32) + .into(), ), TypedStatement::Return( (FieldElementExpression::identifier("a".into()) @@ -1306,33 +1192,11 @@ mod tests { let expected_main = TypedFunction { arguments: vec![], - statements: vec![ - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "bar") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 2)] - .into_iter() - .collect(), - ), - ), - TypedStatement::PopCallLog, - TypedStatement::PopCallLog, - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) - .annotate(TupleType::new(vec![])) - .into(), - ), - ], + statements: vec![TypedStatement::Return( + TupleExpressionInner::Value(vec![]) + .annotate(TupleType::new(vec![])) + .into(), + )], signature: DeclarationSignature::new(), }; diff --git a/zokrates_analysis/src/reducer/shallow_ssa.rs b/zokrates_analysis/src/reducer/shallow_ssa.rs index 0d450950..aaec4fd5 100644 --- a/zokrates_analysis/src/reducer/shallow_ssa.rs +++ b/zokrates_analysis/src/reducer/shallow_ssa.rs @@ -1,7 +1,6 @@ -// The SSA transformation leaves gaps in the indices when it hits a for-loop, so that the body of the for-loop can -// modify the variables in scope. The state of the indices before all for-loops is returned to account for that possibility. -// Function calls are also left unvisited -// Saving the indices is not required for function calls, as they cannot modify their environment +// The SSA transformation +// * introduces new versions if and only if we are assigning to an identifier +// * does not visit the statements of loops // Example: // def main(field a) -> field { @@ -19,21 +18,34 @@ // u32 n_0 = 42; // a_1 = a_0 + 1; // field b_0 = foo(a_1); // we keep the function call as is -// # versions: {n: 0, a: 1, b: 0} // for u32 i_0 in 0..n_0 { // // we keep the loop body as is // } // return b_3; // we leave versions b_1 and b_2 to make b accessible and modifiable inside the for-loop // } +use std::collections::HashMap; + use zokrates_ast::typed::folder::*; -use zokrates_ast::typed::identifier::FrameIdentifier; use zokrates_ast::typed::*; use zokrates_field::Field; -use super::Versions; +// An SSA version map, giving access to the latest version number for each identifier +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Versions<'ast> { + map: HashMap, usize>>, +} + +impl<'ast> Default for Versions<'ast> { + fn default() -> Self { + // create a call frame at index 0 + Self { + map: vec![(0, Default::default())].into_iter().collect(), + } + } +} #[derive(Debug, Default)] pub struct ShallowTransformer<'ast> { @@ -45,14 +57,16 @@ pub struct ShallowTransformer<'ast> { impl<'ast> ShallowTransformer<'ast> { pub fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> { - let frame_versions = self.versions.map.entry(self.frame()).or_default(); + let frame = self.frame(); + + let frame_versions = self.versions.map.entry(frame).or_default(); let version = frame_versions .entry(c_id.clone()) .and_modify(|e| *e += 1) // if it was already declared, we increment .or_default(); // otherwise, we start from this version - Identifier::from(c_id).version(*version) + Identifier::from(c_id.in_frame(frame)).version(*version) } fn issue_next_ssa_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> { @@ -81,43 +95,69 @@ impl<'ast> ShallowTransformer<'ast> { self.versions.map.remove(&frame); } - pub fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { + // fold an assignee replacing by the latest version. This is necessary because the trait implementation increases the ssa version for identifiers, + // but this should not be applied recursively to complex assignees + fn fold_assignee_no_ssa_increase( + &mut self, + a: TypedAssignee<'ast, T>, + ) -> TypedAssignee<'ast, T> { match a { - TypedAssignee::Identifier(v) => { - let v = self.issue_next_ssa_variable(v); - TypedAssignee::Identifier(self.fold_variable(v)) + TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)), + TypedAssignee::Select(box a, box index) => TypedAssignee::Select( + box self.fold_assignee_no_ssa_increase(a), + box self.fold_uint_expression(index), + ), + TypedAssignee::Member(box s, m) => { + TypedAssignee::Member(box self.fold_assignee_no_ssa_increase(s), m) + } + TypedAssignee::Element(box s, index) => { + TypedAssignee::Element(box self.fold_assignee_no_ssa_increase(s), index) } - a => fold_assignee(self, a), } } } impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { - fn fold_assembly_statement( + fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> { + for g in &f.signature.generics { + let generic_parameter = match g.as_ref().unwrap() { + DeclarationConstant::Generic(g) => g, + _ => unreachable!(), + }; + let _ = self.issue_next_identifier(CoreIdentifier::from(generic_parameter.clone())); + } + + fold_function(self, f) + } + + fn fold_parameter( &mut self, - s: TypedAssemblyStatement<'ast, T>, - ) -> Vec> { - match s { - TypedAssemblyStatement::Assignment(a, e) => { - let e = self.fold_expression(e); - let a = self.fold_assignee(a); - vec![TypedAssemblyStatement::Assignment(a, e)] - } - s => fold_assembly_statement(self, s), + p: DeclarationParameter<'ast, T>, + ) -> DeclarationParameter<'ast, T> { + DeclarationParameter { + id: DeclarationVariable { + id: self.issue_next_identifier(p.id.id.id.id), + ..p.id + }, + ..p } } + + fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { + match a { + // create a new version for assignments to identifiers + TypedAssignee::Identifier(v) => { + let v = self.issue_next_ssa_variable(v); + TypedAssignee::Identifier(self.fold_variable(v)) + } + // otherwise, simply replace by the current version + a => self.fold_assignee_no_ssa_increase(a), + } + } + fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { match s { - TypedStatement::Definition(a, DefinitionRhs::Expression(e)) => { - let e = self.fold_expression(e); - let a = self.fold_assignee(a); - vec![TypedStatement::definition(a, e)] - } - TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => { - let embed_call = self.fold_embed_call(embed_call); - let assignee = self.fold_assignee(assignee); - vec![TypedStatement::embed_call_definition(assignee, embed_call)] - } + // only fold bounds of for loop statements TypedStatement::For(v, from, to, stats) => { let from = self.fold_uint_expression(from); let to = self.fold_uint_expression(to); @@ -127,6 +167,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { } } + // retrieve the latest version fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { let version = self .versions @@ -137,13 +178,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { .cloned() .unwrap_or(0); - let id = FrameIdentifier { - frame: self.frame(), - ..n.id - }; - - let res = Identifier { version, id }; - res + n.in_frame(self.frame()).version(version) } } @@ -156,36 +191,57 @@ mod tests { use super::*; #[test] - fn detect_non_constant_bound() { - let loops: Vec> = vec![TypedStatement::For( - Variable::new("i", Type::Uint(UBitwidth::B32), false), - UExpression::identifier("i".into()).annotate(UBitwidth::B32), - 2u32.into(), - vec![], - )]; + fn ignore_loop_content() { + // field foo = 0 + // u32 i = 4; + // for u32 i in i..2 { + // foo = 5; + // } - let statements = loops; + // should be left unchanged, as we do not visit the loop content nor the index variable let f = TypedFunction { arguments: vec![], - signature: DeclarationSignature::new(), - statements, + statements: vec![ + TypedStatement::definition( + TypedAssignee::Identifier(Variable::field_element(Identifier::from("foo"))), + FieldElementExpression::Number(Bn128Field::from(4)).into(), + ), + TypedStatement::definition( + TypedAssignee::Identifier(Variable::uint( + Identifier::from("i"), + UBitwidth::B32, + )), + UExpression::from(0u32).into(), + ), + TypedStatement::For( + Variable::new("i", Type::Uint(UBitwidth::B32), false), + UExpression::identifier("i".into()).annotate(UBitwidth::B32), + 2u32.into(), + vec![TypedStatement::definition( + TypedAssignee::Identifier(Variable::field_element(Identifier::from( + "foo", + ))), + FieldElementExpression::Number(Bn128Field::from(5)).into(), + )], + ), + TypedStatement::Return( + TupleExpressionInner::Value(vec![]) + .annotate(TupleType::new(vec![])) + .into(), + ), + ], + signature: DeclarationSignature::default(), }; - match ShallowTransformer::transform( - f, - &ConcreteGenericsAssignment::default(), - &mut Versions::default(), - ) { - Output::Incomplete(..) => {} - _ => unreachable!(), - }; + let mut ssa = ShallowTransformer::default(); + + assert_eq!(ssa.fold_function(f.clone()), f); } #[test] fn definition() { - // field a - // a = 5 + // field a = 5 // a = 6 // a @@ -194,9 +250,7 @@ mod tests { // a_1 = 6 // a_1 - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), @@ -236,17 +290,14 @@ mod tests { #[test] fn incremental_definition() { - // field a - // a = 5 + // field a = 5 // a = a + 1 // should be turned into // a_0 = 5 // a_1 = a_0 + 1 - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), @@ -295,9 +346,7 @@ mod tests { // a_0 = 2 // a_1 = foo(a_0) - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), @@ -356,9 +405,7 @@ mod tests { // a_0 = [1, 1] // a_0[1] = 2 - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)), @@ -413,9 +460,7 @@ mod tests { // a_0 = [[0, 1], [2, 3]] // a_0 = [4, 5] - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32)); @@ -510,10 +555,10 @@ mod tests { mod for_loop { use super::*; - use zokrates_ast::typed::types::GGenericsAssignment; + #[test] fn treat_loop() { - // def main(field a) -> field { + // def main(field a) -> field { // u32 n = 42; // n = n; // a = a; @@ -528,24 +573,21 @@ mod tests { // return a; // } - // When called with K := 1, expected: + // expected: // def main(field a_0) -> field { - // u32 K = 1; // u32 n_0 = 42; // n_1 = n_0; // a_1 = a_0; - // # versions: {n: 1, a: 1, K: 0} + // for u32 i_0 in n_1..n_1*n_1 { + // a_0 = a_0; + // } + // a_2 = a_1; // for u32 i_0 in n_1..n_1*n_1 { // a_0 = a_0; // } // a_3 = a_2; - // # versions: {n: 2, a: 3, K: 1} - // for u32 i_0 in n_2..n_2*n_2 { - // a_0 = a_0; - // } - // a_5 = a_4; - // return a_5; - // } # versions: {n: 3, a: 5, K: 2} + // return a_3; + // } let f: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], @@ -595,32 +637,15 @@ mod tests { TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()), ], signature: DeclarationSignature::new() - .generics(vec![Some( - GenericIdentifier::with_name("K").with_index(0).into(), - )]) .inputs(vec![DeclarationType::FieldElement]) .output(DeclarationType::FieldElement), }; - let mut versions = Versions::default(); - - let ssa = ShallowTransformer::transform( - f, - &GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - &mut versions, - ); + let mut ssa = ShallowTransformer::default(); let expected = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], statements: vec![ - TypedStatement::definition( - Variable::uint("K", UBitwidth::B32).into(), - TypedExpression::Uint(1u32.into()), - ), TypedStatement::definition( Variable::uint("n", UBitwidth::B32).into(), TypedExpression::Uint(42u32.into()), @@ -649,16 +674,16 @@ mod tests { )], ), TypedStatement::definition( - Variable::field_element(Identifier::from("a").version(3)).into(), - FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), + Variable::field_element(Identifier::from("a").version(2)).into(), + FieldElementExpression::identifier(Identifier::from("a").version(1)).into(), ), TypedStatement::For( Variable::uint("i", UBitwidth::B32), - UExpression::identifier(Identifier::from("n").version(2)) + UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32), - UExpression::identifier(Identifier::from("n").version(2)) + UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32) - * UExpression::identifier(Identifier::from("n").version(2)) + * UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32), vec![TypedStatement::definition( Variable::field_element("a").into(), @@ -666,47 +691,35 @@ mod tests { )], ), TypedStatement::definition( - Variable::field_element(Identifier::from("a").version(5)).into(), - FieldElementExpression::identifier(Identifier::from("a").version(4)).into(), + Variable::field_element(Identifier::from("a").version(3)).into(), + FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), ), TypedStatement::Return( - FieldElementExpression::identifier(Identifier::from("a").version(5)).into(), + FieldElementExpression::identifier(Identifier::from("a").version(3)).into(), ), ], signature: DeclarationSignature::new() - .generics(vec![Some( - GenericIdentifier::with_name("K").with_index(0).into(), - )]) .inputs(vec![DeclarationType::FieldElement]) .output(DeclarationType::FieldElement), }; + let res = ssa.fold_function(f); + assert_eq!( - versions, - vec![("n".into(), 3), ("a".into(), 5), ("K".into(), 2)] - .into_iter() - .collect::() + ssa.versions.map, + vec![( + 0, + vec![("n".into(), 1), ("a".into(), 3)].into_iter().collect() + )] + .into_iter() + .collect() ); - let expected = Output::Incomplete( - expected, - vec![ - vec![("n".into(), 1), ("a".into(), 1), ("K".into(), 0)] - .into_iter() - .collect::(), - vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 1)] - .into_iter() - .collect::(), - ], - ); - - assert_eq!(ssa, expected); + assert_eq!(res, expected); } } mod shadowing { - use zokrates_ast::typed::types::GGenericsAssignment; - use super::*; #[test] @@ -717,11 +730,11 @@ mod tests { // return; // } - // should become + // should become (only the field variable is affected as shadowing is taken care of in semantics already) - // def main(field a_0) { - // field a_1 = 42; - // bool a_2 = true; + // def main(field a_s0_v0) { + // field a_s0_v1 = 42; + // bool a_s1_v0 = true // return; // } @@ -733,7 +746,11 @@ mod tests { TypedExpression::Uint(42u32.into()), ), TypedStatement::definition( - Variable::boolean("a").into(), + Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow( + "a".into(), + 1, + ))) + .into(), BooleanExpression::Value(true).into(), ), TypedStatement::Return( @@ -742,9 +759,7 @@ mod tests { .into(), ), ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]), + signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]), }; let expected: TypedFunction = TypedFunction { @@ -755,7 +770,11 @@ mod tests { TypedExpression::Uint(42u32.into()), ), TypedStatement::definition( - Variable::boolean(Identifier::from("a").version(2)).into(), + Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow( + "a".into(), + 1, + ))) + .into(), BooleanExpression::Value(true).into(), ), TypedStatement::Return( @@ -764,121 +783,17 @@ mod tests { .into(), ), ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]), + signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]), }; - let mut versions = Versions::default(); + let ssa = ShallowTransformer::default().fold_function(f); - let ssa = - ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions); - - assert_eq!(ssa, Output::Complete(expected)); - } - - #[test] - fn next_scope() { - // def main(field a) { - // for u32 i in 0..1 { - // a = a + 1 - // field a = 42 - // } - // return a - // } - - // should become - - // def main(field a_0) { - // # versions: {a: 0} - // for u32 i in 0..1 { - // a_0 = a_0 - // field a_0 = 42 - // } - // return a_1 - // } - - let f: TypedFunction = TypedFunction { - arguments: vec![DeclarationVariable::field_element("a").into()], - statements: vec![ - TypedStatement::For( - Variable::uint("i", UBitwidth::B32), - 0u32.into(), - 1u32.into(), - vec![ - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::identifier("a".into()).into(), - ), - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::Number(42usize.into()).into(), - ), - ], - ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![FieldElementExpression::identifier( - "a".into(), - ) - .into()]) - .annotate(TupleType::new(vec![Type::FieldElement])) - .into(), - ), - ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]) - .output(DeclarationType::FieldElement), - }; - - let expected: TypedFunction = TypedFunction { - arguments: vec![DeclarationVariable::field_element("a").into()], - statements: vec![ - TypedStatement::For( - Variable::uint("i", UBitwidth::B32), - 0u32.into(), - 1u32.into(), - vec![ - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::identifier(Identifier::from("a")).into(), - ), - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::Number(42usize.into()).into(), - ), - ], - ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![FieldElementExpression::identifier( - Identifier::from("a").version(1), - ) - .into()]) - .annotate(TupleType::new(vec![Type::FieldElement])) - .into(), - ), - ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]) - .output(DeclarationType::FieldElement), - }; - - let mut versions = Versions::default(); - - let ssa = - ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions); - - assert_eq!( - ssa, - Output::Incomplete(expected, vec![vec![("a".into(), 0)].into_iter().collect()]) - ); + assert_eq!(ssa, expected); } } mod function_call { use super::*; - use zokrates_ast::typed::types::GGenericsAssignment; // test that function calls are left in #[test] fn treat_calls() { @@ -892,17 +807,12 @@ mod tests { // return a; // } - // When called with K := 1, expected: // def main(field a_0) -> field { - // K = 1; - // u32 n_0 = 42; - // n_1 = n_0; // a_1 = a_0; - // a_2 = foo::(a_1); - // n_2 = n_1; - // a_3 = a_2 * foo::(a_2); + // a_2 = foo::<42>(a_1); + // a_3 = a_2 * foo::<42>(a_2); // return a_3; - // } # versions: {n: 2, a: 3} + // } let f: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], @@ -960,25 +870,9 @@ mod tests { .output(DeclarationType::FieldElement), }; - let mut versions = Versions::default(); - - let ssa = ShallowTransformer::transform( - f, - &GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - &mut versions, - ); - let expected = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], statements: vec![ - TypedStatement::definition( - Variable::uint("K", UBitwidth::B32).into(), - TypedExpression::Uint(1u32.into()), - ), TypedStatement::definition( Variable::uint("n", UBitwidth::B32).into(), TypedExpression::Uint(42u32.into()), @@ -1042,14 +936,23 @@ mod tests { .output(DeclarationType::FieldElement), }; + let mut ssa = ShallowTransformer::default(); + + let res = ssa.fold_function(f); + assert_eq!( - versions, - vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)] - .into_iter() - .collect::() + ssa.versions.map, + vec![( + 0, + vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)] + .into_iter() + .collect() + )] + .into_iter() + .collect() ); - assert_eq!(ssa, Output::Incomplete(expected, vec![],)); + assert_eq!(res, expected); } } } diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index 1180874f..989819fc 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -531,10 +531,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( ) -> Vec> { match s { TypedAssemblyStatement::Assignment(a, e) => { - vec![TypedAssemblyStatement::Assignment( - f.fold_assignee(a), - f.fold_expression(e), - )] + let e = f.fold_expression(e); + vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a), e)] } TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { vec![TypedAssemblyStatement::Constraint( @@ -552,8 +550,9 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( ) -> Vec> { let res = match s { TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)), - TypedStatement::Definition(a, e) => { - TypedStatement::Definition(f.fold_assignee(a), f.fold_definition_rhs(e)) + TypedStatement::Definition(a, rhs) => { + let rhs = f.fold_definition_rhs(rhs); + TypedStatement::Definition(f.fold_assignee(a), rhs) } TypedStatement::Assertion(e, error) => { TypedStatement::Assertion(f.fold_boolean_expression(e), error) @@ -576,7 +575,6 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( .flat_map(|s| f.fold_assembly_statement(s)) .collect(), ), - s => s, }; vec![res] } diff --git a/zokrates_ast/src/typed/identifier.rs b/zokrates_ast/src/typed/identifier.rs index 91aaaa30..d23a9b3b 100644 --- a/zokrates_ast/src/typed/identifier.rs +++ b/zokrates_ast/src/typed/identifier.rs @@ -24,6 +24,21 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> { } } +impl<'ast> FrameIdentifier<'ast> { + pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> { + FrameIdentifier { frame, ..self } + } +} + +impl<'ast> Identifier<'ast> { + pub fn in_frame(self, frame: usize) -> Identifier<'ast> { + Identifier { + id: self.id.in_frame(frame), + ..self + } + } +} + impl<'ast> CoreIdentifier<'ast> { pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> { FrameIdentifier { id: self, frame } diff --git a/zokrates_ast/src/typed/mod.rs b/zokrates_ast/src/typed/mod.rs index df8b0f63..1d4ad0fc 100644 --- a/zokrates_ast/src/typed/mod.rs +++ b/zokrates_ast/src/typed/mod.rs @@ -27,7 +27,7 @@ pub use self::types::{ UBitwidth, }; use self::types::{ConcreteArrayType, ConcreteStructType}; -use crate::typed::types::{ConcreteGenericsAssignment, IntoType}; +use crate::typed::types::IntoType; pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable}; use std::marker::PhantomData; @@ -353,19 +353,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { writeln!(f)?; - let mut tab = 0; - for s in &self.statements { - if let TypedStatement::PopCallLog = s { - tab -= 1; - }; - - s.fmt_indented(f, 1 + tab)?; - writeln!(f)?; - - if let TypedStatement::PushCallLog(..) = s { - tab += 1; - }; + writeln!(f, "{}", s)?; } writeln!(f, "}}")?; @@ -695,12 +684,6 @@ pub enum TypedStatement<'ast, T> { Vec>, ), Log(FormatString, Vec>), - // Aux - PushCallLog( - DeclarationFunctionKey<'ast, T>, - ConcreteGenericsAssignment<'ast>, - ), - PopCallLog, Assembly(Vec>), } @@ -714,31 +697,6 @@ impl<'ast, T> TypedStatement<'ast, T> { } } -impl<'ast, T: fmt::Display> TypedStatement<'ast, T> { - fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result { - match self { - TypedStatement::For(variable, from, to, statements) => { - write!(f, "{}", "\t".repeat(depth))?; - writeln!(f, "for {} in {}..{} {{", variable, from, to)?; - for s in statements { - s.fmt_indented(f, depth + 1)?; - writeln!(f)?; - } - write!(f, "{}}}", "\t".repeat(depth)) - } - TypedStatement::Assembly(statements) => { - write!(f, "{}", "\t".repeat(depth))?; - writeln!(f, "asm {{")?; - for s in statements { - writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?; - } - write!(f, "{}}}", "\t".repeat(depth)) - } - s => write!(f, "{}{}", "\t".repeat(depth), s), - } - } -} - impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -773,14 +731,6 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { .collect::>() .join(", ") ), - TypedStatement::PushCallLog(ref key, ref generics) => write!( - f, - "// PUSH CALL TO {}/{}::<{}>", - key.module.display(), - key.id, - generics, - ), - TypedStatement::PopCallLog => write!(f, "// POP CALL",), TypedStatement::Assembly(ref statements) => { writeln!(f, "asm {{")?; for s in statements { diff --git a/zokrates_ast/src/typed/result_folder.rs b/zokrates_ast/src/typed/result_folder.rs index 25c84c29..8ed91131 100644 --- a/zokrates_ast/src/typed/result_folder.rs +++ b/zokrates_ast/src/typed/result_folder.rs @@ -532,10 +532,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( ) -> Result>, F::Error> { Ok(match s { TypedAssemblyStatement::Assignment(a, e) => { - vec![TypedAssemblyStatement::Assignment( - f.fold_assignee(a)?, - f.fold_expression(e)?, - )] + let e = f.fold_expression(e)?; + vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, e)] } TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { vec![TypedAssemblyStatement::Constraint( @@ -554,7 +552,8 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( let res = match s { TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)?), TypedStatement::Definition(a, e) => { - TypedStatement::Definition(f.fold_assignee(a)?, f.fold_definition_rhs(e)?) + let rhs = f.fold_definition_rhs(e)?; + TypedStatement::Definition(f.fold_assignee(a)?, rhs) } TypedStatement::Assertion(e, error) => { TypedStatement::Assertion(f.fold_boolean_expression(e)?, error) @@ -586,7 +585,6 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( .flatten() .collect(), ), - s => s, }; Ok(vec![res]) } diff --git a/zokrates_ast/src/typed/types.rs b/zokrates_ast/src/typed/types.rs index 60d3792f..f2bf23d5 100644 --- a/zokrates_ast/src/typed/types.rs +++ b/zokrates_ast/src/typed/types.rs @@ -240,9 +240,9 @@ impl<'ast, T> From for UExpression<'ast, T> { impl<'ast, T: Field> From> for UExpression<'ast, T> { fn from(c: DeclarationConstant<'ast, T>) -> Self { match c { - DeclarationConstant::Generic(_g) => { - // UExpression::identifier(FrameIdentifier::from(g).into()).annotate(UBitwidth::B32) - unreachable!() + DeclarationConstant::Generic(g) => { + UExpression::identifier(Identifier::from(CoreIdentifier::from(g))) + .annotate(UBitwidth::B32) } DeclarationConstant::Concrete(v) => { UExpressionInner::Value(v as u128).annotate(UBitwidth::B32) diff --git a/zokrates_core_test/tests/tests/call_ssa.json b/zokrates_core_test/tests/tests/call_ssa.json new file mode 100644 index 00000000..43675482 --- /dev/null +++ b/zokrates_core_test/tests/tests/call_ssa.json @@ -0,0 +1,16 @@ +{ + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": ["0"] + }, + "output": { + "Ok": { + "value": "4" + } + } + } + ] + } + \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/call_ssa.zok b/zokrates_core_test/tests/tests/call_ssa.zok new file mode 100644 index 00000000..dad61d19 --- /dev/null +++ b/zokrates_core_test/tests/tests/call_ssa.zok @@ -0,0 +1,11 @@ +// main should be x -> x + 4 + +def foo(field mut a) -> field { + a = a + 1; + return a + 1; +} + +def main(field mut a) -> field { + a = foo(a + 1); + return a + 1; +} \ No newline at end of file