1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

Merge pull request #864 from Zokrates/fix-constants

Support the use of constants in declaration types
This commit is contained in:
Thibaut Schaeffer 2021-05-16 23:29:00 +02:00 committed by GitHub
commit 2ed9e6a972
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 523 additions and 224 deletions

View file

@ -0,0 +1 @@
Support the use of constants in struct and function declarations

View file

@ -0,0 +1,4 @@
const field SIZE = 2
def main(field[SIZE] n):
return

View file

@ -0,0 +1,4 @@
const u8 SIZE = 0x02
def main(field[SIZE] n):
return

View file

@ -0,0 +1,7 @@
const u32 N = 42
def foo<N>(field[N] a) -> bool:
return true
def main():
return

File diff suppressed because it is too large Load diff

View file

@ -1,20 +1,37 @@
use crate::static_analysis::propagation::Propagator;
use crate::typed_absy::folder::*;
use crate::typed_absy::result_folder::ResultFolder;
use crate::typed_absy::types::{Constant, DeclarationStructType, GStructMember};
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryInto;
use zokrates_field::Field;
pub struct ConstantInliner<'ast, T: Field> {
pub struct ConstantInliner<'ast, 'a, T: Field> {
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
propagator: Propagator<'ast, 'a, T>,
}
impl<'ast, T: Field> ConstantInliner<'ast, T> {
pub fn new(modules: TypedModules<'ast, T>, location: OwnedTypedModuleId) -> Self {
ConstantInliner { modules, location }
impl<'ast, 'a, T: Field> ConstantInliner<'ast, 'a, T> {
pub fn new(
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
propagator: Propagator<'ast, 'a, T>,
) -> Self {
ConstantInliner {
modules,
location,
propagator,
}
}
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone());
let mut constants = HashMap::new();
let mut inliner = ConstantInliner::new(
p.modules.clone(),
p.main.clone(),
Propagator::with_constants(&mut constants),
);
inliner.fold_program(p)
}
@ -51,12 +68,18 @@ impl<'ast, T: Field> ConstantInliner<'ast, T> {
let _ = self.change_location(location);
symbol
}
TypedConstantSymbol::Here(tc) => self.fold_constant(tc),
TypedConstantSymbol::Here(tc) => {
let tc: TypedConstant<T> = self.fold_constant(tc);
TypedConstant {
expression: self.propagator.fold_expression(tc.expression).unwrap(),
..tc
}
}
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> {
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
TypedProgram {
modules: p
@ -71,6 +94,62 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
}
}
fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> {
match t {
DeclarationType::Array(ref array_ty) => match array_ty.size {
Constant::Identifier(name, _) => {
let tc = self.get_constant(&name.into()).unwrap();
let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap();
match expression.inner {
UExpressionInner::Value(v) => DeclarationType::array((
self.fold_declaration_type(*array_ty.ty.clone()),
Constant::Concrete(v as u32),
)),
_ => unreachable!("expected u32 value"),
}
}
_ => t,
},
DeclarationType::Struct(struct_ty) => DeclarationType::struc(DeclarationStructType {
members: struct_ty
.members
.into_iter()
.map(|m| GStructMember::new(m.id, self.fold_declaration_type(*m.ty)))
.collect(),
..struct_ty
}),
_ => t,
}
}
fn fold_type(&mut self, t: Type<'ast, T>) -> Type<'ast, T> {
use self::GType::*;
match t {
Array(ref array_type) => match &array_type.size.inner {
UExpressionInner::Identifier(v) => match self.get_constant(v) {
Some(tc) => {
let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap();
Type::array(GArrayType::new(
self.fold_type(*array_type.ty.clone()),
expression,
))
}
None => t,
},
_ => t,
},
Struct(struct_type) => Type::struc(GStructType {
members: struct_type
.members
.into_iter()
.map(|m| GStructMember::new(m.id, self.fold_type(*m.ty)))
.collect(),
..struct_type
}),
_ => t,
}
}
fn fold_constant_symbol(
&mut self,
s: TypedConstantSymbol<'ast, T>,
@ -636,11 +715,9 @@ mod tests {
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1)),
)
.into()])],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
@ -675,9 +752,8 @@ mod tests {
const_b_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Add(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1)),
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(2),
)),
)),
),

View file

@ -31,10 +31,21 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_function_symbol(self, s)
}
fn fold_declaration_function_key(
&mut self,
key: DeclarationFunctionKey<'ast>,
) -> DeclarationFunctionKey<'ast> {
fold_declaration_function_key(self, key)
}
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
fold_function(self, f)
}
fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> {
fold_signature(self, s)
}
fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> {
DeclarationParameter {
id: self.fold_declaration_variable(p.id),
@ -190,6 +201,7 @@ pub trait Folder<'ast, T: Field>: Sized {
) -> ArrayExpressionInner<'ast, T> {
fold_array_expression_inner(self, ty, e)
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
@ -212,7 +224,12 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
functions: m
.functions
.into_iter()
.map(|(key, fun)| (key, f.fold_function_symbol(fun)))
.map(|(key, fun)| {
(
f.fold_declaration_function_key(key),
f.fold_function_symbol(fun),
)
})
.collect(),
}
}
@ -653,6 +670,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
}
}
pub fn fold_declaration_function_key<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
key: DeclarationFunctionKey<'ast>,
) -> DeclarationFunctionKey<'ast> {
DeclarationFunctionKey {
signature: f.fold_signature(key.signature),
..key
}
}
pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
fun: TypedFunction<'ast, T>,
@ -668,7 +695,26 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
..fun
signature: f.fold_signature(fun.signature),
}
}
fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: DeclarationSignature<'ast>,
) -> DeclarationSignature<'ast> {
DeclarationSignature {
generics: s.generics,
inputs: s
.inputs
.into_iter()
.map(|o| f.fold_declaration_type(o))
.collect(),
outputs: s
.outputs
.into_iter()
.map(|o| f.fold_declaration_type(o))
.collect(),
}
}
@ -721,9 +767,10 @@ pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
e: StructExpression<'ast, T>,
) -> StructExpression<'ast, T> {
let ty = f.fold_struct_type(e.ty);
StructExpression {
inner: f.fold_struct_expression_inner(&e.ty, e.inner),
..e
inner: f.fold_struct_expression_inner(&ty, e.inner),
ty,
}
}

View file

@ -290,8 +290,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Debug)]
pub struct TypedConstant<'ast, T> {
ty: Type<'ast, T>,
expression: TypedExpression<'ast, T>,
pub ty: Type<'ast, T>,
pub expression: TypedExpression<'ast, T>,
}
impl<'ast, T> TypedConstant<'ast, T> {

View file

@ -42,6 +42,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_function_symbol(self, s)
}
fn fold_declaration_function_key(
&mut self,
key: DeclarationFunctionKey<'ast>,
) -> Result<DeclarationFunctionKey<'ast>, Self::Error> {
fold_declaration_function_key(self, key)
}
fn fold_function(
&mut self,
f: TypedFunction<'ast, T>,
@ -49,6 +56,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_function(self, f)
}
fn fold_signature(
&mut self,
s: DeclarationSignature<'ast>,
) -> Result<DeclarationSignature<'ast>, Self::Error> {
fold_signature(self, s)
}
fn fold_parameter(
&mut self,
p: DeclarationParameter<'ast>,
@ -723,6 +737,16 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
Ok(e)
}
pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
key: DeclarationFunctionKey<'ast>,
) -> Result<DeclarationFunctionKey<'ast>, F::Error> {
Ok(DeclarationFunctionKey {
signature: f.fold_signature(key.signature)?,
..key
})
}
pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
fun: TypedFunction<'ast, T>,
@ -741,7 +765,26 @@ pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>(
.into_iter()
.flatten()
.collect(),
..fun
signature: f.fold_signature(fun.signature)?,
})
}
fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
s: DeclarationSignature<'ast>,
) -> Result<DeclarationSignature<'ast>, F::Error> {
Ok(DeclarationSignature {
generics: s.generics,
inputs: s
.inputs
.into_iter()
.map(|o| f.fold_declaration_type(o))
.collect::<Result<_, _>>()?,
outputs: s
.outputs
.into_iter()
.map(|o| f.fold_declaration_type(o))
.collect::<Result<_, _>>()?,
})
}
@ -801,9 +844,10 @@ pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
e: StructExpression<'ast, T>,
) -> Result<StructExpression<'ast, T>, F::Error> {
let ty = f.fold_struct_type(e.ty)?;
Ok(StructExpression {
inner: f.fold_struct_expression_inner(&e.ty, e.inner)?,
..e
inner: f.fold_struct_expression_inner(&ty, e.inner)?,
ty,
})
}

View file

@ -1,4 +1,4 @@
use crate::typed_absy::{OwnedTypedModuleId, UExpression, UExpressionInner};
use crate::typed_absy::{Identifier, OwnedTypedModuleId, UExpression, UExpressionInner};
use crate::typed_absy::{TryFrom, TryInto};
use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
use std::collections::BTreeMap;
@ -54,6 +54,7 @@ pub struct SpecializationError;
pub enum Constant<'ast> {
Generic(GenericIdentifier<'ast>),
Concrete(u32),
Identifier(&'ast str, usize),
}
impl<'ast> From<u32> for Constant<'ast> {
@ -79,6 +80,7 @@ impl<'ast> fmt::Display for Constant<'ast> {
match self {
Constant::Generic(i) => write!(f, "{}", i),
Constant::Concrete(v) => write!(f, "{}", v),
Constant::Identifier(v, _) => write!(f, "{}", v),
}
}
}
@ -96,6 +98,9 @@ impl<'ast, T> From<Constant<'ast>> for UExpression<'ast, T> {
UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32)
}
Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32),
Constant::Identifier(v, size) => {
UExpressionInner::Identifier(Identifier::from(v)).annotate(UBitwidth::from(size))
}
}
}
}
@ -920,6 +925,7 @@ pub mod signature {
}
},
Constant::Concrete(s0) => s1 == *s0 as usize,
Constant::Identifier(_, s0) => s1 == *s0,
}
}
(DeclarationType::FieldElement, GType::FieldElement)
@ -945,6 +951,7 @@ pub mod signature {
let size = match t0.size {
Constant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
Constant::Concrete(s) => Ok(s.into()),
Constant::Identifier(_, s) => Ok((s as u32).into()),
}?;
GType::Array(GArrayType { size, ty })

View file

@ -1,4 +1,5 @@
const field[2] ARRAY = [1, 2]
const u32 N = 2
const field[N] ARRAY = [1, 2]
def main() -> field[2]:
def main() -> field[N]:
return ARRAY

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/array_size.zok",
"max_constraint_count": 2,
"tests": [
{
"input": {
"values": ["42", "42"]
},
"output": {
"Ok": {
"values": ["42", "42"]
}
}
}
]
}

View file

@ -0,0 +1,5 @@
const u32 SIZE = 2
def main(field[SIZE] a) -> field[SIZE]:
field[SIZE] b = a
return b

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/mixed.zok",
"max_constraint_count": 6,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["1", "2", "1", "3", "4", "0"]
}
}
}
]
}

View file

@ -0,0 +1,15 @@
const u32 N = 2
const bool B = true
struct Foo {
field[N] a
bool b
}
const Foo[N] F = [
Foo { a: [1, 2], b: B },
Foo { a: [3, 4], b: !B }
]
def main() -> Foo[N]:
return F

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/propagate.zok",
"max_constraint_count": 4,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["42", "42", "42", "42"]
}
}
}
]
}

View file

@ -0,0 +1,5 @@
const u32 TWO = 2
const u32 FOUR = TWO * TWO
def main() -> field[FOUR]:
return [42; FOUR]

View file

@ -1,6 +1,6 @@
{
"entry_point": "./tests/tests/constants/struct.zok",
"max_constraint_count": 1,
"max_constraint_count": 6,
"tests": [
{
"input": {
@ -8,7 +8,7 @@
},
"output": {
"Ok": {
"values": ["4"]
"values": ["1", "2", "3", "4", "5", "6"]
}
}
}

View file

@ -1,9 +1,14 @@
struct Foo {
field a
field b
const u32 N = 2
struct State {
field[N] a
field[N][N] b
}
const Foo FOO = Foo { a: 2, b: 2 }
const State STATE = State {
a: [1, 2],
b: [[3, 4], [5, 6]]
}
def main() -> field:
return FOO.a + FOO.b
def main() -> State:
return STATE