1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

make tests pass

This commit is contained in:
schaeff 2020-09-21 22:22:38 +02:00
parent 69e682f065
commit be77844e69
25 changed files with 840 additions and 1024 deletions

View file

@ -1,3 +1,3 @@
def main() -> (): def main() -> ():
assert(1 == 2) assert(1f == 2f)
return return

View file

@ -14,6 +14,6 @@ def main(field a) -> field:
assert(2 * b == a * 12 + 60) assert(2 * b == a * 12 + 60)
field c = 7 * (b + a) field c = 7 * (b + a)
assert(isEqual(c, 7 * b + 7 * a)) assert(isEqual(c, 7 * b + 7 * a))
field k = if [1, 2] == [3, 4] then 1 else 3 fi field k = if [1f, 2] == [3f, 4] then 1 else 3 fi
assert([Bar { foo : [Foo { a: 42 }]}] == [Bar { foo : [Foo { a: 42 }]}]) assert([Bar { foo : [Foo { a: 42 }]}] == [Bar { foo : [Foo { a: 42 }]}])
return b + c return b + c

View file

@ -722,10 +722,9 @@ mod tests {
arguments: vec![], arguments: vec![],
statements: vec![absy::Statement::Return( statements: vec![absy::Statement::Return(
absy::ExpressionList { absy::ExpressionList {
expressions: vec![absy::Expression::FieldConstant( expressions: vec![
Bn128Field::from(42), absy::Expression::IntConstant(42usize.into()).into()
) ],
.into()],
} }
.into(), .into(),
) )
@ -803,10 +802,9 @@ mod tests {
], ],
statements: vec![absy::Statement::Return( statements: vec![absy::Statement::Return(
absy::ExpressionList { absy::ExpressionList {
expressions: vec![absy::Expression::FieldConstant( expressions: vec![
Bn128Field::from(42), absy::Expression::IntConstant(42usize.into()).into()
) ],
.into()],
} }
.into(), .into(),
) )
@ -870,7 +868,7 @@ mod tests {
"field[2]", "field[2]",
absy::UnresolvedType::Array( absy::UnresolvedType::Array(
box absy::UnresolvedType::FieldElement.mock(), box absy::UnresolvedType::FieldElement.mock(),
absy::Expression::FieldConstant(Bn128Field::from(2)).mock(), absy::Expression::IntConstant(2usize.into()).mock(),
), ),
), ),
( (
@ -878,10 +876,10 @@ mod tests {
absy::UnresolvedType::Array( absy::UnresolvedType::Array(
box absy::UnresolvedType::Array( box absy::UnresolvedType::Array(
box absy::UnresolvedType::FieldElement.mock(), box absy::UnresolvedType::FieldElement.mock(),
absy::Expression::FieldConstant(Bn128Field::from(3)).mock(), absy::Expression::IntConstant(3usize.into()).mock(),
) )
.mock(), .mock(),
absy::Expression::FieldConstant(Bn128Field::from(2)).mock(), absy::Expression::IntConstant(2usize.into()).mock(),
), ),
), ),
( (
@ -889,10 +887,10 @@ mod tests {
absy::UnresolvedType::Array( absy::UnresolvedType::Array(
box absy::UnresolvedType::Array( box absy::UnresolvedType::Array(
box absy::UnresolvedType::Boolean.mock(), box absy::UnresolvedType::Boolean.mock(),
absy::Expression::U32Constant(3).mock(), absy::Expression::U32Constant(3u32).mock(),
) )
.mock(), .mock(),
absy::Expression::FieldConstant(Bn128Field::from(2)).mock(), absy::Expression::IntConstant(2usize.into()).mock(),
), ),
), ),
]; ];
@ -943,7 +941,7 @@ mod tests {
absy::Expression::Select( absy::Expression::Select(
box absy::Expression::Identifier("a").into(), box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(3)).into(), absy::Expression::IntConstant(3usize.into()).into(),
) )
.into(), .into(),
), ),
@ -954,13 +952,13 @@ mod tests {
box absy::Expression::Select( box absy::Expression::Select(
box absy::Expression::Identifier("a").into(), box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(3)).into(), absy::Expression::IntConstant(3usize.into()).into(),
) )
.into(), .into(),
) )
.into(), .into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(4)).into(), absy::Expression::IntConstant(4usize.into()).into(),
) )
.into(), .into(),
), ),
@ -970,11 +968,11 @@ mod tests {
absy::Expression::Select( absy::Expression::Select(
box absy::Expression::FunctionCall( box absy::Expression::FunctionCall(
"a", "a",
vec![absy::Expression::FieldConstant(Bn128Field::from(3)).into()], vec![absy::Expression::IntConstant(3usize.into()).into()],
) )
.into(), .into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(4)).into(), absy::Expression::IntConstant(4usize.into()).into(),
) )
.into(), .into(),
), ),
@ -985,17 +983,17 @@ mod tests {
box absy::Expression::Select( box absy::Expression::Select(
box absy::Expression::FunctionCall( box absy::Expression::FunctionCall(
"a", "a",
vec![absy::Expression::FieldConstant(Bn128Field::from(3)).into()], vec![absy::Expression::IntConstant(3usize.into()).into()],
) )
.into(), .into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(4)).into(), absy::Expression::IntConstant(4usize.into()).into(),
) )
.into(), .into(),
) )
.into(), .into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(Bn128Field::from(5)).into(), absy::Expression::IntConstant(5usize.into()).into(),
) )
.into(), .into(),
), ),

View file

@ -377,7 +377,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
BooleanExpression::Identifier(x) => { BooleanExpression::Identifier(x) => {
FlatExpression::Identifier(self.layout.get(&x).unwrap().clone()) FlatExpression::Identifier(self.layout.get(&x).unwrap().clone())
} }
BooleanExpression::FieldLt(box lhs, box rhs) => { BooleanExpression::Lt(box lhs, box rhs) => {
// Get the bit width to know the size of the binary decompositions for this Field // Get the bit width to know the size of the binary decompositions for this Field
let bit_width = T::get_required_bits(); let bit_width = T::get_required_bits();
let safe_width = bit_width - 2; // making sure we don't overflow, assert here? let safe_width = bit_width - 2; // making sure we don't overflow, assert here?
@ -660,11 +660,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
res res
} }
BooleanExpression::FieldLe(box lhs, box rhs) => { BooleanExpression::Le(box lhs, box rhs) => {
let lt = self.flatten_boolean_expression( let lt = self.flatten_boolean_expression(
symbols, symbols,
statements_flattened, statements_flattened,
BooleanExpression::FieldLt(box lhs.clone(), box rhs.clone()), BooleanExpression::Lt(box lhs.clone(), box rhs.clone()),
); );
let eq = self.flatten_boolean_expression( let eq = self.flatten_boolean_expression(
symbols, symbols,
@ -673,110 +673,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
); );
FlatExpression::Add(box eq, box lt) FlatExpression::Add(box eq, box lt)
} }
BooleanExpression::FieldGt(lhs, rhs) => self.flatten_boolean_expression( BooleanExpression::Gt(lhs, rhs) => self.flatten_boolean_expression(
symbols, symbols,
statements_flattened, statements_flattened,
BooleanExpression::FieldLt(rhs, lhs), BooleanExpression::Lt(rhs, lhs),
), ),
BooleanExpression::FieldGe(lhs, rhs) => self.flatten_boolean_expression( BooleanExpression::Ge(lhs, rhs) => self.flatten_boolean_expression(
symbols, symbols,
statements_flattened, statements_flattened,
BooleanExpression::FieldLe(rhs, lhs), BooleanExpression::Le(rhs, lhs),
),
BooleanExpression::UintLt(box lhs, box rhs) => {
let lhs_flattened =
self.flatten_uint_expression(symbols, statements_flattened, lhs);
let rhs_flattened =
self.flatten_uint_expression(symbols, statements_flattened, rhs);
// Get the bit width to know the size of the binary decompositions for this Field
// This is not this uint bitwidth
let bit_width = T::get_required_bits();
// lhs
let lhs_id = self.define(lhs_flattened.get_field_unchecked(), statements_flattened);
let rhs_id = self.define(rhs_flattened.get_field_unchecked(), statements_flattened);
// sym := (lhs * 2) - (rhs * 2)
let subtraction_result = FlatExpression::Sub(
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(lhs_id),
),
box FlatExpression::Mult(
box FlatExpression::Number(T::from(2)),
box FlatExpression::Identifier(rhs_id),
),
);
// define variables for the bits
let sub_bits_be: Vec<FlatVariable> =
(0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
sub_bits_be.clone(),
Solver::bits(bit_width),
vec![subtraction_result.clone()],
)));
// bitness checks
for i in 0..bit_width {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(sub_bits_be[i]),
FlatExpression::Mult(
box FlatExpression::Identifier(sub_bits_be[i]),
box FlatExpression::Identifier(sub_bits_be[i]),
),
));
}
// check that the decomposition is in the field with a strict `< p` checks
self.strict_le_check(
statements_flattened,
&T::max_value_bit_vector_be(),
sub_bits_be.clone(),
);
// sum(sym_b{i} * 2**i)
let mut expr = FlatExpression::Number(T::from(0));
for i in 0..bit_width {
expr = FlatExpression::Add(
box expr,
box FlatExpression::Mult(
box FlatExpression::Identifier(sub_bits_be[i]),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened.push(FlatStatement::Condition(subtraction_result, expr));
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
}
BooleanExpression::UintLe(box lhs, box rhs) => {
let lt = self.flatten_boolean_expression(
symbols,
statements_flattened,
BooleanExpression::UintLt(box lhs.clone(), box rhs.clone()),
);
let eq = self.flatten_boolean_expression(
symbols,
statements_flattened,
BooleanExpression::UintEq(box lhs.clone(), box rhs.clone()),
);
FlatExpression::Add(box eq, box lt)
}
BooleanExpression::UintGt(lhs, rhs) => self.flatten_boolean_expression(
symbols,
statements_flattened,
BooleanExpression::UintLt(rhs, lhs),
),
BooleanExpression::UintGe(lhs, rhs) => self.flatten_boolean_expression(
symbols,
statements_flattened,
BooleanExpression::UintLe(rhs, lhs),
), ),
BooleanExpression::Or(box lhs, box rhs) => { BooleanExpression::Or(box lhs, box rhs) => {
let x = self.flatten_boolean_expression(symbols, statements_flattened, lhs); let x = self.flatten_boolean_expression(symbols, statements_flattened, lhs);
@ -2464,7 +2369,7 @@ mod tests {
#[test] #[test]
fn geq_leq() { fn geq_leq() {
let mut flattener = Flattener::new(); let mut flattener = Flattener::new();
let expression_le = BooleanExpression::FieldLe( let expression_le = BooleanExpression::Le(
box FieldElementExpression::Number(Bn128Field::from(32)), box FieldElementExpression::Number(Bn128Field::from(32)),
box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)),
); );
@ -2475,7 +2380,7 @@ mod tests {
); );
let mut flattener = Flattener::new(); let mut flattener = Flattener::new();
let expression_ge = BooleanExpression::FieldGe( let expression_ge = BooleanExpression::Ge(
box FieldElementExpression::Number(Bn128Field::from(32)), box FieldElementExpression::Number(Bn128Field::from(32)),
box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)),
); );
@ -2496,7 +2401,7 @@ mod tests {
box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)),
), ),
box BooleanExpression::FieldLt( box BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(20)), box FieldElementExpression::Number(Bn128Field::from(20)),
), ),

File diff suppressed because it is too large Load diff

View file

@ -1,33 +1,32 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use typed_absy; use typed_absy;
use typed_absy::types::UBitwidth; use typed_absy::types::{StructType, UBitwidth};
use zir; use zir;
use zokrates_field::Field; use zokrates_field::Field;
use std::convert::{TryFrom, TryInto};
pub struct Flattener<T: Field> { pub struct Flattener<T: Field> {
phantom: PhantomData<T>, phantom: PhantomData<T>,
} }
fn flatten_identifier_rec<'ast>( fn flatten_identifier_rec<'a>(
id: zir::SourceIdentifier<'ast>, id: zir::SourceIdentifier<'a>,
ty: typed_absy::types::Type, ty: typed_absy::Type,
) -> Vec<zir::Variable<'ast>> { ) -> Vec<zir::Variable> {
match ty { match ty {
typed_absy::types::Type::FieldElement => vec![zir::Variable { typed_absy::Type::Int => unreachable!(),
typed_absy::Type::FieldElement => vec![zir::Variable {
id: zir::Identifier::Source(id), id: zir::Identifier::Source(id),
_type: zir::Type::FieldElement, _type: zir::Type::FieldElement,
}], }],
typed_absy::types::Type::Boolean => vec![zir::Variable { typed_absy::Type::Boolean => vec![zir::Variable {
id: zir::Identifier::Source(id), id: zir::Identifier::Source(id),
_type: zir::Type::Boolean, _type: zir::Type::Boolean,
}], }],
typed_absy::types::Type::Uint(bitwidth) => vec![zir::Variable { typed_absy::Type::Uint(bitwidth) => vec![zir::Variable {
id: zir::Identifier::Source(id), id: zir::Identifier::Source(id),
_type: zir::Type::uint(bitwidth.to_usize()), _type: zir::Type::uint(bitwidth.to_usize()),
}], }],
typed_absy::types::Type::Array(array_type) => (0..array_type.size) typed_absy::Type::Array(array_type) => (0..array_type.size)
.flat_map(|i| { .flat_map(|i| {
flatten_identifier_rec( flatten_identifier_rec(
zir::SourceIdentifier::Select(box id.clone(), i), zir::SourceIdentifier::Select(box id.clone(), i),
@ -35,7 +34,7 @@ fn flatten_identifier_rec<'ast>(
) )
}) })
.collect(), .collect(),
typed_absy::types::Type::Struct(members) => members typed_absy::Type::Struct(members) => members
.into_iter() .into_iter()
.flat_map(|struct_member| { .flat_map(|struct_member| {
flatten_identifier_rec( flatten_identifier_rec(
@ -44,7 +43,6 @@ fn flatten_identifier_rec<'ast>(
) )
}) })
.collect(), .collect(),
typed_absy::types::Type::Int => unreachable!(),
} }
} }
@ -80,7 +78,7 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_parameter(&mut self, p: typed_absy::Parameter<'ast>) -> Vec<zir::Parameter<'ast>> { fn fold_parameter(&mut self, p: typed_absy::Parameter<'ast>) -> Vec<zir::Parameter<'ast>> {
let private = p.private; let private = p.private;
self.fold_variable(p.id.try_into().unwrap()) self.fold_variable(p.id)
.into_iter() .into_iter()
.map(|v| zir::Parameter { id: v, private }) .map(|v| zir::Parameter { id: v, private })
.collect() .collect()
@ -94,8 +92,6 @@ impl<'ast, T: Field> Flattener<T> {
let id = self.fold_name(v.id.clone()); let id = self.fold_name(v.id.clone());
let ty = v.get_type(); let ty = v.get_type();
let ty = typed_absy::types::Type::try_from(ty).unwrap();
flatten_identifier_rec(id, ty) flatten_identifier_rec(id, ty)
} }
@ -152,8 +148,6 @@ impl<'ast, T: Field> Flattener<T> {
) -> zir::ZirExpressionList<'ast, T> { ) -> zir::ZirExpressionList<'ast, T> {
match es { match es {
typed_absy::TypedExpressionList::FunctionCall(id, arguments, _) => { typed_absy::TypedExpressionList::FunctionCall(id, arguments, _) => {
let id = typed_absy::types::FunctionKey::try_from(id).unwrap();
zir::ZirExpressionList::FunctionCall( zir::ZirExpressionList::FunctionCall(
self.fold_function_key(id), self.fold_function_key(id),
arguments arguments
@ -202,7 +196,7 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_array_expression_inner( fn fold_array_expression_inner(
&mut self, &mut self,
ty: &typed_absy::types::Type, ty: &typed_absy::Type,
size: usize, size: usize,
e: typed_absy::ArrayExpressionInner<'ast, T>, e: typed_absy::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
@ -210,7 +204,7 @@ impl<'ast, T: Field> Flattener<T> {
} }
fn fold_struct_expression_inner( fn fold_struct_expression_inner(
&mut self, &mut self,
ty: &typed_absy::types::StructType, ty: &StructType,
e: typed_absy::StructExpressionInner<'ast, T>, e: typed_absy::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
fold_struct_expression_inner(self, ty, e) fold_struct_expression_inner(self, ty, e)
@ -225,12 +219,7 @@ pub fn fold_module<'ast, T: Field>(
functions: p functions: p
.functions .functions
.into_iter() .into_iter()
.map(|(key, fun)| { .map(|(key, fun)| (f.fold_function_key(key), f.fold_function_symbol(fun)))
(
f.fold_function_key(key.try_into().unwrap()),
f.fold_function_symbol(fun),
)
})
.collect(), .collect(),
} }
} }
@ -280,16 +269,14 @@ pub fn fold_statement<'ast, T: Field>(
pub fn fold_array_expression_inner<'ast, T: Field>( pub fn fold_array_expression_inner<'ast, T: Field>(
f: &mut Flattener<T>, f: &mut Flattener<T>,
t: &typed_absy::types::Type, t: &typed_absy::Type,
size: usize, size: usize,
e: typed_absy::ArrayExpressionInner<'ast, T>, e: typed_absy::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
match e { match e {
typed_absy::ArrayExpressionInner::Identifier(id) => { typed_absy::ArrayExpressionInner::Identifier(id) => {
let variables = flatten_identifier_rec( let variables =
f.fold_name(id), flatten_identifier_rec(f.fold_name(id), typed_absy::Type::array(t.clone(), size));
typed_absy::types::Type::array(t.clone(), size),
);
variables variables
.into_iter() .into_iter()
.map(|v| match v._type { .map(|v| match v._type {
@ -344,11 +331,7 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
let offset: usize = members let offset: usize = members
.iter() .iter()
.take_while(|member| member.id != id) .take_while(|member| member.id != id)
.map(|member| { .map(|member| member.ty.get_primitive_count())
typed_absy::types::Type::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum(); .sum();
// we also need the size of this member // we also need the size of this member
@ -358,15 +341,12 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
} }
typed_absy::ArrayExpressionInner::Select(box array, box index) => { typed_absy::ArrayExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index); let index = f.fold_field_expression(index);
match index.into_inner() { match index {
zir::UExpressionInner::Value(i) => { zir::FieldElementExpression::Number(i) => {
let size = typed_absy::types::Type::try_from(t.clone()) let size = t.get_primitive_count() * size;
.unwrap() let start = i.to_dec_string().parse::<usize>().unwrap() * size;
.get_primitive_count()
* size;
let start = i as usize * size;
let end = start + size; let end = start + size;
array[start..end].to_vec() array[start..end].to_vec()
} }
@ -378,13 +358,13 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
pub fn fold_struct_expression_inner<'ast, T: Field>( pub fn fold_struct_expression_inner<'ast, T: Field>(
f: &mut Flattener<T>, f: &mut Flattener<T>,
t: &typed_absy::types::StructType, t: &StructType,
e: typed_absy::StructExpressionInner<'ast, T>, e: typed_absy::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
match e { match e {
typed_absy::StructExpressionInner::Identifier(id) => { typed_absy::StructExpressionInner::Identifier(id) => {
let variables = let variables =
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::struc(t.clone())); flatten_identifier_rec(f.fold_name(id), typed_absy::Type::struc(t.clone()));
variables variables
.into_iter() .into_iter()
.map(|v| match v._type { .map(|v| match v._type {
@ -439,33 +419,30 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
let offset: usize = members let offset: usize = members
.iter() .iter()
.take_while(|member| member.id != id) .take_while(|member| member.id != id)
.map(|member| { .map(|member| member.ty.get_primitive_count())
typed_absy::types::Type::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum(); .sum();
// we also need the size of this member // we also need the size of this member
let size = typed_absy::types::Type::try_from( let size = t
*t.iter().find(|member| member.id == id).cloned().unwrap().ty, .iter()
) .find(|member| member.id == id)
.unwrap() .unwrap()
.get_primitive_count(); .ty
.get_primitive_count();
s[offset..offset + size].to_vec() s[offset..offset + size].to_vec()
} }
typed_absy::StructExpressionInner::Select(box array, box index) => { typed_absy::StructExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index); let index = f.fold_field_expression(index);
match index.into_inner() { match index {
zir::UExpressionInner::Value(i) => { zir::FieldElementExpression::Number(i) => {
let size = t let size = t
.iter() .iter()
.map(|m| m.ty.get_primitive_count()) .map(|m| m.ty.get_primitive_count())
.fold(0, |acc, current| acc + current); .fold(0, |acc, current| acc + current);
let start = i as usize * size; let start = i.to_dec_string().parse::<usize>().unwrap() * size;
let end = start + size; let end = start + size;
array[start..end].to_vec() array[start..end].to_vec()
} }
@ -483,7 +460,7 @@ pub fn fold_field_expression<'ast, T: Field>(
typed_absy::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n), typed_absy::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n),
typed_absy::FieldElementExpression::Identifier(id) => { typed_absy::FieldElementExpression::Identifier(id) => {
zir::FieldElementExpression::Identifier( zir::FieldElementExpression::Identifier(
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::FieldElement)[0] flatten_identifier_rec(f.fold_name(id), typed_absy::Type::FieldElement)[0]
.id .id
.clone(), .clone(),
) )
@ -528,22 +505,26 @@ pub fn fold_field_expression<'ast, T: Field>(
let offset: usize = members let offset: usize = members
.iter() .iter()
.take_while(|member| member.id != id) .take_while(|member| member.id != id)
.map(|member| { .map(|member| member.ty.get_primitive_count())
typed_absy::types::Type::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum(); .sum();
use std::convert::TryInto;
s[offset].clone().try_into().unwrap() s[offset].clone().try_into().unwrap()
} }
typed_absy::FieldElementExpression::Select(box array, box index) => { typed_absy::FieldElementExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index); let index = f.fold_field_expression(index);
match index.into_inner() { use std::convert::TryInto;
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
match index {
zir::FieldElementExpression::Number(i) => array
[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap(),
_ => unreachable!(""), _ => unreachable!(""),
} }
} }
@ -557,7 +538,7 @@ pub fn fold_boolean_expression<'ast, T: Field>(
match e { match e {
typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v), typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v),
typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier( typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier(
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::Boolean)[0] flatten_identifier_rec(f.fold_name(id), typed_absy::Type::Boolean)[0]
.id .id
.clone(), .clone(),
), ),
@ -633,45 +614,25 @@ pub fn fold_boolean_expression<'ast, T: Field>(
zir::BooleanExpression::UintEq(box e1, box e2) zir::BooleanExpression::UintEq(box e1, box e2)
} }
typed_absy::BooleanExpression::FieldLt(box e1, box e2) => { typed_absy::BooleanExpression::Lt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::FieldLt(box e1, box e2) zir::BooleanExpression::Lt(box e1, box e2)
} }
typed_absy::BooleanExpression::FieldLe(box e1, box e2) => { typed_absy::BooleanExpression::Le(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::FieldLe(box e1, box e2) zir::BooleanExpression::Le(box e1, box e2)
} }
typed_absy::BooleanExpression::FieldGt(box e1, box e2) => { typed_absy::BooleanExpression::Gt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::FieldGt(box e1, box e2) zir::BooleanExpression::Gt(box e1, box e2)
} }
typed_absy::BooleanExpression::FieldGe(box e1, box e2) => { typed_absy::BooleanExpression::Ge(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::FieldGe(box e1, box e2) zir::BooleanExpression::Ge(box e1, box e2)
}
typed_absy::BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintLt(box e1, box e2)
}
typed_absy::BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintLe(box e1, box e2)
}
typed_absy::BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintGt(box e1, box e2)
}
typed_absy::BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintGe(box e1, box e2)
} }
typed_absy::BooleanExpression::Or(box e1, box e2) => { typed_absy::BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1); let e1 = f.fold_boolean_expression(e1);
@ -702,21 +663,25 @@ pub fn fold_boolean_expression<'ast, T: Field>(
let offset: usize = members let offset: usize = members
.iter() .iter()
.take_while(|member| member.id != id) .take_while(|member| member.id != id)
.map(|member| { .map(|member| member.ty.get_primitive_count())
typed_absy::types::Type::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum(); .sum();
use std::convert::TryInto;
s[offset].clone().try_into().unwrap() s[offset].clone().try_into().unwrap()
} }
typed_absy::BooleanExpression::Select(box array, box index) => { typed_absy::BooleanExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index); let index = f.fold_field_expression(index);
match index.into_inner() { use std::convert::TryInto;
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
match index {
zir::FieldElementExpression::Number(i) => array
[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap(),
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -739,7 +704,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
match e { match e {
typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v), typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v),
typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier( typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier(
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::Uint(bitwidth))[0] flatten_identifier_rec(f.fold_name(id), typed_absy::Type::Uint(bitwidth))[0]
.id .id
.clone(), .clone(),
), ),
@ -801,11 +766,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
} }
typed_absy::UExpressionInner::Select(box array, box index) => { typed_absy::UExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index); let index = f.fold_field_expression(index);
match index.into_inner() { use std::convert::TryInto;
zir::UExpressionInner::Value(i) => {
let e: zir::UExpression<_> = array[i as usize].clone().try_into().unwrap(); match index {
zir::FieldElementExpression::Number(i) => {
let e: zir::UExpression<_> = array[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap();
e.into_inner() e.into_inner()
} }
_ => unreachable!(), _ => unreachable!(),
@ -819,13 +789,11 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
let offset: usize = members let offset: usize = members
.iter() .iter()
.take_while(|member| member.id != id) .take_while(|member| member.id != id)
.map(|member| { .map(|member| member.ty.get_primitive_count())
typed_absy::types::Type::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum(); .sum();
use std::convert::TryInto;
let res: zir::UExpression<'ast, T> = s[offset].clone().try_into().unwrap(); let res: zir::UExpression<'ast, T> = s[offset].clone().try_into().unwrap();
res.into_inner() res.into_inner()
@ -854,9 +822,7 @@ pub fn fold_function<'ast, T: Field>(
.into_iter() .into_iter()
.flat_map(|s| f.fold_statement(s)) .flat_map(|s| f.fold_statement(s))
.collect(), .collect(),
signature: typed_absy::types::Signature::try_from(fun.signature) signature: fun.signature.into(),
.unwrap()
.into(),
} }
} }
@ -864,22 +830,14 @@ pub fn fold_array_expression<'ast, T: Field>(
f: &mut Flattener<T>, f: &mut Flattener<T>,
e: typed_absy::ArrayExpression<'ast, T>, e: typed_absy::ArrayExpression<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
let size = e.size(); f.fold_array_expression_inner(&e.inner_type().clone(), e.size(), e.into_inner())
f.fold_array_expression_inner(
&typed_absy::types::Type::try_from(e.inner_type().clone()).unwrap(),
size,
e.into_inner(),
)
} }
pub fn fold_struct_expression<'ast, T: Field>( pub fn fold_struct_expression<'ast, T: Field>(
f: &mut Flattener<T>, f: &mut Flattener<T>,
e: typed_absy::StructExpression<'ast, T>, e: typed_absy::StructExpression<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
f.fold_struct_expression_inner( f.fold_struct_expression_inner(&e.ty().clone(), e.into_inner())
&typed_absy::types::StructType::try_from(e.ty().clone()).unwrap(),
e.into_inner(),
)
} }
pub fn fold_function_symbol<'ast, T: Field>( pub fn fold_function_symbol<'ast, T: Field>(
@ -890,10 +848,9 @@ pub fn fold_function_symbol<'ast, T: Field>(
typed_absy::TypedFunctionSymbol::Here(fun) => { typed_absy::TypedFunctionSymbol::Here(fun) => {
zir::ZirFunctionSymbol::Here(f.fold_function(fun)) zir::ZirFunctionSymbol::Here(f.fold_function(fun))
} }
typed_absy::TypedFunctionSymbol::There(key, module) => zir::ZirFunctionSymbol::There( typed_absy::TypedFunctionSymbol::There(key, module) => {
f.fold_function_key(typed_absy::types::FunctionKey::try_from(key).unwrap()), zir::ZirFunctionSymbol::There(f.fold_function_key(key), module)
module, } // by default, do not fold modules recursively
), // by default, do not fold modules recursively
typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat), typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat),
} }
} }

View file

@ -67,10 +67,10 @@ mod tests {
#[test] #[test]
fn detect_non_constant_bound() { fn detect_non_constant_bound() {
let loops: Vec<TypedStatement<Bn128Field>> = vec![TypedStatement::For( let loops = vec![TypedStatement::For(
Variable::field_element("i"), Variable::field_element("i"),
UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32), FieldElementExpression::Identifier("i".into()),
2u32.into(), FieldElementExpression::Number(Bn128Field::from(2)),
vec![], vec![],
)]; )];
@ -118,12 +118,12 @@ mod tests {
let s = TypedStatement::For( let s = TypedStatement::For(
Variable::field_element("i"), Variable::field_element("i"),
0u32.into(), FieldElementExpression::Number(Bn128Field::from(0)),
2u32.into(), FieldElementExpression::Number(Bn128Field::from(2)),
vec![TypedStatement::For( vec![TypedStatement::For(
Variable::field_element("j"), Variable::field_element("j"),
UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32), FieldElementExpression::Identifier("i".into()),
2u32.into(), FieldElementExpression::Number(Bn128Field::from(2)),
vec![ vec![
TypedStatement::Declaration(Variable::field_element("foo")), TypedStatement::Declaration(Variable::field_element("foo")),
TypedStatement::Definition( TypedStatement::Definition(

View file

@ -121,8 +121,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
// we stop propagation here as constants maybe be modified inside the loop body // we stop propagation here as constants maybe be modified inside the loop body
// which we do not visit // which we do not visit
TypedStatement::For(v, from, to, statements) => { TypedStatement::For(v, from, to, statements) => {
let from = self.fold_uint_expression(from); let from = self.fold_field_expression(from);
let to = self.fold_uint_expression(to); let to = self.fold_field_expression(to);
// invalidate the constants map as any constant could be modified inside the loop body, which we don't visit // invalidate the constants map as any constant could be modified inside the loop body, which we don't visit
self.constants.clear(); self.constants.clear();
@ -505,31 +505,33 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
} }
UExpressionInner::Select(box array, box index) => { UExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array); let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index); let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone(); let inner_type = array.inner_type().clone();
let size = array.size(); let size = array.size();
match (array.into_inner(), index.into_inner()) { match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n = n as usize; let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n < size { if n_as_usize < size {
UExpression::try_from(v[n].clone()).unwrap().into_inner() UExpression::try_from(v[n_as_usize].clone())
.unwrap()
.into_inner()
} else { } else {
unreachable!( unreachable!(
"out of bounds index ({} >= {}) found during static analysis", "out of bounds index ({} >= {}) found during static analysis",
n, size n_as_usize, size
); );
} }
} }
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select( match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array( box TypedAssignee::Identifier(Variable::array(
id.clone(), id.clone(),
inner_type.clone(), inner_type.clone(),
size, size,
)), )),
box UExpressionInner::Value(n).annotate(UBitwidth::B32), box FieldElementExpression::Number(n.clone()).into(),
)) { )) {
Some(e) => match e { Some(e) => match e {
TypedExpression::Uint(e) => e.clone().into_inner(), TypedExpression::Uint(e) => e.clone().into_inner(),
@ -537,14 +539,11 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}, },
None => UExpressionInner::Select( None => UExpressionInner::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box UExpressionInner::Value(n).annotate(UBitwidth::B32), box FieldElementExpression::Number(n),
), ),
} }
} }
(a, i) => UExpressionInner::Select( (a, i) => UExpressionInner::Select(box a.annotate(inner_type, size), box i),
box a.annotate(inner_type, size),
box i.annotate(UBitwidth::B32),
),
} }
} }
UExpressionInner::FunctionCall(key, arguments) => { UExpressionInner::FunctionCall(key, arguments) => {
@ -648,46 +647,45 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
} }
FieldElementExpression::Select(box array, box index) => { FieldElementExpression::Select(box array, box index) => {
let array = self.fold_array_expression(array); let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index); let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone(); let inner_type = array.inner_type().clone();
let size = array.size(); let size = array.size();
match (array.into_inner(), index.into_inner()) { match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n = n as usize; let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n < size { if n_as_usize < size {
FieldElementExpression::try_from(v[n].clone()).unwrap() FieldElementExpression::try_from(v[n_as_usize].clone()).unwrap()
} else { } else {
unreachable!( unreachable!(
"out of bounds index ({} >= {}) found during static analysis", "out of bounds index ({} >= {}) found during static analysis",
n, size n_as_usize, size
); );
} }
} }
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select( match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array( box TypedAssignee::Identifier(Variable::array(
id.clone(), id.clone(),
inner_type.clone(), inner_type.clone(),
size, size,
)), )),
box UExpressionInner::Value(n).annotate(UBitwidth::B32), box FieldElementExpression::Number(n.clone()).into(),
)) { )) {
Some(e) => match e { Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(), TypedExpression::FieldElement(e) => e.clone(),
_ => unreachable!(""), _ => unreachable!("??"),
}, },
None => FieldElementExpression::Select( None => FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box UExpressionInner::Value(n).annotate(UBitwidth::B32), box FieldElementExpression::Number(n),
), ),
} }
} }
(a, i) => FieldElementExpression::Select( (a, i) => {
box a.annotate(inner_type, size), FieldElementExpression::Select(box a.annotate(inner_type, size), box i)
box i.annotate(UBitwidth::B32), }
),
} }
} }
FieldElementExpression::Member(box s, m) => { FieldElementExpression::Member(box s, m) => {
@ -749,48 +747,45 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
} }
ArrayExpressionInner::Select(box array, box index) => { ArrayExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array); let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index); let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone(); let inner_type = array.inner_type().clone();
let size = array.size(); let size = array.size();
match (array.into_inner(), index.into_inner()) { match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n = n as usize; let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n < size { if n_as_usize < size {
ArrayExpression::try_from(v[n].clone()) ArrayExpression::try_from(v[n_as_usize].clone())
.unwrap() .unwrap()
.into_inner() .into_inner()
} else { } else {
unreachable!( unreachable!(
"out of bounds index ({} >= {}) found during static analysis", "out of bounds index ({} >= {}) found during static analysis",
n, size n_as_usize, size
); );
} }
} }
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select( match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array( box TypedAssignee::Identifier(Variable::array(
id.clone(), id.clone(),
inner_type.clone(), inner_type.clone(),
size, size,
)), )),
box UExpressionInner::Value(n).annotate(UBitwidth::B32), box FieldElementExpression::Number(n.clone()).into(),
)) { )) {
Some(e) => match e { Some(e) => match e {
TypedExpression::Array(e) => e.clone().into_inner(), TypedExpression::Array(e) => e.clone().into_inner(),
_ => unreachable!(""), _ => unreachable!("should be an array"),
}, },
None => ArrayExpressionInner::Select( None => ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box UExpressionInner::Value(n).annotate(UBitwidth::B32), box FieldElementExpression::Number(n),
), ),
} }
} }
(a, i) => ArrayExpressionInner::Select( (a, i) => ArrayExpressionInner::Select(box a.annotate(inner_type, size), box i),
box a.annotate(inner_type, size),
box i.annotate(UBitwidth::B32),
),
} }
} }
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => { ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
@ -864,48 +859,47 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
} }
StructExpressionInner::Select(box array, box index) => { StructExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array); let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index); let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone(); let inner_type = array.inner_type().clone();
let size = array.size(); let size = array.size();
match (array.into_inner(), index.into_inner()) { match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n = n as usize; let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n < size { if n_as_usize < size {
StructExpression::try_from(v[n].clone()) StructExpression::try_from(v[n_as_usize].clone())
.unwrap() .unwrap()
.into_inner() .into_inner()
} else { } else {
unreachable!( unreachable!(
"out of bounds index ({} >= {}) found during static analysis", "out of bounds index ({} >= {}) found during static analysis",
n, size n_as_usize, size
); );
} }
} }
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select( match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array( box TypedAssignee::Identifier(Variable::array(
id.clone(), id.clone(),
inner_type.clone(), inner_type.clone(),
size, size,
)), )),
box UExpressionInner::Value(n).annotate(UBitwidth::B32), box FieldElementExpression::Number(n.clone()).into(),
)) { )) {
Some(e) => match e { Some(e) => match e {
TypedExpression::Struct(e) => e.clone().into_inner(), TypedExpression::Struct(e) => e.clone().into_inner(),
_ => unreachable!(""), _ => unreachable!("should be a struct"),
}, },
None => StructExpressionInner::Select( None => StructExpressionInner::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box UExpressionInner::Value(n).annotate(UBitwidth::B32), box FieldElementExpression::Number(n),
), ),
} }
} }
(a, i) => StructExpressionInner::Select( (a, i) => {
box a.annotate(inner_type, size), StructExpressionInner::Select(box a.annotate(inner_type, size), box i)
box i.annotate(UBitwidth::B32), }
),
} }
} }
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
@ -998,7 +992,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(e1, e2) => BooleanExpression::BoolEq(box e1, box e2), (e1, e2) => BooleanExpression::BoolEq(box e1, box e2),
} }
} }
BooleanExpression::FieldLt(box e1, box e2) => { BooleanExpression::Lt(box e1, box e2) => {
let e1 = self.fold_field_expression(e1); let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2); let e2 = self.fold_field_expression(e2);
@ -1006,10 +1000,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 < n2) BooleanExpression::Value(n1 < n2)
} }
(e1, e2) => BooleanExpression::FieldLt(box e1, box e2), (e1, e2) => BooleanExpression::Lt(box e1, box e2),
} }
} }
BooleanExpression::FieldLe(box e1, box e2) => { BooleanExpression::Le(box e1, box e2) => {
let e1 = self.fold_field_expression(e1); let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2); let e2 = self.fold_field_expression(e2);
@ -1017,10 +1011,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 <= n2) BooleanExpression::Value(n1 <= n2)
} }
(e1, e2) => BooleanExpression::FieldLe(box e1, box e2), (e1, e2) => BooleanExpression::Le(box e1, box e2),
} }
} }
BooleanExpression::FieldGt(box e1, box e2) => { BooleanExpression::Gt(box e1, box e2) => {
let e1 = self.fold_field_expression(e1); let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2); let e2 = self.fold_field_expression(e2);
@ -1028,10 +1022,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 > n2) BooleanExpression::Value(n1 > n2)
} }
(e1, e2) => BooleanExpression::FieldGt(box e1, box e2), (e1, e2) => BooleanExpression::Gt(box e1, box e2),
} }
} }
BooleanExpression::FieldGe(box e1, box e2) => { BooleanExpression::Ge(box e1, box e2) => {
let e1 = self.fold_field_expression(e1); let e1 = self.fold_field_expression(e1);
let e2 = self.fold_field_expression(e2); let e2 = self.fold_field_expression(e2);
@ -1039,7 +1033,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
BooleanExpression::Value(n1 >= n2) BooleanExpression::Value(n1 >= n2)
} }
(e1, e2) => BooleanExpression::FieldGe(box e1, box e2), (e1, e2) => BooleanExpression::Ge(box e1, box e2),
} }
} }
BooleanExpression::Or(box e1, box e2) => { BooleanExpression::Or(box e1, box e2) => {
@ -1098,46 +1092,43 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
} }
BooleanExpression::Select(box array, box index) => { BooleanExpression::Select(box array, box index) => {
let array = self.fold_array_expression(array); let array = self.fold_array_expression(array);
let index = self.fold_uint_expression(index); let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone(); let inner_type = array.inner_type().clone();
let size = array.size(); let size = array.size();
match (array.into_inner(), index.into_inner()) { match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n = n as usize; let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n < size { if n_as_usize < size {
BooleanExpression::try_from(v[n].clone()).unwrap() BooleanExpression::try_from(v[n_as_usize].clone()).unwrap()
} else { } else {
unreachable!( unreachable!(
"out of bounds index ({} >= {}) found during static analysis", "out of bounds index ({} >= {}) found during static analysis",
n, size n_as_usize, size
); );
} }
} }
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select( match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array( box TypedAssignee::Identifier(Variable::array(
id.clone(), id.clone(),
inner_type.clone(), inner_type.clone(),
size, size,
)), )),
box UExpressionInner::Value(n).annotate(UBitwidth::B32), box FieldElementExpression::Number(n.clone()).into(),
)) { )) {
Some(e) => match e { Some(e) => match e {
TypedExpression::Boolean(e) => e.clone(), TypedExpression::Boolean(e) => e.clone(),
_ => unreachable!(""), _ => unreachable!("Should be a boolean"),
}, },
None => BooleanExpression::Select( None => BooleanExpression::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box UExpressionInner::Value(n).annotate(UBitwidth::B32), box FieldElementExpression::Number(n),
), ),
} }
} }
(a, i) => BooleanExpression::Select( (a, i) => BooleanExpression::Select(box a.annotate(inner_type, size), box i),
box a.annotate(inner_type, size),
box i.annotate(UBitwidth::B32),
),
} }
} }
BooleanExpression::Member(box s, m) => { BooleanExpression::Member(box s, m) => {
@ -1291,7 +1282,10 @@ mod tests {
FieldElementExpression::Number(Bn128Field::from(3)).into(), FieldElementExpression::Number(Bn128Field::from(3)).into(),
]) ])
.annotate(Type::FieldElement, 3), .annotate(Type::FieldElement, 3),
box UExpression::add(1u32.into(), 1u32.into()), box FieldElementExpression::Add(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1)),
),
); );
assert_eq!( assert_eq!(
@ -1397,12 +1391,12 @@ mod tests {
#[test] #[test]
fn lt() { fn lt() {
let e_true = BooleanExpression::FieldLt( let e_true = BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)),
); );
let e_false = BooleanExpression::FieldLt( let e_false = BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(2)),
); );
@ -1419,12 +1413,12 @@ mod tests {
#[test] #[test]
fn le() { fn le() {
let e_true = BooleanExpression::FieldLe( let e_true = BooleanExpression::Le(
box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(2)),
); );
let e_false = BooleanExpression::FieldLe( let e_false = BooleanExpression::Le(
box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(2)),
); );
@ -1441,12 +1435,12 @@ mod tests {
#[test] #[test]
fn gt() { fn gt() {
let e_true = BooleanExpression::FieldGt( let e_true = BooleanExpression::Gt(
box FieldElementExpression::Number(Bn128Field::from(5)), box FieldElementExpression::Number(Bn128Field::from(5)),
box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)),
); );
let e_false = BooleanExpression::FieldGt( let e_false = BooleanExpression::Gt(
box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(5)), box FieldElementExpression::Number(Bn128Field::from(5)),
); );
@ -1463,12 +1457,12 @@ mod tests {
#[test] #[test]
fn ge() { fn ge() {
let e_true = BooleanExpression::FieldGe( let e_true = BooleanExpression::Ge(
box FieldElementExpression::Number(Bn128Field::from(5)), box FieldElementExpression::Number(Bn128Field::from(5)),
box FieldElementExpression::Number(Bn128Field::from(5)), box FieldElementExpression::Number(Bn128Field::from(5)),
); );
let e_false = BooleanExpression::FieldGe( let e_false = BooleanExpression::Ge(
box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(5)), box FieldElementExpression::Number(Bn128Field::from(5)),
); );

View file

@ -23,6 +23,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder {
.iter() .iter()
.zip(ret_identifiers.iter()) .zip(ret_identifiers.iter())
.map(|(e, i)| match e.get_type() { .map(|(e, i)| match e.get_type() {
Type::Int => unreachable!(),
Type::FieldElement => FieldElementExpression::Identifier(i.clone()).into(), Type::FieldElement => FieldElementExpression::Identifier(i.clone()).into(),
Type::Boolean => BooleanExpression::Identifier(i.clone()).into(), Type::Boolean => BooleanExpression::Identifier(i.clone()).into(),
Type::Array(array_type) => ArrayExpressionInner::Identifier(i.clone()) Type::Array(array_type) => ArrayExpressionInner::Identifier(i.clone())
@ -34,7 +35,6 @@ impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder {
Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone()) Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone())
.annotate(bitwidth) .annotate(bitwidth)
.into(), .into(),
Type::Int => unreachable!(),
}) })
.collect(); .collect();

View file

@ -85,10 +85,9 @@ impl<'ast> Unroller<'ast> {
match head { match head {
Access::Select(head) => { Access::Select(head) => {
statements.insert(TypedStatement::Assertion( statements.insert(TypedStatement::Assertion(
BooleanExpression::UintLt( BooleanExpression::Lt(
box head.clone(), box head.clone(),
box UExpressionInner::Value(size as u128) box FieldElementExpression::Number(T::from(size)),
.annotate(UBitwidth::B32),
) )
.into(), .into(),
)); ));
@ -98,16 +97,14 @@ impl<'ast> Unroller<'ast> {
.map(|i| match inner_ty { .map(|i| match inner_ty {
Type::Int => unreachable!(), Type::Int => unreachable!(),
Type::Array(..) => ArrayExpression::if_else( Type::Array(..) => ArrayExpression::if_else(
BooleanExpression::UintEq( BooleanExpression::FieldEq(
box UExpressionInner::Value(i as u128) box FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
box head.clone(), box head.clone(),
), ),
match Self::choose_many( match Self::choose_many(
ArrayExpression::select( ArrayExpression::select(
base.clone(), base.clone(),
UExpressionInner::Value(i as u128) FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
) )
.into(), .into(),
tail.clone(), tail.clone(),
@ -122,22 +119,19 @@ impl<'ast> Unroller<'ast> {
}, },
ArrayExpression::select( ArrayExpression::select(
base.clone(), base.clone(),
UExpressionInner::Value(i as u128) FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
), ),
) )
.into(), .into(),
Type::Struct(..) => StructExpression::if_else( Type::Struct(..) => StructExpression::if_else(
BooleanExpression::UintEq( BooleanExpression::FieldEq(
box UExpressionInner::Value(i as u128) box FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
box head.clone(), box head.clone(),
), ),
match Self::choose_many( match Self::choose_many(
StructExpression::select( StructExpression::select(
base.clone(), base.clone(),
UExpressionInner::Value(i as u128) FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
) )
.into(), .into(),
tail.clone(), tail.clone(),
@ -152,22 +146,19 @@ impl<'ast> Unroller<'ast> {
}, },
StructExpression::select( StructExpression::select(
base.clone(), base.clone(),
UExpressionInner::Value(i as u128) FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
), ),
) )
.into(), .into(),
Type::FieldElement => FieldElementExpression::if_else( Type::FieldElement => FieldElementExpression::if_else(
BooleanExpression::UintEq( BooleanExpression::FieldEq(
box UExpressionInner::Value(i as u128) box FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
box head.clone(), box head.clone(),
), ),
match Self::choose_many( match Self::choose_many(
FieldElementExpression::select( FieldElementExpression::select(
base.clone(), base.clone(),
UExpressionInner::Value(i as u128) FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
) )
.into(), .into(),
tail.clone(), tail.clone(),
@ -182,22 +173,19 @@ impl<'ast> Unroller<'ast> {
}, },
FieldElementExpression::select( FieldElementExpression::select(
base.clone(), base.clone(),
UExpressionInner::Value(i as u128) FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
), ),
) )
.into(), .into(),
Type::Boolean => BooleanExpression::if_else( Type::Boolean => BooleanExpression::if_else(
BooleanExpression::UintEq( BooleanExpression::FieldEq(
box UExpressionInner::Value(i as u128) box FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
box head.clone(), box head.clone(),
), ),
match Self::choose_many( match Self::choose_many(
BooleanExpression::select( BooleanExpression::select(
base.clone(), base.clone(),
UExpressionInner::Value(i as u128) FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
) )
.into(), .into(),
tail.clone(), tail.clone(),
@ -212,22 +200,19 @@ impl<'ast> Unroller<'ast> {
}, },
BooleanExpression::select( BooleanExpression::select(
base.clone(), base.clone(),
UExpressionInner::Value(i as u128) FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
), ),
) )
.into(), .into(),
Type::Uint(..) => UExpression::if_else( Type::Uint(..) => UExpression::if_else(
BooleanExpression::UintEq( BooleanExpression::FieldEq(
box UExpressionInner::Value(i as u128) box FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
box head.clone(), box head.clone(),
), ),
match Self::choose_many( match Self::choose_many(
UExpression::select( UExpression::select(
base.clone(), base.clone(),
UExpressionInner::Value(i as u128) FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
) )
.into(), .into(),
tail.clone(), tail.clone(),
@ -242,8 +227,7 @@ impl<'ast> Unroller<'ast> {
}, },
UExpression::select( UExpression::select(
base.clone(), base.clone(),
UExpressionInner::Value(i as u128) FieldElementExpression::Number(T::from(i)),
.annotate(UBitwidth::B32),
), ),
) )
.into(), .into(),
@ -376,7 +360,7 @@ impl<'ast> Unroller<'ast> {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
enum Access<'ast, T: Field> { enum Access<'ast, T: Field> {
Select(UExpression<'ast, T>), Select(FieldElementExpression<'ast, T>),
Member(MemberId), Member(MemberId),
} }
/// Turn an assignee into its representation as a base variable and a list accesses /// Turn an assignee into its representation as a base variable and a list accesses
@ -437,7 +421,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
let indices = indices let indices = indices
.into_iter() .into_iter()
.map(|a| match a { .map(|a| match a {
Access::Select(i) => Access::Select(self.fold_uint_expression(i)), Access::Select(i) => Access::Select(self.fold_field_expression(i)),
a => a, a => a,
}) })
.collect(); .collect();
@ -463,16 +447,16 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
vec![TypedStatement::MultipleDefinition(variables, exprs)] vec![TypedStatement::MultipleDefinition(variables, exprs)]
} }
TypedStatement::For(v, from, to, stats) => { TypedStatement::For(v, from, to, stats) => {
let from = self.fold_uint_expression(from); let from = self.fold_field_expression(from);
let to = self.fold_uint_expression(to); let to = self.fold_field_expression(to);
match (from.into_inner(), to.into_inner()) { match (from, to) {
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => { (FieldElementExpression::Number(from), FieldElementExpression::Number(to)) => {
let mut values = vec![]; let mut values: Vec<T> = vec![];
let mut current = from; let mut current = from;
while current < to { while current < to {
values.push(current.clone()); values.push(current.clone());
current = 1 + &current; current = T::one() + &current;
} }
let res = values let res = values
@ -483,9 +467,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
TypedStatement::Declaration(v.clone()), TypedStatement::Declaration(v.clone()),
TypedStatement::Definition( TypedStatement::Definition(
TypedAssignee::Identifier(v.clone()), TypedAssignee::Identifier(v.clone()),
UExpressionInner::Value(index) FieldElementExpression::Number(index).into(),
.annotate(UBitwidth::B32)
.into(),
), ),
], ],
stats.clone(), stats.clone(),
@ -501,12 +483,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
} }
(from, to) => { (from, to) => {
self.complete = false; self.complete = false;
vec![TypedStatement::For( vec![TypedStatement::For(v, from, to, stats)]
v,
from.annotate(UBitwidth::B32),
to.annotate(UBitwidth::B32),
stats,
)]
} }
} }
} }
@ -538,11 +515,11 @@ mod tests {
#[test] #[test]
fn ssa_array() { fn ssa_array() {
let a0 = ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 3usize); let a0 = ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 3);
let e = FieldElementExpression::Number(Bn128Field::from(42)).into(); let e = FieldElementExpression::Number(Bn128Field::from(42)).into();
let index = 1u32.into(); let index = FieldElementExpression::Number(Bn128Field::from(1));
let a1 = Unroller::choose_many( let a1 = Unroller::choose_many(
a0.clone().into(), a0.clone().into(),
@ -558,34 +535,52 @@ mod tests {
a1, a1,
ArrayExpressionInner::Value(vec![ ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else( FieldElementExpression::if_else(
BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(42)), FieldElementExpression::Number(Bn128Field::from(42)),
FieldElementExpression::select(a0.clone(), 0u32.into(),) FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
)
) )
.into(), .into(),
FieldElementExpression::if_else( FieldElementExpression::if_else(
BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(42)), FieldElementExpression::Number(Bn128Field::from(42)),
FieldElementExpression::select(a0.clone(), 1u32.into()) FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
)
) )
.into(), .into(),
FieldElementExpression::if_else( FieldElementExpression::if_else(
BooleanExpression::UintEq(box 2u32.into(), box 1u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(42)), FieldElementExpression::Number(Bn128Field::from(42)),
FieldElementExpression::select(a0.clone(), 2u32.into()) FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(2))
)
) )
.into() .into()
]) ])
.annotate(Type::FieldElement, 3usize) .annotate(Type::FieldElement, 3)
.into() .into()
); );
let a0: ArrayExpression<Bn128Field> = ArrayExpressionInner::Identifier("a".into()) let a0 = ArrayExpressionInner::Identifier("a".into())
.annotate(Type::array(Type::FieldElement, 3), 3); .annotate(Type::array(Type::FieldElement, 3), 3);
let e = ArrayExpressionInner::Identifier("b".into()).annotate(Type::FieldElement, 3); let e = ArrayExpressionInner::Identifier("b".into()).annotate(Type::FieldElement, 3);
let index = 1u32.into(); let index = FieldElementExpression::Number(Bn128Field::from(1));
let a1 = Unroller::choose_many( let a1 = Unroller::choose_many(
a0.clone().into(), a0.clone().into(),
@ -601,21 +596,39 @@ mod tests {
a1, a1,
ArrayExpressionInner::Value(vec![ ArrayExpressionInner::Value(vec![
ArrayExpression::if_else( ArrayExpression::if_else(
BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
e.clone(), e.clone(),
ArrayExpression::select(a0.clone(), 0u32.into()) ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
)
) )
.into(), .into(),
ArrayExpression::if_else( ArrayExpression::if_else(
BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
e.clone(), e.clone(),
ArrayExpression::select(a0.clone(), 1u32.into()) ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
)
) )
.into(), .into(),
ArrayExpression::if_else( ArrayExpression::if_else(
BooleanExpression::UintEq(box 2u32.into(), box 1u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
e.clone(), e.clone(),
ArrayExpression::select(a0.clone(), 2u32.into()) ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(2))
)
) )
.into() .into()
]) ])
@ -628,7 +641,10 @@ mod tests {
let e = FieldElementExpression::Number(Bn128Field::from(42)); let e = FieldElementExpression::Number(Bn128Field::from(42));
let indices = vec![Access::Select(0u32.into()), Access::Select(0u32.into())]; let indices = vec![
Access::Select(FieldElementExpression::Number(Bn128Field::from(0))),
Access::Select(FieldElementExpression::Number(Bn128Field::from(0))),
];
let a1 = Unroller::choose_many( let a1 = Unroller::choose_many(
a0.clone().into(), a0.clone().into(),
@ -644,55 +660,91 @@ mod tests {
a1, a1,
ArrayExpressionInner::Value(vec![ ArrayExpressionInner::Value(vec![
ArrayExpression::if_else( ArrayExpression::if_else(
BooleanExpression::UintEq(box 0u32.into(), box 0u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
ArrayExpressionInner::Value(vec![ ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else( FieldElementExpression::if_else(
BooleanExpression::UintEq(box 0u32.into(), box 0u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(), e.clone(),
FieldElementExpression::select( FieldElementExpression::select(
ArrayExpression::select(a0.clone(), 0u32.into()), ArrayExpression::select(
0u32.into() a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
),
FieldElementExpression::Number(Bn128Field::from(0))
) )
) )
.into(), .into(),
FieldElementExpression::if_else( FieldElementExpression::if_else(
BooleanExpression::UintEq(box 1u32.into(), box 0u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(), e.clone(),
FieldElementExpression::select( FieldElementExpression::select(
ArrayExpression::select(a0.clone(), 0u32.into()), ArrayExpression::select(
1u32.into() a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
),
FieldElementExpression::Number(Bn128Field::from(1))
) )
) )
.into() .into()
]) ])
.annotate(Type::FieldElement, 2), .annotate(Type::FieldElement, 2),
ArrayExpression::select(a0.clone(), 0u32.into()) ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
)
) )
.into(), .into(),
ArrayExpression::if_else( ArrayExpression::if_else(
BooleanExpression::UintEq(box 1u32.into(), box 0u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
ArrayExpressionInner::Value(vec![ ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else( FieldElementExpression::if_else(
BooleanExpression::UintEq(box 0u32.into(), box 0u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(), e.clone(),
FieldElementExpression::select( FieldElementExpression::select(
ArrayExpression::select(a0.clone(), 1u32.into()), ArrayExpression::select(
0u32.into() a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(0))
) )
) )
.into(), .into(),
FieldElementExpression::if_else( FieldElementExpression::if_else(
BooleanExpression::UintEq(box 1u32.into(), box 0u32.into()), BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(), e.clone(),
FieldElementExpression::select( FieldElementExpression::select(
ArrayExpression::select(a0.clone(), 1u32.into()), ArrayExpression::select(
1u32.into() a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(1))
) )
) )
.into() .into()
]) ])
.annotate(Type::FieldElement, 2), .annotate(Type::FieldElement, 2),
ArrayExpression::select(a0.clone(), 1u32.into()) ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
)
) )
.into(), .into(),
]) ])
@ -721,8 +773,8 @@ mod tests {
let s = TypedStatement::For( let s = TypedStatement::For(
Variable::field_element("i"), Variable::field_element("i"),
2u32.into(), FieldElementExpression::Number(Bn128Field::from(2)),
5u32.into(), FieldElementExpression::Number(Bn128Field::from(5)),
vec![ vec![
TypedStatement::Declaration(Variable::field_element("foo")), TypedStatement::Declaration(Variable::field_element("foo")),
TypedStatement::Definition( TypedStatement::Definition(
@ -1031,7 +1083,7 @@ mod tests {
let s: TypedStatement<Bn128Field> = TypedStatement::Definition( let s: TypedStatement<Bn128Field> = TypedStatement::Definition(
TypedAssignee::Select( TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::field_array("a", 2)), box TypedAssignee::Identifier(Variable::field_array("a", 2)),
box 1u32.into(), box FieldElementExpression::Number(Bn128Field::from(1)),
), ),
FieldElementExpression::Number(Bn128Field::from(2)).into(), FieldElementExpression::Number(Bn128Field::from(2)).into(),
); );
@ -1040,7 +1092,11 @@ mod tests {
u.fold_statement(s), u.fold_statement(s),
vec![ vec![
TypedStatement::Assertion( TypedStatement::Assertion(
BooleanExpression::UintLt(box 1u32.into(), box 2u32.into()).into(), BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(2))
)
.into(),
), ),
TypedStatement::Definition( TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array( TypedAssignee::Identifier(Variable::field_array(
@ -1049,26 +1105,32 @@ mod tests {
)), )),
ArrayExpressionInner::Value(vec![ ArrayExpressionInner::Value(vec![
FieldElementExpression::IfElse( FieldElementExpression::IfElse(
box BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()), box BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Select( box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier( box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0) Identifier::from("a").version(0)
) )
.annotate(Type::FieldElement, 2), .annotate(Type::FieldElement, 2),
box 0u32.into() box FieldElementExpression::Number(Bn128Field::from(0))
), ),
) )
.into(), .into(),
FieldElementExpression::IfElse( FieldElementExpression::IfElse(
box BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()), box BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Select( box FieldElementExpression::Select(
box ArrayExpressionInner::Identifier( box ArrayExpressionInner::Identifier(
Identifier::from("a").version(0) Identifier::from("a").version(0)
) )
.annotate(Type::FieldElement, 2), .annotate(Type::FieldElement, 2),
box 1u32.into() box FieldElementExpression::Number(Bn128Field::from(1))
), ),
) )
.into(), .into(),
@ -1153,7 +1215,7 @@ mod tests {
"a", "a",
array_of_array_ty.clone(), array_of_array_ty.clone(),
)), )),
box 1u32.into(), box FieldElementExpression::Number(Bn128Field::from(1)),
), ),
ArrayExpressionInner::Value(vec![ ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(Bn128Field::from(4)).into(), FieldElementExpression::Number(Bn128Field::from(4)).into(),
@ -1167,7 +1229,11 @@ mod tests {
u.fold_statement(s), u.fold_statement(s),
vec![ vec![
TypedStatement::Assertion( TypedStatement::Assertion(
BooleanExpression::UintLt(box 1u32.into(), box 2u32.into()).into(), BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(2))
)
.into(),
), ),
TypedStatement::Definition( TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type( TypedAssignee::Identifier(Variable::with_id_and_type(
@ -1176,7 +1242,10 @@ mod tests {
)), )),
ArrayExpressionInner::Value(vec![ ArrayExpressionInner::Value(vec![
ArrayExpressionInner::IfElse( ArrayExpressionInner::IfElse(
box BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()), box BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
box ArrayExpressionInner::Value(vec![ box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(Bn128Field::from(4)).into(), FieldElementExpression::Number(Bn128Field::from(4)).into(),
FieldElementExpression::Number(Bn128Field::from(5)).into(), FieldElementExpression::Number(Bn128Field::from(5)).into(),
@ -1188,14 +1257,17 @@ mod tests {
Identifier::from("a").version(0) Identifier::from("a").version(0)
) )
.annotate(Type::array(Type::FieldElement, 2), 2), .annotate(Type::array(Type::FieldElement, 2), 2),
box 0u32.into() box FieldElementExpression::Number(Bn128Field::from(0))
) )
.annotate(Type::FieldElement, 2), .annotate(Type::FieldElement, 2),
) )
.annotate(Type::FieldElement, 2) .annotate(Type::FieldElement, 2)
.into(), .into(),
ArrayExpressionInner::IfElse( ArrayExpressionInner::IfElse(
box BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()), box BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
box ArrayExpressionInner::Value(vec![ box ArrayExpressionInner::Value(vec![
FieldElementExpression::Number(Bn128Field::from(4)).into(), FieldElementExpression::Number(Bn128Field::from(4)).into(),
FieldElementExpression::Number(Bn128Field::from(5)).into(), FieldElementExpression::Number(Bn128Field::from(5)).into(),
@ -1207,7 +1279,7 @@ mod tests {
Identifier::from("a").version(0) Identifier::from("a").version(0)
) )
.annotate(Type::array(Type::FieldElement, 2), 2), .annotate(Type::array(Type::FieldElement, 2), 2),
box 1u32.into() box FieldElementExpression::Number(Bn128Field::from(1))
) )
.annotate(Type::FieldElement, 2), .annotate(Type::FieldElement, 2),
) )

View file

@ -29,12 +29,10 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
fn select<U: Select<'ast, T> + IfElse<'ast, T>>( fn select<U: Select<'ast, T> + IfElse<'ast, T>>(
&mut self, &mut self,
a: ArrayExpression<'ast, T>, a: ArrayExpression<'ast, T>,
i: UExpression<'ast, T>, i: FieldElementExpression<'ast, T>,
) -> U { ) -> U {
match i.into_inner() { match i {
UExpressionInner::Value(i) => { FieldElementExpression::Number(i) => U::select(a, FieldElementExpression::Number(i)),
U::select(a, UExpressionInner::Value(i).annotate(UBitwidth::B32))
}
i => { i => {
let size = match a.get_type().clone() { let size = match a.get_type().clone() {
Type::Array(array_ty) => array_ty.size, Type::Array(array_ty) => array_ty.size,
@ -44,9 +42,9 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
self.statements.push(TypedStatement::Assertion( self.statements.push(TypedStatement::Assertion(
(0..size) (0..size)
.map(|index| { .map(|index| {
BooleanExpression::UintEq( BooleanExpression::FieldEq(
box i.clone().annotate(UBitwidth::B32), box i.clone(),
box UExpressionInner::Value(index as u128).annotate(UBitwidth::B32), box FieldElementExpression::Number(index.into()).into(),
) )
}) })
.fold(None, |acc, e| match acc { .fold(None, |acc, e| match acc {
@ -58,19 +56,14 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
)); ));
(0..size) (0..size)
.map(|i| { .map(|i| U::select(a.clone(), FieldElementExpression::Number(i.into())))
U::select(
a.clone(),
UExpressionInner::Value(i as u128).annotate(UBitwidth::B32),
)
})
.enumerate() .enumerate()
.rev() .rev()
.fold(None, |acc, (index, res)| match acc { .fold(None, |acc, (index, res)| match acc {
Some(acc) => Some(U::if_else( Some(acc) => Some(U::if_else(
BooleanExpression::UintEq( BooleanExpression::FieldEq(
box i.clone().annotate(UBitwidth::B32), box i.clone(),
box UExpressionInner::Value(index as u128).annotate(UBitwidth::B32), box FieldElementExpression::Number(index.into()),
), ),
res, res,
acc, acc,
@ -168,7 +161,7 @@ mod tests {
TypedAssignee::Identifier(Variable::field_element("b")), TypedAssignee::Identifier(Variable::field_element("b")),
FieldElementExpression::Select( FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 2), box ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 2),
box UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32), box FieldElementExpression::Identifier("i".into()),
) )
.into(), .into(),
); );
@ -199,12 +192,12 @@ mod tests {
FieldElementExpression::Select( FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into()) box ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 2), .annotate(Type::FieldElement, 2),
box 0u32.into(), box FieldElementExpression::Number(0.into()),
), ),
FieldElementExpression::Select( FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into()) box ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 2), .annotate(Type::FieldElement, 2),
box 1u32.into(), box FieldElementExpression::Number(1.into()),
) )
) )
.into() .into()

View file

@ -46,7 +46,7 @@ pub trait Folder<'ast, T: Field>: Sized {
TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)), TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)),
TypedAssignee::Select(box a, box index) => TypedAssignee::Select( TypedAssignee::Select(box a, box index) => TypedAssignee::Select(
box self.fold_assignee(a), box self.fold_assignee(a),
box self.fold_uint_expression(index), box self.fold_field_expression(index),
), ),
TypedAssignee::Member(box s, m) => TypedAssignee::Member(box self.fold_assignee(s), m), TypedAssignee::Member(box s, m) => TypedAssignee::Member(box self.fold_assignee(s), m),
} }
@ -216,7 +216,7 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
} }
ArrayExpressionInner::Select(box array, box index) => { ArrayExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index); let index = f.fold_field_expression(index);
ArrayExpressionInner::Select(box array, box index) ArrayExpressionInner::Select(box array, box index)
} }
} }
@ -249,7 +249,7 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
} }
StructExpressionInner::Select(box array, box index) => { StructExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index); let index = f.fold_field_expression(index);
StructExpressionInner::Select(box array, box index) StructExpressionInner::Select(box array, box index)
} }
} }
@ -305,7 +305,7 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
} }
FieldElementExpression::Select(box array, box index) => { FieldElementExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index); let index = f.fold_field_expression(index);
FieldElementExpression::Select(box array, box index) FieldElementExpression::Select(box array, box index)
} }
} }
@ -350,45 +350,25 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_uint_expression(e2); let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintEq(box e1, box e2) BooleanExpression::UintEq(box e1, box e2)
} }
BooleanExpression::FieldLt(box e1, box e2) => { BooleanExpression::Lt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLt(box e1, box e2) BooleanExpression::Lt(box e1, box e2)
} }
BooleanExpression::FieldLe(box e1, box e2) => { BooleanExpression::Le(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLe(box e1, box e2) BooleanExpression::Le(box e1, box e2)
} }
BooleanExpression::FieldGt(box e1, box e2) => { BooleanExpression::Gt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGt(box e1, box e2) BooleanExpression::Gt(box e1, box e2)
} }
BooleanExpression::FieldGe(box e1, box e2) => { BooleanExpression::Ge(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGe(box e1, box e2) BooleanExpression::Ge(box e1, box e2)
}
BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLt(box e1, box e2)
}
BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLe(box e1, box e2)
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGt(box e1, box e2)
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGe(box e1, box e2)
} }
BooleanExpression::Or(box e1, box e2) => { BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1); let e1 = f.fold_boolean_expression(e1);
@ -420,7 +400,7 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
} }
BooleanExpression::Select(box array, box index) => { BooleanExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index); let index = f.fold_field_expression(index);
BooleanExpression::Select(box array, box index) BooleanExpression::Select(box array, box index)
} }
} }
@ -503,7 +483,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
} }
UExpressionInner::Select(box array, box index) => { UExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_uint_expression(index); let index = f.fold_field_expression(index);
UExpressionInner::Select(box array, box index) UExpressionInner::Select(box array, box index)
} }
UExpressionInner::IfElse(box cond, box cons, box alt) => { UExpressionInner::IfElse(box cond, box cons, box alt) => {

View file

@ -256,7 +256,10 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Hash, Eq)] #[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedAssignee<'ast, T> { pub enum TypedAssignee<'ast, T> {
Identifier(Variable<'ast>), Identifier(Variable<'ast>),
Select(Box<TypedAssignee<'ast, T>>, Box<UExpression<'ast, T>>), Select(
Box<TypedAssignee<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
Member(Box<TypedAssignee<'ast, T>>, MemberId), Member(Box<TypedAssignee<'ast, T>>, MemberId),
} }
@ -316,8 +319,8 @@ pub enum TypedStatement<'ast, T> {
Assertion(BooleanExpression<'ast, T>), Assertion(BooleanExpression<'ast, T>),
For( For(
Variable<'ast>, Variable<'ast>,
UExpression<'ast, T>, FieldElementExpression<'ast, T>,
UExpression<'ast, T>, FieldElementExpression<'ast, T>,
Vec<TypedStatement<'ast, T>>, Vec<TypedStatement<'ast, T>>,
), ),
MultipleDefinition(Vec<Variable<'ast>>, TypedExpressionList<'ast, T>), MultipleDefinition(Vec<Variable<'ast>>, TypedExpressionList<'ast, T>),
@ -424,44 +427,85 @@ pub enum TypedExpression<'ast, T> {
} }
impl<'ast, T: Field> TypedExpression<'ast, T> { impl<'ast, T: Field> TypedExpression<'ast, T> {
// return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types. // return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types if possible.
// Post condition is that (lhs, rhs) cannot be made equal by further removing IntExpressions // Post condition is that (lhs, rhs) cannot be made equal by further removing IntExpressions
pub fn align_without_integers( pub fn align_without_integers(
lhs: Self, lhs: Self,
rhs: Self, rhs: Self,
) -> Result<(TypedExpression<'ast, T>, TypedExpression<'ast, T>), String> { ) -> Result<
(TypedExpression<'ast, T>, TypedExpression<'ast, T>),
(TypedExpression<'ast, T>, TypedExpression<'ast, T>),
> {
use self::TypedExpression::*; use self::TypedExpression::*;
match (lhs, rhs) { match (lhs, rhs) {
(Int(lhs), FieldElement(rhs)) => Ok(( (Int(lhs), FieldElement(rhs)) => Ok((
FieldElementExpression::try_from_int(lhs)?.into(), FieldElementExpression::try_from_int(lhs)
.map_err(|lhs| (lhs.into(), rhs.clone().into()))?
.into(),
FieldElement(rhs), FieldElement(rhs),
)), )),
(FieldElement(lhs), Int(rhs)) => Ok(( (FieldElement(lhs), Int(rhs)) => Ok((
FieldElement(lhs), FieldElement(lhs.clone()),
FieldElementExpression::try_from_int(rhs)?.into(), FieldElementExpression::try_from_int(rhs)
.map_err(|rhs| (lhs.into(), rhs.into()))?
.into(),
)), )),
(Int(lhs), Uint(rhs)) => Ok(( (Int(lhs), Uint(rhs)) => Ok((
UExpression::try_from_int(lhs, rhs.bitwidth())?.into(), UExpression::try_from_int(lhs, rhs.bitwidth())
.map_err(|lhs| (lhs.into(), rhs.clone().into()))?
.into(),
Uint(rhs), Uint(rhs),
)), )),
(Uint(lhs), Int(rhs)) => { (Uint(lhs), Int(rhs)) => {
let bitwidth = lhs.bitwidth(); let bitwidth = lhs.bitwidth();
Ok((Uint(lhs), UExpression::try_from_int(rhs, bitwidth)?.into())) Ok((
Uint(lhs.clone()),
UExpression::try_from_int(rhs, bitwidth)
.map_err(|rhs| (lhs.into(), rhs.into()))?
.into(),
))
} }
(lhs, rhs) => Ok((lhs, rhs)), (Array(lhs), Array(rhs)) => {
if lhs.get_type() == rhs.get_type() {
Ok((Array(lhs), Array(rhs)))
} else {
Err((Array(lhs), Array(rhs)))
}
}
(Struct(lhs), Struct(rhs)) => {
if lhs.get_type() == rhs.get_type() {
Ok((Struct(lhs).into(), Struct(rhs).into()))
} else {
Err((Struct(lhs).into(), Struct(rhs).into()))
}
}
(Uint(lhs), Uint(rhs)) => Ok((lhs.into(), rhs.into())),
(Boolean(lhs), Boolean(rhs)) => Ok((lhs.into(), rhs.into())),
(FieldElement(lhs), FieldElement(rhs)) => Ok((lhs.into(), rhs.into())),
(Int(lhs), Int(rhs)) => Ok((lhs.into(), rhs.into())),
(lhs, rhs) => Err((lhs, rhs)),
} }
} }
pub fn align_to_type(e: Self, ty: Type) -> Result<Self, String> { pub fn align_to_type(e: Self, ty: Type) -> Result<Self, (Self, Type)> {
use self::TypedExpression::*; match ty.clone() {
match (e, ty) { Type::FieldElement => {
(Int(e), Type::FieldElement) => Ok(FieldElementExpression::try_from_int(e)?.into()), FieldElementExpression::try_from_typed(e).map(TypedExpression::from)
(Int(e), Type::Uint(bitwidth)) => Ok(UExpression::try_from_int(e, bitwidth)?.into()), }
(Array(a), Type::Array(array_ty)) => Ok(ArrayExpression::try_from_int(a, array_ty) Type::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from),
.map_err(|_| String::from("align array to type"))? Type::Uint(bitwidth) => {
.into()), UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from)
(e, _) => Ok(e), }
Type::Array(array_ty) => {
ArrayExpression::try_from_typed(e, array_ty).map(TypedExpression::from)
}
Type::Struct(struct_ty) => {
StructExpression::try_from_typed(e, struct_ty).map(TypedExpression::from)
}
Type::Int => Err(e),
} }
.map_err(|e| (e, ty))
} }
} }
@ -478,7 +522,10 @@ pub enum IntExpression<'ast, T> {
Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>,
Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>,
), ),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>), Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
Xor(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>), Xor(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
And(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>), And(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
Or(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>), Or(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
@ -734,7 +781,10 @@ pub enum FieldElementExpression<'ast, T> {
), ),
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>), FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
Member(Box<StructExpression<'ast, T>>, MemberId), Member(Box<StructExpression<'ast, T>>, MemberId),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>), Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
} }
impl<'ast, T: Field> BooleanExpression<'ast, T> { impl<'ast, T: Field> BooleanExpression<'ast, T> {
@ -746,6 +796,12 @@ impl<'ast, T: Field> BooleanExpression<'ast, T> {
} }
} }
impl<'ast, T> From<T> for FieldElementExpression<'ast, T> {
fn from(n: T) -> Self {
FieldElementExpression::Number(n)
}
}
impl<'ast, T: Field> FieldElementExpression<'ast, T> { impl<'ast, T: Field> FieldElementExpression<'ast, T> {
pub fn try_from_typed(e: TypedExpression<'ast, T>) -> Result<Self, TypedExpression<'ast, T>> { pub fn try_from_typed(e: TypedExpression<'ast, T>) -> Result<Self, TypedExpression<'ast, T>> {
match e { match e {
@ -757,13 +813,13 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
} }
} }
pub fn try_from_int(i: IntExpression<'ast, T>) -> Result<Self, String> { pub fn try_from_int(i: IntExpression<'ast, T>) -> Result<Self, IntExpression<'ast, T>> {
match i { match i {
IntExpression::Value(i) => { IntExpression::Value(i) => {
if i <= T::max_value().to_biguint() { if i <= T::max_value().to_biguint() {
Ok(Self::Number(T::from(i))) Ok(Self::Number(T::from(i)))
} else { } else {
Err(format!("Literal `{} is too large for type `field`", i)) Err(IntExpression::Value(i))
} }
} }
IntExpression::Add(box e1, box e2) => Ok(Self::Add( IntExpression::Add(box e1, box e2) => Ok(Self::Add(
@ -794,7 +850,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
)) ))
} }
IntExpression::Select(..) => unimplemented!(), IntExpression::Select(..) => unimplemented!(),
i => Err(format!("Expected a `field` but found expression `{}`", i)), i => Err(i),
} }
} }
} }
@ -804,26 +860,22 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
pub enum BooleanExpression<'ast, T> { pub enum BooleanExpression<'ast, T> {
Identifier(Identifier<'ast>), Identifier(Identifier<'ast>),
Value(bool), Value(bool),
FieldLt( Lt(
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
), ),
FieldLe( Le(
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
), ),
FieldGe( Ge(
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
), ),
FieldGt( Gt(
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
), ),
UintLt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
FieldEq( FieldEq(
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
@ -854,7 +906,10 @@ pub enum BooleanExpression<'ast, T> {
), ),
Member(Box<StructExpression<'ast, T>>, MemberId), Member(Box<StructExpression<'ast, T>>, MemberId),
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>), FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>), Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
} }
/// An expression of type `array` /// An expression of type `array`
@ -880,7 +935,10 @@ pub enum ArrayExpressionInner<'ast, T> {
Box<ArrayExpression<'ast, T>>, Box<ArrayExpression<'ast, T>>,
), ),
Member(Box<StructExpression<'ast, T>>, MemberId), Member(Box<StructExpression<'ast, T>>, MemberId),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>), Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
} }
impl<'ast, T> ArrayExpressionInner<'ast, T> { impl<'ast, T> ArrayExpressionInner<'ast, T> {
@ -931,13 +989,16 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
} }
// precondition: `array` is only made of inline arrays // precondition: `array` is only made of inline arrays
pub fn try_from_int(array: Self, target_array_ty: ArrayType) -> Result<Self, ()> { pub fn try_from_int(
if array.get_array_type() == target_array_ty { array: Self,
target_array_ty: ArrayType,
) -> Result<Self, TypedExpression<'ast, T>> {
let array_ty = array.get_array_type();
if array_ty == target_array_ty {
return Ok(array); return Ok(array);
} }
let array_ty = array.get_array_type();
// sizes must be equal // sizes must be equal
match target_array_ty.size == array_ty.size { match target_array_ty.size == array_ty.size {
true => true =>
@ -946,56 +1007,67 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
match array.into_inner() { match array.into_inner() {
ArrayExpressionInner::Value(inline_array) => { ArrayExpressionInner::Value(inline_array) => {
match *target_array_ty.ty { match *target_array_ty.ty {
Type::Int => Ok(ArrayExpressionInner::Value(inline_array)
.annotate(*target_array_ty.ty, array_ty.size)),
Type::FieldElement => { Type::FieldElement => {
// try to convert all elements to field // try to convert all elements to field
let converted = inline_array let converted = inline_array
.into_iter() .into_iter()
.map(|e| { .map(|e| {
let int = IntExpression::try_from(e)?; FieldElementExpression::try_from_typed(e)
let field = FieldElementExpression::try_from_int(int) .map(TypedExpression::from)
.map_err(|_| ())?;
Ok(field.into())
}) })
.collect::<Result<Vec<TypedExpression<'ast, T>>, ()>>()?; .collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)?;
Ok(ArrayExpressionInner::Value(converted) Ok(ArrayExpressionInner::Value(converted)
.annotate(*target_array_ty.ty, array_ty.size)) .annotate(*target_array_ty.ty, array_ty.size))
} }
Type::Uint(bitwidth) => { Type::Uint(bitwidth) => {
// try to convert all elements to field // try to convert all elements to uint
let converted = inline_array let converted = inline_array
.into_iter() .into_iter()
.map(|e| { .map(|e| {
let int = IntExpression::try_from(e)?; UExpression::try_from_typed(e, bitwidth)
let field = UExpression::try_from_int(int, bitwidth) .map(TypedExpression::from)
.map_err(|_| ())?;
Ok(field.into())
}) })
.collect::<Result<Vec<TypedExpression<'ast, T>>, ()>>()?; .collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)?;
Ok(ArrayExpressionInner::Value(converted) Ok(ArrayExpressionInner::Value(converted)
.annotate(*target_array_ty.ty, array_ty.size)) .annotate(*target_array_ty.ty, array_ty.size))
} }
Type::Array(ref array_ty) => { Type::Array(ref inner_array_ty) => {
// try to convert all elements to field // try to convert all elements to uint
let converted = inline_array let converted = inline_array
.into_iter() .into_iter()
.map(|e| { .map(|e| {
let array = ArrayExpression::try_from(e)?; ArrayExpression::try_from_typed(e, inner_array_ty.clone())
let array = .map(TypedExpression::from)
ArrayExpression::try_from_int(array, array_ty.clone())
.map_err(|_| ())?;
Ok(array.into())
}) })
.collect::<Result<Vec<TypedExpression<'ast, T>>, ()>>()?; .collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)?;
Ok(ArrayExpressionInner::Value(converted) Ok(ArrayExpressionInner::Value(converted)
.annotate(*target_array_ty.ty.clone(), array_ty.size.clone())) .annotate(*target_array_ty.ty, array_ty.size))
} }
_ => Err(()), Type::Struct(ref struct_ty) => {
// try to convert all elements to uint
let converted = inline_array
.into_iter()
.map(|e| {
StructExpression::try_from_typed(e, struct_ty.clone())
.map(TypedExpression::from)
})
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
.map_err(TypedExpression::from)?;
Ok(ArrayExpressionInner::Value(converted)
.annotate(*target_array_ty.ty, array_ty.size))
}
Type::Boolean => unreachable!(),
} }
} }
_ => unreachable!(""), _ => unreachable!(""),
} }
} }
false => Err(()), false => unreachable!(),
} }
} }
} }
@ -1047,7 +1119,10 @@ pub enum StructExpressionInner<'ast, T> {
Box<StructExpression<'ast, T>>, Box<StructExpression<'ast, T>>,
), ),
Member(Box<StructExpression<'ast, T>>, MemberId), Member(Box<StructExpression<'ast, T>>, MemberId),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>), Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
} }
impl<'ast, T> StructExpressionInner<'ast, T> { impl<'ast, T> StructExpressionInner<'ast, T> {
@ -1165,14 +1240,10 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
BooleanExpression::Identifier(ref var) => write!(f, "{}", var), BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::ArrayEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::ArrayEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
@ -1247,30 +1318,10 @@ impl<'ast, T: fmt::Debug> fmt::Debug for BooleanExpression<'ast, T> {
"IfElse({:?}, {:?}, {:?})", "IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative condition, consequent, alternative
), ),
BooleanExpression::FieldLt(ref lhs, ref rhs) => { BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "Lt({:?}, {:?})", lhs, rhs),
write!(f, "FieldLt({:?}, {:?})", lhs, rhs) BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "Le({:?}, {:?})", lhs, rhs),
} BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "Ge({:?}, {:?})", lhs, rhs),
BooleanExpression::FieldLe(ref lhs, ref rhs) => { BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "Gt({:?}, {:?})", lhs, rhs),
write!(f, "FieldLe({:?}, {:?})", lhs, rhs)
}
BooleanExpression::FieldGe(ref lhs, ref rhs) => {
write!(f, "FieldGe({:?}, {:?})", lhs, rhs)
}
BooleanExpression::FieldGt(ref lhs, ref rhs) => {
write!(f, "FieldGt({:?}, {:?})", lhs, rhs)
}
BooleanExpression::UintLt(ref lhs, ref rhs) => {
write!(f, "UintLt({:?}, {:?})", lhs, rhs)
}
BooleanExpression::UintLe(ref lhs, ref rhs) => {
write!(f, "UintLe({:?}, {:?})", lhs, rhs)
}
BooleanExpression::UintGe(ref lhs, ref rhs) => {
write!(f, "UintGe({:?}, {:?})", lhs, rhs)
}
BooleanExpression::UintGt(ref lhs, ref rhs) => {
write!(f, "UintGt({:?}, {:?})", lhs, rhs)
}
BooleanExpression::FieldEq(ref lhs, ref rhs) => { BooleanExpression::FieldEq(ref lhs, ref rhs) => {
write!(f, "FieldEq({:?}, {:?})", lhs, rhs) write!(f, "FieldEq({:?}, {:?})", lhs, rhs)
} }
@ -1483,23 +1534,23 @@ impl<'ast, T: Clone> IfElse<'ast, T> for StructExpression<'ast, T> {
} }
pub trait Select<'ast, T> { pub trait Select<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self; fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self;
} }
impl<'ast, T> Select<'ast, T> for FieldElementExpression<'ast, T> { impl<'ast, T> Select<'ast, T> for FieldElementExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
FieldElementExpression::Select(box array, box index) FieldElementExpression::Select(box array, box index)
} }
} }
impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> { impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
BooleanExpression::Select(box array, box index) BooleanExpression::Select(box array, box index)
} }
} }
impl<'ast, T: Clone> Select<'ast, T> for UExpression<'ast, T> { impl<'ast, T: Clone> Select<'ast, T> for UExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let bitwidth = match array.inner_type().clone() { let bitwidth = match array.inner_type().clone() {
Type::Uint(bitwidth) => bitwidth, Type::Uint(bitwidth) => bitwidth,
_ => unreachable!(), _ => unreachable!(),
@ -1510,7 +1561,7 @@ impl<'ast, T: Clone> Select<'ast, T> for UExpression<'ast, T> {
} }
impl<'ast, T: Clone> Select<'ast, T> for ArrayExpression<'ast, T> { impl<'ast, T: Clone> Select<'ast, T> for ArrayExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let (ty, size) = match array.inner_type() { let (ty, size) = match array.inner_type() {
Type::Array(array_type) => (array_type.ty.clone(), array_type.size.clone()), Type::Array(array_type) => (array_type.ty.clone(), array_type.size.clone()),
_ => unreachable!(), _ => unreachable!(),
@ -1521,7 +1572,7 @@ impl<'ast, T: Clone> Select<'ast, T> for ArrayExpression<'ast, T> {
} }
impl<'ast, T: Clone> Select<'ast, T> for StructExpression<'ast, T> { impl<'ast, T: Clone> Select<'ast, T> for StructExpression<'ast, T> {
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
let members = match array.inner_type().clone() { let members = match array.inner_type().clone() {
Type::Struct(members) => members, Type::Struct(members) => members,
_ => unreachable!(), _ => unreachable!(),

View file

@ -190,6 +190,20 @@ impl Type {
Type::Uint(b.into()) Type::Uint(b.into())
} }
pub fn can_be_specialized_to(&self, other: &Type) -> bool {
use self::Type::*;
if self == other {
true
} else {
match (self, other) {
(Int, FieldElement) | (Int, Uint(..)) => true,
(Array(l), Array(r)) => l.size == r.size && l.ty.can_be_specialized_to(&r.ty),
_ => false,
}
}
}
fn to_slug(&self) -> String { fn to_slug(&self) -> String {
match self { match self {
Type::FieldElement => String::from("f"), Type::FieldElement => String::from("f"),

View file

@ -72,7 +72,10 @@ impl<'ast, T: Field> UExpression<'ast, T> {
} }
} }
pub fn try_from_int(i: IntExpression<'ast, T>, bitwidth: UBitwidth) -> Result<Self, String> { pub fn try_from_int(
i: IntExpression<'ast, T>,
bitwidth: UBitwidth,
) -> Result<Self, IntExpression<'ast, T>> {
use self::IntExpression::*; use self::IntExpression::*;
match i { match i {
@ -83,10 +86,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
) )
.annotate(bitwidth)) .annotate(bitwidth))
} else { } else {
Err(format!( Err(Value(i))
"Literal `{}` is too large for type u{}",
i, bitwidth
))
} }
} }
Add(box e1, box e2) => Ok(UExpression::add( Add(box e1, box e2) => Ok(UExpression::add(
@ -127,10 +127,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
Self::try_from_int(alternative, bitwidth)?, Self::try_from_int(alternative, bitwidth)?,
)), )),
Select(..) => unimplemented!(), Select(..) => unimplemented!(),
i => Err(format!( i => Err(i),
"Expected a `u{}` but found expression `{}`",
bitwidth, i
)),
} }
} }
} }
@ -192,7 +189,10 @@ pub enum UExpressionInner<'ast, T> {
Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>,
), ),
Member(Box<StructExpression<'ast, T>>, MemberId), Member(Box<StructExpression<'ast, T>>, MemberId),
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>), Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
} }
impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {

View file

@ -204,45 +204,25 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_uint_expression(e2); let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintEq(box e1, box e2) BooleanExpression::UintEq(box e1, box e2)
} }
BooleanExpression::FieldLt(box e1, box e2) => { BooleanExpression::Lt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLt(box e1, box e2) BooleanExpression::Lt(box e1, box e2)
} }
BooleanExpression::FieldLe(box e1, box e2) => { BooleanExpression::Le(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLe(box e1, box e2) BooleanExpression::Le(box e1, box e2)
} }
BooleanExpression::FieldGt(box e1, box e2) => { BooleanExpression::Gt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGt(box e1, box e2) BooleanExpression::Gt(box e1, box e2)
} }
BooleanExpression::FieldGe(box e1, box e2) => { BooleanExpression::Ge(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGe(box e1, box e2) BooleanExpression::Ge(box e1, box e2)
}
BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLt(box e1, box e2)
}
BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLe(box e1, box e2)
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGt(box e1, box e2)
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGe(box e1, box e2)
} }
BooleanExpression::Or(box e1, box e2) => { BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1); let e1 = f.fold_boolean_expression(e1);

View file

@ -381,26 +381,22 @@ pub enum FieldElementExpression<'ast, T> {
pub enum BooleanExpression<'ast, T> { pub enum BooleanExpression<'ast, T> {
Identifier(Identifier<'ast>), Identifier(Identifier<'ast>),
Value(bool), Value(bool),
FieldLt( Lt(
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
), ),
FieldLe( Le(
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
), ),
FieldGe( Ge(
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
), ),
FieldGt( Gt(
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
), ),
UintLt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
FieldEq( FieldEq(
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
@ -510,14 +506,10 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
BooleanExpression::Identifier(ref var) => write!(f, "{}", var), BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),

View file

@ -4,12 +4,12 @@ def main():
assert(7 == 2 ** 2 * 2 - 1) assert(7 == 2 ** 2 * 2 - 1)
assert(3 == 2 ** 2 / 2 + 1) assert(3 == 2 ** 2 / 2 + 1)
field a = if 3 == 2 ** 2 / 2 + 1 && true then 1 else 0 fi // combines arithmetic with boolean operators field a = if 3f == 2f ** 2 / 2 + 1 && true then 1 else 0 fi // combines arithmetic with boolean operators
field b = if 3 == 3 && 4 < 5 then 1 else 0 fi // checks precedence of boolean operators field b = if 3f == 3f && 4f < 5f then 1 else 0 fi // checks precedence of boolean operators
field c = if 4 < 5 && 3 == 3 then 1 else 0 fi field c = if 4f < 5f && 3f == 3f then 1 else 0 fi
field d = if 4 > 5 && 2 >= 1 || 1 == 1 then 1 else 0 fi field d = if 4f > 5f && 2f >= 1f || 1f == 1f then 1 else 0 fi
field e = if 2 >= 1 && 4 > 5 || 1 == 1 then 1 else 0 fi field e = if 2f >= 1f && 4f > 5f || 1f == 1f then 1 else 0 fi
field f = if 1 < 2 && false || 4 < 5 && 2 >= 1 then 1 else 0 fi field f = if 1f < 2f && false || 4f < 5f && 2f >= 1f then 1 else 0 fi
assert(0x00 ^ 0x00 == 0x00) assert(0x00 ^ 0x00 == 0x00)

View file

@ -101,7 +101,9 @@ mod tests {
term(43, 44, [ term(43, 44, [
primary_expression(43, 44, [ primary_expression(43, 44, [
literal(43, 44, [ literal(43, 44, [
decimal_number(43, 44) decimal_literal(43, 44, [
decimal_number(43, 44)
])
]) ])
]) ])
]) ])
@ -258,7 +260,9 @@ mod tests {
term(30, 31, [ term(30, 31, [
primary_expression(30, 31, [ primary_expression(30, 31, [
literal(30, 31, [ literal(30, 31, [
decimal_number(30, 31) decimal_literal(30, 31, [
decimal_number(30, 31)
])
]) ])
]) ])
]) ])

View file

@ -4,7 +4,6 @@
// Note: parameters will be updated soon to be more compatible with zCash's implementation // Note: parameters will be updated soon to be more compatible with zCash's implementation
struct BabyJubJubParams { struct BabyJubJubParams {
field JUBJUBE
field JUBJUBC field JUBJUBC
field JUBJUBA field JUBJUBA
field JUBJUBD field JUBJUBD
@ -18,7 +17,6 @@ struct BabyJubJubParams {
def main() -> BabyJubJubParams: def main() -> BabyJubJubParams:
// Order of the curve E // Order of the curve E
field JUBJUBE = 21888242871839275222246405745257275088614511777268538073601725287587578984328
field JUBJUBC = 8 // Cofactor field JUBJUBC = 8 // Cofactor
field JUBJUBA = 168700 // Coefficient A field JUBJUBA = 168700 // Coefficient A
field JUBJUBD = 168696 // Coefficient D field JUBJUBD = 168696 // Coefficient D
@ -40,7 +38,6 @@ return BabyJubJubParams {
INFINITY: INFINITY, INFINITY: INFINITY,
Gu: Gu, Gu: Gu,
Gv: Gv, Gv: Gv,
JUBJUBE: JUBJUBE,
JUBJUBC: JUBJUBC, JUBJUBC: JUBJUBC,
MONTA: MONTA, MONTA: MONTA,
MONTB: MONTB MONTB: MONTB

View file

@ -1,9 +1,6 @@
import "hashes/mimc7/mimc7R10" import "hashes/mimc7/mimc7R10"
def main(): def main():
assert(mimc7R10(0, 0) == 6004544488495356385698286530147974336054653445122716140990101827963729149289) assert(mimc7R10(0, 0) == 6004544488495356385698286530147974336054653445122716140990101827963729149289f)
assert(mimc7R10(100, 0) == 2977550761518141183167168643824354554080911485709001361112529600968315693145) assert(mimc7R10(100, 0) == 2977550761518141183167168643824354554080911485709001361112529600968315693145f)
assert(mimc7R10(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 2977550761518141183167168643824354554080911485709001361112529600968315693145)
assert(mimc7R10(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 11476724043755138071320043459606423473319855817296339514744600646762741571430)
assert(mimc7R10(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 6004544488495356385698286530147974336054653445122716140990101827963729149289)
return return

View file

@ -1,9 +1,6 @@
import "hashes/mimc7/mimc7R20" import "hashes/mimc7/mimc7R20"
def main(): def main():
assert(mimc7R20(0, 0) == 19139739902058628561064841933381604453445216873412991992755775746150759284829) assert(mimc7R20(0, 0) == 19139739902058628561064841933381604453445216873412991992755775746150759284829f)
assert(mimc7R20(100, 0) == 8623418512398828792274158979964869393034224267928014534933203776818702139758) assert(mimc7R20(100, 0) == 8623418512398828792274158979964869393034224267928014534933203776818702139758f)
assert(mimc7R20(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 8623418512398828792274158979964869393034224267928014534933203776818702139758)
assert(mimc7R20(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 15315177265066649795408805007175121550344555424263995530745989936206840798041)
assert(mimc7R20(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 19139739902058628561064841933381604453445216873412991992755775746150759284829)
return return

View file

@ -1,9 +1,6 @@
import "hashes/mimc7/mimc7R50" import "hashes/mimc7/mimc7R50"
def main(): def main():
assert(mimc7R50(0, 0) == 3049953358280347916081509186284461274525472221619157672645224540758481713173) assert(mimc7R50(0, 0) == 3049953358280347916081509186284461274525472221619157672645224540758481713173f)
assert(mimc7R50(100, 0) == 18511388995652647480418174218630545482006454713617579894396683237092568946789) assert(mimc7R50(100, 0) == 18511388995652647480418174218630545482006454713617579894396683237092568946789f)
assert(mimc7R50(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 18511388995652647480418174218630545482006454713617579894396683237092568946789)
assert(mimc7R50(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 9149577627043020462780389988155990926223727917856424056384664564191878439702)
assert(mimc7R50(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 3049953358280347916081509186284461274525472221619157672645224540758481713173)
return return

View file

@ -2,8 +2,5 @@ import "hashes/mimc7/mimc7R90"
def main(): def main():
assert(mimc7R90(0, 0) == 20281265111705407344053532742843085357648991805359414661661476832595822221514) assert(mimc7R90(0, 0) == 20281265111705407344053532742843085357648991805359414661661476832595822221514)
assert(mimc7R90(100, 0) == 1010054095264022068840870550831559811104631937745987065544478027572003292636) assert(mimc7R90(100, 0) == 1010054095264022068840870550831559811104631937745987065544478027572003292636f)
assert(mimc7R90(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 1010054095264022068840870550831559811104631937745987065544478027572003292636)
assert(mimc7R90(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 8189519586469873426687580455476035992041353456517724932462363814215190642760)
assert(mimc7R90(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 20281265111705407344053532742843085357648991805359414661661476832595822221514)
return return

View file

@ -1,7 +1,6 @@
import "hashes/mimcSponge/mimcSponge" as mimcSponge import "hashes/mimcSponge/mimcSponge" as mimcSponge
def main(): def main():
assert(mimcSponge([1,2], 3) == [20225509322021146255705869525264566735642015554514977326536820959638320229084,13871743498877225461925335509899475799121918157213219438898506786048812913771,21633608428713573518356618235457250173701815120501233429160399974209848779097]) assert(mimcSponge([1,2], 3) == [20225509322021146255705869525264566735642015554514977326536820959638320229084,13871743498877225461925335509899475799121918157213219438898506786048812913771,21633608428713573518356618235457250173701815120501233429160399974209848779097f])
assert(mimcSponge([0,0], 0) == [20636625426020718969131298365984859231982649550971729229988535915544421356929,6046202021237334713296073963481784771443313518730771623154467767602059802325,16227963524034219233279650312501310147918176407385833422019760797222680144279]) assert(mimcSponge([0,0], 0) == [20636625426020718969131298365984859231982649550971729229988535915544421356929,6046202021237334713296073963481784771443313518730771623154467767602059802325,16227963524034219233279650312501310147918176407385833422019760797222680144279f])
assert(mimcSponge([21888242871839275222246405745257275088548364400416034343698204186575808495617, 0], 0) == [20636625426020718969131298365984859231982649550971729229988535915544421356929,6046202021237334713296073963481784771443313518730771623154467767602059802325,16227963524034219233279650312501310147918176407385833422019760797222680144279])
return return