refactor uint optimizer, remove bitwidth
This commit is contained in:
parent
d05ee17640
commit
2be4859b28
18 changed files with 795 additions and 959 deletions
8
test.zok
8
test.zok
|
@ -1,6 +1,2 @@
|
|||
def main(u32[1000] a) -> (u32):
|
||||
u32 res = 0x00000000
|
||||
for field i in 0..1000 do
|
||||
res = res + a[i]
|
||||
endfor
|
||||
return res
|
||||
def main(u8 a) -> (u8):
|
||||
return a + a
|
|
@ -152,6 +152,8 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
|
|||
// analyse (unroll and constant propagation)
|
||||
let typed_ast = typed_ast.analyse();
|
||||
|
||||
println!("{:#?}", typed_ast);
|
||||
|
||||
// flatten input program
|
||||
let program_flattened = Flattener::flatten(typed_ast);
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -21,21 +21,22 @@ impl<T: Field> Prog<T> {
|
|||
Statement::Constraint(quad, lin) => {
|
||||
println!("{}", statement);
|
||||
match lin.is_assignee(&witness) {
|
||||
true => {
|
||||
let val = quad.evaluate(&witness).unwrap();
|
||||
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
|
||||
}
|
||||
false => {
|
||||
let lhs_value = quad.evaluate(&witness).unwrap();
|
||||
let rhs_value = lin.evaluate(&witness).unwrap();
|
||||
if lhs_value != rhs_value {
|
||||
return Err(Error::UnsatisfiedConstraint {
|
||||
left: lhs_value.to_dec_string(),
|
||||
right: rhs_value.to_dec_string(),
|
||||
});
|
||||
true => {
|
||||
let val = quad.evaluate(&witness).unwrap();
|
||||
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
|
||||
}
|
||||
}}
|
||||
},
|
||||
false => {
|
||||
let lhs_value = quad.evaluate(&witness).unwrap();
|
||||
let rhs_value = lin.evaluate(&witness).unwrap();
|
||||
if lhs_value != rhs_value {
|
||||
return Err(Error::UnsatisfiedConstraint {
|
||||
left: lhs_value.to_dec_string(),
|
||||
right: rhs_value.to_dec_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Statement::Directive(ref d) => {
|
||||
let input_values: Vec<T> = d
|
||||
.inputs
|
||||
|
|
|
@ -4049,32 +4049,28 @@ mod tests {
|
|||
});
|
||||
|
||||
assert_eq!(
|
||||
checker.check_parameter(
|
||||
absy::Parameter {
|
||||
id:
|
||||
absy::Variable::new("a", UnresolvedType::User("Foo".into()).mock(),)
|
||||
.mock(),
|
||||
private: true,
|
||||
}
|
||||
.mock(),
|
||||
&PathBuf::from(MODULE_ID).into(),
|
||||
&state.types,
|
||||
),
|
||||
Ok(Parameter {
|
||||
id: Variable::with_id_and_type(
|
||||
<<<<<<< HEAD
|
||||
"a",
|
||||
=======
|
||||
"a".into(),
|
||||
>>>>>>> b0382ea64e8df4bbdf363fc6fc4c3862900629e7
|
||||
Type::Struct(vec![StructMember::new(
|
||||
"foo".to_string(),
|
||||
Type::FieldElement
|
||||
)])
|
||||
),
|
||||
private: true
|
||||
})
|
||||
);
|
||||
checker.check_parameter(
|
||||
absy::Parameter {
|
||||
id:
|
||||
absy::Variable::new("a", UnresolvedType::User("Foo".into()).mock(),)
|
||||
.mock(),
|
||||
private: true,
|
||||
}
|
||||
.mock(),
|
||||
&PathBuf::from(MODULE_ID).into(),
|
||||
&state.types,
|
||||
),
|
||||
Ok(Parameter {
|
||||
id: Variable::with_id_and_type(
|
||||
"a",
|
||||
Type::Struct(vec![StructMember::new(
|
||||
"foo".to_string(),
|
||||
Type::FieldElement
|
||||
)])
|
||||
),
|
||||
private: true
|
||||
})
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
checker
|
||||
|
@ -4112,27 +4108,23 @@ mod tests {
|
|||
});
|
||||
|
||||
assert_eq!(
|
||||
checker.check_statement::<FieldPrime>(
|
||||
Statement::Declaration(
|
||||
absy::Variable::new("a", UnresolvedType::User("Foo".into()).mock(),)
|
||||
.mock()
|
||||
)
|
||||
.mock(),
|
||||
&PathBuf::from(MODULE_ID).into(),
|
||||
&state.types,
|
||||
),
|
||||
Ok(TypedStatement::Declaration(Variable::with_id_and_type(
|
||||
<<<<<<< HEAD
|
||||
"a",
|
||||
=======
|
||||
"a".into(),
|
||||
>>>>>>> b0382ea64e8df4bbdf363fc6fc4c3862900629e7
|
||||
Type::Struct(vec![StructMember::new(
|
||||
"foo".to_string(),
|
||||
Type::FieldElement
|
||||
)])
|
||||
)))
|
||||
);
|
||||
checker.check_statement::<FieldPrime>(
|
||||
Statement::Declaration(
|
||||
absy::Variable::new("a", UnresolvedType::User("Foo".into()).mock(),)
|
||||
.mock()
|
||||
)
|
||||
.mock(),
|
||||
&PathBuf::from(MODULE_ID).into(),
|
||||
&state.types,
|
||||
),
|
||||
Ok(TypedStatement::Declaration(Variable::with_id_and_type(
|
||||
"a",
|
||||
Type::Struct(vec![StructMember::new(
|
||||
"foo".to_string(),
|
||||
Type::FieldElement
|
||||
)])
|
||||
)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
checker
|
||||
|
|
|
@ -65,7 +65,7 @@ impl Solver {
|
|||
.into_iter()
|
||||
.map(|x| T::from_bellman(x))
|
||||
.collect()
|
||||
},
|
||||
}
|
||||
Solver::Xor => {
|
||||
let x = inputs[0].clone();
|
||||
let y = inputs[1].clone();
|
||||
|
|
|
@ -133,25 +133,4 @@ impl<'ast, T: Field> Folder<'ast, T> for InputConstrainer<'ast, T> {
|
|||
..f
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_uint_expression(&mut self, e: UExpression<'ast, T>) -> UExpression<'ast, T> {
|
||||
match e.inner {
|
||||
UExpressionInner::Identifier(ref id) => {
|
||||
if self.uints.contains(id) {
|
||||
use std::convert::TryInto;
|
||||
UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
max: Some(2_usize.pow(e.bitwidth.try_into().unwrap()).into()),
|
||||
bitwidth: Some(e.bitwidth),
|
||||
should_reduce: Some(false),
|
||||
}),
|
||||
..e
|
||||
}
|
||||
} else {
|
||||
e
|
||||
}
|
||||
}
|
||||
_ => fold_uint_expression(self, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -710,7 +710,9 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
|
||||
zir::UExpressionInner::Not(box e)
|
||||
}
|
||||
typed_absy::UExpressionInner::FunctionCall(key, exps) => unreachable!("function calls should have been removed"),
|
||||
typed_absy::UExpressionInner::FunctionCall(key, exps) => {
|
||||
unreachable!("function calls should have been removed")
|
||||
}
|
||||
typed_absy::UExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_field_expression(index);
|
||||
|
|
|
@ -136,10 +136,10 @@ impl<'ast, T: Field> Inliner<'ast, T> {
|
|||
TypedModule {
|
||||
functions: vec![
|
||||
(unpack_key, TypedFunctionSymbol::Flat(unpack)),
|
||||
// (sha256_round_key, TypedFunctionSymbol::Flat(sha256_round)),
|
||||
// (check_u8_key, TypedFunctionSymbol::Flat(check_u8)),
|
||||
// (check_u16_key, TypedFunctionSymbol::Flat(check_u16)),
|
||||
// (check_u32_key, TypedFunctionSymbol::Flat(check_u32)),
|
||||
(sha256_round_key, TypedFunctionSymbol::Flat(sha256_round)),
|
||||
(check_u8_key, TypedFunctionSymbol::Flat(check_u8)),
|
||||
(check_u16_key, TypedFunctionSymbol::Flat(check_u16)),
|
||||
(check_u32_key, TypedFunctionSymbol::Flat(check_u32)),
|
||||
(u32_to_bits_key, TypedFunctionSymbol::Flat(u32_to_bits)),
|
||||
(u32_from_bits_key, TypedFunctionSymbol::Flat(u32_from_bits)),
|
||||
(main_key, main),
|
||||
|
@ -398,13 +398,20 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
|
|||
self.statement_buffer
|
||||
.push(TypedStatement::MultipleDefinition(
|
||||
vec![Variable::with_id_and_type(id.clone(), tys[0].clone())],
|
||||
TypedExpressionList::FunctionCall(key.clone(), expressions.clone(), tys),
|
||||
TypedExpressionList::FunctionCall(
|
||||
key.clone(),
|
||||
expressions.clone(),
|
||||
tys,
|
||||
),
|
||||
));
|
||||
|
||||
self.call_cache_mut()
|
||||
.entry(key.clone())
|
||||
.or_insert_with(|| HashMap::new())
|
||||
.insert(expressions, vec![BooleanExpression::Identifier(id.clone()).into()]);
|
||||
.insert(
|
||||
expressions,
|
||||
vec![BooleanExpression::Identifier(id.clone()).into()],
|
||||
);
|
||||
|
||||
BooleanExpression::Identifier(id)
|
||||
}
|
||||
|
@ -431,7 +438,6 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
|
|||
_ => unreachable!(),
|
||||
},
|
||||
Err((embed_key, expressions)) => {
|
||||
|
||||
let tys = key.signature.outputs.clone();
|
||||
let id = Identifier {
|
||||
id: CoreIdentifier::Call(key.clone()),
|
||||
|
@ -444,7 +450,11 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
|
|||
self.statement_buffer
|
||||
.push(TypedStatement::MultipleDefinition(
|
||||
vec![Variable::with_id_and_type(id.clone(), tys[0].clone())],
|
||||
TypedExpressionList::FunctionCall(embed_key.clone(), expressions.clone(), tys),
|
||||
TypedExpressionList::FunctionCall(
|
||||
embed_key.clone(),
|
||||
expressions.clone(),
|
||||
tys,
|
||||
),
|
||||
));
|
||||
|
||||
let out = ArrayExpressionInner::Identifier(id);
|
||||
|
@ -452,7 +462,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
|
|||
self.call_cache_mut()
|
||||
.entry(key.clone())
|
||||
.or_insert_with(|| HashMap::new())
|
||||
.insert(expressions, vec![out.clone().annotate(ty.clone(), size).into()]);
|
||||
.insert(
|
||||
expressions,
|
||||
vec![out.clone().annotate(ty.clone(), size).into()],
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
@ -528,7 +541,11 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
|
|||
self.statement_buffer
|
||||
.push(TypedStatement::MultipleDefinition(
|
||||
vec![Variable::with_id_and_type(id.clone(), tys[0].clone())],
|
||||
TypedExpressionList::FunctionCall(embed_key.clone(), expressions.clone(), tys),
|
||||
TypedExpressionList::FunctionCall(
|
||||
embed_key.clone(),
|
||||
expressions.clone(),
|
||||
tys,
|
||||
),
|
||||
));
|
||||
|
||||
let out = UExpressionInner::Identifier(id);
|
||||
|
@ -833,12 +850,12 @@ mod tests {
|
|||
FunctionKey::with_id("main").signature(signature.clone()),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![Parameter {
|
||||
id: Variable::field_element("a".into()),
|
||||
id: Variable::field_element("a"),
|
||||
private: true,
|
||||
}],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("b".into())),
|
||||
TypedAssignee::Identifier(Variable::field_element("b")),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::FunctionCall(
|
||||
FunctionKey::with_id("foo").signature(signature.clone()),
|
||||
|
@ -876,7 +893,7 @@ mod tests {
|
|||
FunctionKey::with_id("foo").signature(signature.clone()),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![Parameter {
|
||||
id: Variable::field_element("a".into()),
|
||||
id: Variable::field_element("a"),
|
||||
private: true,
|
||||
}],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
|
@ -911,7 +928,7 @@ mod tests {
|
|||
.unwrap(),
|
||||
&TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![Parameter {
|
||||
id: Variable::field_element("a".into()),
|
||||
id: Variable::field_element("a"),
|
||||
private: true,
|
||||
}],
|
||||
statements: vec![
|
||||
|
@ -926,7 +943,7 @@ mod tests {
|
|||
FieldElementExpression::Identifier("a".into()).into()
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("b".into())),
|
||||
TypedAssignee::Identifier(Variable::field_element("b")),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from("a").stack(
|
||||
vec![(
|
||||
|
@ -987,12 +1004,12 @@ mod tests {
|
|||
),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![Parameter {
|
||||
id: Variable::field_element("a".into()),
|
||||
id: Variable::field_element("a"),
|
||||
private: true,
|
||||
}],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("b".into())),
|
||||
TypedAssignee::Identifier(Variable::field_element("b")),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::FunctionCall(
|
||||
FunctionKey::with_id("foo").signature(signature.clone()),
|
||||
|
@ -1017,7 +1034,7 @@ mod tests {
|
|||
FunctionKey::with_id("bar").signature(signature.clone()),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![Parameter {
|
||||
id: Variable::field_element("a".into()),
|
||||
id: Variable::field_element("a"),
|
||||
private: true,
|
||||
}],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
|
@ -1047,7 +1064,7 @@ mod tests {
|
|||
FunctionKey::with_id("foo").signature(signature.clone()),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![Parameter {
|
||||
id: Variable::field_element("a".into()),
|
||||
id: Variable::field_element("a"),
|
||||
private: true,
|
||||
}],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
|
@ -1082,7 +1099,7 @@ mod tests {
|
|||
.unwrap(),
|
||||
&TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![Parameter {
|
||||
id: Variable::field_element("a".into()),
|
||||
id: Variable::field_element("a"),
|
||||
private: true,
|
||||
}],
|
||||
statements: vec![
|
||||
|
@ -1129,7 +1146,7 @@ mod tests {
|
|||
.into()
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("b".into())),
|
||||
TypedAssignee::Identifier(Variable::field_element("b")),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from("a").stack(
|
||||
vec![(
|
||||
|
|
|
@ -10,26 +10,26 @@ mod flatten_complex_types;
|
|||
mod inline;
|
||||
mod propagate_unroll;
|
||||
mod propagation;
|
||||
mod uint_optimizer;
|
||||
mod redefinition;
|
||||
mod return_binder;
|
||||
mod uint_optimizer;
|
||||
mod unconstrained_vars;
|
||||
mod unroll;
|
||||
mod redefinition;
|
||||
|
||||
use self::constrain_inputs::InputConstrainer;
|
||||
use self::flatten_complex_types::Flattener;
|
||||
use self::inline::Inliner;
|
||||
use self::propagate_unroll::PropagatedUnroller;
|
||||
use self::propagation::Propagator;
|
||||
use self::uint_optimizer::UintOptimizer;
|
||||
use self::redefinition::RedefinitionOptimizer;
|
||||
use self::return_binder::ReturnBinder;
|
||||
use self::uint_optimizer::UintOptimizer;
|
||||
use self::unconstrained_vars::UnconstrainedVariableDetector;
|
||||
use crate::flat_absy::FlatProg;
|
||||
use crate::ir::Prog;
|
||||
use crate::typed_absy::TypedProgram;
|
||||
use zir::ZirProgram;
|
||||
use self::return_binder::ReturnBinder;
|
||||
use self::unconstrained_vars::UnconstrainedVariableDetector;
|
||||
use zokrates_field::field::Field;
|
||||
use crate::ir::Prog;
|
||||
|
||||
pub trait Analyse {
|
||||
fn analyse(self) -> Self;
|
||||
|
|
|
@ -10,24 +10,24 @@
|
|||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
|
||||
use std::marker::PhantomData;
|
||||
use crate::typed_absy::folder::*;
|
||||
use crate::typed_absy::*;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use std::marker::PhantomData;
|
||||
use typed_absy::types::{StructMember, Type};
|
||||
use zokrates_field::field::Field;
|
||||
|
||||
pub struct RedefinitionOptimizer<'ast, T: Field> {
|
||||
identifiers: HashMap<Identifier<'ast>, Identifier<'ast>>,
|
||||
phantom: PhantomData<T>
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> RedefinitionOptimizer<'ast, T> {
|
||||
fn new() -> Self {
|
||||
RedefinitionOptimizer {
|
||||
identifiers: HashMap::new(),
|
||||
phantom: PhantomData
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,19 +71,16 @@ impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer<'ast, T> {
|
|||
Some(id) => {
|
||||
let target = self.identifiers.get(&id).unwrap_or(&id).clone();
|
||||
|
||||
self.identifiers
|
||||
.insert(var.id, target);
|
||||
self.identifiers.insert(var.id, target);
|
||||
vec![]
|
||||
},
|
||||
None => {
|
||||
vec![TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(var),
|
||||
expr,
|
||||
)]
|
||||
}
|
||||
None => vec![TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(var),
|
||||
expr,
|
||||
)],
|
||||
}
|
||||
},
|
||||
s => fold_statement(self, s)
|
||||
}
|
||||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,9 @@ impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder {
|
|||
Type::Struct(struct_type) => StructExpressionInner::Identifier(i.clone())
|
||||
.annotate(struct_type)
|
||||
.into(),
|
||||
Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone()).annotate(bitwidth).into()
|
||||
Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone())
|
||||
.annotate(bitwidth)
|
||||
.into(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
use num_bigint::BigUint;
|
||||
use crate::embed::FlatEmbed;
|
||||
use crate::zir::*;
|
||||
use num_bigint::BigUint;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
use zir::bitwidth;
|
||||
use zir::folder::*;
|
||||
use zokrates_field::field::Field;
|
||||
|
||||
|
@ -26,12 +25,27 @@ impl<'ast, T: Field> UintOptimizer<'ast, T> {
|
|||
}
|
||||
|
||||
fn register(&mut self, a: ZirAssignee<'ast>, m: UMetadata) {
|
||||
// match (a, m) {
|
||||
// (a, ZirExpression::Uint(e)) => {
|
||||
self.ids.insert(a, m);
|
||||
// }
|
||||
// _ => {}
|
||||
// }
|
||||
self.ids.insert(a, m);
|
||||
}
|
||||
}
|
||||
|
||||
fn force_reduce<'ast, T: Field>(e: UExpression<'ast, T>) -> UExpression<'ast, T> {
|
||||
UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..e.metadata.unwrap()
|
||||
}),
|
||||
..e
|
||||
}
|
||||
}
|
||||
|
||||
fn force_no_reduce<'ast, T: Field>(e: UExpression<'ast, T>) -> UExpression<'ast, T> {
|
||||
UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(false),
|
||||
..e.metadata.unwrap()
|
||||
}),
|
||||
..e
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -39,8 +53,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
fn fold_uint_expression(&mut self, e: UExpression<'ast, T>) -> UExpression<'ast, T> {
|
||||
let max_bitwidth = T::get_required_bits() - 1;
|
||||
|
||||
let max = T::max_value().into_big_uint();
|
||||
|
||||
let range = e.bitwidth;
|
||||
|
||||
let range_max: BigUint = (2_usize.pow(range as u32) - 1).into();
|
||||
|
@ -48,7 +60,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
assert!(range < max_bitwidth / 2);
|
||||
|
||||
if e.metadata.is_some() {
|
||||
return e;
|
||||
unreachable!("{:?} had metadata", e);
|
||||
}
|
||||
|
||||
let metadata = e.metadata;
|
||||
|
@ -58,13 +70,9 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
|
||||
let res = match inner {
|
||||
Value(v) => Value(v).annotate(range).metadata(UMetadata {
|
||||
max: Some(v.into()),
|
||||
bitwidth: Some(range),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
max: v.into(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
}),
|
||||
Identifier(id) => Identifier(id.clone()).annotate(range).metadata(
|
||||
self.ids
|
||||
|
@ -73,181 +81,127 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
.expect(&format!("identifier should have been defined: {}", id)),
|
||||
),
|
||||
Add(box left, box right) => {
|
||||
use num::CheckedAdd;
|
||||
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
let left_max = left.metadata.clone().unwrap().max;
|
||||
let right_max = right.metadata.clone().unwrap().max;
|
||||
|
||||
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
|
||||
let left_after = left_metadata.after(range);
|
||||
let (should_reduce_left, should_reduce_right, max) = left_max
|
||||
.checked_add(&right_max)
|
||||
.map(|max| (false, false, max))
|
||||
.unwrap_or_else(|| {
|
||||
range_max
|
||||
.clone()
|
||||
.checked_add(&right_max)
|
||||
.map(|max| (true, false, max))
|
||||
.unwrap_or_else(|| {
|
||||
left_max
|
||||
.checked_add(&range_max.clone())
|
||||
.map(|max| (false, true, max))
|
||||
.unwrap_or_else(|| (true, true, range_max.clone() + range_max))
|
||||
})
|
||||
});
|
||||
|
||||
let right_after = right_metadata.after(range);
|
||||
|
||||
let left_max = left_after.clone().max.unwrap();
|
||||
|
||||
let right_max = right_after.clone().max.unwrap();
|
||||
|
||||
//let output_width = std::cmp::max(left_after.bitwidth.unwrap(), right_after.bitwidth.unwrap()) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
|
||||
|
||||
if left_max.clone() + right_max.clone() > max {
|
||||
// the addition doesnt fit, we try to reduce one term
|
||||
|
||||
if left_max.clone() + range_max.clone() <= max {
|
||||
let right = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_after
|
||||
}),
|
||||
..right
|
||||
};
|
||||
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
max: Some(left_max.clone() + range_max.clone()),
|
||||
bitwidth: Some(std::cmp::max(left_after.bitwidth.unwrap(), range) + 1),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
} else if right_max.clone() + range_max.clone() <= max {
|
||||
let left = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_after
|
||||
}),
|
||||
..left
|
||||
};
|
||||
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
max: Some(left_max.clone() + right_max.clone()),
|
||||
bitwidth: Some(std::cmp::max(right_after.bitwidth.unwrap(), range) + 1),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
} else {
|
||||
let left = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_after
|
||||
}),
|
||||
..left
|
||||
};
|
||||
|
||||
let right = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_after
|
||||
}),
|
||||
..right
|
||||
};
|
||||
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
max: Some(range_max.clone() * 2_u32),
|
||||
bitwidth: Some(range + 1),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
}
|
||||
let left = if should_reduce_left {
|
||||
force_reduce(left)
|
||||
} else {
|
||||
// the addition fits, so we just add
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
max: Some(left_max.clone() + right_max.clone()),
|
||||
bitwidth: Some(std::cmp::max(left_after.bitwidth.unwrap(), right_after.bitwidth.unwrap())),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
}
|
||||
left
|
||||
};
|
||||
let right = if should_reduce_right {
|
||||
force_reduce(right)
|
||||
} else {
|
||||
right
|
||||
};
|
||||
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
max,
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
Sub(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
unimplemented!()
|
||||
// // reduce the two terms
|
||||
// let left = self.fold_uint_expression(left);
|
||||
// let right = self.fold_uint_expression(right);
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
// let left_metadata = left.metadata.clone().unwrap();
|
||||
// let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
|
||||
let left_bitwidth = left_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
left_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let right_bitwidth = right_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
right_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
// // determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
|
||||
// let left_bitwidth = left_metadata
|
||||
// .should_reduce
|
||||
// .map(|should_reduce| {
|
||||
// if should_reduce {
|
||||
// range
|
||||
// } else {
|
||||
// left_metadata.bitwidth.unwrap()
|
||||
// }
|
||||
// })
|
||||
// .unwrap();
|
||||
// let right_bitwidth = right_metadata
|
||||
// .should_reduce
|
||||
// .map(|should_reduce| {
|
||||
// if should_reduce {
|
||||
// range
|
||||
// } else {
|
||||
// right_metadata.bitwidth.unwrap()
|
||||
// }
|
||||
// })
|
||||
// .unwrap();
|
||||
|
||||
// a(p), b(q) both of target n (p and q their real bitwidth)
|
||||
// a(p) - b(q) can always underflow
|
||||
// instead consider s = a(p) - b(q) + 2**q which is always positive
|
||||
// the min of s is 0 and the max is 2**p + 2**q, which is smaller than 2**(max(p, q) + 1)
|
||||
// // a(p), b(q) both of target n (p and q their real bitwidth)
|
||||
// // a(p) - b(q) can always underflow
|
||||
// // instead consider s = a(p) - b(q) + 2**q which is always positive
|
||||
// // the min of s is 0 and the max is 2**p + 2**q, which is smaller than 2**(max(p, q) + 1)
|
||||
|
||||
// so we can use s(max(p, q) + 1) as a representation of a - b if max(p, q) + 1 < max_bitwidth
|
||||
// // so we can use s(max(p, q) + 1) as a representation of a - b if max(p, q) + 1 < max_bitwidth
|
||||
|
||||
let output_width = std::cmp::max(left_bitwidth, right_bitwidth) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
|
||||
// let output_width = std::cmp::max(left_bitwidth, right_bitwidth) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
|
||||
|
||||
if output_width > max_bitwidth {
|
||||
// the addition doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
|
||||
// if output_width > max_bitwidth {
|
||||
// // the addition doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
|
||||
|
||||
let left = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
}),
|
||||
..left
|
||||
};
|
||||
// let left = UExpression {
|
||||
// metadata: Some(UMetadata {
|
||||
// should_reduce: Some(true),
|
||||
// ..left_metadata
|
||||
// }),
|
||||
// ..left
|
||||
// };
|
||||
|
||||
let right = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
}),
|
||||
..right
|
||||
};
|
||||
// let right = UExpression {
|
||||
// metadata: Some(UMetadata {
|
||||
// should_reduce: Some(true),
|
||||
// ..right_metadata
|
||||
// }),
|
||||
// ..right
|
||||
// };
|
||||
|
||||
UExpression::sub(left, right).metadata(UMetadata {
|
||||
max: Some(2_u32 * range_max),
|
||||
bitwidth: Some(range + 1),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
} else {
|
||||
UExpression::sub(left, right).metadata(UMetadata {
|
||||
max: None,
|
||||
bitwidth: Some(output_width),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
}
|
||||
// UExpression::sub(left, right).metadata(UMetadata {
|
||||
// max: 2_u32 * range_max,
|
||||
// bitwidth: Some(range + 1),
|
||||
// should_reduce: Some(
|
||||
// metadata
|
||||
// .map(|m| m.should_reduce.unwrap_or(false))
|
||||
// .unwrap_or(false),
|
||||
// ),
|
||||
// })
|
||||
// } else {
|
||||
// UExpression::sub(left, right).metadata(UMetadata {
|
||||
// max: None,
|
||||
// bitwidth: Some(output_width),
|
||||
// should_reduce: Some(
|
||||
// metadata
|
||||
// .map(|m| m.should_reduce.unwrap_or(false))
|
||||
// .unwrap_or(false),
|
||||
// ),
|
||||
// })
|
||||
// }
|
||||
}
|
||||
Xor(box left, box right) => {
|
||||
// reduce the two terms
|
||||
|
@ -270,8 +224,8 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
});
|
||||
|
||||
UExpression::xor(left, right).metadata(UMetadata {
|
||||
max: Some(range_max.clone()),
|
||||
bitwidth: Some(range),
|
||||
max: range_max.clone(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
|
@ -280,24 +234,9 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// for and we need both terms to be in range. Therefore we reduce them to being in range.
|
||||
// NB: if they are already in range, the flattening process will ignore the reduction
|
||||
let left = left.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
});
|
||||
|
||||
let right = right.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
});
|
||||
|
||||
UExpression::and(left, right).metadata(UMetadata {
|
||||
max: Some(range_max.clone()),
|
||||
bitwidth: Some(range),
|
||||
UExpression::and(force_reduce(left), force_reduce(right)).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
|
@ -306,142 +245,63 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// for xor we need both terms to be in range. Therefore we reduce them to being in range.
|
||||
// NB: if they are already in range, the flattening process will ignore the reduction
|
||||
let left = left.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
});
|
||||
|
||||
let right = right.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
});
|
||||
|
||||
UExpression::or(left, right).metadata(UMetadata {
|
||||
max: None,
|
||||
bitwidth: Some(range),
|
||||
UExpression::or(force_reduce(left), force_reduce(right)).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
Mult(box left, box right) => {
|
||||
use num::CheckedMul;
|
||||
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
let left_max = left.metadata.clone().unwrap().max;
|
||||
let right_max = right.metadata.clone().unwrap().max;
|
||||
|
||||
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
|
||||
let left_after = left_metadata.after(range);
|
||||
let (should_reduce_left, should_reduce_right, max) = left_max
|
||||
.checked_mul(&right_max)
|
||||
.map(|max| (false, false, max))
|
||||
.unwrap_or_else(|| {
|
||||
range_max
|
||||
.clone()
|
||||
.checked_mul(&right_max)
|
||||
.map(|max| (true, false, max))
|
||||
.unwrap_or_else(|| {
|
||||
left_max
|
||||
.checked_mul(&range_max.clone())
|
||||
.map(|max| (false, true, max))
|
||||
.unwrap_or_else(|| (true, true, range_max.clone() * range_max))
|
||||
})
|
||||
});
|
||||
|
||||
let right_after = right_metadata.after(range);
|
||||
let output_width = left_after.bitwidth.clone().unwrap() + right_after.bitwidth.clone().unwrap(); // bitwidth(a*b) = bitwidth(a) + bitwidth(b)
|
||||
|
||||
let left_max = left_after.clone().max.unwrap();
|
||||
|
||||
let right_max = right_after.clone().max.unwrap();
|
||||
|
||||
if left_max.clone() * right_max.clone() > max {
|
||||
// the addition doesnt fit, we try to reduce one term
|
||||
|
||||
if left_max.clone() * range_max.clone() <= max {
|
||||
let right = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_after
|
||||
}),
|
||||
..right
|
||||
};
|
||||
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
max: Some(left_max.clone() * range_max.clone()),
|
||||
bitwidth: Some(std::cmp::max(left_after.bitwidth.unwrap(), range) + 1),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
} else if right_max.clone() * range_max.clone() <= max {
|
||||
let left = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_after
|
||||
}),
|
||||
..left
|
||||
};
|
||||
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
max: Some(left_max.clone() * right_max.clone()),
|
||||
bitwidth: Some(std::cmp::max(right_after.bitwidth.unwrap(), range) + 1),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
} else {
|
||||
let left = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_after
|
||||
}),
|
||||
..left
|
||||
};
|
||||
|
||||
let right = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_after
|
||||
}),
|
||||
..right
|
||||
};
|
||||
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
max: Some(range_max.clone() * range_max.clone()),
|
||||
bitwidth: Some(range + 1),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
}
|
||||
let left = if should_reduce_left {
|
||||
force_reduce(left)
|
||||
} else {
|
||||
// the addition fits, so we just add
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
max: Some(left_max.clone() * right_max.clone()),
|
||||
bitwidth: Some(std::cmp::max(left_after.bitwidth.unwrap(), right_after.bitwidth.unwrap())),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
}
|
||||
left
|
||||
};
|
||||
let right = if should_reduce_right {
|
||||
force_reduce(right)
|
||||
} else {
|
||||
right
|
||||
};
|
||||
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
max,
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
Not(box e) => {
|
||||
let e = self.fold_uint_expression(e);
|
||||
|
||||
let e_metadata = e.metadata.clone().unwrap();
|
||||
|
||||
let e_bitwidth = range;
|
||||
|
||||
let e = e.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..e_metadata
|
||||
});
|
||||
|
||||
UExpressionInner::Not(box e)
|
||||
UExpressionInner::Not(box force_reduce(e))
|
||||
.annotate(range)
|
||||
.metadata(UMetadata {
|
||||
max: Some(range_max.clone()),
|
||||
bitwidth: Some(range),
|
||||
should_reduce: Some(true),
|
||||
max: range_max.clone(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
LeftShift(box e, box by) => {
|
||||
|
@ -449,18 +309,8 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
let e = self.fold_uint_expression(e);
|
||||
let by = self.fold_field_expression(by);
|
||||
|
||||
let e_metadata = e.metadata.clone().unwrap();
|
||||
|
||||
// for shift we need the expression to be in range. Therefore we reduce them to being in range.
|
||||
// NB: if they are already in range, the flattening process will ignore the reduction
|
||||
let e = e.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..e_metadata
|
||||
});
|
||||
|
||||
UExpression::left_shift(e, by).metadata(UMetadata {
|
||||
max: Some(range_max.clone()),
|
||||
bitwidth: Some(range),
|
||||
UExpression::left_shift(force_reduce(e), by).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
should_reduce: Some(true),
|
||||
})
|
||||
}
|
||||
|
@ -469,62 +319,28 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
let e = self.fold_uint_expression(e);
|
||||
let by = self.fold_field_expression(by);
|
||||
|
||||
let e_metadata = e.metadata.clone().unwrap();
|
||||
|
||||
// for shift we need the expression to be in range. Therefore we reduce them to being in range.
|
||||
// NB: if they are already in range, the flattening process will ignore the reduction
|
||||
let e = e.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..e_metadata
|
||||
});
|
||||
|
||||
UExpression::right_shift(e, by).metadata(UMetadata {
|
||||
max: Some(range_max.clone()),
|
||||
bitwidth: Some(range),
|
||||
should_reduce: Some(true),
|
||||
UExpression::right_shift(force_reduce(e), by).metadata(UMetadata {
|
||||
max: range_max.clone(),
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
IfElse(box condition, box consequence, box alternative) => {
|
||||
let consequence = self.fold_uint_expression(consequence);
|
||||
let alternative = self.fold_uint_expression(alternative);
|
||||
|
||||
let consequence_metadata = consequence.metadata.clone().unwrap();
|
||||
let alternative_metadata = alternative.metadata.clone().unwrap();
|
||||
let consequence_max = consequence.metadata.clone().unwrap().max;
|
||||
let alternative_max = alternative.metadata.clone().unwrap().max;
|
||||
|
||||
let consequence_bitwidth = consequence_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
consequence_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let alternative_bitwidth = alternative_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
alternative_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let output_width = std::cmp::max(consequence_bitwidth, alternative_bitwidth);
|
||||
let max = std::cmp::max(consequence_max, alternative_max);
|
||||
|
||||
UExpression::if_else(condition, consequence, alternative).metadata(UMetadata {
|
||||
max: None,
|
||||
bitwidth: Some(output_width),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
max,
|
||||
|
||||
should_reduce: Some(false),
|
||||
})
|
||||
}
|
||||
e => fold_uint_expression_inner(self, range, e).annotate(range),
|
||||
e => unimplemented!("{:?}", e),
|
||||
};
|
||||
|
||||
assert!(res.metadata.is_some());
|
||||
|
@ -536,11 +352,14 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
match s {
|
||||
ZirStatement::Definition(a, e) => {
|
||||
let e = self.fold_expression(e);
|
||||
match e {
|
||||
ZirExpression::Uint(ref i) => {
|
||||
|
||||
let e = match e {
|
||||
ZirExpression::Uint(i) => {
|
||||
let i = force_no_reduce(i);
|
||||
self.register(a.clone(), i.metadata.clone().unwrap());
|
||||
},
|
||||
_ => {}
|
||||
ZirExpression::Uint(i)
|
||||
}
|
||||
e => e,
|
||||
};
|
||||
vec![ZirStatement::Definition(a, e)]
|
||||
}
|
||||
|
@ -570,29 +389,29 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
ZirExpressionList::FunctionCall(key, arguments, ty) => match key.clone().id {
|
||||
"_U32_FROM_BITS" => {
|
||||
assert_eq!(lhs.len(), 1);
|
||||
self.register(lhs[0].clone(), UMetadata {
|
||||
max: Some(BigUint::from(2_u64.pow(32_u32) - 1)),
|
||||
bitwidth: Some(32),
|
||||
should_reduce: Some(true),
|
||||
});
|
||||
self.register(
|
||||
lhs[0].clone(),
|
||||
UMetadata {
|
||||
max: BigUint::from(2_u64.pow(32_u32) - 1),
|
||||
should_reduce: Some(false),
|
||||
},
|
||||
);
|
||||
vec![ZirStatement::MultipleDefinition(
|
||||
lhs,
|
||||
ZirExpressionList::FunctionCall(key, arguments, ty),
|
||||
)]
|
||||
}
|
||||
_ => {
|
||||
vec![ZirStatement::MultipleDefinition(
|
||||
lhs,
|
||||
ZirExpressionList::FunctionCall(
|
||||
key,
|
||||
arguments
|
||||
.into_iter()
|
||||
.map(|e| self.fold_expression(e))
|
||||
.collect(),
|
||||
ty,
|
||||
),
|
||||
)]
|
||||
}
|
||||
_ => vec![ZirStatement::MultipleDefinition(
|
||||
lhs,
|
||||
ZirExpressionList::FunctionCall(
|
||||
key,
|
||||
arguments
|
||||
.into_iter()
|
||||
.map(|e| self.fold_expression(e))
|
||||
.collect(),
|
||||
ty,
|
||||
),
|
||||
)],
|
||||
},
|
||||
},
|
||||
// we need to put back in range to assert
|
||||
|
@ -620,6 +439,27 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> {
|
||||
let id = match p.id.get_type() {
|
||||
Type::Uint(bitwidth) => {
|
||||
self.register(
|
||||
p.id.clone(),
|
||||
UMetadata {
|
||||
max: BigUint::from(2_u64.pow(bitwidth as u32) - 1),
|
||||
should_reduce: Some(false),
|
||||
},
|
||||
);
|
||||
p.id
|
||||
}
|
||||
_ => p.id,
|
||||
};
|
||||
|
||||
Parameter {
|
||||
id: self.fold_variable(id),
|
||||
..p
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -632,8 +472,7 @@ mod tests {
|
|||
let e = UExpressionInner::Identifier("foo".into())
|
||||
.annotate(32)
|
||||
.metadata(UMetadata {
|
||||
max: None,
|
||||
bitwidth: Some(33),
|
||||
max: BigUint::from(2_u64.pow(33_u32) - 1),
|
||||
should_reduce: Some(false),
|
||||
});
|
||||
|
||||
|
|
|
@ -138,7 +138,7 @@ mod tests {
|
|||
arguments: vec![_0],
|
||||
statements: vec![
|
||||
Statement::Directive(Directive {
|
||||
inputs: vec![LinComb::summand(-42, one) + LinComb::summand(1, _0)],
|
||||
inputs: vec![(LinComb::summand(-42, one) + LinComb::summand(1, _0)).into()],
|
||||
outputs: vec![_1, _2],
|
||||
solver: Solver::ConditionEq,
|
||||
}),
|
||||
|
|
|
@ -231,28 +231,5 @@ pub mod signature {
|
|||
|
||||
assert_eq!(s.to_slug(), String::from("i3fofbf"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_slug() {
|
||||
let s = Signature::new()
|
||||
.inputs(vec![
|
||||
Type::array(Type::FieldElement, 42),
|
||||
Type::array(Type::FieldElement, 21),
|
||||
])
|
||||
.outputs(vec![]);
|
||||
|
||||
assert_eq!(s.to_slug(), String::from("if[42]f[21]o"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn array() {
|
||||
let t = Type::Array(box Type::FieldElement, 42);
|
||||
assert_eq!(t.get_primitive_count(), 42);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -69,25 +69,10 @@ impl<'ast, T: Field> From<&'ast str> for UExpressionInner<'ast, T> {
|
|||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct UMetadata {
|
||||
pub bitwidth: Option<Bitwidth>,
|
||||
pub max: Option<BigUint>,
|
||||
pub max: BigUint,
|
||||
pub should_reduce: Option<bool>,
|
||||
}
|
||||
|
||||
impl UMetadata {
|
||||
pub fn after(self, bitwidth: usize) -> Self {
|
||||
use std::convert::TryInto;
|
||||
match self.should_reduce.unwrap() {
|
||||
true => UMetadata {
|
||||
should_reduce: Some(false),
|
||||
bitwidth: Some(bitwidth),
|
||||
max: Some(2_usize.pow(bitwidth.try_into().unwrap()).into())
|
||||
},
|
||||
false => self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct UExpression<'ast, T: Field> {
|
||||
pub bitwidth: Bitwidth,
|
||||
|
|
|
@ -56,6 +56,8 @@ pub trait Field:
|
|||
+ Pow<usize, Output = Self>
|
||||
+ Pow<Self, Output = Self>
|
||||
+ for<'a> Pow<&'a Self, Output = Self>
|
||||
+ num_traits::CheckedAdd
|
||||
+ num_traits::CheckedMul
|
||||
{
|
||||
/// An associated type to be able to operate with Bellman ff traits
|
||||
type BellmanEngine: Engine;
|
||||
|
@ -101,6 +103,28 @@ pub struct FieldPrime {
|
|||
value: BigInt,
|
||||
}
|
||||
|
||||
impl num_traits::CheckedAdd for FieldPrime {
|
||||
fn checked_add(&self, other: &Self) -> Option<Self> {
|
||||
let res = self.value.clone() + other.value.clone();
|
||||
if res >= *P {
|
||||
None
|
||||
} else {
|
||||
Some(FieldPrime { value: res })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl num_traits::CheckedMul for FieldPrime {
|
||||
fn checked_mul(&self, other: &Self) -> Option<Self> {
|
||||
let res = self.value.clone() * other.value.clone();
|
||||
if res >= *P {
|
||||
None
|
||||
} else {
|
||||
Some(FieldPrime { value: res })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Field for FieldPrime {
|
||||
type BellmanEngine = Bn256;
|
||||
|
||||
|
|
|
@ -92,7 +92,6 @@ pub fn test_inner(test_path: &str) {
|
|||
|
||||
match t.max_constraint_count {
|
||||
Some(target_count) => {
|
||||
|
||||
let count = bin.constraint_count();
|
||||
// assert!(
|
||||
// count <= target_count,
|
||||
|
@ -102,8 +101,12 @@ pub fn test_inner(test_path: &str) {
|
|||
// bin
|
||||
// );
|
||||
|
||||
println!("{} at {}% of max", test_path, (count as f32)/(target_count as f32) * 100_f32);
|
||||
},
|
||||
println!(
|
||||
"{} at {}% of max",
|
||||
test_path,
|
||||
(count as f32) / (target_count as f32) * 100_f32
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in a new issue