fix imports, more tests
This commit is contained in:
parent
8f4ee00276
commit
dafef03b1f
26 changed files with 589 additions and 145 deletions
5
zokrates_cli/examples/book/constant_definition.zok
Normal file
5
zokrates_cli/examples/book/constant_definition.zok
Normal file
|
@ -0,0 +1,5 @@
|
|||
const field ONE = 1
|
||||
const field TWO = ONE + ONE
|
||||
|
||||
def main() -> field:
|
||||
return TWO
|
|
@ -0,0 +1,5 @@
|
|||
const field a = 1
|
||||
|
||||
def main() -> field:
|
||||
a = 2 // not allowed
|
||||
return a
|
|
@ -1,5 +1,7 @@
|
|||
struct Bar {
|
||||
}
|
||||
struct Bar {}
|
||||
|
||||
const field ONE = 1
|
||||
const field BAR = 21 * ONE
|
||||
|
||||
def main() -> field:
|
||||
return 21
|
||||
return BAR
|
|
@ -1,5 +1,6 @@
|
|||
struct Baz {
|
||||
}
|
||||
struct Baz {}
|
||||
|
||||
const field BAZ = 123
|
||||
|
||||
def main() -> field:
|
||||
return 123
|
||||
return BAZ
|
|
@ -1,9 +1,10 @@
|
|||
from "./baz" import Baz
|
||||
|
||||
import "./baz"
|
||||
from "./baz" import main as my_function
|
||||
import "./baz"
|
||||
|
||||
const field FOO = 144
|
||||
|
||||
def main() -> field:
|
||||
field a = my_function()
|
||||
Baz b = Baz {}
|
||||
return baz()
|
||||
Baz b = Baz {}
|
||||
assert(baz() == my_function())
|
||||
return FOO
|
|
@ -1,11 +0,0 @@
|
|||
from "./bar" import Bar as MyBar
|
||||
from "./bar" import Bar
|
||||
|
||||
import "./foo"
|
||||
import "./bar"
|
||||
|
||||
def main() -> field:
|
||||
MyBar my_bar = MyBar {}
|
||||
Bar bar = Bar {}
|
||||
assert(my_bar == bar)
|
||||
return foo() + bar()
|
6
zokrates_cli/examples/imports/import_constants.zok
Normal file
6
zokrates_cli/examples/imports/import_constants.zok
Normal file
|
@ -0,0 +1,6 @@
|
|||
from "./foo" import FOO
|
||||
from "./bar" import BAR
|
||||
from "./baz" import BAZ
|
||||
|
||||
def main() -> bool:
|
||||
return FOO == BAR + BAZ
|
6
zokrates_cli/examples/imports/import_functions.zok
Normal file
6
zokrates_cli/examples/imports/import_functions.zok
Normal file
|
@ -0,0 +1,6 @@
|
|||
import "./foo"
|
||||
import "./bar"
|
||||
import "./baz"
|
||||
|
||||
def main() -> bool:
|
||||
return foo() == bar() + baz()
|
8
zokrates_cli/examples/imports/import_structs.zok
Normal file
8
zokrates_cli/examples/imports/import_structs.zok
Normal file
|
@ -0,0 +1,8 @@
|
|||
from "./bar" import Bar as MyBar
|
||||
from "./bar" import Bar
|
||||
|
||||
def main():
|
||||
MyBar my_bar = MyBar {}
|
||||
Bar bar = Bar {}
|
||||
assert(my_bar == bar)
|
||||
return
|
|
@ -1,4 +1,8 @@
|
|||
import "./foo" as d
|
||||
from "./bar" import main as bar
|
||||
from "./baz" import main as baz
|
||||
import "./foo" as f
|
||||
|
||||
def main() -> field:
|
||||
return d()
|
||||
field foo = f()
|
||||
assert(foo == bar() + baz())
|
||||
return foo
|
|
@ -407,7 +407,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
module_id: &ModuleId,
|
||||
state: &mut State<'ast, T>,
|
||||
functions: &mut HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>,
|
||||
constants: &mut HashMap<identifier::Identifier<'ast>, TypedConstant<'ast, T>>,
|
||||
constants: &mut HashMap<ConstantIdentifier<'ast>, TypedConstantSymbol<'ast, T>>,
|
||||
symbol_unifier: &mut SymbolUnifier<'ast>,
|
||||
) -> Result<(), Vec<Error>> {
|
||||
let mut errors: Vec<Error> = vec![];
|
||||
|
@ -470,8 +470,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
),
|
||||
true => {}
|
||||
};
|
||||
constants
|
||||
.insert(identifier::Identifier::from(declaration.id), c.clone());
|
||||
constants.insert(declaration.id, TypedConstantSymbol::Here(c.clone()));
|
||||
self.insert_into_scope(Variable::with_id_and_type(c.id, c.ty), true);
|
||||
}
|
||||
Err(e) => {
|
||||
|
@ -549,8 +548,21 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.get(import.symbol_id)
|
||||
.cloned();
|
||||
|
||||
match (function_candidates.len(), type_candidate) {
|
||||
(0, Some(t)) => {
|
||||
// find constant definition candidate
|
||||
let const_candidate = state
|
||||
.typed_modules
|
||||
.get(&import.module_id)
|
||||
.unwrap()
|
||||
.constants
|
||||
.as_ref()
|
||||
.and_then(|tc| tc.get(import.symbol_id))
|
||||
.and_then(|sym| match sym {
|
||||
TypedConstantSymbol::Here(tc) => Some(tc),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
match (function_candidates.len(), type_candidate, const_candidate) {
|
||||
(0, Some(t), None) => {
|
||||
|
||||
// rename the type to the declared symbol
|
||||
let t = match t {
|
||||
|
@ -585,7 +597,26 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.or_default()
|
||||
.insert(declaration.id.to_string(), t);
|
||||
}
|
||||
(0, None) => {
|
||||
(0, None, Some(c)) => {
|
||||
match symbol_unifier.insert_symbol(declaration.id, SymbolType::Constant) {
|
||||
false => {
|
||||
errors.push(Error {
|
||||
module_id: module_id.to_path_buf(),
|
||||
inner: ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"{} conflicts with another symbol",
|
||||
declaration.id,
|
||||
),
|
||||
}});
|
||||
}
|
||||
true => {
|
||||
constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id, declaration.id));
|
||||
self.insert_into_scope(Variable::with_id_and_type(c.id.clone(), c.ty.clone()), true);
|
||||
}
|
||||
};
|
||||
}
|
||||
(0, None, None) => {
|
||||
errors.push(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
|
@ -594,7 +625,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
),
|
||||
}.in_file(module_id));
|
||||
}
|
||||
(_, Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"),
|
||||
(_, Some(_), Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"),
|
||||
_ => {
|
||||
for candidate in function_candidates {
|
||||
|
||||
|
|
|
@ -1,36 +1,92 @@
|
|||
use crate::typed_absy::folder::{
|
||||
fold_array_expression, fold_array_expression_inner, fold_boolean_expression,
|
||||
fold_field_expression, fold_module, fold_struct_expression, fold_struct_expression_inner,
|
||||
fold_uint_expression, fold_uint_expression_inner, Folder,
|
||||
};
|
||||
use crate::typed_absy::{
|
||||
ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, FieldElementExpression,
|
||||
StructExpression, StructExpressionInner, StructType, TypedConstants, TypedModule, TypedProgram,
|
||||
UBitwidth, UExpression, UExpressionInner,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use crate::typed_absy::folder::*;
|
||||
use crate::typed_absy::*;
|
||||
use std::convert::TryInto;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub struct ConstantInliner<'ast, T: Field> {
|
||||
constants: TypedConstants<'ast, T>,
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ConstantInliner<'ast, T> {
|
||||
fn with_modules_and_location(
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
) -> Self {
|
||||
ConstantInliner { modules, location }
|
||||
}
|
||||
|
||||
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
let mut inliner = ConstantInliner {
|
||||
constants: HashMap::new(),
|
||||
};
|
||||
// initialize an inliner over all modules, starting from the main module
|
||||
let mut inliner =
|
||||
ConstantInliner::with_modules_and_location(p.modules.clone(), p.main.clone());
|
||||
|
||||
inliner.fold_program(p)
|
||||
}
|
||||
|
||||
pub fn module(&self) -> &TypedModule<'ast, T> {
|
||||
self.modules.get(&self.location).unwrap()
|
||||
}
|
||||
|
||||
pub fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
|
||||
let prev = self.location.clone();
|
||||
self.location = location;
|
||||
prev
|
||||
}
|
||||
|
||||
pub fn get_constant(&mut self, id: &Identifier) -> Option<TypedConstant<'ast, T>> {
|
||||
self.modules
|
||||
.get(&self.location)
|
||||
.unwrap()
|
||||
.constants
|
||||
.as_ref()
|
||||
.and_then(|c| c.get(id.clone().try_into().unwrap()))
|
||||
.cloned()
|
||||
.and_then(|tc| {
|
||||
let symbol = self.fold_constant_symbol(tc);
|
||||
match symbol {
|
||||
TypedConstantSymbol::Here(tc) => Some(tc),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
||||
fn fold_module(&mut self, p: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
|
||||
self.constants = p.constants.clone().unwrap_or_default();
|
||||
TypedModule {
|
||||
functions: fold_module(self, p).functions,
|
||||
constants: None,
|
||||
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
TypedProgram {
|
||||
modules: p
|
||||
.modules
|
||||
.into_iter()
|
||||
.map(|(module_id, module)| {
|
||||
self.change_location(module_id.clone());
|
||||
(module_id, self.fold_module(module))
|
||||
})
|
||||
.collect(),
|
||||
main: p.main,
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_constant_symbol(
|
||||
&mut self,
|
||||
p: TypedConstantSymbol<'ast, T>,
|
||||
) -> TypedConstantSymbol<'ast, T> {
|
||||
match p {
|
||||
TypedConstantSymbol::There(module_id, id) => {
|
||||
let location = self.change_location(module_id);
|
||||
let symbol = self
|
||||
.module()
|
||||
.constants
|
||||
.as_ref()
|
||||
.and_then(|c| c.get(id))
|
||||
.unwrap()
|
||||
.to_owned();
|
||||
|
||||
let symbol = self.fold_constant_symbol(symbol);
|
||||
let _ = self.change_location(location);
|
||||
symbol
|
||||
}
|
||||
_ => fold_constant_symbol(self, p),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -39,7 +95,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
e: FieldElementExpression<'ast, T>,
|
||||
) -> FieldElementExpression<'ast, T> {
|
||||
match e {
|
||||
FieldElementExpression::Identifier(ref id) => match self.constants.get(id).cloned() {
|
||||
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => fold_field_expression(self, c.expression.try_into().unwrap()),
|
||||
None => fold_field_expression(self, e),
|
||||
},
|
||||
|
@ -52,7 +108,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
e: BooleanExpression<'ast, T>,
|
||||
) -> BooleanExpression<'ast, T> {
|
||||
match e {
|
||||
BooleanExpression::Identifier(ref id) => match self.constants.get(id).cloned() {
|
||||
BooleanExpression::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => fold_boolean_expression(self, c.expression.try_into().unwrap()),
|
||||
None => fold_boolean_expression(self, e),
|
||||
},
|
||||
|
@ -66,14 +122,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
e: UExpressionInner<'ast, T>,
|
||||
) -> UExpressionInner<'ast, T> {
|
||||
match e {
|
||||
UExpressionInner::Identifier(ref id) => match self.constants.get(id).cloned() {
|
||||
UExpressionInner::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => {
|
||||
let expr: UExpression<'ast, T> = c.expression.try_into().unwrap();
|
||||
fold_uint_expression(self, expr).into_inner()
|
||||
fold_uint_expression(self, c.expression.try_into().unwrap()).into_inner()
|
||||
}
|
||||
None => fold_uint_expression_inner(self, size, e),
|
||||
},
|
||||
// default
|
||||
e => fold_uint_expression_inner(self, size, e),
|
||||
}
|
||||
}
|
||||
|
@ -84,14 +138,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
e: ArrayExpressionInner<'ast, T>,
|
||||
) -> ArrayExpressionInner<'ast, T> {
|
||||
match e {
|
||||
ArrayExpressionInner::Identifier(ref id) => match self.constants.get(id).cloned() {
|
||||
ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => {
|
||||
let expr: ArrayExpression<'ast, T> = c.expression.try_into().unwrap();
|
||||
fold_array_expression(self, expr).into_inner()
|
||||
fold_array_expression(self, c.expression.try_into().unwrap()).into_inner()
|
||||
}
|
||||
None => fold_array_expression_inner(self, ty, e),
|
||||
},
|
||||
// default
|
||||
e => fold_array_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
|
@ -102,14 +154,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
e: StructExpressionInner<'ast, T>,
|
||||
) -> StructExpressionInner<'ast, T> {
|
||||
match e {
|
||||
StructExpressionInner::Identifier(ref id) => match self.constants.get(id).cloned() {
|
||||
StructExpressionInner::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => {
|
||||
let expr: StructExpression<'ast, T> = c.expression.try_into().unwrap();
|
||||
fold_struct_expression(self, expr).into_inner()
|
||||
fold_struct_expression(self, c.expression.try_into().unwrap()).into_inner()
|
||||
}
|
||||
None => fold_struct_expression_inner(self, ty, e),
|
||||
},
|
||||
// default
|
||||
e => fold_struct_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
|
@ -132,28 +182,29 @@ mod tests {
|
|||
// def main() -> field:
|
||||
// return a
|
||||
|
||||
let const_id = Identifier::from("a");
|
||||
let const_id = "a";
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(const_id.clone()).into(),
|
||||
FieldElementExpression::Identifier(Identifier::from(const_id)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let mut constants = TypedConstants::<Bn128Field>::new();
|
||||
constants.insert(
|
||||
const_id.clone(),
|
||||
TypedConstant {
|
||||
id: const_id.clone(),
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
const_id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
id: Identifier::from(const_id),
|
||||
ty: GType::FieldElement,
|
||||
expression: (TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(1),
|
||||
))),
|
||||
},
|
||||
);
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
|
@ -170,7 +221,7 @@ mod tests {
|
|||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Some(constants),
|
||||
constants: Some(constants.clone()),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -204,7 +255,7 @@ mod tests {
|
|||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: None,
|
||||
constants: Some(constants),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -221,11 +272,11 @@ mod tests {
|
|||
// def main() -> bool:
|
||||
// return a
|
||||
|
||||
let const_id = Identifier::from("a");
|
||||
let const_id = "a";
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier(
|
||||
const_id.clone(),
|
||||
Identifier::from(const_id),
|
||||
)
|
||||
.into()])],
|
||||
signature: DeclarationSignature::new()
|
||||
|
@ -233,15 +284,16 @@ mod tests {
|
|||
.outputs(vec![DeclarationType::Boolean]),
|
||||
};
|
||||
|
||||
let mut constants = TypedConstants::<Bn128Field>::new();
|
||||
constants.insert(
|
||||
const_id.clone(),
|
||||
TypedConstant {
|
||||
id: const_id.clone(),
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
const_id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
id: Identifier::from(const_id),
|
||||
ty: GType::Boolean,
|
||||
expression: (TypedExpression::Boolean(BooleanExpression::Value(true))),
|
||||
},
|
||||
);
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
|
@ -258,7 +310,7 @@ mod tests {
|
|||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Some(constants),
|
||||
constants: Some(constants.clone()),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -292,7 +344,7 @@ mod tests {
|
|||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: None,
|
||||
constants: Some(constants),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -309,11 +361,11 @@ mod tests {
|
|||
// def main() -> u32:
|
||||
// return a
|
||||
|
||||
let const_id = Identifier::from("a");
|
||||
let const_id = "a";
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier(
|
||||
const_id.clone(),
|
||||
Identifier::from(const_id),
|
||||
)
|
||||
.annotate(UBitwidth::B32)
|
||||
.into()])],
|
||||
|
@ -322,17 +374,18 @@ mod tests {
|
|||
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
|
||||
};
|
||||
|
||||
let mut constants = TypedConstants::<Bn128Field>::new();
|
||||
constants.insert(
|
||||
const_id.clone(),
|
||||
TypedConstant {
|
||||
id: const_id.clone(),
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
const_id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
id: Identifier::from(const_id),
|
||||
ty: GType::Uint(UBitwidth::B32),
|
||||
expression: (UExpressionInner::Value(1u128)
|
||||
.annotate(UBitwidth::B32)
|
||||
.into()),
|
||||
},
|
||||
);
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
|
@ -349,7 +402,7 @@ mod tests {
|
|||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Some(constants),
|
||||
constants: Some(constants.clone()),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -383,7 +436,7 @@ mod tests {
|
|||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: None,
|
||||
constants: Some(constants),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -400,18 +453,18 @@ mod tests {
|
|||
// def main() -> field:
|
||||
// return a[0] + a[1]
|
||||
|
||||
let const_id = Identifier::from("a");
|
||||
let const_id = "a";
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
|
||||
FieldElementExpression::Select(
|
||||
box ArrayExpressionInner::Identifier(const_id.clone())
|
||||
box ArrayExpressionInner::Identifier(Identifier::from(const_id))
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
box UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::Select(
|
||||
box ArrayExpressionInner::Identifier(const_id.clone())
|
||||
box ArrayExpressionInner::Identifier(Identifier::from(const_id))
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
box UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
|
||||
)
|
||||
|
@ -423,11 +476,10 @@ mod tests {
|
|||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let mut constants = TypedConstants::<Bn128Field>::new();
|
||||
constants.insert(
|
||||
const_id.clone(),
|
||||
TypedConstant {
|
||||
id: const_id.clone(),
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
const_id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
id: Identifier::from(const_id),
|
||||
ty: GType::FieldElement,
|
||||
expression: TypedExpression::Array(
|
||||
ArrayExpressionInner::Value(
|
||||
|
@ -439,8 +491,10 @@ mod tests {
|
|||
)
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
),
|
||||
},
|
||||
);
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
|
@ -457,7 +511,7 @@ mod tests {
|
|||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Some(constants),
|
||||
constants: Some(constants.clone()),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -515,7 +569,7 @@ mod tests {
|
|||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: None,
|
||||
constants: Some(constants),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -533,44 +587,19 @@ mod tests {
|
|||
// def main() -> field:
|
||||
// return b
|
||||
|
||||
let const_a_id = Identifier::from("a");
|
||||
let const_b_id = Identifier::from("b");
|
||||
let const_a_id = "a";
|
||||
let const_b_id = "b";
|
||||
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(const_b_id.clone()).into(),
|
||||
FieldElementExpression::Identifier(Identifier::from(const_b_id)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let mut constants = TypedConstants::<Bn128Field>::new();
|
||||
constants.extend(vec![
|
||||
(
|
||||
const_a_id.clone(),
|
||||
TypedConstant {
|
||||
id: const_a_id.clone(),
|
||||
ty: GType::FieldElement,
|
||||
expression: (TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(1),
|
||||
))),
|
||||
},
|
||||
),
|
||||
(
|
||||
const_b_id.clone(),
|
||||
TypedConstant {
|
||||
id: const_b_id.clone(),
|
||||
ty: GType::FieldElement,
|
||||
expression: (TypedExpression::FieldElement(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(const_a_id.clone()),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
))),
|
||||
},
|
||||
),
|
||||
]);
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
|
@ -586,7 +615,37 @@ mod tests {
|
|||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Some(constants),
|
||||
constants: Some(
|
||||
vec![
|
||||
(
|
||||
const_a_id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
id: Identifier::from(const_a_id),
|
||||
ty: GType::FieldElement,
|
||||
expression: (TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
)),
|
||||
}),
|
||||
),
|
||||
(
|
||||
const_b_id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
id: Identifier::from(const_b_id),
|
||||
ty: GType::FieldElement,
|
||||
expression: (TypedExpression::FieldElement(
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(
|
||||
Identifier::from(const_a_id),
|
||||
),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
),
|
||||
)),
|
||||
}),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -622,7 +681,35 @@ mod tests {
|
|||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: None,
|
||||
constants: Some(
|
||||
vec![
|
||||
(
|
||||
const_a_id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
id: Identifier::from(const_a_id),
|
||||
ty: GType::FieldElement,
|
||||
expression: (TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
)),
|
||||
}),
|
||||
),
|
||||
(
|
||||
const_b_id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
id: Identifier::from(const_b_id),
|
||||
ty: GType::FieldElement,
|
||||
expression: (TypedExpression::FieldElement(
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
),
|
||||
)),
|
||||
}),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -631,4 +718,139 @@ mod tests {
|
|||
|
||||
assert_eq!(program, expected_program)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inline_imported_constant() {
|
||||
// ---------------------
|
||||
// module `foo`
|
||||
// --------------------
|
||||
// const field FOO = 42
|
||||
//
|
||||
// def main():
|
||||
// return
|
||||
//
|
||||
// ---------------------
|
||||
// module `main`
|
||||
// ---------------------
|
||||
// from "foo" import FOO
|
||||
//
|
||||
// def main() -> field:
|
||||
// return FOO
|
||||
|
||||
let foo_const_id = "FOO";
|
||||
let foo_module = TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main")
|
||||
.signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![],
|
||||
signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Some(
|
||||
vec![(
|
||||
foo_const_id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
id: Identifier::from(foo_const_id),
|
||||
ty: GType::FieldElement,
|
||||
expression: (TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(42)),
|
||||
)),
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
};
|
||||
|
||||
let main_module = TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(Identifier::from(foo_const_id)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Some(
|
||||
vec![(
|
||||
foo_const_id,
|
||||
TypedConstantSymbol::There(OwnedTypedModuleId::from("foo"), foo_const_id),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
};
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![
|
||||
("main".into(), main_module),
|
||||
("foo".into(), foo_module.clone()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let program = ConstantInliner::inline(program);
|
||||
let expected_main_module = TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(42)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Some(
|
||||
vec![(
|
||||
foo_const_id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
id: Identifier::from(foo_const_id),
|
||||
ty: GType::FieldElement,
|
||||
expression: (TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(42)),
|
||||
)),
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
};
|
||||
|
||||
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![
|
||||
("main".into(), expected_main_module),
|
||||
("foo".into(), foo_module),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
assert_eq!(program, expected_program)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,6 +13,13 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
fold_module(self, p)
|
||||
}
|
||||
|
||||
fn fold_constant_symbol(
|
||||
&mut self,
|
||||
p: TypedConstantSymbol<'ast, T>,
|
||||
) -> TypedConstantSymbol<'ast, T> {
|
||||
fold_constant_symbol(self, p)
|
||||
}
|
||||
|
||||
fn fold_function_symbol(
|
||||
&mut self,
|
||||
s: TypedFunctionSymbol<'ast, T>,
|
||||
|
@ -193,12 +200,16 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
p: TypedModule<'ast, T>,
|
||||
) -> TypedModule<'ast, T> {
|
||||
TypedModule {
|
||||
constants: p.constants.map(|tc| {
|
||||
tc.into_iter()
|
||||
.map(|(key, tc)| (key, f.fold_constant_symbol(tc)))
|
||||
.collect()
|
||||
}),
|
||||
functions: p
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|(key, fun)| (key, f.fold_function_symbol(fun)))
|
||||
.collect(),
|
||||
constants: p.constants,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -692,6 +703,19 @@ pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
p: TypedConstantSymbol<'ast, T>,
|
||||
) -> TypedConstantSymbol<'ast, T> {
|
||||
match p {
|
||||
TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(TypedConstant {
|
||||
expression: f.fold_expression(tc.expression),
|
||||
..tc
|
||||
}),
|
||||
there => there,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: TypedFunctionSymbol<'ast, T>,
|
||||
|
|
|
@ -61,8 +61,17 @@ pub type TypedModules<'ast, T> = HashMap<OwnedTypedModuleId, TypedModule<'ast, T
|
|||
pub type TypedFunctionSymbols<'ast, T> =
|
||||
HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>;
|
||||
|
||||
/// A collection of `TypedConstant`s
|
||||
pub type TypedConstants<'ast, T> = HashMap<Identifier<'ast>, TypedConstant<'ast, T>>;
|
||||
pub type ConstantIdentifier<'ast> = &'ast str;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum TypedConstantSymbol<'ast, T> {
|
||||
Here(TypedConstant<'ast, T>),
|
||||
There(OwnedTypedModuleId, ConstantIdentifier<'ast>),
|
||||
}
|
||||
|
||||
/// A collection of `TypedConstantSymbol`s
|
||||
pub type TypedConstantSymbols<'ast, T> =
|
||||
HashMap<ConstantIdentifier<'ast>, TypedConstantSymbol<'ast, T>>;
|
||||
|
||||
/// A typed program as a collection of modules, one of them being the main
|
||||
#[derive(PartialEq, Debug, Clone)]
|
||||
|
@ -144,7 +153,7 @@ pub struct TypedModule<'ast, T> {
|
|||
/// Functions of the module
|
||||
pub functions: TypedFunctionSymbols<'ast, T>,
|
||||
/// Constants defined in module
|
||||
pub constants: Option<TypedConstants<'ast, T>>,
|
||||
pub constants: Option<TypedConstantSymbols<'ast, T>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
|
@ -320,7 +329,11 @@ pub struct TypedConstant<'ast, T> {
|
|||
|
||||
impl<'ast, T: fmt::Debug> fmt::Debug for TypedConstant<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "TypedConstant({:?}, {:?}, ...)", self.id, self.ty)
|
||||
write!(
|
||||
f,
|
||||
"TypedConstant({:?}, {:?}, {:?})",
|
||||
self.id, self.ty, self.expression
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
16
zokrates_core_test/tests/tests/constants/array.json
Normal file
16
zokrates_core_test/tests/tests/constants/array.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/constants/array.zok",
|
||||
"max_constraint_count": 2,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1", "2"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
4
zokrates_core_test/tests/tests/constants/array.zok
Normal file
4
zokrates_core_test/tests/tests/constants/array.zok
Normal file
|
@ -0,0 +1,4 @@
|
|||
const field[2] ARRAY = [1, 2]
|
||||
|
||||
def main() -> field[2]:
|
||||
return ARRAY
|
16
zokrates_core_test/tests/tests/constants/bool.json
Normal file
16
zokrates_core_test/tests/tests/constants/bool.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/constants/bool.zok",
|
||||
"max_constraint_count": 1,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
4
zokrates_core_test/tests/tests/constants/bool.zok
Normal file
4
zokrates_core_test/tests/tests/constants/bool.zok
Normal file
|
@ -0,0 +1,4 @@
|
|||
const bool BOOLEAN = true
|
||||
|
||||
def main() -> bool:
|
||||
return BOOLEAN
|
16
zokrates_core_test/tests/tests/constants/field.json
Normal file
16
zokrates_core_test/tests/tests/constants/field.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/constants/field.zok",
|
||||
"max_constraint_count": 1,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
4
zokrates_core_test/tests/tests/constants/field.zok
Normal file
4
zokrates_core_test/tests/tests/constants/field.zok
Normal file
|
@ -0,0 +1,4 @@
|
|||
const field ONE = 1
|
||||
|
||||
def main() -> field:
|
||||
return ONE
|
16
zokrates_core_test/tests/tests/constants/nested.json
Normal file
16
zokrates_core_test/tests/tests/constants/nested.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/constants/nested.zok",
|
||||
"max_constraint_count": 1,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["8"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
6
zokrates_core_test/tests/tests/constants/nested.zok
Normal file
6
zokrates_core_test/tests/tests/constants/nested.zok
Normal file
|
@ -0,0 +1,6 @@
|
|||
const field A = 2
|
||||
const field B = 2
|
||||
const field[2] ARRAY = [A * 2, B * 2]
|
||||
|
||||
def main() -> field:
|
||||
return ARRAY[0] + ARRAY[1]
|
16
zokrates_core_test/tests/tests/constants/struct.json
Normal file
16
zokrates_core_test/tests/tests/constants/struct.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/constants/struct.zok",
|
||||
"max_constraint_count": 1,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["4"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
9
zokrates_core_test/tests/tests/constants/struct.zok
Normal file
9
zokrates_core_test/tests/tests/constants/struct.zok
Normal file
|
@ -0,0 +1,9 @@
|
|||
struct Foo {
|
||||
field a
|
||||
field b
|
||||
}
|
||||
|
||||
const Foo FOO = Foo { a: 2, b: 2 }
|
||||
|
||||
def main() -> field:
|
||||
return FOO.a + FOO.b
|
16
zokrates_core_test/tests/tests/constants/uint.json
Normal file
16
zokrates_core_test/tests/tests/constants/uint.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/constants/uint.zok",
|
||||
"max_constraint_count": 1,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["1"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
4
zokrates_core_test/tests/tests/constants/uint.zok
Normal file
4
zokrates_core_test/tests/tests/constants/uint.zok
Normal file
|
@ -0,0 +1,4 @@
|
|||
const u32 ONE = 0x00000001
|
||||
|
||||
def main() -> u32:
|
||||
return ONE
|
Loading…
Reference in a new issue