1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00
This commit is contained in:
schaeff 2019-11-11 10:18:44 +01:00
parent 80e3e18de2
commit 505da3cef9
19 changed files with 661 additions and 291 deletions

6
u8.zok
View file

@ -1,2 +1,4 @@
def main(u32 a, u32 b, u16 c, u16 d, u8 e, u8 f) -> (u32, u16, u8):
return a + b, c + d, e + f
def main(u32 a, u32 b) -> (u32):
u32 c = a ^ b ^ a ^ a ^ b
u32 d = c ^ a
return d ^ a

View file

@ -379,6 +379,10 @@ impl<'ast, T: Field> From<pest::BinaryExpression<'ast>> for absy::ExpressionNode
box absy::ExpressionNode::from(*expression.left),
box absy::ExpressionNode::from(*expression.right),
),
pest::BinaryOperator::Xor => absy::Expression::Xor(
box absy::ExpressionNode::from(*expression.left),
box absy::ExpressionNode::from(*expression.right),
),
o => unimplemented!("Operator {:?} not implemented", o),
}
.span(expression.span)

View file

@ -478,6 +478,7 @@ pub enum Expression<'ast, T: Field> {
),
Member(Box<ExpressionNode<'ast, T>>, Box<Identifier<'ast>>),
Or(Box<ExpressionNode<'ast, T>>, Box<ExpressionNode<'ast, T>>),
Xor(Box<ExpressionNode<'ast, T>>, Box<ExpressionNode<'ast, T>>),
}
pub type ExpressionNode<'ast, T> = Node<Expression<'ast, T>>;
@ -538,6 +539,7 @@ impl<'ast, T: Field> fmt::Display for Expression<'ast, T> {
Expression::Select(ref array, ref index) => write!(f, "{}[{}]", array, index),
Expression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
Expression::Xor(ref lhs, ref rhs) => write!(f, "{} ^ {}", lhs, rhs),
}
}
}
@ -585,6 +587,7 @@ impl<'ast, T: Field> fmt::Debug for Expression<'ast, T> {
}
Expression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
Expression::Xor(ref lhs, ref rhs) => write!(f, "{} ^ {}", lhs, rhs),
}
}
}

View file

@ -15,6 +15,9 @@ use zokrates_field::field::Field;
pub enum FlatEmbed {
Sha256Round,
Unpack,
CheckU8,
CheckU16,
CheckU32,
}
impl FlatEmbed {
@ -32,6 +35,15 @@ impl FlatEmbed {
Type::FieldElement,
T::get_required_bits(),
)]),
FlatEmbed::CheckU8 => Signature::new()
.inputs(vec![Type::Uint(8)])
.outputs(vec![Type::array(Type::FieldElement, 8)]),
FlatEmbed::CheckU16 => Signature::new()
.inputs(vec![Type::Uint(16)])
.outputs(vec![Type::array(Type::FieldElement, 16)]),
FlatEmbed::CheckU32 => Signature::new()
.inputs(vec![Type::Uint(32)])
.outputs(vec![Type::array(Type::FieldElement, 32)]),
}
}
@ -43,6 +55,9 @@ impl FlatEmbed {
match self {
FlatEmbed::Sha256Round => "_SHA256_ROUND",
FlatEmbed::Unpack => "_UNPACK",
FlatEmbed::CheckU8 => "_CHECK_U8",
FlatEmbed::CheckU16 => "_CHECK_U16",
FlatEmbed::CheckU32 => "_CHECK_U32",
}
}
@ -50,7 +65,10 @@ impl FlatEmbed {
pub fn synthetize<T: Field>(&self) -> FlatFunction<T> {
match self {
FlatEmbed::Sha256Round => sha256_round(),
FlatEmbed::Unpack => unpack(),
FlatEmbed::Unpack => unpack_to_host_bitwidth(),
FlatEmbed::CheckU8 => unpack_to_bitwidth(8),
FlatEmbed::CheckU16 => unpack_to_bitwidth(16),
FlatEmbed::CheckU32 => unpack_to_bitwidth(32),
}
}
}
@ -210,9 +228,16 @@ fn use_variable(
/// # Remarks
/// * the return value of the `FlatFunction` is not deterministic: as we decompose over log_2(p) + 1 bits, some
/// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
pub fn unpack<T: Field>() -> FlatFunction<T> {
pub fn unpack_to_host_bitwidth<T: Field>() -> FlatFunction<T> {
unpack_to_bitwidth(T::get_required_bits())
}
/// A `FlatFunction` which checks a u8
pub fn unpack_to_bitwidth<T: Field>(width: usize) -> FlatFunction<T> {
let nbits = T::get_required_bits();
assert!(width <= nbits);
let mut counter = 0;
let mut layout = HashMap::new();
@ -229,28 +254,27 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
format!("i0"),
&mut counter,
))];
let directive_outputs: Vec<FlatVariable> = (0..T::get_required_bits())
let directive_outputs: Vec<FlatVariable> = (0..width)
.map(|index| use_variable(&mut layout, format!("o{}", index), &mut counter))
.collect();
let helper = Helper::bits(T::get_required_bits());
let helper = Helper::bits(width);
let signature = Signature {
inputs: vec![Type::FieldElement],
outputs: vec![Type::array(Type::FieldElement, nbits)],
inputs: vec![Type::Uint(width)],
outputs: vec![Type::array(Type::FieldElement, width)],
};
let outputs = directive_outputs
.iter()
.enumerate()
.filter(|(index, _)| *index >= T::get_required_bits() - nbits)
.map(|(_, o)| FlatExpression::Identifier(o.clone()))
.collect();
.collect::<Vec<_>>();
// o253, o252, ... o{253 - (nbits - 1)} are bits
let mut statements: Vec<FlatStatement<T>> = (0..nbits)
// o253, o252, ... o{253 - (width - 1)} are bits
let mut statements: Vec<FlatStatement<T>> = (0..width)
.map(|index| {
let bit = FlatExpression::Identifier(FlatVariable::new(T::get_required_bits() - index));
let bit = FlatExpression::Identifier(FlatVariable::new(width - index));
FlatStatement::Condition(
bit.clone(),
FlatExpression::Mult(box bit.clone(), box bit.clone()),
@ -258,14 +282,14 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
})
.collect();
// sum check: o253 + o252 * 2 + ... + o{253 - (nbits - 1)} * 2**(nbits - 1)
// sum check: o253 + o252 * 2 + ... + o{253 - (width - 1)} * 2**(width - 1)
let mut lhs_sum = FlatExpression::Number(T::from(0));
for i in 0..nbits {
for i in 0..width {
lhs_sum = FlatExpression::Add(
box lhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(FlatVariable::new(T::get_required_bits() - i)),
box FlatExpression::Identifier(FlatVariable::new(width - i)),
box FlatExpression::Number(T::from(2).pow(i)),
),
);

View file

@ -159,7 +159,7 @@ impl<T: Field> FlatStatement<T> {
}
}
#[derive(Clone, PartialEq, Serialize, Deserialize)]
#[derive(Clone, PartialEq, Serialize, Deserialize, Eq, Hash)]
pub enum FlatExpression<T: Field> {
Number(T),
Identifier(FlatVariable),

View file

@ -9,6 +9,7 @@ use crate::flat_absy::*;
use crate::helpers::{DirectiveStatement, Helper, RustHelper};
use crate::typed_absy::types::{FunctionIdentifier, FunctionKey, MemberId, Signature, Type};
use crate::typed_absy::*;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::convert::TryFrom;
use zokrates_field::field::Field;
@ -22,6 +23,8 @@ pub struct Flattener<'ast, T: Field> {
layout: HashMap<Identifier<'ast>, Vec<FlatVariable>>,
/// Cached `FlatFunction`s to avoid re-flattening them
flat_cache: HashMap<FunctionKey<'ast>, FlatFunction<T>>,
/// Cached bit decompositions to avoid re-generating them
bits_cache: HashMap<FlatExpression<T>, Vec<FlatVariable>>,
}
// We introduce a trait in order to make it possible to make flattening `e` generic over the type of `e`
@ -121,6 +124,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> {
impl<'ast, T: Field> Flattener<'ast, T> {
pub fn flatten(p: TypedProgram<'ast, T>) -> FlatProg<T> {
println!("{}", p);
Flattener::new().flatten_program(p)
}
@ -131,6 +135,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
next_var_idx: 0,
layout: HashMap::new(),
flat_cache: HashMap::new(),
bits_cache: HashMap::new(),
}
}
@ -869,6 +874,23 @@ impl<'ast, T: Field> Flattener<'ast, T> {
box FlatExpression::Identifier(name_x_and_y),
)
}
BooleanExpression::Xor(box lhs, box rhs) => {
let x = box self.flatten_boolean_expression(symbols, statements_flattened, lhs);
let y = box self.flatten_boolean_expression(symbols, statements_flattened, rhs);
assert!(x.is_linear() && y.is_linear());
let name_2_x_and_y = self.use_sym();
statements_flattened.push(FlatStatement::Definition(
name_2_x_and_y,
FlatExpression::Mult(
box FlatExpression::Mult(box FlatExpression::Number(T::from(2)), x.clone()),
y.clone(),
),
));
FlatExpression::Sub(
box FlatExpression::Add(x, y),
box FlatExpression::Identifier(name_2_x_and_y),
)
}
BooleanExpression::And(box lhs, box rhs) => {
let x = self.flatten_boolean_expression(symbols, statements_flattened, lhs);
let y = self.flatten_boolean_expression(symbols, statements_flattened, rhs);
@ -1041,15 +1063,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
vec![self.flatten_boolean_expression(symbols, statements_flattened, e)]
}
TypedExpression::Uint(e) => {
let e = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
bitwidth: None,
}),
..e
};
let e = e.reduce::<T>();
vec![self.flatten_uint_expression(symbols, statements_flattened, e)]
}
TypedExpression::Array(e) => match e.inner_type().clone() {
@ -1132,50 +1145,127 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatExpression::Add(box new_left, box new_right)
}
UExpressionInner::Mult(box left, box right) => {
if metadata.should_reduce.unwrap() {
unimplemented!()
let left_flattened =
self.flatten_uint_expression(symbols, statements_flattened, left);
let right_flattened =
self.flatten_uint_expression(symbols, statements_flattened, right);
let new_left = if left_flattened.is_linear() {
left_flattened
} else {
let left_flattened =
self.flatten_uint_expression(symbols, statements_flattened, left);
let right_flattened =
self.flatten_uint_expression(symbols, statements_flattened, right);
let new_left = if left_flattened.is_linear() {
left_flattened
} else {
let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
FlatExpression::Identifier(id)
};
let new_right = if right_flattened.is_linear() {
right_flattened
} else {
let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id)
};
FlatExpression::Mult(box new_left, box new_right)
}
let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
FlatExpression::Identifier(id)
};
let new_right = if right_flattened.is_linear() {
right_flattened
} else {
let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id)
};
FlatExpression::Mult(box new_left, box new_right)
}
UExpressionInner::Xor(box left, box right) => {
let left_flattened =
self.flatten_uint_expression(symbols, statements_flattened, left);
let right_flattened =
self.flatten_uint_expression(symbols, statements_flattened, right);
let left_bits =
self.get_bits(left_flattened, target_bitwidth, statements_flattened);
let right_bits =
self.get_bits(right_flattened, target_bitwidth, statements_flattened);
assert_eq!(left_bits.len(), target_bitwidth);
assert_eq!(right_bits.len(), target_bitwidth);
let name_xor = left_bits.iter().map(|_| self.use_sym()).collect::<Vec<_>>();
statements_flattened.extend(
name_xor
.iter()
.zip(left_bits.iter().zip(right_bits.iter()))
.flat_map(|(name, (x, y))| {
let name_2_x_and_y = self.use_sym();
vec![
FlatStatement::Definition(
name_2_x_and_y,
FlatExpression::Mult(
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box x.clone(),
),
box y.clone(),
),
),
FlatStatement::Definition(
*name,
FlatExpression::Sub(
box FlatExpression::Add(box x.clone(), box y.clone()),
box FlatExpression::Identifier(name_2_x_and_y),
),
),
]
}),
);
name_xor.into_iter().enumerate().fold(
FlatExpression::Number(T::from(0)),
|acc, (i, e)| {
FlatExpression::Add(
box acc,
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2).pow(target_bitwidth - i - 1)),
box e.into(),
),
)
},
)
}
UExpressionInner::Xor(box left, box right) => unimplemented!(),
};
match should_reduce {
true => {
let bits = (0..actual_bitwidth)
.map(|_| self.use_sym())
.collect::<Vec<_>>();
let bits = self.get_bits(res.clone(), actual_bitwidth, statements_flattened);
// truncate to the target bitwidth
(0..target_bitwidth).fold(FlatExpression::Number(T::from(0)), |acc, i| {
FlatExpression::Add(
box acc,
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2).pow(target_bitwidth - i - 1)),
box bits[i + actual_bitwidth - target_bitwidth].clone().into(),
),
)
})
}
false => res,
}
}
fn get_bits(
&mut self,
e: FlatExpression<T>,
bitwidth: usize,
statements_flattened: &mut Vec<FlatStatement<T>>,
) -> Vec<FlatExpression<T>> {
println!("{:?}", self.bits_cache);
match self.bits_cache.entry(e.clone()) {
Entry::Occupied(entry) => entry.get().clone().into_iter().map(|e| e.into()).collect(),
Entry::Vacant(_) => {
let bits = (0..bitwidth).map(|_| self.use_sym()).collect::<Vec<_>>();
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
bits.clone(),
Helper::Rust(RustHelper::Bits(actual_bitwidth)),
vec![res.clone()],
Helper::Rust(RustHelper::Bits(bitwidth)),
vec![e.clone()],
)));
use std::convert::TryInto;
// decompose to the actual bitwidth
// bit checks
statements_flattened.extend((0..actual_bitwidth).map(|i| {
statements_flattened.extend((0..bitwidth).map(|i| {
FlatStatement::Condition(
bits[i].clone().into(),
FlatExpression::Mult(
@ -1185,32 +1275,24 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)
}));
// sum check
statements_flattened.push(FlatStatement::Condition(
res.clone(),
(0..actual_bitwidth).fold(FlatExpression::Number(T::from(0)), |acc, i| {
FlatExpression::Add(
box acc,
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2).pow(actual_bitwidth - i - 1)),
box bits[i].into(),
),
)
}),
));
// truncate to the target bitwidth
(0..target_bitwidth).fold(FlatExpression::Number(T::from(0)), |acc, i| {
let sum = (0..bitwidth).fold(FlatExpression::Number(T::from(0)), |acc, i| {
FlatExpression::Add(
box acc,
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2).pow(target_bitwidth - i - 1)),
box bits[i + actual_bitwidth - target_bitwidth].into(),
box FlatExpression::Number(T::from(2).pow(bitwidth - i - 1)),
box bits[i].into(),
),
)
})
});
// sum check
statements_flattened.push(FlatStatement::Condition(e.clone(), sum.clone()));
self.bits_cache.insert(e, bits.clone());
self.bits_cache.insert(sum, bits.clone());
bits.into_iter().map(|v| v.into()).collect()
}
false => res,
}
}

View file

@ -0,0 +1,71 @@
//! Module containing the `RedefinitionOptimizer` to remove code of the form
// ```
// b := Directive(a)
// c := Directive(a)
// ```
// and replace by
// ```
// b := Directive(a)
// c := b
// ```
use helpers::Helper;
use crate::flat_absy::flat_variable::FlatVariable;
use crate::ir::folder::*;
use crate::ir::LinComb;
use crate::ir::*;
use num::Zero;
use std::collections::hash_map::{HashMap, Entry};
use zokrates_field::field::Field;
#[derive(Debug)]
pub struct DirectiveOptimizer<T: Field> {
calls: HashMap<(Helper, Vec<LinComb<T>>), Vec<FlatVariable>>,
/// Map of renamings for reassigned variables while processing the program.
substitution: HashMap<FlatVariable, FlatVariable>,
}
impl<T: Field> DirectiveOptimizer<T> {
fn new() -> DirectiveOptimizer<T> {
DirectiveOptimizer {
calls: HashMap::new(),
substitution: HashMap::new(),
}
}
pub fn optimize(p: Prog<T>) -> Prog<T> {
DirectiveOptimizer::new().fold_module(p)
}
}
impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
println!("{:?}", s);
match s {
Statement::Directive(d) => {
let d = self.fold_directive(d);
match self.calls.entry((d.helper.clone(), d.inputs.clone())) {
Entry::Vacant(e) => {
e.insert(d.outputs.clone());
vec![Statement::Directive(d)]
},
Entry::Occupied(e) => {
self.substitution.extend(d.outputs.into_iter().zip(e.get().into_iter().cloned()));
vec![]
}
}
},
s => fold_statement(self, s)
}
}
fn fold_variable(&mut self, v: FlatVariable) -> FlatVariable {
*self.substitution.get(&v).unwrap_or(&v)
}
}
#[cfg(test)]
mod tests {
}

View file

@ -4,10 +4,12 @@
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
mod directive;
mod duplicate;
mod redefinition;
mod tautology;
use self::directive::DirectiveOptimizer;
use self::duplicate::DuplicateOptimizer;
use self::redefinition::RedefinitionOptimizer;
use self::tautology::TautologyOptimizer;
@ -25,6 +27,8 @@ impl<T: Field> Optimize for Prog<T> {
let r = RedefinitionOptimizer::optimize(self);
// remove constraints that are always satisfied
let r = TautologyOptimizer::optimize(r);
// deduplicate directives which take the same input
let r = DirectiveOptimizer::optimize(r);
// remove duplicate constraints
let r = DuplicateOptimizer::optimize(r);
r

View file

@ -738,7 +738,7 @@ impl<'ast> Checker<'ast> {
types: &TypeMap,
) -> Result<Variable<'ast>, Vec<Error>> {
Ok(Variable::with_id_and_type(
v.value.id.into(),
v.value.id,
self.check_type(v.value._type, module_id, types)
.map_err(|e| vec![e])?,
))
@ -933,7 +933,7 @@ impl<'ast> Checker<'ast> {
match assignee.value {
Assignee::Identifier(variable_name) => match self.get_scope(&variable_name) {
Some(var) => Ok(TypedAssignee::Identifier(Variable::with_id_and_type(
variable_name.into(),
variable_name,
var.id._type.clone(),
))),
None => Err(Error {
@ -1903,6 +1903,39 @@ impl<'ast> Checker<'ast> {
}),
}
}
Expression::Xor(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
match (e1_checked, e2_checked) {
(TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => {
Ok(BooleanExpression::Xor(box e1, box e2).into())
}
(TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => {
if e1.get_type() == e2.get_type() {
Ok(UExpression::xor(e1, e2).into())
} else {
Err(Error {
pos: Some(pos),
message: format!(
"Cannot apply `^` to {}, {}",
e1.get_type(),
e2.get_type()
),
})
}
}
(e1, e2) => Err(Error {
pos: Some(pos),
message: format!(
"Cannot apply `^` to {}, {}",
e1.get_type(),
e2.get_type()
),
}),
}
}
Expression::Not(box e) => {
let e_checked = self.check_expression(e, module_id, &types)?;
match e_checked {

View file

@ -33,12 +33,14 @@ use crate::typed_absy::*;
use zokrates_field::field::Field;
pub struct InputConstrainer<'ast, T: Field> {
next_var_id: usize,
constraints: Vec<TypedStatement<'ast, T>>,
}
impl<'ast, T: Field> InputConstrainer<'ast, T> {
fn new() -> Self {
InputConstrainer {
next_var_id: 0,
constraints: vec![],
}
}
@ -47,6 +49,28 @@ impl<'ast, T: Field> InputConstrainer<'ast, T> {
InputConstrainer::new().fold_program(p)
}
fn constrain_bits(&mut self, u: UExpression<'ast>) {
let bitwidth = u.bitwidth;
let bit_input = Variable::with_id_and_type(
CoreIdentifier::Internal("bit_input_array", self.next_var_id),
Type::array(Type::FieldElement, bitwidth),
);
self.next_var_id += 1;
self.constraints.push(TypedStatement::MultipleDefinition(
vec![bit_input],
TypedExpressionList::FunctionCall(
match bitwidth {
8 => crate::embed::FlatEmbed::CheckU8.key::<T>(),
16 => crate::embed::FlatEmbed::CheckU16.key::<T>(),
32 => crate::embed::FlatEmbed::CheckU32.key::<T>(),
_ => unreachable!()
},
vec![u.into()],
vec![Type::array(Type::FieldElement, bitwidth)],
),
));
}
fn constrain_expression(&mut self, e: TypedExpression<'ast, T>) {
match e {
TypedExpression::FieldElement(_) => {}
@ -54,8 +78,8 @@ impl<'ast, T: Field> InputConstrainer<'ast, T> {
b.clone().into(),
BooleanExpression::And(box b.clone(), box b).into(),
)),
TypedExpression::Uint(bitwidth) => {
// TODO constrain by checking that it decomposes correctly
TypedExpression::Uint(u) => {
self.constrain_bits(u);
}
TypedExpression::Array(a) => {
for i in 0..a.size() {

View file

@ -72,6 +72,18 @@ impl<'ast, T: Field> Inliner<'ast, T> {
let sha256_round = crate::embed::FlatEmbed::Sha256Round;
let sha256_round_key = sha256_round.key::<T>();
// define a function in the main module for the `check_u8` embed
let check_u8 = crate::embed::FlatEmbed::CheckU8;
let check_u8_key = check_u8.key::<T>();
// define a function in the main module for the `check_u8` embed
let check_u16 = crate::embed::FlatEmbed::CheckU16;
let check_u16_key = check_u16.key::<T>();
// define a function in the main module for the `check_u8` embed
let check_u32 = crate::embed::FlatEmbed::CheckU32;
let check_u32_key = check_u32.key::<T>();
// return a program with a single module containing `main`, `_UNPACK`, and `_SHA256_ROUND
TypedProgram {
main: String::from("main"),
@ -81,6 +93,9 @@ impl<'ast, T: Field> Inliner<'ast, T> {
functions: vec![
(unpack_key, TypedFunctionSymbol::Flat(unpack)),
(sha256_round_key, TypedFunctionSymbol::Flat(sha256_round)),
(check_u8_key, TypedFunctionSymbol::Flat(check_u8)),
(check_u16_key, TypedFunctionSymbol::Flat(check_u16)),
(check_u32_key, TypedFunctionSymbol::Flat(check_u32)),
(main_key, main),
]
.into_iter()

View file

@ -9,11 +9,13 @@ mod flat_propagation;
mod inline;
mod propagation;
mod unroll;
mod uint_optimizer;
use self::constrain_inputs::InputConstrainer;
use self::inline::Inliner;
use self::propagation::Propagator;
use self::unroll::Unroller;
use self::uint_optimizer::UintOptimizer;
use crate::flat_absy::FlatProg;
use crate::typed_absy::TypedProgram;
use zokrates_field::field::Field;
@ -32,6 +34,8 @@ impl<'ast, T: Field> Analyse for TypedProgram<'ast, T> {
let r = Propagator::propagate(r);
// constrain inputs
let r = InputConstrainer::constrain(r);
// optimize uint expressions
let r = UintOptimizer::optimize(r);
r
}
}

View file

@ -0,0 +1,261 @@
use typed_absy::bitwidth;
use std::collections::HashMap;
use zokrates_field::field::Field;
use typed_absy::folder::*;
use crate::typed_absy::*;
use std::marker::PhantomData;
#[derive(Default)]
pub struct UintOptimizer<'ast, T> {
ids: HashMap<Identifier<'ast>, UMetadata>,
phantom: PhantomData<T>
}
impl<'ast, T: Field> UintOptimizer<'ast, T> {
pub fn new() -> Self {
UintOptimizer {
ids: HashMap::new(),
phantom: PhantomData
}
}
pub fn optimize(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
UintOptimizer::new().fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
fn fold_uint_expression(&mut self, e: UExpression<'ast>) -> UExpression<'ast> {
let max_bitwidth = T::get_required_bits() - 1;
let range = e.bitwidth;
assert!(range < max_bitwidth / 2);
let metadata = e.metadata;
let inner = e.inner;
use self::UExpressionInner::*;
match inner {
Value(v) => Value(v).annotate(range).metadata(UMetadata {
bitwidth: Some(bitwidth(v)),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
}),
Identifier(id) => Identifier(id.clone()).annotate(range).metadata(self.ids.get(&id).cloned().unwrap_or(UMetadata {
bitwidth: Some(range),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})),
Add(box left, box right) => {
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
println!("{} + {}", left, right);
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
let left_bitwidth = left_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
left_metadata.bitwidth.unwrap()
}
})
.unwrap();
let right_bitwidth = right_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
right_metadata.bitwidth.unwrap()
}
})
.unwrap();
let output_width = std::cmp::max(left_bitwidth, right_bitwidth) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
println!("{}", output_width);
if output_width > max_bitwidth {
// the addition doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
let left = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..left_metadata
}),
..left
};
let right = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..right_metadata
}),
..right
};
UExpression::add(left, right).metadata(UMetadata {
bitwidth: Some(range + 1),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
} else {
// the addition fits, so we just add
UExpression::add(left, right).metadata(UMetadata {
bitwidth: Some(output_width),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
}
}
Xor(box left, box right) => {
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
// for xor we need both terms to be in range. Therefore we reduce them to being in range.
// NB: if they are already in range, the flattening process will ignore the reduction
let left = left.metadata(UMetadata {
should_reduce: Some(true),
..left_metadata
});
let right = right.metadata(UMetadata {
should_reduce: Some(true),
..right_metadata
});
UExpression::xor(left, right).metadata(UMetadata {
bitwidth: Some(range),
should_reduce: Some(true),
})
}
Mult(box left, box right) => {
// reduce the two terms
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
let left_bitwidth = left_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
left_metadata.bitwidth.unwrap()
}
})
.unwrap();
let right_bitwidth = right_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
right_metadata.bitwidth.unwrap()
}
})
.unwrap();
let output_width = left_bitwidth + right_bitwidth; // bitwidth(a*b) = bitwidth(a) + bitwidth(b)
if output_width > max_bitwidth {
// the multiplication doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
let left = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..left_metadata
}),
..left
};
let right = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..right_metadata
}),
..right
};
UExpression::mult(left, right).metadata(UMetadata {
bitwidth: Some(2 * range),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
} else {
// the multiplication fits, so we just multiply
UExpression::mult(left, right).metadata(UMetadata {
bitwidth: Some(output_width),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
}
}
}
}
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Definition(TypedAssignee::Identifier(id), TypedExpression::Uint(e)) => {
let e = self.fold_uint_expression(e);
self.ids.insert(id.id.clone(), e.metadata.clone().unwrap());
vec![TypedStatement::Definition(TypedAssignee::Identifier(id), TypedExpression::Uint(e))]
},
// we need to put back in range to return
TypedStatement::Return(expressions) => vec![TypedStatement::Return(expressions
.into_iter()
.map(|e| match e {
TypedExpression::Uint(e) => {
let e = self.fold_uint_expression(e);
let e = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
bitwidth: e.metadata.unwrap().bitwidth,
}),
..e
};
TypedExpression::Uint(e)
}
e => self.fold_expression(e)
})
.collect())],
s => fold_statement(self, s)
}
}
}

View file

@ -26,7 +26,7 @@ impl<'ast> Unroller<'ast> {
let res = match self.substitution.get(&v.id) {
Some(i) => Variable {
id: Identifier {
id: v.id.id,
id: v.id.clone().id,
version: i + 1,
stack: vec![],
},

View file

@ -350,6 +350,11 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_boolean_expression(e2);
BooleanExpression::Or(box e1, box e2)
}
BooleanExpression::Xor(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1);
let e2 = f.fold_boolean_expression(e2);
BooleanExpression::Xor(box e1, box e2)
}
BooleanExpression::And(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1);
let e2 = f.fold_boolean_expression(e2);

View file

@ -2,11 +2,26 @@ use std::fmt;
use typed_absy::types::FunctionKey;
use typed_absy::TypedModuleId;
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
pub enum CoreIdentifier<'ast> {
Source(&'ast str),
Internal(&'static str, usize),
}
impl<'ast> fmt::Display for CoreIdentifier<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
CoreIdentifier::Source(s) => write!(f, "{}", s),
CoreIdentifier::Internal(s, i) => write!(f, "#INTERNAL#_{}_{}", s, i),
}
}
}
/// A identifier for a variable
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
pub struct Identifier<'ast> {
/// the id of the variable
pub id: &'ast str,
pub id: CoreIdentifier<'ast>,
/// the version of the variable, used after SSA transformation
pub version: usize,
/// the call stack of the variable, used when inlining
@ -35,6 +50,12 @@ impl<'ast> fmt::Display for Identifier<'ast> {
impl<'ast> From<&'ast str> for Identifier<'ast> {
fn from(id: &'ast str) -> Identifier<'ast> {
Identifier::from(CoreIdentifier::Source(id))
}
}
impl<'ast> From<CoreIdentifier<'ast>> for Identifier<'ast> {
fn from(id: CoreIdentifier<'ast>) -> Identifier<'ast> {
Identifier {
id,
version: 0,

View file

@ -12,10 +12,11 @@ pub mod types;
mod uint;
mod variable;
pub use self::identifier::CoreIdentifier;
pub use self::parameter::Parameter;
pub use self::types::Type;
pub use self::variable::Variable;
pub use typed_absy::uint::{UExpression, UExpressionInner, UMetadata};
pub use typed_absy::uint::{UExpression, UExpressionInner, UMetadata, bitwidth};
use crate::typed_absy::types::{FunctionKey, MemberId, Signature};
use embed::FlatEmbed;
@ -571,6 +572,10 @@ pub enum BooleanExpression<'ast, T: Field> {
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
Xor(
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
And(
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
@ -784,7 +789,13 @@ impl<'ast, T: Field> fmt::Display for FieldElementExpression<'ast, T> {
impl<'ast> fmt::Display for UExpression<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
unimplemented!()
match self.inner {
UExpressionInner::Value(ref v) => write!(f, "{}", v),
UExpressionInner::Identifier(ref var) => write!(f, "{}", var),
UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
UExpressionInner::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
UExpressionInner::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs),
}
}
}
@ -799,6 +810,7 @@ impl<'ast, T: Field> fmt::Display for BooleanExpression<'ast, T> {
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
BooleanExpression::Xor(ref lhs, ref rhs) => write!(f, "{} ^ {}", lhs, rhs),
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
BooleanExpression::Value(b) => write!(f, "{}", b),

View file

@ -68,7 +68,7 @@ impl<'ast> UExpressionInner<'ast> {
}
impl<'ast> UExpression<'ast> {
fn metadata(self, metadata: UMetadata) -> UExpression<'ast> {
pub fn metadata(self, metadata: UMetadata) -> UExpression<'ast> {
UExpression {
metadata: Some(metadata),
..self
@ -76,208 +76,10 @@ impl<'ast> UExpression<'ast> {
}
}
fn bitwidth(a: u128) -> Bitwidth {
pub fn bitwidth(a: u128) -> Bitwidth {
(128 - a.leading_zeros()) as Bitwidth
}
impl<'ast> UExpression<'ast> {
pub fn reduce<T: Field>(self) -> Self {
let max_bitwidth = T::get_required_bits() - 1;
let range = self.bitwidth;
assert!(range < max_bitwidth / 2);
let metadata = self.metadata;
let inner = self.inner;
use self::UExpressionInner::*;
match inner {
Value(v) => Value(v).annotate(range).metadata(UMetadata {
bitwidth: Some(bitwidth(v)),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
}),
Identifier(id) => Identifier(id).annotate(range).metadata(UMetadata {
bitwidth: Some(range),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
}),
Add(box left, box right) => {
// reduce the two terms
let left = left.reduce::<T>();
let right = right.reduce::<T>();
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
let left_bitwidth = left_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
left_metadata.bitwidth.unwrap()
}
})
.unwrap();
let right_bitwidth = right_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
right_metadata.bitwidth.unwrap()
}
})
.unwrap();
let output_width = std::cmp::max(left_bitwidth, right_bitwidth) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
if output_width > max_bitwidth {
// the addition doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
let left = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..left_metadata
}),
..left
};
let right = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..right_metadata
}),
..right
};
UExpression::add(left, right).metadata(UMetadata {
bitwidth: Some(range + 1),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
} else {
// the addition fits, so we just add
UExpression::add(left, right).metadata(UMetadata {
bitwidth: Some(output_width),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
}
}
Xor(box left, box right) => {
// reduce the two terms
let left = left.reduce::<T>();
let right = right.reduce::<T>();
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
// for xor we need both terms to be in range. Therefore we reduce them to being in range.
// NB: if they are already in range, the flattening process will ignore the reduction
let left = left.metadata(UMetadata {
should_reduce: Some(true),
..left_metadata
});
let right = right.metadata(UMetadata {
should_reduce: Some(true),
..right_metadata
});
UExpression::xor(left, right)
}
Mult(box left, box right) => {
// reduce the two terms
let left = left.reduce::<T>();
let right = right.reduce::<T>();
let left_metadata = left.metadata.clone().unwrap();
let right_metadata = right.metadata.clone().unwrap();
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
let left_bitwidth = left_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
left_metadata.bitwidth.unwrap()
}
})
.unwrap();
let right_bitwidth = right_metadata
.should_reduce
.map(|should_reduce| {
if should_reduce {
range
} else {
right_metadata.bitwidth.unwrap()
}
})
.unwrap();
let output_width = left_bitwidth + right_bitwidth; // bitwidth(a*b) = bitwidth(a) + bitwidth(b)
if output_width > max_bitwidth {
// the multiplication doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
let left = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..left_metadata
}),
..left
};
let right = UExpression {
metadata: Some(UMetadata {
should_reduce: Some(true),
..right_metadata
}),
..right
};
UExpression::mult(left, right).metadata(UMetadata {
bitwidth: Some(2 * range),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
} else {
// the multiplication fits, so we just multiply
UExpression::mult(left, right).metadata(UMetadata {
bitwidth: Some(output_width),
should_reduce: Some(
metadata
.map(|m| m.should_reduce.unwrap_or(false))
.unwrap_or(false),
),
})
}
}
}
}
}
impl<'ast> UExpression<'ast> {
pub fn bitwidth(&self) -> Bitwidth {
self.bitwidth

View file

@ -9,7 +9,7 @@ pub struct Variable<'ast> {
}
impl<'ast> Variable<'ast> {
pub fn field_element(id: Identifier<'ast>) -> Variable<'ast> {
pub fn field_element<I: Into<Identifier<'ast>>>(id: I) -> Variable<'ast> {
Self::with_id_and_type(id, Type::FieldElement)
}
@ -18,20 +18,23 @@ impl<'ast> Variable<'ast> {
}
#[cfg(test)]
pub fn field_array(id: Identifier<'ast>, size: usize) -> Variable<'ast> {
pub fn field_array<I: Into<Identifier<'ast>>>(id: I, size: usize) -> Variable<'ast> {
Self::array(id, Type::FieldElement, size)
}
pub fn array(id: Identifier<'ast>, ty: Type, size: usize) -> Variable<'ast> {
pub fn array<I: Into<Identifier<'ast>>>(id: I, ty: Type, size: usize) -> Variable<'ast> {
Self::with_id_and_type(id, Type::array(ty, size))
}
pub fn struc(id: Identifier<'ast>, ty: Vec<(MemberId, Type)>) -> Variable<'ast> {
pub fn struc<I: Into<Identifier<'ast>>>(id: I, ty: Vec<(MemberId, Type)>) -> Variable<'ast> {
Self::with_id_and_type(id, Type::Struct(ty))
}
pub fn with_id_and_type(id: Identifier<'ast>, _type: Type) -> Variable<'ast> {
Variable { id, _type }
pub fn with_id_and_type<I: Into<Identifier<'ast>>>(id: I, _type: Type) -> Variable<'ast> {
Variable {
id: id.into(),
_type,
}
}
pub fn get_type(&self) -> Type {