1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

reduce ssa gap to 1

This commit is contained in:
schaeff 2021-11-16 15:58:37 +01:00
parent d2e8b905c1
commit d518d2d35e
2 changed files with 33 additions and 22 deletions

View file

@ -412,6 +412,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
TypedStatement::For(v, from, to, statements) => {
let versions_before = self.for_loop_versions.pop().unwrap();
println!("versions before {:#?}", versions_before);
println!("versions {:#?}", self.versions);
match (from.as_inner(), to.as_inner()) {
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => {
let mut out_statements = vec![];
@ -422,13 +425,17 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
// add this set of versions to the substitution, pointing to the versions before the loop
register(self.substitutions, self.versions, &versions_before);
// the versions after the loop are found by applying an offset of 2 to the versions before the loop
// the versions after the loop are found by applying an offset of 1 to the versions before the loop
let versions_after = versions_before
.clone()
.into_iter()
.map(|(k, v)| (k, v + 2))
.map(|(k, v)| (k, v + 1))
.collect();
println!("versions after {:#?}", versions_after);
println!("versions for the loop {:#?}", self.versions);
let mut transformer = ShallowTransformer::with_versions(self.versions);
if to - from > MAX_FOR_LOOP_SIZE {
@ -572,6 +579,8 @@ fn reduce_function<'ast, T: Field>(
let mut hash = None;
loop {
log::trace!("BEFORE REDUCE {}", f);
let mut reducer = Reducer::new(
program,
&mut versions,
@ -591,6 +600,8 @@ fn reduce_function<'ast, T: Field>(
..f
};
log::trace!("AFTER REDUCE {}", new_f);
assert!(reducer.for_loop_versions.is_empty());
match reducer.complete {

View file

@ -51,10 +51,10 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
}
}
// increase all versions by 2 and return the old versions
// increase all versions by 1 and return the old versions
fn create_version_gap(&mut self) -> Versions<'ast> {
let ret = self.versions.clone();
self.versions.values_mut().for_each(|v| *v += 2);
self.versions.values_mut().for_each(|v| *v += 1);
ret
}
@ -598,18 +598,18 @@ mod tests {
// u32 n_0 = 42
// n_1 = n_0
// a_1 = a_0
// # versions: {n: 1, a: 1}
// for u32 i_0 in n_0..n_0*n_0:
// # versions: {n: 1, a: 1, K: 0}
// for u32 i_0 in n_1..n_1*n_1:
// a_0 = a_0
// endfor
// a_4 = a_3
// # versions: {n: 3, a: 4}
// for u32 i_0 in n_0..n_0*n_0:
// a_3 = a_2
// # versions: {n: 2, a: 3, K: 1}
// for u32 i_0 in n_2..n_2*n_2:
// a_0 = a_0
// endfor
// a_7 = a_6
// return a_7
// # versions: {n: 5, a: 7}
// a_5 = a_4
// return a_5
// # versions: {n: 3, a: 5, K: 2}
let f: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![DeclarationVariable::field_element("a").into()],
@ -715,16 +715,16 @@ mod tests {
)],
),
TypedStatement::Definition(
Variable::field_element(Identifier::from("a").version(4)).into(),
FieldElementExpression::Identifier(Identifier::from("a").version(3)).into(),
Variable::field_element(Identifier::from("a").version(3)).into(),
FieldElementExpression::Identifier(Identifier::from("a").version(2)).into(),
),
TypedStatement::For(
Variable::uint("i", UBitwidth::B32),
UExpressionInner::Identifier(Identifier::from("n").version(3))
UExpressionInner::Identifier(Identifier::from("n").version(2))
.annotate(UBitwidth::B32),
UExpressionInner::Identifier(Identifier::from("n").version(3))
UExpressionInner::Identifier(Identifier::from("n").version(2))
.annotate(UBitwidth::B32)
* UExpressionInner::Identifier(Identifier::from("n").version(3))
* UExpressionInner::Identifier(Identifier::from("n").version(2))
.annotate(UBitwidth::B32),
vec![TypedStatement::Definition(
Variable::field_element("a").into(),
@ -732,11 +732,11 @@ mod tests {
)],
),
TypedStatement::Definition(
Variable::field_element(Identifier::from("a").version(7)).into(),
FieldElementExpression::Identifier(Identifier::from("a").version(6)).into(),
Variable::field_element(Identifier::from("a").version(5)).into(),
FieldElementExpression::Identifier(Identifier::from("a").version(4)).into(),
),
TypedStatement::Return(vec![FieldElementExpression::Identifier(
Identifier::from("a").version(7),
Identifier::from("a").version(5),
)
.into()]),
],
@ -750,7 +750,7 @@ mod tests {
assert_eq!(
versions,
vec![("n".into(), 5), ("a".into(), 7), ("K".into(), 4)]
vec![("n".into(), 3), ("a".into(), 5), ("K".into(), 2)]
.into_iter()
.collect::<Versions>()
);
@ -761,7 +761,7 @@ mod tests {
vec![("n".into(), 1), ("a".into(), 1), ("K".into(), 0)]
.into_iter()
.collect::<Versions>(),
vec![("n".into(), 3), ("a".into(), 4), ("K".into(), 2)]
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 1)]
.into_iter()
.collect::<Versions>(),
],