1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

implement sub, clean

This commit is contained in:
schaeff 2020-05-04 19:39:44 +02:00
parent 2be4859b28
commit 5a145921d4
9 changed files with 224 additions and 173 deletions

View file

@ -1,2 +1,2 @@
def main(u8 a) -> (u8):
return a + a
def main(u8 a, u8 b, u8 c) -> (u8):
return a - b - c

View file

@ -152,8 +152,6 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
// analyse (unroll and constant propagation)
let typed_ast = typed_ast.analyse();
println!("{:#?}", typed_ast);
// flatten input program
let program_flattened = Flattener::flatten(typed_ast);

View file

@ -9,16 +9,11 @@ use crate::flat_absy::*;
use crate::solvers::Solver;
use crate::zir::types::{FunctionIdentifier, FunctionKey, Signature, Type};
use crate::zir::*;
use num_bigint::BigUint;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::convert::TryFrom;
use zokrates_field::field::Field;
fn log2(a: BigUint) -> u32 {
a.bits() as u32
}
/// Flattener, computes flattened program.
#[derive(Debug)]
pub struct Flattener<'ast, T: Field> {
@ -675,7 +670,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
ZirExpression::Uint(e) => e,
_ => unreachable!(),
};
let from = log2(p.metadata.clone().unwrap().max);
let from = p.metadata.clone().unwrap().bitwidth();
let p = self.flatten_uint_expression(symbols, statements_flattened, p);
let bits = self
.get_bits(p, from as usize, 32, statements_flattened)
@ -827,7 +822,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let metadata = expr.metadata.clone().unwrap().clone();
let actual_bitwidth = log2(metadata.max) as usize;
let actual_bitwidth = metadata.bitwidth() as usize;
let should_reduce = metadata.should_reduce.unwrap();
let should_reduce = should_reduce && actual_bitwidth > target_bitwidth;
@ -845,7 +840,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatUExpression::with_field(field).bits(bits)
}
UExpressionInner::Not(box e) => {
let from = log2(e.metadata.clone().unwrap().max);
let from = e.metadata.clone().unwrap().bitwidth();
let e_flattened = self.flatten_uint_expression(symbols, statements_flattened, e);
@ -904,7 +899,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
UExpressionInner::Sub(box left, box right) => {
let aux = FlatExpression::Number(
T::from(2).pow(log2(right.metadata.clone().unwrap().max) as usize),
T::from(2).pow(right.metadata.clone().unwrap().bitwidth() as usize),
);
let left_flattened = self
@ -934,7 +929,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
))
}
UExpressionInner::LeftShift(box e, box by) => {
let from = log2(e.metadata.clone().unwrap().max);
let from = e.metadata.clone().unwrap().bitwidth();
let e = self.flatten_uint_expression(symbols, statements_flattened, e);
@ -961,7 +956,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)
}
UExpressionInner::RightShift(box e, box by) => {
let from = log2(e.metadata.clone().unwrap().max);
let from = e.metadata.clone().unwrap().bitwidth();
let e = self.flatten_uint_expression(symbols, statements_flattened, e);
@ -1028,8 +1023,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)
.clone(),
UExpressionInner::Xor(box left, box right) => {
let left_from = log2(left.metadata.clone().unwrap().max);
let right_from = log2(right.metadata.clone().unwrap().max);
let left_from = left.metadata.clone().unwrap().bitwidth();
let right_from = right.metadata.clone().unwrap().bitwidth();
let left_flattened =
self.flatten_uint_expression(symbols, statements_flattened, left);
@ -1061,13 +1056,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.flat_map(|(name, (x, y))| match (x, y) {
(FlatExpression::Number(n), e) | (e, FlatExpression::Number(n)) => {
if *n == T::from(0) {
vec![FlatStatement::Definition(name.clone(), y.clone())]
vec![FlatStatement::Definition(name.clone(), e.clone())]
} else if *n == T::from(1) {
vec![FlatStatement::Definition(
name.clone(),
FlatExpression::Sub(
box FlatExpression::Number(T::from(1)),
box y.clone(),
box e.clone(),
),
)]
} else {
@ -1105,8 +1100,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)
}
UExpressionInner::And(box left, box right) => {
let left_from = log2(left.metadata.clone().unwrap().max);
let right_from = log2(right.metadata.clone().unwrap().max);
let left_from = left.metadata.clone().unwrap().bitwidth();
let right_from = right.metadata.clone().unwrap().bitwidth();
let left_flattened =
self.flatten_uint_expression(symbols, statements_flattened, left);
@ -1186,10 +1181,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
self.depth -= 1;
statements_flattened.push(FlatStatement::Log(format!(
" {} DONE",
" ".repeat(self.depth)
)));
// statements_flattened.push(FlatStatement::Log(format!(
// " {} DONE",
// " ".repeat(self.depth)
// )));
res
}

View file

@ -18,25 +18,22 @@ impl<T: Field> Prog<T> {
for statement in main.statements.iter() {
match statement {
Statement::Constraint(quad, lin) => {
println!("{}", statement);
match lin.is_assignee(&witness) {
true => {
let val = quad.evaluate(&witness).unwrap();
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
}
false => {
let lhs_value = quad.evaluate(&witness).unwrap();
let rhs_value = lin.evaluate(&witness).unwrap();
if lhs_value != rhs_value {
return Err(Error::UnsatisfiedConstraint {
left: lhs_value.to_dec_string(),
right: rhs_value.to_dec_string(),
});
}
Statement::Constraint(quad, lin) => match lin.is_assignee(&witness) {
true => {
let val = quad.evaluate(&witness).unwrap();
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
}
false => {
let lhs_value = quad.evaluate(&witness).unwrap();
let rhs_value = lin.evaluate(&witness).unwrap();
if lhs_value != rhs_value {
return Err(Error::UnsatisfiedConstraint {
left: lhs_value.to_dec_string(),
right: rhs_value.to_dec_string(),
});
}
}
}
},
Statement::Directive(ref d) => {
let input_values: Vec<T> = d
.inputs

View file

@ -383,8 +383,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
UExpressionInner::Value(v1 & v2)
}
(UExpressionInner::Value(0), e2) => UExpressionInner::Value(0),
(e1, UExpressionInner::Value(0)) => UExpressionInner::Value(0),
(UExpressionInner::Value(0), _) | (_, UExpressionInner::Value(0)) => {
UExpressionInner::Value(0)
}
(e1, e2) => {
UExpressionInner::And(box e1.annotate(bitwidth), box e2.annotate(bitwidth))
}

View file

@ -8,7 +8,7 @@ use zokrates_field::field::Field;
#[derive(Default)]
pub struct UintOptimizer<'ast, T: Field> {
ids: HashMap<ZirAssignee<'ast>, UMetadata>,
ids: HashMap<ZirAssignee<'ast>, UMetadata<T>>,
phantom: PhantomData<T>,
}
@ -24,7 +24,7 @@ impl<'ast, T: Field> UintOptimizer<'ast, T> {
UintOptimizer::new().fold_program(p)
}
fn register(&mut self, a: ZirAssignee<'ast>, m: UMetadata) {
fn register(&mut self, a: ZirAssignee<'ast>, m: UMetadata<T>) {
self.ids.insert(a, m);
}
}
@ -51,19 +51,18 @@ fn force_no_reduce<'ast, T: Field>(e: UExpression<'ast, T>) -> UExpression<'ast,
impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
fn fold_uint_expression(&mut self, e: UExpression<'ast, T>) -> UExpression<'ast, T> {
if e.metadata.is_some() {
return e;
}
let max_bitwidth = T::get_required_bits() - 1;
let range = e.bitwidth;
let range_max: BigUint = (2_usize.pow(range as u32) - 1).into();
let range_max: T = (2_usize.pow(range as u32) - 1).into();
assert!(range < max_bitwidth / 2);
if e.metadata.is_some() {
unreachable!("{:?} had metadata", e);
}
let metadata = e.metadata;
let inner = e.inner;
use self::UExpressionInner::*;
@ -71,7 +70,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
let res = match inner {
Value(v) => Value(v).annotate(range).metadata(UMetadata {
max: v.into(),
should_reduce: Some(false),
}),
Identifier(id) => Identifier(id.clone()).annotate(range).metadata(
@ -81,8 +79,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
.expect(&format!("identifier should have been defined: {}", id)),
),
Add(box left, box right) => {
use num::CheckedAdd;
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
@ -119,113 +115,81 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
UExpression::add(left, right).metadata(UMetadata {
max,
should_reduce: Some(false),
})
}
Sub(box left, box right) => {
unimplemented!()
// // reduce the two terms
// let left = self.fold_uint_expression(left);
// let right = self.fold_uint_expression(right);
use num::traits::{CheckedAdd, Pow};
// let left_metadata = left.metadata.clone().unwrap();
// let right_metadata = right.metadata.clone().unwrap();
// let `target` the target bitwidth of `left` and `right`
// `0 <= left <= max_left`
// `0 <= right <= max_right`
// `- max_right <= left - right <= max_right`
// let `n_bits_left` the number of bits needed to represent `max_left`
// let `n = max(n_bits_left, target)`
// let offset = 2**n`
// // determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
// let left_bitwidth = left_metadata
// .should_reduce
// .map(|should_reduce| {
// if should_reduce {
// range
// } else {
// left_metadata.bitwidth.unwrap()
// }
// })
// .unwrap();
// let right_bitwidth = right_metadata
// .should_reduce
// .map(|should_reduce| {
// if should_reduce {
// range
// } else {
// right_metadata.bitwidth.unwrap()
// }
// })
// .unwrap();
// `2**n - max_left <= a - b + 2 ** n <= bound where bound = max_left + offset`
// // a(p), b(q) both of target n (p and q their real bitwidth)
// // a(p) - b(q) can always underflow
// // instead consider s = a(p) - b(q) + 2**q which is always positive
// // the min of s is 0 and the max is 2**p + 2**q, which is smaller than 2**(max(p, q) + 1)
// If ´bound < N´, we set we return `bound` as the max of ´left - right`
// Else we start again, reducing `left`. In this case `max_left` becomes `2**target - 1`
// Else we start again, reducing `right`. In this case `offset` becomes `2**target`
// Else we start again reducing both. In this case `bound` becomes `2**(target+1) - 1` which is always
// smaller or equal to N for target in {8, 16, 32}
// // so we can use s(max(p, q) + 1) as a representation of a - b if max(p, q) + 1 < max_bitwidth
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
// let output_width = std::cmp::max(left_bitwidth, right_bitwidth) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
let left_max = left.metadata.clone().unwrap().max;
let right_bitwidth = right.metadata.clone().unwrap().bitwidth();
// if output_width > max_bitwidth {
// // the addition doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
let offset =
T::from(2u32).pow(std::cmp::max(right_bitwidth, range as u32) as usize);
let target_offset = T::from(2u32).pow(range);
// let left = UExpression {
// metadata: Some(UMetadata {
// should_reduce: Some(true),
// ..left_metadata
// }),
// ..left
// };
let (should_reduce_left, should_reduce_right, max) = left_max
.checked_add(&offset)
.map(|max| (false, false, max))
.unwrap_or_else(|| {
range_max
.clone()
.checked_add(&offset)
.map(|max| (true, false, max))
.unwrap_or_else(|| {
left_max
.checked_add(&target_offset.clone())
.map(|max| (false, true, max))
.unwrap_or_else(|| {
(true, true, range_max.clone() + target_offset)
})
})
});
// let right = UExpression {
// metadata: Some(UMetadata {
// should_reduce: Some(true),
// ..right_metadata
// }),
// ..right
// };
let left = if should_reduce_left {
force_reduce(left)
} else {
left
};
let right = if should_reduce_right {
force_reduce(right)
} else {
right
};
// UExpression::sub(left, right).metadata(UMetadata {
// max: 2_u32 * range_max,
// bitwidth: Some(range + 1),
// should_reduce: Some(
// metadata
// .map(|m| m.should_reduce.unwrap_or(false))
// .unwrap_or(false),
// ),
// })
// } else {
// UExpression::sub(left, right).metadata(UMetadata {
// max: None,
// bitwidth: Some(output_width),
// should_reduce: Some(
// metadata
// .map(|m| m.should_reduce.unwrap_or(false))
// .unwrap_or(false),
// ),
// })
// }
UExpression::sub(left, right).metadata(UMetadata {
max,
should_reduce: Some(false),
})
}
Xor(box left, box right) => {
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
// for xor we need both terms to be in range. Therefore we reduce them to being in range.
// NB: if they are already in range, the flattening process will ignore the reduction
let left = left.metadata(UMetadata {
should_reduce: Some(true),
..left_metadata
});
let right = right.metadata(UMetadata {
should_reduce: Some(true),
..right_metadata
});
UExpression::xor(left, right).metadata(UMetadata {
UExpression::xor(force_reduce(left), force_reduce(right)).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
}
@ -236,7 +200,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
UExpression::and(force_reduce(left), force_reduce(right)).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
}
@ -247,7 +210,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
UExpression::or(force_reduce(left), force_reduce(right)).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
}
@ -300,7 +262,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
.annotate(range)
.metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
}
@ -321,7 +282,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
UExpression::right_shift(force_reduce(e), by).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
}
@ -332,15 +292,15 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
let consequence_max = consequence.metadata.clone().unwrap().max;
let alternative_max = alternative.metadata.clone().unwrap().max;
let max = std::cmp::max(consequence_max, alternative_max);
unimplemented!();
UExpression::if_else(condition, consequence, alternative).metadata(UMetadata {
max,
should_reduce: Some(false),
})
// let max = std::cmp::max(consequence_max, alternative_max);
// UExpression::if_else(condition, consequence, alternative).metadata(UMetadata {
// max,
// should_reduce: Some(false),
// })
}
e => unimplemented!("{:?}", e),
};
assert!(res.metadata.is_some());
@ -392,7 +352,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
self.register(
lhs[0].clone(),
UMetadata {
max: BigUint::from(2_u64.pow(32_u32) - 1),
max: T::from(2).pow(32) - T::from(1),
should_reduce: Some(false),
},
);
@ -446,7 +406,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
self.register(
p.id.clone(),
UMetadata {
max: BigUint::from(2_u64.pow(bitwidth as u32) - 1),
max: T::from(2_u32).pow(bitwidth) - T::from(1),
should_reduce: Some(false),
},
);
@ -465,21 +425,99 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::field::FieldPrime;
use zokrates_field::field::{FieldPrime, Pow};
#[should_panic]
#[test]
fn existing_metadata() {
let e = UExpressionInner::Identifier("foo".into())
.annotate(32)
.metadata(UMetadata {
max: BigUint::from(2_u64.pow(33_u32) - 1),
should_reduce: Some(false),
});
.metadata(UMetadata::with_max(2_u32.pow(33_u32) - 1));
let mut optimizer: UintOptimizer<FieldPrime> = UintOptimizer::new();
let optimized = optimizer.fold_uint_expression(e.clone());
let _ = optimizer.fold_uint_expression(e.clone());
}
assert_eq!(e, optimized);
#[test]
fn add() {
// max(left + right) = max(left) + max(right)
let left: UExpression<FieldPrime> = UExpressionInner::Identifier("foo".into())
.annotate(32)
.metadata(UMetadata::with_max(42u32));
let right = UExpressionInner::Identifier("foo".into())
.annotate(32)
.metadata(UMetadata::with_max(33u32));
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::add(left, right))
.metadata
.unwrap()
.max,
75u32.into()
);
}
#[test]
fn sub() {
// `left` and `right` are smaller than the target
let left: UExpression<FieldPrime> = UExpressionInner::Identifier("a".into())
.annotate(32)
.metadata(UMetadata::with_max(42u32));
let right = UExpressionInner::Identifier("b".into())
.annotate(32)
.metadata(UMetadata::with_max(33u32));
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::sub(left, right))
.metadata
.unwrap()
.max,
FieldPrime::from(2u32).pow(32) + FieldPrime::from(42)
);
// `left` and `right` are larger than the target but no readjustment is required
let left: UExpression<FieldPrime> = UExpressionInner::Identifier("a".into())
.annotate(32)
.metadata(UMetadata::with_max(u64::MAX as u128));
let right = UExpressionInner::Identifier("b".into())
.annotate(32)
.metadata(UMetadata::with_max(u64::MAX as u128));
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::sub(left, right))
.metadata
.unwrap()
.max,
FieldPrime::from(2).pow(64) + FieldPrime::from(u64::MAX as u128)
);
// `left` and `right` are larger than the target and needs to be readjusted
let left: UExpression<FieldPrime> = UExpressionInner::Identifier("a".into())
.annotate(32)
.metadata(UMetadata::with_max(
FieldPrime::from(2u32).pow(FieldPrime::get_required_bits() - 1)
- FieldPrime::from(1),
));
let right = UExpressionInner::Identifier("b".into())
.annotate(32)
.metadata(UMetadata::with_max(42u32));
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::sub(left, right))
.metadata
.unwrap()
.max,
FieldPrime::from(2u32).pow(32) * FieldPrime::from(2) - FieldPrime::from(1)
);
}
}

View file

@ -10,7 +10,7 @@ pub use self::parameter::Parameter;
pub use self::types::Type;
pub use self::variable::Variable;
use std::path::PathBuf;
pub use zir::uint::{bitwidth, UExpression, UExpressionInner, UMetadata};
pub use zir::uint::{UExpression, UExpressionInner, UMetadata};
use embed::FlatEmbed;
use std::collections::HashMap;

View file

@ -68,15 +68,28 @@ impl<'ast, T: Field> From<&'ast str> for UExpressionInner<'ast, T> {
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct UMetadata {
pub max: BigUint,
pub struct UMetadata<T> {
pub max: T,
pub should_reduce: Option<bool>,
}
impl<T: Field> UMetadata<T> {
pub fn with_max<U: Into<T>>(max: U) -> Self {
UMetadata {
max: max.into(),
should_reduce: None,
}
}
pub fn bitwidth(&self) -> u32 {
self.max.bits() as u32
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct UExpression<'ast, T: Field> {
pub bitwidth: Bitwidth,
pub metadata: Option<UMetadata>,
pub metadata: Option<UMetadata<T>>,
pub inner: UExpressionInner<'ast, T>,
}
@ -117,7 +130,7 @@ impl<'ast, T: Field> UExpressionInner<'ast, T> {
}
impl<'ast, T: Field> UExpression<'ast, T> {
pub fn metadata(self, metadata: UMetadata) -> UExpression<'ast, T> {
pub fn metadata(self, metadata: UMetadata<T>) -> UExpression<'ast, T> {
UExpression {
metadata: Some(metadata),
..self
@ -125,10 +138,6 @@ impl<'ast, T: Field> UExpression<'ast, T> {
}
}
pub fn bitwidth(a: u128) -> Bitwidth {
(128 - a.leading_zeros()) as Bitwidth
}
impl<'ast, T: Field> UExpression<'ast, T> {
pub fn bitwidth(&self) -> Bitwidth {
self.bitwidth

View file

@ -96,6 +96,8 @@ pub trait Field:
fn to_compact_dec_string(&self) -> String;
/// Converts to BigUint
fn into_big_uint(self) -> BigUint;
/// Gets the number of bits
fn bits(&self) -> u32;
}
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Serialize, Deserialize)]
@ -105,8 +107,13 @@ pub struct FieldPrime {
impl num_traits::CheckedAdd for FieldPrime {
fn checked_add(&self, other: &Self) -> Option<Self> {
use num_traits::Pow;
let res = self.value.clone() + other.value.clone();
if res >= *P {
let bound = BigInt::from(2u32).pow(Self::get_required_bits() - 1);
// we only go up to 2**(bitwidth - 1) because after that we lose uniqueness of bit decomposition
if res >= bound {
None
} else {
Some(FieldPrime { value: res })
@ -116,8 +123,10 @@ impl num_traits::CheckedAdd for FieldPrime {
impl num_traits::CheckedMul for FieldPrime {
fn checked_mul(&self, other: &Self) -> Option<Self> {
use num_traits::Pow;
let res = self.value.clone() * other.value.clone();
if res >= *P {
// we only go up to 2**(bitwidth - 1) because after that we lose uniqueness of bit decomposition
if res >= BigInt::from(2u32).pow(Self::get_required_bits() - 1) {
None
} else {
Some(FieldPrime { value: res })
@ -150,6 +159,10 @@ impl Field for FieldPrime {
self.value.to_str_radix(10)
}
fn bits(&self) -> u32 {
self.value.bits() as u32
}
fn inverse_mul(&self) -> FieldPrime {
let (b, s, _) = extended_euclid(&self.value, &*P);
assert_eq!(b, BigInt::one());