1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00
ZoKrates/zokrates_core/src/static_analysis/propagation.rs
2020-12-20 20:30:40 +01:00

1871 lines
86 KiB
Rust

//! Module containing constant propagation for the typed AST
//!
//! Constant propagation on the SSA program. The constants map can be passed by the caller to allow for many passes to use
//! the same constants.
//!
//! @file propagation.rs
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
use crate::typed_absy::folder::*;
use crate::typed_absy::types::Type;
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryFrom;
use zokrates_field::Field;
type Constants<'ast, T> = HashMap<Identifier<'ast>, TypedExpression<'ast, T>>;
pub struct Propagator<'ast, 'a, T: Field> {
// constants keeps track of constant expressions
// we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is
constants: &'a mut Constants<'ast, T>,
}
impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> {
pub fn with_constants(constants: &'a mut Constants<'ast, T>) -> Self {
Propagator { constants }
}
pub fn propagate(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
let mut constants = Constants::new();
Propagator {
constants: &mut constants,
}
.fold_program(p)
}
// get a mutable reference to the constant corresponding to a given assignee if any, otherwise
// return the identifier at the root of this assignee
fn try_get_constant_mut<'b>(
&mut self,
assignee: &'b TypedAssignee<'ast, T>,
) -> Result<(&'b Variable<'ast, T>, &mut TypedExpression<'ast, T>), &'b Variable<'ast, T>> {
match assignee {
TypedAssignee::Identifier(var) => self
.constants
.get_mut(&var.id)
.map(|c| Ok((var, c)))
.unwrap_or(Err(var)),
TypedAssignee::Select(box assignee, box index) => {
match self.try_get_constant_mut(&assignee) {
Ok((v, c)) => match index.as_inner() {
UExpressionInner::Value(n) => match c {
TypedExpression::Array(a) => match a.as_inner_mut() {
ArrayExpressionInner::Value(value) => {
Ok((v, &mut value[*n as usize]))
}
_ => unreachable!(),
},
_ => unreachable!(),
},
_ => Err(v),
},
e => e,
}
}
TypedAssignee::Member(box assignee, m) => match self.try_get_constant_mut(&assignee) {
Ok((v, c)) => {
let ty = assignee.get_type();
let index = match ty {
Type::Struct(struct_ty) => struct_ty
.members
.iter()
.position(|member| *m == member.id)
.unwrap(),
_ => unreachable!(),
};
match c {
TypedExpression::Struct(a) => match a.as_inner_mut() {
StructExpressionInner::Value(value) => Ok((v, &mut value[index])),
_ => unreachable!(),
},
_ => unreachable!(),
}
}
e => e,
},
}
}
}
fn is_constant<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> bool {
match e {
TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true,
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)),
_ => false,
},
TypedExpression::Struct(a) => match a.as_inner() {
StructExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)),
_ => false,
},
TypedExpression::Uint(a) => matches!(a.as_inner(), UExpressionInner::Value(..)),
_ => false,
}
}
impl<'ast, 'a, T: Field> Folder<'ast, T> for Propagator<'ast, 'a, T> {
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
let main = p.main.clone();
TypedProgram {
modules: p
.modules
.into_iter()
.map(|(module_id, module)| {
(
module_id.clone(),
if module_id == main {
self.fold_module(module)
} else {
module
},
)
})
.collect(),
main: p.main,
}
}
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
TypedModule {
functions: m
.functions
.into_iter()
.map(|(key, fun)| {
(
key.clone(),
if key.id == "main" {
self.fold_function_symbol(fun)
} else {
fun
},
)
})
.collect(),
..m
}
}
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
fold_function(self, f)
}
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
let res = match s {
TypedStatement::Declaration(v) => vec![TypedStatement::Declaration(v)],
TypedStatement::Return(expressions) => vec![TypedStatement::Return(
expressions
.into_iter()
.map(|e| self.fold_expression(e))
.collect(),
)],
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(assignee, expr) => {
let expr = self.fold_expression(expr);
let assignee = self.fold_assignee(assignee);
if is_constant(&expr) {
match assignee {
TypedAssignee::Identifier(var) => {
assert!(self.constants.insert(var.id, expr).is_none());
vec![]
}
assignee => match self.try_get_constant_mut(&assignee) {
Ok((_, c)) => {
*c = expr;
vec![]
}
Err(v) => match self.constants.remove(&v.id) {
// invalidate the cache for this identifier, and define the latest
// version of the constant in the program, if any
Some(c) => vec![
TypedStatement::Definition(v.clone().into(), c),
TypedStatement::Definition(assignee, expr),
],
None => vec![TypedStatement::Definition(assignee, expr)],
},
},
}
} else {
// the expression being assigned is not constant, invalidate the cache
let v = self
.try_get_constant_mut(&assignee)
.map(|(v, _)| v)
.unwrap_or_else(|v| v);
match self.constants.remove(&v.id) {
Some(c) => vec![
TypedStatement::Definition(v.clone().into(), c),
TypedStatement::Definition(assignee, expr),
],
None => vec![TypedStatement::Definition(assignee, expr)],
}
}
}
// propagate the boolean
TypedStatement::Assertion(e) => {
// could stop execution here if condition is known to fail
vec![TypedStatement::Assertion(self.fold_boolean_expression(e))]
}
// we do not visit the for-loop statements
TypedStatement::For(v, from, to, statements) => {
let from = self.fold_uint_expression(from);
let to = self.fold_uint_expression(to);
vec![TypedStatement::For(v, from, to, statements)]
}
TypedStatement::MultipleDefinition(assignees, expression_list) => {
let assignees: Vec<TypedAssignee<'ast, T>> = assignees
.into_iter()
.map(|a| self.fold_assignee(a))
.collect();
let expression_list = self.fold_expression_list(expression_list);
match expression_list {
TypedExpressionList::FunctionCall(key, arguments, types) => {
let arguments: Vec<_> = arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect();
let types = types.into_iter().map(|t| self.fold_type(t)).collect();
fn process_u_from_bits<'ast, T: Field>(
variables: Vec<TypedAssignee<'ast, T>>,
arguments: Vec<TypedExpression<'ast, T>>,
bitwidth: UBitwidth,
) -> TypedExpression<'ast, T> {
assert_eq!(variables.len(), 1);
assert_eq!(arguments.len(), 1);
use std::convert::TryInto;
match ArrayExpression::try_from(arguments[0].clone())
.unwrap()
.into_inner()
{
ArrayExpressionInner::Value(v) => {
assert_eq!(v.len(), bitwidth.to_usize());
UExpressionInner::Value(
v.into_iter()
.map(|v| match v {
TypedExpression::Boolean(
BooleanExpression::Value(v),
) => v,
_ => unreachable!("should be a boolean value"),
})
.enumerate()
.fold(0, |acc, (i, v)| {
if v {
acc + 2u128.pow(
(bitwidth.to_usize() - i - 1)
.try_into()
.unwrap(),
)
} else {
acc
}
}),
)
.annotate(bitwidth)
.into()
}
_ => unreachable!("should be an array value"),
}
}
fn process_u_to_bits<'ast, T: Field>(
variables: Vec<TypedAssignee<'ast, T>>,
arguments: Vec<TypedExpression<'ast, T>>,
bitwidth: UBitwidth,
) -> TypedExpression<'ast, T> {
assert_eq!(variables.len(), 1);
assert_eq!(arguments.len(), 1);
match UExpression::try_from(arguments[0].clone())
.unwrap()
.into_inner()
{
UExpressionInner::Value(v) => {
let mut num = v;
let mut res = vec![];
for i in (0..bitwidth as u32).rev() {
if 2u128.pow(i) <= num {
num -= 2u128.pow(i);
res.push(true);
} else {
res.push(false);
}
}
assert_eq!(num, 0);
ArrayExpressionInner::Value(
res.into_iter()
.map(|v| BooleanExpression::Value(v).into())
.collect(),
)
.annotate(Type::Boolean, bitwidth.to_usize() as u32)
.into()
}
_ => unreachable!("should be a uint value"),
}
}
match arguments.iter().all(|a| is_constant(a)) {
true => {
let r: Option<TypedExpression<'ast, T>> = match key.id {
"_U32_FROM_BITS" => Some(process_u_from_bits(
assignees.clone(),
arguments.clone(),
UBitwidth::B32,
)),
"_U16_FROM_BITS" => Some(process_u_from_bits(
assignees.clone(),
arguments.clone(),
UBitwidth::B16,
)),
"_U8_FROM_BITS" => Some(process_u_from_bits(
assignees.clone(),
arguments.clone(),
UBitwidth::B8,
)),
"_U32_TO_BITS" => Some(process_u_to_bits(
assignees.clone(),
arguments.clone(),
UBitwidth::B32,
)),
"_U16_TO_BITS" => Some(process_u_to_bits(
assignees.clone(),
arguments.clone(),
UBitwidth::B16,
)),
"_U8_TO_BITS" => Some(process_u_to_bits(
assignees.clone(),
arguments.clone(),
UBitwidth::B8,
)),
"_UNPACK" => {
assert_eq!(assignees.len(), 1);
assert_eq!(arguments.len(), 1);
match FieldElementExpression::try_from(arguments[0].clone())
.unwrap()
{
FieldElementExpression::Number(num) => {
let mut num = num;
let mut res = vec![];
for i in (0..T::get_required_bits()).rev() {
if T::from(2).pow(i) <= num {
num = num - T::from(2).pow(i);
res.push(true);
} else {
res.push(false);
}
}
assert_eq!(num, T::zero());
Some(
ArrayExpressionInner::Value(
res.into_iter()
.map(|v| {
BooleanExpression::Value(v).into()
})
.collect(),
)
.annotate(
Type::Boolean,
T::get_required_bits() as u32,
)
.into(),
)
}
_ => unreachable!("should be a field value"),
}
}
"_SHA256_ROUND" => None,
_ => None,
};
match r {
// if the function call returns a constant
Some(expr) => {
let mut assignees = assignees;
match assignees.pop().unwrap() {
TypedAssignee::Identifier(var) => {
self.constants.insert(var.id, expr);
vec![]
}
assignee => {
match self.try_get_constant_mut(&assignee) {
Ok((_, c)) => {
*c = expr;
vec![]
}
Err(v) => match self.constants.remove(&v.id) {
Some(c) => vec![
TypedStatement::Definition(
v.clone().into(),
c,
),
TypedStatement::Definition(
assignee, expr,
),
],
None => vec![TypedStatement::Definition(
assignee, expr,
)],
},
}
}
}
}
None => {
// if the function call does not return a constant, invalidate the cache
// this happpens because we only propagate certain calls here
let mut assignees = assignees;
let assignee = assignees.pop().unwrap();
let v = self
.try_get_constant_mut(&assignee)
.map(|(v, _)| v)
.unwrap_or_else(|v| v);
match self.constants.remove(&v.id) {
Some(c) => vec![
TypedStatement::Definition(v.clone().into(), c),
TypedStatement::MultipleDefinition(
vec![assignee],
TypedExpressionList::FunctionCall(
key, arguments, types,
),
),
],
None => vec![TypedStatement::MultipleDefinition(
vec![assignee],
TypedExpressionList::FunctionCall(
key, arguments, types,
),
)],
}
}
}
}
false => {
// if the function arguments are not constant, invalidate the cache
// for the return assignees
let invalidations = assignees
.iter()
.flat_map(|assignee| {
let v = self
.try_get_constant_mut(&assignee)
.map(|(v, _)| v)
.unwrap_or_else(|v| v);
match self.constants.remove(&v.id) {
Some(c) => vec![TypedStatement::Definition(
v.clone().into(),
c,
)],
None => vec![],
}
})
.collect::<Vec<_>>();
let l = TypedExpressionList::FunctionCall(key, arguments, types);
invalidations
.into_iter()
.chain(std::iter::once(TypedStatement::MultipleDefinition(
assignees, l,
)))
.collect()
}
}
}
}
}
s @ TypedStatement::PushCallLog(..) => vec![s],
s @ TypedStatement::PopCallLog => vec![s],
};
res
}
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Identifier(id) => match self.constants.get(&id) {
Some(e) => match e {
TypedExpression::Uint(e) => e.as_inner().clone(),
_ => unreachable!("constant stored for a uint should be a uint"),
},
None => UExpressionInner::Identifier(id),
},
UExpressionInner::Add(box e1, box e2) => match (
self.fold_uint_expression(e1).into_inner(),
self.fold_uint_expression(e2).into_inner(),
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
use std::convert::TryInto;
UExpressionInner::Value(
(v1 + v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()),
)
}
(e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => match v {
0 => e,
_ => UExpressionInner::Add(
box e.annotate(bitwidth),
box UExpressionInner::Value(v).annotate(bitwidth),
),
},
(e1, e2) => {
UExpressionInner::Add(box e1.annotate(bitwidth), box e2.annotate(bitwidth))
}
},
UExpressionInner::Sub(box e1, box e2) => match (
self.fold_uint_expression(e1).into_inner(),
self.fold_uint_expression(e2).into_inner(),
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
use std::convert::TryInto;
UExpressionInner::Value(
(v1.wrapping_sub(v2)) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()),
)
}
(e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => match v {
0 => e,
_ => UExpressionInner::Sub(
box e.annotate(bitwidth),
box UExpressionInner::Value(v).annotate(bitwidth),
),
},
(e1, e2) => {
UExpressionInner::Sub(box e1.annotate(bitwidth), box e2.annotate(bitwidth))
}
},
UExpressionInner::Mult(box e1, box e2) => match (
self.fold_uint_expression(e1).into_inner(),
self.fold_uint_expression(e2).into_inner(),
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
use std::convert::TryInto;
UExpressionInner::Value(
(v1 * v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()),
)
}
(e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => match v {
0 => UExpressionInner::Value(0),
1 => e,
_ => UExpressionInner::Mult(
box e.annotate(bitwidth),
box UExpressionInner::Value(v).annotate(bitwidth),
),
},
(e1, e2) => {
UExpressionInner::Mult(box e1.annotate(bitwidth), box e2.annotate(bitwidth))
}
},
UExpressionInner::Div(box e1, box e2) => match (
self.fold_uint_expression(e1).into_inner(),
self.fold_uint_expression(e2).into_inner(),
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
use std::convert::TryInto;
UExpressionInner::Value(
(v1 / v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()),
)
}
(e, UExpressionInner::Value(v)) => match v {
1 => e,
_ => UExpressionInner::Div(
box e.annotate(bitwidth),
box UExpressionInner::Value(v).annotate(bitwidth),
),
},
(e1, e2) => {
UExpressionInner::Div(box e1.annotate(bitwidth), box e2.annotate(bitwidth))
}
},
UExpressionInner::Rem(box e1, box e2) => match (
self.fold_uint_expression(e1).into_inner(),
self.fold_uint_expression(e2).into_inner(),
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
use std::convert::TryInto;
UExpressionInner::Value(
(v1 % v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()),
)
}
(e, UExpressionInner::Value(v)) => match v {
1 => UExpressionInner::Value(0),
_ => UExpressionInner::Rem(
box e.annotate(bitwidth),
box UExpressionInner::Value(v).annotate(bitwidth),
),
},
(e1, e2) => {
UExpressionInner::Rem(box e1.annotate(bitwidth), box e2.annotate(bitwidth))
}
},
UExpressionInner::RightShift(box e, box by) => {
let e = self.fold_uint_expression(e);
let by = self.fold_field_expression(by);
match (e.into_inner(), by) {
(UExpressionInner::Value(v), FieldElementExpression::Number(by)) => {
let by_as_usize = by.to_dec_string().parse::<usize>().unwrap();
UExpressionInner::Value(v >> by_as_usize)
}
(e, FieldElementExpression::Number(by)) => UExpressionInner::RightShift(
box e.annotate(bitwidth),
box FieldElementExpression::Number(by),
),
(_, e2) => unreachable!(format!(
"non-constant shift {} detected during static analysis",
e2
)),
}
}
UExpressionInner::LeftShift(box e, box by) => {
let e = self.fold_uint_expression(e);
let by = self.fold_field_expression(by);
match (e.into_inner(), by) {
(UExpressionInner::Value(v), FieldElementExpression::Number(by)) => {
let by_as_usize = by.to_dec_string().parse::<usize>().unwrap();
UExpressionInner::Value((v << by_as_usize) & 0xffffffff)
}
(e, FieldElementExpression::Number(by)) => UExpressionInner::LeftShift(
box e.annotate(bitwidth),
box FieldElementExpression::Number(by),
),
(_, e2) => unreachable!(format!(
"non-constant shift {} detected during static analysis",
e2
)),
}
}
UExpressionInner::Xor(box e1, box e2) => match (
self.fold_uint_expression(e1).into_inner(),
self.fold_uint_expression(e2).into_inner(),
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
UExpressionInner::Value(v1 ^ v2)
}
(UExpressionInner::Value(0), e2) => e2,
(e1, UExpressionInner::Value(0)) => e1,
(e1, e2) => {
if e1 == e2 {
UExpressionInner::Value(0)
} else {
UExpressionInner::Xor(box e1.annotate(bitwidth), box e2.annotate(bitwidth))
}
}
},
UExpressionInner::And(box e1, box e2) => match (
self.fold_uint_expression(e1).into_inner(),
self.fold_uint_expression(e2).into_inner(),
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
UExpressionInner::Value(v1 & v2)
}
(UExpressionInner::Value(0), _) | (_, UExpressionInner::Value(0)) => {
UExpressionInner::Value(0)
}
(e1, e2) => {
UExpressionInner::And(box e1.annotate(bitwidth), box e2.annotate(bitwidth))
}
},
UExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_uint_expression(consequence);
let alternative = self.fold_uint_expression(alternative);
match self.fold_boolean_expression(condition) {
BooleanExpression::Value(true) => consequence.into_inner(),
BooleanExpression::Value(false) => alternative.into_inner(),
c => UExpressionInner::IfElse(box c, box consequence, box alternative),
}
}
UExpressionInner::Not(box e) => {
let e = self.fold_uint_expression(e).into_inner();
match e {
UExpressionInner::Value(v) => UExpressionInner::Value((!v) & 0xffffffff),
e => UExpressionInner::Not(box e.annotate(bitwidth)),
}
}
UExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match size.into_inner() {
UExpressionInner::Value(size) => {
match (array.into_inner(), index.into_inner()) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
UExpression::try_from(v[n as usize].clone())
.unwrap()
.into_inner()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n, size
);
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
UExpression::try_from(v[n as usize].clone())
.unwrap()
.into_inner()
}
_ => unreachable!(),
},
_ => unreachable!(""),
},
None => UExpressionInner::Select(
box ArrayExpressionInner::Identifier(id)
.annotate(inner_type, size as u32),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
),
}
}
(a, i) => UExpressionInner::Select(
box a.annotate(inner_type, size as u32),
box i.annotate(UBitwidth::B32),
),
}
}
_ => fold_uint_expression_inner(
self,
bitwidth,
UExpressionInner::Select(box array, box index),
),
}
}
UExpressionInner::FunctionCall(key, arguments) => fold_uint_expression_inner(
self,
bitwidth,
UExpressionInner::FunctionCall(key, arguments),
),
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Identifier(id) => match self.constants.get(&id) {
Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(),
_ => unreachable!(
"constant stored for a field element should be a field element"
),
},
None => FieldElementExpression::Identifier(id),
},
FieldElementExpression::Add(box e1, box e2) => match (
self.fold_field_expression(e1),
self.fold_field_expression(e2),
) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
FieldElementExpression::Number(n1 + n2)
}
(e1, e2) => FieldElementExpression::Add(box e1, box e2),
},
FieldElementExpression::Sub(box e1, box e2) => match (
self.fold_field_expression(e1),
self.fold_field_expression(e2),
) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
FieldElementExpression::Number(n1 - n2)
}
(e1, e2) => FieldElementExpression::Sub(box e1, box e2),
},
FieldElementExpression::Mult(box e1, box e2) => match (
self.fold_field_expression(e1),
self.fold_field_expression(e2),
) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
FieldElementExpression::Number(n1 * n2)
}
(e1, e2) => FieldElementExpression::Mult(box e1, box e2),
},
FieldElementExpression::Div(box e1, box e2) => match (
self.fold_field_expression(e1),
self.fold_field_expression(e2),
) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
FieldElementExpression::Number(n1 / n2)
}
(e1, e2) => FieldElementExpression::Div(box e1, box e2),
},
FieldElementExpression::Pow(box e1, box e2) => {
let e1 = self.fold_field_expression(e1);
let e2 = self.fold_uint_expression(e2);
match (e1, e2.into_inner()) {
(_, UExpressionInner::Value(ref n2)) if *n2 == 0 => {
FieldElementExpression::Number(T::from(1))
}
(FieldElementExpression::Number(n1), UExpressionInner::Value(n2)) => {
FieldElementExpression::Number(n1.pow(n2 as usize))
}
(e1, UExpressionInner::Value(n2)) => FieldElementExpression::Pow(
box e1,
box UExpressionInner::Value(n2).annotate(UBitwidth::B32),
),
(_, e2) => unreachable!(format!(
"non-constant exponent {} detected during static analysis",
e2.annotate(UBitwidth::B32)
)),
}
}
FieldElementExpression::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_field_expression(consequence);
let alternative = self.fold_field_expression(alternative);
match self.fold_boolean_expression(condition) {
BooleanExpression::Value(true) => consequence,
BooleanExpression::Value(false) => alternative,
c => FieldElementExpression::IfElse(box c, box consequence, box alternative),
}
}
FieldElementExpression::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match size.into_inner() {
UExpressionInner::Value(size) => {
match (array.into_inner(), index.into_inner()) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
FieldElementExpression::try_from(v[n as usize].clone()).unwrap()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n, size
);
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
FieldElementExpression::try_from(
v[n as usize].clone(),
)
.unwrap()
}
_ => unreachable!(),
},
_ => unreachable!(""),
},
None => FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(id)
.annotate(inner_type, size as u32),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
),
}
}
(a, i) => FieldElementExpression::Select(
box a.annotate(inner_type, size as u32),
box i.annotate(UBitwidth::B32),
),
}
}
_ => fold_field_expression(
self,
FieldElementExpression::Select(box array, box index),
),
}
}
FieldElementExpression::Member(box s, m) => {
let s = self.fold_struct_expression(s);
let members = match s.get_type() {
Type::Struct(members) => members,
_ => unreachable!("???"),
};
match s.into_inner() {
StructExpressionInner::Value(v) => {
match members
.iter()
.zip(v)
.find(|(member, _)| member.id == m)
.unwrap()
.1
{
TypedExpression::FieldElement(s) => s,
_ => unreachable!("????"),
}
}
inner => FieldElementExpression::Member(box inner.annotate(members), m),
}
}
e => fold_field_expression(self, e),
}
}
fn fold_array_expression_inner(
&mut self,
ty: &Type<'ast, T>,
size: UExpression<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Identifier(id) => match self.constants.get(&id) {
Some(e) => match e {
TypedExpression::Array(e) => e.as_inner().clone(),
_ => panic!("constant stored for an array should be an array"),
},
None => ArrayExpressionInner::Identifier(id),
},
ArrayExpressionInner::Select(box array, box index) => {
let outer_size = size;
let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match size.into_inner() {
UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner())
{
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
ArrayExpression::try_from(v[n as usize].clone())
.unwrap()
.into_inner()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n, size
);
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
ArrayExpression::try_from(v[n as usize].clone())
.unwrap()
.into_inner()
}
_ => unreachable!(),
},
_ => unreachable!(""),
},
None => ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(id)
.annotate(inner_type, size as u32),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
),
}
}
(a, i) => ArrayExpressionInner::Select(
box a.annotate(inner_type, size as u32),
box i.annotate(UBitwidth::B32),
),
},
_ => fold_array_expression_inner(
self,
ty,
outer_size,
ArrayExpressionInner::Select(box array, box index),
),
}
}
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_array_expression(consequence);
let alternative = self.fold_array_expression(alternative);
match self.fold_boolean_expression(condition) {
BooleanExpression::Value(true) => consequence.into_inner(),
BooleanExpression::Value(false) => alternative.into_inner(),
c => ArrayExpressionInner::IfElse(box c, box consequence, box alternative),
}
}
ArrayExpressionInner::Member(box struc, id) => {
let struc = self.fold_struct_expression(struc);
let members = match struc.get_type() {
Type::Struct(members) => members,
_ => unreachable!("should be a struct"),
};
match struc.into_inner() {
StructExpressionInner::Value(v) => {
match members
.iter()
.zip(v)
.find(|(member, _)| member.id == id)
.unwrap()
.1
{
TypedExpression::Array(a) => a.into_inner(),
_ => unreachable!("should be an array"),
}
}
inner => ArrayExpressionInner::Member(box inner.annotate(members), id),
}
}
e => fold_array_expression_inner(self, ty, size, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Identifier(id) => match self.constants.get(&id) {
Some(e) => match e {
TypedExpression::Struct(e) => e.as_inner().clone(),
_ => panic!("constant stored for an array should be an array"),
},
None => StructExpressionInner::Identifier(id),
},
StructExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match size.into_inner() {
UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner())
{
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
StructExpression::try_from(v[n as usize].clone())
.unwrap()
.into_inner()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n, size
);
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
StructExpression::try_from(v[n as usize].clone())
.unwrap()
.into_inner()
}
_ => unreachable!(),
},
_ => unreachable!(""),
},
None => StructExpressionInner::Select(
box ArrayExpressionInner::Identifier(id)
.annotate(inner_type, size as u32),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
),
}
}
(a, i) => StructExpressionInner::Select(
box a.annotate(inner_type, size as u32),
box i.annotate(UBitwidth::B32),
),
},
_ => fold_struct_expression_inner(
self,
ty,
StructExpressionInner::Select(box array, box index),
),
}
}
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_struct_expression(consequence);
let alternative = self.fold_struct_expression(alternative);
match self.fold_boolean_expression(condition) {
BooleanExpression::Value(true) => consequence.into_inner(),
BooleanExpression::Value(false) => alternative.into_inner(),
c => StructExpressionInner::IfElse(box c, box consequence, box alternative),
}
}
StructExpressionInner::Member(box s, m) => {
let s = self.fold_struct_expression(s);
let members = match s.get_type() {
Type::Struct(members) => members,
_ => unreachable!("should be a struct"),
};
match s.into_inner() {
StructExpressionInner::Value(v) => {
match members
.iter()
.zip(v)
.find(|(member, _)| member.id == m)
.unwrap()
.1
{
TypedExpression::Struct(s) => s.into_inner(),
_ => unreachable!("should be a struct"),
}
}
inner => StructExpressionInner::Member(box inner.annotate(members), m),
}
}
e => fold_struct_expression_inner(self, ty, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
// Note: we only propagate when we see constants, as comparing of arbitrary expressions would lead to
// a lot of false negatives due to expressions not being in a canonical form
// For example, `2 * a` is equivalent to `a + a`, but our notion of equality would not detect that here
// These kind of reduction rules are easier to apply later in the process, when we have canonical representations
// of expressions, ie `a + a` would always be written `2 * a`
match e {
BooleanExpression::Identifier(id) => match self.constants.get(&id) {
Some(e) => match e {
TypedExpression::Boolean(e) => e.clone(),
_ => panic!("constant stored for a boolean should be a boolean"),
},
None => BooleanExpression::Identifier(id),
},
BooleanExpression::FieldEq(box e1, box e2) => {
let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2);
match (e1, e2) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 == n2)
}
(e1, e2) => BooleanExpression::FieldEq(box e1, box e2),
}
}
BooleanExpression::UintEq(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1);
let e2 = self.fold_uint_expression(e2);
match (e1.as_inner(), e2.as_inner()) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
BooleanExpression::Value(v1 == v2)
}
_ => BooleanExpression::UintEq(box e1, box e2),
}
}
BooleanExpression::BoolEq(box e1, box e2) => {
let e1 = self.fold_boolean_expression(e1);
let e2 = self.fold_boolean_expression(e2);
match (e1, e2) {
(BooleanExpression::Value(n1), BooleanExpression::Value(n2)) => {
BooleanExpression::Value(n1 == n2)
}
(e1, e2) => BooleanExpression::BoolEq(box e1, box e2),
}
}
BooleanExpression::FieldLt(box e1, box e2) => {
let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2);
match (e1, e2) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 < n2)
}
(e1, e2) => BooleanExpression::FieldLt(box e1, box e2),
}
}
BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2);
match (e1, e2) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 <= n2)
}
(e1, e2) => BooleanExpression::FieldLe(box e1, box e2),
}
}
BooleanExpression::FieldGt(box e1, box e2) => {
let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2);
match (e1, e2) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 > n2)
}
(e1, e2) => BooleanExpression::FieldGt(box e1, box e2),
}
}
BooleanExpression::FieldGe(box e1, box e2) => {
let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2);
match (e1, e2) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 >= n2)
}
(e1, e2) => BooleanExpression::FieldGe(box e1, box e2),
}
}
BooleanExpression::UintLt(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1);
let e2 = self.fold_uint_expression(e2);
match (e1.as_inner(), e2.as_inner()) {
(UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => {
BooleanExpression::Value(n1 < n2)
}
_ => BooleanExpression::UintLt(box e1, box e2),
}
}
BooleanExpression::UintLe(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1);
let e2 = self.fold_uint_expression(e2);
match (e1.as_inner(), e2.as_inner()) {
(UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => {
BooleanExpression::Value(n1 <= n2)
}
_ => BooleanExpression::UintLe(box e1, box e2),
}
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1);
let e2 = self.fold_uint_expression(e2);
match (e1.as_inner(), e2.as_inner()) {
(UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => {
BooleanExpression::Value(n1 > n2)
}
_ => BooleanExpression::UintGt(box e1, box e2),
}
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1);
let e2 = self.fold_uint_expression(e2);
match (e1.as_inner(), e2.as_inner()) {
(UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => {
BooleanExpression::Value(n1 >= n2)
}
_ => BooleanExpression::UintGe(box e1, box e2),
}
}
BooleanExpression::Or(box e1, box e2) => {
let e1 = self.fold_boolean_expression(e1);
let e2 = self.fold_boolean_expression(e2);
match (e1, e2) {
// reduction of constants
(BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => {
BooleanExpression::Value(v1 || v2)
}
// x || true == true
(_, BooleanExpression::Value(true)) | (BooleanExpression::Value(true), _) => {
BooleanExpression::Value(true)
}
// x || false == x
(e, BooleanExpression::Value(false)) | (BooleanExpression::Value(false), e) => {
e
}
(e1, e2) => BooleanExpression::Or(box e1, box e2),
}
}
BooleanExpression::And(box e1, box e2) => {
let e1 = self.fold_boolean_expression(e1);
let e2 = self.fold_boolean_expression(e2);
match (e1, e2) {
// reduction of constants
(BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => {
BooleanExpression::Value(v1 && v2)
}
// x && true == x
(e, BooleanExpression::Value(true)) | (BooleanExpression::Value(true), e) => e,
// x && false == false
(_, BooleanExpression::Value(false)) | (BooleanExpression::Value(false), _) => {
BooleanExpression::Value(false)
}
(e1, e2) => BooleanExpression::And(box e1, box e2),
}
}
BooleanExpression::Not(box e) => {
let e = self.fold_boolean_expression(e);
match e {
BooleanExpression::Value(v) => BooleanExpression::Value(!v),
e => BooleanExpression::Not(box e),
}
}
BooleanExpression::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_boolean_expression(consequence);
let alternative = self.fold_boolean_expression(alternative);
match self.fold_boolean_expression(condition) {
BooleanExpression::Value(true) => consequence,
BooleanExpression::Value(false) => alternative,
c => BooleanExpression::IfElse(box c, box consequence, box alternative),
}
}
BooleanExpression::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match size.into_inner() {
UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner())
{
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
BooleanExpression::try_from(v[n as usize].clone()).unwrap()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n, size
);
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
BooleanExpression::try_from(v[n as usize].clone())
.unwrap()
}
_ => unreachable!(),
},
_ => unreachable!(""),
},
None => BooleanExpression::Select(
box ArrayExpressionInner::Identifier(id)
.annotate(inner_type, size as u32),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
),
}
}
(a, i) => BooleanExpression::Select(
box a.annotate(inner_type, size as u32),
box i.annotate(UBitwidth::B32),
),
},
_ => fold_boolean_expression(
self,
BooleanExpression::Select(box array, box index),
),
}
}
BooleanExpression::Member(box s, m) => {
let s = self.fold_struct_expression(s);
let members = match s.get_type() {
Type::Struct(members) => members,
_ => unreachable!("should be a struct"),
};
match s.into_inner() {
StructExpressionInner::Value(v) => {
match members
.iter()
.zip(v)
.find(|(member, _)| member.id == m)
.unwrap()
.1
{
TypedExpression::Boolean(s) => s,
_ => unreachable!("should be a boolean"),
}
}
inner => BooleanExpression::Member(box inner.annotate(members), m),
}
}
e => fold_boolean_expression(self, e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::Bn128Field;
#[cfg(test)]
mod expression {
use super::*;
#[cfg(test)]
mod field {
use super::*;
#[test]
fn add() {
let e = FieldElementExpression::Add(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(3)),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
FieldElementExpression::Number(Bn128Field::from(5))
);
}
#[test]
fn sub() {
let e = FieldElementExpression::Sub(
box FieldElementExpression::Number(Bn128Field::from(3)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
FieldElementExpression::Number(Bn128Field::from(1))
);
}
#[test]
fn mult() {
let e = FieldElementExpression::Mult(
box FieldElementExpression::Number(Bn128Field::from(3)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
FieldElementExpression::Number(Bn128Field::from(6))
);
}
#[test]
fn div() {
let e = FieldElementExpression::Div(
box FieldElementExpression::Number(Bn128Field::from(6)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
FieldElementExpression::Number(Bn128Field::from(3))
);
}
#[test]
fn pow() {
let e = FieldElementExpression::Pow(
box FieldElementExpression::Number(Bn128Field::from(2)),
box 3u32.into(),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
FieldElementExpression::Number(Bn128Field::from(8))
);
}
#[test]
fn if_else_true() {
let e = FieldElementExpression::IfElse(
box BooleanExpression::Value(true),
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(3)),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
FieldElementExpression::Number(Bn128Field::from(2))
);
}
#[test]
fn if_else_false() {
let e = FieldElementExpression::IfElse(
box BooleanExpression::Value(false),
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(3)),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
FieldElementExpression::Number(Bn128Field::from(3))
);
}
#[test]
fn select() {
let e = FieldElementExpression::Select(
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(Bn128Field::from(1)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(3)).into(),
])
.annotate(Type::FieldElement, 3usize),
box UExpressionInner::Add(box 1u32.into(), box 1u32.into())
.annotate(UBitwidth::B32),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new()).fold_field_expression(e),
FieldElementExpression::Number(Bn128Field::from(3))
);
}
}
#[cfg(test)]
mod boolean {
use super::*;
#[test]
fn not() {
let e_true: BooleanExpression<Bn128Field> =
BooleanExpression::Not(box BooleanExpression::Value(false));
let e_false: BooleanExpression<Bn128Field> =
BooleanExpression::Not(box BooleanExpression::Value(true));
let e_default: BooleanExpression<Bn128Field> =
BooleanExpression::Not(box BooleanExpression::Identifier("a".into()));
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
BooleanExpression::Value(false)
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_default.clone()),
e_default
);
}
#[test]
fn field_eq() {
let e_true = BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
let e_false = BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
BooleanExpression::Value(false)
);
}
#[test]
fn bool_eq() {
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(
box BooleanExpression::Value(false),
box BooleanExpression::Value(false)
)),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(
box BooleanExpression::Value(true),
box BooleanExpression::Value(true)
)),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(
box BooleanExpression::Value(true),
box BooleanExpression::Value(false)
)),
BooleanExpression::Value(false)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(
box BooleanExpression::Value(false),
box BooleanExpression::Value(true)
)),
BooleanExpression::Value(false)
);
}
#[test]
fn lt() {
let e_true = BooleanExpression::FieldLt(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(4)),
);
let e_false = BooleanExpression::FieldLt(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
BooleanExpression::Value(false)
);
}
#[test]
fn le() {
let e_true = BooleanExpression::FieldLe(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
let e_false = BooleanExpression::FieldLe(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
BooleanExpression::Value(false)
);
}
#[test]
fn gt() {
let e_true = BooleanExpression::FieldGt(
box FieldElementExpression::Number(Bn128Field::from(5)),
box FieldElementExpression::Number(Bn128Field::from(4)),
);
let e_false = BooleanExpression::FieldGt(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(5)),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
BooleanExpression::Value(false)
);
}
#[test]
fn ge() {
let e_true = BooleanExpression::FieldGe(
box FieldElementExpression::Number(Bn128Field::from(5)),
box FieldElementExpression::Number(Bn128Field::from(5)),
);
let e_false = BooleanExpression::FieldGe(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(5)),
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
BooleanExpression::Value(false)
);
}
#[test]
fn and() {
let a_bool: Identifier = "a".into();
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Value(true),
box BooleanExpression::Identifier(a_bool.clone())
)),
BooleanExpression::Identifier(a_bool.clone())
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Identifier(a_bool.clone()),
box BooleanExpression::Value(true),
)),
BooleanExpression::Identifier(a_bool.clone())
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Value(false),
box BooleanExpression::Identifier(a_bool.clone())
)),
BooleanExpression::Value(false)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Identifier(a_bool.clone()),
box BooleanExpression::Value(false),
)),
BooleanExpression::Value(false)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Value(true),
box BooleanExpression::Value(false),
)),
BooleanExpression::Value(false)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Value(false),
box BooleanExpression::Value(true),
)),
BooleanExpression::Value(false)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Value(true),
box BooleanExpression::Value(true),
)),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::And(
box BooleanExpression::Value(false),
box BooleanExpression::Value(false),
)),
BooleanExpression::Value(false)
);
}
#[test]
fn or() {
let a_bool: Identifier = "a".into();
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Value(true),
box BooleanExpression::Identifier(a_bool.clone())
)),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Identifier(a_bool.clone()),
box BooleanExpression::Value(true),
)),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Value(false),
box BooleanExpression::Identifier(a_bool.clone())
)),
BooleanExpression::Identifier(a_bool.clone())
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Identifier(a_bool.clone()),
box BooleanExpression::Value(false),
)),
BooleanExpression::Identifier(a_bool.clone())
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Value(true),
box BooleanExpression::Value(false),
)),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Value(false),
box BooleanExpression::Value(true),
)),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Value(true),
box BooleanExpression::Value(true),
)),
BooleanExpression::Value(true)
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::Or(
box BooleanExpression::Value(false),
box BooleanExpression::Value(false),
)),
BooleanExpression::Value(false)
);
}
}
}
}