1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

simplify inliner

This commit is contained in:
schaeff 2023-02-27 13:35:50 +01:00
parent 35c1a9686b
commit c81cc66fd3
2 changed files with 53 additions and 95 deletions

View file

@ -42,18 +42,8 @@ use zokrates_field::Field;
pub enum InlineError<'ast, T> {
Generic(DeclarationFunctionKey<'ast, T>, ConcreteFunctionKey<'ast>),
Flat(
FlatEmbed,
Vec<u32>,
Vec<TypedExpression<'ast, T>>,
Type<'ast, T>,
),
NonConstant(
DeclarationFunctionKey<'ast, T>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
Type<'ast, T>,
),
Flat(FlatEmbed, Vec<u32>, Type<'ast, T>),
NonConstant,
}
fn get_canonical_function<'ast, T: Field>(
@ -74,26 +64,26 @@ fn get_canonical_function<'ast, T: Field>(
}
}
type InlineResult<'ast, T> = Result<
(
Vec<Variable<'ast, T>>,
Vec<TypedExpression<'ast, T>>,
Vec<TypedStatement<'ast, T>>,
Vec<TypedStatement<'ast, T>>,
TypedExpression<'ast, T>,
),
InlineError<'ast, T>,
>;
pub struct InlineValue<'ast, T> {
/// the pre-SSA input variables to assign the arguments to
pub input_variables: Vec<Variable<'ast, T>>,
/// the pre-SSA statements for this call, including definition of the generic parameters
pub statements: Vec<TypedStatement<'ast, T>>,
/// the pre-SSA return value for this call
pub return_value: TypedExpression<'ast, T>,
}
type InlineResult<'ast, T> = Result<InlineValue<'ast, T>, InlineError<'ast, T>>;
pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
k: &DeclarationFunctionKey<'ast, T>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output: &E::Ty,
generics: &[Option<UExpression<'ast, T>>],
arguments: &[TypedExpression<'ast, T>],
output_ty: &E::Ty,
program: &TypedProgram<'ast, T>,
) -> InlineResult<'ast, T> {
use zokrates_ast::typed::Typed;
let output_type = output.clone().into_type();
let output_type = output_ty.clone().into_type();
// we try to get concrete values for explicit generics
let generics_values: Vec<Option<u32>> = generics
@ -107,32 +97,19 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
.transpose()
})
.collect::<Result<_, _>>()
.map_err(|_| {
InlineError::NonConstant(
k.clone(),
generics.clone(),
arguments.clone(),
output_type.clone(),
)
})?;
.map_err(|_| InlineError::NonConstant)?;
// we infer a signature based on inputs and outputs
// this is where we could handle explicit annotations
let inferred_signature = Signature::new()
.generics(generics.clone())
.generics(generics.to_vec().clone())
.inputs(arguments.iter().map(|a| a.get_type()).collect())
.output(output_type.clone());
// we try to get concrete values for the whole signature. if this fails we should propagate again
// we try to get concrete values for the whole signature
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {
Ok(s) => s,
Err(_) => {
return Err(InlineError::NonConstant(
k.clone(),
generics,
arguments,
output_type,
));
return Err(InlineError::NonConstant);
}
};
@ -158,7 +135,6 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat(
e,
e.generics::<T>(&assignment),
arguments.clone(),
output_type,
)),
_ => unreachable!(),
@ -166,19 +142,15 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
assert_eq!(f.arguments.len(), arguments.len());
let generics_bindings: Vec<_> = assignment
.0
.into_iter()
.map(|(identifier, value)| {
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::uint(
CoreIdentifier::from(identifier),
UBitwidth::B32,
)),
TypedExpression::from(UExpression::from(value)).into(),
)
})
.collect();
let generic_bindings = assignment.0.into_iter().map(|(identifier, value)| {
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::uint(
CoreIdentifier::from(identifier),
UBitwidth::B32,
)),
TypedExpression::from(UExpression::from(value)).into(),
)
});
let input_variables: Vec<Variable<'ast, T>> = f
.arguments
@ -188,23 +160,20 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
.map(Variable::from)
.collect();
let (statements, mut returns): (Vec<_>, Vec<_>) = f
.statements
.into_iter()
let (statements, mut returns): (Vec<_>, Vec<_>) = generic_bindings
.chain(f.statements)
.partition(|s| !matches!(s, TypedStatement::Return(..)));
assert_eq!(returns.len(), 1);
let return_expression = match returns.pop().unwrap() {
let return_value = match returns.pop().unwrap() {
TypedStatement::Return(e) => e,
_ => unreachable!(),
};
Ok((
Ok(InlineValue {
input_variables,
arguments,
generics_bindings,
statements,
return_expression,
))
return_value,
})
}

View file

@ -17,6 +17,7 @@ mod constants_writer;
mod inline;
mod shallow_ssa;
use self::inline::InlineValue;
use self::inline::{inline_call, InlineError};
use std::collections::HashMap;
use zokrates_ast::typed::result_folder::*;
@ -120,7 +121,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
) -> Result<FunctionCallOrExpression<'ast, T, E>, Self::Error> {
// generics are already in ssa form
let generics = e
let generics: Vec<_> = e
.generics
.into_iter()
.map(|g| {
@ -138,7 +139,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
// arguments are already in ssa form
let arguments = e
let arguments: Vec<_> = e
.arguments
.into_iter()
.map(|e| {
@ -153,24 +154,14 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
self.ssa.push_call_frame();
let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program);
let res = inline_call::<_, E>(&e.function_key, &generics, &arguments, ty, self.program);
let res = match res {
Ok((input_variables, arguments, generics_bindings, statements, expression)) => {
let generics_bindings = generics_bindings
.into_iter()
.flat_map(|s| self.ssa.fold_statement(s))
.map(|s| self.propagator.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.map(|s| self.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
self.statement_buffer.extend(generics_bindings);
Ok(InlineValue {
input_variables,
statements,
return_value,
}) => {
// the lhs is from the inner call frame, the rhs is from the outer one, so only fold the lhs
let input_bindings: Vec<_> = input_variables
.into_iter()
@ -196,27 +187,25 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
self.statement_buffer.extend(statements);
let expression = self.ssa.fold_expression(expression);
let return_value = self.ssa.fold_expression(return_value);
let expression = self.propagator.fold_expression(expression)?;
let return_value = self.propagator.fold_expression(return_value)?;
let expression = self.fold_expression(expression)?;
let return_value = self.fold_expression(return_value)?;
Ok(FunctionCallOrExpression::Expression(
E::from(expression).into_inner(),
E::from(return_value).into_inner(),
))
}
Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!(
"Call site `{}` incompatible with declaration `{}`",
conc, decl
))),
Err(InlineError::NonConstant(key, generics, arguments, _)) => {
Err(Error::NonConstant(format!(
"Generic parameters must be compile-time constants, found {}",
FunctionCallExpression::<_, E>::new(key, generics, arguments)
)))
}
Err(InlineError::Flat(embed, generics, arguments, output_type)) => {
Err(InlineError::NonConstant) => Err(Error::NonConstant(format!(
"Generic parameters must be compile-time constants, found {}",
FunctionCallExpression::<_, E>::new(e.function_key, generics, arguments)
))),
Err(InlineError::Flat(embed, generics, output_type)) => {
let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0));
let var = Variable::immutable(identifier.clone(), output_type);