From c81cc66fd3322eca63ccb130e554667caf24793e Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 27 Feb 2023 13:35:50 +0100 Subject: [PATCH] simplify inliner --- zokrates_analysis/src/reducer/inline.rs | 101 ++++++++---------------- zokrates_analysis/src/reducer/mod.rs | 47 +++++------ 2 files changed, 53 insertions(+), 95 deletions(-) diff --git a/zokrates_analysis/src/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs index edaa2803..8d3727d3 100644 --- a/zokrates_analysis/src/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -42,18 +42,8 @@ use zokrates_field::Field; pub enum InlineError<'ast, T> { Generic(DeclarationFunctionKey<'ast, T>, ConcreteFunctionKey<'ast>), - Flat( - FlatEmbed, - Vec, - Vec>, - Type<'ast, T>, - ), - NonConstant( - DeclarationFunctionKey<'ast, T>, - Vec>>, - Vec>, - Type<'ast, T>, - ), + Flat(FlatEmbed, Vec, Type<'ast, T>), + NonConstant, } fn get_canonical_function<'ast, T: Field>( @@ -74,26 +64,26 @@ fn get_canonical_function<'ast, T: Field>( } } -type InlineResult<'ast, T> = Result< - ( - Vec>, - Vec>, - Vec>, - Vec>, - TypedExpression<'ast, T>, - ), - InlineError<'ast, T>, ->; +pub struct InlineValue<'ast, T> { + /// the pre-SSA input variables to assign the arguments to + pub input_variables: Vec>, + /// the pre-SSA statements for this call, including definition of the generic parameters + pub statements: Vec>, + /// the pre-SSA return value for this call + pub return_value: TypedExpression<'ast, T>, +} + +type InlineResult<'ast, T> = Result, InlineError<'ast, T>>; pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( k: &DeclarationFunctionKey<'ast, T>, - generics: Vec>>, - arguments: Vec>, - output: &E::Ty, + generics: &[Option>], + arguments: &[TypedExpression<'ast, T>], + output_ty: &E::Ty, program: &TypedProgram<'ast, T>, ) -> InlineResult<'ast, T> { use zokrates_ast::typed::Typed; - let output_type = output.clone().into_type(); + let output_type = output_ty.clone().into_type(); // we try to get concrete values for explicit generics let generics_values: Vec> = generics @@ -107,32 +97,19 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( .transpose() }) .collect::>() - .map_err(|_| { - InlineError::NonConstant( - k.clone(), - generics.clone(), - arguments.clone(), - output_type.clone(), - ) - })?; + .map_err(|_| InlineError::NonConstant)?; // we infer a signature based on inputs and outputs - // this is where we could handle explicit annotations let inferred_signature = Signature::new() - .generics(generics.clone()) + .generics(generics.to_vec().clone()) .inputs(arguments.iter().map(|a| a.get_type()).collect()) .output(output_type.clone()); - // we try to get concrete values for the whole signature. if this fails we should propagate again + // we try to get concrete values for the whole signature let inferred_signature = match ConcreteSignature::try_from(inferred_signature) { Ok(s) => s, Err(_) => { - return Err(InlineError::NonConstant( - k.clone(), - generics, - arguments, - output_type, - )); + return Err(InlineError::NonConstant); } }; @@ -158,7 +135,6 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat( e, e.generics::(&assignment), - arguments.clone(), output_type, )), _ => unreachable!(), @@ -166,19 +142,15 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( assert_eq!(f.arguments.len(), arguments.len()); - let generics_bindings: Vec<_> = assignment - .0 - .into_iter() - .map(|(identifier, value)| { - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::uint( - CoreIdentifier::from(identifier), - UBitwidth::B32, - )), - TypedExpression::from(UExpression::from(value)).into(), - ) - }) - .collect(); + let generic_bindings = assignment.0.into_iter().map(|(identifier, value)| { + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::uint( + CoreIdentifier::from(identifier), + UBitwidth::B32, + )), + TypedExpression::from(UExpression::from(value)).into(), + ) + }); let input_variables: Vec> = f .arguments @@ -188,23 +160,20 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( .map(Variable::from) .collect(); - let (statements, mut returns): (Vec<_>, Vec<_>) = f - .statements - .into_iter() + let (statements, mut returns): (Vec<_>, Vec<_>) = generic_bindings + .chain(f.statements) .partition(|s| !matches!(s, TypedStatement::Return(..))); assert_eq!(returns.len(), 1); - let return_expression = match returns.pop().unwrap() { + let return_value = match returns.pop().unwrap() { TypedStatement::Return(e) => e, _ => unreachable!(), }; - Ok(( + Ok(InlineValue { input_variables, - arguments, - generics_bindings, statements, - return_expression, - )) + return_value, + }) } diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index 79749d68..826cd666 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -17,6 +17,7 @@ mod constants_writer; mod inline; mod shallow_ssa; +use self::inline::InlineValue; use self::inline::{inline_call, InlineError}; use std::collections::HashMap; use zokrates_ast::typed::result_folder::*; @@ -120,7 +121,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ) -> Result, Self::Error> { // generics are already in ssa form - let generics = e + let generics: Vec<_> = e .generics .into_iter() .map(|g| { @@ -138,7 +139,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { // arguments are already in ssa form - let arguments = e + let arguments: Vec<_> = e .arguments .into_iter() .map(|e| { @@ -153,24 +154,14 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { self.ssa.push_call_frame(); - let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program); + let res = inline_call::<_, E>(&e.function_key, &generics, &arguments, ty, self.program); let res = match res { - Ok((input_variables, arguments, generics_bindings, statements, expression)) => { - let generics_bindings = generics_bindings - .into_iter() - .flat_map(|s| self.ssa.fold_statement(s)) - .map(|s| self.propagator.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .map(|s| self.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten(); - - self.statement_buffer.extend(generics_bindings); - + Ok(InlineValue { + input_variables, + statements, + return_value, + }) => { // the lhs is from the inner call frame, the rhs is from the outer one, so only fold the lhs let input_bindings: Vec<_> = input_variables .into_iter() @@ -196,27 +187,25 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { self.statement_buffer.extend(statements); - let expression = self.ssa.fold_expression(expression); + let return_value = self.ssa.fold_expression(return_value); - let expression = self.propagator.fold_expression(expression)?; + let return_value = self.propagator.fold_expression(return_value)?; - let expression = self.fold_expression(expression)?; + let return_value = self.fold_expression(return_value)?; Ok(FunctionCallOrExpression::Expression( - E::from(expression).into_inner(), + E::from(return_value).into_inner(), )) } Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!( "Call site `{}` incompatible with declaration `{}`", conc, decl ))), - Err(InlineError::NonConstant(key, generics, arguments, _)) => { - Err(Error::NonConstant(format!( - "Generic parameters must be compile-time constants, found {}", - FunctionCallExpression::<_, E>::new(key, generics, arguments) - ))) - } - Err(InlineError::Flat(embed, generics, arguments, output_type)) => { + Err(InlineError::NonConstant) => Err(Error::NonConstant(format!( + "Generic parameters must be compile-time constants, found {}", + FunctionCallExpression::<_, E>::new(e.function_key, generics, arguments) + ))), + Err(InlineError::Flat(embed, generics, output_type)) => { let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0)); let var = Variable::immutable(identifier.clone(), output_type);