add use_parameter for functions in flattening to add boolean constraints
This commit is contained in:
parent
275d22a72c
commit
3e410f84ee
2 changed files with 95 additions and 10 deletions
|
@ -17,7 +17,7 @@ use std::collections::{BTreeMap, HashMap};
|
|||
use std::fmt;
|
||||
use zokrates_field::field::Field;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub struct FlatProg<T: Field> {
|
||||
/// FlatFunctions of the program
|
||||
pub functions: Vec<FlatFunction<T>>,
|
||||
|
|
|
@ -1211,18 +1211,14 @@ impl<'ast> Flattener<'ast> {
|
|||
self.layout = HashMap::new();
|
||||
|
||||
self.next_var_idx = 0;
|
||||
|
||||
let mut arguments_flattened: Vec<FlatParameter> = Vec::new();
|
||||
let mut statements_flattened: Vec<FlatStatement<T>> = Vec::new();
|
||||
|
||||
// push parameters
|
||||
for arg in &funct.arguments {
|
||||
let variables = self.use_variable(&arg.id);
|
||||
arguments_flattened.extend(variables.into_iter().map(|v| FlatParameter {
|
||||
id: v,
|
||||
private: arg.private,
|
||||
}));
|
||||
}
|
||||
let arguments_flattened = funct
|
||||
.arguments
|
||||
.into_iter()
|
||||
.flat_map(|p| self.use_parameter(&p, &mut statements_flattened))
|
||||
.collect();
|
||||
|
||||
// flatten statements in functions and apply substitution
|
||||
for stat in funct.statements {
|
||||
|
@ -1276,6 +1272,26 @@ impl<'ast> Flattener<'ast> {
|
|||
vars
|
||||
}
|
||||
|
||||
fn use_parameter<T: Field>(
|
||||
&mut self,
|
||||
parameter: &Parameter<'ast>,
|
||||
statements: &mut Vec<FlatStatement<T>>,
|
||||
) -> Vec<FlatParameter> {
|
||||
let variables = self.use_variable(¶meter.id);
|
||||
match parameter.id.get_type() {
|
||||
Type::Boolean => statements.extend(Self::boolean_constraint(&variables)),
|
||||
_ => {}
|
||||
};
|
||||
|
||||
variables
|
||||
.into_iter()
|
||||
.map(|v| FlatParameter {
|
||||
id: v,
|
||||
private: parameter.private,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn issue_new_variables(&mut self, count: usize) -> Vec<FlatVariable> {
|
||||
(0..count)
|
||||
.map(|_| {
|
||||
|
@ -1286,6 +1302,21 @@ impl<'ast> Flattener<'ast> {
|
|||
.collect()
|
||||
}
|
||||
|
||||
fn boolean_constraint<T: Field>(variables: &Vec<FlatVariable>) -> Vec<FlatStatement<T>> {
|
||||
variables
|
||||
.iter()
|
||||
.map(|v| {
|
||||
FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*v),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*v),
|
||||
box FlatExpression::Identifier(*v),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// create an internal variable. We do not register it in the layout
|
||||
fn use_sym(&mut self) -> FlatVariable {
|
||||
let var = self.issue_new_variables(1);
|
||||
|
@ -1312,6 +1343,60 @@ mod tests {
|
|||
use crate::types::Type;
|
||||
use zokrates_field::field::FieldPrime;
|
||||
|
||||
mod boolean_checks {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn boolean_arg() {
|
||||
// def main(bool a):
|
||||
// return a
|
||||
//
|
||||
// -> should flatten to
|
||||
//
|
||||
// def main(_0) -> (1):
|
||||
// _0 * _0 == _0
|
||||
// return _0
|
||||
|
||||
let function: TypedFunction<FieldPrime> = TypedFunction {
|
||||
id: "main",
|
||||
arguments: vec![Parameter::private(Variable::boolean("a".into()))],
|
||||
statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier(
|
||||
"a".into(),
|
||||
)
|
||||
.into()])],
|
||||
signature: Signature::new()
|
||||
.inputs(vec![Type::Boolean])
|
||||
.outputs(vec![Type::Boolean]),
|
||||
};
|
||||
|
||||
let expected = FlatFunction {
|
||||
id: String::from("main"),
|
||||
arguments: vec![FlatParameter::private(FlatVariable::new(0))],
|
||||
statements: vec![
|
||||
FlatStatement::Condition(
|
||||
FlatExpression::Identifier(FlatVariable::new(0)),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(FlatVariable::new(0)),
|
||||
box FlatExpression::Identifier(FlatVariable::new(0)),
|
||||
),
|
||||
),
|
||||
FlatStatement::Return(FlatExpressionList {
|
||||
expressions: vec![FlatExpression::Identifier(FlatVariable::new(0))],
|
||||
}),
|
||||
],
|
||||
signature: Signature::new()
|
||||
.inputs(vec![Type::Boolean])
|
||||
.outputs(vec![Type::Boolean]),
|
||||
};
|
||||
|
||||
let mut flattener = Flattener::new();
|
||||
|
||||
let flat_function = flattener.flatten_function(&mut vec![], function);
|
||||
|
||||
assert_eq!(flat_function, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_definition() {
|
||||
// def foo()
|
||||
|
|
Loading…
Reference in a new issue