simplify inliner
This commit is contained in:
parent
35c1a9686b
commit
c81cc66fd3
2 changed files with 53 additions and 95 deletions
|
@ -42,18 +42,8 @@ use zokrates_field::Field;
|
|||
|
||||
pub enum InlineError<'ast, T> {
|
||||
Generic(DeclarationFunctionKey<'ast, T>, ConcreteFunctionKey<'ast>),
|
||||
Flat(
|
||||
FlatEmbed,
|
||||
Vec<u32>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Type<'ast, T>,
|
||||
),
|
||||
NonConstant(
|
||||
DeclarationFunctionKey<'ast, T>,
|
||||
Vec<Option<UExpression<'ast, T>>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Type<'ast, T>,
|
||||
),
|
||||
Flat(FlatEmbed, Vec<u32>, Type<'ast, T>),
|
||||
NonConstant,
|
||||
}
|
||||
|
||||
fn get_canonical_function<'ast, T: Field>(
|
||||
|
@ -74,26 +64,26 @@ fn get_canonical_function<'ast, T: Field>(
|
|||
}
|
||||
}
|
||||
|
||||
type InlineResult<'ast, T> = Result<
|
||||
(
|
||||
Vec<Variable<'ast, T>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Vec<TypedStatement<'ast, T>>,
|
||||
Vec<TypedStatement<'ast, T>>,
|
||||
TypedExpression<'ast, T>,
|
||||
),
|
||||
InlineError<'ast, T>,
|
||||
>;
|
||||
pub struct InlineValue<'ast, T> {
|
||||
/// the pre-SSA input variables to assign the arguments to
|
||||
pub input_variables: Vec<Variable<'ast, T>>,
|
||||
/// the pre-SSA statements for this call, including definition of the generic parameters
|
||||
pub statements: Vec<TypedStatement<'ast, T>>,
|
||||
/// the pre-SSA return value for this call
|
||||
pub return_value: TypedExpression<'ast, T>,
|
||||
}
|
||||
|
||||
type InlineResult<'ast, T> = Result<InlineValue<'ast, T>, InlineError<'ast, T>>;
|
||||
|
||||
pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
||||
k: &DeclarationFunctionKey<'ast, T>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output: &E::Ty,
|
||||
generics: &[Option<UExpression<'ast, T>>],
|
||||
arguments: &[TypedExpression<'ast, T>],
|
||||
output_ty: &E::Ty,
|
||||
program: &TypedProgram<'ast, T>,
|
||||
) -> InlineResult<'ast, T> {
|
||||
use zokrates_ast::typed::Typed;
|
||||
let output_type = output.clone().into_type();
|
||||
let output_type = output_ty.clone().into_type();
|
||||
|
||||
// we try to get concrete values for explicit generics
|
||||
let generics_values: Vec<Option<u32>> = generics
|
||||
|
@ -107,32 +97,19 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
.transpose()
|
||||
})
|
||||
.collect::<Result<_, _>>()
|
||||
.map_err(|_| {
|
||||
InlineError::NonConstant(
|
||||
k.clone(),
|
||||
generics.clone(),
|
||||
arguments.clone(),
|
||||
output_type.clone(),
|
||||
)
|
||||
})?;
|
||||
.map_err(|_| InlineError::NonConstant)?;
|
||||
|
||||
// we infer a signature based on inputs and outputs
|
||||
// this is where we could handle explicit annotations
|
||||
let inferred_signature = Signature::new()
|
||||
.generics(generics.clone())
|
||||
.generics(generics.to_vec().clone())
|
||||
.inputs(arguments.iter().map(|a| a.get_type()).collect())
|
||||
.output(output_type.clone());
|
||||
|
||||
// we try to get concrete values for the whole signature. if this fails we should propagate again
|
||||
// we try to get concrete values for the whole signature
|
||||
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
return Err(InlineError::NonConstant(
|
||||
k.clone(),
|
||||
generics,
|
||||
arguments,
|
||||
output_type,
|
||||
));
|
||||
return Err(InlineError::NonConstant);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -158,7 +135,6 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat(
|
||||
e,
|
||||
e.generics::<T>(&assignment),
|
||||
arguments.clone(),
|
||||
output_type,
|
||||
)),
|
||||
_ => unreachable!(),
|
||||
|
@ -166,19 +142,15 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
|
||||
assert_eq!(f.arguments.len(), arguments.len());
|
||||
|
||||
let generics_bindings: Vec<_> = assignment
|
||||
.0
|
||||
.into_iter()
|
||||
.map(|(identifier, value)| {
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::uint(
|
||||
CoreIdentifier::from(identifier),
|
||||
UBitwidth::B32,
|
||||
)),
|
||||
TypedExpression::from(UExpression::from(value)).into(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let generic_bindings = assignment.0.into_iter().map(|(identifier, value)| {
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::uint(
|
||||
CoreIdentifier::from(identifier),
|
||||
UBitwidth::B32,
|
||||
)),
|
||||
TypedExpression::from(UExpression::from(value)).into(),
|
||||
)
|
||||
});
|
||||
|
||||
let input_variables: Vec<Variable<'ast, T>> = f
|
||||
.arguments
|
||||
|
@ -188,23 +160,20 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
.map(Variable::from)
|
||||
.collect();
|
||||
|
||||
let (statements, mut returns): (Vec<_>, Vec<_>) = f
|
||||
.statements
|
||||
.into_iter()
|
||||
let (statements, mut returns): (Vec<_>, Vec<_>) = generic_bindings
|
||||
.chain(f.statements)
|
||||
.partition(|s| !matches!(s, TypedStatement::Return(..)));
|
||||
|
||||
assert_eq!(returns.len(), 1);
|
||||
|
||||
let return_expression = match returns.pop().unwrap() {
|
||||
let return_value = match returns.pop().unwrap() {
|
||||
TypedStatement::Return(e) => e,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok((
|
||||
Ok(InlineValue {
|
||||
input_variables,
|
||||
arguments,
|
||||
generics_bindings,
|
||||
statements,
|
||||
return_expression,
|
||||
))
|
||||
return_value,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ mod constants_writer;
|
|||
mod inline;
|
||||
mod shallow_ssa;
|
||||
|
||||
use self::inline::InlineValue;
|
||||
use self::inline::{inline_call, InlineError};
|
||||
use std::collections::HashMap;
|
||||
use zokrates_ast::typed::result_folder::*;
|
||||
|
@ -120,7 +121,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
) -> Result<FunctionCallOrExpression<'ast, T, E>, Self::Error> {
|
||||
// generics are already in ssa form
|
||||
|
||||
let generics = e
|
||||
let generics: Vec<_> = e
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| {
|
||||
|
@ -138,7 +139,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
|
||||
// arguments are already in ssa form
|
||||
|
||||
let arguments = e
|
||||
let arguments: Vec<_> = e
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
|
@ -153,24 +154,14 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
|
||||
self.ssa.push_call_frame();
|
||||
|
||||
let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program);
|
||||
let res = inline_call::<_, E>(&e.function_key, &generics, &arguments, ty, self.program);
|
||||
|
||||
let res = match res {
|
||||
Ok((input_variables, arguments, generics_bindings, statements, expression)) => {
|
||||
let generics_bindings = generics_bindings
|
||||
.into_iter()
|
||||
.flat_map(|s| self.ssa.fold_statement(s))
|
||||
.map(|s| self.propagator.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|s| self.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
self.statement_buffer.extend(generics_bindings);
|
||||
|
||||
Ok(InlineValue {
|
||||
input_variables,
|
||||
statements,
|
||||
return_value,
|
||||
}) => {
|
||||
// the lhs is from the inner call frame, the rhs is from the outer one, so only fold the lhs
|
||||
let input_bindings: Vec<_> = input_variables
|
||||
.into_iter()
|
||||
|
@ -196,27 +187,25 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
|
||||
self.statement_buffer.extend(statements);
|
||||
|
||||
let expression = self.ssa.fold_expression(expression);
|
||||
let return_value = self.ssa.fold_expression(return_value);
|
||||
|
||||
let expression = self.propagator.fold_expression(expression)?;
|
||||
let return_value = self.propagator.fold_expression(return_value)?;
|
||||
|
||||
let expression = self.fold_expression(expression)?;
|
||||
let return_value = self.fold_expression(return_value)?;
|
||||
|
||||
Ok(FunctionCallOrExpression::Expression(
|
||||
E::from(expression).into_inner(),
|
||||
E::from(return_value).into_inner(),
|
||||
))
|
||||
}
|
||||
Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!(
|
||||
"Call site `{}` incompatible with declaration `{}`",
|
||||
conc, decl
|
||||
))),
|
||||
Err(InlineError::NonConstant(key, generics, arguments, _)) => {
|
||||
Err(Error::NonConstant(format!(
|
||||
"Generic parameters must be compile-time constants, found {}",
|
||||
FunctionCallExpression::<_, E>::new(key, generics, arguments)
|
||||
)))
|
||||
}
|
||||
Err(InlineError::Flat(embed, generics, arguments, output_type)) => {
|
||||
Err(InlineError::NonConstant) => Err(Error::NonConstant(format!(
|
||||
"Generic parameters must be compile-time constants, found {}",
|
||||
FunctionCallExpression::<_, E>::new(e.function_key, generics, arguments)
|
||||
))),
|
||||
Err(InlineError::Flat(embed, generics, output_type)) => {
|
||||
let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0));
|
||||
|
||||
let var = Variable::immutable(identifier.clone(), output_type);
|
||||
|
|
Loading…
Reference in a new issue