wip
This commit is contained in:
parent
bb262ddb4c
commit
898b3eb7e8
26 changed files with 1602 additions and 456 deletions
18
test.zok
18
test.zok
|
@ -1,10 +1,10 @@
|
|||
def sum_array<P>(field[P] a) -> field:
|
||||
field res = 0
|
||||
for u32 i in 0..P do
|
||||
res = res + a[i]
|
||||
def main():
|
||||
field a = 3
|
||||
for u32 i in 0..2 do
|
||||
for u32 j in 0..2 do
|
||||
a = a + 1
|
||||
endfor
|
||||
a = a + 1
|
||||
endfor
|
||||
return res
|
||||
|
||||
def main(u32 i) -> (field):
|
||||
u32[2][4] a = [[1, 2, 3, 4], [1, 2, 3, 4]]
|
||||
return 1
|
||||
field b = a
|
||||
return
|
|
@ -52,7 +52,7 @@ impl FlatEmbed {
|
|||
}
|
||||
|
||||
pub fn key<T: Field>(&self) -> ConcreteFunctionKey<'static> {
|
||||
ConcreteFunctionKey::with_id(self.id()).signature(self.signature())
|
||||
ConcreteFunctionKey::with_location("#EMBED#", self.id()).signature(self.signature())
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &'static str {
|
||||
|
|
|
@ -889,7 +889,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.inputs(param_expressions.iter().map(|e| e.get_type()).collect())
|
||||
.outputs(return_types);
|
||||
|
||||
let key = FunctionKey::with_id(id).signature(passed_signature);
|
||||
let key = FunctionKey::with_location("#EMBED#", id).signature(passed_signature);
|
||||
|
||||
let funct = self.get_embed(&key, &symbols);
|
||||
|
||||
|
|
|
@ -422,12 +422,18 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
};
|
||||
|
||||
self.functions.insert(
|
||||
DeclarationFunctionKey::with_id(declaration.id.clone())
|
||||
.signature(funct.signature.clone()),
|
||||
DeclarationFunctionKey::with_location(
|
||||
module_id.clone(),
|
||||
declaration.id.clone(),
|
||||
)
|
||||
.signature(funct.signature.clone()),
|
||||
);
|
||||
functions.insert(
|
||||
DeclarationFunctionKey::with_id(declaration.id.clone())
|
||||
.signature(funct.signature.clone()),
|
||||
DeclarationFunctionKey::with_location(
|
||||
module_id.clone(),
|
||||
declaration.id.clone(),
|
||||
)
|
||||
.signature(funct.signature.clone()),
|
||||
TypedFunctionSymbol::Here(funct),
|
||||
);
|
||||
}
|
||||
|
@ -450,6 +456,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.iter()
|
||||
.filter(|(k, _)| k.id == import.symbol_id)
|
||||
.map(|(_, v)| DeclarationFunctionKey {
|
||||
module: import.module_id.clone(),
|
||||
id: import.symbol_id.clone(),
|
||||
signature: v.signature(&state.typed_modules).clone(),
|
||||
})
|
||||
|
@ -525,12 +532,12 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
true => {}
|
||||
};
|
||||
|
||||
self.functions.insert(candidate.clone().id(declaration.id));
|
||||
let local_key = candidate.clone().id(declaration.id).module(module_id.clone());
|
||||
|
||||
self.functions.insert(local_key.clone());
|
||||
functions.insert(
|
||||
candidate.clone().id(declaration.id),
|
||||
TypedFunctionSymbol::There(
|
||||
candidate,
|
||||
import.module_id.clone(),
|
||||
local_key,
|
||||
TypedFunctionSymbol::There(candidate,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
@ -562,12 +569,18 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
};
|
||||
|
||||
self.functions.insert(
|
||||
DeclarationFunctionKey::with_id(declaration.id.clone())
|
||||
.signature(funct.signature().clone().try_into().unwrap()),
|
||||
DeclarationFunctionKey::with_location(
|
||||
module_id.clone(),
|
||||
declaration.id.clone(),
|
||||
)
|
||||
.signature(funct.signature().clone().try_into().unwrap()),
|
||||
);
|
||||
functions.insert(
|
||||
DeclarationFunctionKey::with_id(declaration.id.clone())
|
||||
.signature(funct.signature().clone().try_into().unwrap()),
|
||||
DeclarationFunctionKey::with_location(
|
||||
module_id.clone(),
|
||||
declaration.id.clone(),
|
||||
)
|
||||
.signature(funct.signature().clone().try_into().unwrap()),
|
||||
TypedFunctionSymbol::Flat(funct),
|
||||
);
|
||||
}
|
||||
|
@ -1363,11 +1376,11 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
let f = functions.pop().unwrap();
|
||||
|
||||
let arguments_checked = arguments_checked.into_iter().zip(f.signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::<Result<Vec<_>, _>>().map_err(|e| vec![ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected function call argument to be of type {}, found {}", e.1, e.0)
|
||||
}])?;
|
||||
pos: Some(pos),
|
||||
message: format!("Expected function call argument to be of type {}, found {} of type {}", e.1, e.0, e.0.get_type())
|
||||
}])?;
|
||||
|
||||
let call = TypedExpressionList::FunctionCall(f.clone().into(), arguments_checked, Signature::from(f.signature.clone()).outputs);
|
||||
let call = TypedExpressionList::FunctionCall(f.clone().into(), arguments_checked, variables.iter().map(|v| v.get_type()).collect());
|
||||
|
||||
Ok(TypedStatement::MultipleDefinition(variables, call))
|
||||
},
|
||||
|
@ -1914,9 +1927,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
|
||||
let f = functions.pop().unwrap();
|
||||
|
||||
let signature: Signature<T> = f.signature.into();
|
||||
let signature = f.signature;
|
||||
|
||||
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
|
||||
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Expected function call argument to be of type {}, found {}", e.1, e.0)
|
||||
})?;
|
||||
|
@ -1924,25 +1937,28 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
// the return count has to be 1
|
||||
match signature.outputs.len() {
|
||||
1 => match &signature.outputs[0] {
|
||||
Type::Int => unreachable!(),
|
||||
Type::FieldElement => Ok(FieldElementExpression::FunctionCall(
|
||||
FunctionKey {
|
||||
DeclarationType::Int => unreachable!(),
|
||||
DeclarationType::FieldElement => Ok(FieldElementExpression::FunctionCall(
|
||||
DeclarationFunctionKey {
|
||||
module: module_id.clone(),
|
||||
id: f.id.clone(),
|
||||
signature: signature.clone(),
|
||||
},
|
||||
arguments_checked,
|
||||
)
|
||||
.into()),
|
||||
Type::Boolean => Ok(BooleanExpression::FunctionCall(
|
||||
FunctionKey {
|
||||
DeclarationType::Boolean => Ok(BooleanExpression::FunctionCall(
|
||||
DeclarationFunctionKey {
|
||||
module: module_id.clone(),
|
||||
id: f.id.clone(),
|
||||
signature: signature.clone(),
|
||||
},
|
||||
arguments_checked,
|
||||
)
|
||||
.into()),
|
||||
Type::Uint(bitwidth) => Ok(UExpressionInner::FunctionCall(
|
||||
FunctionKey {
|
||||
DeclarationType::Uint(bitwidth) => Ok(UExpressionInner::FunctionCall(
|
||||
DeclarationFunctionKey {
|
||||
module: module_id.clone(),
|
||||
id: f.id.clone(),
|
||||
signature: signature.clone(),
|
||||
},
|
||||
|
@ -1950,23 +1966,25 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
)
|
||||
.annotate(*bitwidth)
|
||||
.into()),
|
||||
Type::Struct(members) => Ok(StructExpressionInner::FunctionCall(
|
||||
FunctionKey {
|
||||
DeclarationType::Struct(members) => Ok(StructExpressionInner::FunctionCall(
|
||||
DeclarationFunctionKey {
|
||||
module: module_id.clone(),
|
||||
id: f.id.clone(),
|
||||
signature: signature.clone(),
|
||||
},
|
||||
arguments_checked,
|
||||
)
|
||||
.annotate(members.clone())
|
||||
.annotate(members.clone().into())
|
||||
.into()),
|
||||
Type::Array(array_type) => Ok(ArrayExpressionInner::FunctionCall(
|
||||
FunctionKey {
|
||||
DeclarationType::Array(array_type) => Ok(ArrayExpressionInner::FunctionCall(
|
||||
DeclarationFunctionKey {
|
||||
module: module_id.clone(),
|
||||
id: f.id.clone(),
|
||||
signature: signature.clone(),
|
||||
},
|
||||
arguments_checked,
|
||||
)
|
||||
.annotate(*array_type.ty.clone(), array_type.size.clone())
|
||||
.annotate(Type::from(*array_type.ty.clone()), array_type.size.clone())
|
||||
.into()),
|
||||
},
|
||||
n => Err(ErrorInner {
|
||||
|
@ -2592,6 +2610,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.next()
|
||||
.unwrap_or(Type::Int);
|
||||
|
||||
println!("INFERRED TYPE {:?}", inferred_type);
|
||||
|
||||
match inferred_type {
|
||||
Type::Int => {
|
||||
// no need to check the expressions have the same type, this is guaranteed above
|
||||
|
@ -2643,6 +2663,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
|
||||
let size = unwrapped_expressions.len() as u32;
|
||||
|
||||
println!("hey");
|
||||
|
||||
Ok(ArrayExpressionInner::Value(unwrapped_expressions)
|
||||
.annotate(Type::Boolean, size as usize)
|
||||
.into())
|
||||
|
@ -3379,12 +3401,11 @@ mod tests {
|
|||
state.typed_modules.get(&PathBuf::from("bar")),
|
||||
Some(&TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_id("main")
|
||||
DeclarationFunctionKey::with_location("bar", "main")
|
||||
.signature(DeclarationSignature::new()),
|
||||
TypedFunctionSymbol::There(
|
||||
DeclarationFunctionKey::with_id("main")
|
||||
DeclarationFunctionKey::with_location("foo", "main")
|
||||
.signature(DeclarationSignature::new()),
|
||||
"foo".into()
|
||||
)
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -3672,16 +3693,19 @@ mod tests {
|
|||
.unwrap()
|
||||
.functions
|
||||
.contains_key(
|
||||
&DeclarationFunctionKey::with_id("foo").signature(DeclarationSignature::new())
|
||||
&DeclarationFunctionKey::with_location(MODULE_ID, "foo")
|
||||
.signature(DeclarationSignature::new())
|
||||
));
|
||||
assert!(state
|
||||
.typed_modules
|
||||
.get(&PathBuf::from(MODULE_ID))
|
||||
.unwrap()
|
||||
.functions
|
||||
.contains_key(&DeclarationFunctionKey::with_id("foo").signature(
|
||||
DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement])
|
||||
)))
|
||||
.contains_key(
|
||||
&DeclarationFunctionKey::with_location(MODULE_ID, "foo").signature(
|
||||
DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement])
|
||||
)
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -4271,6 +4295,7 @@ mod tests {
|
|||
];
|
||||
|
||||
let foo = DeclarationFunctionKey {
|
||||
module: "main".into(),
|
||||
id: "foo",
|
||||
signature: DeclarationSignature {
|
||||
inputs: vec![],
|
||||
|
@ -4323,6 +4348,7 @@ mod tests {
|
|||
.mock()];
|
||||
|
||||
let foo = DeclarationFunctionKey {
|
||||
module: "main".into(),
|
||||
id: "foo",
|
||||
signature: DeclarationSignature {
|
||||
inputs: vec![],
|
||||
|
@ -4856,7 +4882,7 @@ mod tests {
|
|||
typed_absy::Variable::field_element("b"),
|
||||
],
|
||||
TypedExpressionList::FunctionCall(
|
||||
DeclarationFunctionKey::with_id("foo").signature(
|
||||
DeclarationFunctionKey::with_location(MODULE_ID, "foo").signature(
|
||||
DeclarationSignature::new().outputs(vec![
|
||||
DeclarationType::FieldElement,
|
||||
DeclarationType::FieldElement,
|
||||
|
@ -4874,6 +4900,7 @@ mod tests {
|
|||
];
|
||||
|
||||
let foo = DeclarationFunctionKey {
|
||||
module: "main".into(),
|
||||
id: "foo",
|
||||
signature: DeclarationSignature {
|
||||
inputs: vec![],
|
||||
|
|
|
@ -279,7 +279,7 @@ pub fn fold_statement<'ast, T: Field>(
|
|||
)]
|
||||
}
|
||||
typed_absy::TypedStatement::PushCallLog(..) => vec![],
|
||||
typed_absy::TypedStatement::PopCallLog(..) => vec![],
|
||||
typed_absy::TypedStatement::PopCallLog => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -906,9 +906,8 @@ pub fn fold_function_symbol<'ast, T: Field>(
|
|||
typed_absy::TypedFunctionSymbol::Here(fun) => {
|
||||
zir::ZirFunctionSymbol::Here(f.fold_function(fun))
|
||||
}
|
||||
typed_absy::TypedFunctionSymbol::There(key, module) => zir::ZirFunctionSymbol::There(
|
||||
typed_absy::TypedFunctionSymbol::There(key) => zir::ZirFunctionSymbol::There(
|
||||
f.fold_function_key(typed_absy::types::ConcreteFunctionKey::try_from(key).unwrap()),
|
||||
module,
|
||||
), // by default, do not fold modules recursively
|
||||
typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat),
|
||||
}
|
||||
|
|
|
@ -41,12 +41,16 @@ pub trait Analyse {
|
|||
impl<'ast, T: Field> TypedProgram<'ast, T> {
|
||||
pub fn analyse(self) -> (ZirProgram<'ast, T>, Abi) {
|
||||
// return binding
|
||||
let r = ReturnBinder::bind(self);
|
||||
//let r = ReturnBinder::bind(self);
|
||||
|
||||
// propagated unrolling
|
||||
//let r = PropagatedUnroller::unroll(r).unwrap_or_else(|e| panic!(e));
|
||||
|
||||
let r = reduce_program(r).unwrap();
|
||||
println!("{:#?}", self);
|
||||
|
||||
let r = reduce_program(self).unwrap();
|
||||
|
||||
println!("reduced");
|
||||
|
||||
let r = Trimmer::trim(r);
|
||||
|
||||
|
|
|
@ -331,7 +331,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
}
|
||||
}
|
||||
s @ TypedStatement::PushCallLog(..) => Some(s),
|
||||
s @ TypedStatement::PopCallLog(..) => Some(s),
|
||||
s @ TypedStatement::PopCallLog => Some(s),
|
||||
};
|
||||
|
||||
// In verbose mode, we always return a statement
|
||||
|
|
|
@ -25,23 +25,25 @@
|
|||
// - The body of the function is in SSA form
|
||||
// - The return value(s) are assigned to internal variables
|
||||
|
||||
use embed::FlatEmbed;
|
||||
use static_analysis::reducer::CallCache;
|
||||
use static_analysis::reducer::Output;
|
||||
use static_analysis::reducer::ShallowTransformer;
|
||||
use static_analysis::reducer::Versions;
|
||||
use typed_absy::CoreIdentifier;
|
||||
use typed_absy::Identifier;
|
||||
use typed_absy::TypedAssignee;
|
||||
use typed_absy::{
|
||||
ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, DeclarationSignature, Signature,
|
||||
Type, TypedExpression, TypedFunction, TypedFunctionSymbol, TypedModuleId, TypedModules,
|
||||
TypedStatement, Variable,
|
||||
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey,
|
||||
DeclarationSignature, Signature, Type, TypedExpression, TypedFunction, TypedFunctionSymbol,
|
||||
TypedModules, TypedStatement, Variable,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub enum InlineError<'ast, T> {
|
||||
Generic(DeclarationSignature<'ast>, ConcreteSignature),
|
||||
Flat,
|
||||
Flat(FlatEmbed, Vec<TypedExpression<'ast, T>>, Vec<Type<'ast, T>>),
|
||||
NonConstant(
|
||||
TypedModuleId,
|
||||
DeclarationFunctionKey<'ast>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Vec<Type<'ast, T>>,
|
||||
|
@ -49,33 +51,29 @@ pub enum InlineError<'ast, T> {
|
|||
}
|
||||
|
||||
fn get_canonical_function<'ast, T: Field>(
|
||||
module_id: TypedModuleId,
|
||||
function_key: DeclarationFunctionKey<'ast>,
|
||||
modules: &TypedModules<'ast, T>,
|
||||
) -> (
|
||||
TypedModuleId,
|
||||
DeclarationFunctionKey<'ast>,
|
||||
TypedFunction<'ast, T>,
|
||||
) {
|
||||
) -> Result<(DeclarationFunctionKey<'ast>, TypedFunction<'ast, T>), FlatEmbed> {
|
||||
match modules
|
||||
.get(&module_id.clone())
|
||||
.get(&function_key.module)
|
||||
.unwrap()
|
||||
.functions
|
||||
.iter()
|
||||
.find(|(key, _)| function_key == **key)
|
||||
.unwrap()
|
||||
{
|
||||
(key, TypedFunctionSymbol::Here(f)) => (module_id, key.clone(), f.clone()),
|
||||
_ => unimplemented!(),
|
||||
(key, TypedFunctionSymbol::Here(f)) => Ok((key.clone(), f.clone())),
|
||||
(_, TypedFunctionSymbol::There(key)) => get_canonical_function(key.clone(), &modules),
|
||||
(_, TypedFunctionSymbol::Flat(f)) => Err(f.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inline_call<'a, 'ast, T: Field>(
|
||||
module_id: TypedModuleId,
|
||||
k: DeclarationFunctionKey<'ast>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_types: Vec<Type<'ast, T>>,
|
||||
modules: &TypedModules<'ast, T>,
|
||||
cache: &mut CallCache<'ast, T>,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
) -> Result<
|
||||
Output<(Vec<TypedStatement<'ast, T>>, Vec<TypedExpression<'ast, T>>), Vec<Versions<'ast>>>,
|
||||
|
@ -95,16 +93,12 @@ pub fn inline_call<'a, 'ast, T: Field>(
|
|||
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {
|
||||
Ok(s) => s,
|
||||
Err(()) => {
|
||||
return Err(InlineError::NonConstant(
|
||||
module_id,
|
||||
k,
|
||||
arguments,
|
||||
output_types,
|
||||
));
|
||||
return Err(InlineError::NonConstant(k, arguments, output_types));
|
||||
}
|
||||
};
|
||||
|
||||
let (module_id, decl_key, f) = get_canonical_function(module_id, k.clone(), modules);
|
||||
let (decl_key, f) = get_canonical_function(k.clone(), modules)
|
||||
.map_err(|e| InlineError::Flat(e, arguments.clone(), output_types))?;
|
||||
assert_eq!(f.arguments.len(), arguments.len());
|
||||
|
||||
// get an assignment of generics for this call site
|
||||
|
@ -115,23 +109,34 @@ pub fn inline_call<'a, 'ast, T: Field>(
|
|||
InlineError::Generic(decl_key.signature.clone(), inferred_signature.clone())
|
||||
})?;
|
||||
|
||||
let concrete_key = ConcreteFunctionKey {
|
||||
module: decl_key.module.clone(),
|
||||
id: decl_key.id.clone(),
|
||||
signature: inferred_signature.clone(),
|
||||
};
|
||||
|
||||
match cache.get(&(concrete_key.clone(), arguments.clone())) {
|
||||
Some(v) => {
|
||||
return Ok(Output::Complete((vec![], v.clone())));
|
||||
}
|
||||
None => {}
|
||||
};
|
||||
|
||||
let (ssa_f, incomplete_data) = match ShallowTransformer::transform(f, &assignment, versions) {
|
||||
Output::Complete(v) => (v, None),
|
||||
Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)),
|
||||
};
|
||||
|
||||
let call_log = TypedStatement::PushCallLog(
|
||||
module_id,
|
||||
decl_key.clone(),
|
||||
assignment,
|
||||
ssa_f
|
||||
.arguments
|
||||
.into_iter()
|
||||
.zip(inferred_signature.inputs.clone())
|
||||
.map(|(p, t)| ConcreteVariable::with_id_and_type(p.id.id, t))
|
||||
.zip(arguments)
|
||||
.collect(),
|
||||
);
|
||||
let call_log = TypedStatement::PushCallLog(decl_key.clone(), assignment);
|
||||
|
||||
let input_bindings: Vec<TypedStatement<'ast, T>> = ssa_f
|
||||
.arguments
|
||||
.into_iter()
|
||||
.zip(inferred_signature.inputs.clone())
|
||||
.map(|(p, t)| ConcreteVariable::with_id_and_type(p.id.id, t))
|
||||
.zip(arguments.clone())
|
||||
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
|
||||
.collect();
|
||||
|
||||
let (statements, returns): (Vec<_>, Vec<_>) =
|
||||
ssa_f.statements.into_iter().partition(|s| match s {
|
||||
|
@ -151,7 +156,15 @@ pub fn inline_call<'a, 'ast, T: Field>(
|
|||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
ConcreteVariable::with_id_and_type(Identifier::from(CoreIdentifier::Call(i)), t.clone())
|
||||
ConcreteVariable::with_id_and_type(
|
||||
Identifier::from(CoreIdentifier::Call(i)).version(
|
||||
*versions
|
||||
.entry(CoreIdentifier::Call(i).clone())
|
||||
.and_modify(|e| *e += 1) // if it was already declared, we increment
|
||||
.or_insert(0),
|
||||
),
|
||||
t.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
@ -162,13 +175,26 @@ pub fn inline_call<'a, 'ast, T: Field>(
|
|||
|
||||
assert_eq!(res.len(), returns.len());
|
||||
|
||||
let call_pop = TypedStatement::PopCallLog(res.into_iter().zip(returns).collect());
|
||||
let output_bindings: Vec<TypedStatement<'ast, T>> = res
|
||||
.into_iter()
|
||||
.zip(returns)
|
||||
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
|
||||
.collect();
|
||||
|
||||
let pop_log = TypedStatement::PopCallLog;
|
||||
|
||||
let statements: Vec<_> = std::iter::once(call_log)
|
||||
.chain(input_bindings)
|
||||
.chain(statements)
|
||||
.chain(std::iter::once(call_pop))
|
||||
.chain(output_bindings)
|
||||
.chain(std::iter::once(pop_log))
|
||||
.collect();
|
||||
|
||||
cache.insert(
|
||||
(concrete_key.clone(), arguments.clone()),
|
||||
expressions.clone(),
|
||||
);
|
||||
|
||||
Ok(incomplete_data
|
||||
.map(|d| Output::Incomplete((statements.clone(), expressions.clone()), d))
|
||||
.unwrap_or_else(|| Output::Complete((statements, expressions))))
|
||||
|
|
|
@ -17,12 +17,20 @@ mod unroll;
|
|||
|
||||
use self::inline::{inline_call, InlineError};
|
||||
use std::collections::HashMap;
|
||||
use typed_absy::result_folder::*;
|
||||
use typed_absy::types::GenericsAssignment;
|
||||
use typed_absy::Folder;
|
||||
|
||||
use typed_absy::{
|
||||
CoreIdentifier, TypedExpressionList, TypedFunction, TypedFunctionSymbol, TypedModule,
|
||||
TypedModules, TypedProgram, TypedStatement,
|
||||
ArrayExpression, ArrayExpressionInner, BooleanExpression, ConcreteFunctionKey, CoreIdentifier,
|
||||
DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier, StructExpression,
|
||||
StructExpressionInner, Type, Typed, TypedExpression, TypedExpressionList, TypedFunction,
|
||||
TypedFunctionSymbol, TypedModule, TypedModules, TypedProgram, TypedStatement, UExpression,
|
||||
UExpressionInner,
|
||||
};
|
||||
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
||||
use zokrates_field::Field;
|
||||
|
||||
use self::shallow_ssa::ShallowTransformer;
|
||||
|
@ -44,6 +52,326 @@ pub enum Error {
|
|||
Incompatible,
|
||||
}
|
||||
|
||||
type CallCache<'ast, T> = HashMap<
|
||||
(ConcreteFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
>;
|
||||
|
||||
type Substitutions<'ast> = HashMap<CoreIdentifier<'ast>, HashMap<usize, usize>>;
|
||||
|
||||
struct Sub<'a, 'ast> {
|
||||
substitutions: &'a mut Substitutions<'ast>,
|
||||
}
|
||||
|
||||
impl<'a, 'ast> Sub<'a, 'ast> {
|
||||
fn new(substitutions: &'a mut Substitutions<'ast>) -> Self {
|
||||
Self { substitutions }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'ast, T: Field> Folder<'ast, T> for Sub<'a, 'ast> {
|
||||
fn fold_name(&mut self, id: Identifier<'ast>) -> Identifier<'ast> {
|
||||
let sub = self.substitutions.entry(id.id.clone()).or_default();
|
||||
let version = sub.get(&id.version).cloned().unwrap_or(id.version);
|
||||
id.version(version)
|
||||
}
|
||||
}
|
||||
|
||||
fn register<'ast>(
|
||||
substitutions: &mut Substitutions<'ast>,
|
||||
substitute: &Versions<'ast>,
|
||||
with: &Versions<'ast>,
|
||||
) {
|
||||
//assert!(substitute.len() <= with.len());
|
||||
for (id, key, value) in substitute
|
||||
.iter()
|
||||
.filter_map(|(id, version)| with.get(&id).clone().map(|to| (id, version, to)))
|
||||
.filter(|(_, key, value)| key != value)
|
||||
{
|
||||
let sub = substitutions.entry(id.clone()).or_default();
|
||||
sub.insert(*key, *value);
|
||||
}
|
||||
}
|
||||
|
||||
struct Reducer<'ast, 'a, T> {
|
||||
statement_buffer: Vec<TypedStatement<'ast, T>>,
|
||||
for_loop_versions: Vec<Versions<'ast>>,
|
||||
for_loop_index: usize,
|
||||
modules: &'a TypedModules<'ast, T>,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
substitutions: Substitutions<'ast>,
|
||||
cache: CallCache<'ast, T>,
|
||||
complete: bool,
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
||||
fn new(
|
||||
modules: &'a TypedModules<'ast, T>,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
for_loop_versions: Vec<Versions<'ast>>,
|
||||
) -> Self {
|
||||
Reducer {
|
||||
statement_buffer: vec![],
|
||||
for_loop_index: 0,
|
||||
for_loop_versions,
|
||||
cache: CallCache::default(),
|
||||
substitutions: Substitutions::default(),
|
||||
modules,
|
||||
versions,
|
||||
complete: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_function_call<E>(
|
||||
&mut self,
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_types: Vec<Type<'ast, T>>,
|
||||
) -> Result<E, <Self as ResultFolder<'ast, T>>::Error>
|
||||
where
|
||||
E: FunctionCall<'ast, T> + TryFrom<TypedExpression<'ast, T>, Error = ()> + std::fmt::Debug,
|
||||
{
|
||||
let arguments = arguments
|
||||
.into_iter()
|
||||
.map(|e| self.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
let res = inline_call(
|
||||
key.clone(),
|
||||
arguments,
|
||||
output_types,
|
||||
&self.modules,
|
||||
&mut self.cache,
|
||||
&mut self.versions,
|
||||
);
|
||||
|
||||
match res {
|
||||
Ok(Output::Complete((statements, expressions))) => {
|
||||
self.complete &= true;
|
||||
self.statement_buffer.extend(statements);
|
||||
Ok(expressions[0].clone().try_into().unwrap())
|
||||
}
|
||||
Ok(Output::Incomplete((statements, expressions), _delta_for_loop_versions)) => {
|
||||
self.complete = false;
|
||||
self.statement_buffer.extend(statements);
|
||||
Ok(expressions[0].clone().try_into().unwrap())
|
||||
}
|
||||
Err(InlineError::Generic(a, b)) => Err(Error::Incompatible),
|
||||
Err(InlineError::NonConstant(key, arguments, _)) => {
|
||||
self.complete = false;
|
||||
|
||||
Ok(E::function_call(key, arguments))
|
||||
}
|
||||
Err(InlineError::Flat(embed, arguments, output_types)) => {
|
||||
Ok(E::function_call(embed.key::<T>().into(), arguments))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_statement(
|
||||
&mut self,
|
||||
s: TypedStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
|
||||
let res = match s {
|
||||
TypedStatement::MultipleDefinition(
|
||||
v,
|
||||
TypedExpressionList::FunctionCall(key, arguments, output_types),
|
||||
) => {
|
||||
let arguments = arguments
|
||||
.into_iter()
|
||||
.map(|a| self.fold_expression(a))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
match inline_call(
|
||||
key,
|
||||
arguments,
|
||||
output_types,
|
||||
&self.modules,
|
||||
&mut self.cache,
|
||||
&mut self.versions,
|
||||
) {
|
||||
Ok(Output::Complete((statements, expressions))) => {
|
||||
assert_eq!(v.len(), expressions.len());
|
||||
|
||||
self.complete &= true;
|
||||
|
||||
Ok(statements
|
||||
.into_iter()
|
||||
.chain(
|
||||
v.into_iter()
|
||||
.zip(expressions)
|
||||
.map(|(v, e)| TypedStatement::Definition(v.into(), e)),
|
||||
)
|
||||
.collect())
|
||||
}
|
||||
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
|
||||
assert_eq!(v.len(), expressions.len());
|
||||
|
||||
self.complete = false;
|
||||
|
||||
Ok(statements
|
||||
.into_iter()
|
||||
.chain(
|
||||
v.into_iter()
|
||||
.zip(expressions)
|
||||
.map(|(v, e)| TypedStatement::Definition(v.into(), e)),
|
||||
)
|
||||
.collect())
|
||||
}
|
||||
Err(InlineError::Generic(..)) => Err(Error::Incompatible),
|
||||
Err(InlineError::NonConstant(key, arguments, output_types)) => {
|
||||
self.complete = false;
|
||||
|
||||
Ok(vec![TypedStatement::MultipleDefinition(
|
||||
v,
|
||||
TypedExpressionList::FunctionCall(key, arguments, output_types),
|
||||
)])
|
||||
}
|
||||
Err(InlineError::Flat(embed, arguments, output_types)) => {
|
||||
Ok(vec![TypedStatement::MultipleDefinition(
|
||||
v,
|
||||
TypedExpressionList::FunctionCall(
|
||||
embed.key::<T>().into(),
|
||||
arguments,
|
||||
output_types,
|
||||
),
|
||||
)])
|
||||
}
|
||||
}
|
||||
}
|
||||
TypedStatement::For(v, from, to, statements) => {
|
||||
match (from.as_inner(), to.as_inner()) {
|
||||
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => {
|
||||
let mut out_statements = vec![];
|
||||
|
||||
let versions_before = self.for_loop_versions[self.for_loop_index].clone();
|
||||
|
||||
self.versions.values_mut().for_each(|v| *v = *v + 1);
|
||||
|
||||
register(&mut self.substitutions, &self.versions, &versions_before);
|
||||
|
||||
let versions_after = versions_before
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v + 2))
|
||||
.collect();
|
||||
|
||||
let mut transformer = ShallowTransformer::with_versions(&mut self.versions);
|
||||
|
||||
let old_index = self.for_loop_index;
|
||||
|
||||
for index in *from..*to {
|
||||
let statements: Vec<TypedStatement<_>> =
|
||||
std::iter::once(TypedStatement::Definition(
|
||||
v.clone().into(),
|
||||
UExpression::from(index as u32).into(),
|
||||
))
|
||||
.chain(statements.clone().into_iter())
|
||||
.map(|s| transformer.fold_statement(s))
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
let new_versions_count = transformer.for_loop_backups.len();
|
||||
|
||||
self.for_loop_index += new_versions_count;
|
||||
|
||||
out_statements.extend(statements);
|
||||
}
|
||||
|
||||
let backups = transformer.for_loop_backups;
|
||||
let blocked = transformer.blocked;
|
||||
|
||||
register(&mut self.substitutions, &versions_after, &self.versions);
|
||||
|
||||
self.for_loop_versions.splice(old_index..old_index, backups);
|
||||
|
||||
self.complete &= !blocked;
|
||||
|
||||
let out_statements = out_statements
|
||||
.into_iter()
|
||||
.map(|s| Sub::new(&mut self.substitutions).fold_statement(s))
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
Ok(out_statements)
|
||||
}
|
||||
_ => {
|
||||
self.complete = false;
|
||||
self.for_loop_index += 1;
|
||||
Ok(vec![TypedStatement::For(v, from, to, statements)])
|
||||
}
|
||||
}
|
||||
}
|
||||
s => fold_statement(self, s),
|
||||
};
|
||||
|
||||
res.map(|res| self.statement_buffer.drain(..).chain(res).collect())
|
||||
}
|
||||
|
||||
fn fold_boolean_expression(
|
||||
&mut self,
|
||||
e: BooleanExpression<'ast, T>,
|
||||
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
BooleanExpression::FunctionCall(key, arguments) => {
|
||||
self.fold_function_call(key, arguments, vec![Type::Boolean])
|
||||
}
|
||||
e => fold_boolean_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_uint_expression(
|
||||
&mut self,
|
||||
e: UExpression<'ast, T>,
|
||||
) -> Result<UExpression<'ast, T>, Self::Error> {
|
||||
match e.as_inner() {
|
||||
UExpressionInner::FunctionCall(key, arguments) => {
|
||||
self.fold_function_call(key.clone(), arguments.clone(), vec![e.get_type()])
|
||||
}
|
||||
_ => fold_uint_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
&mut self,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
FieldElementExpression::FunctionCall(key, arguments) => {
|
||||
self.fold_function_call(key, arguments, vec![Type::FieldElement])
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_array_expression(
|
||||
&mut self,
|
||||
e: ArrayExpression<'ast, T>,
|
||||
) -> Result<ArrayExpression<'ast, T>, Self::Error> {
|
||||
match e.as_inner() {
|
||||
ArrayExpressionInner::FunctionCall(key, arguments) => {
|
||||
self.fold_function_call(key.clone(), arguments.clone(), vec![e.get_type()])
|
||||
}
|
||||
_ => fold_array_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_struct_expression(
|
||||
&mut self,
|
||||
e: StructExpression<'ast, T>,
|
||||
) -> Result<StructExpression<'ast, T>, Self::Error> {
|
||||
match e.as_inner() {
|
||||
StructExpressionInner::FunctionCall(key, arguments) => {
|
||||
self.fold_function_call(key.clone(), arguments.clone(), vec![e.get_type()])
|
||||
}
|
||||
_ => fold_struct_expression(self, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
|
||||
let main_module = p.modules.get(&p.main).unwrap().clone();
|
||||
|
||||
|
@ -93,27 +421,27 @@ fn reduce_function<'ast, T: Field>(
|
|||
let mut f = new_f;
|
||||
|
||||
let statements = loop {
|
||||
match reduce_statements(
|
||||
f.statements,
|
||||
&mut for_loop_versions,
|
||||
&mut versions,
|
||||
modules,
|
||||
) {
|
||||
Ok(Output::Complete(statements)) => {
|
||||
break statements;
|
||||
}
|
||||
Ok(Output::Incomplete(new_statements, new_for_loop_versions)) => {
|
||||
let new_f = TypedFunction {
|
||||
statements: new_statements,
|
||||
..f
|
||||
};
|
||||
println!("{}", f);
|
||||
|
||||
use typed_absy::folder::Folder;
|
||||
let mut reducer = Reducer::new(&modules, &mut versions, for_loop_versions);
|
||||
|
||||
let statements = f
|
||||
.statements
|
||||
.into_iter()
|
||||
.map(|s| reducer.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
match reducer.complete {
|
||||
true => break statements,
|
||||
false => {
|
||||
let new_f = TypedFunction { statements, ..f };
|
||||
|
||||
f = Propagator::verbose().fold_function(new_f);
|
||||
for_loop_versions = new_for_loop_versions;
|
||||
for_loop_versions = reducer.for_loop_versions;
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -122,111 +450,47 @@ fn reduce_function<'ast, T: Field>(
|
|||
}
|
||||
}
|
||||
|
||||
fn reduce_statements<'ast, T: Field>(
|
||||
statements: Vec<TypedStatement<'ast, T>>,
|
||||
for_loop_versions: &mut Vec<Versions<'ast>>,
|
||||
versions: &mut Versions<'ast>,
|
||||
modules: &TypedModules<'ast, T>,
|
||||
) -> Result<Output<Vec<TypedStatement<'ast, T>>, Vec<Versions<'ast>>>, Error> {
|
||||
let mut versions = versions;
|
||||
// fn reduce_statements<'ast, T: Field>(
|
||||
// statements: Vec<TypedStatement<'ast, T>>,
|
||||
// for_loop_versions: &mut Vec<Versions<'ast>>,
|
||||
// versions: &mut Versions<'ast>,
|
||||
// modules: &TypedModules<'ast, T>,
|
||||
// ) -> Result<Output<Vec<TypedStatement<'ast, T>>, Vec<Versions<'ast>>>, Error> {
|
||||
// let mut versions = versions;
|
||||
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|s| reduce_statement(s, for_loop_versions, &mut versions, modules));
|
||||
// let statements = statements
|
||||
// .into_iter()
|
||||
// .map(|s| reduce_statement(s, for_loop_versions, &mut versions, modules));
|
||||
|
||||
statements
|
||||
.into_iter()
|
||||
.fold(Ok(Output::Complete(vec![])), |state, e| match (state, e) {
|
||||
(Ok(state), Ok(e)) => {
|
||||
use self::Output::*;
|
||||
match (state, e) {
|
||||
(Complete(mut s), Complete(c)) => {
|
||||
s.extend(c);
|
||||
Ok(Complete(s))
|
||||
}
|
||||
(Complete(mut s), Incomplete(stats, for_loop_versions)) => {
|
||||
s.extend(stats);
|
||||
Ok(Incomplete(s, for_loop_versions))
|
||||
}
|
||||
(Incomplete(mut stats, new_for_loop_versions), Complete(c)) => {
|
||||
stats.extend(c);
|
||||
Ok(Incomplete(stats, new_for_loop_versions))
|
||||
}
|
||||
(Incomplete(mut stats, mut versions), Incomplete(new_stats, new_versions)) => {
|
||||
stats.extend(new_stats);
|
||||
versions.extend(new_versions);
|
||||
Ok(Incomplete(stats, versions))
|
||||
}
|
||||
}
|
||||
}
|
||||
(Err(state), _) => Err(state),
|
||||
(Ok(_), Err(e)) => Err(e),
|
||||
})
|
||||
}
|
||||
|
||||
fn reduce_statement<'ast, T: Field>(
|
||||
statement: TypedStatement<'ast, T>,
|
||||
_for_loop_versions: &mut Vec<Versions<'ast>>,
|
||||
versions: &mut Versions<'ast>,
|
||||
modules: &TypedModules<'ast, T>,
|
||||
) -> Result<Output<Vec<TypedStatement<'ast, T>>, Vec<Versions<'ast>>>, Error> {
|
||||
use self::Output::*;
|
||||
|
||||
match statement {
|
||||
TypedStatement::MultipleDefinition(
|
||||
v,
|
||||
TypedExpressionList::FunctionCall(key, arguments, output_types),
|
||||
) => match inline_call(
|
||||
"main".into(),
|
||||
key,
|
||||
arguments,
|
||||
output_types,
|
||||
modules,
|
||||
versions,
|
||||
) {
|
||||
Ok(Output::Complete((statements, expressions))) => {
|
||||
assert_eq!(v.len(), expressions.len());
|
||||
Ok(Output::Complete(
|
||||
statements
|
||||
.into_iter()
|
||||
.chain(
|
||||
v.into_iter()
|
||||
.zip(expressions)
|
||||
.map(|(v, e)| TypedStatement::Definition(v.into(), e)),
|
||||
)
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
|
||||
assert_eq!(v.len(), expressions.len());
|
||||
Ok(Output::Incomplete(
|
||||
statements
|
||||
.into_iter()
|
||||
.chain(
|
||||
v.into_iter()
|
||||
.zip(expressions)
|
||||
.map(|(v, e)| TypedStatement::Definition(v.into(), e)),
|
||||
)
|
||||
.collect(),
|
||||
delta_for_loop_versions,
|
||||
))
|
||||
}
|
||||
Err(InlineError::Generic(..)) => Err(Error::Incompatible),
|
||||
Err(InlineError::NonConstant(_module, key, arguments, output_types)) => {
|
||||
Ok(Output::Incomplete(
|
||||
vec![TypedStatement::MultipleDefinition(
|
||||
v,
|
||||
TypedExpressionList::FunctionCall(key, arguments, output_types),
|
||||
)],
|
||||
vec![],
|
||||
))
|
||||
}
|
||||
Err(InlineError::Flat) => unimplemented!(),
|
||||
},
|
||||
TypedStatement::For(..) => unimplemented!(),
|
||||
s => Ok(Complete(vec![s])),
|
||||
}
|
||||
}
|
||||
// statements
|
||||
// .into_iter()
|
||||
// .fold(Ok(Output::Complete(vec![])), |state, e| match (state, e) {
|
||||
// (Ok(state), Ok(e)) => {
|
||||
// use self::Output::*;
|
||||
// match (state, e) {
|
||||
// (Complete(mut s), Complete(c)) => {
|
||||
// s.extend(c);
|
||||
// Ok(Complete(s))
|
||||
// }
|
||||
// (Complete(mut s), Incomplete(stats, for_loop_versions)) => {
|
||||
// s.extend(stats);
|
||||
// Ok(Incomplete(s, for_loop_versions))
|
||||
// }
|
||||
// (Incomplete(mut stats, new_for_loop_versions), Complete(c)) => {
|
||||
// stats.extend(c);
|
||||
// Ok(Incomplete(stats, new_for_loop_versions))
|
||||
// }
|
||||
// (Incomplete(mut stats, mut versions), Incomplete(new_stats, new_versions)) => {
|
||||
// stats.extend(new_stats);
|
||||
// versions.extend(new_versions);
|
||||
// Ok(Incomplete(stats, versions))
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// (Err(state), _) => Err(state),
|
||||
// (Ok(_), Err(e)) => Err(e),
|
||||
// })
|
||||
// }
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
@ -257,9 +521,11 @@ mod tests {
|
|||
// u32 n_0 = 42
|
||||
// n_1 = n_0
|
||||
// a_1 = a_0
|
||||
// # PUSH CALL to foo with a_3 := a_1
|
||||
// # POP CALL with foo_ifof_42 := a_3
|
||||
// a_2 = foo_ifof_42
|
||||
// # PUSH CALL to foo
|
||||
// a_3 := a_1 // input binding
|
||||
// #RETURN_AT_INDEX_0_0 := a_3
|
||||
// # POP CALL
|
||||
// a_2 = #RETURN_AT_INDEX_0_0
|
||||
// n_2 = n_1
|
||||
// return a_2
|
||||
|
||||
|
@ -295,7 +561,7 @@ mod tests {
|
|||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element("a").into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
DeclarationFunctionKey::with_id("foo").signature(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -324,7 +590,7 @@ mod tests {
|
|||
TypedModule {
|
||||
functions: vec![
|
||||
(
|
||||
DeclarationFunctionKey::with_id("foo").signature(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -332,7 +598,7 @@ mod tests {
|
|||
TypedFunctionSymbol::Here(foo),
|
||||
),
|
||||
(
|
||||
DeclarationFunctionKey::with_id("main").signature(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -369,28 +635,23 @@ mod tests {
|
|||
FieldElementExpression::Identifier("a".into()).into(),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
"main".into(),
|
||||
DeclarationFunctionKey::with_id("foo").signature(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
GenericsAssignment::default(),
|
||||
vec![(
|
||||
ConcreteVariable::with_id_and_type(
|
||||
Identifier::from("a").version(3),
|
||||
ConcreteType::FieldElement,
|
||||
),
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(1)).into(),
|
||||
)],
|
||||
),
|
||||
TypedStatement::PopCallLog(vec![(
|
||||
ConcreteVariable::with_id_and_type(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
ConcreteType::FieldElement,
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(1)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0))
|
||||
.into(),
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(3)).into(),
|
||||
)]),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::Definition(
|
||||
Variable::field_element(Identifier::from("a").version(2)).into(),
|
||||
FieldElementExpression::Identifier(
|
||||
|
@ -420,7 +681,7 @@ mod tests {
|
|||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_id("main").signature(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -435,6 +696,9 @@ mod tests {
|
|||
.collect(),
|
||||
};
|
||||
|
||||
println!("{}", reduced.clone().unwrap());
|
||||
println!("{}", expected);
|
||||
|
||||
assert_eq!(reduced.unwrap(), expected);
|
||||
}
|
||||
|
||||
|
@ -455,9 +719,12 @@ mod tests {
|
|||
// u32 n_0 = 42
|
||||
// n_1 = n_0
|
||||
// field[1] b_0 = [42]
|
||||
// # PUSH CALL to foo::<1> with a_0 := b_0
|
||||
// # POP CALL with foo_if[1]_of[1]_42 := a_0
|
||||
// b_1 = foo_if[1]_of[1]_42
|
||||
// # PUSH CALL to foo::<1>
|
||||
// a_0 = b_0
|
||||
// K = 1
|
||||
// #RETURN_AT_INDEX_0_0 := a_0
|
||||
// # POP CALL
|
||||
// b_1 = #RETURN_AT_INDEX_0_0
|
||||
// n_2 = n_1
|
||||
// return a_2
|
||||
|
||||
|
@ -512,7 +779,8 @@ mod tests {
|
|||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::array("b", Type::FieldElement, 1u32.into()).into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
vec![ArrayExpressionInner::Identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into()],
|
||||
|
@ -539,11 +807,12 @@ mod tests {
|
|||
TypedModule {
|
||||
functions: vec![
|
||||
(
|
||||
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
TypedFunctionSymbol::Here(foo),
|
||||
),
|
||||
(
|
||||
DeclarationFunctionKey::with_id("main").signature(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -584,35 +853,37 @@ mod tests {
|
|||
.into(),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
"main".into(),
|
||||
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
GenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
||||
vec![(
|
||||
ConcreteVariable::array(
|
||||
Identifier::from("a").version(1),
|
||||
ConcreteType::FieldElement,
|
||||
1,
|
||||
)
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::array(
|
||||
Identifier::from("a").version(1),
|
||||
Type::FieldElement,
|
||||
1u32.into(),
|
||||
)
|
||||
.into(),
|
||||
ArrayExpressionInner::Identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpressionInner::Identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
)],
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::uint("K", UBitwidth::B32).into(),
|
||||
UExpression::from(1u32).into(),
|
||||
TypedExpression::Uint(1u32.into()),
|
||||
),
|
||||
TypedStatement::PopCallLog(vec![(
|
||||
ConcreteVariable::array(
|
||||
TypedStatement::Definition(
|
||||
Variable::array(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
ConcreteType::FieldElement,
|
||||
1,
|
||||
),
|
||||
Type::FieldElement,
|
||||
1u32.into(),
|
||||
)
|
||||
.into(),
|
||||
ArrayExpressionInner::Identifier(Identifier::from("a").version(1))
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
)]),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::Definition(
|
||||
Variable::array(
|
||||
Identifier::from("b").version(1),
|
||||
|
@ -645,7 +916,7 @@ mod tests {
|
|||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_id("main").signature(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -660,6 +931,9 @@ mod tests {
|
|||
.collect(),
|
||||
};
|
||||
|
||||
println!("{}", reduced.clone().unwrap());
|
||||
println!("{}", expected);
|
||||
|
||||
assert_eq!(reduced.unwrap(), expected);
|
||||
}
|
||||
|
||||
|
@ -680,9 +954,12 @@ mod tests {
|
|||
// u32 n_0 = 2
|
||||
// n_1 = 2
|
||||
// field[1] b_0 = [42]
|
||||
// # PUSH CALL to foo::<1> with a_3 := b_0
|
||||
// # POP CALL with foo_if[1]of[1]_42 := a_3
|
||||
// b_1 = foo_if[1]of[1]_42
|
||||
// # PUSH CALL to foo::<1>
|
||||
// a_3 = b_0
|
||||
// K = 1
|
||||
// #RETURN_AT_INDEX_0_0 = a_3
|
||||
// # POP CALL
|
||||
// b_1 = #RETURN_AT_INDEX_0_0
|
||||
// n_2 = 2
|
||||
// return a_2
|
||||
|
||||
|
@ -755,7 +1032,8 @@ mod tests {
|
|||
)
|
||||
.into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
vec![ArrayExpressionInner::Identifier("b".into())
|
||||
.annotate(
|
||||
Type::FieldElement,
|
||||
|
@ -798,11 +1076,12 @@ mod tests {
|
|||
TypedModule {
|
||||
functions: vec![
|
||||
(
|
||||
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
TypedFunctionSymbol::Here(foo),
|
||||
),
|
||||
(
|
||||
DeclarationFunctionKey::with_id("main").signature(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -841,37 +1120,39 @@ mod tests {
|
|||
.into(),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
"main".into(),
|
||||
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
GenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
||||
vec![(
|
||||
ConcreteVariable::array(
|
||||
Identifier::from("a").version(1),
|
||||
ConcreteType::FieldElement,
|
||||
1,
|
||||
)
|
||||
.into(),
|
||||
ArrayExpressionInner::Value(vec![
|
||||
FieldElementExpression::Number(1.into()).into()
|
||||
])
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
)],
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::array(
|
||||
Identifier::from("a").version(1),
|
||||
Type::FieldElement,
|
||||
1u32.into(),
|
||||
)
|
||||
.into(),
|
||||
ArrayExpressionInner::Value(vec![
|
||||
FieldElementExpression::Number(1.into()).into()
|
||||
])
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::uint("K", UBitwidth::B32).into(),
|
||||
UExpression::from(1u32).into(),
|
||||
),
|
||||
TypedStatement::PopCallLog(vec![(
|
||||
ConcreteVariable::array(
|
||||
TypedStatement::Definition(
|
||||
Variable::array(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
ConcreteType::FieldElement,
|
||||
1,
|
||||
),
|
||||
Type::FieldElement,
|
||||
1u32.into(),
|
||||
)
|
||||
.into(),
|
||||
ArrayExpressionInner::Identifier(Identifier::from("a").version(1))
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
)]),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::Definition(
|
||||
Variable::array(
|
||||
Identifier::from("b").version(1),
|
||||
|
@ -902,7 +1183,7 @@ mod tests {
|
|||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_id("main").signature(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -917,6 +1198,9 @@ mod tests {
|
|||
.collect(),
|
||||
};
|
||||
|
||||
println!("{}", reduced.clone().unwrap());
|
||||
println!("{}", expected);
|
||||
|
||||
assert_eq!(reduced.unwrap(), expected);
|
||||
}
|
||||
|
||||
|
@ -1010,7 +1294,7 @@ mod tests {
|
|||
// TypedStatement::MultipleDefinition(
|
||||
// vec![Variable::array("b", Type::FieldElement, 1u32.into()).into()],
|
||||
// TypedExpressionList::FunctionCall(
|
||||
// DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
|
||||
// DeclarationFunctionKey::with_location(module_id.clone(), "foo").signature(foo_signature.clone()),
|
||||
// vec![ArrayExpressionInner::Identifier("b".into())
|
||||
// .annotate(Type::FieldElement, 1u32)
|
||||
// .into()],
|
||||
|
@ -1029,11 +1313,11 @@ mod tests {
|
|||
// TypedModule {
|
||||
// functions: vec![
|
||||
// (
|
||||
// DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
|
||||
// DeclarationFunctionKey::with_location(module_id.clone(), "foo").signature(foo_signature.clone()),
|
||||
// TypedFunctionSymbol::Here(foo),
|
||||
// ),
|
||||
// (
|
||||
// DeclarationFunctionKey::with_id("main"),
|
||||
// DeclarationFunctionKey::with_location(module_id.clone(), "main"),
|
||||
// TypedFunctionSymbol::Here(main),
|
||||
// ),
|
||||
// ]
|
||||
|
@ -1071,7 +1355,7 @@ mod tests {
|
|||
// ),
|
||||
// TypedStatement::PushCallLog(
|
||||
// "main".into(),
|
||||
// DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
|
||||
// DeclarationFunctionKey::with_location(module_id.clone(), "foo").signature(foo_signature.clone()),
|
||||
// GenericsAssignment(vec![("K", 1)].into_iter().collect()),
|
||||
// vec![(
|
||||
// ConcreteVariable::array("a", ConcreteType::FieldElement, 1).into(),
|
||||
|
@ -1113,7 +1397,7 @@ mod tests {
|
|||
// "main".into(),
|
||||
// TypedModule {
|
||||
// functions: vec![(
|
||||
// DeclarationFunctionKey::with_id("main").signature(
|
||||
// DeclarationFunctionKey::with_location(module_id.clone(), "main").signature(
|
||||
// DeclarationSignature::new()
|
||||
// .inputs(vec![DeclarationType::FieldElement])
|
||||
// .outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -1178,7 +1462,8 @@ mod tests {
|
|||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::array("b", Type::FieldElement, 1u32.into()).into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
vec![ArrayExpressionInner::Value(vec![])
|
||||
.annotate(Type::FieldElement, 0u32)
|
||||
.into()],
|
||||
|
@ -1197,11 +1482,12 @@ mod tests {
|
|||
TypedModule {
|
||||
functions: vec![
|
||||
(
|
||||
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
TypedFunctionSymbol::Here(foo),
|
||||
),
|
||||
(
|
||||
DeclarationFunctionKey::with_id("main").signature(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
|
|
|
@ -37,15 +37,15 @@ use super::{Output, Versions};
|
|||
|
||||
pub struct ShallowTransformer<'ast, 'a> {
|
||||
// version index for any variable name
|
||||
versions: &'a mut Versions<'ast>,
|
||||
pub versions: &'a mut Versions<'ast>,
|
||||
// A backup of the versions before each for-loop
|
||||
for_loop_backups: Vec<Versions<'ast>>,
|
||||
pub for_loop_backups: Vec<Versions<'ast>>,
|
||||
// whether all statements could be unrolled so far. Loops with variable bounds cannot.
|
||||
blocked: bool,
|
||||
pub blocked: bool,
|
||||
}
|
||||
|
||||
impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
|
||||
fn with_versions(versions: &'a mut Versions<'ast>) -> Self {
|
||||
pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self {
|
||||
ShallowTransformer {
|
||||
versions,
|
||||
for_loop_backups: Vec::default(),
|
||||
|
@ -568,7 +568,8 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
|
|||
e: TypedExpressionList<'ast, T>,
|
||||
) -> TypedExpressionList<'ast, T> {
|
||||
match e {
|
||||
TypedExpressionList::FunctionCall(ref k, _, _) => {
|
||||
TypedExpressionList::FunctionCall(ref k, ref a, _) => {
|
||||
println!("{:#?}", a.iter().map(|a| a.get_type()).collect::<Vec<_>>());
|
||||
if !k.id.starts_with("_") {
|
||||
self.blocked = true;
|
||||
}
|
||||
|
@ -929,7 +930,7 @@ mod tests {
|
|||
let s: TypedStatement<Bn128Field> = TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element("a")],
|
||||
TypedExpressionList::FunctionCall(
|
||||
DeclarationFunctionKey::with_id("foo").signature(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -943,7 +944,7 @@ mod tests {
|
|||
vec![TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element(Identifier::from("a").version(1))],
|
||||
TypedExpressionList::FunctionCall(
|
||||
DeclarationFunctionKey::with_id("foo").signature(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement])
|
||||
|
@ -1447,7 +1448,7 @@ mod tests {
|
|||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element("a").into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
DeclarationFunctionKey::with_id("foo"),
|
||||
DeclarationFunctionKey::with_location("main", "foo"),
|
||||
vec![FieldElementExpression::Identifier("a".into()).into()],
|
||||
vec![Type::FieldElement],
|
||||
),
|
||||
|
@ -1463,7 +1464,7 @@ mod tests {
|
|||
FieldElementExpression::mult(
|
||||
FieldElementExpression::Identifier("a".into()),
|
||||
FieldElementExpression::FunctionCall(
|
||||
FunctionKey::with_id("foo"),
|
||||
DeclarationFunctionKey::with_location("main", "foo"),
|
||||
vec![FieldElementExpression::Identifier("a".into()).into()],
|
||||
),
|
||||
)
|
||||
|
@ -1511,7 +1512,7 @@ mod tests {
|
|||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element(Identifier::from("a").version(2)).into()],
|
||||
TypedExpressionList::FunctionCall(
|
||||
DeclarationFunctionKey::with_id("foo"),
|
||||
DeclarationFunctionKey::with_location("main", "foo"),
|
||||
vec![FieldElementExpression::Identifier(
|
||||
Identifier::from("a").version(1),
|
||||
)
|
||||
|
@ -1530,7 +1531,7 @@ mod tests {
|
|||
FieldElementExpression::mult(
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(2)),
|
||||
FieldElementExpression::FunctionCall(
|
||||
FunctionKey::with_id("foo"),
|
||||
DeclarationFunctionKey::with_location("main", "foo"),
|
||||
vec![FieldElementExpression::Identifier(
|
||||
Identifier::from("a").version(2),
|
||||
)
|
||||
|
|
|
@ -43,7 +43,7 @@ mod tests {
|
|||
fn generate_abi_from_typed_ast() {
|
||||
let mut functions = HashMap::new();
|
||||
functions.insert(
|
||||
ConcreteFunctionKey::with_id("main").into(),
|
||||
ConcreteFunctionKey::with_location("main", "main").into(),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
generics: vec![],
|
||||
arguments: vec![
|
||||
|
|
|
@ -381,80 +381,70 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
|||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
let array_ty = array.get_array_type();
|
||||
|
||||
if array_ty == target_array_ty {
|
||||
if array_ty.weak_eq(&target_array_ty) {
|
||||
return Ok(array);
|
||||
}
|
||||
|
||||
// sizes must be equal
|
||||
let converted = match target_array_ty.size == array_ty.size {
|
||||
true =>
|
||||
// elements must fit in the target type
|
||||
{
|
||||
match array.into_inner() {
|
||||
ArrayExpressionInner::Value(inline_array) => {
|
||||
match *target_array_ty.ty {
|
||||
Type::Int => Ok(inline_array),
|
||||
Type::FieldElement => {
|
||||
// try to convert all elements to field
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
FieldElementExpression::try_from_typed(e)
|
||||
.map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)
|
||||
}
|
||||
Type::Uint(bitwidth) => {
|
||||
// try to convert all elements to uint
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
UExpression::try_from_typed(e, bitwidth)
|
||||
.map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)
|
||||
}
|
||||
Type::Array(ref inner_array_ty) => {
|
||||
// try to convert all elements to array
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
ArrayExpression::try_from_typed(e, inner_array_ty.clone())
|
||||
.map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)
|
||||
}
|
||||
Type::Struct(ref struct_ty) => {
|
||||
// try to convert all elements to struct
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
StructExpression::try_from_typed(e, struct_ty.clone())
|
||||
.map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)
|
||||
}
|
||||
Type::Boolean => {
|
||||
// try to convert all elements to boolean
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
BooleanExpression::try_from_typed(e)
|
||||
.map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)
|
||||
}
|
||||
}
|
||||
// elements must fit in the target type
|
||||
let converted = match array.into_inner() {
|
||||
ArrayExpressionInner::Value(inline_array) => {
|
||||
match *target_array_ty.ty {
|
||||
Type::Int => Ok(inline_array),
|
||||
Type::FieldElement => {
|
||||
// try to convert all elements to field
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
FieldElementExpression::try_from_typed(e).map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)
|
||||
}
|
||||
Type::Uint(bitwidth) => {
|
||||
// try to convert all elements to uint
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)
|
||||
}
|
||||
Type::Array(ref inner_array_ty) => {
|
||||
// try to convert all elements to array
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
ArrayExpression::try_from_typed(e, inner_array_ty.clone())
|
||||
.map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)
|
||||
}
|
||||
Type::Struct(ref struct_ty) => {
|
||||
// try to convert all elements to struct
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
StructExpression::try_from_typed(e, struct_ty.clone())
|
||||
.map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)
|
||||
}
|
||||
Type::Boolean => {
|
||||
// try to convert all elements to boolean
|
||||
inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
BooleanExpression::try_from_typed(e).map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)
|
||||
}
|
||||
_ => unreachable!(""),
|
||||
}
|
||||
}
|
||||
false => Err(array.into()),
|
||||
_ => unreachable!(),
|
||||
}?;
|
||||
|
||||
Ok(ArrayExpressionInner::Value(converted).annotate(*target_array_ty.ty, array_ty.size))
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
pub mod abi;
|
||||
pub mod folder;
|
||||
pub mod identifier;
|
||||
pub mod result_folder;
|
||||
|
||||
mod integer;
|
||||
mod parameter;
|
||||
|
@ -19,8 +20,8 @@ pub use self::identifier::CoreIdentifier;
|
|||
pub use self::parameter::{DeclarationParameter, GParameter};
|
||||
pub use self::types::{
|
||||
ConcreteFunctionKey, ConcreteSignature, ConcreteType, DeclarationFunctionKey,
|
||||
DeclarationSignature, DeclarationType, GType, GenericIdentifier, Signature, StructType, Type,
|
||||
UBitwidth,
|
||||
DeclarationSignature, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
|
||||
Signature, StructType, Type, UBitwidth,
|
||||
};
|
||||
use typed_absy::types::GenericsAssignment;
|
||||
|
||||
|
@ -141,7 +142,7 @@ pub struct TypedModule<'ast, T> {
|
|||
#[derive(Clone, PartialEq)]
|
||||
pub enum TypedFunctionSymbol<'ast, T> {
|
||||
Here(TypedFunction<'ast, T>),
|
||||
There(DeclarationFunctionKey<'ast>, TypedModuleId),
|
||||
There(DeclarationFunctionKey<'ast>),
|
||||
Flat(FlatEmbed),
|
||||
}
|
||||
|
||||
|
@ -150,7 +151,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunctionSymbol<'ast, T> {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
TypedFunctionSymbol::Here(s) => write!(f, "Here({:?})", s),
|
||||
TypedFunctionSymbol::There(key, module) => write!(f, "There({:?}, {:?})", key, module),
|
||||
TypedFunctionSymbol::There(key) => write!(f, "There({:?})", key),
|
||||
TypedFunctionSymbol::Flat(s) => write!(f, "Flat({:?})", s),
|
||||
}
|
||||
}
|
||||
|
@ -163,8 +164,8 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> {
|
|||
) -> DeclarationSignature<'ast> {
|
||||
match self {
|
||||
TypedFunctionSymbol::Here(f) => f.signature.clone(),
|
||||
TypedFunctionSymbol::There(key, module_id) => modules
|
||||
.get(module_id)
|
||||
TypedFunctionSymbol::There(key) => modules
|
||||
.get(&key.module)
|
||||
.unwrap()
|
||||
.functions
|
||||
.get(key)
|
||||
|
@ -183,10 +184,10 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
|
|||
.iter()
|
||||
.map(|(key, symbol)| match symbol {
|
||||
TypedFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function),
|
||||
TypedFunctionSymbol::There(ref fun_key, ref module_id) => format!(
|
||||
TypedFunctionSymbol::There(ref fun_key) => format!(
|
||||
"import {} from \"{}\" as {} // with signature {}",
|
||||
fun_key.id,
|
||||
module_id.display(),
|
||||
fun_key.module.display(),
|
||||
key.id,
|
||||
key.signature
|
||||
),
|
||||
|
@ -230,12 +231,16 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"<{}>",
|
||||
self.generics
|
||||
.iter()
|
||||
.map(|g| g.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
"{}",
|
||||
if self.generics.len() > 0 {
|
||||
self.generics
|
||||
.iter()
|
||||
.map(|g| g.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
} else {
|
||||
"".to_string()
|
||||
}
|
||||
)?;
|
||||
write!(
|
||||
f,
|
||||
|
@ -271,7 +276,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
|
|||
|
||||
for s in &self.statements {
|
||||
match s {
|
||||
TypedStatement::PopCallLog(..) => tab -= 1,
|
||||
TypedStatement::PopCallLog => tab -= 1,
|
||||
_ => {}
|
||||
};
|
||||
|
||||
|
@ -380,13 +385,8 @@ pub enum TypedStatement<'ast, T> {
|
|||
),
|
||||
MultipleDefinition(Vec<Variable<'ast, T>>, TypedExpressionList<'ast, T>),
|
||||
// Aux
|
||||
PushCallLog(
|
||||
TypedModuleId,
|
||||
DeclarationFunctionKey<'ast>,
|
||||
GenericsAssignment<'ast>,
|
||||
Vec<(ConcreteVariable<'ast>, TypedExpression<'ast, T>)>,
|
||||
),
|
||||
PopCallLog(Vec<(ConcreteVariable<'ast>, TypedExpression<'ast, T>)>),
|
||||
PushCallLog(DeclarationFunctionKey<'ast>, GenericsAssignment<'ast>),
|
||||
PopCallLog,
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Debug> fmt::Debug for TypedStatement<'ast, T> {
|
||||
|
@ -417,16 +417,10 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedStatement<'ast, T> {
|
|||
TypedStatement::MultipleDefinition(ref lhs, ref rhs) => {
|
||||
write!(f, "MultipleDefinition({:?}, {:?})", lhs, rhs)
|
||||
}
|
||||
TypedStatement::PushCallLog(ref module_id, ref key, ref generics, ref assignments) => {
|
||||
write!(
|
||||
f,
|
||||
"PushCallLog({:?}, {:?}, {:?}, {:?})",
|
||||
module_id, key, generics, assignments
|
||||
)
|
||||
}
|
||||
TypedStatement::PopCallLog(ref assignments) => {
|
||||
write!(f, "PopCallLog({:?})", assignments)
|
||||
TypedStatement::PushCallLog(ref key, ref generics) => {
|
||||
write!(f, "PushCallLog({:?}, {:?})", key, generics)
|
||||
}
|
||||
TypedStatement::PopCallLog => write!(f, "PopCallLog"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -480,29 +474,14 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
|
|||
}
|
||||
write!(f, " = {}", rhs)
|
||||
}
|
||||
TypedStatement::PushCallLog(ref module_id, ref key, ref generics, ref assignments) => {
|
||||
write!(
|
||||
f,
|
||||
"// PUSH CALL TO {}_{}::<{}> with {}",
|
||||
module_id.display(),
|
||||
key.id,
|
||||
generics,
|
||||
assignments
|
||||
.iter()
|
||||
.map(|(v, e)| format!("{} := {}", v, e))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)
|
||||
}
|
||||
TypedStatement::PopCallLog(ref assignments) => write!(
|
||||
TypedStatement::PushCallLog(ref key, ref generics) => write!(
|
||||
f,
|
||||
"// POP CALL with {}",
|
||||
assignments
|
||||
.iter()
|
||||
.map(|(v, e)| format!("{} := {}", v, e))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
"// PUSH CALL TO {}/{}::<{}>",
|
||||
key.module.display(),
|
||||
key.id,
|
||||
generics,
|
||||
),
|
||||
TypedStatement::PopCallLog => write!(f, "// POP CALL",),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -735,7 +714,7 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FunctionCall(FunctionKey<'ast, T>, Vec<TypedExpression<'ast, T>>),
|
||||
FunctionCall(DeclarationFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
|
||||
Member(Box<StructExpression<'ast, T>>, MemberId),
|
||||
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
}
|
||||
|
@ -822,7 +801,7 @@ pub enum BooleanExpression<'ast, T> {
|
|||
Box<BooleanExpression<'ast, T>>,
|
||||
),
|
||||
Member(Box<StructExpression<'ast, T>>, MemberId),
|
||||
FunctionCall(FunctionKey<'ast, T>, Vec<TypedExpression<'ast, T>>),
|
||||
FunctionCall(DeclarationFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
|
||||
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
}
|
||||
|
||||
|
@ -848,7 +827,7 @@ pub struct ArrayExpression<'ast, T> {
|
|||
pub enum ArrayExpressionInner<'ast, T> {
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(Vec<TypedExpression<'ast, T>>),
|
||||
FunctionCall(FunctionKey<'ast, T>, Vec<TypedExpression<'ast, T>>),
|
||||
FunctionCall(DeclarationFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
|
||||
IfElse(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<ArrayExpression<'ast, T>>,
|
||||
|
@ -939,7 +918,7 @@ impl<'ast, T> StructExpression<'ast, T> {
|
|||
pub enum StructExpressionInner<'ast, T> {
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(Vec<TypedExpression<'ast, T>>),
|
||||
FunctionCall(FunctionKey<'ast, T>, Vec<TypedExpression<'ast, T>>),
|
||||
FunctionCall(DeclarationFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
|
||||
IfElse(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<StructExpression<'ast, T>>,
|
||||
|
@ -1328,7 +1307,9 @@ impl<'ast, T: Field> From<Variable<'ast, T>> for TypedExpression<'ast, T> {
|
|||
Type::Array(ty) => ArrayExpressionInner::Identifier(v.id)
|
||||
.annotate(*ty.ty, ty.size)
|
||||
.into(),
|
||||
_ => unimplemented!(),
|
||||
Type::Struct(ty) => StructExpressionInner::Identifier(v.id).annotate(ty).into(),
|
||||
Type::Uint(w) => UExpressionInner::Identifier(v.id).annotate(w).into(),
|
||||
Type::Int => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1536,3 +1517,69 @@ impl<'ast, T: Clone> Member<'ast, T> for StructExpression<'ast, T> {
|
|||
StructExpressionInner::Member(box s, member_id).annotate(members)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FunctionCall<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self;
|
||||
}
|
||||
|
||||
impl<'ast, T> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self {
|
||||
FieldElementExpression::FunctionCall(key, arguments)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> FunctionCall<'ast, T> for BooleanExpression<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self {
|
||||
BooleanExpression::FunctionCall(key, arguments)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> FunctionCall<'ast, T> for UExpression<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self {
|
||||
let bitwidth = match &key.signature.outputs[0] {
|
||||
DeclarationType::Uint(bitwidth) => bitwidth.clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
UExpressionInner::FunctionCall(key, arguments).annotate(bitwidth)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> FunctionCall<'ast, T> for ArrayExpression<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self {
|
||||
let array_ty = match &key.signature.outputs[0] {
|
||||
DeclarationType::Array(array_ty) => array_ty.clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
ArrayExpressionInner::FunctionCall(key, arguments)
|
||||
.annotate(Type::<T>::from(*array_ty.ty), array_ty.size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> FunctionCall<'ast, T> for StructExpression<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self {
|
||||
let struct_ty = match &key.signature.outputs[0] {
|
||||
DeclarationType::Struct(struct_ty) => struct_ty.clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
StructExpressionInner::FunctionCall(key, arguments).annotate(StructType::from(struct_ty))
|
||||
}
|
||||
}
|
||||
|
|
667
zokrates_core/src/typed_absy/result_folder.rs
Normal file
667
zokrates_core/src/typed_absy/result_folder.rs
Normal file
|
@ -0,0 +1,667 @@
|
|||
// Generic walk through a typed AST. Not mutating in place
|
||||
|
||||
use crate::typed_absy::*;
|
||||
use typed_absy::types::{ArrayType, StructMember, StructType};
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub trait ResultFolder<'ast, T: Field>: Sized {
|
||||
type Error;
|
||||
|
||||
fn fold_function_symbol(
|
||||
&mut self,
|
||||
s: TypedFunctionSymbol<'ast, T>,
|
||||
) -> Result<TypedFunctionSymbol<'ast, T>, Self::Error> {
|
||||
fold_function_symbol(self, s)
|
||||
}
|
||||
|
||||
fn fold_function(
|
||||
&mut self,
|
||||
f: TypedFunction<'ast, T>,
|
||||
) -> Result<TypedFunction<'ast, T>, Self::Error> {
|
||||
fold_function(self, f)
|
||||
}
|
||||
|
||||
fn fold_parameter(
|
||||
&mut self,
|
||||
p: DeclarationParameter<'ast>,
|
||||
) -> Result<DeclarationParameter<'ast>, Self::Error> {
|
||||
Ok(DeclarationParameter {
|
||||
id: self.fold_declaration_variable(p.id)?,
|
||||
..p
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> {
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
fn fold_variable(&mut self, v: Variable<'ast, T>) -> Result<Variable<'ast, T>, Self::Error> {
|
||||
Ok(Variable {
|
||||
id: self.fold_name(v.id)?,
|
||||
_type: self.fold_type(v._type)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_declaration_variable(
|
||||
&mut self,
|
||||
v: DeclarationVariable<'ast>,
|
||||
) -> Result<DeclarationVariable<'ast>, Self::Error> {
|
||||
Ok(DeclarationVariable {
|
||||
id: self.fold_name(v.id)?,
|
||||
_type: self.fold_declaration_type(v._type)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_type(&mut self, t: Type<'ast, T>) -> Result<Type<'ast, T>, Self::Error> {
|
||||
use self::GType::*;
|
||||
|
||||
match t {
|
||||
Array(array_type) => Ok(Array(ArrayType {
|
||||
ty: box self.fold_type(*array_type.ty)?,
|
||||
size: self.fold_uint_expression(array_type.size)?,
|
||||
})),
|
||||
Struct(struct_type) => Ok(Struct(StructType {
|
||||
members: struct_type
|
||||
.members
|
||||
.into_iter()
|
||||
.map(|m| {
|
||||
self.fold_type(*m.ty.clone())
|
||||
.map(|ty| StructMember { ty: box ty, ..m })
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
..struct_type
|
||||
})),
|
||||
t => Ok(t),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_declaration_type(
|
||||
&mut self,
|
||||
t: DeclarationType<'ast>,
|
||||
) -> Result<DeclarationType<'ast>, Self::Error> {
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
fn fold_assignee(
|
||||
&mut self,
|
||||
a: TypedAssignee<'ast, T>,
|
||||
) -> Result<TypedAssignee<'ast, T>, Self::Error> {
|
||||
match a {
|
||||
TypedAssignee::Identifier(v) => Ok(TypedAssignee::Identifier(self.fold_variable(v)?)),
|
||||
TypedAssignee::Select(box a, box index) => Ok(TypedAssignee::Select(
|
||||
box self.fold_assignee(a)?,
|
||||
box self.fold_uint_expression(index)?,
|
||||
)),
|
||||
TypedAssignee::Member(box s, m) => {
|
||||
Ok(TypedAssignee::Member(box self.fold_assignee(s)?, m))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_statement(
|
||||
&mut self,
|
||||
s: TypedStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
|
||||
fold_statement(self, s)
|
||||
}
|
||||
|
||||
fn fold_expression(
|
||||
&mut self,
|
||||
e: TypedExpression<'ast, T>,
|
||||
) -> Result<TypedExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
TypedExpression::FieldElement(e) => Ok(self.fold_field_expression(e)?.into()),
|
||||
TypedExpression::Boolean(e) => Ok(self.fold_boolean_expression(e)?.into()),
|
||||
TypedExpression::Uint(e) => Ok(self.fold_uint_expression(e)?.into()),
|
||||
TypedExpression::Array(e) => Ok(self.fold_array_expression(e)?.into()),
|
||||
TypedExpression::Struct(e) => Ok(self.fold_struct_expression(e)?.into()),
|
||||
TypedExpression::Int(e) => Ok(self.fold_int_expression(e)?.into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_array_expression(
|
||||
&mut self,
|
||||
e: ArrayExpression<'ast, T>,
|
||||
) -> Result<ArrayExpression<'ast, T>, Self::Error> {
|
||||
fold_array_expression(self, e)
|
||||
}
|
||||
|
||||
fn fold_struct_expression(
|
||||
&mut self,
|
||||
e: StructExpression<'ast, T>,
|
||||
) -> Result<StructExpression<'ast, T>, Self::Error> {
|
||||
fold_struct_expression(self, e)
|
||||
}
|
||||
|
||||
fn fold_expression_list(
|
||||
&mut self,
|
||||
es: TypedExpressionList<'ast, T>,
|
||||
) -> Result<TypedExpressionList<'ast, T>, Self::Error> {
|
||||
fold_expression_list(self, es)
|
||||
}
|
||||
|
||||
fn fold_int_expression(
|
||||
&mut self,
|
||||
e: IntExpression<'ast, T>,
|
||||
) -> Result<IntExpression<'ast, T>, Self::Error> {
|
||||
fold_int_expression(self, e)
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
&mut self,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
|
||||
fold_field_expression(self, e)
|
||||
}
|
||||
fn fold_boolean_expression(
|
||||
&mut self,
|
||||
e: BooleanExpression<'ast, T>,
|
||||
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
|
||||
fold_boolean_expression(self, e)
|
||||
}
|
||||
fn fold_uint_expression(
|
||||
&mut self,
|
||||
e: UExpression<'ast, T>,
|
||||
) -> Result<UExpression<'ast, T>, Self::Error> {
|
||||
fold_uint_expression(self, e)
|
||||
}
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
bitwidth: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, Self::Error> {
|
||||
fold_uint_expression_inner(self, bitwidth, e)
|
||||
}
|
||||
|
||||
fn fold_array_expression_inner(
|
||||
&mut self,
|
||||
ty: &Type<'ast, T>,
|
||||
size: UExpression<'ast, T>,
|
||||
e: ArrayExpressionInner<'ast, T>,
|
||||
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
|
||||
fold_array_expression_inner(self, ty, size, e)
|
||||
}
|
||||
fn fold_struct_expression_inner(
|
||||
&mut self,
|
||||
ty: &StructType<'ast, T>,
|
||||
e: StructExpressionInner<'ast, T>,
|
||||
) -> Result<StructExpressionInner<'ast, T>, Self::Error> {
|
||||
fold_struct_expression_inner(self, ty, e)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: TypedStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedStatement<'ast, T>>, F::Error> {
|
||||
let res = match s {
|
||||
TypedStatement::Return(expressions) => TypedStatement::Return(
|
||||
expressions
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
TypedStatement::Definition(a, e) => {
|
||||
TypedStatement::Definition(f.fold_assignee(a)?, f.fold_expression(e)?)
|
||||
}
|
||||
TypedStatement::Declaration(v) => TypedStatement::Declaration(f.fold_variable(v)?),
|
||||
TypedStatement::Assertion(e) => TypedStatement::Assertion(f.fold_boolean_expression(e)?),
|
||||
TypedStatement::For(v, from, to, statements) => TypedStatement::For(
|
||||
f.fold_variable(v)?,
|
||||
f.fold_uint_expression(from)?,
|
||||
f.fold_uint_expression(to)?,
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|s| f.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
),
|
||||
TypedStatement::MultipleDefinition(variables, elist) => TypedStatement::MultipleDefinition(
|
||||
variables
|
||||
.into_iter()
|
||||
.map(|v| f.fold_variable(v))
|
||||
.collect::<Result<_, _>>()?,
|
||||
f.fold_expression_list(elist)?,
|
||||
),
|
||||
s => s,
|
||||
};
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
_: &Type<'ast, T>,
|
||||
_: UExpression<'ast, T>,
|
||||
e: ArrayExpressionInner<'ast, T>,
|
||||
) -> Result<ArrayExpressionInner<'ast, T>, F::Error> {
|
||||
let e = match e {
|
||||
ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)?),
|
||||
ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value(
|
||||
exprs
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
ArrayExpressionInner::FunctionCall(id, exps) => {
|
||||
let exps = exps
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
ArrayExpressionInner::FunctionCall(id, exps)
|
||||
}
|
||||
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
||||
ArrayExpressionInner::IfElse(
|
||||
box f.fold_boolean_expression(condition)?,
|
||||
box f.fold_array_expression(consequence)?,
|
||||
box f.fold_array_expression(alternative)?,
|
||||
)
|
||||
}
|
||||
ArrayExpressionInner::Member(box s, id) => {
|
||||
let s = f.fold_struct_expression(s)?;
|
||||
ArrayExpressionInner::Member(box s, id)
|
||||
}
|
||||
ArrayExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array)?;
|
||||
let index = f.fold_uint_expression(index)?;
|
||||
ArrayExpressionInner::Select(box array, box index)
|
||||
}
|
||||
};
|
||||
Ok(e)
|
||||
}
|
||||
|
||||
pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
_: &StructType<'ast, T>,
|
||||
e: StructExpressionInner<'ast, T>,
|
||||
) -> Result<StructExpressionInner<'ast, T>, F::Error> {
|
||||
let e = match e {
|
||||
StructExpressionInner::Identifier(id) => {
|
||||
StructExpressionInner::Identifier(f.fold_name(id)?)
|
||||
}
|
||||
StructExpressionInner::Value(exprs) => StructExpressionInner::Value(
|
||||
exprs
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
StructExpressionInner::FunctionCall(id, exps) => {
|
||||
let exps = exps
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
StructExpressionInner::FunctionCall(id, exps)
|
||||
}
|
||||
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
||||
StructExpressionInner::IfElse(
|
||||
box f.fold_boolean_expression(condition)?,
|
||||
box f.fold_struct_expression(consequence)?,
|
||||
box f.fold_struct_expression(alternative)?,
|
||||
)
|
||||
}
|
||||
StructExpressionInner::Member(box s, id) => {
|
||||
let s = f.fold_struct_expression(s)?;
|
||||
StructExpressionInner::Member(box s, id)
|
||||
}
|
||||
StructExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array)?;
|
||||
let index = f.fold_uint_expression(index)?;
|
||||
StructExpressionInner::Select(box array, box index)
|
||||
}
|
||||
};
|
||||
Ok(e)
|
||||
}
|
||||
|
||||
pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
) -> Result<FieldElementExpression<'ast, T>, F::Error> {
|
||||
let e = match e {
|
||||
FieldElementExpression::Number(n) => FieldElementExpression::Number(n),
|
||||
FieldElementExpression::Identifier(id) => {
|
||||
FieldElementExpression::Identifier(f.fold_name(id)?)
|
||||
}
|
||||
FieldElementExpression::Add(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
FieldElementExpression::Add(box e1, box e2)
|
||||
}
|
||||
FieldElementExpression::Sub(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
FieldElementExpression::Sub(box e1, box e2)
|
||||
}
|
||||
FieldElementExpression::Mult(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
FieldElementExpression::Mult(box e1, box e2)
|
||||
}
|
||||
FieldElementExpression::Div(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
FieldElementExpression::Div(box e1, box e2)
|
||||
}
|
||||
FieldElementExpression::Pow(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
FieldElementExpression::Pow(box e1, box e2)
|
||||
}
|
||||
FieldElementExpression::IfElse(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond)?;
|
||||
let cons = f.fold_field_expression(cons)?;
|
||||
let alt = f.fold_field_expression(alt)?;
|
||||
FieldElementExpression::IfElse(box cond, box cons, box alt)
|
||||
}
|
||||
FieldElementExpression::FunctionCall(key, exps) => {
|
||||
let exps = exps
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
FieldElementExpression::FunctionCall(key, exps)
|
||||
}
|
||||
FieldElementExpression::Member(box s, id) => {
|
||||
let s = f.fold_struct_expression(s)?;
|
||||
FieldElementExpression::Member(box s, id)
|
||||
}
|
||||
FieldElementExpression::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array)?;
|
||||
let index = f.fold_uint_expression(index)?;
|
||||
FieldElementExpression::Select(box array, box index)
|
||||
}
|
||||
};
|
||||
Ok(e)
|
||||
}
|
||||
|
||||
pub fn fold_int_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
_: &mut F,
|
||||
_: IntExpression<'ast, T>,
|
||||
) -> Result<IntExpression<'ast, T>, F::Error> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: BooleanExpression<'ast, T>,
|
||||
) -> Result<BooleanExpression<'ast, T>, F::Error> {
|
||||
let e = match e {
|
||||
BooleanExpression::Value(v) => BooleanExpression::Value(v),
|
||||
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?),
|
||||
BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::BoolEq(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1)?;
|
||||
let e2 = f.fold_boolean_expression(e2)?;
|
||||
BooleanExpression::BoolEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::ArrayEq(box e1, box e2) => {
|
||||
let e1 = f.fold_array_expression(e1)?;
|
||||
let e2 = f.fold_array_expression(e2)?;
|
||||
BooleanExpression::ArrayEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::StructEq(box e1, box e2) => {
|
||||
let e1 = f.fold_struct_expression(e1)?;
|
||||
let e2 = f.fold_struct_expression(e2)?;
|
||||
BooleanExpression::StructEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintEq(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::Or(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1)?;
|
||||
let e2 = f.fold_boolean_expression(e2)?;
|
||||
BooleanExpression::Or(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::And(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1)?;
|
||||
let e2 = f.fold_boolean_expression(e2)?;
|
||||
BooleanExpression::And(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::Not(box e) => {
|
||||
let e = f.fold_boolean_expression(e)?;
|
||||
BooleanExpression::Not(box e)
|
||||
}
|
||||
BooleanExpression::FunctionCall(key, exps) => {
|
||||
let exps = exps
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
BooleanExpression::FunctionCall(key, exps)
|
||||
}
|
||||
BooleanExpression::IfElse(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond)?;
|
||||
let cons = f.fold_boolean_expression(cons)?;
|
||||
let alt = f.fold_boolean_expression(alt)?;
|
||||
BooleanExpression::IfElse(box cond, box cons, box alt)
|
||||
}
|
||||
BooleanExpression::Member(box s, id) => {
|
||||
let s = f.fold_struct_expression(s)?;
|
||||
BooleanExpression::Member(box s, id)
|
||||
}
|
||||
BooleanExpression::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array)?;
|
||||
let index = f.fold_uint_expression(index)?;
|
||||
BooleanExpression::Select(box array, box index)
|
||||
}
|
||||
};
|
||||
Ok(e)
|
||||
}
|
||||
|
||||
pub fn fold_uint_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: UExpression<'ast, T>,
|
||||
) -> Result<UExpression<'ast, T>, F::Error> {
|
||||
Ok(UExpression {
|
||||
inner: f.fold_uint_expression_inner(e.bitwidth, e.inner)?,
|
||||
..e
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
_: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, F::Error> {
|
||||
let e = match e {
|
||||
UExpressionInner::Value(v) => UExpressionInner::Value(v),
|
||||
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)?),
|
||||
UExpressionInner::Add(box left, box right) => {
|
||||
let left = f.fold_uint_expression(left)?;
|
||||
let right = f.fold_uint_expression(right)?;
|
||||
|
||||
UExpressionInner::Add(box left, box right)
|
||||
}
|
||||
UExpressionInner::Sub(box left, box right) => {
|
||||
let left = f.fold_uint_expression(left)?;
|
||||
let right = f.fold_uint_expression(right)?;
|
||||
|
||||
UExpressionInner::Sub(box left, box right)
|
||||
}
|
||||
UExpressionInner::Mult(box left, box right) => {
|
||||
let left = f.fold_uint_expression(left)?;
|
||||
let right = f.fold_uint_expression(right)?;
|
||||
|
||||
UExpressionInner::Mult(box left, box right)
|
||||
}
|
||||
UExpressionInner::Xor(box left, box right) => {
|
||||
let left = f.fold_uint_expression(left)?;
|
||||
let right = f.fold_uint_expression(right)?;
|
||||
|
||||
UExpressionInner::Xor(box left, box right)
|
||||
}
|
||||
UExpressionInner::And(box left, box right) => {
|
||||
let left = f.fold_uint_expression(left)?;
|
||||
let right = f.fold_uint_expression(right)?;
|
||||
|
||||
UExpressionInner::And(box left, box right)
|
||||
}
|
||||
UExpressionInner::Or(box left, box right) => {
|
||||
let left = f.fold_uint_expression(left)?;
|
||||
let right = f.fold_uint_expression(right)?;
|
||||
|
||||
UExpressionInner::Or(box left, box right)
|
||||
}
|
||||
UExpressionInner::LeftShift(box e, box by) => {
|
||||
let e = f.fold_uint_expression(e)?;
|
||||
let by = f.fold_field_expression(by)?;
|
||||
|
||||
UExpressionInner::LeftShift(box e, box by)
|
||||
}
|
||||
UExpressionInner::RightShift(box e, box by) => {
|
||||
let e = f.fold_uint_expression(e)?;
|
||||
let by = f.fold_field_expression(by)?;
|
||||
|
||||
UExpressionInner::RightShift(box e, box by)
|
||||
}
|
||||
UExpressionInner::Not(box e) => {
|
||||
let e = f.fold_uint_expression(e)?;
|
||||
|
||||
UExpressionInner::Not(box e)
|
||||
}
|
||||
UExpressionInner::FunctionCall(key, exps) => {
|
||||
let exps = exps
|
||||
.into_iter()
|
||||
.map(|e| f.fold_expression(e))
|
||||
.collect::<Result<_, _>>()?;
|
||||
UExpressionInner::FunctionCall(key, exps)
|
||||
}
|
||||
UExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array)?;
|
||||
let index = f.fold_uint_expression(index)?;
|
||||
UExpressionInner::Select(box array, box index)
|
||||
}
|
||||
UExpressionInner::IfElse(box cond, box cons, box alt) => {
|
||||
let cond = f.fold_boolean_expression(cond)?;
|
||||
let cons = f.fold_uint_expression(cons)?;
|
||||
let alt = f.fold_uint_expression(alt)?;
|
||||
UExpressionInner::IfElse(box cond, box cons, box alt)
|
||||
}
|
||||
UExpressionInner::Member(box s, id) => {
|
||||
let s = f.fold_struct_expression(s)?;
|
||||
UExpressionInner::Member(box s, id)
|
||||
}
|
||||
};
|
||||
Ok(e)
|
||||
}
|
||||
|
||||
pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
fun: TypedFunction<'ast, T>,
|
||||
) -> Result<TypedFunction<'ast, T>, F::Error> {
|
||||
Ok(TypedFunction {
|
||||
arguments: fun
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|a| f.fold_parameter(a))
|
||||
.collect::<Result<_, _>>()?,
|
||||
statements: fun
|
||||
.statements
|
||||
.into_iter()
|
||||
.map(|s| f.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
..fun
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: ArrayExpression<'ast, T>,
|
||||
) -> Result<ArrayExpression<'ast, T>, F::Error> {
|
||||
let size = f.fold_uint_expression(e.size)?;
|
||||
|
||||
Ok(ArrayExpression {
|
||||
inner: f.fold_array_expression_inner(&e.ty, size.clone(), e.inner)?,
|
||||
size,
|
||||
..e
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_expression_list<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
es: TypedExpressionList<'ast, T>,
|
||||
) -> Result<TypedExpressionList<'ast, T>, F::Error> {
|
||||
match es {
|
||||
TypedExpressionList::FunctionCall(id, arguments, types) => {
|
||||
Ok(TypedExpressionList::FunctionCall(
|
||||
id,
|
||||
arguments
|
||||
.into_iter()
|
||||
.map(|a| f.fold_expression(a))
|
||||
.collect::<Result<_, _>>()?,
|
||||
types
|
||||
.into_iter()
|
||||
.map(|t| f.fold_type(t))
|
||||
.collect::<Result<_, _>>()?,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: StructExpression<'ast, T>,
|
||||
) -> Result<StructExpression<'ast, T>, F::Error> {
|
||||
Ok(StructExpression {
|
||||
inner: f.fold_struct_expression_inner(&e.ty, e.inner)?,
|
||||
..e
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: TypedFunctionSymbol<'ast, T>,
|
||||
) -> Result<TypedFunctionSymbol<'ast, T>, F::Error> {
|
||||
match s {
|
||||
TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)),
|
||||
there => Ok(there), // by default, do not fold modules recursively
|
||||
}
|
||||
}
|
|
@ -3,7 +3,7 @@ use std::fmt;
|
|||
use std::hash::{Hash, Hasher};
|
||||
use std::path::{Path, PathBuf};
|
||||
use typed_absy::{TryFrom, TryInto};
|
||||
use typed_absy::{UExpression, UExpressionInner};
|
||||
use typed_absy::{TypedModuleId, UExpression, UExpressionInner};
|
||||
|
||||
pub type GenericIdentifier<'ast> = &'ast str;
|
||||
|
||||
|
@ -15,7 +15,7 @@ pub enum Constant<'ast> {
|
|||
|
||||
// At this stage we want all constants to be equal
|
||||
impl<'ast> PartialEq for Constant<'ast> {
|
||||
fn eq(&self, _: &Self) -> bool {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
@ -193,6 +193,14 @@ impl<'ast, T: PartialEq> PartialEq<DeclarationArrayType<'ast>> for ArrayType<'as
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: PartialEq> ArrayType<'ast, T> {
|
||||
// array type equality with non-strict size checks
|
||||
// sizes always match unless they are different constants
|
||||
pub fn weak_eq(&self, other: &Self) -> bool {
|
||||
self.ty == other.ty
|
||||
}
|
||||
}
|
||||
|
||||
fn try_from_g_array_type<T: TryInto<U>, U>(t: GArrayType<T>) -> Result<GArrayType<U>, ()> {
|
||||
Ok(GArrayType {
|
||||
size: t.size.try_into().map_err(|_| ())?,
|
||||
|
@ -676,6 +684,7 @@ pub type FunctionIdentifier<'ast> = &'ast str;
|
|||
|
||||
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
|
||||
pub struct GFunctionKey<'ast, S> {
|
||||
pub module: TypedModuleId,
|
||||
pub id: FunctionIdentifier<'ast>,
|
||||
pub signature: GSignature<S>,
|
||||
}
|
||||
|
@ -703,12 +712,13 @@ impl<'ast> fmt::Display for GenericsAssignment<'ast> {
|
|||
|
||||
impl<'ast> PartialEq<DeclarationFunctionKey<'ast>> for ConcreteFunctionKey<'ast> {
|
||||
fn eq(&self, other: &DeclarationFunctionKey<'ast>) -> bool {
|
||||
self.id == other.id && self.signature == other.signature
|
||||
self.module == other.module && self.id == other.id && self.signature == other.signature
|
||||
}
|
||||
}
|
||||
|
||||
fn try_from_g_function_key<T: TryInto<U>, U>(k: GFunctionKey<T>) -> Result<GFunctionKey<U>, ()> {
|
||||
Ok(GFunctionKey {
|
||||
module: k.module,
|
||||
signature: signature::try_from_g_signature(k.signature)?,
|
||||
id: k.id,
|
||||
})
|
||||
|
@ -749,8 +759,12 @@ impl<'ast, T> From<DeclarationFunctionKey<'ast>> for FunctionKey<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, S> GFunctionKey<'ast, S> {
|
||||
pub fn with_id<U: Into<FunctionIdentifier<'ast>>>(id: U) -> Self {
|
||||
pub fn with_location<T: Into<TypedModuleId>, U: Into<FunctionIdentifier<'ast>>>(
|
||||
module: T,
|
||||
id: U,
|
||||
) -> Self {
|
||||
GFunctionKey {
|
||||
module: module.into(),
|
||||
id: id.into(),
|
||||
signature: GSignature::new(),
|
||||
}
|
||||
|
@ -765,11 +779,21 @@ impl<'ast, S> GFunctionKey<'ast, S> {
|
|||
self.id = id.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn module<T: Into<TypedModuleId>>(mut self, module: T) -> Self {
|
||||
self.module = module.into();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> ConcreteFunctionKey<'ast> {
|
||||
pub fn to_slug(&self) -> String {
|
||||
format!("{}_{}", self.id, self.signature.to_slug())
|
||||
format!(
|
||||
"{}/{}_{}",
|
||||
self.module.display(),
|
||||
self.id,
|
||||
self.signature.to_slug()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ pub enum UExpressionInner<'ast, T> {
|
|||
Box<UExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FunctionCall(FunctionKey<'ast, T>, Vec<TypedExpression<'ast, T>>),
|
||||
FunctionCall(DeclarationFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
|
||||
IfElse(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<UExpression<'ast, T>>,
|
||||
|
|
|
@ -4,6 +4,7 @@ use zir;
|
|||
impl<'ast> From<typed_absy::types::ConcreteFunctionKey<'ast>> for zir::types::FunctionKey<'ast> {
|
||||
fn from(k: typed_absy::types::ConcreteFunctionKey<'ast>) -> zir::types::FunctionKey<'ast> {
|
||||
zir::types::FunctionKey {
|
||||
module: k.module,
|
||||
id: k.id,
|
||||
signature: k.signature.into(),
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ pub struct ZirModule<'ast, T> {
|
|||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ZirFunctionSymbol<'ast, T> {
|
||||
Here(ZirFunction<'ast, T>),
|
||||
There(FunctionKey<'ast>, ZirModuleId),
|
||||
There(FunctionKey<'ast>),
|
||||
Flat(FlatEmbed),
|
||||
}
|
||||
|
||||
|
@ -82,8 +82,8 @@ impl<'ast, T> ZirFunctionSymbol<'ast, T> {
|
|||
pub fn signature<'a>(&'a self, modules: &'a ZirModules<T>) -> Signature {
|
||||
match self {
|
||||
ZirFunctionSymbol::Here(f) => f.signature.clone(),
|
||||
ZirFunctionSymbol::There(key, module_id) => modules
|
||||
.get(module_id)
|
||||
ZirFunctionSymbol::There(key) => modules
|
||||
.get(&key.module)
|
||||
.unwrap()
|
||||
.functions
|
||||
.get(key)
|
||||
|
@ -102,10 +102,10 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirModule<'ast, T> {
|
|||
.iter()
|
||||
.map(|(key, symbol)| match symbol {
|
||||
ZirFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function),
|
||||
ZirFunctionSymbol::There(ref fun_key, ref module_id) => format!(
|
||||
ZirFunctionSymbol::There(ref fun_key) => format!(
|
||||
"import {} from \"{}\" as {} // with signature {}",
|
||||
fun_key.id,
|
||||
module_id.display(),
|
||||
fun_key.module.display(),
|
||||
key.id,
|
||||
key.signature
|
||||
),
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use std::fmt;
|
||||
use zir::ZirModuleId;
|
||||
|
||||
pub type Identifier<'ast> = &'ast str;
|
||||
|
||||
|
@ -95,13 +96,18 @@ pub type FunctionIdentifier<'ast> = &'ast str;
|
|||
|
||||
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
|
||||
pub struct FunctionKey<'ast> {
|
||||
pub module: ZirModuleId,
|
||||
pub id: FunctionIdentifier<'ast>,
|
||||
pub signature: Signature,
|
||||
}
|
||||
|
||||
impl<'ast> FunctionKey<'ast> {
|
||||
pub fn with_id<S: Into<Identifier<'ast>>>(id: S) -> Self {
|
||||
pub fn with_location<T: Into<ZirModuleId>, S: Into<Identifier<'ast>>>(
|
||||
module: T,
|
||||
id: S,
|
||||
) -> Self {
|
||||
FunctionKey {
|
||||
module: module.into(),
|
||||
id: id.into(),
|
||||
signature: Signature::new(),
|
||||
}
|
||||
|
@ -117,6 +123,11 @@ impl<'ast> FunctionKey<'ast> {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn module<T: Into<ZirModuleId>>(mut self, module: T) -> Self {
|
||||
self.module = module.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn to_slug(&self) -> String {
|
||||
format!("{}_{}", self.id, self.signature.to_slug())
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import "utils/casts/u32_to_field" as to_field
|
||||
|
||||
def main(field x) -> field:
|
||||
field f = 1
|
||||
field counter = 0
|
||||
for field i in 1..5 do
|
||||
f = if counter == x then f else f * i fi
|
||||
for u32 i in 1..5 do
|
||||
f = if counter == x then f else f * to_field(i) fi
|
||||
counter = if counter == x then counter else counter + 1 fi
|
||||
endfor
|
||||
return f
|
4
zokrates_core_test/tests/tests/generics/call.json
Normal file
4
zokrates_core_test/tests/tests/generics/call.json
Normal file
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"curves": ["Bn128", "Bls12"],
|
||||
"tests": []
|
||||
}
|
8
zokrates_core_test/tests/tests/generics/call.zok
Normal file
8
zokrates_core_test/tests/tests/generics/call.zok
Normal file
|
@ -0,0 +1,8 @@
|
|||
def foo<T>(field[T] b) -> field:
|
||||
return 1
|
||||
|
||||
def bar<T>(field[T] b) -> field:
|
||||
return foo(b)
|
||||
|
||||
def main(field[3] a) -> field:
|
||||
return foo(a) + bar(a)
|
23
zokrates_core_test/tests/tests/generics/multidef.json
Normal file
23
zokrates_core_test/tests/tests/generics/multidef.json
Normal file
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"curves": ["Bn128", "Bls12"],
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": [
|
||||
"1",
|
||||
"2",
|
||||
"3"
|
||||
]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": [
|
||||
"1",
|
||||
"2",
|
||||
"3"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
6
zokrates_core_test/tests/tests/generics/multidef.zok
Normal file
6
zokrates_core_test/tests/tests/generics/multidef.zok
Normal file
|
@ -0,0 +1,6 @@
|
|||
def foo<T>(field[T] b) -> field[T]:
|
||||
return b
|
||||
|
||||
def main(field[3] a) -> field[3]:
|
||||
field[3] res = foo(a)
|
||||
return res
|
10
zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok
Normal file
10
zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok
Normal file
|
@ -0,0 +1,10 @@
|
|||
import "EMBED/u32_to_bits" as to_bits
|
||||
|
||||
def main(u32 i) -> field:
|
||||
bool[32] bits = to_bits(i)
|
||||
field res = 0
|
||||
for u32 j in 0..32 do
|
||||
u32 exponent = 32 - j - 1
|
||||
res = res + 2 ** exponent
|
||||
endfor
|
||||
return res
|
|
@ -13,7 +13,7 @@ enum Curve {
|
|||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
struct Tests {
|
||||
pub entry_point: PathBuf,
|
||||
pub entry_point: Option<PathBuf>,
|
||||
pub curves: Option<Vec<Curve>>,
|
||||
pub max_constraint_count: Option<usize>,
|
||||
pub tests: Vec<Test>,
|
||||
|
@ -92,6 +92,14 @@ pub fn test_inner(test_path: &str) {
|
|||
|
||||
let curves = t.curves.clone().unwrap_or(vec![Curve::Bn128]);
|
||||
|
||||
let t = Tests {
|
||||
entry_point: Some(
|
||||
t.entry_point
|
||||
.unwrap_or(PathBuf::from(String::from(test_path)).with_extension("zok")),
|
||||
),
|
||||
..t
|
||||
};
|
||||
|
||||
for c in &curves {
|
||||
match c {
|
||||
Curve::Bn128 => compile_and_run::<Bn128Field>(t.clone()),
|
||||
|
@ -101,11 +109,13 @@ pub fn test_inner(test_path: &str) {
|
|||
}
|
||||
|
||||
fn compile_and_run<T: Field>(t: Tests) {
|
||||
let code = std::fs::read_to_string(&t.entry_point).unwrap();
|
||||
let entry_point = t.entry_point.unwrap();
|
||||
|
||||
let code = std::fs::read_to_string(&entry_point).unwrap();
|
||||
|
||||
let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap();
|
||||
let resolver = FileSystemResolver::with_stdlib_root(stdlib.to_str().unwrap());
|
||||
let artifacts = compile::<T, _>(code, t.entry_point.clone(), Some(&resolver)).unwrap();
|
||||
let artifacts = compile::<T, _>(code, entry_point.clone(), Some(&resolver)).unwrap();
|
||||
|
||||
let bin = artifacts.prog();
|
||||
|
||||
|
@ -115,7 +125,7 @@ fn compile_and_run<T: Field>(t: Tests) {
|
|||
|
||||
println!(
|
||||
"{} at {}% of max",
|
||||
t.entry_point.display(),
|
||||
entry_point.display(),
|
||||
(count as f32) / (target_count as f32) * 100_f32
|
||||
);
|
||||
}
|
||||
|
@ -131,7 +141,7 @@ fn compile_and_run<T: Field>(t: Tests) {
|
|||
|
||||
match compare(output, test.output) {
|
||||
Err(e) => {
|
||||
let mut code = File::open(&t.entry_point).unwrap();
|
||||
let mut code = File::open(&entry_point).unwrap();
|
||||
let mut s = String::new();
|
||||
code.read_to_string(&mut s).unwrap();
|
||||
let context = format!(
|
||||
|
|
Loading…
Reference in a new issue