From c8111c77c2c8107a95d2a2ae77f1f4e9c0557255 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 20 May 2020 17:45:50 +0200 Subject: [PATCH] simplify uint optimizer with with_max helper, change strategy for if_else --- zokrates_core/src/flatten/mod.rs | 102 ++++++------------ .../src/static_analysis/uint_optimizer.rs | 78 +++----------- zokrates_core/src/zir/uint.rs | 9 +- .../tests/tests/uint/if_else.json | 4 +- 4 files changed, 59 insertions(+), 134 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 2cc95e32..880ed096 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -31,7 +31,7 @@ pub struct Flattener<'ast, T: Field> { trait FlattenOutput: Sized { // fn branches(self, other: Self) -> (Self, Self); - fn flat(&self) -> Vec>; + fn flat(&self) -> FlatExpression; } impl FlattenOutput for FlatExpression { @@ -39,8 +39,8 @@ impl FlattenOutput for FlatExpression { // (self, other) // } - fn flat(&self) -> Vec> { - vec![self.clone()] + fn flat(&self) -> FlatExpression { + self.clone() } } @@ -71,14 +71,8 @@ impl FlattenOutput for FlatUExpression { // ) // } - fn flat(&self) -> Vec> { - self.bits - .clone() - .unwrap() - .clone() - .into_iter() - .chain(std::iter::once(self.field.clone().unwrap())) - .collect() + fn flat(&self) -> FlatExpression { + self.clone().get_field_unchecked() } } @@ -241,71 +235,45 @@ impl<'ast, T: Field> Flattener<'ast, T> { let consequence = consequence.flat(); let alternative = alternative.flat(); - let size = consequence.len(); + let consequence_id = self.use_sym(); + statements_flattened.push(FlatStatement::Definition(consequence_id, consequence)); - let consequence_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect(); - statements_flattened.extend( - consequence - .into_iter() - .zip(consequence_ids.iter()) - .map(|(c, c_id)| FlatStatement::Definition(*c_id, c)), - ); + let alternative_id = self.use_sym(); + statements_flattened.push(FlatStatement::Definition(alternative_id, alternative)); - let alternative_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect(); - statements_flattened.extend( - alternative - .into_iter() - .zip(alternative_ids.iter()) - .map(|(a, a_id)| FlatStatement::Definition(*a_id, a)), - ); - - let term0_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect(); - statements_flattened.extend(consequence_ids.iter().zip(term0_ids.iter()).map( - |(c_id, t0_id)| { - FlatStatement::Definition( - *t0_id, - FlatExpression::Mult( - box condition_id.clone().into(), - box FlatExpression::from(*c_id), - ), - ) - }, + let term0_id = self.use_sym(); + statements_flattened.push(FlatStatement::Definition( + term0_id, + FlatExpression::Mult( + box condition_id.clone().into(), + box FlatExpression::from(consequence_id), + ), )); - let term1_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect(); - statements_flattened.extend(alternative_ids.iter().zip(term1_ids.iter()).map( - |(a_id, t1_id)| { - FlatStatement::Definition( - *t1_id, - FlatExpression::Mult( - box FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box condition_id.into(), - ), - box FlatExpression::from(*a_id), - ), - ) - }, + let term1_id = self.use_sym(); + statements_flattened.push(FlatStatement::Definition( + term1_id, + FlatExpression::Mult( + box FlatExpression::Sub( + box FlatExpression::Number(T::one()), + box condition_id.into(), + ), + box FlatExpression::from(alternative_id), + ), )); - let res: Vec<_> = (0..size).map(|_| self.use_sym()).collect(); - statements_flattened.extend(term0_ids.iter().zip(term1_ids).zip(res.iter()).map( - |((t0_id, t1_id), r_id)| { - FlatStatement::Definition( - *r_id, - FlatExpression::Add( - box FlatExpression::from(*t0_id), - box FlatExpression::from(t1_id), - ), - ) - }, + let res = self.use_sym(); + statements_flattened.push(FlatStatement::Definition( + res, + FlatExpression::Add( + box FlatExpression::from(term0_id), + box FlatExpression::from(term1_id), + ), )); - let mut res: Vec<_> = res.into_iter().map(|r| r.into()).collect(); - FlatUExpression { - field: Some(res.pop().unwrap()), - bits: Some(res), + field: Some(FlatExpression::Identifier(res)), + bits: None, } } diff --git a/zokrates_core/src/static_analysis/uint_optimizer.rs b/zokrates_core/src/static_analysis/uint_optimizer.rs index 08d17132..128495a6 100644 --- a/zokrates_core/src/static_analysis/uint_optimizer.rs +++ b/zokrates_core/src/static_analysis/uint_optimizer.rs @@ -63,10 +63,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { use self::UExpressionInner::*; let res = match inner { - Value(v) => Value(v).annotate(range).metadata(UMetadata { - max: v.into(), - should_reduce: Some(false), - }), + Value(v) => Value(v).annotate(range).with_max(v), Identifier(id) => Identifier(id.clone()).annotate(range).metadata( self.ids .get(&Variable::uint(id.clone(), range)) @@ -108,11 +105,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { right }; - UExpression::add(left, right).metadata(UMetadata { - max, - - should_reduce: Some(false), - }) + UExpression::add(left, right).with_max(max) } Sub(box left, box right) => { // let `target` the target bitwidth of `left` and `right` @@ -171,40 +164,28 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { right }; - UExpression::sub(left, right).metadata(UMetadata { - max, - should_reduce: Some(false), - }) + UExpression::sub(left, right).with_max(max) } Xor(box left, box right) => { // reduce the two terms let left = self.fold_uint_expression(left); let right = self.fold_uint_expression(right); - UExpression::xor(force_reduce(left), force_reduce(right)).metadata(UMetadata { - max: range_max.clone(), - should_reduce: Some(false), - }) + UExpression::xor(force_reduce(left), force_reduce(right)).with_max(range_max) } And(box left, box right) => { // reduce the two terms let left = self.fold_uint_expression(left); let right = self.fold_uint_expression(right); - UExpression::and(force_reduce(left), force_reduce(right)).metadata(UMetadata { - max: range_max.clone(), - should_reduce: Some(false), - }) + UExpression::and(force_reduce(left), force_reduce(right)).with_max(range_max) } Or(box left, box right) => { // reduce the two terms let left = self.fold_uint_expression(left); let right = self.fold_uint_expression(right); - UExpression::or(force_reduce(left), force_reduce(right)).metadata(UMetadata { - max: range_max.clone(), - should_reduce: Some(false), - }) + UExpression::or(force_reduce(left), force_reduce(right)).with_max(range_max) } Mult(box left, box right) => { // reduce the two terms @@ -241,40 +222,28 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { right }; - UExpression::mult(left, right).metadata(UMetadata { - max, - should_reduce: Some(false), - }) + UExpression::mult(left, right).with_max(max) } Not(box e) => { let e = self.fold_uint_expression(e); UExpressionInner::Not(box force_reduce(e)) .annotate(range) - .metadata(UMetadata { - max: range_max.clone(), - should_reduce: Some(false), - }) + .with_max(range_max) } LeftShift(box e, box by) => { // reduce the two terms let e = self.fold_uint_expression(e); let by = self.fold_field_expression(by); - UExpression::left_shift(force_reduce(e), by).metadata(UMetadata { - max: range_max.clone(), - should_reduce: Some(true), - }) + UExpression::left_shift(force_reduce(e), by).with_max(range_max) } RightShift(box e, box by) => { // reduce the two terms let e = self.fold_uint_expression(e); let by = self.fold_field_expression(by); - UExpression::right_shift(force_reduce(e), by).metadata(UMetadata { - max: range_max.clone(), - should_reduce: Some(false), - }) + UExpression::right_shift(force_reduce(e), by).with_max(range_max) } IfElse(box condition, box consequence, box alternative) => { let consequence = self.fold_uint_expression(consequence); @@ -288,10 +257,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { alternative_max.into_big_uint(), ); - UExpression::if_else(condition, consequence, alternative).metadata(UMetadata { - max: max.into(), - should_reduce: Some(false), - }) + UExpression::if_else(condition, consequence, alternative).with_max(max) } }; @@ -323,13 +289,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { ZirExpression::Uint(e) => { let e = self.fold_uint_expression(e); - let e = UExpression { - metadata: Some(UMetadata { - should_reduce: Some(true), - ..e.metadata.unwrap() - }), - ..e - }; + let e = force_reduce(e); ZirExpression::Uint(e) } @@ -370,19 +330,9 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { ZirStatement::Condition(lhs, rhs) => { match (self.fold_expression(lhs), self.fold_expression(rhs)) { (ZirExpression::Uint(lhs), ZirExpression::Uint(rhs)) => { - let lhs_metadata = lhs.metadata.clone().unwrap(); - let rhs_metadata = rhs.metadata.clone().unwrap(); vec![ZirStatement::Condition( - lhs.metadata(UMetadata { - should_reduce: Some(true), - ..lhs_metadata - }) - .into(), - rhs.metadata(UMetadata { - should_reduce: Some(true), - ..rhs_metadata - }) - .into(), + force_reduce(lhs).into(), + force_reduce(rhs).into(), )] } (lhs, rhs) => vec![ZirStatement::Condition(lhs, rhs)], diff --git a/zokrates_core/src/zir/uint.rs b/zokrates_core/src/zir/uint.rs index 5c903edb..bfd4ccc9 100644 --- a/zokrates_core/src/zir/uint.rs +++ b/zokrates_core/src/zir/uint.rs @@ -126,13 +126,20 @@ impl<'ast, T> UExpressionInner<'ast, T> { } } -impl<'ast, T> UExpression<'ast, T> { +impl<'ast, T: Field> UExpression<'ast, T> { pub fn metadata(self, metadata: UMetadata) -> UExpression<'ast, T> { UExpression { metadata: Some(metadata), ..self } } + + pub fn with_max>(self, max: U) -> Self { + UExpression { + metadata: Some(UMetadata::with_max(max)), + ..self + } + } } impl<'ast, T> UExpression<'ast, T> { diff --git a/zokrates_core_test/tests/tests/uint/if_else.json b/zokrates_core_test/tests/tests/uint/if_else.json index 222f79c0..d4e7af04 100644 --- a/zokrates_core_test/tests/tests/uint/if_else.json +++ b/zokrates_core_test/tests/tests/uint/if_else.json @@ -8,7 +8,7 @@ }, "output": { "Ok": { - "values": ["0xff"] + "values": ["0x00"] } } }, @@ -18,7 +18,7 @@ }, "output": { "Ok": { - "values": ["0x00"] + "values": ["0xff"] } } }