1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00
This commit is contained in:
schaeff 2020-11-04 21:50:44 +00:00
parent bb262ddb4c
commit 898b3eb7e8
26 changed files with 1602 additions and 456 deletions

View file

@ -1,10 +1,10 @@
def sum_array<P>(field[P] a) -> field:
field res = 0
for u32 i in 0..P do
res = res + a[i]
def main():
field a = 3
for u32 i in 0..2 do
for u32 j in 0..2 do
a = a + 1
endfor
a = a + 1
endfor
return res
def main(u32 i) -> (field):
u32[2][4] a = [[1, 2, 3, 4], [1, 2, 3, 4]]
return 1
field b = a
return

View file

@ -52,7 +52,7 @@ impl FlatEmbed {
}
pub fn key<T: Field>(&self) -> ConcreteFunctionKey<'static> {
ConcreteFunctionKey::with_id(self.id()).signature(self.signature())
ConcreteFunctionKey::with_location("#EMBED#", self.id()).signature(self.signature())
}
pub fn id(&self) -> &'static str {

View file

@ -889,7 +889,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.inputs(param_expressions.iter().map(|e| e.get_type()).collect())
.outputs(return_types);
let key = FunctionKey::with_id(id).signature(passed_signature);
let key = FunctionKey::with_location("#EMBED#", id).signature(passed_signature);
let funct = self.get_embed(&key, &symbols);

View file

@ -422,12 +422,18 @@ impl<'ast, T: Field> Checker<'ast, T> {
};
self.functions.insert(
DeclarationFunctionKey::with_id(declaration.id.clone())
.signature(funct.signature.clone()),
DeclarationFunctionKey::with_location(
module_id.clone(),
declaration.id.clone(),
)
.signature(funct.signature.clone()),
);
functions.insert(
DeclarationFunctionKey::with_id(declaration.id.clone())
.signature(funct.signature.clone()),
DeclarationFunctionKey::with_location(
module_id.clone(),
declaration.id.clone(),
)
.signature(funct.signature.clone()),
TypedFunctionSymbol::Here(funct),
);
}
@ -450,6 +456,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.iter()
.filter(|(k, _)| k.id == import.symbol_id)
.map(|(_, v)| DeclarationFunctionKey {
module: import.module_id.clone(),
id: import.symbol_id.clone(),
signature: v.signature(&state.typed_modules).clone(),
})
@ -525,12 +532,12 @@ impl<'ast, T: Field> Checker<'ast, T> {
true => {}
};
self.functions.insert(candidate.clone().id(declaration.id));
let local_key = candidate.clone().id(declaration.id).module(module_id.clone());
self.functions.insert(local_key.clone());
functions.insert(
candidate.clone().id(declaration.id),
TypedFunctionSymbol::There(
candidate,
import.module_id.clone(),
local_key,
TypedFunctionSymbol::There(candidate,
),
);
}
@ -562,12 +569,18 @@ impl<'ast, T: Field> Checker<'ast, T> {
};
self.functions.insert(
DeclarationFunctionKey::with_id(declaration.id.clone())
.signature(funct.signature().clone().try_into().unwrap()),
DeclarationFunctionKey::with_location(
module_id.clone(),
declaration.id.clone(),
)
.signature(funct.signature().clone().try_into().unwrap()),
);
functions.insert(
DeclarationFunctionKey::with_id(declaration.id.clone())
.signature(funct.signature().clone().try_into().unwrap()),
DeclarationFunctionKey::with_location(
module_id.clone(),
declaration.id.clone(),
)
.signature(funct.signature().clone().try_into().unwrap()),
TypedFunctionSymbol::Flat(funct),
);
}
@ -1363,11 +1376,11 @@ impl<'ast, T: Field> Checker<'ast, T> {
let f = functions.pop().unwrap();
let arguments_checked = arguments_checked.into_iter().zip(f.signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::<Result<Vec<_>, _>>().map_err(|e| vec![ErrorInner {
pos: Some(pos),
message: format!("Expected function call argument to be of type {}, found {}", e.1, e.0)
}])?;
pos: Some(pos),
message: format!("Expected function call argument to be of type {}, found {} of type {}", e.1, e.0, e.0.get_type())
}])?;
let call = TypedExpressionList::FunctionCall(f.clone().into(), arguments_checked, Signature::from(f.signature.clone()).outputs);
let call = TypedExpressionList::FunctionCall(f.clone().into(), arguments_checked, variables.iter().map(|v| v.get_type()).collect());
Ok(TypedStatement::MultipleDefinition(variables, call))
},
@ -1914,9 +1927,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
let f = functions.pop().unwrap();
let signature: Signature<T> = f.signature.into();
let signature = f.signature;
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::<Result<Vec<_>, _>>().map_err(|e| ErrorInner {
pos: Some(pos),
message: format!("Expected function call argument to be of type {}, found {}", e.1, e.0)
})?;
@ -1924,25 +1937,28 @@ impl<'ast, T: Field> Checker<'ast, T> {
// the return count has to be 1
match signature.outputs.len() {
1 => match &signature.outputs[0] {
Type::Int => unreachable!(),
Type::FieldElement => Ok(FieldElementExpression::FunctionCall(
FunctionKey {
DeclarationType::Int => unreachable!(),
DeclarationType::FieldElement => Ok(FieldElementExpression::FunctionCall(
DeclarationFunctionKey {
module: module_id.clone(),
id: f.id.clone(),
signature: signature.clone(),
},
arguments_checked,
)
.into()),
Type::Boolean => Ok(BooleanExpression::FunctionCall(
FunctionKey {
DeclarationType::Boolean => Ok(BooleanExpression::FunctionCall(
DeclarationFunctionKey {
module: module_id.clone(),
id: f.id.clone(),
signature: signature.clone(),
},
arguments_checked,
)
.into()),
Type::Uint(bitwidth) => Ok(UExpressionInner::FunctionCall(
FunctionKey {
DeclarationType::Uint(bitwidth) => Ok(UExpressionInner::FunctionCall(
DeclarationFunctionKey {
module: module_id.clone(),
id: f.id.clone(),
signature: signature.clone(),
},
@ -1950,23 +1966,25 @@ impl<'ast, T: Field> Checker<'ast, T> {
)
.annotate(*bitwidth)
.into()),
Type::Struct(members) => Ok(StructExpressionInner::FunctionCall(
FunctionKey {
DeclarationType::Struct(members) => Ok(StructExpressionInner::FunctionCall(
DeclarationFunctionKey {
module: module_id.clone(),
id: f.id.clone(),
signature: signature.clone(),
},
arguments_checked,
)
.annotate(members.clone())
.annotate(members.clone().into())
.into()),
Type::Array(array_type) => Ok(ArrayExpressionInner::FunctionCall(
FunctionKey {
DeclarationType::Array(array_type) => Ok(ArrayExpressionInner::FunctionCall(
DeclarationFunctionKey {
module: module_id.clone(),
id: f.id.clone(),
signature: signature.clone(),
},
arguments_checked,
)
.annotate(*array_type.ty.clone(), array_type.size.clone())
.annotate(Type::from(*array_type.ty.clone()), array_type.size.clone())
.into()),
},
n => Err(ErrorInner {
@ -2592,6 +2610,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
.next()
.unwrap_or(Type::Int);
println!("INFERRED TYPE {:?}", inferred_type);
match inferred_type {
Type::Int => {
// no need to check the expressions have the same type, this is guaranteed above
@ -2643,6 +2663,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
let size = unwrapped_expressions.len() as u32;
println!("hey");
Ok(ArrayExpressionInner::Value(unwrapped_expressions)
.annotate(Type::Boolean, size as usize)
.into())
@ -3379,12 +3401,11 @@ mod tests {
state.typed_modules.get(&PathBuf::from("bar")),
Some(&TypedModule {
functions: vec![(
DeclarationFunctionKey::with_id("main")
DeclarationFunctionKey::with_location("bar", "main")
.signature(DeclarationSignature::new()),
TypedFunctionSymbol::There(
DeclarationFunctionKey::with_id("main")
DeclarationFunctionKey::with_location("foo", "main")
.signature(DeclarationSignature::new()),
"foo".into()
)
)]
.into_iter()
@ -3672,16 +3693,19 @@ mod tests {
.unwrap()
.functions
.contains_key(
&DeclarationFunctionKey::with_id("foo").signature(DeclarationSignature::new())
&DeclarationFunctionKey::with_location(MODULE_ID, "foo")
.signature(DeclarationSignature::new())
));
assert!(state
.typed_modules
.get(&PathBuf::from(MODULE_ID))
.unwrap()
.functions
.contains_key(&DeclarationFunctionKey::with_id("foo").signature(
DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement])
)))
.contains_key(
&DeclarationFunctionKey::with_location(MODULE_ID, "foo").signature(
DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement])
)
))
}
#[test]
@ -4271,6 +4295,7 @@ mod tests {
];
let foo = DeclarationFunctionKey {
module: "main".into(),
id: "foo",
signature: DeclarationSignature {
inputs: vec![],
@ -4323,6 +4348,7 @@ mod tests {
.mock()];
let foo = DeclarationFunctionKey {
module: "main".into(),
id: "foo",
signature: DeclarationSignature {
inputs: vec![],
@ -4856,7 +4882,7 @@ mod tests {
typed_absy::Variable::field_element("b"),
],
TypedExpressionList::FunctionCall(
DeclarationFunctionKey::with_id("foo").signature(
DeclarationFunctionKey::with_location(MODULE_ID, "foo").signature(
DeclarationSignature::new().outputs(vec![
DeclarationType::FieldElement,
DeclarationType::FieldElement,
@ -4874,6 +4900,7 @@ mod tests {
];
let foo = DeclarationFunctionKey {
module: "main".into(),
id: "foo",
signature: DeclarationSignature {
inputs: vec![],

View file

@ -279,7 +279,7 @@ pub fn fold_statement<'ast, T: Field>(
)]
}
typed_absy::TypedStatement::PushCallLog(..) => vec![],
typed_absy::TypedStatement::PopCallLog(..) => vec![],
typed_absy::TypedStatement::PopCallLog => vec![],
}
}
@ -906,9 +906,8 @@ pub fn fold_function_symbol<'ast, T: Field>(
typed_absy::TypedFunctionSymbol::Here(fun) => {
zir::ZirFunctionSymbol::Here(f.fold_function(fun))
}
typed_absy::TypedFunctionSymbol::There(key, module) => zir::ZirFunctionSymbol::There(
typed_absy::TypedFunctionSymbol::There(key) => zir::ZirFunctionSymbol::There(
f.fold_function_key(typed_absy::types::ConcreteFunctionKey::try_from(key).unwrap()),
module,
), // by default, do not fold modules recursively
typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat),
}

View file

@ -41,12 +41,16 @@ pub trait Analyse {
impl<'ast, T: Field> TypedProgram<'ast, T> {
pub fn analyse(self) -> (ZirProgram<'ast, T>, Abi) {
// return binding
let r = ReturnBinder::bind(self);
//let r = ReturnBinder::bind(self);
// propagated unrolling
//let r = PropagatedUnroller::unroll(r).unwrap_or_else(|e| panic!(e));
let r = reduce_program(r).unwrap();
println!("{:#?}", self);
let r = reduce_program(self).unwrap();
println!("reduced");
let r = Trimmer::trim(r);

View file

@ -331,7 +331,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
}
s @ TypedStatement::PushCallLog(..) => Some(s),
s @ TypedStatement::PopCallLog(..) => Some(s),
s @ TypedStatement::PopCallLog => Some(s),
};
// In verbose mode, we always return a statement

View file

@ -25,23 +25,25 @@
// - The body of the function is in SSA form
// - The return value(s) are assigned to internal variables
use embed::FlatEmbed;
use static_analysis::reducer::CallCache;
use static_analysis::reducer::Output;
use static_analysis::reducer::ShallowTransformer;
use static_analysis::reducer::Versions;
use typed_absy::CoreIdentifier;
use typed_absy::Identifier;
use typed_absy::TypedAssignee;
use typed_absy::{
ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, DeclarationSignature, Signature,
Type, TypedExpression, TypedFunction, TypedFunctionSymbol, TypedModuleId, TypedModules,
TypedStatement, Variable,
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey,
DeclarationSignature, Signature, Type, TypedExpression, TypedFunction, TypedFunctionSymbol,
TypedModules, TypedStatement, Variable,
};
use zokrates_field::Field;
pub enum InlineError<'ast, T> {
Generic(DeclarationSignature<'ast>, ConcreteSignature),
Flat,
Flat(FlatEmbed, Vec<TypedExpression<'ast, T>>, Vec<Type<'ast, T>>),
NonConstant(
TypedModuleId,
DeclarationFunctionKey<'ast>,
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
@ -49,33 +51,29 @@ pub enum InlineError<'ast, T> {
}
fn get_canonical_function<'ast, T: Field>(
module_id: TypedModuleId,
function_key: DeclarationFunctionKey<'ast>,
modules: &TypedModules<'ast, T>,
) -> (
TypedModuleId,
DeclarationFunctionKey<'ast>,
TypedFunction<'ast, T>,
) {
) -> Result<(DeclarationFunctionKey<'ast>, TypedFunction<'ast, T>), FlatEmbed> {
match modules
.get(&module_id.clone())
.get(&function_key.module)
.unwrap()
.functions
.iter()
.find(|(key, _)| function_key == **key)
.unwrap()
{
(key, TypedFunctionSymbol::Here(f)) => (module_id, key.clone(), f.clone()),
_ => unimplemented!(),
(key, TypedFunctionSymbol::Here(f)) => Ok((key.clone(), f.clone())),
(_, TypedFunctionSymbol::There(key)) => get_canonical_function(key.clone(), &modules),
(_, TypedFunctionSymbol::Flat(f)) => Err(f.clone()),
}
}
pub fn inline_call<'a, 'ast, T: Field>(
module_id: TypedModuleId,
k: DeclarationFunctionKey<'ast>,
arguments: Vec<TypedExpression<'ast, T>>,
output_types: Vec<Type<'ast, T>>,
modules: &TypedModules<'ast, T>,
cache: &mut CallCache<'ast, T>,
versions: &'a mut Versions<'ast>,
) -> Result<
Output<(Vec<TypedStatement<'ast, T>>, Vec<TypedExpression<'ast, T>>), Vec<Versions<'ast>>>,
@ -95,16 +93,12 @@ pub fn inline_call<'a, 'ast, T: Field>(
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {
Ok(s) => s,
Err(()) => {
return Err(InlineError::NonConstant(
module_id,
k,
arguments,
output_types,
));
return Err(InlineError::NonConstant(k, arguments, output_types));
}
};
let (module_id, decl_key, f) = get_canonical_function(module_id, k.clone(), modules);
let (decl_key, f) = get_canonical_function(k.clone(), modules)
.map_err(|e| InlineError::Flat(e, arguments.clone(), output_types))?;
assert_eq!(f.arguments.len(), arguments.len());
// get an assignment of generics for this call site
@ -115,23 +109,34 @@ pub fn inline_call<'a, 'ast, T: Field>(
InlineError::Generic(decl_key.signature.clone(), inferred_signature.clone())
})?;
let concrete_key = ConcreteFunctionKey {
module: decl_key.module.clone(),
id: decl_key.id.clone(),
signature: inferred_signature.clone(),
};
match cache.get(&(concrete_key.clone(), arguments.clone())) {
Some(v) => {
return Ok(Output::Complete((vec![], v.clone())));
}
None => {}
};
let (ssa_f, incomplete_data) = match ShallowTransformer::transform(f, &assignment, versions) {
Output::Complete(v) => (v, None),
Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)),
};
let call_log = TypedStatement::PushCallLog(
module_id,
decl_key.clone(),
assignment,
ssa_f
.arguments
.into_iter()
.zip(inferred_signature.inputs.clone())
.map(|(p, t)| ConcreteVariable::with_id_and_type(p.id.id, t))
.zip(arguments)
.collect(),
);
let call_log = TypedStatement::PushCallLog(decl_key.clone(), assignment);
let input_bindings: Vec<TypedStatement<'ast, T>> = ssa_f
.arguments
.into_iter()
.zip(inferred_signature.inputs.clone())
.map(|(p, t)| ConcreteVariable::with_id_and_type(p.id.id, t))
.zip(arguments.clone())
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
.collect();
let (statements, returns): (Vec<_>, Vec<_>) =
ssa_f.statements.into_iter().partition(|s| match s {
@ -151,7 +156,15 @@ pub fn inline_call<'a, 'ast, T: Field>(
.iter()
.enumerate()
.map(|(i, t)| {
ConcreteVariable::with_id_and_type(Identifier::from(CoreIdentifier::Call(i)), t.clone())
ConcreteVariable::with_id_and_type(
Identifier::from(CoreIdentifier::Call(i)).version(
*versions
.entry(CoreIdentifier::Call(i).clone())
.and_modify(|e| *e += 1) // if it was already declared, we increment
.or_insert(0),
),
t.clone(),
)
})
.collect();
@ -162,13 +175,26 @@ pub fn inline_call<'a, 'ast, T: Field>(
assert_eq!(res.len(), returns.len());
let call_pop = TypedStatement::PopCallLog(res.into_iter().zip(returns).collect());
let output_bindings: Vec<TypedStatement<'ast, T>> = res
.into_iter()
.zip(returns)
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
.collect();
let pop_log = TypedStatement::PopCallLog;
let statements: Vec<_> = std::iter::once(call_log)
.chain(input_bindings)
.chain(statements)
.chain(std::iter::once(call_pop))
.chain(output_bindings)
.chain(std::iter::once(pop_log))
.collect();
cache.insert(
(concrete_key.clone(), arguments.clone()),
expressions.clone(),
);
Ok(incomplete_data
.map(|d| Output::Incomplete((statements.clone(), expressions.clone()), d))
.unwrap_or_else(|| Output::Complete((statements, expressions))))

View file

@ -17,12 +17,20 @@ mod unroll;
use self::inline::{inline_call, InlineError};
use std::collections::HashMap;
use typed_absy::result_folder::*;
use typed_absy::types::GenericsAssignment;
use typed_absy::Folder;
use typed_absy::{
CoreIdentifier, TypedExpressionList, TypedFunction, TypedFunctionSymbol, TypedModule,
TypedModules, TypedProgram, TypedStatement,
ArrayExpression, ArrayExpressionInner, BooleanExpression, ConcreteFunctionKey, CoreIdentifier,
DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier, StructExpression,
StructExpressionInner, Type, Typed, TypedExpression, TypedExpressionList, TypedFunction,
TypedFunctionSymbol, TypedModule, TypedModules, TypedProgram, TypedStatement, UExpression,
UExpressionInner,
};
use std::convert::{TryFrom, TryInto};
use zokrates_field::Field;
use self::shallow_ssa::ShallowTransformer;
@ -44,6 +52,326 @@ pub enum Error {
Incompatible,
}
type CallCache<'ast, T> = HashMap<
(ConcreteFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
Vec<TypedExpression<'ast, T>>,
>;
type Substitutions<'ast> = HashMap<CoreIdentifier<'ast>, HashMap<usize, usize>>;
struct Sub<'a, 'ast> {
substitutions: &'a mut Substitutions<'ast>,
}
impl<'a, 'ast> Sub<'a, 'ast> {
fn new(substitutions: &'a mut Substitutions<'ast>) -> Self {
Self { substitutions }
}
}
impl<'a, 'ast, T: Field> Folder<'ast, T> for Sub<'a, 'ast> {
fn fold_name(&mut self, id: Identifier<'ast>) -> Identifier<'ast> {
let sub = self.substitutions.entry(id.id.clone()).or_default();
let version = sub.get(&id.version).cloned().unwrap_or(id.version);
id.version(version)
}
}
fn register<'ast>(
substitutions: &mut Substitutions<'ast>,
substitute: &Versions<'ast>,
with: &Versions<'ast>,
) {
//assert!(substitute.len() <= with.len());
for (id, key, value) in substitute
.iter()
.filter_map(|(id, version)| with.get(&id).clone().map(|to| (id, version, to)))
.filter(|(_, key, value)| key != value)
{
let sub = substitutions.entry(id.clone()).or_default();
sub.insert(*key, *value);
}
}
struct Reducer<'ast, 'a, T> {
statement_buffer: Vec<TypedStatement<'ast, T>>,
for_loop_versions: Vec<Versions<'ast>>,
for_loop_index: usize,
modules: &'a TypedModules<'ast, T>,
versions: &'a mut Versions<'ast>,
substitutions: Substitutions<'ast>,
cache: CallCache<'ast, T>,
complete: bool,
}
impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
fn new(
modules: &'a TypedModules<'ast, T>,
versions: &'a mut Versions<'ast>,
for_loop_versions: Vec<Versions<'ast>>,
) -> Self {
Reducer {
statement_buffer: vec![],
for_loop_index: 0,
for_loop_versions,
cache: CallCache::default(),
substitutions: Substitutions::default(),
modules,
versions,
complete: true,
}
}
fn fold_function_call<E>(
&mut self,
key: DeclarationFunctionKey<'ast>,
arguments: Vec<TypedExpression<'ast, T>>,
output_types: Vec<Type<'ast, T>>,
) -> Result<E, <Self as ResultFolder<'ast, T>>::Error>
where
E: FunctionCall<'ast, T> + TryFrom<TypedExpression<'ast, T>, Error = ()> + std::fmt::Debug,
{
let arguments = arguments
.into_iter()
.map(|e| self.fold_expression(e))
.collect::<Result<_, _>>()?;
let res = inline_call(
key.clone(),
arguments,
output_types,
&self.modules,
&mut self.cache,
&mut self.versions,
);
match res {
Ok(Output::Complete((statements, expressions))) => {
self.complete &= true;
self.statement_buffer.extend(statements);
Ok(expressions[0].clone().try_into().unwrap())
}
Ok(Output::Incomplete((statements, expressions), _delta_for_loop_versions)) => {
self.complete = false;
self.statement_buffer.extend(statements);
Ok(expressions[0].clone().try_into().unwrap())
}
Err(InlineError::Generic(a, b)) => Err(Error::Incompatible),
Err(InlineError::NonConstant(key, arguments, _)) => {
self.complete = false;
Ok(E::function_call(key, arguments))
}
Err(InlineError::Flat(embed, arguments, output_types)) => {
Ok(E::function_call(embed.key::<T>().into(), arguments))
}
}
}
}
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
type Error = Error;
fn fold_statement(
&mut self,
s: TypedStatement<'ast, T>,
) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
let res = match s {
TypedStatement::MultipleDefinition(
v,
TypedExpressionList::FunctionCall(key, arguments, output_types),
) => {
let arguments = arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect::<Result<_, _>>()?;
match inline_call(
key,
arguments,
output_types,
&self.modules,
&mut self.cache,
&mut self.versions,
) {
Ok(Output::Complete((statements, expressions))) => {
assert_eq!(v.len(), expressions.len());
self.complete &= true;
Ok(statements
.into_iter()
.chain(
v.into_iter()
.zip(expressions)
.map(|(v, e)| TypedStatement::Definition(v.into(), e)),
)
.collect())
}
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
assert_eq!(v.len(), expressions.len());
self.complete = false;
Ok(statements
.into_iter()
.chain(
v.into_iter()
.zip(expressions)
.map(|(v, e)| TypedStatement::Definition(v.into(), e)),
)
.collect())
}
Err(InlineError::Generic(..)) => Err(Error::Incompatible),
Err(InlineError::NonConstant(key, arguments, output_types)) => {
self.complete = false;
Ok(vec![TypedStatement::MultipleDefinition(
v,
TypedExpressionList::FunctionCall(key, arguments, output_types),
)])
}
Err(InlineError::Flat(embed, arguments, output_types)) => {
Ok(vec![TypedStatement::MultipleDefinition(
v,
TypedExpressionList::FunctionCall(
embed.key::<T>().into(),
arguments,
output_types,
),
)])
}
}
}
TypedStatement::For(v, from, to, statements) => {
match (from.as_inner(), to.as_inner()) {
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => {
let mut out_statements = vec![];
let versions_before = self.for_loop_versions[self.for_loop_index].clone();
self.versions.values_mut().for_each(|v| *v = *v + 1);
register(&mut self.substitutions, &self.versions, &versions_before);
let versions_after = versions_before
.clone()
.into_iter()
.map(|(k, v)| (k, v + 2))
.collect();
let mut transformer = ShallowTransformer::with_versions(&mut self.versions);
let old_index = self.for_loop_index;
for index in *from..*to {
let statements: Vec<TypedStatement<_>> =
std::iter::once(TypedStatement::Definition(
v.clone().into(),
UExpression::from(index as u32).into(),
))
.chain(statements.clone().into_iter())
.map(|s| transformer.fold_statement(s))
.flatten()
.collect();
let new_versions_count = transformer.for_loop_backups.len();
self.for_loop_index += new_versions_count;
out_statements.extend(statements);
}
let backups = transformer.for_loop_backups;
let blocked = transformer.blocked;
register(&mut self.substitutions, &versions_after, &self.versions);
self.for_loop_versions.splice(old_index..old_index, backups);
self.complete &= !blocked;
let out_statements = out_statements
.into_iter()
.map(|s| Sub::new(&mut self.substitutions).fold_statement(s))
.flatten()
.collect();
Ok(out_statements)
}
_ => {
self.complete = false;
self.for_loop_index += 1;
Ok(vec![TypedStatement::For(v, from, to, statements)])
}
}
}
s => fold_statement(self, s),
};
res.map(|res| self.statement_buffer.drain(..).chain(res).collect())
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
match e {
BooleanExpression::FunctionCall(key, arguments) => {
self.fold_function_call(key, arguments, vec![Type::Boolean])
}
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression(
&mut self,
e: UExpression<'ast, T>,
) -> Result<UExpression<'ast, T>, Self::Error> {
match e.as_inner() {
UExpressionInner::FunctionCall(key, arguments) => {
self.fold_function_call(key.clone(), arguments.clone(), vec![e.get_type()])
}
_ => fold_uint_expression(self, e),
}
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match e {
FieldElementExpression::FunctionCall(key, arguments) => {
self.fold_function_call(key, arguments, vec![Type::FieldElement])
}
e => fold_field_expression(self, e),
}
}
fn fold_array_expression(
&mut self,
e: ArrayExpression<'ast, T>,
) -> Result<ArrayExpression<'ast, T>, Self::Error> {
match e.as_inner() {
ArrayExpressionInner::FunctionCall(key, arguments) => {
self.fold_function_call(key.clone(), arguments.clone(), vec![e.get_type()])
}
_ => fold_array_expression(self, e),
}
}
fn fold_struct_expression(
&mut self,
e: StructExpression<'ast, T>,
) -> Result<StructExpression<'ast, T>, Self::Error> {
match e.as_inner() {
StructExpressionInner::FunctionCall(key, arguments) => {
self.fold_function_call(key.clone(), arguments.clone(), vec![e.get_type()])
}
_ => fold_struct_expression(self, e),
}
}
}
pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
let main_module = p.modules.get(&p.main).unwrap().clone();
@ -93,27 +421,27 @@ fn reduce_function<'ast, T: Field>(
let mut f = new_f;
let statements = loop {
match reduce_statements(
f.statements,
&mut for_loop_versions,
&mut versions,
modules,
) {
Ok(Output::Complete(statements)) => {
break statements;
}
Ok(Output::Incomplete(new_statements, new_for_loop_versions)) => {
let new_f = TypedFunction {
statements: new_statements,
..f
};
println!("{}", f);
use typed_absy::folder::Folder;
let mut reducer = Reducer::new(&modules, &mut versions, for_loop_versions);
let statements = f
.statements
.into_iter()
.map(|s| reducer.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect();
match reducer.complete {
true => break statements,
false => {
let new_f = TypedFunction { statements, ..f };
f = Propagator::verbose().fold_function(new_f);
for_loop_versions = new_for_loop_versions;
for_loop_versions = reducer.for_loop_versions;
}
Err(e) => return Err(e),
}
};
@ -122,111 +450,47 @@ fn reduce_function<'ast, T: Field>(
}
}
fn reduce_statements<'ast, T: Field>(
statements: Vec<TypedStatement<'ast, T>>,
for_loop_versions: &mut Vec<Versions<'ast>>,
versions: &mut Versions<'ast>,
modules: &TypedModules<'ast, T>,
) -> Result<Output<Vec<TypedStatement<'ast, T>>, Vec<Versions<'ast>>>, Error> {
let mut versions = versions;
// fn reduce_statements<'ast, T: Field>(
// statements: Vec<TypedStatement<'ast, T>>,
// for_loop_versions: &mut Vec<Versions<'ast>>,
// versions: &mut Versions<'ast>,
// modules: &TypedModules<'ast, T>,
// ) -> Result<Output<Vec<TypedStatement<'ast, T>>, Vec<Versions<'ast>>>, Error> {
// let mut versions = versions;
let statements = statements
.into_iter()
.map(|s| reduce_statement(s, for_loop_versions, &mut versions, modules));
// let statements = statements
// .into_iter()
// .map(|s| reduce_statement(s, for_loop_versions, &mut versions, modules));
statements
.into_iter()
.fold(Ok(Output::Complete(vec![])), |state, e| match (state, e) {
(Ok(state), Ok(e)) => {
use self::Output::*;
match (state, e) {
(Complete(mut s), Complete(c)) => {
s.extend(c);
Ok(Complete(s))
}
(Complete(mut s), Incomplete(stats, for_loop_versions)) => {
s.extend(stats);
Ok(Incomplete(s, for_loop_versions))
}
(Incomplete(mut stats, new_for_loop_versions), Complete(c)) => {
stats.extend(c);
Ok(Incomplete(stats, new_for_loop_versions))
}
(Incomplete(mut stats, mut versions), Incomplete(new_stats, new_versions)) => {
stats.extend(new_stats);
versions.extend(new_versions);
Ok(Incomplete(stats, versions))
}
}
}
(Err(state), _) => Err(state),
(Ok(_), Err(e)) => Err(e),
})
}
fn reduce_statement<'ast, T: Field>(
statement: TypedStatement<'ast, T>,
_for_loop_versions: &mut Vec<Versions<'ast>>,
versions: &mut Versions<'ast>,
modules: &TypedModules<'ast, T>,
) -> Result<Output<Vec<TypedStatement<'ast, T>>, Vec<Versions<'ast>>>, Error> {
use self::Output::*;
match statement {
TypedStatement::MultipleDefinition(
v,
TypedExpressionList::FunctionCall(key, arguments, output_types),
) => match inline_call(
"main".into(),
key,
arguments,
output_types,
modules,
versions,
) {
Ok(Output::Complete((statements, expressions))) => {
assert_eq!(v.len(), expressions.len());
Ok(Output::Complete(
statements
.into_iter()
.chain(
v.into_iter()
.zip(expressions)
.map(|(v, e)| TypedStatement::Definition(v.into(), e)),
)
.collect(),
))
}
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
assert_eq!(v.len(), expressions.len());
Ok(Output::Incomplete(
statements
.into_iter()
.chain(
v.into_iter()
.zip(expressions)
.map(|(v, e)| TypedStatement::Definition(v.into(), e)),
)
.collect(),
delta_for_loop_versions,
))
}
Err(InlineError::Generic(..)) => Err(Error::Incompatible),
Err(InlineError::NonConstant(_module, key, arguments, output_types)) => {
Ok(Output::Incomplete(
vec![TypedStatement::MultipleDefinition(
v,
TypedExpressionList::FunctionCall(key, arguments, output_types),
)],
vec![],
))
}
Err(InlineError::Flat) => unimplemented!(),
},
TypedStatement::For(..) => unimplemented!(),
s => Ok(Complete(vec![s])),
}
}
// statements
// .into_iter()
// .fold(Ok(Output::Complete(vec![])), |state, e| match (state, e) {
// (Ok(state), Ok(e)) => {
// use self::Output::*;
// match (state, e) {
// (Complete(mut s), Complete(c)) => {
// s.extend(c);
// Ok(Complete(s))
// }
// (Complete(mut s), Incomplete(stats, for_loop_versions)) => {
// s.extend(stats);
// Ok(Incomplete(s, for_loop_versions))
// }
// (Incomplete(mut stats, new_for_loop_versions), Complete(c)) => {
// stats.extend(c);
// Ok(Incomplete(stats, new_for_loop_versions))
// }
// (Incomplete(mut stats, mut versions), Incomplete(new_stats, new_versions)) => {
// stats.extend(new_stats);
// versions.extend(new_versions);
// Ok(Incomplete(stats, versions))
// }
// }
// }
// (Err(state), _) => Err(state),
// (Ok(_), Err(e)) => Err(e),
// })
// }
#[cfg(test)]
mod tests {
@ -257,9 +521,11 @@ mod tests {
// u32 n_0 = 42
// n_1 = n_0
// a_1 = a_0
// # PUSH CALL to foo with a_3 := a_1
// # POP CALL with foo_ifof_42 := a_3
// a_2 = foo_ifof_42
// # PUSH CALL to foo
// a_3 := a_1 // input binding
// #RETURN_AT_INDEX_0_0 := a_3
// # POP CALL
// a_2 = #RETURN_AT_INDEX_0_0
// n_2 = n_1
// return a_2
@ -295,7 +561,7 @@ mod tests {
TypedStatement::MultipleDefinition(
vec![Variable::field_element("a").into()],
TypedExpressionList::FunctionCall(
DeclarationFunctionKey::with_id("foo").signature(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
@ -324,7 +590,7 @@ mod tests {
TypedModule {
functions: vec![
(
DeclarationFunctionKey::with_id("foo").signature(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
@ -332,7 +598,7 @@ mod tests {
TypedFunctionSymbol::Here(foo),
),
(
DeclarationFunctionKey::with_id("main").signature(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
@ -369,28 +635,23 @@ mod tests {
FieldElementExpression::Identifier("a".into()).into(),
),
TypedStatement::PushCallLog(
"main".into(),
DeclarationFunctionKey::with_id("foo").signature(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
),
GenericsAssignment::default(),
vec![(
ConcreteVariable::with_id_and_type(
Identifier::from("a").version(3),
ConcreteType::FieldElement,
),
FieldElementExpression::Identifier(Identifier::from("a").version(1)).into(),
)],
),
TypedStatement::PopCallLog(vec![(
ConcreteVariable::with_id_and_type(
Identifier::from(CoreIdentifier::Call(0)).version(0),
ConcreteType::FieldElement,
),
TypedStatement::Definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
FieldElementExpression::Identifier(Identifier::from("a").version(1)).into(),
),
TypedStatement::Definition(
Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0))
.into(),
FieldElementExpression::Identifier(Identifier::from("a").version(3)).into(),
)]),
),
TypedStatement::PopCallLog,
TypedStatement::Definition(
Variable::field_element(Identifier::from("a").version(2)).into(),
FieldElementExpression::Identifier(
@ -420,7 +681,7 @@ mod tests {
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_id("main").signature(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
@ -435,6 +696,9 @@ mod tests {
.collect(),
};
println!("{}", reduced.clone().unwrap());
println!("{}", expected);
assert_eq!(reduced.unwrap(), expected);
}
@ -455,9 +719,12 @@ mod tests {
// u32 n_0 = 42
// n_1 = n_0
// field[1] b_0 = [42]
// # PUSH CALL to foo::<1> with a_0 := b_0
// # POP CALL with foo_if[1]_of[1]_42 := a_0
// b_1 = foo_if[1]_of[1]_42
// # PUSH CALL to foo::<1>
// a_0 = b_0
// K = 1
// #RETURN_AT_INDEX_0_0 := a_0
// # POP CALL
// b_1 = #RETURN_AT_INDEX_0_0
// n_2 = n_1
// return a_2
@ -512,7 +779,8 @@ mod tests {
TypedStatement::MultipleDefinition(
vec![Variable::array("b", Type::FieldElement, 1u32.into()).into()],
TypedExpressionList::FunctionCall(
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![ArrayExpressionInner::Identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into()],
@ -539,11 +807,12 @@ mod tests {
TypedModule {
functions: vec![
(
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
TypedFunctionSymbol::Here(foo),
),
(
DeclarationFunctionKey::with_id("main").signature(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
@ -584,35 +853,37 @@ mod tests {
.into(),
),
TypedStatement::PushCallLog(
"main".into(),
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
GenericsAssignment(vec![("K", 1)].into_iter().collect()),
vec![(
ConcreteVariable::array(
Identifier::from("a").version(1),
ConcreteType::FieldElement,
1,
)
),
TypedStatement::Definition(
Variable::array(
Identifier::from("a").version(1),
Type::FieldElement,
1u32.into(),
)
.into(),
ArrayExpressionInner::Identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into(),
ArrayExpressionInner::Identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into(),
)],
),
TypedStatement::Definition(
Variable::uint("K", UBitwidth::B32).into(),
UExpression::from(1u32).into(),
TypedExpression::Uint(1u32.into()),
),
TypedStatement::PopCallLog(vec![(
ConcreteVariable::array(
TypedStatement::Definition(
Variable::array(
Identifier::from(CoreIdentifier::Call(0)).version(0),
ConcreteType::FieldElement,
1,
),
Type::FieldElement,
1u32.into(),
)
.into(),
ArrayExpressionInner::Identifier(Identifier::from("a").version(1))
.annotate(Type::FieldElement, 1u32)
.into(),
)]),
),
TypedStatement::PopCallLog,
TypedStatement::Definition(
Variable::array(
Identifier::from("b").version(1),
@ -645,7 +916,7 @@ mod tests {
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_id("main").signature(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
@ -660,6 +931,9 @@ mod tests {
.collect(),
};
println!("{}", reduced.clone().unwrap());
println!("{}", expected);
assert_eq!(reduced.unwrap(), expected);
}
@ -680,9 +954,12 @@ mod tests {
// u32 n_0 = 2
// n_1 = 2
// field[1] b_0 = [42]
// # PUSH CALL to foo::<1> with a_3 := b_0
// # POP CALL with foo_if[1]of[1]_42 := a_3
// b_1 = foo_if[1]of[1]_42
// # PUSH CALL to foo::<1>
// a_3 = b_0
// K = 1
// #RETURN_AT_INDEX_0_0 = a_3
// # POP CALL
// b_1 = #RETURN_AT_INDEX_0_0
// n_2 = 2
// return a_2
@ -755,7 +1032,8 @@ mod tests {
)
.into()],
TypedExpressionList::FunctionCall(
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![ArrayExpressionInner::Identifier("b".into())
.annotate(
Type::FieldElement,
@ -798,11 +1076,12 @@ mod tests {
TypedModule {
functions: vec![
(
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
TypedFunctionSymbol::Here(foo),
),
(
DeclarationFunctionKey::with_id("main").signature(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
@ -841,37 +1120,39 @@ mod tests {
.into(),
),
TypedStatement::PushCallLog(
"main".into(),
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
GenericsAssignment(vec![("K", 1)].into_iter().collect()),
vec![(
ConcreteVariable::array(
Identifier::from("a").version(1),
ConcreteType::FieldElement,
1,
)
.into(),
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(1.into()).into()
])
.annotate(Type::FieldElement, 1u32)
.into(),
)],
),
TypedStatement::Definition(
Variable::array(
Identifier::from("a").version(1),
Type::FieldElement,
1u32.into(),
)
.into(),
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(1.into()).into()
])
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Definition(
Variable::uint("K", UBitwidth::B32).into(),
UExpression::from(1u32).into(),
),
TypedStatement::PopCallLog(vec![(
ConcreteVariable::array(
TypedStatement::Definition(
Variable::array(
Identifier::from(CoreIdentifier::Call(0)).version(0),
ConcreteType::FieldElement,
1,
),
Type::FieldElement,
1u32.into(),
)
.into(),
ArrayExpressionInner::Identifier(Identifier::from("a").version(1))
.annotate(Type::FieldElement, 1u32)
.into(),
)]),
),
TypedStatement::PopCallLog,
TypedStatement::Definition(
Variable::array(
Identifier::from("b").version(1),
@ -902,7 +1183,7 @@ mod tests {
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_id("main").signature(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
@ -917,6 +1198,9 @@ mod tests {
.collect(),
};
println!("{}", reduced.clone().unwrap());
println!("{}", expected);
assert_eq!(reduced.unwrap(), expected);
}
@ -1010,7 +1294,7 @@ mod tests {
// TypedStatement::MultipleDefinition(
// vec![Variable::array("b", Type::FieldElement, 1u32.into()).into()],
// TypedExpressionList::FunctionCall(
// DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
// DeclarationFunctionKey::with_location(module_id.clone(), "foo").signature(foo_signature.clone()),
// vec![ArrayExpressionInner::Identifier("b".into())
// .annotate(Type::FieldElement, 1u32)
// .into()],
@ -1029,11 +1313,11 @@ mod tests {
// TypedModule {
// functions: vec![
// (
// DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
// DeclarationFunctionKey::with_location(module_id.clone(), "foo").signature(foo_signature.clone()),
// TypedFunctionSymbol::Here(foo),
// ),
// (
// DeclarationFunctionKey::with_id("main"),
// DeclarationFunctionKey::with_location(module_id.clone(), "main"),
// TypedFunctionSymbol::Here(main),
// ),
// ]
@ -1071,7 +1355,7 @@ mod tests {
// ),
// TypedStatement::PushCallLog(
// "main".into(),
// DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
// DeclarationFunctionKey::with_location(module_id.clone(), "foo").signature(foo_signature.clone()),
// GenericsAssignment(vec![("K", 1)].into_iter().collect()),
// vec![(
// ConcreteVariable::array("a", ConcreteType::FieldElement, 1).into(),
@ -1113,7 +1397,7 @@ mod tests {
// "main".into(),
// TypedModule {
// functions: vec![(
// DeclarationFunctionKey::with_id("main").signature(
// DeclarationFunctionKey::with_location(module_id.clone(), "main").signature(
// DeclarationSignature::new()
// .inputs(vec![DeclarationType::FieldElement])
// .outputs(vec![DeclarationType::FieldElement]),
@ -1178,7 +1462,8 @@ mod tests {
TypedStatement::MultipleDefinition(
vec![Variable::array("b", Type::FieldElement, 1u32.into()).into()],
TypedExpressionList::FunctionCall(
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![ArrayExpressionInner::Value(vec![])
.annotate(Type::FieldElement, 0u32)
.into()],
@ -1197,11 +1482,12 @@ mod tests {
TypedModule {
functions: vec![
(
DeclarationFunctionKey::with_id("foo").signature(foo_signature.clone()),
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
TypedFunctionSymbol::Here(foo),
),
(
DeclarationFunctionKey::with_id("main").signature(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
),
TypedFunctionSymbol::Here(main),

View file

@ -37,15 +37,15 @@ use super::{Output, Versions};
pub struct ShallowTransformer<'ast, 'a> {
// version index for any variable name
versions: &'a mut Versions<'ast>,
pub versions: &'a mut Versions<'ast>,
// A backup of the versions before each for-loop
for_loop_backups: Vec<Versions<'ast>>,
pub for_loop_backups: Vec<Versions<'ast>>,
// whether all statements could be unrolled so far. Loops with variable bounds cannot.
blocked: bool,
pub blocked: bool,
}
impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
fn with_versions(versions: &'a mut Versions<'ast>) -> Self {
pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self {
ShallowTransformer {
versions,
for_loop_backups: Vec::default(),
@ -568,7 +568,8 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
e: TypedExpressionList<'ast, T>,
) -> TypedExpressionList<'ast, T> {
match e {
TypedExpressionList::FunctionCall(ref k, _, _) => {
TypedExpressionList::FunctionCall(ref k, ref a, _) => {
println!("{:#?}", a.iter().map(|a| a.get_type()).collect::<Vec<_>>());
if !k.id.starts_with("_") {
self.blocked = true;
}
@ -929,7 +930,7 @@ mod tests {
let s: TypedStatement<Bn128Field> = TypedStatement::MultipleDefinition(
vec![Variable::field_element("a")],
TypedExpressionList::FunctionCall(
DeclarationFunctionKey::with_id("foo").signature(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
@ -943,7 +944,7 @@ mod tests {
vec![TypedStatement::MultipleDefinition(
vec![Variable::field_element(Identifier::from("a").version(1))],
TypedExpressionList::FunctionCall(
DeclarationFunctionKey::with_id("foo").signature(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement])
@ -1447,7 +1448,7 @@ mod tests {
TypedStatement::MultipleDefinition(
vec![Variable::field_element("a").into()],
TypedExpressionList::FunctionCall(
DeclarationFunctionKey::with_id("foo"),
DeclarationFunctionKey::with_location("main", "foo"),
vec![FieldElementExpression::Identifier("a".into()).into()],
vec![Type::FieldElement],
),
@ -1463,7 +1464,7 @@ mod tests {
FieldElementExpression::mult(
FieldElementExpression::Identifier("a".into()),
FieldElementExpression::FunctionCall(
FunctionKey::with_id("foo"),
DeclarationFunctionKey::with_location("main", "foo"),
vec![FieldElementExpression::Identifier("a".into()).into()],
),
)
@ -1511,7 +1512,7 @@ mod tests {
TypedStatement::MultipleDefinition(
vec![Variable::field_element(Identifier::from("a").version(2)).into()],
TypedExpressionList::FunctionCall(
DeclarationFunctionKey::with_id("foo"),
DeclarationFunctionKey::with_location("main", "foo"),
vec![FieldElementExpression::Identifier(
Identifier::from("a").version(1),
)
@ -1530,7 +1531,7 @@ mod tests {
FieldElementExpression::mult(
FieldElementExpression::Identifier(Identifier::from("a").version(2)),
FieldElementExpression::FunctionCall(
FunctionKey::with_id("foo"),
DeclarationFunctionKey::with_location("main", "foo"),
vec![FieldElementExpression::Identifier(
Identifier::from("a").version(2),
)

View file

@ -43,7 +43,7 @@ mod tests {
fn generate_abi_from_typed_ast() {
let mut functions = HashMap::new();
functions.insert(
ConcreteFunctionKey::with_id("main").into(),
ConcreteFunctionKey::with_location("main", "main").into(),
TypedFunctionSymbol::Here(TypedFunction {
generics: vec![],
arguments: vec![

View file

@ -381,80 +381,70 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
) -> Result<Self, TypedExpression<'ast, T>> {
let array_ty = array.get_array_type();
if array_ty == target_array_ty {
if array_ty.weak_eq(&target_array_ty) {
return Ok(array);
}
// sizes must be equal
let converted = match target_array_ty.size == array_ty.size {
true =>
// elements must fit in the target type
{
match array.into_inner() {
ArrayExpressionInner::Value(inline_array) => {
match *target_array_ty.ty {
Type::Int => Ok(inline_array),
Type::FieldElement => {
// try to convert all elements to field
inline_array
.into_iter()
.map(|e| {
FieldElementExpression::try_from_typed(e)
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)
}
Type::Uint(bitwidth) => {
// try to convert all elements to uint
inline_array
.into_iter()
.map(|e| {
UExpression::try_from_typed(e, bitwidth)
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)
}
Type::Array(ref inner_array_ty) => {
// try to convert all elements to array
inline_array
.into_iter()
.map(|e| {
ArrayExpression::try_from_typed(e, inner_array_ty.clone())
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)
}
Type::Struct(ref struct_ty) => {
// try to convert all elements to struct
inline_array
.into_iter()
.map(|e| {
StructExpression::try_from_typed(e, struct_ty.clone())
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)
}
Type::Boolean => {
// try to convert all elements to boolean
inline_array
.into_iter()
.map(|e| {
BooleanExpression::try_from_typed(e)
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)
}
}
// elements must fit in the target type
let converted = match array.into_inner() {
ArrayExpressionInner::Value(inline_array) => {
match *target_array_ty.ty {
Type::Int => Ok(inline_array),
Type::FieldElement => {
// try to convert all elements to field
inline_array
.into_iter()
.map(|e| {
FieldElementExpression::try_from_typed(e).map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)
}
Type::Uint(bitwidth) => {
// try to convert all elements to uint
inline_array
.into_iter()
.map(|e| {
UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)
}
Type::Array(ref inner_array_ty) => {
// try to convert all elements to array
inline_array
.into_iter()
.map(|e| {
ArrayExpression::try_from_typed(e, inner_array_ty.clone())
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)
}
Type::Struct(ref struct_ty) => {
// try to convert all elements to struct
inline_array
.into_iter()
.map(|e| {
StructExpression::try_from_typed(e, struct_ty.clone())
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)
}
Type::Boolean => {
// try to convert all elements to boolean
inline_array
.into_iter()
.map(|e| {
BooleanExpression::try_from_typed(e).map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)
}
_ => unreachable!(""),
}
}
false => Err(array.into()),
_ => unreachable!(),
}?;
Ok(ArrayExpressionInner::Value(converted).annotate(*target_array_ty.ty, array_ty.size))

View file

@ -8,6 +8,7 @@
pub mod abi;
pub mod folder;
pub mod identifier;
pub mod result_folder;
mod integer;
mod parameter;
@ -19,8 +20,8 @@ pub use self::identifier::CoreIdentifier;
pub use self::parameter::{DeclarationParameter, GParameter};
pub use self::types::{
ConcreteFunctionKey, ConcreteSignature, ConcreteType, DeclarationFunctionKey,
DeclarationSignature, DeclarationType, GType, GenericIdentifier, Signature, StructType, Type,
UBitwidth,
DeclarationSignature, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
Signature, StructType, Type, UBitwidth,
};
use typed_absy::types::GenericsAssignment;
@ -141,7 +142,7 @@ pub struct TypedModule<'ast, T> {
#[derive(Clone, PartialEq)]
pub enum TypedFunctionSymbol<'ast, T> {
Here(TypedFunction<'ast, T>),
There(DeclarationFunctionKey<'ast>, TypedModuleId),
There(DeclarationFunctionKey<'ast>),
Flat(FlatEmbed),
}
@ -150,7 +151,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunctionSymbol<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
TypedFunctionSymbol::Here(s) => write!(f, "Here({:?})", s),
TypedFunctionSymbol::There(key, module) => write!(f, "There({:?}, {:?})", key, module),
TypedFunctionSymbol::There(key) => write!(f, "There({:?})", key),
TypedFunctionSymbol::Flat(s) => write!(f, "Flat({:?})", s),
}
}
@ -163,8 +164,8 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> {
) -> DeclarationSignature<'ast> {
match self {
TypedFunctionSymbol::Here(f) => f.signature.clone(),
TypedFunctionSymbol::There(key, module_id) => modules
.get(module_id)
TypedFunctionSymbol::There(key) => modules
.get(&key.module)
.unwrap()
.functions
.get(key)
@ -183,10 +184,10 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
.iter()
.map(|(key, symbol)| match symbol {
TypedFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function),
TypedFunctionSymbol::There(ref fun_key, ref module_id) => format!(
TypedFunctionSymbol::There(ref fun_key) => format!(
"import {} from \"{}\" as {} // with signature {}",
fun_key.id,
module_id.display(),
fun_key.module.display(),
key.id,
key.signature
),
@ -230,12 +231,16 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"<{}>",
self.generics
.iter()
.map(|g| g.to_string())
.collect::<Vec<_>>()
.join(", ")
"{}",
if self.generics.len() > 0 {
self.generics
.iter()
.map(|g| g.to_string())
.collect::<Vec<_>>()
.join(", ")
} else {
"".to_string()
}
)?;
write!(
f,
@ -271,7 +276,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
for s in &self.statements {
match s {
TypedStatement::PopCallLog(..) => tab -= 1,
TypedStatement::PopCallLog => tab -= 1,
_ => {}
};
@ -380,13 +385,8 @@ pub enum TypedStatement<'ast, T> {
),
MultipleDefinition(Vec<Variable<'ast, T>>, TypedExpressionList<'ast, T>),
// Aux
PushCallLog(
TypedModuleId,
DeclarationFunctionKey<'ast>,
GenericsAssignment<'ast>,
Vec<(ConcreteVariable<'ast>, TypedExpression<'ast, T>)>,
),
PopCallLog(Vec<(ConcreteVariable<'ast>, TypedExpression<'ast, T>)>),
PushCallLog(DeclarationFunctionKey<'ast>, GenericsAssignment<'ast>),
PopCallLog,
}
impl<'ast, T: fmt::Debug> fmt::Debug for TypedStatement<'ast, T> {
@ -417,16 +417,10 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedStatement<'ast, T> {
TypedStatement::MultipleDefinition(ref lhs, ref rhs) => {
write!(f, "MultipleDefinition({:?}, {:?})", lhs, rhs)
}
TypedStatement::PushCallLog(ref module_id, ref key, ref generics, ref assignments) => {
write!(
f,
"PushCallLog({:?}, {:?}, {:?}, {:?})",
module_id, key, generics, assignments
)
}
TypedStatement::PopCallLog(ref assignments) => {
write!(f, "PopCallLog({:?})", assignments)
TypedStatement::PushCallLog(ref key, ref generics) => {
write!(f, "PushCallLog({:?}, {:?})", key, generics)
}
TypedStatement::PopCallLog => write!(f, "PopCallLog"),
}
}
}
@ -480,29 +474,14 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
}
write!(f, " = {}", rhs)
}
TypedStatement::PushCallLog(ref module_id, ref key, ref generics, ref assignments) => {
write!(
f,
"// PUSH CALL TO {}_{}::<{}> with {}",
module_id.display(),
key.id,
generics,
assignments
.iter()
.map(|(v, e)| format!("{} := {}", v, e))
.collect::<Vec<_>>()
.join(", ")
)
}
TypedStatement::PopCallLog(ref assignments) => write!(
TypedStatement::PushCallLog(ref key, ref generics) => write!(
f,
"// POP CALL with {}",
assignments
.iter()
.map(|(v, e)| format!("{} := {}", v, e))
.collect::<Vec<_>>()
.join(", ")
"// PUSH CALL TO {}/{}::<{}>",
key.module.display(),
key.id,
generics,
),
TypedStatement::PopCallLog => write!(f, "// POP CALL",),
}
}
}
@ -735,7 +714,7 @@ pub enum FieldElementExpression<'ast, T> {
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FunctionCall(FunctionKey<'ast, T>, Vec<TypedExpression<'ast, T>>),
FunctionCall(DeclarationFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
Member(Box<StructExpression<'ast, T>>, MemberId),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
}
@ -822,7 +801,7 @@ pub enum BooleanExpression<'ast, T> {
Box<BooleanExpression<'ast, T>>,
),
Member(Box<StructExpression<'ast, T>>, MemberId),
FunctionCall(FunctionKey<'ast, T>, Vec<TypedExpression<'ast, T>>),
FunctionCall(DeclarationFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
}
@ -848,7 +827,7 @@ pub struct ArrayExpression<'ast, T> {
pub enum ArrayExpressionInner<'ast, T> {
Identifier(Identifier<'ast>),
Value(Vec<TypedExpression<'ast, T>>),
FunctionCall(FunctionKey<'ast, T>, Vec<TypedExpression<'ast, T>>),
FunctionCall(DeclarationFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
IfElse(
Box<BooleanExpression<'ast, T>>,
Box<ArrayExpression<'ast, T>>,
@ -939,7 +918,7 @@ impl<'ast, T> StructExpression<'ast, T> {
pub enum StructExpressionInner<'ast, T> {
Identifier(Identifier<'ast>),
Value(Vec<TypedExpression<'ast, T>>),
FunctionCall(FunctionKey<'ast, T>, Vec<TypedExpression<'ast, T>>),
FunctionCall(DeclarationFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
IfElse(
Box<BooleanExpression<'ast, T>>,
Box<StructExpression<'ast, T>>,
@ -1328,7 +1307,9 @@ impl<'ast, T: Field> From<Variable<'ast, T>> for TypedExpression<'ast, T> {
Type::Array(ty) => ArrayExpressionInner::Identifier(v.id)
.annotate(*ty.ty, ty.size)
.into(),
_ => unimplemented!(),
Type::Struct(ty) => StructExpressionInner::Identifier(v.id).annotate(ty).into(),
Type::Uint(w) => UExpressionInner::Identifier(v.id).annotate(w).into(),
Type::Int => unreachable!(),
}
}
}
@ -1536,3 +1517,69 @@ impl<'ast, T: Clone> Member<'ast, T> for StructExpression<'ast, T> {
StructExpressionInner::Member(box s, member_id).annotate(members)
}
}
pub trait FunctionCall<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self;
}
impl<'ast, T> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self {
FieldElementExpression::FunctionCall(key, arguments)
}
}
impl<'ast, T> FunctionCall<'ast, T> for BooleanExpression<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self {
BooleanExpression::FunctionCall(key, arguments)
}
}
impl<'ast, T: Clone> FunctionCall<'ast, T> for UExpression<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self {
let bitwidth = match &key.signature.outputs[0] {
DeclarationType::Uint(bitwidth) => bitwidth.clone(),
_ => unreachable!(),
};
UExpressionInner::FunctionCall(key, arguments).annotate(bitwidth)
}
}
impl<'ast, T: Clone> FunctionCall<'ast, T> for ArrayExpression<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self {
let array_ty = match &key.signature.outputs[0] {
DeclarationType::Array(array_ty) => array_ty.clone(),
_ => unreachable!(),
};
ArrayExpressionInner::FunctionCall(key, arguments)
.annotate(Type::<T>::from(*array_ty.ty), array_ty.size)
}
}
impl<'ast, T: Clone> FunctionCall<'ast, T> for StructExpression<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self {
let struct_ty = match &key.signature.outputs[0] {
DeclarationType::Struct(struct_ty) => struct_ty.clone(),
_ => unreachable!(),
};
StructExpressionInner::FunctionCall(key, arguments).annotate(StructType::from(struct_ty))
}
}

View file

@ -0,0 +1,667 @@
// Generic walk through a typed AST. Not mutating in place
use crate::typed_absy::*;
use typed_absy::types::{ArrayType, StructMember, StructType};
use zokrates_field::Field;
pub trait ResultFolder<'ast, T: Field>: Sized {
type Error;
fn fold_function_symbol(
&mut self,
s: TypedFunctionSymbol<'ast, T>,
) -> Result<TypedFunctionSymbol<'ast, T>, Self::Error> {
fold_function_symbol(self, s)
}
fn fold_function(
&mut self,
f: TypedFunction<'ast, T>,
) -> Result<TypedFunction<'ast, T>, Self::Error> {
fold_function(self, f)
}
fn fold_parameter(
&mut self,
p: DeclarationParameter<'ast>,
) -> Result<DeclarationParameter<'ast>, Self::Error> {
Ok(DeclarationParameter {
id: self.fold_declaration_variable(p.id)?,
..p
})
}
fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> {
Ok(n)
}
fn fold_variable(&mut self, v: Variable<'ast, T>) -> Result<Variable<'ast, T>, Self::Error> {
Ok(Variable {
id: self.fold_name(v.id)?,
_type: self.fold_type(v._type)?,
})
}
fn fold_declaration_variable(
&mut self,
v: DeclarationVariable<'ast>,
) -> Result<DeclarationVariable<'ast>, Self::Error> {
Ok(DeclarationVariable {
id: self.fold_name(v.id)?,
_type: self.fold_declaration_type(v._type)?,
})
}
fn fold_type(&mut self, t: Type<'ast, T>) -> Result<Type<'ast, T>, Self::Error> {
use self::GType::*;
match t {
Array(array_type) => Ok(Array(ArrayType {
ty: box self.fold_type(*array_type.ty)?,
size: self.fold_uint_expression(array_type.size)?,
})),
Struct(struct_type) => Ok(Struct(StructType {
members: struct_type
.members
.into_iter()
.map(|m| {
self.fold_type(*m.ty.clone())
.map(|ty| StructMember { ty: box ty, ..m })
})
.collect::<Result<_, _>>()?,
..struct_type
})),
t => Ok(t),
}
}
fn fold_declaration_type(
&mut self,
t: DeclarationType<'ast>,
) -> Result<DeclarationType<'ast>, Self::Error> {
Ok(t)
}
fn fold_assignee(
&mut self,
a: TypedAssignee<'ast, T>,
) -> Result<TypedAssignee<'ast, T>, Self::Error> {
match a {
TypedAssignee::Identifier(v) => Ok(TypedAssignee::Identifier(self.fold_variable(v)?)),
TypedAssignee::Select(box a, box index) => Ok(TypedAssignee::Select(
box self.fold_assignee(a)?,
box self.fold_uint_expression(index)?,
)),
TypedAssignee::Member(box s, m) => {
Ok(TypedAssignee::Member(box self.fold_assignee(s)?, m))
}
}
}
fn fold_statement(
&mut self,
s: TypedStatement<'ast, T>,
) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
fold_statement(self, s)
}
fn fold_expression(
&mut self,
e: TypedExpression<'ast, T>,
) -> Result<TypedExpression<'ast, T>, Self::Error> {
match e {
TypedExpression::FieldElement(e) => Ok(self.fold_field_expression(e)?.into()),
TypedExpression::Boolean(e) => Ok(self.fold_boolean_expression(e)?.into()),
TypedExpression::Uint(e) => Ok(self.fold_uint_expression(e)?.into()),
TypedExpression::Array(e) => Ok(self.fold_array_expression(e)?.into()),
TypedExpression::Struct(e) => Ok(self.fold_struct_expression(e)?.into()),
TypedExpression::Int(e) => Ok(self.fold_int_expression(e)?.into()),
}
}
fn fold_array_expression(
&mut self,
e: ArrayExpression<'ast, T>,
) -> Result<ArrayExpression<'ast, T>, Self::Error> {
fold_array_expression(self, e)
}
fn fold_struct_expression(
&mut self,
e: StructExpression<'ast, T>,
) -> Result<StructExpression<'ast, T>, Self::Error> {
fold_struct_expression(self, e)
}
fn fold_expression_list(
&mut self,
es: TypedExpressionList<'ast, T>,
) -> Result<TypedExpressionList<'ast, T>, Self::Error> {
fold_expression_list(self, es)
}
fn fold_int_expression(
&mut self,
e: IntExpression<'ast, T>,
) -> Result<IntExpression<'ast, T>, Self::Error> {
fold_int_expression(self, e)
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
fold_field_expression(self, e)
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
fold_boolean_expression(self, e)
}
fn fold_uint_expression(
&mut self,
e: UExpression<'ast, T>,
) -> Result<UExpression<'ast, T>, Self::Error> {
fold_uint_expression(self, e)
}
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Self::Error> {
fold_uint_expression_inner(self, bitwidth, e)
}
fn fold_array_expression_inner(
&mut self,
ty: &Type<'ast, T>,
size: UExpression<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
fold_array_expression_inner(self, ty, size, e)
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> Result<StructExpressionInner<'ast, T>, Self::Error> {
fold_struct_expression_inner(self, ty, e)
}
}
pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
s: TypedStatement<'ast, T>,
) -> Result<Vec<TypedStatement<'ast, T>>, F::Error> {
let res = match s {
TypedStatement::Return(expressions) => TypedStatement::Return(
expressions
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?,
),
TypedStatement::Definition(a, e) => {
TypedStatement::Definition(f.fold_assignee(a)?, f.fold_expression(e)?)
}
TypedStatement::Declaration(v) => TypedStatement::Declaration(f.fold_variable(v)?),
TypedStatement::Assertion(e) => TypedStatement::Assertion(f.fold_boolean_expression(e)?),
TypedStatement::For(v, from, to, statements) => TypedStatement::For(
f.fold_variable(v)?,
f.fold_uint_expression(from)?,
f.fold_uint_expression(to)?,
statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect(),
),
TypedStatement::MultipleDefinition(variables, elist) => TypedStatement::MultipleDefinition(
variables
.into_iter()
.map(|v| f.fold_variable(v))
.collect::<Result<_, _>>()?,
f.fold_expression_list(elist)?,
),
s => s,
};
Ok(vec![res])
}
pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
_: &Type<'ast, T>,
_: UExpression<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, F::Error> {
let e = match e {
ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)?),
ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value(
exprs
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?,
),
ArrayExpressionInner::FunctionCall(id, exps) => {
let exps = exps
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?;
ArrayExpressionInner::FunctionCall(id, exps)
}
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
ArrayExpressionInner::IfElse(
box f.fold_boolean_expression(condition)?,
box f.fold_array_expression(consequence)?,
box f.fold_array_expression(alternative)?,
)
}
ArrayExpressionInner::Member(box s, id) => {
let s = f.fold_struct_expression(s)?;
ArrayExpressionInner::Member(box s, id)
}
ArrayExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array)?;
let index = f.fold_uint_expression(index)?;
ArrayExpressionInner::Select(box array, box index)
}
};
Ok(e)
}
pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
_: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> Result<StructExpressionInner<'ast, T>, F::Error> {
let e = match e {
StructExpressionInner::Identifier(id) => {
StructExpressionInner::Identifier(f.fold_name(id)?)
}
StructExpressionInner::Value(exprs) => StructExpressionInner::Value(
exprs
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?,
),
StructExpressionInner::FunctionCall(id, exps) => {
let exps = exps
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?;
StructExpressionInner::FunctionCall(id, exps)
}
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
StructExpressionInner::IfElse(
box f.fold_boolean_expression(condition)?,
box f.fold_struct_expression(consequence)?,
box f.fold_struct_expression(alternative)?,
)
}
StructExpressionInner::Member(box s, id) => {
let s = f.fold_struct_expression(s)?;
StructExpressionInner::Member(box s, id)
}
StructExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array)?;
let index = f.fold_uint_expression(index)?;
StructExpressionInner::Select(box array, box index)
}
};
Ok(e)
}
pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, F::Error> {
let e = match e {
FieldElementExpression::Number(n) => FieldElementExpression::Number(n),
FieldElementExpression::Identifier(id) => {
FieldElementExpression::Identifier(f.fold_name(id)?)
}
FieldElementExpression::Add(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
FieldElementExpression::Add(box e1, box e2)
}
FieldElementExpression::Sub(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
FieldElementExpression::Sub(box e1, box e2)
}
FieldElementExpression::Mult(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
FieldElementExpression::Mult(box e1, box e2)
}
FieldElementExpression::Div(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
FieldElementExpression::Div(box e1, box e2)
}
FieldElementExpression::Pow(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
FieldElementExpression::Pow(box e1, box e2)
}
FieldElementExpression::IfElse(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond)?;
let cons = f.fold_field_expression(cons)?;
let alt = f.fold_field_expression(alt)?;
FieldElementExpression::IfElse(box cond, box cons, box alt)
}
FieldElementExpression::FunctionCall(key, exps) => {
let exps = exps
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?;
FieldElementExpression::FunctionCall(key, exps)
}
FieldElementExpression::Member(box s, id) => {
let s = f.fold_struct_expression(s)?;
FieldElementExpression::Member(box s, id)
}
FieldElementExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array)?;
let index = f.fold_uint_expression(index)?;
FieldElementExpression::Select(box array, box index)
}
};
Ok(e)
}
pub fn fold_int_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
_: &mut F,
_: IntExpression<'ast, T>,
) -> Result<IntExpression<'ast, T>, F::Error> {
unreachable!()
}
pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
e: BooleanExpression<'ast, T>,
) -> Result<BooleanExpression<'ast, T>, F::Error> {
let e = match e {
BooleanExpression::Value(v) => BooleanExpression::Value(v),
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?),
BooleanExpression::FieldEq(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldEq(box e1, box e2)
}
BooleanExpression::BoolEq(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1)?;
let e2 = f.fold_boolean_expression(e2)?;
BooleanExpression::BoolEq(box e1, box e2)
}
BooleanExpression::ArrayEq(box e1, box e2) => {
let e1 = f.fold_array_expression(e1)?;
let e2 = f.fold_array_expression(e2)?;
BooleanExpression::ArrayEq(box e1, box e2)
}
BooleanExpression::StructEq(box e1, box e2) => {
let e1 = f.fold_struct_expression(e1)?;
let e2 = f.fold_struct_expression(e2)?;
BooleanExpression::StructEq(box e1, box e2)
}
BooleanExpression::UintEq(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintEq(box e1, box e2)
}
BooleanExpression::FieldLt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldLt(box e1, box e2)
}
BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldLe(box e1, box e2)
}
BooleanExpression::FieldGt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldGt(box e1, box e2)
}
BooleanExpression::FieldGe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldGe(box e1, box e2)
}
BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintLt(box e1, box e2)
}
BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintLe(box e1, box e2)
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintGt(box e1, box e2)
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintGe(box e1, box e2)
}
BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1)?;
let e2 = f.fold_boolean_expression(e2)?;
BooleanExpression::Or(box e1, box e2)
}
BooleanExpression::And(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1)?;
let e2 = f.fold_boolean_expression(e2)?;
BooleanExpression::And(box e1, box e2)
}
BooleanExpression::Not(box e) => {
let e = f.fold_boolean_expression(e)?;
BooleanExpression::Not(box e)
}
BooleanExpression::FunctionCall(key, exps) => {
let exps = exps
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?;
BooleanExpression::FunctionCall(key, exps)
}
BooleanExpression::IfElse(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond)?;
let cons = f.fold_boolean_expression(cons)?;
let alt = f.fold_boolean_expression(alt)?;
BooleanExpression::IfElse(box cond, box cons, box alt)
}
BooleanExpression::Member(box s, id) => {
let s = f.fold_struct_expression(s)?;
BooleanExpression::Member(box s, id)
}
BooleanExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array)?;
let index = f.fold_uint_expression(index)?;
BooleanExpression::Select(box array, box index)
}
};
Ok(e)
}
pub fn fold_uint_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
e: UExpression<'ast, T>,
) -> Result<UExpression<'ast, T>, F::Error> {
Ok(UExpression {
inner: f.fold_uint_expression_inner(e.bitwidth, e.inner)?,
..e
})
}
pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
_: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, F::Error> {
let e = match e {
UExpressionInner::Value(v) => UExpressionInner::Value(v),
UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)?),
UExpressionInner::Add(box left, box right) => {
let left = f.fold_uint_expression(left)?;
let right = f.fold_uint_expression(right)?;
UExpressionInner::Add(box left, box right)
}
UExpressionInner::Sub(box left, box right) => {
let left = f.fold_uint_expression(left)?;
let right = f.fold_uint_expression(right)?;
UExpressionInner::Sub(box left, box right)
}
UExpressionInner::Mult(box left, box right) => {
let left = f.fold_uint_expression(left)?;
let right = f.fold_uint_expression(right)?;
UExpressionInner::Mult(box left, box right)
}
UExpressionInner::Xor(box left, box right) => {
let left = f.fold_uint_expression(left)?;
let right = f.fold_uint_expression(right)?;
UExpressionInner::Xor(box left, box right)
}
UExpressionInner::And(box left, box right) => {
let left = f.fold_uint_expression(left)?;
let right = f.fold_uint_expression(right)?;
UExpressionInner::And(box left, box right)
}
UExpressionInner::Or(box left, box right) => {
let left = f.fold_uint_expression(left)?;
let right = f.fold_uint_expression(right)?;
UExpressionInner::Or(box left, box right)
}
UExpressionInner::LeftShift(box e, box by) => {
let e = f.fold_uint_expression(e)?;
let by = f.fold_field_expression(by)?;
UExpressionInner::LeftShift(box e, box by)
}
UExpressionInner::RightShift(box e, box by) => {
let e = f.fold_uint_expression(e)?;
let by = f.fold_field_expression(by)?;
UExpressionInner::RightShift(box e, box by)
}
UExpressionInner::Not(box e) => {
let e = f.fold_uint_expression(e)?;
UExpressionInner::Not(box e)
}
UExpressionInner::FunctionCall(key, exps) => {
let exps = exps
.into_iter()
.map(|e| f.fold_expression(e))
.collect::<Result<_, _>>()?;
UExpressionInner::FunctionCall(key, exps)
}
UExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array)?;
let index = f.fold_uint_expression(index)?;
UExpressionInner::Select(box array, box index)
}
UExpressionInner::IfElse(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond)?;
let cons = f.fold_uint_expression(cons)?;
let alt = f.fold_uint_expression(alt)?;
UExpressionInner::IfElse(box cond, box cons, box alt)
}
UExpressionInner::Member(box s, id) => {
let s = f.fold_struct_expression(s)?;
UExpressionInner::Member(box s, id)
}
};
Ok(e)
}
pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
fun: TypedFunction<'ast, T>,
) -> Result<TypedFunction<'ast, T>, F::Error> {
Ok(TypedFunction {
arguments: fun
.arguments
.into_iter()
.map(|a| f.fold_parameter(a))
.collect::<Result<_, _>>()?,
statements: fun
.statements
.into_iter()
.map(|s| f.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect(),
..fun
})
}
pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
e: ArrayExpression<'ast, T>,
) -> Result<ArrayExpression<'ast, T>, F::Error> {
let size = f.fold_uint_expression(e.size)?;
Ok(ArrayExpression {
inner: f.fold_array_expression_inner(&e.ty, size.clone(), e.inner)?,
size,
..e
})
}
pub fn fold_expression_list<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
es: TypedExpressionList<'ast, T>,
) -> Result<TypedExpressionList<'ast, T>, F::Error> {
match es {
TypedExpressionList::FunctionCall(id, arguments, types) => {
Ok(TypedExpressionList::FunctionCall(
id,
arguments
.into_iter()
.map(|a| f.fold_expression(a))
.collect::<Result<_, _>>()?,
types
.into_iter()
.map(|t| f.fold_type(t))
.collect::<Result<_, _>>()?,
))
}
}
}
pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
e: StructExpression<'ast, T>,
) -> Result<StructExpression<'ast, T>, F::Error> {
Ok(StructExpression {
inner: f.fold_struct_expression_inner(&e.ty, e.inner)?,
..e
})
}
pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
s: TypedFunctionSymbol<'ast, T>,
) -> Result<TypedFunctionSymbol<'ast, T>, F::Error> {
match s {
TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)),
there => Ok(there), // by default, do not fold modules recursively
}
}

View file

@ -3,7 +3,7 @@ use std::fmt;
use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf};
use typed_absy::{TryFrom, TryInto};
use typed_absy::{UExpression, UExpressionInner};
use typed_absy::{TypedModuleId, UExpression, UExpressionInner};
pub type GenericIdentifier<'ast> = &'ast str;
@ -15,7 +15,7 @@ pub enum Constant<'ast> {
// At this stage we want all constants to be equal
impl<'ast> PartialEq for Constant<'ast> {
fn eq(&self, _: &Self) -> bool {
fn eq(&self, other: &Self) -> bool {
true
}
}
@ -193,6 +193,14 @@ impl<'ast, T: PartialEq> PartialEq<DeclarationArrayType<'ast>> for ArrayType<'as
}
}
impl<'ast, T: PartialEq> ArrayType<'ast, T> {
// array type equality with non-strict size checks
// sizes always match unless they are different constants
pub fn weak_eq(&self, other: &Self) -> bool {
self.ty == other.ty
}
}
fn try_from_g_array_type<T: TryInto<U>, U>(t: GArrayType<T>) -> Result<GArrayType<U>, ()> {
Ok(GArrayType {
size: t.size.try_into().map_err(|_| ())?,
@ -676,6 +684,7 @@ pub type FunctionIdentifier<'ast> = &'ast str;
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
pub struct GFunctionKey<'ast, S> {
pub module: TypedModuleId,
pub id: FunctionIdentifier<'ast>,
pub signature: GSignature<S>,
}
@ -703,12 +712,13 @@ impl<'ast> fmt::Display for GenericsAssignment<'ast> {
impl<'ast> PartialEq<DeclarationFunctionKey<'ast>> for ConcreteFunctionKey<'ast> {
fn eq(&self, other: &DeclarationFunctionKey<'ast>) -> bool {
self.id == other.id && self.signature == other.signature
self.module == other.module && self.id == other.id && self.signature == other.signature
}
}
fn try_from_g_function_key<T: TryInto<U>, U>(k: GFunctionKey<T>) -> Result<GFunctionKey<U>, ()> {
Ok(GFunctionKey {
module: k.module,
signature: signature::try_from_g_signature(k.signature)?,
id: k.id,
})
@ -749,8 +759,12 @@ impl<'ast, T> From<DeclarationFunctionKey<'ast>> for FunctionKey<'ast, T> {
}
impl<'ast, S> GFunctionKey<'ast, S> {
pub fn with_id<U: Into<FunctionIdentifier<'ast>>>(id: U) -> Self {
pub fn with_location<T: Into<TypedModuleId>, U: Into<FunctionIdentifier<'ast>>>(
module: T,
id: U,
) -> Self {
GFunctionKey {
module: module.into(),
id: id.into(),
signature: GSignature::new(),
}
@ -765,11 +779,21 @@ impl<'ast, S> GFunctionKey<'ast, S> {
self.id = id.into();
self
}
pub fn module<T: Into<TypedModuleId>>(mut self, module: T) -> Self {
self.module = module.into();
self
}
}
impl<'ast> ConcreteFunctionKey<'ast> {
pub fn to_slug(&self) -> String {
format!("{}_{}", self.id, self.signature.to_slug())
format!(
"{}/{}_{}",
self.module.display(),
self.id,
self.signature.to_slug()
)
}
}

View file

@ -119,7 +119,7 @@ pub enum UExpressionInner<'ast, T> {
Box<UExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FunctionCall(FunctionKey<'ast, T>, Vec<TypedExpression<'ast, T>>),
FunctionCall(DeclarationFunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
IfElse(
Box<BooleanExpression<'ast, T>>,
Box<UExpression<'ast, T>>,

View file

@ -4,6 +4,7 @@ use zir;
impl<'ast> From<typed_absy::types::ConcreteFunctionKey<'ast>> for zir::types::FunctionKey<'ast> {
fn from(k: typed_absy::types::ConcreteFunctionKey<'ast>) -> zir::types::FunctionKey<'ast> {
zir::types::FunctionKey {
module: k.module,
id: k.id,
signature: k.signature.into(),
}

View file

@ -74,7 +74,7 @@ pub struct ZirModule<'ast, T> {
#[derive(Debug, Clone, PartialEq)]
pub enum ZirFunctionSymbol<'ast, T> {
Here(ZirFunction<'ast, T>),
There(FunctionKey<'ast>, ZirModuleId),
There(FunctionKey<'ast>),
Flat(FlatEmbed),
}
@ -82,8 +82,8 @@ impl<'ast, T> ZirFunctionSymbol<'ast, T> {
pub fn signature<'a>(&'a self, modules: &'a ZirModules<T>) -> Signature {
match self {
ZirFunctionSymbol::Here(f) => f.signature.clone(),
ZirFunctionSymbol::There(key, module_id) => modules
.get(module_id)
ZirFunctionSymbol::There(key) => modules
.get(&key.module)
.unwrap()
.functions
.get(key)
@ -102,10 +102,10 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirModule<'ast, T> {
.iter()
.map(|(key, symbol)| match symbol {
ZirFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function),
ZirFunctionSymbol::There(ref fun_key, ref module_id) => format!(
ZirFunctionSymbol::There(ref fun_key) => format!(
"import {} from \"{}\" as {} // with signature {}",
fun_key.id,
module_id.display(),
fun_key.module.display(),
key.id,
key.signature
),

View file

@ -1,4 +1,5 @@
use std::fmt;
use zir::ZirModuleId;
pub type Identifier<'ast> = &'ast str;
@ -95,13 +96,18 @@ pub type FunctionIdentifier<'ast> = &'ast str;
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
pub struct FunctionKey<'ast> {
pub module: ZirModuleId,
pub id: FunctionIdentifier<'ast>,
pub signature: Signature,
}
impl<'ast> FunctionKey<'ast> {
pub fn with_id<S: Into<Identifier<'ast>>>(id: S) -> Self {
pub fn with_location<T: Into<ZirModuleId>, S: Into<Identifier<'ast>>>(
module: T,
id: S,
) -> Self {
FunctionKey {
module: module.into(),
id: id.into(),
signature: Signature::new(),
}
@ -117,6 +123,11 @@ impl<'ast> FunctionKey<'ast> {
self
}
pub fn module<T: Into<ZirModuleId>>(mut self, module: T) -> Self {
self.module = module.into();
self
}
pub fn to_slug(&self) -> String {
format!("{}_{}", self.id, self.signature.to_slug())
}

View file

@ -1,8 +1,10 @@
import "utils/casts/u32_to_field" as to_field
def main(field x) -> field:
field f = 1
field counter = 0
for field i in 1..5 do
f = if counter == x then f else f * i fi
for u32 i in 1..5 do
f = if counter == x then f else f * to_field(i) fi
counter = if counter == x then counter else counter + 1 fi
endfor
return f

View file

@ -0,0 +1,4 @@
{
"curves": ["Bn128", "Bls12"],
"tests": []
}

View file

@ -0,0 +1,8 @@
def foo<T>(field[T] b) -> field:
return 1
def bar<T>(field[T] b) -> field:
return foo(b)
def main(field[3] a) -> field:
return foo(a) + bar(a)

View file

@ -0,0 +1,23 @@
{
"curves": ["Bn128", "Bls12"],
"tests": [
{
"input": {
"values": [
"1",
"2",
"3"
]
},
"output": {
"Ok": {
"values": [
"1",
"2",
"3"
]
}
}
}
]
}

View file

@ -0,0 +1,6 @@
def foo<T>(field[T] b) -> field[T]:
return b
def main(field[3] a) -> field[3]:
field[3] res = foo(a)
return res

View file

@ -0,0 +1,10 @@
import "EMBED/u32_to_bits" as to_bits
def main(u32 i) -> field:
bool[32] bits = to_bits(i)
field res = 0
for u32 j in 0..32 do
u32 exponent = 32 - j - 1
res = res + 2 ** exponent
endfor
return res

View file

@ -13,7 +13,7 @@ enum Curve {
#[derive(Serialize, Deserialize, Clone)]
struct Tests {
pub entry_point: PathBuf,
pub entry_point: Option<PathBuf>,
pub curves: Option<Vec<Curve>>,
pub max_constraint_count: Option<usize>,
pub tests: Vec<Test>,
@ -92,6 +92,14 @@ pub fn test_inner(test_path: &str) {
let curves = t.curves.clone().unwrap_or(vec![Curve::Bn128]);
let t = Tests {
entry_point: Some(
t.entry_point
.unwrap_or(PathBuf::from(String::from(test_path)).with_extension("zok")),
),
..t
};
for c in &curves {
match c {
Curve::Bn128 => compile_and_run::<Bn128Field>(t.clone()),
@ -101,11 +109,13 @@ pub fn test_inner(test_path: &str) {
}
fn compile_and_run<T: Field>(t: Tests) {
let code = std::fs::read_to_string(&t.entry_point).unwrap();
let entry_point = t.entry_point.unwrap();
let code = std::fs::read_to_string(&entry_point).unwrap();
let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap();
let resolver = FileSystemResolver::with_stdlib_root(stdlib.to_str().unwrap());
let artifacts = compile::<T, _>(code, t.entry_point.clone(), Some(&resolver)).unwrap();
let artifacts = compile::<T, _>(code, entry_point.clone(), Some(&resolver)).unwrap();
let bin = artifacts.prog();
@ -115,7 +125,7 @@ fn compile_and_run<T: Field>(t: Tests) {
println!(
"{} at {}% of max",
t.entry_point.display(),
entry_point.display(),
(count as f32) / (target_count as f32) * 100_f32
);
}
@ -131,7 +141,7 @@ fn compile_and_run<T: Field>(t: Tests) {
match compare(output, test.output) {
Err(e) => {
let mut code = File::open(&t.entry_point).unwrap();
let mut code = File::open(&entry_point).unwrap();
let mut s = String::new();
code.read_to_string(&mut s).unwrap();
let context = format!(