Merge pull request #1283 from Zokrates/greedy-reducer
Refactor reducer to reduce memory usage and runtime
This commit is contained in:
commit
8ca79372df
17 changed files with 775 additions and 1150 deletions
1
changelogs/unreleased/1283-schaeff
Normal file
1
changelogs/unreleased/1283-schaeff
Normal file
|
@ -0,0 +1 @@
|
|||
Reduce memory usage and runtime by refactoring the reducer (ssa, propagation, unrolling and inlining)
|
|
@ -629,8 +629,6 @@ fn fold_statement<'ast, T: Field>(
|
|||
})
|
||||
.collect(),
|
||||
)],
|
||||
typed::TypedStatement::PushCallLog(..) => vec![],
|
||||
typed::TypedStatement::PopCallLog => vec![],
|
||||
typed::TypedStatement::For(..) => unreachable!(),
|
||||
};
|
||||
|
||||
|
|
|
@ -161,10 +161,6 @@ pub fn analyse<'ast, T: Field>(
|
|||
let r = reduce_program(r).map_err(Error::from)?;
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
log::debug!("Static analyser: Propagate");
|
||||
let r = Propagator::propagate(r)?;
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
log::debug!("Static analyser: Concretize structs");
|
||||
let r = StructConcretizer::concretize(r);
|
||||
log::trace!("\n{}", r);
|
||||
|
|
|
@ -44,25 +44,16 @@ impl fmt::Display for Error {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Propagator<'ast, 'a, T: Field> {
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Propagator<'ast, T> {
|
||||
// constants keeps track of constant expressions
|
||||
// we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is
|
||||
constants: &'a mut Constants<'ast, T>,
|
||||
constants: Constants<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> {
|
||||
pub fn with_constants(constants: &'a mut Constants<'ast, T>) -> Self {
|
||||
Propagator { constants }
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Propagator<'ast, T> {
|
||||
pub fn propagate(p: TypedProgram<'ast, T>) -> Result<TypedProgram<'ast, T>, Error> {
|
||||
let mut constants = Constants::new();
|
||||
|
||||
Propagator {
|
||||
constants: &mut constants,
|
||||
}
|
||||
.fold_program(p)
|
||||
Propagator::default().fold_program(p)
|
||||
}
|
||||
|
||||
// get a mutable reference to the constant corresponding to a given assignee if any, otherwise
|
||||
|
@ -141,7 +132,7 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> Result<TypedProgram<'ast, T>, Error> {
|
||||
|
@ -629,8 +620,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
_ => Ok(vec![TypedStatement::Assertion(expr, err)]),
|
||||
}
|
||||
}
|
||||
s @ TypedStatement::PushCallLog(..) => Ok(vec![s]),
|
||||
s @ TypedStatement::PopCallLog => Ok(vec![s]),
|
||||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
@ -1502,7 +1491,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(5)))
|
||||
);
|
||||
}
|
||||
|
@ -1515,7 +1504,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(1)))
|
||||
);
|
||||
}
|
||||
|
@ -1528,7 +1517,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(6)))
|
||||
);
|
||||
}
|
||||
|
@ -1541,7 +1530,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
|
||||
);
|
||||
}
|
||||
|
@ -1554,15 +1543,14 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(8)))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn left_shift() {
|
||||
let mut constants = Constants::new();
|
||||
let mut propagator = Propagator::with_constants(&mut constants);
|
||||
let mut propagator = Propagator::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::LeftShift(
|
||||
|
@ -1607,8 +1595,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn right_shift() {
|
||||
let mut constants = Constants::new();
|
||||
let mut propagator = Propagator::with_constants(&mut constants);
|
||||
let mut propagator = Propagator::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_field_expression(FieldElementExpression::RightShift(
|
||||
|
@ -1676,7 +1663,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(2)))
|
||||
);
|
||||
}
|
||||
|
@ -1691,7 +1678,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
|
||||
);
|
||||
}
|
||||
|
@ -1713,7 +1700,7 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
|
||||
Propagator::default().fold_field_expression(e),
|
||||
Ok(FieldElementExpression::Number(Bn128Field::from(3)))
|
||||
);
|
||||
}
|
||||
|
@ -1735,18 +1722,15 @@ mod tests {
|
|||
BooleanExpression::Not(box BooleanExpression::identifier("a".into()));
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_true),
|
||||
Propagator::default().fold_boolean_expression(e_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_false),
|
||||
Propagator::default().fold_boolean_expression(e_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_default.clone()),
|
||||
Propagator::default().fold_boolean_expression(e_default.clone()),
|
||||
Ok(e_default)
|
||||
);
|
||||
}
|
||||
|
@ -1776,23 +1760,19 @@ mod tests {
|
|||
));
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_constant_true),
|
||||
Propagator::default().fold_boolean_expression(e_constant_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_constant_false),
|
||||
Propagator::default().fold_boolean_expression(e_constant_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_true),
|
||||
Propagator::default().fold_boolean_expression(e_identifier_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_unchanged.clone()),
|
||||
Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()),
|
||||
Ok(e_identifier_unchanged)
|
||||
);
|
||||
}
|
||||
|
@ -1800,38 +1780,42 @@ mod tests {
|
|||
#[test]
|
||||
fn bool_eq() {
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(false),
|
||||
BooleanExpression::Value(false)
|
||||
))),
|
||||
))
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(true),
|
||||
BooleanExpression::Value(true)
|
||||
))),
|
||||
))
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(true),
|
||||
BooleanExpression::Value(false)
|
||||
))),
|
||||
))
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(false),
|
||||
BooleanExpression::Value(true)
|
||||
))),
|
||||
))
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -1933,33 +1917,27 @@ mod tests {
|
|||
));
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_constant_true),
|
||||
Propagator::default().fold_boolean_expression(e_constant_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_constant_false),
|
||||
Propagator::default().fold_boolean_expression(e_constant_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_true),
|
||||
Propagator::default().fold_boolean_expression(e_identifier_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_unchanged.clone()),
|
||||
Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()),
|
||||
Ok(e_identifier_unchanged)
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_non_canonical_true),
|
||||
Propagator::default().fold_boolean_expression(e_non_canonical_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_non_canonical_false),
|
||||
Propagator::default().fold_boolean_expression(e_non_canonical_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -1977,13 +1955,11 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_true),
|
||||
Propagator::default().fold_boolean_expression(e_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_false),
|
||||
Propagator::default().fold_boolean_expression(e_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -2001,13 +1977,11 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_true),
|
||||
Propagator::default().fold_boolean_expression(e_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_false),
|
||||
Propagator::default().fold_boolean_expression(e_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -2025,13 +1999,11 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_true),
|
||||
Propagator::default().fold_boolean_expression(e_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_false),
|
||||
Propagator::default().fold_boolean_expression(e_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -2049,13 +2021,11 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_true),
|
||||
Propagator::default().fold_boolean_expression(e_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_false),
|
||||
Propagator::default().fold_boolean_expression(e_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -2065,67 +2035,75 @@ mod tests {
|
|||
let a_bool: Identifier = "a".into();
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::identifier(a_bool.clone())
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::identifier(a_bool.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::identifier(a_bool.clone()),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::identifier(a_bool.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::identifier(a_bool.clone())
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::identifier(a_bool.clone()),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::And(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
@ -2135,67 +2113,75 @@ mod tests {
|
|||
let a_bool: Identifier = "a".into();
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::identifier(a_bool.clone())
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::identifier(a_bool.clone()),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::identifier(a_bool.clone())
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::identifier(a_bool.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::identifier(a_bool.clone()),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::identifier(a_bool.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(true),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::Or(
|
||||
Propagator::<Bn128Field>::default().fold_boolean_expression(
|
||||
BooleanExpression::Or(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::Value(false),
|
||||
)),
|
||||
)
|
||||
),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
|
|
@ -2,10 +2,11 @@
|
|||
|
||||
use crate::reducer::ConstantDefinitions;
|
||||
use zokrates_ast::typed::{
|
||||
folder::*, ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier,
|
||||
DeclarationConstant, Expr, FieldElementExpression, Id, Identifier, IdentifierExpression,
|
||||
StructExpression, StructExpressionInner, StructType, TupleExpression, TupleExpressionInner,
|
||||
TupleType, TypedProgram, TypedSymbolDeclaration, UBitwidth, UExpression, UExpressionInner,
|
||||
folder::*, identifier::FrameIdentifier, ArrayExpression, ArrayExpressionInner, ArrayType,
|
||||
BooleanExpression, CoreIdentifier, DeclarationConstant, Expr, FieldElementExpression, Id,
|
||||
Identifier, IdentifierExpression, StructExpression, StructExpressionInner, StructType,
|
||||
TupleExpression, TupleExpressionInner, TupleType, TypedProgram, TypedSymbolDeclaration,
|
||||
UBitwidth, UExpression, UExpressionInner,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -61,7 +62,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
FieldElementExpression::Identifier(IdentifierExpression {
|
||||
id:
|
||||
Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
..
|
||||
|
@ -86,7 +91,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
BooleanExpression::Identifier(IdentifierExpression {
|
||||
id:
|
||||
Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
..
|
||||
|
@ -112,7 +121,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
UExpressionInner::Identifier(IdentifierExpression {
|
||||
id:
|
||||
Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
..
|
||||
|
@ -136,7 +149,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
ArrayExpressionInner::Identifier(IdentifierExpression {
|
||||
id:
|
||||
Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
..
|
||||
|
@ -162,7 +179,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
TupleExpressionInner::Identifier(IdentifierExpression {
|
||||
id:
|
||||
Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
..
|
||||
|
@ -188,7 +209,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
StructExpressionInner::Identifier(IdentifierExpression {
|
||||
id:
|
||||
Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
..
|
||||
|
|
|
@ -5,9 +5,9 @@ use crate::reducer::{
|
|||
};
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use zokrates_ast::typed::{
|
||||
result_folder::*, types::ConcreteGenericsAssignment, Constant, OwnedTypedModuleId, Typed,
|
||||
TypedConstant, TypedConstantSymbol, TypedConstantSymbolDeclaration, TypedModuleId,
|
||||
TypedProgram, TypedSymbolDeclaration, UExpression,
|
||||
result_folder::*, Constant, OwnedTypedModuleId, Typed, TypedConstant, TypedConstantSymbol,
|
||||
TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram, TypedSymbolDeclaration,
|
||||
UExpression,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -118,11 +118,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> {
|
|||
signature: DeclarationSignature::new().output(c.ty.clone()),
|
||||
};
|
||||
|
||||
let mut inlined_wrapper = reduce_function(
|
||||
wrapper,
|
||||
ConcreteGenericsAssignment::default(),
|
||||
&self.program,
|
||||
)?;
|
||||
let mut inlined_wrapper = reduce_function(wrapper, &self.program)?;
|
||||
|
||||
if let TypedStatement::Return(expression) =
|
||||
inlined_wrapper.statements.pop().unwrap()
|
||||
|
|
|
@ -15,25 +15,24 @@
|
|||
// ```
|
||||
//
|
||||
// Becomes
|
||||
// ```
|
||||
// # Call foo::<42> with a_0 := x
|
||||
// n_0 = 42
|
||||
// a_1 = a_0
|
||||
// n_1 = n_0
|
||||
// # Pop call with #CALL_RETURN_AT_INDEX_0_0 := a_1
|
||||
// inputs: [a]
|
||||
// arguments: [x]
|
||||
// generics_bindings: [n = 42]
|
||||
// statements:
|
||||
// n = 42
|
||||
// a = a
|
||||
// n = n
|
||||
// return_expression: a
|
||||
|
||||
// Notes:
|
||||
// - The body of the function is in SSA form
|
||||
// - The return value(s) are assigned to internal variables
|
||||
|
||||
use crate::reducer::Output;
|
||||
use crate::reducer::ShallowTransformer;
|
||||
use crate::reducer::Versions;
|
||||
// - The body of the function is *not* in SSA form
|
||||
|
||||
use zokrates_ast::common::FlatEmbed;
|
||||
use zokrates_ast::typed::types::{ConcreteGenericsAssignment, IntoType};
|
||||
use zokrates_ast::typed::CoreIdentifier;
|
||||
use zokrates_ast::typed::Identifier;
|
||||
|
||||
use zokrates_ast::typed::TypedAssignee;
|
||||
use zokrates_ast::typed::UBitwidth;
|
||||
use zokrates_ast::typed::{
|
||||
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Expr,
|
||||
Signature, Type, TypedExpression, TypedFunctionSymbol, TypedFunctionSymbolDeclaration,
|
||||
|
@ -43,22 +42,12 @@ use zokrates_field::Field;
|
|||
|
||||
pub enum InlineError<'ast, T> {
|
||||
Generic(DeclarationFunctionKey<'ast, T>, ConcreteFunctionKey<'ast>),
|
||||
Flat(
|
||||
FlatEmbed,
|
||||
Vec<u32>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Type<'ast, T>,
|
||||
),
|
||||
NonConstant(
|
||||
DeclarationFunctionKey<'ast, T>,
|
||||
Vec<Option<UExpression<'ast, T>>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Type<'ast, T>,
|
||||
),
|
||||
Flat(FlatEmbed, Vec<u32>, Type<'ast, T>),
|
||||
NonConstant,
|
||||
}
|
||||
|
||||
fn get_canonical_function<'ast, T: Field>(
|
||||
function_key: DeclarationFunctionKey<'ast, T>,
|
||||
function_key: &DeclarationFunctionKey<'ast, T>,
|
||||
program: &TypedProgram<'ast, T>,
|
||||
) -> TypedFunctionSymbolDeclaration<'ast, T> {
|
||||
let s = program
|
||||
|
@ -66,30 +55,35 @@ fn get_canonical_function<'ast, T: Field>(
|
|||
.get(&function_key.module)
|
||||
.unwrap()
|
||||
.functions_iter()
|
||||
.find(|d| d.key == function_key)
|
||||
.find(|d| d.key == *function_key)
|
||||
.unwrap();
|
||||
|
||||
match &s.symbol {
|
||||
TypedFunctionSymbol::There(key) => get_canonical_function(key.clone(), program),
|
||||
TypedFunctionSymbol::There(key) => get_canonical_function(key, program),
|
||||
_ => s.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
type InlineResult<'ast, T> = Result<
|
||||
Output<(Vec<TypedStatement<'ast, T>>, TypedExpression<'ast, T>), Vec<Versions<'ast>>>,
|
||||
InlineError<'ast, T>,
|
||||
>;
|
||||
pub struct InlineValue<'ast, T> {
|
||||
/// the pre-SSA input variables to assign the arguments to
|
||||
pub input_variables: Vec<Variable<'ast, T>>,
|
||||
/// the pre-SSA statements for this call, including definition of the generic parameters
|
||||
pub statements: Vec<TypedStatement<'ast, T>>,
|
||||
/// the pre-SSA return value for this call
|
||||
pub return_value: TypedExpression<'ast, T>,
|
||||
}
|
||||
|
||||
type InlineResult<'ast, T> = Result<InlineValue<'ast, T>, InlineError<'ast, T>>;
|
||||
|
||||
pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
||||
k: DeclarationFunctionKey<'ast, T>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output: &E::Ty,
|
||||
k: &DeclarationFunctionKey<'ast, T>,
|
||||
generics: &[Option<UExpression<'ast, T>>],
|
||||
arguments: &[TypedExpression<'ast, T>],
|
||||
output_ty: &E::Ty,
|
||||
program: &TypedProgram<'ast, T>,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
) -> InlineResult<'ast, T> {
|
||||
use zokrates_ast::typed::Typed;
|
||||
let output_type = output.clone().into_type();
|
||||
let output_type = output_ty.clone().into_type();
|
||||
|
||||
// we try to get concrete values for explicit generics
|
||||
let generics_values: Vec<Option<u32>> = generics
|
||||
|
@ -103,36 +97,23 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
.transpose()
|
||||
})
|
||||
.collect::<Result<_, _>>()
|
||||
.map_err(|_| {
|
||||
InlineError::NonConstant(
|
||||
k.clone(),
|
||||
generics.clone(),
|
||||
arguments.clone(),
|
||||
output_type.clone(),
|
||||
)
|
||||
})?;
|
||||
.map_err(|_| InlineError::NonConstant)?;
|
||||
|
||||
// we infer a signature based on inputs and outputs
|
||||
// this is where we could handle explicit annotations
|
||||
let inferred_signature = Signature::new()
|
||||
.generics(generics.clone())
|
||||
.generics(generics.to_vec().clone())
|
||||
.inputs(arguments.iter().map(|a| a.get_type()).collect())
|
||||
.output(output_type.clone());
|
||||
|
||||
// we try to get concrete values for the whole signature. if this fails we should propagate again
|
||||
// we try to get concrete values for the whole signature
|
||||
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
return Err(InlineError::NonConstant(
|
||||
k,
|
||||
generics,
|
||||
arguments,
|
||||
output_type,
|
||||
));
|
||||
return Err(InlineError::NonConstant);
|
||||
}
|
||||
};
|
||||
|
||||
let decl = get_canonical_function(k.clone(), program);
|
||||
let decl = get_canonical_function(k, program);
|
||||
|
||||
// get an assignment of generics for this call site
|
||||
let assignment: ConcreteGenericsAssignment<'ast> = k
|
||||
|
@ -154,7 +135,6 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat(
|
||||
e,
|
||||
e.generics::<T>(&assignment),
|
||||
arguments.clone(),
|
||||
output_type,
|
||||
)),
|
||||
_ => unreachable!(),
|
||||
|
@ -162,59 +142,38 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
|
||||
assert_eq!(f.arguments.len(), arguments.len());
|
||||
|
||||
let (ssa_f, incomplete_data) = match ShallowTransformer::transform(f, &assignment, versions) {
|
||||
Output::Complete(v) => (v, None),
|
||||
Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)),
|
||||
};
|
||||
let generic_bindings = assignment.0.into_iter().map(|(identifier, value)| {
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::uint(
|
||||
CoreIdentifier::from(identifier),
|
||||
UBitwidth::B32,
|
||||
)),
|
||||
TypedExpression::from(UExpression::from(value)).into(),
|
||||
)
|
||||
});
|
||||
|
||||
let call_log = TypedStatement::PushCallLog(decl.key.clone(), assignment.clone());
|
||||
|
||||
let input_bindings: Vec<TypedStatement<'ast, T>> = ssa_f
|
||||
let input_variables: Vec<Variable<'ast, T>> = f
|
||||
.arguments
|
||||
.into_iter()
|
||||
.zip(inferred_signature.inputs.clone())
|
||||
.map(|(p, t)| ConcreteVariable::new(p.id.id, t, false))
|
||||
.zip(arguments.clone())
|
||||
.map(|(v, a)| TypedStatement::definition(Variable::from(v).into(), a))
|
||||
.map(Variable::from)
|
||||
.collect();
|
||||
|
||||
let (statements, mut returns): (Vec<_>, Vec<_>) = ssa_f
|
||||
.statements
|
||||
.into_iter()
|
||||
let (statements, mut returns): (Vec<_>, Vec<_>) = generic_bindings
|
||||
.chain(f.statements)
|
||||
.partition(|s| !matches!(s, TypedStatement::Return(..)));
|
||||
|
||||
assert_eq!(returns.len(), 1);
|
||||
|
||||
let return_expression = match returns.pop().unwrap() {
|
||||
let return_value = match returns.pop().unwrap() {
|
||||
TypedStatement::Return(e) => e,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let v: ConcreteVariable<'ast> = ConcreteVariable::new(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(
|
||||
*versions
|
||||
.entry(CoreIdentifier::Call(0))
|
||||
.and_modify(|e| *e += 1) // if it was already declared, we increment
|
||||
.or_insert(0),
|
||||
),
|
||||
*inferred_signature.output.clone(),
|
||||
false,
|
||||
);
|
||||
|
||||
let expression = TypedExpression::from(Variable::from(v.clone()));
|
||||
|
||||
let output_binding = TypedStatement::definition(Variable::from(v).into(), return_expression);
|
||||
|
||||
let pop_log = TypedStatement::PopCallLog;
|
||||
|
||||
let statements: Vec<_> = std::iter::once(call_log)
|
||||
.chain(input_bindings)
|
||||
.chain(statements)
|
||||
.chain(std::iter::once(output_binding))
|
||||
.chain(std::iter::once(pop_log))
|
||||
.collect();
|
||||
|
||||
Ok(incomplete_data
|
||||
.map(|d| Output::Incomplete((statements.clone(), expression.clone()), d))
|
||||
.unwrap_or_else(|| Output::Complete((statements, expression))))
|
||||
Ok(InlineValue {
|
||||
input_variables,
|
||||
statements,
|
||||
return_value,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,40 +3,42 @@
|
|||
// - free of function calls (except for low level calls) thanks to inlining
|
||||
// - free of for-loops thanks to unrolling
|
||||
|
||||
// The process happens in two steps
|
||||
// 1. Shallow SSA for the `main` function
|
||||
// We turn the `main` function into SSA form, but ignoring function calls and for loops
|
||||
// 2. Unroll and inline
|
||||
// We go through the shallow-SSA program and
|
||||
// - unroll loops
|
||||
// - inline function calls. This includes applying shallow-ssa on the target function
|
||||
// The process happens in a greedy way, starting from the main function
|
||||
// For each statement:
|
||||
// * put it in ssa form
|
||||
// * propagate it
|
||||
// * inline it (calling this process recursively)
|
||||
// * propagate again
|
||||
|
||||
// if at any time a generic parameter or loop bound is not constant, error out, because it should have been propagated to a constant by the greedy approach
|
||||
|
||||
mod constants_reader;
|
||||
mod constants_writer;
|
||||
mod inline;
|
||||
mod shallow_ssa;
|
||||
|
||||
use self::inline::InlineValue;
|
||||
use self::inline::{inline_call, InlineError};
|
||||
use std::collections::HashMap;
|
||||
use zokrates_ast::typed::result_folder::*;
|
||||
use zokrates_ast::typed::types::ConcreteGenericsAssignment;
|
||||
use zokrates_ast::typed::types::GGenericsAssignment;
|
||||
use zokrates_ast::typed::DeclarationParameter;
|
||||
use zokrates_ast::typed::Folder;
|
||||
use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable};
|
||||
|
||||
use zokrates_ast::typed::TypedAssignee;
|
||||
use zokrates_ast::typed::{
|
||||
ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall,
|
||||
FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId,
|
||||
TypedExpression, TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration,
|
||||
TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner,
|
||||
FunctionCallExpression, FunctionCallOrExpression, Id, OwnedTypedModuleId, TypedExpression,
|
||||
TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram,
|
||||
TypedStatement, UExpression, UExpressionInner,
|
||||
};
|
||||
use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable};
|
||||
|
||||
use zokrates_field::Field;
|
||||
|
||||
use self::constants_writer::ConstantsWriter;
|
||||
use self::shallow_ssa::ShallowTransformer;
|
||||
|
||||
use crate::propagation::{Constants, Propagator};
|
||||
use crate::propagation;
|
||||
use crate::propagation::Propagator;
|
||||
|
||||
use std::fmt;
|
||||
|
||||
|
@ -46,25 +48,15 @@ const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20);
|
|||
pub type ConstantDefinitions<'ast, T> =
|
||||
HashMap<CanonicalConstantIdentifier<'ast>, TypedExpression<'ast, T>>;
|
||||
|
||||
// An SSA version map, giving access to the latest version number for each identifier
|
||||
pub type Versions<'ast> = HashMap<CoreIdentifier<'ast>, usize>;
|
||||
|
||||
// A container to represent whether more treatment must be applied to the function
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum Output<U, V> {
|
||||
Complete(U),
|
||||
Incomplete(U, V),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
Incompatible(String),
|
||||
GenericsInMain,
|
||||
// TODO: give more details about what's blocking the progress
|
||||
NoProgress,
|
||||
LoopTooLarge(u128),
|
||||
ConstantReduction(String, OwnedTypedModuleId),
|
||||
NonConstant(String),
|
||||
Type(String),
|
||||
Propagation(propagation::Error),
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
|
@ -76,133 +68,36 @@ impl fmt::Display for Error {
|
|||
s
|
||||
),
|
||||
Error::GenericsInMain => write!(f, "Cannot generate code for generic function"),
|
||||
Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"),
|
||||
Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE),
|
||||
Error::ConstantReduction(name, module) => write!(f, "Failed to reduce constant `{}` in module `{}` to a literal, try simplifying its declaration", name, module.display()),
|
||||
Error::Type(message) => write!(f, "{}", message),
|
||||
Error::NonConstant(s) => write!(f, "{}", s),
|
||||
Error::Type(s) => write!(f, "{}", s),
|
||||
Error::Propagation(e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct Substitutions<'ast>(HashMap<CoreIdentifier<'ast>, HashMap<usize, usize>>);
|
||||
|
||||
impl<'ast> Substitutions<'ast> {
|
||||
// create an equivalent substitution map where all paths
|
||||
// are of length 1
|
||||
fn canonicalize(self) -> Self {
|
||||
Substitutions(
|
||||
self.0
|
||||
.into_iter()
|
||||
.map(|(id, sub)| (id, Self::canonicalize_sub(sub)))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
// canonicalize substitutions for a given id
|
||||
fn canonicalize_sub(sub: HashMap<usize, usize>) -> HashMap<usize, usize> {
|
||||
fn add_to_cache(
|
||||
sub: &HashMap<usize, usize>,
|
||||
cache: HashMap<usize, usize>,
|
||||
k: usize,
|
||||
) -> HashMap<usize, usize> {
|
||||
match cache.contains_key(&k) {
|
||||
// `k` is already in the cache, no changes to the cache
|
||||
true => cache,
|
||||
_ => match sub.get(&k) {
|
||||
// `k` does not point to anything, no changes to the cache
|
||||
None => cache,
|
||||
// `k` points to some `v
|
||||
Some(v) => {
|
||||
// add `v` to the cache
|
||||
let cache = add_to_cache(sub, cache, *v);
|
||||
// `k` points to what `v` points to, or to `v`
|
||||
let v = cache.get(v).cloned().unwrap_or(*v);
|
||||
let mut cache = cache;
|
||||
cache.insert(k, v);
|
||||
cache
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
sub.keys()
|
||||
.fold(HashMap::new(), |cache, k| add_to_cache(&sub, cache, *k))
|
||||
}
|
||||
}
|
||||
|
||||
struct Sub<'a, 'ast> {
|
||||
substitutions: &'a Substitutions<'ast>,
|
||||
}
|
||||
|
||||
impl<'a, 'ast> Sub<'a, 'ast> {
|
||||
fn new(substitutions: &'a Substitutions<'ast>) -> Self {
|
||||
Self { substitutions }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'ast, T: Field> Folder<'ast, T> for Sub<'a, 'ast> {
|
||||
fn fold_name(&mut self, id: Identifier<'ast>) -> Identifier<'ast> {
|
||||
let version = self
|
||||
.substitutions
|
||||
.0
|
||||
.get(&id.id)
|
||||
.map(|sub| sub.get(&id.version).cloned().unwrap_or(id.version))
|
||||
.unwrap_or(id.version);
|
||||
id.version(version)
|
||||
}
|
||||
}
|
||||
|
||||
fn register<'ast>(
|
||||
substitutions: &mut Substitutions<'ast>,
|
||||
substitute: &Versions<'ast>,
|
||||
with: &Versions<'ast>,
|
||||
) {
|
||||
for (id, key, value) in substitute
|
||||
.iter()
|
||||
.filter_map(|(id, version)| with.get(id).map(|to| (id, version, to)))
|
||||
.filter(|(_, key, value)| key != value)
|
||||
{
|
||||
let sub = substitutions.0.entry(id.clone()).or_default();
|
||||
|
||||
// redirect `k` to `v`, unless `v` is already redirected to `v0`, in which case we redirect to `v0`
|
||||
|
||||
sub.insert(*key, *sub.get(value).unwrap_or(value));
|
||||
impl From<propagation::Error> for Error {
|
||||
fn from(e: propagation::Error) -> Self {
|
||||
Self::Propagation(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Reducer<'ast, 'a, T> {
|
||||
statement_buffer: Vec<TypedStatement<'ast, T>>,
|
||||
for_loop_versions: Vec<Versions<'ast>>,
|
||||
for_loop_versions_after: Vec<Versions<'ast>>,
|
||||
program: &'a TypedProgram<'ast, T>,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
substitutions: &'a mut Substitutions<'ast>,
|
||||
complete: bool,
|
||||
propagator: Propagator<'ast, T>,
|
||||
ssa: ShallowTransformer<'ast>,
|
||||
statement_buffer: Vec<TypedStatement<'ast, T>>,
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
||||
fn new(
|
||||
program: &'a TypedProgram<'ast, T>,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
substitutions: &'a mut Substitutions<'ast>,
|
||||
for_loop_versions: Vec<Versions<'ast>>,
|
||||
) -> Self {
|
||||
// we reverse the vector as it's cheaper to `pop` than to take from
|
||||
// the head
|
||||
let mut for_loop_versions = for_loop_versions;
|
||||
|
||||
for_loop_versions.reverse();
|
||||
|
||||
fn new(program: &'a TypedProgram<'ast, T>) -> Self {
|
||||
Reducer {
|
||||
propagator: Propagator::default(),
|
||||
ssa: ShallowTransformer::default(),
|
||||
statement_buffer: vec![],
|
||||
for_loop_versions_after: vec![],
|
||||
for_loop_versions,
|
||||
substitutions,
|
||||
program,
|
||||
versions,
|
||||
complete: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -210,6 +105,13 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
|||
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_parameter(
|
||||
&mut self,
|
||||
p: DeclarationParameter<'ast, T>,
|
||||
) -> Result<DeclarationParameter<'ast, T>, Self::Error> {
|
||||
Ok(self.ssa.fold_parameter(p))
|
||||
}
|
||||
|
||||
fn fold_function_call_expression<
|
||||
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
|
||||
>(
|
||||
|
@ -217,65 +119,98 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
ty: &E::Ty,
|
||||
e: FunctionCallExpression<'ast, T, E>,
|
||||
) -> Result<FunctionCallOrExpression<'ast, T, E>, Self::Error> {
|
||||
let generics = e
|
||||
// generics are already in ssa form
|
||||
|
||||
let generics: Vec<_> = e
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
|
||||
.map(|g| {
|
||||
g.map(|g| {
|
||||
let g = self.propagator.fold_uint_expression(g)?;
|
||||
let g = self.fold_uint_expression(g)?;
|
||||
|
||||
self.propagator
|
||||
.fold_uint_expression(g)
|
||||
.map_err(Self::Error::from)
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let arguments = e
|
||||
// arguments are already in ssa form
|
||||
|
||||
let arguments: Vec<_> = e
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|e| self.fold_expression(e))
|
||||
.map(|e| {
|
||||
let e = self.propagator.fold_expression(e)?;
|
||||
let e = self.fold_expression(e)?;
|
||||
|
||||
self.propagator
|
||||
.fold_expression(e)
|
||||
.map_err(Self::Error::from)
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let res = inline_call::<_, E>(
|
||||
e.function_key.clone(),
|
||||
generics,
|
||||
arguments,
|
||||
ty,
|
||||
self.program,
|
||||
self.versions,
|
||||
);
|
||||
self.ssa.push_call_frame();
|
||||
|
||||
let res = inline_call::<_, E>(&e.function_key, &generics, &arguments, ty, self.program);
|
||||
|
||||
let res = match res {
|
||||
Ok(InlineValue {
|
||||
input_variables,
|
||||
statements,
|
||||
return_value,
|
||||
}) => {
|
||||
// the lhs is from the inner call frame, the rhs is from the outer one, so only fold the lhs
|
||||
let input_bindings: Vec<_> = input_variables
|
||||
.into_iter()
|
||||
.zip(arguments)
|
||||
.map(|(v, a)| TypedStatement::definition(self.ssa.fold_assignee(v.into()), a))
|
||||
.collect();
|
||||
|
||||
let input_bindings = input_bindings
|
||||
.into_iter()
|
||||
.map(|s| self.propagator.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
self.statement_buffer.extend(input_bindings);
|
||||
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|s| self.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
match res {
|
||||
Ok(Output::Complete((statements, expression))) => {
|
||||
self.complete &= true;
|
||||
self.statement_buffer.extend(statements);
|
||||
|
||||
let return_value = self.ssa.fold_expression(return_value);
|
||||
|
||||
let return_value = self.propagator.fold_expression(return_value)?;
|
||||
|
||||
let return_value = self.fold_expression(return_value)?;
|
||||
|
||||
Ok(FunctionCallOrExpression::Expression(
|
||||
E::from(expression).into_inner(),
|
||||
))
|
||||
}
|
||||
Ok(Output::Incomplete((statements, expression), delta_for_loop_versions)) => {
|
||||
self.complete = false;
|
||||
self.statement_buffer.extend(statements);
|
||||
self.for_loop_versions_after.extend(delta_for_loop_versions);
|
||||
Ok(FunctionCallOrExpression::Expression(
|
||||
E::from(expression.clone()).into_inner(),
|
||||
E::from(return_value).into_inner(),
|
||||
))
|
||||
}
|
||||
Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!(
|
||||
"Call site `{}` incompatible with declaration `{}`",
|
||||
conc, decl
|
||||
))),
|
||||
Err(InlineError::NonConstant(key, generics, arguments, _)) => {
|
||||
self.complete = false;
|
||||
|
||||
Ok(FunctionCallOrExpression::Expression(E::function_call(
|
||||
key, generics, arguments,
|
||||
)))
|
||||
}
|
||||
Err(InlineError::Flat(embed, generics, arguments, output_type)) => {
|
||||
let identifier = Identifier::from(CoreIdentifier::Call(0)).version(
|
||||
*self
|
||||
.versions
|
||||
.entry(CoreIdentifier::Call(0))
|
||||
.and_modify(|e| *e += 1) // if it was already declared, we increment
|
||||
.or_insert(0),
|
||||
);
|
||||
Err(InlineError::NonConstant) => Err(Error::NonConstant(format!(
|
||||
"Generic parameters must be compile-time constants, found {}",
|
||||
FunctionCallExpression::<_, E>::new(e.function_key, generics, arguments)
|
||||
))),
|
||||
Err(InlineError::Flat(embed, generics, output_type)) => {
|
||||
let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0));
|
||||
|
||||
let var = Variable::immutable(identifier.clone(), output_type);
|
||||
let v = var.clone().into();
|
||||
|
||||
let v: TypedAssignee<'ast, T> = var.clone().into();
|
||||
|
||||
self.statement_buffer
|
||||
.push(TypedStatement::embed_call_definition(
|
||||
|
@ -286,7 +221,11 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
identifier,
|
||||
)))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
self.ssa.pop_call_frame();
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
fn fold_block_expression<E: ResultFold<'ast, T>>(
|
||||
|
@ -325,74 +264,70 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
|
||||
let res = match s {
|
||||
TypedStatement::For(v, from, to, statements) => {
|
||||
let versions_before = self.for_loop_versions.pop().unwrap();
|
||||
let from = self.ssa.fold_uint_expression(from);
|
||||
let from = self.propagator.fold_uint_expression(from)?;
|
||||
let from = self.fold_uint_expression(from)?;
|
||||
let from = self.propagator.fold_uint_expression(from)?;
|
||||
|
||||
let to = self.ssa.fold_uint_expression(to);
|
||||
let to = self.propagator.fold_uint_expression(to)?;
|
||||
let to = self.fold_uint_expression(to)?;
|
||||
let to = self.propagator.fold_uint_expression(to)?;
|
||||
|
||||
match (from.as_inner(), to.as_inner()) {
|
||||
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => {
|
||||
let mut out_statements = vec![];
|
||||
|
||||
// get a fresh set of versions for all variables to use as a starting point inside the loop
|
||||
self.versions.values_mut().for_each(|v| *v += 1);
|
||||
|
||||
// add this set of versions to the substitution, pointing to the versions before the loop
|
||||
register(self.substitutions, self.versions, &versions_before);
|
||||
|
||||
// the versions after the loop are found by applying an offset of 1 to the versions before the loop
|
||||
let versions_after = versions_before
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v + 1))
|
||||
.collect();
|
||||
|
||||
let mut transformer = ShallowTransformer::with_versions(self.versions);
|
||||
|
||||
if to - from > MAX_FOR_LOOP_SIZE {
|
||||
return Err(Error::LoopTooLarge(to.saturating_sub(*from)));
|
||||
}
|
||||
|
||||
for index in *from..*to {
|
||||
let statements: Vec<TypedStatement<_>> =
|
||||
std::iter::once(TypedStatement::definition(
|
||||
v.clone().into(),
|
||||
UExpression::from(index as u32).into(),
|
||||
))
|
||||
.chain(statements.clone().into_iter())
|
||||
.flat_map(|s| transformer.fold_statement(s))
|
||||
.collect();
|
||||
|
||||
out_statements.extend(statements);
|
||||
}
|
||||
|
||||
let backups = transformer.for_loop_backups;
|
||||
let blocked = transformer.blocked;
|
||||
|
||||
// we know the final versions of the variables after full unrolling of the loop
|
||||
// the versions after the loop need to point to these, so we add to the substitutions
|
||||
register(self.substitutions, &versions_after, self.versions);
|
||||
|
||||
// we may have found new for loops when unrolling this one, which means new backed up versions
|
||||
// we insert these in our backup list and update our cursor
|
||||
|
||||
self.for_loop_versions_after.extend(backups);
|
||||
|
||||
// if the ssa transform got blocked, the reduction is not complete
|
||||
self.complete &= !blocked;
|
||||
|
||||
Ok(out_statements)
|
||||
(UExpressionInner::Value(from), UExpressionInner::Value(to))
|
||||
if to - from > MAX_FOR_LOOP_SIZE =>
|
||||
{
|
||||
Err(Error::LoopTooLarge(to.saturating_sub(*from)))
|
||||
}
|
||||
_ => {
|
||||
let from = self.fold_uint_expression(from)?;
|
||||
let to = self.fold_uint_expression(to)?;
|
||||
self.complete = false;
|
||||
self.for_loop_versions_after.push(versions_before);
|
||||
Ok(vec![TypedStatement::For(v, from, to, statements)])
|
||||
}
|
||||
}
|
||||
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from
|
||||
..*to)
|
||||
.flat_map(|index| {
|
||||
std::iter::once(TypedStatement::definition(
|
||||
v.clone().into(),
|
||||
UExpression::from(index as u32).into(),
|
||||
))
|
||||
.chain(statements.clone())
|
||||
.map(|s| self.fold_statement(s))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()),
|
||||
_ => Err(Error::NonConstant(format!(
|
||||
"Expected loop bounds to be constant, found {}..{}",
|
||||
from, to
|
||||
))),
|
||||
}?
|
||||
}
|
||||
s => {
|
||||
let statements = self.ssa.fold_statement(s);
|
||||
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|s| self.propagator.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
let statements = statements
|
||||
.map(|s| fold_statement(self, s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
let statements = statements
|
||||
.map(|s| self.propagator.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
statements.collect()
|
||||
}
|
||||
s => fold_statement(self, s),
|
||||
};
|
||||
|
||||
res.map(|res| self.statement_buffer.drain(..).chain(res).collect())
|
||||
Ok(self.statement_buffer.drain(..).chain(res).collect())
|
||||
}
|
||||
|
||||
fn fold_array_expression_inner(
|
||||
|
@ -402,18 +337,29 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
ArrayExpressionInner::Slice(box array, box from, box to) => {
|
||||
let array = self.ssa.fold_array_expression(array);
|
||||
let array = self.propagator.fold_array_expression(array)?;
|
||||
let array = self.fold_array_expression(array)?;
|
||||
let array = self.propagator.fold_array_expression(array)?;
|
||||
|
||||
let from = self.ssa.fold_uint_expression(from);
|
||||
let from = self.propagator.fold_uint_expression(from)?;
|
||||
let from = self.fold_uint_expression(from)?;
|
||||
let from = self.propagator.fold_uint_expression(from)?;
|
||||
|
||||
let to = self.ssa.fold_uint_expression(to);
|
||||
let to = self.propagator.fold_uint_expression(to)?;
|
||||
let to = self.fold_uint_expression(to)?;
|
||||
let to = self.propagator.fold_uint_expression(to)?;
|
||||
|
||||
match (from.as_inner(), to.as_inner()) {
|
||||
(UExpressionInner::Value(..), UExpressionInner::Value(..)) => {
|
||||
Ok(ArrayExpressionInner::Slice(box array, box from, box to))
|
||||
}
|
||||
_ => {
|
||||
self.complete = false;
|
||||
Ok(ArrayExpressionInner::Slice(box array, box from, box to))
|
||||
}
|
||||
_ => Err(Error::NonConstant(format!(
|
||||
"Slice bounds must be compile time constants, found {}",
|
||||
ArrayExpressionInner::Slice(box array, box from, box to)
|
||||
))),
|
||||
}
|
||||
}
|
||||
_ => fold_array_expression_inner(self, array_ty, e),
|
||||
|
@ -443,7 +389,7 @@ pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, E
|
|||
|
||||
match main_function.signature.generics.len() {
|
||||
0 => {
|
||||
let main_function = reduce_function(main_function, GGenericsAssignment::default(), &p)?;
|
||||
let main_function = Reducer::new(&p).fold_function(main_function)?;
|
||||
|
||||
Ok(TypedProgram {
|
||||
main: p.main.clone(),
|
||||
|
@ -467,91 +413,11 @@ pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, E
|
|||
|
||||
fn reduce_function<'ast, T: Field>(
|
||||
f: TypedFunction<'ast, T>,
|
||||
generics: ConcreteGenericsAssignment<'ast>,
|
||||
program: &TypedProgram<'ast, T>,
|
||||
) -> Result<TypedFunction<'ast, T>, Error> {
|
||||
let mut versions = Versions::default();
|
||||
assert!(f.signature.generics.is_empty());
|
||||
|
||||
let mut constants = Constants::default();
|
||||
|
||||
let f = match ShallowTransformer::transform(f, &generics, &mut versions) {
|
||||
Output::Complete(f) => Ok(f),
|
||||
Output::Incomplete(new_f, new_for_loop_versions) => {
|
||||
let mut for_loop_versions = new_for_loop_versions;
|
||||
|
||||
let mut f = new_f;
|
||||
|
||||
let mut substitutions = Substitutions::default();
|
||||
|
||||
let mut hash = None;
|
||||
|
||||
loop {
|
||||
let mut reducer = Reducer::new(
|
||||
program,
|
||||
&mut versions,
|
||||
&mut substitutions,
|
||||
for_loop_versions,
|
||||
);
|
||||
|
||||
let new_f = TypedFunction {
|
||||
statements: f
|
||||
.statements
|
||||
.into_iter()
|
||||
.map(|s| reducer.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
..f
|
||||
};
|
||||
|
||||
assert!(reducer.for_loop_versions.is_empty());
|
||||
|
||||
match reducer.complete {
|
||||
true => {
|
||||
substitutions = substitutions.canonicalize();
|
||||
|
||||
let new_f = Sub::new(&substitutions).fold_function(new_f);
|
||||
|
||||
let new_f = Propagator::with_constants(&mut constants)
|
||||
.fold_function(new_f)
|
||||
.map_err(|e| Error::Incompatible(format!("{}", e)))?;
|
||||
|
||||
break Ok(new_f);
|
||||
}
|
||||
false => {
|
||||
for_loop_versions = reducer.for_loop_versions_after;
|
||||
|
||||
let new_f = Sub::new(&substitutions).fold_function(new_f);
|
||||
|
||||
f = Propagator::with_constants(&mut constants)
|
||||
.fold_function(new_f)
|
||||
.map_err(|e| Error::Incompatible(format!("{}", e)))?;
|
||||
|
||||
let new_hash = Some(compute_hash(&f));
|
||||
|
||||
if new_hash == hash {
|
||||
break Err(Error::NoProgress);
|
||||
} else {
|
||||
hash = new_hash
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}?;
|
||||
|
||||
Propagator::with_constants(&mut constants)
|
||||
.fold_function(f)
|
||||
.map_err(|e| Error::Incompatible(format!("{}", e)))
|
||||
}
|
||||
|
||||
fn compute_hash<T: Field>(f: &TypedFunction<T>) -> u64 {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut s = DefaultHasher::new();
|
||||
f.hash(&mut s);
|
||||
s.finish()
|
||||
Reducer::new(program).fold_function(f)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -588,14 +454,11 @@ mod tests {
|
|||
// }
|
||||
|
||||
// expected:
|
||||
// def main(field a_0) -> field {
|
||||
// a_1 = a_0;
|
||||
// # PUSH CALL to foo
|
||||
// a_3 := a_1; // input binding
|
||||
// #RETURN_AT_INDEX_0_0 := a_3;
|
||||
// # POP CALL
|
||||
// a_2 = #RETURN_AT_INDEX_0_0;
|
||||
// return a_2;
|
||||
// def main(field a_f0_v0) -> field {
|
||||
// a_f0_v1 = a_f0_v0; // redef
|
||||
// a_f1_v0 = a_f0_v1; // input binding
|
||||
// a_f0_v2 = a_f1_v0; // output binding
|
||||
// return a_f0_v2;
|
||||
// }
|
||||
|
||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||
|
@ -691,30 +554,13 @@ mod tests {
|
|||
Variable::field_element(Identifier::from("a").version(1)).into(),
|
||||
FieldElementExpression::identifier("a".into()).into(),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.output(DeclarationType::FieldElement),
|
||||
),
|
||||
GGenericsAssignment::default(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
Variable::field_element(Identifier::from("a").in_frame(1)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(1)).into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0))
|
||||
.into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(3)).into(),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(2)).into(),
|
||||
FieldElementExpression::identifier(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").in_frame(1)).into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
|
||||
|
@ -763,14 +609,11 @@ mod tests {
|
|||
// }
|
||||
|
||||
// expected:
|
||||
// def main(field a_0) -> field {
|
||||
// field[1] b_0 = [42];
|
||||
// # PUSH CALL to foo::<1>
|
||||
// a_0 = b_0;
|
||||
// #RETURN_AT_INDEX_0_0 := a_0;
|
||||
// # POP CALL
|
||||
// b_1 = #RETURN_AT_INDEX_0_0;
|
||||
// return a_2 + b_1[0];
|
||||
// def main(field a_f0_v0) -> field {
|
||||
// field[1] b_f0_v0 = [a_f0_v0];
|
||||
// a_f1_v0 = b_f0_v0;
|
||||
// b_f0_v1 = a_f1_v0;
|
||||
// return a_f0_v0 + b_f0_v1[0];
|
||||
// }
|
||||
|
||||
let foo_signature = DeclarationSignature::new()
|
||||
|
@ -897,42 +740,19 @@ mod tests {
|
|||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
||||
Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::array(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
Type::FieldElement,
|
||||
1u32,
|
||||
)
|
||||
.into(),
|
||||
ArrayExpression::identifier(Identifier::from("a").version(1))
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier(Identifier::from("a").in_frame(1))
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
(FieldElementExpression::identifier("a".into())
|
||||
|
@ -987,14 +807,11 @@ mod tests {
|
|||
// }
|
||||
|
||||
// expected:
|
||||
// def main(field a_0) -> field {
|
||||
// field[1] b_0 = [42];
|
||||
// # PUSH CALL to foo::<1>
|
||||
// a_0 = b_0;
|
||||
// #RETURN_AT_INDEX_0_0 := a_0;
|
||||
// # POP CALL
|
||||
// b_1 = #RETURN_AT_INDEX_0_0;
|
||||
// return a_2 + b_1[0];
|
||||
// def main(field a) -> field {
|
||||
// field[1] b = [a];
|
||||
// a_f1 = b;
|
||||
// b_1 = a_f1;
|
||||
// return a + b_1[0];
|
||||
// }
|
||||
|
||||
let foo_signature = DeclarationSignature::new()
|
||||
|
@ -1125,47 +942,25 @@ mod tests {
|
|||
TypedStatement::definition(
|
||||
Variable::array("b", Type::FieldElement, 1u32).into(),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![FieldElementExpression::identifier("a".into()).into()].into(),
|
||||
vec![FieldElementExpression::identifier(Identifier::from("a")).into()]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32)
|
||||
Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::array(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
Type::FieldElement,
|
||||
1u32,
|
||||
)
|
||||
.into(),
|
||||
ArrayExpression::identifier(Identifier::from("a").version(1))
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::definition(
|
||||
Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier(
|
||||
Identifier::from(CoreIdentifier::Call(0)).version(0),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
ArrayExpression::identifier(Identifier::from("a").in_frame(1))
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
(FieldElementExpression::identifier("a".into())
|
||||
|
@ -1391,33 +1186,11 @@ mod tests {
|
|||
|
||||
let expected_main = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![
|
||||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::PushCallLog(
|
||||
DeclarationFunctionKey::with_location("main", "bar")
|
||||
.signature(foo_signature.clone()),
|
||||
GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 2)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
),
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::PopCallLog,
|
||||
TypedStatement::Return(
|
||||
TupleExpressionInner::Value(vec![])
|
||||
.annotate(TupleType::new(vec![]))
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
statements: vec![TypedStatement::Return(
|
||||
TupleExpressionInner::Value(vec![])
|
||||
.annotate(TupleType::new(vec![]))
|
||||
.into(),
|
||||
)],
|
||||
signature: DeclarationSignature::new(),
|
||||
};
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
// The SSA transformation leaves gaps in the indices when it hits a for-loop, so that the body of the for-loop can
|
||||
// modify the variables in scope. The state of the indices before all for-loops is returned to account for that possibility.
|
||||
// Function calls are also left unvisited
|
||||
// Saving the indices is not required for function calls, as they cannot modify their environment
|
||||
// The SSA transformation
|
||||
// * introduces new versions if and only if we are assigning to an identifier
|
||||
// * does not visit the statements of loops
|
||||
|
||||
// Example:
|
||||
// def main(field a) -> field {
|
||||
|
@ -19,178 +18,167 @@
|
|||
// u32 n_0 = 42;
|
||||
// a_1 = a_0 + 1;
|
||||
// field b_0 = foo(a_1); // we keep the function call as is
|
||||
// # versions: {n: 0, a: 1, b: 0}
|
||||
// for u32 i_0 in 0..n_0 {
|
||||
// <body> // we keep the loop body as is
|
||||
// }
|
||||
// return b_3; // we leave versions b_1 and b_2 to make b accessible and modifiable inside the for-loop
|
||||
// }
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use zokrates_ast::typed::folder::*;
|
||||
use zokrates_ast::typed::types::ConcreteGenericsAssignment;
|
||||
use zokrates_ast::typed::types::Type;
|
||||
|
||||
use zokrates_ast::typed::*;
|
||||
|
||||
use zokrates_field::Field;
|
||||
|
||||
use super::{Output, Versions};
|
||||
|
||||
pub struct ShallowTransformer<'ast, 'a> {
|
||||
// version index for any variable name
|
||||
pub versions: &'a mut Versions<'ast>,
|
||||
// A backup of the versions before each for-loop
|
||||
pub for_loop_backups: Vec<Versions<'ast>>,
|
||||
// whether all statements could be unrolled so far. Loops with variable bounds cannot.
|
||||
pub blocked: bool,
|
||||
// An SSA version map, giving access to the latest version number for each identifier
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Versions<'ast> {
|
||||
map: HashMap<usize, HashMap<CoreIdentifier<'ast>, usize>>,
|
||||
}
|
||||
|
||||
impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
|
||||
pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self {
|
||||
ShallowTransformer {
|
||||
versions,
|
||||
for_loop_backups: Vec::default(),
|
||||
blocked: false,
|
||||
impl<'ast> Default for Versions<'ast> {
|
||||
fn default() -> Self {
|
||||
// create a call frame at index 0
|
||||
Self {
|
||||
map: vec![(0, Default::default())].into_iter().collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// increase all versions by 1 and return the old versions
|
||||
fn create_version_gap(&mut self) -> Versions<'ast> {
|
||||
let ret = self.versions.clone();
|
||||
self.versions.values_mut().for_each(|v| *v += 1);
|
||||
ret
|
||||
}
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ShallowTransformer<'ast> {
|
||||
// version index for any variable name
|
||||
pub versions: Versions<'ast>,
|
||||
pub frames: Vec<usize>,
|
||||
pub latest_frame: usize,
|
||||
}
|
||||
|
||||
fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> {
|
||||
let version = *self
|
||||
.versions
|
||||
impl<'ast> ShallowTransformer<'ast> {
|
||||
pub fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> {
|
||||
let frame = self.frame();
|
||||
|
||||
let frame_versions = self.versions.map.entry(frame).or_default();
|
||||
|
||||
let version = frame_versions
|
||||
.entry(c_id.clone())
|
||||
.and_modify(|e| *e += 1) // if it was already declared, we increment
|
||||
.or_insert(0); // otherwise, we start from this version
|
||||
.or_default(); // otherwise, we start from this version
|
||||
|
||||
Identifier::from(c_id).version(version)
|
||||
Identifier::from(c_id.in_frame(frame)).version(*version)
|
||||
}
|
||||
|
||||
fn issue_next_ssa_variable<T: Field>(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> {
|
||||
assert_eq!(v.id.version, 0);
|
||||
|
||||
Variable {
|
||||
id: self.issue_next_identifier(v.id.id),
|
||||
id: self.issue_next_identifier(v.id.id.id),
|
||||
..v
|
||||
}
|
||||
}
|
||||
|
||||
pub fn transform<T: Field>(
|
||||
f: TypedFunction<'ast, T>,
|
||||
generics: &ConcreteGenericsAssignment<'ast>,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
) -> Output<TypedFunction<'ast, T>, Vec<Versions<'ast>>> {
|
||||
let mut unroller = ShallowTransformer::with_versions(versions);
|
||||
|
||||
let f = unroller.fold_function(f, generics);
|
||||
|
||||
match unroller.blocked {
|
||||
false => Output::Complete(f),
|
||||
true => Output::Incomplete(f, unroller.for_loop_backups),
|
||||
}
|
||||
fn frame(&self) -> usize {
|
||||
*self.frames.last().unwrap_or(&0)
|
||||
}
|
||||
|
||||
fn fold_function<T: Field>(
|
||||
pub fn push_call_frame(&mut self) {
|
||||
self.latest_frame += 1;
|
||||
self.frames.push(self.latest_frame);
|
||||
self.versions
|
||||
.map
|
||||
.insert(self.latest_frame, Default::default());
|
||||
}
|
||||
|
||||
pub fn pop_call_frame(&mut self) {
|
||||
let frame = self.frames.pop().unwrap();
|
||||
self.versions.map.remove(&frame);
|
||||
}
|
||||
|
||||
// fold an assignee replacing by the latest version. This is necessary because the trait implementation increases the ssa version for identifiers,
|
||||
// but this should not be applied recursively to complex assignees
|
||||
fn fold_assignee_no_ssa_increase<T: Field>(
|
||||
&mut self,
|
||||
f: TypedFunction<'ast, T>,
|
||||
generics: &ConcreteGenericsAssignment<'ast>,
|
||||
) -> TypedFunction<'ast, T> {
|
||||
let mut f = f;
|
||||
a: TypedAssignee<'ast, T>,
|
||||
) -> TypedAssignee<'ast, T> {
|
||||
match a {
|
||||
TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)),
|
||||
TypedAssignee::Select(box a, box index) => TypedAssignee::Select(
|
||||
box self.fold_assignee_no_ssa_increase(a),
|
||||
box self.fold_uint_expression(index),
|
||||
),
|
||||
TypedAssignee::Member(box s, m) => {
|
||||
TypedAssignee::Member(box self.fold_assignee_no_ssa_increase(s), m)
|
||||
}
|
||||
TypedAssignee::Element(box s, index) => {
|
||||
TypedAssignee::Element(box self.fold_assignee_no_ssa_increase(s), index)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
f.statements = generics
|
||||
.0
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(g, v)| {
|
||||
TypedStatement::definition(
|
||||
Variable::new(CoreIdentifier::from(g), Type::Uint(UBitwidth::B32), false)
|
||||
.into(),
|
||||
UExpression::from(v as u32).into(),
|
||||
)
|
||||
})
|
||||
.chain(f.statements)
|
||||
.collect();
|
||||
|
||||
for arg in &f.arguments {
|
||||
let _ = self.issue_next_identifier(arg.id.id.id.clone());
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> {
|
||||
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
|
||||
for g in &f.signature.generics {
|
||||
let generic_parameter = match g.as_ref().unwrap() {
|
||||
DeclarationConstant::Generic(g) => g,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let _ = self.issue_next_identifier(CoreIdentifier::from(generic_parameter.clone()));
|
||||
}
|
||||
|
||||
fold_function(self, f)
|
||||
}
|
||||
|
||||
fn fold_assignee<T: Field>(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
|
||||
fn fold_parameter(
|
||||
&mut self,
|
||||
p: DeclarationParameter<'ast, T>,
|
||||
) -> DeclarationParameter<'ast, T> {
|
||||
DeclarationParameter {
|
||||
id: DeclarationVariable {
|
||||
id: self.issue_next_identifier(p.id.id.id.id),
|
||||
..p.id
|
||||
},
|
||||
..p
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
|
||||
match a {
|
||||
// create a new version for assignments to identifiers
|
||||
TypedAssignee::Identifier(v) => {
|
||||
let v = self.issue_next_ssa_variable(v);
|
||||
TypedAssignee::Identifier(self.fold_variable(v))
|
||||
}
|
||||
a => fold_assignee(self, a),
|
||||
// otherwise, simply replace by the current version
|
||||
a => self.fold_assignee_no_ssa_increase(a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Vec<TypedAssemblyStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) => {
|
||||
let e = self.fold_expression(e);
|
||||
let a = self.fold_assignee(a);
|
||||
vec![TypedAssemblyStatement::Assignment(a, e)]
|
||||
}
|
||||
s => fold_assembly_statement(self, s),
|
||||
}
|
||||
}
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedStatement::Definition(a, DefinitionRhs::Expression(e)) => {
|
||||
let e = self.fold_expression(e);
|
||||
let a = self.fold_assignee(a);
|
||||
vec![TypedStatement::definition(a, e)]
|
||||
}
|
||||
TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => {
|
||||
let embed_call = self.fold_embed_call(embed_call);
|
||||
let assignee = self.fold_assignee(assignee);
|
||||
vec![TypedStatement::embed_call_definition(assignee, embed_call)]
|
||||
}
|
||||
// only fold bounds of for loop statements
|
||||
TypedStatement::For(v, from, to, stats) => {
|
||||
let from = self.fold_uint_expression(from);
|
||||
let to = self.fold_uint_expression(to);
|
||||
self.blocked = true;
|
||||
let versions_before_loop = self.create_version_gap();
|
||||
self.for_loop_backups.push(versions_before_loop);
|
||||
vec![TypedStatement::For(v, from, to, stats)]
|
||||
}
|
||||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
||||
// retrieve the latest version
|
||||
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
|
||||
let res = Identifier {
|
||||
version: *self.versions.get(&(n.id)).unwrap_or(&0),
|
||||
..n
|
||||
};
|
||||
res
|
||||
}
|
||||
let version = self
|
||||
.versions
|
||||
.map
|
||||
.get(&self.frame())
|
||||
.unwrap()
|
||||
.get(&n.id.id)
|
||||
.cloned()
|
||||
.unwrap_or(0);
|
||||
|
||||
fn fold_function_call_expression<
|
||||
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
|
||||
>(
|
||||
&mut self,
|
||||
ty: &E::Ty,
|
||||
c: FunctionCallExpression<'ast, T, E>,
|
||||
) -> FunctionCallOrExpression<'ast, T, E> {
|
||||
if !c.function_key.id.starts_with('_') {
|
||||
self.blocked = true;
|
||||
}
|
||||
|
||||
fold_function_call_expression(self, ty, c)
|
||||
n.in_frame(self.frame()).version(version)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -203,36 +191,57 @@ mod tests {
|
|||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn detect_non_constant_bound() {
|
||||
let loops: Vec<TypedStatement<Bn128Field>> = vec![TypedStatement::For(
|
||||
Variable::new("i", Type::Uint(UBitwidth::B32), false),
|
||||
UExpression::identifier("i".into()).annotate(UBitwidth::B32),
|
||||
2u32.into(),
|
||||
vec![],
|
||||
)];
|
||||
fn ignore_loop_content() {
|
||||
// field foo = 0
|
||||
// u32 i = 4;
|
||||
// for u32 i in i..2 {
|
||||
// foo = 5;
|
||||
// }
|
||||
|
||||
let statements = loops;
|
||||
// should be left unchanged, as we do not visit the loop content nor the index variable
|
||||
|
||||
let f = TypedFunction {
|
||||
arguments: vec![],
|
||||
signature: DeclarationSignature::new(),
|
||||
statements,
|
||||
statements: vec![
|
||||
TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(Identifier::from("foo"))),
|
||||
FieldElementExpression::Number(Bn128Field::from(4)).into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::uint(
|
||||
Identifier::from("i"),
|
||||
UBitwidth::B32,
|
||||
)),
|
||||
UExpression::from(0u32).into(),
|
||||
),
|
||||
TypedStatement::For(
|
||||
Variable::new("i", Type::Uint(UBitwidth::B32), false),
|
||||
UExpression::identifier("i".into()).annotate(UBitwidth::B32),
|
||||
2u32.into(),
|
||||
vec![TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(Identifier::from(
|
||||
"foo",
|
||||
))),
|
||||
FieldElementExpression::Number(Bn128Field::from(5)).into(),
|
||||
)],
|
||||
),
|
||||
TypedStatement::Return(
|
||||
TupleExpressionInner::Value(vec![])
|
||||
.annotate(TupleType::new(vec![]))
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::default(),
|
||||
};
|
||||
|
||||
match ShallowTransformer::transform(
|
||||
f,
|
||||
&ConcreteGenericsAssignment::default(),
|
||||
&mut Versions::default(),
|
||||
) {
|
||||
Output::Incomplete(..) => {}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let mut ssa = ShallowTransformer::default();
|
||||
|
||||
assert_eq!(ssa.fold_function(f.clone()), f);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn definition() {
|
||||
// field a
|
||||
// a = 5
|
||||
// field a = 5
|
||||
// a = 6
|
||||
// a
|
||||
|
||||
|
@ -241,9 +250,7 @@ mod tests {
|
|||
// a_1 = 6
|
||||
// a_1
|
||||
|
||||
let mut versions = Versions::new();
|
||||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
let mut u = ShallowTransformer::default();
|
||||
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("a")),
|
||||
|
@ -283,17 +290,14 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn incremental_definition() {
|
||||
// field a
|
||||
// a = 5
|
||||
// field a = 5
|
||||
// a = a + 1
|
||||
|
||||
// should be turned into
|
||||
// a_0 = 5
|
||||
// a_1 = a_0 + 1
|
||||
|
||||
let mut versions = Versions::new();
|
||||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
let mut u = ShallowTransformer::default();
|
||||
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("a")),
|
||||
|
@ -342,9 +346,7 @@ mod tests {
|
|||
// a_0 = 2
|
||||
// a_1 = foo(a_0)
|
||||
|
||||
let mut versions = Versions::new();
|
||||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
let mut u = ShallowTransformer::default();
|
||||
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("a")),
|
||||
|
@ -403,9 +405,7 @@ mod tests {
|
|||
// a_0 = [1, 1]
|
||||
// a_0[1] = 2
|
||||
|
||||
let mut versions = Versions::new();
|
||||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
let mut u = ShallowTransformer::default();
|
||||
|
||||
let s = TypedStatement::definition(
|
||||
TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)),
|
||||
|
@ -460,9 +460,7 @@ mod tests {
|
|||
// a_0 = [[0, 1], [2, 3]]
|
||||
// a_0 = [4, 5]
|
||||
|
||||
let mut versions = Versions::new();
|
||||
|
||||
let mut u = ShallowTransformer::with_versions(&mut versions);
|
||||
let mut u = ShallowTransformer::default();
|
||||
|
||||
let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32));
|
||||
|
||||
|
@ -557,10 +555,10 @@ mod tests {
|
|||
|
||||
mod for_loop {
|
||||
use super::*;
|
||||
use zokrates_ast::typed::types::GGenericsAssignment;
|
||||
|
||||
#[test]
|
||||
fn treat_loop() {
|
||||
// def main<K>(field a) -> field {
|
||||
// def main(field a) -> field {
|
||||
// u32 n = 42;
|
||||
// n = n;
|
||||
// a = a;
|
||||
|
@ -575,24 +573,21 @@ mod tests {
|
|||
// return a;
|
||||
// }
|
||||
|
||||
// When called with K := 1, expected:
|
||||
// expected:
|
||||
// def main(field a_0) -> field {
|
||||
// u32 K = 1;
|
||||
// u32 n_0 = 42;
|
||||
// n_1 = n_0;
|
||||
// a_1 = a_0;
|
||||
// # versions: {n: 1, a: 1, K: 0}
|
||||
// for u32 i_0 in n_1..n_1*n_1 {
|
||||
// a_0 = a_0;
|
||||
// }
|
||||
// a_2 = a_1;
|
||||
// for u32 i_0 in n_1..n_1*n_1 {
|
||||
// a_0 = a_0;
|
||||
// }
|
||||
// a_3 = a_2;
|
||||
// # versions: {n: 2, a: 3, K: 1}
|
||||
// for u32 i_0 in n_2..n_2*n_2 {
|
||||
// a_0 = a_0;
|
||||
// }
|
||||
// a_5 = a_4;
|
||||
// return a_5;
|
||||
// } # versions: {n: 3, a: 5, K: 2}
|
||||
// return a_3;
|
||||
// }
|
||||
|
||||
let f: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
|
@ -642,32 +637,15 @@ mod tests {
|
|||
TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").with_index(0).into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let mut versions = Versions::default();
|
||||
|
||||
let ssa = ShallowTransformer::transform(
|
||||
f,
|
||||
&GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
&mut versions,
|
||||
);
|
||||
let mut ssa = ShallowTransformer::default();
|
||||
|
||||
let expected = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::definition(
|
||||
Variable::uint("K", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(1u32.into()),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(42u32.into()),
|
||||
|
@ -696,16 +674,16 @@ mod tests {
|
|||
)],
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
|
||||
Variable::field_element(Identifier::from("a").version(2)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(1)).into(),
|
||||
),
|
||||
TypedStatement::For(
|
||||
Variable::uint("i", UBitwidth::B32),
|
||||
UExpression::identifier(Identifier::from("n").version(2))
|
||||
UExpression::identifier(Identifier::from("n").version(1))
|
||||
.annotate(UBitwidth::B32),
|
||||
UExpression::identifier(Identifier::from("n").version(2))
|
||||
UExpression::identifier(Identifier::from("n").version(1))
|
||||
.annotate(UBitwidth::B32)
|
||||
* UExpression::identifier(Identifier::from("n").version(2))
|
||||
* UExpression::identifier(Identifier::from("n").version(1))
|
||||
.annotate(UBitwidth::B32),
|
||||
vec![TypedStatement::definition(
|
||||
Variable::field_element("a").into(),
|
||||
|
@ -713,47 +691,35 @@ mod tests {
|
|||
)],
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a").version(5)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(4)).into(),
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(2)).into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(5)).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a").version(3)).into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").with_index(0).into(),
|
||||
)])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let res = ssa.fold_function(f);
|
||||
|
||||
assert_eq!(
|
||||
versions,
|
||||
vec![("n".into(), 3), ("a".into(), 5), ("K".into(), 2)]
|
||||
.into_iter()
|
||||
.collect::<Versions>()
|
||||
ssa.versions.map,
|
||||
vec![(
|
||||
0,
|
||||
vec![("n".into(), 1), ("a".into(), 3)].into_iter().collect()
|
||||
)]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
let expected = Output::Incomplete(
|
||||
expected,
|
||||
vec![
|
||||
vec![("n".into(), 1), ("a".into(), 1), ("K".into(), 0)]
|
||||
.into_iter()
|
||||
.collect::<Versions>(),
|
||||
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 1)]
|
||||
.into_iter()
|
||||
.collect::<Versions>(),
|
||||
],
|
||||
);
|
||||
|
||||
assert_eq!(ssa, expected);
|
||||
assert_eq!(res, expected);
|
||||
}
|
||||
}
|
||||
|
||||
mod shadowing {
|
||||
use zokrates_ast::typed::types::GGenericsAssignment;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
|
@ -764,11 +730,11 @@ mod tests {
|
|||
// return;
|
||||
// }
|
||||
|
||||
// should become
|
||||
// should become (only the field variable is affected as shadowing is taken care of in semantics already)
|
||||
|
||||
// def main(field a_0) {
|
||||
// field a_1 = 42;
|
||||
// bool a_2 = true;
|
||||
// def main(field a_s0_v0) {
|
||||
// field a_s0_v1 = 42;
|
||||
// bool a_s1_v0 = true
|
||||
// return;
|
||||
// }
|
||||
|
||||
|
@ -780,7 +746,11 @@ mod tests {
|
|||
TypedExpression::Uint(42u32.into()),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::boolean("a").into(),
|
||||
Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow(
|
||||
"a".into(),
|
||||
1,
|
||||
)))
|
||||
.into(),
|
||||
BooleanExpression::Value(true).into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
|
@ -789,9 +759,7 @@ mod tests {
|
|||
.into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![])
|
||||
.inputs(vec![DeclarationType::FieldElement]),
|
||||
signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let expected: TypedFunction<Bn128Field> = TypedFunction {
|
||||
|
@ -802,7 +770,11 @@ mod tests {
|
|||
TypedExpression::Uint(42u32.into()),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::boolean(Identifier::from("a").version(2)).into(),
|
||||
Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow(
|
||||
"a".into(),
|
||||
1,
|
||||
)))
|
||||
.into(),
|
||||
BooleanExpression::Value(true).into(),
|
||||
),
|
||||
TypedStatement::Return(
|
||||
|
@ -811,121 +783,17 @@ mod tests {
|
|||
.into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![])
|
||||
.inputs(vec![DeclarationType::FieldElement]),
|
||||
signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let mut versions = Versions::default();
|
||||
let ssa = ShallowTransformer::default().fold_function(f);
|
||||
|
||||
let ssa =
|
||||
ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions);
|
||||
|
||||
assert_eq!(ssa, Output::Complete(expected));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_scope() {
|
||||
// def main(field a) {
|
||||
// for u32 i in 0..1 {
|
||||
// a = a + 1
|
||||
// field a = 42
|
||||
// }
|
||||
// return a
|
||||
// }
|
||||
|
||||
// should become
|
||||
|
||||
// def main(field a_0) {
|
||||
// # versions: {a: 0}
|
||||
// for u32 i in 0..1 {
|
||||
// a_0 = a_0
|
||||
// field a_0 = 42
|
||||
// }
|
||||
// return a_1
|
||||
// }
|
||||
|
||||
let f: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::For(
|
||||
Variable::uint("i", UBitwidth::B32),
|
||||
0u32.into(),
|
||||
1u32.into(),
|
||||
vec![
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a")).into(),
|
||||
FieldElementExpression::identifier("a".into()).into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a")).into(),
|
||||
FieldElementExpression::Number(42usize.into()).into(),
|
||||
),
|
||||
],
|
||||
),
|
||||
TypedStatement::Return(
|
||||
TupleExpressionInner::Value(vec![FieldElementExpression::identifier(
|
||||
"a".into(),
|
||||
)
|
||||
.into()])
|
||||
.annotate(TupleType::new(vec![Type::FieldElement]))
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let expected: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::For(
|
||||
Variable::uint("i", UBitwidth::B32),
|
||||
0u32.into(),
|
||||
1u32.into(),
|
||||
vec![
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a")).into(),
|
||||
FieldElementExpression::identifier(Identifier::from("a")).into(),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::field_element(Identifier::from("a")).into(),
|
||||
FieldElementExpression::Number(42usize.into()).into(),
|
||||
),
|
||||
],
|
||||
),
|
||||
TypedStatement::Return(
|
||||
TupleExpressionInner::Value(vec![FieldElementExpression::identifier(
|
||||
Identifier::from("a").version(1),
|
||||
)
|
||||
.into()])
|
||||
.annotate(TupleType::new(vec![Type::FieldElement]))
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
signature: DeclarationSignature::new()
|
||||
.generics(vec![])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let mut versions = Versions::default();
|
||||
|
||||
let ssa =
|
||||
ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions);
|
||||
|
||||
assert_eq!(
|
||||
ssa,
|
||||
Output::Incomplete(expected, vec![vec![("a".into(), 0)].into_iter().collect()])
|
||||
);
|
||||
assert_eq!(ssa, expected);
|
||||
}
|
||||
}
|
||||
|
||||
mod function_call {
|
||||
use super::*;
|
||||
use zokrates_ast::typed::types::GGenericsAssignment;
|
||||
// test that function calls are left in
|
||||
#[test]
|
||||
fn treat_calls() {
|
||||
|
@ -939,17 +807,12 @@ mod tests {
|
|||
// return a;
|
||||
// }
|
||||
|
||||
// When called with K := 1, expected:
|
||||
// def main(field a_0) -> field {
|
||||
// K = 1;
|
||||
// u32 n_0 = 42;
|
||||
// n_1 = n_0;
|
||||
// a_1 = a_0;
|
||||
// a_2 = foo::<n_1>(a_1);
|
||||
// n_2 = n_1;
|
||||
// a_3 = a_2 * foo::<n_2>(a_2);
|
||||
// a_2 = foo::<42>(a_1);
|
||||
// a_3 = a_2 * foo::<42>(a_2);
|
||||
// return a_3;
|
||||
// } # versions: {n: 2, a: 3}
|
||||
// }
|
||||
|
||||
let f: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
|
@ -1007,25 +870,9 @@ mod tests {
|
|||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let mut versions = Versions::default();
|
||||
|
||||
let ssa = ShallowTransformer::transform(
|
||||
f,
|
||||
&GGenericsAssignment(
|
||||
vec![(GenericIdentifier::with_name("K").with_index(0), 1)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
&mut versions,
|
||||
);
|
||||
|
||||
let expected = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
statements: vec![
|
||||
TypedStatement::definition(
|
||||
Variable::uint("K", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(1u32.into()),
|
||||
),
|
||||
TypedStatement::definition(
|
||||
Variable::uint("n", UBitwidth::B32).into(),
|
||||
TypedExpression::Uint(42u32.into()),
|
||||
|
@ -1089,14 +936,23 @@ mod tests {
|
|||
.output(DeclarationType::FieldElement),
|
||||
};
|
||||
|
||||
let mut ssa = ShallowTransformer::default();
|
||||
|
||||
let res = ssa.fold_function(f);
|
||||
|
||||
assert_eq!(
|
||||
versions,
|
||||
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)]
|
||||
.into_iter()
|
||||
.collect::<Versions>()
|
||||
ssa.versions.map,
|
||||
vec![(
|
||||
0,
|
||||
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)]
|
||||
.into_iter()
|
||||
.collect()
|
||||
)]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
assert_eq!(ssa, Output::Incomplete(expected, vec![],));
|
||||
assert_eq!(res, expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ use crate::typed::types::*;
|
|||
use crate::typed::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
use super::identifier::FrameIdentifier;
|
||||
|
||||
pub trait Fold<'ast, T: Field>: Sized {
|
||||
fn fold<F: Folder<'ast, T>>(self, f: &mut F) -> Self;
|
||||
}
|
||||
|
@ -128,11 +130,12 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
}
|
||||
|
||||
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
|
||||
let id = match n.id {
|
||||
CoreIdentifier::Constant(c) => {
|
||||
CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c))
|
||||
}
|
||||
id => id,
|
||||
let id = match n.id.id.clone() {
|
||||
CoreIdentifier::Constant(c) => FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)),
|
||||
frame: 0,
|
||||
},
|
||||
_ => n.id,
|
||||
};
|
||||
|
||||
Identifier { id, ..n }
|
||||
|
@ -528,10 +531,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
) -> Vec<TypedAssemblyStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) => {
|
||||
vec![TypedAssemblyStatement::Assignment(
|
||||
f.fold_assignee(a),
|
||||
f.fold_expression(e),
|
||||
)]
|
||||
let e = f.fold_expression(e);
|
||||
vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a), e)]
|
||||
}
|
||||
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
vec![TypedAssemblyStatement::Constraint(
|
||||
|
@ -549,8 +550,9 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
) -> Vec<TypedStatement<'ast, T>> {
|
||||
let res = match s {
|
||||
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)),
|
||||
TypedStatement::Definition(a, e) => {
|
||||
TypedStatement::Definition(f.fold_assignee(a), f.fold_definition_rhs(e))
|
||||
TypedStatement::Definition(a, rhs) => {
|
||||
let rhs = f.fold_definition_rhs(rhs);
|
||||
TypedStatement::Definition(f.fold_assignee(a), rhs)
|
||||
}
|
||||
TypedStatement::Assertion(e, error) => {
|
||||
TypedStatement::Assertion(f.fold_boolean_expression(e), error)
|
||||
|
@ -573,7 +575,6 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
.flat_map(|s| f.fold_assembly_statement(s))
|
||||
.collect(),
|
||||
),
|
||||
s => s,
|
||||
};
|
||||
vec![res]
|
||||
}
|
||||
|
|
|
@ -24,18 +24,49 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for CoreIdentifier<'ast> {
|
||||
fn from(s: CanonicalConstantIdentifier<'ast>) -> CoreIdentifier<'ast> {
|
||||
CoreIdentifier::Constant(s)
|
||||
impl<'ast> FrameIdentifier<'ast> {
|
||||
pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> {
|
||||
FrameIdentifier { frame, ..self }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> Identifier<'ast> {
|
||||
pub fn in_frame(self, frame: usize) -> Identifier<'ast> {
|
||||
Identifier {
|
||||
id: self.id.in_frame(frame),
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> CoreIdentifier<'ast> {
|
||||
pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> {
|
||||
FrameIdentifier { id: self, frame }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for FrameIdentifier<'ast> {
|
||||
fn from(s: CanonicalConstantIdentifier<'ast>) -> FrameIdentifier<'ast> {
|
||||
FrameIdentifier::from(CoreIdentifier::Constant(s))
|
||||
}
|
||||
}
|
||||
|
||||
/// A identifier for a variable in a given call frame
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct FrameIdentifier<'ast> {
|
||||
/// the id of the variable
|
||||
#[serde(borrow)]
|
||||
pub id: CoreIdentifier<'ast>,
|
||||
/// the frame of the variable
|
||||
pub frame: usize,
|
||||
}
|
||||
|
||||
/// A identifier for a variable
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Identifier<'ast> {
|
||||
/// the id of the variable
|
||||
#[serde(borrow)]
|
||||
pub id: CoreIdentifier<'ast>,
|
||||
pub id: FrameIdentifier<'ast>,
|
||||
/// the version of the variable, used after SSA transformation
|
||||
pub version: usize,
|
||||
}
|
||||
|
@ -58,7 +89,7 @@ impl<'ast> fmt::Display for ShadowedIdentifier<'ast> {
|
|||
if self.shadow == 0 {
|
||||
write!(f, "{}", self.id)
|
||||
} else {
|
||||
write!(f, "{}_{}", self.id, self.shadow)
|
||||
write!(f, "{}_s{}", self.id, self.shadow)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -68,20 +99,45 @@ impl<'ast> fmt::Display for Identifier<'ast> {
|
|||
if self.version == 0 {
|
||||
write!(f, "{}", self.id)
|
||||
} else {
|
||||
write!(f, "{}_{}", self.id, self.version)
|
||||
write!(f, "{}_v{}", self.id, self.version)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for FrameIdentifier<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
if self.frame == 0 {
|
||||
write!(f, "{}", self.id)
|
||||
} else {
|
||||
write!(f, "{}_f{}", self.id, self.frame)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for Identifier<'ast> {
|
||||
fn from(id: CanonicalConstantIdentifier<'ast>) -> Identifier<'ast> {
|
||||
Identifier::from(CoreIdentifier::Constant(id))
|
||||
Identifier::from(FrameIdentifier::from(CoreIdentifier::Constant(id)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<FrameIdentifier<'ast>> for Identifier<'ast> {
|
||||
fn from(id: FrameIdentifier<'ast>) -> Identifier<'ast> {
|
||||
Identifier { id, version: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<CoreIdentifier<'ast>> for Identifier<'ast> {
|
||||
fn from(id: CoreIdentifier<'ast>) -> Identifier<'ast> {
|
||||
Identifier { id, version: 0 }
|
||||
Identifier {
|
||||
id: FrameIdentifier::from(id),
|
||||
version: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<CoreIdentifier<'ast>> for FrameIdentifier<'ast> {
|
||||
fn from(id: CoreIdentifier<'ast>) -> FrameIdentifier<'ast> {
|
||||
FrameIdentifier { id, frame: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -107,6 +163,6 @@ impl<'ast> From<&'ast str> for CoreIdentifier<'ast> {
|
|||
|
||||
impl<'ast> From<&'ast str> for Identifier<'ast> {
|
||||
fn from(id: &'ast str) -> Identifier<'ast> {
|
||||
Identifier::from(CoreIdentifier::from(id))
|
||||
Identifier::from(FrameIdentifier::from(CoreIdentifier::from(id)))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ pub use self::types::{
|
|||
UBitwidth,
|
||||
};
|
||||
use self::types::{ConcreteArrayType, ConcreteStructType};
|
||||
use crate::typed::types::{ConcreteGenericsAssignment, IntoType};
|
||||
use crate::typed::types::IntoType;
|
||||
|
||||
pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable};
|
||||
use std::marker::PhantomData;
|
||||
|
@ -353,19 +353,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
|
|||
|
||||
writeln!(f)?;
|
||||
|
||||
let mut tab = 0;
|
||||
|
||||
for s in &self.statements {
|
||||
if let TypedStatement::PopCallLog = s {
|
||||
tab -= 1;
|
||||
};
|
||||
|
||||
s.fmt_indented(f, 1 + tab)?;
|
||||
writeln!(f)?;
|
||||
|
||||
if let TypedStatement::PushCallLog(..) = s {
|
||||
tab += 1;
|
||||
};
|
||||
writeln!(f, "{}", s)?;
|
||||
}
|
||||
|
||||
writeln!(f, "}}")?;
|
||||
|
@ -695,12 +684,6 @@ pub enum TypedStatement<'ast, T> {
|
|||
Vec<TypedStatement<'ast, T>>,
|
||||
),
|
||||
Log(FormatString, Vec<TypedExpression<'ast, T>>),
|
||||
// Aux
|
||||
PushCallLog(
|
||||
DeclarationFunctionKey<'ast, T>,
|
||||
ConcreteGenericsAssignment<'ast>,
|
||||
),
|
||||
PopCallLog,
|
||||
Assembly(Vec<TypedAssemblyStatement<'ast, T>>),
|
||||
}
|
||||
|
||||
|
@ -714,31 +697,6 @@ impl<'ast, T> TypedStatement<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> TypedStatement<'ast, T> {
|
||||
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
|
||||
match self {
|
||||
TypedStatement::For(variable, from, to, statements) => {
|
||||
write!(f, "{}", "\t".repeat(depth))?;
|
||||
writeln!(f, "for {} in {}..{} {{", variable, from, to)?;
|
||||
for s in statements {
|
||||
s.fmt_indented(f, depth + 1)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
write!(f, "{}}}", "\t".repeat(depth))
|
||||
}
|
||||
TypedStatement::Assembly(statements) => {
|
||||
write!(f, "{}", "\t".repeat(depth))?;
|
||||
writeln!(f, "asm {{")?;
|
||||
for s in statements {
|
||||
writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?;
|
||||
}
|
||||
write!(f, "{}}}", "\t".repeat(depth))
|
||||
}
|
||||
s => write!(f, "{}{}", "\t".repeat(depth), s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
|
@ -773,14 +731,6 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
|
|||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
),
|
||||
TypedStatement::PushCallLog(ref key, ref generics) => write!(
|
||||
f,
|
||||
"// PUSH CALL TO {}/{}::<{}>",
|
||||
key.module.display(),
|
||||
key.id,
|
||||
generics,
|
||||
),
|
||||
TypedStatement::PopCallLog => write!(f, "// POP CALL",),
|
||||
TypedStatement::Assembly(ref statements) => {
|
||||
writeln!(f, "asm {{")?;
|
||||
for s in statements {
|
||||
|
|
|
@ -4,6 +4,8 @@ use crate::typed::types::*;
|
|||
use crate::typed::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
use super::identifier::FrameIdentifier;
|
||||
|
||||
pub trait ResultFold<'ast, T: Field>: Sized {
|
||||
fn fold<F: ResultFolder<'ast, T>>(self, f: &mut F) -> Result<Self, F::Error>;
|
||||
}
|
||||
|
@ -156,11 +158,12 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
}
|
||||
|
||||
fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> {
|
||||
let id = match n.id {
|
||||
CoreIdentifier::Constant(c) => {
|
||||
CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)?)
|
||||
}
|
||||
id => id,
|
||||
let id = match n.id.id.clone() {
|
||||
CoreIdentifier::Constant(c) => FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)?),
|
||||
frame: 0,
|
||||
},
|
||||
_ => n.id,
|
||||
};
|
||||
|
||||
Ok(Identifier { id, ..n })
|
||||
|
@ -529,10 +532,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, F::Error> {
|
||||
Ok(match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) => {
|
||||
vec![TypedAssemblyStatement::Assignment(
|
||||
f.fold_assignee(a)?,
|
||||
f.fold_expression(e)?,
|
||||
)]
|
||||
let e = f.fold_expression(e)?;
|
||||
vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, e)]
|
||||
}
|
||||
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
vec![TypedAssemblyStatement::Constraint(
|
||||
|
@ -551,7 +552,8 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
let res = match s {
|
||||
TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)?),
|
||||
TypedStatement::Definition(a, e) => {
|
||||
TypedStatement::Definition(f.fold_assignee(a)?, f.fold_definition_rhs(e)?)
|
||||
let rhs = f.fold_definition_rhs(e)?;
|
||||
TypedStatement::Definition(f.fold_assignee(a)?, rhs)
|
||||
}
|
||||
TypedStatement::Assertion(e, error) => {
|
||||
TypedStatement::Assertion(f.fold_boolean_expression(e)?, error)
|
||||
|
@ -583,7 +585,6 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
.flatten()
|
||||
.collect(),
|
||||
),
|
||||
s => s,
|
||||
};
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
|
|
@ -241,13 +241,14 @@ impl<'ast, T: Field> From<DeclarationConstant<'ast, T>> for UExpression<'ast, T>
|
|||
fn from(c: DeclarationConstant<'ast, T>) -> Self {
|
||||
match c {
|
||||
DeclarationConstant::Generic(g) => {
|
||||
UExpression::identifier(CoreIdentifier::from(g).into()).annotate(UBitwidth::B32)
|
||||
UExpression::identifier(Identifier::from(CoreIdentifier::from(g)))
|
||||
.annotate(UBitwidth::B32)
|
||||
}
|
||||
DeclarationConstant::Concrete(v) => {
|
||||
UExpressionInner::Value(v as u128).annotate(UBitwidth::B32)
|
||||
}
|
||||
DeclarationConstant::Constant(v) => {
|
||||
UExpression::identifier(CoreIdentifier::from(v).into()).annotate(UBitwidth::B32)
|
||||
UExpression::identifier(FrameIdentifier::from(v).into()).annotate(UBitwidth::B32)
|
||||
}
|
||||
DeclarationConstant::Expression(e) => e.try_into().unwrap(),
|
||||
}
|
||||
|
@ -1144,8 +1145,7 @@ pub fn check_type<'ast, T, S: Clone + PartialEq + PartialEq<u32>>(
|
|||
|
||||
impl<'ast, T: Field> From<CanonicalConstantIdentifier<'ast>> for UExpression<'ast, T> {
|
||||
fn from(c: CanonicalConstantIdentifier<'ast>) -> Self {
|
||||
UExpression::identifier(Identifier::from(CoreIdentifier::Constant(c)))
|
||||
.annotate(UBitwidth::B32)
|
||||
UExpression::identifier(Identifier::from(FrameIdentifier::from(c))).annotate(UBitwidth::B32)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1230,6 +1230,7 @@ pub use self::signature::{
|
|||
try_from_g_signature, ConcreteSignature, DeclarationSignature, GSignature, Signature,
|
||||
};
|
||||
|
||||
use super::identifier::FrameIdentifier;
|
||||
use super::{Id, ShadowedIdentifier};
|
||||
|
||||
pub mod signature {
|
||||
|
|
|
@ -1170,7 +1170,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
|
||||
let id = arg.id.value.id;
|
||||
let info = IdentifierInfo {
|
||||
id: decl_v.id.id.clone(),
|
||||
id: decl_v.id.id.id.clone(),
|
||||
ty,
|
||||
is_mutable,
|
||||
};
|
||||
|
|
15
zokrates_core_test/tests/tests/call_ssa.json
Normal file
15
zokrates_core_test/tests/tests/call_ssa.json
Normal file
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"max_constraint_count": 1,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["0"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "4"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
11
zokrates_core_test/tests/tests/call_ssa.zok
Normal file
11
zokrates_core_test/tests/tests/call_ssa.zok
Normal file
|
@ -0,0 +1,11 @@
|
|||
// main should be x -> x + 4
|
||||
|
||||
def foo(field mut a) -> field {
|
||||
a = a + 1;
|
||||
return a + 1;
|
||||
}
|
||||
|
||||
def main(field mut a) -> field {
|
||||
a = foo(a + 1);
|
||||
return a + 1;
|
||||
}
|
Loading…
Reference in a new issue