make all tests pass, clean
This commit is contained in:
parent
1bb524f6a2
commit
30801c86fa
13 changed files with 442 additions and 707 deletions
|
@ -629,8 +629,6 @@ fn fold_statement<'ast, T: Field>(
|
|||
})
|
||||
.collect(),
|
||||
)],
|
||||
typed::TypedStatement::PushCallLog(..) => vec![],
|
||||
typed::TypedStatement::PopCallLog => vec![],
|
||||
typed::TypedStatement::For(..) => unreachable!(),
|
||||
};
|
||||
|
||||
|
|
|
@ -161,10 +161,6 @@ pub fn analyse<'ast, T: Field>(
|
|||
let r = reduce_program(r).map_err(Error::from)?;
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
log::debug!("Static analyser: Propagate");
|
||||
let r = Propagator::propagate(r)?;
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
log::debug!("Static analyser: Concretize structs");
|
||||
let r = StructConcretizer::concretize(r);
|
||||
log::trace!("\n{}", r);
|
||||
|
|
|
@ -308,21 +308,12 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
|
|||
}
|
||||
};
|
||||
|
||||
// particular case of `lhs = rhs`
|
||||
if TypedExpression::from(assignee.clone()) == expr {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
if expr.is_constant() {
|
||||
match assignee {
|
||||
TypedAssignee::Identifier(var) => {
|
||||
let expr = expr.into_canonical_constant();
|
||||
|
||||
assert!(
|
||||
self.constants.insert(var.clone().id, expr).is_none(),
|
||||
"{}",
|
||||
var
|
||||
);
|
||||
assert!(self.constants.insert(var.id, expr).is_none());
|
||||
|
||||
Ok(vec![])
|
||||
}
|
||||
|
@ -629,8 +620,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
|
|||
_ => Ok(vec![TypedStatement::Assertion(expr, err)]),
|
||||
}
|
||||
}
|
||||
s @ TypedStatement::PushCallLog(..) => Ok(vec![s]),
|
||||
s @ TypedStatement::PopCallLog => Ok(vec![s]),
|
||||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
@ -1502,7 +1491,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(5)))
|
||||
);
|
||||
}
|
||||
|
@ -1515,7 +1504,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
|
||||
);
|
||||
}
|
||||
|
@ -1528,7 +1517,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(6)))
|
||||
);
|
||||
}
|
||||
|
@ -1541,7 +1530,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
|
||||
);
|
||||
}
|
||||
|
@ -1554,15 +1543,14 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(8)))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn left_shift() {
|
||||
let mut constants = Constants::new();
|
||||
let mut propagator = Propagator::with_constants(&mut constants);
|
||||
let mut propagator = Propagator::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
|
@ -1607,8 +1595,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn right_shift() {
|
||||
let mut constants = Constants::new();
|
||||
let mut propagator = Propagator::with_constants(&mut constants);
|
||||
let mut propagator = Propagator::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
|
@ -1676,7 +1663,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(2)))
|
||||
);
|
||||
}
|
||||
|
@ -1691,7 +1678,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
|
||||
);
|
||||
}
|
||||
|
@ -1713,7 +1700,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
|
||||
);
|
||||
}
|
||||
|
@ -1735,18 +1722,15 @@ mod tests {
|
|||
BooleanExpression::Not(box BooleanExpression::identifier("a".into()));
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_true),
|
||||
Propagator::default().fold_boolean_expression(e_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_false),
|
||||
Propagator::default().fold_boolean_expression(e_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_default.clone()),
|
||||
Propagator::default().fold_boolean_expression(e_default.clone()),
|
||||
Ok(e_default)
|
||||
);
|
||||
}
|
||||
|
@ -1776,23 +1760,19 @@ mod tests {
|
|||
));
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_constant_true),
|
||||
Propagator::default().fold_boolean_expression(e_constant_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_constant_false),
|
||||
Propagator::default().fold_boolean_expression(e_constant_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_true),
|
||||
Propagator::default().fold_boolean_expression(e_identifier_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_unchanged.clone()),
|
||||
Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()),
|
||||
Ok(e_identifier_unchanged)
|
||||
);
|
||||
}
|
||||
|
@ -1800,38 +1780,42 @@ mod tests {
|
|||
#[test]
|
||||
fn bool_eq() {
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(false),
|
||||
BooleanExpression::Value(false)
|
||||
))),
|
||||
))
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(true),
|
||||
BooleanExpression::Value(true)
|
||||
))),
|
||||
))
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(true),
|
||||
BooleanExpression::Value(false)
|
||||
))),
|
||||
))
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(false),
|
||||
BooleanExpression::Value(true)
|
||||
))),
|
||||
))
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -1933,33 +1917,27 @@ mod tests {
|
|||
));
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_constant_true),
|
||||
Propagator::default().fold_boolean_expression(e_constant_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_constant_false),
|
||||
Propagator::default().fold_boolean_expression(e_constant_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_true),
|
||||
Propagator::default().fold_boolean_expression(e_identifier_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_unchanged.clone()),
|
||||
Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()),
|
||||
Ok(e_identifier_unchanged)
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_non_canonical_true),
|
||||
Propagator::default().fold_boolean_expression(e_non_canonical_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_non_canonical_false),
|
||||
Propagator::default().fold_boolean_expression(e_non_canonical_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -1977,13 +1955,11 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_true),
|
||||
Propagator::default().fold_boolean_expression(e_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_false),
|
||||
Propagator::default().fold_boolean_expression(e_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -2001,13 +1977,11 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_true),
|
||||
Propagator::default().fold_boolean_expression(e_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_false),
|
||||
Propagator::default().fold_boolean_expression(e_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -2025,13 +1999,11 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_true),
|
||||
Propagator::default().fold_boolean_expression(e_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_false),
|
||||
Propagator::default().fold_boolean_expression(e_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -2049,13 +2021,11 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_true),
|
||||
Propagator::default().fold_boolean_expression(e_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_false),
|
||||
Propagator::default().fold_boolean_expression(e_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -2065,67 +2035,75 @@ mod tests {
|
|||
let a_bool: Identifier = "a".into();
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::identifier(a_bool.clone())
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::identifier(a_bool.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::identifier(a_bool.clone()),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::identifier(a_bool.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::identifier(a_bool.clone())
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::identifier(a_bool.clone()),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -2135,67 +2113,75 @@ mod tests {
|
|||
let a_bool: Identifier = "a".into();
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::identifier(a_bool.clone())
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::identifier(a_bool.clone()),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::identifier(a_bool.clone())
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::identifier(a_bool.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::identifier(a_bool.clone()),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::identifier(a_bool.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
|
|
@ -135,7 +135,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
}
|
||||
};
|
||||
|
||||
let decl = get_canonical_function(&k, program);
|
||||
let decl = get_canonical_function(k, program);
|
||||
|
||||
// get an assignment of generics for this call site
|
||||
let assignment: ConcreteGenericsAssignment<'ast> = k
|
||||
|
@ -190,7 +190,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
.into_iter()
|
||||
.zip(inferred_signature.inputs.clone())
|
||||
.map(|(p, t)| ConcreteVariable::new(p.id.id, t, false))
|
||||
.map(|v| Variable::from(v))
|
||||
.map(Variable::from)
|
||||
.collect();
|
||||
|
||||
let (statements, mut returns): (Vec<_>, Vec<_>) = f
|
||||
|
|
|
@ -3,13 +3,14 @@
|
|||
// - free of function calls (except for low level calls) thanks to inlining
|
||||
// - free of for-loops thanks to unrolling
|
||||
|
||||
// The process happens in two steps
|
||||
// 1. Shallow SSA for the `main` function
|
||||
// We turn the `main` function into SSA form, but ignoring function calls and for loops
|
||||
// 2. Unroll and inline
|
||||
// We go through the shallow-SSA program and
|
||||
// - unroll loops
|
||||
// - inline function calls. This includes applying shallow-ssa on the target function
|
||||
// The process happens in a greedy way, starting from the main function
|
||||
// For each statement:
|
||||
// * put it in ssa form
|
||||
// * propagate it
|
||||
// * inline it (calling this process recursively)
|
||||
// * propagate again
|
||||
|
||||
// if at any time a generic parameter or loop bound is not constant, error out, because it should have been propagated to a constant by the greedy approach
|
||||
|
||||
mod constants_reader;
|
||||
mod constants_writer;
|
||||
|
@ -21,7 +22,6 @@ use std::collections::HashMap;
|
|||
use zokrates_ast::typed::result_folder::*;
|
||||
use zokrates_ast::typed::DeclarationParameter;
|
||||
use zokrates_ast::typed::Folder;
|
||||
use zokrates_ast::typed::TypedAssemblyStatement;
|
||||
use zokrates_ast::typed::TypedAssignee;
|
||||
use zokrates_ast::typed::{
|
||||
ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall,
|
||||
|
@ -47,23 +47,6 @@ const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20);
|
|||
pub type ConstantDefinitions<'ast, T> =
|
||||
HashMap<CanonicalConstantIdentifier<'ast>, TypedExpression<'ast, T>>;
|
||||
|
||||
// An SSA version map, giving access to the latest version number for each identifier
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct Versions<'ast> {
|
||||
map: HashMap<usize, HashMap<CoreIdentifier<'ast>, usize>>,
|
||||
}
|
||||
|
||||
impl<'ast> Versions<'ast> {
|
||||
fn insert_in_frame(
|
||||
&mut self,
|
||||
id: CoreIdentifier<'ast>,
|
||||
version: usize,
|
||||
frame: usize,
|
||||
) -> Option<usize> {
|
||||
self.map.entry(frame).or_default().insert(id, version)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
Incompatible(String),
|
||||
|
@ -125,10 +108,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
&mut self,
|
||||
p: DeclarationParameter<'ast, T>,
|
||||
) -> Result<DeclarationParameter<'ast, T>, Self::Error> {
|
||||
// this is only used on the entry point
|
||||
let id = p.id.id.id.id.clone();
|
||||
assert!(self.ssa.versions.insert_in_frame(id, 0, 0).is_none());
|
||||
Ok(p)
|
||||
Ok(self.ssa.fold_parameter(p))
|
||||
}
|
||||
|
||||
fn fold_function_call_expression<
|
||||
|
@ -138,34 +118,42 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
ty: &E::Ty,
|
||||
e: FunctionCallExpression<'ast, T, E>,
|
||||
) -> Result<FunctionCallOrExpression<'ast, T, E>, Self::Error> {
|
||||
// generics are already in ssa form
|
||||
|
||||
let generics = e
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| {
|
||||
g.map(|g| {
|
||||
let g = self.ssa.fold_uint_expression(g);
|
||||
let g = self.propagator.fold_uint_expression(g)?;
|
||||
let g = self.fold_uint_expression(g)?;
|
||||
|
||||
self.fold_uint_expression(g)
|
||||
self.propagator
|
||||
.fold_uint_expression(g)
|
||||
.map_err(Self::Error::from)
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
// arguments are already in ssa form
|
||||
|
||||
let arguments = e
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
let e = self.ssa.fold_expression(e);
|
||||
let e = self.propagator.fold_expression(e)?;
|
||||
let e = self.fold_expression(e)?;
|
||||
|
||||
self.fold_expression(e)
|
||||
self.propagator
|
||||
.fold_expression(e)
|
||||
.map_err(Self::Error::from)
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
self.ssa.push_call_frame();
|
||||
|
||||
let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, &self.program);
|
||||
let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program);
|
||||
|
||||
let res = match res {
|
||||
Ok((input_variables, arguments, generics_bindings, statements, expression)) => {
|
||||
|
@ -183,7 +171,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
|
||||
self.statement_buffer.extend(generics_bindings);
|
||||
|
||||
// the lhs is from the inner call frame, the rhs is from the outer one, so only fld the lhs
|
||||
// the lhs is from the inner call frame, the rhs is from the outer one, so only fold the lhs
|
||||
let input_bindings: Vec<_> = input_variables
|
||||
.into_iter()
|
||||
.zip(arguments)
|
||||
|
@ -274,27 +262,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
})
|
||||
}
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, Self::Error> {
|
||||
Ok(match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) => {
|
||||
vec![TypedAssemblyStatement::Assignment(
|
||||
self.fold_assignee(a)?,
|
||||
self.fold_expression(e)?,
|
||||
)]
|
||||
}
|
||||
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
vec![TypedAssemblyStatement::Constraint(
|
||||
self.fold_field_expression(lhs)?,
|
||||
self.fold_field_expression(rhs)?,
|
||||
metadata,
|
||||
)]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_canonical_constant_identifier(
|
||||
&mut self,
|
||||
_: CanonicalConstantIdentifier<'ast>,
|
||||
|
@ -307,28 +274,16 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
s: TypedStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
|
||||
let res = match s {
|
||||
TypedStatement::Definition(a, rhs) => {
|
||||
// usually we transform and then propagate
|
||||
// for definitions we need special treatment: we transform and propagate the rhs (which can contain function calls)
|
||||
// then we reduce the rhs to remove the function calls
|
||||
// only then we transform and propagate the assignee
|
||||
|
||||
let rhs = self.ssa.fold_definition_rhs(rhs);
|
||||
let rhs = self.propagator.fold_definition_rhs(rhs)?;
|
||||
let rhs = self.fold_definition_rhs(rhs)?;
|
||||
|
||||
let a = self.ssa.fold_assignee(a);
|
||||
|
||||
self.propagator
|
||||
.fold_statement(TypedStatement::Definition(a, rhs))?
|
||||
}
|
||||
TypedStatement::For(v, from, to, statements) => {
|
||||
let from = self.ssa.fold_uint_expression(from);
|
||||
let from = self.propagator.fold_uint_expression(from)?;
|
||||
let from = self.fold_uint_expression(from)?;
|
||||
let from = self.propagator.fold_uint_expression(from)?;
|
||||
|
||||
let to = self.ssa.fold_uint_expression(to);
|
||||
let to = self.propagator.fold_uint_expression(to)?;
|
||||
let to = self.fold_uint_expression(to)?;
|
||||
let to = self.propagator.fold_uint_expression(to)?;
|
||||
|
||||
match (from.as_inner(), to.as_inner()) {
|
||||
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from
|
||||
|
@ -345,40 +300,37 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect()),
|
||||
.collect::<Vec<_>>()),
|
||||
_ => Err(Error::NonConstant(format!(
|
||||
"Expected loop bounds to be constant, found {}..{}",
|
||||
from, to
|
||||
))),
|
||||
}?
|
||||
}
|
||||
TypedStatement::Return(e) => {
|
||||
let e = self.ssa.fold_expression(e);
|
||||
let e = self.propagator.fold_expression(e)?;
|
||||
vec![TypedStatement::Return(self.fold_expression(e)?)]
|
||||
}
|
||||
TypedStatement::Assertion(e, error) => {
|
||||
let e = self.ssa.fold_boolean_expression(e);
|
||||
let e = self.propagator.fold_boolean_expression(e)?;
|
||||
s => {
|
||||
let statements = self.ssa.fold_statement(s);
|
||||
|
||||
vec![TypedStatement::Assertion(
|
||||
self.fold_boolean_expression(e)?,
|
||||
error,
|
||||
)]
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|s| self.propagator.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
let statements = statements
|
||||
.map(|s| fold_statement(self, s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
let statements = statements
|
||||
.map(|s| self.propagator.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
statements.collect()
|
||||
}
|
||||
s => self
|
||||
.ssa
|
||||
.fold_statement(s)
|
||||
.into_iter()
|
||||
.map(|s| self.propagator.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|s| fold_statement(self, s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
Ok(self.statement_buffer.drain(..).chain(res).collect())
|
||||
|
@ -394,12 +346,17 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
let array = self.ssa.fold_array_expression(array);
|
||||
let array = self.propagator.fold_array_expression(array)?;
|
||||
let array = self.fold_array_expression(array)?;
|
||||
let array = self.propagator.fold_array_expression(array)?;
|
||||
|
||||
let from = self.ssa.fold_uint_expression(from);
|
||||
let from = self.propagator.fold_uint_expression(from)?;
|
||||
let from = self.fold_uint_expression(from)?;
|
||||
let from = self.propagator.fold_uint_expression(from)?;
|
||||
|
||||
let to = self.ssa.fold_uint_expression(to);
|
||||
let to = self.propagator.fold_uint_expression(to)?;
|
||||
let to = self.fold_uint_expression(to)?;
|
||||
let to = self.propagator.fold_uint_expression(to)?;
|
||||
|
||||
match (from.as_inner(), to.as_inner()) {
|
||||
(UExpressionInner::Value(..), UExpressionInner::Value(..)) => {
|
||||
|
@ -503,14 +460,11 @@ mod tests {
|
|||
// }
|
||||
|
||||
// expected:
|
||||
// def main(field a_0) -> field {
|
||||
// a_1 = a_0;
|
||||
// # PUSH CALL to foo
|
||||
// a_3 := a_1; // input binding
|
||||
// #RETURN_AT_INDEX_0_0 := a_3;
|
||||
// # POP CALL
|
||||
// a_2 = #RETURN_AT_INDEX_0_0;
|
||||
// return a_2;
|
||||
// def main(field a_f0_v0) -> field {
|
||||
// a_f0_v1 = a_f0_v0; // redef
|
||||
// a_f1_v0 = a_f0_v1; // input binding
|
||||
// a_f0_v2 = a_f1_v0; // output binding
|
||||
// return a_f0_v2;
|
||||
// }
|
||||
|
||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||
|
@ -606,30 +560,13 @@ mod tests {
|
|||
Variable::field_element(Identifier::from("a").version(1)).into(),
|
||||
FieldElementExpression::identifier("a".into()).into(),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.output(DeclarationType::FieldElement),
|
||||
),
|
||||
GGenericsAssignment::default(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
Variable::field_element(Identifier::from("a").in_frame(1)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(1)).into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0))
|
||||
.into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(3)).into(),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(2)).into(),
|
||||
FieldElementExpression::identifier(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").in_frame(1)).into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
|
||||
|
@ -678,14 +615,11 @@ mod tests {
|
|||
// }
|
||||
|
||||
// expected:
|
||||
// def main(field a_0) -> field {
|
||||
// field[1] b_0 = [42];
|
||||
// # PUSH CALL to foo::<1>
|
||||
// a_0 = b_0;
|
||||
// #RETURN_AT_INDEX_0_0 := a_0;
|
||||
// # POP CALL
|
||||
// b_1 = #RETURN_AT_INDEX_0_0;
|
||||
// return a_2 + b_1[0];
|
||||
// def main(field a_f0_v0) -> field {
|
||||
// field[1] b_f0_v0 = [a_f0_v0];
|
||||
// a_f1_v0 = b_f0_v0;
|
||||
// b_f0_v1 = a_f1_v0;
|
||||
// return a_f0_v0 + b_f0_v1[0];
|
||||
// }
|
||||
|
||||
let foo_signature = DeclarationSignature::new()
|
||||
|
@ -812,42 +746,19 @@ mod tests {
|
|||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
||||
Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::array(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
Type::FieldElement,
|
||||
1u32,
|
||||
)
|
||||
.into(),
|
||||
ArrayExpression::identifier(Identifier::from("a").version(1))
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier(Identifier::from("a").in_frame(1))
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
(FieldElementExpression::identifier("a".into())
|
||||
|
@ -902,14 +813,11 @@ mod tests {
|
|||
// }
|
||||
|
||||
// expected:
|
||||
// def main(field a_0) -> field {
|
||||
// field[1] b_0 = [42];
|
||||
// # PUSH CALL to foo::<1>
|
||||
// a_0 = b_0;
|
||||
// #RETURN_AT_INDEX_0_0 := a_0;
|
||||
// # POP CALL
|
||||
// b_1 = #RETURN_AT_INDEX_0_0;
|
||||
// return a_2 + b_1[0];
|
||||
// def main(field a) -> field {
|
||||
// field[1] b = [a];
|
||||
// a_f1 = b;
|
||||
// b_1 = a_f1;
|
||||
// return a + b_1[0];
|
||||
// }
|
||||
|
||||
let foo_signature = DeclarationSignature::new()
|
||||
|
@ -1040,47 +948,25 @@ mod tests {
|
|||
TypedStatement::definition(
|
||||
Variable::array("b", Type::FieldElement, 1u32).into(),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![FieldElementExpression::identifier("a".into()).into()].into(),
|
||||
vec![FieldElementExpression::identifier(Identifier::from("a")).into()]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
||||
Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::array(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
Type::FieldElement,
|
||||
1u32,
|
||||
)
|
||||
.into(),
|
||||
ArrayExpression::identifier(Identifier::from("a").version(1))
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier(Identifier::from("a").in_frame(1))
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
(FieldElementExpression::identifier("a".into())
|
||||
|
@ -1306,33 +1192,11 @@ mod tests {
|
|||
|
||||
let expected_main = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![
|
||||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "bar")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 2)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::Return(
|
||||
TupleExpressionInner::Value(vec![])
|
||||
.annotate(TupleType::new(vec![]))
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
statements: vec![TypedStatement::Return(
|
||||
TupleExpressionInner::Value(vec![])
|
||||
.annotate(TupleType::new(vec![]))
|
||||
.into(),
|
||||
)],
|
||||
signature: DeclarationSignature::new(),
|
||||
};
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
// The SSA transformation leaves gaps in the indices when it hits a for-loop, so that the body of the for-loop can
|
||||
// modify the variables in scope. The state of the indices before all for-loops is returned to account for that possibility.
|
||||
// Function calls are also left unvisited
|
||||
// Saving the indices is not required for function calls, as they cannot modify their environment
|
||||
// The SSA transformation
|
||||
// * introduces new versions if and only if we are assigning to an identifier
|
||||
// * does not visit the statements of loops
|
||||
|
||||
// Example:
|
||||
// def main(field a) -> field {
|
||||
|
@ -19,21 +18,34 @@
|
|||
// u32 n_0 = 42;
|
||||
// a_1 = a_0 + 1;
|
||||
// field b_0 = foo(a_1); // we keep the function call as is
|
||||
// # versions: {n: 0, a: 1, b: 0}
|
||||
// for u32 i_0 in 0..n_0 {
|
||||
// <body> // we keep the loop body as is
|
||||
// }
|
||||
// return b_3; // we leave versions b_1 and b_2 to make b accessible and modifiable inside the for-loop
|
||||
// }
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use zokrates_ast::typed::folder::*;
|
||||
use zokrates_ast::typed::identifier::FrameIdentifier;
|
||||
|
||||
use zokrates_ast::typed::*;
|
||||
|
||||
use zokrates_field::Field;
|
||||
|
||||
use super::Versions;
|
||||
// An SSA version map, giving access to the latest version number for each identifier
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Versions<'ast> {
|
||||
map: HashMap<usize, HashMap<CoreIdentifier<'ast>, usize>>,
|
||||
}
|
||||
|
||||
impl<'ast> Default for Versions<'ast> {
|
||||
fn default() -> Self {
|
||||
// create a call frame at index 0
|
||||
Self {
|
||||
map: vec![(0, Default::default())].into_iter().collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ShallowTransformer<'ast> {
|
||||
|
@ -45,14 +57,16 @@ pub struct ShallowTransformer<'ast> {
|
|||
|
||||
impl<'ast> ShallowTransformer<'ast> {
|
||||
pub fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> {
|
||||
let frame_versions = self.versions.map.entry(self.frame()).or_default();
|
||||
let frame = self.frame();
|
||||
|
||||
let frame_versions = self.versions.map.entry(frame).or_default();
|
||||
|
||||
let version = frame_versions
|
||||
.entry(c_id.clone())
|
||||
.and_modify(|e| *e += 1) // if it was already declared, we increment
|
||||
.or_default(); // otherwise, we start from this version
|
||||
|
||||
Identifier::from(c_id).version(*version)
|
||||
Identifier::from(c_id.in_frame(frame)).version(*version)
|
||||
}
|
||||
|
||||
fn issue_next_ssa_variable<T: Field>(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> {
|
||||
|
@ -81,43 +95,69 @@ impl<'ast> ShallowTransformer<'ast> {
|
|||
self.versions.map.remove(&frame);
|
||||
}
|
||||
|
||||
pub fn fold_assignee<T: Field>(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
|
||||
// fold an assignee replacing by the latest version. This is necessary because the trait implementation increases the ssa version for identifiers,
|
||||
// but this should not be applied recursively to complex assignees
|
||||
fn fold_assignee_no_ssa_increase<T: Field>(
|
||||
&mut self,
|
||||
a: TypedAssignee<'ast, T>,
|
||||
) -> TypedAssignee<'ast, T> {
|
||||
match a {
|
||||
TypedAssignee::Identifier(v) => {
|
||||
let v = self.issue_next_ssa_variable(v);
|
||||
TypedAssignee::Identifier(self.fold_variable(v))
|
||||
TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)),
|
||||
TypedAssignee::Select(box a, box index) => TypedAssignee::Select(
|
||||
box self.fold_assignee_no_ssa_increase(a),
|
||||
box self.fold_uint_expression(index),
|
||||
),
|
||||
TypedAssignee::Member(box s, m) => {
|
||||
TypedAssignee::Member(box self.fold_assignee_no_ssa_increase(s), m)
|
||||
}
|
||||
TypedAssignee::Element(box s, index) => {
|
||||
TypedAssignee::Element(box self.fold_assignee_no_ssa_increase(s), index)
|
||||
}
|
||||
a => fold_assignee(self, a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> {
|
||||
fn fold_assembly_statement(
|
||||
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
|
||||
for g in &f.signature.generics {
|
||||
let generic_parameter = match g.as_ref().unwrap() {
|
||||
DeclarationConstant::Generic(g) => g,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let _ = self.issue_next_identifier(CoreIdentifier::from(generic_parameter.clone()));
|
||||
}
|
||||
|
||||
fold_function(self, f)
|
||||
}
|
||||
|
||||
fn fold_parameter(
|
||||
&mut self,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Vec<TypedAssemblyStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) => {
|
||||
let e = self.fold_expression(e);
|
||||
let a = self.fold_assignee(a);
|
||||
vec![TypedAssemblyStatement::Assignment(a, e)]
|
||||
}
|
||||
s => fold_assembly_statement(self, s),
|
||||
p: DeclarationParameter<'ast, T>,
|
||||
) -> DeclarationParameter<'ast, T> {
|
||||
DeclarationParameter {
|
||||
id: DeclarationVariable {
|
||||
id: self.issue_next_identifier(p.id.id.id.id),
|
||||
..p.id
|
||||
},
|
||||
..p
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
|
||||
match a {
|
||||
// create a new version for assignments to identifiers
|
||||
TypedAssignee::Identifier(v) => {
|
||||
let v = self.issue_next_ssa_variable(v);
|
||||
TypedAssignee::Identifier(self.fold_variable(v))
|
||||
}
|
||||
// otherwise, simply replace by the current version
|
||||
a => self.fold_assignee_no_ssa_increase(a),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedStatement::Definition(a, DefinitionRhs::Expression(e)) => {
|
||||
let e = self.fold_expression(e);
|
||||
let a = self.fold_assignee(a);
|
||||
vec![TypedStatement::definition(a, e)]
|
||||
}
|
||||
TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => {
|
||||
let embed_call = self.fold_embed_call(embed_call);
|
||||
let assignee = self.fold_assignee(assignee);
|
||||
vec![TypedStatement::embed_call_definition(assignee, embed_call)]
|
||||
}
|
||||
// only fold bounds of for loop statements
|
||||
TypedStatement::For(v, from, to, stats) => {
|
||||
let from = self.fold_uint_expression(from);
|
||||
let to = self.fold_uint_expression(to);
|
||||
|
@ -127,6 +167,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
// retrieve the latest version
|
||||
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
|
||||
let version = self
|
||||
.versions
|
||||
|
@ -137,13 +178,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> {
|
|||
.cloned()
|
||||
.unwrap_or(0);
|
||||
|
||||
let id = FrameIdentifier {
|
||||
frame: self.frame(),
|
||||
..n.id
|
||||
};
|
||||
|
||||
let res = Identifier { version, id };
|
||||
res
|
||||
n.in_frame(self.frame()).version(version)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -156,36 +191,57 @@ mod tests {
|
|||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn detect_non_constant_bound() {
|
||||
let loops: Vec<TypedStatement<Bn128Field>> = vec![TypedStatement::For(
|
||||
Variable::new("i", Type::Uint(UBitwidth::B32), false),
|
||||
UExpression::identifier("i".into()).annotate(UBitwidth::B32),
|
||||
2u32.into(),
|
||||
vec![],
|
||||
)];
|
||||
fn ignore_loop_content() {
|
||||
// field foo = 0
|
||||
// u32 i = 4;
|
||||
// for u32 i in i..2 {
|
||||
// foo = 5;
|
||||
// }
|
||||
|
||||
let statements = loops;
|
||||
// should be left unchanged, as we do not visit the loop content nor the index variable
|
||||
|
||||
let f = TypedFunction {
|
||||
arguments: vec![],
|
||||
signature: DeclarationSignature::new(),
|
||||
statements,
|
||||
statements: vec![
|
||||
TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(Identifier::from("foo"))),
|
||||
FieldElementExpression::Number(Bn128Field::from(4)).into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::uint(
|
||||
Identifier::from("i"),
|
||||
UBitwidth::B32,
|
||||
)),
|
||||
UExpression::from(0u32).into(),
|
||||
),
|
||||
TypedStatement::For(
|
||||
Variable::new("i", Type::Uint(UBitwidth::B32), false),
|
||||
UExpression::identifier("i".into()).annotate(UBitwidth::B32),
|
||||
2u32.into(),
|
||||
vec![TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(Identifier::from(
|
||||
"foo",
|
||||
))),
|
||||
FieldElementExpression::Number(Bn128Field::from(5)).into(),
|
||||
)],
|
||||
),
|
||||
TypedStatement::Return(
|
||||
TupleExpressionInner::Value(vec![])
|
||||
.annotate(TupleType::new(vec![]))
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::default(),
|
||||
};
|
||||
|
||||
match ShallowTransformer::transform(
|
||||
f,
|
||||
&ConcreteGenericsAssignment::default(),
|
||||
&mut Versions::default(),
|
||||
) {
|
||||
Output::Incomplete(..) => {}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let mut ssa = ShallowTransformer::default();
|
||||
|
||||
assert_eq!(ssa.fold_function(f.clone()), f);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn definition() {
|
||||
// field a
|
||||
// a = 5
|
||||
// field a = 5
|
||||
// a = 6
|
||||
// a
|
||||
|
||||
|
@ -194,9 +250,7 @@ mod tests {
|
|||
// a_1 = 6
|
||||
// a_1
|
||||
|
||||
let mut versions = Versions::new();
|
||||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
let mut u = ShallowTransformer::default();
|
||||
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("a")),
|
||||
|
@ -236,17 +290,14 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn incremental_definition() {
|
||||
// field a
|
||||
// a = 5
|
||||
// field a = 5
|
||||
// a = a + 1
|
||||
|
||||
// should be turned into
|
||||
// a_0 = 5
|
||||
// a_1 = a_0 + 1
|
||||
|
||||
let mut versions = Versions::new();
|
||||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
let mut u = ShallowTransformer::default();
|
||||
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("a")),
|
||||
|
@ -295,9 +346,7 @@ mod tests {
|
|||
// a_0 = 2
|
||||
// a_1 = foo(a_0)
|
||||
|
||||
let mut versions = Versions::new();
|
||||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
let mut u = ShallowTransformer::default();
|
||||
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("a")),
|
||||
|
@ -356,9 +405,7 @@ mod tests {
|
|||
// a_0 = [1, 1]
|
||||
// a_0[1] = 2
|
||||
|
||||
let mut versions = Versions::new();
|
||||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
let mut u = ShallowTransformer::default();
|
||||
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)),
|
||||
|
@ -413,9 +460,7 @@ mod tests {
|
|||
// a_0 = [[0, 1], [2, 3]]
|
||||
// a_0 = [4, 5]
|
||||
|
||||
let mut versions = Versions::new();
|
||||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
let mut u = ShallowTransformer::default();
|
||||
|
||||
let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32));
|
||||
|
||||
|
@ -510,10 +555,10 @@ mod tests {
|
|||
|
||||
mod for_loop {
|
||||
use super::*;
|
||||
use zokrates_ast::typed::types::GGenericsAssignment;
|
||||
|
||||
#[test]
|
||||
fn treat_loop() {
|
||||
// def main<K>(field a) -> field {
|
||||
// def main(field a) -> field {
|
||||
// u32 n = 42;
|
||||
// n = n;
|
||||
// a = a;
|
||||
|
@ -528,24 +573,21 @@ mod tests {
|
|||
// return a;
|
||||
// }
|
||||
|
||||
// When called with K := 1, expected:
|
||||
// expected:
|
||||
// def main(field a_0) -> field {
|
||||
// u32 K = 1;
|
||||
// u32 n_0 = 42;
|
||||
// n_1 = n_0;
|
||||
// a_1 = a_0;
|
||||
// # versions: {n: 1, a: 1, K: 0}
|
||||
// for u32 i_0 in n_1..n_1*n_1 {
|
||||
// a_0 = a_0;
|
||||
// }
|
||||
// a_2 = a_1;
|
||||
// for u32 i_0 in n_1..n_1*n_1 {
|
||||
// a_0 = a_0;
|
||||
// }
|
||||
// a_3 = a_2;
|
||||
// # versions: {n: 2, a: 3, K: 1}
|
||||
// for u32 i_0 in n_2..n_2*n_2 {
|
||||
// a_0 = a_0;
|
||||
// }
|
||||
// a_5 = a_4;
|
||||
// return a_5;
|
||||
// } # versions: {n: 3, a: 5, K: 2}
|
||||
// return a_3;
|
||||
// }
|
||||
|
||||
let f: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
|
@ -595,32 +637,15 @@ mod tests {
|
|||
TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").with_index(0).into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let mut versions = Versions::default();
|
||||
|
||||
let ssa = ShallowTransformer::transform(
|
||||
f,
|
||||
&GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
&mut versions,
|
||||
);
|
||||
let mut ssa = ShallowTransformer::default();
|
||||
|
||||
let expected = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::definition(
|
||||
Variable::uint("K", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(1u32.into()),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(42u32.into()),
|
||||
|
@ -649,16 +674,16 @@ mod tests {
|
|||
)],
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
|
||||
Variable::field_element(Identifier::from("a").version(2)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(1)).into(),
|
||||
),
|
||||
TypedStatement::For(
|
||||
Variable::uint("i", UBitwidth::B32),
|
||||
UExpression::identifier(Identifier::from("n").version(2))
|
||||
UExpression::identifier(Identifier::from("n").version(1))
|
||||
.annotate(UBitwidth::B32),
|
||||
UExpression::identifier(Identifier::from("n").version(2))
|
||||
UExpression::identifier(Identifier::from("n").version(1))
|
||||
.annotate(UBitwidth::B32)
|
||||
* UExpression::identifier(Identifier::from("n").version(2))
|
||||
* UExpression::identifier(Identifier::from("n").version(1))
|
||||
.annotate(UBitwidth::B32),
|
||||
vec![TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
|
@ -666,47 +691,35 @@ mod tests {
|
|||
)],
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(5)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(4)).into(),
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(5)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(3)).into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").with_index(0).into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let res = ssa.fold_function(f);
|
||||
|
||||
assert_eq!(
|
||||
versions,
|
||||
vec![("n".into(), 3), ("a".into(), 5), ("K".into(), 2)]
|
||||
.into_iter()
|
||||
.collect::<Versions>()
|
||||
ssa.versions.map,
|
||||
vec![(
|
||||
0,
|
||||
vec![("n".into(), 1), ("a".into(), 3)].into_iter().collect()
|
||||
)]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
let expected = Output::Incomplete(
|
||||
expected,
|
||||
vec![
|
||||
vec![("n".into(), 1), ("a".into(), 1), ("K".into(), 0)]
|
||||
.into_iter()
|
||||
.collect::<Versions>(),
|
||||
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 1)]
|
||||
.into_iter()
|
||||
.collect::<Versions>(),
|
||||
],
|
||||
);
|
||||
|
||||
assert_eq!(ssa, expected);
|
||||
assert_eq!(res, expected);
|
||||
}
|
||||
}
|
||||
|
||||
mod shadowing {
|
||||
use zokrates_ast::typed::types::GGenericsAssignment;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
|
@ -717,11 +730,11 @@ mod tests {
|
|||
// return;
|
||||
// }
|
||||
|
||||
// should become
|
||||
// should become (only the field variable is affected as shadowing is taken care of in semantics already)
|
||||
|
||||
// def main(field a_0) {
|
||||
// field a_1 = 42;
|
||||
// bool a_2 = true;
|
||||
// def main(field a_s0_v0) {
|
||||
// field a_s0_v1 = 42;
|
||||
// bool a_s1_v0 = true
|
||||
// return;
|
||||
// }
|
||||
|
||||
|
@ -733,7 +746,11 @@ mod tests {
|
|||
TypedExpression::Uint(42u32.into()),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::boolean("a").into(),
|
||||
Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow(
|
||||
"a".into(),
|
||||
1,
|
||||
)))
|
||||
.into(),
|
||||
BooleanExpression::Value(true).into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
|
@ -742,9 +759,7 @@ mod tests {
|
|||
.into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![])
|
||||
.inputs(vec![DeclarationType::FieldElement]),
|
||||
signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let expected: TypedFunction<Bn128Field> = TypedFunction {
|
||||
|
@ -755,7 +770,11 @@ mod tests {
|
|||
TypedExpression::Uint(42u32.into()),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::boolean(Identifier::from("a").version(2)).into(),
|
||||
Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow(
|
||||
"a".into(),
|
||||
1,
|
||||
)))
|
||||
.into(),
|
||||
BooleanExpression::Value(true).into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
|
@ -764,121 +783,17 @@ mod tests {
|
|||
.into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![])
|
||||
.inputs(vec![DeclarationType::FieldElement]),
|
||||
signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let mut versions = Versions::default();
|
||||
let ssa = ShallowTransformer::default().fold_function(f);
|
||||
|
||||
let ssa =
|
||||
ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions);
|
||||
|
||||
assert_eq!(ssa, Output::Complete(expected));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_scope() {
|
||||
// def main(field a) {
|
||||
// for u32 i in 0..1 {
|
||||
// a = a + 1
|
||||
// field a = 42
|
||||
// }
|
||||
// return a
|
||||
// }
|
||||
|
||||
// should become
|
||||
|
||||
// def main(field a_0) {
|
||||
// # versions: {a: 0}
|
||||
// for u32 i in 0..1 {
|
||||
// a_0 = a_0
|
||||
// field a_0 = 42
|
||||
// }
|
||||
// return a_1
|
||||
// }
|
||||
|
||||
let f: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::For(
|
||||
Variable::uint("i", UBitwidth::B32),
|
||||
0u32.into(),
|
||||
1u32.into(),
|
||||
vec![
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a")).into(),
|
||||
FieldElementExpression::identifier("a".into()).into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a")).into(),
|
||||
FieldElementExpression::Number(42usize.into()).into(),
|
||||
),
|
||||
],
|
||||
),
|
||||
TypedStatement::Return(
|
||||
TupleExpressionInner::Value(vec![FieldElementExpression::identifier(
|
||||
"a".into(),
|
||||
)
|
||||
.into()])
|
||||
.annotate(TupleType::new(vec![Type::FieldElement]))
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let expected: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::For(
|
||||
Variable::uint("i", UBitwidth::B32),
|
||||
0u32.into(),
|
||||
1u32.into(),
|
||||
vec![
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a")).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a")).into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a")).into(),
|
||||
FieldElementExpression::Number(42usize.into()).into(),
|
||||
),
|
||||
],
|
||||
),
|
||||
TypedStatement::Return(
|
||||
TupleExpressionInner::Value(vec![FieldElementExpression::identifier(
|
||||
Identifier::from("a").version(1),
|
||||
)
|
||||
.into()])
|
||||
.annotate(TupleType::new(vec![Type::FieldElement]))
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let mut versions = Versions::default();
|
||||
|
||||
let ssa =
|
||||
ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions);
|
||||
|
||||
assert_eq!(
|
||||
ssa,
|
||||
Output::Incomplete(expected, vec![vec![("a".into(), 0)].into_iter().collect()])
|
||||
);
|
||||
assert_eq!(ssa, expected);
|
||||
}
|
||||
}
|
||||
|
||||
mod function_call {
|
||||
use super::*;
|
||||
use zokrates_ast::typed::types::GGenericsAssignment;
|
||||
// test that function calls are left in
|
||||
#[test]
|
||||
fn treat_calls() {
|
||||
|
@ -892,17 +807,12 @@ mod tests {
|
|||
// return a;
|
||||
// }
|
||||
|
||||
// When called with K := 1, expected:
|
||||
// def main(field a_0) -> field {
|
||||
// K = 1;
|
||||
// u32 n_0 = 42;
|
||||
// n_1 = n_0;
|
||||
// a_1 = a_0;
|
||||
// a_2 = foo::<n_1>(a_1);
|
||||
// n_2 = n_1;
|
||||
// a_3 = a_2 * foo::<n_2>(a_2);
|
||||
// a_2 = foo::<42>(a_1);
|
||||
// a_3 = a_2 * foo::<42>(a_2);
|
||||
// return a_3;
|
||||
// } # versions: {n: 2, a: 3}
|
||||
// }
|
||||
|
||||
let f: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
|
@ -960,25 +870,9 @@ mod tests {
|
|||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let mut versions = Versions::default();
|
||||
|
||||
let ssa = ShallowTransformer::transform(
|
||||
f,
|
||||
&GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
&mut versions,
|
||||
);
|
||||
|
||||
let expected = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::definition(
|
||||
Variable::uint("K", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(1u32.into()),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(42u32.into()),
|
||||
|
@ -1042,14 +936,23 @@ mod tests {
|
|||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let mut ssa = ShallowTransformer::default();
|
||||
|
||||
let res = ssa.fold_function(f);
|
||||
|
||||
assert_eq!(
|
||||
versions,
|
||||
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)]
|
||||
.into_iter()
|
||||
.collect::<Versions>()
|
||||
ssa.versions.map,
|
||||
vec![(
|
||||
0,
|
||||
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)]
|
||||
.into_iter()
|
||||
.collect()
|
||||
)]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
assert_eq!(ssa, Output::Incomplete(expected, vec![],));
|
||||
assert_eq!(res, expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -531,10 +531,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
) -> Vec<TypedAssemblyStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) => {
|
||||
vec![TypedAssemblyStatement::Assignment(
|
||||
f.fold_assignee(a),
|
||||
f.fold_expression(e),
|
||||
)]
|
||||
let e = f.fold_expression(e);
|
||||
vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a), e)]
|
||||
}
|
||||
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
vec![TypedAssemblyStatement::Constraint(
|
||||
|
@ -552,8 +550,9 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
) -> Vec<TypedStatement<'ast, T>> {
|
||||
let res = match s {
|
||||
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)),
|
||||
TypedStatement::Definition(a, e) => {
|
||||
TypedStatement::Definition(f.fold_assignee(a), f.fold_definition_rhs(e))
|
||||
TypedStatement::Definition(a, rhs) => {
|
||||
let rhs = f.fold_definition_rhs(rhs);
|
||||
TypedStatement::Definition(f.fold_assignee(a), rhs)
|
||||
}
|
||||
TypedStatement::Assertion(e, error) => {
|
||||
TypedStatement::Assertion(f.fold_boolean_expression(e), error)
|
||||
|
@ -576,7 +575,6 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
.flat_map(|s| f.fold_assembly_statement(s))
|
||||
.collect(),
|
||||
),
|
||||
s => s,
|
||||
};
|
||||
vec![res]
|
||||
}
|
||||
|
|
|
@ -24,6 +24,21 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> FrameIdentifier<'ast> {
|
||||
pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> {
|
||||
FrameIdentifier { frame, ..self }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> Identifier<'ast> {
|
||||
pub fn in_frame(self, frame: usize) -> Identifier<'ast> {
|
||||
Identifier {
|
||||
id: self.id.in_frame(frame),
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> CoreIdentifier<'ast> {
|
||||
pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> {
|
||||
FrameIdentifier { id: self, frame }
|
||||
|
|
|
@ -27,7 +27,7 @@ pub use self::types::{
|
|||
UBitwidth,
|
||||
};
|
||||
use self::types::{ConcreteArrayType, ConcreteStructType};
|
||||
use crate::typed::types::{ConcreteGenericsAssignment, IntoType};
|
||||
use crate::typed::types::IntoType;
|
||||
|
||||
pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable};
|
||||
use std::marker::PhantomData;
|
||||
|
@ -353,19 +353,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
|
|||
|
||||
writeln!(f)?;
|
||||
|
||||
let mut tab = 0;
|
||||
|
||||
for s in &self.statements {
|
||||
if let TypedStatement::PopCallLog = s {
|
||||
tab -= 1;
|
||||
};
|
||||
|
||||
s.fmt_indented(f, 1 + tab)?;
|
||||
writeln!(f)?;
|
||||
|
||||
if let TypedStatement::PushCallLog(..) = s {
|
||||
tab += 1;
|
||||
};
|
||||
writeln!(f, "{}", s)?;
|
||||
}
|
||||
|
||||
writeln!(f, "}}")?;
|
||||
|
@ -695,12 +684,6 @@ pub enum TypedStatement<'ast, T> {
|
|||
Vec<TypedStatement<'ast, T>>,
|
||||
),
|
||||
Log(FormatString, Vec<TypedExpression<'ast, T>>),
|
||||
// Aux
|
||||
PushCallLog(
|
||||
DeclarationFunctionKey<'ast, T>,
|
||||
ConcreteGenericsAssignment<'ast>,
|
||||
),
|
||||
PopCallLog,
|
||||
Assembly(Vec<TypedAssemblyStatement<'ast, T>>),
|
||||
}
|
||||
|
||||
|
@ -714,31 +697,6 @@ impl<'ast, T> TypedStatement<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> TypedStatement<'ast, T> {
|
||||
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
|
||||
match self {
|
||||
TypedStatement::For(variable, from, to, statements) => {
|
||||
write!(f, "{}", "\t".repeat(depth))?;
|
||||
writeln!(f, "for {} in {}..{} {{", variable, from, to)?;
|
||||
for s in statements {
|
||||
s.fmt_indented(f, depth + 1)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
write!(f, "{}}}", "\t".repeat(depth))
|
||||
}
|
||||
TypedStatement::Assembly(statements) => {
|
||||
write!(f, "{}", "\t".repeat(depth))?;
|
||||
writeln!(f, "asm {{")?;
|
||||
for s in statements {
|
||||
writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?;
|
||||
}
|
||||
write!(f, "{}}}", "\t".repeat(depth))
|
||||
}
|
||||
s => write!(f, "{}{}", "\t".repeat(depth), s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
|
@ -773,14 +731,6 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
|
|||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
),
|
||||
TypedStatement::PushCallLog(ref key, ref generics) => write!(
|
||||
f,
|
||||
"// PUSH CALL TO {}/{}::<{}>",
|
||||
key.module.display(),
|
||||
key.id,
|
||||
generics,
|
||||
),
|
||||
TypedStatement::PopCallLog => write!(f, "// POP CALL",),
|
||||
TypedStatement::Assembly(ref statements) => {
|
||||
writeln!(f, "asm {{")?;
|
||||
for s in statements {
|
||||
|
|
|
@ -532,10 +532,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, F::Error> {
|
||||
Ok(match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) => {
|
||||
vec![TypedAssemblyStatement::Assignment(
|
||||
f.fold_assignee(a)?,
|
||||
f.fold_expression(e)?,
|
||||
)]
|
||||
let e = f.fold_expression(e)?;
|
||||
vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, e)]
|
||||
}
|
||||
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
vec![TypedAssemblyStatement::Constraint(
|
||||
|
@ -554,7 +552,8 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
let res = match s {
|
||||
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)?),
|
||||
TypedStatement::Definition(a, e) => {
|
||||
TypedStatement::Definition(f.fold_assignee(a)?, f.fold_definition_rhs(e)?)
|
||||
let rhs = f.fold_definition_rhs(e)?;
|
||||
TypedStatement::Definition(f.fold_assignee(a)?, rhs)
|
||||
}
|
||||
TypedStatement::Assertion(e, error) => {
|
||||
TypedStatement::Assertion(f.fold_boolean_expression(e)?, error)
|
||||
|
@ -586,7 +585,6 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
.flatten()
|
||||
.collect(),
|
||||
),
|
||||
s => s,
|
||||
};
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
|
|
@ -240,9 +240,9 @@ impl<'ast, T> From<u32> for UExpression<'ast, T> {
|
|||
impl<'ast, T: Field> From<DeclarationConstant<'ast, T>> for UExpression<'ast, T> {
|
||||
fn from(c: DeclarationConstant<'ast, T>) -> Self {
|
||||
match c {
|
||||
DeclarationConstant::Generic(_g) => {
|
||||
// UExpression::identifier(FrameIdentifier::from(g).into()).annotate(UBitwidth::B32)
|
||||
unreachable!()
|
||||
DeclarationConstant::Generic(g) => {
|
||||
UExpression::identifier(Identifier::from(CoreIdentifier::from(g)))
|
||||
.annotate(UBitwidth::B32)
|
||||
}
|
||||
DeclarationConstant::Concrete(v) => {
|
||||
UExpressionInner::Value(v as u128).annotate(UBitwidth::B32)
|
||||
|
|
16
zokrates_core_test/tests/tests/call_ssa.json
Normal file
16
zokrates_core_test/tests/tests/call_ssa.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"max_constraint_count": 1,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["0"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "4"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
11
zokrates_core_test/tests/tests/call_ssa.zok
Normal file
11
zokrates_core_test/tests/tests/call_ssa.zok
Normal file
|
@ -0,0 +1,11 @@
|
|||
// main should be x -> x + 4
|
||||
|
||||
def foo(field mut a) -> field {
|
||||
a = a + 1;
|
||||
return a + 1;
|
||||
}
|
||||
|
||||
def main(field mut a) -> field {
|
||||
a = foo(a + 1);
|
||||
return a + 1;
|
||||
}
|
Loading…
Reference in a new issue