fix conflict
This commit is contained in:
commit
66209ccc26
7 changed files with 108 additions and 263 deletions
|
@ -98,7 +98,7 @@ mod integration {
|
|||
"--light",
|
||||
];
|
||||
|
||||
if program_name.contains("libsnark") {
|
||||
if program_name.contains("sha") {
|
||||
// we don't want to test libsnark integrations if libsnark is not available
|
||||
#[cfg(not(feature = "libsnark"))]
|
||||
return;
|
||||
|
|
|
@ -162,7 +162,7 @@ pub fn compile_aux<T: Field, R: BufRead, S: BufRead, E: Into<imports::Error>>(
|
|||
let typed_ast = typed_ast.analyse();
|
||||
|
||||
// flatten input program
|
||||
let program_flattened = Flattener::new(T::get_required_bits()).flatten_program(typed_ast);
|
||||
let program_flattened = Flattener::flatten(typed_ast);
|
||||
|
||||
// analyse (constant propagation after call resolution)
|
||||
let program_flattened = program_flattened.analyse();
|
||||
|
|
|
@ -36,7 +36,7 @@ impl fmt::Debug for FlatParameter {
|
|||
}
|
||||
|
||||
impl FlatParameter {
|
||||
pub fn apply_direct_substitution(
|
||||
pub fn apply_substitution(
|
||||
self,
|
||||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
) -> FlatParameter {
|
||||
|
|
|
@ -50,15 +50,8 @@ impl fmt::Debug for FlatVariable {
|
|||
}
|
||||
|
||||
impl FlatVariable {
|
||||
pub fn apply_substitution(
|
||||
self,
|
||||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
should_fallback: bool,
|
||||
) -> Self {
|
||||
match should_fallback {
|
||||
true => substitution.get(&self).unwrap_or(&self).clone(),
|
||||
false => substitution.get(&self).unwrap().clone(),
|
||||
}
|
||||
pub fn apply_substitution(self, substitution: &HashMap<FlatVariable, FlatVariable>) -> &Self {
|
||||
substitution.get(&self).unwrap()
|
||||
}
|
||||
|
||||
pub fn is_output(&self) -> bool {
|
||||
|
|
|
@ -213,47 +213,30 @@ impl<T: Field> fmt::Debug for FlatStatement<T> {
|
|||
}
|
||||
|
||||
impl<T: Field> FlatStatement<T> {
|
||||
pub fn apply_recursive_substitution(
|
||||
pub fn apply_substitution(
|
||||
self,
|
||||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
) -> FlatStatement<T> {
|
||||
self.apply_substitution(substitution, true)
|
||||
}
|
||||
|
||||
pub fn apply_direct_substitution(
|
||||
self,
|
||||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
) -> FlatStatement<T> {
|
||||
self.apply_substitution(substitution, false)
|
||||
}
|
||||
|
||||
fn apply_substitution(
|
||||
self,
|
||||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
should_fallback: bool,
|
||||
) -> FlatStatement<T> {
|
||||
match self {
|
||||
FlatStatement::Definition(id, x) => FlatStatement::Definition(
|
||||
id.apply_substitution(substitution, should_fallback),
|
||||
x.apply_substitution(substitution, should_fallback),
|
||||
*id.apply_substitution(substitution),
|
||||
x.apply_substitution(substitution),
|
||||
),
|
||||
FlatStatement::Return(x) => {
|
||||
FlatStatement::Return(x.apply_substitution(substitution, should_fallback))
|
||||
}
|
||||
FlatStatement::Return(x) => FlatStatement::Return(x.apply_substitution(substitution)),
|
||||
FlatStatement::Condition(x, y) => FlatStatement::Condition(
|
||||
x.apply_substitution(substitution, should_fallback),
|
||||
y.apply_substitution(substitution, should_fallback),
|
||||
x.apply_substitution(substitution),
|
||||
y.apply_substitution(substitution),
|
||||
),
|
||||
FlatStatement::Directive(d) => {
|
||||
let outputs = d
|
||||
.outputs
|
||||
.into_iter()
|
||||
.map(|o| o.apply_substitution(substitution, should_fallback))
|
||||
.map(|o| *o.apply_substitution(substitution))
|
||||
.collect();
|
||||
let inputs = d
|
||||
.inputs
|
||||
.into_iter()
|
||||
.map(|i| i.apply_substitution(substitution, should_fallback))
|
||||
.map(|i| i.apply_substitution(substitution))
|
||||
.collect();
|
||||
|
||||
FlatStatement::Directive(DirectiveStatement {
|
||||
|
@ -276,41 +259,26 @@ pub enum FlatExpression<T: Field> {
|
|||
}
|
||||
|
||||
impl<T: Field> FlatExpression<T> {
|
||||
pub fn apply_recursive_substitution(
|
||||
pub fn apply_substitution(
|
||||
self,
|
||||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
) -> FlatExpression<T> {
|
||||
self.apply_substitution(substitution, true)
|
||||
}
|
||||
|
||||
pub fn apply_direct_substitution(
|
||||
self,
|
||||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
) -> FlatExpression<T> {
|
||||
self.apply_substitution(substitution, false)
|
||||
}
|
||||
|
||||
fn apply_substitution(
|
||||
self,
|
||||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
should_fallback: bool,
|
||||
) -> FlatExpression<T> {
|
||||
match self {
|
||||
e @ FlatExpression::Number(_) => e,
|
||||
FlatExpression::Identifier(id) => {
|
||||
FlatExpression::Identifier(id.apply_substitution(substitution, should_fallback))
|
||||
FlatExpression::Identifier(*id.apply_substitution(substitution))
|
||||
}
|
||||
FlatExpression::Add(e1, e2) => FlatExpression::Add(
|
||||
box e1.apply_substitution(substitution, should_fallback),
|
||||
box e2.apply_substitution(substitution, should_fallback),
|
||||
box e1.apply_substitution(substitution),
|
||||
box e2.apply_substitution(substitution),
|
||||
),
|
||||
FlatExpression::Sub(e1, e2) => FlatExpression::Sub(
|
||||
box e1.apply_substitution(substitution, should_fallback),
|
||||
box e2.apply_substitution(substitution, should_fallback),
|
||||
box e1.apply_substitution(substitution),
|
||||
box e2.apply_substitution(substitution),
|
||||
),
|
||||
FlatExpression::Mult(e1, e2) => FlatExpression::Mult(
|
||||
box e1.apply_substitution(substitution, should_fallback),
|
||||
box e2.apply_substitution(substitution, should_fallback),
|
||||
box e1.apply_substitution(substitution),
|
||||
box e2.apply_substitution(substitution),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
@ -392,33 +360,18 @@ impl<T: Field> fmt::Display for FlatExpressionList<T> {
|
|||
}
|
||||
|
||||
impl<T: Field> FlatExpressionList<T> {
|
||||
fn apply_substitution(
|
||||
pub fn apply_substitution(
|
||||
self,
|
||||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
should_fallback: bool,
|
||||
) -> FlatExpressionList<T> {
|
||||
FlatExpressionList {
|
||||
expressions: self
|
||||
.expressions
|
||||
.into_iter()
|
||||
.map(|e| e.apply_substitution(substitution, should_fallback))
|
||||
.map(|e| e.apply_substitution(substitution))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply_recursive_substitution(
|
||||
self,
|
||||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
) -> FlatExpressionList<T> {
|
||||
self.apply_substitution(substitution, true)
|
||||
}
|
||||
|
||||
pub fn apply_direct_substitution(
|
||||
self,
|
||||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
) -> FlatExpressionList<T> {
|
||||
self.apply_substitution(substitution, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Debug for FlatExpressionList<T> {
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
use bimap::BiMap;
|
||||
use flat_absy::*;
|
||||
use helpers::{DirectiveStatement, Helper, RustHelper};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashMap;
|
||||
use typed_absy::*;
|
||||
use types::conversions::cast;
|
||||
use types::Signature;
|
||||
|
@ -18,28 +18,24 @@ use zokrates_field::field::Field;
|
|||
/// Flattener, computes flattened program.
|
||||
#[derive(Debug)]
|
||||
pub struct Flattener {
|
||||
/// Number of bits needed to represent the maximum value.
|
||||
bits: usize,
|
||||
/// Vector containing all used variables while processing the program.
|
||||
variables: HashSet<FlatVariable>,
|
||||
/// Map of renamings for reassigned variables while processing the program.
|
||||
substitution: HashMap<FlatVariable, FlatVariable>,
|
||||
/// Index of the next introduced variable while processing the program.
|
||||
next_var_idx: usize,
|
||||
///
|
||||
bijection: BiMap<String, FlatVariable>,
|
||||
}
|
||||
impl Flattener {
|
||||
pub fn flatten<T: Field>(p: TypedProg<T>) -> FlatProg<T> {
|
||||
Flattener::new().flatten_program(p)
|
||||
}
|
||||
|
||||
/// Returns a `Flattener` with fresh a fresh [substitution] and [variables].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `bits` - Number of bits needed to represent the maximum value.
|
||||
pub fn new(bits: usize) -> Flattener {
|
||||
|
||||
fn new() -> Flattener {
|
||||
Flattener {
|
||||
bits: bits,
|
||||
variables: HashSet::new(),
|
||||
substitution: HashMap::new(),
|
||||
next_var_idx: 0,
|
||||
bijection: BiMap::new(),
|
||||
}
|
||||
|
@ -134,6 +130,9 @@ impl Flattener {
|
|||
FlatExpression::Identifier(self.bijection.get_by_left(&x).unwrap().clone())
|
||||
}
|
||||
BooleanExpression::Lt(box lhs, box rhs) => {
|
||||
// Get the bitwidth to know the size of the binary decompsitions for this Field
|
||||
let bitwidth = T::get_required_bits();
|
||||
|
||||
// We know from semantic checking that lhs and rhs have the same type
|
||||
// What the expression will flatten to depends on that type
|
||||
|
||||
|
@ -152,7 +151,7 @@ impl Flattener {
|
|||
{
|
||||
// define variables for the bits
|
||||
let lhs_bits: Vec<FlatVariable> =
|
||||
(0..self.bits).map(|_| self.use_sym()).collect();
|
||||
(0..bitwidth).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
|
||||
|
@ -162,7 +161,7 @@ impl Flattener {
|
|||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..self.bits - 2 {
|
||||
for i in 0..bitwidth - 2 {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_bits[i + 2]),
|
||||
FlatExpression::Mult(
|
||||
|
@ -175,12 +174,12 @@ impl Flattener {
|
|||
// bit decomposition check
|
||||
let mut lhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..self.bits - 2 {
|
||||
for i in 0..bitwidth - 2 {
|
||||
lhs_sum = FlatExpression::Add(
|
||||
box lhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(lhs_bits[i + 2]),
|
||||
box FlatExpression::Number(T::from(2).pow(self.bits - 2 - i - 1)),
|
||||
box FlatExpression::Number(T::from(2).pow(bitwidth - 2 - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
@ -199,7 +198,7 @@ impl Flattener {
|
|||
{
|
||||
// define variables for the bits
|
||||
let rhs_bits: Vec<FlatVariable> =
|
||||
(0..self.bits).map(|_| self.use_sym()).collect();
|
||||
(0..bitwidth).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
|
||||
|
@ -209,7 +208,7 @@ impl Flattener {
|
|||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..self.bits - 2 {
|
||||
for i in 0..bitwidth - 2 {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(rhs_bits[i + 2]),
|
||||
FlatExpression::Mult(
|
||||
|
@ -222,12 +221,12 @@ impl Flattener {
|
|||
// bit decomposition check
|
||||
let mut rhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..self.bits - 2 {
|
||||
for i in 0..bitwidth - 2 {
|
||||
rhs_sum = FlatExpression::Add(
|
||||
box rhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(rhs_bits[i + 2]),
|
||||
box FlatExpression::Number(T::from(2).pow(self.bits - 2 - i - 1)),
|
||||
box FlatExpression::Number(T::from(2).pow(bitwidth - 2 - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
@ -251,7 +250,7 @@ impl Flattener {
|
|||
);
|
||||
|
||||
// define variables for the bits
|
||||
let sub_bits: Vec<FlatVariable> = (0..self.bits).map(|_| self.use_sym()).collect();
|
||||
let sub_bits: Vec<FlatVariable> = (0..bitwidth).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(DirectiveStatement::new(
|
||||
|
@ -261,7 +260,7 @@ impl Flattener {
|
|||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..self.bits {
|
||||
for i in 0..bitwidth {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(sub_bits[i]),
|
||||
FlatExpression::Mult(
|
||||
|
@ -274,12 +273,12 @@ impl Flattener {
|
|||
// sum(sym_b{i} * 2**i)
|
||||
let mut expr = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..self.bits {
|
||||
for i in 0..bitwidth {
|
||||
expr = FlatExpression::Add(
|
||||
box expr,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(sub_bits[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(self.bits - i - 1)),
|
||||
box FlatExpression::Number(T::from(2).pow(bitwidth - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
@ -443,11 +442,6 @@ impl Flattener {
|
|||
.flat_map(|x| x)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let params_flattened = params_flattened
|
||||
.into_iter()
|
||||
.map(|e| e.apply_recursive_substitution(&self.substitution))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (index, r) in params_flattened.into_iter().enumerate() {
|
||||
let new_var = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(new_var, r));
|
||||
|
@ -464,19 +458,19 @@ impl Flattener {
|
|||
expressions: list
|
||||
.expressions
|
||||
.into_iter()
|
||||
.map(|x| x.apply_direct_substitution(&replacement_map))
|
||||
.map(|x| x.apply_substitution(&replacement_map))
|
||||
.collect(),
|
||||
};
|
||||
}
|
||||
FlatStatement::Definition(var, rhs) => {
|
||||
let new_var = self.issue_new_variable();
|
||||
replacement_map.insert(var, new_var);
|
||||
let new_rhs = rhs.apply_direct_substitution(&replacement_map);
|
||||
let new_rhs = rhs.apply_substitution(&replacement_map);
|
||||
statements_flattened.push(FlatStatement::Definition(new_var, new_rhs));
|
||||
}
|
||||
FlatStatement::Condition(lhs, rhs) => {
|
||||
let new_lhs = lhs.apply_direct_substitution(&replacement_map);
|
||||
let new_rhs = rhs.apply_direct_substitution(&replacement_map);
|
||||
let new_lhs = lhs.apply_substitution(&replacement_map);
|
||||
let new_rhs = rhs.apply_substitution(&replacement_map);
|
||||
statements_flattened.push(FlatStatement::Condition(new_lhs, new_rhs));
|
||||
}
|
||||
FlatStatement::Directive(d) => {
|
||||
|
@ -492,7 +486,7 @@ impl Flattener {
|
|||
let new_inputs = d
|
||||
.inputs
|
||||
.into_iter()
|
||||
.map(|i| i.apply_direct_substitution(&replacement_map))
|
||||
.map(|i| i.apply_substitution(&replacement_map))
|
||||
.collect();
|
||||
statements_flattened.push(FlatStatement::Directive(
|
||||
DirectiveStatement {
|
||||
|
@ -659,13 +653,11 @@ impl Flattener {
|
|||
match exponent {
|
||||
FieldElementExpression::Number(ref e) => {
|
||||
// flatten the base expression
|
||||
let base_flattened = self
|
||||
.flatten_field_expression(
|
||||
functions_flattened,
|
||||
statements_flattened,
|
||||
base.clone(),
|
||||
)
|
||||
.apply_recursive_substitution(&self.substitution);
|
||||
let base_flattened = self.flatten_field_expression(
|
||||
functions_flattened,
|
||||
statements_flattened,
|
||||
base.clone(),
|
||||
);
|
||||
|
||||
// we require from the base to be linear
|
||||
// TODO change that
|
||||
|
@ -684,18 +676,14 @@ impl Flattener {
|
|||
// flatten(base ** n) = flatten(base) * flatten(base ** (n-1))
|
||||
e => {
|
||||
// flatten(base ** (n-1))
|
||||
let tmp_expression = self
|
||||
.flatten_field_expression(
|
||||
functions_flattened,
|
||||
statements_flattened,
|
||||
FieldElementExpression::Pow(
|
||||
box base,
|
||||
box FieldElementExpression::Number(
|
||||
e.clone() - T::one(),
|
||||
),
|
||||
),
|
||||
)
|
||||
.apply_recursive_substitution(&self.substitution);
|
||||
let tmp_expression = self.flatten_field_expression(
|
||||
functions_flattened,
|
||||
statements_flattened,
|
||||
FieldElementExpression::Pow(
|
||||
box base,
|
||||
box FieldElementExpression::Number(e.clone() - T::one()),
|
||||
),
|
||||
);
|
||||
|
||||
let id = self.use_sym();
|
||||
|
||||
|
@ -750,7 +738,6 @@ impl Flattener {
|
|||
statements_flattened,
|
||||
expressions[n.to_dec_string().parse::<usize>().unwrap()].clone(),
|
||||
)
|
||||
.apply_recursive_substitution(&self.substitution)
|
||||
}
|
||||
FieldElementArrayExpression::FunctionCall(..) => {
|
||||
unimplemented!("please use intermediate variables for now")
|
||||
|
@ -776,7 +763,6 @@ impl Flattener {
|
|||
),
|
||||
),
|
||||
)
|
||||
.apply_recursive_substitution(&self.substitution)
|
||||
}
|
||||
},
|
||||
e => {
|
||||
|
@ -862,7 +848,6 @@ impl Flattener {
|
|||
statements_flattened,
|
||||
lookup,
|
||||
)
|
||||
.apply_recursive_substitution(&self.substitution)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -936,7 +921,7 @@ impl Flattener {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn flatten_statement<T: Field>(
|
||||
fn flatten_statement<T: Field>(
|
||||
&mut self,
|
||||
functions_flattened: &Vec<FlatFunction<T>>,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
|
@ -952,11 +937,6 @@ impl Flattener {
|
|||
.flat_map(|x| x)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let flat_expressions = flat_expressions
|
||||
.into_iter()
|
||||
.map(|e| e.apply_recursive_substitution(&self.substitution))
|
||||
.collect();
|
||||
|
||||
statements_flattened.push(FlatStatement::Return(FlatExpressionList {
|
||||
expressions: flat_expressions,
|
||||
}));
|
||||
|
@ -975,11 +955,6 @@ impl Flattener {
|
|||
expr.clone(),
|
||||
);
|
||||
|
||||
let rhs = rhs
|
||||
.into_iter()
|
||||
.map(|e| e.apply_recursive_substitution(&self.substitution))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
match expr.get_type() {
|
||||
Type::FieldElement | Type::Boolean => {
|
||||
match assignee {
|
||||
|
@ -987,14 +962,6 @@ impl Flattener {
|
|||
let debug_name = v.clone().id;
|
||||
let var = self.use_variable(&debug_name);
|
||||
// handle return of function call
|
||||
let var_to_replace = self.get_latest_var_substitution(&debug_name);
|
||||
if !(var == var_to_replace)
|
||||
&& self.variables.contains(&var_to_replace)
|
||||
&& !self.substitution.contains_key(&var_to_replace)
|
||||
{
|
||||
self.substitution
|
||||
.insert(var_to_replace.clone(), var.clone());
|
||||
}
|
||||
statements_flattened
|
||||
.push(FlatStatement::Definition(var, rhs[0].clone()));
|
||||
}
|
||||
|
@ -1004,32 +971,17 @@ impl Flattener {
|
|||
_ => panic!("not a field element as rhs of array element update, should have been caught at semantic")
|
||||
};
|
||||
match index {
|
||||
box FieldElementExpression::Number(n) => {
|
||||
match array {
|
||||
box TypedAssignee::Identifier(id) => {
|
||||
let debug_name = format!("{}_c{}", id.id, n);
|
||||
let var = self.use_variable(&debug_name);
|
||||
// handle return of function call
|
||||
let var_to_replace =
|
||||
self.get_latest_var_substitution(&debug_name);
|
||||
if !(var == var_to_replace)
|
||||
&& self.variables.contains(&var_to_replace)
|
||||
&& !self
|
||||
.substitution
|
||||
.contains_key(&var_to_replace)
|
||||
{
|
||||
self.substitution.insert(
|
||||
var_to_replace.clone(),
|
||||
var.clone(),
|
||||
);
|
||||
}
|
||||
statements_flattened.push(
|
||||
FlatStatement::Definition(var, rhs[0].clone()),
|
||||
);
|
||||
}
|
||||
_ => panic!("no multidimension array for now"),
|
||||
box FieldElementExpression::Number(n) => match array {
|
||||
box TypedAssignee::Identifier(id) => {
|
||||
let debug_name = format!("{}_c{}", id.id, n);
|
||||
let var = self.use_variable(&debug_name);
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
var,
|
||||
rhs[0].clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => panic!("no multidimension array for now"),
|
||||
},
|
||||
box e => {
|
||||
// we have array[e] with e an arbitrary expression
|
||||
// first we check that e is in 0..array.len(), so we check that sum(if e == i then 1 else 0) == 1
|
||||
|
@ -1107,15 +1059,6 @@ impl Flattener {
|
|||
_ => unimplemented!(),
|
||||
};
|
||||
let var = self.use_variable(&debug_name);
|
||||
// handle return of function call
|
||||
let var_to_replace = self.get_latest_var_substitution(&debug_name);
|
||||
if !(var == var_to_replace)
|
||||
&& self.variables.contains(&var_to_replace)
|
||||
&& !self.substitution.contains_key(&var_to_replace)
|
||||
{
|
||||
self.substitution
|
||||
.insert(var_to_replace.clone(), var.clone());
|
||||
}
|
||||
statements_flattened.push(FlatStatement::Definition(var, r));
|
||||
}
|
||||
}
|
||||
|
@ -1132,14 +1075,12 @@ impl Flattener {
|
|||
functions_flattened,
|
||||
statements_flattened,
|
||||
e1,
|
||||
)
|
||||
.apply_recursive_substitution(&self.substitution),
|
||||
),
|
||||
self.flatten_field_expression(
|
||||
functions_flattened,
|
||||
statements_flattened,
|
||||
e2,
|
||||
)
|
||||
.apply_recursive_substitution(&self.substitution),
|
||||
),
|
||||
);
|
||||
|
||||
if lhs.is_linear() {
|
||||
|
@ -1157,14 +1098,12 @@ impl Flattener {
|
|||
functions_flattened,
|
||||
statements_flattened,
|
||||
e1,
|
||||
)
|
||||
.apply_recursive_substitution(&self.substitution),
|
||||
),
|
||||
self.flatten_boolean_expression(
|
||||
functions_flattened,
|
||||
statements_flattened,
|
||||
e2,
|
||||
)
|
||||
.apply_recursive_substitution(&self.substitution),
|
||||
),
|
||||
);
|
||||
|
||||
if lhs.is_linear() {
|
||||
|
@ -1193,13 +1132,6 @@ impl Flattener {
|
|||
),
|
||||
);
|
||||
|
||||
let (lhs, rhs) = (
|
||||
lhs.into_iter()
|
||||
.map(|e| e.apply_recursive_substitution(&self.substitution)),
|
||||
rhs.into_iter()
|
||||
.map(|e| e.apply_recursive_substitution(&self.substitution)),
|
||||
);
|
||||
|
||||
assert_eq!(lhs.len(), rhs.len());
|
||||
|
||||
for (l, r) in lhs.into_iter().zip(rhs.into_iter()) {
|
||||
|
@ -1239,15 +1171,13 @@ impl Flattener {
|
|||
|
||||
match rhs {
|
||||
TypedExpressionList::FunctionCall(fun_id, exprs, _) => {
|
||||
let rhs_flattened = self
|
||||
.flatten_function_call(
|
||||
functions_flattened,
|
||||
statements_flattened,
|
||||
&fun_id,
|
||||
var_types,
|
||||
&exprs,
|
||||
)
|
||||
.apply_recursive_substitution(&self.substitution);
|
||||
let rhs_flattened = self.flatten_function_call(
|
||||
functions_flattened,
|
||||
statements_flattened,
|
||||
&fun_id,
|
||||
var_types,
|
||||
&exprs,
|
||||
);
|
||||
|
||||
let mut iterator = rhs_flattened.expressions.into_iter();
|
||||
|
||||
|
@ -1259,16 +1189,6 @@ impl Flattener {
|
|||
for index in 0..size {
|
||||
let debug_name = format!("{}_c{}", v.id, index);
|
||||
let var = self.use_variable(&debug_name);
|
||||
// handle return of function call
|
||||
let var_to_replace =
|
||||
self.get_latest_var_substitution(&debug_name);
|
||||
if !(var == var_to_replace)
|
||||
&& self.variables.contains(&var_to_replace)
|
||||
&& !self.substitution.contains_key(&var_to_replace)
|
||||
{
|
||||
self.substitution
|
||||
.insert(var_to_replace.clone(), var.clone());
|
||||
}
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
var,
|
||||
iterator.next().unwrap(),
|
||||
|
@ -1278,16 +1198,6 @@ impl Flattener {
|
|||
Type::Boolean | Type::FieldElement => {
|
||||
let debug_name = v.id;
|
||||
let var = self.use_variable(&debug_name);
|
||||
// handle return of function call
|
||||
let var_to_replace =
|
||||
self.get_latest_var_substitution(&debug_name);
|
||||
if !(var == var_to_replace)
|
||||
&& self.variables.contains(&var_to_replace)
|
||||
&& !self.substitution.contains_key(&var_to_replace)
|
||||
{
|
||||
self.substitution
|
||||
.insert(var_to_replace.clone(), var.clone());
|
||||
}
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
var,
|
||||
iterator.next().unwrap(),
|
||||
|
@ -1310,14 +1220,11 @@ impl Flattener {
|
|||
///
|
||||
/// * `functions_flattened` - Vector where new flattened functions can be added.
|
||||
/// * `funct` - `TypedFunction` that will be flattened.
|
||||
pub fn flatten_function<T: Field>(
|
||||
fn flatten_function<T: Field>(
|
||||
&mut self,
|
||||
functions_flattened: &mut Vec<FlatFunction<T>>,
|
||||
funct: TypedFunction<T>,
|
||||
) -> FlatFunction<T> {
|
||||
self.variables = HashSet::new();
|
||||
self.substitution = HashMap::new();
|
||||
|
||||
self.bijection = BiMap::new();
|
||||
|
||||
self.next_var_idx = 0;
|
||||
|
@ -1371,7 +1278,7 @@ impl Flattener {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `prog` - `Prog`ram that will be flattened.
|
||||
pub fn flatten_program<T: Field>(&mut self, prog: TypedProg<T>) -> FlatProg<T> {
|
||||
fn flatten_program<T: Field>(&mut self, prog: TypedProg<T>) -> FlatProg<T> {
|
||||
let mut functions_flattened = Vec::new();
|
||||
|
||||
self.load_corelib(&mut functions_flattened);
|
||||
|
@ -1417,15 +1324,7 @@ impl Flattener {
|
|||
|
||||
fn get_latest_var_substitution(&mut self, name: &String) -> FlatVariable {
|
||||
// start with the variable name
|
||||
let latest_var = self.bijection.get_by_left(name).unwrap().clone();
|
||||
// loop {
|
||||
// // walk the substitutions
|
||||
// match self.substitution.get(&latest_var) {
|
||||
// Some(x) => latest_var = x,
|
||||
// None => break,
|
||||
// }
|
||||
// }
|
||||
latest_var
|
||||
self.bijection.get_by_left(name).unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1443,7 +1342,7 @@ mod tests {
|
|||
// def main()
|
||||
// a, b = foo()
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![FlatFunction {
|
||||
id: "foo".to_string(),
|
||||
arguments: vec![],
|
||||
|
@ -1493,7 +1392,7 @@ mod tests {
|
|||
|
||||
let a = FlatVariable::new(0);
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![FlatFunction {
|
||||
id: "dup".to_string(),
|
||||
arguments: vec![FlatParameter {
|
||||
|
@ -1548,7 +1447,7 @@ mod tests {
|
|||
// def main()
|
||||
// a = foo()
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![FlatFunction {
|
||||
id: "foo".to_string(),
|
||||
arguments: vec![],
|
||||
|
@ -1593,7 +1492,7 @@ mod tests {
|
|||
// a_0 = a + 1
|
||||
// return 1
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
|
||||
let funct = TypedFunction {
|
||||
|
@ -1676,7 +1575,7 @@ mod tests {
|
|||
},
|
||||
};
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
|
||||
let foo_flattened = flattener.flatten_function(&mut vec![], foo);
|
||||
|
||||
|
@ -1740,7 +1639,7 @@ mod tests {
|
|||
},
|
||||
};
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
|
||||
let expected = FlatFunction {
|
||||
id: String::from("main"),
|
||||
|
@ -1797,7 +1696,7 @@ mod tests {
|
|||
// should not panic
|
||||
//
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let functions = vec![
|
||||
TypedFunction {
|
||||
id: "foo".to_string(),
|
||||
|
@ -1867,7 +1766,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn if_else() {
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let expression = FieldElementExpression::IfElse(
|
||||
box BooleanExpression::Eq(
|
||||
box FieldElementExpression::Number(FieldPrime::from(32)),
|
||||
|
@ -1885,7 +1784,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn geq_leq() {
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let expression_le = BooleanExpression::Le(
|
||||
box FieldElementExpression::Number(FieldPrime::from(32)),
|
||||
box FieldElementExpression::Number(FieldPrime::from(4)),
|
||||
|
@ -1918,7 +1817,7 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(51)),
|
||||
);
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
flattener.load_corelib(&mut functions_flattened);
|
||||
flattener.flatten_field_expression(&functions_flattened, &mut vec![], expression);
|
||||
|
@ -1927,7 +1826,7 @@ mod tests {
|
|||
#[test]
|
||||
fn div() {
|
||||
// a = 5 / b / b
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
let mut statements_flattened = vec![];
|
||||
|
||||
|
@ -2040,7 +1939,7 @@ mod tests {
|
|||
fn field_array() {
|
||||
// foo = [ , , ]
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
let mut statements_flattened = vec![];
|
||||
let statement = TypedStatement::Definition(
|
||||
|
@ -2083,7 +1982,7 @@ mod tests {
|
|||
fn array_definition() {
|
||||
// field[3] foo = [1, 2, 3]
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
let mut statements_flattened = vec![];
|
||||
let statement = TypedStatement::Definition(
|
||||
|
@ -2129,7 +2028,7 @@ mod tests {
|
|||
// field[3] foo = [1, 2, 3]
|
||||
// foo[1]
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
let mut statements_flattened = vec![];
|
||||
let statement = TypedStatement::Definition(
|
||||
|
@ -2174,7 +2073,7 @@ mod tests {
|
|||
// bar = foo[0] + foo[1] + foo[2]
|
||||
// we don't optimise detecting constants, this will be done in an optimiser pass
|
||||
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
let mut statements_flattened = vec![];
|
||||
let def = TypedStatement::Definition(
|
||||
|
@ -2243,7 +2142,7 @@ mod tests {
|
|||
// if 1 == 1 then [1] else [3] fi
|
||||
|
||||
let with_arrays = {
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
flattener.load_corelib(&mut functions_flattened);
|
||||
let mut statements_flattened = vec![];
|
||||
|
@ -2275,7 +2174,7 @@ mod tests {
|
|||
};
|
||||
|
||||
let without_arrays = {
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
flattener.load_corelib(&mut functions_flattened);
|
||||
let mut statements_flattened = vec![];
|
||||
|
@ -2305,7 +2204,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn next_variable() {
|
||||
let mut flattener = Flattener::new(FieldPrime::get_required_bits());
|
||||
let mut flattener = Flattener::new();
|
||||
assert_eq!(
|
||||
FlatVariable::new(0),
|
||||
flattener.use_variable(&String::from("a"))
|
||||
|
|
|
@ -103,7 +103,7 @@ impl Optimizer {
|
|||
// filter out synonyms definitions
|
||||
FlatStatement::Definition(_, FlatExpression::Identifier(_)) => None,
|
||||
// substitute all other statements
|
||||
_ => Some(statement.apply_direct_substitution(&self.substitution)),
|
||||
_ => Some(statement.apply_substitution(&self.substitution)),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
@ -112,7 +112,7 @@ impl Optimizer {
|
|||
let optimized_arguments = funct
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|arg| arg.apply_direct_substitution(&self.substitution))
|
||||
.map(|arg| arg.apply_substitution(&self.substitution))
|
||||
.collect();
|
||||
|
||||
FlatFunction {
|
||||
|
|
Loading…
Reference in a new issue