add xor
This commit is contained in:
parent
80e3e18de2
commit
505da3cef9
19 changed files with 661 additions and 291 deletions
6
u8.zok
6
u8.zok
|
@ -1,2 +1,4 @@
|
|||
def main(u32 a, u32 b, u16 c, u16 d, u8 e, u8 f) -> (u32, u16, u8):
|
||||
return a + b, c + d, e + f
|
||||
def main(u32 a, u32 b) -> (u32):
|
||||
u32 c = a ^ b ^ a ^ a ^ b
|
||||
u32 d = c ^ a
|
||||
return d ^ a
|
|
@ -379,6 +379,10 @@ impl<'ast, T: Field> From<pest::BinaryExpression<'ast>> for absy::ExpressionNode
|
|||
box absy::ExpressionNode::from(*expression.left),
|
||||
box absy::ExpressionNode::from(*expression.right),
|
||||
),
|
||||
pest::BinaryOperator::Xor => absy::Expression::Xor(
|
||||
box absy::ExpressionNode::from(*expression.left),
|
||||
box absy::ExpressionNode::from(*expression.right),
|
||||
),
|
||||
o => unimplemented!("Operator {:?} not implemented", o),
|
||||
}
|
||||
.span(expression.span)
|
||||
|
|
|
@ -478,6 +478,7 @@ pub enum Expression<'ast, T: Field> {
|
|||
),
|
||||
Member(Box<ExpressionNode<'ast, T>>, Box<Identifier<'ast>>),
|
||||
Or(Box<ExpressionNode<'ast, T>>, Box<ExpressionNode<'ast, T>>),
|
||||
Xor(Box<ExpressionNode<'ast, T>>, Box<ExpressionNode<'ast, T>>),
|
||||
}
|
||||
|
||||
pub type ExpressionNode<'ast, T> = Node<Expression<'ast, T>>;
|
||||
|
@ -538,6 +539,7 @@ impl<'ast, T: Field> fmt::Display for Expression<'ast, T> {
|
|||
Expression::Select(ref array, ref index) => write!(f, "{}[{}]", array, index),
|
||||
Expression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
|
||||
Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
|
||||
Expression::Xor(ref lhs, ref rhs) => write!(f, "{} ^ {}", lhs, rhs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -585,6 +587,7 @@ impl<'ast, T: Field> fmt::Debug for Expression<'ast, T> {
|
|||
}
|
||||
Expression::Member(ref struc, ref id) => write!(f, "{}.{}", struc, id),
|
||||
Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
|
||||
Expression::Xor(ref lhs, ref rhs) => write!(f, "{} ^ {}", lhs, rhs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,9 @@ use zokrates_field::field::Field;
|
|||
pub enum FlatEmbed {
|
||||
Sha256Round,
|
||||
Unpack,
|
||||
CheckU8,
|
||||
CheckU16,
|
||||
CheckU32,
|
||||
}
|
||||
|
||||
impl FlatEmbed {
|
||||
|
@ -32,6 +35,15 @@ impl FlatEmbed {
|
|||
Type::FieldElement,
|
||||
T::get_required_bits(),
|
||||
)]),
|
||||
FlatEmbed::CheckU8 => Signature::new()
|
||||
.inputs(vec![Type::Uint(8)])
|
||||
.outputs(vec![Type::array(Type::FieldElement, 8)]),
|
||||
FlatEmbed::CheckU16 => Signature::new()
|
||||
.inputs(vec![Type::Uint(16)])
|
||||
.outputs(vec![Type::array(Type::FieldElement, 16)]),
|
||||
FlatEmbed::CheckU32 => Signature::new()
|
||||
.inputs(vec![Type::Uint(32)])
|
||||
.outputs(vec![Type::array(Type::FieldElement, 32)]),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,6 +55,9 @@ impl FlatEmbed {
|
|||
match self {
|
||||
FlatEmbed::Sha256Round => "_SHA256_ROUND",
|
||||
FlatEmbed::Unpack => "_UNPACK",
|
||||
FlatEmbed::CheckU8 => "_CHECK_U8",
|
||||
FlatEmbed::CheckU16 => "_CHECK_U16",
|
||||
FlatEmbed::CheckU32 => "_CHECK_U32",
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,7 +65,10 @@ impl FlatEmbed {
|
|||
pub fn synthetize<T: Field>(&self) -> FlatFunction<T> {
|
||||
match self {
|
||||
FlatEmbed::Sha256Round => sha256_round(),
|
||||
FlatEmbed::Unpack => unpack(),
|
||||
FlatEmbed::Unpack => unpack_to_host_bitwidth(),
|
||||
FlatEmbed::CheckU8 => unpack_to_bitwidth(8),
|
||||
FlatEmbed::CheckU16 => unpack_to_bitwidth(16),
|
||||
FlatEmbed::CheckU32 => unpack_to_bitwidth(32),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -210,9 +228,16 @@ fn use_variable(
|
|||
/// # Remarks
|
||||
/// * the return value of the `FlatFunction` is not deterministic: as we decompose over log_2(p) + 1 bits, some
|
||||
/// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
|
||||
pub fn unpack<T: Field>() -> FlatFunction<T> {
|
||||
pub fn unpack_to_host_bitwidth<T: Field>() -> FlatFunction<T> {
|
||||
unpack_to_bitwidth(T::get_required_bits())
|
||||
}
|
||||
|
||||
/// A `FlatFunction` which checks a u8
|
||||
pub fn unpack_to_bitwidth<T: Field>(width: usize) -> FlatFunction<T> {
|
||||
let nbits = T::get_required_bits();
|
||||
|
||||
assert!(width <= nbits);
|
||||
|
||||
let mut counter = 0;
|
||||
|
||||
let mut layout = HashMap::new();
|
||||
|
@ -229,28 +254,27 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
|
|||
format!("i0"),
|
||||
&mut counter,
|
||||
))];
|
||||
let directive_outputs: Vec<FlatVariable> = (0..T::get_required_bits())
|
||||
let directive_outputs: Vec<FlatVariable> = (0..width)
|
||||
.map(|index| use_variable(&mut layout, format!("o{}", index), &mut counter))
|
||||
.collect();
|
||||
|
||||
let helper = Helper::bits(T::get_required_bits());
|
||||
let helper = Helper::bits(width);
|
||||
|
||||
let signature = Signature {
|
||||
inputs: vec![Type::FieldElement],
|
||||
outputs: vec![Type::array(Type::FieldElement, nbits)],
|
||||
inputs: vec![Type::Uint(width)],
|
||||
outputs: vec![Type::array(Type::FieldElement, width)],
|
||||
};
|
||||
|
||||
let outputs = directive_outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(index, _)| *index >= T::get_required_bits() - nbits)
|
||||
.map(|(_, o)| FlatExpression::Identifier(o.clone()))
|
||||
.collect();
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// o253, o252, ... o{253 - (nbits - 1)} are bits
|
||||
let mut statements: Vec<FlatStatement<T>> = (0..nbits)
|
||||
// o253, o252, ... o{253 - (width - 1)} are bits
|
||||
let mut statements: Vec<FlatStatement<T>> = (0..width)
|
||||
.map(|index| {
|
||||
let bit = FlatExpression::Identifier(FlatVariable::new(T::get_required_bits() - index));
|
||||
let bit = FlatExpression::Identifier(FlatVariable::new(width - index));
|
||||
FlatStatement::Condition(
|
||||
bit.clone(),
|
||||
FlatExpression::Mult(box bit.clone(), box bit.clone()),
|
||||
|
@ -258,14 +282,14 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
|
|||
})
|
||||
.collect();
|
||||
|
||||
// sum check: o253 + o252 * 2 + ... + o{253 - (nbits - 1)} * 2**(nbits - 1)
|
||||
// sum check: o253 + o252 * 2 + ... + o{253 - (width - 1)} * 2**(width - 1)
|
||||
let mut lhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..nbits {
|
||||
for i in 0..width {
|
||||
lhs_sum = FlatExpression::Add(
|
||||
box lhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(FlatVariable::new(T::get_required_bits() - i)),
|
||||
box FlatExpression::Identifier(FlatVariable::new(width - i)),
|
||||
box FlatExpression::Number(T::from(2).pow(i)),
|
||||
),
|
||||
);
|
||||
|
|
|
@ -159,7 +159,7 @@ impl<T: Field> FlatStatement<T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Clone, PartialEq, Serialize, Deserialize, Eq, Hash)]
|
||||
pub enum FlatExpression<T: Field> {
|
||||
Number(T),
|
||||
Identifier(FlatVariable),
|
||||
|
|
|
@ -9,6 +9,7 @@ use crate::flat_absy::*;
|
|||
use crate::helpers::{DirectiveStatement, Helper, RustHelper};
|
||||
use crate::typed_absy::types::{FunctionIdentifier, FunctionKey, MemberId, Signature, Type};
|
||||
use crate::typed_absy::*;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use zokrates_field::field::Field;
|
||||
|
@ -22,6 +23,8 @@ pub struct Flattener<'ast, T: Field> {
|
|||
layout: HashMap<Identifier<'ast>, Vec<FlatVariable>>,
|
||||
/// Cached `FlatFunction`s to avoid re-flattening them
|
||||
flat_cache: HashMap<FunctionKey<'ast>, FlatFunction<T>>,
|
||||
/// Cached bit decompositions to avoid re-generating them
|
||||
bits_cache: HashMap<FlatExpression<T>, Vec<FlatVariable>>,
|
||||
}
|
||||
|
||||
// We introduce a trait in order to make it possible to make flattening `e` generic over the type of `e`
|
||||
|
@ -121,6 +124,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for ArrayExpression<'ast, T> {
|
|||
|
||||
impl<'ast, T: Field> Flattener<'ast, T> {
|
||||
pub fn flatten(p: TypedProgram<'ast, T>) -> FlatProg<T> {
|
||||
println!("{}", p);
|
||||
Flattener::new().flatten_program(p)
|
||||
}
|
||||
|
||||
|
@ -131,6 +135,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
next_var_idx: 0,
|
||||
layout: HashMap::new(),
|
||||
flat_cache: HashMap::new(),
|
||||
bits_cache: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -869,6 +874,23 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
box FlatExpression::Identifier(name_x_and_y),
|
||||
)
|
||||
}
|
||||
BooleanExpression::Xor(box lhs, box rhs) => {
|
||||
let x = box self.flatten_boolean_expression(symbols, statements_flattened, lhs);
|
||||
let y = box self.flatten_boolean_expression(symbols, statements_flattened, rhs);
|
||||
assert!(x.is_linear() && y.is_linear());
|
||||
let name_2_x_and_y = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
name_2_x_and_y,
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Mult(box FlatExpression::Number(T::from(2)), x.clone()),
|
||||
y.clone(),
|
||||
),
|
||||
));
|
||||
FlatExpression::Sub(
|
||||
box FlatExpression::Add(x, y),
|
||||
box FlatExpression::Identifier(name_2_x_and_y),
|
||||
)
|
||||
}
|
||||
BooleanExpression::And(box lhs, box rhs) => {
|
||||
let x = self.flatten_boolean_expression(symbols, statements_flattened, lhs);
|
||||
let y = self.flatten_boolean_expression(symbols, statements_flattened, rhs);
|
||||
|
@ -1041,15 +1063,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
vec![self.flatten_boolean_expression(symbols, statements_flattened, e)]
|
||||
}
|
||||
TypedExpression::Uint(e) => {
|
||||
let e = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
bitwidth: None,
|
||||
}),
|
||||
..e
|
||||
};
|
||||
let e = e.reduce::<T>();
|
||||
|
||||
vec![self.flatten_uint_expression(symbols, statements_flattened, e)]
|
||||
}
|
||||
TypedExpression::Array(e) => match e.inner_type().clone() {
|
||||
|
@ -1132,50 +1145,127 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FlatExpression::Add(box new_left, box new_right)
|
||||
}
|
||||
UExpressionInner::Mult(box left, box right) => {
|
||||
if metadata.should_reduce.unwrap() {
|
||||
unimplemented!()
|
||||
let left_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, left);
|
||||
let right_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, right);
|
||||
let new_left = if left_flattened.is_linear() {
|
||||
left_flattened
|
||||
} else {
|
||||
let left_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, left);
|
||||
let right_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, right);
|
||||
let new_left = if left_flattened.is_linear() {
|
||||
left_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
let new_right = if right_flattened.is_linear() {
|
||||
right_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
FlatExpression::Mult(box new_left, box new_right)
|
||||
}
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
let new_right = if right_flattened.is_linear() {
|
||||
right_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
FlatExpression::Mult(box new_left, box new_right)
|
||||
}
|
||||
UExpressionInner::Xor(box left, box right) => {
|
||||
let left_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, left);
|
||||
let right_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, right);
|
||||
|
||||
let left_bits =
|
||||
self.get_bits(left_flattened, target_bitwidth, statements_flattened);
|
||||
let right_bits =
|
||||
self.get_bits(right_flattened, target_bitwidth, statements_flattened);
|
||||
|
||||
assert_eq!(left_bits.len(), target_bitwidth);
|
||||
assert_eq!(right_bits.len(), target_bitwidth);
|
||||
|
||||
let name_xor = left_bits.iter().map(|_| self.use_sym()).collect::<Vec<_>>();
|
||||
|
||||
statements_flattened.extend(
|
||||
name_xor
|
||||
.iter()
|
||||
.zip(left_bits.iter().zip(right_bits.iter()))
|
||||
.flat_map(|(name, (x, y))| {
|
||||
let name_2_x_and_y = self.use_sym();
|
||||
vec![
|
||||
FlatStatement::Definition(
|
||||
name_2_x_and_y,
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2)),
|
||||
box x.clone(),
|
||||
),
|
||||
box y.clone(),
|
||||
),
|
||||
),
|
||||
FlatStatement::Definition(
|
||||
*name,
|
||||
FlatExpression::Sub(
|
||||
box FlatExpression::Add(box x.clone(), box y.clone()),
|
||||
box FlatExpression::Identifier(name_2_x_and_y),
|
||||
),
|
||||
),
|
||||
]
|
||||
}),
|
||||
);
|
||||
|
||||
name_xor.into_iter().enumerate().fold(
|
||||
FlatExpression::Number(T::from(0)),
|
||||
|acc, (i, e)| {
|
||||
FlatExpression::Add(
|
||||
box acc,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2).pow(target_bitwidth - i - 1)),
|
||||
box e.into(),
|
||||
),
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
UExpressionInner::Xor(box left, box right) => unimplemented!(),
|
||||
};
|
||||
|
||||
match should_reduce {
|
||||
true => {
|
||||
let bits = (0..actual_bitwidth)
|
||||
.map(|_| self.use_sym())
|
||||
.collect::<Vec<_>>();
|
||||
let bits = self.get_bits(res.clone(), actual_bitwidth, statements_flattened);
|
||||
|
||||
// truncate to the target bitwidth
|
||||
(0..target_bitwidth).fold(FlatExpression::Number(T::from(0)), |acc, i| {
|
||||
FlatExpression::Add(
|
||||
box acc,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2).pow(target_bitwidth - i - 1)),
|
||||
box bits[i + actual_bitwidth - target_bitwidth].clone().into(),
|
||||
),
|
||||
)
|
||||
})
|
||||
}
|
||||
false => res,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_bits(
|
||||
&mut self,
|
||||
e: FlatExpression<T>,
|
||||
bitwidth: usize,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
|
||||
println!("{:?}", self.bits_cache);
|
||||
|
||||
match self.bits_cache.entry(e.clone()) {
|
||||
Entry::Occupied(entry) => entry.get().clone().into_iter().map(|e| e.into()).collect(),
|
||||
Entry::Vacant(_) => {
|
||||
let bits = (0..bitwidth).map(|_| self.use_sym()).collect::<Vec<_>>();
|
||||
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
|
||||
bits.clone(),
|
||||
Helper::Rust(RustHelper::Bits(actual_bitwidth)),
|
||||
vec![res.clone()],
|
||||
Helper::Rust(RustHelper::Bits(bitwidth)),
|
||||
vec![e.clone()],
|
||||
)));
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
// decompose to the actual bitwidth
|
||||
|
||||
// bit checks
|
||||
statements_flattened.extend((0..actual_bitwidth).map(|i| {
|
||||
statements_flattened.extend((0..bitwidth).map(|i| {
|
||||
FlatStatement::Condition(
|
||||
bits[i].clone().into(),
|
||||
FlatExpression::Mult(
|
||||
|
@ -1185,32 +1275,24 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
)
|
||||
}));
|
||||
|
||||
// sum check
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
res.clone(),
|
||||
(0..actual_bitwidth).fold(FlatExpression::Number(T::from(0)), |acc, i| {
|
||||
FlatExpression::Add(
|
||||
box acc,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2).pow(actual_bitwidth - i - 1)),
|
||||
box bits[i].into(),
|
||||
),
|
||||
)
|
||||
}),
|
||||
));
|
||||
|
||||
// truncate to the target bitwidth
|
||||
(0..target_bitwidth).fold(FlatExpression::Number(T::from(0)), |acc, i| {
|
||||
let sum = (0..bitwidth).fold(FlatExpression::Number(T::from(0)), |acc, i| {
|
||||
FlatExpression::Add(
|
||||
box acc,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2).pow(target_bitwidth - i - 1)),
|
||||
box bits[i + actual_bitwidth - target_bitwidth].into(),
|
||||
box FlatExpression::Number(T::from(2).pow(bitwidth - i - 1)),
|
||||
box bits[i].into(),
|
||||
),
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
// sum check
|
||||
statements_flattened.push(FlatStatement::Condition(e.clone(), sum.clone()));
|
||||
|
||||
self.bits_cache.insert(e, bits.clone());
|
||||
self.bits_cache.insert(sum, bits.clone());
|
||||
|
||||
bits.into_iter().map(|v| v.into()).collect()
|
||||
}
|
||||
false => res,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
71
zokrates_core/src/optimizer/directive.rs
Normal file
71
zokrates_core/src/optimizer/directive.rs
Normal file
|
@ -0,0 +1,71 @@
|
|||
//! Module containing the `RedefinitionOptimizer` to remove code of the form
|
||||
// ```
|
||||
// b := Directive(a)
|
||||
// c := Directive(a)
|
||||
// ```
|
||||
// and replace by
|
||||
// ```
|
||||
// b := Directive(a)
|
||||
// c := b
|
||||
// ```
|
||||
|
||||
use helpers::Helper;
|
||||
use crate::flat_absy::flat_variable::FlatVariable;
|
||||
use crate::ir::folder::*;
|
||||
use crate::ir::LinComb;
|
||||
use crate::ir::*;
|
||||
use num::Zero;
|
||||
use std::collections::hash_map::{HashMap, Entry};
|
||||
use zokrates_field::field::Field;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DirectiveOptimizer<T: Field> {
|
||||
calls: HashMap<(Helper, Vec<LinComb<T>>), Vec<FlatVariable>>,
|
||||
/// Map of renamings for reassigned variables while processing the program.
|
||||
substitution: HashMap<FlatVariable, FlatVariable>,
|
||||
}
|
||||
|
||||
impl<T: Field> DirectiveOptimizer<T> {
|
||||
fn new() -> DirectiveOptimizer<T> {
|
||||
DirectiveOptimizer {
|
||||
calls: HashMap::new(),
|
||||
substitution: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn optimize(p: Prog<T>) -> Prog<T> {
|
||||
DirectiveOptimizer::new().fold_module(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
|
||||
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
|
||||
println!("{:?}", s);
|
||||
match s {
|
||||
Statement::Directive(d) => {
|
||||
let d = self.fold_directive(d);
|
||||
|
||||
match self.calls.entry((d.helper.clone(), d.inputs.clone())) {
|
||||
Entry::Vacant(e) => {
|
||||
e.insert(d.outputs.clone());
|
||||
vec![Statement::Directive(d)]
|
||||
},
|
||||
Entry::Occupied(e) => {
|
||||
self.substitution.extend(d.outputs.into_iter().zip(e.get().into_iter().cloned()));
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
},
|
||||
s => fold_statement(self, s)
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_variable(&mut self, v: FlatVariable) -> FlatVariable {
|
||||
*self.substitution.get(&v).unwrap_or(&v)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
}
|
|
@ -4,10 +4,12 @@
|
|||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
|
||||
mod directive;
|
||||
mod duplicate;
|
||||
mod redefinition;
|
||||
mod tautology;
|
||||
|
||||
use self::directive::DirectiveOptimizer;
|
||||
use self::duplicate::DuplicateOptimizer;
|
||||
use self::redefinition::RedefinitionOptimizer;
|
||||
use self::tautology::TautologyOptimizer;
|
||||
|
@ -25,6 +27,8 @@ impl<T: Field> Optimize for Prog<T> {
|
|||
let r = RedefinitionOptimizer::optimize(self);
|
||||
// remove constraints that are always satisfied
|
||||
let r = TautologyOptimizer::optimize(r);
|
||||
// deduplicate directives which take the same input
|
||||
let r = DirectiveOptimizer::optimize(r);
|
||||
// remove duplicate constraints
|
||||
let r = DuplicateOptimizer::optimize(r);
|
||||
r
|
||||
|
|
|
@ -738,7 +738,7 @@ impl<'ast> Checker<'ast> {
|
|||
types: &TypeMap,
|
||||
) -> Result<Variable<'ast>, Vec<Error>> {
|
||||
Ok(Variable::with_id_and_type(
|
||||
v.value.id.into(),
|
||||
v.value.id,
|
||||
self.check_type(v.value._type, module_id, types)
|
||||
.map_err(|e| vec![e])?,
|
||||
))
|
||||
|
@ -933,7 +933,7 @@ impl<'ast> Checker<'ast> {
|
|||
match assignee.value {
|
||||
Assignee::Identifier(variable_name) => match self.get_scope(&variable_name) {
|
||||
Some(var) => Ok(TypedAssignee::Identifier(Variable::with_id_and_type(
|
||||
variable_name.into(),
|
||||
variable_name,
|
||||
var.id._type.clone(),
|
||||
))),
|
||||
None => Err(Error {
|
||||
|
@ -1903,6 +1903,39 @@ impl<'ast> Checker<'ast> {
|
|||
}),
|
||||
}
|
||||
}
|
||||
Expression::Xor(box e1, box e2) => {
|
||||
let e1_checked = self.check_expression(e1, module_id, &types)?;
|
||||
let e2_checked = self.check_expression(e2, module_id, &types)?;
|
||||
match (e1_checked, e2_checked) {
|
||||
(TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => {
|
||||
Ok(BooleanExpression::Xor(box e1, box e2).into())
|
||||
}
|
||||
(TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => {
|
||||
if e1.get_type() == e2.get_type() {
|
||||
Ok(UExpression::xor(e1, e2).into())
|
||||
} else {
|
||||
Err(Error {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!(
|
||||
"Cannot apply `^` to {}, {}",
|
||||
e1.get_type(),
|
||||
e2.get_type()
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
(e1, e2) => Err(Error {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!(
|
||||
"Cannot apply `^` to {}, {}",
|
||||
e1.get_type(),
|
||||
e2.get_type()
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
Expression::Not(box e) => {
|
||||
let e_checked = self.check_expression(e, module_id, &types)?;
|
||||
match e_checked {
|
||||
|
|
|
@ -33,12 +33,14 @@ use crate::typed_absy::*;
|
|||
use zokrates_field::field::Field;
|
||||
|
||||
pub struct InputConstrainer<'ast, T: Field> {
|
||||
next_var_id: usize,
|
||||
constraints: Vec<TypedStatement<'ast, T>>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> InputConstrainer<'ast, T> {
|
||||
fn new() -> Self {
|
||||
InputConstrainer {
|
||||
next_var_id: 0,
|
||||
constraints: vec![],
|
||||
}
|
||||
}
|
||||
|
@ -47,6 +49,28 @@ impl<'ast, T: Field> InputConstrainer<'ast, T> {
|
|||
InputConstrainer::new().fold_program(p)
|
||||
}
|
||||
|
||||
fn constrain_bits(&mut self, u: UExpression<'ast>) {
|
||||
let bitwidth = u.bitwidth;
|
||||
let bit_input = Variable::with_id_and_type(
|
||||
CoreIdentifier::Internal("bit_input_array", self.next_var_id),
|
||||
Type::array(Type::FieldElement, bitwidth),
|
||||
);
|
||||
self.next_var_id += 1;
|
||||
self.constraints.push(TypedStatement::MultipleDefinition(
|
||||
vec![bit_input],
|
||||
TypedExpressionList::FunctionCall(
|
||||
match bitwidth {
|
||||
8 => crate::embed::FlatEmbed::CheckU8.key::<T>(),
|
||||
16 => crate::embed::FlatEmbed::CheckU16.key::<T>(),
|
||||
32 => crate::embed::FlatEmbed::CheckU32.key::<T>(),
|
||||
_ => unreachable!()
|
||||
},
|
||||
vec![u.into()],
|
||||
vec![Type::array(Type::FieldElement, bitwidth)],
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
fn constrain_expression(&mut self, e: TypedExpression<'ast, T>) {
|
||||
match e {
|
||||
TypedExpression::FieldElement(_) => {}
|
||||
|
@ -54,8 +78,8 @@ impl<'ast, T: Field> InputConstrainer<'ast, T> {
|
|||
b.clone().into(),
|
||||
BooleanExpression::And(box b.clone(), box b).into(),
|
||||
)),
|
||||
TypedExpression::Uint(bitwidth) => {
|
||||
// TODO constrain by checking that it decomposes correctly
|
||||
TypedExpression::Uint(u) => {
|
||||
self.constrain_bits(u);
|
||||
}
|
||||
TypedExpression::Array(a) => {
|
||||
for i in 0..a.size() {
|
||||
|
|
|
@ -72,6 +72,18 @@ impl<'ast, T: Field> Inliner<'ast, T> {
|
|||
let sha256_round = crate::embed::FlatEmbed::Sha256Round;
|
||||
let sha256_round_key = sha256_round.key::<T>();
|
||||
|
||||
// define a function in the main module for the `check_u8` embed
|
||||
let check_u8 = crate::embed::FlatEmbed::CheckU8;
|
||||
let check_u8_key = check_u8.key::<T>();
|
||||
|
||||
// define a function in the main module for the `check_u8` embed
|
||||
let check_u16 = crate::embed::FlatEmbed::CheckU16;
|
||||
let check_u16_key = check_u16.key::<T>();
|
||||
|
||||
// define a function in the main module for the `check_u8` embed
|
||||
let check_u32 = crate::embed::FlatEmbed::CheckU32;
|
||||
let check_u32_key = check_u32.key::<T>();
|
||||
|
||||
// return a program with a single module containing `main`, `_UNPACK`, and `_SHA256_ROUND
|
||||
TypedProgram {
|
||||
main: String::from("main"),
|
||||
|
@ -81,6 +93,9 @@ impl<'ast, T: Field> Inliner<'ast, T> {
|
|||
functions: vec![
|
||||
(unpack_key, TypedFunctionSymbol::Flat(unpack)),
|
||||
(sha256_round_key, TypedFunctionSymbol::Flat(sha256_round)),
|
||||
(check_u8_key, TypedFunctionSymbol::Flat(check_u8)),
|
||||
(check_u16_key, TypedFunctionSymbol::Flat(check_u16)),
|
||||
(check_u32_key, TypedFunctionSymbol::Flat(check_u32)),
|
||||
(main_key, main),
|
||||
]
|
||||
.into_iter()
|
||||
|
|
|
@ -9,11 +9,13 @@ mod flat_propagation;
|
|||
mod inline;
|
||||
mod propagation;
|
||||
mod unroll;
|
||||
mod uint_optimizer;
|
||||
|
||||
use self::constrain_inputs::InputConstrainer;
|
||||
use self::inline::Inliner;
|
||||
use self::propagation::Propagator;
|
||||
use self::unroll::Unroller;
|
||||
use self::uint_optimizer::UintOptimizer;
|
||||
use crate::flat_absy::FlatProg;
|
||||
use crate::typed_absy::TypedProgram;
|
||||
use zokrates_field::field::Field;
|
||||
|
@ -32,6 +34,8 @@ impl<'ast, T: Field> Analyse for TypedProgram<'ast, T> {
|
|||
let r = Propagator::propagate(r);
|
||||
// constrain inputs
|
||||
let r = InputConstrainer::constrain(r);
|
||||
// optimize uint expressions
|
||||
let r = UintOptimizer::optimize(r);
|
||||
r
|
||||
}
|
||||
}
|
||||
|
|
261
zokrates_core/src/static_analysis/uint_optimizer.rs
Normal file
261
zokrates_core/src/static_analysis/uint_optimizer.rs
Normal file
|
@ -0,0 +1,261 @@
|
|||
use typed_absy::bitwidth;
|
||||
use std::collections::HashMap;
|
||||
use zokrates_field::field::Field;
|
||||
use typed_absy::folder::*;
|
||||
use crate::typed_absy::*;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct UintOptimizer<'ast, T> {
|
||||
ids: HashMap<Identifier<'ast>, UMetadata>,
|
||||
phantom: PhantomData<T>
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> UintOptimizer<'ast, T> {
|
||||
|
||||
pub fn new() -> Self {
|
||||
UintOptimizer {
|
||||
ids: HashMap::new(),
|
||||
phantom: PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
pub fn optimize(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
UintOptimizer::new().fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
||||
fn fold_uint_expression(&mut self, e: UExpression<'ast>) -> UExpression<'ast> {
|
||||
let max_bitwidth = T::get_required_bits() - 1;
|
||||
|
||||
let range = e.bitwidth;
|
||||
|
||||
assert!(range < max_bitwidth / 2);
|
||||
|
||||
let metadata = e.metadata;
|
||||
let inner = e.inner;
|
||||
|
||||
use self::UExpressionInner::*;
|
||||
|
||||
match inner {
|
||||
Value(v) => Value(v).annotate(range).metadata(UMetadata {
|
||||
bitwidth: Some(bitwidth(v)),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
}),
|
||||
Identifier(id) => Identifier(id.clone()).annotate(range).metadata(self.ids.get(&id).cloned().unwrap_or(UMetadata {
|
||||
bitwidth: Some(range),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})),
|
||||
Add(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
println!("{} + {}", left, right);
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
|
||||
let left_bitwidth = left_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
left_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let right_bitwidth = right_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
right_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let output_width = std::cmp::max(left_bitwidth, right_bitwidth) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
|
||||
|
||||
println!("{}", output_width);
|
||||
|
||||
if output_width > max_bitwidth {
|
||||
// the addition doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
|
||||
|
||||
let left = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
}),
|
||||
..left
|
||||
};
|
||||
|
||||
let right = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
}),
|
||||
..right
|
||||
};
|
||||
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(range + 1),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
} else {
|
||||
// the addition fits, so we just add
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(output_width),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
Xor(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// for xor we need both terms to be in range. Therefore we reduce them to being in range.
|
||||
// NB: if they are already in range, the flattening process will ignore the reduction
|
||||
let left = left.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
});
|
||||
|
||||
let right = right.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
});
|
||||
|
||||
UExpression::xor(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(range),
|
||||
should_reduce: Some(true),
|
||||
})
|
||||
}
|
||||
Mult(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
|
||||
let left_bitwidth = left_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
left_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let right_bitwidth = right_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
right_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let output_width = left_bitwidth + right_bitwidth; // bitwidth(a*b) = bitwidth(a) + bitwidth(b)
|
||||
|
||||
if output_width > max_bitwidth {
|
||||
// the multiplication doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
|
||||
|
||||
let left = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
}),
|
||||
..left
|
||||
};
|
||||
|
||||
let right = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
}),
|
||||
..right
|
||||
};
|
||||
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(2 * range),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
} else {
|
||||
// the multiplication fits, so we just multiply
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(output_width),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedStatement::Definition(TypedAssignee::Identifier(id), TypedExpression::Uint(e)) => {
|
||||
let e = self.fold_uint_expression(e);
|
||||
self.ids.insert(id.id.clone(), e.metadata.clone().unwrap());
|
||||
vec![TypedStatement::Definition(TypedAssignee::Identifier(id), TypedExpression::Uint(e))]
|
||||
},
|
||||
// we need to put back in range to return
|
||||
TypedStatement::Return(expressions) => vec![TypedStatement::Return(expressions
|
||||
.into_iter()
|
||||
.map(|e| match e {
|
||||
TypedExpression::Uint(e) => {
|
||||
let e = self.fold_uint_expression(e);
|
||||
let e = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
bitwidth: e.metadata.unwrap().bitwidth,
|
||||
}),
|
||||
..e
|
||||
};
|
||||
|
||||
TypedExpression::Uint(e)
|
||||
}
|
||||
e => self.fold_expression(e)
|
||||
})
|
||||
.collect())],
|
||||
s => fold_statement(self, s)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -26,7 +26,7 @@ impl<'ast> Unroller<'ast> {
|
|||
let res = match self.substitution.get(&v.id) {
|
||||
Some(i) => Variable {
|
||||
id: Identifier {
|
||||
id: v.id.id,
|
||||
id: v.id.clone().id,
|
||||
version: i + 1,
|
||||
stack: vec![],
|
||||
},
|
||||
|
|
|
@ -350,6 +350,11 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let e2 = f.fold_boolean_expression(e2);
|
||||
BooleanExpression::Or(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::Xor(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1);
|
||||
let e2 = f.fold_boolean_expression(e2);
|
||||
BooleanExpression::Xor(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::And(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1);
|
||||
let e2 = f.fold_boolean_expression(e2);
|
||||
|
|
|
@ -2,11 +2,26 @@ use std::fmt;
|
|||
use typed_absy::types::FunctionKey;
|
||||
use typed_absy::TypedModuleId;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||
pub enum CoreIdentifier<'ast> {
|
||||
Source(&'ast str),
|
||||
Internal(&'static str, usize),
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for CoreIdentifier<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
CoreIdentifier::Source(s) => write!(f, "{}", s),
|
||||
CoreIdentifier::Internal(s, i) => write!(f, "#INTERNAL#_{}_{}", s, i),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A identifier for a variable
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||
pub struct Identifier<'ast> {
|
||||
/// the id of the variable
|
||||
pub id: &'ast str,
|
||||
pub id: CoreIdentifier<'ast>,
|
||||
/// the version of the variable, used after SSA transformation
|
||||
pub version: usize,
|
||||
/// the call stack of the variable, used when inlining
|
||||
|
@ -35,6 +50,12 @@ impl<'ast> fmt::Display for Identifier<'ast> {
|
|||
|
||||
impl<'ast> From<&'ast str> for Identifier<'ast> {
|
||||
fn from(id: &'ast str) -> Identifier<'ast> {
|
||||
Identifier::from(CoreIdentifier::Source(id))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<CoreIdentifier<'ast>> for Identifier<'ast> {
|
||||
fn from(id: CoreIdentifier<'ast>) -> Identifier<'ast> {
|
||||
Identifier {
|
||||
id,
|
||||
version: 0,
|
||||
|
|
|
@ -12,10 +12,11 @@ pub mod types;
|
|||
mod uint;
|
||||
mod variable;
|
||||
|
||||
pub use self::identifier::CoreIdentifier;
|
||||
pub use self::parameter::Parameter;
|
||||
pub use self::types::Type;
|
||||
pub use self::variable::Variable;
|
||||
pub use typed_absy::uint::{UExpression, UExpressionInner, UMetadata};
|
||||
pub use typed_absy::uint::{UExpression, UExpressionInner, UMetadata, bitwidth};
|
||||
|
||||
use crate::typed_absy::types::{FunctionKey, MemberId, Signature};
|
||||
use embed::FlatEmbed;
|
||||
|
@ -571,6 +572,10 @@ pub enum BooleanExpression<'ast, T: Field> {
|
|||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
),
|
||||
Xor(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
),
|
||||
And(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
|
@ -784,7 +789,13 @@ impl<'ast, T: Field> fmt::Display for FieldElementExpression<'ast, T> {
|
|||
|
||||
impl<'ast> fmt::Display for UExpression<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
unimplemented!()
|
||||
match self.inner {
|
||||
UExpressionInner::Value(ref v) => write!(f, "{}", v),
|
||||
UExpressionInner::Identifier(ref var) => write!(f, "{}", var),
|
||||
UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
|
||||
UExpressionInner::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs),
|
||||
UExpressionInner::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -799,6 +810,7 @@ impl<'ast, T: Field> fmt::Display for BooleanExpression<'ast, T> {
|
|||
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
|
||||
BooleanExpression::Xor(ref lhs, ref rhs) => write!(f, "{} ^ {}", lhs, rhs),
|
||||
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
|
||||
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
|
||||
BooleanExpression::Value(b) => write!(f, "{}", b),
|
||||
|
|
|
@ -68,7 +68,7 @@ impl<'ast> UExpressionInner<'ast> {
|
|||
}
|
||||
|
||||
impl<'ast> UExpression<'ast> {
|
||||
fn metadata(self, metadata: UMetadata) -> UExpression<'ast> {
|
||||
pub fn metadata(self, metadata: UMetadata) -> UExpression<'ast> {
|
||||
UExpression {
|
||||
metadata: Some(metadata),
|
||||
..self
|
||||
|
@ -76,208 +76,10 @@ impl<'ast> UExpression<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
fn bitwidth(a: u128) -> Bitwidth {
|
||||
pub fn bitwidth(a: u128) -> Bitwidth {
|
||||
(128 - a.leading_zeros()) as Bitwidth
|
||||
}
|
||||
|
||||
impl<'ast> UExpression<'ast> {
|
||||
pub fn reduce<T: Field>(self) -> Self {
|
||||
let max_bitwidth = T::get_required_bits() - 1;
|
||||
|
||||
let range = self.bitwidth;
|
||||
|
||||
assert!(range < max_bitwidth / 2);
|
||||
|
||||
let metadata = self.metadata;
|
||||
let inner = self.inner;
|
||||
|
||||
use self::UExpressionInner::*;
|
||||
|
||||
match inner {
|
||||
Value(v) => Value(v).annotate(range).metadata(UMetadata {
|
||||
bitwidth: Some(bitwidth(v)),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
}),
|
||||
Identifier(id) => Identifier(id).annotate(range).metadata(UMetadata {
|
||||
bitwidth: Some(range),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
}),
|
||||
Add(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = left.reduce::<T>();
|
||||
let right = right.reduce::<T>();
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
|
||||
let left_bitwidth = left_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
left_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let right_bitwidth = right_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
right_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let output_width = std::cmp::max(left_bitwidth, right_bitwidth) + 1; // bitwidth(a + b) = max(bitwidth(a), bitwidth(b)) + 1
|
||||
|
||||
if output_width > max_bitwidth {
|
||||
// the addition doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
|
||||
|
||||
let left = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
}),
|
||||
..left
|
||||
};
|
||||
|
||||
let right = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
}),
|
||||
..right
|
||||
};
|
||||
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(range + 1),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
} else {
|
||||
// the addition fits, so we just add
|
||||
UExpression::add(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(output_width),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
Xor(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = left.reduce::<T>();
|
||||
let right = right.reduce::<T>();
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// for xor we need both terms to be in range. Therefore we reduce them to being in range.
|
||||
// NB: if they are already in range, the flattening process will ignore the reduction
|
||||
let left = left.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
});
|
||||
|
||||
let right = right.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
});
|
||||
|
||||
UExpression::xor(left, right)
|
||||
}
|
||||
Mult(box left, box right) => {
|
||||
// reduce the two terms
|
||||
let left = left.reduce::<T>();
|
||||
let right = right.reduce::<T>();
|
||||
|
||||
let left_metadata = left.metadata.clone().unwrap();
|
||||
let right_metadata = right.metadata.clone().unwrap();
|
||||
|
||||
// determine the bitwidth of each term. It's their current bitwidth, unless they are tagged as `should_reduce` in which case they now have bitwidth 8
|
||||
let left_bitwidth = left_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
left_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let right_bitwidth = right_metadata
|
||||
.should_reduce
|
||||
.map(|should_reduce| {
|
||||
if should_reduce {
|
||||
range
|
||||
} else {
|
||||
right_metadata.bitwidth.unwrap()
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let output_width = left_bitwidth + right_bitwidth; // bitwidth(a*b) = bitwidth(a) + bitwidth(b)
|
||||
|
||||
if output_width > max_bitwidth {
|
||||
// the multiplication doesnt fit, we reduce both terms first (TODO maybe one would be enough here)
|
||||
|
||||
let left = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..left_metadata
|
||||
}),
|
||||
..left
|
||||
};
|
||||
|
||||
let right = UExpression {
|
||||
metadata: Some(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..right_metadata
|
||||
}),
|
||||
..right
|
||||
};
|
||||
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(2 * range),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
} else {
|
||||
// the multiplication fits, so we just multiply
|
||||
UExpression::mult(left, right).metadata(UMetadata {
|
||||
bitwidth: Some(output_width),
|
||||
should_reduce: Some(
|
||||
metadata
|
||||
.map(|m| m.should_reduce.unwrap_or(false))
|
||||
.unwrap_or(false),
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> UExpression<'ast> {
|
||||
pub fn bitwidth(&self) -> Bitwidth {
|
||||
self.bitwidth
|
||||
|
|
|
@ -9,7 +9,7 @@ pub struct Variable<'ast> {
|
|||
}
|
||||
|
||||
impl<'ast> Variable<'ast> {
|
||||
pub fn field_element(id: Identifier<'ast>) -> Variable<'ast> {
|
||||
pub fn field_element<I: Into<Identifier<'ast>>>(id: I) -> Variable<'ast> {
|
||||
Self::with_id_and_type(id, Type::FieldElement)
|
||||
}
|
||||
|
||||
|
@ -18,20 +18,23 @@ impl<'ast> Variable<'ast> {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn field_array(id: Identifier<'ast>, size: usize) -> Variable<'ast> {
|
||||
pub fn field_array<I: Into<Identifier<'ast>>>(id: I, size: usize) -> Variable<'ast> {
|
||||
Self::array(id, Type::FieldElement, size)
|
||||
}
|
||||
|
||||
pub fn array(id: Identifier<'ast>, ty: Type, size: usize) -> Variable<'ast> {
|
||||
pub fn array<I: Into<Identifier<'ast>>>(id: I, ty: Type, size: usize) -> Variable<'ast> {
|
||||
Self::with_id_and_type(id, Type::array(ty, size))
|
||||
}
|
||||
|
||||
pub fn struc(id: Identifier<'ast>, ty: Vec<(MemberId, Type)>) -> Variable<'ast> {
|
||||
pub fn struc<I: Into<Identifier<'ast>>>(id: I, ty: Vec<(MemberId, Type)>) -> Variable<'ast> {
|
||||
Self::with_id_and_type(id, Type::Struct(ty))
|
||||
}
|
||||
|
||||
pub fn with_id_and_type(id: Identifier<'ast>, _type: Type) -> Variable<'ast> {
|
||||
Variable { id, _type }
|
||||
pub fn with_id_and_type<I: Into<Identifier<'ast>>>(id: I, _type: Type) -> Variable<'ast> {
|
||||
Variable {
|
||||
id: id.into(),
|
||||
_type,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_type(&self) -> Type {
|
||||
|
|
Loading…
Reference in a new issue