1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

refactor uint optimizer, remove bitwidth

This commit is contained in:
schaeff 2020-04-29 14:20:13 +02:00
parent d05ee17640
commit 2be4859b28
18 changed files with 795 additions and 959 deletions

View file

@ -1,6 +1,2 @@
def main(u32[1000] a) -> (u32):
u32 res = 0x00000000
for field i in 0..1000 do
res = res + a[i]
endfor
return res
def main(u8 a) -> (u8):
return a + a

View file

@ -152,6 +152,8 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
// analyse (unroll and constant propagation)
let typed_ast = typed_ast.analyse();
println!("{:#?}", typed_ast);
// flatten input program
let program_flattened = Flattener::flatten(typed_ast);

File diff suppressed because it is too large Load diff

View file

@ -21,21 +21,22 @@ impl<T: Field> Prog<T> {
Statement::Constraint(quad, lin) => {
println!("{}", statement);
match lin.is_assignee(&witness) {
true => {
let val = quad.evaluate(&witness).unwrap();
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
}
false => {
let lhs_value = quad.evaluate(&witness).unwrap();
let rhs_value = lin.evaluate(&witness).unwrap();
if lhs_value != rhs_value {
return Err(Error::UnsatisfiedConstraint {
left: lhs_value.to_dec_string(),
right: rhs_value.to_dec_string(),
});
true => {
let val = quad.evaluate(&witness).unwrap();
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
}
}}
},
false => {
let lhs_value = quad.evaluate(&witness).unwrap();
let rhs_value = lin.evaluate(&witness).unwrap();
if lhs_value != rhs_value {
return Err(Error::UnsatisfiedConstraint {
left: lhs_value.to_dec_string(),
right: rhs_value.to_dec_string(),
});
}
}
}
}
Statement::Directive(ref d) => {
let input_values: Vec<T> = d
.inputs

View file

@ -4049,32 +4049,28 @@ mod tests {
});
assert_eq!(
checker.check_parameter(
absy::Parameter {
id:
absy::Variable::new("a", UnresolvedType::User("Foo".into()).mock(),)
.mock(),
private: true,
}
.mock(),
&PathBuf::from(MODULE_ID).into(),
&state.types,
),
Ok(Parameter {
id: Variable::with_id_and_type(
<<<<<<< HEAD
"a",
=======
"a".into(),
>>>>>>> b0382ea64e8df4bbdf363fc6fc4c3862900629e7
Type::Struct(vec![StructMember::new(
"foo".to_string(),
Type::FieldElement
)])
),
private: true
})
);
checker.check_parameter(
absy::Parameter {
id:
absy::Variable::new("a", UnresolvedType::User("Foo".into()).mock(),)
.mock(),
private: true,
}
.mock(),
&PathBuf::from(MODULE_ID).into(),
&state.types,
),
Ok(Parameter {
id: Variable::with_id_and_type(
"a",
Type::Struct(vec![StructMember::new(
"foo".to_string(),
Type::FieldElement
)])
),
private: true
})
);
assert_eq!(
checker
@ -4112,27 +4108,23 @@ mod tests {
});
assert_eq!(
checker.check_statement::<FieldPrime>(
Statement::Declaration(
absy::Variable::new("a", UnresolvedType::User("Foo".into()).mock(),)
.mock()
)
.mock(),
&PathBuf::from(MODULE_ID).into(),
&state.types,
),
Ok(TypedStatement::Declaration(Variable::with_id_and_type(
<<<<<<< HEAD
"a",
=======
"a".into(),
>>>>>>> b0382ea64e8df4bbdf363fc6fc4c3862900629e7
Type::Struct(vec![StructMember::new(
"foo".to_string(),
Type::FieldElement
)])
)))
);
checker.check_statement::<FieldPrime>(
Statement::Declaration(
absy::Variable::new("a", UnresolvedType::User("Foo".into()).mock(),)
.mock()
)
.mock(),
&PathBuf::from(MODULE_ID).into(),
&state.types,
),
Ok(TypedStatement::Declaration(Variable::with_id_and_type(
"a",
Type::Struct(vec![StructMember::new(
"foo".to_string(),
Type::FieldElement
)])
)))
);
assert_eq!(
checker

View file

@ -65,7 +65,7 @@ impl Solver {
.into_iter()
.map(|x| T::from_bellman(x))
.collect()
},
}
Solver::Xor => {
let x = inputs[0].clone();
let y = inputs[1].clone();

View file

@ -133,25 +133,4 @@ impl<'ast, T: Field> Folder<'ast, T> for InputConstrainer<'ast, T> {
..f
}
}
fn fold_uint_expression(&mut self, e: UExpression<'ast, T>) -> UExpression<'ast, T> {
match e.inner {
UExpressionInner::Identifier(ref id) => {
if self.uints.contains(id) {
use std::convert::TryInto;
UExpression {
metadata: Some(UMetadata {
max: Some(2_usize.pow(e.bitwidth.try_into().unwrap()).into()),
bitwidth: Some(e.bitwidth),
should_reduce: Some(false),
}),
..e
}
} else {
e
}
}
_ => fold_uint_expression(self, e),
}
}
}

View file

@ -710,7 +710,9 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
zir::UExpressionInner::Not(box e)
}
typed_absy::UExpressionInner::FunctionCall(key, exps) => unreachable!("function calls should have been removed"),
typed_absy::UExpressionInner::FunctionCall(key, exps) => {
unreachable!("function calls should have been removed")
}
typed_absy::UExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);

View file

@ -136,10 +136,10 @@ impl<'ast, T: Field> Inliner<'ast, T> {
TypedModule {
functions: vec![
(unpack_key, TypedFunctionSymbol::Flat(unpack)),
// (sha256_round_key, TypedFunctionSymbol::Flat(sha256_round)),
// (check_u8_key, TypedFunctionSymbol::Flat(check_u8)),
// (check_u16_key, TypedFunctionSymbol::Flat(check_u16)),
// (check_u32_key, TypedFunctionSymbol::Flat(check_u32)),
(sha256_round_key, TypedFunctionSymbol::Flat(sha256_round)),
(check_u8_key, TypedFunctionSymbol::Flat(check_u8)),
(check_u16_key, TypedFunctionSymbol::Flat(check_u16)),
(check_u32_key, TypedFunctionSymbol::Flat(check_u32)),
(u32_to_bits_key, TypedFunctionSymbol::Flat(u32_to_bits)),
(u32_from_bits_key, TypedFunctionSymbol::Flat(u32_from_bits)),
(main_key, main),
@ -398,13 +398,20 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
self.statement_buffer
.push(TypedStatement::MultipleDefinition(
vec![Variable::with_id_and_type(id.clone(), tys[0].clone())],
TypedExpressionList::FunctionCall(key.clone(), expressions.clone(), tys),
TypedExpressionList::FunctionCall(
key.clone(),
expressions.clone(),
tys,
),
));
self.call_cache_mut()
.entry(key.clone())
.or_insert_with(|| HashMap::new())
.insert(expressions, vec![BooleanExpression::Identifier(id.clone()).into()]);
.insert(
expressions,
vec![BooleanExpression::Identifier(id.clone()).into()],
);
BooleanExpression::Identifier(id)
}
@ -431,7 +438,6 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
_ => unreachable!(),
},
Err((embed_key, expressions)) => {
let tys = key.signature.outputs.clone();
let id = Identifier {
id: CoreIdentifier::Call(key.clone()),
@ -444,7 +450,11 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
self.statement_buffer
.push(TypedStatement::MultipleDefinition(
vec![Variable::with_id_and_type(id.clone(), tys[0].clone())],
TypedExpressionList::FunctionCall(embed_key.clone(), expressions.clone(), tys),
TypedExpressionList::FunctionCall(
embed_key.clone(),
expressions.clone(),
tys,
),
));
let out = ArrayExpressionInner::Identifier(id);
@ -452,7 +462,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
self.call_cache_mut()
.entry(key.clone())
.or_insert_with(|| HashMap::new())
.insert(expressions, vec![out.clone().annotate(ty.clone(), size).into()]);
.insert(
expressions,
vec![out.clone().annotate(ty.clone(), size).into()],
);
out
}
@ -528,7 +541,11 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
self.statement_buffer
.push(TypedStatement::MultipleDefinition(
vec![Variable::with_id_and_type(id.clone(), tys[0].clone())],
TypedExpressionList::FunctionCall(embed_key.clone(), expressions.clone(), tys),
TypedExpressionList::FunctionCall(
embed_key.clone(),
expressions.clone(),
tys,
),
));
let out = UExpressionInner::Identifier(id);
@ -833,12 +850,12 @@ mod tests {
FunctionKey::with_id("main").signature(signature.clone()),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![Parameter {
id: Variable::field_element("a".into()),
id: Variable::field_element("a"),
private: true,
}],
statements: vec![
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("b".into())),
TypedAssignee::Identifier(Variable::field_element("b")),
FieldElementExpression::Add(
box FieldElementExpression::FunctionCall(
FunctionKey::with_id("foo").signature(signature.clone()),
@ -876,7 +893,7 @@ mod tests {
FunctionKey::with_id("foo").signature(signature.clone()),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![Parameter {
id: Variable::field_element("a".into()),
id: Variable::field_element("a"),
private: true,
}],
statements: vec![TypedStatement::Return(vec![
@ -911,7 +928,7 @@ mod tests {
.unwrap(),
&TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![Parameter {
id: Variable::field_element("a".into()),
id: Variable::field_element("a"),
private: true,
}],
statements: vec![
@ -926,7 +943,7 @@ mod tests {
FieldElementExpression::Identifier("a".into()).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("b".into())),
TypedAssignee::Identifier(Variable::field_element("b")),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("a").stack(
vec![(
@ -987,12 +1004,12 @@ mod tests {
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![Parameter {
id: Variable::field_element("a".into()),
id: Variable::field_element("a"),
private: true,
}],
statements: vec![
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("b".into())),
TypedAssignee::Identifier(Variable::field_element("b")),
FieldElementExpression::Add(
box FieldElementExpression::FunctionCall(
FunctionKey::with_id("foo").signature(signature.clone()),
@ -1017,7 +1034,7 @@ mod tests {
FunctionKey::with_id("bar").signature(signature.clone()),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![Parameter {
id: Variable::field_element("a".into()),
id: Variable::field_element("a"),
private: true,
}],
statements: vec![TypedStatement::Return(vec![
@ -1047,7 +1064,7 @@ mod tests {
FunctionKey::with_id("foo").signature(signature.clone()),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![Parameter {
id: Variable::field_element("a".into()),
id: Variable::field_element("a"),
private: true,
}],
statements: vec![TypedStatement::Return(vec![
@ -1082,7 +1099,7 @@ mod tests {
.unwrap(),
&TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![Parameter {
id: Variable::field_element("a".into()),
id: Variable::field_element("a"),
private: true,
}],
statements: vec![
@ -1129,7 +1146,7 @@ mod tests {
.into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("b".into())),
TypedAssignee::Identifier(Variable::field_element("b")),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("a").stack(
vec![(

View file

@ -10,26 +10,26 @@ mod flatten_complex_types;
mod inline;
mod propagate_unroll;
mod propagation;
mod uint_optimizer;
mod redefinition;
mod return_binder;
mod uint_optimizer;
mod unconstrained_vars;
mod unroll;
mod redefinition;
use self::constrain_inputs::InputConstrainer;
use self::flatten_complex_types::Flattener;
use self::inline::Inliner;
use self::propagate_unroll::PropagatedUnroller;
use self::propagation::Propagator;
use self::uint_optimizer::UintOptimizer;
use self::redefinition::RedefinitionOptimizer;
use self::return_binder::ReturnBinder;
use self::uint_optimizer::UintOptimizer;
use self::unconstrained_vars::UnconstrainedVariableDetector;
use crate::flat_absy::FlatProg;
use crate::ir::Prog;
use crate::typed_absy::TypedProgram;
use zir::ZirProgram;
use self::return_binder::ReturnBinder;
use self::unconstrained_vars::UnconstrainedVariableDetector;
use zokrates_field::field::Field;
use crate::ir::Prog;
pub trait Analyse {
fn analyse(self) -> Self;

View file

@ -10,24 +10,24 @@
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
use std::marker::PhantomData;
use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::marker::PhantomData;
use typed_absy::types::{StructMember, Type};
use zokrates_field::field::Field;
pub struct RedefinitionOptimizer<'ast, T: Field> {
identifiers: HashMap<Identifier<'ast>, Identifier<'ast>>,
phantom: PhantomData<T>
phantom: PhantomData<T>,
}
impl<'ast, T: Field> RedefinitionOptimizer<'ast, T> {
fn new() -> Self {
RedefinitionOptimizer {
identifiers: HashMap::new(),
phantom: PhantomData
phantom: PhantomData,
}
}
@ -71,19 +71,16 @@ impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer<'ast, T> {
Some(id) => {
let target = self.identifiers.get(&id).unwrap_or(&id).clone();
self.identifiers
.insert(var.id, target);
self.identifiers.insert(var.id, target);
vec![]
},
None => {
vec![TypedStatement::Definition(
TypedAssignee::Identifier(var),
expr,
)]
}
None => vec![TypedStatement::Definition(
TypedAssignee::Identifier(var),
expr,
)],
}
},
s => fold_statement(self, s)
}
s => fold_statement(self, s),
}
}

View file

@ -31,7 +31,9 @@ impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder {
Type::Struct(struct_type) => StructExpressionInner::Identifier(i.clone())
.annotate(struct_type)
.into(),
Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone()).annotate(bitwidth).into()
Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone())
.annotate(bitwidth)
.into(),
})
.collect();

View file

@ -1,9 +1,8 @@
use num_bigint::BigUint;
use crate::embed::FlatEmbed;
use crate::zir::*;
use num_bigint::BigUint;
use std::collections::HashMap;
use std::marker::PhantomData;
use zir::bitwidth;
use zir::folder::*;
use zokrates_field::field::Field;
@ -26,12 +25,27 @@ impl<'ast, T: Field> UintOptimizer<'ast, T> {
}
fn register(&mut self, a: ZirAssignee<'ast>, m: UMetadata) {
// match (a, m) {
// (a, ZirExpression::Uint(e)) => {
self.ids.insert(a, m);
// }
// _ => {}
// }
self.ids.insert(a, m);
}
}
fn force_reduce<'ast, T: Field>(e: UExpression<'ast, T>) -> UExpression<'ast, T> {
UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..e.metadata.unwrap()
}),
..e
}
}
fn force_no_reduce<'ast, T: Field>(e: UExpression<'ast, T>) -> UExpression<'ast, T> {
UExpression {
metadata: Some(UMetadata {
should_reduce: Some(false),
..e.metadata.unwrap()
}),
..e
}
}
@ -39,8 +53,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
fn fold_uint_expression(&mut self, e: UExpression<'ast, T>) -> UExpression<'ast, T> {
let max_bitwidth = T::get_required_bits() - 1;
let max = T::max_value().into_big_uint();
let range = e.bitwidth;
let range_max: BigUint = (2_usize.pow(range as u32) - 1).into();
@ -48,7 +60,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
assert!(range < max_bitwidth / 2);
if e.metadata.is_some() {
return e;
unreachable!("{:?} had metadata", e);
}
let metadata = e.metadata;
@ -58,13 +70,9 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
let res = match inner {
Value(v) => Value(v).annotate(range).metadata(UMetadata {
max: Some(v.into()),
bitwidth: Some(range),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
max: v.into(),
should_reduce: Some(false),
}),
Identifier(id) => Identifier(id.clone()).annotate(range).metadata(
self.ids
@ -73,181 +81,127 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
.expect(&format!("identifier should have been defined: {}", id)),
),
Add(box left, box right) => {
use num::CheckedAdd;
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
let left_max = left.metadata.clone().unwrap().max;
let right_max = right.metadata.clone().unwrap().max;
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
let left_after = left_metadata.after(range);
let (should_reduce_left, should_reduce_right, max) = left_max
.checked_add(&right_max)
.map(|max| (false, false, max))
.unwrap_or_else(|| {
range_max
.clone()
.checked_add(&right_max)
.map(|max| (true, false, max))
.unwrap_or_else(|| {
left_max
.checked_add(&range_max.clone())
.map(|max| (false, true, max))
.unwrap_or_else(|| (true, true, range_max.clone() + range_max))
})
});
let right_after = right_metadata.after(range);
let left_max = left_after.clone().max.unwrap();
let right_max = right_after.clone().max.unwrap();
//let output_width = std::cmp::max(left_after.bitwidth.unwrap(), right_after.bitwidth.unwrap()) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
if left_max.clone() + right_max.clone() > max {
// the addition doesnt fit, we try to reduce one term
if left_max.clone() + range_max.clone() <= max {
let right = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..right_after
}),
..right
};
UExpression::add(left, right).metadata(UMetadata {
max: Some(left_max.clone() + range_max.clone()),
bitwidth: Some(std::cmp::max(left_after.bitwidth.unwrap(), range) + 1),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
} else if right_max.clone() + range_max.clone() <= max {
let left = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..left_after
}),
..left
};
UExpression::add(left, right).metadata(UMetadata {
max: Some(left_max.clone() + right_max.clone()),
bitwidth: Some(std::cmp::max(right_after.bitwidth.unwrap(), range) + 1),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
} else {
let left = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..left_after
}),
..left
};
let right = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..right_after
}),
..right
};
UExpression::add(left, right).metadata(UMetadata {
max: Some(range_max.clone() * 2_u32),
bitwidth: Some(range + 1),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
}
let left = if should_reduce_left {
force_reduce(left)
} else {
// the addition fits, so we just add
UExpression::add(left, right).metadata(UMetadata {
max: Some(left_max.clone() + right_max.clone()),
bitwidth: Some(std::cmp::max(left_after.bitwidth.unwrap(), right_after.bitwidth.unwrap())),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
}
left
};
let right = if should_reduce_right {
force_reduce(right)
} else {
right
};
UExpression::add(left, right).metadata(UMetadata {
max,
should_reduce: Some(false),
})
}
Sub(box left, box right) => {
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
unimplemented!()
// // reduce the two terms
// let left = self.fold_uint_expression(left);
// let right = self.fold_uint_expression(right);
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
// let left_metadata = left.metadata.clone().unwrap();
// let right_metadata = right.metadata.clone().unwrap();
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
let left_bitwidth = left_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
left_metadata.bitwidth.unwrap()
}
})
.unwrap();
let right_bitwidth = right_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
right_metadata.bitwidth.unwrap()
}
})
.unwrap();
// // determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
// let left_bitwidth = left_metadata
// .should_reduce
// .map(|should_reduce| {
// if should_reduce {
// range
// } else {
// left_metadata.bitwidth.unwrap()
// }
// })
// .unwrap();
// let right_bitwidth = right_metadata
// .should_reduce
// .map(|should_reduce| {
// if should_reduce {
// range
// } else {
// right_metadata.bitwidth.unwrap()
// }
// })
// .unwrap();
// a(p), b(q) both of target n (p and q their real bitwidth)
// a(p) - b(q) can always underflow
// instead consider s = a(p) - b(q) + 2**q which is always positive
// the min of s is 0 and the max is 2**p + 2**q, which is smaller than 2**(max(p, q) + 1)
// // a(p), b(q) both of target n (p and q their real bitwidth)
// // a(p) - b(q) can always underflow
// // instead consider s = a(p) - b(q) + 2**q which is always positive
// // the min of s is 0 and the max is 2**p + 2**q, which is smaller than 2**(max(p, q) + 1)
// so we can use s(max(p, q) + 1) as a representation of a - b if max(p, q) + 1 < max_bitwidth
// // so we can use s(max(p, q) + 1) as a representation of a - b if max(p, q) + 1 < max_bitwidth
let output_width = std::cmp::max(left_bitwidth, right_bitwidth) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
// let output_width = std::cmp::max(left_bitwidth, right_bitwidth) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
if output_width > max_bitwidth {
// the addition doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
// if output_width > max_bitwidth {
// // the addition doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
let left = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..left_metadata
}),
..left
};
// let left = UExpression {
// metadata: Some(UMetadata {
// should_reduce: Some(true),
// ..left_metadata
// }),
// ..left
// };
let right = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..right_metadata
}),
..right
};
// let right = UExpression {
// metadata: Some(UMetadata {
// should_reduce: Some(true),
// ..right_metadata
// }),
// ..right
// };
UExpression::sub(left, right).metadata(UMetadata {
max: Some(2_u32 * range_max),
bitwidth: Some(range + 1),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
} else {
UExpression::sub(left, right).metadata(UMetadata {
max: None,
bitwidth: Some(output_width),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
}
// UExpression::sub(left, right).metadata(UMetadata {
// max: 2_u32 * range_max,
// bitwidth: Some(range + 1),
// should_reduce: Some(
// metadata
// .map(|m| m.should_reduce.unwrap_or(false))
// .unwrap_or(false),
// ),
// })
// } else {
// UExpression::sub(left, right).metadata(UMetadata {
// max: None,
// bitwidth: Some(output_width),
// should_reduce: Some(
// metadata
// .map(|m| m.should_reduce.unwrap_or(false))
// .unwrap_or(false),
// ),
// })
// }
}
Xor(box left, box right) => {
// reduce the two terms
@ -270,8 +224,8 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
});
UExpression::xor(left, right).metadata(UMetadata {
max: Some(range_max.clone()),
bitwidth: Some(range),
max: range_max.clone(),
should_reduce: Some(false),
})
}
@ -280,24 +234,9 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
// for and we need both terms to be in range. Therefore we reduce them to being in range.
// NB: if they are already in range, the flattening process will ignore the reduction
let left = left.metadata(UMetadata {
should_reduce: Some(true),
..left_metadata
});
let right = right.metadata(UMetadata {
should_reduce: Some(true),
..right_metadata
});
UExpression::and(left, right).metadata(UMetadata {
max: Some(range_max.clone()),
bitwidth: Some(range),
UExpression::and(force_reduce(left), force_reduce(right)).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
}
@ -306,142 +245,63 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
// for xor we need both terms to be in range. Therefore we reduce them to being in range.
// NB: if they are already in range, the flattening process will ignore the reduction
let left = left.metadata(UMetadata {
should_reduce: Some(true),
..left_metadata
});
let right = right.metadata(UMetadata {
should_reduce: Some(true),
..right_metadata
});
UExpression::or(left, right).metadata(UMetadata {
max: None,
bitwidth: Some(range),
UExpression::or(force_reduce(left), force_reduce(right)).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
}
Mult(box left, box right) => {
use num::CheckedMul;
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
let left_max = left.metadata.clone().unwrap().max;
let right_max = right.metadata.clone().unwrap().max;
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
let left_after = left_metadata.after(range);
let (should_reduce_left, should_reduce_right, max) = left_max
.checked_mul(&right_max)
.map(|max| (false, false, max))
.unwrap_or_else(|| {
range_max
.clone()
.checked_mul(&right_max)
.map(|max| (true, false, max))
.unwrap_or_else(|| {
left_max
.checked_mul(&range_max.clone())
.map(|max| (false, true, max))
.unwrap_or_else(|| (true, true, range_max.clone() * range_max))
})
});
let right_after = right_metadata.after(range);
let output_width = left_after.bitwidth.clone().unwrap() + right_after.bitwidth.clone().unwrap(); // bitwidth(a*b) = bitwidth(a) + bitwidth(b)
let left_max = left_after.clone().max.unwrap();
let right_max = right_after.clone().max.unwrap();
if left_max.clone() * right_max.clone() > max {
// the addition doesnt fit, we try to reduce one term
if left_max.clone() * range_max.clone() <= max {
let right = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..right_after
}),
..right
};
UExpression::mult(left, right).metadata(UMetadata {
max: Some(left_max.clone() * range_max.clone()),
bitwidth: Some(std::cmp::max(left_after.bitwidth.unwrap(), range) + 1),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
} else if right_max.clone() * range_max.clone() <= max {
let left = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..left_after
}),
..left
};
UExpression::mult(left, right).metadata(UMetadata {
max: Some(left_max.clone() * right_max.clone()),
bitwidth: Some(std::cmp::max(right_after.bitwidth.unwrap(), range) + 1),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
} else {
let left = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..left_after
}),
..left
};
let right = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..right_after
}),
..right
};
UExpression::mult(left, right).metadata(UMetadata {
max: Some(range_max.clone() * range_max.clone()),
bitwidth: Some(range + 1),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
}
let left = if should_reduce_left {
force_reduce(left)
} else {
// the addition fits, so we just add
UExpression::mult(left, right).metadata(UMetadata {
max: Some(left_max.clone() * right_max.clone()),
bitwidth: Some(std::cmp::max(left_after.bitwidth.unwrap(), right_after.bitwidth.unwrap())),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
}
left
};
let right = if should_reduce_right {
force_reduce(right)
} else {
right
};
UExpression::mult(left, right).metadata(UMetadata {
max,
should_reduce: Some(false),
})
}
Not(box e) => {
let e = self.fold_uint_expression(e);
let e_metadata = e.metadata.clone().unwrap();
let e_bitwidth = range;
let e = e.metadata(UMetadata {
should_reduce: Some(true),
..e_metadata
});
UExpressionInner::Not(box e)
UExpressionInner::Not(box force_reduce(e))
.annotate(range)
.metadata(UMetadata {
max: Some(range_max.clone()),
bitwidth: Some(range),
should_reduce: Some(true),
max: range_max.clone(),
should_reduce: Some(false),
})
}
LeftShift(box e, box by) => {
@ -449,18 +309,8 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
let e = self.fold_uint_expression(e);
let by = self.fold_field_expression(by);
let e_metadata = e.metadata.clone().unwrap();
// for shift we need the expression to be in range. Therefore we reduce them to being in range.
// NB: if they are already in range, the flattening process will ignore the reduction
let e = e.metadata(UMetadata {
should_reduce: Some(true),
..e_metadata
});
UExpression::left_shift(e, by).metadata(UMetadata {
max: Some(range_max.clone()),
bitwidth: Some(range),
UExpression::left_shift(force_reduce(e), by).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(true),
})
}
@ -469,62 +319,28 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
let e = self.fold_uint_expression(e);
let by = self.fold_field_expression(by);
let e_metadata = e.metadata.clone().unwrap();
// for shift we need the expression to be in range. Therefore we reduce them to being in range.
// NB: if they are already in range, the flattening process will ignore the reduction
let e = e.metadata(UMetadata {
should_reduce: Some(true),
..e_metadata
});
UExpression::right_shift(e, by).metadata(UMetadata {
max: Some(range_max.clone()),
bitwidth: Some(range),
should_reduce: Some(true),
UExpression::right_shift(force_reduce(e), by).metadata(UMetadata {
max: range_max.clone(),
should_reduce: Some(false),
})
}
IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_uint_expression(consequence);
let alternative = self.fold_uint_expression(alternative);
let consequence_metadata = consequence.metadata.clone().unwrap();
let alternative_metadata = alternative.metadata.clone().unwrap();
let consequence_max = consequence.metadata.clone().unwrap().max;
let alternative_max = alternative.metadata.clone().unwrap().max;
let consequence_bitwidth = consequence_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
consequence_metadata.bitwidth.unwrap()
}
})
.unwrap();
let alternative_bitwidth = alternative_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
alternative_metadata.bitwidth.unwrap()
}
})
.unwrap();
let output_width = std::cmp::max(consequence_bitwidth, alternative_bitwidth);
let max = std::cmp::max(consequence_max, alternative_max);
UExpression::if_else(condition, consequence, alternative).metadata(UMetadata {
max: None,
bitwidth: Some(output_width),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
max,
should_reduce: Some(false),
})
}
e => fold_uint_expression_inner(self, range, e).annotate(range),
e => unimplemented!("{:?}", e),
};
assert!(res.metadata.is_some());
@ -536,11 +352,14 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
match s {
ZirStatement::Definition(a, e) => {
let e = self.fold_expression(e);
match e {
ZirExpression::Uint(ref i) => {
let e = match e {
ZirExpression::Uint(i) => {
let i = force_no_reduce(i);
self.register(a.clone(), i.metadata.clone().unwrap());
},
_ => {}
ZirExpression::Uint(i)
}
e => e,
};
vec![ZirStatement::Definition(a, e)]
}
@ -570,29 +389,29 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
ZirExpressionList::FunctionCall(key, arguments, ty) => match key.clone().id {
"_U32_FROM_BITS" => {
assert_eq!(lhs.len(), 1);
self.register(lhs[0].clone(), UMetadata {
max: Some(BigUint::from(2_u64.pow(32_u32) - 1)),
bitwidth: Some(32),
should_reduce: Some(true),
});
self.register(
lhs[0].clone(),
UMetadata {
max: BigUint::from(2_u64.pow(32_u32) - 1),
should_reduce: Some(false),
},
);
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::FunctionCall(key, arguments, ty),
)]
}
_ => {
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::FunctionCall(
key,
arguments
.into_iter()
.map(|e| self.fold_expression(e))
.collect(),
ty,
),
)]
}
_ => vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::FunctionCall(
key,
arguments
.into_iter()
.map(|e| self.fold_expression(e))
.collect(),
ty,
),
)],
},
},
// we need to put back in range to assert
@ -620,6 +439,27 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
s => fold_statement(self, s),
}
}
fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> {
let id = match p.id.get_type() {
Type::Uint(bitwidth) => {
self.register(
p.id.clone(),
UMetadata {
max: BigUint::from(2_u64.pow(bitwidth as u32) - 1),
should_reduce: Some(false),
},
);
p.id
}
_ => p.id,
};
Parameter {
id: self.fold_variable(id),
..p
}
}
}
#[cfg(test)]
@ -632,8 +472,7 @@ mod tests {
let e = UExpressionInner::Identifier("foo".into())
.annotate(32)
.metadata(UMetadata {
max: None,
bitwidth: Some(33),
max: BigUint::from(2_u64.pow(33_u32) - 1),
should_reduce: Some(false),
});

View file

@ -138,7 +138,7 @@ mod tests {
arguments: vec![_0],
statements: vec![
Statement::Directive(Directive {
inputs: vec![LinComb::summand(-42, one) + LinComb::summand(1, _0)],
inputs: vec![(LinComb::summand(-42, one) + LinComb::summand(1, _0)).into()],
outputs: vec![_1, _2],
solver: Solver::ConditionEq,
}),

View file

@ -231,28 +231,5 @@ pub mod signature {
assert_eq!(s.to_slug(), String::from("i3fofbf"));
}
#[test]
fn array_slug() {
let s = Signature::new()
.inputs(vec![
Type::array(Type::FieldElement, 42),
Type::array(Type::FieldElement, 21),
])
.outputs(vec![]);
assert_eq!(s.to_slug(), String::from("if[42]f[21]o"));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn array() {
let t = Type::Array(box Type::FieldElement, 42);
assert_eq!(t.get_primitive_count(), 42);
}
}

View file

@ -69,25 +69,10 @@ impl<'ast, T: Field> From<&'ast str> for UExpressionInner<'ast, T> {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct UMetadata {
pub bitwidth: Option<Bitwidth>,
pub max: Option<BigUint>,
pub max: BigUint,
pub should_reduce: Option<bool>,
}
impl UMetadata {
pub fn after(self, bitwidth: usize) -> Self {
use std::convert::TryInto;
match self.should_reduce.unwrap() {
true => UMetadata {
should_reduce: Some(false),
bitwidth: Some(bitwidth),
max: Some(2_usize.pow(bitwidth.try_into().unwrap()).into())
},
false => self
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct UExpression<'ast, T: Field> {
pub bitwidth: Bitwidth,

View file

@ -56,6 +56,8 @@ pub trait Field:
+ Pow<usize, Output = Self>
+ Pow<Self, Output = Self>
+ for<'a> Pow<&'a Self, Output = Self>
+ num_traits::CheckedAdd
+ num_traits::CheckedMul
{
/// An associated type to be able to operate with Bellman ff traits
type BellmanEngine: Engine;
@ -101,6 +103,28 @@ pub struct FieldPrime {
value: BigInt,
}
impl num_traits::CheckedAdd for FieldPrime {
fn checked_add(&self, other: &Self) -> Option<Self> {
let res = self.value.clone() + other.value.clone();
if res >= *P {
None
} else {
Some(FieldPrime { value: res })
}
}
}
impl num_traits::CheckedMul for FieldPrime {
fn checked_mul(&self, other: &Self) -> Option<Self> {
let res = self.value.clone() * other.value.clone();
if res >= *P {
None
} else {
Some(FieldPrime { value: res })
}
}
}
impl Field for FieldPrime {
type BellmanEngine = Bn256;

View file

@ -92,7 +92,6 @@ pub fn test_inner(test_path: &str) {
match t.max_constraint_count {
Some(target_count) => {
let count = bin.constraint_count();
// assert!(
// count <= target_count,
@ -102,8 +101,12 @@ pub fn test_inner(test_path: &str) {
// bin
// );
println!("{} at {}% of max", test_path, (count as f32)/(target_count as f32) * 100_f32);
},
println!(
"{} at {}% of max",
test_path,
(count as f32) / (target_count as f32) * 100_f32
);
}
_ => {}
};