1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

fix imports, more tests

This commit is contained in:
dark64 2021-04-08 11:29:21 +02:00
parent 8f4ee00276
commit dafef03b1f
26 changed files with 589 additions and 145 deletions

View file

@ -0,0 +1,5 @@
const field ONE = 1
const field TWO = ONE + ONE
def main() -> field:
return TWO

View file

@ -0,0 +1,5 @@
const field a = 1
def main() -> field:
a = 2 // not allowed
return a

View file

@ -1,5 +1,7 @@
struct Bar {
}
struct Bar {}
const field ONE = 1
const field BAR = 21 * ONE
def main() -> field:
return 21
return BAR

View file

@ -1,5 +1,6 @@
struct Baz {
}
struct Baz {}
const field BAZ = 123
def main() -> field:
return 123
return BAZ

View file

@ -1,9 +1,10 @@
from "./baz" import Baz
import "./baz"
from "./baz" import main as my_function
import "./baz"
const field FOO = 144
def main() -> field:
field a = my_function()
Baz b = Baz {}
return baz()
Baz b = Baz {}
assert(baz() == my_function())
return FOO

View file

@ -1,11 +0,0 @@
from "./bar" import Bar as MyBar
from "./bar" import Bar
import "./foo"
import "./bar"
def main() -> field:
MyBar my_bar = MyBar {}
Bar bar = Bar {}
assert(my_bar == bar)
return foo() + bar()

View file

@ -0,0 +1,6 @@
from "./foo" import FOO
from "./bar" import BAR
from "./baz" import BAZ
def main() -> bool:
return FOO == BAR + BAZ

View file

@ -0,0 +1,6 @@
import "./foo"
import "./bar"
import "./baz"
def main() -> bool:
return foo() == bar() + baz()

View file

@ -0,0 +1,8 @@
from "./bar" import Bar as MyBar
from "./bar" import Bar
def main():
MyBar my_bar = MyBar {}
Bar bar = Bar {}
assert(my_bar == bar)
return

View file

@ -1,4 +1,8 @@
import "./foo" as d
from "./bar" import main as bar
from "./baz" import main as baz
import "./foo" as f
def main() -> field:
return d()
field foo = f()
assert(foo == bar() + baz())
return foo

View file

@ -407,7 +407,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
module_id: &ModuleId,
state: &mut State<'ast, T>,
functions: &mut HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>,
constants: &mut HashMap<identifier::Identifier<'ast>, TypedConstant<'ast, T>>,
constants: &mut HashMap<ConstantIdentifier<'ast>, TypedConstantSymbol<'ast, T>>,
symbol_unifier: &mut SymbolUnifier<'ast>,
) -> Result<(), Vec<Error>> {
let mut errors: Vec<Error> = vec![];
@ -470,8 +470,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
),
true => {}
};
constants
.insert(identifier::Identifier::from(declaration.id), c.clone());
constants.insert(declaration.id, TypedConstantSymbol::Here(c.clone()));
self.insert_into_scope(Variable::with_id_and_type(c.id, c.ty), true);
}
Err(e) => {
@ -549,8 +548,21 @@ impl<'ast, T: Field> Checker<'ast, T> {
.get(import.symbol_id)
.cloned();
match (function_candidates.len(), type_candidate) {
(0, Some(t)) => {
// find constant definition candidate
let const_candidate = state
.typed_modules
.get(&import.module_id)
.unwrap()
.constants
.as_ref()
.and_then(|tc| tc.get(import.symbol_id))
.and_then(|sym| match sym {
TypedConstantSymbol::Here(tc) => Some(tc),
_ => None,
});
match (function_candidates.len(), type_candidate, const_candidate) {
(0, Some(t), None) => {
// rename the type to the declared symbol
let t = match t {
@ -585,7 +597,26 @@ impl<'ast, T: Field> Checker<'ast, T> {
.or_default()
.insert(declaration.id.to_string(), t);
}
(0, None) => {
(0, None, Some(c)) => {
match symbol_unifier.insert_symbol(declaration.id, SymbolType::Constant) {
false => {
errors.push(Error {
module_id: module_id.to_path_buf(),
inner: ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
),
}});
}
true => {
constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id, declaration.id));
self.insert_into_scope(Variable::with_id_and_type(c.id.clone(), c.ty.clone()), true);
}
};
}
(0, None, None) => {
errors.push(ErrorInner {
pos: Some(pos),
message: format!(
@ -594,7 +625,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
),
}.in_file(module_id));
}
(_, Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"),
(_, Some(_), Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"),
_ => {
for candidate in function_candidates {

View file

@ -1,36 +1,92 @@
use crate::typed_absy::folder::{
fold_array_expression, fold_array_expression_inner, fold_boolean_expression,
fold_field_expression, fold_module, fold_struct_expression, fold_struct_expression_inner,
fold_uint_expression, fold_uint_expression_inner, Folder,
};
use crate::typed_absy::{
ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, FieldElementExpression,
StructExpression, StructExpressionInner, StructType, TypedConstants, TypedModule, TypedProgram,
UBitwidth, UExpression, UExpressionInner,
};
use std::collections::HashMap;
use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use std::convert::TryInto;
use zokrates_field::Field;
pub struct ConstantInliner<'ast, T: Field> {
constants: TypedConstants<'ast, T>,
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
}
impl<'ast, T: Field> ConstantInliner<'ast, T> {
fn with_modules_and_location(
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
) -> Self {
ConstantInliner { modules, location }
}
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
let mut inliner = ConstantInliner {
constants: HashMap::new(),
};
// initialize an inliner over all modules, starting from the main module
let mut inliner =
ConstantInliner::with_modules_and_location(p.modules.clone(), p.main.clone());
inliner.fold_program(p)
}
pub fn module(&self) -> &TypedModule<'ast, T> {
self.modules.get(&self.location).unwrap()
}
pub fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
let prev = self.location.clone();
self.location = location;
prev
}
pub fn get_constant(&mut self, id: &Identifier) -> Option<TypedConstant<'ast, T>> {
self.modules
.get(&self.location)
.unwrap()
.constants
.as_ref()
.and_then(|c| c.get(id.clone().try_into().unwrap()))
.cloned()
.and_then(|tc| {
let symbol = self.fold_constant_symbol(tc);
match symbol {
TypedConstantSymbol::Here(tc) => Some(tc),
_ => None,
}
})
}
}
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
fn fold_module(&mut self, p: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
self.constants = p.constants.clone().unwrap_or_default();
TypedModule {
functions: fold_module(self, p).functions,
constants: None,
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
TypedProgram {
modules: p
.modules
.into_iter()
.map(|(module_id, module)| {
self.change_location(module_id.clone());
(module_id, self.fold_module(module))
})
.collect(),
main: p.main,
}
}
fn fold_constant_symbol(
&mut self,
p: TypedConstantSymbol<'ast, T>,
) -> TypedConstantSymbol<'ast, T> {
match p {
TypedConstantSymbol::There(module_id, id) => {
let location = self.change_location(module_id);
let symbol = self
.module()
.constants
.as_ref()
.and_then(|c| c.get(id))
.unwrap()
.to_owned();
let symbol = self.fold_constant_symbol(symbol);
let _ = self.change_location(location);
symbol
}
_ => fold_constant_symbol(self, p),
}
}
@ -39,7 +95,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Identifier(ref id) => match self.constants.get(id).cloned() {
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => fold_field_expression(self, c.expression.try_into().unwrap()),
None => fold_field_expression(self, e),
},
@ -52,7 +108,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Identifier(ref id) => match self.constants.get(id).cloned() {
BooleanExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => fold_boolean_expression(self, c.expression.try_into().unwrap()),
None => fold_boolean_expression(self, e),
},
@ -66,14 +122,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Identifier(ref id) => match self.constants.get(id).cloned() {
UExpressionInner::Identifier(ref id) => match self.get_constant(id) {
Some(c) => {
let expr: UExpression<'ast, T> = c.expression.try_into().unwrap();
fold_uint_expression(self, expr).into_inner()
fold_uint_expression(self, c.expression.try_into().unwrap()).into_inner()
}
None => fold_uint_expression_inner(self, size, e),
},
// default
e => fold_uint_expression_inner(self, size, e),
}
}
@ -84,14 +138,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Identifier(ref id) => match self.constants.get(id).cloned() {
ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) {
Some(c) => {
let expr: ArrayExpression<'ast, T> = c.expression.try_into().unwrap();
fold_array_expression(self, expr).into_inner()
fold_array_expression(self, c.expression.try_into().unwrap()).into_inner()
}
None => fold_array_expression_inner(self, ty, e),
},
// default
e => fold_array_expression_inner(self, ty, e),
}
}
@ -102,14 +154,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Identifier(ref id) => match self.constants.get(id).cloned() {
StructExpressionInner::Identifier(ref id) => match self.get_constant(id) {
Some(c) => {
let expr: StructExpression<'ast, T> = c.expression.try_into().unwrap();
fold_struct_expression(self, expr).into_inner()
fold_struct_expression(self, c.expression.try_into().unwrap()).into_inner()
}
None => fold_struct_expression_inner(self, ty, e),
},
// default
e => fold_struct_expression_inner(self, ty, e),
}
}
@ -132,28 +182,29 @@ mod tests {
// def main() -> field:
// return a
let const_id = Identifier::from("a");
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(const_id.clone()).into(),
FieldElementExpression::Identifier(Identifier::from(const_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let mut constants = TypedConstants::<Bn128Field>::new();
constants.insert(
const_id.clone(),
TypedConstant {
id: const_id.clone(),
let constants: TypedConstantSymbols<_> = vec![(
const_id,
TypedConstantSymbol::Here(TypedConstant {
id: Identifier::from(const_id),
ty: GType::FieldElement,
expression: (TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
))),
},
);
}),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
@ -170,7 +221,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Some(constants),
constants: Some(constants.clone()),
},
)]
.into_iter()
@ -204,7 +255,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: None,
constants: Some(constants),
},
)]
.into_iter()
@ -221,11 +272,11 @@ mod tests {
// def main() -> bool:
// return a
let const_id = Identifier::from("a");
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier(
const_id.clone(),
Identifier::from(const_id),
)
.into()])],
signature: DeclarationSignature::new()
@ -233,15 +284,16 @@ mod tests {
.outputs(vec![DeclarationType::Boolean]),
};
let mut constants = TypedConstants::<Bn128Field>::new();
constants.insert(
const_id.clone(),
TypedConstant {
id: const_id.clone(),
let constants: TypedConstantSymbols<_> = vec![(
const_id,
TypedConstantSymbol::Here(TypedConstant {
id: Identifier::from(const_id),
ty: GType::Boolean,
expression: (TypedExpression::Boolean(BooleanExpression::Value(true))),
},
);
}),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
@ -258,7 +310,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Some(constants),
constants: Some(constants.clone()),
},
)]
.into_iter()
@ -292,7 +344,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: None,
constants: Some(constants),
},
)]
.into_iter()
@ -309,11 +361,11 @@ mod tests {
// def main() -> u32:
// return a
let const_id = Identifier::from("a");
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier(
const_id.clone(),
Identifier::from(const_id),
)
.annotate(UBitwidth::B32)
.into()])],
@ -322,17 +374,18 @@ mod tests {
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
};
let mut constants = TypedConstants::<Bn128Field>::new();
constants.insert(
const_id.clone(),
TypedConstant {
id: const_id.clone(),
let constants: TypedConstantSymbols<_> = vec![(
const_id,
TypedConstantSymbol::Here(TypedConstant {
id: Identifier::from(const_id),
ty: GType::Uint(UBitwidth::B32),
expression: (UExpressionInner::Value(1u128)
.annotate(UBitwidth::B32)
.into()),
},
);
}),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
@ -349,7 +402,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Some(constants),
constants: Some(constants.clone()),
},
)]
.into_iter()
@ -383,7 +436,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: None,
constants: Some(constants),
},
)]
.into_iter()
@ -400,18 +453,18 @@ mod tests {
// def main() -> field:
// return a[0] + a[1]
let const_id = Identifier::from("a");
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(const_id.clone())
box ArrayExpressionInner::Identifier(Identifier::from(const_id))
.annotate(GType::FieldElement, 2usize),
box UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
)
.into(),
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(const_id.clone())
box ArrayExpressionInner::Identifier(Identifier::from(const_id))
.annotate(GType::FieldElement, 2usize),
box UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
)
@ -423,11 +476,10 @@ mod tests {
.outputs(vec![DeclarationType::FieldElement]),
};
let mut constants = TypedConstants::<Bn128Field>::new();
constants.insert(
const_id.clone(),
TypedConstant {
id: const_id.clone(),
let constants: TypedConstantSymbols<_> = vec![(
const_id,
TypedConstantSymbol::Here(TypedConstant {
id: Identifier::from(const_id),
ty: GType::FieldElement,
expression: TypedExpression::Array(
ArrayExpressionInner::Value(
@ -439,8 +491,10 @@ mod tests {
)
.annotate(GType::FieldElement, 2usize),
),
},
);
}),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
@ -457,7 +511,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Some(constants),
constants: Some(constants.clone()),
},
)]
.into_iter()
@ -515,7 +569,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: None,
constants: Some(constants),
},
)]
.into_iter()
@ -533,44 +587,19 @@ mod tests {
// def main() -> field:
// return b
let const_a_id = Identifier::from("a");
let const_b_id = Identifier::from("b");
let const_a_id = "a";
let const_b_id = "b";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(const_b_id.clone()).into(),
FieldElementExpression::Identifier(Identifier::from(const_b_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let mut constants = TypedConstants::<Bn128Field>::new();
constants.extend(vec![
(
const_a_id.clone(),
TypedConstant {
id: const_a_id.clone(),
ty: GType::FieldElement,
expression: (TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
))),
},
),
(
const_b_id.clone(),
TypedConstant {
id: const_b_id.clone(),
ty: GType::FieldElement,
expression: (TypedExpression::FieldElement(FieldElementExpression::Add(
box FieldElementExpression::Identifier(const_a_id.clone()),
box FieldElementExpression::Number(Bn128Field::from(1)),
))),
},
),
]);
let program = TypedProgram {
main: "main".into(),
modules: vec![(
@ -586,7 +615,37 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Some(constants),
constants: Some(
vec![
(
const_a_id,
TypedConstantSymbol::Here(TypedConstant {
id: Identifier::from(const_a_id),
ty: GType::FieldElement,
expression: (TypedExpression::FieldElement(
FieldElementExpression::Number(Bn128Field::from(1)),
)),
}),
),
(
const_b_id,
TypedConstantSymbol::Here(TypedConstant {
id: Identifier::from(const_b_id),
ty: GType::FieldElement,
expression: (TypedExpression::FieldElement(
FieldElementExpression::Add(
box FieldElementExpression::Identifier(
Identifier::from(const_a_id),
),
box FieldElementExpression::Number(Bn128Field::from(1)),
),
)),
}),
),
]
.into_iter()
.collect(),
),
},
)]
.into_iter()
@ -622,7 +681,35 @@ mod tests {
)]
.into_iter()
.collect(),
constants: None,
constants: Some(
vec![
(
const_a_id,
TypedConstantSymbol::Here(TypedConstant {
id: Identifier::from(const_a_id),
ty: GType::FieldElement,
expression: (TypedExpression::FieldElement(
FieldElementExpression::Number(Bn128Field::from(1)),
)),
}),
),
(
const_b_id,
TypedConstantSymbol::Here(TypedConstant {
id: Identifier::from(const_b_id),
ty: GType::FieldElement,
expression: (TypedExpression::FieldElement(
FieldElementExpression::Add(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1)),
),
)),
}),
),
]
.into_iter()
.collect(),
),
},
)]
.into_iter()
@ -631,4 +718,139 @@ mod tests {
assert_eq!(program, expected_program)
}
#[test]
fn inline_imported_constant() {
// ---------------------
// module `foo`
// --------------------
// const field FOO = 42
//
// def main():
// return
//
// ---------------------
// module `main`
// ---------------------
// from "foo" import FOO
//
// def main() -> field:
// return FOO
let foo_const_id = "FOO";
let foo_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main")
.signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![],
signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
}),
)]
.into_iter()
.collect(),
constants: Some(
vec![(
foo_const_id,
TypedConstantSymbol::Here(TypedConstant {
id: Identifier::from(foo_const_id),
ty: GType::FieldElement,
expression: (TypedExpression::FieldElement(
FieldElementExpression::Number(Bn128Field::from(42)),
)),
}),
)]
.into_iter()
.collect(),
),
};
let main_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(foo_const_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)]
.into_iter()
.collect(),
constants: Some(
vec![(
foo_const_id,
TypedConstantSymbol::There(OwnedTypedModuleId::from("foo"), foo_const_id),
)]
.into_iter()
.collect(),
),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), main_module),
("foo".into(), foo_module.clone()),
]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Number(Bn128Field::from(42)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)]
.into_iter()
.collect(),
constants: Some(
vec![(
foo_const_id,
TypedConstantSymbol::Here(TypedConstant {
id: Identifier::from(foo_const_id),
ty: GType::FieldElement,
expression: (TypedExpression::FieldElement(
FieldElementExpression::Number(Bn128Field::from(42)),
)),
}),
)]
.into_iter()
.collect(),
),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), expected_main_module),
("foo".into(), foo_module),
]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
}

View file

@ -13,6 +13,13 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_module(self, p)
}
fn fold_constant_symbol(
&mut self,
p: TypedConstantSymbol<'ast, T>,
) -> TypedConstantSymbol<'ast, T> {
fold_constant_symbol(self, p)
}
fn fold_function_symbol(
&mut self,
s: TypedFunctionSymbol<'ast, T>,
@ -193,12 +200,16 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
p: TypedModule<'ast, T>,
) -> TypedModule<'ast, T> {
TypedModule {
constants: p.constants.map(|tc| {
tc.into_iter()
.map(|(key, tc)| (key, f.fold_constant_symbol(tc)))
.collect()
}),
functions: p
.functions
.into_iter()
.map(|(key, fun)| (key, f.fold_function_symbol(fun)))
.collect(),
constants: p.constants,
}
}
@ -692,6 +703,19 @@ pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>(
}
}
pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
p: TypedConstantSymbol<'ast, T>,
) -> TypedConstantSymbol<'ast, T> {
match p {
TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(TypedConstant {
expression: f.fold_expression(tc.expression),
..tc
}),
there => there,
}
}
pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: TypedFunctionSymbol<'ast, T>,

View file

@ -61,8 +61,17 @@ pub type TypedModules<'ast, T> = HashMap<OwnedTypedModuleId, TypedModule<'ast, T
pub type TypedFunctionSymbols<'ast, T> =
HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>;
/// A collection of `TypedConstant`s
pub type TypedConstants<'ast, T> = HashMap<Identifier<'ast>, TypedConstant<'ast, T>>;
pub type ConstantIdentifier<'ast> = &'ast str;
#[derive(Clone, Debug, PartialEq)]
pub enum TypedConstantSymbol<'ast, T> {
Here(TypedConstant<'ast, T>),
There(OwnedTypedModuleId, ConstantIdentifier<'ast>),
}
/// A collection of `TypedConstantSymbol`s
pub type TypedConstantSymbols<'ast, T> =
HashMap<ConstantIdentifier<'ast>, TypedConstantSymbol<'ast, T>>;
/// A typed program as a collection of modules, one of them being the main
#[derive(PartialEq, Debug, Clone)]
@ -144,7 +153,7 @@ pub struct TypedModule<'ast, T> {
/// Functions of the module
pub functions: TypedFunctionSymbols<'ast, T>,
/// Constants defined in module
pub constants: Option<TypedConstants<'ast, T>>,
pub constants: Option<TypedConstantSymbols<'ast, T>>,
}
#[derive(Clone, PartialEq)]
@ -320,7 +329,11 @@ pub struct TypedConstant<'ast, T> {
impl<'ast, T: fmt::Debug> fmt::Debug for TypedConstant<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "TypedConstant({:?}, {:?}, ...)", self.id, self.ty)
write!(
f,
"TypedConstant({:?}, {:?}, {:?})",
self.id, self.ty, self.expression
)
}
}

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/array.zok",
"max_constraint_count": 2,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["1", "2"]
}
}
}
]
}

View file

@ -0,0 +1,4 @@
const field[2] ARRAY = [1, 2]
def main() -> field[2]:
return ARRAY

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/bool.zok",
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["1"]
}
}
}
]
}

View file

@ -0,0 +1,4 @@
const bool BOOLEAN = true
def main() -> bool:
return BOOLEAN

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/field.zok",
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["1"]
}
}
}
]
}

View file

@ -0,0 +1,4 @@
const field ONE = 1
def main() -> field:
return ONE

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/nested.zok",
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["8"]
}
}
}
]
}

View file

@ -0,0 +1,6 @@
const field A = 2
const field B = 2
const field[2] ARRAY = [A * 2, B * 2]
def main() -> field:
return ARRAY[0] + ARRAY[1]

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/struct.zok",
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["4"]
}
}
}
]
}

View file

@ -0,0 +1,9 @@
struct Foo {
field a
field b
}
const Foo FOO = Foo { a: 2, b: 2 }
def main() -> field:
return FOO.a + FOO.b

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/uint.zok",
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["1"]
}
}
}
]
}

View file

@ -0,0 +1,4 @@
const u32 ONE = 0x00000001
def main() -> u32:
return ONE