Merge pull request #864 from Zokrates/fix-constants
Support the use of constants in declaration types
This commit is contained in:
commit
2ed9e6a972
19 changed files with 523 additions and 224 deletions
1
changelogs/unreleased/864-dark64
Normal file
1
changelogs/unreleased/864-dark64
Normal file
|
@ -0,0 +1 @@
|
|||
Support the use of constants in struct and function declarations
|
|
@ -0,0 +1,4 @@
|
|||
const field SIZE = 2
|
||||
|
||||
def main(field[SIZE] n):
|
||||
return
|
|
@ -0,0 +1,4 @@
|
|||
const u8 SIZE = 0x02
|
||||
|
||||
def main(field[SIZE] n):
|
||||
return
|
|
@ -0,0 +1,7 @@
|
|||
const u32 N = 42
|
||||
|
||||
def foo<N>(field[N] a) -> bool:
|
||||
return true
|
||||
|
||||
def main():
|
||||
return
|
File diff suppressed because it is too large
Load diff
|
@ -1,20 +1,37 @@
|
|||
use crate::static_analysis::propagation::Propagator;
|
||||
use crate::typed_absy::folder::*;
|
||||
use crate::typed_absy::result_folder::ResultFolder;
|
||||
use crate::typed_absy::types::{Constant, DeclarationStructType, GStructMember};
|
||||
use crate::typed_absy::*;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub struct ConstantInliner<'ast, T: Field> {
|
||||
pub struct ConstantInliner<'ast, 'a, T: Field> {
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
propagator: Propagator<'ast, 'a, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ConstantInliner<'ast, T> {
|
||||
pub fn new(modules: TypedModules<'ast, T>, location: OwnedTypedModuleId) -> Self {
|
||||
ConstantInliner { modules, location }
|
||||
impl<'ast, 'a, T: Field> ConstantInliner<'ast, 'a, T> {
|
||||
pub fn new(
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
propagator: Propagator<'ast, 'a, T>,
|
||||
) -> Self {
|
||||
ConstantInliner {
|
||||
modules,
|
||||
location,
|
||||
propagator,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone());
|
||||
let mut constants = HashMap::new();
|
||||
let mut inliner = ConstantInliner::new(
|
||||
p.modules.clone(),
|
||||
p.main.clone(),
|
||||
Propagator::with_constants(&mut constants),
|
||||
);
|
||||
inliner.fold_program(p)
|
||||
}
|
||||
|
||||
|
@ -51,12 +68,18 @@ impl<'ast, T: Field> ConstantInliner<'ast, T> {
|
|||
let _ = self.change_location(location);
|
||||
symbol
|
||||
}
|
||||
TypedConstantSymbol::Here(tc) => self.fold_constant(tc),
|
||||
TypedConstantSymbol::Here(tc) => {
|
||||
let tc: TypedConstant<T> = self.fold_constant(tc);
|
||||
TypedConstant {
|
||||
expression: self.propagator.fold_expression(tc.expression).unwrap(),
|
||||
..tc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
||||
impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> {
|
||||
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
TypedProgram {
|
||||
modules: p
|
||||
|
@ -71,6 +94,62 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> {
|
||||
match t {
|
||||
DeclarationType::Array(ref array_ty) => match array_ty.size {
|
||||
Constant::Identifier(name, _) => {
|
||||
let tc = self.get_constant(&name.into()).unwrap();
|
||||
let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap();
|
||||
match expression.inner {
|
||||
UExpressionInner::Value(v) => DeclarationType::array((
|
||||
self.fold_declaration_type(*array_ty.ty.clone()),
|
||||
Constant::Concrete(v as u32),
|
||||
)),
|
||||
_ => unreachable!("expected u32 value"),
|
||||
}
|
||||
}
|
||||
_ => t,
|
||||
},
|
||||
DeclarationType::Struct(struct_ty) => DeclarationType::struc(DeclarationStructType {
|
||||
members: struct_ty
|
||||
.members
|
||||
.into_iter()
|
||||
.map(|m| GStructMember::new(m.id, self.fold_declaration_type(*m.ty)))
|
||||
.collect(),
|
||||
..struct_ty
|
||||
}),
|
||||
_ => t,
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_type(&mut self, t: Type<'ast, T>) -> Type<'ast, T> {
|
||||
use self::GType::*;
|
||||
match t {
|
||||
Array(ref array_type) => match &array_type.size.inner {
|
||||
UExpressionInner::Identifier(v) => match self.get_constant(v) {
|
||||
Some(tc) => {
|
||||
let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap();
|
||||
Type::array(GArrayType::new(
|
||||
self.fold_type(*array_type.ty.clone()),
|
||||
expression,
|
||||
))
|
||||
}
|
||||
None => t,
|
||||
},
|
||||
_ => t,
|
||||
},
|
||||
Struct(struct_type) => Type::struc(GStructType {
|
||||
members: struct_type
|
||||
.members
|
||||
.into_iter()
|
||||
.map(|m| GStructMember::new(m.id, self.fold_type(*m.ty)))
|
||||
.collect(),
|
||||
..struct_type
|
||||
}),
|
||||
_ => t,
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_constant_symbol(
|
||||
&mut self,
|
||||
s: TypedConstantSymbol<'ast, T>,
|
||||
|
@ -636,11 +715,9 @@ mod tests {
|
|||
|
||||
let expected_main = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
)
|
||||
.into()])],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
|
@ -675,9 +752,8 @@ mod tests {
|
|||
const_b_id,
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(2),
|
||||
)),
|
||||
)),
|
||||
),
|
||||
|
|
|
@ -31,10 +31,21 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
fold_function_symbol(self, s)
|
||||
}
|
||||
|
||||
fn fold_declaration_function_key(
|
||||
&mut self,
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
) -> DeclarationFunctionKey<'ast> {
|
||||
fold_declaration_function_key(self, key)
|
||||
}
|
||||
|
||||
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
|
||||
fold_function(self, f)
|
||||
}
|
||||
|
||||
fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> {
|
||||
fold_signature(self, s)
|
||||
}
|
||||
|
||||
fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> {
|
||||
DeclarationParameter {
|
||||
id: self.fold_declaration_variable(p.id),
|
||||
|
@ -190,6 +201,7 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
) -> ArrayExpressionInner<'ast, T> {
|
||||
fold_array_expression_inner(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_struct_expression_inner(
|
||||
&mut self,
|
||||
ty: &StructType<'ast, T>,
|
||||
|
@ -212,7 +224,12 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
functions: m
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|(key, fun)| (key, f.fold_function_symbol(fun)))
|
||||
.map(|(key, fun)| {
|
||||
(
|
||||
f.fold_declaration_function_key(key),
|
||||
f.fold_function_symbol(fun),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
@ -653,6 +670,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_declaration_function_key<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
) -> DeclarationFunctionKey<'ast> {
|
||||
DeclarationFunctionKey {
|
||||
signature: f.fold_signature(key.signature),
|
||||
..key
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
fun: TypedFunction<'ast, T>,
|
||||
|
@ -668,7 +695,26 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
..fun
|
||||
signature: f.fold_signature(fun.signature),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: DeclarationSignature<'ast>,
|
||||
) -> DeclarationSignature<'ast> {
|
||||
DeclarationSignature {
|
||||
generics: s.generics,
|
||||
inputs: s
|
||||
.inputs
|
||||
.into_iter()
|
||||
.map(|o| f.fold_declaration_type(o))
|
||||
.collect(),
|
||||
outputs: s
|
||||
.outputs
|
||||
.into_iter()
|
||||
.map(|o| f.fold_declaration_type(o))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -721,9 +767,10 @@ pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
f: &mut F,
|
||||
e: StructExpression<'ast, T>,
|
||||
) -> StructExpression<'ast, T> {
|
||||
let ty = f.fold_struct_type(e.ty);
|
||||
StructExpression {
|
||||
inner: f.fold_struct_expression_inner(&e.ty, e.inner),
|
||||
..e
|
||||
inner: f.fold_struct_expression_inner(&ty, e.inner),
|
||||
ty,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -290,8 +290,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
|
|||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct TypedConstant<'ast, T> {
|
||||
ty: Type<'ast, T>,
|
||||
expression: TypedExpression<'ast, T>,
|
||||
pub ty: Type<'ast, T>,
|
||||
pub expression: TypedExpression<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T> TypedConstant<'ast, T> {
|
||||
|
|
|
@ -42,6 +42,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
fold_function_symbol(self, s)
|
||||
}
|
||||
|
||||
fn fold_declaration_function_key(
|
||||
&mut self,
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
) -> Result<DeclarationFunctionKey<'ast>, Self::Error> {
|
||||
fold_declaration_function_key(self, key)
|
||||
}
|
||||
|
||||
fn fold_function(
|
||||
&mut self,
|
||||
f: TypedFunction<'ast, T>,
|
||||
|
@ -49,6 +56,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
fold_function(self, f)
|
||||
}
|
||||
|
||||
fn fold_signature(
|
||||
&mut self,
|
||||
s: DeclarationSignature<'ast>,
|
||||
) -> Result<DeclarationSignature<'ast>, Self::Error> {
|
||||
fold_signature(self, s)
|
||||
}
|
||||
|
||||
fn fold_parameter(
|
||||
&mut self,
|
||||
p: DeclarationParameter<'ast>,
|
||||
|
@ -723,6 +737,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
Ok(e)
|
||||
}
|
||||
|
||||
pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
) -> Result<DeclarationFunctionKey<'ast>, F::Error> {
|
||||
Ok(DeclarationFunctionKey {
|
||||
signature: f.fold_signature(key.signature)?,
|
||||
..key
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
fun: TypedFunction<'ast, T>,
|
||||
|
@ -741,7 +765,26 @@ pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
..fun
|
||||
signature: f.fold_signature(fun.signature)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: DeclarationSignature<'ast>,
|
||||
) -> Result<DeclarationSignature<'ast>, F::Error> {
|
||||
Ok(DeclarationSignature {
|
||||
generics: s.generics,
|
||||
inputs: s
|
||||
.inputs
|
||||
.into_iter()
|
||||
.map(|o| f.fold_declaration_type(o))
|
||||
.collect::<Result<_, _>>()?,
|
||||
outputs: s
|
||||
.outputs
|
||||
.into_iter()
|
||||
.map(|o| f.fold_declaration_type(o))
|
||||
.collect::<Result<_, _>>()?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -801,9 +844,10 @@ pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
f: &mut F,
|
||||
e: StructExpression<'ast, T>,
|
||||
) -> Result<StructExpression<'ast, T>, F::Error> {
|
||||
let ty = f.fold_struct_type(e.ty)?;
|
||||
Ok(StructExpression {
|
||||
inner: f.fold_struct_expression_inner(&e.ty, e.inner)?,
|
||||
..e
|
||||
inner: f.fold_struct_expression_inner(&ty, e.inner)?,
|
||||
ty,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::typed_absy::{OwnedTypedModuleId, UExpression, UExpressionInner};
|
||||
use crate::typed_absy::{Identifier, OwnedTypedModuleId, UExpression, UExpressionInner};
|
||||
use crate::typed_absy::{TryFrom, TryInto};
|
||||
use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::collections::BTreeMap;
|
||||
|
@ -54,6 +54,7 @@ pub struct SpecializationError;
|
|||
pub enum Constant<'ast> {
|
||||
Generic(GenericIdentifier<'ast>),
|
||||
Concrete(u32),
|
||||
Identifier(&'ast str, usize),
|
||||
}
|
||||
|
||||
impl<'ast> From<u32> for Constant<'ast> {
|
||||
|
@ -79,6 +80,7 @@ impl<'ast> fmt::Display for Constant<'ast> {
|
|||
match self {
|
||||
Constant::Generic(i) => write!(f, "{}", i),
|
||||
Constant::Concrete(v) => write!(f, "{}", v),
|
||||
Constant::Identifier(v, _) => write!(f, "{}", v),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -96,6 +98,9 @@ impl<'ast, T> From<Constant<'ast>> for UExpression<'ast, T> {
|
|||
UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32)
|
||||
}
|
||||
Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32),
|
||||
Constant::Identifier(v, size) => {
|
||||
UExpressionInner::Identifier(Identifier::from(v)).annotate(UBitwidth::from(size))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -920,6 +925,7 @@ pub mod signature {
|
|||
}
|
||||
},
|
||||
Constant::Concrete(s0) => s1 == *s0 as usize,
|
||||
Constant::Identifier(_, s0) => s1 == *s0,
|
||||
}
|
||||
}
|
||||
(DeclarationType::FieldElement, GType::FieldElement)
|
||||
|
@ -945,6 +951,7 @@ pub mod signature {
|
|||
let size = match t0.size {
|
||||
Constant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
|
||||
Constant::Concrete(s) => Ok(s.into()),
|
||||
Constant::Identifier(_, s) => Ok((s as u32).into()),
|
||||
}?;
|
||||
|
||||
GType::Array(GArrayType { size, ty })
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
const field[2] ARRAY = [1, 2]
|
||||
const u32 N = 2
|
||||
const field[N] ARRAY = [1, 2]
|
||||
|
||||
def main() -> field[2]:
|
||||
def main() -> field[N]:
|
||||
return ARRAY
|
16
zokrates_core_test/tests/tests/constants/array_size.json
Normal file
16
zokrates_core_test/tests/tests/constants/array_size.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/constants/array_size.zok",
|
||||
"max_constraint_count": 2,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["42", "42"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["42", "42"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
5
zokrates_core_test/tests/tests/constants/array_size.zok
Normal file
5
zokrates_core_test/tests/tests/constants/array_size.zok
Normal file
|
@ -0,0 +1,5 @@
|
|||
const u32 SIZE = 2
|
||||
|
||||
def main(field[SIZE] a) -> field[SIZE]:
|
||||
field[SIZE] b = a
|
||||
return b
|
16
zokrates_core_test/tests/tests/constants/mixed.json
Normal file
16
zokrates_core_test/tests/tests/constants/mixed.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/constants/mixed.zok",
|
||||
"max_constraint_count": 6,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1", "2", "1", "3", "4", "0"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
15
zokrates_core_test/tests/tests/constants/mixed.zok
Normal file
15
zokrates_core_test/tests/tests/constants/mixed.zok
Normal file
|
@ -0,0 +1,15 @@
|
|||
const u32 N = 2
|
||||
const bool B = true
|
||||
|
||||
struct Foo {
|
||||
field[N] a
|
||||
bool b
|
||||
}
|
||||
|
||||
const Foo[N] F = [
|
||||
Foo { a: [1, 2], b: B },
|
||||
Foo { a: [3, 4], b: !B }
|
||||
]
|
||||
|
||||
def main() -> Foo[N]:
|
||||
return F
|
16
zokrates_core_test/tests/tests/constants/propagate.json
Normal file
16
zokrates_core_test/tests/tests/constants/propagate.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/constants/propagate.zok",
|
||||
"max_constraint_count": 4,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["42", "42", "42", "42"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
5
zokrates_core_test/tests/tests/constants/propagate.zok
Normal file
5
zokrates_core_test/tests/tests/constants/propagate.zok
Normal file
|
@ -0,0 +1,5 @@
|
|||
const u32 TWO = 2
|
||||
const u32 FOUR = TWO * TWO
|
||||
|
||||
def main() -> field[FOUR]:
|
||||
return [42; FOUR]
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/constants/struct.zok",
|
||||
"max_constraint_count": 1,
|
||||
"max_constraint_count": 6,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
|
@ -8,7 +8,7 @@
|
|||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["4"]
|
||||
"values": ["1", "2", "3", "4", "5", "6"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
struct Foo {
|
||||
field a
|
||||
field b
|
||||
const u32 N = 2
|
||||
|
||||
struct State {
|
||||
field[N] a
|
||||
field[N][N] b
|
||||
}
|
||||
|
||||
const Foo FOO = Foo { a: 2, b: 2 }
|
||||
const State STATE = State {
|
||||
a: [1, 2],
|
||||
b: [[3, 4], [5, 6]]
|
||||
}
|
||||
|
||||
def main() -> field:
|
||||
return FOO.a + FOO.b
|
||||
def main() -> State:
|
||||
return STATE
|
Loading…
Reference in a new issue