1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge pull request #792 from Zokrates/constant-def

Introduce constant definitions
This commit is contained in:
Thibaut Schaeffer 2021-04-30 11:11:31 +02:00 committed by GitHub
commit 0d804fa6a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 2700 additions and 841 deletions

View file

@ -0,0 +1 @@
Introduce constant definitions to the language (`const` keyword)

View file

@ -8,11 +8,12 @@
- [Variables](language/variables.md)
- [Types](language/types.md)
- [Operators](language/operators.md)
- [Functions](language/functions.md)
- [Control flow](language/control_flow.md)
- [Constants](language/constants.md)
- [Functions](language/functions.md)
- [Generics](language/generics.md)
- [Imports](language/imports.md)
- [Comments](language/comments.md)
- [Generics](language/generics.md)
- [Macros](language/macros.md)
- [Toolbox](toolbox/index.md)

View file

@ -0,0 +1,17 @@
## Constants
Constants must be globally defined outside all other scopes by using a `const` keyword. Constants can be set only to a constant expression.
```zokrates
{{#include ../../../zokrates_cli/examples/book/constant_definition.zok}}
```
The value of a constant can't be changed through reassignment, and it can't be redeclared.
Constants must be explicitly typed. One can reference other constants inside the expression, as long as the referenced constant is already defined.
```zokrates
{{#include ../../../zokrates_cli/examples/book/constant_reference.zok}}
```
The naming convention for constants are similar to that of variables. All characters in a constant name are usually in uppercase.

View file

@ -44,7 +44,7 @@ from "./path/to/my/module" import main as module
Note that this legacy method is likely to become deprecated, so it is recommended to use the preferred way instead.
### Symbols
Two types of symbols can be imported
Three types of symbols can be imported
#### Functions
Functions are imported by name. If many functions have the same name but different signatures, all of them get imported, and which one to use in a particular call is inferred.
@ -52,6 +52,9 @@ Functions are imported by name. If many functions have the same name but differe
#### User-defined types
User-defined types declared with the `struct` keyword are imported by name.
#### Constants
Constants declared with the `const` keyword are imported by name.
### Relative Imports
You can import a resource in the same folder directly, like this:

View file

@ -0,0 +1,4 @@
const field PRIME = 31
def main() -> field:
return PRIME

View file

@ -0,0 +1,5 @@
const field ONE = 1
const field TWO = ONE + ONE
def main() -> field:
return TWO

View file

@ -0,0 +1,5 @@
const field a = 1
def main() -> field:
a = 2 // not allowed
return a

View file

@ -1,10 +1,10 @@
def const() -> field:
def constant() -> field:
return 123123
def add(field a,field b) -> field:
a=const()
return a+b
def add(field a, field b) -> field:
a = constant()
return a + b
def main(field a,field b) -> field:
field c = add(a, b+const())
return const()
def main(field a, field b) -> field:
field c = add(a, b + constant())
return constant()

View file

@ -1,5 +1,7 @@
struct Bar {
}
struct Bar {}
const field ONE = 1
const field BAR = 21 * ONE
def main() -> field:
return 21
return BAR

View file

@ -1,5 +1,6 @@
struct Baz {
}
struct Baz {}
const field BAZ = 123
def main() -> field:
return 123
return BAZ

View file

@ -1,9 +1,10 @@
from "./baz" import Baz
import "./baz"
from "./baz" import main as my_function
import "./baz"
const field FOO = 144
def main() -> field:
field a = my_function()
Baz b = Baz {}
return baz()
Baz b = Baz {}
assert(baz() == my_function())
return FOO

View file

@ -1,11 +0,0 @@
from "./bar" import Bar as MyBar
from "./bar" import Bar
import "./foo"
import "./bar"
def main() -> field:
MyBar my_bar = MyBar {}
Bar bar = Bar {}
assert(my_bar == bar)
return foo() + bar()

View file

@ -0,0 +1,6 @@
from "./foo" import FOO
from "./bar" import BAR
from "./baz" import BAZ
def main() -> bool:
return FOO == BAR + BAZ

View file

@ -0,0 +1,6 @@
import "./foo"
import "./bar"
import "./baz"
def main() -> bool:
return foo() == bar() + baz()

View file

@ -0,0 +1,8 @@
from "./bar" import Bar as MyBar
from "./bar" import Bar
def main():
MyBar my_bar = MyBar {}
Bar bar = Bar {}
assert(my_bar == bar)
return

View file

@ -1,4 +1,8 @@
import "./foo" as d
from "./bar" import main as bar
from "./baz" import BAZ as baz
import "./foo" as f
def main() -> field:
return d()
field foo = f()
assert(foo == bar() + baz)
return foo

View file

@ -1,6 +1,7 @@
use crate::absy;
use crate::imports;
use crate::absy::SymbolDefinition;
use num_bigint::BigUint;
use zokrates_pest_ast as pest;
@ -10,6 +11,11 @@ impl<'ast> From<pest::File<'ast>> for absy::Module<'ast> {
prog.structs
.into_iter()
.map(absy::SymbolDeclarationNode::from)
.chain(
prog.constants
.into_iter()
.map(absy::SymbolDeclarationNode::from),
)
.chain(
prog.functions
.into_iter()
@ -78,7 +84,7 @@ impl<'ast> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'a
absy::SymbolDeclaration {
id,
symbol: absy::Symbol::HereType(ty),
symbol: absy::Symbol::Here(SymbolDefinition::Struct(ty)),
}
.span(span)
}
@ -98,6 +104,27 @@ impl<'ast> From<pest::StructField<'ast>> for absy::StructDefinitionFieldNode<'as
}
}
impl<'ast> From<pest::ConstantDefinition<'ast>> for absy::SymbolDeclarationNode<'ast> {
fn from(definition: pest::ConstantDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> {
use crate::absy::NodeValue;
let span = definition.span;
let id = definition.id.span.as_str();
let ty = absy::ConstantDefinition {
ty: definition.ty.into(),
expression: definition.expression.into(),
}
.span(span.clone());
absy::SymbolDeclaration {
id,
symbol: absy::Symbol::Here(SymbolDefinition::Constant(ty)),
}
.span(span)
}
}
impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
fn from(function: pest::Function<'ast>) -> absy::SymbolDeclarationNode<'ast> {
use crate::absy::NodeValue;
@ -148,7 +175,7 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
absy::SymbolDeclaration {
id,
symbol: absy::Symbol::HereFunction(function),
symbol: absy::Symbol::Here(SymbolDefinition::Function(function)),
}
.span(span)
}
@ -754,7 +781,7 @@ mod tests {
let expected: absy::Module = absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::Symbol::HereFunction(
symbol: absy::Symbol::Here(SymbolDefinition::Function(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -771,7 +798,7 @@ mod tests {
.outputs(vec![UnresolvedType::FieldElement.mock()]),
}
.into(),
),
)),
}
.into()],
imports: vec![],
@ -786,7 +813,7 @@ mod tests {
let expected: absy::Module = absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::Symbol::HereFunction(
symbol: absy::Symbol::Here(SymbolDefinition::Function(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -801,7 +828,7 @@ mod tests {
.outputs(vec![UnresolvedType::Boolean.mock()]),
}
.into(),
),
)),
}
.into()],
imports: vec![],
@ -817,7 +844,7 @@ mod tests {
let expected: absy::Module = absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::Symbol::HereFunction(
symbol: absy::Symbol::Here(SymbolDefinition::Function(
absy::Function {
arguments: vec![
absy::Parameter::private(
@ -854,7 +881,7 @@ mod tests {
.outputs(vec![UnresolvedType::FieldElement.mock()]),
}
.into(),
),
)),
}
.into()],
imports: vec![],
@ -871,7 +898,7 @@ mod tests {
absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: "main",
symbol: absy::Symbol::HereFunction(
symbol: absy::Symbol::Here(SymbolDefinition::Function(
absy::Function {
arguments: vec![absy::Parameter::private(
absy::Variable::new("a", ty.clone().mock()).into(),
@ -887,7 +914,7 @@ mod tests {
signature: UnresolvedSignature::new().inputs(vec![ty.mock()]),
}
.into(),
),
)),
}
.into()],
imports: vec![],
@ -945,7 +972,7 @@ mod tests {
absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: "main",
symbol: absy::Symbol::HereFunction(
symbol: absy::Symbol::Here(SymbolDefinition::Function(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -958,7 +985,7 @@ mod tests {
signature: UnresolvedSignature::new(),
}
.into(),
),
)),
}
.into()],
imports: vec![],

View file

@ -51,10 +51,27 @@ pub struct SymbolDeclaration<'ast> {
pub symbol: Symbol<'ast>,
}
#[allow(clippy::large_enum_variant)]
#[derive(PartialEq, Clone)]
pub enum SymbolDefinition<'ast> {
Struct(StructDefinitionNode<'ast>),
Constant(ConstantDefinitionNode<'ast>),
Function(FunctionNode<'ast>),
}
impl<'ast> fmt::Debug for SymbolDefinition<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SymbolDefinition::Struct(s) => write!(f, "Struct({:?})", s),
SymbolDefinition::Constant(c) => write!(f, "Constant({:?})", c),
SymbolDefinition::Function(func) => write!(f, "Function({:?})", func),
}
}
}
#[derive(PartialEq, Clone)]
pub enum Symbol<'ast> {
HereType(StructDefinitionNode<'ast>),
HereFunction(FunctionNode<'ast>),
Here(SymbolDefinition<'ast>),
There(SymbolImportNode<'ast>),
Flat(FlatEmbed),
}
@ -62,9 +79,8 @@ pub enum Symbol<'ast> {
impl<'ast> fmt::Debug for Symbol<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Symbol::HereType(t) => write!(f, "HereType({:?})", t),
Symbol::HereFunction(fun) => write!(f, "HereFunction({:?})", fun),
Symbol::There(t) => write!(f, "There({:?})", t),
Symbol::Here(k) => write!(f, "Here({:?})", k),
Symbol::There(i) => write!(f, "There({:?})", i),
Symbol::Flat(flat) => write!(f, "Flat({:?})", flat),
}
}
@ -73,8 +89,15 @@ impl<'ast> fmt::Debug for Symbol<'ast> {
impl<'ast> fmt::Display for SymbolDeclaration<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.symbol {
Symbol::HereType(ref t) => write!(f, "struct {} {}", self.id, t),
Symbol::HereFunction(ref fun) => write!(f, "def {}{}", self.id, fun),
Symbol::Here(ref kind) => match kind {
SymbolDefinition::Struct(t) => write!(f, "struct {} {}", self.id, t),
SymbolDefinition::Constant(c) => write!(
f,
"const {} {} = {}",
c.value.ty, self.id, c.value.expression
),
SymbolDefinition::Function(func) => write!(f, "def {}{}", self.id, func),
},
Symbol::There(ref import) => write!(f, "import {} as {}", import, self.id),
Symbol::Flat(ref flat_fun) => {
write!(f, "def {}{}:\n\t// hidden", self.id, flat_fun.signature())
@ -146,6 +169,30 @@ impl<'ast> fmt::Display for StructDefinitionField<'ast> {
type StructDefinitionFieldNode<'ast> = Node<StructDefinitionField<'ast>>;
#[derive(Clone, PartialEq)]
pub struct ConstantDefinition<'ast> {
pub ty: UnresolvedTypeNode<'ast>,
pub expression: ExpressionNode<'ast>,
}
pub type ConstantDefinitionNode<'ast> = Node<ConstantDefinition<'ast>>;
impl<'ast> fmt::Display for ConstantDefinition<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "const {}({})", self.ty, self.expression)
}
}
impl<'ast> fmt::Debug for ConstantDefinition<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"ConstantDefinition({:?}, {:?})",
self.ty, self.expression
)
}
}
/// An import
#[derive(Debug, Clone, PartialEq)]
pub struct SymbolImport<'ast> {

View file

@ -84,6 +84,7 @@ impl<'ast> NodeValue for SymbolDeclaration<'ast> {}
impl<'ast> NodeValue for UnresolvedType<'ast> {}
impl<'ast> NodeValue for StructDefinition<'ast> {}
impl<'ast> NodeValue for StructDefinitionField<'ast> {}
impl<'ast> NodeValue for ConstantDefinition<'ast> {}
impl<'ast> NodeValue for Function<'ast> {}
impl<'ast> NodeValue for Module<'ast> {}
impl<'ast> NodeValue for SymbolImport<'ast> {}

View file

@ -47,6 +47,8 @@ impl ErrorInner {
}
type TypeMap<'ast> = HashMap<OwnedModuleId, HashMap<UserTypeId, DeclarationType<'ast>>>;
type ConstantMap<'ast, T> =
HashMap<OwnedModuleId, HashMap<ConstantIdentifier<'ast>, Type<'ast, T>>>;
/// The global state of the program during semantic checks
#[derive(Debug)]
@ -57,12 +59,15 @@ struct State<'ast, T> {
typed_modules: TypedModules<'ast, T>,
/// The user-defined types, which we keep track at this phase only. In later phases, we rely only on basic types and combinations thereof
types: TypeMap<'ast>,
// The user-defined constants
constants: ConstantMap<'ast, T>,
}
/// A symbol for a given name: either a type or a group of functions. Not both!
#[derive(PartialEq, Hash, Eq, Debug)]
enum SymbolType<'ast> {
Type,
Constant,
Functions(BTreeSet<DeclarationSignature<'ast>>),
}
@ -74,9 +79,9 @@ struct SymbolUnifier<'ast> {
impl<'ast> SymbolUnifier<'ast> {
fn insert_type<S: Into<String>>(&mut self, id: S) -> bool {
let s_type = self.symbols.entry(id.into());
match s_type {
// if anything is already called `id`, we cannot introduce this type
let e = self.symbols.entry(id.into());
match e {
// if anything is already called `id`, we cannot introduce the symbol
Entry::Occupied(..) => false,
// otherwise, we can!
Entry::Vacant(v) => {
@ -86,6 +91,19 @@ impl<'ast> SymbolUnifier<'ast> {
}
}
fn insert_constant<S: Into<String>>(&mut self, id: S) -> bool {
let e = self.symbols.entry(id.into());
match e {
// if anything is already called `id`, we cannot introduce this constant
Entry::Occupied(..) => false,
// otherwise, we can!
Entry::Vacant(v) => {
v.insert(SymbolType::Constant);
true
}
}
}
fn insert_function<S: Into<String>>(
&mut self,
id: S,
@ -96,8 +114,8 @@ impl<'ast> SymbolUnifier<'ast> {
// if anything is already called `id`, it depends what it is
Entry::Occupied(mut o) => {
match o.get_mut() {
// if it's a Type, then we can't introduce a function
SymbolType::Type => false,
// if it's a Type or a Constant, then we can't introduce a function
SymbolType::Type | SymbolType::Constant => false,
// if it's a Function, we can introduce it only if it has a different signature
SymbolType::Functions(signatures) => signatures.insert(signature),
}
@ -117,6 +135,7 @@ impl<'ast, T: Field> State<'ast, T> {
modules,
typed_modules: HashMap::new(),
types: HashMap::new(),
constants: HashMap::new(),
}
}
}
@ -248,6 +267,12 @@ pub struct ScopedVariable<'ast, T> {
level: usize,
}
impl<'ast, T> ScopedVariable<'ast, T> {
fn is_constant(&self) -> bool {
self.level == 0
}
}
/// Identifiers of different `ScopedVariable`s should not conflict, so we define them as equivalent
impl<'ast, T> PartialEq for ScopedVariable<'ast, T> {
fn eq(&self, other: &Self) -> bool {
@ -325,6 +350,50 @@ impl<'ast, T: Field> Checker<'ast, T> {
})
}
fn check_constant_definition(
&mut self,
id: &'ast str,
c: ConstantDefinitionNode<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
) -> Result<TypedConstant<'ast, T>, ErrorInner> {
let pos = c.pos();
let ty = self.check_type(c.value.ty.clone(), module_id, &types)?;
let checked_expr = self.check_expression(c.value.expression.clone(), module_id, types)?;
match ty {
Type::FieldElement => {
FieldElementExpression::try_from_typed(checked_expr).map(TypedExpression::from)
}
Type::Boolean => {
BooleanExpression::try_from_typed(checked_expr).map(TypedExpression::from)
}
Type::Uint(bitwidth) => {
UExpression::try_from_typed(checked_expr, bitwidth).map(TypedExpression::from)
}
Type::Array(ref array_ty) => {
ArrayExpression::try_from_typed(checked_expr, *array_ty.ty.clone())
.map(TypedExpression::from)
}
Type::Struct(ref struct_ty) => {
StructExpression::try_from_typed(checked_expr, struct_ty.clone())
.map(TypedExpression::from)
}
Type::Int => Err(checked_expr), // Integers cannot be assigned
}
.map_err(|e| ErrorInner {
pos: Some(pos),
message: format!(
"Expression `{}` of type `{}` cannot be assigned to constant `{}` of type `{}`",
e,
e.get_type(),
id,
ty
),
})
.map(|e| TypedConstant::new(ty, e))
}
fn check_struct_type_declaration(
&mut self,
id: String,
@ -378,6 +447,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
module_id: &ModuleId,
state: &mut State<'ast, T>,
functions: &mut HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>,
constants: &mut HashMap<ConstantIdentifier<'ast>, TypedConstantSymbol<'ast, T>>,
symbol_unifier: &mut SymbolUnifier<'ast>,
) -> Result<(), Vec<Error>> {
let mut errors: Vec<Error> = vec![];
@ -386,76 +456,120 @@ impl<'ast, T: Field> Checker<'ast, T> {
let declaration = declaration.value;
match declaration.symbol.clone() {
Symbol::HereType(t) => {
match self.check_struct_type_declaration(
declaration.id.to_string(),
t.clone(),
module_id,
&state.types,
) {
Ok(ty) => {
match symbol_unifier.insert_type(declaration.id) {
false => errors.push(
ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
),
}
.in_file(module_id),
),
true => {
// there should be no entry in the map for this type yet
assert!(state
.types
.entry(module_id.to_path_buf())
.or_default()
.insert(declaration.id.to_string(), ty)
.is_none());
}
};
}
Err(e) => errors.extend(e.into_iter().map(|inner| Error {
inner,
module_id: module_id.to_path_buf(),
})),
}
}
Symbol::HereFunction(f) => match self.check_function(f, module_id, &state.types) {
Ok(funct) => {
match symbol_unifier.insert_function(declaration.id, funct.signature.clone()) {
false => errors.push(
ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
Symbol::Here(kind) => match kind {
SymbolDefinition::Struct(t) => {
match self.check_struct_type_declaration(
declaration.id.to_string(),
t.clone(),
module_id,
&state.types,
) {
Ok(ty) => {
match symbol_unifier.insert_type(declaration.id) {
false => errors.push(
ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
),
}
.in_file(module_id),
),
}
.in_file(module_id),
),
true => {}
};
self.functions.insert(
DeclarationFunctionKey::with_location(
module_id.to_path_buf(),
declaration.id,
)
.signature(funct.signature.clone()),
);
functions.insert(
DeclarationFunctionKey::with_location(
module_id.to_path_buf(),
declaration.id,
)
.signature(funct.signature.clone()),
TypedFunctionSymbol::Here(funct),
);
true => {
// there should be no entry in the map for this type yet
assert!(state
.types
.entry(module_id.to_path_buf())
.or_default()
.insert(declaration.id.to_string(), ty)
.is_none());
}
};
}
Err(e) => errors.extend(e.into_iter().map(|inner| Error {
inner,
module_id: module_id.to_path_buf(),
})),
}
}
Err(e) => {
errors.extend(e.into_iter().map(|inner| inner.in_file(module_id)));
SymbolDefinition::Constant(c) => {
match self.check_constant_definition(declaration.id, c, module_id, &state.types)
{
Ok(c) => {
match symbol_unifier.insert_constant(declaration.id) {
false => errors.push(
ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
),
}
.in_file(module_id),
),
true => {
constants.insert(
declaration.id,
TypedConstantSymbol::Here(c.clone()),
);
self.insert_into_scope(Variable::with_id_and_type(
declaration.id,
c.get_type(),
));
assert!(state
.constants
.entry(module_id.to_path_buf())
.or_default()
.insert(declaration.id, c.get_type())
.is_none());
}
};
}
Err(e) => {
errors.push(e.in_file(module_id));
}
}
}
SymbolDefinition::Function(f) => {
match self.check_function(f, module_id, &state.types) {
Ok(funct) => {
match symbol_unifier
.insert_function(declaration.id, funct.signature.clone())
{
false => errors.push(
ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
),
}
.in_file(module_id),
),
true => {}
};
self.functions.insert(
DeclarationFunctionKey::with_location(
module_id.to_path_buf(),
declaration.id,
)
.signature(funct.signature.clone()),
);
functions.insert(
DeclarationFunctionKey::with_location(
module_id.to_path_buf(),
declaration.id,
)
.signature(funct.signature.clone()),
TypedFunctionSymbol::Here(funct),
);
}
Err(e) => {
errors.extend(e.into_iter().map(|inner| inner.in_file(module_id)));
}
}
}
},
Symbol::There(import) => {
@ -487,8 +601,16 @@ impl<'ast, T: Field> Checker<'ast, T> {
.get(import.symbol_id)
.cloned();
match (function_candidates.len(), type_candidate) {
(0, Some(t)) => {
// find constant definition candidate
let const_candidate = state
.constants
.entry(import.module_id.to_path_buf())
.or_default()
.get(import.symbol_id)
.cloned();
match (function_candidates.len(), type_candidate, const_candidate) {
(0, Some(t), None) => {
// rename the type to the declared symbol
let t = match t {
@ -523,7 +645,32 @@ impl<'ast, T: Field> Checker<'ast, T> {
.or_default()
.insert(declaration.id.to_string(), t);
}
(0, None) => {
(0, None, Some(ty)) => {
match symbol_unifier.insert_constant(declaration.id) {
false => {
errors.push(Error {
module_id: module_id.to_path_buf(),
inner: ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
),
}});
}
true => {
constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id, import.symbol_id));
self.insert_into_scope(Variable::with_id_and_type(declaration.id, ty.clone()));
state
.constants
.entry(module_id.to_path_buf())
.or_default()
.insert(declaration.id, ty);
}
};
}
(0, None, None) => {
errors.push(ErrorInner {
pos: Some(pos),
message: format!(
@ -532,7 +679,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
),
}.in_file(module_id));
}
(_, Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"),
(_, Some(_), Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"),
_ => {
for candidate in function_candidates {
@ -609,6 +756,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
state: &mut State<'ast, T>,
) -> Result<(), Vec<Error>> {
let mut checked_functions = HashMap::new();
let mut checked_constants = HashMap::new();
// check if the module was already removed from the untyped ones
let to_insert = match state.modules.remove(module_id) {
@ -621,7 +769,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
// we need to create an entry in the types map to store types for this module
state.types.entry(module_id.to_path_buf()).or_default();
// we keep track of the introduced symbols to avoid colisions between types and functions
// we keep track of the introduced symbols to avoid collisions between types and functions
let mut symbol_unifier = SymbolUnifier::default();
// we go through symbol declarations and check them
@ -631,12 +779,14 @@ impl<'ast, T: Field> Checker<'ast, T> {
module_id,
state,
&mut checked_functions,
&mut checked_constants,
&mut symbol_unifier,
)?
}
Some(TypedModule {
functions: checked_functions,
constants: checked_constants,
})
}
};
@ -688,7 +838,6 @@ impl<'ast, T: Field> Checker<'ast, T> {
module_id: &ModuleId,
types: &TypeMap<'ast>,
) -> Result<TypedFunction<'ast, T>, Vec<ErrorInner>> {
assert!(self.scope.is_empty());
assert!(self.return_types.is_none());
self.enter_scope();
@ -815,7 +964,6 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
self.return_types = None;
assert!(self.scope.is_empty());
Ok(TypedFunction {
arguments: arguments_checked,
@ -1275,7 +1423,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.map_err(|e| ErrorInner {
pos: Some(pos),
message: format!(
"Expression {} of type {} cannot be assigned to {} of type {}",
"Expression `{}` of type `{}` cannot be assigned to `{}` of type `{}`",
e,
e.get_type(),
var.clone(),
@ -1408,10 +1556,16 @@ impl<'ast, T: Field> Checker<'ast, T> {
// check that the assignee is declared
match assignee.value {
Assignee::Identifier(variable_name) => match self.get_scope(&variable_name) {
Some(var) => Ok(TypedAssignee::Identifier(Variable::with_id_and_type(
variable_name,
var.id._type.clone(),
))),
Some(var) => match var.is_constant() {
true => Err(ErrorInner {
pos: Some(assignee.pos()),
message: format!("Assignment to constant variable `{}`", variable_name),
}),
false => Ok(TypedAssignee::Identifier(Variable::with_id_and_type(
variable_name,
var.id._type.clone(),
))),
},
None => Err(ErrorInner {
pos: Some(assignee.pos()),
message: format!("Variable `{}` is undeclared", variable_name),
@ -3094,7 +3248,7 @@ mod tests {
let foo: Module = Module {
symbols: vec![SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(function0()),
symbol: Symbol::Here(SymbolDefinition::Function(function0())),
}
.mock()],
imports: vec![],
@ -3134,6 +3288,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: TypedConstantSymbols::default()
})
);
}
@ -3151,12 +3306,12 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(function0()),
symbol: Symbol::Here(SymbolDefinition::Function(function0())),
}
.mock(),
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(function0()),
symbol: Symbol::Here(SymbolDefinition::Function(function0())),
}
.mock(),
],
@ -3230,12 +3385,12 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(f0),
symbol: Symbol::Here(SymbolDefinition::Function(f0)),
}
.mock(),
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(f1),
symbol: Symbol::Here(SymbolDefinition::Function(f1)),
}
.mock(),
],
@ -3268,12 +3423,12 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(foo),
symbol: Symbol::Here(SymbolDefinition::Function(foo)),
}
.mock(),
SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(function0()),
symbol: Symbol::Here(SymbolDefinition::Function(function0())),
}
.mock(),
],
@ -3321,12 +3476,12 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(foo),
symbol: Symbol::Here(SymbolDefinition::Function(foo)),
}
.mock(),
SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(function0()),
symbol: Symbol::Here(SymbolDefinition::Function(function0())),
}
.mock(),
],
@ -3361,12 +3516,12 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(function0()),
symbol: Symbol::Here(SymbolDefinition::Function(function0())),
}
.mock(),
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(function1()),
symbol: Symbol::Here(SymbolDefinition::Function(function1())),
}
.mock(),
],
@ -3411,12 +3566,12 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereType(struct0()),
symbol: Symbol::Here(SymbolDefinition::Struct(struct0())),
}
.mock(),
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereType(struct1()),
symbol: Symbol::Here(SymbolDefinition::Struct(struct1())),
}
.mock(),
],
@ -3448,12 +3603,14 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(function0()),
symbol: Symbol::Here(SymbolDefinition::Function(function0())),
}
.mock(),
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereType(StructDefinition { fields: vec![] }.mock()),
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition { fields: vec![] }.mock(),
)),
}
.mock(),
],
@ -3488,7 +3645,7 @@ mod tests {
let bar = Module::with_symbols(vec![SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(function0()),
symbol: Symbol::Here(SymbolDefinition::Function(function0())),
}
.mock()]);
@ -3503,7 +3660,7 @@ mod tests {
.mock(),
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereType(struct0()),
symbol: Symbol::Here(SymbolDefinition::Struct(struct0())),
}
.mock(),
],
@ -3537,7 +3694,7 @@ mod tests {
let bar = Module::with_symbols(vec![SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(function0()),
symbol: Symbol::Here(SymbolDefinition::Function(function0())),
}
.mock()]);
@ -3545,7 +3702,7 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereType(struct0()),
symbol: Symbol::Here(SymbolDefinition::Struct(struct0())),
}
.mock(),
SymbolDeclaration {
@ -3667,6 +3824,8 @@ mod tests {
let types = HashMap::new();
let mut checker: Checker<Bn128Field> = Checker::new();
checker.enter_scope();
assert_eq!(
checker.check_statement(statement, &*MODULE_ID, &types),
Err(vec![ErrorInner {
@ -3691,12 +3850,13 @@ mod tests {
let mut scope = HashSet::new();
scope.insert(ScopedVariable {
id: Variable::field_element("a"),
level: 0,
level: 1,
});
scope.insert(ScopedVariable {
id: Variable::field_element("b"),
level: 0,
level: 1,
});
let mut checker: Checker<Bn128Field> = new_with_args(scope, 1, HashSet::new());
assert_eq!(
checker.check_statement(statement, &*MODULE_ID, &types),
@ -3762,12 +3922,12 @@ mod tests {
let symbols = vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(foo),
symbol: Symbol::Here(SymbolDefinition::Function(foo)),
}
.mock(),
SymbolDeclaration {
id: "bar",
symbol: Symbol::HereFunction(bar),
symbol: Symbol::Here(SymbolDefinition::Function(bar)),
}
.mock(),
];
@ -3877,17 +4037,17 @@ mod tests {
let symbols = vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(foo),
symbol: Symbol::Here(SymbolDefinition::Function(foo)),
}
.mock(),
SymbolDeclaration {
id: "bar",
symbol: Symbol::HereFunction(bar),
symbol: Symbol::Here(SymbolDefinition::Function(bar)),
}
.mock(),
SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(main),
symbol: Symbol::Here(SymbolDefinition::Function(main)),
}
.mock(),
];
@ -4265,12 +4425,12 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(foo),
symbol: Symbol::Here(SymbolDefinition::Function(foo)),
}
.mock(),
SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(main),
symbol: Symbol::Here(SymbolDefinition::Function(main)),
}
.mock(),
],
@ -4352,12 +4512,12 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(foo),
symbol: Symbol::Here(SymbolDefinition::Function(foo)),
}
.mock(),
SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(main),
symbol: Symbol::Here(SymbolDefinition::Function(main)),
}
.mock(),
],
@ -4468,12 +4628,12 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(foo),
symbol: Symbol::Here(SymbolDefinition::Function(foo)),
}
.mock(),
SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(main),
symbol: Symbol::Here(SymbolDefinition::Function(main)),
}
.mock(),
],
@ -4761,12 +4921,12 @@ mod tests {
let symbols = vec![
SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(main1),
symbol: Symbol::Here(SymbolDefinition::Function(main1)),
}
.mock(),
SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(main2),
symbol: Symbol::Here(SymbolDefinition::Function(main2)),
}
.mock(),
];
@ -4879,7 +5039,7 @@ mod tests {
imports: vec![],
symbols: vec![SymbolDeclaration {
id: "Foo",
symbol: Symbol::HereType(s.mock()),
symbol: Symbol::Here(SymbolDefinition::Struct(s.mock())),
}
.mock()],
};
@ -5009,7 +5169,7 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "Foo",
symbol: Symbol::HereType(
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
fields: vec![StructDefinitionField {
id: "foo",
@ -5018,12 +5178,12 @@ mod tests {
.mock()],
}
.mock(),
),
)),
}
.mock(),
SymbolDeclaration {
id: "Bar",
symbol: Symbol::HereType(
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
fields: vec![StructDefinitionField {
id: "foo",
@ -5032,7 +5192,7 @@ mod tests {
.mock()],
}
.mock(),
),
)),
}
.mock(),
],
@ -5078,7 +5238,7 @@ mod tests {
imports: vec![],
symbols: vec![SymbolDeclaration {
id: "Bar",
symbol: Symbol::HereType(
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
fields: vec![StructDefinitionField {
id: "foo",
@ -5087,7 +5247,7 @@ mod tests {
.mock()],
}
.mock(),
),
)),
}
.mock()],
};
@ -5111,7 +5271,7 @@ mod tests {
imports: vec![],
symbols: vec![SymbolDeclaration {
id: "Foo",
symbol: Symbol::HereType(
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
fields: vec![StructDefinitionField {
id: "foo",
@ -5120,7 +5280,7 @@ mod tests {
.mock()],
}
.mock(),
),
)),
}
.mock()],
};
@ -5146,7 +5306,7 @@ mod tests {
symbols: vec![
SymbolDeclaration {
id: "Foo",
symbol: Symbol::HereType(
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
fields: vec![StructDefinitionField {
id: "bar",
@ -5155,12 +5315,12 @@ mod tests {
.mock()],
}
.mock(),
),
)),
}
.mock(),
SymbolDeclaration {
id: "Bar",
symbol: Symbol::HereType(
symbol: Symbol::Here(SymbolDefinition::Struct(
StructDefinition {
fields: vec![StructDefinitionField {
id: "foo",
@ -5169,7 +5329,7 @@ mod tests {
.mock()],
}
.mock(),
),
)),
}
.mock(),
],
@ -5638,17 +5798,17 @@ mod tests {
let m = Module::with_symbols(vec![
absy::SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(foo_field),
symbol: Symbol::Here(SymbolDefinition::Function(foo_field)),
}
.mock(),
absy::SymbolDeclaration {
id: "foo",
symbol: Symbol::HereFunction(foo_u32),
symbol: Symbol::Here(SymbolDefinition::Function(foo_u32)),
}
.mock(),
absy::SymbolDeclaration {
id: "main",
symbol: Symbol::HereFunction(main),
symbol: Symbol::Here(SymbolDefinition::Function(main)),
}
.mock(),
]);
@ -5680,6 +5840,7 @@ mod tests {
let types = HashMap::new();
let mut checker: Checker<Bn128Field> = Checker::new();
checker.enter_scope();
checker
.check_statement(
@ -5713,6 +5874,8 @@ mod tests {
let types = HashMap::new();
let mut checker: Checker<Bn128Field> = Checker::new();
checker.enter_scope();
checker
.check_statement(
Statement::Declaration(
@ -5763,6 +5926,8 @@ mod tests {
let types = HashMap::new();
let mut checker: Checker<Bn128Field> = Checker::new();
checker.enter_scope();
checker
.check_statement(
Statement::Declaration(

View file

@ -0,0 +1,822 @@
use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use std::convert::TryInto;
use zokrates_field::Field;
pub struct ConstantInliner<'ast, T: Field> {
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
}
impl<'ast, T: Field> ConstantInliner<'ast, T> {
pub fn new(modules: TypedModules<'ast, T>, location: OwnedTypedModuleId) -> Self {
ConstantInliner { modules, location }
}
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone());
inliner.fold_program(p)
}
fn module(&self) -> &TypedModule<'ast, T> {
self.modules.get(&self.location).unwrap()
}
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
let prev = self.location.clone();
self.location = location;
prev
}
fn get_constant(&mut self, id: &Identifier) -> Option<TypedConstant<'ast, T>> {
self.modules
.get(&self.location)
.unwrap()
.constants
.get(id.clone().try_into().unwrap())
.cloned()
.map(|symbol| self.get_canonical_constant(symbol))
}
fn get_canonical_constant(
&mut self,
symbol: TypedConstantSymbol<'ast, T>,
) -> TypedConstant<'ast, T> {
match symbol {
TypedConstantSymbol::There(module_id, id) => {
let location = self.change_location(module_id);
let symbol = self.module().constants.get(id).cloned().unwrap();
let symbol = self.get_canonical_constant(symbol);
let _ = self.change_location(location);
symbol
}
TypedConstantSymbol::Here(tc) => self.fold_constant(tc),
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
TypedProgram {
modules: p
.modules
.into_iter()
.map(|(module_id, module)| {
self.change_location(module_id.clone());
(module_id, self.fold_module(module))
})
.collect(),
main: p.main,
}
}
fn fold_constant_symbol(
&mut self,
s: TypedConstantSymbol<'ast, T>,
) -> TypedConstantSymbol<'ast, T> {
let tc = self.get_canonical_constant(s);
TypedConstantSymbol::Here(tc)
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => self.fold_constant(c).try_into().unwrap(),
None => fold_field_expression(self, e),
},
e => fold_field_expression(self, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => self.fold_constant(c).try_into().unwrap(),
None => fold_boolean_expression(self, e),
},
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression_inner(
&mut self,
size: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Identifier(ref id) => match self.get_constant(id) {
Some(c) => {
let e: UExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
e.into_inner()
}
None => fold_uint_expression_inner(self, size, e),
},
e => fold_uint_expression_inner(self, size, e),
}
}
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) {
Some(c) => {
let e: ArrayExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
e.into_inner()
}
None => fold_array_expression_inner(self, ty, e),
},
e => fold_array_expression_inner(self, ty, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Identifier(ref id) => match self.get_constant(id) {
Some(c) => {
let e: StructExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
e.into_inner()
}
None => fold_struct_expression_inner(self, ty, e),
},
e => fold_struct_expression_inner(self, ty, e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::typed_absy::types::DeclarationSignature;
use crate::typed_absy::{
DeclarationFunctionKey, DeclarationType, FieldElementExpression, GType, Identifier,
TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, TypedStatement,
};
use zokrates_field::Bn128Field;
#[test]
fn inline_const_field() {
// const field a = 1
//
// def main() -> field:
// return a
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(const_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))),
)),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Number(Bn128Field::from(1)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
#[test]
fn inline_const_boolean() {
// const bool a = true
//
// def main() -> bool:
// return a
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier(
Identifier::from(const_id),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::Boolean,
TypedExpression::Boolean(BooleanExpression::Value(true)),
)),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
BooleanExpression::Value(true).into()
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
#[test]
fn inline_const_uint() {
// const u32 a = 0x00000001
//
// def main() -> u32:
// return a
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier(
Identifier::from(const_id),
)
.annotate(UBitwidth::B32)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::Uint(UBitwidth::B32),
UExpressionInner::Value(1u128)
.annotate(UBitwidth::B32)
.into(),
)),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![UExpressionInner::Value(1u128)
.annotate(UBitwidth::B32)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
#[test]
fn inline_const_field_array() {
// const field[2] a = [2, 2]
//
// def main() -> field:
// return a[0] + a[1]
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(Identifier::from(const_id))
.annotate(GType::FieldElement, 2usize),
box UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
)
.into(),
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(Identifier::from(const_id))
.annotate(GType::FieldElement, 2usize),
box UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
)
.into(),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::Array(
ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
),
)),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
FieldElementExpression::Select(
box ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
box UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
)
.into(),
FieldElementExpression::Select(
box ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
box UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
)
.into(),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
#[test]
fn inline_nested_const_field() {
// const field a = 1
// const field b = a + 1
//
// def main() -> field:
// return b
let const_a_id = "a";
let const_b_id = "b";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(const_b_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: vec![
(
const_a_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
)),
)),
),
(
const_b_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from(
const_a_id,
)),
box FieldElementExpression::Number(Bn128Field::from(1)),
)),
)),
),
]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1)),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants: vec![
(
const_a_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
)),
)),
),
(
const_b_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Add(
box FieldElementExpression::Number(Bn128Field::from(1)),
box FieldElementExpression::Number(Bn128Field::from(1)),
)),
)),
),
]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
#[test]
fn inline_imported_constant() {
// ---------------------
// module `foo`
// --------------------
// const field FOO = 42
//
// def main():
// return
//
// ---------------------
// module `main`
// ---------------------
// from "foo" import FOO
//
// def main() -> field:
// return FOO
let foo_const_id = "FOO";
let foo_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main")
.signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![],
signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
}),
)]
.into_iter()
.collect(),
constants: vec![(
foo_const_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(42),
)),
)),
)]
.into_iter()
.collect(),
};
let main_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(foo_const_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)]
.into_iter()
.collect(),
constants: vec![(
foo_const_id,
TypedConstantSymbol::There(OwnedTypedModuleId::from("foo"), foo_const_id),
)]
.into_iter()
.collect(),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), main_module),
("foo".into(), foo_module.clone()),
]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Number(Bn128Field::from(42)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)]
.into_iter()
.collect(),
constants: vec![(
foo_const_id,
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(42),
)),
)),
)]
.into_iter()
.collect(),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), expected_main_module),
("foo".into(), foo_module),
]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
}

View file

@ -5,6 +5,7 @@
//! @date 2018
mod bounds_checker;
mod constant_inliner;
mod flat_propagation;
mod flatten_complex_types;
mod propagation;
@ -28,6 +29,7 @@ use self::variable_read_remover::VariableReadRemover;
use self::variable_write_remover::VariableWriteRemover;
use crate::flat_absy::FlatProg;
use crate::ir::Prog;
use crate::static_analysis::constant_inliner::ConstantInliner;
use crate::typed_absy::{abi::Abi, TypedProgram};
use crate::zir::ZirProgram;
use std::fmt;
@ -73,8 +75,11 @@ impl fmt::Display for Error {
impl<'ast, T: Field> TypedProgram<'ast, T> {
pub fn analyse(self) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
let r = reduce_program(self).map_err(Error::from)?;
// inline user-defined constants
let r = ConstantInliner::inline(self);
// reduce the program to a single function
let r = reduce_program(r).map_err(Error::from)?;
// generate abi
let abi = r.abi();
// propagate

View file

@ -264,6 +264,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
})
.collect::<Result<_, _>>()?,
..m
})
}

View file

@ -547,6 +547,7 @@ pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, E
)]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -761,6 +762,7 @@ mod tests {
]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -826,6 +828,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -959,6 +962,7 @@ mod tests {
]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1043,6 +1047,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1185,6 +1190,7 @@ mod tests {
]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1269,6 +1275,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1447,6 +1454,7 @@ mod tests {
]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1558,6 +1566,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1646,6 +1655,7 @@ mod tests {
]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()

View file

@ -65,7 +65,13 @@ mod tests {
);
let mut modules = HashMap::new();
modules.insert("main".into(), TypedModule { functions });
modules.insert(
"main".into(),
TypedModule {
functions,
constants: Default::default(),
},
);
let typed_ast: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),

View file

@ -9,8 +9,19 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_program(self, p)
}
fn fold_module(&mut self, p: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
fold_module(self, p)
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
fold_module(self, m)
}
fn fold_constant(&mut self, c: TypedConstant<'ast, T>) -> TypedConstant<'ast, T> {
fold_constant(self, c)
}
fn fold_constant_symbol(
&mut self,
s: TypedConstantSymbol<'ast, T>,
) -> TypedConstantSymbol<'ast, T> {
fold_constant_symbol(self, s)
}
fn fold_function_symbol(
@ -190,10 +201,15 @@ pub trait Folder<'ast, T: Field>: Sized {
pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
p: TypedModule<'ast, T>,
m: TypedModule<'ast, T>,
) -> TypedModule<'ast, T> {
TypedModule {
functions: p
constants: m
.constants
.into_iter()
.map(|(key, tc)| (key, f.fold_constant_symbol(tc)))
.collect(),
functions: m
.functions
.into_iter()
.map(|(key, fun)| (key, f.fold_function_symbol(fun)))
@ -711,6 +727,26 @@ pub fn fold_struct_expression<'ast, T: Field, F: Folder<'ast, T>>(
}
}
pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
c: TypedConstant<'ast, T>,
) -> TypedConstant<'ast, T> {
TypedConstant {
ty: f.fold_type(c.ty),
expression: f.fold_expression(c.expression),
}
}
pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: TypedConstantSymbol<'ast, T>,
) -> TypedConstantSymbol<'ast, T> {
match s {
TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(f.fold_constant(tc)),
there => there,
}
}
pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: TypedFunctionSymbol<'ast, T>,

View file

@ -61,6 +61,18 @@ pub type TypedModules<'ast, T> = HashMap<OwnedTypedModuleId, TypedModule<'ast, T
pub type TypedFunctionSymbols<'ast, T> =
HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>;
pub type ConstantIdentifier<'ast> = &'ast str;
#[derive(Clone, Debug, PartialEq)]
pub enum TypedConstantSymbol<'ast, T> {
Here(TypedConstant<'ast, T>),
There(OwnedTypedModuleId, ConstantIdentifier<'ast>),
}
/// A collection of `TypedConstantSymbol`s
pub type TypedConstantSymbols<'ast, T> =
HashMap<ConstantIdentifier<'ast>, TypedConstantSymbol<'ast, T>>;
/// A typed program as a collection of modules, one of them being the main
#[derive(PartialEq, Debug, Clone)]
pub struct TypedProgram<'ast, T> {
@ -135,11 +147,13 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedProgram<'ast, T> {
}
}
/// A typed program as a collection of functions. Types have been resolved during semantic checking.
/// A typed module as a collection of functions. Types have been resolved during semantic checking.
#[derive(PartialEq, Clone)]
pub struct TypedModule<'ast, T> {
/// Functions of the program
/// Functions of the module
pub functions: TypedFunctionSymbols<'ast, T>,
/// Constants defined in module
pub constants: TypedConstantSymbols<'ast, T>,
}
#[derive(Clone, PartialEq)]
@ -182,22 +196,31 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = self
.functions
.constants
.iter()
.map(|(key, symbol)| match symbol {
TypedConstantSymbol::Here(ref tc) => {
format!("const {} {} = {}", tc.ty, key, tc.expression)
}
TypedConstantSymbol::There(ref module_id, ref id) => {
format!("from \"{}\" import {} as {}", module_id.display(), id, key)
}
})
.chain(self.functions.iter().map(|(key, symbol)| match symbol {
TypedFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function),
TypedFunctionSymbol::There(ref fun_key) => format!(
"import {} from \"{}\" as {} // with signature {}",
fun_key.id,
"from \"{}\" import {} as {} // with signature {}",
fun_key.module.display(),
fun_key.id,
key.id,
key.signature
),
TypedFunctionSymbol::Flat(ref flat_fun) => {
format!("def {}{}:\n\t// hidden", key.id, flat_fun.signature())
}
})
}))
.collect::<Vec<_>>();
write!(f, "{}", res.join("\n"))
}
}
@ -206,8 +229,13 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedModule<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"module(\n\tfunctions:\n\t\t{:?}\n)",
"TypedModule(\n\tFunctions:\n\t\t{:?}\n\tConstants:\n\t\t{:?}\n)",
self.functions
.iter()
.map(|x| format!("{:?}", x))
.collect::<Vec<_>>()
.join("\n\t\t"),
self.constants
.iter()
.map(|x| format!("{:?}", x))
.collect::<Vec<_>>()
@ -306,6 +334,36 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunction<'ast, T> {
}
}
#[derive(Clone, PartialEq)]
pub struct TypedConstant<'ast, T> {
ty: Type<'ast, T>,
expression: TypedExpression<'ast, T>,
}
impl<'ast, T> TypedConstant<'ast, T> {
pub fn new(ty: Type<'ast, T>, expression: TypedExpression<'ast, T>) -> Self {
TypedConstant { ty, expression }
}
}
impl<'ast, T: fmt::Debug> fmt::Debug for TypedConstant<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "TypedConstant({:?}, {:?})", self.ty, self.expression)
}
}
impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "const {}({})", self.ty, self.expression)
}
}
impl<'ast, T: Clone> Typed<'ast, T> for TypedConstant<'ast, T> {
fn get_type(&self) -> Type<'ast, T> {
self.ty.clone()
}
}
/// Something we can assign to.
#[derive(Clone, PartialEq, Hash, Eq)]
pub enum TypedAssignee<'ast, T> {
@ -1224,6 +1282,56 @@ impl<'ast, T> TryFrom<TypedExpression<'ast, T>> for StructExpression<'ast, T> {
}
}
impl<'ast, T> TryFrom<TypedConstant<'ast, T>> for FieldElementExpression<'ast, T> {
type Error = ();
fn try_from(
tc: TypedConstant<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
tc.expression.try_into()
}
}
impl<'ast, T> TryFrom<TypedConstant<'ast, T>> for BooleanExpression<'ast, T> {
type Error = ();
fn try_from(tc: TypedConstant<'ast, T>) -> Result<BooleanExpression<'ast, T>, Self::Error> {
tc.expression.try_into()
}
}
impl<'ast, T> TryFrom<TypedConstant<'ast, T>> for UExpression<'ast, T> {
type Error = ();
fn try_from(tc: TypedConstant<'ast, T>) -> Result<UExpression<'ast, T>, Self::Error> {
tc.expression.try_into()
}
}
impl<'ast, T> TryFrom<TypedConstant<'ast, T>> for ArrayExpression<'ast, T> {
type Error = ();
fn try_from(tc: TypedConstant<'ast, T>) -> Result<ArrayExpression<'ast, T>, Self::Error> {
tc.expression.try_into()
}
}
impl<'ast, T> TryFrom<TypedConstant<'ast, T>> for StructExpression<'ast, T> {
type Error = ();
fn try_from(tc: TypedConstant<'ast, T>) -> Result<StructExpression<'ast, T>, Self::Error> {
tc.expression.try_into()
}
}
impl<'ast, T> TryFrom<TypedConstant<'ast, T>> for IntExpression<'ast, T> {
type Error = ();
fn try_from(tc: TypedConstant<'ast, T>) -> Result<IntExpression<'ast, T>, Self::Error> {
tc.expression.try_into()
}
}
impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {

View file

@ -16,9 +16,23 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fn fold_module(
&mut self,
p: TypedModule<'ast, T>,
m: TypedModule<'ast, T>,
) -> Result<TypedModule<'ast, T>, Self::Error> {
fold_module(self, p)
fold_module(self, m)
}
fn fold_constant(
&mut self,
c: TypedConstant<'ast, T>,
) -> Result<TypedConstant<'ast, T>, Self::Error> {
fold_constant(self, c)
}
fn fold_constant_symbol(
&mut self,
s: TypedConstantSymbol<'ast, T>,
) -> Result<TypedConstantSymbol<'ast, T>, Self::Error> {
fold_constant_symbol(self, s)
}
fn fold_function_symbol(
@ -793,6 +807,26 @@ pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
})
}
pub fn fold_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
c: TypedConstant<'ast, T>,
) -> Result<TypedConstant<'ast, T>, F::Error> {
Ok(TypedConstant {
ty: f.fold_type(c.ty)?,
expression: f.fold_expression(c.expression)?,
})
}
pub fn fold_constant_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
s: TypedConstantSymbol<'ast, T>,
) -> Result<TypedConstantSymbol<'ast, T>, F::Error> {
match s {
TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(f.fold_constant(tc)?)),
there => Ok(there),
}
}
pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
s: TypedFunctionSymbol<'ast, T>,
@ -805,10 +839,15 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
pub fn fold_module<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
p: TypedModule<'ast, T>,
m: TypedModule<'ast, T>,
) -> Result<TypedModule<'ast, T>, F::Error> {
Ok(TypedModule {
functions: p
constants: m
.constants
.into_iter()
.map(|(key, tc)| f.fold_constant_symbol(tc).map(|tc| (key, tc)))
.collect::<Result<_, _>>()?,
functions: m
.functions
.into_iter()
.map(|(key, fun)| f.fold_function_symbol(fun).map(|f| (key, f)))

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/array.zok",
"max_constraint_count": 2,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["1", "2"]
}
}
}
]
}

View file

@ -0,0 +1,4 @@
const field[2] ARRAY = [1, 2]
def main() -> field[2]:
return ARRAY

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/bool.zok",
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["1"]
}
}
}
]
}

View file

@ -0,0 +1,4 @@
const bool BOOLEAN = true
def main() -> bool:
return BOOLEAN

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/field.zok",
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["1"]
}
}
}
]
}

View file

@ -0,0 +1,4 @@
const field ONE = 1
def main() -> field:
return ONE

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/nested.zok",
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["8"]
}
}
}
]
}

View file

@ -0,0 +1,6 @@
const field A = 2
const field B = 2
const field[2] ARRAY = [A * 2, B * 2]
def main() -> field:
return ARRAY[0] + ARRAY[1]

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/struct.zok",
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["4"]
}
}
}
]
}

View file

@ -0,0 +1,9 @@
struct Foo {
field a
field b
}
const Foo FOO = Foo { a: 2, b: 2 }
def main() -> field:
return FOO.a + FOO.b

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/uint.zok",
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["1"]
}
}
}
]
}

View file

@ -0,0 +1,4 @@
const u32 ONE = 0x00000001
def main() -> u32:
return ONE

View file

@ -37,7 +37,7 @@ ace.define("ace/mode/zokrates_highlight_rules",["require","exports","module","ac
var ZoKratesHighlightRules = function () {
var keywords = (
"assert|as|bool|byte|def|do|else|endfor|export|false|field|for|if|then|fi|import|from|in|private|public|return|struct|true|u8|u16|u32|u64"
"assert|as|bool|byte|const|def|do|else|endfor|export|false|field|for|if|then|fi|import|from|in|private|public|return|struct|true|u8|u16|u32|u64"
);
var keywordMapper = this.createKeywordMapper({

View file

@ -1,27 +1,36 @@
{
"name": "zokrates",
"displayName": "zokrates",
"description": "Syntax highlighting for the ZoKrates language",
"publisher": "zokrates",
"repository": "https://github.com/ZoKrates/ZoKrates",
"version": "0.0.1",
"engines": {
"vscode": "^1.53.0"
},
"categories": [
"Programming Languages"
"name": "zokrates",
"displayName": "zokrates",
"description": "Syntax highlighting for the ZoKrates language",
"publisher": "zokrates",
"repository": "https://github.com/ZoKrates/ZoKrates",
"version": "0.0.1",
"engines": {
"vscode": "^1.53.0"
},
"categories": [
"Programming Languages"
],
"contributes": {
"languages": [
{
"id": "zokrates",
"aliases": [
"ZoKrates",
"zokrates"
],
"extensions": [
".zok"
],
"configuration": "./language-configuration.json"
}
],
"contributes": {
"languages": [{
"id": "zokrates",
"aliases": ["ZoKrates", "zokrates"],
"extensions": [".zok"],
"configuration": "./language-configuration.json"
}],
"grammars": [{
"language": "zokrates",
"scopeName": "source.zok",
"path": "./syntaxes/zokrates.tmLanguage.json"
}]
}
"grammars": [
{
"language": "zokrates",
"scopeName": "source.zok",
"path": "./syntaxes/zokrates.tmLanguage.json"
}
]
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,349 @@
$schema: 'https://raw.githubusercontent.com/martinring/tmlanguage/master/tmlanguage.json'
name: ZoKrates
fileTypes:
- zok
scopeName: source.zok
patterns:
-
comment: attributes
name: meta.attribute.zokrates
begin: '(#)(\!?)(\[)'
beginCaptures:
'1':
name: punctuation.definition.attribute.zokrates
'2':
name: keyword.operator.attribute.inner.zokrates
'3':
name: punctuation.brackets.attribute.zokrates
end: '\]'
endCaptures:
'0':
name: punctuation.brackets.attribute.zokrates
patterns:
-
include: '#block-comments'
-
include: '#comments'
-
include: '#keywords'
-
include: '#punctuation'
-
include: '#strings'
-
include: '#types'
-
include: '#block-comments'
-
include: '#comments'
-
include: '#constants'
-
include: '#functions'
-
include: '#types'
-
include: '#keywords'
-
include: '#punctuation'
-
include: '#strings'
-
include: '#variables'
repository:
comments:
patterns:
-
comment: 'line comments'
name: comment.line.double-slash.zokrates
match: '\s*//.*'
block-comments:
patterns:
-
comment: 'empty block comments'
name: comment.block.zokrates
match: '/\*\*/'
-
comment: 'block comments'
name: comment.block.zokrates
begin: '/\*(?!\*)'
end: '\*/'
patterns:
- {include: '#block-comments'}
constants:
patterns:
-
comment: 'ALL CAPS constants'
name: constant.other.caps.zokrates
match: '\b[A-Z]{2}[A-Z0-9_]*\b'
-
comment: 'decimal integers and floats'
name: constant.numeric.decimal.zokrates
match: '\b\d[\d_]*(?:u128|u16|u32|u64|u8|f)?\b'
-
comment: 'hexadecimal integers'
name: constant.numeric.hex.zokrates
match: '\b0x[\da-fA-F_]+\b'
-
comment: booleans
name: constant.language.bool.zokrates
match: \b(true|false)\b
imports:
patterns:
-
comment: 'explicit import statement'
name: meta.import.explicit.zokrates
match: '\b(from)\s+(\".*\")(import)\s+([A-Za-z0-9_]+)\s+((as)\s+[A-Za-z0-9_]+)?\b'
patterns:
- {include: '#block-comments'}
- {include: '#comments'}
- {include: '#keywords'}
- {include: '#punctuation'}
- {include: '#types'}
- {include: '#strings'}
-
comment: 'main import statement'
name: meta.import.explicit.zokrates
match: '\b(import)\s+(\".*\")\s+((as)\s+[A-Za-z0-9_]+)?\b'
patterns:
- {include: '#block-comments'}
- {include: '#comments'}
- {include: '#keywords'}
- {include: '#punctuation'}
- {include: '#types'}
- {include: '#strings'}
constant-definitions:
patterns:
-
comment: 'constant definition'
name: meta.constant.definition.zokrates
match: '\b(const)\s+([A-Za-z0-9_]+)\s+([A-Za-z0-9_]+)\s+=\s+(?:.+)\b'
captures:
'1': {name: keyword.other.const.zokrates}
'2': {name: entity.name.type.zokrates}
'3': {name: entity.name.constant.zokrates}
patterns:
- {include: '#block-comments'}
- {include: '#comments'}
- {include: '#keywords'}
- {include: '#constants'}
- {include: '#punctuation'}
- {include: '#types'}
- {include: '#variables'}
functions:
patterns:
-
comment: 'function definition'
name: meta.function.definition.zokrates
begin: '\b(def)\s+([A-Za-z0-9_]+)((\()|(<))'
beginCaptures:
'1': {name: keyword.other.def.zokrates}
'2': {name: entity.name.function.zokrates}
'4': {name: punctuation.brackets.round.zokrates}
'5': {name: punctuation.brackets.angle.zokrates}
end: '\:|;'
endCaptures:
'0': {name: keyword.punctuation.colon.zokrates}
patterns:
- {include: '#block-comments'}
- {include: '#comments'}
- {include: '#keywords'}
- {include: '#constants'}
- {include: '#functions'}
- {include: '#punctuation'}
- {include: '#strings'}
- {include: '#types'}
- {include: '#variables'}
-
comment: 'function/method calls, chaining'
name: meta.function.call.zokrates
begin: '([A-Za-z0-9_]+)(\()'
beginCaptures:
'1': {name: entity.name.function.zokrates}
'2': {name: punctuation.brackets.round.zokrates}
end: \)
endCaptures:
'0': {name: punctuation.brackets.round.zokrates}
patterns:
- {include: '#block-comments'}
- {include: '#comments'}
- {include: '#keywords'}
- {include: '#constants'}
- {include: '#functions'}
- {include: '#punctuation'}
- {include: '#strings'}
- {include: '#types'}
- {include: '#variables'}
-
comment: 'function/method calls with turbofish'
name: meta.function.call.zokrates
begin: '([A-Za-z0-9_]+)(?=::<.*>\()'
beginCaptures:
'1': {name: entity.name.function.zokrates}
end: \)
endCaptures:
'0': {name: punctuation.brackets.round.zokrates}
patterns:
- {include: '#block-comments'}
- {include: '#comments'}
- {include: '#keywords'}
- {include: '#constants'}
- {include: '#functions'}
- {include: '#punctuation'}
- {include: '#strings'}
- {include: '#types'}
- {include: '#variables'}
keywords:
patterns:
-
comment: 'argument visibility'
name: keyword.visibility.zokrates
match: \b(public|private)\b
-
comment: 'control flow keywords'
name: keyword.control.zokrates
match: \b(do|else|for|do|endfor|if|then|fi|return|assert)\b
-
comment: 'storage keywords'
name: storage.type.zokrates
match: \b(struct)\b
-
comment: const
name: keyword.other.const.zokrates
match: \bconst\b
-
comment: def
name: keyword.other.def.zokrates
match: \bdef\b
-
comment: 'import keywords'
name: keyword.other.import.zokrates
match: \b(import|from|as)\b
-
comment: 'logical operators'
name: keyword.operator.logical.zokrates
match: '(\^|\||\|\||&|&&|<<|>>|!)(?!=)'
-
comment: 'single equal'
name: keyword.operator.assignment.equal.zokrates
match: '(?<![<>])=(?!=|>)'
-
comment: 'comparison operators'
name: keyword.operator.comparison.zokrates
match: '(=(=)?(?!>)|!=|<=|(?<!=)>=)'
-
comment: 'math operators'
name: keyword.operator.math.zokrates
match: '(([+%]|(\*(?!\w)))(?!=))|(-(?!>))|(/(?!/))'
-
comment: 'less than, greater than (special case)'
match: '(?:\b|(?:(\))|(\])|(\})))[ \t]+([<>])[ \t]+(?:\b|(?:(\()|(\[)|(\{)))'
captures:
'1': {name: punctuation.brackets.round.zokrates}
'2': {name: punctuation.brackets.square.zokrates}
'3': {name: punctuation.brackets.curly.zokrates}
'4': {name: keyword.operator.comparison.zokrates}
'5': {name: punctuation.brackets.round.zokrates}
'6': {name: punctuation.brackets.square.zokrates}
'7': {name: punctuation.brackets.curly.zokrates}
-
comment: 'dot access'
name: keyword.operator.access.dot.zokrates
match: '\.(?!\.)'
-
comment: 'ranges, range patterns'
name: keyword.operator.range.zokrates
match: '\.{2}(=|\.)?'
-
comment: colon
name: keyword.operator.colon.zokrates
match: ':(?!:)'
-
comment: 'dashrocket, skinny arrow'
name: keyword.operator.arrow.skinny.zokrates
match: '->'
types:
patterns:
-
comment: 'numeric types'
match: '(?<![A-Za-z])(u128|u16|u32|u64|u8|field)\b'
captures:
'1': {name: entity.name.type.numeric.zokrates}
-
comment: 'parameterized types'
begin: '\b([A-Z][A-Za-z0-9]*)(<)'
beginCaptures:
'1': {name: entity.name.type.zokrates}
'2': {name: punctuation.brackets.angle.zokrates}
end: '>'
endCaptures:
'0': {name: punctuation.brackets.angle.zokrates}
patterns:
- {include: '#block-comments'}
- {include: '#comments'}
- {include: '#keywords'}
- {include: '#punctuation'}
- {include: '#types'}
- {include: '#variables'}
-
comment: 'primitive types'
name: entity.name.type.primitive.zokrates
match: \b(bool)\b
-
comment: 'struct declarations'
match: '\b(struct)\s+([A-Z][A-Za-z0-9]*)\b'
captures:
'1': {name: storage.type.zokrates}
'2': {name: entity.name.type.struct.zokrates}
-
comment: types
name: entity.name.type.zokrates
match: '\b[A-Z][A-Za-z0-9]*\b(?!!)'
punctuation:
patterns:
-
comment: comma
name: punctuation.comma.zokrates
match: ','
-
comment: 'parentheses, round brackets'
name: punctuation.brackets.round.zokrates
match: '[()]'
-
comment: 'square brackets'
name: punctuation.brackets.square.zokrates
match: '[\[\]]'
-
comment: 'angle brackets'
name: punctuation.brackets.angle.zokrates
match: '(?<!=)[<>]'
strings:
patterns:
-
comment: 'double-quoted strings and byte strings'
name: string.quoted.double.zokrates
begin: '(b?)(")'
beginCaptures:
'1': {name: string.quoted.byte.raw.zokrates}
'2': {name: punctuation.definition.string.zokrates}
end: '"'
endCaptures:
'0': {name: punctuation.definition.string.zokrates}
-
comment: 'double-quoted raw strings and raw byte strings'
name: string.quoted.double.zokrates
begin: '(b?r)(#*)(")'
beginCaptures:
'1': {name: string.quoted.byte.raw.zokrates}
'2': {name: punctuation.definition.string.raw.zokrates}
'3': {name: punctuation.definition.string.zokrates}
end: '(")(\2)'
endCaptures:
'1': {name: punctuation.definition.string.zokrates}
'2': {name: punctuation.definition.string.raw.zokrates}
variables:
patterns:
-
comment: variables
name: variable.other.zokrates
match: '\b(?<!(?<!\.)\.)[a-z0-9_]+\b'

View file

@ -1,5 +1,5 @@
file = { SOI ~ NEWLINE* ~ pragma? ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ ty_struct_definition* ~ NEWLINE* ~ function_definition* ~ EOI }
file = { SOI ~ NEWLINE* ~ pragma? ~ NEWLINE* ~ import_directive* ~ NEWLINE* ~ ty_struct_definition* ~ NEWLINE* ~ const_definition* ~ NEWLINE* ~ function_definition* ~ EOI }
pragma = { "#pragma" ~ "curve" ~ curve }
curve = @{ (ASCII_ALPHANUMERIC | "_") * }
@ -11,6 +11,7 @@ import_source = @{(!"\"" ~ ANY)*}
import_symbol = { identifier ~ ("as" ~ identifier)? }
import_symbol_list = _{ import_symbol ~ ("," ~ import_symbol)* }
function_definition = {"def" ~ identifier ~ constant_generics_declaration? ~ "(" ~ parameter_list ~ ")" ~ return_types ~ ":" ~ NEWLINE* ~ statement* }
const_definition = {"const" ~ ty ~ identifier ~ "=" ~ expression ~ NEWLINE*}
return_types = _{ ( "->" ~ ( "(" ~ type_list ~ ")" | ty ))? }
constant_generics_declaration = _{ "<" ~ constant_generics_list ~ ">" }
constant_generics_list = _{ identifier ~ ("," ~ identifier)* }
@ -159,6 +160,6 @@ COMMENT = _{ ("/*" ~ (!"*/" ~ ANY)* ~ "*/") | ("//" ~ (!NEWLINE ~ ANY)*) }
// the ordering of reserved keywords matters: if "as" is before "assert", then "assert" gets parsed as (as)(sert) and incorrectly
// accepted
keyword = @{"assert"|"as"|"bool"|"byte"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"|
keyword = @{"assert"|"as"|"bool"|"byte"|"const"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|"from"|
"in"|"private"|"public"|"return"|"struct"|"true"|"u8"|"u16"|"u32"|"u64"
}

View file

@ -10,14 +10,14 @@ extern crate lazy_static;
pub use ast::{
Access, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement,
Assignee, AssigneeAccess, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator,
CallAccess, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber, DecimalSuffix,
DefinitionStatement, ExplicitGenerics, Expression, FieldType, File, FromExpression, Function,
HexLiteralExpression, HexNumberExpression, IdentifierExpression, ImportDirective, ImportSource,
ImportSymbol, InlineArrayExpression, InlineStructExpression, InlineStructMember,
IterationStatement, LiteralExpression, OptionallyTypedAssignee, Parameter, PostfixExpression,
Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement,
StructDefinition, StructField, TernaryExpression, ToExpression, Type, UnaryExpression,
UnaryOperator, Underscore, Visibility,
CallAccess, ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber,
DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, File,
FromExpression, Function, HexLiteralExpression, HexNumberExpression, IdentifierExpression,
ImportDirective, ImportSource, ImportSymbol, InlineArrayExpression, InlineStructExpression,
InlineStructMember, IterationStatement, LiteralExpression, OptionallyTypedAssignee, Parameter,
PostfixExpression, Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression,
Statement, StructDefinition, StructField, TernaryExpression, ToExpression, Type,
UnaryExpression, UnaryOperator, Underscore, Visibility,
};
mod ast {
@ -111,6 +111,7 @@ mod ast {
pub pragma: Option<Pragma<'ast>>,
pub imports: Vec<ImportDirective<'ast>>,
pub structs: Vec<StructDefinition<'ast>>,
pub constants: Vec<ConstantDefinition<'ast>>,
pub functions: Vec<Function<'ast>>,
pub eoi: EOI,
#[pest_ast(outer())]
@ -164,6 +165,16 @@ mod ast {
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::const_definition))]
pub struct ConstantDefinition<'ast> {
pub ty: Type<'ast>,
pub id: IdentifierExpression<'ast>,
pub expression: Expression<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::import_directive))]
pub enum ImportDirective<'ast> {
@ -1047,6 +1058,7 @@ mod tests {
Ok(File {
pragma: None,
structs: vec![],
constants: vec![],
functions: vec![Function {
generics: vec![],
id: IdentifierExpression {
@ -1107,6 +1119,7 @@ mod tests {
Ok(File {
pragma: None,
structs: vec![],
constants: vec![],
functions: vec![Function {
generics: vec![],
id: IdentifierExpression {
@ -1191,6 +1204,7 @@ mod tests {
Ok(File {
pragma: None,
structs: vec![],
constants: vec![],
functions: vec![Function {
generics: vec![],
id: IdentifierExpression {
@ -1259,6 +1273,7 @@ mod tests {
Ok(File {
pragma: None,
structs: vec![],
constants: vec![],
functions: vec![Function {
generics: vec![],
id: IdentifierExpression {
@ -1299,6 +1314,7 @@ mod tests {
Ok(File {
pragma: None,
structs: vec![],
constants: vec![],
functions: vec![Function {
generics: vec![],
id: IdentifierExpression {