make tests pass
This commit is contained in:
parent
69e682f065
commit
be77844e69
25 changed files with 840 additions and 1024 deletions
|
@ -1,3 +1,3 @@
|
|||
def main() -> ():
|
||||
assert(1 == 2)
|
||||
assert(1f == 2f)
|
||||
return
|
|
@ -14,6 +14,6 @@ def main(field a) -> field:
|
|||
assert(2 * b == a * 12 + 60)
|
||||
field c = 7 * (b + a)
|
||||
assert(isEqual(c, 7 * b + 7 * a))
|
||||
field k = if [1, 2] == [3, 4] then 1 else 3 fi
|
||||
field k = if [1f, 2] == [3f, 4] then 1 else 3 fi
|
||||
assert([Bar { foo : [Foo { a: 42 }]}] == [Bar { foo : [Foo { a: 42 }]}])
|
||||
return b + c
|
|
@ -722,10 +722,9 @@ mod tests {
|
|||
arguments: vec![],
|
||||
statements: vec![absy::Statement::Return(
|
||||
absy::ExpressionList {
|
||||
expressions: vec![absy::Expression::FieldConstant(
|
||||
Bn128Field::from(42),
|
||||
)
|
||||
.into()],
|
||||
expressions: vec![
|
||||
absy::Expression::IntConstant(42usize.into()).into()
|
||||
],
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
|
@ -803,10 +802,9 @@ mod tests {
|
|||
],
|
||||
statements: vec![absy::Statement::Return(
|
||||
absy::ExpressionList {
|
||||
expressions: vec![absy::Expression::FieldConstant(
|
||||
Bn128Field::from(42),
|
||||
)
|
||||
.into()],
|
||||
expressions: vec![
|
||||
absy::Expression::IntConstant(42usize.into()).into()
|
||||
],
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
|
@ -870,7 +868,7 @@ mod tests {
|
|||
"field[2]",
|
||||
absy::UnresolvedType::Array(
|
||||
box absy::UnresolvedType::FieldElement.mock(),
|
||||
absy::Expression::FieldConstant(Bn128Field::from(2)).mock(),
|
||||
absy::Expression::IntConstant(2usize.into()).mock(),
|
||||
),
|
||||
),
|
||||
(
|
||||
|
@ -878,10 +876,10 @@ mod tests {
|
|||
absy::UnresolvedType::Array(
|
||||
box absy::UnresolvedType::Array(
|
||||
box absy::UnresolvedType::FieldElement.mock(),
|
||||
absy::Expression::FieldConstant(Bn128Field::from(3)).mock(),
|
||||
absy::Expression::IntConstant(3usize.into()).mock(),
|
||||
)
|
||||
.mock(),
|
||||
absy::Expression::FieldConstant(Bn128Field::from(2)).mock(),
|
||||
absy::Expression::IntConstant(2usize.into()).mock(),
|
||||
),
|
||||
),
|
||||
(
|
||||
|
@ -889,10 +887,10 @@ mod tests {
|
|||
absy::UnresolvedType::Array(
|
||||
box absy::UnresolvedType::Array(
|
||||
box absy::UnresolvedType::Boolean.mock(),
|
||||
absy::Expression::U32Constant(3).mock(),
|
||||
absy::Expression::U32Constant(3u32).mock(),
|
||||
)
|
||||
.mock(),
|
||||
absy::Expression::FieldConstant(Bn128Field::from(2)).mock(),
|
||||
absy::Expression::IntConstant(2usize.into()).mock(),
|
||||
),
|
||||
),
|
||||
];
|
||||
|
@ -943,7 +941,7 @@ mod tests {
|
|||
absy::Expression::Select(
|
||||
box absy::Expression::Identifier("a").into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(Bn128Field::from(3)).into(),
|
||||
absy::Expression::IntConstant(3usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
|
@ -954,13 +952,13 @@ mod tests {
|
|||
box absy::Expression::Select(
|
||||
box absy::Expression::Identifier("a").into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(Bn128Field::from(3)).into(),
|
||||
absy::Expression::IntConstant(3usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(Bn128Field::from(4)).into(),
|
||||
absy::Expression::IntConstant(4usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
|
@ -970,11 +968,11 @@ mod tests {
|
|||
absy::Expression::Select(
|
||||
box absy::Expression::FunctionCall(
|
||||
"a",
|
||||
vec![absy::Expression::FieldConstant(Bn128Field::from(3)).into()],
|
||||
vec![absy::Expression::IntConstant(3usize.into()).into()],
|
||||
)
|
||||
.into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(Bn128Field::from(4)).into(),
|
||||
absy::Expression::IntConstant(4usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
|
@ -985,17 +983,17 @@ mod tests {
|
|||
box absy::Expression::Select(
|
||||
box absy::Expression::FunctionCall(
|
||||
"a",
|
||||
vec![absy::Expression::FieldConstant(Bn128Field::from(3)).into()],
|
||||
vec![absy::Expression::IntConstant(3usize.into()).into()],
|
||||
)
|
||||
.into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(Bn128Field::from(4)).into(),
|
||||
absy::Expression::IntConstant(4usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(Bn128Field::from(5)).into(),
|
||||
absy::Expression::IntConstant(5usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
|
|
|
@ -377,7 +377,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
BooleanExpression::Identifier(x) => {
|
||||
FlatExpression::Identifier(self.layout.get(&x).unwrap().clone())
|
||||
}
|
||||
BooleanExpression::FieldLt(box lhs, box rhs) => {
|
||||
BooleanExpression::Lt(box lhs, box rhs) => {
|
||||
// Get the bit width to know the size of the binary decompositions for this Field
|
||||
let bit_width = T::get_required_bits();
|
||||
let safe_width = bit_width - 2; // making sure we don't overflow, assert here?
|
||||
|
@ -660,11 +660,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
res
|
||||
}
|
||||
BooleanExpression::FieldLe(box lhs, box rhs) => {
|
||||
BooleanExpression::Le(box lhs, box rhs) => {
|
||||
let lt = self.flatten_boolean_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
BooleanExpression::FieldLt(box lhs.clone(), box rhs.clone()),
|
||||
BooleanExpression::Lt(box lhs.clone(), box rhs.clone()),
|
||||
);
|
||||
let eq = self.flatten_boolean_expression(
|
||||
symbols,
|
||||
|
@ -673,110 +673,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
FlatExpression::Add(box eq, box lt)
|
||||
}
|
||||
BooleanExpression::FieldGt(lhs, rhs) => self.flatten_boolean_expression(
|
||||
BooleanExpression::Gt(lhs, rhs) => self.flatten_boolean_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
BooleanExpression::FieldLt(rhs, lhs),
|
||||
BooleanExpression::Lt(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::FieldGe(lhs, rhs) => self.flatten_boolean_expression(
|
||||
BooleanExpression::Ge(lhs, rhs) => self.flatten_boolean_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
BooleanExpression::FieldLe(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::UintLt(box lhs, box rhs) => {
|
||||
let lhs_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, lhs);
|
||||
let rhs_flattened =
|
||||
self.flatten_uint_expression(symbols, statements_flattened, rhs);
|
||||
|
||||
// Get the bit width to know the size of the binary decompositions for this Field
|
||||
// This is not this uint bitwidth
|
||||
let bit_width = T::get_required_bits();
|
||||
|
||||
// lhs
|
||||
let lhs_id = self.define(lhs_flattened.get_field_unchecked(), statements_flattened);
|
||||
let rhs_id = self.define(rhs_flattened.get_field_unchecked(), statements_flattened);
|
||||
|
||||
// sym := (lhs * 2) - (rhs * 2)
|
||||
let subtraction_result = FlatExpression::Sub(
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2)),
|
||||
box FlatExpression::Identifier(lhs_id),
|
||||
),
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2)),
|
||||
box FlatExpression::Identifier(rhs_id),
|
||||
),
|
||||
);
|
||||
|
||||
// define variables for the bits
|
||||
let sub_bits_be: Vec<FlatVariable> =
|
||||
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
sub_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![subtraction_result.clone()],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..bit_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(sub_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(sub_bits_be[i]),
|
||||
box FlatExpression::Identifier(sub_bits_be[i]),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// check that the decomposition is in the field with a strict `< p` checks
|
||||
self.strict_le_check(
|
||||
statements_flattened,
|
||||
&T::max_value_bit_vector_be(),
|
||||
sub_bits_be.clone(),
|
||||
);
|
||||
|
||||
// sum(sym_b{i} * 2**i)
|
||||
let mut expr = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..bit_width {
|
||||
expr = FlatExpression::Add(
|
||||
box expr,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(sub_bits_be[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(subtraction_result, expr));
|
||||
|
||||
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
|
||||
}
|
||||
BooleanExpression::UintLe(box lhs, box rhs) => {
|
||||
let lt = self.flatten_boolean_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
BooleanExpression::UintLt(box lhs.clone(), box rhs.clone()),
|
||||
);
|
||||
let eq = self.flatten_boolean_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
BooleanExpression::UintEq(box lhs.clone(), box rhs.clone()),
|
||||
);
|
||||
FlatExpression::Add(box eq, box lt)
|
||||
}
|
||||
BooleanExpression::UintGt(lhs, rhs) => self.flatten_boolean_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
BooleanExpression::UintLt(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::UintGe(lhs, rhs) => self.flatten_boolean_expression(
|
||||
symbols,
|
||||
statements_flattened,
|
||||
BooleanExpression::UintLe(rhs, lhs),
|
||||
BooleanExpression::Le(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::Or(box lhs, box rhs) => {
|
||||
let x = self.flatten_boolean_expression(symbols, statements_flattened, lhs);
|
||||
|
@ -2464,7 +2369,7 @@ mod tests {
|
|||
#[test]
|
||||
fn geq_leq() {
|
||||
let mut flattener = Flattener::new();
|
||||
let expression_le = BooleanExpression::FieldLe(
|
||||
let expression_le = BooleanExpression::Le(
|
||||
box FieldElementExpression::Number(Bn128Field::from(32)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
);
|
||||
|
@ -2475,7 +2380,7 @@ mod tests {
|
|||
);
|
||||
|
||||
let mut flattener = Flattener::new();
|
||||
let expression_ge = BooleanExpression::FieldGe(
|
||||
let expression_ge = BooleanExpression::Ge(
|
||||
box FieldElementExpression::Number(Bn128Field::from(32)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
);
|
||||
|
@ -2496,7 +2401,7 @@ mod tests {
|
|||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
),
|
||||
box BooleanExpression::FieldLt(
|
||||
box BooleanExpression::Lt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(20)),
|
||||
),
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,33 +1,32 @@
|
|||
use std::marker::PhantomData;
|
||||
use typed_absy;
|
||||
use typed_absy::types::UBitwidth;
|
||||
use typed_absy::types::{StructType, UBitwidth};
|
||||
use zir;
|
||||
use zokrates_field::Field;
|
||||
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
||||
pub struct Flattener<T: Field> {
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
fn flatten_identifier_rec<'ast>(
|
||||
id: zir::SourceIdentifier<'ast>,
|
||||
ty: typed_absy::types::Type,
|
||||
) -> Vec<zir::Variable<'ast>> {
|
||||
fn flatten_identifier_rec<'a>(
|
||||
id: zir::SourceIdentifier<'a>,
|
||||
ty: typed_absy::Type,
|
||||
) -> Vec<zir::Variable> {
|
||||
match ty {
|
||||
typed_absy::types::Type::FieldElement => vec![zir::Variable {
|
||||
typed_absy::Type::Int => unreachable!(),
|
||||
typed_absy::Type::FieldElement => vec![zir::Variable {
|
||||
id: zir::Identifier::Source(id),
|
||||
_type: zir::Type::FieldElement,
|
||||
}],
|
||||
typed_absy::types::Type::Boolean => vec![zir::Variable {
|
||||
typed_absy::Type::Boolean => vec![zir::Variable {
|
||||
id: zir::Identifier::Source(id),
|
||||
_type: zir::Type::Boolean,
|
||||
}],
|
||||
typed_absy::types::Type::Uint(bitwidth) => vec![zir::Variable {
|
||||
typed_absy::Type::Uint(bitwidth) => vec![zir::Variable {
|
||||
id: zir::Identifier::Source(id),
|
||||
_type: zir::Type::uint(bitwidth.to_usize()),
|
||||
}],
|
||||
typed_absy::types::Type::Array(array_type) => (0..array_type.size)
|
||||
typed_absy::Type::Array(array_type) => (0..array_type.size)
|
||||
.flat_map(|i| {
|
||||
flatten_identifier_rec(
|
||||
zir::SourceIdentifier::Select(box id.clone(), i),
|
||||
|
@ -35,7 +34,7 @@ fn flatten_identifier_rec<'ast>(
|
|||
)
|
||||
})
|
||||
.collect(),
|
||||
typed_absy::types::Type::Struct(members) => members
|
||||
typed_absy::Type::Struct(members) => members
|
||||
.into_iter()
|
||||
.flat_map(|struct_member| {
|
||||
flatten_identifier_rec(
|
||||
|
@ -44,7 +43,6 @@ fn flatten_identifier_rec<'ast>(
|
|||
)
|
||||
})
|
||||
.collect(),
|
||||
typed_absy::types::Type::Int => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -80,7 +78,7 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
|
||||
fn fold_parameter(&mut self, p: typed_absy::Parameter<'ast>) -> Vec<zir::Parameter<'ast>> {
|
||||
let private = p.private;
|
||||
self.fold_variable(p.id.try_into().unwrap())
|
||||
self.fold_variable(p.id)
|
||||
.into_iter()
|
||||
.map(|v| zir::Parameter { id: v, private })
|
||||
.collect()
|
||||
|
@ -94,8 +92,6 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
let id = self.fold_name(v.id.clone());
|
||||
let ty = v.get_type();
|
||||
|
||||
let ty = typed_absy::types::Type::try_from(ty).unwrap();
|
||||
|
||||
flatten_identifier_rec(id, ty)
|
||||
}
|
||||
|
||||
|
@ -152,8 +148,6 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
) -> zir::ZirExpressionList<'ast, T> {
|
||||
match es {
|
||||
typed_absy::TypedExpressionList::FunctionCall(id, arguments, _) => {
|
||||
let id = typed_absy::types::FunctionKey::try_from(id).unwrap();
|
||||
|
||||
zir::ZirExpressionList::FunctionCall(
|
||||
self.fold_function_key(id),
|
||||
arguments
|
||||
|
@ -202,7 +196,7 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
|
||||
fn fold_array_expression_inner(
|
||||
&mut self,
|
||||
ty: &typed_absy::types::Type,
|
||||
ty: &typed_absy::Type,
|
||||
size: usize,
|
||||
e: typed_absy::ArrayExpressionInner<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
|
@ -210,7 +204,7 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
}
|
||||
fn fold_struct_expression_inner(
|
||||
&mut self,
|
||||
ty: &typed_absy::types::StructType,
|
||||
ty: &StructType,
|
||||
e: typed_absy::StructExpressionInner<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
fold_struct_expression_inner(self, ty, e)
|
||||
|
@ -225,12 +219,7 @@ pub fn fold_module<'ast, T: Field>(
|
|||
functions: p
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|(key, fun)| {
|
||||
(
|
||||
f.fold_function_key(key.try_into().unwrap()),
|
||||
f.fold_function_symbol(fun),
|
||||
)
|
||||
})
|
||||
.map(|(key, fun)| (f.fold_function_key(key), f.fold_function_symbol(fun)))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
@ -280,16 +269,14 @@ pub fn fold_statement<'ast, T: Field>(
|
|||
|
||||
pub fn fold_array_expression_inner<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
t: &typed_absy::types::Type,
|
||||
t: &typed_absy::Type,
|
||||
size: usize,
|
||||
e: typed_absy::ArrayExpressionInner<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
match e {
|
||||
typed_absy::ArrayExpressionInner::Identifier(id) => {
|
||||
let variables = flatten_identifier_rec(
|
||||
f.fold_name(id),
|
||||
typed_absy::types::Type::array(t.clone(), size),
|
||||
);
|
||||
let variables =
|
||||
flatten_identifier_rec(f.fold_name(id), typed_absy::Type::array(t.clone(), size));
|
||||
variables
|
||||
.into_iter()
|
||||
.map(|v| match v._type {
|
||||
|
@ -344,11 +331,7 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
|
|||
let offset: usize = members
|
||||
.iter()
|
||||
.take_while(|member| member.id != id)
|
||||
.map(|member| {
|
||||
typed_absy::types::Type::try_from(*member.ty.clone())
|
||||
.unwrap()
|
||||
.get_primitive_count()
|
||||
})
|
||||
.map(|member| member.ty.get_primitive_count())
|
||||
.sum();
|
||||
|
||||
// we also need the size of this member
|
||||
|
@ -358,15 +341,12 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
|
|||
}
|
||||
typed_absy::ArrayExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_uint_expression(index);
|
||||
let index = f.fold_field_expression(index);
|
||||
|
||||
match index.into_inner() {
|
||||
zir::UExpressionInner::Value(i) => {
|
||||
let size = typed_absy::types::Type::try_from(t.clone())
|
||||
.unwrap()
|
||||
.get_primitive_count()
|
||||
* size;
|
||||
let start = i as usize * size;
|
||||
match index {
|
||||
zir::FieldElementExpression::Number(i) => {
|
||||
let size = t.get_primitive_count() * size;
|
||||
let start = i.to_dec_string().parse::<usize>().unwrap() * size;
|
||||
let end = start + size;
|
||||
array[start..end].to_vec()
|
||||
}
|
||||
|
@ -378,13 +358,13 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
|
|||
|
||||
pub fn fold_struct_expression_inner<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
t: &typed_absy::types::StructType,
|
||||
t: &StructType,
|
||||
e: typed_absy::StructExpressionInner<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
match e {
|
||||
typed_absy::StructExpressionInner::Identifier(id) => {
|
||||
let variables =
|
||||
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::struc(t.clone()));
|
||||
flatten_identifier_rec(f.fold_name(id), typed_absy::Type::struc(t.clone()));
|
||||
variables
|
||||
.into_iter()
|
||||
.map(|v| match v._type {
|
||||
|
@ -439,33 +419,30 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
|
|||
let offset: usize = members
|
||||
.iter()
|
||||
.take_while(|member| member.id != id)
|
||||
.map(|member| {
|
||||
typed_absy::types::Type::try_from(*member.ty.clone())
|
||||
.unwrap()
|
||||
.get_primitive_count()
|
||||
})
|
||||
.map(|member| member.ty.get_primitive_count())
|
||||
.sum();
|
||||
|
||||
// we also need the size of this member
|
||||
let size = typed_absy::types::Type::try_from(
|
||||
*t.iter().find(|member| member.id == id).cloned().unwrap().ty,
|
||||
)
|
||||
.unwrap()
|
||||
.get_primitive_count();
|
||||
let size = t
|
||||
.iter()
|
||||
.find(|member| member.id == id)
|
||||
.unwrap()
|
||||
.ty
|
||||
.get_primitive_count();
|
||||
|
||||
s[offset..offset + size].to_vec()
|
||||
}
|
||||
typed_absy::StructExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_uint_expression(index);
|
||||
let index = f.fold_field_expression(index);
|
||||
|
||||
match index.into_inner() {
|
||||
zir::UExpressionInner::Value(i) => {
|
||||
match index {
|
||||
zir::FieldElementExpression::Number(i) => {
|
||||
let size = t
|
||||
.iter()
|
||||
.map(|m| m.ty.get_primitive_count())
|
||||
.fold(0, |acc, current| acc + current);
|
||||
let start = i as usize * size;
|
||||
let start = i.to_dec_string().parse::<usize>().unwrap() * size;
|
||||
let end = start + size;
|
||||
array[start..end].to_vec()
|
||||
}
|
||||
|
@ -483,7 +460,7 @@ pub fn fold_field_expression<'ast, T: Field>(
|
|||
typed_absy::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n),
|
||||
typed_absy::FieldElementExpression::Identifier(id) => {
|
||||
zir::FieldElementExpression::Identifier(
|
||||
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::FieldElement)[0]
|
||||
flatten_identifier_rec(f.fold_name(id), typed_absy::Type::FieldElement)[0]
|
||||
.id
|
||||
.clone(),
|
||||
)
|
||||
|
@ -528,22 +505,26 @@ pub fn fold_field_expression<'ast, T: Field>(
|
|||
let offset: usize = members
|
||||
.iter()
|
||||
.take_while(|member| member.id != id)
|
||||
.map(|member| {
|
||||
typed_absy::types::Type::try_from(*member.ty.clone())
|
||||
.unwrap()
|
||||
.get_primitive_count()
|
||||
})
|
||||
.map(|member| member.ty.get_primitive_count())
|
||||
.sum();
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
s[offset].clone().try_into().unwrap()
|
||||
}
|
||||
typed_absy::FieldElementExpression::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
|
||||
let index = f.fold_uint_expression(index);
|
||||
let index = f.fold_field_expression(index);
|
||||
|
||||
match index.into_inner() {
|
||||
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
|
||||
use std::convert::TryInto;
|
||||
|
||||
match index {
|
||||
zir::FieldElementExpression::Number(i) => array
|
||||
[i.to_dec_string().parse::<usize>().unwrap()]
|
||||
.clone()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
_ => unreachable!(""),
|
||||
}
|
||||
}
|
||||
|
@ -557,7 +538,7 @@ pub fn fold_boolean_expression<'ast, T: Field>(
|
|||
match e {
|
||||
typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v),
|
||||
typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier(
|
||||
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::Boolean)[0]
|
||||
flatten_identifier_rec(f.fold_name(id), typed_absy::Type::Boolean)[0]
|
||||
.id
|
||||
.clone(),
|
||||
),
|
||||
|
@ -633,45 +614,25 @@ pub fn fold_boolean_expression<'ast, T: Field>(
|
|||
|
||||
zir::BooleanExpression::UintEq(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
typed_absy::BooleanExpression::Lt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
zir::BooleanExpression::FieldLt(box e1, box e2)
|
||||
zir::BooleanExpression::Lt(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
typed_absy::BooleanExpression::Le(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
zir::BooleanExpression::FieldLe(box e1, box e2)
|
||||
zir::BooleanExpression::Le(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
typed_absy::BooleanExpression::Gt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
zir::BooleanExpression::FieldGt(box e1, box e2)
|
||||
zir::BooleanExpression::Gt(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
typed_absy::BooleanExpression::Ge(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
zir::BooleanExpression::FieldGe(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::UintLt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
zir::BooleanExpression::UintLt(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::UintLe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
zir::BooleanExpression::UintLe(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
zir::BooleanExpression::UintGt(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
zir::BooleanExpression::UintGe(box e1, box e2)
|
||||
zir::BooleanExpression::Ge(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::Or(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1);
|
||||
|
@ -702,21 +663,25 @@ pub fn fold_boolean_expression<'ast, T: Field>(
|
|||
let offset: usize = members
|
||||
.iter()
|
||||
.take_while(|member| member.id != id)
|
||||
.map(|member| {
|
||||
typed_absy::types::Type::try_from(*member.ty.clone())
|
||||
.unwrap()
|
||||
.get_primitive_count()
|
||||
})
|
||||
.map(|member| member.ty.get_primitive_count())
|
||||
.sum();
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
s[offset].clone().try_into().unwrap()
|
||||
}
|
||||
typed_absy::BooleanExpression::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_uint_expression(index);
|
||||
let index = f.fold_field_expression(index);
|
||||
|
||||
match index.into_inner() {
|
||||
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
|
||||
use std::convert::TryInto;
|
||||
|
||||
match index {
|
||||
zir::FieldElementExpression::Number(i) => array
|
||||
[i.to_dec_string().parse::<usize>().unwrap()]
|
||||
.clone()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
@ -739,7 +704,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
match e {
|
||||
typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v),
|
||||
typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier(
|
||||
flatten_identifier_rec(f.fold_name(id), typed_absy::types::Type::Uint(bitwidth))[0]
|
||||
flatten_identifier_rec(f.fold_name(id), typed_absy::Type::Uint(bitwidth))[0]
|
||||
.id
|
||||
.clone(),
|
||||
),
|
||||
|
@ -801,11 +766,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
}
|
||||
typed_absy::UExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_uint_expression(index);
|
||||
let index = f.fold_field_expression(index);
|
||||
|
||||
match index.into_inner() {
|
||||
zir::UExpressionInner::Value(i) => {
|
||||
let e: zir::UExpression<_> = array[i as usize].clone().try_into().unwrap();
|
||||
use std::convert::TryInto;
|
||||
|
||||
match index {
|
||||
zir::FieldElementExpression::Number(i) => {
|
||||
let e: zir::UExpression<_> = array[i.to_dec_string().parse::<usize>().unwrap()]
|
||||
.clone()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
e.into_inner()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
|
@ -819,13 +789,11 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
let offset: usize = members
|
||||
.iter()
|
||||
.take_while(|member| member.id != id)
|
||||
.map(|member| {
|
||||
typed_absy::types::Type::try_from(*member.ty.clone())
|
||||
.unwrap()
|
||||
.get_primitive_count()
|
||||
})
|
||||
.map(|member| member.ty.get_primitive_count())
|
||||
.sum();
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
let res: zir::UExpression<'ast, T> = s[offset].clone().try_into().unwrap();
|
||||
|
||||
res.into_inner()
|
||||
|
@ -854,9 +822,7 @@ pub fn fold_function<'ast, T: Field>(
|
|||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
signature: typed_absy::types::Signature::try_from(fun.signature)
|
||||
.unwrap()
|
||||
.into(),
|
||||
signature: fun.signature.into(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -864,22 +830,14 @@ pub fn fold_array_expression<'ast, T: Field>(
|
|||
f: &mut Flattener<T>,
|
||||
e: typed_absy::ArrayExpression<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
let size = e.size();
|
||||
f.fold_array_expression_inner(
|
||||
&typed_absy::types::Type::try_from(e.inner_type().clone()).unwrap(),
|
||||
size,
|
||||
e.into_inner(),
|
||||
)
|
||||
f.fold_array_expression_inner(&e.inner_type().clone(), e.size(), e.into_inner())
|
||||
}
|
||||
|
||||
pub fn fold_struct_expression<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
e: typed_absy::StructExpression<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
f.fold_struct_expression_inner(
|
||||
&typed_absy::types::StructType::try_from(e.ty().clone()).unwrap(),
|
||||
e.into_inner(),
|
||||
)
|
||||
f.fold_struct_expression_inner(&e.ty().clone(), e.into_inner())
|
||||
}
|
||||
|
||||
pub fn fold_function_symbol<'ast, T: Field>(
|
||||
|
@ -890,10 +848,9 @@ pub fn fold_function_symbol<'ast, T: Field>(
|
|||
typed_absy::TypedFunctionSymbol::Here(fun) => {
|
||||
zir::ZirFunctionSymbol::Here(f.fold_function(fun))
|
||||
}
|
||||
typed_absy::TypedFunctionSymbol::There(key, module) => zir::ZirFunctionSymbol::There(
|
||||
f.fold_function_key(typed_absy::types::FunctionKey::try_from(key).unwrap()),
|
||||
module,
|
||||
), // by default, do not fold modules recursively
|
||||
typed_absy::TypedFunctionSymbol::There(key, module) => {
|
||||
zir::ZirFunctionSymbol::There(f.fold_function_key(key), module)
|
||||
} // by default, do not fold modules recursively
|
||||
typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,10 +67,10 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn detect_non_constant_bound() {
|
||||
let loops: Vec<TypedStatement<Bn128Field>> = vec![TypedStatement::For(
|
||||
let loops = vec![TypedStatement::For(
|
||||
Variable::field_element("i"),
|
||||
UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32),
|
||||
2u32.into(),
|
||||
FieldElementExpression::Identifier("i".into()),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
vec![],
|
||||
)];
|
||||
|
||||
|
@ -118,12 +118,12 @@ mod tests {
|
|||
|
||||
let s = TypedStatement::For(
|
||||
Variable::field_element("i"),
|
||||
0u32.into(),
|
||||
2u32.into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(0)),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
vec![TypedStatement::For(
|
||||
Variable::field_element("j"),
|
||||
UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32),
|
||||
2u32.into(),
|
||||
FieldElementExpression::Identifier("i".into()),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
vec![
|
||||
TypedStatement::Declaration(Variable::field_element("foo")),
|
||||
TypedStatement::Definition(
|
||||
|
|
|
@ -121,8 +121,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
// we stop propagation here as constants maybe be modified inside the loop body
|
||||
// which we do not visit
|
||||
TypedStatement::For(v, from, to, statements) => {
|
||||
let from = self.fold_uint_expression(from);
|
||||
let to = self.fold_uint_expression(to);
|
||||
let from = self.fold_field_expression(from);
|
||||
let to = self.fold_field_expression(to);
|
||||
|
||||
// invalidate the constants map as any constant could be modified inside the loop body, which we don't visit
|
||||
self.constants.clear();
|
||||
|
@ -505,31 +505,33 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
}
|
||||
UExpressionInner::Select(box array, box index) => {
|
||||
let array = self.fold_array_expression(array);
|
||||
let index = self.fold_uint_expression(index);
|
||||
let index = self.fold_field_expression(index);
|
||||
|
||||
let inner_type = array.inner_type().clone();
|
||||
let size = array.size();
|
||||
|
||||
match (array.into_inner(), index.into_inner()) {
|
||||
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
|
||||
let n = n as usize;
|
||||
if n < size {
|
||||
UExpression::try_from(v[n].clone()).unwrap().into_inner()
|
||||
match (array.into_inner(), index) {
|
||||
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
|
||||
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
|
||||
if n_as_usize < size {
|
||||
UExpression::try_from(v[n_as_usize].clone())
|
||||
.unwrap()
|
||||
.into_inner()
|
||||
} else {
|
||||
unreachable!(
|
||||
"out of bounds index ({} >= {}) found during static analysis",
|
||||
n, size
|
||||
n_as_usize, size
|
||||
);
|
||||
}
|
||||
}
|
||||
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
|
||||
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
|
||||
match self.constants.get(&TypedAssignee::Select(
|
||||
box TypedAssignee::Identifier(Variable::array(
|
||||
id.clone(),
|
||||
inner_type.clone(),
|
||||
size,
|
||||
)),
|
||||
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Number(n.clone()).into(),
|
||||
)) {
|
||||
Some(e) => match e {
|
||||
TypedExpression::Uint(e) => e.clone().into_inner(),
|
||||
|
@ -537,14 +539,11 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
},
|
||||
None => UExpressionInner::Select(
|
||||
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
|
||||
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Number(n),
|
||||
),
|
||||
}
|
||||
}
|
||||
(a, i) => UExpressionInner::Select(
|
||||
box a.annotate(inner_type, size),
|
||||
box i.annotate(UBitwidth::B32),
|
||||
),
|
||||
(a, i) => UExpressionInner::Select(box a.annotate(inner_type, size), box i),
|
||||
}
|
||||
}
|
||||
UExpressionInner::FunctionCall(key, arguments) => {
|
||||
|
@ -648,46 +647,45 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
}
|
||||
FieldElementExpression::Select(box array, box index) => {
|
||||
let array = self.fold_array_expression(array);
|
||||
let index = self.fold_uint_expression(index);
|
||||
let index = self.fold_field_expression(index);
|
||||
|
||||
let inner_type = array.inner_type().clone();
|
||||
let size = array.size();
|
||||
|
||||
match (array.into_inner(), index.into_inner()) {
|
||||
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
|
||||
let n = n as usize;
|
||||
if n < size {
|
||||
FieldElementExpression::try_from(v[n].clone()).unwrap()
|
||||
match (array.into_inner(), index) {
|
||||
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
|
||||
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
|
||||
if n_as_usize < size {
|
||||
FieldElementExpression::try_from(v[n_as_usize].clone()).unwrap()
|
||||
} else {
|
||||
unreachable!(
|
||||
"out of bounds index ({} >= {}) found during static analysis",
|
||||
n, size
|
||||
n_as_usize, size
|
||||
);
|
||||
}
|
||||
}
|
||||
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
|
||||
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
|
||||
match self.constants.get(&TypedAssignee::Select(
|
||||
box TypedAssignee::Identifier(Variable::array(
|
||||
id.clone(),
|
||||
inner_type.clone(),
|
||||
size,
|
||||
)),
|
||||
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Number(n.clone()).into(),
|
||||
)) {
|
||||
Some(e) => match e {
|
||||
TypedExpression::FieldElement(e) => e.clone(),
|
||||
_ => unreachable!(""),
|
||||
_ => unreachable!("??"),
|
||||
},
|
||||
None => FieldElementExpression::Select(
|
||||
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
|
||||
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Number(n),
|
||||
),
|
||||
}
|
||||
}
|
||||
(a, i) => FieldElementExpression::Select(
|
||||
box a.annotate(inner_type, size),
|
||||
box i.annotate(UBitwidth::B32),
|
||||
),
|
||||
(a, i) => {
|
||||
FieldElementExpression::Select(box a.annotate(inner_type, size), box i)
|
||||
}
|
||||
}
|
||||
}
|
||||
FieldElementExpression::Member(box s, m) => {
|
||||
|
@ -749,48 +747,45 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
}
|
||||
ArrayExpressionInner::Select(box array, box index) => {
|
||||
let array = self.fold_array_expression(array);
|
||||
let index = self.fold_uint_expression(index);
|
||||
let index = self.fold_field_expression(index);
|
||||
|
||||
let inner_type = array.inner_type().clone();
|
||||
let size = array.size();
|
||||
|
||||
match (array.into_inner(), index.into_inner()) {
|
||||
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
|
||||
let n = n as usize;
|
||||
if n < size {
|
||||
ArrayExpression::try_from(v[n].clone())
|
||||
match (array.into_inner(), index) {
|
||||
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
|
||||
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
|
||||
if n_as_usize < size {
|
||||
ArrayExpression::try_from(v[n_as_usize].clone())
|
||||
.unwrap()
|
||||
.into_inner()
|
||||
} else {
|
||||
unreachable!(
|
||||
"out of bounds index ({} >= {}) found during static analysis",
|
||||
n, size
|
||||
n_as_usize, size
|
||||
);
|
||||
}
|
||||
}
|
||||
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
|
||||
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
|
||||
match self.constants.get(&TypedAssignee::Select(
|
||||
box TypedAssignee::Identifier(Variable::array(
|
||||
id.clone(),
|
||||
inner_type.clone(),
|
||||
size,
|
||||
)),
|
||||
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Number(n.clone()).into(),
|
||||
)) {
|
||||
Some(e) => match e {
|
||||
TypedExpression::Array(e) => e.clone().into_inner(),
|
||||
_ => unreachable!(""),
|
||||
_ => unreachable!("should be an array"),
|
||||
},
|
||||
None => ArrayExpressionInner::Select(
|
||||
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
|
||||
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Number(n),
|
||||
),
|
||||
}
|
||||
}
|
||||
(a, i) => ArrayExpressionInner::Select(
|
||||
box a.annotate(inner_type, size),
|
||||
box i.annotate(UBitwidth::B32),
|
||||
),
|
||||
(a, i) => ArrayExpressionInner::Select(box a.annotate(inner_type, size), box i),
|
||||
}
|
||||
}
|
||||
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
||||
|
@ -864,48 +859,47 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
}
|
||||
StructExpressionInner::Select(box array, box index) => {
|
||||
let array = self.fold_array_expression(array);
|
||||
let index = self.fold_uint_expression(index);
|
||||
let index = self.fold_field_expression(index);
|
||||
|
||||
let inner_type = array.inner_type().clone();
|
||||
let size = array.size();
|
||||
|
||||
match (array.into_inner(), index.into_inner()) {
|
||||
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
|
||||
let n = n as usize;
|
||||
if n < size {
|
||||
StructExpression::try_from(v[n].clone())
|
||||
match (array.into_inner(), index) {
|
||||
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
|
||||
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
|
||||
if n_as_usize < size {
|
||||
StructExpression::try_from(v[n_as_usize].clone())
|
||||
.unwrap()
|
||||
.into_inner()
|
||||
} else {
|
||||
unreachable!(
|
||||
"out of bounds index ({} >= {}) found during static analysis",
|
||||
n, size
|
||||
n_as_usize, size
|
||||
);
|
||||
}
|
||||
}
|
||||
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
|
||||
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
|
||||
match self.constants.get(&TypedAssignee::Select(
|
||||
box TypedAssignee::Identifier(Variable::array(
|
||||
id.clone(),
|
||||
inner_type.clone(),
|
||||
size,
|
||||
)),
|
||||
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Number(n.clone()).into(),
|
||||
)) {
|
||||
Some(e) => match e {
|
||||
TypedExpression::Struct(e) => e.clone().into_inner(),
|
||||
_ => unreachable!(""),
|
||||
_ => unreachable!("should be a struct"),
|
||||
},
|
||||
None => StructExpressionInner::Select(
|
||||
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
|
||||
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Number(n),
|
||||
),
|
||||
}
|
||||
}
|
||||
(a, i) => StructExpressionInner::Select(
|
||||
box a.annotate(inner_type, size),
|
||||
box i.annotate(UBitwidth::B32),
|
||||
),
|
||||
(a, i) => {
|
||||
StructExpressionInner::Select(box a.annotate(inner_type, size), box i)
|
||||
}
|
||||
}
|
||||
}
|
||||
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
||||
|
@ -998,7 +992,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
(e1, e2) => BooleanExpression::BoolEq(box e1, box e2),
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
BooleanExpression::Lt(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1);
|
||||
let e2 = self.fold_field_expression(e2);
|
||||
|
||||
|
@ -1006,10 +1000,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
BooleanExpression::Value(n1 < n2)
|
||||
}
|
||||
(e1, e2) => BooleanExpression::FieldLt(box e1, box e2),
|
||||
(e1, e2) => BooleanExpression::Lt(box e1, box e2),
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
BooleanExpression::Le(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1);
|
||||
let e2 = self.fold_field_expression(e2);
|
||||
|
||||
|
@ -1017,10 +1011,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
BooleanExpression::Value(n1 <= n2)
|
||||
}
|
||||
(e1, e2) => BooleanExpression::FieldLe(box e1, box e2),
|
||||
(e1, e2) => BooleanExpression::Le(box e1, box e2),
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
BooleanExpression::Gt(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1);
|
||||
let e2 = self.fold_field_expression(e2);
|
||||
|
||||
|
@ -1028,10 +1022,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
BooleanExpression::Value(n1 > n2)
|
||||
}
|
||||
(e1, e2) => BooleanExpression::FieldGt(box e1, box e2),
|
||||
(e1, e2) => BooleanExpression::Gt(box e1, box e2),
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
BooleanExpression::Ge(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1);
|
||||
let e2 = self.fold_field_expression(e2);
|
||||
|
||||
|
@ -1039,7 +1033,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
BooleanExpression::Value(n1 >= n2)
|
||||
}
|
||||
(e1, e2) => BooleanExpression::FieldGe(box e1, box e2),
|
||||
(e1, e2) => BooleanExpression::Ge(box e1, box e2),
|
||||
}
|
||||
}
|
||||
BooleanExpression::Or(box e1, box e2) => {
|
||||
|
@ -1098,46 +1092,43 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
}
|
||||
BooleanExpression::Select(box array, box index) => {
|
||||
let array = self.fold_array_expression(array);
|
||||
let index = self.fold_uint_expression(index);
|
||||
let index = self.fold_field_expression(index);
|
||||
|
||||
let inner_type = array.inner_type().clone();
|
||||
let size = array.size();
|
||||
|
||||
match (array.into_inner(), index.into_inner()) {
|
||||
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
|
||||
let n = n as usize;
|
||||
if n < size {
|
||||
BooleanExpression::try_from(v[n].clone()).unwrap()
|
||||
match (array.into_inner(), index) {
|
||||
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
|
||||
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
|
||||
if n_as_usize < size {
|
||||
BooleanExpression::try_from(v[n_as_usize].clone()).unwrap()
|
||||
} else {
|
||||
unreachable!(
|
||||
"out of bounds index ({} >= {}) found during static analysis",
|
||||
n, size
|
||||
n_as_usize, size
|
||||
);
|
||||
}
|
||||
}
|
||||
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
|
||||
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
|
||||
match self.constants.get(&TypedAssignee::Select(
|
||||
box TypedAssignee::Identifier(Variable::array(
|
||||
id.clone(),
|
||||
inner_type.clone(),
|
||||
size,
|
||||
)),
|
||||
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Number(n.clone()).into(),
|
||||
)) {
|
||||
Some(e) => match e {
|
||||
TypedExpression::Boolean(e) => e.clone(),
|
||||
_ => unreachable!(""),
|
||||
_ => unreachable!("Should be a boolean"),
|
||||
},
|
||||
None => BooleanExpression::Select(
|
||||
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
|
||||
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Number(n),
|
||||
),
|
||||
}
|
||||
}
|
||||
(a, i) => BooleanExpression::Select(
|
||||
box a.annotate(inner_type, size),
|
||||
box i.annotate(UBitwidth::B32),
|
||||
),
|
||||
(a, i) => BooleanExpression::Select(box a.annotate(inner_type, size), box i),
|
||||
}
|
||||
}
|
||||
BooleanExpression::Member(box s, m) => {
|
||||
|
@ -1291,7 +1282,10 @@ mod tests {
|
|||
FieldElementExpression::Number(Bn128Field::from(3)).into(),
|
||||
])
|
||||
.annotate(Type::FieldElement, 3),
|
||||
box UExpression::add(1u32.into(), 1u32.into()),
|
||||
box FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
|
@ -1397,12 +1391,12 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn lt() {
|
||||
let e_true = BooleanExpression::FieldLt(
|
||||
let e_true = BooleanExpression::Lt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
);
|
||||
|
||||
let e_false = BooleanExpression::FieldLt(
|
||||
let e_false = BooleanExpression::Lt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
);
|
||||
|
@ -1419,12 +1413,12 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn le() {
|
||||
let e_true = BooleanExpression::FieldLe(
|
||||
let e_true = BooleanExpression::Le(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
);
|
||||
|
||||
let e_false = BooleanExpression::FieldLe(
|
||||
let e_false = BooleanExpression::Le(
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
);
|
||||
|
@ -1441,12 +1435,12 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn gt() {
|
||||
let e_true = BooleanExpression::FieldGt(
|
||||
let e_true = BooleanExpression::Gt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(5)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
);
|
||||
|
||||
let e_false = BooleanExpression::FieldGt(
|
||||
let e_false = BooleanExpression::Gt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(5)),
|
||||
);
|
||||
|
@ -1463,12 +1457,12 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn ge() {
|
||||
let e_true = BooleanExpression::FieldGe(
|
||||
let e_true = BooleanExpression::Ge(
|
||||
box FieldElementExpression::Number(Bn128Field::from(5)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(5)),
|
||||
);
|
||||
|
||||
let e_false = BooleanExpression::FieldGe(
|
||||
let e_false = BooleanExpression::Ge(
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(5)),
|
||||
);
|
||||
|
|
|
@ -23,6 +23,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder {
|
|||
.iter()
|
||||
.zip(ret_identifiers.iter())
|
||||
.map(|(e, i)| match e.get_type() {
|
||||
Type::Int => unreachable!(),
|
||||
Type::FieldElement => FieldElementExpression::Identifier(i.clone()).into(),
|
||||
Type::Boolean => BooleanExpression::Identifier(i.clone()).into(),
|
||||
Type::Array(array_type) => ArrayExpressionInner::Identifier(i.clone())
|
||||
|
@ -34,7 +35,6 @@ impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder {
|
|||
Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone())
|
||||
.annotate(bitwidth)
|
||||
.into(),
|
||||
Type::Int => unreachable!(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
|
@ -85,10 +85,9 @@ impl<'ast> Unroller<'ast> {
|
|||
match head {
|
||||
Access::Select(head) => {
|
||||
statements.insert(TypedStatement::Assertion(
|
||||
BooleanExpression::UintLt(
|
||||
BooleanExpression::Lt(
|
||||
box head.clone(),
|
||||
box UExpressionInner::Value(size as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Number(T::from(size)),
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
|
@ -98,16 +97,14 @@ impl<'ast> Unroller<'ast> {
|
|||
.map(|i| match inner_ty {
|
||||
Type::Int => unreachable!(),
|
||||
Type::Array(..) => ArrayExpression::if_else(
|
||||
BooleanExpression::UintEq(
|
||||
box UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(T::from(i)),
|
||||
box head.clone(),
|
||||
),
|
||||
match Self::choose_many(
|
||||
ArrayExpression::select(
|
||||
base.clone(),
|
||||
UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
)
|
||||
.into(),
|
||||
tail.clone(),
|
||||
|
@ -122,22 +119,19 @@ impl<'ast> Unroller<'ast> {
|
|||
},
|
||||
ArrayExpression::select(
|
||||
base.clone(),
|
||||
UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
),
|
||||
)
|
||||
.into(),
|
||||
Type::Struct(..) => StructExpression::if_else(
|
||||
BooleanExpression::UintEq(
|
||||
box UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(T::from(i)),
|
||||
box head.clone(),
|
||||
),
|
||||
match Self::choose_many(
|
||||
StructExpression::select(
|
||||
base.clone(),
|
||||
UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
)
|
||||
.into(),
|
||||
tail.clone(),
|
||||
|
@ -152,22 +146,19 @@ impl<'ast> Unroller<'ast> {
|
|||
},
|
||||
StructExpression::select(
|
||||
base.clone(),
|
||||
UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
),
|
||||
)
|
||||
.into(),
|
||||
Type::FieldElement => FieldElementExpression::if_else(
|
||||
BooleanExpression::UintEq(
|
||||
box UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(T::from(i)),
|
||||
box head.clone(),
|
||||
),
|
||||
match Self::choose_many(
|
||||
FieldElementExpression::select(
|
||||
base.clone(),
|
||||
UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
)
|
||||
.into(),
|
||||
tail.clone(),
|
||||
|
@ -182,22 +173,19 @@ impl<'ast> Unroller<'ast> {
|
|||
},
|
||||
FieldElementExpression::select(
|
||||
base.clone(),
|
||||
UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
),
|
||||
)
|
||||
.into(),
|
||||
Type::Boolean => BooleanExpression::if_else(
|
||||
BooleanExpression::UintEq(
|
||||
box UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(T::from(i)),
|
||||
box head.clone(),
|
||||
),
|
||||
match Self::choose_many(
|
||||
BooleanExpression::select(
|
||||
base.clone(),
|
||||
UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
)
|
||||
.into(),
|
||||
tail.clone(),
|
||||
|
@ -212,22 +200,19 @@ impl<'ast> Unroller<'ast> {
|
|||
},
|
||||
BooleanExpression::select(
|
||||
base.clone(),
|
||||
UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
),
|
||||
)
|
||||
.into(),
|
||||
Type::Uint(..) => UExpression::if_else(
|
||||
BooleanExpression::UintEq(
|
||||
box UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(T::from(i)),
|
||||
box head.clone(),
|
||||
),
|
||||
match Self::choose_many(
|
||||
UExpression::select(
|
||||
base.clone(),
|
||||
UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
)
|
||||
.into(),
|
||||
tail.clone(),
|
||||
|
@ -242,8 +227,7 @@ impl<'ast> Unroller<'ast> {
|
|||
},
|
||||
UExpression::select(
|
||||
base.clone(),
|
||||
UExpressionInner::Value(i as u128)
|
||||
.annotate(UBitwidth::B32),
|
||||
FieldElementExpression::Number(T::from(i)),
|
||||
),
|
||||
)
|
||||
.into(),
|
||||
|
@ -376,7 +360,7 @@ impl<'ast> Unroller<'ast> {
|
|||
|
||||
#[derive(Clone, Debug)]
|
||||
enum Access<'ast, T: Field> {
|
||||
Select(UExpression<'ast, T>),
|
||||
Select(FieldElementExpression<'ast, T>),
|
||||
Member(MemberId),
|
||||
}
|
||||
/// Turn an assignee into its representation as a base variable and a list accesses
|
||||
|
@ -437,7 +421,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
|
|||
let indices = indices
|
||||
.into_iter()
|
||||
.map(|a| match a {
|
||||
Access::Select(i) => Access::Select(self.fold_uint_expression(i)),
|
||||
Access::Select(i) => Access::Select(self.fold_field_expression(i)),
|
||||
a => a,
|
||||
})
|
||||
.collect();
|
||||
|
@ -463,16 +447,16 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
|
|||
vec![TypedStatement::MultipleDefinition(variables, exprs)]
|
||||
}
|
||||
TypedStatement::For(v, from, to, stats) => {
|
||||
let from = self.fold_uint_expression(from);
|
||||
let to = self.fold_uint_expression(to);
|
||||
let from = self.fold_field_expression(from);
|
||||
let to = self.fold_field_expression(to);
|
||||
|
||||
match (from.into_inner(), to.into_inner()) {
|
||||
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => {
|
||||
let mut values = vec![];
|
||||
match (from, to) {
|
||||
(FieldElementExpression::Number(from), FieldElementExpression::Number(to)) => {
|
||||
let mut values: Vec<T> = vec![];
|
||||
let mut current = from;
|
||||
while current < to {
|
||||
values.push(current.clone());
|
||||
current = 1 + ¤t;
|
||||
current = T::one() + ¤t;
|
||||
}
|
||||
|
||||
let res = values
|
||||
|
@ -483,9 +467,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
|
|||
TypedStatement::Declaration(v.clone()),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(v.clone()),
|
||||
UExpressionInner::Value(index)
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
FieldElementExpression::Number(index).into(),
|
||||
),
|
||||
],
|
||||
stats.clone(),
|
||||
|
@ -501,12 +483,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
|
|||
}
|
||||
(from, to) => {
|
||||
self.complete = false;
|
||||
vec![TypedStatement::For(
|
||||
v,
|
||||
from.annotate(UBitwidth::B32),
|
||||
to.annotate(UBitwidth::B32),
|
||||
stats,
|
||||
)]
|
||||
vec![TypedStatement::For(v, from, to, stats)]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -538,11 +515,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn ssa_array() {
|
||||
let a0 = ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 3usize);
|
||||
let a0 = ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 3);
|
||||
|
||||
let e = FieldElementExpression::Number(Bn128Field::from(42)).into();
|
||||
|
||||
let index = 1u32.into();
|
||||
let index = FieldElementExpression::Number(Bn128Field::from(1));
|
||||
|
||||
let a1 = Unroller::choose_many(
|
||||
a0.clone().into(),
|
||||
|
@ -558,34 +535,52 @@ mod tests {
|
|||
a1,
|
||||
ArrayExpressionInner::Value(vec![
|
||||
FieldElementExpression::if_else(
|
||||
BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(0)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
FieldElementExpression::Number(Bn128Field::from(42)),
|
||||
FieldElementExpression::select(a0.clone(), 0u32.into(),)
|
||||
FieldElementExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(0))
|
||||
)
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::if_else(
|
||||
BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
FieldElementExpression::Number(Bn128Field::from(42)),
|
||||
FieldElementExpression::select(a0.clone(), 1u32.into())
|
||||
FieldElementExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(1))
|
||||
)
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::if_else(
|
||||
BooleanExpression::UintEq(box 2u32.into(), box 1u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
FieldElementExpression::Number(Bn128Field::from(42)),
|
||||
FieldElementExpression::select(a0.clone(), 2u32.into())
|
||||
FieldElementExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2))
|
||||
)
|
||||
)
|
||||
.into()
|
||||
])
|
||||
.annotate(Type::FieldElement, 3usize)
|
||||
.annotate(Type::FieldElement, 3)
|
||||
.into()
|
||||
);
|
||||
|
||||
let a0: ArrayExpression<Bn128Field> = ArrayExpressionInner::Identifier("a".into())
|
||||
let a0 = ArrayExpressionInner::Identifier("a".into())
|
||||
.annotate(Type::array(Type::FieldElement, 3), 3);
|
||||
|
||||
let e = ArrayExpressionInner::Identifier("b".into()).annotate(Type::FieldElement, 3);
|
||||
|
||||
let index = 1u32.into();
|
||||
let index = FieldElementExpression::Number(Bn128Field::from(1));
|
||||
|
||||
let a1 = Unroller::choose_many(
|
||||
a0.clone().into(),
|
||||
|
@ -601,21 +596,39 @@ mod tests {
|
|||
a1,
|
||||
ArrayExpressionInner::Value(vec![
|
||||
ArrayExpression::if_else(
|
||||
BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(0)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
e.clone(),
|
||||
ArrayExpression::select(a0.clone(), 0u32.into())
|
||||
ArrayExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(0))
|
||||
)
|
||||
)
|
||||
.into(),
|
||||
ArrayExpression::if_else(
|
||||
BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
e.clone(),
|
||||
ArrayExpression::select(a0.clone(), 1u32.into())
|
||||
ArrayExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(1))
|
||||
)
|
||||
)
|
||||
.into(),
|
||||
ArrayExpression::if_else(
|
||||
BooleanExpression::UintEq(box 2u32.into(), box 1u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
e.clone(),
|
||||
ArrayExpression::select(a0.clone(), 2u32.into())
|
||||
ArrayExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2))
|
||||
)
|
||||
)
|
||||
.into()
|
||||
])
|
||||
|
@ -628,7 +641,10 @@ mod tests {
|
|||
|
||||
let e = FieldElementExpression::Number(Bn128Field::from(42));
|
||||
|
||||
let indices = vec![Access::Select(0u32.into()), Access::Select(0u32.into())];
|
||||
let indices = vec![
|
||||
Access::Select(FieldElementExpression::Number(Bn128Field::from(0))),
|
||||
Access::Select(FieldElementExpression::Number(Bn128Field::from(0))),
|
||||
];
|
||||
|
||||
let a1 = Unroller::choose_many(
|
||||
a0.clone().into(),
|
||||
|
@ -644,55 +660,91 @@ mod tests {
|
|||
a1,
|
||||
ArrayExpressionInner::Value(vec![
|
||||
ArrayExpression::if_else(
|
||||
BooleanExpression::UintEq(box 0u32.into(), box 0u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(0)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(0))
|
||||
),
|
||||
ArrayExpressionInner::Value(vec![
|
||||
FieldElementExpression::if_else(
|
||||
BooleanExpression::UintEq(box 0u32.into(), box 0u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(0)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(0))
|
||||
),
|
||||
e.clone(),
|
||||
FieldElementExpression::select(
|
||||
ArrayExpression::select(a0.clone(), 0u32.into()),
|
||||
0u32.into()
|
||||
ArrayExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(0))
|
||||
),
|
||||
FieldElementExpression::Number(Bn128Field::from(0))
|
||||
)
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::if_else(
|
||||
BooleanExpression::UintEq(box 1u32.into(), box 0u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(0))
|
||||
),
|
||||
e.clone(),
|
||||
FieldElementExpression::select(
|
||||
ArrayExpression::select(a0.clone(), 0u32.into()),
|
||||
1u32.into()
|
||||
ArrayExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(0))
|
||||
),
|
||||
FieldElementExpression::Number(Bn128Field::from(1))
|
||||
)
|
||||
)
|
||||
.into()
|
||||
])
|
||||
.annotate(Type::FieldElement, 2),
|
||||
ArrayExpression::select(a0.clone(), 0u32.into())
|
||||
ArrayExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(0))
|
||||
)
|
||||
)
|
||||
.into(),
|
||||
ArrayExpression::if_else(
|
||||
BooleanExpression::UintEq(box 1u32.into(), box 0u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(0))
|
||||
),
|
||||
ArrayExpressionInner::Value(vec![
|
||||
FieldElementExpression::if_else(
|
||||
BooleanExpression::UintEq(box 0u32.into(), box 0u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(0)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(0))
|
||||
),
|
||||
e.clone(),
|
||||
FieldElementExpression::select(
|
||||
ArrayExpression::select(a0.clone(), 1u32.into()),
|
||||
0u32.into()
|
||||
ArrayExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
FieldElementExpression::Number(Bn128Field::from(0))
|
||||
)
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::if_else(
|
||||
BooleanExpression::UintEq(box 1u32.into(), box 0u32.into()),
|
||||
BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(0))
|
||||
),
|
||||
e.clone(),
|
||||
FieldElementExpression::select(
|
||||
ArrayExpression::select(a0.clone(), 1u32.into()),
|
||||
1u32.into()
|
||||
ArrayExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
FieldElementExpression::Number(Bn128Field::from(1))
|
||||
)
|
||||
)
|
||||
.into()
|
||||
])
|
||||
.annotate(Type::FieldElement, 2),
|
||||
ArrayExpression::select(a0.clone(), 1u32.into())
|
||||
ArrayExpression::select(
|
||||
a0.clone(),
|
||||
FieldElementExpression::Number(Bn128Field::from(1))
|
||||
)
|
||||
)
|
||||
.into(),
|
||||
])
|
||||
|
@ -721,8 +773,8 @@ mod tests {
|
|||
|
||||
let s = TypedStatement::For(
|
||||
Variable::field_element("i"),
|
||||
2u32.into(),
|
||||
5u32.into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
FieldElementExpression::Number(Bn128Field::from(5)),
|
||||
vec![
|
||||
TypedStatement::Declaration(Variable::field_element("foo")),
|
||||
TypedStatement::Definition(
|
||||
|
@ -1031,7 +1083,7 @@ mod tests {
|
|||
let s: TypedStatement<Bn128Field> = TypedStatement::Definition(
|
||||
TypedAssignee::Select(
|
||||
box TypedAssignee::Identifier(Variable::field_array("a", 2)),
|
||||
box 1u32.into(),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
);
|
||||
|
@ -1040,7 +1092,11 @@ mod tests {
|
|||
u.fold_statement(s),
|
||||
vec![
|
||||
TypedStatement::Assertion(
|
||||
BooleanExpression::UintLt(box 1u32.into(), box 2u32.into()).into(),
|
||||
BooleanExpression::Lt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2))
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_array(
|
||||
|
@ -1049,26 +1105,32 @@ mod tests {
|
|||
)),
|
||||
ArrayExpressionInner::Value(vec![
|
||||
FieldElementExpression::IfElse(
|
||||
box BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()),
|
||||
box BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(0)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Select(
|
||||
box ArrayExpressionInner::Identifier(
|
||||
Identifier::from("a").version(0)
|
||||
)
|
||||
.annotate(Type::FieldElement, 2),
|
||||
box 0u32.into()
|
||||
box FieldElementExpression::Number(Bn128Field::from(0))
|
||||
),
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::IfElse(
|
||||
box BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()),
|
||||
box BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Select(
|
||||
box ArrayExpressionInner::Identifier(
|
||||
Identifier::from("a").version(0)
|
||||
)
|
||||
.annotate(Type::FieldElement, 2),
|
||||
box 1u32.into()
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
)
|
||||
.into(),
|
||||
|
@ -1153,7 +1215,7 @@ mod tests {
|
|||
"a",
|
||||
array_of_array_ty.clone(),
|
||||
)),
|
||||
box 1u32.into(),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
),
|
||||
ArrayExpressionInner::Value(vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(4)).into(),
|
||||
|
@ -1167,7 +1229,11 @@ mod tests {
|
|||
u.fold_statement(s),
|
||||
vec![
|
||||
TypedStatement::Assertion(
|
||||
BooleanExpression::UintLt(box 1u32.into(), box 2u32.into()).into(),
|
||||
BooleanExpression::Lt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2))
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::with_id_and_type(
|
||||
|
@ -1176,7 +1242,10 @@ mod tests {
|
|||
)),
|
||||
ArrayExpressionInner::Value(vec![
|
||||
ArrayExpressionInner::IfElse(
|
||||
box BooleanExpression::UintEq(box 0u32.into(), box 1u32.into()),
|
||||
box BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(0)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
box ArrayExpressionInner::Value(vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(4)).into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(5)).into(),
|
||||
|
@ -1188,14 +1257,17 @@ mod tests {
|
|||
Identifier::from("a").version(0)
|
||||
)
|
||||
.annotate(Type::array(Type::FieldElement, 2), 2),
|
||||
box 0u32.into()
|
||||
box FieldElementExpression::Number(Bn128Field::from(0))
|
||||
)
|
||||
.annotate(Type::FieldElement, 2),
|
||||
)
|
||||
.annotate(Type::FieldElement, 2)
|
||||
.into(),
|
||||
ArrayExpressionInner::IfElse(
|
||||
box BooleanExpression::UintEq(box 1u32.into(), box 1u32.into()),
|
||||
box BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
),
|
||||
box ArrayExpressionInner::Value(vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(4)).into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(5)).into(),
|
||||
|
@ -1207,7 +1279,7 @@ mod tests {
|
|||
Identifier::from("a").version(0)
|
||||
)
|
||||
.annotate(Type::array(Type::FieldElement, 2), 2),
|
||||
box 1u32.into()
|
||||
box FieldElementExpression::Number(Bn128Field::from(1))
|
||||
)
|
||||
.annotate(Type::FieldElement, 2),
|
||||
)
|
||||
|
|
|
@ -29,12 +29,10 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
|
|||
fn select<U: Select<'ast, T> + IfElse<'ast, T>>(
|
||||
&mut self,
|
||||
a: ArrayExpression<'ast, T>,
|
||||
i: UExpression<'ast, T>,
|
||||
i: FieldElementExpression<'ast, T>,
|
||||
) -> U {
|
||||
match i.into_inner() {
|
||||
UExpressionInner::Value(i) => {
|
||||
U::select(a, UExpressionInner::Value(i).annotate(UBitwidth::B32))
|
||||
}
|
||||
match i {
|
||||
FieldElementExpression::Number(i) => U::select(a, FieldElementExpression::Number(i)),
|
||||
i => {
|
||||
let size = match a.get_type().clone() {
|
||||
Type::Array(array_ty) => array_ty.size,
|
||||
|
@ -44,9 +42,9 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
|
|||
self.statements.push(TypedStatement::Assertion(
|
||||
(0..size)
|
||||
.map(|index| {
|
||||
BooleanExpression::UintEq(
|
||||
box i.clone().annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(index as u128).annotate(UBitwidth::B32),
|
||||
BooleanExpression::FieldEq(
|
||||
box i.clone(),
|
||||
box FieldElementExpression::Number(index.into()).into(),
|
||||
)
|
||||
})
|
||||
.fold(None, |acc, e| match acc {
|
||||
|
@ -58,19 +56,14 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
|
|||
));
|
||||
|
||||
(0..size)
|
||||
.map(|i| {
|
||||
U::select(
|
||||
a.clone(),
|
||||
UExpressionInner::Value(i as u128).annotate(UBitwidth::B32),
|
||||
)
|
||||
})
|
||||
.map(|i| U::select(a.clone(), FieldElementExpression::Number(i.into())))
|
||||
.enumerate()
|
||||
.rev()
|
||||
.fold(None, |acc, (index, res)| match acc {
|
||||
Some(acc) => Some(U::if_else(
|
||||
BooleanExpression::UintEq(
|
||||
box i.clone().annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(index as u128).annotate(UBitwidth::B32),
|
||||
BooleanExpression::FieldEq(
|
||||
box i.clone(),
|
||||
box FieldElementExpression::Number(index.into()),
|
||||
),
|
||||
res,
|
||||
acc,
|
||||
|
@ -168,7 +161,7 @@ mod tests {
|
|||
TypedAssignee::Identifier(Variable::field_element("b")),
|
||||
FieldElementExpression::Select(
|
||||
box ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 2),
|
||||
box UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32),
|
||||
box FieldElementExpression::Identifier("i".into()),
|
||||
)
|
||||
.into(),
|
||||
);
|
||||
|
@ -199,12 +192,12 @@ mod tests {
|
|||
FieldElementExpression::Select(
|
||||
box ArrayExpressionInner::Identifier("a".into())
|
||||
.annotate(Type::FieldElement, 2),
|
||||
box 0u32.into(),
|
||||
box FieldElementExpression::Number(0.into()),
|
||||
),
|
||||
FieldElementExpression::Select(
|
||||
box ArrayExpressionInner::Identifier("a".into())
|
||||
.annotate(Type::FieldElement, 2),
|
||||
box 1u32.into(),
|
||||
box FieldElementExpression::Number(1.into()),
|
||||
)
|
||||
)
|
||||
.into()
|
||||
|
|
|
@ -46,7 +46,7 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)),
|
||||
TypedAssignee::Select(box a, box index) => TypedAssignee::Select(
|
||||
box self.fold_assignee(a),
|
||||
box self.fold_uint_expression(index),
|
||||
box self.fold_field_expression(index),
|
||||
),
|
||||
TypedAssignee::Member(box s, m) => TypedAssignee::Member(box self.fold_assignee(s), m),
|
||||
}
|
||||
|
@ -216,7 +216,7 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
ArrayExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_uint_expression(index);
|
||||
let index = f.fold_field_expression(index);
|
||||
ArrayExpressionInner::Select(box array, box index)
|
||||
}
|
||||
}
|
||||
|
@ -249,7 +249,7 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
StructExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_uint_expression(index);
|
||||
let index = f.fold_field_expression(index);
|
||||
StructExpressionInner::Select(box array, box index)
|
||||
}
|
||||
}
|
||||
|
@ -305,7 +305,7 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
FieldElementExpression::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_uint_expression(index);
|
||||
let index = f.fold_field_expression(index);
|
||||
FieldElementExpression::Select(box array, box index)
|
||||
}
|
||||
}
|
||||
|
@ -350,45 +350,25 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
BooleanExpression::Lt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldLt(box e1, box e2)
|
||||
BooleanExpression::Lt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
BooleanExpression::Le(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldLe(box e1, box e2)
|
||||
BooleanExpression::Le(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
BooleanExpression::Gt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldGt(box e1, box e2)
|
||||
BooleanExpression::Gt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
BooleanExpression::Ge(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintGe(box e1, box e2)
|
||||
BooleanExpression::Ge(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::Or(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1);
|
||||
|
@ -420,7 +400,7 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
BooleanExpression::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_uint_expression(index);
|
||||
let index = f.fold_field_expression(index);
|
||||
BooleanExpression::Select(box array, box index)
|
||||
}
|
||||
}
|
||||
|
@ -503,7 +483,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
UExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_uint_expression(index);
|
||||
let index = f.fold_field_expression(index);
|
||||
UExpressionInner::Select(box array, box index)
|
||||
}
|
||||
UExpressionInner::IfElse(box cond, box cons, box alt) => {
|
||||
|
|
|
@ -256,7 +256,10 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunction<'ast, T> {
|
|||
#[derive(Clone, PartialEq, Hash, Eq)]
|
||||
pub enum TypedAssignee<'ast, T> {
|
||||
Identifier(Variable<'ast>),
|
||||
Select(Box<TypedAssignee<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Select(
|
||||
Box<TypedAssignee<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
Member(Box<TypedAssignee<'ast, T>>, MemberId),
|
||||
}
|
||||
|
||||
|
@ -316,8 +319,8 @@ pub enum TypedStatement<'ast, T> {
|
|||
Assertion(BooleanExpression<'ast, T>),
|
||||
For(
|
||||
Variable<'ast>,
|
||||
UExpression<'ast, T>,
|
||||
UExpression<'ast, T>,
|
||||
FieldElementExpression<'ast, T>,
|
||||
FieldElementExpression<'ast, T>,
|
||||
Vec<TypedStatement<'ast, T>>,
|
||||
),
|
||||
MultipleDefinition(Vec<Variable<'ast>>, TypedExpressionList<'ast, T>),
|
||||
|
@ -424,44 +427,85 @@ pub enum TypedExpression<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, T: Field> TypedExpression<'ast, T> {
|
||||
// return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types.
|
||||
// return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types if possible.
|
||||
// Post condition is that (lhs, rhs) cannot be made equal by further removing IntExpressions
|
||||
pub fn align_without_integers(
|
||||
lhs: Self,
|
||||
rhs: Self,
|
||||
) -> Result<(TypedExpression<'ast, T>, TypedExpression<'ast, T>), String> {
|
||||
) -> Result<
|
||||
(TypedExpression<'ast, T>, TypedExpression<'ast, T>),
|
||||
(TypedExpression<'ast, T>, TypedExpression<'ast, T>),
|
||||
> {
|
||||
use self::TypedExpression::*;
|
||||
|
||||
match (lhs, rhs) {
|
||||
(Int(lhs), FieldElement(rhs)) => Ok((
|
||||
FieldElementExpression::try_from_int(lhs)?.into(),
|
||||
FieldElementExpression::try_from_int(lhs)
|
||||
.map_err(|lhs| (lhs.into(), rhs.clone().into()))?
|
||||
.into(),
|
||||
FieldElement(rhs),
|
||||
)),
|
||||
(FieldElement(lhs), Int(rhs)) => Ok((
|
||||
FieldElement(lhs),
|
||||
FieldElementExpression::try_from_int(rhs)?.into(),
|
||||
FieldElement(lhs.clone()),
|
||||
FieldElementExpression::try_from_int(rhs)
|
||||
.map_err(|rhs| (lhs.into(), rhs.into()))?
|
||||
.into(),
|
||||
)),
|
||||
(Int(lhs), Uint(rhs)) => Ok((
|
||||
UExpression::try_from_int(lhs, rhs.bitwidth())?.into(),
|
||||
UExpression::try_from_int(lhs, rhs.bitwidth())
|
||||
.map_err(|lhs| (lhs.into(), rhs.clone().into()))?
|
||||
.into(),
|
||||
Uint(rhs),
|
||||
)),
|
||||
(Uint(lhs), Int(rhs)) => {
|
||||
let bitwidth = lhs.bitwidth();
|
||||
Ok((Uint(lhs), UExpression::try_from_int(rhs, bitwidth)?.into()))
|
||||
Ok((
|
||||
Uint(lhs.clone()),
|
||||
UExpression::try_from_int(rhs, bitwidth)
|
||||
.map_err(|rhs| (lhs.into(), rhs.into()))?
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
(lhs, rhs) => Ok((lhs, rhs)),
|
||||
(Array(lhs), Array(rhs)) => {
|
||||
if lhs.get_type() == rhs.get_type() {
|
||||
Ok((Array(lhs), Array(rhs)))
|
||||
} else {
|
||||
Err((Array(lhs), Array(rhs)))
|
||||
}
|
||||
}
|
||||
(Struct(lhs), Struct(rhs)) => {
|
||||
if lhs.get_type() == rhs.get_type() {
|
||||
Ok((Struct(lhs).into(), Struct(rhs).into()))
|
||||
} else {
|
||||
Err((Struct(lhs).into(), Struct(rhs).into()))
|
||||
}
|
||||
}
|
||||
(Uint(lhs), Uint(rhs)) => Ok((lhs.into(), rhs.into())),
|
||||
(Boolean(lhs), Boolean(rhs)) => Ok((lhs.into(), rhs.into())),
|
||||
(FieldElement(lhs), FieldElement(rhs)) => Ok((lhs.into(), rhs.into())),
|
||||
(Int(lhs), Int(rhs)) => Ok((lhs.into(), rhs.into())),
|
||||
(lhs, rhs) => Err((lhs, rhs)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn align_to_type(e: Self, ty: Type) -> Result<Self, String> {
|
||||
use self::TypedExpression::*;
|
||||
match (e, ty) {
|
||||
(Int(e), Type::FieldElement) => Ok(FieldElementExpression::try_from_int(e)?.into()),
|
||||
(Int(e), Type::Uint(bitwidth)) => Ok(UExpression::try_from_int(e, bitwidth)?.into()),
|
||||
(Array(a), Type::Array(array_ty)) => Ok(ArrayExpression::try_from_int(a, array_ty)
|
||||
.map_err(|_| String::from("align array to type"))?
|
||||
.into()),
|
||||
(e, _) => Ok(e),
|
||||
pub fn align_to_type(e: Self, ty: Type) -> Result<Self, (Self, Type)> {
|
||||
match ty.clone() {
|
||||
Type::FieldElement => {
|
||||
FieldElementExpression::try_from_typed(e).map(TypedExpression::from)
|
||||
}
|
||||
Type::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from),
|
||||
Type::Uint(bitwidth) => {
|
||||
UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from)
|
||||
}
|
||||
Type::Array(array_ty) => {
|
||||
ArrayExpression::try_from_typed(e, array_ty).map(TypedExpression::from)
|
||||
}
|
||||
Type::Struct(struct_ty) => {
|
||||
StructExpression::try_from_typed(e, struct_ty).map(TypedExpression::from)
|
||||
}
|
||||
Type::Int => Err(e),
|
||||
}
|
||||
.map_err(|e| (e, ty))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -478,7 +522,10 @@ pub enum IntExpression<'ast, T> {
|
|||
Box<IntExpression<'ast, T>>,
|
||||
Box<IntExpression<'ast, T>>,
|
||||
),
|
||||
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Select(
|
||||
Box<ArrayExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
Xor(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
||||
And(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
||||
Or(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
||||
|
@ -734,7 +781,10 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
),
|
||||
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
|
||||
Member(Box<StructExpression<'ast, T>>, MemberId),
|
||||
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Select(
|
||||
Box<ArrayExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> BooleanExpression<'ast, T> {
|
||||
|
@ -746,6 +796,12 @@ impl<'ast, T: Field> BooleanExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<T> for FieldElementExpression<'ast, T> {
|
||||
fn from(n: T) -> Self {
|
||||
FieldElementExpression::Number(n)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
||||
pub fn try_from_typed(e: TypedExpression<'ast, T>) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
match e {
|
||||
|
@ -757,13 +813,13 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn try_from_int(i: IntExpression<'ast, T>) -> Result<Self, String> {
|
||||
pub fn try_from_int(i: IntExpression<'ast, T>) -> Result<Self, IntExpression<'ast, T>> {
|
||||
match i {
|
||||
IntExpression::Value(i) => {
|
||||
if i <= T::max_value().to_biguint() {
|
||||
Ok(Self::Number(T::from(i)))
|
||||
} else {
|
||||
Err(format!("Literal `{} is too large for type `field`", i))
|
||||
Err(IntExpression::Value(i))
|
||||
}
|
||||
}
|
||||
IntExpression::Add(box e1, box e2) => Ok(Self::Add(
|
||||
|
@ -794,7 +850,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
|||
))
|
||||
}
|
||||
IntExpression::Select(..) => unimplemented!(),
|
||||
i => Err(format!("Expected a `field` but found expression `{}`", i)),
|
||||
i => Err(i),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -804,26 +860,22 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
|||
pub enum BooleanExpression<'ast, T> {
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(bool),
|
||||
FieldLt(
|
||||
Lt(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldLe(
|
||||
Le(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldGe(
|
||||
Ge(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldGt(
|
||||
Gt(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
UintLt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
FieldEq(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
|
@ -854,7 +906,10 @@ pub enum BooleanExpression<'ast, T> {
|
|||
),
|
||||
Member(Box<StructExpression<'ast, T>>, MemberId),
|
||||
FunctionCall(FunctionKey<'ast>, Vec<TypedExpression<'ast, T>>),
|
||||
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Select(
|
||||
Box<ArrayExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
}
|
||||
|
||||
/// An expression of type `array`
|
||||
|
@ -880,7 +935,10 @@ pub enum ArrayExpressionInner<'ast, T> {
|
|||
Box<ArrayExpression<'ast, T>>,
|
||||
),
|
||||
Member(Box<StructExpression<'ast, T>>, MemberId),
|
||||
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Select(
|
||||
Box<ArrayExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
}
|
||||
|
||||
impl<'ast, T> ArrayExpressionInner<'ast, T> {
|
||||
|
@ -931,13 +989,16 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
|||
}
|
||||
|
||||
// precondition: `array` is only made of inline arrays
|
||||
pub fn try_from_int(array: Self, target_array_ty: ArrayType) -> Result<Self, ()> {
|
||||
if array.get_array_type() == target_array_ty {
|
||||
pub fn try_from_int(
|
||||
array: Self,
|
||||
target_array_ty: ArrayType,
|
||||
) -> Result<Self, TypedExpression<'ast, T>> {
|
||||
let array_ty = array.get_array_type();
|
||||
|
||||
if array_ty == target_array_ty {
|
||||
return Ok(array);
|
||||
}
|
||||
|
||||
let array_ty = array.get_array_type();
|
||||
|
||||
// sizes must be equal
|
||||
match target_array_ty.size == array_ty.size {
|
||||
true =>
|
||||
|
@ -946,56 +1007,67 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> {
|
|||
match array.into_inner() {
|
||||
ArrayExpressionInner::Value(inline_array) => {
|
||||
match *target_array_ty.ty {
|
||||
Type::Int => Ok(ArrayExpressionInner::Value(inline_array)
|
||||
.annotate(*target_array_ty.ty, array_ty.size)),
|
||||
Type::FieldElement => {
|
||||
// try to convert all elements to field
|
||||
let converted = inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
let int = IntExpression::try_from(e)?;
|
||||
let field = FieldElementExpression::try_from_int(int)
|
||||
.map_err(|_| ())?;
|
||||
Ok(field.into())
|
||||
FieldElementExpression::try_from_typed(e)
|
||||
.map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, ()>>()?;
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)?;
|
||||
Ok(ArrayExpressionInner::Value(converted)
|
||||
.annotate(*target_array_ty.ty, array_ty.size))
|
||||
}
|
||||
Type::Uint(bitwidth) => {
|
||||
// try to convert all elements to field
|
||||
// try to convert all elements to uint
|
||||
let converted = inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
let int = IntExpression::try_from(e)?;
|
||||
let field = UExpression::try_from_int(int, bitwidth)
|
||||
.map_err(|_| ())?;
|
||||
Ok(field.into())
|
||||
UExpression::try_from_typed(e, bitwidth)
|
||||
.map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, ()>>()?;
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)?;
|
||||
Ok(ArrayExpressionInner::Value(converted)
|
||||
.annotate(*target_array_ty.ty, array_ty.size))
|
||||
}
|
||||
Type::Array(ref array_ty) => {
|
||||
// try to convert all elements to field
|
||||
Type::Array(ref inner_array_ty) => {
|
||||
// try to convert all elements to uint
|
||||
let converted = inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
let array = ArrayExpression::try_from(e)?;
|
||||
let array =
|
||||
ArrayExpression::try_from_int(array, array_ty.clone())
|
||||
.map_err(|_| ())?;
|
||||
Ok(array.into())
|
||||
ArrayExpression::try_from_typed(e, inner_array_ty.clone())
|
||||
.map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, ()>>()?;
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)?;
|
||||
Ok(ArrayExpressionInner::Value(converted)
|
||||
.annotate(*target_array_ty.ty.clone(), array_ty.size.clone()))
|
||||
.annotate(*target_array_ty.ty, array_ty.size))
|
||||
}
|
||||
_ => Err(()),
|
||||
Type::Struct(ref struct_ty) => {
|
||||
// try to convert all elements to uint
|
||||
let converted = inline_array
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
StructExpression::try_from_typed(e, struct_ty.clone())
|
||||
.map(TypedExpression::from)
|
||||
})
|
||||
.collect::<Result<Vec<TypedExpression<'ast, T>>, _>>()
|
||||
.map_err(TypedExpression::from)?;
|
||||
Ok(ArrayExpressionInner::Value(converted)
|
||||
.annotate(*target_array_ty.ty, array_ty.size))
|
||||
}
|
||||
Type::Boolean => unreachable!(),
|
||||
}
|
||||
}
|
||||
_ => unreachable!(""),
|
||||
}
|
||||
}
|
||||
false => Err(()),
|
||||
false => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1047,7 +1119,10 @@ pub enum StructExpressionInner<'ast, T> {
|
|||
Box<StructExpression<'ast, T>>,
|
||||
),
|
||||
Member(Box<StructExpression<'ast, T>>, MemberId),
|
||||
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Select(
|
||||
Box<ArrayExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
}
|
||||
|
||||
impl<'ast, T> StructExpressionInner<'ast, T> {
|
||||
|
@ -1165,14 +1240,10 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
|
||||
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::ArrayEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
|
@ -1247,30 +1318,10 @@ impl<'ast, T: fmt::Debug> fmt::Debug for BooleanExpression<'ast, T> {
|
|||
"IfElse({:?}, {:?}, {:?})",
|
||||
condition, consequent, alternative
|
||||
),
|
||||
BooleanExpression::FieldLt(ref lhs, ref rhs) => {
|
||||
write!(f, "FieldLt({:?}, {:?})", lhs, rhs)
|
||||
}
|
||||
BooleanExpression::FieldLe(ref lhs, ref rhs) => {
|
||||
write!(f, "FieldLe({:?}, {:?})", lhs, rhs)
|
||||
}
|
||||
BooleanExpression::FieldGe(ref lhs, ref rhs) => {
|
||||
write!(f, "FieldGe({:?}, {:?})", lhs, rhs)
|
||||
}
|
||||
BooleanExpression::FieldGt(ref lhs, ref rhs) => {
|
||||
write!(f, "FieldGt({:?}, {:?})", lhs, rhs)
|
||||
}
|
||||
BooleanExpression::UintLt(ref lhs, ref rhs) => {
|
||||
write!(f, "UintLt({:?}, {:?})", lhs, rhs)
|
||||
}
|
||||
BooleanExpression::UintLe(ref lhs, ref rhs) => {
|
||||
write!(f, "UintLe({:?}, {:?})", lhs, rhs)
|
||||
}
|
||||
BooleanExpression::UintGe(ref lhs, ref rhs) => {
|
||||
write!(f, "UintGe({:?}, {:?})", lhs, rhs)
|
||||
}
|
||||
BooleanExpression::UintGt(ref lhs, ref rhs) => {
|
||||
write!(f, "UintGt({:?}, {:?})", lhs, rhs)
|
||||
}
|
||||
BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "Lt({:?}, {:?})", lhs, rhs),
|
||||
BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "Le({:?}, {:?})", lhs, rhs),
|
||||
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "Ge({:?}, {:?})", lhs, rhs),
|
||||
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "Gt({:?}, {:?})", lhs, rhs),
|
||||
BooleanExpression::FieldEq(ref lhs, ref rhs) => {
|
||||
write!(f, "FieldEq({:?}, {:?})", lhs, rhs)
|
||||
}
|
||||
|
@ -1483,23 +1534,23 @@ impl<'ast, T: Clone> IfElse<'ast, T> for StructExpression<'ast, T> {
|
|||
}
|
||||
|
||||
pub trait Select<'ast, T> {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self;
|
||||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self;
|
||||
}
|
||||
|
||||
impl<'ast, T> Select<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
|
||||
FieldElementExpression::Select(box array, box index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
|
||||
BooleanExpression::Select(box array, box index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Select<'ast, T> for UExpression<'ast, T> {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
|
||||
let bitwidth = match array.inner_type().clone() {
|
||||
Type::Uint(bitwidth) => bitwidth,
|
||||
_ => unreachable!(),
|
||||
|
@ -1510,7 +1561,7 @@ impl<'ast, T: Clone> Select<'ast, T> for UExpression<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, T: Clone> Select<'ast, T> for ArrayExpression<'ast, T> {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
|
||||
let (ty, size) = match array.inner_type() {
|
||||
Type::Array(array_type) => (array_type.ty.clone(), array_type.size.clone()),
|
||||
_ => unreachable!(),
|
||||
|
@ -1521,7 +1572,7 @@ impl<'ast, T: Clone> Select<'ast, T> for ArrayExpression<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, T: Clone> Select<'ast, T> for StructExpression<'ast, T> {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self {
|
||||
fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self {
|
||||
let members = match array.inner_type().clone() {
|
||||
Type::Struct(members) => members,
|
||||
_ => unreachable!(),
|
||||
|
|
|
@ -190,6 +190,20 @@ impl Type {
|
|||
Type::Uint(b.into())
|
||||
}
|
||||
|
||||
pub fn can_be_specialized_to(&self, other: &Type) -> bool {
|
||||
use self::Type::*;
|
||||
|
||||
if self == other {
|
||||
true
|
||||
} else {
|
||||
match (self, other) {
|
||||
(Int, FieldElement) | (Int, Uint(..)) => true,
|
||||
(Array(l), Array(r)) => l.size == r.size && l.ty.can_be_specialized_to(&r.ty),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_slug(&self) -> String {
|
||||
match self {
|
||||
Type::FieldElement => String::from("f"),
|
||||
|
|
|
@ -72,7 +72,10 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn try_from_int(i: IntExpression<'ast, T>, bitwidth: UBitwidth) -> Result<Self, String> {
|
||||
pub fn try_from_int(
|
||||
i: IntExpression<'ast, T>,
|
||||
bitwidth: UBitwidth,
|
||||
) -> Result<Self, IntExpression<'ast, T>> {
|
||||
use self::IntExpression::*;
|
||||
|
||||
match i {
|
||||
|
@ -83,10 +86,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
)
|
||||
.annotate(bitwidth))
|
||||
} else {
|
||||
Err(format!(
|
||||
"Literal `{}` is too large for type u{}",
|
||||
i, bitwidth
|
||||
))
|
||||
Err(Value(i))
|
||||
}
|
||||
}
|
||||
Add(box e1, box e2) => Ok(UExpression::add(
|
||||
|
@ -127,10 +127,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
Self::try_from_int(alternative, bitwidth)?,
|
||||
)),
|
||||
Select(..) => unimplemented!(),
|
||||
i => Err(format!(
|
||||
"Expected a `u{}` but found expression `{}`",
|
||||
bitwidth, i
|
||||
)),
|
||||
i => Err(i),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -192,7 +189,10 @@ pub enum UExpressionInner<'ast, T> {
|
|||
Box<UExpression<'ast, T>>,
|
||||
),
|
||||
Member(Box<StructExpression<'ast, T>>, MemberId),
|
||||
Select(Box<ArrayExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Select(
|
||||
Box<ArrayExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
||||
|
|
|
@ -204,45 +204,25 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
BooleanExpression::Lt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldLt(box e1, box e2)
|
||||
BooleanExpression::Lt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
BooleanExpression::Le(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldLe(box e1, box e2)
|
||||
BooleanExpression::Le(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
BooleanExpression::Gt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldGt(box e1, box e2)
|
||||
BooleanExpression::Gt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
BooleanExpression::Ge(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintGe(box e1, box e2)
|
||||
BooleanExpression::Ge(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::Or(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1);
|
||||
|
|
|
@ -381,26 +381,22 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
pub enum BooleanExpression<'ast, T> {
|
||||
Identifier(Identifier<'ast>),
|
||||
Value(bool),
|
||||
FieldLt(
|
||||
Lt(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldLe(
|
||||
Le(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldGe(
|
||||
Ge(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldGt(
|
||||
Gt(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
UintLt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
FieldEq(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
|
@ -510,14 +506,10 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
BooleanExpression::Identifier(ref var) => write!(f, "{}", var),
|
||||
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
|
|
|
@ -4,12 +4,12 @@ def main():
|
|||
assert(7 == 2 ** 2 * 2 - 1)
|
||||
assert(3 == 2 ** 2 / 2 + 1)
|
||||
|
||||
field a = if 3 == 2 ** 2 / 2 + 1 && true then 1 else 0 fi // combines arithmetic with boolean operators
|
||||
field b = if 3 == 3 && 4 < 5 then 1 else 0 fi // checks precedence of boolean operators
|
||||
field c = if 4 < 5 && 3 == 3 then 1 else 0 fi
|
||||
field d = if 4 > 5 && 2 >= 1 || 1 == 1 then 1 else 0 fi
|
||||
field e = if 2 >= 1 && 4 > 5 || 1 == 1 then 1 else 0 fi
|
||||
field f = if 1 < 2 && false || 4 < 5 && 2 >= 1 then 1 else 0 fi
|
||||
field a = if 3f == 2f ** 2 / 2 + 1 && true then 1 else 0 fi // combines arithmetic with boolean operators
|
||||
field b = if 3f == 3f && 4f < 5f then 1 else 0 fi // checks precedence of boolean operators
|
||||
field c = if 4f < 5f && 3f == 3f then 1 else 0 fi
|
||||
field d = if 4f > 5f && 2f >= 1f || 1f == 1f then 1 else 0 fi
|
||||
field e = if 2f >= 1f && 4f > 5f || 1f == 1f then 1 else 0 fi
|
||||
field f = if 1f < 2f && false || 4f < 5f && 2f >= 1f then 1 else 0 fi
|
||||
|
||||
assert(0x00 ^ 0x00 == 0x00)
|
||||
|
||||
|
|
|
@ -101,7 +101,9 @@ mod tests {
|
|||
term(43, 44, [
|
||||
primary_expression(43, 44, [
|
||||
literal(43, 44, [
|
||||
decimal_number(43, 44)
|
||||
decimal_literal(43, 44, [
|
||||
decimal_number(43, 44)
|
||||
])
|
||||
])
|
||||
])
|
||||
])
|
||||
|
@ -258,7 +260,9 @@ mod tests {
|
|||
term(30, 31, [
|
||||
primary_expression(30, 31, [
|
||||
literal(30, 31, [
|
||||
decimal_number(30, 31)
|
||||
decimal_literal(30, 31, [
|
||||
decimal_number(30, 31)
|
||||
])
|
||||
])
|
||||
])
|
||||
])
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
// Note: parameters will be updated soon to be more compatible with zCash's implementation
|
||||
|
||||
struct BabyJubJubParams {
|
||||
field JUBJUBE
|
||||
field JUBJUBC
|
||||
field JUBJUBA
|
||||
field JUBJUBD
|
||||
|
@ -18,7 +17,6 @@ struct BabyJubJubParams {
|
|||
def main() -> BabyJubJubParams:
|
||||
|
||||
// Order of the curve E
|
||||
field JUBJUBE = 21888242871839275222246405745257275088614511777268538073601725287587578984328
|
||||
field JUBJUBC = 8 // Cofactor
|
||||
field JUBJUBA = 168700 // Coefficient A
|
||||
field JUBJUBD = 168696 // Coefficient D
|
||||
|
@ -40,7 +38,6 @@ return BabyJubJubParams {
|
|||
INFINITY: INFINITY,
|
||||
Gu: Gu,
|
||||
Gv: Gv,
|
||||
JUBJUBE: JUBJUBE,
|
||||
JUBJUBC: JUBJUBC,
|
||||
MONTA: MONTA,
|
||||
MONTB: MONTB
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
import "hashes/mimc7/mimc7R10"
|
||||
|
||||
def main():
|
||||
assert(mimc7R10(0, 0) == 6004544488495356385698286530147974336054653445122716140990101827963729149289)
|
||||
assert(mimc7R10(100, 0) == 2977550761518141183167168643824354554080911485709001361112529600968315693145)
|
||||
assert(mimc7R10(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 2977550761518141183167168643824354554080911485709001361112529600968315693145)
|
||||
assert(mimc7R10(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 11476724043755138071320043459606423473319855817296339514744600646762741571430)
|
||||
assert(mimc7R10(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 6004544488495356385698286530147974336054653445122716140990101827963729149289)
|
||||
assert(mimc7R10(0, 0) == 6004544488495356385698286530147974336054653445122716140990101827963729149289f)
|
||||
assert(mimc7R10(100, 0) == 2977550761518141183167168643824354554080911485709001361112529600968315693145f)
|
||||
return
|
|
@ -1,9 +1,6 @@
|
|||
import "hashes/mimc7/mimc7R20"
|
||||
|
||||
def main():
|
||||
assert(mimc7R20(0, 0) == 19139739902058628561064841933381604453445216873412991992755775746150759284829)
|
||||
assert(mimc7R20(100, 0) == 8623418512398828792274158979964869393034224267928014534933203776818702139758)
|
||||
assert(mimc7R20(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 8623418512398828792274158979964869393034224267928014534933203776818702139758)
|
||||
assert(mimc7R20(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 15315177265066649795408805007175121550344555424263995530745989936206840798041)
|
||||
assert(mimc7R20(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 19139739902058628561064841933381604453445216873412991992755775746150759284829)
|
||||
assert(mimc7R20(0, 0) == 19139739902058628561064841933381604453445216873412991992755775746150759284829f)
|
||||
assert(mimc7R20(100, 0) == 8623418512398828792274158979964869393034224267928014534933203776818702139758f)
|
||||
return
|
|
@ -1,9 +1,6 @@
|
|||
import "hashes/mimc7/mimc7R50"
|
||||
|
||||
def main():
|
||||
assert(mimc7R50(0, 0) == 3049953358280347916081509186284461274525472221619157672645224540758481713173)
|
||||
assert(mimc7R50(100, 0) == 18511388995652647480418174218630545482006454713617579894396683237092568946789)
|
||||
assert(mimc7R50(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 18511388995652647480418174218630545482006454713617579894396683237092568946789)
|
||||
assert(mimc7R50(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 9149577627043020462780389988155990926223727917856424056384664564191878439702)
|
||||
assert(mimc7R50(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 3049953358280347916081509186284461274525472221619157672645224540758481713173)
|
||||
assert(mimc7R50(0, 0) == 3049953358280347916081509186284461274525472221619157672645224540758481713173f)
|
||||
assert(mimc7R50(100, 0) == 18511388995652647480418174218630545482006454713617579894396683237092568946789f)
|
||||
return
|
|
@ -2,8 +2,5 @@ import "hashes/mimc7/mimc7R90"
|
|||
|
||||
def main():
|
||||
assert(mimc7R90(0, 0) == 20281265111705407344053532742843085357648991805359414661661476832595822221514)
|
||||
assert(mimc7R90(100, 0) == 1010054095264022068840870550831559811104631937745987065544478027572003292636)
|
||||
assert(mimc7R90(100, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 1010054095264022068840870550831559811104631937745987065544478027572003292636)
|
||||
assert(mimc7R90(21888242871839275222246405745257275088548364400416034343698204186575808495618, 1) == 8189519586469873426687580455476035992041353456517724932462363814215190642760)
|
||||
assert(mimc7R90(21888242871839275222246405745257275088548364400416034343698204186575808495617, 21888242871839275222246405745257275088548364400416034343698204186575808495617) == 20281265111705407344053532742843085357648991805359414661661476832595822221514)
|
||||
assert(mimc7R90(100, 0) == 1010054095264022068840870550831559811104631937745987065544478027572003292636f)
|
||||
return
|
|
@ -1,7 +1,6 @@
|
|||
import "hashes/mimcSponge/mimcSponge" as mimcSponge
|
||||
|
||||
def main():
|
||||
assert(mimcSponge([1,2], 3) == [20225509322021146255705869525264566735642015554514977326536820959638320229084,13871743498877225461925335509899475799121918157213219438898506786048812913771,21633608428713573518356618235457250173701815120501233429160399974209848779097])
|
||||
assert(mimcSponge([0,0], 0) == [20636625426020718969131298365984859231982649550971729229988535915544421356929,6046202021237334713296073963481784771443313518730771623154467767602059802325,16227963524034219233279650312501310147918176407385833422019760797222680144279])
|
||||
assert(mimcSponge([21888242871839275222246405745257275088548364400416034343698204186575808495617, 0], 0) == [20636625426020718969131298365984859231982649550971729229988535915544421356929,6046202021237334713296073963481784771443313518730771623154467767602059802325,16227963524034219233279650312501310147918176407385833422019760797222680144279])
|
||||
assert(mimcSponge([1,2], 3) == [20225509322021146255705869525264566735642015554514977326536820959638320229084,13871743498877225461925335509899475799121918157213219438898506786048812913771,21633608428713573518356618235457250173701815120501233429160399974209848779097f])
|
||||
assert(mimcSponge([0,0], 0) == [20636625426020718969131298365984859231982649550971729229988535915544421356929,6046202021237334713296073963481784771443313518730771623154467767602059802325,16227963524034219233279650312501310147918176407385833422019760797222680144279f])
|
||||
return
|
Loading…
Reference in a new issue