1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

generalize arrays. wip, basic boolean array mvp

This commit is contained in:
schaeff 2019-06-22 00:30:33 +02:00
parent ba2e6e32ac
commit be375073d1
16 changed files with 639 additions and 325 deletions

View file

@ -0,0 +1,6 @@
def main(bool[3] a) -> (field[3]):
field[3] result = [0; 3]
for field i in 0..3 do
result[i] = if a[i] then 1 else 0 fi
endfor
return result

View file

@ -517,9 +517,10 @@ impl<'ast> From<pest::Type<'ast>> for Type {
}
};
match t.ty {
pest::BasicType::Field(_) => Type::FieldElementArray(size),
pest::BasicType::Field(_) => Type::array(Type::FieldElement, size),
pest::BasicType::Boolean(_) => Type::array(Type::Boolean, size),
_ => unimplemented!(
"Array elements should be field elements, found {}",
"Array elements of {} are not implemented yet",
t.span.as_str()
),
}

View file

@ -37,7 +37,7 @@ impl<'ast> Variable<'ast> {
pub fn field_array<S: Into<&'ast str>>(id: S, size: usize) -> Variable<'ast> {
Variable {
id: id.into(),
_type: Type::FieldElementArray(size),
_type: Type::array(Type::FieldElement, size),
}
}

View file

@ -12,6 +12,7 @@ use crate::types::conversions::cast;
use crate::types::Signature;
use crate::types::Type;
use std::collections::HashMap;
use std::convert::TryFrom;
use zokrates_field::field::Field;
/// Flattener, computes flattened program.
@ -45,7 +46,7 @@ impl<'ast> Flattener<'ast> {
// Load type casting functions
functions_flattened.push(cast(&Type::Boolean, &Type::FieldElement));
// Load IfElse helper
// Load IfElse helper for fields
let ie = TypedFunction {
id: "_if_else_field",
arguments: vec![
@ -395,6 +396,168 @@ impl<'ast> Flattener<'ast> {
true => T::from(1),
false => T::from(0),
}),
BooleanExpression::IfElse(box condition, box consequent, box alternative) => self
.flatten_function_call(
functions_flattened,
statements_flattened,
&"_if_else_field".to_string(),
vec![Type::Boolean],
&vec![
condition.into(),
FieldElementExpression::FunctionCall(
"_bool_to_field".to_string(),
vec![consequent.into()],
)
.into(),
FieldElementExpression::FunctionCall(
"_bool_to_field".to_string(),
vec![alternative.into()],
)
.into(),
],
)
.expressions[0]
.clone(),
BooleanExpression::Select(box array, box index) => {
let size = array.size();
let ty = array.inner_type();
assert_eq!(ty, &Type::Boolean);
println!("{}[{}]", array, index);
println!("{:?}", self.layout);
match index {
FieldElementExpression::Number(n) => match array.inner {
ArrayExpressionInner::Identifier(id) => {
assert!(n < T::from(size));
FlatExpression::Identifier(
self.layout.get(&id).unwrap().clone()
[n.to_dec_string().parse::<usize>().unwrap()],
)
}
ArrayExpressionInner::Value(expressions) => {
assert!(n < T::from(size));
self.flatten_boolean_expression(
functions_flattened,
statements_flattened,
BooleanExpression::try_from(
expressions[n.to_dec_string().parse::<usize>().unwrap()]
.clone(),
)
.unwrap(),
)
}
ArrayExpressionInner::FunctionCall(..) => {
unimplemented!("please use intermediate variables for now")
}
ArrayExpressionInner::IfElse(condition, consequence, alternative) => {
// [if cond then [a, b] else [c, d]][1] == if cond then [a, b][1] else [c, d][1]
self.flatten_boolean_expression(
functions_flattened,
statements_flattened,
BooleanExpression::IfElse(
condition,
box BooleanExpression::Select(
consequence,
box FieldElementExpression::Number(n.clone()),
),
box BooleanExpression::Select(
alternative,
box FieldElementExpression::Number(n),
),
),
)
}
},
e => {
let size = array.size();
// we have array[e] with e an arbitrary expression
// first we check that e is in 0..array.len(), so we check that sum(if e == i then 1 else 0) == 1
// here depending on the size, we could use a proper range check based on bits
let range_check = (0..size)
.map(|i| {
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box e.clone(),
box FieldElementExpression::Number(T::from(i)),
),
box FieldElementExpression::Number(T::from(1)),
box FieldElementExpression::Number(T::from(0)),
)
})
.fold(FieldElementExpression::Number(T::from(0)), |acc, e| {
FieldElementExpression::Add(box acc, box e)
});
let range_check_statement = TypedStatement::Condition(
FieldElementExpression::Number(T::from(1)).into(),
range_check.into(),
);
self.flatten_statement(
functions_flattened,
statements_flattened,
range_check_statement,
);
// now we flatten to sum(if e == i then array[i] else false)
(0..size)
.map(|i| {
self.flatten_boolean_expression(
functions_flattened,
statements_flattened,
BooleanExpression::IfElse(
box BooleanExpression::Eq(
box e.clone(),
box FieldElementExpression::Number(T::from(i)),
),
box match array.clone().inner {
ArrayExpressionInner::Identifier(id) => {
BooleanExpression::Select(
box ArrayExpression {
ty: ty.clone(),
size,
inner: ArrayExpressionInner::Identifier(id),
},
box FieldElementExpression::Number(T::from(i)),
)
}
ArrayExpressionInner::Value(expressions) => {
assert_eq!(size, expressions.len());
BooleanExpression::try_from(expressions[i].clone())
.unwrap()
}
ArrayExpressionInner::FunctionCall(..) => {
unimplemented!(
"please use intermediate variables for now"
)
}
ArrayExpressionInner::IfElse(
condition,
consequence,
alternative,
) => BooleanExpression::IfElse(
condition,
box BooleanExpression::Select(
consequence,
box FieldElementExpression::Number(T::from(i)),
),
box BooleanExpression::Select(
alternative,
box FieldElementExpression::Number(T::from(i)),
),
),
},
box BooleanExpression::Value(false),
),
)
})
.fold(FlatExpression::Number(T::from(0)), |acc, e| {
FlatExpression::Add(box acc, box e)
})
}
}
}
}
}
@ -515,8 +678,8 @@ impl<'ast> Flattener<'ast> {
TypedExpression::Boolean(e) => {
vec![self.flatten_boolean_expression(functions_flattened, statements_flattened, e)]
}
TypedExpression::FieldElementArray(e) => {
self.flatten_field_array_expression(functions_flattened, statements_flattened, e)
TypedExpression::Array(e) => {
self.flatten_array_expression(functions_flattened, statements_flattened, e)
}
}
}
@ -717,31 +880,34 @@ impl<'ast> Flattener<'ast> {
exprs_flattened.expressions[0].clone()
}
FieldElementExpression::Select(box array, box index) => {
let size = array.size();
let ty = array.inner_type();
assert_eq!(ty, &Type::FieldElement);
match index {
FieldElementExpression::Number(n) => match array {
FieldElementArrayExpression::Identifier(size, id) => {
FieldElementExpression::Number(n) => match array.inner {
ArrayExpressionInner::Identifier(id) => {
assert!(n < T::from(size));
FlatExpression::Identifier(
self.layout.get(&id).unwrap().clone()
[n.to_dec_string().parse::<usize>().unwrap()],
)
}
FieldElementArrayExpression::Value(size, expressions) => {
ArrayExpressionInner::Value(expressions) => {
assert!(n < T::from(size));
self.flatten_field_expression(
functions_flattened,
statements_flattened,
expressions[n.to_dec_string().parse::<usize>().unwrap()].clone(),
FieldElementExpression::try_from(
expressions[n.to_dec_string().parse::<usize>().unwrap()]
.clone(),
)
.unwrap(),
)
}
FieldElementArrayExpression::FunctionCall(..) => {
ArrayExpressionInner::FunctionCall(..) => {
unimplemented!("please use intermediate variables for now")
}
FieldElementArrayExpression::IfElse(
condition,
consequence,
alternative,
) => {
ArrayExpressionInner::IfElse(condition, consequence, alternative) => {
// [if cond then [a, b] else [c, d]][1] == if cond then [a, b][1] else [c, d][1]
self.flatten_field_expression(
functions_flattened,
@ -799,25 +965,26 @@ impl<'ast> Flattener<'ast> {
box e.clone(),
box FieldElementExpression::Number(T::from(i)),
),
box match array.clone() {
FieldElementArrayExpression::Identifier(size, id) => {
box match array.clone().inner {
ArrayExpressionInner::Identifier(id) => {
FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(
size, id,
),
box ArrayExpression {
ty: ty.clone(),
size,
inner: ArrayExpressionInner::Identifier(id),
},
box FieldElementExpression::Number(T::from(i)),
)
}
FieldElementArrayExpression::Value(size, expressions) => {
ArrayExpressionInner::Value(expressions) => {
assert_eq!(size, expressions.len());
expressions[i].clone()
FieldElementExpression::try_from(expressions[i].clone())
.unwrap()
}
FieldElementArrayExpression::FunctionCall(..) => {
unimplemented!(
"please use intermediate variables for now"
)
}
FieldElementArrayExpression::IfElse(
ArrayExpressionInner::FunctionCall(..) => unimplemented!(
"please use intermediate variables for now"
),
ArrayExpressionInner::IfElse(
condition,
consequence,
alternative,
@ -851,47 +1018,50 @@ impl<'ast> Flattener<'ast> {
}
}
fn flatten_field_array_expression<T: Field>(
fn flatten_array_expression<T: Field>(
&mut self,
functions_flattened: &Vec<FlatFunction<T>>,
statements_flattened: &mut Vec<FlatStatement<T>>,
expr: FieldElementArrayExpression<'ast, T>,
expr: ArrayExpression<'ast, T>,
) -> Vec<FlatExpression<T>> {
match expr {
FieldElementArrayExpression::Identifier(_, x) => self
let ty = expr.get_type();
let size = expr.size();
match expr.inner {
ArrayExpressionInner::Identifier(x) => self
.layout
.get(&x)
.unwrap()
.iter()
.map(|v| FlatExpression::Identifier(v.clone()))
.collect(),
FieldElementArrayExpression::Value(size, values) => {
ArrayExpressionInner::Value(values) => {
assert_eq!(size, values.len());
values
.into_iter()
.map(|v| {
self.flatten_field_expression(functions_flattened, statements_flattened, v)
self.flatten_field_expression(
functions_flattened,
statements_flattened,
FieldElementExpression::try_from(v).unwrap(),
)
})
.collect()
}
FieldElementArrayExpression::FunctionCall(size, ref id, ref param_expressions) => {
ArrayExpressionInner::FunctionCall(ref id, ref param_expressions) => {
let exprs_flattened = self.flatten_function_call(
functions_flattened,
statements_flattened,
id,
vec![Type::FieldElementArray(size)],
vec![Type::array(ty, size)],
param_expressions,
);
assert!(exprs_flattened.expressions.len() == size); // outside of MultipleDefinition, FunctionCalls must return a single value
exprs_flattened.expressions
}
FieldElementArrayExpression::IfElse(
ref condition,
ref consequence,
ref alternative,
) => {
ArrayExpressionInner::IfElse(ref condition, ref consequence, ref alternative) => {
let size = match consequence.get_type() {
Type::FieldElementArray(n) => n,
Type::Array(_, n) => n,
_ => unreachable!(),
};
(0..size)
@ -985,7 +1155,7 @@ impl<'ast> Flattener<'ast> {
// first we check that e is in 0..array.len(), so we check that sum(if e == i then 1 else 0) == 1
// here depending on the size, we could use a proper range check based on bits
let size = match array.get_type() {
Type::FieldElementArray(n) => n,
Type::Array(_, n) => n,
_ => panic!("checker should generate array element based on non array")
};
let range_check = (0..size)
@ -1057,7 +1227,7 @@ impl<'ast> Flattener<'ast> {
}
}
}
Type::FieldElementArray(..) => {
Type::Array(..) => {
let vars = match assignee {
TypedAssignee::Identifier(v) => self.use_variable(&v),
_ => unimplemented!(),
@ -1121,17 +1291,14 @@ impl<'ast> Flattener<'ast> {
unimplemented!()
}
}
(
TypedExpression::FieldElementArray(e1),
TypedExpression::FieldElementArray(e2),
) => {
(TypedExpression::Array(e1), TypedExpression::Array(e2)) => {
let (lhs, rhs) = (
self.flatten_field_array_expression(
self.flatten_array_expression(
functions_flattened,
statements_flattened,
e1,
),
self.flatten_field_array_expression(
self.flatten_array_expression(
functions_flattened,
statements_flattened,
e2,
@ -1265,7 +1432,7 @@ impl<'ast> Flattener<'ast> {
let vars = match variable.get_type() {
Type::FieldElement => self.issue_new_variables(1),
Type::Boolean => self.issue_new_variables(1),
Type::FieldElementArray(size) => self.issue_new_variables(size),
Type::Array(ty, size) => self.issue_new_variables(ty.get_primitive_count() * size),
};
self.layout.insert(variable.id.clone(), vars.clone());
@ -2001,8 +2168,8 @@ mod tests {
let mut functions_flattened = vec![];
let mut statements_flattened = vec![];
let statement = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array("foo".into(), 3)),
FieldElementArrayExpression::Value(
TypedAssignee::Identifier(Variable::array(Type::FieldElement, "foo".into(), 3)),
ArrayExpressionInner::Value(
3,
vec![
FieldElementExpression::Number(FieldPrime::from(1)),
@ -2012,7 +2179,7 @@ mod tests {
)
.into(),
);
let expression = FieldElementArrayExpression::Identifier(3, "foo".into());
let expression = ArrayExpressionInner::Identifier(3, "foo".into());
flattener.flatten_statement(
&mut functions_flattened,
@ -2044,8 +2211,8 @@ mod tests {
let mut functions_flattened = vec![];
let mut statements_flattened = vec![];
let statement = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array("foo".into(), 3)),
FieldElementArrayExpression::Value(
TypedAssignee::Identifier(Variable::array(Type::FieldElement, "foo".into(), 3)),
ArrayExpressionInner::Value(
3,
vec![
FieldElementExpression::Number(FieldPrime::from(1)),
@ -2090,8 +2257,8 @@ mod tests {
let mut functions_flattened = vec![];
let mut statements_flattened = vec![];
let statement = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array("foo".into(), 3)),
FieldElementArrayExpression::Value(
TypedAssignee::Identifier(Variable::array(Type::FieldElement, "foo".into(), 3)),
ArrayExpressionInner::Value(
3,
vec![
FieldElementExpression::Number(FieldPrime::from(1)),
@ -2103,7 +2270,7 @@ mod tests {
);
let expression = FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(3, "foo".into()),
box ArrayExpressionInner::Identifier(3, "foo".into()),
box FieldElementExpression::Number(FieldPrime::from(1)),
);
@ -2135,8 +2302,8 @@ mod tests {
let mut functions_flattened = vec![];
let mut statements_flattened = vec![];
let def = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array("foo".into(), 3)),
FieldElementArrayExpression::Value(
TypedAssignee::Identifier(Variable::array(Type::FieldElement, "foo".into(), 3)),
ArrayExpressionInner::Value(
3,
vec![
FieldElementExpression::Number(FieldPrime::from(1)),
@ -2152,16 +2319,16 @@ mod tests {
FieldElementExpression::Add(
box FieldElementExpression::Add(
box FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(3, "foo".into()),
box ArrayExpressionInner::Identifier(3, "foo".into()),
box FieldElementExpression::Number(FieldPrime::from(0)),
),
box FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(3, "foo".into()),
box ArrayExpressionInner::Identifier(3, "foo".into()),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
),
box FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(3, "foo".into()),
box ArrayExpressionInner::Identifier(3, "foo".into()),
box FieldElementExpression::Number(FieldPrime::from(2)),
),
)
@ -2205,16 +2372,16 @@ mod tests {
flattener.load_corelib(&mut functions_flattened);
let mut statements_flattened = vec![];
let e = FieldElementArrayExpression::IfElse(
let e = ArrayExpressionInner::IfElse(
box BooleanExpression::Eq(
box FieldElementExpression::Number(FieldPrime::from(1)),
box FieldElementExpression::Number(FieldPrime::from(1)),
),
box FieldElementArrayExpression::Value(
box ArrayExpressionInner::Value(
1,
vec![FieldElementExpression::Number(FieldPrime::from(1))],
),
box FieldElementArrayExpression::Value(
box ArrayExpressionInner::Value(
1,
vec![FieldElementExpression::Number(FieldPrime::from(3))],
),

View file

@ -536,7 +536,7 @@ impl<'ast> Checker<'ast> {
let checked_expression = self.check_expression(s.value.expression)?;
match checked_expression {
TypedExpression::FieldElementArray(e) => {
TypedExpression::Array(e) => {
let size = e.size();
Ok((0..size)
.map(|i| {
@ -577,9 +577,12 @@ impl<'ast> Checker<'ast> {
Type::FieldElement => {
Ok(FieldElementExpression::Identifier(name.into()).into())
}
Type::FieldElementArray(n) => {
Ok(FieldElementArrayExpression::Identifier(n, name.into()).into())
Type::Array(ty, size) => Ok(ArrayExpression {
ty: *ty,
size,
inner: ArrayExpressionInner::Identifier(name.into()),
}
.into()),
},
None => Err(Error {
pos: Some(pos),
@ -696,8 +699,16 @@ impl<'ast> Checker<'ast> {
(TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => {
Ok(FieldElementExpression::IfElse(box condition, box consequence, box alternative).into())
},
(TypedExpression::FieldElementArray(consequence), TypedExpression::FieldElementArray(alternative)) => {
Ok(FieldElementArrayExpression::IfElse(box condition, box consequence, box alternative).into())
(TypedExpression::Array(consequence), TypedExpression::Array(alternative)) => {
if consequence.get_type() == alternative.get_type() && consequence.size() == alternative.size() {
Ok(ArrayExpression {
ty: consequence.get_type(),
size: consequence.size(),
inner: ArrayExpressionInner::IfElse(box condition, box consequence, box alternative)
}.into())
} else {
unimplemented!("handle consequence alternative inner type mismatch")
}
},
_ => unimplemented!()
}
@ -742,20 +753,21 @@ impl<'ast> Checker<'ast> {
let f = &candidates[0];
// the return count has to be 1
match f.signature.outputs.len() {
1 => match f.signature.outputs[0] {
1 => match f.signature.outputs[0].clone() {
Type::FieldElement => Ok(FieldElementExpression::FunctionCall(
f.id.to_string(),
arguments_checked,
)
.into()),
Type::FieldElementArray(size) => {
Ok(FieldElementArrayExpression::FunctionCall(
size,
Type::Array(ty, size) => Ok(ArrayExpression {
ty: *ty,
size,
inner: ArrayExpressionInner::FunctionCall(
f.id.to_string(),
arguments_checked,
)
.into())
),
}
.into()),
_ => unimplemented!(),
},
n => Err(Error {
@ -880,8 +892,9 @@ impl<'ast> Checker<'ast> {
match index {
RangeOrExpression::Range(r) => match array {
TypedExpression::FieldElementArray(array) => {
TypedExpression::Array(array) => {
let array_size = array.size();
let inner_type = array.get_type();
let from = r
.value
@ -895,8 +908,6 @@ impl<'ast> Checker<'ast> {
.map(|v| v.to_dec_string().parse::<usize>().unwrap())
.unwrap_or(array_size);
println!("from {} to {}", from, to);
match (from, to, array_size) {
(f, _, s) if f > s => Err(Error {
pos: Some(pos),
@ -919,27 +930,36 @@ impl<'ast> Checker<'ast> {
f, t,
),
}),
(f, t, _) => Ok(FieldElementArrayExpression::Value(
t - f,
(f..t)
.map(|i| {
FieldElementExpression::Select(
box array.clone(),
box FieldElementExpression::Number(T::from(i)),
)
})
.collect(),
)
(f, t, _) => Ok(ArrayExpression {
ty: inner_type,
size: t - f,
inner: ArrayExpressionInner::Value(
(f..t)
.map(|i| {
FieldElementExpression::Select(
box array.clone(),
box FieldElementExpression::Number(T::from(i)),
)
.into()
})
.collect(),
),
}
.into()),
}
}
_ => panic!(""),
},
RangeOrExpression::Expression(e) => match (array, self.check_expression(e)?) {
(
TypedExpression::FieldElementArray(a),
TypedExpression::FieldElement(i),
) => Ok(FieldElementExpression::Select(box a, box i).into()),
(TypedExpression::Array(a), TypedExpression::FieldElement(i)) => {
match a.inner_type() {
Type::FieldElement => {
Ok(FieldElementExpression::Select(box a, box i).into())
}
Type::Boolean => Ok(BooleanExpression::Select(box a, box i).into()),
Type::Array(_, _) => unimplemented!("multi dim wow"),
}
}
(a, e) => Err(Error {
pos: Some(pos),
message: format!(
@ -984,13 +1004,14 @@ impl<'ast> Checker<'ast> {
),
}),
}?;
unwrapped_expressions.push(unwrapped_e);
unwrapped_expressions.push(unwrapped_e.into());
}
Ok(FieldElementArrayExpression::Value(
unwrapped_expressions.len(),
unwrapped_expressions,
)
Ok(ArrayExpression {
ty: Type::FieldElement,
size: unwrapped_expressions.len(),
inner: ArrayExpressionInner::Value(unwrapped_expressions),
}
.into())
}
_ => Err(Error {

View file

@ -78,10 +78,10 @@ pub fn sha_round<T: Field>() -> FlatFunction<T> {
// define the signature of the resulting function
let signature = Signature {
inputs: vec![
Type::FieldElementArray(input_indices.len()),
Type::FieldElementArray(current_hash_indices.len()),
Type::array(Type::FieldElement, input_indices.len()),
Type::array(Type::FieldElement, current_hash_indices.len()),
],
outputs: vec![Type::FieldElementArray(output_indices.len())],
outputs: vec![Type::array(Type::FieldElement, output_indices.len())],
};
// define parameters to the function based on the variables
@ -160,11 +160,8 @@ mod tests {
assert_eq!(
compiled.signature,
Signature::new()
.inputs(vec![
Type::FieldElementArray(512),
Type::FieldElementArray(256)
])
.outputs(vec![Type::FieldElementArray(256)])
.inputs(vec![Type::Array(512), Type::Array(256)])
.outputs(vec![Type::Array(256)])
);
// function should have 768 inputs

View file

@ -79,23 +79,25 @@ impl<'ast, T: Field> Folder<'ast, T> for DeadCode {
}
}
fn fold_field_array_expression(
fn fold_array_expression_inner(
&mut self,
e: FieldElementArrayExpression<'ast, T>,
) -> FieldElementArrayExpression<'ast, T> {
ty: &Type,
size: usize,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
FieldElementArrayExpression::FunctionCall(size, id, exps) => {
ArrayExpressionInner::FunctionCall(id, exps) => {
let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect();
let signature = Signature::new()
.inputs(exps.iter().map(|e| e.get_type()).collect())
.outputs(vec![Type::FieldElementArray(size)]);
.outputs(vec![Type::array(ty.clone(), size)]);
self.called
.insert(format!("{}_{}", id, signature.to_slug()));
FieldElementArrayExpression::FunctionCall(size, id, exps)
ArrayExpressionInner::FunctionCall(id, exps)
}
e => fold_field_array_expression(self, e),
_ => fold_array_expression_inner(self, ty, size, e),
}
}
}

View file

@ -39,7 +39,7 @@ impl<'ast, T: Field> Inliner<'ast, T> {
Some(..) => {
// check whether non-array arguments are constant
arguments.iter().all(|e| match e {
TypedExpression::FieldElementArray(..) => true,
TypedExpression::Array(..) => true,
TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true,
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
_ => false,
@ -213,17 +213,19 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
}
// inline calls which return a field element array
fn fold_field_array_expression(
fn fold_array_expression_inner(
&mut self,
e: FieldElementArrayExpression<'ast, T>,
) -> FieldElementArrayExpression<'ast, T> {
ty: &Type,
size: usize,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
FieldElementArrayExpression::FunctionCall(size, id, exps) => {
ArrayExpressionInner::FunctionCall(id, exps) => {
let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect();
let passed_signature = Signature::new()
.inputs(exps.iter().map(|e| e.get_type()).collect())
.outputs(vec![Type::FieldElementArray(size)]);
.outputs(vec![Type::array(ty.clone(), size)]);
// find the function
let function = self
@ -237,15 +239,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> {
let ret = self.inline_call(function.unwrap(), exps);
// unwrap the result to return a field element
match ret[0].clone() {
TypedExpression::FieldElementArray(e) => e,
TypedExpression::Array(e) => e.inner,
_ => panic!(""),
}
}
false => FieldElementArrayExpression::FunctionCall(size, id, exps),
false => ArrayExpressionInner::FunctionCall(id, exps),
}
}
// default
e => fold_field_array_expression(self, e),
e => fold_array_expression_inner(self, ty, size, e),
}
}
}
@ -269,19 +271,19 @@ mod tests {
],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(3, Identifier::from("b")),
box ArrayExpressionInner::Identifier(3, Identifier::from("b")),
box FieldElementExpression::Identifier(Identifier::from("a")),
)
.into(),
])],
signature: Signature::new()
.inputs(vec![Type::FieldElement, Type::FieldElementArray(3)])
.inputs(vec![Type::FieldElement, Type::Array(3)])
.outputs(vec![Type::FieldElement]),
};
let arguments = vec![
FieldElementExpression::Number(FieldPrime::from(0)).into(),
FieldElementArrayExpression::Identifier(3, Identifier::from("random")).into(),
ArrayExpressionInner::Identifier(3, Identifier::from("random")).into(),
];
let i = Inliner::new();
@ -299,19 +301,19 @@ mod tests {
],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(3, Identifier::from("b")),
box ArrayExpressionInner::Identifier(3, Identifier::from("b")),
box FieldElementExpression::Identifier(Identifier::from("a")),
)
.into(),
])],
signature: Signature::new()
.inputs(vec![Type::FieldElement, Type::FieldElementArray(3)])
.inputs(vec![Type::FieldElement, Type::Array(3)])
.outputs(vec![Type::FieldElement]),
};
let arguments = vec![
FieldElementExpression::Identifier(Identifier::from("notconstant")).into(),
FieldElementArrayExpression::Identifier(3, Identifier::from("random")).into(),
ArrayExpressionInner::Identifier(3, Identifier::from("random")).into(),
];
let i = Inliner::new();

View file

@ -7,6 +7,8 @@
use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryFrom;
use types::Type;
use zokrates_field::field::Field;
pub struct Propagator<'ast, T: Field> {
@ -42,20 +44,31 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
self.constants.insert(TypedAssignee::Identifier(var), e);
None
},
TypedExpression::FieldElementArray(FieldElementArrayExpression::Value(size, array)) => {
match array.iter().all(|e| match e {
FieldElementExpression::Number(..) => true,
_ => false
}) {
true => {
// all elements of the array are constants
self.constants.insert(TypedAssignee::Identifier(var), FieldElementArrayExpression::Value(size, array).into());
None
},
false => {
Some(TypedStatement::Definition(TypedAssignee::Identifier(var), FieldElementArrayExpression::Value(size, array).into()))
}
}
TypedExpression::Array(e) => {
let inner_type = e.inner_type();
match e.inner {
ArrayExpressionInner::Value(array) =>
match array.iter().all(|e| match e {
TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true,
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
_ => false
}) {
true => {
// all elements of the array are constants
self.constants.insert(TypedAssignee::Identifier(var), ArrayExpression {
inner: ArrayExpressionInner::Value(array),
..e}.into());
None
},
false => {
Some(TypedStatement::Definition(TypedAssignee::Identifier(var), ArrayExpression {
inner: ArrayExpressionInner::Value(array),
..e}.into()))
}
}
,
e => unimplemented!()
}
},
e => {
Some(TypedStatement::Definition(TypedAssignee::Identifier(var), e))
@ -75,16 +88,22 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
// a[42] = 33
// -> store (a[42] -> 33) in the constants, possibly overwriting the previous entry
self.constants.entry(TypedAssignee::Identifier(var)).and_modify(|e| {
match *e {
TypedExpression::FieldElementArray(FieldElementArrayExpression::Value(size, ref mut v)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
v[n_as_usize] = expr;
} else {
panic!(format!("out of bounds index ({} >= {}) found during static analysis", n_as_usize, size));
}
},
_ => panic!("constants should only store constants")
match e {
TypedExpression::Array(e) => {
let size = e.size();
match e.inner {
ArrayExpressionInner::Value(ref mut v) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
v[n_as_usize] = expr.into();
} else {
panic!(format!("out of bounds index ({} >= {}) found during static analysis", n_as_usize, size));
}
},
_ => panic!("constants should only store constants")
}
},
_ => unimplemented!()
}
});
None
@ -201,17 +220,17 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
}
FieldElementExpression::Select(box array, box index) => {
let array = self.fold_field_array_expression(array);
let array = self.fold_array_expression(array);
let index = self.fold_field_expression(index);
match (array, index) {
(
FieldElementArrayExpression::Value(size, v),
FieldElementExpression::Number(n),
) => {
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.inner, index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
v[n_as_usize].clone()
FieldElementExpression::try_from(v[n_as_usize].clone()).unwrap()
} else {
panic!(format!(
"out of bounds index ({} >= {}) found during static analysis",
@ -219,49 +238,66 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
));
}
}
(
FieldElementArrayExpression::Identifier(size, id),
FieldElementExpression::Number(n),
) => match self.constants.get(&TypedAssignee::ArrayElement(
box TypedAssignee::Identifier(Variable::field_array(id.clone(), size)),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(),
_ => panic!(""),
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::ArrayElement(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(),
_ => panic!(""),
},
None => FieldElementExpression::Select(
box ArrayExpression {
ty: inner_type,
size,
inner: ArrayExpressionInner::Identifier(id),
},
box FieldElementExpression::Number(n),
),
}
}
(a, i) => FieldElementExpression::Select(
box ArrayExpression {
ty: inner_type,
size,
inner: a,
},
None => FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(size, id),
box FieldElementExpression::Number(n),
),
},
(a, i) => FieldElementExpression::Select(box a, box i),
box i,
),
}
}
e => fold_field_expression(self, e),
}
}
fn fold_field_array_expression(
fn fold_array_expression_inner(
&mut self,
e: FieldElementArrayExpression<'ast, T>,
) -> FieldElementArrayExpression<'ast, T> {
ty: &Type,
size: usize,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
FieldElementArrayExpression::Identifier(size, id) => {
ArrayExpressionInner::Identifier(id) => {
match self
.constants
.get(&TypedAssignee::Identifier(Variable::field_array(
.get(&TypedAssignee::Identifier(Variable::array(
id.clone(),
ty.clone(),
size,
))) {
Some(e) => match e {
TypedExpression::FieldElementArray(e) => e.clone(),
TypedExpression::Array(e) => e.inner.clone(),
_ => panic!("constant stored for an array should be an array"),
},
None => FieldElementArrayExpression::Identifier(size, id),
None => ArrayExpressionInner::Identifier(id),
}
}
e => fold_field_array_expression(self, e),
e => fold_array_expression_inner(self, ty, size, e),
}
}
@ -449,7 +485,7 @@ mod tests {
#[test]
fn select() {
let e = FieldElementExpression::Select(
box FieldElementArrayExpression::Value(
box ArrayExpressionInner::Value(
3,
vec![
FieldElementExpression::Number(FieldPrime::from(1)),
@ -604,7 +640,7 @@ mod tests {
let declaration = TypedStatement::Declaration(Variable::field_array("a".into(), 2));
let definition = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
FieldElementArrayExpression::Value(
ArrayExpressionInner::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(21)),
@ -625,15 +661,14 @@ mod tests {
p.fold_statement(declaration);
p.fold_statement(definition);
let expected_value: TypedExpression<FieldPrime> =
FieldElementArrayExpression::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(21)),
FieldElementExpression::Number(FieldPrime::from(22)),
],
)
.into();
let expected_value: TypedExpression<FieldPrime> = ArrayExpressionInner::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(21)),
FieldElementExpression::Number(FieldPrime::from(22)),
],
)
.into();
assert_eq!(
p.constants
@ -646,15 +681,14 @@ mod tests {
);
p.fold_statement(overwrite);
let expected_value: TypedExpression<FieldPrime> =
FieldElementArrayExpression::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(21)),
FieldElementExpression::Number(FieldPrime::from(42)),
],
)
.into();
let expected_value: TypedExpression<FieldPrime> = ArrayExpressionInner::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(21)),
FieldElementExpression::Number(FieldPrime::from(42)),
],
)
.into();
assert_eq!(
p.constants

View file

@ -76,7 +76,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
};
let array_size = match original_variable.get_type() {
Type::FieldElementArray(size) => size,
Type::Array(_, size) => size,
_ => panic!("array identifier should be a field element array"),
};
@ -87,27 +87,34 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
let new_variable = self.issue_next_ssa_variable(original_variable);
let new_array = FieldElementArrayExpression::Value(
array_size,
(0..array_size)
.map(|i| {
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box index.clone(),
box FieldElementExpression::Number(T::from(i)),
),
box expr.clone(),
box FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(
array_size,
current_ssa_variable.id.clone(),
let new_array = ArrayExpression {
ty: Type::FieldElement,
size: array_size,
inner: ArrayExpressionInner::Value(
(0..array_size)
.map(|i| {
FieldElementExpression::IfElse(
box BooleanExpression::Eq(
box index.clone(),
box FieldElementExpression::Number(T::from(i)),
),
box FieldElementExpression::Number(T::from(i)),
),
)
})
.collect(),
);
box expr.clone(),
box FieldElementExpression::Select(
box ArrayExpression {
ty: Type::FieldElement,
size: array_size,
inner: ArrayExpressionInner::Identifier(
current_ssa_variable.id.clone(),
),
},
box FieldElementExpression::Number(T::from(i)),
),
)
.into()
})
.collect(),
),
};
vec![TypedStatement::Definition(
TypedAssignee::Identifier(new_variable),
@ -433,7 +440,7 @@ mod tests {
let s = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_array("a".into(), 2)),
FieldElementArrayExpression::Value(
ArrayExpressionInner::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(1)),
@ -450,7 +457,7 @@ mod tests {
Identifier::from("a").version(0),
2
)),
FieldElementArrayExpression::Value(
ArrayExpressionInner::Value(
2,
vec![
FieldElementExpression::Number(FieldPrime::from(1)),
@ -476,7 +483,7 @@ mod tests {
Identifier::from("a").version(1),
2
)),
FieldElementArrayExpression::Value(
ArrayExpressionInner::Value(
2,
vec![
FieldElementExpression::IfElse(
@ -486,7 +493,7 @@ mod tests {
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(
box ArrayExpressionInner::Identifier(
2,
Identifier::from("a").version(0)
),
@ -500,7 +507,7 @@ mod tests {
),
box FieldElementExpression::Number(FieldPrime::from(2)),
box FieldElementExpression::Select(
box FieldElementArrayExpression::Identifier(
box ArrayExpressionInner::Identifier(
2,
Identifier::from("a").version(0)
),

View file

@ -48,10 +48,14 @@ pub trait Folder<'ast, T: Field>: Sized {
match e {
TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(),
TypedExpression::Boolean(e) => self.fold_boolean_expression(e).into(),
TypedExpression::FieldElementArray(e) => self.fold_field_array_expression(e).into(),
TypedExpression::Array(e) => self.fold_array_expression(e).into(),
}
}
fn fold_array_expression(&mut self, e: ArrayExpression<'ast, T>) -> ArrayExpression<'ast, T> {
fold_array_expression(self, e)
}
fn fold_expression_list(
&mut self,
es: TypedExpressionList<'ast, T>,
@ -82,11 +86,13 @@ pub trait Folder<'ast, T: Field>: Sized {
) -> BooleanExpression<'ast, T> {
fold_boolean_expression(self, e)
}
fn fold_field_array_expression(
fn fold_array_expression_inner(
&mut self,
e: FieldElementArrayExpression<'ast, T>,
) -> FieldElementArrayExpression<'ast, T> {
fold_field_array_expression(self, e)
ty: &Type,
size: usize,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
fold_array_expression_inner(self, ty, size, e)
}
}
@ -139,30 +145,26 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
vec![res]
}
pub fn fold_field_array_expression<'ast, T: Field, F: Folder<'ast, T>>(
pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
e: FieldElementArrayExpression<'ast, T>,
) -> FieldElementArrayExpression<'ast, T> {
ty: &Type,
size: usize,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
FieldElementArrayExpression::Identifier(size, id) => {
FieldElementArrayExpression::Identifier(size, f.fold_name(id))
ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)),
ArrayExpressionInner::Value(exprs) => {
ArrayExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect())
}
FieldElementArrayExpression::Value(size, exprs) => FieldElementArrayExpression::Value(
size,
exprs
.into_iter()
.map(|e| f.fold_field_expression(e))
.collect(),
),
FieldElementArrayExpression::FunctionCall(size, id, exps) => {
ArrayExpressionInner::FunctionCall(id, exps) => {
let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect();
FieldElementArrayExpression::FunctionCall(size, id, exps)
ArrayExpressionInner::FunctionCall(id, exps)
}
FieldElementArrayExpression::IfElse(box condition, box consequence, box alternative) => {
FieldElementArrayExpression::IfElse(
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
ArrayExpressionInner::IfElse(
box f.fold_boolean_expression(condition),
box f.fold_field_array_expression(consequence),
box f.fold_field_array_expression(alternative),
box f.fold_array_expression(consequence),
box f.fold_array_expression(alternative),
)
}
}
@ -213,7 +215,7 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
FieldElementExpression::FunctionCall(id, exps)
}
FieldElementExpression::Select(box array, box index) => {
let array = f.fold_field_array_expression(array);
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
FieldElementExpression::Select(box array, box index)
}
@ -266,6 +268,17 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e = f.fold_boolean_expression(e);
BooleanExpression::Not(box e)
}
BooleanExpression::IfElse(box cond, box cons, box alt) => {
let cond = f.fold_boolean_expression(cond);
let cons = f.fold_boolean_expression(cons);
let alt = f.fold_boolean_expression(alt);
BooleanExpression::IfElse(box cond, box cons, box alt)
}
BooleanExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
BooleanExpression::Select(box array, box index)
}
}
}
@ -287,3 +300,13 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
..fun
}
}
pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
e: ArrayExpression<'ast, T>,
) -> ArrayExpression<'ast, T> {
ArrayExpression {
inner: f.fold_array_expression_inner(&e.ty, e.size, e.inner),
..e
}
}

View file

@ -16,6 +16,7 @@ use crate::types::Signature;
use crate::flat_absy::*;
use crate::imports::Import;
use crate::types::Type;
use std::convert::TryFrom;
use std::fmt;
use zokrates_field::field::Field;
@ -191,7 +192,7 @@ impl<'ast, T: Field> Typed for TypedAssignee<'ast, T> {
TypedAssignee::ArrayElement(ref a, _) => {
let a_type = a.get_type();
match a_type {
Type::FieldElementArray(_) => Type::FieldElement,
Type::Array(..) => Type::FieldElement,
_ => panic!("array element has to take array"),
}
}
@ -302,7 +303,7 @@ pub trait Typed {
pub enum TypedExpression<'ast, T: Field> {
Boolean(BooleanExpression<'ast, T>),
FieldElement(FieldElementExpression<'ast, T>),
FieldElementArray(FieldElementArrayExpression<'ast, T>),
Array(ArrayExpression<'ast, T>),
}
impl<'ast, T: Field> From<BooleanExpression<'ast, T>> for TypedExpression<'ast, T> {
@ -317,9 +318,9 @@ impl<'ast, T: Field> From<FieldElementExpression<'ast, T>> for TypedExpression<'
}
}
impl<'ast, T: Field> From<FieldElementArrayExpression<'ast, T>> for TypedExpression<'ast, T> {
fn from(e: FieldElementArrayExpression<'ast, T>) -> TypedExpression<T> {
TypedExpression::FieldElementArray(e)
impl<'ast, T: Field> From<ArrayExpression<'ast, T>> for TypedExpression<'ast, T> {
fn from(e: ArrayExpression<'ast, T>) -> TypedExpression<T> {
TypedExpression::Array(e)
}
}
@ -328,7 +329,7 @@ impl<'ast, T: Field> fmt::Display for TypedExpression<'ast, T> {
match *self {
TypedExpression::Boolean(ref e) => write!(f, "{}", e),
TypedExpression::FieldElement(ref e) => write!(f, "{}", e),
TypedExpression::FieldElementArray(ref e) => write!(f, "{}", e),
TypedExpression::Array(ref e) => write!(f, "{}", e.inner),
}
}
}
@ -338,29 +339,36 @@ impl<'ast, T: Field> fmt::Debug for TypedExpression<'ast, T> {
match *self {
TypedExpression::Boolean(ref e) => write!(f, "{:?}", e),
TypedExpression::FieldElement(ref e) => write!(f, "{:?}", e),
TypedExpression::FieldElementArray(ref e) => write!(f, "{:?}", e),
TypedExpression::Array(ref e) => write!(f, "{:?}", e),
}
}
}
impl<'ast, T: Field> fmt::Display for ArrayExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.inner)
}
}
impl<'ast, T: Field> fmt::Debug for ArrayExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.inner)
}
}
impl<'ast, T: Field> Typed for TypedExpression<'ast, T> {
fn get_type(&self) -> Type {
match *self {
TypedExpression::Boolean(_) => Type::Boolean,
TypedExpression::FieldElement(_) => Type::FieldElement,
TypedExpression::FieldElementArray(ref e) => e.get_type(),
TypedExpression::Array(ref e) => e.get_type(),
}
}
}
impl<'ast, T: Field> Typed for FieldElementArrayExpression<'ast, T> {
impl<'ast, T: Field> Typed for ArrayExpression<'ast, T> {
fn get_type(&self) -> Type {
match *self {
FieldElementArrayExpression::Identifier(n, _) => Type::FieldElementArray(n),
FieldElementArrayExpression::Value(n, _) => Type::FieldElementArray(n),
FieldElementArrayExpression::FunctionCall(n, _, _) => Type::FieldElementArray(n),
FieldElementArrayExpression::IfElse(_, ref consequence, _) => consequence.get_type(),
}
Type::array(self.ty.clone(), self.size)
}
}
@ -412,7 +420,7 @@ pub enum FieldElementExpression<'ast, T: Field> {
),
FunctionCall(String, Vec<TypedExpression<'ast, T>>),
Select(
Box<FieldElementArrayExpression<'ast, T>>,
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
@ -450,28 +458,68 @@ pub enum BooleanExpression<'ast, T: Field> {
Box<BooleanExpression<'ast, T>>,
),
Not(Box<BooleanExpression<'ast, T>>),
}
// for now we store the array size in the variants
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum FieldElementArrayExpression<'ast, T: Field> {
Identifier(usize, Identifier<'ast>),
Value(usize, Vec<FieldElementExpression<'ast, T>>),
FunctionCall(usize, String, Vec<TypedExpression<'ast, T>>),
IfElse(
Box<BooleanExpression<'ast, T>>,
Box<FieldElementArrayExpression<'ast, T>>,
Box<FieldElementArrayExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
Select(
Box<ArrayExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
}
impl<'ast, T: Field> FieldElementArrayExpression<'ast, T> {
#[derive(Clone, PartialEq, Hash, Eq)]
pub struct ArrayExpression<'ast, T: Field> {
pub size: usize,
pub ty: Type,
pub inner: ArrayExpressionInner<'ast, T>,
}
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum ArrayExpressionInner<'ast, T: Field> {
Identifier(Identifier<'ast>),
Value(Vec<TypedExpression<'ast, T>>),
FunctionCall(String, Vec<TypedExpression<'ast, T>>),
IfElse(
Box<BooleanExpression<'ast, T>>,
Box<ArrayExpression<'ast, T>>,
Box<ArrayExpression<'ast, T>>,
),
}
impl<'ast, T: Field> ArrayExpression<'ast, T> {
pub fn inner_type(&self) -> &Type {
&self.ty
}
pub fn size(&self) -> usize {
match *self {
FieldElementArrayExpression::Identifier(s, _)
| FieldElementArrayExpression::Value(s, _)
| FieldElementArrayExpression::FunctionCall(s, ..) => s,
FieldElementArrayExpression::IfElse(_, ref consequence, _) => consequence.size(),
self.size
}
}
// Downcasts
impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for FieldElementExpression<'ast, T> {
type Error = ();
fn try_from(
te: TypedExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match te {
TypedExpression::FieldElement(e) => Ok(e),
_ => Err(()),
}
}
}
impl<'ast, T: Field> TryFrom<TypedExpression<'ast, T>> for BooleanExpression<'ast, T> {
type Error = ();
fn try_from(te: TypedExpression<'ast, T>) -> Result<BooleanExpression<'ast, T>, Self::Error> {
match te {
TypedExpression::Boolean(e) => Ok(e),
_ => Err(()),
}
}
}
@ -521,15 +569,21 @@ impl<'ast, T: Field> fmt::Display for BooleanExpression<'ast, T> {
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
BooleanExpression::Value(b) => write!(f, "{}", b),
BooleanExpression::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
"if {} then {} else {} fi",
condition, consequent, alternative
),
BooleanExpression::Select(ref id, ref index) => write!(f, "{}[{}]", id, index),
}
}
}
impl<'ast, T: Field> fmt::Display for FieldElementArrayExpression<'ast, T> {
impl<'ast, T: Field> fmt::Display for ArrayExpressionInner<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
FieldElementArrayExpression::Identifier(_, ref var) => write!(f, "{}", var),
FieldElementArrayExpression::Value(_, ref values) => write!(
ArrayExpressionInner::Identifier(ref var) => write!(f, "{}", var),
ArrayExpressionInner::Value(ref values) => write!(
f,
"[{}]",
values
@ -538,7 +592,7 @@ impl<'ast, T: Field> fmt::Display for FieldElementArrayExpression<'ast, T> {
.collect::<Vec<String>>()
.join(", ")
),
FieldElementArrayExpression::FunctionCall(_, ref i, ref p) => {
ArrayExpressionInner::FunctionCall(ref i, ref p) => {
r#try!(write!(f, "{}(", i,));
for (i, param) in p.iter().enumerate() {
r#try!(write!(f, "{}", param));
@ -548,13 +602,11 @@ impl<'ast, T: Field> fmt::Display for FieldElementArrayExpression<'ast, T> {
}
write!(f, ")")
}
FieldElementArrayExpression::IfElse(ref condition, ref consequent, ref alternative) => {
write!(
f,
"if {} then {} else {} fi",
condition, consequent, alternative
)
}
ArrayExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
"if {} then {} else {} fi",
condition, consequent, alternative
),
}
}
}
@ -596,23 +648,21 @@ impl<'ast, T: Field> fmt::Debug for FieldElementExpression<'ast, T> {
}
}
impl<'ast, T: Field> fmt::Debug for FieldElementArrayExpression<'ast, T> {
impl<'ast, T: Field> fmt::Debug for ArrayExpressionInner<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
FieldElementArrayExpression::Identifier(_, ref var) => write!(f, "{:?}", var),
FieldElementArrayExpression::Value(_, ref values) => write!(f, "{:?}", values),
FieldElementArrayExpression::FunctionCall(_, ref i, ref p) => {
ArrayExpressionInner::Identifier(ref var) => write!(f, "{:?}", var),
ArrayExpressionInner::Value(ref values) => write!(f, "{:?}", values),
ArrayExpressionInner::FunctionCall(ref i, ref p) => {
r#try!(write!(f, "FunctionCall({:?}, (", i));
r#try!(f.debug_list().entries(p.iter()).finish());
write!(f, ")")
}
FieldElementArrayExpression::IfElse(ref condition, ref consequent, ref alternative) => {
write!(
f,
"IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative
)
}
ArrayExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
"IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative
),
}
}
}

View file

@ -18,8 +18,8 @@ impl<'ast> Variable<'ast> {
Self::with_id_and_type(id, Type::Boolean)
}
pub fn field_array(id: Identifier<'ast>, size: usize) -> Variable<'ast> {
Self::with_id_and_type(id, Type::FieldElementArray(size))
pub fn array(id: Identifier<'ast>, ty: Type, size: usize) -> Variable<'ast> {
Self::with_id_and_type(id, Type::array(ty, size))
}
pub fn with_id_and_type(id: Identifier<'ast>, _type: Type) -> Variable<'ast> {

View file

@ -45,7 +45,7 @@ pub fn split<T: Field>() -> FlatProg<T> {
let signature = Signature {
inputs: vec![Type::FieldElement],
outputs: vec![Type::FieldElementArray(nbits)],
outputs: vec![Type::array(Type::FieldElement, nbits)],
};
let outputs = directive_outputs

View file

@ -8,44 +8,48 @@ mod signature;
pub enum Type {
FieldElement,
Boolean,
FieldElementArray(usize),
Array(Box<Type>, usize),
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
match self {
Type::FieldElement => write!(f, "field"),
Type::Boolean => write!(f, "bool"),
Type::FieldElementArray(size) => write!(f, "{}[{}]", Type::FieldElement, size),
Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
}
}
}
impl fmt::Debug for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
match self {
Type::FieldElement => write!(f, "field"),
Type::Boolean => write!(f, "bool"),
Type::FieldElementArray(size) => write!(f, "{}[{}]", Type::FieldElement, size),
Type::Array(ref ty, ref size) => write!(f, "{}[{}]", ty, size),
}
}
}
impl Type {
pub fn array(ty: Type, size: usize) -> Self {
Type::Array(box ty, size)
}
// the number of field elements the type maps to
pub fn get_primitive_count(&self) -> usize {
match self {
Type::FieldElement => 1,
Type::Boolean => 1,
Type::FieldElementArray(size) => size * Type::FieldElement.get_primitive_count(),
Type::Array(ty, size) => size * ty.get_primitive_count(),
}
}
fn to_slug(&self) -> String {
match *self {
match self {
Type::FieldElement => String::from("f"),
Type::Boolean => String::from("b"),
Type::FieldElementArray(size) => format!("{}[{}]", Type::FieldElement.to_slug(), size), // TODO differentiate types?
Type::Array(ref ty, ref size) => format!("{}[{}]", ty.to_slug(), size), // TODO differentiate types?
}
}
}
@ -56,7 +60,7 @@ mod tests {
#[test]
fn array() {
let t = Type::FieldElementArray(42);
let t = Type::Array(box Type::FieldElement, 42);
assert_eq!(t.get_primitive_count(), 42);
assert_eq!(t.to_slug(), "f[42]");
}

View file

@ -150,8 +150,8 @@ mod tests {
fn array_slug() {
let s = Signature::new()
.inputs(vec![
Type::FieldElementArray(42),
Type::FieldElementArray(21),
Type::Array(Type::FieldElement, 42),
Type::Array(Type::FieldElement, 21),
])
.outputs(vec![]);