implement sub, clean
This commit is contained in:
parent
2be4859b28
commit
5a145921d4
9 changed files with 224 additions and 173 deletions
4
test.zok
4
test.zok
|
@ -1,2 +1,2 @@
|
|||
def main(u8 a) -> (u8):
|
||||
return a + a
|
||||
def main(u8 a, u8 b, u8 c) -> (u8):
|
||||
return a - b - c
|
|
@ -152,8 +152,6 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
|
|||
// analyse (unroll and constant propagation)
|
||||
let typed_ast = typed_ast.analyse();
|
||||
|
||||
println!("{:#?}", typed_ast);
|
||||
|
||||
// flatten input program
|
||||
let program_flattened = Flattener::flatten(typed_ast);
|
||||
|
||||
|
|
|
@ -9,16 +9,11 @@ use crate::flat_absy::*;
|
|||
use crate::solvers::Solver;
|
||||
use crate::zir::types::{FunctionIdentifier, FunctionKey, Signature, Type};
|
||||
use crate::zir::*;
|
||||
use num_bigint::BigUint;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use zokrates_field::field::Field;
|
||||
|
||||
fn log2(a: BigUint) -> u32 {
|
||||
a.bits() as u32
|
||||
}
|
||||
|
||||
/// Flattener, computes flattened program.
|
||||
#[derive(Debug)]
|
||||
pub struct Flattener<'ast, T: Field> {
|
||||
|
@ -675,7 +670,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
ZirExpression::Uint(e) => e,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let from = log2(p.metadata.clone().unwrap().max);
|
||||
let from = p.metadata.clone().unwrap().bitwidth();
|
||||
let p = self.flatten_uint_expression(symbols, statements_flattened, p);
|
||||
let bits = self
|
||||
.get_bits(p, from as usize, 32, statements_flattened)
|
||||
|
@ -827,7 +822,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
let metadata = expr.metadata.clone().unwrap().clone();
|
||||
|
||||
let actual_bitwidth = log2(metadata.max) as usize;
|
||||
let actual_bitwidth = metadata.bitwidth() as usize;
|
||||
let should_reduce = metadata.should_reduce.unwrap();
|
||||
|
||||
let should_reduce = should_reduce && actual_bitwidth > target_bitwidth;
|
||||
|
@ -845,7 +840,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FlatUExpression::with_field(field).bits(bits)
|
||||
}
|
||||
UExpressionInner::Not(box e) => {
|
||||
let from = log2(e.metadata.clone().unwrap().max);
|
||||
let from = e.metadata.clone().unwrap().bitwidth();
|
||||
|
||||
let e_flattened = self.flatten_uint_expression(symbols, statements_flattened, e);
|
||||
|
||||
|
@ -904,7 +899,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
UExpressionInner::Sub(box left, box right) => {
|
||||
let aux = FlatExpression::Number(
|
||||
T::from(2).pow(log2(right.metadata.clone().unwrap().max) as usize),
|
||||
T::from(2).pow(right.metadata.clone().unwrap().bitwidth() as usize),
|
||||
);
|
||||
|
||||
let left_flattened = self
|
||||
|
@ -934,7 +929,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
))
|
||||
}
|
||||
UExpressionInner::LeftShift(box e, box by) => {
|
||||
let from = log2(e.metadata.clone().unwrap().max);
|
||||
let from = e.metadata.clone().unwrap().bitwidth();
|
||||
|
||||
let e = self.flatten_uint_expression(symbols, statements_flattened, e);
|
||||
|
||||
|
@ -961,7 +956,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
)
|
||||
}
|
||||
UExpressionInner::RightShift(box e, box by) => {
|
||||
let from = log2(e.metadata.clone().unwrap().max);
|
||||
let from = e.metadata.clone().unwrap().bitwidth();
|
||||
|
||||
let e = self.flatten_uint_expression(symbols, statements_flattened, e);
|
||||
|
||||
|
@ -1028,8 +1023,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
)
|
||||
.clone(),
|
||||
UExpressionInner::Xor(box left, box right) => {
|
||||
let left_from = log2(left.metadata.clone().unwrap().max);
|
||||
let right_from = log2(right.metadata.clone().unwrap().max);
|
||||
let left_from = left.metadata.clone().unwrap().bitwidth();
|
||||
let right_from = right.metadata.clone().unwrap().bitwidth();
|
||||
|
||||
let left_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, left);
|
||||
|
@ -1061,13 +1056,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.flat_map(|(name, (x, y))| match (x, y) {
|
||||
(FlatExpression::Number(n), e) | (e, FlatExpression::Number(n)) => {
|
||||
if *n == T::from(0) {
|
||||
vec![FlatStatement::Definition(name.clone(), y.clone())]
|
||||
vec![FlatStatement::Definition(name.clone(), e.clone())]
|
||||
} else if *n == T::from(1) {
|
||||
vec![FlatStatement::Definition(
|
||||
name.clone(),
|
||||
FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::from(1)),
|
||||
box y.clone(),
|
||||
box e.clone(),
|
||||
),
|
||||
)]
|
||||
} else {
|
||||
|
@ -1105,8 +1100,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
)
|
||||
}
|
||||
UExpressionInner::And(box left, box right) => {
|
||||
let left_from = log2(left.metadata.clone().unwrap().max);
|
||||
let right_from = log2(right.metadata.clone().unwrap().max);
|
||||
let left_from = left.metadata.clone().unwrap().bitwidth();
|
||||
let right_from = right.metadata.clone().unwrap().bitwidth();
|
||||
|
||||
let left_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, left);
|
||||
|
@ -1186,10 +1181,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
self.depth -= 1;
|
||||
|
||||
statements_flattened.push(FlatStatement::Log(format!(
|
||||
" {} DONE",
|
||||
" ".repeat(self.depth)
|
||||
)));
|
||||
// statements_flattened.push(FlatStatement::Log(format!(
|
||||
// " {} DONE",
|
||||
// " ".repeat(self.depth)
|
||||
// )));
|
||||
|
||||
res
|
||||
}
|
||||
|
|
|
@ -18,25 +18,22 @@ impl<T: Field> Prog<T> {
|
|||
|
||||
for statement in main.statements.iter() {
|
||||
match statement {
|
||||
Statement::Constraint(quad, lin) => {
|
||||
println!("{}", statement);
|
||||
match lin.is_assignee(&witness) {
|
||||
true => {
|
||||
let val = quad.evaluate(&witness).unwrap();
|
||||
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
|
||||
}
|
||||
false => {
|
||||
let lhs_value = quad.evaluate(&witness).unwrap();
|
||||
let rhs_value = lin.evaluate(&witness).unwrap();
|
||||
if lhs_value != rhs_value {
|
||||
return Err(Error::UnsatisfiedConstraint {
|
||||
left: lhs_value.to_dec_string(),
|
||||
right: rhs_value.to_dec_string(),
|
||||
});
|
||||
}
|
||||
Statement::Constraint(quad, lin) => match lin.is_assignee(&witness) {
|
||||
true => {
|
||||
let val = quad.evaluate(&witness).unwrap();
|
||||
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
|
||||
}
|
||||
false => {
|
||||
let lhs_value = quad.evaluate(&witness).unwrap();
|
||||
let rhs_value = lin.evaluate(&witness).unwrap();
|
||||
if lhs_value != rhs_value {
|
||||
return Err(Error::UnsatisfiedConstraint {
|
||||
left: lhs_value.to_dec_string(),
|
||||
right: rhs_value.to_dec_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Statement::Directive(ref d) => {
|
||||
let input_values: Vec<T> = d
|
||||
.inputs
|
||||
|
|
|
@ -383,8 +383,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
|
||||
UExpressionInner::Value(v1 & v2)
|
||||
}
|
||||
(UExpressionInner::Value(0), e2) => UExpressionInner::Value(0),
|
||||
(e1, UExpressionInner::Value(0)) => UExpressionInner::Value(0),
|
||||
(UExpressionInner::Value(0), _) | (_, UExpressionInner::Value(0)) => {
|
||||
UExpressionInner::Value(0)
|
||||
}
|
||||
(e1, e2) => {
|
||||
UExpressionInner::And(box e1.annotate(bitwidth), box e2.annotate(bitwidth))
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ use zokrates_field::field::Field;
|
|||
|
||||
#[derive(Default)]
|
||||
pub struct UintOptimizer<'ast, T: Field> {
|
||||
ids: HashMap<ZirAssignee<'ast>, UMetadata>,
|
||||
ids: HashMap<ZirAssignee<'ast>, UMetadata<T>>,
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,7 @@ impl<'ast, T: Field> UintOptimizer<'ast, T> {
|
|||
UintOptimizer::new().fold_program(p)
|
||||
}
|
||||
|
||||
fn register(&mut self, a: ZirAssignee<'ast>, m: UMetadata) {
|
||||
fn register(&mut self, a: ZirAssignee<'ast>, m: UMetadata<T>) {
|
||||
self.ids.insert(a, m);
|
||||
}
|
||||
}
|
||||
|
@ -51,19 +51,18 @@ fn force_no_reduce<'ast, T: Field>(e: UExpression<'ast, T>) -> UExpression<'ast,
|
|||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
||||
fn fold_uint_expression(&mut self, e: UExpression<'ast, T>) -> UExpression<'ast, T> {
|
||||
if e.metadata.is_some() {
|
||||
return e;
|
||||
}
|
||||
|
||||
let max_bitwidth = T::get_required_bits() - 1;
|
||||
|
||||
let range = e.bitwidth;
|
||||
|
||||
let range_max: BigUint = (2_usize.pow(range as u32) - 1).into();
|
||||
let range_max: T = (2_usize.pow(range as u32) - 1).into();
|
||||
|
||||
assert!(range < max_bitwidth / 2);
|
||||
|
||||
if e.metadata.is_some() {
|
||||
unreachable!("{:?} had metadata", e);
|
||||
}
|
||||
|
||||
let metadata = e.metadata;
|
||||
let inner = e.inner;
|
||||
|
||||
use self::UExpressionInner::*;
|
||||
|
@ -71,7 +70,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
let res = match inner {
|
||||
Value(v) => Value(v).annotate(range).metadata(UMetadata {
|
||||
max: v.into(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
}),
|
||||
Identifier(id) => Identifier(id.clone()).annotate(range).metadata(
|
||||
|
@ -81,8 +79,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
.expect(&format!("identifier should have been defined: {}", id)),
|
||||
),
|
||||
Add(box left, box right) => {
|
||||
use num::CheckedAdd;
|
||||
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
@ -119,113 +115,81 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
max,
|
||||
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
Sub(box left, box right) => {
|
||||
unimplemented!()
|
||||
// // reduce the two terms
|
||||
// let left = self.fold_uint_expression(left);
|
||||
// let right = self.fold_uint_expression(right);
|
||||
use num::traits::{CheckedAdd, Pow};
|
||||
|
||||
// let left_metadata = left.metadata.clone().unwrap();
|
||||
// let right_metadata = right.metadata.clone().unwrap();
|
||||
// let `target` the target bitwidth of `left` and `right`
|
||||
// `0 <= left <= max_left`
|
||||
// `0 <= right <= max_right`
|
||||
// `- max_right <= left - right <= max_right`
|
||||
// let `n_bits_left` the number of bits needed to represent `max_left`
|
||||
// let `n = max(n_bits_left, target)`
|
||||
// let offset = 2**n`
|
||||
|
||||
// // determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
|
||||
// let left_bitwidth = left_metadata
|
||||
// .should_reduce
|
||||
// .map(|should_reduce| {
|
||||
// if should_reduce {
|
||||
// range
|
||||
// } else {
|
||||
// left_metadata.bitwidth.unwrap()
|
||||
// }
|
||||
// })
|
||||
// .unwrap();
|
||||
// let right_bitwidth = right_metadata
|
||||
// .should_reduce
|
||||
// .map(|should_reduce| {
|
||||
// if should_reduce {
|
||||
// range
|
||||
// } else {
|
||||
// right_metadata.bitwidth.unwrap()
|
||||
// }
|
||||
// })
|
||||
// .unwrap();
|
||||
// `2**n - max_left <= a - b + 2 ** n <= bound where bound = max_left + offset`
|
||||
|
||||
// // a(p), b(q) both of target n (p and q their real bitwidth)
|
||||
// // a(p) - b(q) can always underflow
|
||||
// // instead consider s = a(p) - b(q) + 2**q which is always positive
|
||||
// // the min of s is 0 and the max is 2**p + 2**q, which is smaller than 2**(max(p, q) + 1)
|
||||
// If ´bound < N´, we set we return `bound` as the max of ´left - right`
|
||||
// Else we start again, reducing `left`. In this case `max_left` becomes `2**target - 1`
|
||||
// Else we start again, reducing `right`. In this case `offset` becomes `2**target`
|
||||
// Else we start again reducing both. In this case `bound` becomes `2**(target+1) - 1` which is always
|
||||
// smaller or equal to N for target in {8, 16, 32}
|
||||
|
||||
// // so we can use s(max(p, q) + 1) as a representation of a - b if max(p, q) + 1 < max_bitwidth
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
// let output_width = std::cmp::max(left_bitwidth, right_bitwidth) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
|
||||
let left_max = left.metadata.clone().unwrap().max;
|
||||
let right_bitwidth = right.metadata.clone().unwrap().bitwidth();
|
||||
|
||||
// if output_width > max_bitwidth {
|
||||
// // the addition doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
|
||||
let offset =
|
||||
T::from(2u32).pow(std::cmp::max(right_bitwidth, range as u32) as usize);
|
||||
let target_offset = T::from(2u32).pow(range);
|
||||
|
||||
// let left = UExpression {
|
||||
// metadata: Some(UMetadata {
|
||||
// should_reduce: Some(true),
|
||||
// ..left_metadata
|
||||
// }),
|
||||
// ..left
|
||||
// };
|
||||
let (should_reduce_left, should_reduce_right, max) = left_max
|
||||
.checked_add(&offset)
|
||||
.map(|max| (false, false, max))
|
||||
.unwrap_or_else(|| {
|
||||
range_max
|
||||
.clone()
|
||||
.checked_add(&offset)
|
||||
.map(|max| (true, false, max))
|
||||
.unwrap_or_else(|| {
|
||||
left_max
|
||||
.checked_add(&target_offset.clone())
|
||||
.map(|max| (false, true, max))
|
||||
.unwrap_or_else(|| {
|
||||
(true, true, range_max.clone() + target_offset)
|
||||
})
|
||||
})
|
||||
});
|
||||
|
||||
// let right = UExpression {
|
||||
// metadata: Some(UMetadata {
|
||||
// should_reduce: Some(true),
|
||||
// ..right_metadata
|
||||
// }),
|
||||
// ..right
|
||||
// };
|
||||
let left = if should_reduce_left {
|
||||
force_reduce(left)
|
||||
} else {
|
||||
left
|
||||
};
|
||||
let right = if should_reduce_right {
|
||||
force_reduce(right)
|
||||
} else {
|
||||
right
|
||||
};
|
||||
|
||||
// UExpression::sub(left, right).metadata(UMetadata {
|
||||
// max: 2_u32 * range_max,
|
||||
// bitwidth: Some(range + 1),
|
||||
// should_reduce: Some(
|
||||
// metadata
|
||||
// .map(|m| m.should_reduce.unwrap_or(false))
|
||||
// .unwrap_or(false),
|
||||
// ),
|
||||
// })
|
||||
// } else {
|
||||
// UExpression::sub(left, right).metadata(UMetadata {
|
||||
// max: None,
|
||||
// bitwidth: Some(output_width),
|
||||
// should_reduce: Some(
|
||||
// metadata
|
||||
// .map(|m| m.should_reduce.unwrap_or(false))
|
||||
// .unwrap_or(false),
|
||||
// ),
|
||||
// })
|
||||
// }
|
||||
UExpression::sub(left, right).metadata(UMetadata {
|
||||
max,
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
Xor(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// for xor we need both terms to be in range. Therefore we reduce them to being in range.
|
||||
// NB: if they are already in range, the flattening process will ignore the reduction
|
||||
let left = left.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
});
|
||||
|
||||
let right = right.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
});
|
||||
|
||||
UExpression::xor(left, right).metadata(UMetadata {
|
||||
UExpression::xor(force_reduce(left), force_reduce(right)).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
|
@ -236,7 +200,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
|
||||
UExpression::and(force_reduce(left), force_reduce(right)).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
|
@ -247,7 +210,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
|
||||
UExpression::or(force_reduce(left), force_reduce(right)).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
|
@ -300,7 +262,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
.annotate(range)
|
||||
.metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
|
@ -321,7 +282,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
|
||||
UExpression::right_shift(force_reduce(e), by).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
|
@ -332,15 +292,15 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
let consequence_max = consequence.metadata.clone().unwrap().max;
|
||||
let alternative_max = alternative.metadata.clone().unwrap().max;
|
||||
|
||||
let max = std::cmp::max(consequence_max, alternative_max);
|
||||
unimplemented!();
|
||||
|
||||
UExpression::if_else(condition, consequence, alternative).metadata(UMetadata {
|
||||
max,
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
// let max = std::cmp::max(consequence_max, alternative_max);
|
||||
|
||||
// UExpression::if_else(condition, consequence, alternative).metadata(UMetadata {
|
||||
// max,
|
||||
// should_reduce: Some(false),
|
||||
// })
|
||||
}
|
||||
e => unimplemented!("{:?}", e),
|
||||
};
|
||||
|
||||
assert!(res.metadata.is_some());
|
||||
|
@ -392,7 +352,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
self.register(
|
||||
lhs[0].clone(),
|
||||
UMetadata {
|
||||
max: BigUint::from(2_u64.pow(32_u32) - 1),
|
||||
max: T::from(2).pow(32) - T::from(1),
|
||||
should_reduce: Some(false),
|
||||
},
|
||||
);
|
||||
|
@ -446,7 +406,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
self.register(
|
||||
p.id.clone(),
|
||||
UMetadata {
|
||||
max: BigUint::from(2_u64.pow(bitwidth as u32) - 1),
|
||||
max: T::from(2_u32).pow(bitwidth) - T::from(1),
|
||||
should_reduce: Some(false),
|
||||
},
|
||||
);
|
||||
|
@ -465,21 +425,99 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use zokrates_field::field::FieldPrime;
|
||||
use zokrates_field::field::{FieldPrime, Pow};
|
||||
|
||||
#[should_panic]
|
||||
#[test]
|
||||
fn existing_metadata() {
|
||||
let e = UExpressionInner::Identifier("foo".into())
|
||||
.annotate(32)
|
||||
.metadata(UMetadata {
|
||||
max: BigUint::from(2_u64.pow(33_u32) - 1),
|
||||
should_reduce: Some(false),
|
||||
});
|
||||
.metadata(UMetadata::with_max(2_u32.pow(33_u32) - 1));
|
||||
|
||||
let mut optimizer: UintOptimizer<FieldPrime> = UintOptimizer::new();
|
||||
|
||||
let optimized = optimizer.fold_uint_expression(e.clone());
|
||||
let _ = optimizer.fold_uint_expression(e.clone());
|
||||
}
|
||||
|
||||
assert_eq!(e, optimized);
|
||||
#[test]
|
||||
fn add() {
|
||||
// max(left + right) = max(left) + max(right)
|
||||
|
||||
let left: UExpression<FieldPrime> = UExpressionInner::Identifier("foo".into())
|
||||
.annotate(32)
|
||||
.metadata(UMetadata::with_max(42u32));
|
||||
|
||||
let right = UExpressionInner::Identifier("foo".into())
|
||||
.annotate(32)
|
||||
.metadata(UMetadata::with_max(33u32));
|
||||
|
||||
assert_eq!(
|
||||
UintOptimizer::new()
|
||||
.fold_uint_expression(UExpression::add(left, right))
|
||||
.metadata
|
||||
.unwrap()
|
||||
.max,
|
||||
75u32.into()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub() {
|
||||
// `left` and `right` are smaller than the target
|
||||
let left: UExpression<FieldPrime> = UExpressionInner::Identifier("a".into())
|
||||
.annotate(32)
|
||||
.metadata(UMetadata::with_max(42u32));
|
||||
|
||||
let right = UExpressionInner::Identifier("b".into())
|
||||
.annotate(32)
|
||||
.metadata(UMetadata::with_max(33u32));
|
||||
|
||||
assert_eq!(
|
||||
UintOptimizer::new()
|
||||
.fold_uint_expression(UExpression::sub(left, right))
|
||||
.metadata
|
||||
.unwrap()
|
||||
.max,
|
||||
FieldPrime::from(2u32).pow(32) + FieldPrime::from(42)
|
||||
);
|
||||
|
||||
// `left` and `right` are larger than the target but no readjustment is required
|
||||
let left: UExpression<FieldPrime> = UExpressionInner::Identifier("a".into())
|
||||
.annotate(32)
|
||||
.metadata(UMetadata::with_max(u64::MAX as u128));
|
||||
|
||||
let right = UExpressionInner::Identifier("b".into())
|
||||
.annotate(32)
|
||||
.metadata(UMetadata::with_max(u64::MAX as u128));
|
||||
|
||||
assert_eq!(
|
||||
UintOptimizer::new()
|
||||
.fold_uint_expression(UExpression::sub(left, right))
|
||||
.metadata
|
||||
.unwrap()
|
||||
.max,
|
||||
FieldPrime::from(2).pow(64) + FieldPrime::from(u64::MAX as u128)
|
||||
);
|
||||
|
||||
// `left` and `right` are larger than the target and needs to be readjusted
|
||||
let left: UExpression<FieldPrime> = UExpressionInner::Identifier("a".into())
|
||||
.annotate(32)
|
||||
.metadata(UMetadata::with_max(
|
||||
FieldPrime::from(2u32).pow(FieldPrime::get_required_bits() - 1)
|
||||
- FieldPrime::from(1),
|
||||
));
|
||||
|
||||
let right = UExpressionInner::Identifier("b".into())
|
||||
.annotate(32)
|
||||
.metadata(UMetadata::with_max(42u32));
|
||||
|
||||
assert_eq!(
|
||||
UintOptimizer::new()
|
||||
.fold_uint_expression(UExpression::sub(left, right))
|
||||
.metadata
|
||||
.unwrap()
|
||||
.max,
|
||||
FieldPrime::from(2u32).pow(32) * FieldPrime::from(2) - FieldPrime::from(1)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ pub use self::parameter::Parameter;
|
|||
pub use self::types::Type;
|
||||
pub use self::variable::Variable;
|
||||
use std::path::PathBuf;
|
||||
pub use zir::uint::{bitwidth, UExpression, UExpressionInner, UMetadata};
|
||||
pub use zir::uint::{UExpression, UExpressionInner, UMetadata};
|
||||
|
||||
use embed::FlatEmbed;
|
||||
use std::collections::HashMap;
|
||||
|
|
|
@ -68,15 +68,28 @@ impl<'ast, T: Field> From<&'ast str> for UExpressionInner<'ast, T> {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct UMetadata {
|
||||
pub max: BigUint,
|
||||
pub struct UMetadata<T> {
|
||||
pub max: T,
|
||||
pub should_reduce: Option<bool>,
|
||||
}
|
||||
|
||||
impl<T: Field> UMetadata<T> {
|
||||
pub fn with_max<U: Into<T>>(max: U) -> Self {
|
||||
UMetadata {
|
||||
max: max.into(),
|
||||
should_reduce: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bitwidth(&self) -> u32 {
|
||||
self.max.bits() as u32
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct UExpression<'ast, T: Field> {
|
||||
pub bitwidth: Bitwidth,
|
||||
pub metadata: Option<UMetadata>,
|
||||
pub metadata: Option<UMetadata<T>>,
|
||||
pub inner: UExpressionInner<'ast, T>,
|
||||
}
|
||||
|
||||
|
@ -117,7 +130,7 @@ impl<'ast, T: Field> UExpressionInner<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, T: Field> UExpression<'ast, T> {
|
||||
pub fn metadata(self, metadata: UMetadata) -> UExpression<'ast, T> {
|
||||
pub fn metadata(self, metadata: UMetadata<T>) -> UExpression<'ast, T> {
|
||||
UExpression {
|
||||
metadata: Some(metadata),
|
||||
..self
|
||||
|
@ -125,10 +138,6 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn bitwidth(a: u128) -> Bitwidth {
|
||||
(128 - a.leading_zeros()) as Bitwidth
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> UExpression<'ast, T> {
|
||||
pub fn bitwidth(&self) -> Bitwidth {
|
||||
self.bitwidth
|
||||
|
|
|
@ -96,6 +96,8 @@ pub trait Field:
|
|||
fn to_compact_dec_string(&self) -> String;
|
||||
/// Converts to BigUint
|
||||
fn into_big_uint(self) -> BigUint;
|
||||
/// Gets the number of bits
|
||||
fn bits(&self) -> u32;
|
||||
}
|
||||
|
||||
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Serialize, Deserialize)]
|
||||
|
@ -105,8 +107,13 @@ pub struct FieldPrime {
|
|||
|
||||
impl num_traits::CheckedAdd for FieldPrime {
|
||||
fn checked_add(&self, other: &Self) -> Option<Self> {
|
||||
use num_traits::Pow;
|
||||
let res = self.value.clone() + other.value.clone();
|
||||
if res >= *P {
|
||||
|
||||
let bound = BigInt::from(2u32).pow(Self::get_required_bits() - 1);
|
||||
|
||||
// we only go up to 2**(bitwidth - 1) because after that we lose uniqueness of bit decomposition
|
||||
if res >= bound {
|
||||
None
|
||||
} else {
|
||||
Some(FieldPrime { value: res })
|
||||
|
@ -116,8 +123,10 @@ impl num_traits::CheckedAdd for FieldPrime {
|
|||
|
||||
impl num_traits::CheckedMul for FieldPrime {
|
||||
fn checked_mul(&self, other: &Self) -> Option<Self> {
|
||||
use num_traits::Pow;
|
||||
let res = self.value.clone() * other.value.clone();
|
||||
if res >= *P {
|
||||
// we only go up to 2**(bitwidth - 1) because after that we lose uniqueness of bit decomposition
|
||||
if res >= BigInt::from(2u32).pow(Self::get_required_bits() - 1) {
|
||||
None
|
||||
} else {
|
||||
Some(FieldPrime { value: res })
|
||||
|
@ -150,6 +159,10 @@ impl Field for FieldPrime {
|
|||
self.value.to_str_radix(10)
|
||||
}
|
||||
|
||||
fn bits(&self) -> u32 {
|
||||
self.value.bits() as u32
|
||||
}
|
||||
|
||||
fn inverse_mul(&self) -> FieldPrime {
|
||||
let (b, s, _) = extended_euclid(&self.value, &*P);
|
||||
assert_eq!(b, BigInt::one());
|
||||
|
|
Loading…
Reference in a new issue