1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

add loop unrolling, post flattening constant propagation

This commit is contained in:
schaeff 2018-10-12 16:40:11 +02:00
parent a21cbe7c30
commit 8b2319b0cb
16 changed files with 692 additions and 88 deletions

View file

@ -2,4 +2,8 @@ def main() -> (field):
field a = 1 + 2 + 3
field b = if 1 < a then 3 else a + 3 fi
field c = if b + a == 2 then 1 else b fi
return c
for field e in 0..2 do
field g = 4
c = c + g
endfor
return c * a

View file

@ -0,0 +1,7 @@
def foo(field a, field b) -> (field, field):
a == b + 2
return a, b
def main() -> (field):
a, b = foo(1, 1)
return a + b

View file

@ -1,8 +1,10 @@
// we can compare numbers up to 2^(pbits - 2) - 1, ie any number which fits in (pbits - 2) bits
// lt should not work for the maxvalue = 2^(pbits - 2) - 1 augmented by one
// /!\ should be called with a = 0
def main(field a) -> (field):
field pbits = 254
// maxvalue = 2**252 - 1
field maxvalue = 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1
field maxvalue = a + 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1
// we added a = 0 to prevent the condition to be evaluated at compile time
return if 0 < (maxvalue + 1) then 1 else 0 fi

View file

@ -1,5 +1,7 @@
// as p - 1 is greater than p/2, comparing to it should fail
// /!\ should be called with a = 0
def main(field a) -> (field):
field p = 21888242871839275222246405745257275088548364400416034343698204186575808495617
field p = 21888242871839275222246405745257275088548364400416034343698204186575808495617 + a
// we added a = 0 to prevent the condition to be evaluated at compile time
return if 0 < p - 1 then 1 else 0 fi

View file

@ -14,7 +14,7 @@ use semantics::{self, Checker};
use optimizer::{Optimizer};
use flatten::Flattener;
use std::io::{self};
use propagation::Propagate;
use static_analysis::Analyse;
#[derive(Debug)]
pub enum CompileError<T: Field> {
@ -81,13 +81,16 @@ pub fn compile_aux<T: Field, R: BufRead, S: BufRead, E: Into<imports::Error>>(re
// check semantics
let typed_ast = Checker::new().check_program(program_ast)?;
// optimize (constant propagation)
let typed_ast = typed_ast.propagate();
// analyse (unroll and constant propagation)
let typed_ast = typed_ast.analyse();
// flatten input program
let program_flattened =
Flattener::new(T::get_required_bits()).flatten_program(typed_ast);
// analyse (constant propagation after call resolution)
let program_flattened = program_flattened.analyse();
Ok(program_flattened)
}

View file

@ -115,7 +115,7 @@ impl<T: Field> FlatFunction<T> {
}
},
FlatStatement::Directive(ref d) => {
let input_values: Vec<T> = d.inputs.iter().map(|i| witness.get(i).unwrap().clone()).collect();
let input_values: Vec<T> = d.inputs.iter().map(|i| i.solve(&mut witness)).collect();
match d.helper.execute(&input_values) {
Ok(res) => {
for (i, o) in d.outputs.iter().enumerate() {
@ -188,7 +188,7 @@ pub enum FlatStatement<T: Field> {
Return(FlatExpressionList<T>),
Condition(FlatExpression<T>, FlatExpression<T>),
Definition(FlatVariable, FlatExpression<T>),
Directive(DirectiveStatement)
Directive(DirectiveStatement<T>)
}
impl<T: Field> fmt::Display for FlatStatement<T> {
@ -239,7 +239,7 @@ impl<T: Field> FlatStatement<T> {
},
FlatStatement::Directive(d) => {
let outputs = d.outputs.into_iter().map(|o| substitution.get(&o).unwrap_or(o)).collect();
let inputs = d.inputs.into_iter().map(|i| substitution.get(&i).unwrap()).collect();
let inputs = d.inputs.into_iter().map(|i| i.apply_substitution(substitution, should_fallback)).collect();
FlatStatement::Directive(
DirectiveStatement {
@ -295,7 +295,7 @@ impl<T: Field> FlatExpression<T> {
}
}
fn solve(&self, inputs: &mut BTreeMap<FlatVariable, T>) -> T {
pub fn solve(&self, inputs: &mut BTreeMap<FlatVariable, T>) -> T {
match *self {
FlatExpression::Number(ref x) => x.clone(),
FlatExpression::Identifier(ref var) => {

View file

@ -336,11 +336,11 @@ impl Flattener {
statements_flattened.push(FlatStatement::Definition(name_x, x));
statements_flattened.push(
FlatStatement::Directive(DirectiveStatement {
outputs: vec![name_y, name_m],
inputs: vec![name_x],
helper: Helper::Rust(RustHelper::ConditionEq)
})
FlatStatement::Directive(DirectiveStatement::new(
vec![name_y, name_m],
Helper::Rust(RustHelper::ConditionEq),
vec![name_x],
))
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(name_y),
@ -435,13 +435,13 @@ impl Flattener {
replacement_map.insert(o, new_o);
new_o
}).collect();
let new_inputs = d.inputs.iter().map(|i| replacement_map.get(&i).unwrap()).collect();
let new_inputs = d.inputs.into_iter().map(|i| i.apply_direct_substitution(&replacement_map)).collect();
statements_flattened.push(
FlatStatement::Directive(
DirectiveStatement {
outputs: new_outputs,
helper: d.helper.clone(),
inputs: new_inputs,
helper: d.helper.clone()
}
)
)

View file

@ -7,30 +7,30 @@ pub use self::libsnark_gadget::LibsnarkGadgetHelper;
pub use self::rust::RustHelper;
use std::fmt;
use field::{Field};
use flat_absy::FlatVariable;
use flat_absy::{FlatExpression, FlatVariable};
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct DirectiveStatement {
pub inputs: Vec<FlatVariable>,
pub struct DirectiveStatement<T: Field> {
pub inputs: Vec<FlatExpression<T>>,
pub outputs: Vec<FlatVariable>,
pub helper: Helper
}
impl DirectiveStatement {
impl<T: Field> DirectiveStatement<T> {
pub fn new(outputs: Vec<FlatVariable>, helper: Helper, inputs: Vec<FlatVariable>) -> Self {
let (in_len, out_len) = helper.get_signature();
assert_eq!(in_len, inputs.len());
assert_eq!(out_len, outputs.len());
DirectiveStatement {
helper,
inputs,
inputs: inputs.into_iter().map(|i| FlatExpression::Identifier(i)).collect(),
outputs,
}
}
}
impl fmt::Display for DirectiveStatement {
impl<T: Field> fmt::Display for DirectiveStatement<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "# {} = {}({})",
self.outputs.iter().map(|o| o.to_string()).collect::<Vec<String>>().join(", "),

View file

@ -24,7 +24,7 @@ mod standard;
mod helpers;
mod types;
mod typed_absy;
mod propagation;
mod static_analysis;
pub mod absy;
pub mod flat_absy;

View file

@ -117,7 +117,7 @@ pub fn parse_statement<T: Field, R: BufRead>(
}
},
Some(Ok(ref x)) if !x.trim().starts_with("return") => match parse_statement(lines, x, &Position { line: current_line, col: 1 }) {
Ok((statement, ..)) => statements.push(statement[0].clone()),
Ok((mut statement, ..)) => statements.append(&mut statement),
Err(err) => return Err(err),
},
Some(Err(err)) => panic!("Error while reading Definitions: {}", err),

View file

@ -80,8 +80,8 @@ impl FunctionQuery {
})
}
fn match_funcs(&self, funcs: Vec<FunctionDeclaration>) -> Vec<FunctionDeclaration> {
funcs.into_iter().filter(|func| self.match_func(func)).collect()
fn match_funcs(&self, funcs: &HashSet<FunctionDeclaration>) -> Vec<FunctionDeclaration> {
funcs.iter().filter(|func| self.match_func(func)).cloned().collect()
}
}
@ -538,7 +538,7 @@ impl Checker {
}
fn find_candidates(&self, query: &FunctionQuery) -> Vec<FunctionDeclaration> {
query.match_funcs(self.functions.clone().into_iter().collect())
query.match_funcs(&self.functions)
}
fn enter_scope(&mut self) -> () {

View file

@ -0,0 +1,174 @@
//! Module containing constant propagation
//!
//! @file propagation.rs
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
use helpers::DirectiveStatement;
use flat_absy::*;
use std::collections::HashMap;
use field::Field;
pub trait Propagate<T: Field> {
fn propagate(self) -> Self;
}
pub trait PropagateWithContext<T: Field> {
fn propagate(self, constants: &mut HashMap<FlatVariable, T>) -> Self;
}
impl<T: Field> PropagateWithContext<T> for FlatExpression<T> {
fn propagate(self, constants: &mut HashMap<FlatVariable, T>) -> FlatExpression<T> {
match self {
FlatExpression::Identifier(id) => {
match constants.get(&id) {
Some(c) => FlatExpression::Number(c.clone()),
None => FlatExpression::Identifier(id)
}
},
FlatExpression::Add(box e1, box e2) => {
match (e1.propagate(constants), e2.propagate(constants)) {
(FlatExpression::Number(n1), FlatExpression::Number(n2)) => FlatExpression::Number(n1 + n2),
(e1, e2) => FlatExpression::Add(box e1, box e2),
}
},
FlatExpression::Sub(box e1, box e2) => {
match (e1.propagate(constants), e2.propagate(constants)) {
(FlatExpression::Number(n1), FlatExpression::Number(n2)) => FlatExpression::Number(n1 - n2),
(e1, e2) => FlatExpression::Sub(box e1, box e2),
}
},
FlatExpression::Mult(box e1, box e2) => {
match (e1.propagate(constants), e2.propagate(constants)) {
(FlatExpression::Number(n1), FlatExpression::Number(n2)) => FlatExpression::Number(n1 * n2),
(e1, e2) => FlatExpression::Mult(box e1, box e2),
}
},
FlatExpression::Div(box e1, box e2) => {
match (e1.propagate(constants), e2.propagate(constants)) {
(FlatExpression::Number(n1), FlatExpression::Number(n2)) => FlatExpression::Number(n1 / n2),
(e1, e2) => FlatExpression::Div(box e1, box e2),
}
},
_ => self
}
}
}
impl<T: Field> FlatStatement<T> {
fn propagate(self, constants: &mut HashMap<FlatVariable, T>) -> Option<FlatStatement<T>> {
match self {
FlatStatement::Return(list) => Some(FlatStatement::Return(FlatExpressionList {
expressions: list.expressions.into_iter().map(|e| e.propagate(constants)).collect()
})),
FlatStatement::Definition(var, expr) => {
match expr.propagate(constants) {
FlatExpression::Number(n) => {
constants.insert(var, n);
None
},
e => {
Some(FlatStatement::Definition(var, e))
}
}
},
FlatStatement::Condition(e1, e2) => {
// could stop execution here if condition is known to fail...
Some(FlatStatement::Condition(e1.propagate(constants), e2.propagate(constants)))
},
FlatStatement::Directive(d) => {
Some(FlatStatement::Directive(
DirectiveStatement {
inputs: d.inputs.into_iter().map(|i| i.propagate(constants)).collect(),
..d
}
))
}
}
}
}
impl<T: Field> Propagate<T> for FlatFunction<T> {
fn propagate(self) -> FlatFunction<T> {
let mut constants = HashMap::new();
FlatFunction {
statements: self.statements.into_iter().filter_map(|s| s.propagate(&mut constants)).collect(),
..self
}
}
}
impl<T: Field> FlatProg<T> {
pub fn propagate(self) -> FlatProg<T> {
let mut functions = vec![];
for f in self.functions {
let fun = f.propagate();
functions.push(fun);
}
FlatProg {
functions,
..self
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use field::FieldPrime;
#[cfg(test)]
mod expression {
use super::*;
#[cfg(test)]
mod field {
use super::*;
#[test]
fn add() {
let e = FlatExpression::Add(
box FlatExpression::Number(FieldPrime::from(2)),
box FlatExpression::Number(FieldPrime::from(3))
);
assert_eq!(e.propagate(&mut HashMap::new()), FlatExpression::Number(FieldPrime::from(5)));
}
#[test]
fn sub() {
let e = FlatExpression::Sub(
box FlatExpression::Number(FieldPrime::from(3)),
box FlatExpression::Number(FieldPrime::from(2))
);
assert_eq!(e.propagate(&mut HashMap::new()), FlatExpression::Number(FieldPrime::from(1)));
}
#[test]
fn mult() {
let e = FlatExpression::Mult(
box FlatExpression::Number(FieldPrime::from(3)),
box FlatExpression::Number(FieldPrime::from(2))
);
assert_eq!(e.propagate(&mut HashMap::new()), FlatExpression::Number(FieldPrime::from(6)));
}
#[test]
fn div() {
let e = FlatExpression::Div(
box FlatExpression::Number(FieldPrime::from(6)),
box FlatExpression::Number(FieldPrime::from(2))
);
assert_eq!(e.propagate(&mut HashMap::new()), FlatExpression::Number(FieldPrime::from(3)));
}
}
}
}

View file

@ -0,0 +1,30 @@
//! Module containing static analysis
//!
//! @file mod.rs
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
mod propagation;
mod unroll;
mod flat_propagation;
use flat_absy::FlatProg;
use field::Field;
use typed_absy::TypedProg;
use self::unroll::Unroll;
pub trait Analyse {
fn analyse(self) -> Self;
}
impl<T: Field> Analyse for TypedProg<T> {
fn analyse(self) -> Self {
self.unroll().propagate()
}
}
impl<T: Field> Analyse for FlatProg<T> {
fn analyse(self) -> Self {
self.propagate()
}
}

View file

@ -1,27 +1,33 @@
//! Module containing constant propagation
//!
//! @file propagation.rs
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
use absy::variable::Variable;
use std::collections::HashMap;
use field::Field;
use typed_absy::*;
pub trait Propagate {
fn propagate(self) -> Self;
pub trait Propagate<T: Field> {
fn propagate(self, functions: &Vec<TypedFunction<T>>) -> Self;
}
pub trait PropagateWithContext<T: Field> {
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>) -> Self;
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> Self;
}
impl<T: Field> PropagateWithContext<T> for TypedExpression<T> {
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>) -> TypedExpression<T> {
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> TypedExpression<T> {
match self {
TypedExpression::FieldElement(e) => e.propagate(constants).into(),
TypedExpression::Boolean(e) => e.propagate(constants).into(),
TypedExpression::FieldElement(e) => e.propagate(constants, functions).into(),
TypedExpression::Boolean(e) => e.propagate(constants, functions).into(),
}
}
}
impl<T: Field> PropagateWithContext<T> for FieldElementExpression<T> {
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>) -> FieldElementExpression<T> {
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> FieldElementExpression<T> {
match self {
FieldElementExpression::Identifier(id) => {
match constants.get(&Variable::field_element(id.clone())) {
@ -33,51 +39,89 @@ impl<T: Field> PropagateWithContext<T> for FieldElementExpression<T> {
}
},
FieldElementExpression::Add(box e1, box e2) => {
match (e1.propagate(constants), e2.propagate(constants)) {
match (e1.propagate(constants, functions), e2.propagate(constants, functions)) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => FieldElementExpression::Number(n1 + n2),
(e1, e2) => FieldElementExpression::Add(box e1, box e2),
}
},
FieldElementExpression::Sub(box e1, box e2) => {
match (e1.propagate(constants), e2.propagate(constants)) {
match (e1.propagate(constants, functions), e2.propagate(constants, functions)) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => FieldElementExpression::Number(n1 - n2),
(e1, e2) => FieldElementExpression::Sub(box e1, box e2),
}
},
FieldElementExpression::Mult(box e1, box e2) => {
match (e1.propagate(constants), e2.propagate(constants)) {
match (e1.propagate(constants, functions), e2.propagate(constants, functions)) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => FieldElementExpression::Number(n1 * n2),
(e1, e2) => FieldElementExpression::Mult(box e1, box e2),
}
},
FieldElementExpression::Div(box e1, box e2) => {
match (e1.propagate(constants), e2.propagate(constants)) {
match (e1.propagate(constants, functions), e2.propagate(constants, functions)) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => FieldElementExpression::Number(n1 / n2),
(e1, e2) => FieldElementExpression::Div(box e1, box e2),
}
},
FieldElementExpression::Pow(box e1, box e2) => {
match (e1.propagate(constants), e2.propagate(constants)) {
match (e1.propagate(constants, functions), e2.propagate(constants, functions)) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => FieldElementExpression::Number(n1.pow(n2)),
(e1, e2) => FieldElementExpression::Pow(box e1, box e2),
}
},
FieldElementExpression::IfElse(box condition, box consequence, box alternative) => {
let consequence = consequence.propagate(constants);
let alternative = alternative.propagate(constants);
match condition.propagate(constants) {
let consequence = consequence.propagate(constants, functions);
let alternative = alternative.propagate(constants, functions);
match condition.propagate(constants, functions) {
BooleanExpression::Value(true) => consequence,
BooleanExpression::Value(false) => alternative,
c => FieldElementExpression::IfElse(box c, box consequence, box alternative)
}
},
FieldElementExpression::FunctionCall(id, arguments) => {
// We only cover the case where all arguments are constants, therefore the call is guaranteed to be constant
// Propagate the arguments, then return Ok if they're constant, Err otherwise
let arguments = arguments.into_iter().map(|a| a.propagate(constants, functions)).collect();
// let f = functions[0]; // TODO find functon based on id
// match f.execute(arguments) {
// Ok(expressions) => expressions,
// _ => FieldElementExpression::FunctionCall(id, arguments),
// }
FieldElementExpression::FunctionCall(id, arguments)
// let each_argument_constant = arguments.into_iter().map(|a| a.propagate(constants, functions)).map(|a| match a {
// a @ TypedExpression::FieldElement(FieldElementExpression::Number(..)) => Ok(a),
// a @ TypedExpression::Boolean(BooleanExpression::Value(..)) => Ok(a),
// a => Err(a)
// });
// let all_arguments_constant = each_argument_constant.collect::<Result<Vec<_>, _>>();
// match all_arguments_constant {
// Ok(arguments) => {
// // all arguments are constant, we can execute the function now
// unimplemented!()
// },
// Err(_) => {
// // not all arguments are constant, keep the function call
// let arguments = each_argument_constant.into_iter().map(|a| match a {
// Ok(a) => a,
// Err(a) => a
// }).collect();
//
// }
// }
}
_ => self
}
}
}
impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>) -> BooleanExpression<T> {
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> BooleanExpression<T> {
match self {
BooleanExpression::Identifier(id) => {
match constants.get(&Variable::boolean(id.clone())) {
@ -89,8 +133,8 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
}
},
BooleanExpression::Eq(box e1, box e2) => {
let e1 = e1.propagate(constants);
let e2 = e2.propagate(constants);
let e1 = e1.propagate(constants, functions);
let e2 = e2.propagate(constants, functions);
match (e1, e2) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
@ -100,8 +144,8 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
}
}
BooleanExpression::Lt(box e1, box e2) => {
let e1 = e1.propagate(constants);
let e2 = e2.propagate(constants);
let e1 = e1.propagate(constants, functions);
let e2 = e2.propagate(constants, functions);
match (e1, e2) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
@ -111,8 +155,8 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
}
}
BooleanExpression::Le(box e1, box e2) => {
let e1 = e1.propagate(constants);
let e2 = e2.propagate(constants);
let e1 = e1.propagate(constants, functions);
let e2 = e2.propagate(constants, functions);
match (e1, e2) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
@ -122,8 +166,8 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
}
}
BooleanExpression::Gt(box e1, box e2) => {
let e1 = e1.propagate(constants);
let e2 = e2.propagate(constants);
let e1 = e1.propagate(constants, functions);
let e2 = e2.propagate(constants, functions);
match (e1, e2) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
@ -133,8 +177,8 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
}
}
BooleanExpression::Ge(box e1, box e2) => {
let e1 = e1.propagate(constants);
let e2 = e2.propagate(constants);
let e1 = e1.propagate(constants, functions);
let e2 = e2.propagate(constants, functions);
match (e1, e2) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
@ -148,12 +192,22 @@ impl<T: Field> PropagateWithContext<T> for BooleanExpression<T> {
}
}
impl<T: Field> TypedStatement<T> {
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>) -> Option<TypedStatement<T>> {
impl<T: Field> TypedExpressionList<T> {
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> TypedExpressionList<T> {
match self {
TypedStatement::Return(expressions) => Some(TypedStatement::Return(expressions.into_iter().map(|e| e.propagate(constants)).collect())),
TypedExpressionList::FunctionCall(id, arguments, types) => {
TypedExpressionList::FunctionCall(id, arguments.into_iter().map(|e| e.propagate(constants, functions)).collect(), types)
}
}
}
}
impl<T: Field> TypedStatement<T> {
fn propagate(self, constants: &mut HashMap<Variable, TypedExpression<T>>, functions: &Vec<TypedFunction<T>>) -> Option<TypedStatement<T>> {
match self {
TypedStatement::Return(expressions) => Some(TypedStatement::Return(expressions.into_iter().map(|e| e.propagate(constants, functions)).collect())),
TypedStatement::Definition(var, expr) => {
match expr.propagate(constants) {
match expr.propagate(constants, functions) {
e @ TypedExpression::Boolean(BooleanExpression::Value(..)) | e @ TypedExpression::FieldElement(FieldElementExpression::Number(..)) => {
constants.insert(var, e);
None
@ -165,30 +219,42 @@ impl<T: Field> TypedStatement<T> {
},
TypedStatement::Condition(e1, e2) => {
// could stop execution here if condition is known to fail...
Some(TypedStatement::Condition(e1.propagate(constants), e2.propagate(constants)))
Some(TypedStatement::Condition(e1.propagate(constants, functions), e2.propagate(constants, functions)))
},
TypedStatement::For(v, from, to, stats) => Some(TypedStatement::For(v, from, to, stats.into_iter().filter_map(|s| s.propagate(constants)).collect())),
TypedStatement::For(..) => panic!("no for expected"),
TypedStatement::MultipleDefinition(variables, expression_list) => {
let expression_list = expression_list.propagate(constants, functions);
Some(TypedStatement::MultipleDefinition(variables, expression_list))
}
_ => Some(self)
}
}
}
impl<T: Field> Propagate for TypedFunction<T> {
fn propagate(self) -> TypedFunction<T> {
impl<T: Field> Propagate<T> for TypedFunction<T> {
fn propagate(self, functions: &Vec<TypedFunction<T>>) -> TypedFunction<T> {
let mut constants = HashMap::new();
TypedFunction {
statements: self.statements.into_iter().filter_map(|s| s.propagate(&mut constants)).collect(),
statements: self.statements.into_iter().filter_map(|s| s.propagate(&mut constants, functions)).collect(),
..self
}
}
}
impl<T: Field> Propagate for TypedProg<T> {
fn propagate(self) -> TypedProg<T> {
impl<T: Field> TypedProg<T> {
pub fn propagate(self) -> TypedProg<T> {
let mut functions = vec![];
for f in self.functions {
let fun = f.propagate(&mut functions);
functions.push(fun);
}
TypedProg {
functions: self.functions.into_iter().map(|f| f.propagate()).collect(),
functions,
..self
}
}
@ -214,7 +280,7 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(3))
);
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(5)));
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(5)));
}
#[test]
@ -224,7 +290,7 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(2))
);
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(1)));
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(1)));
}
#[test]
@ -234,7 +300,7 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(2))
);
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(6)));
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(6)));
}
#[test]
@ -244,7 +310,7 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(2))
);
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(3)));
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(3)));
}
#[test]
@ -254,7 +320,7 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(3))
);
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(8)));
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(8)));
}
#[test]
@ -265,7 +331,7 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(3))
);
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(2)));
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(2)));
}
#[test]
@ -276,7 +342,7 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(3))
);
assert_eq!(e.propagate(&mut HashMap::new()), FieldElementExpression::Number(FieldPrime::from(3)));
assert_eq!(e.propagate(&mut HashMap::new(), &mut vec![]), FieldElementExpression::Number(FieldPrime::from(3)));
}
}
@ -296,8 +362,8 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(2))
);
assert_eq!(e_true.propagate(&mut HashMap::new()), BooleanExpression::Value(true));
assert_eq!(e_false.propagate(&mut HashMap::new()), BooleanExpression::Value(false));
assert_eq!(e_true.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(true));
assert_eq!(e_false.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(false));
}
#[test]
@ -312,8 +378,8 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(2))
);
assert_eq!(e_true.propagate(&mut HashMap::new()), BooleanExpression::Value(true));
assert_eq!(e_false.propagate(&mut HashMap::new()), BooleanExpression::Value(false));
assert_eq!(e_true.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(true));
assert_eq!(e_false.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(false));
}
#[test]
@ -328,8 +394,8 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(2))
);
assert_eq!(e_true.propagate(&mut HashMap::new()), BooleanExpression::Value(true));
assert_eq!(e_false.propagate(&mut HashMap::new()), BooleanExpression::Value(false));
assert_eq!(e_true.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(true));
assert_eq!(e_false.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(false));
}
#[test]
@ -344,8 +410,8 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(5))
);
assert_eq!(e_true.propagate(&mut HashMap::new()), BooleanExpression::Value(true));
assert_eq!(e_false.propagate(&mut HashMap::new()), BooleanExpression::Value(false));
assert_eq!(e_true.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(true));
assert_eq!(e_false.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(false));
}
#[test]
@ -360,8 +426,8 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(5))
);
assert_eq!(e_true.propagate(&mut HashMap::new()), BooleanExpression::Value(true));
assert_eq!(e_false.propagate(&mut HashMap::new()), BooleanExpression::Value(false));
assert_eq!(e_true.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(true));
assert_eq!(e_false.propagate(&mut HashMap::new(), &mut vec![]), BooleanExpression::Value(false));
}
}
}

View file

@ -0,0 +1,316 @@
//! Module containing SSA reduction, including for-loop unrolling
//!
//! @file unroll.rs
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
use absy::parameter::Parameter;
use absy::variable::Variable;
use std::collections::HashMap;
use field::Field;
use typed_absy::*;
pub trait Unroll {
fn unroll(self) -> Self;
}
pub trait UnrollWithContext<T: Field> {
fn unroll(self, substitution: &mut HashMap<String, usize>) -> Self;
}
impl<T: Field> TypedExpression<T> {
fn unroll(self, substitution: &HashMap<String, usize>) -> TypedExpression<T> {
match self {
TypedExpression::FieldElement(e) => e.unroll(substitution).into(),
TypedExpression::Boolean(e) => e.unroll(substitution).into(),
}
}
}
impl<T: Field> FieldElementExpression<T> {
fn unroll(self, substitution: &HashMap<String, usize>) -> FieldElementExpression<T> {
match self {
FieldElementExpression::Identifier(id) => FieldElementExpression::Identifier(format!("{}_{}", id, substitution.get(&id).unwrap().clone())),
FieldElementExpression::Number(n) => FieldElementExpression::Number(n),
FieldElementExpression::Add(box e1, box e2) => FieldElementExpression::Add(box e1.unroll(substitution), box e2.unroll(substitution)),
FieldElementExpression::Sub(box e1, box e2) => FieldElementExpression::Sub(box e1.unroll(substitution), box e2.unroll(substitution)),
FieldElementExpression::Mult(box e1, box e2) => FieldElementExpression::Mult(box e1.unroll(substitution), box e2.unroll(substitution)),
FieldElementExpression::Div(box e1, box e2) => FieldElementExpression::Div(box e1.unroll(substitution), box e2.unroll(substitution)),
FieldElementExpression::Pow(box e1, box e2) => FieldElementExpression::Div(box e1.unroll(substitution), box e2.unroll(substitution)),
FieldElementExpression::IfElse(box cond, box cons, box alt) => FieldElementExpression::IfElse(box cond.unroll(substitution), box cons.unroll(substitution), box alt.unroll(substitution)),
FieldElementExpression::FunctionCall(id, args) => FieldElementExpression::FunctionCall(id, args.into_iter().map(|a| a.unroll(substitution)).collect()),
}
}
}
impl<T: Field> BooleanExpression<T> {
fn unroll(self, substitution: &HashMap<String, usize>) -> BooleanExpression<T> {
match self {
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(format!("{}_{}", id, substitution.get(&id).unwrap().clone())),
BooleanExpression::Value(v) => BooleanExpression::Value(v),
BooleanExpression::Eq(box e1, box e2) => BooleanExpression::Eq(box e1.unroll(substitution), box e2.unroll(substitution)),
BooleanExpression::Lt(box e1, box e2) => BooleanExpression::Lt(box e1.unroll(substitution), box e2.unroll(substitution)),
BooleanExpression::Le(box e1, box e2) => BooleanExpression::Le(box e1.unroll(substitution), box e2.unroll(substitution)),
BooleanExpression::Gt(box e1, box e2) => BooleanExpression::Gt(box e1.unroll(substitution), box e2.unroll(substitution)),
BooleanExpression::Ge(box e1, box e2) => BooleanExpression::Ge(box e1.unroll(substitution), box e2.unroll(substitution)),
}
}
}
impl<T: Field> TypedExpressionList<T> {
fn unroll(self, substitution: &HashMap<String, usize>) -> TypedExpressionList<T> {
match self {
TypedExpressionList::FunctionCall(id, arguments, types) => {
TypedExpressionList::FunctionCall(id, arguments.into_iter().map(|a| a.unroll(substitution)).collect(), types)
}
}
}
}
impl<T: Field> TypedStatement<T> {
fn unroll(self, substitution: &mut HashMap<String, usize>) -> Vec<TypedStatement<T>> {
match self {
TypedStatement::Declaration(_) => {
vec![]
},
TypedStatement::Definition(variable, expr) => {
let expr = expr.unroll(substitution);
let res = match substitution.get(&variable.id) {
Some(i) => {
vec![TypedStatement::Definition(Variable { id: format!("{}_{}", variable.id, i + 1), ..variable}, expr)]
},
None => {
vec![TypedStatement::Definition(Variable { id: format!("{}_{}", variable.id, 0), ..variable}, expr)]
}
};
substitution.entry(variable.id)
.and_modify(|e| { *e += 1 })
.or_insert(0);
res
},
TypedStatement::MultipleDefinition(variables, exprs) => {
let exprs = exprs.unroll(substitution);
let variables = variables.into_iter().map(|v| {
let res = match substitution.get(&v.id) {
Some(i) => {
Variable { id: format!("{}_{}", v.id, i + 1), ..v}
},
None => {
Variable { id: format!("{}_{}", v.id, 0), ..v}
}
};
substitution.entry(v.id)
.and_modify(|e| { *e += 1 })
.or_insert(0);
res
}).collect();
vec![TypedStatement::MultipleDefinition(variables, exprs)]
},
TypedStatement::Condition(e1, e2) => vec![TypedStatement::Condition(e1.unroll(substitution), e2.unroll(substitution))],
TypedStatement::For(v, from, to, stats) => {
let mut values: Vec<T> = vec![];
let mut current = from;
while current < to {
values.push(current.clone());
current = T::one() + &current;
}
let res = values.into_iter().map(|index| {
vec![
vec![
TypedStatement::Declaration(v.clone()),
TypedStatement::Definition(v.clone(), FieldElementExpression::Number(index).into()),
],
stats.clone()
].into_iter().flat_map(|x| x)
}).flat_map(|x| x).flat_map(|x| x.unroll(substitution)).collect();
res
}
TypedStatement::Return(exprs) => {
vec![TypedStatement::Return(exprs.into_iter().map(|e| e.unroll(substitution)).collect())]
}
}
}
}
impl<T: Field> Unroll for TypedFunction<T> {
fn unroll(self) -> TypedFunction<T> {
let mut substitution = HashMap::new();
let arguments = self.arguments.into_iter().map(|p|
Parameter {
id: Variable {
id: format!("{}_{}", p.id.id.clone(), substitution.entry(p.id.id)
.and_modify(|e| { *e += 1 })
.or_insert(0)),
..p.id
},
..p
}
).collect();
TypedFunction {
arguments: arguments,
statements: self.statements.into_iter().flat_map(|s| s.unroll(&mut substitution)).collect(),
..self
}
}
}
impl<T: Field> Unroll for TypedProg<T> {
fn unroll(self) -> TypedProg<T> {
TypedProg {
functions: self.functions.into_iter().map(|f| f.unroll()).collect(),
..self
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use field::FieldPrime;
#[cfg(test)]
mod statement {
use super::*;
#[test]
fn for_loop() {
let s = TypedStatement::For(Variable::field_element("i"), FieldPrime::from(2), FieldPrime::from(5), vec![
TypedStatement::Declaration(Variable::field_element("foo")),
TypedStatement::Definition(Variable::field_element("foo"), FieldElementExpression::Identifier(String::from("i")).into())]
);
let expected = vec![
TypedStatement::Definition(Variable::field_element("i_0"), FieldElementExpression::Number(FieldPrime::from(2)).into()),
TypedStatement::Definition(Variable::field_element("foo_0"), FieldElementExpression::Identifier(String::from("i_0")).into()),
TypedStatement::Definition(Variable::field_element("i_1"), FieldElementExpression::Number(FieldPrime::from(3)).into()),
TypedStatement::Definition(Variable::field_element("foo_1"), FieldElementExpression::Identifier(String::from("i_1")).into()),
TypedStatement::Definition(Variable::field_element("i_2"), FieldElementExpression::Number(FieldPrime::from(4)).into()),
TypedStatement::Definition(Variable::field_element("foo_2"), FieldElementExpression::Identifier(String::from("i_2")).into()),
];
assert_eq!(s.unroll(&mut HashMap::new()), expected);
}
#[test]
fn definition() {
// field a
// a = 5
// a = 6
// a
// should be turned into
// a_0 = 5
// a_1 = 6
// a_1
let mut substitution = HashMap::new();
let s: TypedStatement<FieldPrime> = TypedStatement::Declaration(Variable::field_element("a"));
assert_eq!(s.unroll(&mut substitution), vec![]);
let s = TypedStatement::Definition(Variable::field_element("a"), FieldElementExpression::Number(FieldPrime::from(5)).into());
assert_eq!(s.unroll(&mut substitution), vec![TypedStatement::Definition(Variable::field_element("a_0"), FieldElementExpression::Number(FieldPrime::from(5)).into())]);
let s = TypedStatement::Definition(Variable::field_element("a"), FieldElementExpression::Number(FieldPrime::from(6)).into());
assert_eq!(s.unroll(&mut substitution), vec![TypedStatement::Definition(Variable::field_element("a_1"), FieldElementExpression::Number(FieldPrime::from(6)).into())]);
let e: FieldElementExpression<FieldPrime> = FieldElementExpression::Identifier(String::from("a"));
assert_eq!(e.unroll(&mut substitution), FieldElementExpression::Identifier(String::from("a_1")));
}
#[test]
fn incremental_definition() {
// field a
// a = 5
// a = a + 1
// should be turned into
// a_0 = 5
// a_1 = a_0 + 1
let mut substitution = HashMap::new();
let s: TypedStatement<FieldPrime> = TypedStatement::Declaration(Variable::field_element("a"));
assert_eq!(s.unroll(&mut substitution), vec![]);
let s = TypedStatement::Definition(Variable::field_element("a"), FieldElementExpression::Number(FieldPrime::from(5)).into());
assert_eq!(s.unroll(&mut substitution), vec![TypedStatement::Definition(Variable::field_element("a_0"), FieldElementExpression::Number(FieldPrime::from(5)).into())]);
let s = TypedStatement::Definition(
Variable::field_element("a"),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(String::from("a")),
box FieldElementExpression::Number(FieldPrime::from(1))
).into()
);
assert_eq!(
s.unroll(&mut substitution),
vec![
TypedStatement::Definition(
Variable::field_element("a_1"),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(String::from("a_0")),
box FieldElementExpression::Number(FieldPrime::from(1))
).into()
)
]
);
}
#[test]
fn incremental_multiple_definition() {
use types::Type;
// field a
// a = 2
// a = foo(a)
// should be turned into
// a_0 = 2
// a_1 = foo(a_0)
let mut substitution = HashMap::new();
let s: TypedStatement<FieldPrime> = TypedStatement::Declaration(Variable::field_element("a"));
assert_eq!(s.unroll(&mut substitution), vec![]);
let s = TypedStatement::Definition(Variable::field_element("a"), FieldElementExpression::Number(FieldPrime::from(2)).into());
assert_eq!(s.unroll(&mut substitution), vec![TypedStatement::Definition(Variable::field_element("a_0"), FieldElementExpression::Number(FieldPrime::from(2)).into())]);
let s: TypedStatement<FieldPrime> = TypedStatement::MultipleDefinition(
vec![Variable::field_element("a")],
TypedExpressionList::FunctionCall(
String::from("foo"),
vec![FieldElementExpression::Identifier(String::from("a")).into()],
vec![Type::FieldElement],
)
);
assert_eq!(
s.unroll(&mut substitution),
vec![
TypedStatement::MultipleDefinition(
vec![Variable::field_element("a_1")],
TypedExpressionList::FunctionCall(
String::from("foo"),
vec![FieldElementExpression::Identifier(String::from("a_0")).into()],
vec![Type::FieldElement],
)
)
]
);
}
}
}

View file

@ -88,11 +88,11 @@ pub fn cast<T: Field>(from: &Type, to: &Type) -> FlatFunction<T> {
let mut statements = conditions;
statements.insert(0, FlatStatement::Directive(
DirectiveStatement {
inputs: directive_inputs,
outputs: directive_outputs,
helper: helper
}
DirectiveStatement::new(
directive_outputs,
helper,
directive_inputs,
)
));
statements.push(FlatStatement::Return(