reduce ssa gap to 1
This commit is contained in:
parent
d2e8b905c1
commit
d518d2d35e
2 changed files with 33 additions and 22 deletions
|
@ -412,6 +412,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
TypedStatement::For(v, from, to, statements) => {
|
||||
let versions_before = self.for_loop_versions.pop().unwrap();
|
||||
|
||||
println!("versions before {:#?}", versions_before);
|
||||
println!("versions {:#?}", self.versions);
|
||||
|
||||
match (from.as_inner(), to.as_inner()) {
|
||||
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => {
|
||||
let mut out_statements = vec![];
|
||||
|
@ -422,13 +425,17 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
// add this set of versions to the substitution, pointing to the versions before the loop
|
||||
register(self.substitutions, self.versions, &versions_before);
|
||||
|
||||
// the versions after the loop are found by applying an offset of 2 to the versions before the loop
|
||||
// the versions after the loop are found by applying an offset of 1 to the versions before the loop
|
||||
let versions_after = versions_before
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v + 2))
|
||||
.map(|(k, v)| (k, v + 1))
|
||||
.collect();
|
||||
|
||||
println!("versions after {:#?}", versions_after);
|
||||
|
||||
println!("versions for the loop {:#?}", self.versions);
|
||||
|
||||
let mut transformer = ShallowTransformer::with_versions(self.versions);
|
||||
|
||||
if to - from > MAX_FOR_LOOP_SIZE {
|
||||
|
@ -572,6 +579,8 @@ fn reduce_function<'ast, T: Field>(
|
|||
let mut hash = None;
|
||||
|
||||
loop {
|
||||
log::trace!("BEFORE REDUCE {}", f);
|
||||
|
||||
let mut reducer = Reducer::new(
|
||||
program,
|
||||
&mut versions,
|
||||
|
@ -591,6 +600,8 @@ fn reduce_function<'ast, T: Field>(
|
|||
..f
|
||||
};
|
||||
|
||||
log::trace!("AFTER REDUCE {}", new_f);
|
||||
|
||||
assert!(reducer.for_loop_versions.is_empty());
|
||||
|
||||
match reducer.complete {
|
||||
|
|
|
@ -51,10 +51,10 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
|
|||
}
|
||||
}
|
||||
|
||||
// increase all versions by 2 and return the old versions
|
||||
// increase all versions by 1 and return the old versions
|
||||
fn create_version_gap(&mut self) -> Versions<'ast> {
|
||||
let ret = self.versions.clone();
|
||||
self.versions.values_mut().for_each(|v| *v += 2);
|
||||
self.versions.values_mut().for_each(|v| *v += 1);
|
||||
ret
|
||||
}
|
||||
|
||||
|
@ -598,18 +598,18 @@ mod tests {
|
|||
// u32 n_0 = 42
|
||||
// n_1 = n_0
|
||||
// a_1 = a_0
|
||||
// # versions: {n: 1, a: 1}
|
||||
// for u32 i_0 in n_0..n_0*n_0:
|
||||
// # versions: {n: 1, a: 1, K: 0}
|
||||
// for u32 i_0 in n_1..n_1*n_1:
|
||||
// a_0 = a_0
|
||||
// endfor
|
||||
// a_4 = a_3
|
||||
// # versions: {n: 3, a: 4}
|
||||
// for u32 i_0 in n_0..n_0*n_0:
|
||||
// a_3 = a_2
|
||||
// # versions: {n: 2, a: 3, K: 1}
|
||||
// for u32 i_0 in n_2..n_2*n_2:
|
||||
// a_0 = a_0
|
||||
// endfor
|
||||
// a_7 = a_6
|
||||
// return a_7
|
||||
// # versions: {n: 5, a: 7}
|
||||
// a_5 = a_4
|
||||
// return a_5
|
||||
// # versions: {n: 3, a: 5, K: 2}
|
||||
|
||||
let f: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![DeclarationVariable::field_element("a").into()],
|
||||
|
@ -715,16 +715,16 @@ mod tests {
|
|||
)],
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::field_element(Identifier::from("a").version(4)).into(),
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(3)).into(),
|
||||
Variable::field_element(Identifier::from("a").version(3)).into(),
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(2)).into(),
|
||||
),
|
||||
TypedStatement::For(
|
||||
Variable::uint("i", UBitwidth::B32),
|
||||
UExpressionInner::Identifier(Identifier::from("n").version(3))
|
||||
UExpressionInner::Identifier(Identifier::from("n").version(2))
|
||||
.annotate(UBitwidth::B32),
|
||||
UExpressionInner::Identifier(Identifier::from("n").version(3))
|
||||
UExpressionInner::Identifier(Identifier::from("n").version(2))
|
||||
.annotate(UBitwidth::B32)
|
||||
* UExpressionInner::Identifier(Identifier::from("n").version(3))
|
||||
* UExpressionInner::Identifier(Identifier::from("n").version(2))
|
||||
.annotate(UBitwidth::B32),
|
||||
vec![TypedStatement::Definition(
|
||||
Variable::field_element("a").into(),
|
||||
|
@ -732,11 +732,11 @@ mod tests {
|
|||
)],
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
Variable::field_element(Identifier::from("a").version(7)).into(),
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(6)).into(),
|
||||
Variable::field_element(Identifier::from("a").version(5)).into(),
|
||||
FieldElementExpression::Identifier(Identifier::from("a").version(4)).into(),
|
||||
),
|
||||
TypedStatement::Return(vec![FieldElementExpression::Identifier(
|
||||
Identifier::from("a").version(7),
|
||||
Identifier::from("a").version(5),
|
||||
)
|
||||
.into()]),
|
||||
],
|
||||
|
@ -750,7 +750,7 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
versions,
|
||||
vec![("n".into(), 5), ("a".into(), 7), ("K".into(), 4)]
|
||||
vec![("n".into(), 3), ("a".into(), 5), ("K".into(), 2)]
|
||||
.into_iter()
|
||||
.collect::<Versions>()
|
||||
);
|
||||
|
@ -761,7 +761,7 @@ mod tests {
|
|||
vec![("n".into(), 1), ("a".into(), 1), ("K".into(), 0)]
|
||||
.into_iter()
|
||||
.collect::<Versions>(),
|
||||
vec![("n".into(), 3), ("a".into(), 4), ("K".into(), 2)]
|
||||
vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 1)]
|
||||
.into_iter()
|
||||
.collect::<Versions>(),
|
||||
],
|
||||
|
|
Loading…
Reference in a new issue