simplify uint optimizer with with_max helper, change strategy for if_else
This commit is contained in:
parent
9e18975e84
commit
c8111c77c2
4 changed files with 59 additions and 134 deletions
|
@ -31,7 +31,7 @@ pub struct Flattener<'ast, T: Field> {
|
|||
trait FlattenOutput<T: Field>: Sized {
|
||||
// fn branches(self, other: Self) -> (Self, Self);
|
||||
|
||||
fn flat(&self) -> Vec<FlatExpression<T>>;
|
||||
fn flat(&self) -> FlatExpression<T>;
|
||||
}
|
||||
|
||||
impl<T: Field> FlattenOutput<T> for FlatExpression<T> {
|
||||
|
@ -39,8 +39,8 @@ impl<T: Field> FlattenOutput<T> for FlatExpression<T> {
|
|||
// (self, other)
|
||||
// }
|
||||
|
||||
fn flat(&self) -> Vec<FlatExpression<T>> {
|
||||
vec![self.clone()]
|
||||
fn flat(&self) -> FlatExpression<T> {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,14 +71,8 @@ impl<T: Field> FlattenOutput<T> for FlatUExpression<T> {
|
|||
// )
|
||||
// }
|
||||
|
||||
fn flat(&self) -> Vec<FlatExpression<T>> {
|
||||
self.bits
|
||||
.clone()
|
||||
.unwrap()
|
||||
.clone()
|
||||
.into_iter()
|
||||
.chain(std::iter::once(self.field.clone().unwrap()))
|
||||
.collect()
|
||||
fn flat(&self) -> FlatExpression<T> {
|
||||
self.clone().get_field_unchecked()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -241,71 +235,45 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let consequence = consequence.flat();
|
||||
let alternative = alternative.flat();
|
||||
|
||||
let size = consequence.len();
|
||||
let consequence_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(consequence_id, consequence));
|
||||
|
||||
let consequence_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
|
||||
statements_flattened.extend(
|
||||
consequence
|
||||
.into_iter()
|
||||
.zip(consequence_ids.iter())
|
||||
.map(|(c, c_id)| FlatStatement::Definition(*c_id, c)),
|
||||
);
|
||||
let alternative_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(alternative_id, alternative));
|
||||
|
||||
let alternative_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
|
||||
statements_flattened.extend(
|
||||
alternative
|
||||
.into_iter()
|
||||
.zip(alternative_ids.iter())
|
||||
.map(|(a, a_id)| FlatStatement::Definition(*a_id, a)),
|
||||
);
|
||||
|
||||
let term0_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
|
||||
statements_flattened.extend(consequence_ids.iter().zip(term0_ids.iter()).map(
|
||||
|(c_id, t0_id)| {
|
||||
FlatStatement::Definition(
|
||||
*t0_id,
|
||||
FlatExpression::Mult(
|
||||
box condition_id.clone().into(),
|
||||
box FlatExpression::from(*c_id),
|
||||
),
|
||||
)
|
||||
},
|
||||
let term0_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
term0_id,
|
||||
FlatExpression::Mult(
|
||||
box condition_id.clone().into(),
|
||||
box FlatExpression::from(consequence_id),
|
||||
),
|
||||
));
|
||||
|
||||
let term1_ids: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
|
||||
statements_flattened.extend(alternative_ids.iter().zip(term1_ids.iter()).map(
|
||||
|(a_id, t1_id)| {
|
||||
FlatStatement::Definition(
|
||||
*t1_id,
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box condition_id.into(),
|
||||
),
|
||||
box FlatExpression::from(*a_id),
|
||||
),
|
||||
)
|
||||
},
|
||||
let term1_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
term1_id,
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box condition_id.into(),
|
||||
),
|
||||
box FlatExpression::from(alternative_id),
|
||||
),
|
||||
));
|
||||
|
||||
let res: Vec<_> = (0..size).map(|_| self.use_sym()).collect();
|
||||
statements_flattened.extend(term0_ids.iter().zip(term1_ids).zip(res.iter()).map(
|
||||
|((t0_id, t1_id), r_id)| {
|
||||
FlatStatement::Definition(
|
||||
*r_id,
|
||||
FlatExpression::Add(
|
||||
box FlatExpression::from(*t0_id),
|
||||
box FlatExpression::from(t1_id),
|
||||
),
|
||||
)
|
||||
},
|
||||
let res = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
res,
|
||||
FlatExpression::Add(
|
||||
box FlatExpression::from(term0_id),
|
||||
box FlatExpression::from(term1_id),
|
||||
),
|
||||
));
|
||||
|
||||
let mut res: Vec<_> = res.into_iter().map(|r| r.into()).collect();
|
||||
|
||||
FlatUExpression {
|
||||
field: Some(res.pop().unwrap()),
|
||||
bits: Some(res),
|
||||
field: Some(FlatExpression::Identifier(res)),
|
||||
bits: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -63,10 +63,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
use self::UExpressionInner::*;
|
||||
|
||||
let res = match inner {
|
||||
Value(v) => Value(v).annotate(range).metadata(UMetadata {
|
||||
max: v.into(),
|
||||
should_reduce: Some(false),
|
||||
}),
|
||||
Value(v) => Value(v).annotate(range).with_max(v),
|
||||
Identifier(id) => Identifier(id.clone()).annotate(range).metadata(
|
||||
self.ids
|
||||
.get(&Variable::uint(id.clone(), range))
|
||||
|
@ -108,11 +105,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
right
|
||||
};
|
||||
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
max,
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
UExpression::add(left, right).with_max(max)
|
||||
}
|
||||
Sub(box left, box right) => {
|
||||
// let `target` the target bitwidth of `left` and `right`
|
||||
|
@ -171,40 +164,28 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
right
|
||||
};
|
||||
|
||||
UExpression::sub(left, right).metadata(UMetadata {
|
||||
max,
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
UExpression::sub(left, right).with_max(max)
|
||||
}
|
||||
Xor(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
UExpression::xor(force_reduce(left), force_reduce(right)).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
UExpression::xor(force_reduce(left), force_reduce(right)).with_max(range_max)
|
||||
}
|
||||
And(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
UExpression::and(force_reduce(left), force_reduce(right)).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
UExpression::and(force_reduce(left), force_reduce(right)).with_max(range_max)
|
||||
}
|
||||
Or(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
UExpression::or(force_reduce(left), force_reduce(right)).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
UExpression::or(force_reduce(left), force_reduce(right)).with_max(range_max)
|
||||
}
|
||||
Mult(box left, box right) => {
|
||||
// reduce the two terms
|
||||
|
@ -241,40 +222,28 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
right
|
||||
};
|
||||
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
max,
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
UExpression::mult(left, right).with_max(max)
|
||||
}
|
||||
Not(box e) => {
|
||||
let e = self.fold_uint_expression(e);
|
||||
|
||||
UExpressionInner::Not(box force_reduce(e))
|
||||
.annotate(range)
|
||||
.metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
.with_max(range_max)
|
||||
}
|
||||
LeftShift(box e, box by) => {
|
||||
// reduce the two terms
|
||||
let e = self.fold_uint_expression(e);
|
||||
let by = self.fold_field_expression(by);
|
||||
|
||||
UExpression::left_shift(force_reduce(e), by).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
should_reduce: Some(true),
|
||||
})
|
||||
UExpression::left_shift(force_reduce(e), by).with_max(range_max)
|
||||
}
|
||||
RightShift(box e, box by) => {
|
||||
// reduce the two terms
|
||||
let e = self.fold_uint_expression(e);
|
||||
let by = self.fold_field_expression(by);
|
||||
|
||||
UExpression::right_shift(force_reduce(e), by).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
UExpression::right_shift(force_reduce(e), by).with_max(range_max)
|
||||
}
|
||||
IfElse(box condition, box consequence, box alternative) => {
|
||||
let consequence = self.fold_uint_expression(consequence);
|
||||
|
@ -288,10 +257,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
alternative_max.into_big_uint(),
|
||||
);
|
||||
|
||||
UExpression::if_else(condition, consequence, alternative).metadata(UMetadata {
|
||||
max: max.into(),
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
UExpression::if_else(condition, consequence, alternative).with_max(max)
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -323,13 +289,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
ZirExpression::Uint(e) => {
|
||||
let e = self.fold_uint_expression(e);
|
||||
|
||||
let e = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..e.metadata.unwrap()
|
||||
}),
|
||||
..e
|
||||
};
|
||||
let e = force_reduce(e);
|
||||
|
||||
ZirExpression::Uint(e)
|
||||
}
|
||||
|
@ -370,19 +330,9 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
ZirStatement::Condition(lhs, rhs) => {
|
||||
match (self.fold_expression(lhs), self.fold_expression(rhs)) {
|
||||
(ZirExpression::Uint(lhs), ZirExpression::Uint(rhs)) => {
|
||||
let lhs_metadata = lhs.metadata.clone().unwrap();
|
||||
let rhs_metadata = rhs.metadata.clone().unwrap();
|
||||
vec![ZirStatement::Condition(
|
||||
lhs.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..lhs_metadata
|
||||
})
|
||||
.into(),
|
||||
rhs.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..rhs_metadata
|
||||
})
|
||||
.into(),
|
||||
force_reduce(lhs).into(),
|
||||
force_reduce(rhs).into(),
|
||||
)]
|
||||
}
|
||||
(lhs, rhs) => vec![ZirStatement::Condition(lhs, rhs)],
|
||||
|
|
|
@ -126,13 +126,20 @@ impl<'ast, T> UExpressionInner<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> UExpression<'ast, T> {
|
||||
impl<'ast, T: Field> UExpression<'ast, T> {
|
||||
pub fn metadata(self, metadata: UMetadata<T>) -> UExpression<'ast, T> {
|
||||
UExpression {
|
||||
metadata: Some(metadata),
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_max<U: Into<T>>(self, max: U) -> Self {
|
||||
UExpression {
|
||||
metadata: Some(UMetadata::with_max(max)),
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> UExpression<'ast, T> {
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0xff"]
|
||||
"values": ["0x00"]
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -18,7 +18,7 @@
|
|||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0x00"]
|
||||
"values": ["0xff"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue