1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00
ZoKrates/zokrates_analysis/src/condition_redefiner.rs
2022-11-17 14:38:12 +01:00

398 lines
14 KiB
Rust

use zokrates_ast::typed::{
folder::*, BlockExpression, BooleanExpression, Conditional, ConditionalExpression,
ConditionalOrExpression, CoreIdentifier, Expr, Id, Identifier, Type, TypedExpression,
TypedProgram, TypedStatement, Variable,
};
use zokrates_field::Field;
#[derive(Default)]
pub struct ConditionRedefiner<'ast, T> {
index: usize,
buffer: Vec<TypedStatement<'ast, T>>,
}
impl<'ast, T: Field> ConditionRedefiner<'ast, T> {
pub fn redefine(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
Self::default().fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
assert!(self.buffer.is_empty());
let s = fold_statement(self, s);
let buffer = std::mem::take(&mut self.buffer);
buffer.into_iter().chain(s).collect()
}
fn fold_block_expression<E: Fold<'ast, T>>(
&mut self,
b: BlockExpression<'ast, T, E>,
) -> BlockExpression<'ast, T, E> {
// start with a fresh state, but keep the global counter
let mut redefiner = ConditionRedefiner {
index: self.index,
buffer: vec![],
};
let b = fold_block_expression(&mut redefiner, b);
// we add the buffer statements *after* the block statements because they refer to the return value,
// the buffered statements for the block statements are already included in the result
let b = BlockExpression {
statements: b
.statements
.into_iter()
.chain(std::mem::take(&mut redefiner.buffer))
.collect(),
..b
};
// continue from the latest index
self.index = redefiner.index;
b
}
fn fold_conditional_expression<E: Expr<'ast, T> + Conditional<'ast, T> + Fold<'ast, T>>(
&mut self,
_: &E::Ty,
e: ConditionalExpression<'ast, T, E>,
) -> ConditionalOrExpression<'ast, T, E> {
let condition = self.fold_boolean_expression(*e.condition);
let condition = match condition {
condition @ BooleanExpression::Value(_)
| condition @ BooleanExpression::Identifier(_) => condition,
condition => {
let condition_id = Identifier::from(CoreIdentifier::Condition(self.index));
self.buffer.push(TypedStatement::definition(
Variable::immutable(condition_id.clone(), Type::Boolean).into(),
TypedExpression::from(condition),
));
self.index += 1;
BooleanExpression::identifier(condition_id)
}
};
let consequence = e.consequence.fold(self);
let alternative = e.alternative.fold(self);
ConditionalOrExpression::Conditional(ConditionalExpression::new(
condition,
consequence,
alternative,
e.kind,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use zokrates_ast::typed::{
Block, BooleanExpression, Conditional, ConditionalKind, FieldElementExpression, Type,
};
use zokrates_field::Bn128Field;
#[test]
fn no_redefine_if_constant() {
// field foo = if true { 1 } else { 2 };
// should be left unchanged
let s = TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
BooleanExpression::Value(true),
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
ConditionalKind::IfElse,
)
.into(),
);
let mut r = ConditionRedefiner::default();
assert_eq!(r.fold_statement(s.clone()), vec![s]);
}
#[test]
fn no_redefine_if_identifier() {
// field foo = if c { 1 } else { 2 };
// should be left unchanged
let s = TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
BooleanExpression::identifier("c".into()),
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
ConditionalKind::IfElse,
)
.into(),
);
let mut r = ConditionRedefiner::default();
assert_eq!(r.fold_statement(s.clone()), vec![s]);
}
#[test]
fn redefine_if_expression() {
// field foo = if c && d { 1 } else { 2 };
// should become
// bool #CONDITION_0 = c && d;
// field foo = if #CONDITION_0 { 1 } else { 2 };
let condition = BooleanExpression::And(
box BooleanExpression::identifier("c".into()),
box BooleanExpression::identifier("d".into()),
);
let s = TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
condition.clone(),
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
ConditionalKind::IfElse,
)
.into(),
);
let mut r = ConditionRedefiner::default();
let expected = vec![
// define condition
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(),
condition.into(),
),
// rewrite statement
TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
BooleanExpression::identifier(CoreIdentifier::Condition(0).into()),
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
ConditionalKind::IfElse,
)
.into(),
),
];
assert_eq!(r.fold_statement(s), expected);
}
#[test]
fn redefine_rec() {
// field foo = if c && d {
// if e && f { 1 } else { 2 }
// } else {
// 3
// };
//
//
// should become
//
// bool #CONDITION_0 = c && d;
// bool #CONDITION_1 = e && f;
// field foo = if #CONDITION_0 {
// if #CONDITION_1 { 1 } else { 2 }
// } else {
// 3
// };
let condition_0 = BooleanExpression::And(
box BooleanExpression::identifier("c".into()),
box BooleanExpression::identifier("d".into()),
);
let condition_1 = BooleanExpression::And(
box BooleanExpression::identifier("e".into()),
box BooleanExpression::identifier("f".into()),
);
let s = TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
condition_0.clone(),
FieldElementExpression::conditional(
condition_1.clone(),
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
ConditionalKind::IfElse,
),
FieldElementExpression::Number(Bn128Field::from(3)),
ConditionalKind::IfElse,
)
.into(),
);
let mut r = ConditionRedefiner::default();
let expected = vec![
// define conditions
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(),
condition_0.into(),
),
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(1), Type::Boolean).into(),
condition_1.into(),
),
// rewrite statement
TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
BooleanExpression::identifier(CoreIdentifier::Condition(0).into()),
FieldElementExpression::conditional(
BooleanExpression::identifier(CoreIdentifier::Condition(1).into()),
FieldElementExpression::Number(Bn128Field::from(1)),
FieldElementExpression::Number(Bn128Field::from(2)),
ConditionalKind::IfElse,
),
FieldElementExpression::Number(Bn128Field::from(3)),
ConditionalKind::IfElse,
)
.into(),
),
];
assert_eq!(r.fold_statement(s), expected);
}
#[test]
fn redefine_block() {
// field foo = if c && d {
// field a = 1;
// if e && f { 2 } else { 3 }
// } else {
// field b = 2;
// if e && f { 2 } else { 3 }
// };
//
// should become
//
// bool #CONDITION_0 = c && d;
// field foo = if #CONDITION_0 ? {
// field a = 1;
// bool #CONDITION_1 = e && f;
// if #CONDITION_1 { 2 } : { 3 }
// } else {
// field b = 2;
// bool #CONDITION_2 = e && f;
// if #CONDITION_2 { 2 } : { 3 }
// };
let condition_0 = BooleanExpression::And(
box BooleanExpression::identifier("c".into()),
box BooleanExpression::identifier("d".into()),
);
let condition_1 = BooleanExpression::And(
box BooleanExpression::identifier("e".into()),
box BooleanExpression::identifier("f".into()),
);
let condition_2 = BooleanExpression::And(
box BooleanExpression::identifier("e".into()),
box BooleanExpression::identifier("f".into()),
);
let condition_id_0 = BooleanExpression::identifier(CoreIdentifier::Condition(0).into());
let condition_id_1 = BooleanExpression::identifier(CoreIdentifier::Condition(1).into());
let condition_id_2 = BooleanExpression::identifier(CoreIdentifier::Condition(2).into());
let s = TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
condition_0.clone(),
FieldElementExpression::block(
vec![TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Number(Bn128Field::from(1)).into(),
)],
FieldElementExpression::conditional(
condition_1.clone(),
FieldElementExpression::Number(Bn128Field::from(2)),
FieldElementExpression::Number(Bn128Field::from(3)),
ConditionalKind::IfElse,
),
),
FieldElementExpression::block(
vec![TypedStatement::definition(
Variable::field_element("b").into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
)],
FieldElementExpression::conditional(
condition_2.clone(),
FieldElementExpression::Number(Bn128Field::from(2)),
FieldElementExpression::Number(Bn128Field::from(3)),
ConditionalKind::IfElse,
),
),
ConditionalKind::IfElse,
)
.into(),
);
let mut r = ConditionRedefiner::default();
let expected = vec![
// define conditions
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(),
condition_0.into(),
),
// rewrite statement
TypedStatement::definition(
Variable::field_element("foo").into(),
FieldElementExpression::conditional(
condition_id_0.clone(),
FieldElementExpression::block(
vec![
TypedStatement::definition(
Variable::field_element("a").into(),
FieldElementExpression::Number(Bn128Field::from(1)).into(),
),
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(1), Type::Boolean)
.into(),
condition_1.into(),
),
],
FieldElementExpression::conditional(
condition_id_1,
FieldElementExpression::Number(Bn128Field::from(2)),
FieldElementExpression::Number(Bn128Field::from(3)),
ConditionalKind::IfElse,
),
),
FieldElementExpression::block(
vec![
TypedStatement::definition(
Variable::field_element("b").into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
),
TypedStatement::definition(
Variable::immutable(CoreIdentifier::Condition(2), Type::Boolean)
.into(),
condition_2.into(),
),
],
FieldElementExpression::conditional(
condition_id_2,
FieldElementExpression::Number(Bn128Field::from(2)),
FieldElementExpression::Number(Bn128Field::from(3)),
ConditionalKind::IfElse,
),
),
ConditionalKind::IfElse,
)
.into(),
),
];
assert_eq!(r.fold_statement(s), expected);
}
}