1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

make if_else generic

This commit is contained in:
schaeff 2019-07-11 15:54:51 +02:00
commit e89d60be04
2 changed files with 58 additions and 181 deletions

View file

@ -8,7 +8,6 @@
use crate::flat_absy::*;
use crate::helpers::{DirectiveStatement, Helper, RustHelper};
use crate::typed_absy::*;
use crate::types::conversions::cast;
use crate::types::Signature;
use crate::types::Type;
use std::collections::HashMap;
@ -123,68 +122,52 @@ impl<'ast> Flattener<'ast> {
}
}
/// Loads the code library
fn load_corelib<T: Field>(&mut self, functions_flattened: &mut Vec<FlatFunction<T>>) -> () {
// Load type casting functions
functions_flattened.push(cast(&Type::Boolean, &Type::FieldElement));
fn flatten_if_else_expression<U: Flatten<'ast, T>, T: Field>(
&mut self,
functions_flattened: &Vec<FlatFunction<T>>,
statements_flattened: &mut Vec<FlatStatement<T>>,
condition: BooleanExpression<'ast, T>,
consequence: U,
alternative: U,
) -> FlatExpression<T> {
let condition =
self.flatten_boolean_expression(functions_flattened, statements_flattened, condition);
// Load IfElse helper for fields
let ie = TypedFunction {
id: "_if_else_field",
arguments: vec![
Parameter {
id: Variable {
id: "condition".into(),
_type: Type::Boolean,
},
private: true,
},
Parameter {
id: Variable {
id: "consequence".into(),
_type: Type::FieldElement,
},
private: true,
},
Parameter {
id: Variable {
id: "alternative".into(),
_type: Type::FieldElement,
},
private: true,
},
],
statements: vec![
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("condition_as_field".into())),
FieldElementExpression::FunctionCall(
"_bool_to_field".to_string(),
vec![BooleanExpression::Identifier("condition".into()).into()],
)
.into(),
let consequence = consequence.flatten(self, functions_flattened, statements_flattened);
let alternative = alternative.flatten(self, functions_flattened, statements_flattened);
let condition_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(condition_id, condition));
let consequence_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(consequence_id, consequence));
let alternative_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(alternative_id, alternative));
let term0 = self.use_sym();
statements_flattened.push(FlatStatement::Definition(
term0,
FlatExpression::Mult(box condition_id.clone().into(), box consequence_id.into()),
));
let term1 = self.use_sym();
statements_flattened.push(FlatStatement::Definition(
term1,
FlatExpression::Mult(
box FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box condition_id.into(),
),
TypedStatement::Return(vec![FieldElementExpression::Add(
box FieldElementExpression::Mult(
box FieldElementExpression::Identifier("condition_as_field".into()),
box FieldElementExpression::Identifier("consequence".into()),
),
box FieldElementExpression::Mult(
box FieldElementExpression::Sub(
box FieldElementExpression::Number(T::one()),
box FieldElementExpression::Identifier("condition_as_field".into()),
),
box FieldElementExpression::Identifier("alternative".into()),
),
)
.into()]),
],
signature: Signature::new()
.inputs(vec![Type::Boolean, Type::FieldElement, Type::FieldElement])
.outputs(vec![Type::FieldElement]),
};
let ief = self.flatten_function(functions_flattened, ie);
functions_flattened.push(ief);
box alternative_id.into(),
),
));
let res = self.use_sym();
statements_flattened.push(FlatStatement::Definition(
res,
FlatExpression::Add(box term0.into(), box term1.into()),
));
res.into()
}
fn flatten_select_expression<U: Flatten<'ast, T>, T: Field>(
@ -603,28 +586,14 @@ impl<'ast> Flattener<'ast> {
true => T::from(1),
false => T::from(0),
}),
BooleanExpression::IfElse(box condition, box consequent, box alternative) => self
.flatten_function_call(
BooleanExpression::IfElse(box condition, box consequence, box alternative) => self
.flatten_if_else_expression(
functions_flattened,
statements_flattened,
&"_if_else_field".to_string(),
vec![Type::FieldElement],
&vec![
condition.into(),
FieldElementExpression::FunctionCall(
"_bool_to_field".to_string(),
vec![consequent.into()],
)
.into(),
FieldElementExpression::FunctionCall(
"_bool_to_field".to_string(),
vec![alternative.into()],
)
.into(),
],
)
.expressions[0]
.clone(),
condition,
consequence,
alternative,
),
BooleanExpression::Select(box array, box index) => self
.flatten_select_expression::<BooleanExpression<'ast, T>, _>(
functions_flattened,
@ -971,16 +940,14 @@ impl<'ast> Flattener<'ast> {
_ => panic!("Expected number as pow exponent"),
}
}
FieldElementExpression::IfElse(box condition, box consequent, box alternative) => self
.flatten_function_call(
FieldElementExpression::IfElse(box condition, box consequence, box alternative) => self
.flatten_if_else_expression(
functions_flattened,
statements_flattened,
&"_if_else_field".to_string(),
vec![Type::FieldElement],
&vec![condition.into(), consequent.into(), alternative.into()],
)
.expressions[0]
.clone(),
condition,
consequence,
alternative,
),
FieldElementExpression::FunctionCall(ref id, ref param_expressions) => {
let exprs_flattened = self.flatten_function_call(
functions_flattened,
@ -1389,8 +1356,6 @@ impl<'ast> Flattener<'ast> {
fn flatten_program<T: Field>(&mut self, prog: TypedProg<'ast, T>) -> FlatProg<T> {
let mut functions_flattened = Vec::new();
self.load_corelib(&mut functions_flattened);
for func in prog.imported_functions {
functions_flattened.push(func);
}
@ -2156,10 +2121,7 @@ mod tests {
box FieldElementExpression::Number(FieldPrime::from(51)),
);
let mut functions_flattened = vec![];
flattener.load_corelib(&mut functions_flattened);
flattener.flatten_field_expression(&functions_flattened, &mut vec![], expression);
flattener.flatten_field_expression(&vec![], &mut vec![], expression);
}
#[test]
@ -2198,9 +2160,7 @@ mod tests {
);
let mut flattener = Flattener::new();
let mut functions_flattened = vec![];
flattener.load_corelib(&mut functions_flattened);
flattener.flatten_field_expression(&functions_flattened, &mut vec![], expression);
flattener.flatten_field_expression(&vec![], &mut vec![], expression);
}
#[test]
@ -2548,7 +2508,6 @@ mod tests {
let with_arrays = {
let mut flattener = Flattener::new();
let mut functions_flattened = vec![];
flattener.load_corelib(&mut functions_flattened);
let mut statements_flattened = vec![];
let e =
@ -2591,7 +2550,6 @@ mod tests {
let without_arrays = {
let mut flattener = Flattener::new();
let mut functions_flattened = vec![];
flattener.load_corelib(&mut functions_flattened);
let mut statements_flattened = vec![];
// if 1 == 1 then 1 else 3 fi

View file

@ -110,92 +110,11 @@ pub fn split<T: Field>() -> FlatProg<T> {
}
}
pub fn cast<T: Field>(from: &Type, to: &Type) -> FlatFunction<T> {
let mut counter = 0;
let mut layout = HashMap::new();
let arguments = (0..from.get_primitive_count())
.enumerate()
.map(|(index, _)| FlatParameter {
id: FlatVariable::new(index),
private: true,
})
.collect();
let binding_inputs: Vec<_> = (0..from.get_primitive_count())
.map(|index| use_variable(&mut layout, format!("i{}", index), &mut counter))
.collect();
let binding_outputs: Vec<FlatVariable> = (0..to.get_primitive_count())
.map(|index| use_variable(&mut layout, format!("o{}", index), &mut counter))
.collect();
let outputs = binding_outputs
.iter()
.map(|o| FlatExpression::Identifier(o.clone()))
.collect();
let bindings: Vec<_> = match (from, to) {
(Type::Boolean, Type::FieldElement) => binding_outputs
.into_iter()
.zip(binding_inputs.into_iter())
.map(|(o, i)| FlatStatement::Definition(o, i.into()))
.collect(),
_ => panic!(format!("can't cast {} to {}", from, to)),
};
let signature = Signature {
inputs: vec![from.clone()],
outputs: vec![to.clone()],
};
let statements = bindings
.into_iter()
.chain(std::iter::once(FlatStatement::Return(FlatExpressionList {
expressions: outputs,
})))
.collect();
FlatFunction {
id: format!("_{}_to_{}", from, to),
arguments,
statements,
signature,
}
}
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::field::FieldPrime;
#[cfg(test)]
mod cast {
use super::*;
#[test]
fn bool_to_field() {
let b2f: FlatFunction<FieldPrime> = cast(&Type::Boolean, &Type::FieldElement);
assert_eq!(b2f.id, String::from("_bool_to_field"));
assert_eq!(
b2f.arguments,
vec![FlatParameter::private(FlatVariable::new(0))]
);
assert_eq!(b2f.statements.len(), 2); // 1 definition, 1 return
assert_eq!(
b2f.statements[0],
FlatStatement::Definition(FlatVariable::new(1), FlatVariable::new(0).into())
);
assert_eq!(
b2f.statements[1],
FlatStatement::Return(FlatExpressionList {
expressions: vec![FlatExpression::Identifier(FlatVariable::new(1))]
})
);
assert_eq!(b2f.signature.outputs.len(), 1);
}
}
#[cfg(test)]
mod split {
use super::*;