add loop unrolling, post flattening constant propagation
This commit is contained in:
parent
a21cbe7c30
commit
8b2319b0cb
16 changed files with 692 additions and 88 deletions
|
@ -2,4 +2,8 @@ def main() -> (field):
|
|||
field a = 1 + 2 + 3
|
||||
field b = if 1 < a then 3 else a + 3 fi
|
||||
field c = if b + a == 2 then 1 else b fi
|
||||
return c
|
||||
for field e in 0..2 do
|
||||
field g = 4
|
||||
c = c + g
|
||||
endfor
|
||||
return c * a
|
7
zokrates_cli/examples/propagate_call.code
Normal file
7
zokrates_cli/examples/propagate_call.code
Normal file
|
@ -0,0 +1,7 @@
|
|||
def foo(field a, field b) -> (field, field):
|
||||
a == b + 2
|
||||
return a, b
|
||||
|
||||
def main() -> (field):
|
||||
a, b = foo(1, 1)
|
||||
return a + b
|
|
@ -1,8 +1,10 @@
|
|||
// we can compare numbers up to 2^(pbits - 2) - 1, ie any number which fits in (pbits - 2) bits
|
||||
// lt should not work for the maxvalue = 2^(pbits - 2) - 1 augmented by one
|
||||
// /!\ should be called with a = 0
|
||||
|
||||
def main(field a) -> (field):
|
||||
field pbits = 254
|
||||
// maxvalue = 2**252 - 1
|
||||
field maxvalue = 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1
|
||||
field maxvalue = a + 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1
|
||||
// we added a = 0 to prevent the condition to be evaluated at compile time
|
||||
return if 0 < (maxvalue + 1) then 1 else 0 fi
|
|
@ -1,5 +1,7 @@
|
|||
// as p - 1 is greater than p/2, comparing to it should fail
|
||||
// /!\ should be called with a = 0
|
||||
|
||||
def main(field a) -> (field):
|
||||
field p = 21888242871839275222246405745257275088548364400416034343698204186575808495617
|
||||
field p = 21888242871839275222246405745257275088548364400416034343698204186575808495617 + a
|
||||
// we added a = 0 to prevent the condition to be evaluated at compile time
|
||||
return if 0 < p - 1 then 1 else 0 fi
|
|
@ -14,7 +14,7 @@ use semantics::{self, Checker};
|
|||
use optimizer::{Optimizer};
|
||||
use flatten::Flattener;
|
||||
use std::io::{self};
|
||||
use propagation::Propagate;
|
||||
use static_analysis::Analyse;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CompileError<T: Field> {
|
||||
|
@ -81,13 +81,16 @@ pub fn compile_aux<T: Field, R: BufRead, S: BufRead, E: Into<imports::Error>>(re
|
|||
// check semantics
|
||||
let typed_ast = Checker::new().check_program(program_ast)?;
|
||||
|
||||
// optimize (constant propagation)
|
||||
let typed_ast = typed_ast.propagate();
|
||||
// analyse (unroll and constant propagation)
|
||||
let typed_ast = typed_ast.analyse();
|
||||
|
||||
// flatten input program
|
||||
let program_flattened =
|
||||
Flattener::new(T::get_required_bits()).flatten_program(typed_ast);
|
||||
|
||||
// analyse (constant propagation after call resolution)
|
||||
let program_flattened = program_flattened.analyse();
|
||||
|
||||
Ok(program_flattened)
|
||||
}
|
||||
|
||||
|
|
|
@ -115,7 +115,7 @@ impl<T: Field> FlatFunction<T> {
|
|||
}
|
||||
},
|
||||
FlatStatement::Directive(ref d) => {
|
||||
let input_values: Vec<T> = d.inputs.iter().map(|i| witness.get(i).unwrap().clone()).collect();
|
||||
let input_values: Vec<T> = d.inputs.iter().map(|i| i.solve(&mut witness)).collect();
|
||||
match d.helper.execute(&input_values) {
|
||||
Ok(res) => {
|
||||
for (i, o) in d.outputs.iter().enumerate() {
|
||||
|
@ -188,7 +188,7 @@ pub enum FlatStatement<T: Field> {
|
|||
Return(FlatExpressionList<T>),
|
||||
Condition(FlatExpression<T>, FlatExpression<T>),
|
||||
Definition(FlatVariable, FlatExpression<T>),
|
||||
Directive(DirectiveStatement)
|
||||
Directive(DirectiveStatement<T>)
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for FlatStatement<T> {
|
||||
|
@ -239,7 +239,7 @@ impl<T: Field> FlatStatement<T> {
|
|||
},
|
||||
FlatStatement::Directive(d) => {
|
||||
let outputs = d.outputs.into_iter().map(|o| substitution.get(&o).unwrap_or(o)).collect();
|
||||
let inputs = d.inputs.into_iter().map(|i| substitution.get(&i).unwrap()).collect();
|
||||
let inputs = d.inputs.into_iter().map(|i| i.apply_substitution(substitution, should_fallback)).collect();
|
||||
|
||||
FlatStatement::Directive(
|
||||
DirectiveStatement {
|
||||
|
@ -295,7 +295,7 @@ impl<T: Field> FlatExpression<T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn solve(&self, inputs: &mut BTreeMap<FlatVariable, T>) -> T {
|
||||
pub fn solve(&self, inputs: &mut BTreeMap<FlatVariable, T>) -> T {
|
||||
match *self {
|
||||
FlatExpression::Number(ref x) => x.clone(),
|
||||
FlatExpression::Identifier(ref var) => {
|
||||
|
|
|
@ -336,11 +336,11 @@ impl Flattener {
|
|||
|
||||
statements_flattened.push(FlatStatement::Definition(name_x, x));
|
||||
statements_flattened.push(
|
||||
FlatStatement::Directive(DirectiveStatement {
|
||||
outputs: vec![name_y, name_m],
|
||||
inputs: vec![name_x],
|
||||
helper: Helper::Rust(RustHelper::ConditionEq)
|
||||
})
|
||||
FlatStatement::Directive(DirectiveStatement::new(
|
||||
vec![name_y, name_m],
|
||||
Helper::Rust(RustHelper::ConditionEq),
|
||||
vec![name_x],
|
||||
))
|
||||
);
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(name_y),
|
||||
|
@ -435,13 +435,13 @@ impl Flattener {
|
|||
replacement_map.insert(o, new_o);
|
||||
new_o
|
||||
}).collect();
|
||||
let new_inputs = d.inputs.iter().map(|i| replacement_map.get(&i).unwrap()).collect();
|
||||
let new_inputs = d.inputs.into_iter().map(|i| i.apply_direct_substitution(&replacement_map)).collect();
|
||||
statements_flattened.push(
|
||||
FlatStatement::Directive(
|
||||
DirectiveStatement {
|
||||
outputs: new_outputs,
|
||||
helper: d.helper.clone(),
|
||||
inputs: new_inputs,
|
||||
helper: d.helper.clone()
|
||||
}
|
||||
)
|
||||
)
|
||||
|
|
|
@ -7,30 +7,30 @@ pub use self::libsnark_gadget::LibsnarkGadgetHelper;
|
|||
pub use self::rust::RustHelper;
|
||||
use std::fmt;
|
||||
use field::{Field};
|
||||
use flat_absy::FlatVariable;
|
||||
use flat_absy::{FlatExpression, FlatVariable};
|
||||
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
|
||||
pub struct DirectiveStatement {
|
||||
pub inputs: Vec<FlatVariable>,
|
||||
pub struct DirectiveStatement<T: Field> {
|
||||
pub inputs: Vec<FlatExpression<T>>,
|
||||
pub outputs: Vec<FlatVariable>,
|
||||
pub helper: Helper
|
||||
}
|
||||
|
||||
impl DirectiveStatement {
|
||||
impl<T: Field> DirectiveStatement<T> {
|
||||
pub fn new(outputs: Vec<FlatVariable>, helper: Helper, inputs: Vec<FlatVariable>) -> Self {
|
||||
let (in_len, out_len) = helper.get_signature();
|
||||
assert_eq!(in_len, inputs.len());
|
||||
assert_eq!(out_len, outputs.len());
|
||||
DirectiveStatement {
|
||||
helper,
|
||||
inputs,
|
||||
inputs: inputs.into_iter().map(|i| FlatExpression::Identifier(i)).collect(),
|
||||
outputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for DirectiveStatement {
|
||||
impl<T: Field> fmt::Display for DirectiveStatement<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "# {} = {}({})",
|
||||
self.outputs.iter().map(|o| o.to_string()).collect::<Vec<String>>().join(", "),
|
||||
|
|
|
@ -24,7 +24,7 @@ mod standard;
|
|||
mod helpers;
|
||||
mod types;
|
||||
mod typed_absy;
|
||||
mod propagation;
|
||||
mod static_analysis;
|
||||
|
||||
pub mod absy;
|
||||
pub mod flat_absy;
|
||||
|
|
|
@ -117,7 +117,7 @@ pub fn parse_statement<T: Field, R: BufRead>(
|
|||
}
|
||||
},
|
||||
Some(Ok(ref x)) if !x.trim().starts_with("return") => match parse_statement(lines, x, &Position { line: current_line, col: 1 }) {
|
||||
Ok((statement, ..)) => statements.push(statement[0].clone()),
|
||||
Ok((mut statement, ..)) => statements.append(&mut statement),
|
||||
Err(err) => return Err(err),
|
||||
},
|
||||
Some(Err(err)) => panic!("Error while reading Definitions: {}", err),
|
||||
|
|
|
@ -80,8 +80,8 @@ impl FunctionQuery {
|
|||
})
|
||||
}
|
||||
|
||||
fn match_funcs(&self, funcs: Vec<FunctionDeclaration>) -> Vec<FunctionDeclaration> {
|
||||
funcs.into_iter().filter(|func| self.match_func(func)).collect()
|
||||
fn match_funcs(&self, funcs: &HashSet<FunctionDeclaration>) -> Vec<FunctionDeclaration> {
|
||||
funcs.iter().filter(|func| self.match_func(func)).cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -538,7 +538,7 @@ impl Checker {
|
|||
}
|
||||
|
||||
fn find_candidates(&self, query: &FunctionQuery) -> Vec<FunctionDeclaration> {
|
||||
query.match_funcs(self.functions.clone().into_iter().collect())
|
||||
query.match_funcs(&self.functions)
|
||||
}
|
||||
|
||||
fn enter_scope(&mut self) -> () {
|
||||
|
|
174
zokrates_core/src/static_analysis/flat_propagation.rs
Normal file
174
zokrates_core/src/static_analysis/flat_propagation.rs
Normal file
|
@ -0,0 +1,174 @@
|
|||
//! Module containing constant propagation
|
||||
//!
|
||||
//! @file propagation.rs
|
||||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
|
||||
use helpers::DirectiveStatement;
|
||||
use flat_absy::*;
|
||||
use std::collections::HashMap;
|
||||
use field::Field;
|
||||
|
||||
pub trait Propagate<T: Field> {
|
||||
fn propagate(self) -> Self;
|
||||
}
|
||||
|
||||
pub trait PropagateWithContext<T: Field> {
|
||||
fn propagate(self, constants: &mut HashMap<FlatVariable, T>) -> Self;
|
||||
}
|
||||
|
||||
impl<T: Field> PropagateWithContext<T> for FlatExpression<T> {
|
||||
fn propagate(self, constants: &mut HashMap<FlatVariable, T>) -> FlatExpression<T> {
|
||||
match self {
|
||||
FlatExpression::Identifier(id) => {
|
||||
match constants.get(&id) {
|
||||
Some(c) => FlatExpression::Number(c.clone()),
|
||||
None => FlatExpression::Identifier(id)
|
||||
}
|
||||
},
|
||||
FlatExpression::Add(box e1, box e2) => {
|
||||
match (e1.propagate(constants), e2.propagate(constants)) {
|
||||
(FlatExpression::Number(n1), FlatExpression::Number(n2)) => FlatExpression::Number(n1 + n2),
|
||||
(e1, e2) => FlatExpression::Add(box e1, box e2),
|
||||
}
|
||||
},
|
||||
FlatExpression::Sub(box e1, box e2) => {
|
||||
match (e1.propagate(constants), e2.propagate(constants)) {
|
||||
(FlatExpression::Number(n1), FlatExpression::Number(n2)) => FlatExpression::Number(n1 - n2),
|
||||
(e1, e2) => FlatExpression::Sub(box e1, box e2),
|
||||
}
|
||||
},
|
||||
FlatExpression::Mult(box e1, box e2) => {
|
||||
match (e1.propagate(constants), e2.propagate(constants)) {
|
||||
(FlatExpression::Number(n1), FlatExpression::Number(n2)) => FlatExpression::Number(n1 * n2),
|
||||
(e1, e2) => FlatExpression::Mult(box e1, box e2),
|
||||
}
|
||||
},
|
||||
FlatExpression::Div(box e1, box e2) => {
|
||||
match (e1.propagate(constants), e2.propagate(constants)) {
|
||||
(FlatExpression::Number(n1), FlatExpression::Number(n2)) => FlatExpression::Number(n1 / n2),
|
||||
(e1, e2) => FlatExpression::Div(box e1, box e2),
|
||||
}
|
||||
},
|
||||
_ => self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> FlatStatement<T> {
|
||||
fn propagate(self, constants: &mut HashMap<FlatVariable, T>) -> Option<FlatStatement<T>> {
|
||||
match self {
|
||||
FlatStatement::Return(list) => Some(FlatStatement::Return(FlatExpressionList {
|
||||
expressions: list.expressions.into_iter().map(|e| e.propagate(constants)).collect()
|
||||
})),
|
||||
FlatStatement::Definition(var, expr) => {
|
||||
match expr.propagate(constants) {
|
||||
FlatExpression::Number(n) => {
|
||||
constants.insert(var, n);
|
||||
None
|
||||
},
|
||||
e => {
|
||||
Some(FlatStatement::Definition(var, e))
|
||||
}
|
||||
}
|
||||
},
|
||||
FlatStatement::Condition(e1, e2) => {
|
||||
// could stop execution here if condition is known to fail...
|
||||
Some(FlatStatement::Condition(e1.propagate(constants), e2.propagate(constants)))
|
||||
},
|
||||
FlatStatement::Directive(d) => {
|
||||
Some(FlatStatement::Directive(
|
||||
DirectiveStatement {
|
||||
inputs: d.inputs.into_iter().map(|i| i.propagate(constants)).collect(),
|
||||
..d
|
||||
}
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Propagate<T> for FlatFunction<T> {
|
||||
fn propagate(self) -> FlatFunction<T> {
|
||||
|
||||
let mut constants = HashMap::new();
|
||||
|
||||
FlatFunction {
|
||||
statements: self.statements.into_iter().filter_map(|s| s.propagate(&mut constants)).collect(),
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> FlatProg<T> {
|
||||
pub fn propagate(self) -> FlatProg<T> {
|
||||
|
||||
let mut functions = vec![];
|
||||
|
||||
for f in self.functions {
|
||||
let fun = f.propagate();
|
||||
functions.push(fun);
|
||||
}
|
||||
|
||||
FlatProg {
|
||||
functions,
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use field::FieldPrime;
|
||||
|
||||
#[cfg(test)]
|
||||
mod expression {
|
||||
use super::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod field {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn add() {
|
||||
let e = FlatExpression::Add(
|
||||
box FlatExpression::Number(FieldPrime::from(2)),
|
||||
box FlatExpression::Number(FieldPrime::from(3))
|
||||
);
|
||||
|
||||
assert_eq!(e.propagate(&mut HashMap::new()), FlatExpression::Number(FieldPrime::from(5)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub() {
|
||||
let e = FlatExpression::Sub(
|
||||
box FlatExpression::Number(FieldPrime::from(3)),
|
||||
box FlatExpression::Number(FieldPrime::from(2))
|
||||
);
|
||||
|
||||
assert_eq!(e.propagate(&mut HashMap::new()), FlatExpression::Number(FieldPrime::from(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mult() {
|
||||
let e = FlatExpression::Mult(
|
||||
box FlatExpression::Number(FieldPrime::from(3)),
|
||||
box FlatExpression::Number(FieldPrime::from(2))
|
||||
);
|
||||
|
||||
assert_eq!(e.propagate(&mut HashMap::new()), FlatExpression::Number(FieldPrime::from(6)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn div() {
|
||||
let e = FlatExpression::Div(
|
||||
box FlatExpression::Number(FieldPrime::from(6)),
|
||||
box FlatExpression::Number(FieldPrime::from(2))
|
||||
);
|
||||
|
||||
assert_eq!(e.propagate(&mut HashMap::new()), FlatExpression::Number(FieldPrime::from(3)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
30
zokrates_core/src/static_analysis/mod.rs
Normal file
30
zokrates_core/src/static_analysis/mod.rs
Normal file
|
@ -0,0 +1,30 @@
|
|||
//! Module containing static analysis
|
||||
//!
|
||||
//! @file mod.rs
|
||||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
|
||||
mod propagation;
|
||||
mod unroll;
|
||||
mod flat_propagation;
|
||||
|
||||
use flat_absy::FlatProg;
|
||||
use field::Field;
|
||||
use typed_absy::TypedProg;
|
||||
use self::unroll::Unroll;
|
||||
|
||||
pub trait Analyse {
|
||||
fn analyse(self) -> Self;
|
||||
}
|
||||
|
||||
impl<T: Field> Analyse for TypedProg<T> {
|
||||
fn analyse(self) -> Self {
|
||||
self.unroll().propagate()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Analyse for FlatProg<T> {
|
||||
fn analyse(self) -> Self {
|
||||
self.propagate()
|
||||
}
|
||||
}
|
|
@ -1,27 +1,33 @@
|
|||
//! Module containing constant propagation
|
||||
//!
|
||||
//! @file propagation.rs
|
||||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
|
||||
use absy::variable::Variable;
|
||||
use std::collections::HashMap;
|
||||
use field::Field;
|
||||
use typed_absy::*;
|
||||
|
||||
pub trait Propagate {
|
||||
fn propagate(self) -> Self;
|
||||
pub trait Propagate<T: Field> {
|
||||
fn propagate(self, functions: &Vec<TypedFunction<T>>) -> Self;
|
||||
}
|
||||
|
||||
pub trait PropagateWithContext<T: Field> {
|
||||
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>) -> Self;
|
||||
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> Self;
|
||||
}
|
||||
|
||||
impl<T: Field> PropagateWithContext<T> for TypedExpression<T> {
|
||||
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>) -> TypedExpression<T> {
|
||||
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> TypedExpression<T> {
|
||||
match self {
|
||||
TypedExpression::FieldElement(e) => e.propagate(constants).into(),
|
||||
TypedExpression::Boolean(e) => e.propagate(constants).into(),
|
||||
TypedExpression::FieldElement(e) => e.propagate(constants, functions).into(),
|
||||
TypedExpression::Boolean(e) => e.propagate(constants, functions).into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> PropagateWithContext<T> for FieldElementExpression<T> {
|
||||
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>) -> FieldElementExpression<T> {
|
||||
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> FieldElementExpression<T> {
|
||||
match self {
|
||||
FieldElementExpression::Identifier(id) => {
|
||||
match constants.get(&Variable::field_element(id.clone())) {
|
||||
|
@ -33,51 +39,89 @@ impl<T: Field> PropagateWithContext<T> for FieldElementExpression<T> {
|
|||
}
|
||||
},
|
||||
FieldElementExpression::Add(box e1, box e2) => {
|
||||
match (e1.propagate(constants), e2.propagate(constants)) {
|
||||
match (e1.propagate(constants, functions), e2.propagate(constants, functions)) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => FieldElementExpression::Number(n1 + n2),
|
||||
(e1, e2) => FieldElementExpression::Add(box e1, box e2),
|
||||
}
|
||||
},
|
||||
FieldElementExpression::Sub(box e1, box e2) => {
|
||||
match (e1.propagate(constants), e2.propagate(constants)) {
|
||||
match (e1.propagate(constants, functions), e2.propagate(constants, functions)) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => FieldElementExpression::Number(n1 - n2),
|
||||
(e1, e2) => FieldElementExpression::Sub(box e1, box e2),
|
||||
}
|
||||
},
|
||||
FieldElementExpression::Mult(box e1, box e2) => {
|
||||
match (e1.propagate(constants), e2.propagate(constants)) {
|
||||
match (e1.propagate(constants, functions), e2.propagate(constants, functions)) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => FieldElementExpression::Number(n1 * n2),
|
||||
(e1, e2) => FieldElementExpression::Mult(box e1, box e2),
|
||||
}
|
||||
},
|
||||
FieldElementExpression::Div(box e1, box e2) => {
|
||||
match (e1.propagate(constants), e2.propagate(constants)) {
|
||||
match (e1.propagate(constants, functions), e2.propagate(constants, functions)) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => FieldElementExpression::Number(n1 / n2),
|
||||
(e1, e2) => FieldElementExpression::Div(box e1, box e2),
|
||||
}
|
||||
},
|
||||
FieldElementExpression::Pow(box e1, box e2) => {
|
||||
match (e1.propagate(constants), e2.propagate(constants)) {
|
||||
match (e1.propagate(constants, functions), e2.propagate(constants, functions)) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => FieldElementExpression::Number(n1.pow(n2)),
|
||||
(e1, e2) => FieldElementExpression::Pow(box e1, box e2),
|
||||
}
|
||||
},
|
||||
FieldElementExpression::IfElse(box condition, box consequence, box alternative) => {
|
||||
let consequence = consequence.propagate(constants);
|
||||
let alternative = alternative.propagate(constants);
|
||||
match condition.propagate(constants) {
|
||||
let consequence = consequence.propagate(constants, functions);
|
||||
let alternative = alternative.propagate(constants, functions);
|
||||
match condition.propagate(constants, functions) {
|
||||
BooleanExpression::Value(true) => consequence,
|
||||
BooleanExpression::Value(false) => alternative,
|
||||
c => FieldElementExpression::IfElse(box c, box consequence, box alternative)
|
||||
}
|
||||
},
|
||||
FieldElementExpression::FunctionCall(id, arguments) => {
|
||||
// We only cover the case where all arguments are constants, therefore the call is guaranteed to be constant
|
||||
// Propagate the arguments, then return Ok if they're constant, Err otherwise
|
||||
|
||||
let arguments = arguments.into_iter().map(|a| a.propagate(constants, functions)).collect();
|
||||
|
||||
// let f = functions[0]; // TODO find functon based on id
|
||||
// match f.execute(arguments) {
|
||||
// Ok(expressions) => expressions,
|
||||
// _ => FieldElementExpression::FunctionCall(id, arguments),
|
||||
// }
|
||||
|
||||
FieldElementExpression::FunctionCall(id, arguments)
|
||||
|
||||
// let each_argument_constant = arguments.into_iter().map(|a| a.propagate(constants, functions)).map(|a| match a {
|
||||
// a @ TypedExpression::FieldElement(FieldElementExpression::Number(..)) => Ok(a),
|
||||
// a @ TypedExpression::Boolean(BooleanExpression::Value(..)) => Ok(a),
|
||||
// a => Err(a)
|
||||
// });
|
||||
|
||||
// let all_arguments_constant = each_argument_constant.collect::<Result<Vec<_>, _>>();
|
||||
|
||||
// match all_arguments_constant {
|
||||
// Ok(arguments) => {
|
||||
// // all arguments are constant, we can execute the function now
|
||||
// unimplemented!()
|
||||
// },
|
||||
// Err(_) => {
|
||||
// // not all arguments are constant, keep the function call
|
||||
// let arguments = each_argument_constant.into_iter().map(|a| match a {
|
||||
// Ok(a) => a,
|
||||
// Err(a) => a
|
||||
// }).collect();
|
||||
|
||||
//
|
||||
// }
|
||||
// }
|
||||
}
|
||||
_ => self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
|
||||
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>) -> BooleanExpression<T> {
|
||||
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> BooleanExpression<T> {
|
||||
match self {
|
||||
BooleanExpression::Identifier(id) => {
|
||||
match constants.get(&Variable::boolean(id.clone())) {
|
||||
|
@ -89,8 +133,8 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
|
|||
}
|
||||
},
|
||||
BooleanExpression::Eq(box e1, box e2) => {
|
||||
let e1 = e1.propagate(constants);
|
||||
let e2 = e2.propagate(constants);
|
||||
let e1 = e1.propagate(constants, functions);
|
||||
let e2 = e2.propagate(constants, functions);
|
||||
|
||||
match (e1, e2) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
|
@ -100,8 +144,8 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
|
|||
}
|
||||
}
|
||||
BooleanExpression::Lt(box e1, box e2) => {
|
||||
let e1 = e1.propagate(constants);
|
||||
let e2 = e2.propagate(constants);
|
||||
let e1 = e1.propagate(constants, functions);
|
||||
let e2 = e2.propagate(constants, functions);
|
||||
|
||||
match (e1, e2) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
|
@ -111,8 +155,8 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
|
|||
}
|
||||
}
|
||||
BooleanExpression::Le(box e1, box e2) => {
|
||||
let e1 = e1.propagate(constants);
|
||||
let e2 = e2.propagate(constants);
|
||||
let e1 = e1.propagate(constants, functions);
|
||||
let e2 = e2.propagate(constants, functions);
|
||||
|
||||
match (e1, e2) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
|
@ -122,8 +166,8 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
|
|||
}
|
||||
}
|
||||
BooleanExpression::Gt(box e1, box e2) => {
|
||||
let e1 = e1.propagate(constants);
|
||||
let e2 = e2.propagate(constants);
|
||||
let e1 = e1.propagate(constants, functions);
|
||||
let e2 = e2.propagate(constants, functions);
|
||||
|
||||
match (e1, e2) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
|
@ -133,8 +177,8 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
|
|||
}
|
||||
}
|
||||
BooleanExpression::Ge(box e1, box e2) => {
|
||||
let e1 = e1.propagate(constants);
|
||||
let e2 = e2.propagate(constants);
|
||||
let e1 = e1.propagate(constants, functions);
|
||||
let e2 = e2.propagate(constants, functions);
|
||||
|
||||
match (e1, e2) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
|
@ -148,12 +192,22 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> TypedStatement<T> {
|
||||
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>) -> Option<TypedStatement<T>> {
|
||||
impl<T: Field> TypedExpressionList<T> {
|
||||
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> TypedExpressionList<T> {
|
||||
match self {
|
||||
TypedStatement::Return(expressions) => Some(TypedStatement::Return(expressions.into_iter().map(|e| e.propagate(constants)).collect())),
|
||||
TypedExpressionList::FunctionCall(id, arguments, types) => {
|
||||
TypedExpressionList::FunctionCall(id, arguments.into_iter().map(|e| e.propagate(constants, functions)).collect(), types)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> TypedStatement<T> {
|
||||
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> Option<TypedStatement<T>> {
|
||||
match self {
|
||||
TypedStatement::Return(expressions) => Some(TypedStatement::Return(expressions.into_iter().map(|e| e.propagate(constants, functions)).collect())),
|
||||
TypedStatement::Definition(var, expr) => {
|
||||
match expr.propagate(constants) {
|
||||
match expr.propagate(constants, functions) {
|
||||
e @ TypedExpression::Boolean(BooleanExpression::Value(..)) | e @ TypedExpression::FieldElement(FieldElementExpression::Number(..)) => {
|
||||
constants.insert(var, e);
|
||||
None
|
||||
|
@ -165,30 +219,42 @@ impl<T: Field> TypedStatement<T> {
|
|||
},
|
||||
TypedStatement::Condition(e1, e2) => {
|
||||
// could stop execution here if condition is known to fail...
|
||||
Some(TypedStatement::Condition(e1.propagate(constants), e2.propagate(constants)))
|
||||
Some(TypedStatement::Condition(e1.propagate(constants, functions), e2.propagate(constants, functions)))
|
||||
},
|
||||
TypedStatement::For(v, from, to, stats) => Some(TypedStatement::For(v, from, to, stats.into_iter().filter_map(|s| s.propagate(constants)).collect())),
|
||||
TypedStatement::For(..) => panic!("no for expected"),
|
||||
TypedStatement::MultipleDefinition(variables, expression_list) => {
|
||||
let expression_list = expression_list.propagate(constants, functions);
|
||||
Some(TypedStatement::MultipleDefinition(variables, expression_list))
|
||||
}
|
||||
_ => Some(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Propagate for TypedFunction<T> {
|
||||
fn propagate(self) -> TypedFunction<T> {
|
||||
impl<T: Field> Propagate<T> for TypedFunction<T> {
|
||||
fn propagate(self, functions: &Vec<TypedFunction<T>>) -> TypedFunction<T> {
|
||||
|
||||
let mut constants = HashMap::new();
|
||||
|
||||
TypedFunction {
|
||||
statements: self.statements.into_iter().filter_map(|s| s.propagate(&mut constants)).collect(),
|
||||
statements: self.statements.into_iter().filter_map(|s| s.propagate(&mut constants, functions)).collect(),
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Propagate for TypedProg<T> {
|
||||
fn propagate(self) -> TypedProg<T> {
|
||||
impl<T: Field> TypedProg<T> {
|
||||
pub fn propagate(self) -> TypedProg<T> {
|
||||
|
||||
let mut functions = vec![];
|
||||
|
||||
for f in self.functions {
|
||||
let fun = f.propagate(&mut functions);
|
||||
functions.push(fun);
|
||||
}
|
||||
|
||||
TypedProg {
|
||||
functions: self.functions.into_iter().map(|f| f.propagate()).collect(),
|
||||
functions,
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
@ -214,7 +280,7 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(3))
|
||||
);
|
||||
|
||||
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(5)));
|
||||
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(5)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -224,7 +290,7 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(2))
|
||||
);
|
||||
|
||||
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(1)));
|
||||
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -234,7 +300,7 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(2))
|
||||
);
|
||||
|
||||
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(6)));
|
||||
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(6)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -244,7 +310,7 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(2))
|
||||
);
|
||||
|
||||
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(3)));
|
||||
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(3)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -254,7 +320,7 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(3))
|
||||
);
|
||||
|
||||
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(8)));
|
||||
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(8)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -265,7 +331,7 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(3))
|
||||
);
|
||||
|
||||
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(2)));
|
||||
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -276,7 +342,7 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(3))
|
||||
);
|
||||
|
||||
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(3)));
|
||||
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(3)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -296,8 +362,8 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(2))
|
||||
);
|
||||
|
||||
assert_eq!(e_true.propagate(&mut HashMap::new()), BooleanExpression::Value(true));
|
||||
assert_eq!(e_false.propagate(&mut HashMap::new()), BooleanExpression::Value(false));
|
||||
assert_eq!(e_true.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(true));
|
||||
assert_eq!(e_false.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -312,8 +378,8 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(2))
|
||||
);
|
||||
|
||||
assert_eq!(e_true.propagate(&mut HashMap::new()), BooleanExpression::Value(true));
|
||||
assert_eq!(e_false.propagate(&mut HashMap::new()), BooleanExpression::Value(false));
|
||||
assert_eq!(e_true.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(true));
|
||||
assert_eq!(e_false.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -328,8 +394,8 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(2))
|
||||
);
|
||||
|
||||
assert_eq!(e_true.propagate(&mut HashMap::new()), BooleanExpression::Value(true));
|
||||
assert_eq!(e_false.propagate(&mut HashMap::new()), BooleanExpression::Value(false));
|
||||
assert_eq!(e_true.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(true));
|
||||
assert_eq!(e_false.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -344,8 +410,8 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(5))
|
||||
);
|
||||
|
||||
assert_eq!(e_true.propagate(&mut HashMap::new()), BooleanExpression::Value(true));
|
||||
assert_eq!(e_false.propagate(&mut HashMap::new()), BooleanExpression::Value(false));
|
||||
assert_eq!(e_true.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(true));
|
||||
assert_eq!(e_false.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -360,8 +426,8 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(5))
|
||||
);
|
||||
|
||||
assert_eq!(e_true.propagate(&mut HashMap::new()), BooleanExpression::Value(true));
|
||||
assert_eq!(e_false.propagate(&mut HashMap::new()), BooleanExpression::Value(false));
|
||||
assert_eq!(e_true.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(true));
|
||||
assert_eq!(e_false.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(false));
|
||||
}
|
||||
}
|
||||
}
|
316
zokrates_core/src/static_analysis/unroll.rs
Normal file
316
zokrates_core/src/static_analysis/unroll.rs
Normal file
|
@ -0,0 +1,316 @@
|
|||
//! Module containing SSA reduction, including for-loop unrolling
|
||||
//!
|
||||
//! @file unroll.rs
|
||||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
|
||||
use absy::parameter::Parameter;
|
||||
use absy::variable::Variable;
|
||||
use std::collections::HashMap;
|
||||
use field::Field;
|
||||
use typed_absy::*;
|
||||
|
||||
pub trait Unroll {
|
||||
fn unroll(self) -> Self;
|
||||
}
|
||||
|
||||
pub trait UnrollWithContext<T: Field> {
|
||||
fn unroll(self, substitution: &mut HashMap<String, usize>) -> Self;
|
||||
}
|
||||
|
||||
impl<T: Field> TypedExpression<T> {
|
||||
fn unroll(self, substitution: &HashMap<String, usize>) -> TypedExpression<T> {
|
||||
match self {
|
||||
TypedExpression::FieldElement(e) => e.unroll(substitution).into(),
|
||||
TypedExpression::Boolean(e) => e.unroll(substitution).into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> FieldElementExpression<T> {
|
||||
fn unroll(self, substitution: &HashMap<String, usize>) -> FieldElementExpression<T> {
|
||||
match self {
|
||||
FieldElementExpression::Identifier(id) => FieldElementExpression::Identifier(format!("{}_{}", id, substitution.get(&id).unwrap().clone())),
|
||||
FieldElementExpression::Number(n) => FieldElementExpression::Number(n),
|
||||
FieldElementExpression::Add(box e1, box e2) => FieldElementExpression::Add(box e1.unroll(substitution), box e2.unroll(substitution)),
|
||||
FieldElementExpression::Sub(box e1, box e2) => FieldElementExpression::Sub(box e1.unroll(substitution), box e2.unroll(substitution)),
|
||||
FieldElementExpression::Mult(box e1, box e2) => FieldElementExpression::Mult(box e1.unroll(substitution), box e2.unroll(substitution)),
|
||||
FieldElementExpression::Div(box e1, box e2) => FieldElementExpression::Div(box e1.unroll(substitution), box e2.unroll(substitution)),
|
||||
FieldElementExpression::Pow(box e1, box e2) => FieldElementExpression::Div(box e1.unroll(substitution), box e2.unroll(substitution)),
|
||||
FieldElementExpression::IfElse(box cond, box cons, box alt) => FieldElementExpression::IfElse(box cond.unroll(substitution), box cons.unroll(substitution), box alt.unroll(substitution)),
|
||||
FieldElementExpression::FunctionCall(id, args) => FieldElementExpression::FunctionCall(id, args.into_iter().map(|a| a.unroll(substitution)).collect()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> BooleanExpression<T> {
|
||||
fn unroll(self, substitution: &HashMap<String, usize>) -> BooleanExpression<T> {
|
||||
match self {
|
||||
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(format!("{}_{}", id, substitution.get(&id).unwrap().clone())),
|
||||
BooleanExpression::Value(v) => BooleanExpression::Value(v),
|
||||
BooleanExpression::Eq(box e1, box e2) => BooleanExpression::Eq(box e1.unroll(substitution), box e2.unroll(substitution)),
|
||||
BooleanExpression::Lt(box e1, box e2) => BooleanExpression::Lt(box e1.unroll(substitution), box e2.unroll(substitution)),
|
||||
BooleanExpression::Le(box e1, box e2) => BooleanExpression::Le(box e1.unroll(substitution), box e2.unroll(substitution)),
|
||||
BooleanExpression::Gt(box e1, box e2) => BooleanExpression::Gt(box e1.unroll(substitution), box e2.unroll(substitution)),
|
||||
BooleanExpression::Ge(box e1, box e2) => BooleanExpression::Ge(box e1.unroll(substitution), box e2.unroll(substitution)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> TypedExpressionList<T> {
|
||||
fn unroll(self, substitution: &HashMap<String, usize>) -> TypedExpressionList<T> {
|
||||
match self {
|
||||
TypedExpressionList::FunctionCall(id, arguments, types) => {
|
||||
TypedExpressionList::FunctionCall(id, arguments.into_iter().map(|a| a.unroll(substitution)).collect(), types)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<T: Field> TypedStatement<T> {
|
||||
fn unroll(self, substitution: &mut HashMap<String, usize>) -> Vec<TypedStatement<T>> {
|
||||
match self {
|
||||
TypedStatement::Declaration(_) => {
|
||||
vec![]
|
||||
},
|
||||
TypedStatement::Definition(variable, expr) => {
|
||||
let expr = expr.unroll(substitution);
|
||||
|
||||
let res = match substitution.get(&variable.id) {
|
||||
Some(i) => {
|
||||
vec![TypedStatement::Definition(Variable { id: format!("{}_{}", variable.id, i + 1), ..variable}, expr)]
|
||||
},
|
||||
None => {
|
||||
vec![TypedStatement::Definition(Variable { id: format!("{}_{}", variable.id, 0), ..variable}, expr)]
|
||||
}
|
||||
};
|
||||
substitution.entry(variable.id)
|
||||
.and_modify(|e| { *e += 1 })
|
||||
.or_insert(0);
|
||||
res
|
||||
},
|
||||
TypedStatement::MultipleDefinition(variables, exprs) => {
|
||||
let exprs = exprs.unroll(substitution);
|
||||
let variables = variables.into_iter().map(|v| {
|
||||
let res = match substitution.get(&v.id) {
|
||||
Some(i) => {
|
||||
Variable { id: format!("{}_{}", v.id, i + 1), ..v}
|
||||
},
|
||||
None => {
|
||||
Variable { id: format!("{}_{}", v.id, 0), ..v}
|
||||
}
|
||||
};
|
||||
substitution.entry(v.id)
|
||||
.and_modify(|e| { *e += 1 })
|
||||
.or_insert(0);
|
||||
res
|
||||
}).collect();
|
||||
|
||||
vec![TypedStatement::MultipleDefinition(variables, exprs)]
|
||||
},
|
||||
TypedStatement::Condition(e1, e2) => vec![TypedStatement::Condition(e1.unroll(substitution), e2.unroll(substitution))],
|
||||
TypedStatement::For(v, from, to, stats) => {
|
||||
let mut values: Vec<T> = vec![];
|
||||
let mut current = from;
|
||||
while current < to {
|
||||
values.push(current.clone());
|
||||
current = T::one() + ¤t;
|
||||
}
|
||||
|
||||
let res = values.into_iter().map(|index| {
|
||||
vec![
|
||||
vec![
|
||||
TypedStatement::Declaration(v.clone()),
|
||||
TypedStatement::Definition(v.clone(), FieldElementExpression::Number(index).into()),
|
||||
],
|
||||
stats.clone()
|
||||
].into_iter().flat_map(|x| x)
|
||||
}).flat_map(|x| x).flat_map(|x| x.unroll(substitution)).collect();
|
||||
|
||||
res
|
||||
}
|
||||
TypedStatement::Return(exprs) => {
|
||||
vec![TypedStatement::Return(exprs.into_iter().map(|e| e.unroll(substitution)).collect())]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Unroll for TypedFunction<T> {
|
||||
fn unroll(self) -> TypedFunction<T> {
|
||||
|
||||
let mut substitution = HashMap::new();
|
||||
|
||||
let arguments = self.arguments.into_iter().map(|p|
|
||||
Parameter {
|
||||
id: Variable {
|
||||
id: format!("{}_{}", p.id.id.clone(), substitution.entry(p.id.id)
|
||||
.and_modify(|e| { *e += 1 })
|
||||
.or_insert(0)),
|
||||
..p.id
|
||||
},
|
||||
..p
|
||||
}
|
||||
).collect();
|
||||
|
||||
TypedFunction {
|
||||
arguments: arguments,
|
||||
statements: self.statements.into_iter().flat_map(|s| s.unroll(&mut substitution)).collect(),
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Unroll for TypedProg<T> {
|
||||
fn unroll(self) -> TypedProg<T> {
|
||||
TypedProg {
|
||||
functions: self.functions.into_iter().map(|f| f.unroll()).collect(),
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use field::FieldPrime;
|
||||
|
||||
#[cfg(test)]
|
||||
mod statement {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn for_loop() {
|
||||
let s = TypedStatement::For(Variable::field_element("i"), FieldPrime::from(2), FieldPrime::from(5), vec![
|
||||
TypedStatement::Declaration(Variable::field_element("foo")),
|
||||
TypedStatement::Definition(Variable::field_element("foo"), FieldElementExpression::Identifier(String::from("i")).into())]
|
||||
);
|
||||
|
||||
let expected = vec![
|
||||
TypedStatement::Definition(Variable::field_element("i_0"), FieldElementExpression::Number(FieldPrime::from(2)).into()),
|
||||
TypedStatement::Definition(Variable::field_element("foo_0"), FieldElementExpression::Identifier(String::from("i_0")).into()),
|
||||
|
||||
TypedStatement::Definition(Variable::field_element("i_1"), FieldElementExpression::Number(FieldPrime::from(3)).into()),
|
||||
TypedStatement::Definition(Variable::field_element("foo_1"), FieldElementExpression::Identifier(String::from("i_1")).into()),
|
||||
|
||||
TypedStatement::Definition(Variable::field_element("i_2"), FieldElementExpression::Number(FieldPrime::from(4)).into()),
|
||||
TypedStatement::Definition(Variable::field_element("foo_2"), FieldElementExpression::Identifier(String::from("i_2")).into()),
|
||||
];
|
||||
|
||||
assert_eq!(s.unroll(&mut HashMap::new()), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn definition() {
|
||||
|
||||
// field a
|
||||
// a = 5
|
||||
// a = 6
|
||||
// a
|
||||
|
||||
// should be turned into
|
||||
// a_0 = 5
|
||||
// a_1 = 6
|
||||
// a_1
|
||||
|
||||
let mut substitution = HashMap::new();
|
||||
|
||||
let s: TypedStatement<FieldPrime> = TypedStatement::Declaration(Variable::field_element("a"));
|
||||
assert_eq!(s.unroll(&mut substitution), vec![]);
|
||||
|
||||
let s = TypedStatement::Definition(Variable::field_element("a"), FieldElementExpression::Number(FieldPrime::from(5)).into());
|
||||
assert_eq!(s.unroll(&mut substitution), vec![TypedStatement::Definition(Variable::field_element("a_0"), FieldElementExpression::Number(FieldPrime::from(5)).into())]);
|
||||
|
||||
let s = TypedStatement::Definition(Variable::field_element("a"), FieldElementExpression::Number(FieldPrime::from(6)).into());
|
||||
assert_eq!(s.unroll(&mut substitution), vec![TypedStatement::Definition(Variable::field_element("a_1"), FieldElementExpression::Number(FieldPrime::from(6)).into())]);
|
||||
|
||||
let e: FieldElementExpression<FieldPrime> = FieldElementExpression::Identifier(String::from("a"));
|
||||
assert_eq!(e.unroll(&mut substitution), FieldElementExpression::Identifier(String::from("a_1")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn incremental_definition() {
|
||||
|
||||
// field a
|
||||
// a = 5
|
||||
// a = a + 1
|
||||
|
||||
// should be turned into
|
||||
// a_0 = 5
|
||||
// a_1 = a_0 + 1
|
||||
|
||||
let mut substitution = HashMap::new();
|
||||
|
||||
let s: TypedStatement<FieldPrime> = TypedStatement::Declaration(Variable::field_element("a"));
|
||||
assert_eq!(s.unroll(&mut substitution), vec![]);
|
||||
|
||||
let s = TypedStatement::Definition(Variable::field_element("a"), FieldElementExpression::Number(FieldPrime::from(5)).into());
|
||||
assert_eq!(s.unroll(&mut substitution), vec![TypedStatement::Definition(Variable::field_element("a_0"), FieldElementExpression::Number(FieldPrime::from(5)).into())]);
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
Variable::field_element("a"),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(String::from("a")),
|
||||
box FieldElementExpression::Number(FieldPrime::from(1))
|
||||
).into()
|
||||
);
|
||||
assert_eq!(
|
||||
s.unroll(&mut substitution),
|
||||
vec![
|
||||
TypedStatement::Definition(
|
||||
Variable::field_element("a_1"),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(String::from("a_0")),
|
||||
box FieldElementExpression::Number(FieldPrime::from(1))
|
||||
).into()
|
||||
)
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn incremental_multiple_definition() {
|
||||
|
||||
use types::Type;
|
||||
|
||||
// field a
|
||||
// a = 2
|
||||
// a = foo(a)
|
||||
|
||||
// should be turned into
|
||||
// a_0 = 2
|
||||
// a_1 = foo(a_0)
|
||||
|
||||
let mut substitution = HashMap::new();
|
||||
|
||||
let s: TypedStatement<FieldPrime> = TypedStatement::Declaration(Variable::field_element("a"));
|
||||
assert_eq!(s.unroll(&mut substitution), vec![]);
|
||||
|
||||
let s = TypedStatement::Definition(Variable::field_element("a"), FieldElementExpression::Number(FieldPrime::from(2)).into());
|
||||
assert_eq!(s.unroll(&mut substitution), vec![TypedStatement::Definition(Variable::field_element("a_0"), FieldElementExpression::Number(FieldPrime::from(2)).into())]);
|
||||
|
||||
let s: TypedStatement<FieldPrime> = TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element("a")],
|
||||
TypedExpressionList::FunctionCall(
|
||||
String::from("foo"),
|
||||
vec![FieldElementExpression::Identifier(String::from("a")).into()],
|
||||
vec![Type::FieldElement],
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
s.unroll(&mut substitution),
|
||||
vec![
|
||||
TypedStatement::MultipleDefinition(
|
||||
vec![Variable::field_element("a_1")],
|
||||
TypedExpressionList::FunctionCall(
|
||||
String::from("foo"),
|
||||
vec![FieldElementExpression::Identifier(String::from("a_0")).into()],
|
||||
vec![Type::FieldElement],
|
||||
)
|
||||
)
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -88,11 +88,11 @@ pub fn cast<T: Field>(from: &Type, to: &Type) -> FlatFunction<T> {
|
|||
let mut statements = conditions;
|
||||
|
||||
statements.insert(0, FlatStatement::Directive(
|
||||
DirectiveStatement {
|
||||
inputs: directive_inputs,
|
||||
outputs: directive_outputs,
|
||||
helper: helper
|
||||
}
|
||||
DirectiveStatement::new(
|
||||
directive_outputs,
|
||||
helper,
|
||||
directive_inputs,
|
||||
)
|
||||
));
|
||||
|
||||
statements.push(FlatStatement::Return(
|
||||
|
|
Loading…
Reference in a new issue