1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

simplify uint optimizer with with_max helper, change strategy for if_else

This commit is contained in:
schaeff 2020-05-20 17:45:50 +02:00
parent 9e18975e84
commit c8111c77c2
4 changed files with 59 additions and 134 deletions

View file

@ -31,7 +31,7 @@ pub struct Flattener<'ast, T: Field> {
trait FlattenOutput<T: Field>: Sized {
// fn branches(self, other: Self) -> (Self, Self);
fn flat(&self) -> Vec<FlatExpression<T>>;
fn flat(&self) -> FlatExpression<T>;
}
impl<T: Field> FlattenOutput<T> for FlatExpression<T> {
@ -39,8 +39,8 @@ impl<T: Field> FlattenOutput<T> for FlatExpression<T> {
// (self, other)
// }
fn flat(&self) -> Vec<FlatExpression<T>> {
vec![self.clone()]
fn flat(&self) -> FlatExpression<T> {
self.clone()
}
}
@ -71,14 +71,8 @@ impl<T: Field> FlattenOutput<T> for FlatUExpression<T> {
// )
// }
fn flat(&self) -> Vec<FlatExpression<T>> {
self.bits
.clone()
.unwrap()
.clone()
.into_iter()
.chain(std::iter::once(self.field.clone().unwrap()))
.collect()
fn flat(&self) -> FlatExpression<T> {
self.clone().get_field_unchecked()
}
}
@ -241,71 +235,45 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let consequence = consequence.flat();
let alternative = alternative.flat();
let size = consequence.len();
let consequence_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(consequence_id, consequence));
let consequence_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
statements_flattened.extend(
consequence
.into_iter()
.zip(consequence_ids.iter())
.map(|(c, c_id)| FlatStatement::Definition(*c_id, c)),
);
let alternative_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(alternative_id, alternative));
let alternative_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
statements_flattened.extend(
alternative
.into_iter()
.zip(alternative_ids.iter())
.map(|(a, a_id)| FlatStatement::Definition(*a_id, a)),
);
let term0_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
statements_flattened.extend(consequence_ids.iter().zip(term0_ids.iter()).map(
|(c_id, t0_id)| {
FlatStatement::Definition(
*t0_id,
FlatExpression::Mult(
box condition_id.clone().into(),
box FlatExpression::from(*c_id),
),
)
},
let term0_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(
term0_id,
FlatExpression::Mult(
box condition_id.clone().into(),
box FlatExpression::from(consequence_id),
),
));
let term1_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
statements_flattened.extend(alternative_ids.iter().zip(term1_ids.iter()).map(
|(a_id, t1_id)| {
FlatStatement::Definition(
*t1_id,
FlatExpression::Mult(
box FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box condition_id.into(),
),
box FlatExpression::from(*a_id),
),
)
},
let term1_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(
term1_id,
FlatExpression::Mult(
box FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box condition_id.into(),
),
box FlatExpression::from(alternative_id),
),
));
let res: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
statements_flattened.extend(term0_ids.iter().zip(term1_ids).zip(res.iter()).map(
|((t0_id, t1_id), r_id)| {
FlatStatement::Definition(
*r_id,
FlatExpression::Add(
box FlatExpression::from(*t0_id),
box FlatExpression::from(t1_id),
),
)
},
let res = self.use_sym();
statements_flattened.push(FlatStatement::Definition(
res,
FlatExpression::Add(
box FlatExpression::from(term0_id),
box FlatExpression::from(term1_id),
),
));
let mut res: Vec<_> = res.into_iter().map(|r| r.into()).collect();
FlatUExpression {
field: Some(res.pop().unwrap()),
bits: Some(res),
field: Some(FlatExpression::Identifier(res)),
bits: None,
}
}

View file

@ -63,10 +63,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
use self::UExpressionInner::*;
let res = match inner {
Value(v) => Value(v).annotate(range).metadata(UMetadata {
max: v.into(),
should_reduce: Some(false),
}),
Value(v) => Value(v).annotate(range).with_max(v),
Identifier(id) => Identifier(id.clone()).annotate(range).metadata(
self.ids
.get(&Variable::uint(id.clone(), range))
@ -108,11 +105,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
right
};
UExpression::add(left, right).metadata(UMetadata {
max,
should_reduce: Some(false),
})
UExpression::add(left, right).with_max(max)
}
Sub(box left, box right) => {
// let `target` the target bitwidth of `left` and `right`
@ -171,40 +164,28 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
right
};
UExpression::sub(left, right).metadata(UMetadata {
max,
should_reduce: Some(false),
})
UExpression::sub(left, right).with_max(max)
}
Xor(box left, box right) => {
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
UExpression::xor(force_reduce(left), force_reduce(right)).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
UExpression::xor(force_reduce(left), force_reduce(right)).with_max(range_max)
}
And(box left, box right) => {
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
UExpression::and(force_reduce(left), force_reduce(right)).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
UExpression::and(force_reduce(left), force_reduce(right)).with_max(range_max)
}
Or(box left, box right) => {
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
UExpression::or(force_reduce(left), force_reduce(right)).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
UExpression::or(force_reduce(left), force_reduce(right)).with_max(range_max)
}
Mult(box left, box right) => {
// reduce the two terms
@ -241,40 +222,28 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
right
};
UExpression::mult(left, right).metadata(UMetadata {
max,
should_reduce: Some(false),
})
UExpression::mult(left, right).with_max(max)
}
Not(box e) => {
let e = self.fold_uint_expression(e);
UExpressionInner::Not(box force_reduce(e))
.annotate(range)
.metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
.with_max(range_max)
}
LeftShift(box e, box by) => {
// reduce the two terms
let e = self.fold_uint_expression(e);
let by = self.fold_field_expression(by);
UExpression::left_shift(force_reduce(e), by).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(true),
})
UExpression::left_shift(force_reduce(e), by).with_max(range_max)
}
RightShift(box e, box by) => {
// reduce the two terms
let e = self.fold_uint_expression(e);
let by = self.fold_field_expression(by);
UExpression::right_shift(force_reduce(e), by).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
UExpression::right_shift(force_reduce(e), by).with_max(range_max)
}
IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_uint_expression(consequence);
@ -288,10 +257,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
alternative_max.into_big_uint(),
);
UExpression::if_else(condition, consequence, alternative).metadata(UMetadata {
max: max.into(),
should_reduce: Some(false),
})
UExpression::if_else(condition, consequence, alternative).with_max(max)
}
};
@ -323,13 +289,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
ZirExpression::Uint(e) => {
let e = self.fold_uint_expression(e);
let e = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..e.metadata.unwrap()
}),
..e
};
let e = force_reduce(e);
ZirExpression::Uint(e)
}
@ -370,19 +330,9 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
ZirStatement::Condition(lhs, rhs) => {
match (self.fold_expression(lhs), self.fold_expression(rhs)) {
(ZirExpression::Uint(lhs), ZirExpression::Uint(rhs)) => {
let lhs_metadata = lhs.metadata.clone().unwrap();
let rhs_metadata = rhs.metadata.clone().unwrap();
vec![ZirStatement::Condition(
lhs.metadata(UMetadata {
should_reduce: Some(true),
..lhs_metadata
})
.into(),
rhs.metadata(UMetadata {
should_reduce: Some(true),
..rhs_metadata
})
.into(),
force_reduce(lhs).into(),
force_reduce(rhs).into(),
)]
}
(lhs, rhs) => vec![ZirStatement::Condition(lhs, rhs)],

View file

@ -126,13 +126,20 @@ impl<'ast, T> UExpressionInner<'ast, T> {
}
}
impl<'ast, T> UExpression<'ast, T> {
impl<'ast, T: Field> UExpression<'ast, T> {
pub fn metadata(self, metadata: UMetadata<T>) -> UExpression<'ast, T> {
UExpression {
metadata: Some(metadata),
..self
}
}
pub fn with_max<U: Into<T>>(self, max: U) -> Self {
UExpression {
metadata: Some(UMetadata::with_max(max)),
..self
}
}
}
impl<'ast, T> UExpression<'ast, T> {

View file

@ -8,7 +8,7 @@
},
"output": {
"Ok": {
"values": ["0xff"]
"values": ["0x00"]
}
}
},
@ -18,7 +18,7 @@
},
"output": {
"Ok": {
"values": ["0x00"]
"values": ["0xff"]
}
}
}