make if_else generic
This commit is contained in:
commit
e89d60be04
2 changed files with 58 additions and 181 deletions
|
@ -8,7 +8,6 @@
|
|||
use crate::flat_absy::*;
|
||||
use crate::helpers::{DirectiveStatement, Helper, RustHelper};
|
||||
use crate::typed_absy::*;
|
||||
use crate::types::conversions::cast;
|
||||
use crate::types::Signature;
|
||||
use crate::types::Type;
|
||||
use std::collections::HashMap;
|
||||
|
@ -123,68 +122,52 @@ impl<'ast> Flattener<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Loads the code library
|
||||
fn load_corelib<T: Field>(&mut self, functions_flattened: &mut Vec<FlatFunction<T>>) -> () {
|
||||
// Load type casting functions
|
||||
functions_flattened.push(cast(&Type::Boolean, &Type::FieldElement));
|
||||
fn flatten_if_else_expression<U: Flatten<'ast, T>, T: Field>(
|
||||
&mut self,
|
||||
functions_flattened: &Vec<FlatFunction<T>>,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
condition: BooleanExpression<'ast, T>,
|
||||
consequence: U,
|
||||
alternative: U,
|
||||
) -> FlatExpression<T> {
|
||||
let condition =
|
||||
self.flatten_boolean_expression(functions_flattened, statements_flattened, condition);
|
||||
|
||||
// Load IfElse helper for fields
|
||||
let ie = TypedFunction {
|
||||
id: "_if_else_field",
|
||||
arguments: vec![
|
||||
Parameter {
|
||||
id: Variable {
|
||||
id: "condition".into(),
|
||||
_type: Type::Boolean,
|
||||
},
|
||||
private: true,
|
||||
},
|
||||
Parameter {
|
||||
id: Variable {
|
||||
id: "consequence".into(),
|
||||
_type: Type::FieldElement,
|
||||
},
|
||||
private: true,
|
||||
},
|
||||
Parameter {
|
||||
id: Variable {
|
||||
id: "alternative".into(),
|
||||
_type: Type::FieldElement,
|
||||
},
|
||||
private: true,
|
||||
},
|
||||
],
|
||||
statements: vec![
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("condition_as_field".into())),
|
||||
FieldElementExpression::FunctionCall(
|
||||
"_bool_to_field".to_string(),
|
||||
vec![BooleanExpression::Identifier("condition".into()).into()],
|
||||
)
|
||||
.into(),
|
||||
let consequence = consequence.flatten(self, functions_flattened, statements_flattened);
|
||||
|
||||
let alternative = alternative.flatten(self, functions_flattened, statements_flattened);
|
||||
|
||||
let condition_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(condition_id, condition));
|
||||
|
||||
let consequence_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(consequence_id, consequence));
|
||||
|
||||
let alternative_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(alternative_id, alternative));
|
||||
|
||||
let term0 = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
term0,
|
||||
FlatExpression::Mult(box condition_id.clone().into(), box consequence_id.into()),
|
||||
));
|
||||
let term1 = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
term1,
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box condition_id.into(),
|
||||
),
|
||||
TypedStatement::Return(vec![FieldElementExpression::Add(
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Identifier("condition_as_field".into()),
|
||||
box FieldElementExpression::Identifier("consequence".into()),
|
||||
),
|
||||
box FieldElementExpression::Mult(
|
||||
box FieldElementExpression::Sub(
|
||||
box FieldElementExpression::Number(T::one()),
|
||||
box FieldElementExpression::Identifier("condition_as_field".into()),
|
||||
),
|
||||
box FieldElementExpression::Identifier("alternative".into()),
|
||||
),
|
||||
)
|
||||
.into()]),
|
||||
],
|
||||
signature: Signature::new()
|
||||
.inputs(vec![Type::Boolean, Type::FieldElement, Type::FieldElement])
|
||||
.outputs(vec![Type::FieldElement]),
|
||||
};
|
||||
|
||||
let ief = self.flatten_function(functions_flattened, ie);
|
||||
functions_flattened.push(ief);
|
||||
box alternative_id.into(),
|
||||
),
|
||||
));
|
||||
let res = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
res,
|
||||
FlatExpression::Add(box term0.into(), box term1.into()),
|
||||
));
|
||||
res.into()
|
||||
}
|
||||
|
||||
fn flatten_select_expression<U: Flatten<'ast, T>, T: Field>(
|
||||
|
@ -603,28 +586,14 @@ impl<'ast> Flattener<'ast> {
|
|||
true => T::from(1),
|
||||
false => T::from(0),
|
||||
}),
|
||||
BooleanExpression::IfElse(box condition, box consequent, box alternative) => self
|
||||
.flatten_function_call(
|
||||
BooleanExpression::IfElse(box condition, box consequence, box alternative) => self
|
||||
.flatten_if_else_expression(
|
||||
functions_flattened,
|
||||
statements_flattened,
|
||||
&"_if_else_field".to_string(),
|
||||
vec![Type::FieldElement],
|
||||
&vec![
|
||||
condition.into(),
|
||||
FieldElementExpression::FunctionCall(
|
||||
"_bool_to_field".to_string(),
|
||||
vec![consequent.into()],
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::FunctionCall(
|
||||
"_bool_to_field".to_string(),
|
||||
vec![alternative.into()],
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
)
|
||||
.expressions[0]
|
||||
.clone(),
|
||||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
),
|
||||
BooleanExpression::Select(box array, box index) => self
|
||||
.flatten_select_expression::<BooleanExpression<'ast, T>, _>(
|
||||
functions_flattened,
|
||||
|
@ -971,16 +940,14 @@ impl<'ast> Flattener<'ast> {
|
|||
_ => panic!("Expected number as pow exponent"),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::IfElse(box condition, box consequent, box alternative) => self
|
||||
.flatten_function_call(
|
||||
FieldElementExpression::IfElse(box condition, box consequence, box alternative) => self
|
||||
.flatten_if_else_expression(
|
||||
functions_flattened,
|
||||
statements_flattened,
|
||||
&"_if_else_field".to_string(),
|
||||
vec![Type::FieldElement],
|
||||
&vec![condition.into(), consequent.into(), alternative.into()],
|
||||
)
|
||||
.expressions[0]
|
||||
.clone(),
|
||||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
),
|
||||
FieldElementExpression::FunctionCall(ref id, ref param_expressions) => {
|
||||
let exprs_flattened = self.flatten_function_call(
|
||||
functions_flattened,
|
||||
|
@ -1389,8 +1356,6 @@ impl<'ast> Flattener<'ast> {
|
|||
fn flatten_program<T: Field>(&mut self, prog: TypedProg<'ast, T>) -> FlatProg<T> {
|
||||
let mut functions_flattened = Vec::new();
|
||||
|
||||
self.load_corelib(&mut functions_flattened);
|
||||
|
||||
for func in prog.imported_functions {
|
||||
functions_flattened.push(func);
|
||||
}
|
||||
|
@ -2156,10 +2121,7 @@ mod tests {
|
|||
box FieldElementExpression::Number(FieldPrime::from(51)),
|
||||
);
|
||||
|
||||
let mut functions_flattened = vec![];
|
||||
flattener.load_corelib(&mut functions_flattened);
|
||||
|
||||
flattener.flatten_field_expression(&functions_flattened, &mut vec![], expression);
|
||||
flattener.flatten_field_expression(&vec![], &mut vec![], expression);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -2198,9 +2160,7 @@ mod tests {
|
|||
);
|
||||
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
flattener.load_corelib(&mut functions_flattened);
|
||||
flattener.flatten_field_expression(&functions_flattened, &mut vec![], expression);
|
||||
flattener.flatten_field_expression(&vec![], &mut vec![], expression);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -2548,7 +2508,6 @@ mod tests {
|
|||
let with_arrays = {
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
flattener.load_corelib(&mut functions_flattened);
|
||||
let mut statements_flattened = vec![];
|
||||
|
||||
let e =
|
||||
|
@ -2591,7 +2550,6 @@ mod tests {
|
|||
let without_arrays = {
|
||||
let mut flattener = Flattener::new();
|
||||
let mut functions_flattened = vec![];
|
||||
flattener.load_corelib(&mut functions_flattened);
|
||||
let mut statements_flattened = vec![];
|
||||
|
||||
// if 1 == 1 then 1 else 3 fi
|
||||
|
|
|
@ -110,92 +110,11 @@ pub fn split<T: Field>() -> FlatProg<T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn cast<T: Field>(from: &Type, to: &Type) -> FlatFunction<T> {
|
||||
let mut counter = 0;
|
||||
|
||||
let mut layout = HashMap::new();
|
||||
|
||||
let arguments = (0..from.get_primitive_count())
|
||||
.enumerate()
|
||||
.map(|(index, _)| FlatParameter {
|
||||
id: FlatVariable::new(index),
|
||||
private: true,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let binding_inputs: Vec<_> = (0..from.get_primitive_count())
|
||||
.map(|index| use_variable(&mut layout, format!("i{}", index), &mut counter))
|
||||
.collect();
|
||||
let binding_outputs: Vec<FlatVariable> = (0..to.get_primitive_count())
|
||||
.map(|index| use_variable(&mut layout, format!("o{}", index), &mut counter))
|
||||
.collect();
|
||||
|
||||
let outputs = binding_outputs
|
||||
.iter()
|
||||
.map(|o| FlatExpression::Identifier(o.clone()))
|
||||
.collect();
|
||||
|
||||
let bindings: Vec<_> = match (from, to) {
|
||||
(Type::Boolean, Type::FieldElement) => binding_outputs
|
||||
.into_iter()
|
||||
.zip(binding_inputs.into_iter())
|
||||
.map(|(o, i)| FlatStatement::Definition(o, i.into()))
|
||||
.collect(),
|
||||
_ => panic!(format!("can't cast {} to {}", from, to)),
|
||||
};
|
||||
|
||||
let signature = Signature {
|
||||
inputs: vec![from.clone()],
|
||||
outputs: vec![to.clone()],
|
||||
};
|
||||
|
||||
let statements = bindings
|
||||
.into_iter()
|
||||
.chain(std::iter::once(FlatStatement::Return(FlatExpressionList {
|
||||
expressions: outputs,
|
||||
})))
|
||||
.collect();
|
||||
|
||||
FlatFunction {
|
||||
id: format!("_{}_to_{}", from, to),
|
||||
arguments,
|
||||
statements,
|
||||
signature,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use zokrates_field::field::FieldPrime;
|
||||
|
||||
#[cfg(test)]
|
||||
mod cast {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn bool_to_field() {
|
||||
let b2f: FlatFunction<FieldPrime> = cast(&Type::Boolean, &Type::FieldElement);
|
||||
assert_eq!(b2f.id, String::from("_bool_to_field"));
|
||||
assert_eq!(
|
||||
b2f.arguments,
|
||||
vec![FlatParameter::private(FlatVariable::new(0))]
|
||||
);
|
||||
assert_eq!(b2f.statements.len(), 2); // 1 definition, 1 return
|
||||
assert_eq!(
|
||||
b2f.statements[0],
|
||||
FlatStatement::Definition(FlatVariable::new(1), FlatVariable::new(0).into())
|
||||
);
|
||||
assert_eq!(
|
||||
b2f.statements[1],
|
||||
FlatStatement::Return(FlatExpressionList {
|
||||
expressions: vec![FlatExpression::Identifier(FlatVariable::new(1))]
|
||||
})
|
||||
);
|
||||
assert_eq!(b2f.signature.outputs.len(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod split {
|
||||
use super::*;
|
||||
|
|
Loading…
Reference in a new issue