1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

make all tests pass, clean

This commit is contained in:
schaeff 2023-02-22 21:30:11 +01:00
parent 1bb524f6a2
commit 30801c86fa
13 changed files with 442 additions and 707 deletions

View file

@ -629,8 +629,6 @@ fn fold_statement<'ast, T: Field>(
})
.collect(),
)],
typed::TypedStatement::PushCallLog(..) => vec![],
typed::TypedStatement::PopCallLog => vec![],
typed::TypedStatement::For(..) => unreachable!(),
};

View file

@ -161,10 +161,6 @@ pub fn analyse<'ast, T: Field>(
let r = reduce_program(r).map_err(Error::from)?;
log::trace!("\n{}", r);
log::debug!("Static analyser: Propagate");
let r = Propagator::propagate(r)?;
log::trace!("\n{}", r);
log::debug!("Static analyser: Concretize structs");
let r = StructConcretizer::concretize(r);
log::trace!("\n{}", r);

View file

@ -308,21 +308,12 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
}
};
// particular case of `lhs = rhs`
if TypedExpression::from(assignee.clone()) == expr {
return Ok(vec![]);
}
if expr.is_constant() {
match assignee {
TypedAssignee::Identifier(var) => {
let expr = expr.into_canonical_constant();
assert!(
self.constants.insert(var.clone().id, expr).is_none(),
"{}",
var
);
assert!(self.constants.insert(var.id, expr).is_none());
Ok(vec![])
}
@ -629,8 +620,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
_ => Ok(vec![TypedStatement::Assertion(expr, err)]),
}
}
s @ TypedStatement::PushCallLog(..) => Ok(vec![s]),
s @ TypedStatement::PopCallLog => Ok(vec![s]),
s => fold_statement(self, s),
}
}
@ -1502,7 +1491,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(5)))
);
}
@ -1515,7 +1504,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
);
}
@ -1528,7 +1517,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(6)))
);
}
@ -1541,7 +1530,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
);
}
@ -1554,15 +1543,14 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(8)))
);
}
#[test]
fn left_shift() {
let mut constants = Constants::new();
let mut propagator = Propagator::with_constants(&mut constants);
let mut propagator = Propagator::default();
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::LeftShift(
@ -1607,8 +1595,7 @@ mod tests {
#[test]
fn right_shift() {
let mut constants = Constants::new();
let mut propagator = Propagator::with_constants(&mut constants);
let mut propagator = Propagator::default();
assert_eq!(
propagator.fold_field_expression(FieldElementExpression::RightShift(
@ -1676,7 +1663,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(2)))
);
}
@ -1691,7 +1678,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
);
}
@ -1713,7 +1700,7 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
);
}
@ -1735,18 +1722,15 @@ mod tests {
BooleanExpression::Not(box BooleanExpression::identifier("a".into()));
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
Propagator::default().fold_boolean_expression(e_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
Propagator::default().fold_boolean_expression(e_false),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_default.clone()),
Propagator::default().fold_boolean_expression(e_default.clone()),
Ok(e_default)
);
}
@ -1776,23 +1760,19 @@ mod tests {
));
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_constant_true),
Propagator::default().fold_boolean_expression(e_constant_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_constant_false),
Propagator::default().fold_boolean_expression(e_constant_false),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_true),
Propagator::default().fold_boolean_expression(e_identifier_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_unchanged.clone()),
Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()),
Ok(e_identifier_unchanged)
);
}
@ -1800,38 +1780,42 @@ mod tests {
#[test]
fn bool_eq() {
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(false),
BooleanExpression::Value(false)
))),
))
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(true),
BooleanExpression::Value(true)
))),
))
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(true),
BooleanExpression::Value(false)
))),
))
),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(false),
BooleanExpression::Value(true)
))),
))
),
Ok(BooleanExpression::Value(false))
);
}
@ -1933,33 +1917,27 @@ mod tests {
));
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_constant_true),
Propagator::default().fold_boolean_expression(e_constant_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_constant_false),
Propagator::default().fold_boolean_expression(e_constant_false),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_true),
Propagator::default().fold_boolean_expression(e_identifier_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_unchanged.clone()),
Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()),
Ok(e_identifier_unchanged)
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_non_canonical_true),
Propagator::default().fold_boolean_expression(e_non_canonical_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_non_canonical_false),
Propagator::default().fold_boolean_expression(e_non_canonical_false),
Ok(BooleanExpression::Value(false))
);
}
@ -1977,13 +1955,11 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
Propagator::default().fold_boolean_expression(e_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
Propagator::default().fold_boolean_expression(e_false),
Ok(BooleanExpression::Value(false))
);
}
@ -2001,13 +1977,11 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
Propagator::default().fold_boolean_expression(e_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
Propagator::default().fold_boolean_expression(e_false),
Ok(BooleanExpression::Value(false))
);
}
@ -2025,13 +1999,11 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
Propagator::default().fold_boolean_expression(e_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
Propagator::default().fold_boolean_expression(e_false),
Ok(BooleanExpression::Value(false))
);
}
@ -2049,13 +2021,11 @@ mod tests {
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
Propagator::default().fold_boolean_expression(e_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
Propagator::default().fold_boolean_expression(e_false),
Ok(BooleanExpression::Value(false))
);
}
@ -2065,67 +2035,75 @@ mod tests {
let a_bool: Identifier = "a".into();
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(true),
box BooleanExpression::identifier(a_bool.clone())
)),
)
),
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(false),
box BooleanExpression::identifier(a_bool.clone())
)),
)
),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(true),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(false),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(true),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::And(
box BooleanExpression::Value(false),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::Value(false))
);
}
@ -2135,67 +2113,75 @@ mod tests {
let a_bool: Identifier = "a".into();
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(true),
box BooleanExpression::identifier(a_bool.clone())
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(false),
box BooleanExpression::identifier(a_bool.clone())
)),
)
),
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::identifier(a_bool.clone()),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::identifier(a_bool.clone()))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(true),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(false),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(true),
box BooleanExpression::Value(true),
)),
)
),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
Propagator::<Bn128Field>::default().fold_boolean_expression(
BooleanExpression::Or(
box BooleanExpression::Value(false),
box BooleanExpression::Value(false),
)),
)
),
Ok(BooleanExpression::Value(false))
);
}

View file

@ -135,7 +135,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
}
};
let decl = get_canonical_function(&k, program);
let decl = get_canonical_function(k, program);
// get an assignment of generics for this call site
let assignment: ConcreteGenericsAssignment<'ast> = k
@ -190,7 +190,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
.into_iter()
.zip(inferred_signature.inputs.clone())
.map(|(p, t)| ConcreteVariable::new(p.id.id, t, false))
.map(|v| Variable::from(v))
.map(Variable::from)
.collect();
let (statements, mut returns): (Vec<_>, Vec<_>) = f

View file

@ -3,13 +3,14 @@
// - free of function calls (except for low level calls) thanks to inlining
// - free of for-loops thanks to unrolling
// The process happens in two steps
// 1. Shallow SSA for the `main` function
// We turn the `main` function into SSA form, but ignoring function calls and for loops
// 2. Unroll and inline
// We go through the shallow-SSA program and
// - unroll loops
// - inline function calls. This includes applying shallow-ssa on the target function
// The process happens in a greedy way, starting from the main function
// For each statement:
// * put it in ssa form
// * propagate it
// * inline it (calling this process recursively)
// * propagate again
// if at any time a generic parameter or loop bound is not constant, error out, because it should have been propagated to a constant by the greedy approach
mod constants_reader;
mod constants_writer;
@ -21,7 +22,6 @@ use std::collections::HashMap;
use zokrates_ast::typed::result_folder::*;
use zokrates_ast::typed::DeclarationParameter;
use zokrates_ast::typed::Folder;
use zokrates_ast::typed::TypedAssemblyStatement;
use zokrates_ast::typed::TypedAssignee;
use zokrates_ast::typed::{
ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall,
@ -47,23 +47,6 @@ const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20);
pub type ConstantDefinitions<'ast, T> =
HashMap<CanonicalConstantIdentifier<'ast>, TypedExpression<'ast, T>>;
// An SSA version map, giving access to the latest version number for each identifier
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct Versions<'ast> {
map: HashMap<usize, HashMap<CoreIdentifier<'ast>, usize>>,
}
impl<'ast> Versions<'ast> {
fn insert_in_frame(
&mut self,
id: CoreIdentifier<'ast>,
version: usize,
frame: usize,
) -> Option<usize> {
self.map.entry(frame).or_default().insert(id, version)
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum Error {
Incompatible(String),
@ -125,10 +108,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
&mut self,
p: DeclarationParameter<'ast, T>,
) -> Result<DeclarationParameter<'ast, T>, Self::Error> {
// this is only used on the entry point
let id = p.id.id.id.id.clone();
assert!(self.ssa.versions.insert_in_frame(id, 0, 0).is_none());
Ok(p)
Ok(self.ssa.fold_parameter(p))
}
fn fold_function_call_expression<
@ -138,34 +118,42 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
ty: &E::Ty,
e: FunctionCallExpression<'ast, T, E>,
) -> Result<FunctionCallOrExpression<'ast, T, E>, Self::Error> {
// generics are already in ssa form
let generics = e
.generics
.into_iter()
.map(|g| {
g.map(|g| {
let g = self.ssa.fold_uint_expression(g);
let g = self.propagator.fold_uint_expression(g)?;
let g = self.fold_uint_expression(g)?;
self.fold_uint_expression(g)
self.propagator
.fold_uint_expression(g)
.map_err(Self::Error::from)
})
.transpose()
})
.collect::<Result<_, _>>()?;
// arguments are already in ssa form
let arguments = e
.arguments
.into_iter()
.map(|e| {
let e = self.ssa.fold_expression(e);
let e = self.propagator.fold_expression(e)?;
let e = self.fold_expression(e)?;
self.fold_expression(e)
self.propagator
.fold_expression(e)
.map_err(Self::Error::from)
})
.collect::<Result<_, _>>()?;
self.ssa.push_call_frame();
let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, &self.program);
let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program);
let res = match res {
Ok((input_variables, arguments, generics_bindings, statements, expression)) => {
@ -183,7 +171,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
self.statement_buffer.extend(generics_bindings);
// the lhs is from the inner call frame, the rhs is from the outer one, so only fld the lhs
// the lhs is from the inner call frame, the rhs is from the outer one, so only fold the lhs
let input_bindings: Vec<_> = input_variables
.into_iter()
.zip(arguments)
@ -274,27 +262,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
})
}
fn fold_assembly_statement(
&mut self,
s: TypedAssemblyStatement<'ast, T>,
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, Self::Error> {
Ok(match s {
TypedAssemblyStatement::Assignment(a, e) => {
vec![TypedAssemblyStatement::Assignment(
self.fold_assignee(a)?,
self.fold_expression(e)?,
)]
}
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
vec![TypedAssemblyStatement::Constraint(
self.fold_field_expression(lhs)?,
self.fold_field_expression(rhs)?,
metadata,
)]
}
})
}
fn fold_canonical_constant_identifier(
&mut self,
_: CanonicalConstantIdentifier<'ast>,
@ -307,28 +274,16 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
s: TypedStatement<'ast, T>,
) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
let res = match s {
TypedStatement::Definition(a, rhs) => {
// usually we transform and then propagate
// for definitions we need special treatment: we transform and propagate the rhs (which can contain function calls)
// then we reduce the rhs to remove the function calls
// only then we transform and propagate the assignee
let rhs = self.ssa.fold_definition_rhs(rhs);
let rhs = self.propagator.fold_definition_rhs(rhs)?;
let rhs = self.fold_definition_rhs(rhs)?;
let a = self.ssa.fold_assignee(a);
self.propagator
.fold_statement(TypedStatement::Definition(a, rhs))?
}
TypedStatement::For(v, from, to, statements) => {
let from = self.ssa.fold_uint_expression(from);
let from = self.propagator.fold_uint_expression(from)?;
let from = self.fold_uint_expression(from)?;
let from = self.propagator.fold_uint_expression(from)?;
let to = self.ssa.fold_uint_expression(to);
let to = self.propagator.fold_uint_expression(to)?;
let to = self.fold_uint_expression(to)?;
let to = self.propagator.fold_uint_expression(to)?;
match (from.as_inner(), to.as_inner()) {
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from
@ -345,40 +300,37 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect()),
.collect::<Vec<_>>()),
_ => Err(Error::NonConstant(format!(
"Expected loop bounds to be constant, found {}..{}",
from, to
))),
}?
}
TypedStatement::Return(e) => {
let e = self.ssa.fold_expression(e);
let e = self.propagator.fold_expression(e)?;
vec![TypedStatement::Return(self.fold_expression(e)?)]
}
TypedStatement::Assertion(e, error) => {
let e = self.ssa.fold_boolean_expression(e);
let e = self.propagator.fold_boolean_expression(e)?;
s => {
let statements = self.ssa.fold_statement(s);
vec![TypedStatement::Assertion(
self.fold_boolean_expression(e)?,
error,
)]
let statements = statements
.into_iter()
.map(|s| self.propagator.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
let statements = statements
.map(|s| fold_statement(self, s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
let statements = statements
.map(|s| self.propagator.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
statements.collect()
}
s => self
.ssa
.fold_statement(s)
.into_iter()
.map(|s| self.propagator.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.map(|s| fold_statement(self, s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect(),
};
Ok(self.statement_buffer.drain(..).chain(res).collect())
@ -394,12 +346,17 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
let array = self.ssa.fold_array_expression(array);
let array = self.propagator.fold_array_expression(array)?;
let array = self.fold_array_expression(array)?;
let array = self.propagator.fold_array_expression(array)?;
let from = self.ssa.fold_uint_expression(from);
let from = self.propagator.fold_uint_expression(from)?;
let from = self.fold_uint_expression(from)?;
let from = self.propagator.fold_uint_expression(from)?;
let to = self.ssa.fold_uint_expression(to);
let to = self.propagator.fold_uint_expression(to)?;
let to = self.fold_uint_expression(to)?;
let to = self.propagator.fold_uint_expression(to)?;
match (from.as_inner(), to.as_inner()) {
(UExpressionInner::Value(..), UExpressionInner::Value(..)) => {
@ -503,14 +460,11 @@ mod tests {
// }
// expected:
// def main(field a_0) -> field {
// a_1 = a_0;
// # PUSH CALL to foo
// a_3 := a_1; // input binding
// #RETURN_AT_INDEX_0_0 := a_3;
// # POP CALL
// a_2 = #RETURN_AT_INDEX_0_0;
// return a_2;
// def main(field a_f0_v0) -> field {
// a_f0_v1 = a_f0_v0; // redef
// a_f1_v0 = a_f0_v1; // input binding
// a_f0_v2 = a_f1_v0; // output binding
// return a_f0_v2;
// }
let foo: TypedFunction<Bn128Field> = TypedFunction {
@ -606,30 +560,13 @@ mod tests {
Variable::field_element(Identifier::from("a").version(1)).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.output(DeclarationType::FieldElement),
),
GGenericsAssignment::default(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
Variable::field_element(Identifier::from("a").in_frame(1)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(1)).into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0))
.into(),
FieldElementExpression::identifier(Identifier::from("a").version(3)).into(),
),
TypedStatement::PopCallLog,
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(2)).into(),
FieldElementExpression::identifier(
Identifier::from(CoreIdentifier::Call(0)).version(0),
)
.into(),
FieldElementExpression::identifier(Identifier::from("a").in_frame(1)).into(),
),
TypedStatement::Return(
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
@ -678,14 +615,11 @@ mod tests {
// }
// expected:
// def main(field a_0) -> field {
// field[1] b_0 = [42];
// # PUSH CALL to foo::<1>
// a_0 = b_0;
// #RETURN_AT_INDEX_0_0 := a_0;
// # POP CALL
// b_1 = #RETURN_AT_INDEX_0_0;
// return a_2 + b_1[0];
// def main(field a_f0_v0) -> field {
// field[1] b_f0_v0 = [a_f0_v0];
// a_f1_v0 = b_f0_v0;
// b_f0_v1 = a_f1_v0;
// return a_f0_v0 + b_f0_v1[0];
// }
let foo_signature = DeclarationSignature::new()
@ -812,42 +746,19 @@ mod tests {
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
.into_iter()
.collect(),
),
),
TypedStatement::definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::definition(
Variable::array(
Identifier::from(CoreIdentifier::Call(0)).version(0),
Type::FieldElement,
1u32,
)
.into(),
ArrayExpression::identifier(Identifier::from("a").version(1))
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::PopCallLog,
TypedStatement::definition(
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier(
Identifier::from(CoreIdentifier::Call(0)).version(0),
)
.annotate(Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier(Identifier::from("a").in_frame(1))
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Return(
(FieldElementExpression::identifier("a".into())
@ -902,14 +813,11 @@ mod tests {
// }
// expected:
// def main(field a_0) -> field {
// field[1] b_0 = [42];
// # PUSH CALL to foo::<1>
// a_0 = b_0;
// #RETURN_AT_INDEX_0_0 := a_0;
// # POP CALL
// b_1 = #RETURN_AT_INDEX_0_0;
// return a_2 + b_1[0];
// def main(field a) -> field {
// field[1] b = [a];
// a_f1 = b;
// b_1 = a_f1;
// return a + b_1[0];
// }
let foo_signature = DeclarationSignature::new()
@ -1040,47 +948,25 @@ mod tests {
TypedStatement::definition(
Variable::array("b", Type::FieldElement, 1u32).into(),
ArrayExpressionInner::Value(
vec![FieldElementExpression::identifier("a".into()).into()].into(),
vec![FieldElementExpression::identifier(Identifier::from("a")).into()]
.into(),
)
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
.into_iter()
.collect(),
),
),
TypedStatement::definition(
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::definition(
Variable::array(
Identifier::from(CoreIdentifier::Call(0)).version(0),
Type::FieldElement,
1u32,
)
.into(),
ArrayExpression::identifier(Identifier::from("a").version(1))
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::PopCallLog,
TypedStatement::definition(
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier(
Identifier::from(CoreIdentifier::Call(0)).version(0),
)
.annotate(Type::FieldElement, 1u32)
.into(),
ArrayExpression::identifier(Identifier::from("a").in_frame(1))
.annotate(Type::FieldElement, 1u32)
.into(),
),
TypedStatement::Return(
(FieldElementExpression::identifier("a".into())
@ -1306,33 +1192,11 @@ mod tests {
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![
TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
.into_iter()
.collect(),
),
),
TypedStatement::PushCallLog(
DeclarationFunctionKey::with_location("main", "bar")
.signature(foo_signature.clone()),
GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 2)]
.into_iter()
.collect(),
),
),
TypedStatement::PopCallLog,
TypedStatement::PopCallLog,
TypedStatement::Return(
TupleExpressionInner::Value(vec![])
.annotate(TupleType::new(vec![]))
.into(),
),
],
statements: vec![TypedStatement::Return(
TupleExpressionInner::Value(vec![])
.annotate(TupleType::new(vec![]))
.into(),
)],
signature: DeclarationSignature::new(),
};

View file

@ -1,7 +1,6 @@
// The SSA transformation leaves gaps in the indices when it hits a for-loop, so that the body of the for-loop can
// modify the variables in scope. The state of the indices before all for-loops is returned to account for that possibility.
// Function calls are also left unvisited
// Saving the indices is not required for function calls, as they cannot modify their environment
// The SSA transformation
// * introduces new versions if and only if we are assigning to an identifier
// * does not visit the statements of loops
// Example:
// def main(field a) -> field {
@ -19,21 +18,34 @@
// u32 n_0 = 42;
// a_1 = a_0 + 1;
// field b_0 = foo(a_1); // we keep the function call as is
// # versions: {n: 0, a: 1, b: 0}
// for u32 i_0 in 0..n_0 {
// <body> // we keep the loop body as is
// }
// return b_3; // we leave versions b_1 and b_2 to make b accessible and modifiable inside the for-loop
// }
use std::collections::HashMap;
use zokrates_ast::typed::folder::*;
use zokrates_ast::typed::identifier::FrameIdentifier;
use zokrates_ast::typed::*;
use zokrates_field::Field;
use super::Versions;
// An SSA version map, giving access to the latest version number for each identifier
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Versions<'ast> {
map: HashMap<usize, HashMap<CoreIdentifier<'ast>, usize>>,
}
impl<'ast> Default for Versions<'ast> {
fn default() -> Self {
// create a call frame at index 0
Self {
map: vec![(0, Default::default())].into_iter().collect(),
}
}
}
#[derive(Debug, Default)]
pub struct ShallowTransformer<'ast> {
@ -45,14 +57,16 @@ pub struct ShallowTransformer<'ast> {
impl<'ast> ShallowTransformer<'ast> {
pub fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> {
let frame_versions = self.versions.map.entry(self.frame()).or_default();
let frame = self.frame();
let frame_versions = self.versions.map.entry(frame).or_default();
let version = frame_versions
.entry(c_id.clone())
.and_modify(|e| *e += 1) // if it was already declared, we increment
.or_default(); // otherwise, we start from this version
Identifier::from(c_id).version(*version)
Identifier::from(c_id.in_frame(frame)).version(*version)
}
fn issue_next_ssa_variable<T: Field>(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> {
@ -81,43 +95,69 @@ impl<'ast> ShallowTransformer<'ast> {
self.versions.map.remove(&frame);
}
pub fn fold_assignee<T: Field>(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
// fold an assignee replacing by the latest version. This is necessary because the trait implementation increases the ssa version for identifiers,
// but this should not be applied recursively to complex assignees
fn fold_assignee_no_ssa_increase<T: Field>(
&mut self,
a: TypedAssignee<'ast, T>,
) -> TypedAssignee<'ast, T> {
match a {
TypedAssignee::Identifier(v) => {
let v = self.issue_next_ssa_variable(v);
TypedAssignee::Identifier(self.fold_variable(v))
TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)),
TypedAssignee::Select(box a, box index) => TypedAssignee::Select(
box self.fold_assignee_no_ssa_increase(a),
box self.fold_uint_expression(index),
),
TypedAssignee::Member(box s, m) => {
TypedAssignee::Member(box self.fold_assignee_no_ssa_increase(s), m)
}
TypedAssignee::Element(box s, index) => {
TypedAssignee::Element(box self.fold_assignee_no_ssa_increase(s), index)
}
a => fold_assignee(self, a),
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> {
fn fold_assembly_statement(
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
for g in &f.signature.generics {
let generic_parameter = match g.as_ref().unwrap() {
DeclarationConstant::Generic(g) => g,
_ => unreachable!(),
};
let _ = self.issue_next_identifier(CoreIdentifier::from(generic_parameter.clone()));
}
fold_function(self, f)
}
fn fold_parameter(
&mut self,
s: TypedAssemblyStatement<'ast, T>,
) -> Vec<TypedAssemblyStatement<'ast, T>> {
match s {
TypedAssemblyStatement::Assignment(a, e) => {
let e = self.fold_expression(e);
let a = self.fold_assignee(a);
vec![TypedAssemblyStatement::Assignment(a, e)]
}
s => fold_assembly_statement(self, s),
p: DeclarationParameter<'ast, T>,
) -> DeclarationParameter<'ast, T> {
DeclarationParameter {
id: DeclarationVariable {
id: self.issue_next_identifier(p.id.id.id.id),
..p.id
},
..p
}
}
fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
match a {
// create a new version for assignments to identifiers
TypedAssignee::Identifier(v) => {
let v = self.issue_next_ssa_variable(v);
TypedAssignee::Identifier(self.fold_variable(v))
}
// otherwise, simply replace by the current version
a => self.fold_assignee_no_ssa_increase(a),
}
}
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Definition(a, DefinitionRhs::Expression(e)) => {
let e = self.fold_expression(e);
let a = self.fold_assignee(a);
vec![TypedStatement::definition(a, e)]
}
TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => {
let embed_call = self.fold_embed_call(embed_call);
let assignee = self.fold_assignee(assignee);
vec![TypedStatement::embed_call_definition(assignee, embed_call)]
}
// only fold bounds of for loop statements
TypedStatement::For(v, from, to, stats) => {
let from = self.fold_uint_expression(from);
let to = self.fold_uint_expression(to);
@ -127,6 +167,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> {
}
}
// retrieve the latest version
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
let version = self
.versions
@ -137,13 +178,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> {
.cloned()
.unwrap_or(0);
let id = FrameIdentifier {
frame: self.frame(),
..n.id
};
let res = Identifier { version, id };
res
n.in_frame(self.frame()).version(version)
}
}
@ -156,36 +191,57 @@ mod tests {
use super::*;
#[test]
fn detect_non_constant_bound() {
let loops: Vec<TypedStatement<Bn128Field>> = vec![TypedStatement::For(
Variable::new("i", Type::Uint(UBitwidth::B32), false),
UExpression::identifier("i".into()).annotate(UBitwidth::B32),
2u32.into(),
vec![],
)];
fn ignore_loop_content() {
// field foo = 0
// u32 i = 4;
// for u32 i in i..2 {
// foo = 5;
// }
let statements = loops;
// should be left unchanged, as we do not visit the loop content nor the index variable
let f = TypedFunction {
arguments: vec![],
signature: DeclarationSignature::new(),
statements,
statements: vec![
TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element(Identifier::from("foo"))),
FieldElementExpression::Number(Bn128Field::from(4)).into(),
),
TypedStatement::definition(
TypedAssignee::Identifier(Variable::uint(
Identifier::from("i"),
UBitwidth::B32,
)),
UExpression::from(0u32).into(),
),
TypedStatement::For(
Variable::new("i", Type::Uint(UBitwidth::B32), false),
UExpression::identifier("i".into()).annotate(UBitwidth::B32),
2u32.into(),
vec![TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element(Identifier::from(
"foo",
))),
FieldElementExpression::Number(Bn128Field::from(5)).into(),
)],
),
TypedStatement::Return(
TupleExpressionInner::Value(vec![])
.annotate(TupleType::new(vec![]))
.into(),
),
],
signature: DeclarationSignature::default(),
};
match ShallowTransformer::transform(
f,
&ConcreteGenericsAssignment::default(),
&mut Versions::default(),
) {
Output::Incomplete(..) => {}
_ => unreachable!(),
};
let mut ssa = ShallowTransformer::default();
assert_eq!(ssa.fold_function(f.clone()), f);
}
#[test]
fn definition() {
// field a
// a = 5
// field a = 5
// a = 6
// a
@ -194,9 +250,7 @@ mod tests {
// a_1 = 6
// a_1
let mut versions = Versions::new();
let mut u = ShallowTransformer::with_versions(&mut versions);
let mut u = ShallowTransformer::default();
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
@ -236,17 +290,14 @@ mod tests {
#[test]
fn incremental_definition() {
// field a
// a = 5
// field a = 5
// a = a + 1
// should be turned into
// a_0 = 5
// a_1 = a_0 + 1
let mut versions = Versions::new();
let mut u = ShallowTransformer::with_versions(&mut versions);
let mut u = ShallowTransformer::default();
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
@ -295,9 +346,7 @@ mod tests {
// a_0 = 2
// a_1 = foo(a_0)
let mut versions = Versions::new();
let mut u = ShallowTransformer::with_versions(&mut versions);
let mut u = ShallowTransformer::default();
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::field_element("a")),
@ -356,9 +405,7 @@ mod tests {
// a_0 = [1, 1]
// a_0[1] = 2
let mut versions = Versions::new();
let mut u = ShallowTransformer::with_versions(&mut versions);
let mut u = ShallowTransformer::default();
let s = TypedStatement::definition(
TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)),
@ -413,9 +460,7 @@ mod tests {
// a_0 = [[0, 1], [2, 3]]
// a_0 = [4, 5]
let mut versions = Versions::new();
let mut u = ShallowTransformer::with_versions(&mut versions);
let mut u = ShallowTransformer::default();
let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32));
@ -510,10 +555,10 @@ mod tests {
mod for_loop {
use super::*;
use zokrates_ast::typed::types::GGenericsAssignment;
#[test]
fn treat_loop() {
// def main<K>(field a) -> field {
// def main(field a) -> field {
// u32 n = 42;
// n = n;
// a = a;
@ -528,24 +573,21 @@ mod tests {
// return a;
// }
// When called with K := 1, expected:
// expected:
// def main(field a_0) -> field {
// u32 K = 1;
// u32 n_0 = 42;
// n_1 = n_0;
// a_1 = a_0;
// # versions: {n: 1, a: 1, K: 0}
// for u32 i_0 in n_1..n_1*n_1 {
// a_0 = a_0;
// }
// a_2 = a_1;
// for u32 i_0 in n_1..n_1*n_1 {
// a_0 = a_0;
// }
// a_3 = a_2;
// # versions: {n: 2, a: 3, K: 1}
// for u32 i_0 in n_2..n_2*n_2 {
// a_0 = a_0;
// }
// a_5 = a_4;
// return a_5;
// } # versions: {n: 3, a: 5, K: 2}
// return a_3;
// }
let f: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
@ -595,32 +637,15 @@ mod tests {
TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()),
],
signature: DeclarationSignature::new()
.generics(vec![Some(
GenericIdentifier::with_name("K").with_index(0).into(),
)])
.inputs(vec![DeclarationType::FieldElement])
.output(DeclarationType::FieldElement),
};
let mut versions = Versions::default();
let ssa = ShallowTransformer::transform(
f,
&GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
.into_iter()
.collect(),
),
&mut versions,
);
let mut ssa = ShallowTransformer::default();
let expected = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::definition(
Variable::uint("K", UBitwidth::B32).into(),
TypedExpression::Uint(1u32.into()),
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
TypedExpression::Uint(42u32.into()),
@ -649,16 +674,16 @@ mod tests {
)],
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
Variable::field_element(Identifier::from("a").version(2)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(1)).into(),
),
TypedStatement::For(
Variable::uint("i", UBitwidth::B32),
UExpression::identifier(Identifier::from("n").version(2))
UExpression::identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32),
UExpression::identifier(Identifier::from("n").version(2))
UExpression::identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32)
* UExpression::identifier(Identifier::from("n").version(2))
* UExpression::identifier(Identifier::from("n").version(1))
.annotate(UBitwidth::B32),
vec![TypedStatement::definition(
Variable::field_element("a").into(),
@ -666,47 +691,35 @@ mod tests {
)],
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a").version(5)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(4)).into(),
Variable::field_element(Identifier::from("a").version(3)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
),
TypedStatement::Return(
FieldElementExpression::identifier(Identifier::from("a").version(5)).into(),
FieldElementExpression::identifier(Identifier::from("a").version(3)).into(),
),
],
signature: DeclarationSignature::new()
.generics(vec![Some(
GenericIdentifier::with_name("K").with_index(0).into(),
)])
.inputs(vec![DeclarationType::FieldElement])
.output(DeclarationType::FieldElement),
};
let res = ssa.fold_function(f);
assert_eq!(
versions,
vec![("n".into(), 3), ("a".into(), 5), ("K".into(), 2)]
.into_iter()
.collect::<Versions>()
ssa.versions.map,
vec![(
0,
vec![("n".into(), 1), ("a".into(), 3)].into_iter().collect()
)]
.into_iter()
.collect()
);
let expected = Output::Incomplete(
expected,
vec![
vec![("n".into(), 1), ("a".into(), 1), ("K".into(), 0)]
.into_iter()
.collect::<Versions>(),
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 1)]
.into_iter()
.collect::<Versions>(),
],
);
assert_eq!(ssa, expected);
assert_eq!(res, expected);
}
}
mod shadowing {
use zokrates_ast::typed::types::GGenericsAssignment;
use super::*;
#[test]
@ -717,11 +730,11 @@ mod tests {
// return;
// }
// should become
// should become (only the field variable is affected as shadowing is taken care of in semantics already)
// def main(field a_0) {
// field a_1 = 42;
// bool a_2 = true;
// def main(field a_s0_v0) {
// field a_s0_v1 = 42;
// bool a_s1_v0 = true
// return;
// }
@ -733,7 +746,11 @@ mod tests {
TypedExpression::Uint(42u32.into()),
),
TypedStatement::definition(
Variable::boolean("a").into(),
Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow(
"a".into(),
1,
)))
.into(),
BooleanExpression::Value(true).into(),
),
TypedStatement::Return(
@ -742,9 +759,7 @@ mod tests {
.into(),
),
],
signature: DeclarationSignature::new()
.generics(vec![])
.inputs(vec![DeclarationType::FieldElement]),
signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]),
};
let expected: TypedFunction<Bn128Field> = TypedFunction {
@ -755,7 +770,11 @@ mod tests {
TypedExpression::Uint(42u32.into()),
),
TypedStatement::definition(
Variable::boolean(Identifier::from("a").version(2)).into(),
Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow(
"a".into(),
1,
)))
.into(),
BooleanExpression::Value(true).into(),
),
TypedStatement::Return(
@ -764,121 +783,17 @@ mod tests {
.into(),
),
],
signature: DeclarationSignature::new()
.generics(vec![])
.inputs(vec![DeclarationType::FieldElement]),
signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]),
};
let mut versions = Versions::default();
let ssa = ShallowTransformer::default().fold_function(f);
let ssa =
ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions);
assert_eq!(ssa, Output::Complete(expected));
}
#[test]
fn next_scope() {
// def main(field a) {
// for u32 i in 0..1 {
// a = a + 1
// field a = 42
// }
// return a
// }
// should become
// def main(field a_0) {
// # versions: {a: 0}
// for u32 i in 0..1 {
// a_0 = a_0
// field a_0 = 42
// }
// return a_1
// }
let f: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::For(
Variable::uint("i", UBitwidth::B32),
0u32.into(),
1u32.into(),
vec![
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
FieldElementExpression::identifier("a".into()).into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
FieldElementExpression::Number(42usize.into()).into(),
),
],
),
TypedStatement::Return(
TupleExpressionInner::Value(vec![FieldElementExpression::identifier(
"a".into(),
)
.into()])
.annotate(TupleType::new(vec![Type::FieldElement]))
.into(),
),
],
signature: DeclarationSignature::new()
.generics(vec![])
.inputs(vec![DeclarationType::FieldElement])
.output(DeclarationType::FieldElement),
};
let expected: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::For(
Variable::uint("i", UBitwidth::B32),
0u32.into(),
1u32.into(),
vec![
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
FieldElementExpression::identifier(Identifier::from("a")).into(),
),
TypedStatement::definition(
Variable::field_element(Identifier::from("a")).into(),
FieldElementExpression::Number(42usize.into()).into(),
),
],
),
TypedStatement::Return(
TupleExpressionInner::Value(vec![FieldElementExpression::identifier(
Identifier::from("a").version(1),
)
.into()])
.annotate(TupleType::new(vec![Type::FieldElement]))
.into(),
),
],
signature: DeclarationSignature::new()
.generics(vec![])
.inputs(vec![DeclarationType::FieldElement])
.output(DeclarationType::FieldElement),
};
let mut versions = Versions::default();
let ssa =
ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions);
assert_eq!(
ssa,
Output::Incomplete(expected, vec![vec![("a".into(), 0)].into_iter().collect()])
);
assert_eq!(ssa, expected);
}
}
mod function_call {
use super::*;
use zokrates_ast::typed::types::GGenericsAssignment;
// test that function calls are left in
#[test]
fn treat_calls() {
@ -892,17 +807,12 @@ mod tests {
// return a;
// }
// When called with K := 1, expected:
// def main(field a_0) -> field {
// K = 1;
// u32 n_0 = 42;
// n_1 = n_0;
// a_1 = a_0;
// a_2 = foo::<n_1>(a_1);
// n_2 = n_1;
// a_3 = a_2 * foo::<n_2>(a_2);
// a_2 = foo::<42>(a_1);
// a_3 = a_2 * foo::<42>(a_2);
// return a_3;
// } # versions: {n: 2, a: 3}
// }
let f: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
@ -960,25 +870,9 @@ mod tests {
.output(DeclarationType::FieldElement),
};
let mut versions = Versions::default();
let ssa = ShallowTransformer::transform(
f,
&GGenericsAssignment(
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
.into_iter()
.collect(),
),
&mut versions,
);
let expected = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
statements: vec![
TypedStatement::definition(
Variable::uint("K", UBitwidth::B32).into(),
TypedExpression::Uint(1u32.into()),
),
TypedStatement::definition(
Variable::uint("n", UBitwidth::B32).into(),
TypedExpression::Uint(42u32.into()),
@ -1042,14 +936,23 @@ mod tests {
.output(DeclarationType::FieldElement),
};
let mut ssa = ShallowTransformer::default();
let res = ssa.fold_function(f);
assert_eq!(
versions,
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)]
.into_iter()
.collect::<Versions>()
ssa.versions.map,
vec![(
0,
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)]
.into_iter()
.collect()
)]
.into_iter()
.collect()
);
assert_eq!(ssa, Output::Incomplete(expected, vec![],));
assert_eq!(res, expected);
}
}
}

View file

@ -531,10 +531,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>(
) -> Vec<TypedAssemblyStatement<'ast, T>> {
match s {
TypedAssemblyStatement::Assignment(a, e) => {
vec![TypedAssemblyStatement::Assignment(
f.fold_assignee(a),
f.fold_expression(e),
)]
let e = f.fold_expression(e);
vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a), e)]
}
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
vec![TypedAssemblyStatement::Constraint(
@ -552,8 +550,9 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
) -> Vec<TypedStatement<'ast, T>> {
let res = match s {
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)),
TypedStatement::Definition(a, e) => {
TypedStatement::Definition(f.fold_assignee(a), f.fold_definition_rhs(e))
TypedStatement::Definition(a, rhs) => {
let rhs = f.fold_definition_rhs(rhs);
TypedStatement::Definition(f.fold_assignee(a), rhs)
}
TypedStatement::Assertion(e, error) => {
TypedStatement::Assertion(f.fold_boolean_expression(e), error)
@ -576,7 +575,6 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
.flat_map(|s| f.fold_assembly_statement(s))
.collect(),
),
s => s,
};
vec![res]
}

View file

@ -24,6 +24,21 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> {
}
}
impl<'ast> FrameIdentifier<'ast> {
pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> {
FrameIdentifier { frame, ..self }
}
}
impl<'ast> Identifier<'ast> {
pub fn in_frame(self, frame: usize) -> Identifier<'ast> {
Identifier {
id: self.id.in_frame(frame),
..self
}
}
}
impl<'ast> CoreIdentifier<'ast> {
pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> {
FrameIdentifier { id: self, frame }

View file

@ -27,7 +27,7 @@ pub use self::types::{
UBitwidth,
};
use self::types::{ConcreteArrayType, ConcreteStructType};
use crate::typed::types::{ConcreteGenericsAssignment, IntoType};
use crate::typed::types::IntoType;
pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable};
use std::marker::PhantomData;
@ -353,19 +353,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
writeln!(f)?;
let mut tab = 0;
for s in &self.statements {
if let TypedStatement::PopCallLog = s {
tab -= 1;
};
s.fmt_indented(f, 1 + tab)?;
writeln!(f)?;
if let TypedStatement::PushCallLog(..) = s {
tab += 1;
};
writeln!(f, "{}", s)?;
}
writeln!(f, "}}")?;
@ -695,12 +684,6 @@ pub enum TypedStatement<'ast, T> {
Vec<TypedStatement<'ast, T>>,
),
Log(FormatString, Vec<TypedExpression<'ast, T>>),
// Aux
PushCallLog(
DeclarationFunctionKey<'ast, T>,
ConcreteGenericsAssignment<'ast>,
),
PopCallLog,
Assembly(Vec<TypedAssemblyStatement<'ast, T>>),
}
@ -714,31 +697,6 @@ impl<'ast, T> TypedStatement<'ast, T> {
}
}
impl<'ast, T: fmt::Display> TypedStatement<'ast, T> {
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
match self {
TypedStatement::For(variable, from, to, statements) => {
write!(f, "{}", "\t".repeat(depth))?;
writeln!(f, "for {} in {}..{} {{", variable, from, to)?;
for s in statements {
s.fmt_indented(f, depth + 1)?;
writeln!(f)?;
}
write!(f, "{}}}", "\t".repeat(depth))
}
TypedStatement::Assembly(statements) => {
write!(f, "{}", "\t".repeat(depth))?;
writeln!(f, "asm {{")?;
for s in statements {
writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?;
}
write!(f, "{}}}", "\t".repeat(depth))
}
s => write!(f, "{}{}", "\t".repeat(depth), s),
}
}
}
impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -773,14 +731,6 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
.collect::<Vec<_>>()
.join(", ")
),
TypedStatement::PushCallLog(ref key, ref generics) => write!(
f,
"// PUSH CALL TO {}/{}::<{}>",
key.module.display(),
key.id,
generics,
),
TypedStatement::PopCallLog => write!(f, "// POP CALL",),
TypedStatement::Assembly(ref statements) => {
writeln!(f, "asm {{")?;
for s in statements {

View file

@ -532,10 +532,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, F::Error> {
Ok(match s {
TypedAssemblyStatement::Assignment(a, e) => {
vec![TypedAssemblyStatement::Assignment(
f.fold_assignee(a)?,
f.fold_expression(e)?,
)]
let e = f.fold_expression(e)?;
vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, e)]
}
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
vec![TypedAssemblyStatement::Constraint(
@ -554,7 +552,8 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
let res = match s {
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)?),
TypedStatement::Definition(a, e) => {
TypedStatement::Definition(f.fold_assignee(a)?, f.fold_definition_rhs(e)?)
let rhs = f.fold_definition_rhs(e)?;
TypedStatement::Definition(f.fold_assignee(a)?, rhs)
}
TypedStatement::Assertion(e, error) => {
TypedStatement::Assertion(f.fold_boolean_expression(e)?, error)
@ -586,7 +585,6 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
.flatten()
.collect(),
),
s => s,
};
Ok(vec![res])
}

View file

@ -240,9 +240,9 @@ impl<'ast, T> From<u32> for UExpression<'ast, T> {
impl<'ast, T: Field> From<DeclarationConstant<'ast, T>> for UExpression<'ast, T> {
fn from(c: DeclarationConstant<'ast, T>) -> Self {
match c {
DeclarationConstant::Generic(_g) => {
// UExpression::identifier(FrameIdentifier::from(g).into()).annotate(UBitwidth::B32)
unreachable!()
DeclarationConstant::Generic(g) => {
UExpression::identifier(Identifier::from(CoreIdentifier::from(g)))
.annotate(UBitwidth::B32)
}
DeclarationConstant::Concrete(v) => {
UExpressionInner::Value(v as u128).annotate(UBitwidth::B32)

View file

@ -0,0 +1,16 @@
{
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": ["0"]
},
"output": {
"Ok": {
"value": "4"
}
}
}
]
}

View file

@ -0,0 +1,11 @@
// main should be x -> x + 4
def foo(field mut a) -> field {
a = a + 1;
return a + 1;
}
def main(field mut a) -> field {
a = foo(a + 1);
return a + 1;
}