1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge pull request #894 from Zokrates/function-call-expression

implement function call expression
This commit is contained in:
Thibaut Schaeffer 2021-06-04 15:20:45 +02:00 committed by GitHub
commit 498f31e003
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 635 additions and 681 deletions

View file

@ -1548,7 +1548,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
message: format!("Expected function call argument to be of type {}, found {} of type {}", e.1, e.0, e.0.get_type())
}])?;
let call = TypedExpressionList::FunctionCall(f.clone(), generics_checked.unwrap_or_else(|| vec![None; f.signature.generics.len()]), arguments_checked, assignees.iter().map(|a| a.get_type()).collect());
let call = TypedExpressionList::function_call(f.clone(), generics_checked.unwrap_or_else(|| vec![None; f.signature.generics.len()]), arguments_checked).annotate(Types { inner: assignees.iter().map(|a| a.get_type()).collect()});
Ok(TypedStatement::MultipleDefinition(assignees, call))
},
@ -2102,7 +2102,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]);
let output_types = signature.get_output_types(
let mut output_types = signature.get_output_types(
generics_checked.clone(),
arguments_checked.iter().map(|a| a.get_type()).collect()
).map_err(|e| ErrorInner {
@ -2113,63 +2113,41 @@ impl<'ast, T: Field> Checker<'ast, T> {
),
})?;
let function_key = DeclarationFunctionKey {
module: module_id.to_path_buf(),
id: f.id,
signature: signature.clone(),
};
// the return count has to be 1
match output_types.len() {
1 => match &output_types[0] {
1 => match output_types.pop().unwrap() {
Type::Int => unreachable!(),
Type::FieldElement => Ok(FieldElementExpression::FunctionCall(
DeclarationFunctionKey {
module: module_id.to_path_buf(),
id: f.id,
signature: signature.clone(),
},
Type::FieldElement => Ok(FieldElementExpression::function_call(
function_key,
generics_checked,
arguments_checked,
)
.into()),
Type::Boolean => Ok(BooleanExpression::FunctionCall(
DeclarationFunctionKey {
module: module_id.to_path_buf(),
id: f.id,
signature: signature.clone(),
},
).into()),
Type::Boolean => Ok(BooleanExpression::function_call(
function_key,
generics_checked,
arguments_checked,
)
.into()),
Type::Uint(bitwidth) => Ok(UExpressionInner::FunctionCall(
DeclarationFunctionKey {
module: module_id.to_path_buf(),
id: f.id,
signature: signature.clone(),
},
).into()),
Type::Uint(bitwidth) => Ok(UExpression::function_call(
function_key,
generics_checked,
arguments_checked,
)
.annotate(*bitwidth)
.into()),
Type::Struct(members) => Ok(StructExpressionInner::FunctionCall(
DeclarationFunctionKey {
module: module_id.to_path_buf(),
id: f.id,
signature: signature.clone(),
},
).annotate(bitwidth).into()),
Type::Struct(struct_ty) => Ok(StructExpression::function_call(
function_key,
generics_checked,
arguments_checked,
)
.annotate(members.clone())
.into()),
Type::Array(array_type) => Ok(ArrayExpressionInner::FunctionCall(
DeclarationFunctionKey {
module: module_id.to_path_buf(),
id: f.id,
signature: signature.clone(),
},
).annotate(struct_ty).into()),
Type::Array(array_ty) => Ok(ArrayExpression::function_call(
function_key,
generics_checked,
arguments_checked,
)
.annotate(*array_type.ty.clone(), array_type.size.clone())
.into()),
).annotate(*array_ty.ty, array_ty.size).into()),
},
n => Err(ErrorInner {
pos: Some(pos),
@ -4778,7 +4756,7 @@ mod tests {
typed_absy::Variable::field_element("a").into(),
typed_absy::Variable::field_element("b").into(),
],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location((*MODULE_ID).clone(), "foo").signature(
DeclarationSignature::new().outputs(vec![
DeclarationType::FieldElement,
@ -4787,8 +4765,8 @@ mod tests {
),
vec![],
vec![],
vec![Type::FieldElement, Type::FieldElement],
),
)
.annotate(Types::new(vec![Type::FieldElement, Type::FieldElement])),
),
TypedStatement::Return(vec![FieldElementExpression::Add(
box FieldElementExpression::Identifier("a".into()),

View file

@ -1,5 +1,5 @@
use crate::typed_absy;
use crate::typed_absy::types::UBitwidth;
use crate::typed_absy::{self, Expr};
use crate::zir;
use std::marker::PhantomData;
use zokrates_field::Field;
@ -208,8 +208,8 @@ impl<'ast, T: Field> Flattener<T> {
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
es: typed_absy::TypedExpressionList<'ast, T>,
) -> zir::ZirExpressionList<'ast, T> {
match es {
typed_absy::TypedExpressionList::EmbedCall(embed, generics, arguments, _) => {
match es.into_inner() {
typed_absy::TypedExpressionListInner::EmbedCall(embed, generics, arguments) => {
zir::ZirExpressionList::EmbedCall(
embed,
generics,

View file

@ -367,18 +367,23 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
.collect::<Result<_, _>>()?;
let expression_list = self.fold_expression_list(expression_list)?;
let statements = match expression_list {
TypedExpressionList::EmbedCall(embed, generics, arguments, types) => {
let types = Types {
inner: expression_list
.types
.clone()
.inner
.into_iter()
.map(|t| self.fold_type(t))
.collect::<Result<_, _>>()?,
};
let statements = match expression_list.into_inner() {
TypedExpressionListInner::EmbedCall(embed, generics, arguments) => {
let arguments: Vec<_> = arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect::<Result<_, _>>()?;
let types = types
.into_iter()
.map(|t| self.fold_type(t))
.collect::<Result<_, _>>()?;
fn process_u_from_bits<'ast, T: Field>(
variables: Vec<TypedAssignee<'ast, T>>,
mut arguments: Vec<TypedExpression<'ast, T>>,
@ -602,16 +607,18 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
TypedStatement::Definition(v.clone().into(), c),
TypedStatement::MultipleDefinition(
vec![assignee],
TypedExpressionList::EmbedCall(
embed, generics, arguments, types,
),
TypedExpressionListInner::EmbedCall(
embed, generics, arguments,
)
.annotate(types),
),
],
None => vec![TypedStatement::MultipleDefinition(
vec![assignee],
TypedExpressionList::EmbedCall(
embed, generics, arguments, types,
),
TypedExpressionListInner::EmbedCall(
embed, generics, arguments,
)
.annotate(types),
)],
}
}
@ -623,9 +630,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
let def = TypedStatement::MultipleDefinition(
assignees.clone(),
TypedExpressionList::EmbedCall(
embed, generics, arguments, types,
),
TypedExpressionListInner::EmbedCall(embed, generics, arguments)
.annotate(types),
);
let invalidations = assignees.iter().flat_map(|assignee| {
@ -645,27 +651,29 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
}
}
TypedExpressionList::FunctionCall(key, generics, arguments, types) => {
let generics = generics
TypedExpressionListInner::FunctionCall(function_call) => {
let generics = function_call
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
let arguments: Vec<_> = arguments
let arguments: Vec<_> = function_call
.arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect::<Result<_, _>>()?;
let types = types
.into_iter()
.map(|t| self.fold_type(t))
.collect::<Result<_, _>>()?;
// invalidate the cache for the return assignees as this call mutates them
let def = TypedStatement::MultipleDefinition(
assignees.clone(),
TypedExpressionList::FunctionCall(key, generics, arguments, types),
TypedExpressionList::function_call(
function_call.function_key,
generics,
arguments,
)
.annotate(types),
);
let invalidations = assignees.iter().flat_map(|assignee| {

View file

@ -29,14 +29,14 @@ use crate::embed::FlatEmbed;
use crate::static_analysis::reducer::Output;
use crate::static_analysis::reducer::ShallowTransformer;
use crate::static_analysis::reducer::Versions;
use crate::typed_absy::types::ConcreteGenericsAssignment;
use crate::typed_absy::types::{ConcreteGenericsAssignment, IntoTypes};
use crate::typed_absy::CoreIdentifier;
use crate::typed_absy::Identifier;
use crate::typed_absy::TypedAssignee;
use crate::typed_absy::{
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Signature,
Type, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, UExpression,
UExpressionInner, Variable,
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Expr,
Signature, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, Types,
UExpression, UExpressionInner, Variable,
};
use zokrates_field::Field;
@ -46,13 +46,13 @@ pub enum InlineError<'ast, T> {
FlatEmbed,
Vec<u32>,
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
Types<'ast, T>,
),
NonConstant(
DeclarationFunctionKey<'ast>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
Types<'ast, T>,
),
}
@ -79,11 +79,11 @@ type InlineResult<'ast, T> = Result<
InlineError<'ast, T>,
>;
pub fn inline_call<'a, 'ast, T: Field>(
pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
k: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_types: Vec<Type<'ast, T>>,
output: &E::Ty,
program: &TypedProgram<'ast, T>,
versions: &'a mut Versions<'ast>,
) -> InlineResult<'ast, T> {
@ -91,6 +91,8 @@ pub fn inline_call<'a, 'ast, T: Field>(
use crate::typed_absy::Typed;
let output_types = output.clone().into_types();
// we try to get concrete values for explicit generics
let generics_values: Vec<Option<u32>> = generics
.iter()
@ -117,7 +119,7 @@ pub fn inline_call<'a, 'ast, T: Field>(
let inferred_signature = Signature::new()
.generics(generics.clone())
.inputs(arguments.iter().map(|a| a.get_type()).collect())
.outputs(output_types.clone());
.outputs(output_types.clone().inner);
// we try to get concrete values for the whole signature. if this fails we should propagate again
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {

View file

@ -22,15 +22,12 @@ use crate::typed_absy::Folder;
use std::collections::HashMap;
use crate::typed_absy::{
ArrayExpression, ArrayExpressionInner, ArrayType, BlockExpression, BooleanExpression,
CoreIdentifier, DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier,
StructExpression, StructExpressionInner, StructType, Type, TypedExpression,
TypedExpressionList, TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram,
TypedStatement, UBitwidth, UExpression, UExpressionInner, Variable,
ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall,
FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, TypedExpression,
TypedExpressionList, TypedExpressionListInner, TypedFunction, TypedFunctionSymbol, TypedModule,
TypedProgram, TypedStatement, UExpression, UExpressionInner, Variable,
};
use std::convert::TryInto;
use zokrates_field::Field;
use self::shallow_ssa::ShallowTransformer;
@ -195,32 +192,35 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
complete: true,
}
}
}
fn fold_function_call<E>(
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
type Error = Error;
fn fold_function_call_expression<
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
>(
&mut self,
key: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_type: Type<'ast, T>,
) -> Result<E, Error>
where
E: FunctionCall<'ast, T> + From<TypedExpression<'ast, T>> + std::fmt::Debug,
{
let generics = generics
ty: &E::Ty,
e: FunctionCallExpression<'ast, T, E>,
) -> Result<FunctionCallOrExpression<'ast, T, E>, Self::Error> {
let generics = e
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
let arguments = arguments
let arguments = e
.arguments
.into_iter()
.map(|e| self.fold_expression(e))
.collect::<Result<_, _>>()?;
let res = inline_call(
key.clone(),
let res = inline_call::<_, E>(
e.function_key.clone(),
generics,
arguments,
vec![output_type.clone()],
ty,
&self.program,
&mut self.versions,
);
@ -229,30 +229,31 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
Ok(Output::Complete((statements, mut expressions))) => {
self.complete &= true;
self.statement_buffer.extend(statements);
Ok(expressions.pop().unwrap().try_into().unwrap())
Ok(FunctionCallOrExpression::Expression(
E::from(expressions.pop().unwrap()).into_inner(),
))
}
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
self.complete = false;
self.statement_buffer.extend(statements);
self.for_loop_versions_after.extend(delta_for_loop_versions);
Ok(expressions[0].clone().try_into().unwrap())
Ok(FunctionCallOrExpression::Expression(
E::from(expressions[0].clone()).into_inner(),
))
}
Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!(
"Call site `{}` incompatible with declaration `{}`",
conc.to_string(),
decl.to_string()
))),
Err(InlineError::NonConstant(key, generics, arguments, mut output_types)) => {
Err(InlineError::NonConstant(key, generics, arguments, _)) => {
self.complete = false;
Ok(E::function_call(
key,
generics,
arguments,
output_types.pop().unwrap(),
))
Ok(FunctionCallOrExpression::Expression(E::function_call(
key, generics, arguments,
)))
}
Err(InlineError::Flat(embed, generics, arguments, mut output_types)) => {
Err(InlineError::Flat(embed, generics, arguments, output_types)) => {
let identifier = Identifier::from(CoreIdentifier::Call(0)).version(
*self
.versions
@ -260,23 +261,25 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
.and_modify(|e| *e += 1) // if it was already declared, we increment
.or_insert(0),
);
let var = Variable::with_id_and_type(identifier, output_types.pop().unwrap());
let var = Variable::with_id_and_type(
identifier.clone(),
output_types.clone().inner.pop().unwrap(),
);
let v = vec![var.clone().into()];
self.statement_buffer
.push(TypedStatement::MultipleDefinition(
v,
TypedExpressionList::EmbedCall(embed, generics, arguments, output_types),
TypedExpressionListInner::EmbedCall(embed, generics, arguments)
.annotate(output_types),
));
Ok(TypedExpression::from(var).try_into().unwrap())
Ok(FunctionCallOrExpression::Expression(E::identifier(
identifier,
)))
}
}
}
}
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
type Error = Error;
fn fold_block_expression<E: ResultFold<'ast, T>>(
&mut self,
@ -308,23 +311,28 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
let res = match s {
TypedStatement::MultipleDefinition(
v,
TypedExpressionList::FunctionCall(key, generics, arguments, output_types),
TypedExpressionList {
inner: TypedExpressionListInner::FunctionCall(function_call),
types,
},
) => {
let generics = generics
let generics = function_call
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
let arguments = arguments
let arguments = function_call
.arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect::<Result<_, _>>()?;
match inline_call(
key,
match inline_call::<_, TypedExpressionList<'ast, T>>(
function_call.function_key,
generics,
arguments,
output_types,
&types,
&self.program,
&mut self.versions,
) {
@ -367,23 +375,15 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
Ok(vec![TypedStatement::MultipleDefinition(
v,
TypedExpressionList::FunctionCall(
key,
generics,
arguments,
output_types,
),
TypedExpressionList::function_call(key, generics, arguments)
.annotate(output_types),
)])
}
Err(InlineError::Flat(embed, generics, arguments, output_types)) => {
Ok(vec![TypedStatement::MultipleDefinition(
v,
TypedExpressionList::EmbedCall(
embed,
generics,
arguments,
output_types,
),
TypedExpressionListInner::EmbedCall(embed, generics, arguments)
.annotate(output_types),
)])
}
}
@ -460,62 +460,12 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
res.map(|res| self.statement_buffer.drain(..).chain(res).collect())
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
match e {
BooleanExpression::FunctionCall(key, generics, arguments) => {
self.fold_function_call(key, generics, arguments, Type::Boolean)
}
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Self::Error> {
match e {
UExpressionInner::FunctionCall(key, generics, arguments) => self
.fold_function_call::<UExpression<'ast, T>>(
key,
generics,
arguments,
Type::Uint(bitwidth),
)
.map(|e| e.into_inner()),
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match e {
FieldElementExpression::FunctionCall(key, generic, arguments) => {
self.fold_function_call(key, generic, arguments, Type::FieldElement)
}
e => fold_field_expression(self, e),
}
}
fn fold_array_expression_inner(
&mut self,
array_ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
match e {
ArrayExpressionInner::FunctionCall(key, generics, arguments) => self
.fold_function_call::<ArrayExpression<_>>(
key.clone(),
generics,
arguments.clone(),
Type::array(array_ty.clone()),
)
.map(|e| e.into_inner()),
ArrayExpressionInner::Slice(box array, box from, box to) => {
let array = self.fold_array_expression(array)?;
let from = self.fold_uint_expression(from)?;
@ -531,25 +481,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
}
}
}
_ => fold_array_expression_inner(self, &array_ty, e),
}
}
fn fold_struct_expression_inner(
&mut self,
struct_ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> Result<StructExpressionInner<'ast, T>, Self::Error> {
match e {
StructExpressionInner::FunctionCall(key, generics, arguments) => self
.fold_function_call::<StructExpression<'ast, T>>(
key,
generics,
arguments,
Type::Struct(struct_ty.clone()),
)
.map(|e| e.into_inner()),
_ => fold_struct_expression_inner(self, struct_ty, e),
_ => fold_array_expression_inner(self, array_ty, e),
}
}
}
@ -685,10 +617,10 @@ mod tests {
use crate::typed_absy::types::Constant;
use crate::typed_absy::types::DeclarationSignature;
use crate::typed_absy::{
ArrayExpressionInner, DeclarationFunctionKey, DeclarationType, DeclarationVariable,
FieldElementExpression, GenericIdentifier, Identifier, OwnedTypedModuleId, Select, Type,
TypedExpression, TypedExpressionList, TypedExpressionOrSpread, UBitwidth, UExpressionInner,
Variable,
ArrayExpression, ArrayExpressionInner, DeclarationFunctionKey, DeclarationType,
DeclarationVariable, FieldElementExpression, GenericIdentifier, Identifier,
OwnedTypedModuleId, Select, Type, TypedExpression, TypedExpressionList,
TypedExpressionOrSpread, Types, UBitwidth, UExpressionInner, Variable,
};
use zokrates_field::Bn128Field;
@ -749,7 +681,7 @@ mod tests {
),
TypedStatement::MultipleDefinition(
vec![Variable::field_element("a").into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
@ -757,8 +689,8 @@ mod tests {
),
vec![],
vec![FieldElementExpression::Identifier("a".into()).into()],
vec![Type::FieldElement],
),
)
.annotate(Types::new(vec![Type::FieldElement])),
),
TypedStatement::Definition(
Variable::uint("n", UBitwidth::B32).into(),
@ -947,15 +879,15 @@ mod tests {
),
TypedStatement::MultipleDefinition(
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![None],
vec![ArrayExpressionInner::Identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into()],
vec![Type::array((Type::FieldElement, 1u32))],
),
)
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
),
TypedStatement::Definition(
Variable::uint("n", UBitwidth::B32).into(),
@ -1175,15 +1107,15 @@ mod tests {
),
TypedStatement::MultipleDefinition(
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![None],
vec![ArrayExpressionInner::Identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into()],
vec![Type::array((Type::FieldElement, 1u32))],
),
)
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
),
TypedStatement::Definition(
Variable::uint("n", UBitwidth::B32).into(),
@ -1379,7 +1311,7 @@ mod tests {
)
.into(),
ArrayExpressionInner::Slice(
box ArrayExpressionInner::FunctionCall(
box ArrayExpression::function_call(
DeclarationFunctionKey::with_location("main", "bar")
.signature(foo_signature.clone()),
vec![None],
@ -1450,7 +1382,7 @@ mod tests {
statements: vec![
TypedStatement::MultipleDefinition(
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![None],
@ -1459,8 +1391,8 @@ mod tests {
)
.annotate(Type::FieldElement, 1u32)
.into()],
vec![Type::array((Type::FieldElement, 1u32))],
),
)
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
),
TypedStatement::Return(vec![]),
],
@ -1656,15 +1588,15 @@ mod tests {
statements: vec![
TypedStatement::MultipleDefinition(
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![None],
vec![ArrayExpressionInner::Value(vec![].into())
.annotate(Type::FieldElement, 0u32)
.into()],
vec![Type::array((Type::FieldElement, 1u32))],
),
)
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
),
TypedStatement::Return(vec![]),
],

View file

@ -174,88 +174,18 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
res
}
fn fold_field_expression(
fn fold_function_call_expression<
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
>(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
if let FieldElementExpression::FunctionCall(ref k, _, _) = e {
if !k.id.starts_with('_') {
ty: &E::Ty,
c: FunctionCallExpression<'ast, T, E>,
) -> FunctionCallOrExpression<'ast, T, E> {
if !c.function_key.id.starts_with('_') {
self.blocked = true;
}
}
fold_field_expression(self, e)
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
if let BooleanExpression::FunctionCall(ref k, _, _) = e {
if !k.id.starts_with('_') {
self.blocked = true;
}
};
fold_boolean_expression(self, e)
}
fn fold_uint_expression_inner(
&mut self,
b: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
if let UExpressionInner::FunctionCall(ref k, _, _) = e {
if !k.id.starts_with('_') {
self.blocked = true;
}
};
fold_uint_expression_inner(self, b, e)
}
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
if let ArrayExpressionInner::FunctionCall(ref k, _, _) = e {
if !k.id.starts_with('_') {
self.blocked = true;
}
};
fold_array_expression_inner(self, ty, e)
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
if let StructExpressionInner::FunctionCall(ref k, _, _) = e {
if !k.id.starts_with('_') {
self.blocked = true;
}
};
fold_struct_expression_inner(self, ty, e)
}
fn fold_expression_list(
&mut self,
e: TypedExpressionList<'ast, T>,
) -> TypedExpressionList<'ast, T> {
match e {
TypedExpressionList::FunctionCall(ref k, _, _, _) => {
if !k.id.starts_with('_') {
self.blocked = true;
}
}
_ => unreachable!(),
};
fold_expression_list(self, e)
fold_function_call_expression(self, ty, c)
}
}
@ -440,7 +370,7 @@ mod tests {
let s: TypedStatement<Bn128Field> = TypedStatement::MultipleDefinition(
vec![Variable::field_element("a").into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
@ -448,14 +378,14 @@ mod tests {
),
vec![],
vec![FieldElementExpression::Identifier("a".into()).into()],
vec![Type::FieldElement],
),
)
.annotate(Types::new(vec![Type::FieldElement])),
);
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::MultipleDefinition(
vec![Variable::field_element(Identifier::from("a").version(1)).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
@ -465,9 +395,9 @@ mod tests {
vec![
FieldElementExpression::Identifier(Identifier::from("a").version(0))
.into()
],
vec![Type::FieldElement],
]
)
.annotate(Types::new(vec![Type::FieldElement]))
)]
);
}
@ -887,14 +817,14 @@ mod tests {
),
TypedStatement::MultipleDefinition(
vec![Variable::field_element("a").into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
)],
vec![FieldElementExpression::Identifier("a".into()).into()],
vec![Type::FieldElement],
),
)
.annotate(Types::new(vec![Type::FieldElement])),
),
TypedStatement::Definition(
Variable::uint("n", UBitwidth::B32).into(),
@ -905,7 +835,7 @@ mod tests {
TypedStatement::Definition(
Variable::field_element("a").into(),
(FieldElementExpression::Identifier("a".into())
* FieldElementExpression::FunctionCall(
* FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier("n".into())
@ -962,7 +892,7 @@ mod tests {
),
TypedStatement::MultipleDefinition(
vec![Variable::field_element(Identifier::from("a").version(2)).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier(Identifier::from("n").version(1))
@ -972,8 +902,8 @@ mod tests {
Identifier::from("a").version(1),
)
.into()],
vec![Type::FieldElement],
),
)
.annotate(Types::new(vec![Type::FieldElement])),
),
TypedStatement::Definition(
Variable::uint(Identifier::from("n").version(2), UBitwidth::B32).into(),
@ -984,7 +914,7 @@ mod tests {
TypedStatement::Definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
(FieldElementExpression::Identifier(Identifier::from("a").version(2))
* FieldElementExpression::FunctionCall(
* FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier(Identifier::from("n").version(2))

View file

@ -118,6 +118,10 @@ pub trait Folder<'ast, T: Field>: Sized {
}
}
fn fold_types(&mut self, tys: Types<'ast, T>) -> Types<'ast, T> {
fold_types(self, tys)
}
fn fold_array_type(&mut self, t: ArrayType<'ast, T>) -> ArrayType<'ast, T> {
ArrayType {
ty: box self.fold_type(*t.ty),
@ -199,6 +203,16 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_member_expression(self, ty, e)
}
fn fold_function_call_expression<
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
>(
&mut self,
ty: &E::Ty,
e: FunctionCallExpression<'ast, T, E>,
) -> FunctionCallOrExpression<'ast, T, E> {
fold_function_call_expression(self, ty, e)
}
fn fold_select_expression<
E: Expr<'ast, T> + Select<'ast, T> + IfElse<'ast, T> + From<TypedExpression<'ast, T>>,
>(
@ -227,6 +241,14 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_expression_list(self, es)
}
fn fold_expression_list_inner(
&mut self,
tys: Types<'ast, T>,
es: TypedExpressionListInner<'ast, T>,
) -> TypedExpressionListInner<'ast, T> {
fold_expression_list_inner(self, tys, es)
}
fn fold_int_expression(&mut self, e: IntExpression<'ast, T>) -> IntExpression<'ast, T> {
fold_int_expression(self, e)
}
@ -345,13 +367,13 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
.map(|e| f.fold_expression_or_spread(e))
.collect(),
),
ArrayExpressionInner::FunctionCall(id, generics, exps) => {
let generics = generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)))
.collect();
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
ArrayExpressionInner::FunctionCall(id, generics, exps)
ArrayExpressionInner::FunctionCall(function_call) => {
match f.fold_function_call_expression(ty, function_call) {
FunctionCallOrExpression::FunctionCall(function_call) => {
ArrayExpressionInner::FunctionCall(function_call)
}
FunctionCallOrExpression::Expression(u) => u,
}
}
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
ArrayExpressionInner::IfElse(
@ -395,13 +417,13 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
StructExpressionInner::Value(exprs) => {
StructExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect())
}
StructExpressionInner::FunctionCall(id, generics, exps) => {
let generics = generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)))
.collect();
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
StructExpressionInner::FunctionCall(id, generics, exps)
StructExpressionInner::FunctionCall(function_call) => {
match f.fold_function_call_expression(ty, function_call) {
FunctionCallOrExpression::FunctionCall(function_call) => {
StructExpressionInner::FunctionCall(function_call)
}
FunctionCallOrExpression::Expression(u) => u,
}
}
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
StructExpressionInner::IfElse(
@ -474,13 +496,13 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
let alt = f.fold_field_expression(alt);
FieldElementExpression::IfElse(box cond, box cons, box alt)
}
FieldElementExpression::FunctionCall(key, generics, exps) => {
let generics = generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)))
.collect();
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
FieldElementExpression::FunctionCall(key, generics, exps)
FieldElementExpression::FunctionCall(function_call) => {
match f.fold_function_call_expression(&Type::FieldElement, function_call) {
FunctionCallOrExpression::FunctionCall(function_call) => {
FieldElementExpression::FunctionCall(function_call)
}
FunctionCallOrExpression::Expression(u) => u,
}
}
FieldElementExpression::Select(select) => {
match f.fold_select_expression(&Type::FieldElement, select) {
@ -622,13 +644,13 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e = f.fold_boolean_expression(e);
BooleanExpression::Not(box e)
}
BooleanExpression::FunctionCall(key, generics, exps) => {
let generics = generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)))
.collect();
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
BooleanExpression::FunctionCall(key, generics, exps)
BooleanExpression::FunctionCall(function_call) => {
match f.fold_function_call_expression(&Type::Boolean, function_call) {
FunctionCallOrExpression::FunctionCall(function_call) => {
BooleanExpression::FunctionCall(function_call)
}
FunctionCallOrExpression::Expression(u) => u,
}
}
BooleanExpression::IfElse(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond);
@ -748,13 +770,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
UExpressionInner::Pos(box e)
}
UExpressionInner::FunctionCall(key, generics, exps) => {
let generics = generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)))
.collect();
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
UExpressionInner::FunctionCall(key, generics, exps)
UExpressionInner::FunctionCall(function_call) => {
match f.fold_function_call_expression(&ty, function_call) {
FunctionCallOrExpression::FunctionCall(function_call) => {
UExpressionInner::FunctionCall(function_call)
}
FunctionCallOrExpression::Expression(u) => u,
}
}
UExpressionInner::Select(select) => match f.fold_select_expression(&ty, select) {
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
@ -797,6 +819,29 @@ pub fn fold_declaration_function_key<'ast, T: Field, F: Folder<'ast, T>>(
}
}
pub fn fold_function_call_expression<
'ast,
T: Field,
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
F: Folder<'ast, T>,
>(
f: &mut F,
_: &E::Ty,
e: FunctionCallExpression<'ast, T, E>,
) -> FunctionCallOrExpression<'ast, T, E> {
FunctionCallOrExpression::FunctionCall(FunctionCallExpression::new(
e.function_key,
e.generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)))
.collect(),
e.arguments
.into_iter()
.map(|e| f.fold_expression(e))
.collect(),
))
}
pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
fun: TypedFunction<'ast, T>,
@ -851,30 +896,45 @@ pub fn fold_expression_list<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
es: TypedExpressionList<'ast, T>,
) -> TypedExpressionList<'ast, T> {
match es {
TypedExpressionList::FunctionCall(id, generics, arguments, types) => {
TypedExpressionList::FunctionCall(
id,
generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)))
.collect(),
arguments
.into_iter()
.map(|a| f.fold_expression(a))
.collect(),
types.into_iter().map(|t| f.fold_type(t)).collect(),
)
let types = f.fold_types(es.types);
TypedExpressionList {
inner: f.fold_expression_list_inner(types.clone(), es.inner),
types,
}
TypedExpressionList::EmbedCall(embed, generics, arguments, types) => {
TypedExpressionList::EmbedCall(
}
pub fn fold_types<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
tys: Types<'ast, T>,
) -> Types<'ast, T> {
Types {
inner: tys.inner.into_iter().map(|t| f.fold_type(t)).collect(),
}
}
pub fn fold_expression_list_inner<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
tys: Types<'ast, T>,
es: TypedExpressionListInner<'ast, T>,
) -> TypedExpressionListInner<'ast, T> {
match es {
TypedExpressionListInner::FunctionCall(function_call) => {
match f.fold_function_call_expression(&tys, function_call) {
FunctionCallOrExpression::FunctionCall(function_call) => {
TypedExpressionListInner::FunctionCall(function_call)
}
FunctionCallOrExpression::Expression(u) => u,
}
}
TypedExpressionListInner::EmbedCall(embed, generics, arguments) => {
TypedExpressionListInner::EmbedCall(
embed,
generics,
arguments
.into_iter()
.map(|a| f.fold_expression(a))
.collect(),
types.into_iter().map(|t| f.fold_type(t)).collect(),
)
}
}

View file

@ -21,7 +21,7 @@ pub use self::parameter::{DeclarationParameter, GParameter};
pub use self::types::{
ConcreteFunctionKey, ConcreteSignature, ConcreteType, DeclarationFunctionKey,
DeclarationSignature, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
Signature, StructType, Type, UBitwidth,
IntoTypes, Signature, StructType, Type, Types, UBitwidth,
};
use crate::typed_absy::types::ConcreteGenericsAssignment;
@ -638,30 +638,8 @@ impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> {
.collect::<Vec<String>>()
.join(", ")
),
StructExpressionInner::FunctionCall(ref key, ref generics, ref p) => {
write!(f, "{}", key.id,)?;
if !generics.is_empty() {
write!(
f,
"::<{}>",
generics
.iter()
.map(|g| g
.as_ref()
.map(|g| g.to_string())
.unwrap_or_else(|| '_'.to_string()))
.collect::<Vec<_>>()
.join(", ")
)?;
}
write!(f, "(")?;
for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?;
if i < p.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
StructExpressionInner::FunctionCall(ref function_call) => {
write!(f, "{}", function_call)
}
StructExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => {
write!(
@ -724,27 +702,33 @@ pub trait MultiTyped<'ast, T> {
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub enum TypedExpressionList<'ast, T> {
FunctionCall(
DeclarationFunctionKey<'ast>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
),
EmbedCall(
FlatEmbed,
Vec<u32>,
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
),
pub struct TypedExpressionList<'ast, T> {
pub inner: TypedExpressionListInner<'ast, T>,
pub types: Types<'ast, T>,
}
impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionList<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.inner)
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub enum TypedExpressionListInner<'ast, T> {
FunctionCall(FunctionCallExpression<'ast, T, TypedExpressionList<'ast, T>>),
EmbedCall(FlatEmbed, Vec<u32>, Vec<TypedExpression<'ast, T>>),
}
impl<'ast, T> MultiTyped<'ast, T> for TypedExpressionList<'ast, T> {
fn get_types(&self) -> &Vec<Type<'ast, T>> {
match *self {
TypedExpressionList::FunctionCall(_, _, _, ref types) => types,
TypedExpressionList::EmbedCall(_, _, _, ref types) => types,
&self.types.inner
}
}
impl<'ast, T> TypedExpressionListInner<'ast, T> {
pub fn annotate(self, types: Types<'ast, T>) -> TypedExpressionList<'ast, T> {
TypedExpressionList { inner: self, types }
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
@ -808,6 +792,58 @@ impl<'ast, T: fmt::Display, E> fmt::Display for SelectExpression<'ast, T, E> {
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub struct FunctionCallExpression<'ast, T, E> {
pub function_key: DeclarationFunctionKey<'ast>,
pub generics: Vec<Option<UExpression<'ast, T>>>,
pub arguments: Vec<TypedExpression<'ast, T>>,
ty: PhantomData<E>,
}
impl<'ast, T, E> FunctionCallExpression<'ast, T, E> {
pub fn new(
function_key: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self {
FunctionCallExpression {
function_key,
generics,
arguments,
ty: PhantomData,
}
}
}
impl<'ast, T: fmt::Display, E> fmt::Display for FunctionCallExpression<'ast, T, E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.function_key.id,)?;
if !self.generics.is_empty() {
write!(
f,
"::<{}>",
self.generics
.iter()
.map(|g| g
.as_ref()
.map(|g| g.to_string())
.unwrap_or_else(|| '_'.to_string()))
.collect::<Vec<_>>()
.join(", ")
)?;
}
write!(
f,
"({})",
self.arguments
.iter()
.map(|a| a.to_string())
.collect::<Vec<_>>()
.join(",")
)
}
}
/// An expression of type `field`
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
pub enum FieldElementExpression<'ast, T> {
@ -841,11 +877,7 @@ pub enum FieldElementExpression<'ast, T> {
),
Neg(Box<FieldElementExpression<'ast, T>>),
Pos(Box<FieldElementExpression<'ast, T>>),
FunctionCall(
DeclarationFunctionKey<'ast>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
),
FunctionCall(FunctionCallExpression<'ast, T, Self>),
Member(MemberExpression<'ast, T, Self>),
Select(SelectExpression<'ast, T, Self>),
}
@ -948,11 +980,7 @@ pub enum BooleanExpression<'ast, T> {
Box<BooleanExpression<'ast, T>>,
),
Member(MemberExpression<'ast, T, Self>),
FunctionCall(
DeclarationFunctionKey<'ast>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
),
FunctionCall(FunctionCallExpression<'ast, T, Self>),
Select(SelectExpression<'ast, T, Self>),
}
@ -1056,11 +1084,7 @@ pub enum ArrayExpressionInner<'ast, T> {
Block(BlockExpression<'ast, T, ArrayExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Value(ArrayValue<'ast, T>),
FunctionCall(
DeclarationFunctionKey<'ast>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
),
FunctionCall(FunctionCallExpression<'ast, T, ArrayExpression<'ast, T>>),
IfElse(
Box<BooleanExpression<'ast, T>>,
Box<ArrayExpression<'ast, T>>,
@ -1165,11 +1189,7 @@ pub enum StructExpressionInner<'ast, T> {
Block(BlockExpression<'ast, T, StructExpression<'ast, T>>),
Identifier(Identifier<'ast>),
Value(Vec<TypedExpression<'ast, T>>),
FunctionCall(
DeclarationFunctionKey<'ast>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
),
FunctionCall(FunctionCallExpression<'ast, T, StructExpression<'ast, T>>),
IfElse(
Box<BooleanExpression<'ast, T>>,
Box<StructExpression<'ast, T>>,
@ -1243,6 +1263,15 @@ impl<'ast, T> From<TypedExpression<'ast, T>> for StructExpression<'ast, T> {
}
}
// `TypedExpressionList` can technically not be constructed from `TypedExpression`
// However implementing `From<TypedExpression>` is required for `TypedExpressionList` to be `Expr`, which makes generic treatment of function calls possible
// This could maybe be avoided by splitting the `Expr` trait into many, but I did not find a way
impl<'ast, T> From<TypedExpression<'ast, T>> for TypedExpressionList<'ast, T> {
fn from(_: TypedExpression<'ast, T>) -> TypedExpressionList<'ast, T> {
unreachable!()
}
}
impl<'ast, T> From<TypedConstant<'ast, T>> for FieldElementExpression<'ast, T> {
fn from(tc: TypedConstant<'ast, T>) -> FieldElementExpression<'ast, T> {
tc.expression.into()
@ -1314,30 +1343,8 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
condition, consequent, alternative
)
}
FieldElementExpression::FunctionCall(ref k, ref generics, ref p) => {
write!(f, "{}", k.id,)?;
if !generics.is_empty() {
write!(
f,
"::<{}>",
generics
.iter()
.map(|g| g
.as_ref()
.map(|g| g.to_string())
.unwrap_or_else(|| '_'.to_string()))
.collect::<Vec<_>>()
.join(", ")
)?;
}
write!(f, "(")?;
for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?;
if i < p.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
FieldElementExpression::FunctionCall(ref function_call) => {
write!(f, "{}", function_call)
}
FieldElementExpression::Member(ref m) => write!(f, "{}", m),
FieldElementExpression::Select(ref select) => write!(f, "{}", select),
@ -1368,31 +1375,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
UExpressionInner::Neg(ref e) => write!(f, "(-{})", e),
UExpressionInner::Pos(ref e) => write!(f, "(+{})", e),
UExpressionInner::Select(ref select) => write!(f, "{}", select),
UExpressionInner::FunctionCall(ref k, ref generics, ref p) => {
write!(f, "{}", k.id,)?;
if !generics.is_empty() {
write!(
f,
"::<{}>",
generics
.iter()
.map(|g| g
.as_ref()
.map(|g| g.to_string())
.unwrap_or_else(|| '_'.to_string()))
.collect::<Vec<_>>()
.join(", ")
)?;
}
write!(f, "(")?;
for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?;
if i < p.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
}
UExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call),
UExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
"if {} then {} else {} fi",
@ -1425,31 +1408,7 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
BooleanExpression::Value(b) => write!(f, "{}", b),
BooleanExpression::FunctionCall(ref k, ref generics, ref p) => {
write!(f, "{}", k.id,)?;
if !generics.is_empty() {
write!(
f,
"::<{}>",
generics
.iter()
.map(|g| g
.as_ref()
.map(|g| g.to_string())
.unwrap_or_else(|| '_'.to_string()))
.collect::<Vec<_>>()
.join(", ")
)?;
}
write!(f, "(")?;
for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?;
if i < p.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
}
BooleanExpression::FunctionCall(ref function_call) => write!(f, "{}", function_call),
BooleanExpression::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
"if {} then {} else {} fi",
@ -1475,31 +1434,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> {
.collect::<Vec<String>>()
.join(", ")
),
ArrayExpressionInner::FunctionCall(ref key, ref generics, ref p) => {
write!(f, "{}", key.id,)?;
if !generics.is_empty() {
write!(
f,
"::<{}>",
generics
.iter()
.map(|g| g
.as_ref()
.map(|g| g.to_string())
.unwrap_or_else(|| '_'.to_string()))
.collect::<Vec<_>>()
.join(", ")
)?;
}
write!(f, "(")?;
for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?;
if i < p.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
}
ArrayExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call),
ArrayExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
"if {} then {} else {} fi",
@ -1517,35 +1452,13 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> {
}
}
impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionList<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionListInner<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
TypedExpressionList::FunctionCall(ref k, ref generics, ref p, _) => {
write!(f, "{}", k.id,)?;
if !generics.is_empty() {
write!(
f,
"::<{}>",
generics
.iter()
.map(|g| g
.as_ref()
.map(|g| g.to_string())
.unwrap_or_else(|| '_'.to_string()))
.collect::<Vec<_>>()
.join(", ")
)?;
TypedExpressionListInner::FunctionCall(ref function_call) => {
write!(f, "{}", function_call)
}
write!(f, "(")?;
for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?;
if i < p.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
}
TypedExpressionList::EmbedCall(ref embed, ref generics, ref p, _) => {
TypedExpressionListInner::EmbedCall(ref embed, ref generics, ref p) => {
write!(f, "{}", embed.id())?;
if !generics.is_empty() {
write!(
@ -1590,16 +1503,16 @@ impl<'ast, T: Field> From<Variable<'ast, T>> for TypedExpression<'ast, T> {
// Common behaviour across expressions
pub trait Expr<'ast, T> {
pub trait Expr<'ast, T>: From<TypedExpression<'ast, T>> {
type Inner;
type Ty;
type Ty: Clone + IntoTypes<'ast, T>;
fn into_inner(self) -> Self::Inner;
fn as_inner(&self) -> &Self::Inner;
}
impl<'ast, T> Expr<'ast, T> for FieldElementExpression<'ast, T> {
impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
@ -1612,7 +1525,7 @@ impl<'ast, T> Expr<'ast, T> for FieldElementExpression<'ast, T> {
}
}
impl<'ast, T> Expr<'ast, T> for BooleanExpression<'ast, T> {
impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
@ -1625,7 +1538,7 @@ impl<'ast, T> Expr<'ast, T> for BooleanExpression<'ast, T> {
}
}
impl<'ast, T> Expr<'ast, T> for UExpression<'ast, T> {
impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> {
type Inner = UExpressionInner<'ast, T>;
type Ty = UBitwidth;
@ -1638,7 +1551,7 @@ impl<'ast, T> Expr<'ast, T> for UExpression<'ast, T> {
}
}
impl<'ast, T> Expr<'ast, T> for StructExpression<'ast, T> {
impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> {
type Inner = StructExpressionInner<'ast, T>;
type Ty = StructType<'ast, T>;
@ -1651,7 +1564,7 @@ impl<'ast, T> Expr<'ast, T> for StructExpression<'ast, T> {
}
}
impl<'ast, T> Expr<'ast, T> for ArrayExpression<'ast, T> {
impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> {
type Inner = ArrayExpressionInner<'ast, T>;
type Ty = ArrayType<'ast, T>;
@ -1664,7 +1577,7 @@ impl<'ast, T> Expr<'ast, T> for ArrayExpression<'ast, T> {
}
}
impl<'ast, T> Expr<'ast, T> for IntExpression<'ast, T> {
impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
@ -1677,8 +1590,25 @@ impl<'ast, T> Expr<'ast, T> for IntExpression<'ast, T> {
}
}
impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> {
type Inner = TypedExpressionListInner<'ast, T>;
type Ty = Types<'ast, T>;
fn into_inner(self) -> Self::Inner {
self.inner
}
fn as_inner(&self) -> &Self::Inner {
&self.inner
}
}
// Enums types to enable returning e.g a member expression OR another type of expression of this type
pub enum FunctionCallOrExpression<'ast, T, E: Expr<'ast, T>> {
FunctionCall(FunctionCallExpression<'ast, T, E>),
Expression(E::Inner),
}
pub enum SelectOrExpression<'ast, T, E: Expr<'ast, T>> {
Select(SelectExpression<'ast, T, E>),
Expression(E::Inner),
@ -1886,13 +1816,55 @@ impl<'ast, T: Clone> Member<'ast, T> for StructExpression<'ast, T> {
}
}
pub trait FunctionCall<'ast, T> {
pub trait Id<'ast, T>: Expr<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner;
}
impl<'ast, T: Field> Id<'ast, T> for FieldElementExpression<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner {
FieldElementExpression::Identifier(id)
}
}
impl<'ast, T: Field> Id<'ast, T> for BooleanExpression<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner {
BooleanExpression::Identifier(id)
}
}
impl<'ast, T: Field> Id<'ast, T> for UExpression<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner {
UExpressionInner::Identifier(id)
}
}
impl<'ast, T: Field> Id<'ast, T> for ArrayExpression<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner {
ArrayExpressionInner::Identifier(id)
}
}
impl<'ast, T: Field> Id<'ast, T> for StructExpression<'ast, T> {
fn identifier(id: Identifier<'ast>) -> Self::Inner {
StructExpressionInner::Identifier(id)
}
}
// `TypedExpressionList` does not have an Identifier variant
// However implementing `From<TypedExpression>` is required for `TypedExpressionList` to be `Expr`, which makes generic treatment of function calls possible
// This could maybe be avoided by splitting the `Expr` trait into many, but I did not find a way
impl<'ast, T: Field> Id<'ast, T> for TypedExpressionList<'ast, T> {
fn identifier(_: Identifier<'ast>) -> Self::Inner {
unreachable!()
}
}
pub trait FunctionCall<'ast, T>: Expr<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_type: Type<'ast, T>,
) -> Self;
) -> Self::Inner;
}
impl<'ast, T: Field> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> {
@ -1900,10 +1872,8 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> {
key: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_type: Type<'ast, T>,
) -> Self {
assert_eq!(output_type, Type::FieldElement);
FieldElementExpression::FunctionCall(key, generics, arguments)
) -> Self::Inner {
FieldElementExpression::FunctionCall(FunctionCallExpression::new(key, generics, arguments))
}
}
@ -1912,10 +1882,8 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for BooleanExpression<'ast, T> {
key: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_type: Type<'ast, T>,
) -> Self {
assert_eq!(output_type, Type::Boolean);
BooleanExpression::FunctionCall(key, generics, arguments)
) -> Self::Inner {
BooleanExpression::FunctionCall(FunctionCallExpression::new(key, generics, arguments))
}
}
@ -1924,13 +1892,8 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for UExpression<'ast, T> {
key: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_type: Type<'ast, T>,
) -> Self {
let bitwidth = match output_type {
Type::Uint(bitwidth) => bitwidth,
_ => unreachable!(),
};
UExpressionInner::FunctionCall(key, generics, arguments).annotate(bitwidth)
) -> Self::Inner {
UExpressionInner::FunctionCall(FunctionCallExpression::new(key, generics, arguments))
}
}
@ -1939,14 +1902,8 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for ArrayExpression<'ast, T> {
key: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_type: Type<'ast, T>,
) -> Self {
let array_ty = match output_type {
Type::Array(array_ty) => array_ty,
_ => unreachable!(),
};
ArrayExpressionInner::FunctionCall(key, generics, arguments)
.annotate(*array_ty.ty, array_ty.size)
) -> Self::Inner {
ArrayExpressionInner::FunctionCall(FunctionCallExpression::new(key, generics, arguments))
}
}
@ -1955,14 +1912,20 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for StructExpression<'ast, T> {
key: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_type: Type<'ast, T>,
) -> Self {
let struct_ty = match output_type {
Type::Struct(struct_ty) => struct_ty,
_ => unreachable!(),
};
) -> Self::Inner {
StructExpressionInner::FunctionCall(FunctionCallExpression::new(key, generics, arguments))
}
}
StructExpressionInner::FunctionCall(key, generics, arguments).annotate(struct_ty)
impl<'ast, T: Field> FunctionCall<'ast, T> for TypedExpressionList<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self::Inner {
TypedExpressionListInner::FunctionCall(FunctionCallExpression::new(
key, generics, arguments,
))
}
}

View file

@ -138,6 +138,10 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
}
}
fn fold_types(&mut self, tys: Types<'ast, T>) -> Result<Types<'ast, T>, Self::Error> {
fold_types(self, tys)
}
fn fold_block_expression<E: ResultFold<'ast, T>>(
&mut self,
block: BlockExpression<'ast, T, E>,
@ -165,6 +169,16 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_select_expression(self, ty, e)
}
fn fold_function_call_expression<
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
>(
&mut self,
ty: &E::Ty,
e: FunctionCallExpression<'ast, T, E>,
) -> Result<FunctionCallOrExpression<'ast, T, E>, Self::Error> {
fold_function_call_expression(self, ty, e)
}
fn fold_array_type(
&mut self,
t: ArrayType<'ast, T>,
@ -274,6 +288,14 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_struct_expression(self, e)
}
fn fold_expression_list_inner(
&mut self,
tys: &Types<'ast, T>,
es: TypedExpressionListInner<'ast, T>,
) -> Result<TypedExpressionListInner<'ast, T>, Self::Error> {
fold_expression_list_inner(self, tys, es)
}
fn fold_expression_list(
&mut self,
es: TypedExpressionList<'ast, T>,
@ -387,16 +409,11 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
.map(|e| f.fold_expression_or_spread(e))
.collect::<Result<_, _>>()?,
),
ArrayExpressionInner::FunctionCall(id, generics, exps) => {
let generics = generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
let exps = exps
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?;
ArrayExpressionInner::FunctionCall(id, generics, exps)
ArrayExpressionInner::FunctionCall(function_call) => {
match f.fold_function_call_expression(ty, function_call)? {
FunctionCallOrExpression::FunctionCall(c) => ArrayExpressionInner::FunctionCall(c),
FunctionCallOrExpression::Expression(u) => u,
}
}
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
ArrayExpressionInner::IfElse(
@ -446,16 +463,11 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?,
),
StructExpressionInner::FunctionCall(id, generics, exps) => {
let generics = generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
let exps = exps
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?;
StructExpressionInner::FunctionCall(id, generics, exps)
StructExpressionInner::FunctionCall(function_call) => {
match f.fold_function_call_expression(ty, function_call)? {
FunctionCallOrExpression::FunctionCall(c) => StructExpressionInner::FunctionCall(c),
FunctionCallOrExpression::Expression(u) => u,
}
}
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
StructExpressionInner::IfElse(
@ -529,16 +541,13 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
let alt = f.fold_field_expression(alt)?;
FieldElementExpression::IfElse(box cond, box cons, box alt)
}
FieldElementExpression::FunctionCall(key, generics, exps) => {
let generics = generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
let exps = exps
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?;
FieldElementExpression::FunctionCall(key, generics, exps)
FieldElementExpression::FunctionCall(function_call) => {
match f.fold_function_call_expression(&Type::FieldElement, function_call)? {
FunctionCallOrExpression::FunctionCall(c) => {
FieldElementExpression::FunctionCall(c)
}
FunctionCallOrExpression::Expression(u) => u,
}
}
FieldElementExpression::Member(m) => {
match f.fold_member_expression(&Type::FieldElement, m)? {
@ -612,6 +621,29 @@ pub fn fold_select_expression<
)))
}
pub fn fold_function_call_expression<
'ast,
T: Field,
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
F: ResultFolder<'ast, T>,
>(
f: &mut F,
_: &E::Ty,
e: FunctionCallExpression<'ast, T, E>,
) -> Result<FunctionCallOrExpression<'ast, T, E>, F::Error> {
Ok(FunctionCallOrExpression::Expression(E::function_call(
e.function_key,
e.generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?,
e.arguments
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?,
)))
}
pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
e: BooleanExpression<'ast, T>,
@ -701,16 +733,11 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
let e = f.fold_boolean_expression(e)?;
BooleanExpression::Not(box e)
}
BooleanExpression::FunctionCall(key, generics, exps) => {
let generics = generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
let exps = exps
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?;
BooleanExpression::FunctionCall(key, generics, exps)
BooleanExpression::FunctionCall(function_call) => {
match f.fold_function_call_expression(&Type::Boolean, function_call)? {
FunctionCallOrExpression::FunctionCall(c) => BooleanExpression::FunctionCall(c),
FunctionCallOrExpression::Expression(u) => u,
}
}
BooleanExpression::IfElse(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond)?;
@ -832,16 +859,11 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
UExpressionInner::Pos(box e)
}
UExpressionInner::FunctionCall(key, generics, exps) => {
let generics = generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
let exps = exps
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?;
UExpressionInner::FunctionCall(key, generics, exps)
UExpressionInner::FunctionCall(function_call) => {
match f.fold_function_call_expression(&ty, function_call)? {
FunctionCallOrExpression::FunctionCall(c) => UExpressionInner::FunctionCall(c),
FunctionCallOrExpression::Expression(u) => u,
}
}
UExpressionInner::Select(select) => match f.fold_select_expression(&ty, select)? {
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
@ -928,37 +950,49 @@ pub fn fold_expression_list<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
es: TypedExpressionList<'ast, T>,
) -> Result<TypedExpressionList<'ast, T>, F::Error> {
match es {
TypedExpressionList::FunctionCall(id, generics, arguments, types) => {
let generics = generics
.into_iter()
.map(|g| g.map(|g| f.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
Ok(TypedExpressionList::FunctionCall(
id,
generics,
arguments
.into_iter()
.map(|a| f.fold_expression(a))
.collect::<Result<_, _>>()?,
types
let types = f.fold_types(es.types)?;
Ok(TypedExpressionList {
inner: f.fold_expression_list_inner(&types, es.inner)?,
types,
})
}
pub fn fold_types<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
tys: Types<'ast, T>,
) -> Result<Types<'ast, T>, F::Error> {
Ok(Types {
inner: tys
.inner
.into_iter()
.map(|t| f.fold_type(t))
.collect::<Result<_, _>>()?,
))
})
}
pub fn fold_expression_list_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
tys: &Types<'ast, T>,
es: TypedExpressionListInner<'ast, T>,
) -> Result<TypedExpressionListInner<'ast, T>, F::Error> {
match es {
TypedExpressionListInner::FunctionCall(function_call) => {
match f.fold_function_call_expression(tys, function_call)? {
FunctionCallOrExpression::FunctionCall(function_call) => {
Ok(TypedExpressionListInner::FunctionCall(function_call))
}
TypedExpressionList::EmbedCall(embed, generics, arguments, types) => {
Ok(TypedExpressionList::EmbedCall(
FunctionCallOrExpression::Expression(list) => Ok(list),
}
}
TypedExpressionListInner::EmbedCall(embed, generics, arguments) => {
Ok(TypedExpressionListInner::EmbedCall(
embed,
generics,
arguments
.into_iter()
.map(|a| f.fold_expression(a))
.collect::<Result<_, _>>()?,
types
.into_iter()
.map(|t| f.fold_type(t))
.collect::<Result<_, _>>()?,
))
}
}

View file

@ -6,6 +6,57 @@ use std::fmt;
use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf};
pub trait IntoTypes<'ast, T> {
fn into_types(self) -> Types<'ast, T>;
}
impl<'ast, T> IntoTypes<'ast, T> for Type<'ast, T> {
fn into_types(self) -> Types<'ast, T> {
Types { inner: vec![self] }
}
}
impl<'ast, T> IntoTypes<'ast, T> for StructType<'ast, T> {
fn into_types(self) -> Types<'ast, T> {
Types {
inner: vec![Type::Struct(self)],
}
}
}
impl<'ast, T> IntoTypes<'ast, T> for ArrayType<'ast, T> {
fn into_types(self) -> Types<'ast, T> {
Types {
inner: vec![Type::Array(self)],
}
}
}
impl<'ast, T> IntoTypes<'ast, T> for UBitwidth {
fn into_types(self) -> Types<'ast, T> {
Types {
inner: vec![Type::Uint(self)],
}
}
}
impl<'ast, T> IntoTypes<'ast, T> for Types<'ast, T> {
fn into_types(self) -> Types<'ast, T> {
self
}
}
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
pub struct Types<'ast, T> {
pub inner: Vec<Type<'ast, T>>,
}
impl<'ast, T> Types<'ast, T> {
pub fn new(types: Vec<Type<'ast, T>>) -> Self {
Self { inner: types }
}
}
#[derive(Debug, Clone, Eq, Ord)]
pub struct GenericIdentifier<'ast> {
pub name: &'ast str,

View file

@ -190,11 +190,7 @@ pub enum UExpressionInner<'ast, T> {
Not(Box<UExpression<'ast, T>>),
Neg(Box<UExpression<'ast, T>>),
Pos(Box<UExpression<'ast, T>>),
FunctionCall(
DeclarationFunctionKey<'ast>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
),
FunctionCall(FunctionCallExpression<'ast, T, UExpression<'ast, T>>),
LeftShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
RightShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
IfElse(