Merge pull request #894 from Zokrates/function-call-expression
implement function call expression
This commit is contained in:
commit
498f31e003
11 changed files with 635 additions and 681 deletions
|
@ -1548,7 +1548,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
message: format!("Expected function call argument to be of type {}, found {} of type {}", e.1, e.0, e.0.get_type())
|
||||
}])?;
|
||||
|
||||
let call = TypedExpressionList::FunctionCall(f.clone(), generics_checked.unwrap_or_else(|| vec![None; f.signature.generics.len()]), arguments_checked, assignees.iter().map(|a| a.get_type()).collect());
|
||||
let call = TypedExpressionList::function_call(f.clone(), generics_checked.unwrap_or_else(|| vec![None; f.signature.generics.len()]), arguments_checked).annotate(Types { inner: assignees.iter().map(|a| a.get_type()).collect()});
|
||||
|
||||
Ok(TypedStatement::MultipleDefinition(assignees, call))
|
||||
},
|
||||
|
@ -2102,7 +2102,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
|
||||
let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]);
|
||||
|
||||
let output_types = signature.get_output_types(
|
||||
let mut output_types = signature.get_output_types(
|
||||
generics_checked.clone(),
|
||||
arguments_checked.iter().map(|a| a.get_type()).collect()
|
||||
).map_err(|e| ErrorInner {
|
||||
|
@ -2113,63 +2113,41 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
),
|
||||
})?;
|
||||
|
||||
let function_key = DeclarationFunctionKey {
|
||||
module: module_id.to_path_buf(),
|
||||
id: f.id,
|
||||
signature: signature.clone(),
|
||||
};
|
||||
|
||||
// the return count has to be 1
|
||||
match output_types.len() {
|
||||
1 => match &output_types[0] {
|
||||
1 => match output_types.pop().unwrap() {
|
||||
Type::Int => unreachable!(),
|
||||
Type::FieldElement => Ok(FieldElementExpression::FunctionCall(
|
||||
DeclarationFunctionKey {
|
||||
module: module_id.to_path_buf(),
|
||||
id: f.id,
|
||||
signature: signature.clone(),
|
||||
},
|
||||
Type::FieldElement => Ok(FieldElementExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
)
|
||||
.into()),
|
||||
Type::Boolean => Ok(BooleanExpression::FunctionCall(
|
||||
DeclarationFunctionKey {
|
||||
module: module_id.to_path_buf(),
|
||||
id: f.id,
|
||||
signature: signature.clone(),
|
||||
},
|
||||
).into()),
|
||||
Type::Boolean => Ok(BooleanExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
)
|
||||
.into()),
|
||||
Type::Uint(bitwidth) => Ok(UExpressionInner::FunctionCall(
|
||||
DeclarationFunctionKey {
|
||||
module: module_id.to_path_buf(),
|
||||
id: f.id,
|
||||
signature: signature.clone(),
|
||||
},
|
||||
).into()),
|
||||
Type::Uint(bitwidth) => Ok(UExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
)
|
||||
.annotate(*bitwidth)
|
||||
.into()),
|
||||
Type::Struct(members) => Ok(StructExpressionInner::FunctionCall(
|
||||
DeclarationFunctionKey {
|
||||
module: module_id.to_path_buf(),
|
||||
id: f.id,
|
||||
signature: signature.clone(),
|
||||
},
|
||||
).annotate(bitwidth).into()),
|
||||
Type::Struct(struct_ty) => Ok(StructExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
)
|
||||
.annotate(members.clone())
|
||||
.into()),
|
||||
Type::Array(array_type) => Ok(ArrayExpressionInner::FunctionCall(
|
||||
DeclarationFunctionKey {
|
||||
module: module_id.to_path_buf(),
|
||||
id: f.id,
|
||||
signature: signature.clone(),
|
||||
},
|
||||
).annotate(struct_ty).into()),
|
||||
Type::Array(array_ty) => Ok(ArrayExpression::function_call(
|
||||
function_key,
|
||||
generics_checked,
|
||||
arguments_checked,
|
||||
)
|
||||
.annotate(*array_type.ty.clone(), array_type.size.clone())
|
||||
.into()),
|
||||
).annotate(*array_ty.ty, array_ty.size).into()),
|
||||
},
|
||||
n => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
|
@ -4778,7 +4756,7 @@ mod tests {
|
|||
typed_absy::Variable::field_element("a").into(),
|
||||
typed_absy::Variable::field_element("b").into(),
|
||||
],
|
||||
TypedExpressionList::FunctionCall(
|
||||
TypedExpressionList::function_call(
|
||||
DeclarationFunctionKey::with_location((*MODULE_ID).clone(), "foo").signature(
|
||||
DeclarationSignature::new().outputs(vec![
|
||||
DeclarationType::FieldElement,
|
||||
|
@ -4787,8 +4765,8 @@ mod tests {
|
|||
),
|
||||
vec![],
|
||||
vec![],
|
||||
vec![Type::FieldElement, Type::FieldElement],
|
||||
),
|
||||
)
|
||||
.annotate(Types::new(vec![Type::FieldElement, Type::FieldElement])),
|
||||
),
|
||||
TypedStatement::Return(vec![FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier("a".into()),
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::typed_absy;
|
||||
use crate::typed_absy::types::UBitwidth;
|
||||
use crate::typed_absy::{self, Expr};
|
||||
use crate::zir;
|
||||
use std::marker::PhantomData;
|
||||
use zokrates_field::Field;
|
||||
|
@ -208,8 +208,8 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||
es: typed_absy::TypedExpressionList<'ast, T>,
|
||||
) -> zir::ZirExpressionList<'ast, T> {
|
||||
match es {
|
||||
typed_absy::TypedExpressionList::EmbedCall(embed, generics, arguments, _) => {
|
||||
match es.into_inner() {
|
||||
typed_absy::TypedExpressionListInner::EmbedCall(embed, generics, arguments) => {
|
||||
zir::ZirExpressionList::EmbedCall(
|
||||
embed,
|
||||
generics,
|
||||
|
|
|
@ -367,18 +367,23 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
.collect::<Result<_, _>>()?;
|
||||
let expression_list = self.fold_expression_list(expression_list)?;
|
||||
|
||||
let statements = match expression_list {
|
||||
TypedExpressionList::EmbedCall(embed, generics, arguments, types) => {
|
||||
let types = Types {
|
||||
inner: expression_list
|
||||
.types
|
||||
.clone()
|
||||
.inner
|
||||
.into_iter()
|
||||
.map(|t| self.fold_type(t))
|
||||
.collect::<Result<_, _>>()?,
|
||||
};
|
||||
|
||||
let statements = match expression_list.into_inner() {
|
||||
TypedExpressionListInner::EmbedCall(embed, generics, arguments) => {
|
||||
let arguments: Vec<_> = arguments
|
||||
.into_iter()
|
||||
.map(|a| self.fold_expression(a))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let types = types
|
||||
.into_iter()
|
||||
.map(|t| self.fold_type(t))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
fn process_u_from_bits<'ast, T: Field>(
|
||||
variables: Vec<TypedAssignee<'ast, T>>,
|
||||
mut arguments: Vec<TypedExpression<'ast, T>>,
|
||||
|
@ -602,16 +607,18 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
TypedStatement::Definition(v.clone().into(), c),
|
||||
TypedStatement::MultipleDefinition(
|
||||
vec![assignee],
|
||||
TypedExpressionList::EmbedCall(
|
||||
embed, generics, arguments, types,
|
||||
),
|
||||
TypedExpressionListInner::EmbedCall(
|
||||
embed, generics, arguments,
|
||||
)
|
||||
.annotate(types),
|
||||
),
|
||||
],
|
||||
None => vec![TypedStatement::MultipleDefinition(
|
||||
vec![assignee],
|
||||
TypedExpressionList::EmbedCall(
|
||||
embed, generics, arguments, types,
|
||||
),
|
||||
TypedExpressionListInner::EmbedCall(
|
||||
embed, generics, arguments,
|
||||
)
|
||||
.annotate(types),
|
||||
)],
|
||||
}
|
||||
}
|
||||
|
@ -623,9 +630,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
|
||||
let def = TypedStatement::MultipleDefinition(
|
||||
assignees.clone(),
|
||||
TypedExpressionList::EmbedCall(
|
||||
embed, generics, arguments, types,
|
||||
),
|
||||
TypedExpressionListInner::EmbedCall(embed, generics, arguments)
|
||||
.annotate(types),
|
||||
);
|
||||
|
||||
let invalidations = assignees.iter().flat_map(|assignee| {
|
||||
|
@ -645,27 +651,29 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
}
|
||||
}
|
||||
}
|
||||
TypedExpressionList::FunctionCall(key, generics, arguments, types) => {
|
||||
let generics = generics
|
||||
TypedExpressionListInner::FunctionCall(function_call) => {
|
||||
let generics = function_call
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let arguments: Vec<_> = arguments
|
||||
let arguments: Vec<_> = function_call
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|a| self.fold_expression(a))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let types = types
|
||||
.into_iter()
|
||||
.map(|t| self.fold_type(t))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
// invalidate the cache for the return assignees as this call mutates them
|
||||
|
||||
let def = TypedStatement::MultipleDefinition(
|
||||
assignees.clone(),
|
||||
TypedExpressionList::FunctionCall(key, generics, arguments, types),
|
||||
TypedExpressionList::function_call(
|
||||
function_call.function_key,
|
||||
generics,
|
||||
arguments,
|
||||
)
|
||||
.annotate(types),
|
||||
);
|
||||
|
||||
let invalidations = assignees.iter().flat_map(|assignee| {
|
||||
|
|
|
@ -29,14 +29,14 @@ use crate::embed::FlatEmbed;
|
|||
use crate::static_analysis::reducer::Output;
|
||||
use crate::static_analysis::reducer::ShallowTransformer;
|
||||
use crate::static_analysis::reducer::Versions;
|
||||
use crate::typed_absy::types::ConcreteGenericsAssignment;
|
||||
use crate::typed_absy::types::{ConcreteGenericsAssignment, IntoTypes};
|
||||
use crate::typed_absy::CoreIdentifier;
|
||||
use crate::typed_absy::Identifier;
|
||||
use crate::typed_absy::TypedAssignee;
|
||||
use crate::typed_absy::{
|
||||
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Signature,
|
||||
Type, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, UExpression,
|
||||
UExpressionInner, Variable,
|
||||
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Expr,
|
||||
Signature, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, Types,
|
||||
UExpression, UExpressionInner, Variable,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -46,13 +46,13 @@ pub enum InlineError<'ast, T> {
|
|||
FlatEmbed,
|
||||
Vec<u32>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Vec<Type<'ast, T>>,
|
||||
Types<'ast, T>,
|
||||
),
|
||||
NonConstant(
|
||||
DeclarationFunctionKey<'ast>,
|
||||
Vec<Option<UExpression<'ast, T>>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Vec<Type<'ast, T>>,
|
||||
Types<'ast, T>,
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -79,11 +79,11 @@ type InlineResult<'ast, T> = Result<
|
|||
InlineError<'ast, T>,
|
||||
>;
|
||||
|
||||
pub fn inline_call<'a, 'ast, T: Field>(
|
||||
pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
||||
k: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_types: Vec<Type<'ast, T>>,
|
||||
output: &E::Ty,
|
||||
program: &TypedProgram<'ast, T>,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
) -> InlineResult<'ast, T> {
|
||||
|
@ -91,6 +91,8 @@ pub fn inline_call<'a, 'ast, T: Field>(
|
|||
|
||||
use crate::typed_absy::Typed;
|
||||
|
||||
let output_types = output.clone().into_types();
|
||||
|
||||
// we try to get concrete values for explicit generics
|
||||
let generics_values: Vec<Option<u32>> = generics
|
||||
.iter()
|
||||
|
@ -117,7 +119,7 @@ pub fn inline_call<'a, 'ast, T: Field>(
|
|||
let inferred_signature = Signature::new()
|
||||
.generics(generics.clone())
|
||||
.inputs(arguments.iter().map(|a| a.get_type()).collect())
|
||||
.outputs(output_types.clone());
|
||||
.outputs(output_types.clone().inner);
|
||||
|
||||
// we try to get concrete values for the whole signature. if this fails we should propagate again
|
||||
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {
|
||||
|
|
|
@ -22,15 +22,12 @@ use crate::typed_absy::Folder;
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::typed_absy::{
|
||||
ArrayExpression, ArrayExpressionInner, ArrayType, BlockExpression, BooleanExpression,
|
||||
CoreIdentifier, DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier,
|
||||
StructExpression, StructExpressionInner, StructType, Type, TypedExpression,
|
||||
TypedExpressionList, TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram,
|
||||
TypedStatement, UBitwidth, UExpression, UExpressionInner, Variable,
|
||||
ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall,
|
||||
FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, TypedExpression,
|
||||
TypedExpressionList, TypedExpressionListInner, TypedFunction, TypedFunctionSymbol, TypedModule,
|
||||
TypedProgram, TypedStatement, UExpression, UExpressionInner, Variable,
|
||||
};
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
use zokrates_field::Field;
|
||||
|
||||
use self::shallow_ssa::ShallowTransformer;
|
||||
|
@ -195,32 +192,35 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
|||
complete: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_function_call<E>(
|
||||
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_function_call_expression<
|
||||
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
|
||||
>(
|
||||
&mut self,
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Result<E, Error>
|
||||
where
|
||||
E: FunctionCall<'ast, T> + From<TypedExpression<'ast, T>> + std::fmt::Debug,
|
||||
{
|
||||
let generics = generics
|
||||
ty: &E::Ty,
|
||||
e: FunctionCallExpression<'ast, T, E>,
|
||||
) -> Result<FunctionCallOrExpression<'ast, T, E>, Self::Error> {
|
||||
let generics = e
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let arguments = arguments
|
||||
let arguments = e
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|e| self.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let res = inline_call(
|
||||
key.clone(),
|
||||
let res = inline_call::<_, E>(
|
||||
e.function_key.clone(),
|
||||
generics,
|
||||
arguments,
|
||||
vec![output_type.clone()],
|
||||
ty,
|
||||
&self.program,
|
||||
&mut self.versions,
|
||||
);
|
||||
|
@ -229,30 +229,31 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
|||
Ok(Output::Complete((statements, mut expressions))) => {
|
||||
self.complete &= true;
|
||||
self.statement_buffer.extend(statements);
|
||||
Ok(expressions.pop().unwrap().try_into().unwrap())
|
||||
Ok(FunctionCallOrExpression::Expression(
|
||||
E::from(expressions.pop().unwrap()).into_inner(),
|
||||
))
|
||||
}
|
||||
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
|
||||
self.complete = false;
|
||||
self.statement_buffer.extend(statements);
|
||||
self.for_loop_versions_after.extend(delta_for_loop_versions);
|
||||
Ok(expressions[0].clone().try_into().unwrap())
|
||||
Ok(FunctionCallOrExpression::Expression(
|
||||
E::from(expressions[0].clone()).into_inner(),
|
||||
))
|
||||
}
|
||||
Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!(
|
||||
"Call site `{}` incompatible with declaration `{}`",
|
||||
conc.to_string(),
|
||||
decl.to_string()
|
||||
))),
|
||||
Err(InlineError::NonConstant(key, generics, arguments, mut output_types)) => {
|
||||
Err(InlineError::NonConstant(key, generics, arguments, _)) => {
|
||||
self.complete = false;
|
||||
|
||||
Ok(E::function_call(
|
||||
key,
|
||||
generics,
|
||||
arguments,
|
||||
output_types.pop().unwrap(),
|
||||
))
|
||||
Ok(FunctionCallOrExpression::Expression(E::function_call(
|
||||
key, generics, arguments,
|
||||
)))
|
||||
}
|
||||
Err(InlineError::Flat(embed, generics, arguments, mut output_types)) => {
|
||||
Err(InlineError::Flat(embed, generics, arguments, output_types)) => {
|
||||
let identifier = Identifier::from(CoreIdentifier::Call(0)).version(
|
||||
*self
|
||||
.versions
|
||||
|
@ -260,23 +261,25 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
|||
.and_modify(|e| *e += 1) // if it was already declared, we increment
|
||||
.or_insert(0),
|
||||
);
|
||||
let var = Variable::with_id_and_type(identifier, output_types.pop().unwrap());
|
||||
let var = Variable::with_id_and_type(
|
||||
identifier.clone(),
|
||||
output_types.clone().inner.pop().unwrap(),
|
||||
);
|
||||
|
||||
let v = vec![var.clone().into()];
|
||||
|
||||
self.statement_buffer
|
||||
.push(TypedStatement::MultipleDefinition(
|
||||
v,
|
||||
TypedExpressionList::EmbedCall(embed, generics, arguments, output_types),
|
||||
TypedExpressionListInner::EmbedCall(embed, generics, arguments)
|
||||
.annotate(output_types),
|
||||
));
|
||||
Ok(TypedExpression::from(var).try_into().unwrap())
|
||||
Ok(FunctionCallOrExpression::Expression(E::identifier(
|
||||
identifier,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_block_expression<E: ResultFold<'ast, T>>(
|
||||
&mut self,
|
||||
|
@ -308,23 +311,28 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
let res = match s {
|
||||
TypedStatement::MultipleDefinition(
|
||||
v,
|
||||
TypedExpressionList::FunctionCall(key, generics, arguments, output_types),
|
||||
TypedExpressionList {
|
||||
inner: TypedExpressionListInner::FunctionCall(function_call),
|
||||
types,
|
||||
},
|
||||
) => {
|
||||
let generics = generics
|
||||
let generics = function_call
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let arguments = arguments
|
||||
let arguments = function_call
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|a| self.fold_expression(a))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
match inline_call(
|
||||
key,
|
||||
match inline_call::<_, TypedExpressionList<'ast, T>>(
|
||||
function_call.function_key,
|
||||
generics,
|
||||
arguments,
|
||||
output_types,
|
||||
&types,
|
||||
&self.program,
|
||||
&mut self.versions,
|
||||
) {
|
||||
|
@ -367,23 +375,15 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
|
||||
Ok(vec![TypedStatement::MultipleDefinition(
|
||||
v,
|
||||
TypedExpressionList::FunctionCall(
|
||||
key,
|
||||
generics,
|
||||
arguments,
|
||||
output_types,
|
||||
),
|
||||
TypedExpressionList::function_call(key, generics, arguments)
|
||||
.annotate(output_types),
|
||||
)])
|
||||
}
|
||||
Err(InlineError::Flat(embed, generics, arguments, output_types)) => {
|
||||
Ok(vec![TypedStatement::MultipleDefinition(
|
||||
v,
|
||||
TypedExpressionList::EmbedCall(
|
||||
embed,
|
||||
generics,
|
||||
arguments,
|
||||
output_types,
|
||||
),
|
||||
TypedExpressionListInner::EmbedCall(embed, generics, arguments)
|
||||
.annotate(output_types),
|
||||
)])
|
||||
}
|
||||
}
|
||||
|
@ -460,62 +460,12 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
res.map(|res| self.statement_buffer.drain(..).chain(res).collect())
|
||||
}
|
||||
|
||||
fn fold_boolean_expression(
|
||||
&mut self,
|
||||
e: BooleanExpression<'ast, T>,
|
||||
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
BooleanExpression::FunctionCall(key, generics, arguments) => {
|
||||
self.fold_function_call(key, generics, arguments, Type::Boolean)
|
||||
}
|
||||
e => fold_boolean_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
bitwidth: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
UExpressionInner::FunctionCall(key, generics, arguments) => self
|
||||
.fold_function_call::<UExpression<'ast, T>>(
|
||||
key,
|
||||
generics,
|
||||
arguments,
|
||||
Type::Uint(bitwidth),
|
||||
)
|
||||
.map(|e| e.into_inner()),
|
||||
e => fold_uint_expression_inner(self, bitwidth, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
&mut self,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
FieldElementExpression::FunctionCall(key, generic, arguments) => {
|
||||
self.fold_function_call(key, generic, arguments, Type::FieldElement)
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_array_expression_inner(
|
||||
&mut self,
|
||||
array_ty: &ArrayType<'ast, T>,
|
||||
e: ArrayExpressionInner<'ast, T>,
|
||||
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
ArrayExpressionInner::FunctionCall(key, generics, arguments) => self
|
||||
.fold_function_call::<ArrayExpression<_>>(
|
||||
key.clone(),
|
||||
generics,
|
||||
arguments.clone(),
|
||||
Type::array(array_ty.clone()),
|
||||
)
|
||||
.map(|e| e.into_inner()),
|
||||
ArrayExpressionInner::Slice(box array, box from, box to) => {
|
||||
let array = self.fold_array_expression(array)?;
|
||||
let from = self.fold_uint_expression(from)?;
|
||||
|
@ -531,25 +481,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
}
|
||||
}
|
||||
}
|
||||
_ => fold_array_expression_inner(self, &array_ty, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_struct_expression_inner(
|
||||
&mut self,
|
||||
struct_ty: &StructType<'ast, T>,
|
||||
e: StructExpressionInner<'ast, T>,
|
||||
) -> Result<StructExpressionInner<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
StructExpressionInner::FunctionCall(key, generics, arguments) => self
|
||||
.fold_function_call::<StructExpression<'ast, T>>(
|
||||
key,
|
||||
generics,
|
||||
arguments,
|
||||
Type::Struct(struct_ty.clone()),
|
||||
)
|
||||
.map(|e| e.into_inner()),
|
||||
_ => fold_struct_expression_inner(self, struct_ty, e),
|
||||
_ => fold_array_expression_inner(self, array_ty, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -685,10 +617,10 @@ mod tests {
|
|||
use crate::typed_absy::types::Constant;
|
||||
use crate::typed_absy::types::DeclarationSignature;
|
||||
use crate::typed_absy::{
|
||||
ArrayExpressionInner, DeclarationFunctionKey, DeclarationType, DeclarationVariable,
|
||||
FieldElementExpression, GenericIdentifier, Identifier, OwnedTypedModuleId, Select, Type,
|
||||
TypedExpression, TypedExpressionList, TypedExpressionOrSpread, UBitwidth, UExpressionInner,
|
||||
Variable,
|
||||
ArrayExpression, ArrayExpressionInner, DeclarationFunctionKey, DeclarationType,
|
||||
DeclarationVariable, FieldElementExpression, GenericIdentifier, Identifier,
|
||||
OwnedTypedModuleId, Select, Type, TypedExpression, TypedExpressionList,
|
||||
TypedExpressionOrSpread, Types, UBitwidth, UExpressionInner, Variable,
|
||||
};
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
|
@ -749,7 +681,7 @@ mod tests {
|
|||
),
|
||||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element("a").into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
TypedExpressionList::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
|
@ -757,8 +689,8 @@ mod tests {
|
|||
),
|
||||
vec![],
|
||||
vec![FieldElementExpression::Identifier("a".into()).into()],
|
||||
vec![Type::FieldElement],
|
||||
),
|
||||
)
|
||||
.annotate(Types::new(vec![Type::FieldElement])),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
|
@ -947,15 +879,15 @@ mod tests {
|
|||
),
|
||||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
TypedExpressionList::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
vec![None],
|
||||
vec![ArrayExpressionInner::Identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into()],
|
||||
vec![Type::array((Type::FieldElement, 1u32))],
|
||||
),
|
||||
)
|
||||
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
|
@ -1175,15 +1107,15 @@ mod tests {
|
|||
),
|
||||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
TypedExpressionList::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
vec![None],
|
||||
vec![ArrayExpressionInner::Identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into()],
|
||||
vec![Type::array((Type::FieldElement, 1u32))],
|
||||
),
|
||||
)
|
||||
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
|
@ -1379,7 +1311,7 @@ mod tests {
|
|||
)
|
||||
.into(),
|
||||
ArrayExpressionInner::Slice(
|
||||
box ArrayExpressionInner::FunctionCall(
|
||||
box ArrayExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "bar")
|
||||
.signature(foo_signature.clone()),
|
||||
vec![None],
|
||||
|
@ -1450,7 +1382,7 @@ mod tests {
|
|||
statements: vec![
|
||||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
TypedExpressionList::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
vec![None],
|
||||
|
@ -1459,8 +1391,8 @@ mod tests {
|
|||
)
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into()],
|
||||
vec![Type::array((Type::FieldElement, 1u32))],
|
||||
),
|
||||
)
|
||||
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
|
||||
),
|
||||
TypedStatement::Return(vec![]),
|
||||
],
|
||||
|
@ -1656,15 +1588,15 @@ mod tests {
|
|||
statements: vec![
|
||||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
TypedExpressionList::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
vec![None],
|
||||
vec![ArrayExpressionInner::Value(vec![].into())
|
||||
.annotate(Type::FieldElement, 0u32)
|
||||
.into()],
|
||||
vec![Type::array((Type::FieldElement, 1u32))],
|
||||
),
|
||||
)
|
||||
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
|
||||
),
|
||||
TypedStatement::Return(vec![]),
|
||||
],
|
||||
|
|
|
@ -174,88 +174,18 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
|
|||
res
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
fn fold_function_call_expression<
|
||||
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
|
||||
>(
|
||||
&mut self,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
) -> FieldElementExpression<'ast, T> {
|
||||
if let FieldElementExpression::FunctionCall(ref k, _, _) = e {
|
||||
if !k.id.starts_with('_') {
|
||||
self.blocked = true;
|
||||
}
|
||||
ty: &E::Ty,
|
||||
c: FunctionCallExpression<'ast, T, E>,
|
||||
) -> FunctionCallOrExpression<'ast, T, E> {
|
||||
if !c.function_key.id.starts_with('_') {
|
||||
self.blocked = true;
|
||||
}
|
||||
|
||||
fold_field_expression(self, e)
|
||||
}
|
||||
|
||||
fn fold_boolean_expression(
|
||||
&mut self,
|
||||
e: BooleanExpression<'ast, T>,
|
||||
) -> BooleanExpression<'ast, T> {
|
||||
if let BooleanExpression::FunctionCall(ref k, _, _) = e {
|
||||
if !k.id.starts_with('_') {
|
||||
self.blocked = true;
|
||||
}
|
||||
};
|
||||
|
||||
fold_boolean_expression(self, e)
|
||||
}
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
b: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> UExpressionInner<'ast, T> {
|
||||
if let UExpressionInner::FunctionCall(ref k, _, _) = e {
|
||||
if !k.id.starts_with('_') {
|
||||
self.blocked = true;
|
||||
}
|
||||
};
|
||||
|
||||
fold_uint_expression_inner(self, b, e)
|
||||
}
|
||||
|
||||
fn fold_array_expression_inner(
|
||||
&mut self,
|
||||
ty: &ArrayType<'ast, T>,
|
||||
e: ArrayExpressionInner<'ast, T>,
|
||||
) -> ArrayExpressionInner<'ast, T> {
|
||||
if let ArrayExpressionInner::FunctionCall(ref k, _, _) = e {
|
||||
if !k.id.starts_with('_') {
|
||||
self.blocked = true;
|
||||
}
|
||||
};
|
||||
|
||||
fold_array_expression_inner(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_struct_expression_inner(
|
||||
&mut self,
|
||||
ty: &StructType<'ast, T>,
|
||||
e: StructExpressionInner<'ast, T>,
|
||||
) -> StructExpressionInner<'ast, T> {
|
||||
if let StructExpressionInner::FunctionCall(ref k, _, _) = e {
|
||||
if !k.id.starts_with('_') {
|
||||
self.blocked = true;
|
||||
}
|
||||
};
|
||||
|
||||
fold_struct_expression_inner(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_expression_list(
|
||||
&mut self,
|
||||
e: TypedExpressionList<'ast, T>,
|
||||
) -> TypedExpressionList<'ast, T> {
|
||||
match e {
|
||||
TypedExpressionList::FunctionCall(ref k, _, _, _) => {
|
||||
if !k.id.starts_with('_') {
|
||||
self.blocked = true;
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
fold_expression_list(self, e)
|
||||
fold_function_call_expression(self, ty, c)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -440,7 +370,7 @@ mod tests {
|
|||
|
||||
let s: TypedStatement<Bn128Field> = TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element("a").into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
TypedExpressionList::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
|
@ -448,14 +378,14 @@ mod tests {
|
|||
),
|
||||
vec![],
|
||||
vec![FieldElementExpression::Identifier("a".into()).into()],
|
||||
vec![Type::FieldElement],
|
||||
),
|
||||
)
|
||||
.annotate(Types::new(vec![Type::FieldElement])),
|
||||
);
|
||||
assert_eq!(
|
||||
u.fold_statement(s),
|
||||
vec![TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element(Identifier::from("a").version(1)).into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
TypedExpressionList::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
|
@ -465,9 +395,9 @@ mod tests {
|
|||
vec![
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(0))
|
||||
.into()
|
||||
],
|
||||
vec![Type::FieldElement],
|
||||
]
|
||||
)
|
||||
.annotate(Types::new(vec![Type::FieldElement]))
|
||||
)]
|
||||
);
|
||||
}
|
||||
|
@ -887,14 +817,14 @@ mod tests {
|
|||
),
|
||||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element("a").into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
TypedExpressionList::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo"),
|
||||
vec![Some(
|
||||
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
|
||||
)],
|
||||
vec![FieldElementExpression::Identifier("a".into()).into()],
|
||||
vec![Type::FieldElement],
|
||||
),
|
||||
)
|
||||
.annotate(Types::new(vec![Type::FieldElement])),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
|
@ -905,7 +835,7 @@ mod tests {
|
|||
TypedStatement::Definition(
|
||||
Variable::field_element("a").into(),
|
||||
(FieldElementExpression::Identifier("a".into())
|
||||
* FieldElementExpression::FunctionCall(
|
||||
* FieldElementExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo"),
|
||||
vec![Some(
|
||||
UExpressionInner::Identifier("n".into())
|
||||
|
@ -962,7 +892,7 @@ mod tests {
|
|||
),
|
||||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element(Identifier::from("a").version(2)).into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
TypedExpressionList::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo"),
|
||||
vec![Some(
|
||||
UExpressionInner::Identifier(Identifier::from("n").version(1))
|
||||
|
@ -972,8 +902,8 @@ mod tests {
|
|||
Identifier::from("a").version(1),
|
||||
)
|
||||
.into()],
|
||||
vec![Type::FieldElement],
|
||||
),
|
||||
)
|
||||
.annotate(Types::new(vec![Type::FieldElement])),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::uint(Identifier::from("n").version(2), UBitwidth::B32).into(),
|
||||
|
@ -984,7 +914,7 @@ mod tests {
|
|||
TypedStatement::Definition(
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
(FieldElementExpression::Identifier(Identifier::from("a").version(2))
|
||||
* FieldElementExpression::FunctionCall(
|
||||
* FieldElementExpression::function_call(
|
||||
DeclarationFunctionKey::with_location("main", "foo"),
|
||||
vec![Some(
|
||||
UExpressionInner::Identifier(Identifier::from("n").version(2))
|
||||
|
|
|
@ -118,6 +118,10 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_types(&mut self, tys: Types<'ast, T>) -> Types<'ast, T> {
|
||||
fold_types(self, tys)
|
||||
}
|
||||
|
||||
fn fold_array_type(&mut self, t: ArrayType<'ast, T>) -> ArrayType<'ast, T> {
|
||||
ArrayType {
|
||||
ty: box self.fold_type(*t.ty),
|
||||
|
@ -199,6 +203,16 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
fold_member_expression(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_function_call_expression<
|
||||
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
|
||||
>(
|
||||
&mut self,
|
||||
ty: &E::Ty,
|
||||
e: FunctionCallExpression<'ast, T, E>,
|
||||
) -> FunctionCallOrExpression<'ast, T, E> {
|
||||
fold_function_call_expression(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_select_expression<
|
||||
E: Expr<'ast, T> + Select<'ast, T> + IfElse<'ast, T> + From<TypedExpression<'ast, T>>,
|
||||
>(
|
||||
|
@ -227,6 +241,14 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
fold_expression_list(self, es)
|
||||
}
|
||||
|
||||
fn fold_expression_list_inner(
|
||||
&mut self,
|
||||
tys: Types<'ast, T>,
|
||||
es: TypedExpressionListInner<'ast, T>,
|
||||
) -> TypedExpressionListInner<'ast, T> {
|
||||
fold_expression_list_inner(self, tys, es)
|
||||
}
|
||||
|
||||
fn fold_int_expression(&mut self, e: IntExpression<'ast, T>) -> IntExpression<'ast, T> {
|
||||
fold_int_expression(self, e)
|
||||
}
|
||||
|
@ -345,13 +367,13 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
.map(|e| f.fold_expression_or_spread(e))
|
||||
.collect(),
|
||||
),
|
||||
ArrayExpressionInner::FunctionCall(id, generics, exps) => {
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)))
|
||||
.collect();
|
||||
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
|
||||
ArrayExpressionInner::FunctionCall(id, generics, exps)
|
||||
ArrayExpressionInner::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(ty, function_call) {
|
||||
FunctionCallOrExpression::FunctionCall(function_call) => {
|
||||
ArrayExpressionInner::FunctionCall(function_call)
|
||||
}
|
||||
FunctionCallOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
||||
ArrayExpressionInner::IfElse(
|
||||
|
@ -395,13 +417,13 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
StructExpressionInner::Value(exprs) => {
|
||||
StructExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect())
|
||||
}
|
||||
StructExpressionInner::FunctionCall(id, generics, exps) => {
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)))
|
||||
.collect();
|
||||
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
|
||||
StructExpressionInner::FunctionCall(id, generics, exps)
|
||||
StructExpressionInner::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(ty, function_call) {
|
||||
FunctionCallOrExpression::FunctionCall(function_call) => {
|
||||
StructExpressionInner::FunctionCall(function_call)
|
||||
}
|
||||
FunctionCallOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
||||
StructExpressionInner::IfElse(
|
||||
|
@ -474,13 +496,13 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let alt = f.fold_field_expression(alt);
|
||||
FieldElementExpression::IfElse(box cond, box cons, box alt)
|
||||
}
|
||||
FieldElementExpression::FunctionCall(key, generics, exps) => {
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)))
|
||||
.collect();
|
||||
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
|
||||
FieldElementExpression::FunctionCall(key, generics, exps)
|
||||
FieldElementExpression::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(&Type::FieldElement, function_call) {
|
||||
FunctionCallOrExpression::FunctionCall(function_call) => {
|
||||
FieldElementExpression::FunctionCall(function_call)
|
||||
}
|
||||
FunctionCallOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Select(select) => {
|
||||
match f.fold_select_expression(&Type::FieldElement, select) {
|
||||
|
@ -622,13 +644,13 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let e = f.fold_boolean_expression(e);
|
||||
BooleanExpression::Not(box e)
|
||||
}
|
||||
BooleanExpression::FunctionCall(key, generics, exps) => {
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)))
|
||||
.collect();
|
||||
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
|
||||
BooleanExpression::FunctionCall(key, generics, exps)
|
||||
BooleanExpression::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(&Type::Boolean, function_call) {
|
||||
FunctionCallOrExpression::FunctionCall(function_call) => {
|
||||
BooleanExpression::FunctionCall(function_call)
|
||||
}
|
||||
FunctionCallOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
BooleanExpression::IfElse(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond);
|
||||
|
@ -748,13 +770,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
|
||||
UExpressionInner::Pos(box e)
|
||||
}
|
||||
UExpressionInner::FunctionCall(key, generics, exps) => {
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)))
|
||||
.collect();
|
||||
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
|
||||
UExpressionInner::FunctionCall(key, generics, exps)
|
||||
UExpressionInner::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(&ty, function_call) {
|
||||
FunctionCallOrExpression::FunctionCall(function_call) => {
|
||||
UExpressionInner::FunctionCall(function_call)
|
||||
}
|
||||
FunctionCallOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
UExpressionInner::Select(select) => match f.fold_select_expression(&ty, select) {
|
||||
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
|
||||
|
@ -797,6 +819,29 @@ pub fn fold_declaration_function_key<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_function_call_expression<
|
||||
'ast,
|
||||
T: Field,
|
||||
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
|
||||
F: Folder<'ast, T>,
|
||||
>(
|
||||
f: &mut F,
|
||||
_: &E::Ty,
|
||||
e: FunctionCallExpression<'ast, T, E>,
|
||||
) -> FunctionCallOrExpression<'ast, T, E> {
|
||||
FunctionCallOrExpression::FunctionCall(FunctionCallExpression::new(
|
||||
e.function_key,
|
||||
e.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)))
|
||||
.collect(),
|
||||
e.arguments
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
fun: TypedFunction<'ast, T>,
|
||||
|
@ -851,30 +896,45 @@ pub fn fold_expression_list<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
f: &mut F,
|
||||
es: TypedExpressionList<'ast, T>,
|
||||
) -> TypedExpressionList<'ast, T> {
|
||||
let types = f.fold_types(es.types);
|
||||
|
||||
TypedExpressionList {
|
||||
inner: f.fold_expression_list_inner(types.clone(), es.inner),
|
||||
types,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_types<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
tys: Types<'ast, T>,
|
||||
) -> Types<'ast, T> {
|
||||
Types {
|
||||
inner: tys.inner.into_iter().map(|t| f.fold_type(t)).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_expression_list_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
tys: Types<'ast, T>,
|
||||
es: TypedExpressionListInner<'ast, T>,
|
||||
) -> TypedExpressionListInner<'ast, T> {
|
||||
match es {
|
||||
TypedExpressionList::FunctionCall(id, generics, arguments, types) => {
|
||||
TypedExpressionList::FunctionCall(
|
||||
id,
|
||||
generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)))
|
||||
.collect(),
|
||||
arguments
|
||||
.into_iter()
|
||||
.map(|a| f.fold_expression(a))
|
||||
.collect(),
|
||||
types.into_iter().map(|t| f.fold_type(t)).collect(),
|
||||
)
|
||||
TypedExpressionListInner::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(&tys, function_call) {
|
||||
FunctionCallOrExpression::FunctionCall(function_call) => {
|
||||
TypedExpressionListInner::FunctionCall(function_call)
|
||||
}
|
||||
FunctionCallOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
TypedExpressionList::EmbedCall(embed, generics, arguments, types) => {
|
||||
TypedExpressionList::EmbedCall(
|
||||
TypedExpressionListInner::EmbedCall(embed, generics, arguments) => {
|
||||
TypedExpressionListInner::EmbedCall(
|
||||
embed,
|
||||
generics,
|
||||
arguments
|
||||
.into_iter()
|
||||
.map(|a| f.fold_expression(a))
|
||||
.collect(),
|
||||
types.into_iter().map(|t| f.fold_type(t)).collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ pub use self::parameter::{DeclarationParameter, GParameter};
|
|||
pub use self::types::{
|
||||
ConcreteFunctionKey, ConcreteSignature, ConcreteType, DeclarationFunctionKey,
|
||||
DeclarationSignature, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
|
||||
Signature, StructType, Type, UBitwidth,
|
||||
IntoTypes, Signature, StructType, Type, Types, UBitwidth,
|
||||
};
|
||||
use crate::typed_absy::types::ConcreteGenericsAssignment;
|
||||
|
||||
|
@ -638,30 +638,8 @@ impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> {
|
|||
.collect::<Vec<String>>()
|
||||
.join(", ")
|
||||
),
|
||||
StructExpressionInner::FunctionCall(ref key, ref generics, ref p) => {
|
||||
write!(f, "{}", key.id,)?;
|
||||
if !generics.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"::<{}>",
|
||||
generics
|
||||
.iter()
|
||||
.map(|g| g
|
||||
.as_ref()
|
||||
.map(|g| g.to_string())
|
||||
.unwrap_or_else(|| '_'.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)?;
|
||||
}
|
||||
write!(f, "(")?;
|
||||
for (i, param) in p.iter().enumerate() {
|
||||
write!(f, "{}", param)?;
|
||||
if i < p.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, ")")
|
||||
StructExpressionInner::FunctionCall(ref function_call) => {
|
||||
write!(f, "{}", function_call)
|
||||
}
|
||||
StructExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => {
|
||||
write!(
|
||||
|
@ -724,27 +702,33 @@ pub trait MultiTyped<'ast, T> {
|
|||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
pub enum TypedExpressionList<'ast, T> {
|
||||
FunctionCall(
|
||||
DeclarationFunctionKey<'ast>,
|
||||
Vec<Option<UExpression<'ast, T>>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Vec<Type<'ast, T>>,
|
||||
),
|
||||
EmbedCall(
|
||||
FlatEmbed,
|
||||
Vec<u32>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Vec<Type<'ast, T>>,
|
||||
),
|
||||
|
||||
pub struct TypedExpressionList<'ast, T> {
|
||||
pub inner: TypedExpressionListInner<'ast, T>,
|
||||
pub types: Types<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionList<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.inner)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
pub enum TypedExpressionListInner<'ast, T> {
|
||||
FunctionCall(FunctionCallExpression<'ast, T, TypedExpressionList<'ast, T>>),
|
||||
EmbedCall(FlatEmbed, Vec<u32>, Vec<TypedExpression<'ast, T>>),
|
||||
}
|
||||
|
||||
impl<'ast, T> MultiTyped<'ast, T> for TypedExpressionList<'ast, T> {
|
||||
fn get_types(&self) -> &Vec<Type<'ast, T>> {
|
||||
match *self {
|
||||
TypedExpressionList::FunctionCall(_, _, _, ref types) => types,
|
||||
TypedExpressionList::EmbedCall(_, _, _, ref types) => types,
|
||||
}
|
||||
&self.types.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> TypedExpressionListInner<'ast, T> {
|
||||
pub fn annotate(self, types: Types<'ast, T>) -> TypedExpressionList<'ast, T> {
|
||||
TypedExpressionList { inner: self, types }
|
||||
}
|
||||
}
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
|
@ -808,6 +792,58 @@ impl<'ast, T: fmt::Display, E> fmt::Display for SelectExpression<'ast, T, E> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
pub struct FunctionCallExpression<'ast, T, E> {
|
||||
pub function_key: DeclarationFunctionKey<'ast>,
|
||||
pub generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
pub arguments: Vec<TypedExpression<'ast, T>>,
|
||||
ty: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<'ast, T, E> FunctionCallExpression<'ast, T, E> {
|
||||
pub fn new(
|
||||
function_key: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self {
|
||||
FunctionCallExpression {
|
||||
function_key,
|
||||
generics,
|
||||
arguments,
|
||||
ty: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display, E> fmt::Display for FunctionCallExpression<'ast, T, E> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.function_key.id,)?;
|
||||
if !self.generics.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"::<{}>",
|
||||
self.generics
|
||||
.iter()
|
||||
.map(|g| g
|
||||
.as_ref()
|
||||
.map(|g| g.to_string())
|
||||
.unwrap_or_else(|| '_'.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)?;
|
||||
}
|
||||
write!(
|
||||
f,
|
||||
"({})",
|
||||
self.arguments
|
||||
.iter()
|
||||
.map(|a| a.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// An expression of type `field`
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
pub enum FieldElementExpression<'ast, T> {
|
||||
|
@ -841,11 +877,7 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
),
|
||||
Neg(Box<FieldElementExpression<'ast, T>>),
|
||||
Pos(Box<FieldElementExpression<'ast, T>>),
|
||||
FunctionCall(
|
||||
DeclarationFunctionKey<'ast>,
|
||||
Vec<Option<UExpression<'ast, T>>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
),
|
||||
FunctionCall(FunctionCallExpression<'ast, T, Self>),
|
||||
Member(MemberExpression<'ast, T, Self>),
|
||||
Select(SelectExpression<'ast, T, Self>),
|
||||
}
|
||||
|
@ -948,11 +980,7 @@ pub enum BooleanExpression<'ast, T> {
|
|||
Box<BooleanExpression<'ast, T>>,
|
||||
),
|
||||
Member(MemberExpression<'ast, T, Self>),
|
||||
FunctionCall(
|
||||
DeclarationFunctionKey<'ast>,
|
||||
Vec<Option<UExpression<'ast, T>>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
),
|
||||
FunctionCall(FunctionCallExpression<'ast, T, Self>),
|
||||
Select(SelectExpression<'ast, T, Self>),
|
||||
}
|
||||
|
||||
|
@ -1056,11 +1084,7 @@ pub enum ArrayExpressionInner<'ast, T> {
|
|||
Block(BlockExpression<'ast, T, ArrayExpression<'ast, T>>),
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(ArrayValue<'ast, T>),
|
||||
FunctionCall(
|
||||
DeclarationFunctionKey<'ast>,
|
||||
Vec<Option<UExpression<'ast, T>>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
),
|
||||
FunctionCall(FunctionCallExpression<'ast, T, ArrayExpression<'ast, T>>),
|
||||
IfElse(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<ArrayExpression<'ast, T>>,
|
||||
|
@ -1165,11 +1189,7 @@ pub enum StructExpressionInner<'ast, T> {
|
|||
Block(BlockExpression<'ast, T, StructExpression<'ast, T>>),
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(Vec<TypedExpression<'ast, T>>),
|
||||
FunctionCall(
|
||||
DeclarationFunctionKey<'ast>,
|
||||
Vec<Option<UExpression<'ast, T>>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
),
|
||||
FunctionCall(FunctionCallExpression<'ast, T, StructExpression<'ast, T>>),
|
||||
IfElse(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<StructExpression<'ast, T>>,
|
||||
|
@ -1243,6 +1263,15 @@ impl<'ast, T> From<TypedExpression<'ast, T>> for StructExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
// `TypedExpressionList` can technically not be constructed from `TypedExpression`
|
||||
// However implementing `From<TypedExpression>` is required for `TypedExpressionList` to be `Expr`, which makes generic treatment of function calls possible
|
||||
// This could maybe be avoided by splitting the `Expr` trait into many, but I did not find a way
|
||||
impl<'ast, T> From<TypedExpression<'ast, T>> for TypedExpressionList<'ast, T> {
|
||||
fn from(_: TypedExpression<'ast, T>) -> TypedExpressionList<'ast, T> {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<TypedConstant<'ast, T>> for FieldElementExpression<'ast, T> {
|
||||
fn from(tc: TypedConstant<'ast, T>) -> FieldElementExpression<'ast, T> {
|
||||
tc.expression.into()
|
||||
|
@ -1314,30 +1343,8 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
|
|||
condition, consequent, alternative
|
||||
)
|
||||
}
|
||||
FieldElementExpression::FunctionCall(ref k, ref generics, ref p) => {
|
||||
write!(f, "{}", k.id,)?;
|
||||
if !generics.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"::<{}>",
|
||||
generics
|
||||
.iter()
|
||||
.map(|g| g
|
||||
.as_ref()
|
||||
.map(|g| g.to_string())
|
||||
.unwrap_or_else(|| '_'.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)?;
|
||||
}
|
||||
write!(f, "(")?;
|
||||
for (i, param) in p.iter().enumerate() {
|
||||
write!(f, "{}", param)?;
|
||||
if i < p.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, ")")
|
||||
FieldElementExpression::FunctionCall(ref function_call) => {
|
||||
write!(f, "{}", function_call)
|
||||
}
|
||||
FieldElementExpression::Member(ref m) => write!(f, "{}", m),
|
||||
FieldElementExpression::Select(ref select) => write!(f, "{}", select),
|
||||
|
@ -1368,31 +1375,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
|||
UExpressionInner::Neg(ref e) => write!(f, "(-{})", e),
|
||||
UExpressionInner::Pos(ref e) => write!(f, "(+{})", e),
|
||||
UExpressionInner::Select(ref select) => write!(f, "{}", select),
|
||||
UExpressionInner::FunctionCall(ref k, ref generics, ref p) => {
|
||||
write!(f, "{}", k.id,)?;
|
||||
if !generics.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"::<{}>",
|
||||
generics
|
||||
.iter()
|
||||
.map(|g| g
|
||||
.as_ref()
|
||||
.map(|g| g.to_string())
|
||||
.unwrap_or_else(|| '_'.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)?;
|
||||
}
|
||||
write!(f, "(")?;
|
||||
for (i, param) in p.iter().enumerate() {
|
||||
write!(f, "{}", param)?;
|
||||
if i < p.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, ")")
|
||||
}
|
||||
UExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call),
|
||||
UExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!(
|
||||
f,
|
||||
"if {} then {} else {} fi",
|
||||
|
@ -1425,31 +1408,7 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
|||
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
|
||||
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
|
||||
BooleanExpression::Value(b) => write!(f, "{}", b),
|
||||
BooleanExpression::FunctionCall(ref k, ref generics, ref p) => {
|
||||
write!(f, "{}", k.id,)?;
|
||||
if !generics.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"::<{}>",
|
||||
generics
|
||||
.iter()
|
||||
.map(|g| g
|
||||
.as_ref()
|
||||
.map(|g| g.to_string())
|
||||
.unwrap_or_else(|| '_'.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)?;
|
||||
}
|
||||
write!(f, "(")?;
|
||||
for (i, param) in p.iter().enumerate() {
|
||||
write!(f, "{}", param)?;
|
||||
if i < p.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, ")")
|
||||
}
|
||||
BooleanExpression::FunctionCall(ref function_call) => write!(f, "{}", function_call),
|
||||
BooleanExpression::IfElse(ref condition, ref consequent, ref alternative) => write!(
|
||||
f,
|
||||
"if {} then {} else {} fi",
|
||||
|
@ -1475,31 +1434,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> {
|
|||
.collect::<Vec<String>>()
|
||||
.join(", ")
|
||||
),
|
||||
ArrayExpressionInner::FunctionCall(ref key, ref generics, ref p) => {
|
||||
write!(f, "{}", key.id,)?;
|
||||
if !generics.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"::<{}>",
|
||||
generics
|
||||
.iter()
|
||||
.map(|g| g
|
||||
.as_ref()
|
||||
.map(|g| g.to_string())
|
||||
.unwrap_or_else(|| '_'.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)?;
|
||||
}
|
||||
write!(f, "(")?;
|
||||
for (i, param) in p.iter().enumerate() {
|
||||
write!(f, "{}", param)?;
|
||||
if i < p.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, ")")
|
||||
}
|
||||
ArrayExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call),
|
||||
ArrayExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!(
|
||||
f,
|
||||
"if {} then {} else {} fi",
|
||||
|
@ -1517,35 +1452,13 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionList<'ast, T> {
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionListInner<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
TypedExpressionList::FunctionCall(ref k, ref generics, ref p, _) => {
|
||||
write!(f, "{}", k.id,)?;
|
||||
if !generics.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"::<{}>",
|
||||
generics
|
||||
.iter()
|
||||
.map(|g| g
|
||||
.as_ref()
|
||||
.map(|g| g.to_string())
|
||||
.unwrap_or_else(|| '_'.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)?;
|
||||
}
|
||||
write!(f, "(")?;
|
||||
for (i, param) in p.iter().enumerate() {
|
||||
write!(f, "{}", param)?;
|
||||
if i < p.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, ")")
|
||||
TypedExpressionListInner::FunctionCall(ref function_call) => {
|
||||
write!(f, "{}", function_call)
|
||||
}
|
||||
TypedExpressionList::EmbedCall(ref embed, ref generics, ref p, _) => {
|
||||
TypedExpressionListInner::EmbedCall(ref embed, ref generics, ref p) => {
|
||||
write!(f, "{}", embed.id())?;
|
||||
if !generics.is_empty() {
|
||||
write!(
|
||||
|
@ -1590,16 +1503,16 @@ impl<'ast, T: Field> From<Variable<'ast, T>> for TypedExpression<'ast, T> {
|
|||
|
||||
// Common behaviour across expressions
|
||||
|
||||
pub trait Expr<'ast, T> {
|
||||
pub trait Expr<'ast, T>: From<TypedExpression<'ast, T>> {
|
||||
type Inner;
|
||||
type Ty;
|
||||
type Ty: Clone + IntoTypes<'ast, T>;
|
||||
|
||||
fn into_inner(self) -> Self::Inner;
|
||||
|
||||
fn as_inner(&self) -> &Self::Inner;
|
||||
}
|
||||
|
||||
impl<'ast, T> Expr<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
type Inner = Self;
|
||||
type Ty = Type<'ast, T>;
|
||||
|
||||
|
@ -1612,7 +1525,7 @@ impl<'ast, T> Expr<'ast, T> for FieldElementExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> Expr<'ast, T> for BooleanExpression<'ast, T> {
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> {
|
||||
type Inner = Self;
|
||||
type Ty = Type<'ast, T>;
|
||||
|
||||
|
@ -1625,7 +1538,7 @@ impl<'ast, T> Expr<'ast, T> for BooleanExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> Expr<'ast, T> for UExpression<'ast, T> {
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> {
|
||||
type Inner = UExpressionInner<'ast, T>;
|
||||
type Ty = UBitwidth;
|
||||
|
||||
|
@ -1638,7 +1551,7 @@ impl<'ast, T> Expr<'ast, T> for UExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> Expr<'ast, T> for StructExpression<'ast, T> {
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> {
|
||||
type Inner = StructExpressionInner<'ast, T>;
|
||||
type Ty = StructType<'ast, T>;
|
||||
|
||||
|
@ -1651,7 +1564,7 @@ impl<'ast, T> Expr<'ast, T> for StructExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> Expr<'ast, T> for ArrayExpression<'ast, T> {
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> {
|
||||
type Inner = ArrayExpressionInner<'ast, T>;
|
||||
type Ty = ArrayType<'ast, T>;
|
||||
|
||||
|
@ -1664,7 +1577,7 @@ impl<'ast, T> Expr<'ast, T> for ArrayExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> Expr<'ast, T> for IntExpression<'ast, T> {
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> {
|
||||
type Inner = Self;
|
||||
type Ty = Type<'ast, T>;
|
||||
|
||||
|
@ -1677,8 +1590,25 @@ impl<'ast, T> Expr<'ast, T> for IntExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> {
|
||||
type Inner = TypedExpressionListInner<'ast, T>;
|
||||
type Ty = Types<'ast, T>;
|
||||
|
||||
fn into_inner(self) -> Self::Inner {
|
||||
self.inner
|
||||
}
|
||||
|
||||
fn as_inner(&self) -> &Self::Inner {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
// Enums types to enable returning e.g a member expression OR another type of expression of this type
|
||||
|
||||
pub enum FunctionCallOrExpression<'ast, T, E: Expr<'ast, T>> {
|
||||
FunctionCall(FunctionCallExpression<'ast, T, E>),
|
||||
Expression(E::Inner),
|
||||
}
|
||||
pub enum SelectOrExpression<'ast, T, E: Expr<'ast, T>> {
|
||||
Select(SelectExpression<'ast, T, E>),
|
||||
Expression(E::Inner),
|
||||
|
@ -1886,13 +1816,55 @@ impl<'ast, T: Clone> Member<'ast, T> for StructExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait FunctionCall<'ast, T> {
|
||||
pub trait Id<'ast, T>: Expr<'ast, T> {
|
||||
fn identifier(id: Identifier<'ast>) -> Self::Inner;
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Id<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
fn identifier(id: Identifier<'ast>) -> Self::Inner {
|
||||
FieldElementExpression::Identifier(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Id<'ast, T> for BooleanExpression<'ast, T> {
|
||||
fn identifier(id: Identifier<'ast>) -> Self::Inner {
|
||||
BooleanExpression::Identifier(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Id<'ast, T> for UExpression<'ast, T> {
|
||||
fn identifier(id: Identifier<'ast>) -> Self::Inner {
|
||||
UExpressionInner::Identifier(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Id<'ast, T> for ArrayExpression<'ast, T> {
|
||||
fn identifier(id: Identifier<'ast>) -> Self::Inner {
|
||||
ArrayExpressionInner::Identifier(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Id<'ast, T> for StructExpression<'ast, T> {
|
||||
fn identifier(id: Identifier<'ast>) -> Self::Inner {
|
||||
StructExpressionInner::Identifier(id)
|
||||
}
|
||||
}
|
||||
|
||||
// `TypedExpressionList` does not have an Identifier variant
|
||||
// However implementing `From<TypedExpression>` is required for `TypedExpressionList` to be `Expr`, which makes generic treatment of function calls possible
|
||||
// This could maybe be avoided by splitting the `Expr` trait into many, but I did not find a way
|
||||
impl<'ast, T: Field> Id<'ast, T> for TypedExpressionList<'ast, T> {
|
||||
fn identifier(_: Identifier<'ast>) -> Self::Inner {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FunctionCall<'ast, T>: Expr<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self;
|
||||
) -> Self::Inner;
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
|
@ -1900,10 +1872,8 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> {
|
|||
key: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self {
|
||||
assert_eq!(output_type, Type::FieldElement);
|
||||
FieldElementExpression::FunctionCall(key, generics, arguments)
|
||||
) -> Self::Inner {
|
||||
FieldElementExpression::FunctionCall(FunctionCallExpression::new(key, generics, arguments))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1912,10 +1882,8 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for BooleanExpression<'ast, T> {
|
|||
key: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self {
|
||||
assert_eq!(output_type, Type::Boolean);
|
||||
BooleanExpression::FunctionCall(key, generics, arguments)
|
||||
) -> Self::Inner {
|
||||
BooleanExpression::FunctionCall(FunctionCallExpression::new(key, generics, arguments))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1924,13 +1892,8 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for UExpression<'ast, T> {
|
|||
key: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self {
|
||||
let bitwidth = match output_type {
|
||||
Type::Uint(bitwidth) => bitwidth,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
UExpressionInner::FunctionCall(key, generics, arguments).annotate(bitwidth)
|
||||
) -> Self::Inner {
|
||||
UExpressionInner::FunctionCall(FunctionCallExpression::new(key, generics, arguments))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1939,14 +1902,8 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for ArrayExpression<'ast, T> {
|
|||
key: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self {
|
||||
let array_ty = match output_type {
|
||||
Type::Array(array_ty) => array_ty,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
ArrayExpressionInner::FunctionCall(key, generics, arguments)
|
||||
.annotate(*array_ty.ty, array_ty.size)
|
||||
) -> Self::Inner {
|
||||
ArrayExpressionInner::FunctionCall(FunctionCallExpression::new(key, generics, arguments))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1955,14 +1912,20 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for StructExpression<'ast, T> {
|
|||
key: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_type: Type<'ast, T>,
|
||||
) -> Self {
|
||||
let struct_ty = match output_type {
|
||||
Type::Struct(struct_ty) => struct_ty,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
) -> Self::Inner {
|
||||
StructExpressionInner::FunctionCall(FunctionCallExpression::new(key, generics, arguments))
|
||||
}
|
||||
}
|
||||
|
||||
StructExpressionInner::FunctionCall(key, generics, arguments).annotate(struct_ty)
|
||||
impl<'ast, T: Field> FunctionCall<'ast, T> for TypedExpressionList<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self::Inner {
|
||||
TypedExpressionListInner::FunctionCall(FunctionCallExpression::new(
|
||||
key, generics, arguments,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -138,6 +138,10 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_types(&mut self, tys: Types<'ast, T>) -> Result<Types<'ast, T>, Self::Error> {
|
||||
fold_types(self, tys)
|
||||
}
|
||||
|
||||
fn fold_block_expression<E: ResultFold<'ast, T>>(
|
||||
&mut self,
|
||||
block: BlockExpression<'ast, T, E>,
|
||||
|
@ -165,6 +169,16 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
fold_select_expression(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_function_call_expression<
|
||||
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
|
||||
>(
|
||||
&mut self,
|
||||
ty: &E::Ty,
|
||||
e: FunctionCallExpression<'ast, T, E>,
|
||||
) -> Result<FunctionCallOrExpression<'ast, T, E>, Self::Error> {
|
||||
fold_function_call_expression(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_array_type(
|
||||
&mut self,
|
||||
t: ArrayType<'ast, T>,
|
||||
|
@ -274,6 +288,14 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
fold_struct_expression(self, e)
|
||||
}
|
||||
|
||||
fn fold_expression_list_inner(
|
||||
&mut self,
|
||||
tys: &Types<'ast, T>,
|
||||
es: TypedExpressionListInner<'ast, T>,
|
||||
) -> Result<TypedExpressionListInner<'ast, T>, Self::Error> {
|
||||
fold_expression_list_inner(self, tys, es)
|
||||
}
|
||||
|
||||
fn fold_expression_list(
|
||||
&mut self,
|
||||
es: TypedExpressionList<'ast, T>,
|
||||
|
@ -387,16 +409,11 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
.map(|e| f.fold_expression_or_spread(e))
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
ArrayExpressionInner::FunctionCall(id, generics, exps) => {
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<_, _>>()?;
|
||||
let exps = exps
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
ArrayExpressionInner::FunctionCall(id, generics, exps)
|
||||
ArrayExpressionInner::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(ty, function_call)? {
|
||||
FunctionCallOrExpression::FunctionCall(c) => ArrayExpressionInner::FunctionCall(c),
|
||||
FunctionCallOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
||||
ArrayExpressionInner::IfElse(
|
||||
|
@ -446,16 +463,11 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
StructExpressionInner::FunctionCall(id, generics, exps) => {
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<_, _>>()?;
|
||||
let exps = exps
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
StructExpressionInner::FunctionCall(id, generics, exps)
|
||||
StructExpressionInner::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(ty, function_call)? {
|
||||
FunctionCallOrExpression::FunctionCall(c) => StructExpressionInner::FunctionCall(c),
|
||||
FunctionCallOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
||||
StructExpressionInner::IfElse(
|
||||
|
@ -529,16 +541,13 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
let alt = f.fold_field_expression(alt)?;
|
||||
FieldElementExpression::IfElse(box cond, box cons, box alt)
|
||||
}
|
||||
FieldElementExpression::FunctionCall(key, generics, exps) => {
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<_, _>>()?;
|
||||
let exps = exps
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
FieldElementExpression::FunctionCall(key, generics, exps)
|
||||
FieldElementExpression::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(&Type::FieldElement, function_call)? {
|
||||
FunctionCallOrExpression::FunctionCall(c) => {
|
||||
FieldElementExpression::FunctionCall(c)
|
||||
}
|
||||
FunctionCallOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Member(m) => {
|
||||
match f.fold_member_expression(&Type::FieldElement, m)? {
|
||||
|
@ -612,6 +621,29 @@ pub fn fold_select_expression<
|
|||
)))
|
||||
}
|
||||
|
||||
pub fn fold_function_call_expression<
|
||||
'ast,
|
||||
T: Field,
|
||||
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
|
||||
F: ResultFolder<'ast, T>,
|
||||
>(
|
||||
f: &mut F,
|
||||
_: &E::Ty,
|
||||
e: FunctionCallExpression<'ast, T, E>,
|
||||
) -> Result<FunctionCallOrExpression<'ast, T, E>, F::Error> {
|
||||
Ok(FunctionCallOrExpression::Expression(E::function_call(
|
||||
e.function_key,
|
||||
e.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<_, _>>()?,
|
||||
e.arguments
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?,
|
||||
)))
|
||||
}
|
||||
|
||||
pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: BooleanExpression<'ast, T>,
|
||||
|
@ -701,16 +733,11 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
let e = f.fold_boolean_expression(e)?;
|
||||
BooleanExpression::Not(box e)
|
||||
}
|
||||
BooleanExpression::FunctionCall(key, generics, exps) => {
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<_, _>>()?;
|
||||
let exps = exps
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
BooleanExpression::FunctionCall(key, generics, exps)
|
||||
BooleanExpression::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(&Type::Boolean, function_call)? {
|
||||
FunctionCallOrExpression::FunctionCall(c) => BooleanExpression::FunctionCall(c),
|
||||
FunctionCallOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
BooleanExpression::IfElse(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond)?;
|
||||
|
@ -832,16 +859,11 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
|
||||
UExpressionInner::Pos(box e)
|
||||
}
|
||||
UExpressionInner::FunctionCall(key, generics, exps) => {
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<_, _>>()?;
|
||||
let exps = exps
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
UExpressionInner::FunctionCall(key, generics, exps)
|
||||
UExpressionInner::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(&ty, function_call)? {
|
||||
FunctionCallOrExpression::FunctionCall(c) => UExpressionInner::FunctionCall(c),
|
||||
FunctionCallOrExpression::Expression(u) => u,
|
||||
}
|
||||
}
|
||||
UExpressionInner::Select(select) => match f.fold_select_expression(&ty, select)? {
|
||||
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
|
||||
|
@ -928,37 +950,49 @@ pub fn fold_expression_list<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
f: &mut F,
|
||||
es: TypedExpressionList<'ast, T>,
|
||||
) -> Result<TypedExpressionList<'ast, T>, F::Error> {
|
||||
let types = f.fold_types(es.types)?;
|
||||
|
||||
Ok(TypedExpressionList {
|
||||
inner: f.fold_expression_list_inner(&types, es.inner)?,
|
||||
types,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_types<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
tys: Types<'ast, T>,
|
||||
) -> Result<Types<'ast, T>, F::Error> {
|
||||
Ok(Types {
|
||||
inner: tys
|
||||
.inner
|
||||
.into_iter()
|
||||
.map(|t| f.fold_type(t))
|
||||
.collect::<Result<_, _>>()?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_expression_list_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
tys: &Types<'ast, T>,
|
||||
es: TypedExpressionListInner<'ast, T>,
|
||||
) -> Result<TypedExpressionListInner<'ast, T>, F::Error> {
|
||||
match es {
|
||||
TypedExpressionList::FunctionCall(id, generics, arguments, types) => {
|
||||
let generics = generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
|
||||
.collect::<Result<_, _>>()?;
|
||||
Ok(TypedExpressionList::FunctionCall(
|
||||
id,
|
||||
generics,
|
||||
arguments
|
||||
.into_iter()
|
||||
.map(|a| f.fold_expression(a))
|
||||
.collect::<Result<_, _>>()?,
|
||||
types
|
||||
.into_iter()
|
||||
.map(|t| f.fold_type(t))
|
||||
.collect::<Result<_, _>>()?,
|
||||
))
|
||||
TypedExpressionListInner::FunctionCall(function_call) => {
|
||||
match f.fold_function_call_expression(tys, function_call)? {
|
||||
FunctionCallOrExpression::FunctionCall(function_call) => {
|
||||
Ok(TypedExpressionListInner::FunctionCall(function_call))
|
||||
}
|
||||
FunctionCallOrExpression::Expression(list) => Ok(list),
|
||||
}
|
||||
}
|
||||
TypedExpressionList::EmbedCall(embed, generics, arguments, types) => {
|
||||
Ok(TypedExpressionList::EmbedCall(
|
||||
TypedExpressionListInner::EmbedCall(embed, generics, arguments) => {
|
||||
Ok(TypedExpressionListInner::EmbedCall(
|
||||
embed,
|
||||
generics,
|
||||
arguments
|
||||
.into_iter()
|
||||
.map(|a| f.fold_expression(a))
|
||||
.collect::<Result<_, _>>()?,
|
||||
types
|
||||
.into_iter()
|
||||
.map(|t| f.fold_type(t))
|
||||
.collect::<Result<_, _>>()?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,57 @@ use std::fmt;
|
|||
use std::hash::{Hash, Hasher};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
pub trait IntoTypes<'ast, T> {
|
||||
fn into_types(self) -> Types<'ast, T>;
|
||||
}
|
||||
|
||||
impl<'ast, T> IntoTypes<'ast, T> for Type<'ast, T> {
|
||||
fn into_types(self) -> Types<'ast, T> {
|
||||
Types { inner: vec![self] }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> IntoTypes<'ast, T> for StructType<'ast, T> {
|
||||
fn into_types(self) -> Types<'ast, T> {
|
||||
Types {
|
||||
inner: vec![Type::Struct(self)],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> IntoTypes<'ast, T> for ArrayType<'ast, T> {
|
||||
fn into_types(self) -> Types<'ast, T> {
|
||||
Types {
|
||||
inner: vec![Type::Array(self)],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> IntoTypes<'ast, T> for UBitwidth {
|
||||
fn into_types(self) -> Types<'ast, T> {
|
||||
Types {
|
||||
inner: vec![Type::Uint(self)],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> IntoTypes<'ast, T> for Types<'ast, T> {
|
||||
fn into_types(self) -> Types<'ast, T> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
|
||||
pub struct Types<'ast, T> {
|
||||
pub inner: Vec<Type<'ast, T>>,
|
||||
}
|
||||
|
||||
impl<'ast, T> Types<'ast, T> {
|
||||
pub fn new(types: Vec<Type<'ast, T>>) -> Self {
|
||||
Self { inner: types }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, Ord)]
|
||||
pub struct GenericIdentifier<'ast> {
|
||||
pub name: &'ast str,
|
||||
|
|
|
@ -190,11 +190,7 @@ pub enum UExpressionInner<'ast, T> {
|
|||
Not(Box<UExpression<'ast, T>>),
|
||||
Neg(Box<UExpression<'ast, T>>),
|
||||
Pos(Box<UExpression<'ast, T>>),
|
||||
FunctionCall(
|
||||
DeclarationFunctionKey<'ast>,
|
||||
Vec<Option<UExpression<'ast, T>>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
),
|
||||
FunctionCall(FunctionCallExpression<'ast, T, UExpression<'ast, T>>),
|
||||
LeftShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
RightShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
IfElse(
|
||||
|
|
Loading…
Reference in a new issue