1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

make tests pass

This commit is contained in:
schaeff 2020-09-21 22:22:38 +02:00
parent 69e682f065
commit be77844e69
25 changed files with 840 additions and 1024 deletions

View file

@ -1,3 +1,3 @@
def main() -> ():
assert(1 == 2)
assert(1f == 2f)
return

View file

@ -14,6 +14,6 @@ def main(field a) -> field:
assert(2 * b == a * 12 + 60)
field c = 7 * (b + a)
assert(isEqual(c, 7 * b + 7 * a))
field k = if [1, 2] == [3, 4] then 1 else 3 fi
field k = if [1f, 2] == [3f, 4] then 1 else 3 fi
assert([Bar { foo : [Foo { a: 42 }]}] == [Bar { foo : [Foo { a: 42 }]}])
return b + c

View file

@ -722,10 +722,9 @@ mod tests {
arguments: vec![],
statements: vec![absy::Statement::Return(
absy::ExpressionList {
expressions: vec![absy::Expression::FieldConstant(
Bn128Field::from(42),
)
.into()],
expressions: vec![
absy::Expression::IntConstant(42usize.into()).into()
],
}
.into(),
)
@ -803,10 +802,9 @@ mod tests {
],
statements: vec![absy::Statement::Return(
absy::ExpressionList {
expressions: vec![absy::Expression::FieldConstant(
Bn128Field::from(42),
)
.into()],
expressions: vec![
absy::Expression::IntConstant(42usize.into()).into()
],
}
.into(),
)
@ -870,7 +868,7 @@ mod tests {
"field[2]",
absy::UnresolvedType::Array(
box absy::UnresolvedType::FieldElement.mock(),
absy::Expression::FieldConstant(Bn128Field::from(2)).mock(),
absy::Expression::IntConstant(2usize.into()).mock(),
),
),
(
@ -878,10 +876,10 @@ mod tests {
absy::UnresolvedType::Array(
box absy::UnresolvedType::Array(
box absy::UnresolvedType::FieldElement.mock(),
absy::Expression::FieldConstant(Bn128Field::from(3)).mock(),
absy::Expression::IntConstant(3usize.into()).mock(),
)
.mock(),
absy::Expression::FieldConstant(Bn128Field::from(2)).mock(),
absy::Expression::IntConstant(2usize.into()).mock(),
),
),
(
@ -889,10 +887,10 @@ mod tests {
absy::UnresolvedType::Array(
box absy::UnresolvedType::Array(
box absy::UnresolvedType::Boolean.mock(),
absy::Expression::U32Constant(3).mock(),
absy::Expression::U32Constant(3u32).mock(),
)
.mock(),
absy::Expression::FieldConstant(Bn128Field::from(2)).mock(),
absy::Expression::IntConstant(2usize.into()).mock(),
),
),
];
@ -943,7 +941,7 @@ mod tests {
absy::Expression::Select(
box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(3)).into(),
absy::Expression::IntConstant(3usize.into()).into(),
)
.into(),
),
@ -954,13 +952,13 @@ mod tests {
box absy::Expression::Select(
box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(3)).into(),
absy::Expression::IntConstant(3usize.into()).into(),
)
.into(),
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(4)).into(),
absy::Expression::IntConstant(4usize.into()).into(),
)
.into(),
),
@ -970,11 +968,11 @@ mod tests {
absy::Expression::Select(
box absy::Expression::FunctionCall(
"a",
vec![absy::Expression::FieldConstant(Bn128Field::from(3)).into()],
vec![absy::Expression::IntConstant(3usize.into()).into()],
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(4)).into(),
absy::Expression::IntConstant(4usize.into()).into(),
)
.into(),
),
@ -985,17 +983,17 @@ mod tests {
box absy::Expression::Select(
box absy::Expression::FunctionCall(
"a",
vec![absy::Expression::FieldConstant(Bn128Field::from(3)).into()],
vec![absy::Expression::IntConstant(3usize.into()).into()],
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(4)).into(),
absy::Expression::IntConstant(4usize.into()).into(),
)
.into(),
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(5)).into(),
absy::Expression::IntConstant(5usize.into()).into(),
)
.into(),
),

View file

@ -377,7 +377,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
BooleanExpression::Identifier(x) => {
FlatExpression::Identifier(self.layout.get(&x).unwrap().clone())
}
BooleanExpression::FieldLt(box lhs, box rhs) => {
BooleanExpression::Lt(box lhs, box rhs) => {
// Get the bit width to know the size of the binary decompositions for this Field
let bit_width = T::get_required_bits();
let safe_width = bit_width - 2; // making sure we don't overflow, assert here?
@ -660,11 +660,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
res
}
BooleanExpression::FieldLe(box lhs, box rhs) => {
BooleanExpression::Le(box lhs, box rhs) => {
let lt = self.flatten_boolean_expression(
symbols,
statements_flattened,
BooleanExpression::FieldLt(box lhs.clone(), box rhs.clone()),
BooleanExpression::Lt(box lhs.clone(), box rhs.clone()),
);
let eq = self.flatten_boolean_expression(
symbols,
@ -673,110 +673,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
);
FlatExpression::Add(box eq, box lt)
}
BooleanExpression::FieldGt(lhs, rhs) => self.flatten_boolean_expression(
BooleanExpression::Gt(lhs, rhs) => self.flatten_boolean_expression(
symbols,
statements_flattened,
BooleanExpression::FieldLt(rhs, lhs),
BooleanExpression::Lt(rhs, lhs),
),
BooleanExpression::FieldGe(lhs, rhs) => self.flatten_boolean_expression(
BooleanExpression::Ge(lhs, rhs) => self.flatten_boolean_expression(
symbols,
statements_flattened,
BooleanExpression::FieldLe(rhs, lhs),
),
BooleanExpression::UintLt(box lhs, box rhs) => {
let lhs_flattened =
self.flatten_uint_expression(symbols, statements_flattened, lhs);
let rhs_flattened =
self.flatten_uint_expression(symbols, statements_flattened, rhs);
// Get the bit width to know the size of the binary decompositions for this Field
// This is not this uint bitwidth
let bit_width = T::get_required_bits();
// lhs
let lhs_id = self.define(lhs_flattened.get_field_unchecked(), statements_flattened);
let rhs_id = self.define(rhs_flattened.get_field_unchecked(), statements_flattened);
// sym := (lhs * 2) - (rhs * 2)
let subtraction_result = FlatExpression::Sub(
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(lhs_id),
),
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(rhs_id),
),
);
// define variables for the bits
let sub_bits_be: Vec<FlatVariable> =
(0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
sub_bits_be.clone(),
Solver::bits(bit_width),
vec![subtraction_result.clone()],
)));
// bitness checks
for i in 0..bit_width {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(sub_bits_be[i]),
FlatExpression::Mult(
box FlatExpression::Identifier(sub_bits_be[i]),
box FlatExpression::Identifier(sub_bits_be[i]),
),
));
}
// check that the decomposition is in the field with a strict `< p` checks
self.strict_le_check(
statements_flattened,
&T::max_value_bit_vector_be(),
sub_bits_be.clone(),
);
// sum(sym_b{i} * 2**i)
let mut expr = FlatExpression::Number(T::from(0));
for i in 0..bit_width {
expr = FlatExpression::Add(
box expr,
box FlatExpression::Mult(
box FlatExpression::Identifier(sub_bits_be[i]),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened.push(FlatStatement::Condition(subtraction_result, expr));
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
}
BooleanExpression::UintLe(box lhs, box rhs) => {
let lt = self.flatten_boolean_expression(
symbols,
statements_flattened,
BooleanExpression::UintLt(box lhs.clone(), box rhs.clone()),
);
let eq = self.flatten_boolean_expression(
symbols,
statements_flattened,
BooleanExpression::UintEq(box lhs.clone(), box rhs.clone()),
);
FlatExpression::Add(box eq, box lt)
}
BooleanExpression::UintGt(lhs, rhs) => self.flatten_boolean_expression(
symbols,
statements_flattened,
BooleanExpression::UintLt(rhs, lhs),
),
BooleanExpression::UintGe(lhs, rhs) => self.flatten_boolean_expression(
symbols,
statements_flattened,
BooleanExpression::UintLe(rhs, lhs),
BooleanExpression::Le(rhs, lhs),
),
BooleanExpression::Or(box lhs, box rhs) => {
let x = self.flatten_boolean_expression(symbols, statements_flattened, lhs);
@ -2464,7 +2369,7 @@ mod tests {
#[test]
fn geq_leq() {
let mut flattener = Flattener::new();
let expression_le = BooleanExpression::FieldLe(
let expression_le = BooleanExpression::Le(
box FieldElementExpression::Number(Bn128Field::from(32)),
box FieldElementExpression::Number(Bn128Field::from(4)),
);
@ -2475,7 +2380,7 @@ mod tests {
);
let mut flattener = Flattener::new();
let expression_ge = BooleanExpression::FieldGe(
let expression_ge = BooleanExpression::Ge(
box FieldElementExpression::Number(Bn128Field::from(32)),
box FieldElementExpression::Number(Bn128Field::from(4)),
);
@ -2496,7 +2401,7 @@ mod tests {
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(4)),
),
box BooleanExpression::FieldLt(
box BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(20)),
),

File diff suppressed because it is too large Load diff

View file

@ -1,33 +1,32 @@
use std::marker::PhantomData;
use typed_absy;
use typed_absy::types::UBitwidth;
use typed_absy::types::{StructType, UBitwidth};
use zir;
use zokrates_field::Field;
use std::convert::{TryFrom, TryInto};
pub struct Flattener<T: Field> {
phantom: PhantomData<T>,
}
fn flatten_identifier_rec<'ast>(
id: zir::SourceIdentifier<'ast>,
ty: typed_absy::types::Type,
) -> Vec<zir::Variable<'ast>> {
fn flatten_identifier_rec<'a>(
id: zir::SourceIdentifier<'a>,
ty: typed_absy::Type,
) -> Vec<zir::Variable> {
match ty {
typed_absy::types::Type::FieldElement => vec![zir::Variable {
typed_absy::Type::Int => unreachable!(),
typed_absy::Type::FieldElement => vec![zir::Variable {
id: zir::Identifier::Source(id),
_type: zir::Type::FieldElement,
}],
typed_absy::types::Type::Boolean => vec![zir::Variable {
typed_absy::Type::Boolean => vec![zir::Variable {
id: zir::Identifier::Source(id),
_type: zir::Type::Boolean,
}],
typed_absy::types::Type::Uint(bitwidth) => vec![zir::Variable {
typed_absy::Type::Uint(bitwidth) => vec![zir::Variable {
id: zir::Identifier::Source(id),
_type: zir::Type::uint(bitwidth.to_usize()),
}],
typed_absy::types::Type::Array(array_type) => (0..array_type.size)
typed_absy::Type::Array(array_type) => (0..array_type.size)
.flat_map(|i| {
flatten_identifier_rec(
zir::SourceIdentifier::Select(box id.clone(), i),
@ -35,7 +34,7 @@ fn flatten_identifier_rec<'ast>(
)
})
.collect(),
typed_absy::types::Type::Struct(members) => members
typed_absy::Type::Struct(members) => members
.into_iter()
.flat_map(|struct_member| {
flatten_identifier_rec(
@ -44,7 +43,6 @@ fn flatten_identifier_rec<'ast>(
)
})
.collect(),
typed_absy::types::Type::Int => unreachable!(),
}
}
@ -80,7 +78,7 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_parameter(&mut self, p: typed_absy::Parameter<'ast>) -> Vec<zir::Parameter<'ast>> {
let private = p.private;
self.fold_variable(p.id.try_into().unwrap())
self.fold_variable(p.id)
.into_iter()
.map(|v| zir::Parameter { id: v, private })
.collect()
@ -94,8 +92,6 @@ impl<'ast, T: Field> Flattener<T> {
let id = self.fold_name(v.id.clone());
let ty = v.get_type();
let ty = typed_absy::types::Type::try_from(ty).unwrap();
flatten_identifier_rec(id, ty)
}
@ -152,8 +148,6 @@ impl<'ast, T: Field> Flattener<T> {
) -> zir::ZirExpressionList<'ast, T> {
match es {
typed_absy::TypedExpressionList::FunctionCall(id, arguments, _) => {
let id = typed_absy::types::FunctionKey::try_from(id).unwrap();
zir::ZirExpressionList::FunctionCall(
self.fold_function_key(id),
arguments
@ -202,7 +196,7 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_array_expression_inner(
&mut self,
ty: &typed_absy::types::Type,
ty: &typed_absy::Type,
size: usize,
e: typed_absy::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
@ -210,7 +204,7 @@ impl<'ast, T: Field> Flattener<T> {
}
fn fold_struct_expression_inner(
&mut self,
ty: &typed_absy::types::StructType,
ty: &StructType,
e: typed_absy::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
fold_struct_expression_inner(self, ty, e)
@ -225,12 +219,7 @@ pub fn fold_module<'ast, T: Field>(
functions: p
.functions
.into_iter()
.map(|(key, fun)| {
(
f.fold_function_key(key.try_into().unwrap()),
f.fold_function_symbol(fun),
)
})
.map(|(key, fun)| (f.fold_function_key(key), f.fold_function_symbol(fun)))
.collect(),
}
}
@ -280,16 +269,14 @@ pub fn fold_statement<'ast, T: Field>(
pub fn fold_array_expression_inner<'ast, T: Field>(
f: &mut Flattener<T>,
t: &typed_absy::types::Type,
t: &typed_absy::Type,
size: usize,
e: typed_absy::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match e {
typed_absy::ArrayExpressionInner::Identifier(id) => {
let variables = flatten_identifier_rec(
f.fold_name(id),
typed_absy::types::Type::array(t.clone(), size),
);
let variables =
flatten_identifier_rec(f.fold_name(id), typed_absy::Type::array(t.clone(), size));
variables
.into_iter()
.map(|v| match v._type {
@ -344,11 +331,7 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
let offset: usize = members
.iter()
.take_while(|member| member.id != id)
.map(|member| {
typed_absy::types::Type::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.map(|member| member.ty.get_primitive_count())
.sum();
// we also need the size of this member
@ -358,15 +341,12 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
}
typed_absy::ArrayExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index);
let index = f.fold_field_expression(index);
match index.into_inner() {
zir::UExpressionInner::Value(i) => {
let size = typed_absy::types::Type::try_from(t.clone())
.unwrap()
.get_primitive_count()
* size;
let start = i as usize * size;
match index {
zir::FieldElementExpression::Number(i) => {
let size = t.get_primitive_count() * size;
let start = i.to_dec_string().parse::<usize>().unwrap() * size;
let end = start + size;
array[start..end].to_vec()
}
@ -378,13 +358,13 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
pub fn fold_struct_expression_inner<'ast, T: Field>(
f: &mut Flattener<T>,
t: &typed_absy::types::StructType,
t: &StructType,
e: typed_absy::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match e {
typed_absy::StructExpressionInner::Identifier(id) => {
let variables =
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::struc(t.clone()));
flatten_identifier_rec(f.fold_name(id), typed_absy::Type::struc(t.clone()));
variables
.into_iter()
.map(|v| match v._type {
@ -439,33 +419,30 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
let offset: usize = members
.iter()
.take_while(|member| member.id != id)
.map(|member| {
typed_absy::types::Type::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.map(|member| member.ty.get_primitive_count())
.sum();
// we also need the size of this member
let size = typed_absy::types::Type::try_from(
*t.iter().find(|member| member.id == id).cloned().unwrap().ty,
)
.unwrap()
.get_primitive_count();
let size = t
.iter()
.find(|member| member.id == id)
.unwrap()
.ty
.get_primitive_count();
s[offset..offset + size].to_vec()
}
typed_absy::StructExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index);
let index = f.fold_field_expression(index);
match index.into_inner() {
zir::UExpressionInner::Value(i) => {
match index {
zir::FieldElementExpression::Number(i) => {
let size = t
.iter()
.map(|m| m.ty.get_primitive_count())
.fold(0, |acc, current| acc + current);
let start = i as usize * size;
let start = i.to_dec_string().parse::<usize>().unwrap() * size;
let end = start + size;
array[start..end].to_vec()
}
@ -483,7 +460,7 @@ pub fn fold_field_expression<'ast, T: Field>(
typed_absy::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n),
typed_absy::FieldElementExpression::Identifier(id) => {
zir::FieldElementExpression::Identifier(
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::FieldElement)[0]
flatten_identifier_rec(f.fold_name(id), typed_absy::Type::FieldElement)[0]
.id
.clone(),
)
@ -528,22 +505,26 @@ pub fn fold_field_expression<'ast, T: Field>(
let offset: usize = members
.iter()
.take_while(|member| member.id != id)
.map(|member| {
typed_absy::types::Type::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.map(|member| member.ty.get_primitive_count())
.sum();
use std::convert::TryInto;
s[offset].clone().try_into().unwrap()
}
typed_absy::FieldElementExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index);
let index = f.fold_field_expression(index);
match index.into_inner() {
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
use std::convert::TryInto;
match index {
zir::FieldElementExpression::Number(i) => array
[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap(),
_ => unreachable!(""),
}
}
@ -557,7 +538,7 @@ pub fn fold_boolean_expression<'ast, T: Field>(
match e {
typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v),
typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier(
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::Boolean)[0]
flatten_identifier_rec(f.fold_name(id), typed_absy::Type::Boolean)[0]
.id
.clone(),
),
@ -633,45 +614,25 @@ pub fn fold_boolean_expression<'ast, T: Field>(
zir::BooleanExpression::UintEq(box e1, box e2)
}
typed_absy::BooleanExpression::FieldLt(box e1, box e2) => {
typed_absy::BooleanExpression::Lt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::FieldLt(box e1, box e2)
zir::BooleanExpression::Lt(box e1, box e2)
}
typed_absy::BooleanExpression::FieldLe(box e1, box e2) => {
typed_absy::BooleanExpression::Le(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::FieldLe(box e1, box e2)
zir::BooleanExpression::Le(box e1, box e2)
}
typed_absy::BooleanExpression::FieldGt(box e1, box e2) => {
typed_absy::BooleanExpression::Gt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::FieldGt(box e1, box e2)
zir::BooleanExpression::Gt(box e1, box e2)
}
typed_absy::BooleanExpression::FieldGe(box e1, box e2) => {
typed_absy::BooleanExpression::Ge(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::FieldGe(box e1, box e2)
}
typed_absy::BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintLt(box e1, box e2)
}
typed_absy::BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintLe(box e1, box e2)
}
typed_absy::BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintGt(box e1, box e2)
}
typed_absy::BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintGe(box e1, box e2)
zir::BooleanExpression::Ge(box e1, box e2)
}
typed_absy::BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1);
@ -702,21 +663,25 @@ pub fn fold_boolean_expression<'ast, T: Field>(
let offset: usize = members
.iter()
.take_while(|member| member.id != id)
.map(|member| {
typed_absy::types::Type::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.map(|member| member.ty.get_primitive_count())
.sum();
use std::convert::TryInto;
s[offset].clone().try_into().unwrap()
}
typed_absy::BooleanExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index);
let index = f.fold_field_expression(index);
match index.into_inner() {
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
use std::convert::TryInto;
match index {
zir::FieldElementExpression::Number(i) => array
[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap(),
_ => unreachable!(),
}
}
@ -739,7 +704,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
match e {
typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v),
typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier(
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::Uint(bitwidth))[0]
flatten_identifier_rec(f.fold_name(id), typed_absy::Type::Uint(bitwidth))[0]
.id
.clone(),
),
@ -801,11 +766,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
}
typed_absy::UExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index);
let index = f.fold_field_expression(index);
match index.into_inner() {
zir::UExpressionInner::Value(i) => {
let e: zir::UExpression<_> = array[i as usize].clone().try_into().unwrap();
use std::convert::TryInto;
match index {
zir::FieldElementExpression::Number(i) => {
let e: zir::UExpression<_> = array[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap();
e.into_inner()
}
_ => unreachable!(),
@ -819,13 +789,11 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
let offset: usize = members
.iter()
.take_while(|member| member.id != id)
.map(|member| {
typed_absy::types::Type::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.map(|member| member.ty.get_primitive_count())
.sum();
use std::convert::TryInto;
let res: zir::UExpression<'ast, T> = s[offset].clone().try_into().unwrap();
res.into_inner()
@ -854,9 +822,7 @@ pub fn fold_function<'ast, T: Field>(
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
signature: typed_absy::types::Signature::try_from(fun.signature)
.unwrap()
.into(),
signature: fun.signature.into(),
}
}
@ -864,22 +830,14 @@ pub fn fold_array_expression<'ast, T: Field>(
f: &mut Flattener<T>,
e: typed_absy::ArrayExpression<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
let size = e.size();
f.fold_array_expression_inner(
&typed_absy::types::Type::try_from(e.inner_type().clone()).unwrap(),
size,
e.into_inner(),
)
f.fold_array_expression_inner(&e.inner_type().clone(), e.size(), e.into_inner())
}
pub fn fold_struct_expression<'ast, T: Field>(
f: &mut Flattener<T>,
e: typed_absy::StructExpression<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
f.fold_struct_expression_inner(
&typed_absy::types::StructType::try_from(e.ty().clone()).unwrap(),
e.into_inner(),
)
f.fold_struct_expression_inner(&e.ty().clone(), e.into_inner())
}
pub fn fold_function_symbol<'ast, T: Field>(
@ -890,10 +848,9 @@ pub fn fold_function_symbol<'ast, T: Field>(
typed_absy::TypedFunctionSymbol::Here(fun) => {
zir::ZirFunctionSymbol::Here(f.fold_function(fun))
}
typed_absy::TypedFunctionSymbol::There(key, module) => zir::ZirFunctionSymbol::There(
f.fold_function_key(typed_absy::types::FunctionKey::try_from(key).unwrap()),
module,
), // by default, do not fold modules recursively
typed_absy::TypedFunctionSymbol::There(key, module) => {
zir::ZirFunctionSymbol::There(f.fold_function_key(key), module)
} // by default, do not fold modules recursively
typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat),
}
}

View file

@ -67,10 +67,10 @@ mod tests {
#[test]
fn detect_non_constant_bound() {
let loops: Vec<TypedStatement<Bn128Field>> = vec![TypedStatement::For(
let loops = vec![TypedStatement::For(
Variable::field_element("i"),
UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32),
2u32.into(),
FieldElementExpression::Identifier("i".into()),
FieldElementExpression::Number(Bn128Field::from(2)),
vec![],
)];
@ -118,12 +118,12 @@ mod tests {
let s = TypedStatement::For(
Variable::field_element("i"),
0u32.into(),
2u32.into(),
FieldElementExpression::Number(Bn128Field::from(0)),
FieldElementExpression::Number(Bn128Field::from(2)),
vec![TypedStatement::For(
Variable::field_element("j"),
UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32),
2u32.into(),
FieldElementExpression::Identifier("i".into()),
FieldElementExpression::Number(Bn128Field::from(2)),
vec![
TypedStatement::Declaration(Variable::field_element("foo")),
TypedStatement::Definition(

View file

@ -121,8 +121,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
// we stop propagation here as constants maybe be modified inside the loop body
// which we do not visit
TypedStatement::For(v, from, to, statements) => {
let from = self.fold_uint_expression(from);
let to = self.fold_uint_expression(to);
let from = self.fold_field_expression(from);
let to = self.fold_field_expression(to);
// invalidate the constants map as any constant could be modified inside the loop body, which we don't visit
self.constants.clear();
@ -505,31 +505,33 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
UExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index);
let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.into_inner(), index.into_inner()) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
let n = n as usize;
if n < size {
UExpression::try_from(v[n].clone()).unwrap().into_inner()
match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
UExpression::try_from(v[n_as_usize].clone())
.unwrap()
.into_inner()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n, size
n_as_usize, size
);
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::Uint(e) => e.clone().into_inner(),
@ -537,14 +539,11 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
},
None => UExpressionInner::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
box FieldElementExpression::Number(n),
),
}
}
(a, i) => UExpressionInner::Select(
box a.annotate(inner_type, size),
box i.annotate(UBitwidth::B32),
),
(a, i) => UExpressionInner::Select(box a.annotate(inner_type, size), box i),
}
}
UExpressionInner::FunctionCall(key, arguments) => {
@ -648,46 +647,45 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
FieldElementExpression::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index);
let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.into_inner(), index.into_inner()) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
let n = n as usize;
if n < size {
FieldElementExpression::try_from(v[n].clone()).unwrap()
match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
FieldElementExpression::try_from(v[n_as_usize].clone()).unwrap()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n, size
n_as_usize, size
);
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(),
_ => unreachable!(""),
_ => unreachable!("??"),
},
None => FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
box FieldElementExpression::Number(n),
),
}
}
(a, i) => FieldElementExpression::Select(
box a.annotate(inner_type, size),
box i.annotate(UBitwidth::B32),
),
(a, i) => {
FieldElementExpression::Select(box a.annotate(inner_type, size), box i)
}
}
}
FieldElementExpression::Member(box s, m) => {
@ -749,48 +747,45 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
ArrayExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index);
let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.into_inner(), index.into_inner()) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
let n = n as usize;
if n < size {
ArrayExpression::try_from(v[n].clone())
match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
ArrayExpression::try_from(v[n_as_usize].clone())
.unwrap()
.into_inner()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n, size
n_as_usize, size
);
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::Array(e) => e.clone().into_inner(),
_ => unreachable!(""),
_ => unreachable!("should be an array"),
},
None => ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
box FieldElementExpression::Number(n),
),
}
}
(a, i) => ArrayExpressionInner::Select(
box a.annotate(inner_type, size),
box i.annotate(UBitwidth::B32),
),
(a, i) => ArrayExpressionInner::Select(box a.annotate(inner_type, size), box i),
}
}
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
@ -864,48 +859,47 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
StructExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index);
let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.into_inner(), index.into_inner()) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
let n = n as usize;
if n < size {
StructExpression::try_from(v[n].clone())
match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
StructExpression::try_from(v[n_as_usize].clone())
.unwrap()
.into_inner()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n, size
n_as_usize, size
);
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::Struct(e) => e.clone().into_inner(),
_ => unreachable!(""),
_ => unreachable!("should be a struct"),
},
None => StructExpressionInner::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
box FieldElementExpression::Number(n),
),
}
}
(a, i) => StructExpressionInner::Select(
box a.annotate(inner_type, size),
box i.annotate(UBitwidth::B32),
),
(a, i) => {
StructExpressionInner::Select(box a.annotate(inner_type, size), box i)
}
}
}
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
@ -998,7 +992,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(e1, e2) => BooleanExpression::BoolEq(box e1, box e2),
}
}
BooleanExpression::FieldLt(box e1, box e2) => {
BooleanExpression::Lt(box e1, box e2) => {
let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2);
@ -1006,10 +1000,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 < n2)
}
(e1, e2) => BooleanExpression::FieldLt(box e1, box e2),
(e1, e2) => BooleanExpression::Lt(box e1, box e2),
}
}
BooleanExpression::FieldLe(box e1, box e2) => {
BooleanExpression::Le(box e1, box e2) => {
let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2);
@ -1017,10 +1011,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 <= n2)
}
(e1, e2) => BooleanExpression::FieldLe(box e1, box e2),
(e1, e2) => BooleanExpression::Le(box e1, box e2),
}
}
BooleanExpression::FieldGt(box e1, box e2) => {
BooleanExpression::Gt(box e1, box e2) => {
let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2);
@ -1028,10 +1022,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 > n2)
}
(e1, e2) => BooleanExpression::FieldGt(box e1, box e2),
(e1, e2) => BooleanExpression::Gt(box e1, box e2),
}
}
BooleanExpression::FieldGe(box e1, box e2) => {
BooleanExpression::Ge(box e1, box e2) => {
let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2);
@ -1039,7 +1033,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 >= n2)
}
(e1, e2) => BooleanExpression::FieldGe(box e1, box e2),
(e1, e2) => BooleanExpression::Ge(box e1, box e2),
}
}
BooleanExpression::Or(box e1, box e2) => {
@ -1098,46 +1092,43 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
BooleanExpression::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index);
let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.into_inner(), index.into_inner()) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
let n = n as usize;
if n < size {
BooleanExpression::try_from(v[n].clone()).unwrap()
match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
BooleanExpression::try_from(v[n_as_usize].clone()).unwrap()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n, size
n_as_usize, size
);
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::Boolean(e) => e.clone(),
_ => unreachable!(""),
_ => unreachable!("Should be a boolean"),
},
None => BooleanExpression::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
box FieldElementExpression::Number(n),
),
}
}
(a, i) => BooleanExpression::Select(
box a.annotate(inner_type, size),
box i.annotate(UBitwidth::B32),
),
(a, i) => BooleanExpression::Select(box a.annotate(inner_type, size), box i),
}
}
BooleanExpression::Member(box s, m) => {
@ -1291,7 +1282,10 @@ mod tests {
FieldElementExpression::Number(Bn128Field::from(3)).into(),
])
.annotate(Type::FieldElement, 3),
box UExpression::add(1u32.into(), 1u32.into()),
box FieldElementExpression::Add(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1)),
),
);
assert_eq!(
@ -1397,12 +1391,12 @@ mod tests {
#[test]
fn lt() {
let e_true = BooleanExpression::FieldLt(
let e_true = BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(4)),
);
let e_false = BooleanExpression::FieldLt(
let e_false = BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
@ -1419,12 +1413,12 @@ mod tests {
#[test]
fn le() {
let e_true = BooleanExpression::FieldLe(
let e_true = BooleanExpression::Le(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
let e_false = BooleanExpression::FieldLe(
let e_false = BooleanExpression::Le(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
@ -1441,12 +1435,12 @@ mod tests {
#[test]
fn gt() {
let e_true = BooleanExpression::FieldGt(
let e_true = BooleanExpression::Gt(
box FieldElementExpression::Number(Bn128Field::from(5)),
box FieldElementExpression::Number(Bn128Field::from(4)),
);
let e_false = BooleanExpression::FieldGt(
let e_false = BooleanExpression::Gt(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(5)),
);
@ -1463,12 +1457,12 @@ mod tests {
#[test]
fn ge() {
let e_true = BooleanExpression::FieldGe(
let e_true = BooleanExpression::Ge(
box FieldElementExpression::Number(Bn128Field::from(5)),
box FieldElementExpression::Number(Bn128Field::from(5)),
);
let e_false = BooleanExpression::FieldGe(
let e_false = BooleanExpression::Ge(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(5)),
);

View file

@ -23,6 +23,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder {
.iter()
.zip(ret_identifiers.iter())
.map(|(e, i)| match e.get_type() {
Type::Int => unreachable!(),
Type::FieldElement => FieldElementExpression::Identifier(i.clone()).into(),
Type::Boolean => BooleanExpression::Identifier(i.clone()).into(),
Type::Array(array_type) => ArrayExpressionInner::Identifier(i.clone())
@ -34,7 +35,6 @@ impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder {
Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone())
.annotate(bitwidth)
.into(),
Type::Int => unreachable!(),
})
.collect();

View file

@ -85,10 +85,9 @@ impl<'ast> Unroller<'ast> {
match head {
Access::Select(head) => {
statements.insert(TypedStatement::Assertion(
BooleanExpression::UintLt(
BooleanExpression::Lt(
box head.clone(),
box UExpressionInner::Value(size as u128)
.annotate(UBitwidth::B32),
box FieldElementExpression::Number(T::from(size)),
)
.into(),
));
@ -98,16 +97,14 @@ impl<'ast> Unroller<'ast> {
.map(|i| match inner_ty {
Type::Int => unreachable!(),
Type::Array(..) => ArrayExpression::if_else(
BooleanExpression::UintEq(
box UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
ArrayExpression::select(
base.clone(),
UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
@ -122,22 +119,19 @@ impl<'ast> Unroller<'ast> {
},
ArrayExpression::select(
base.clone(),
UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Struct(..) => StructExpression::if_else(
BooleanExpression::UintEq(
box UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
StructExpression::select(
base.clone(),
UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
@ -152,22 +146,19 @@ impl<'ast> Unroller<'ast> {
},
StructExpression::select(
base.clone(),
UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::FieldElement => FieldElementExpression::if_else(
BooleanExpression::UintEq(
box UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
FieldElementExpression::select(
base.clone(),
UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
@ -182,22 +173,19 @@ impl<'ast> Unroller<'ast> {
},
FieldElementExpression::select(
base.clone(),
UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Boolean => BooleanExpression::if_else(
BooleanExpression::UintEq(
box UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
BooleanExpression::select(
base.clone(),
UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
@ -212,22 +200,19 @@ impl<'ast> Unroller<'ast> {
},
BooleanExpression::select(
base.clone(),
UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Uint(..) => UExpression::if_else(
BooleanExpression::UintEq(
box UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
UExpression::select(
base.clone(),
UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
@ -242,8 +227,7 @@ impl<'ast> Unroller<'ast> {
},
UExpression::select(
base.clone(),
UExpressionInner::Value(i as u128)
.annotate(UBitwidth::B32),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
@ -376,7 +360,7 @@ impl<'ast> Unroller<'ast> {
#[derive(Clone, Debug)]
enum Access<'ast, T: Field> {
Select(UExpression<'ast, T>),
Select(FieldElementExpression<'ast, T>),
Member(MemberId),
}
/// Turn an assignee into its representation as a base variable and a list accesses
@ -437,7 +421,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
let indices = indices
.into_iter()
.map(|a| match a {
Access::Select(i) => Access::Select(self.fold_uint_expression(i)),
Access::Select(i) => Access::Select(self.fold_field_expression(i)),
a => a,
})
.collect();
@ -463,16 +447,16 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
vec![TypedStatement::MultipleDefinition(variables, exprs)]
}
TypedStatement::For(v, from, to, stats) => {
let from = self.fold_uint_expression(from);
let to = self.fold_uint_expression(to);
let from = self.fold_field_expression(from);
let to = self.fold_field_expression(to);
match (from.into_inner(), to.into_inner()) {
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => {
let mut values = vec![];
match (from, to) {
(FieldElementExpression::Number(from), FieldElementExpression::Number(to)) => {
let mut values: Vec<T> = vec![];
let mut current = from;
while current < to {
values.push(current.clone());
current = 1 + &current;
current = T::one() + &current;
}
let res = values
@ -483,9 +467,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
TypedStatement::Declaration(v.clone()),
TypedStatement::Definition(
TypedAssignee::Identifier(v.clone()),
UExpressionInner::Value(index)
.annotate(UBitwidth::B32)
.into(),
FieldElementExpression::Number(index).into(),
),
],
stats.clone(),
@ -501,12 +483,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
}
(from, to) => {
self.complete = false;
vec![TypedStatement::For(
v,
from.annotate(UBitwidth::B32),
to.annotate(UBitwidth::B32),
stats,
)]
vec![TypedStatement::For(v, from, to, stats)]
}
}
}
@ -538,11 +515,11 @@ mod tests {
#[test]
fn ssa_array() {
let a0 = ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 3usize);
let a0 = ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 3);
let e = FieldElementExpression::Number(Bn128Field::from(42)).into();
let index = 1u32.into();
let index = FieldElementExpression::Number(Bn128Field::from(1));
let a1 = Unroller::choose_many(
a0.clone().into(),
@ -558,34 +535,52 @@ mod tests {
a1,
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(42)),
FieldElementExpression::select(a0.clone(), 0u32.into(),)
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(42)),
FieldElementExpression::select(a0.clone(), 1u32.into())
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::UintEq(box 2u32.into(), box 1u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(42)),
FieldElementExpression::select(a0.clone(), 2u32.into())
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(2))
)
)
.into()
])
.annotate(Type::FieldElement, 3usize)
.annotate(Type::FieldElement, 3)
.into()
);
let a0: ArrayExpression<Bn128Field> = ArrayExpressionInner::Identifier("a".into())
let a0 = ArrayExpressionInner::Identifier("a".into())
.annotate(Type::array(Type::FieldElement, 3), 3);
let e = ArrayExpressionInner::Identifier("b".into()).annotate(Type::FieldElement, 3);
let index = 1u32.into();
let index = FieldElementExpression::Number(Bn128Field::from(1));
let a1 = Unroller::choose_many(
a0.clone().into(),
@ -601,21 +596,39 @@ mod tests {
a1,
ArrayExpressionInner::Value(vec![
ArrayExpression::if_else(
BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
e.clone(),
ArrayExpression::select(a0.clone(), 0u32.into())
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
e.clone(),
ArrayExpression::select(a0.clone(), 1u32.into())
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::UintEq(box 2u32.into(), box 1u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
e.clone(),
ArrayExpression::select(a0.clone(), 2u32.into())
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(2))
)
)
.into()
])
@ -628,7 +641,10 @@ mod tests {
let e = FieldElementExpression::Number(Bn128Field::from(42));
let indices = vec![Access::Select(0u32.into()), Access::Select(0u32.into())];
let indices = vec![
Access::Select(FieldElementExpression::Number(Bn128Field::from(0))),
Access::Select(FieldElementExpression::Number(Bn128Field::from(0))),
];
let a1 = Unroller::choose_many(
a0.clone().into(),
@ -644,55 +660,91 @@ mod tests {
a1,
ArrayExpressionInner::Value(vec![
ArrayExpression::if_else(
BooleanExpression::UintEq(box 0u32.into(), box 0u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::UintEq(box 0u32.into(), box 0u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(a0.clone(), 0u32.into()),
0u32.into()
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
),
FieldElementExpression::Number(Bn128Field::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::UintEq(box 1u32.into(), box 0u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(a0.clone(), 0u32.into()),
1u32.into()
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
),
FieldElementExpression::Number(Bn128Field::from(1))
)
)
.into()
])
.annotate(Type::FieldElement, 2),
ArrayExpression::select(a0.clone(), 0u32.into())
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::UintEq(box 1u32.into(), box 0u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::UintEq(box 0u32.into(), box 0u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(a0.clone(), 1u32.into()),
0u32.into()
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::UintEq(box 1u32.into(), box 0u32.into()),
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(a0.clone(), 1u32.into()),
1u32.into()
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(1))
)
)
.into()
])
.annotate(Type::FieldElement, 2),
ArrayExpression::select(a0.clone(), 1u32.into())
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
)
)
.into(),
])
@ -721,8 +773,8 @@ mod tests {
let s = TypedStatement::For(
Variable::field_element("i"),
2u32.into(),
5u32.into(),
FieldElementExpression::Number(Bn128Field::from(2)),
FieldElementExpression::Number(Bn128Field::from(5)),
vec![
TypedStatement::Declaration(Variable::field_element("foo")),
TypedStatement::Definition(
@ -1031,7 +1083,7 @@ mod tests {
let s: TypedStatement<Bn128Field> = TypedStatement::Definition(
TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::field_array("a", 2)),
box 1u32.into(),
box FieldElementExpression::Number(Bn128Field::from(1)),
),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
);
@ -1040,7 +1092,11 @@ mod tests {
u.fold_statement(s),
vec![
TypedStatement::Assertion(
BooleanExpression::UintLt(box 1u32.into(), box 2u32.into()).into(),
BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(2))
)
.into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array(
@ -1049,26 +1105,32 @@ mod tests {
)),
ArrayExpressionInner::Value(vec![
FieldElementExpression::IfElse(
box BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()),
box BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::FieldElement, 2),
box 0u32.into()
box FieldElementExpression::Number(Bn128Field::from(0))
),
)
.into(),
FieldElementExpression::IfElse(
box BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()),
box BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0)
)
.annotate(Type::FieldElement, 2),
box 1u32.into()
box FieldElementExpression::Number(Bn128Field::from(1))
),
)
.into(),
@ -1153,7 +1215,7 @@ mod tests {
"a",
array_of_array_ty.clone(),
)),
box 1u32.into(),
box FieldElementExpression::Number(Bn128Field::from(1)),
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(Bn128Field::from(4)).into(),
@ -1167,7 +1229,11 @@ mod tests {
u.fold_statement(s),
vec![
TypedStatement::Assertion(
BooleanExpression::UintLt(box 1u32.into(), box 2u32.into()).into(),
BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(2))
)
.into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(
@ -1176,7 +1242,10 @@ mod tests {
)),
ArrayExpressionInner::Value(vec![
ArrayExpressionInner::IfElse(
box BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()),
box BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(Bn128Field::from(4)).into(),
FieldElementExpression::Number(Bn128Field::from(5)).into(),
@ -1188,14 +1257,17 @@ mod tests {
Identifier::from("a").version(0)
)
.annotate(Type::array(Type::FieldElement, 2), 2),
box 0u32.into()
box FieldElementExpression::Number(Bn128Field::from(0))
)
.annotate(Type::FieldElement, 2),
)
.annotate(Type::FieldElement, 2)
.into(),
ArrayExpressionInner::IfElse(
box BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()),
box BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(Bn128Field::from(4)).into(),
FieldElementExpression::Number(Bn128Field::from(5)).into(),
@ -1207,7 +1279,7 @@ mod tests {
Identifier::from("a").version(0)
)
.annotate(Type::array(Type::FieldElement, 2), 2),
box 1u32.into()
box FieldElementExpression::Number(Bn128Field::from(1))
)
.annotate(Type::FieldElement, 2),
)

View file

@ -29,12 +29,10 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
fn select<U: Select<'ast, T> + IfElse<'ast, T>>(
&mut self,
a: ArrayExpression<'ast, T>,
i: UExpression<'ast, T>,
i: FieldElementExpression<'ast, T>,
) -> U {
match i.into_inner() {
UExpressionInner::Value(i) => {
U::select(a, UExpressionInner::Value(i).annotate(UBitwidth::B32))
}
match i {
FieldElementExpression::Number(i) => U::select(a, FieldElementExpression::Number(i)),
i => {
let size = match a.get_type().clone() {
Type::Array(array_ty) => array_ty.size,
@ -44,9 +42,9 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
self.statements.push(TypedStatement::Assertion(
(0..size)
.map(|index| {
BooleanExpression::UintEq(
box i.clone().annotate(UBitwidth::B32),
box UExpressionInner::Value(index as u128).annotate(UBitwidth::B32),
BooleanExpression::FieldEq(
box i.clone(),
box FieldElementExpression::Number(index.into()).into(),
)
})
.fold(None, |acc, e| match acc {
@ -58,19 +56,14 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
));
(0..size)
.map(|i| {
U::select(
a.clone(),
UExpressionInner::Value(i as u128).annotate(UBitwidth::B32),
)
})
.map(|i| U::select(a.clone(), FieldElementExpression::Number(i.into())))
.enumerate()
.rev()
.fold(None, |acc, (index, res)| match acc {
Some(acc) => Some(U::if_else(
BooleanExpression::UintEq(
box i.clone().annotate(UBitwidth::B32),
box UExpressionInner::Value(index as u128).annotate(UBitwidth::B32),
BooleanExpression::FieldEq(
box i.clone(),
box FieldElementExpression::Number(index.into()),
),
res,
acc,
@ -168,7 +161,7 @@ mod tests {
TypedAssignee::Identifier(Variable::field_element("b")),
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 2),
box UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32),
box FieldElementExpression::Identifier("i".into()),
)
.into(),
);
@ -199,12 +192,12 @@ mod tests {
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 2),
box 0u32.into(),
box FieldElementExpression::Number(0.into()),
),
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 2),
box 1u32.into(),
box FieldElementExpression::Number(1.into()),
)
)
.into()

View file

@ -46,7 +46,7 @@ pub trait Folder<'ast, T: Field>: Sized {
TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)),
TypedAssignee::Select(box a, box index) => TypedAssignee::Select(
box self.fold_assignee(a),
box self.fold_uint_expression(index),
box self.fold_field_expression(index),
),
TypedAssignee::Member(box s, m) => TypedAssignee::Member(box self.fold_assignee(s), m),
}
@ -216,7 +216,7 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
}
ArrayExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index);
let index = f.fold_field_expression(index);
ArrayExpressionInner::Select(box array, box index)
}
}
@ -249,7 +249,7 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
}
StructExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index);
let index = f.fold_field_expression(index);
StructExpressionInner::Select(box array, box index)
}
}
@ -305,7 +305,7 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
}
FieldElementExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index);
let index = f.fold_field_expression(index);
FieldElementExpression::Select(box array, box index)
}
}
@ -350,45 +350,25 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintEq(box e1, box e2)
}
BooleanExpression::FieldLt(box e1, box e2) => {
BooleanExpression::Lt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLt(box e1, box e2)
BooleanExpression::Lt(box e1, box e2)
}
BooleanExpression::FieldLe(box e1, box e2) => {
BooleanExpression::Le(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLe(box e1, box e2)
BooleanExpression::Le(box e1, box e2)
}
BooleanExpression::FieldGt(box e1, box e2) => {
BooleanExpression::Gt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGt(box e1, box e2)
BooleanExpression::Gt(box e1, box e2)
}
BooleanExpression::FieldGe(box e1, box e2) => {
BooleanExpression::Ge(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGe(box e1, box e2)
}
BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLt(box e1, box e2)
}
BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLe(box e1, box e2)
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGt(box e1, box e2)
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGe(box e1, box e2)
BooleanExpression::Ge(box e1, box e2)
}
BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1);
@ -420,7 +400,7 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
}
BooleanExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index);
let index = f.fold_field_expression(index);
BooleanExpression::Select(box array, box index)
}
}
@ -503,7 +483,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
}
UExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index);
let index = f.fold_field_expression(index);
UExpressionInner::Select(box array, box index)
}
UExpressionInner::IfElse(box cond, box cons, box alt) => {

View file

@ -256,7 +256,10 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedAssignee<'ast, T> {
Identifier(Variable<'ast>),
Select(Box<TypedAssignee<'ast, T>>, Box<UExpression<'ast, T>>),
Select(
Box<TypedAssignee<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
Member(Box<TypedAssignee<'ast, T>>, MemberId),
}
@ -316,8 +319,8 @@ pub enum TypedStatement<'ast, T> {
Assertion(BooleanExpression<'ast, T>),
For(
Variable<'ast>,
UExpression<'ast, T>,
UExpression<'ast, T>,
FieldElementExpression<'ast, T>,
FieldElementExpression<'ast, T>,
Vec<TypedStatement<'ast, T>>,
),
MultipleDefinition(Vec<Variable<'ast>>, TypedExpressionList<'ast, T>),
@ -424,44 +427,85 @@ pub enum TypedExpression<'ast, T> {
}
impl<'ast, T: Field> TypedExpression<'ast, T> {
// return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types.
// return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types if possible.
// Post condition is that (lhs, rhs) cannot be made equal by further removing IntExpressions
pub fn align_without_integers(
lhs: Self,
rhs: Self,
) -> Result<(TypedExpression<'ast, T>, TypedExpression<'ast, T>), String> {
) -> Result<
(TypedExpression<'ast, T>, TypedExpression<'ast, T>),
(TypedExpression<'ast, T>, TypedExpression<'ast, T>),
> {
use self::TypedExpression::*;
match (lhs, rhs) {
(Int(lhs), FieldElement(rhs)) => Ok((
FieldElementExpression::try_from_int(lhs)?.into(),
FieldElementExpression::try_from_int(lhs)
.map_err(|lhs| (lhs.into(), rhs.clone().into()))?
.into(),
FieldElement(rhs),
)),
(FieldElement(lhs), Int(rhs)) => Ok((
FieldElement(lhs),
FieldElementExpression::try_from_int(rhs)?.into(),
FieldElement(lhs.clone()),
FieldElementExpression::try_from_int(rhs)
.map_err(|rhs| (lhs.into(), rhs.into()))?
.into(),
)),
(Int(lhs), Uint(rhs)) => Ok((
UExpression::try_from_int(lhs, rhs.bitwidth())?.into(),
UExpression::try_from_int(lhs, rhs.bitwidth())
.map_err(|lhs| (lhs.into(), rhs.clone().into()))?
.into(),
Uint(rhs),
)),
(Uint(lhs), Int(rhs)) => {
let bitwidth = lhs.bitwidth();
Ok((Uint(lhs), UExpression::try_from_int(rhs, bitwidth)?.into()))
Ok((
Uint(lhs.clone()),
UExpression::try_from_int(rhs, bitwidth)
.map_err(|rhs| (lhs.into(), rhs.into()))?
.into(),
))
}
(lhs, rhs) => Ok((lhs, rhs)),
(Array(lhs), Array(rhs)) => {
if lhs.get_type() == rhs.get_type() {
Ok((Array(lhs), Array(rhs)))
} else {
Err((Array(lhs), Array(rhs)))
}
}
(Struct(lhs), Struct(rhs)) => {
if lhs.get_type() == rhs.get_type() {
Ok((Struct(lhs).into(), Struct(rhs).into()))
} else {
Err((Struct(lhs).into(), Struct(rhs).into()))
}
}
(Uint(lhs), Uint(rhs)) => Ok((lhs.into(), rhs.into())),
(Boolean(lhs), Boolean(rhs)) => Ok((lhs.into(), rhs.into())),
(FieldElement(lhs), FieldElement(rhs)) => Ok((lhs.into(), rhs.into())),
(Int(lhs), Int(rhs)) => Ok((lhs.into(), rhs.into())),
(lhs, rhs) => Err((lhs, rhs)),
}
}
pub fn align_to_type(e: Self, ty: Type) -> Result<Self, String> {
use self::TypedExpression::*;
match (e, ty) {
(Int(e), Type::FieldElement) => Ok(FieldElementExpression::try_from_int(e)?.into()),
(Int(e), Type::Uint(bitwidth)) => Ok(UExpression::try_from_int(e, bitwidth)?.into()),
(Array(a), Type::Array(array_ty)) => Ok(ArrayExpression::try_from_int(a, array_ty)
.map_err(|_| String::from("align array to type"))?
.into()),
(e, _) => Ok(e),
pub fn align_to_type(e: Self, ty: Type) -> Result<Self, (Self, Type)> {
match ty.clone() {
Type::FieldElement => {
FieldElementExpression::try_from_typed(e).map(TypedExpression::from)
}
Type::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from),
Type::Uint(bitwidth) => {
UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from)
}
Type::Array(array_ty) => {
ArrayExpression::try_from_typed(e, array_ty).map(TypedExpression::from)
}
Type::Struct(struct_ty) => {
StructExpression::try_from_typed(e, struct_ty).map(TypedExpression::from)
}
Type::Int => Err(e),
}
.map_err(|e| (e, ty))
}
}
@ -478,7 +522,10 @@ pub enum IntExpression<'ast, T> {
Box<IntExpression<'ast, T>>,
Box<IntExpression<'ast, T>>,
),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
Xor(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
And(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
Or(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
@ -734,7 +781,10 @@ pub enum FieldElementExpression<'ast, T> {
),
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
Member(Box<StructExpression<'ast, T>>, MemberId),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
impl<'ast, T: Field> BooleanExpression<'ast, T> {
@ -746,6 +796,12 @@ impl<'ast, T: Field> BooleanExpression<'ast, T> {
}
}
impl<'ast, T> From<T> for FieldElementExpression<'ast, T> {
fn from(n: T) -> Self {
FieldElementExpression::Number(n)
}
}
impl<'ast, T: Field> FieldElementExpression<'ast, T> {
pub fn try_from_typed(e: TypedExpression<'ast, T>) -> Result<Self, TypedExpression<'ast, T>> {
match e {
@ -757,13 +813,13 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
}
}
pub fn try_from_int(i: IntExpression<'ast, T>) -> Result<Self, String> {
pub fn try_from_int(i: IntExpression<'ast, T>) -> Result<Self, IntExpression<'ast, T>> {
match i {
IntExpression::Value(i) => {
if i <= T::max_value().to_biguint() {
Ok(Self::Number(T::from(i)))
} else {
Err(format!("Literal `{} is too large for type `field`", i))
Err(IntExpression::Value(i))
}
}
IntExpression::Add(box e1, box e2) => Ok(Self::Add(
@ -794,7 +850,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
))
}
IntExpression::Select(..) => unimplemented!(),
i => Err(format!("Expected a `field` but found expression `{}`", i)),
i => Err(i),
}
}
}
@ -804,26 +860,22 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
pub enum BooleanExpression<'ast, T> {
Identifier(Identifier<'ast>),
Value(bool),
FieldLt(
Lt(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldLe(
Le(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldGe(
Ge(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldGt(
Gt(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
UintLt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
FieldEq(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
@ -854,7 +906,10 @@ pub enum BooleanExpression<'ast, T> {
),
Member(Box<StructExpression<'ast, T>>, MemberId),
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
/// An expression of type `array`
@ -880,7 +935,10 @@ pub enum ArrayExpressionInner<'ast, T> {
Box<ArrayExpression<'ast, T>>,
),
Member(Box<StructExpression<'ast, T>>, MemberId),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
impl<'ast, T> ArrayExpressionInner<'ast, T> {
@ -931,13 +989,16 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
}
// precondition: `array` is only made of inline arrays
pub fn try_from_int(array: Self, target_array_ty: ArrayType) -> Result<Self, ()> {
if array.get_array_type() == target_array_ty {
pub fn try_from_int(
array: Self,
target_array_ty: ArrayType,
) -> Result<Self, TypedExpression<'ast, T>> {
let array_ty = array.get_array_type();
if array_ty == target_array_ty {
return Ok(array);
}
let array_ty = array.get_array_type();
// sizes must be equal
match target_array_ty.size == array_ty.size {
true =>
@ -946,56 +1007,67 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
match array.into_inner() {
ArrayExpressionInner::Value(inline_array) => {
match *target_array_ty.ty {
Type::Int => Ok(ArrayExpressionInner::Value(inline_array)
.annotate(*target_array_ty.ty, array_ty.size)),
Type::FieldElement => {
// try to convert all elements to field
let converted = inline_array
.into_iter()
.map(|e| {
let int = IntExpression::try_from(e)?;
let field = FieldElementExpression::try_from_int(int)
.map_err(|_| ())?;
Ok(field.into())
FieldElementExpression::try_from_typed(e)
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, ()>>()?;
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)?;
Ok(ArrayExpressionInner::Value(converted)
.annotate(*target_array_ty.ty, array_ty.size))
}
Type::Uint(bitwidth) => {
// try to convert all elements to field
// try to convert all elements to uint
let converted = inline_array
.into_iter()
.map(|e| {
let int = IntExpression::try_from(e)?;
let field = UExpression::try_from_int(int, bitwidth)
.map_err(|_| ())?;
Ok(field.into())
UExpression::try_from_typed(e, bitwidth)
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, ()>>()?;
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)?;
Ok(ArrayExpressionInner::Value(converted)
.annotate(*target_array_ty.ty, array_ty.size))
}
Type::Array(ref array_ty) => {
// try to convert all elements to field
Type::Array(ref inner_array_ty) => {
// try to convert all elements to uint
let converted = inline_array
.into_iter()
.map(|e| {
let array = ArrayExpression::try_from(e)?;
let array =
ArrayExpression::try_from_int(array, array_ty.clone())
.map_err(|_| ())?;
Ok(array.into())
ArrayExpression::try_from_typed(e, inner_array_ty.clone())
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, ()>>()?;
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)?;
Ok(ArrayExpressionInner::Value(converted)
.annotate(*target_array_ty.ty.clone(), array_ty.size.clone()))
.annotate(*target_array_ty.ty, array_ty.size))
}
_ => Err(()),
Type::Struct(ref struct_ty) => {
// try to convert all elements to uint
let converted = inline_array
.into_iter()
.map(|e| {
StructExpression::try_from_typed(e, struct_ty.clone())
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)?;
Ok(ArrayExpressionInner::Value(converted)
.annotate(*target_array_ty.ty, array_ty.size))
}
Type::Boolean => unreachable!(),
}
}
_ => unreachable!(""),
}
}
false => Err(()),
false => unreachable!(),
}
}
}
@ -1047,7 +1119,10 @@ pub enum StructExpressionInner<'ast, T> {
Box<StructExpression<'ast, T>>,
),
Member(Box<StructExpression<'ast, T>>, MemberId),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
impl<'ast, T> StructExpressionInner<'ast, T> {
@ -1165,14 +1240,10 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::ArrayEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
@ -1247,30 +1318,10 @@ impl<'ast, T: fmt::Debug> fmt::Debug for BooleanExpression<'ast, T> {
"IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative
),
BooleanExpression::FieldLt(ref lhs, ref rhs) => {
write!(f, "FieldLt({:?}, {:?})", lhs, rhs)
}
BooleanExpression::FieldLe(ref lhs, ref rhs) => {
write!(f, "FieldLe({:?}, {:?})", lhs, rhs)
}
BooleanExpression::FieldGe(ref lhs, ref rhs) => {
write!(f, "FieldGe({:?}, {:?})", lhs, rhs)
}
BooleanExpression::FieldGt(ref lhs, ref rhs) => {
write!(f, "FieldGt({:?}, {:?})", lhs, rhs)
}
BooleanExpression::UintLt(ref lhs, ref rhs) => {
write!(f, "UintLt({:?}, {:?})", lhs, rhs)
}
BooleanExpression::UintLe(ref lhs, ref rhs) => {
write!(f, "UintLe({:?}, {:?})", lhs, rhs)
}
BooleanExpression::UintGe(ref lhs, ref rhs) => {
write!(f, "UintGe({:?}, {:?})", lhs, rhs)
}
BooleanExpression::UintGt(ref lhs, ref rhs) => {
write!(f, "UintGt({:?}, {:?})", lhs, rhs)
}
BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "Lt({:?}, {:?})", lhs, rhs),
BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "Le({:?}, {:?})", lhs, rhs),
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "Ge({:?}, {:?})", lhs, rhs),
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "Gt({:?}, {:?})", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => {
write!(f, "FieldEq({:?}, {:?})", lhs, rhs)
}
@ -1483,23 +1534,23 @@ impl<'ast, T: Clone> IfElse<'ast, T> for StructExpression<'ast, T> {
}
pub trait Select<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self;
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self;
}
impl<'ast, T> Select<'ast, T> for FieldElementExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
FieldElementExpression::Select(box array, box index)
}
}
impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
BooleanExpression::Select(box array, box index)
}
}
impl<'ast, T: Clone> Select<'ast, T> for UExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let bitwidth = match array.inner_type().clone() {
Type::Uint(bitwidth) => bitwidth,
_ => unreachable!(),
@ -1510,7 +1561,7 @@ impl<'ast, T: Clone> Select<'ast, T> for UExpression<'ast, T> {
}
impl<'ast, T: Clone> Select<'ast, T> for ArrayExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let (ty, size) = match array.inner_type() {
Type::Array(array_type) => (array_type.ty.clone(), array_type.size.clone()),
_ => unreachable!(),
@ -1521,7 +1572,7 @@ impl<'ast, T: Clone> Select<'ast, T> for ArrayExpression<'ast, T> {
}
impl<'ast, T: Clone> Select<'ast, T> for StructExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self {
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let members = match array.inner_type().clone() {
Type::Struct(members) => members,
_ => unreachable!(),

View file

@ -190,6 +190,20 @@ impl Type {
Type::Uint(b.into())
}
pub fn can_be_specialized_to(&self, other: &Type) -> bool {
use self::Type::*;
if self == other {
true
} else {
match (self, other) {
(Int, FieldElement) | (Int, Uint(..)) => true,
(Array(l), Array(r)) => l.size == r.size && l.ty.can_be_specialized_to(&r.ty),
_ => false,
}
}
}
fn to_slug(&self) -> String {
match self {
Type::FieldElement => String::from("f"),

View file

@ -72,7 +72,10 @@ impl<'ast, T: Field> UExpression<'ast, T> {
}
}
pub fn try_from_int(i: IntExpression<'ast, T>, bitwidth: UBitwidth) -> Result<Self, String> {
pub fn try_from_int(
i: IntExpression<'ast, T>,
bitwidth: UBitwidth,
) -> Result<Self, IntExpression<'ast, T>> {
use self::IntExpression::*;
match i {
@ -83,10 +86,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
)
.annotate(bitwidth))
} else {
Err(format!(
"Literal `{}` is too large for type u{}",
i, bitwidth
))
Err(Value(i))
}
}
Add(box e1, box e2) => Ok(UExpression::add(
@ -127,10 +127,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
Self::try_from_int(alternative, bitwidth)?,
)),
Select(..) => unimplemented!(),
i => Err(format!(
"Expected a `u{}` but found expression `{}`",
bitwidth, i
)),
i => Err(i),
}
}
}
@ -192,7 +189,10 @@ pub enum UExpressionInner<'ast, T> {
Box<UExpression<'ast, T>>,
),
Member(Box<StructExpression<'ast, T>>, MemberId),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {

View file

@ -204,45 +204,25 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintEq(box e1, box e2)
}
BooleanExpression::FieldLt(box e1, box e2) => {
BooleanExpression::Lt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLt(box e1, box e2)
BooleanExpression::Lt(box e1, box e2)
}
BooleanExpression::FieldLe(box e1, box e2) => {
BooleanExpression::Le(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLe(box e1, box e2)
BooleanExpression::Le(box e1, box e2)
}
BooleanExpression::FieldGt(box e1, box e2) => {
BooleanExpression::Gt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGt(box e1, box e2)
BooleanExpression::Gt(box e1, box e2)
}
BooleanExpression::FieldGe(box e1, box e2) => {
BooleanExpression::Ge(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGe(box e1, box e2)
}
BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLt(box e1, box e2)
}
BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLe(box e1, box e2)
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGt(box e1, box e2)
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGe(box e1, box e2)
BooleanExpression::Ge(box e1, box e2)
}
BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1);

View file

@ -381,26 +381,22 @@ pub enum FieldElementExpression<'ast, T> {
pub enum BooleanExpression<'ast, T> {
Identifier(Identifier<'ast>),
Value(bool),
FieldLt(
Lt(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldLe(
Le(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldGe(
Ge(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldGt(
Gt(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
UintLt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
FieldEq(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
@ -510,14 +506,10 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),

View file

@ -4,12 +4,12 @@ def main():
assert(7 == 2 ** 2 * 2 - 1)
assert(3 == 2 ** 2 / 2 + 1)
field a = if 3 == 2 ** 2 / 2 + 1 && true then 1 else 0 fi // combines arithmetic with boolean operators
field b = if 3 == 3 && 4 < 5 then 1 else 0 fi // checks precedence of boolean operators
field c = if 4 < 5 && 3 == 3 then 1 else 0 fi
field d = if 4 > 5 && 2 >= 1 || 1 == 1 then 1 else 0 fi
field e = if 2 >= 1 && 4 > 5 || 1 == 1 then 1 else 0 fi
field f = if 1 < 2 && false || 4 < 5 && 2 >= 1 then 1 else 0 fi
field a = if 3f == 2f ** 2 / 2 + 1 && true then 1 else 0 fi // combines arithmetic with boolean operators
field b = if 3f == 3f && 4f < 5f then 1 else 0 fi // checks precedence of boolean operators
field c = if 4f < 5f && 3f == 3f then 1 else 0 fi
field d = if 4f > 5f && 2f >= 1f || 1f == 1f then 1 else 0 fi
field e = if 2f >= 1f && 4f > 5f || 1f == 1f then 1 else 0 fi
field f = if 1f < 2f && false || 4f < 5f && 2f >= 1f then 1 else 0 fi
assert(0x00 ^ 0x00 == 0x00)

View file

@ -101,7 +101,9 @@ mod tests {
term(43, 44, [
primary_expression(43, 44, [
literal(43, 44, [
decimal_number(43, 44)
decimal_literal(43, 44, [
decimal_number(43, 44)
])
])
])
])
@ -258,7 +260,9 @@ mod tests {
term(30, 31, [
primary_expression(30, 31, [
literal(30, 31, [
decimal_number(30, 31)
decimal_literal(30, 31, [
decimal_number(30, 31)
])
])
])
])

View file

@ -4,7 +4,6 @@
// Note: parameters will be updated soon to be more compatible with zCash's implementation
struct BabyJubJubParams {
field JUBJUBE
field JUBJUBC
field JUBJUBA
field JUBJUBD
@ -18,7 +17,6 @@ struct BabyJubJubParams {
def main() -> BabyJubJubParams:
// Order of the curve E
field JUBJUBE = 21888242871839275222246405745257275088614511777268538073601725287587578984328
field JUBJUBC = 8 // Cofactor
field JUBJUBA = 168700 // Coefficient A
field JUBJUBD = 168696 // Coefficient D
@ -40,7 +38,6 @@ return BabyJubJubParams {
INFINITY: INFINITY,
Gu: Gu,
Gv: Gv,
JUBJUBE: JUBJUBE,
JUBJUBC: JUBJUBC,
MONTA: MONTA,
MONTB: MONTB

View file

@ -1,9 +1,6 @@
import "hashes/mimc7/mimc7R10"
def main():
assert(mimc7R10(0, 0) == 6004544488495356385698286530147974336054653445122716140990101827963729149289)
assert(mimc7R10(100, 0) == 2977550761518141183167168643824354554080911485709001361112529600968315693145)
assert(mimc7R10(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 2977550761518141183167168643824354554080911485709001361112529600968315693145)
assert(mimc7R10(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 11476724043755138071320043459606423473319855817296339514744600646762741571430)
assert(mimc7R10(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 6004544488495356385698286530147974336054653445122716140990101827963729149289)
assert(mimc7R10(0, 0) == 6004544488495356385698286530147974336054653445122716140990101827963729149289f)
assert(mimc7R10(100, 0) == 2977550761518141183167168643824354554080911485709001361112529600968315693145f)
return

View file

@ -1,9 +1,6 @@
import "hashes/mimc7/mimc7R20"
def main():
assert(mimc7R20(0, 0) == 19139739902058628561064841933381604453445216873412991992755775746150759284829)
assert(mimc7R20(100, 0) == 8623418512398828792274158979964869393034224267928014534933203776818702139758)
assert(mimc7R20(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 8623418512398828792274158979964869393034224267928014534933203776818702139758)
assert(mimc7R20(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 15315177265066649795408805007175121550344555424263995530745989936206840798041)
assert(mimc7R20(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 19139739902058628561064841933381604453445216873412991992755775746150759284829)
assert(mimc7R20(0, 0) == 19139739902058628561064841933381604453445216873412991992755775746150759284829f)
assert(mimc7R20(100, 0) == 8623418512398828792274158979964869393034224267928014534933203776818702139758f)
return

View file

@ -1,9 +1,6 @@
import "hashes/mimc7/mimc7R50"
def main():
assert(mimc7R50(0, 0) == 3049953358280347916081509186284461274525472221619157672645224540758481713173)
assert(mimc7R50(100, 0) == 18511388995652647480418174218630545482006454713617579894396683237092568946789)
assert(mimc7R50(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 18511388995652647480418174218630545482006454713617579894396683237092568946789)
assert(mimc7R50(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 9149577627043020462780389988155990926223727917856424056384664564191878439702)
assert(mimc7R50(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 3049953358280347916081509186284461274525472221619157672645224540758481713173)
assert(mimc7R50(0, 0) == 3049953358280347916081509186284461274525472221619157672645224540758481713173f)
assert(mimc7R50(100, 0) == 18511388995652647480418174218630545482006454713617579894396683237092568946789f)
return

View file

@ -2,8 +2,5 @@ import "hashes/mimc7/mimc7R90"
def main():
assert(mimc7R90(0, 0) == 20281265111705407344053532742843085357648991805359414661661476832595822221514)
assert(mimc7R90(100, 0) == 1010054095264022068840870550831559811104631937745987065544478027572003292636)
assert(mimc7R90(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 1010054095264022068840870550831559811104631937745987065544478027572003292636)
assert(mimc7R90(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 8189519586469873426687580455476035992041353456517724932462363814215190642760)
assert(mimc7R90(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 20281265111705407344053532742843085357648991805359414661661476832595822221514)
assert(mimc7R90(100, 0) == 1010054095264022068840870550831559811104631937745987065544478027572003292636f)
return

View file

@ -1,7 +1,6 @@
import "hashes/mimcSponge/mimcSponge" as mimcSponge
def main():
assert(mimcSponge([1,2], 3) == [20225509322021146255705869525264566735642015554514977326536820959638320229084,13871743498877225461925335509899475799121918157213219438898506786048812913771,21633608428713573518356618235457250173701815120501233429160399974209848779097])
assert(mimcSponge([0,0], 0) == [20636625426020718969131298365984859231982649550971729229988535915544421356929,6046202021237334713296073963481784771443313518730771623154467767602059802325,16227963524034219233279650312501310147918176407385833422019760797222680144279])
assert(mimcSponge([21888242871839275222246405745257275088548364400416034343698204186575808495617, 0], 0) == [20636625426020718969131298365984859231982649550971729229988535915544421356929,6046202021237334713296073963481784771443313518730771623154467767602059802325,16227963524034219233279650312501310147918176407385833422019760797222680144279])
assert(mimcSponge([1,2], 3) == [20225509322021146255705869525264566735642015554514977326536820959638320229084,13871743498877225461925335509899475799121918157213219438898506786048812913771,21633608428713573518356618235457250173701815120501233429160399974209848779097f])
assert(mimcSponge([0,0], 0) == [20636625426020718969131298365984859231982649550971729229988535915544421356929,6046202021237334713296073963481784771443313518730771623154467767602059802325,16227963524034219233279650312501310147918176407385833422019760797222680144279f])
return