1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

implement any assertion naively

This commit is contained in:
schaeff 2020-06-15 23:50:42 +02:00
parent eb11028703
commit dce41e50e8
15 changed files with 164 additions and 120 deletions

View file

@ -1,6 +1,19 @@
struct Foo {
field a
}
struct Bar {
Foo[1] foo
}
def isEqual(field a, field b) -> (bool):
return a == b
def main(field a) -> (field):
field b = (a + 5) * 6
2 * b == a * 12 + 60
field c = 7 * (b + a)
c == 7 * b + 7 * a
return b + c
isEqual(c, 7 * b + 7 * a)
field k = if [1, 2] == [3, 4] then 1 else 3 fi
[Bar { foo : [Foo { a: 42 }]}] == [Bar { foo : [Foo { a: 42 }]}]
return b + c

View file

@ -4,11 +4,7 @@
def isWaldo(field a, field p, field q) -> (field):
// make sure that p and q are both non zero
// we can't check inequalities, so let's create binary
// variables
field p1 = if p == 1 then 0 else 1 fi // "p != 1"
field q1 = if q == 1 then 0 else 1 fi // "q != 1"
q1 * p1 == 1 // "p1 and q1"
!(p == 1) && !(q == 1)
// we know how to factor a
a == p * q
@ -16,10 +12,6 @@ def isWaldo(field a, field p, field q) -> (field):
return 1
// define all
def main(field a0, field a1, field a2, field a3, private field index, private field p, private field q) -> (field):
def main(field[4] a, private field index, private field p, private field q) -> (field):
// prover provides the index of Waldo
field waldo = if index == 0 then a0 else 0 fi
waldo = waldo + if index == 1 then a1 else 0 fi
waldo = waldo + if index == 2 then a2 else 0 fi
waldo = waldo + if index == 3 then a3 else 0 fi
return isWaldo(waldo, p, q)
return isWaldo(a[index], p, q)

View file

@ -262,23 +262,8 @@ impl<'ast, T: Field> From<pest::AssertionStatement<'ast>> for absy::StatementNod
fn from(statement: pest::AssertionStatement<'ast>) -> absy::StatementNode<T> {
use absy::NodeValue;
match statement.expression {
pest::Expression::Binary(e) => match e.op {
pest::BinaryOperator::Eq => absy::Statement::Condition(
absy::ExpressionNode::from(*e.left),
absy::ExpressionNode::from(*e.right),
),
_ => unimplemented!(
"Assertion statements should be an equality check, found {}",
statement.span.as_str()
),
},
_ => unimplemented!(
"Assertion statements should be an equality check, found {}",
statement.span.as_str()
),
}
.span(statement.span)
absy::Statement::Assertion(absy::ExpressionNode::from(statement.expression))
.span(statement.span)
}
}

View file

@ -299,7 +299,7 @@ pub enum Statement<'ast, T> {
Return(ExpressionListNode<'ast, T>),
Declaration(VariableNode<'ast>),
Definition(AssigneeNode<'ast, T>, ExpressionNode<'ast, T>),
Condition(ExpressionNode<'ast, T>, ExpressionNode<'ast, T>),
Assertion(ExpressionNode<'ast, T>),
For(
VariableNode<'ast>,
ExpressionNode<'ast, T>,
@ -317,7 +317,7 @@ impl<'ast, T: fmt::Display> fmt::Display for Statement<'ast, T> {
Statement::Return(ref expr) => write!(f, "return {}", expr),
Statement::Declaration(ref var) => write!(f, "{}", var),
Statement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
Statement::Condition(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
Statement::Assertion(ref e) => write!(f, "{}", e),
Statement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {} in {}..{} do\n", var, start, stop)?;
for l in list {
@ -346,7 +346,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for Statement<'ast, T> {
Statement::Definition(ref lhs, ref rhs) => {
write!(f, "Definition({:?}, {:?})", lhs, rhs)
}
Statement::Condition(ref lhs, ref rhs) => write!(f, "Condition({:?}, {:?})", lhs, rhs),
Statement::Assertion(ref e) => write!(f, "Assertion({:?})", e),
Statement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {:?} in {:?}..{:?} do\n", var, start, stop)?;
for l in list {

View file

@ -1674,24 +1674,22 @@ impl<'ast, T: Field> Flattener<'ast, T> {
None => {}
}
}
ZirStatement::Condition(lhs, rhs) => {
// flatten expr1 and expr2 to n flattened expressions with n the number of primitive types for expr1
// add n conditions to check equality of the n expressions
ZirStatement::Assertion(e) => {
// naive approach: flatten the boolean to a single field element and constrain it to 1
let lhs = self
.flatten_expression(symbols, statements_flattened, lhs)
.get_field_unchecked();
let rhs = self
.flatten_expression(symbols, statements_flattened, rhs)
.get_field_unchecked();
let e = self.flatten_boolean_expression(symbols, statements_flattened, e);
if lhs.is_linear() {
statements_flattened.push(FlatStatement::Condition(lhs, rhs));
} else if rhs.is_linear() {
// swap so that left side is linear
statements_flattened.push(FlatStatement::Condition(rhs, lhs));
if e.is_linear() {
statements_flattened.push(FlatStatement::Condition(
e,
FlatExpression::Number(T::from(1)),
));
} else {
unreachable!()
// swap so that left side is linear
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::from(1)),
e,
));
}
}
ZirStatement::MultipleDefinition(vars, rhs) => {

View file

@ -875,27 +875,21 @@ impl<'ast> Checker<'ast> {
}
.map_err(|e| vec![e])
}
Statement::Condition(lhs, rhs) => {
let checked_lhs = self
.check_expression(lhs, module_id, &types)
.map_err(|e| vec![e])?;
let checked_rhs = self
.check_expression(rhs, module_id, &types)
Statement::Assertion(e) => {
let e = self
.check_expression(e, module_id, &types)
.map_err(|e| vec![e])?;
if checked_lhs.get_type() == checked_rhs.get_type() {
Ok(TypedStatement::Condition(checked_lhs, checked_rhs))
} else {
Err(ErrorInner {
match e {
TypedExpression::Boolean(e) => Ok(TypedStatement::Assertion(e)),
e => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Cannot compare {} of type {:?} to {} of type {:?}",
checked_lhs,
checked_lhs.get_type(),
checked_rhs,
checked_rhs.get_type(),
"Expected {} to be of type bool, found {}",
e,
e.get_type(),
),
})
}),
}
.map_err(|e| vec![e])
}
@ -1543,6 +1537,38 @@ impl<'ast> Checker<'ast> {
(TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => {
Ok(BooleanExpression::BoolEq(box e1, box e2).into())
}
(TypedExpression::Array(e1), TypedExpression::Array(e2)) => {
if e1.get_type() == e2.get_type() {
Ok(BooleanExpression::ArrayEq(box e1, box e2).into())
} else {
Err(ErrorInner {
pos: Some(pos),
message: format!(
"Cannot compare {} of type {} to {} of type {}",
e1,
e1.get_type(),
e2,
e2.get_type()
),
})
}
}
(TypedExpression::Struct(e1), TypedExpression::Struct(e2)) => {
if e1.get_type() == e2.get_type() {
Ok(BooleanExpression::StructEq(box e1, box e2).into())
} else {
Err(ErrorInner {
pos: Some(pos),
message: format!(
"Cannot compare {} of type {} to {} of type {}",
e1,
e1.get_type(),
e2,
e2.get_type()
),
})
}
}
(e1, e2) => Err(ErrorInner {
pos: Some(pos),
message: format!(
@ -3136,9 +3162,12 @@ mod tests {
// def bar():
// 2 == foo()
// should fail
let bar_statements: Vec<StatementNode<Bn128Field>> = vec![Statement::Condition(
Expression::FieldConstant(Bn128Field::from(2)).mock(),
Expression::FunctionCall("foo", vec![]).mock(),
let bar_statements: Vec<StatementNode<Bn128Field>> = vec![Statement::Assertion(
Expression::Eq(
box Expression::FieldConstant(Bn128Field::from(2)).mock(),
box Expression::FunctionCall("foo", vec![]).mock(),
)
.mock(),
)
.mock()];
@ -3535,9 +3564,12 @@ mod tests {
// def bar():
// 1 == foo()
// should fail
let bar_statements: Vec<StatementNode<Bn128Field>> = vec![Statement::Condition(
Expression::FieldConstant(Bn128Field::from(1)).mock(),
Expression::FunctionCall("foo", vec![]).mock(),
let bar_statements: Vec<StatementNode<Bn128Field>> = vec![Statement::Assertion(
Expression::Eq(
box Expression::FieldConstant(Bn128Field::from(1)).mock(),
box Expression::FunctionCall("foo", vec![]).mock(),
)
.mock(),
)
.mock()];

View file

@ -248,14 +248,9 @@ pub fn fold_statement<'ast, T: Field>(
.map(|v| zir::ZirStatement::Declaration(v))
.collect()
}
typed_absy::TypedStatement::Condition(left, right) => {
let left = f.fold_expression(left);
let right = f.fold_expression(right);
assert_eq!(left.len(), right.len());
left.into_iter()
.zip(right.into_iter())
.map(|(left, right)| zir::ZirStatement::Condition(left, right))
.collect()
typed_absy::TypedStatement::Assertion(e) => {
let e = f.fold_boolean_expression(e);
vec![zir::ZirStatement::Assertion(e)]
}
typed_absy::TypedStatement::For(..) => unreachable!(),
typed_absy::TypedStatement::MultipleDefinition(variables, elist) => {
@ -555,6 +550,39 @@ pub fn fold_boolean_expression<'ast, T: Field>(
let e2 = f.fold_boolean_expression(e2);
zir::BooleanExpression::BoolEq(box e1, box e2)
}
typed_absy::BooleanExpression::ArrayEq(box e1, box e2) => {
let e1 = f.fold_array_expression(e1);
let e2 = f.fold_array_expression(e2);
assert_eq!(e1.len(), e2.len());
e1.into_iter().zip(e2.into_iter()).fold(
zir::BooleanExpression::Value(true),
|acc, (e1, e2)| {
zir::BooleanExpression::And(
box acc,
box match (e1, e2) {
(
zir::ZirExpression::FieldElement(e1),
zir::ZirExpression::FieldElement(e2),
) => zir::BooleanExpression::FieldEq(box e1, box e2),
(zir::ZirExpression::Boolean(e1), zir::ZirExpression::Boolean(e2)) => {
zir::BooleanExpression::BoolEq(box e1, box e2)
}
_ => unimplemented!(),
},
)
},
)
}
typed_absy::BooleanExpression::StructEq(box e1, box e2) => {
let e1 = f.fold_struct_expression(e1);
let e2 = f.fold_struct_expression(e2);
assert_eq!(e1.len(), e2.len());
unimplemented!()
}
typed_absy::BooleanExpression::Lt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);

View file

@ -112,13 +112,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
TypedStatement::Definition(TypedAssignee::Member(..), _) => {
unreachable!("struct update should have been replaced with full struct redef")
}
// propagate lhs and rhs for conditions
TypedStatement::Condition(e1, e2) => {
// propagate the boolean
TypedStatement::Assertion(e) => {
// could stop execution here if condition is known to fail
Some(TypedStatement::Condition(
self.fold_expression(e1),
self.fold_expression(e2),
))
Some(TypedStatement::Assertion(self.fold_boolean_expression(e)))
}
// only loops with variable bounds are expected here
// we stop propagation here as constants maybe be modified inside the loop body

View file

@ -397,18 +397,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
)],
},
},
// we need to put back in range to assert
ZirStatement::Condition(lhs, rhs) => {
match (self.fold_expression(lhs), self.fold_expression(rhs)) {
(ZirExpression::Uint(lhs), ZirExpression::Uint(rhs)) => {
vec![ZirStatement::Condition(
force_reduce(lhs).into(),
force_reduce(rhs).into(),
)]
}
(lhs, rhs) => vec![ZirStatement::Condition(lhs, rhs)],
}
}
s => fold_statement(self, s),
}
}

View file

@ -84,13 +84,12 @@ impl<'ast> Unroller<'ast> {
match head {
Access::Select(head) => {
statements.insert(TypedStatement::Condition(
statements.insert(TypedStatement::Assertion(
BooleanExpression::Lt(
box head.clone(),
box FieldElementExpression::Number(T::from(size)),
)
.into(),
BooleanExpression::Value(true).into(),
));
ArrayExpressionInner::Value(
@ -1089,13 +1088,12 @@ mod tests {
assert_eq!(
u.fold_statement(s),
vec![
TypedStatement::Condition(
TypedStatement::Assertion(
BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(2))
)
.into(),
BooleanExpression::Value(true).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array(
@ -1227,13 +1225,12 @@ mod tests {
assert_eq!(
u.fold_statement(s),
vec![
TypedStatement::Condition(
TypedStatement::Assertion(
BooleanExpression::Lt(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(2))
)
.into(),
BooleanExpression::Value(true).into()
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(

View file

@ -39,7 +39,7 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
_ => unreachable!(),
};
self.statements.push(TypedStatement::Condition(
self.statements.push(TypedStatement::Assertion(
(0..size)
.map(|index| {
BooleanExpression::FieldEq(
@ -53,7 +53,6 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
})
.unwrap()
.into(),
BooleanExpression::Value(true).into(),
));
(0..size)

View file

@ -165,9 +165,7 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
TypedStatement::Definition(f.fold_assignee(a), f.fold_expression(e))
}
TypedStatement::Declaration(v) => TypedStatement::Declaration(f.fold_variable(v)),
TypedStatement::Condition(left, right) => {
TypedStatement::Condition(f.fold_expression(left), f.fold_expression(right))
}
TypedStatement::Assertion(e) => TypedStatement::Assertion(f.fold_boolean_expression(e)),
TypedStatement::For(v, from, to, statements) => TypedStatement::For(
f.fold_variable(v),
from,
@ -325,6 +323,16 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_boolean_expression(e2);
BooleanExpression::BoolEq(box e1, box e2)
}
BooleanExpression::ArrayEq(box e1, box e2) => {
let e1 = f.fold_array_expression(e1);
let e2 = f.fold_array_expression(e2);
BooleanExpression::ArrayEq(box e1, box e2)
}
BooleanExpression::StructEq(box e1, box e2) => {
let e1 = f.fold_struct_expression(e1);
let e2 = f.fold_struct_expression(e2);
BooleanExpression::StructEq(box e1, box e2)
}
BooleanExpression::Lt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);

View file

@ -300,7 +300,7 @@ pub enum TypedStatement<'ast, T> {
Return(Vec<TypedExpression<'ast, T>>),
Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
Declaration(Variable<'ast>),
Condition(TypedExpression<'ast, T>, TypedExpression<'ast, T>),
Assertion(BooleanExpression<'ast, T>),
For(
Variable<'ast>,
FieldElementExpression<'ast, T>,
@ -327,9 +327,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedStatement<'ast, T> {
TypedStatement::Definition(ref lhs, ref rhs) => {
write!(f, "Definition({:?}, {:?})", lhs, rhs)
}
TypedStatement::Condition(ref lhs, ref rhs) => {
write!(f, "Condition({:?}, {:?})", lhs, rhs)
}
TypedStatement::Assertion(ref e) => write!(f, "Assertion({:?})", e),
TypedStatement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {:?} in {:?}..{:?} do\n", var, start, stop)?;
for l in list {
@ -376,7 +374,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
}
TypedStatement::Declaration(ref var) => write!(f, "{}", var),
TypedStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
TypedStatement::Condition(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
TypedStatement::Assertion(ref e) => write!(f, "{}", e),
TypedStatement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {} in {}..{} do\n", var, start, stop)?;
for l in list {
@ -639,6 +637,11 @@ pub enum BooleanExpression<'ast, T> {
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
ArrayEq(Box<ArrayExpression<'ast, T>>, Box<ArrayExpression<'ast, T>>),
StructEq(
Box<StructExpression<'ast, T>>,
Box<StructExpression<'ast, T>>,
),
Ge(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
@ -906,6 +909,8 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::ArrayEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::StructEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
@ -985,6 +990,12 @@ impl<'ast, T: fmt::Debug> fmt::Debug for BooleanExpression<'ast, T> {
BooleanExpression::BoolEq(ref lhs, ref rhs) => {
write!(f, "BoolEq({:?}, {:?})", lhs, rhs)
}
BooleanExpression::ArrayEq(ref lhs, ref rhs) => {
write!(f, "ArrayEq({:?}, {:?})", lhs, rhs)
}
BooleanExpression::StructEq(ref lhs, ref rhs) => {
write!(f, "StructEq({:?}, {:?})", lhs, rhs)
}
BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "Ge({:?}, {:?})", lhs, rhs),
BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "Gt({:?}, {:?})", lhs, rhs),
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "And({:?}, {:?})", lhs, rhs),

View file

@ -130,9 +130,7 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
ZirStatement::Definition(f.fold_assignee(a), f.fold_expression(e))
}
ZirStatement::Declaration(v) => ZirStatement::Declaration(f.fold_variable(v)),
ZirStatement::Condition(left, right) => {
ZirStatement::Condition(f.fold_expression(left), f.fold_expression(right))
}
ZirStatement::Assertion(e) => ZirStatement::Assertion(f.fold_boolean_expression(e)),
ZirStatement::MultipleDefinition(variables, elist) => ZirStatement::MultipleDefinition(
variables.into_iter().map(|v| f.fold_variable(v)).collect(),
f.fold_expression_list(elist),

View file

@ -191,7 +191,7 @@ pub enum ZirStatement<'ast, T> {
Return(Vec<ZirExpression<'ast, T>>),
Definition(ZirAssignee<'ast>, ZirExpression<'ast, T>),
Declaration(Variable<'ast>),
Condition(ZirExpression<'ast, T>, ZirExpression<'ast, T>),
Assertion(BooleanExpression<'ast, T>),
MultipleDefinition(Vec<Variable<'ast>>, ZirExpressionList<'ast, T>),
}
@ -212,9 +212,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ZirStatement<'ast, T> {
ZirStatement::Definition(ref lhs, ref rhs) => {
write!(f, "Definition({:?}, {:?})", lhs, rhs)
}
ZirStatement::Condition(ref lhs, ref rhs) => {
write!(f, "Condition({:?}, {:?})", lhs, rhs)
}
ZirStatement::Assertion(ref e) => write!(f, "Assertion({:?})", e),
ZirStatement::MultipleDefinition(ref lhs, ref rhs) => {
write!(f, "MultipleDefinition({:?}, {:?})", lhs, rhs)
}
@ -237,7 +235,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> {
}
ZirStatement::Declaration(ref var) => write!(f, "{}", var),
ZirStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
ZirStatement::Condition(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
ZirStatement::Assertion(ref e) => write!(f, "{}", e),
ZirStatement::MultipleDefinition(ref ids, ref rhs) => {
for (i, id) in ids.iter().enumerate() {
write!(f, "{}", id)?;