1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

rename, implement member propagation

This commit is contained in:
schaeff 2020-12-07 16:57:13 +00:00
parent 0d2caddded
commit 1fecb47f0f
10 changed files with 693 additions and 466 deletions

View file

@ -3835,8 +3835,8 @@ mod tests {
TypedStatement::Declaration(typed_absy::Variable::field_element("b")),
TypedStatement::MultipleDefinition(
vec![
typed_absy::Variable::field_element("a"),
typed_absy::Variable::field_element("b"),
typed_absy::Variable::field_element("a").into(),
typed_absy::Variable::field_element("b").into(),
],
TypedExpressionList::FunctionCall(
FunctionKey::with_id("foo").signature(

View file

@ -1257,7 +1257,7 @@ mod tests {
arguments: vec![],
statements: vec![
TypedStatement::MultipleDefinition(
vec![Variable::field_element("a")],
vec![Variable::field_element("a").into()],
TypedExpressionList::FunctionCall(
FunctionKey::with_id("foo").signature(
Signature::new().outputs(vec![Type::FieldElement]),
@ -1366,7 +1366,7 @@ mod tests {
arguments: vec![],
statements: vec![
TypedStatement::MultipleDefinition(
vec![Variable::field_element("a")],
vec![Variable::field_element("a").into()],
TypedExpressionList::FunctionCall(
FunctionKey::with_id("foo").signature(
Signature::new().outputs(vec![Type::FieldElement]),

View file

@ -14,7 +14,8 @@ mod return_binder;
mod uint_optimizer;
mod unconstrained_vars;
mod unroll;
mod variable_access_remover;
mod variable_read_remover;
mod variable_write_remover;
use self::flatten_complex_types::Flattener;
use self::inline::Inliner;
@ -24,7 +25,8 @@ use self::redefinition::RedefinitionOptimizer;
use self::return_binder::ReturnBinder;
use self::uint_optimizer::UintOptimizer;
use self::unconstrained_vars::UnconstrainedVariableDetector;
use self::variable_access_remover::VariableAccessRemover;
use self::variable_read_remover::VariableReadRemover;
use self::variable_write_remover::VariableWriteRemover;
use crate::flat_absy::FlatProg;
use crate::ir::Prog;
use crate::typed_absy::TypedProgram;
@ -54,8 +56,13 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
let r = RedefinitionOptimizer::optimize(r);
println!("variable access");
// remove variable access to complex types
let r = VariableAccessRemover::apply(r);
println!("flattenerr");
let r = VariableReadRemover::apply(r);
println!("{}", r);
println!("variable index");
// remove assignment to variable index
let r = VariableWriteRemover::apply(r);
println!("{}", r);
println!("flatten complex types");
// convert to zir, removing complex types
let zir = Flattener::flatten(r);
println!("uint opt");

View file

@ -12,7 +12,6 @@
use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::convert::TryFrom;
use typed_absy::types::Type;
@ -79,7 +78,29 @@ impl<'ast, T: Field> Propagator<'ast, T> {
e => e,
}
}
TypedAssignee::Member(..) => unimplemented!(),
TypedAssignee::Member(box assignee, m) => match self.get_constant_entry(&assignee) {
Ok((v, c)) => {
let ty = assignee.get_type();
let index = match ty {
Type::Struct(struct_ty) => struct_ty
.members
.iter()
.position(|member| *m == member.id)
.unwrap(),
_ => unreachable!(),
};
match c {
TypedExpression::Struct(a) => match a.as_inner_mut() {
StructExpressionInner::Value(value) => Ok((v, &mut value[index])),
_ => unreachable!(),
},
_ => unreachable!(),
}
}
e => e,
},
}
}
}

View file

@ -163,245 +163,245 @@ mod tests {
use super::*;
use zokrates_field::Bn128Field;
#[test]
fn ssa_array() {
let a0 = ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 3);
// #[test]
// fn ssa_array() {
// let a0 = ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 3);
let e = FieldElementExpression::Number(Bn128Field::from(42)).into();
// let e = FieldElementExpression::Number(Bn128Field::from(42)).into();
let index = FieldElementExpression::Number(Bn128Field::from(1));
// let index = FieldElementExpression::Number(Bn128Field::from(1));
let a1 = Unroller::choose_many(
a0.clone().into(),
vec![Access::Select(index)],
e,
&mut HashSet::new(),
);
// let a1 = Unroller::choose_many(
// a0.clone().into(),
// vec![Access::Select(index)],
// e,
// &mut HashSet::new(),
// );
// a[1] = 42
// -> a = [0 == 1 ? 42 : a[0], 1 == 1 ? 42 : a[1], 2 == 1 ? 42 : a[2]]
// // a[1] = 42
// // -> a = [0 == 1 ? 42 : a[0], 1 == 1 ? 42 : a[1], 2 == 1 ? 42 : a[2]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(42)),
FieldElementExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(2))
)
)
.into()
])
.annotate(Type::FieldElement, 3)
.into()
);
// assert_eq!(
// a1,
// ArrayExpressionInner::Value(vec![
// FieldElementExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(0)),
// box FieldElementExpression::Number(Bn128Field::from(1))
// ),
// FieldElementExpression::Number(Bn128Field::from(42)),
// FieldElementExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(0))
// )
// )
// .into(),
// FieldElementExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(1)),
// box FieldElementExpression::Number(Bn128Field::from(1))
// ),
// FieldElementExpression::Number(Bn128Field::from(42)),
// FieldElementExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(1))
// )
// )
// .into(),
// FieldElementExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(2)),
// box FieldElementExpression::Number(Bn128Field::from(1))
// ),
// FieldElementExpression::Number(Bn128Field::from(42)),
// FieldElementExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(2))
// )
// )
// .into()
// ])
// .annotate(Type::FieldElement, 3)
// .into()
// );
let a0 = ArrayExpressionInner::Identifier("a".into())
.annotate(Type::array(Type::FieldElement, 3), 3);
// let a0 = ArrayExpressionInner::Identifier("a".into())
// .annotate(Type::array(Type::FieldElement, 3), 3);
let e = ArrayExpressionInner::Identifier("b".into()).annotate(Type::FieldElement, 3);
// let e = ArrayExpressionInner::Identifier("b".into()).annotate(Type::FieldElement, 3);
let index = FieldElementExpression::Number(Bn128Field::from(1));
// let index = FieldElementExpression::Number(Bn128Field::from(1));
let a1 = Unroller::choose_many(
a0.clone().into(),
vec![Access::Select(index)],
e.clone().into(),
&mut HashSet::new(),
);
// let a1 = Unroller::choose_many(
// a0.clone().into(),
// vec![Access::Select(index)],
// e.clone().into(),
// &mut HashSet::new(),
// );
// a[0] = b
// -> a = [0 == 1 ? b : a[0], 1 == 1 ? b : a[1], 2 == 1 ? b : a[2]]
// // a[0] = b
// // -> a = [0 == 1 ? b : a[0], 1 == 1 ? b : a[1], 2 == 1 ? b : a[2]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
ArrayExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(1))
),
e.clone(),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(2))
)
)
.into()
])
.annotate(Type::array(Type::FieldElement, 3), 3)
.into()
);
// assert_eq!(
// a1,
// ArrayExpressionInner::Value(vec![
// ArrayExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(0)),
// box FieldElementExpression::Number(Bn128Field::from(1))
// ),
// e.clone(),
// ArrayExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(0))
// )
// )
// .into(),
// ArrayExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(1)),
// box FieldElementExpression::Number(Bn128Field::from(1))
// ),
// e.clone(),
// ArrayExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(1))
// )
// )
// .into(),
// ArrayExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(2)),
// box FieldElementExpression::Number(Bn128Field::from(1))
// ),
// e.clone(),
// ArrayExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(2))
// )
// )
// .into()
// ])
// .annotate(Type::array(Type::FieldElement, 3), 3)
// .into()
// );
let a0 = ArrayExpressionInner::Identifier("a".into())
.annotate(Type::array(Type::FieldElement, 2), 2);
// let a0 = ArrayExpressionInner::Identifier("a".into())
// .annotate(Type::array(Type::FieldElement, 2), 2);
let e = FieldElementExpression::Number(Bn128Field::from(42));
// let e = FieldElementExpression::Number(Bn128Field::from(42));
let indices = vec![
Access::Select(FieldElementExpression::Number(Bn128Field::from(0))),
Access::Select(FieldElementExpression::Number(Bn128Field::from(0))),
];
// let indices = vec![
// Access::Select(FieldElementExpression::Number(Bn128Field::from(0))),
// Access::Select(FieldElementExpression::Number(Bn128Field::from(0))),
// ];
let a1 = Unroller::choose_many(
a0.clone().into(),
indices,
e.clone().into(),
&mut HashSet::new(),
);
// let a1 = Unroller::choose_many(
// a0.clone().into(),
// indices,
// e.clone().into(),
// &mut HashSet::new(),
// );
// a[0][0] = 42
// -> a = [0 == 0 ? [0 == 0 ? 42 : a[0][0], 1 == 0 ? 42 : a[0][1]] : a[0], 1 == 0 ? [0 == 0 ? 42 : a[1][0], 1 == 0 ? 42 : a[1][1]] : a[1]]
// // a[0][0] = 42
// // -> a = [0 == 0 ? [0 == 0 ? 42 : a[0][0], 1 == 0 ? 42 : a[0][1]] : a[0], 1 == 0 ? [0 == 0 ? 42 : a[1][0], 1 == 0 ? 42 : a[1][1]] : a[1]]
assert_eq!(
a1,
ArrayExpressionInner::Value(vec![
ArrayExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
),
FieldElementExpression::Number(Bn128Field::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
),
FieldElementExpression::Number(Bn128Field::from(1))
)
)
.into()
])
.annotate(Type::FieldElement, 2),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(0))
)
)
.into(),
ArrayExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
ArrayExpressionInner::Value(vec![
FieldElementExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(0))
)
)
.into(),
FieldElementExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(0))
),
e.clone(),
FieldElementExpression::select(
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
),
FieldElementExpression::Number(Bn128Field::from(1))
)
)
.into()
])
.annotate(Type::FieldElement, 2),
ArrayExpression::select(
a0.clone(),
FieldElementExpression::Number(Bn128Field::from(1))
)
)
.into(),
])
.annotate(Type::array(Type::FieldElement, 2), 2)
.into()
);
}
// assert_eq!(
// a1,
// ArrayExpressionInner::Value(vec![
// ArrayExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(0)),
// box FieldElementExpression::Number(Bn128Field::from(0))
// ),
// ArrayExpressionInner::Value(vec![
// FieldElementExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(0)),
// box FieldElementExpression::Number(Bn128Field::from(0))
// ),
// e.clone(),
// FieldElementExpression::select(
// ArrayExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(0))
// ),
// FieldElementExpression::Number(Bn128Field::from(0))
// )
// )
// .into(),
// FieldElementExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(1)),
// box FieldElementExpression::Number(Bn128Field::from(0))
// ),
// e.clone(),
// FieldElementExpression::select(
// ArrayExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(0))
// ),
// FieldElementExpression::Number(Bn128Field::from(1))
// )
// )
// .into()
// ])
// .annotate(Type::FieldElement, 2),
// ArrayExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(0))
// )
// )
// .into(),
// ArrayExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(1)),
// box FieldElementExpression::Number(Bn128Field::from(0))
// ),
// ArrayExpressionInner::Value(vec![
// FieldElementExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(0)),
// box FieldElementExpression::Number(Bn128Field::from(0))
// ),
// e.clone(),
// FieldElementExpression::select(
// ArrayExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(1))
// ),
// FieldElementExpression::Number(Bn128Field::from(0))
// )
// )
// .into(),
// FieldElementExpression::if_else(
// BooleanExpression::FieldEq(
// box FieldElementExpression::Number(Bn128Field::from(1)),
// box FieldElementExpression::Number(Bn128Field::from(0))
// ),
// e.clone(),
// FieldElementExpression::select(
// ArrayExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(1))
// ),
// FieldElementExpression::Number(Bn128Field::from(1))
// )
// )
// .into()
// ])
// .annotate(Type::FieldElement, 2),
// ArrayExpression::select(
// a0.clone(),
// FieldElementExpression::Number(Bn128Field::from(1))
// )
// )
// .into(),
// ])
// .annotate(Type::array(Type::FieldElement, 2), 2)
// .into()
// );
// }
#[cfg(test)]
mod statement {
@ -658,7 +658,7 @@ mod tests {
);
let s: TypedStatement<Bn128Field> = TypedStatement::MultipleDefinition(
vec![Variable::field_element("a")],
vec![Variable::field_element("a").into()],
TypedExpressionList::FunctionCall(
FunctionKey::with_id("foo").signature(
Signature::new()
@ -672,7 +672,7 @@ mod tests {
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::MultipleDefinition(
vec![Variable::field_element(Identifier::from("a").version(1))],
vec![Variable::field_element(Identifier::from("a").version(1)).into()],
TypedExpressionList::FunctionCall(
FunctionKey::with_id("foo").signature(
Signature::new()

View file

@ -1,208 +0,0 @@
//! Module containing removal of variable access to complex types
//!
//! For example:
//! ```zokrates
//! a[index]
//! ```
//!
//! Would become
//! ```zokrates
//! if(index == 0, a[0], if(index == 1, a[1], ...))
//! ```
use typed_absy::{folder::*, *};
use zokrates_field::Field;
pub struct VariableAccessRemover<'ast, T: Field> {
statements: Vec<TypedStatement<'ast, T>>,
}
impl<'ast, T: Field> VariableAccessRemover<'ast, T> {
fn new() -> Self {
Self { statements: vec![] }
}
pub fn apply(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
Self::new().fold_program(p)
}
fn select<U: Select<'ast, T> + IfElse<'ast, T>>(
&mut self,
a: ArrayExpression<'ast, T>,
i: FieldElementExpression<'ast, T>,
) -> U {
match i {
FieldElementExpression::Number(i) => U::select(a, FieldElementExpression::Number(i)),
i => {
let size = match a.get_type().clone() {
Type::Array(array_ty) => array_ty.size,
_ => unreachable!(),
};
self.statements.push(TypedStatement::Assertion(
(0..size)
.map(|index| {
BooleanExpression::FieldEq(
box i.clone(),
box FieldElementExpression::Number(index.into()).into(),
)
})
.fold(None, |acc, e| match acc {
Some(acc) => Some(BooleanExpression::Or(box acc, box e)),
None => Some(e),
})
.unwrap()
.into(),
));
(0..size)
.map(|i| U::select(a.clone(), FieldElementExpression::Number(i.into())))
.enumerate()
.rev()
.fold(None, |acc, (index, res)| match acc {
Some(acc) => Some(U::if_else(
BooleanExpression::FieldEq(
box i.clone(),
box FieldElementExpression::Number(index.into()),
),
res,
acc,
)),
None => Some(res),
})
.unwrap()
}
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for VariableAccessRemover<'ast, T> {
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Select(box a, box i) => self.select(a, i),
e => fold_field_expression(self, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Select(box a, box i) => self.select(a, i),
e => fold_boolean_expression(self, e),
}
}
fn fold_array_expression_inner(
&mut self,
ty: &Type,
size: usize,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Select(box a, box i) => {
self.select::<ArrayExpression<'ast, T>>(a, i).into_inner()
}
e => fold_array_expression_inner(self, ty, size, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Select(box a, box i) => {
self.select::<StructExpression<'ast, T>>(a, i).into_inner()
}
e => fold_struct_expression_inner(self, ty, e),
}
}
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Select(box a, box i) => {
self.select::<UExpression<'ast, T>>(a, i).into_inner()
}
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
let s = fold_statement(self, s);
self.statements.drain(..).chain(s).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::Bn128Field;
#[test]
fn select() {
// b = a[i]
// ->
// i <= 1 == true
// b = if i == 0 then a[0] else if i == 1 then a[1] else 0
let access: TypedStatement<Bn128Field> = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("b")),
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 2),
box FieldElementExpression::Identifier("i".into()),
)
.into(),
);
assert_eq!(
VariableAccessRemover::new().fold_statement(access),
vec![
TypedStatement::Assertion(
BooleanExpression::Or(
box BooleanExpression::FieldEq(
box FieldElementExpression::Identifier("i".into()),
box FieldElementExpression::Number(0.into())
),
box BooleanExpression::FieldEq(
box FieldElementExpression::Identifier("i".into()),
box FieldElementExpression::Number(1.into())
)
)
.into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("b")),
FieldElementExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Identifier("i".into()),
box FieldElementExpression::Number(0.into())
),
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(0.into()),
),
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 2),
box FieldElementExpression::Number(1.into()),
)
)
.into()
)
]
);
}
}

View file

@ -13,11 +13,11 @@
use typed_absy::{folder::*, *};
use zokrates_field::Field;
pub struct VariableIndexRemover<'ast, T: Field> {
pub struct VariableReadRemover<'ast, T: Field> {
statements: Vec<TypedStatement<'ast, T>>,
}
impl<'ast, T: Field> VariableIndexRemover<'ast, T> {
impl<'ast, T: Field> VariableReadRemover<'ast, T> {
fn new() -> Self {
Self { statements: vec![] }
}
@ -76,7 +76,7 @@ impl<'ast, T: Field> VariableIndexRemover<'ast, T> {
}
}
impl<'ast, T: Field> Folder<'ast, T> for VariableIndexRemover<'ast, T> {
impl<'ast, T: Field> Folder<'ast, T> for VariableReadRemover<'ast, T> {
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
@ -167,7 +167,7 @@ mod tests {
);
assert_eq!(
VariableIndexRemover::new().fold_statement(access),
VariableReadRemover::new().fold_statement(access),
vec![
TypedStatement::Assertion(
BooleanExpression::Or(

View file

@ -0,0 +1,403 @@
//! Module containing SSA reduction, including for-loop unrolling
//!
//! @file unroll.rs
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
use crate::typed_absy::folder::*;
use crate::typed_absy::types::{MemberId, Type};
use crate::typed_absy::*;
use std::collections::HashMap;
use std::collections::HashSet;
use typed_absy::identifier::CoreIdentifier;
use zokrates_field::Field;
pub struct VariableWriteRemover<'ast> {
// version index for any variable name
substitution: HashMap<CoreIdentifier<'ast>, usize>,
}
impl<'ast> VariableWriteRemover<'ast> {
fn new() -> Self {
VariableWriteRemover {
substitution: HashMap::new(),
}
}
pub fn apply<T: Field>(p: TypedProgram<T>) -> TypedProgram<T> {
let mut remover = VariableWriteRemover::new();
remover.fold_program(p)
}
fn choose_many<T: Field>(
base: TypedExpression<'ast, T>,
indices: Vec<Access<'ast, T>>,
new_expression: TypedExpression<'ast, T>,
statements: &mut HashSet<TypedStatement<'ast, T>>,
) -> TypedExpression<'ast, T> {
let mut indices = indices;
match indices.len() {
0 => new_expression,
_ => match base {
TypedExpression::Array(base) => {
let inner_ty = base.inner_type();
let size = base.size();
let head = indices.remove(0);
let tail = indices;
match head {
Access::Select(head) => {
statements.insert(TypedStatement::Assertion(
BooleanExpression::Lt(
box head.clone(),
box FieldElementExpression::Number(T::from(size)),
)
.into(),
));
ArrayExpressionInner::Value(
(0..size)
.map(|i| match inner_ty {
Type::Array(..) => ArrayExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Array(e) => e,
e => unreachable!(
"the interior was expected to be an array, was {}",
e.get_type()
),
},
ArrayExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Struct(..) => StructExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
StructExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Struct(e) => e,
e => unreachable!(
"the interior was expected to be a struct, was {}",
e.get_type()
),
},
StructExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::FieldElement => FieldElementExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::FieldElement(e) => e,
e => unreachable!(
"the interior was expected to be a field, was {}",
e.get_type()
),
},
FieldElementExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Boolean => BooleanExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Boolean(e) => e,
e => unreachable!(
"the interior was expected to be a boolean, was {}",
e.get_type()
),
},
BooleanExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
Type::Uint(..) => UExpression::if_else(
BooleanExpression::FieldEq(
box FieldElementExpression::Number(T::from(i)),
box head.clone(),
),
match Self::choose_many(
UExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
) {
TypedExpression::Uint(e) => e,
e => unreachable!(
"the interior was expected to be a uint, was {}",
e.get_type()
),
},
UExpression::select(
base.clone(),
FieldElementExpression::Number(T::from(i)),
),
)
.into(),
})
.collect(),
)
.annotate(inner_ty.clone(), size)
.into()
}
Access::Member(..) => unreachable!("can't get a member from an array"),
}
}
TypedExpression::Struct(base) => {
let members = match base.get_type() {
Type::Struct(members) => members.clone(),
_ => unreachable!(),
};
let head = indices.remove(0);
let tail = indices;
match head {
Access::Member(head) => StructExpressionInner::Value(
members
.clone()
.into_iter()
.map(|member| match *member.ty {
Type::FieldElement => {
if member.id == head {
Self::choose_many(
FieldElementExpression::member(
base.clone(),
head.clone(),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
)
} else {
FieldElementExpression::member(
base.clone(),
member.id.clone(),
)
.into()
}
}
Type::Uint(..) => {
if member.id == head {
Self::choose_many(
UExpression::member(base.clone(), head.clone())
.into(),
tail.clone(),
new_expression.clone(),
statements,
)
} else {
UExpression::member(base.clone(), member.id.clone())
.into()
}
}
Type::Boolean => {
if member.id == head {
Self::choose_many(
BooleanExpression::member(
base.clone(),
head.clone(),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
)
} else {
BooleanExpression::member(
base.clone(),
member.id.clone(),
)
.into()
}
}
Type::Array(..) => {
if member.id == head {
Self::choose_many(
ArrayExpression::member(base.clone(), head.clone())
.into(),
tail.clone(),
new_expression.clone(),
statements,
)
} else {
ArrayExpression::member(base.clone(), member.id.clone())
.into()
}
}
Type::Struct(..) => {
if member.id == head {
Self::choose_many(
StructExpression::member(
base.clone(),
head.clone(),
)
.into(),
tail.clone(),
new_expression.clone(),
statements,
)
} else {
StructExpression::member(
base.clone(),
member.id.clone(),
)
.into()
}
}
})
.collect(),
)
.annotate(members)
.into(),
Access::Select(..) => unreachable!("can't get a element from a struct"),
}
}
e => unreachable!("can't make an access on a {}", e.get_type()),
},
}
}
}
#[derive(Clone, Debug)]
enum Access<'ast, T: Field> {
Select(FieldElementExpression<'ast, T>),
Member(MemberId),
}
/// Turn an assignee into its representation as a base variable and a list accesses
/// a[2][3][4] -> (a, [2, 3, 4])
fn linear<'ast, T: Field>(a: TypedAssignee<'ast, T>) -> (Variable, Vec<Access<'ast, T>>) {
match a {
TypedAssignee::Identifier(v) => (v, vec![]),
TypedAssignee::Select(box array, box index) => {
let (v, mut indices) = linear(array);
indices.push(Access::Select(index));
(v, indices)
}
TypedAssignee::Member(box s, m) => {
let (v, mut indices) = linear(s);
indices.push(Access::Member(m));
(v, indices)
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover<'ast> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Definition(assignee, expr) => {
let expr = self.fold_expression(expr);
let (variable, indices) = linear(assignee);
let base = match variable.get_type() {
Type::FieldElement => {
FieldElementExpression::Identifier(variable.id.clone().into()).into()
}
Type::Boolean => {
BooleanExpression::Identifier(variable.id.clone().into()).into()
}
Type::Uint(bitwidth) => {
UExpressionInner::Identifier(variable.id.clone().into())
.annotate(bitwidth)
.into()
}
Type::Array(array_type) => {
ArrayExpressionInner::Identifier(variable.id.clone().into())
.annotate(*array_type.ty, array_type.size)
.into()
}
Type::Struct(members) => {
StructExpressionInner::Identifier(variable.id.clone().into())
.annotate(members)
.into()
}
};
let base = self.fold_expression(base);
let indices = indices
.into_iter()
.map(|a| match a {
Access::Select(i) => Access::Select(self.fold_field_expression(i)),
a => a,
})
.collect();
let mut range_checks = HashSet::new();
let e = Self::choose_many(base, indices, expr, &mut range_checks);
range_checks
.into_iter()
.chain(std::iter::once(TypedStatement::Definition(
TypedAssignee::Identifier(variable),
e,
)))
.collect()
}
s => fold_statement(self, s),
}
}
}

View file

@ -32,26 +32,26 @@ pub struct Identifier<'ast> {
impl<'ast> fmt::Display for Identifier<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.stack.len() == 0 && self.version == 0 {
write!(f, "{}", self.id)
} else {
write!(
f,
"{}_{}_{}",
self.stack
.iter()
.map(|(name, key_hash, count)| format!(
"{}_{}_{}",
name.display(),
key_hash,
count
))
.collect::<Vec<_>>()
.join("_"),
self.id,
self.version
)
}
//if self.stack.len() == 0 && self.version == 0 {
write!(f, "{}", self.id)
// } else {
// write!(
// f,
// "{}_{}_{}",
// self.stack
// .iter()
// .map(|(name, key_hash, count)| format!(
// "{}_{}_{}",
// name.display(),
// key_hash,
// count
// ))
// .collect::<Vec<_>>()
// .join("_"),
// self.id,
// self.version
// )
// }
}
}

View file

@ -773,6 +773,10 @@ impl<'ast, T> StructExpression<'ast, T> {
&self.inner
}
pub fn as_inner_mut(&mut self) -> &mut StructExpressionInner<'ast, T> {
&mut self.inner
}
pub fn into_inner(self) -> StructExpressionInner<'ast, T> {
self.inner
}