1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

switch to paths, add test

This commit is contained in:
schaeff 2020-02-13 14:04:29 +01:00
commit 43075af3d6
24 changed files with 723 additions and 188 deletions

View file

@ -42,9 +42,9 @@ jobs:
- restore_cache:
keys:
- v4-cargo-cache-{{ arch }}-{{ checksum "Cargo.lock" }}
- run:
name: Check format
command: rustup component add rustfmt; cargo fmt --all -- --check
# - run:
# name: Check format
# command: rustup component add rustfmt; cargo fmt --all -- --check
- run:
name: Install libsnark prerequisites
command: ./scripts/install_libsnark_prerequisites.sh

2
PROJECT/dep/dep/foo.zok Normal file
View file

@ -0,0 +1,2 @@
def foo() -> (field):
return 1

4
PROJECT/dep/foo.zok Normal file
View file

@ -0,0 +1,4 @@
from "./dep/foo.zok" import foo as bar
def foo() -> (field):
return 2 + bar()

4
PROJECT/main.zok Normal file
View file

@ -0,0 +1,4 @@
from "./dep/foo.zok" import foo
def main() -> (field):
return foo()

View file

@ -271,8 +271,6 @@ fn cli() -> Result<(), String> {
let path = PathBuf::from(sub_matches.value_of("input").unwrap());
let location = path.to_path_buf().into_os_string().into_string().unwrap();
let light = sub_matches.occurrences_of("light") > 0;
let bin_output_path = Path::new(sub_matches.value_of("output").unwrap());
@ -290,7 +288,7 @@ fn cli() -> Result<(), String> {
.map_err(|why| format!("couldn't open input file {}: {}", path.display(), why))?;
let artifacts: CompilationArtifacts<FieldPrime> =
compile(source, location, Some(&fs_resolve))
compile(source, path, Some(&fs_resolve))
.map_err(|e| format!("Compilation failed:\n\n {}", e))?;
let program_flattened = artifacts.prog();

View file

@ -25,13 +25,13 @@ impl<'ast> From<pest::ImportDirective<'ast>> for absy::ImportNode<'ast> {
match import {
pest::ImportDirective::Main(import) => {
imports::Import::new(None, import.source.span.as_str())
imports::Import::new(None, std::path::Path::new(import.source.span.as_str()))
.alias(import.alias.map(|a| a.span.as_str()))
.span(import.span)
}
pest::ImportDirective::From(import) => imports::Import::new(
Some(import.symbol.span.as_str()),
import.source.span.as_str(),
std::path::Path::new(import.source.span.as_str()),
)
.alias(import.alias.map(|a| a.span.as_str()))
.span(import.span),
@ -283,16 +283,6 @@ impl<'ast, T: Field> From<pest::IterationStatement<'ast>> for absy::StatementNod
.flat_map(|s| statements_from_statement(s))
.collect();
let from = match from.value {
absy::Expression::FieldConstant(n) => n,
e => unimplemented!("For loop bounds should be constants, found {}", e),
};
let to = match to.value {
absy::Expression::FieldConstant(n) => n,
e => unimplemented!("For loop bounds should be constants, found {}", e),
};
let var = absy::Variable::new(index, ty).span(statement.index.span);
absy::Statement::For(var, from, to, statements).span(statement.span)

View file

@ -16,6 +16,7 @@ pub use crate::absy::parameter::{Parameter, ParameterNode};
use crate::absy::types::{FunctionIdentifier, UnresolvedSignature, UnresolvedType, UserTypeId};
pub use crate::absy::variable::{Variable, VariableNode};
use embed::FlatEmbed;
use std::path::PathBuf;
use crate::imports::ImportNode;
use std::fmt;
@ -27,7 +28,7 @@ use std::collections::HashMap;
pub type Identifier<'ast> = &'ast str;
/// The identifier of a `Module`, typically a path or uri
pub type ModuleId = String;
pub type ModuleId = PathBuf;
/// A collection of `Module`s
pub type Modules<'ast, T> = HashMap<ModuleId, Module<'ast, T>>;
@ -160,7 +161,12 @@ impl<'ast> SymbolImport<'ast> {
impl<'ast> fmt::Display for SymbolImport<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} from {}", self.symbol_id, self.module_id)
write!(
f,
"{} from {}",
self.symbol_id,
self.module_id.display().to_string()
)
}
}
@ -282,7 +288,12 @@ pub enum Statement<'ast, T: Field> {
Declaration(VariableNode<'ast>),
Definition(AssigneeNode<'ast, T>, ExpressionNode<'ast, T>),
Condition(ExpressionNode<'ast, T>, ExpressionNode<'ast, T>),
For(VariableNode<'ast>, T, T, Vec<StatementNode<'ast, T>>),
For(
VariableNode<'ast>,
ExpressionNode<'ast, T>,
ExpressionNode<'ast, T>,
Vec<StatementNode<'ast, T>>,
),
MultipleDefinition(Vec<AssigneeNode<'ast, T>>, ExpressionNode<'ast, T>),
}

View file

@ -13,6 +13,7 @@ use static_analysis::Analyse;
use std::collections::HashMap;
use std::fmt;
use std::io;
use std::path::PathBuf;
use typed_absy::abi::Abi;
use typed_arena::Arena;
use zokrates_field::field::Field;
@ -35,7 +36,7 @@ impl<T: Field> CompilationArtifacts<T> {
}
#[derive(Debug)]
pub struct CompileErrors(Vec<CompileError>);
pub struct CompileErrors(pub Vec<CompileError>);
impl From<CompileError> for CompileErrors {
fn from(e: CompileError) -> CompileErrors {
@ -66,27 +67,27 @@ pub enum CompileErrorInner {
}
impl CompileErrorInner {
pub fn with_context(self, context: &String) -> CompileError {
pub fn in_file(self, context: &PathBuf) -> CompileError {
CompileError {
value: self,
context: context.clone(),
file: context.clone(),
}
}
}
#[derive(Debug)]
pub struct CompileError {
context: String,
file: PathBuf,
value: CompileErrorInner,
}
impl CompileErrors {
pub fn with_context(self, context: String) -> Self {
pub fn with_context(self, file: PathBuf) -> Self {
CompileErrors(
self.0
.into_iter()
.map(|e| CompileError {
context: context.clone(),
file: file.clone(),
..e
})
.collect(),
@ -96,7 +97,17 @@ impl CompileErrors {
impl fmt::Display for CompileError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}:{}", self.context, self.value)
write!(
f,
"{}:{}",
self.file
.canonicalize()
.unwrap()
.strip_prefix(std::env::current_dir().unwrap())
.unwrap()
.display(),
self.value
)
}
}
@ -122,7 +133,7 @@ impl From<semantics::Error> for CompileError {
fn from(error: semantics::Error) -> Self {
CompileError {
value: CompileErrorInner::SemanticError(error.inner),
context: error.module_id
file: error.module_id,
}
}
}
@ -139,14 +150,15 @@ impl fmt::Display for CompileErrorInner {
}
}
pub type Resolve<'a, E> = &'a dyn Fn(String, String) -> Result<(String, String), E>;
pub type Resolve<'a, E> = &'a dyn Fn(PathBuf, PathBuf) -> Result<(String, PathBuf), E>;
type FilePath = PathBuf;
pub fn compile<T: Field, E: Into<imports::Error>>(
source: String,
location: String,
location: FilePath,
resolve_option: Option<Resolve<E>>,
) -> Result<CompilationArtifacts<T>, CompileErrors> {
let arena = Arena::new();
let source = arena.alloc(source);
@ -154,12 +166,7 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
// check semantics
let typed_ast = Checker::check(compiled).map_err(|errors| {
CompileErrors(
errors
.into_iter()
.map(|e| CompileError::from(e))
.collect(),
)
CompileErrors(errors.into_iter().map(|e| CompileError::from(e)).collect())
})?;
let abi = typed_ast.abi();
@ -187,7 +194,7 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
pub fn compile_program<'ast, T: Field, E: Into<imports::Error>>(
source: &'ast str,
location: String,
location: FilePath,
resolve_option: Option<Resolve<E>>,
arena: &'ast Arena<String>,
) -> Result<Program<'ast, T>, CompileErrors> {
@ -211,13 +218,13 @@ pub fn compile_program<'ast, T: Field, E: Into<imports::Error>>(
pub fn compile_module<'ast, T: Field, E: Into<imports::Error>>(
source: &'ast str,
location: String,
location: FilePath,
resolve_option: Option<Resolve<E>>,
modules: &mut HashMap<ModuleId, Module<'ast, T>>,
arena: &'ast Arena<String>,
) -> Result<Module<'ast, T>, CompileErrors> {
let ast = pest::generate_ast(&source)
.map_err(|e| CompileErrors::from(CompileErrorInner::from(e).with_context(&location)))?;
.map_err(|e| CompileErrors::from(CompileErrorInner::from(e).in_file(&location)))?;
let module_without_imports: Module<T> = Module::from(ast);
Importer::new().apply_imports(

View file

@ -11,7 +11,7 @@ use zokrates_field::field::Field;
/// A low level function that contains non-deterministic introduction of variables. It is carried out as is until
/// the flattening step when it can be inlined.
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq, Hash)]
pub enum FlatEmbed {
Sha256Round,
Unpack,

View file

@ -12,6 +12,7 @@ use crate::parser::Position;
use std::collections::HashMap;
use std::fmt;
use std::io;
use std::path::{Path, PathBuf};
use typed_arena::Arena;
use zokrates_field::field::Field;
@ -54,9 +55,11 @@ impl From<io::Error> for Error {
}
}
type ImportPath<'ast> = &'ast Path;
#[derive(PartialEq, Clone)]
pub struct Import<'ast> {
source: Identifier<'ast>,
source: ImportPath<'ast>,
symbol: Option<Identifier<'ast>>,
alias: Option<Identifier<'ast>>,
}
@ -64,7 +67,7 @@ pub struct Import<'ast> {
pub type ImportNode<'ast> = Node<Import<'ast>>;
impl<'ast> Import<'ast> {
pub fn new(symbol: Option<Identifier<'ast>>, source: Identifier<'ast>) -> Import<'ast> {
pub fn new(symbol: Option<Identifier<'ast>>, source: ImportPath<'ast>) -> Import<'ast> {
Import {
symbol,
source,
@ -78,7 +81,7 @@ impl<'ast> Import<'ast> {
pub fn new_with_alias(
symbol: Option<Identifier<'ast>>,
source: Identifier<'ast>,
source: ImportPath<'ast>,
alias: Identifier<'ast>,
) -> Import<'ast> {
Import {
@ -93,7 +96,7 @@ impl<'ast> Import<'ast> {
self
}
pub fn get_source(&self) -> &Identifier<'ast> {
pub fn get_source(&self) -> &ImportPath<'ast> {
&self.source
}
}
@ -101,8 +104,8 @@ impl<'ast> Import<'ast> {
impl<'ast> fmt::Display for Import<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.alias {
Some(ref alias) => write!(f, "import {} as {}", self.source, alias),
None => write!(f, "import {}", self.source),
Some(ref alias) => write!(f, "import {} as {}", self.source.display(), alias),
None => write!(f, "import {}", self.source.display()),
}
}
}
@ -110,8 +113,13 @@ impl<'ast> fmt::Display for Import<'ast> {
impl<'ast> fmt::Debug for Import<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.alias {
Some(ref alias) => write!(f, "import(source: {}, alias: {})", self.source, alias),
None => write!(f, "import(source: {})", self.source),
Some(ref alias) => write!(
f,
"import(source: {}, alias: {})",
self.source.display(),
alias
),
None => write!(f, "import(source: {})", self.source.display()),
}
}
}
@ -126,7 +134,7 @@ impl Importer {
pub fn apply_imports<'ast, T: Field, E: Into<Error>>(
&self,
destination: Module<'ast, T>,
location: String,
location: PathBuf,
resolve_option: Option<Resolve<E>>,
modules: &mut HashMap<ModuleId, Module<'ast, T>>,
arena: &'ast Arena<String>,
@ -139,7 +147,7 @@ impl Importer {
let alias = import.alias;
// handle the case of special bellman and packing imports
if import.source.starts_with("EMBED") {
match import.source.as_ref() {
match import.source.to_str().unwrap() {
"EMBED/sha256round" => {
let alias = alias.unwrap_or("sha256round");
@ -166,16 +174,16 @@ impl Importer {
return Err(CompileErrorInner::ImportError(
Error::new(format!("Embed {} not found. Options are \"EMBED/sha256round\", \"EMBED/unpack\"", s)).with_pos(Some(pos)),
)
.with_context(&location)
.in_file(&location)
.into());
}
}
} else {
// to resolve imports, we need a resolver
let folder = std::path::PathBuf::from(location.clone()).parent().unwrap().to_path_buf().into_os_string().into_string().unwrap();
let folder = location.clone().parent().unwrap();
match resolve_option {
Some(resolve) => match resolve(folder, import.source.to_string()) {
Ok((source, location)) => {
Some(resolve) => match resolve(location.clone(), import.source.to_path_buf()) {
Ok((source, new_location)) => {
let source = arena.alloc(source);
// generate an alias from the imported path if none was given explicitely
@ -185,19 +193,23 @@ impl Importer {
.ok_or(CompileErrors::from(
CompileErrorInner::ImportError(Error::new(format!(
"Could not determine alias for import {}",
import.source
import.source.display()
)))
.with_context(&location),
.in_file(&location),
))?
.to_str()
.unwrap(),
);
let compiled =
compile_module(source, location, resolve_option, modules, &arena)
.map_err(|e| e.with_context(import.source.to_string()))?;
let compiled = compile_module(
source,
new_location.clone(),
resolve_option,
modules,
&arena,
)?;
modules.insert(import.source.to_string(), compiled);
assert!(modules.insert(new_location.clone(), compiled).is_none());
symbols.push(
SymbolDeclaration {
@ -205,7 +217,7 @@ impl Importer {
symbol: Symbol::There(
SymbolImport::with_id_in_module(
import.symbol.unwrap_or("main"),
import.source.clone(),
new_location.display().to_string(),
)
.start_end(pos.0, pos.1),
),
@ -217,7 +229,7 @@ impl Importer {
return Err(CompileErrorInner::ImportError(
err.into().with_pos(Some(pos)),
)
.with_context(&location)
.in_file(&location)
.into());
}
},
@ -225,7 +237,7 @@ impl Importer {
return Err(CompileErrorInner::from(Error::new(
"Can't resolve import without a resolver",
))
.with_context(&location)
.in_file(&location)
.into());
}
}

View file

@ -10,6 +10,7 @@ use crate::typed_absy::*;
use crate::typed_absy::{Parameter, Variable};
use std::collections::{hash_map::Entry, BTreeSet, HashMap, HashSet};
use std::fmt;
use std::path::PathBuf;
use zokrates_field::field::Field;
use crate::parser::Position;
@ -29,14 +30,14 @@ pub struct ErrorInner {
#[derive(PartialEq, Debug)]
pub struct Error {
pub inner: ErrorInner,
pub module_id: ModuleId
pub module_id: PathBuf,
}
impl ErrorInner {
fn with_module_id(self, id: &ModuleId) -> Error {
fn in_file(self, id: &ModuleId) -> Error {
Error {
inner: self,
module_id: id.clone()
module_id: id.clone(),
}
}
}
@ -254,11 +255,21 @@ impl<'ast> Checker<'ast> {
let main_id = program.main.clone();
Checker::check_single_main(state.typed_modules.get(&program.main).unwrap())
.map_err(|inner| vec![Error { inner, module_id: main_id }])?;
Checker::check_single_main(
state
.typed_modules
.get(&program.main.display().to_string())
.unwrap(),
)
.map_err(|inner| {
vec![Error {
inner,
module_id: main_id,
}]
})?;
Ok(TypedProgram {
main: program.main,
main: program.main.display().to_string(),
modules: state.typed_modules,
})
}
@ -325,13 +336,16 @@ impl<'ast> Checker<'ast> {
match self.check_struct_type_declaration(t.clone(), module_id, &state.types) {
Ok(ty) => {
match symbol_unifier.insert_type(declaration.id) {
false => errors.push(ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
)}.with_module_id(module_id)),
false => errors.push(
ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
),
}
.in_file(module_id),
),
true => {}
};
state
@ -340,17 +354,25 @@ impl<'ast> Checker<'ast> {
.or_default()
.insert(declaration.id.to_string(), ty);
}
Err(e) => errors.extend(e.into_iter().map(|inner| Error {inner, module_id: module_id.clone() })),
Err(e) => errors.extend(e.into_iter().map(|inner| Error {
inner,
module_id: module_id.clone(),
})),
}
}
Symbol::HereFunction(f) => match self.check_function(f, module_id, &state.types) {
Ok(funct) => {
match symbol_unifier.insert_function(declaration.id, funct.signature.clone()) {
false => errors.push(
ErrorInner {
pos: Some(pos),
message: format!("{} conflicts with another symbol", declaration.id,)
}.with_module_id(module_id)),
ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
),
}
.in_file(module_id),
),
true => {}
};
@ -365,7 +387,7 @@ impl<'ast> Checker<'ast> {
);
}
Err(e) => {
errors.extend(e.into_iter().map(|inner| inner.with_module_id(module_id)));
errors.extend(e.into_iter().map(|inner| inner.in_file(module_id)));
}
},
Symbol::There(import) => {
@ -377,7 +399,7 @@ impl<'ast> Checker<'ast> {
// find candidates in the checked module
let function_candidates: Vec<_> = state
.typed_modules
.get(&import.module_id)
.get(&import.module_id.display().to_string())
.unwrap()
.functions
.iter()
@ -424,9 +446,9 @@ impl<'ast> Checker<'ast> {
pos: Some(pos),
message: format!(
"Could not find symbol {} in module {}",
import.symbol_id, import.module_id,
import.symbol_id, import.module_id.display(),
),
}.with_module_id(module_id));
}.in_file(module_id));
}
(_, Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"),
_ => {
@ -440,7 +462,7 @@ impl<'ast> Checker<'ast> {
"{} conflicts with another symbol",
declaration.id,
),
}.with_module_id(module_id));
}.in_file(module_id));
},
true => {}
};
@ -450,7 +472,7 @@ impl<'ast> Checker<'ast> {
candidate.clone().id(declaration.id),
TypedFunctionSymbol::There(
candidate,
import.module_id.clone(),
import.module_id.clone().display().to_string(),
),
);
}
@ -465,10 +487,16 @@ impl<'ast> Checker<'ast> {
Symbol::Flat(funct) => {
match symbol_unifier.insert_function(declaration.id, funct.signature::<T>()) {
false => {
errors.push(ErrorInner {
pos: Some(pos),
message: format!("{} conflicts with another symbol", declaration.id,),
}.with_module_id(module_id));
errors.push(
ErrorInner {
pos: Some(pos),
message: format!(
"{} conflicts with another symbol",
declaration.id,
),
}
.in_file(module_id),
);
}
true => {}
};
@ -548,7 +576,7 @@ impl<'ast> Checker<'ast> {
// there should be no checked module at that key just yet, if there is we have a colision or we checked something twice
assert!(state
.typed_modules
.insert(module_id.clone(), typed_module)
.insert(module_id.clone().display().to_string(), typed_module)
.is_none());
}
None => {}
@ -865,6 +893,37 @@ impl<'ast> Checker<'ast> {
let var = self.check_variable(var, module_id, types).unwrap();
let from = self
.check_expression(from, module_id, &types)
.map_err(|e| vec![e])?;
let to = self
.check_expression(to, module_id, &types)
.map_err(|e| vec![e])?;
let from = match from {
TypedExpression::FieldElement(e) => Ok(e),
e => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Expected lower loop bound to be of type field, found {}",
e.get_type()
),
}),
}
.map_err(|e| vec![e])?;
let to = match to {
TypedExpression::FieldElement(e) => Ok(e),
e => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Expected higher loop bound to be of type field, found {}",
e.get_type()
),
}),
}
.map_err(|e| vec![e])?;
self.insert_into_scope(var.clone());
let mut checked_statements = vec![];
@ -2663,8 +2722,8 @@ mod tests {
let foo_statements = vec![
Statement::For(
absy::Variable::new("i", UnresolvedType::FieldElement.mock()).mock(),
FieldPrime::from(0),
FieldPrime::from(10),
Expression::FieldConstant(FieldPrime::from(0)).mock(),
Expression::FieldConstant(FieldPrime::from(10)).mock(),
vec![],
)
.mock(),
@ -2721,8 +2780,8 @@ mod tests {
let foo_statements = vec![Statement::For(
absy::Variable::new("i", UnresolvedType::FieldElement.mock()).mock(),
FieldPrime::from(0),
FieldPrime::from(10),
Expression::FieldConstant(FieldPrime::from(0)).mock(),
Expression::FieldConstant(FieldPrime::from(10)).mock(),
for_statements,
)
.mock()];
@ -2737,8 +2796,8 @@ mod tests {
let foo_statements_checked = vec![TypedStatement::For(
typed_absy::Variable::field_element("i".into()),
FieldPrime::from(0),
FieldPrime::from(10),
FieldElementExpression::Number(FieldPrime::from(0)),
FieldElementExpression::Number(FieldPrime::from(10)),
for_statements_checked,
)];
@ -3301,14 +3360,15 @@ mod tests {
&module_id,
&types,
);
let s2_checked: Result<TypedStatement<FieldPrime>, Vec<ErrorInner>> = checker.check_statement(
Statement::Declaration(
absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(),
)
.mock(),
&module_id,
&types,
);
let s2_checked: Result<TypedStatement<FieldPrime>, Vec<ErrorInner>> = checker
.check_statement(
Statement::Declaration(
absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(),
)
.mock(),
&module_id,
&types,
);
assert_eq!(
s2_checked,
Err(vec![ErrorInner {
@ -3337,12 +3397,15 @@ mod tests {
&module_id,
&types,
);
let s2_checked: Result<TypedStatement<FieldPrime>, Vec<ErrorInner>> = checker.check_statement(
Statement::Declaration(absy::Variable::new("a", UnresolvedType::Boolean.mock()).mock())
let s2_checked: Result<TypedStatement<FieldPrime>, Vec<ErrorInner>> = checker
.check_statement(
Statement::Declaration(
absy::Variable::new("a", UnresolvedType::Boolean.mock()).mock(),
)
.mock(),
&module_id,
&types,
);
&module_id,
&types,
);
assert_eq!(
s2_checked,
Err(vec![ErrorInner {

View file

@ -7,13 +7,14 @@
mod constrain_inputs;
mod flat_propagation;
mod inline;
mod propagate_unroll;
mod propagation;
mod unroll;
use self::constrain_inputs::InputConstrainer;
use self::inline::Inliner;
use self::propagate_unroll::PropagatedUnroller;
use self::propagation::Propagator;
use self::unroll::Unroller;
use crate::flat_absy::FlatProg;
use crate::typed_absy::TypedProgram;
use zokrates_field::field::Field;
@ -24,14 +25,15 @@ pub trait Analyse {
impl<'ast, T: Field> Analyse for TypedProgram<'ast, T> {
fn analyse(self) -> Self {
// unroll
let r = Unroller::unroll(self);
// propagated unrolling
let r = PropagatedUnroller::unroll(self).unwrap_or_else(|e| panic!(e));
// inline
let r = Inliner::inline(r);
// propagate
let r = Propagator::propagate(r);
// constrain inputs
let r = InputConstrainer::constrain(r);
r
}
}

View file

@ -0,0 +1,235 @@
//! Module containing iterative unrolling in order to unroll nested loops with variable bounds
//!
//! For example:
//! ```zokrates
//! for field i in 0..5 do
//! for field j in i..5 do
//! //
//! endfor
//! endfor
//! ```
//!
//! We can unroll the outer loop, but to unroll the inner one we need to propagate the value of `i` to the lower bound of the loop
//!
//! This module does exactly that:
//! - unroll the outter loop, detecting that it cannot unroll the inner one as the lower `i` bound isn't constant
//! - apply constant propagation to the program, *not visiting statements of loops whose bounds are not constant yet*
//! - unroll again, this time the 5 inner loops all have constant bounds
//!
//! In the case that a loop bound cannot be reduced to a constant, we detect it by noticing that the unroll does
//! not make progress anymore.
use static_analysis::propagation::Propagator;
use static_analysis::unroll::{Output, Unroller};
use typed_absy::TypedProgram;
use zokrates_field::field::Field;
pub struct PropagatedUnroller;
impl PropagatedUnroller {
pub fn unroll<'ast, T: Field>(
p: TypedProgram<'ast, T>,
) -> Result<TypedProgram<'ast, T>, &'static str> {
let mut blocked_at = None;
// unroll a first time, retrieving whether the unroll is complete
let mut unrolled = Unroller::unroll(p);
loop {
// conditions to exit the loop
unrolled = match unrolled {
Output::Complete(p) => return Ok(p),
Output::Incomplete(next, index) => {
if Some(index) == blocked_at {
return Err("Loop unrolling failed. This happened because a loop bound is not constant");
} else {
// update the index where we blocked
blocked_at = Some(index);
// propagate
let propagated = Propagator::propagate_verbose(next);
// unroll
Unroller::unroll(propagated)
}
}
};
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use typed_absy::types::{FunctionKey, Signature};
use typed_absy::*;
use zokrates_field::field::FieldPrime;
#[test]
fn detect_non_constant_bound() {
let loops = vec![TypedStatement::For(
Variable::field_element("i".into()),
FieldElementExpression::Identifier("i".into()),
FieldElementExpression::Number(FieldPrime::from(2)),
vec![],
)];
let statements = loops;
let p = TypedProgram {
modules: vec![(
"main".to_string(),
TypedModule {
functions: vec![(
FunctionKey::with_id("main"),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
signature: Signature::new(),
statements,
}),
)]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
main: "main".to_string(),
};
assert!(PropagatedUnroller::unroll(p).is_err());
}
#[test]
fn for_loop() {
// for field i in 0..2
// for field j in i..2
// field foo = i + j
// should be unrolled to
// i_0 = 0
// j_0 = 0
// foo_0 = i_0 + j_0
// j_1 = 1
// foo_1 = i_0 + j_1
// i_1 = 1
// j_2 = 1
// foo_2 = i_1 + j_1
let s = TypedStatement::For(
Variable::field_element("i".into()),
FieldElementExpression::Number(FieldPrime::from(0)),
FieldElementExpression::Number(FieldPrime::from(2)),
vec![TypedStatement::For(
Variable::field_element("j".into()),
FieldElementExpression::Identifier("i".into()),
FieldElementExpression::Number(FieldPrime::from(2)),
vec![
TypedStatement::Declaration(Variable::field_element("foo".into())),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("foo".into())),
FieldElementExpression::Add(
box FieldElementExpression::Identifier("i".into()),
box FieldElementExpression::Identifier("j".into()),
)
.into(),
),
],
)],
);
let expected_statements = vec![
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("i").version(0),
)),
FieldElementExpression::Number(FieldPrime::from(0)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("j").version(0),
)),
FieldElementExpression::Number(FieldPrime::from(0)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("foo").version(0),
)),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("i").version(0)),
box FieldElementExpression::Identifier(Identifier::from("j").version(0)),
)
.into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("j").version(1),
)),
FieldElementExpression::Number(FieldPrime::from(1)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("foo").version(1),
)),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("i").version(0)),
box FieldElementExpression::Identifier(Identifier::from("j").version(1)),
)
.into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("i").version(1),
)),
FieldElementExpression::Number(FieldPrime::from(1)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("j").version(2),
)),
FieldElementExpression::Number(FieldPrime::from(1)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("foo").version(2),
)),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("i").version(1)),
box FieldElementExpression::Identifier(Identifier::from("j").version(2)),
)
.into(),
),
];
let p = TypedProgram {
modules: vec![(
"main".to_string(),
TypedModule {
functions: vec![(
FunctionKey::with_id("main"),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
signature: Signature::new(),
statements: vec![s],
}),
)]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
main: "main".to_string(),
};
let statements = match PropagatedUnroller::unroll(p).unwrap().modules["main"].functions
[&FunctionKey::with_id("main")]
.clone()
{
TypedFunctionSymbol::Here(f) => f.statements,
_ => unreachable!(),
};
assert_eq!(statements, expected_statements);
}
}

View file

@ -1,5 +1,11 @@
//! Module containing constant propagation for the typed AST
//!
//! On top of the usual behavior of removing statements which assign a constant to a variable (as the variable can simply be
//! substituted for the constant whenever used), we provide a `verbose` mode which does not remove such statements. This is done
//! as for partial passes which do not visit the whole program, the variables being defined may be be used in parts of the program
//! that are not visited. Keeping the statements is semantically equivalent and enables rebuilding the set of constants at the
//! next pass.
//!
//! @file propagation.rs
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
@ -12,19 +18,36 @@ use typed_absy::types::{StructMember, Type};
use zokrates_field::field::Field;
pub struct Propagator<'ast, T: Field> {
// constants keeps track of constant expressions
// we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is
constants: HashMap<TypedAssignee<'ast, T>, TypedExpression<'ast, T>>,
// the verbose mode doesn't remove statements which assign constants to variables
// it's required when using propagation in combination with unrolling
verbose: bool,
}
impl<'ast, T: Field> Propagator<'ast, T> {
fn verbose() -> Self {
Propagator {
constants: HashMap::new(),
verbose: true,
}
}
fn new() -> Self {
Propagator {
constants: HashMap::new(),
verbose: false,
}
}
pub fn propagate(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
Propagator::new().fold_program(p)
}
pub fn propagate_verbose(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
Propagator::verbose().fold_program(p)
}
}
fn is_constant<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> bool {
@ -63,8 +86,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
let expr = self.fold_expression(expr);
if is_constant(&expr) {
self.constants.insert(TypedAssignee::Identifier(var), expr);
None
self.constants
.insert(TypedAssignee::Identifier(var.clone()), expr.clone());
match self.verbose {
true => Some(TypedStatement::Definition(
TypedAssignee::Identifier(var),
expr,
)),
false => None,
}
} else {
Some(TypedStatement::Definition(
TypedAssignee::Identifier(var),
@ -86,9 +116,17 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
self.fold_expression(e2),
))
}
// we unrolled for loops in the previous step
TypedStatement::For(..) => {
unreachable!("for loop is unexpected, it should have been unrolled")
// only loops with variable bounds are expected here
// we stop propagation here as constants maybe be modified inside the loop body
// which we do not visit
TypedStatement::For(v, from, to, statements) => {
let from = self.fold_field_expression(from);
let to = self.fold_field_expression(to);
// invalidate the constants map as any constant could be modified inside the loop body, which we don't visit
self.constants.clear();
Some(TypedStatement::For(v, from, to, statements))
}
TypedStatement::MultipleDefinition(variables, expression_list) => {
let expression_list = self.fold_expression_list(expression_list);
@ -98,6 +136,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
))
}
};
// In verbose mode, we always return a statement
assert!(res.is_some() || !self.verbose);
match res {
Some(v) => vec![v],
None => vec![],

View file

@ -11,19 +11,30 @@ use std::collections::HashMap;
use std::collections::HashSet;
use zokrates_field::field::Field;
pub enum Output<'ast, T: Field> {
Complete(TypedProgram<'ast, T>),
Incomplete(TypedProgram<'ast, T>, usize),
}
pub struct Unroller<'ast> {
substitution: HashMap<Identifier<'ast>, usize>,
// version index for any variable name
substitution: HashMap<&'ast str, usize>,
// whether all statements could be unrolled so far. Loops with variable bounds cannot.
complete: bool,
statement_count: usize,
}
impl<'ast> Unroller<'ast> {
fn new() -> Self {
Unroller {
substitution: HashMap::new(),
complete: true,
statement_count: 0,
}
}
fn issue_next_ssa_variable(&mut self, v: Variable<'ast>) -> Variable<'ast> {
let res = match self.substitution.get(&v.id) {
let res = match self.substitution.get(&v.id.id) {
Some(i) => Variable {
id: Identifier {
id: v.id.id,
@ -34,15 +45,22 @@ impl<'ast> Unroller<'ast> {
},
None => Variable { ..v.clone() },
};
self.substitution
.entry(v.id)
.entry(v.id.id)
.and_modify(|e| *e += 1)
.or_insert(0);
res
}
pub fn unroll<T: Field>(p: TypedProgram<T>) -> TypedProgram<T> {
Unroller::new().fold_program(p)
pub fn unroll<T: Field>(p: TypedProgram<T>) -> Output<T> {
let mut unroller = Unroller::new();
let p = unroller.fold_program(p);
match unroller.complete {
true => Output::Complete(p),
false => Output::Incomplete(p, unroller.statement_count),
}
}
fn choose_many<T: Field>(
@ -322,6 +340,7 @@ fn linear<'ast, T: Field>(a: TypedAssignee<'ast, T>) -> (Variable, Vec<Access<'a
impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
self.statement_count += 1;
match s {
TypedStatement::Declaration(_) => vec![],
TypedStatement::Definition(assignee, expr) => {
@ -349,6 +368,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
};
let base = self.fold_expression(base);
let indices = indices
.into_iter()
.map(|a| match a {
@ -378,34 +398,45 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
vec![TypedStatement::MultipleDefinition(variables, exprs)]
}
TypedStatement::For(v, from, to, stats) => {
let mut values: Vec<T> = vec![];
let mut current = from;
while current < to {
values.push(current.clone());
current = T::one() + &current;
let from = self.fold_field_expression(from);
let to = self.fold_field_expression(to);
match (from, to) {
(FieldElementExpression::Number(from), FieldElementExpression::Number(to)) => {
let mut values: Vec<T> = vec![];
let mut current = from;
while current < to {
values.push(current.clone());
current = T::one() + &current;
}
let res = values
.into_iter()
.map(|index| {
vec![
vec![
TypedStatement::Declaration(v.clone()),
TypedStatement::Definition(
TypedAssignee::Identifier(v.clone()),
FieldElementExpression::Number(index).into(),
),
],
stats.clone(),
]
.into_iter()
.flat_map(|x| x)
})
.flat_map(|x| x)
.flat_map(|x| self.fold_statement(x))
.collect();
res
}
(from, to) => {
self.complete = false;
vec![TypedStatement::For(v, from, to, stats)]
}
}
let res = values
.into_iter()
.map(|index| {
vec![
vec![
TypedStatement::Declaration(v.clone()),
TypedStatement::Definition(
TypedAssignee::Identifier(v.clone()),
FieldElementExpression::Number(index).into(),
),
],
stats.clone(),
]
.into_iter()
.flat_map(|x| x)
})
.flat_map(|x| x)
.flat_map(|x| self.fold_statement(x))
.collect();
res
}
s => fold_statement(self, s),
}
@ -414,7 +445,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
self.substitution = HashMap::new();
for arg in &f.arguments {
self.substitution.insert(arg.id.id.clone(), 0);
self.substitution.insert(arg.id.id.id.clone(), 0);
}
fold_function(self, f)
@ -422,7 +453,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
Identifier {
version: self.substitution.get(&n).unwrap_or(&0).clone(),
version: self.substitution.get(&n.id).unwrap_or(&0).clone(),
..n
}
}
@ -693,8 +724,8 @@ mod tests {
let s = TypedStatement::For(
Variable::field_element("i".into()),
FieldPrime::from(2),
FieldPrime::from(5),
FieldElementExpression::Number(FieldPrime::from(2)),
FieldElementExpression::Number(FieldPrime::from(5)),
vec![
TypedStatement::Declaration(Variable::field_element("foo".into())),
TypedStatement::Definition(
@ -748,6 +779,46 @@ mod tests {
assert_eq!(u.fold_statement(s), expected);
}
#[test]
fn idempotence() {
// an already unrolled program should not be modified by unrolling again
// a = 5
// a_1 = 6
// a_2 = 7
// should be turned into
// a = 5
// a_1 = 6
// a_2 = 7
let mut u = Unroller::new();
let s = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("a").version(0),
)),
FieldElementExpression::Number(FieldPrime::from(5)).into(),
);
assert_eq!(u.fold_statement(s.clone()), vec![s]);
let s = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("a").version(1),
)),
FieldElementExpression::Number(FieldPrime::from(6)).into(),
);
assert_eq!(u.fold_statement(s.clone()), vec![s]);
let s = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("a").version(2),
)),
FieldElementExpression::Number(FieldPrime::from(7)).into(),
);
assert_eq!(u.fold_statement(s.clone()), vec![s]);
}
#[test]
fn definition() {
// field a

View file

@ -229,7 +229,7 @@ impl<'ast, T: Field> fmt::Display for TypedFunction<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"({}) -> ({}):\n{}",
"({}) -> ({}):",
self.arguments
.iter()
.map(|x| format!("{}", x))
@ -241,12 +241,16 @@ impl<'ast, T: Field> fmt::Display for TypedFunction<'ast, T> {
.map(|x| format!("{}", x))
.collect::<Vec<_>>()
.join(", "),
self.statements
.iter()
.map(|x| format!("\t{}", x))
.collect::<Vec<_>>()
.join("\n")
)
)?;
writeln!(f, "")?;
for s in &self.statements {
s.fmt_indented(f, 1)?;
writeln!(f, "")?;
}
Ok(())
}
}
@ -326,7 +330,12 @@ pub enum TypedStatement<'ast, T: Field> {
Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
Declaration(Variable<'ast>),
Condition(TypedExpression<'ast, T>, TypedExpression<'ast, T>),
For(Variable<'ast>, T, T, Vec<TypedStatement<'ast, T>>),
For(
Variable<'ast>,
FieldElementExpression<'ast, T>,
FieldElementExpression<'ast, T>,
Vec<TypedStatement<'ast, T>>,
),
MultipleDefinition(Vec<Variable<'ast>>, TypedExpressionList<'ast, T>),
}
@ -364,6 +373,23 @@ impl<'ast, T: Field> fmt::Debug for TypedStatement<'ast, T> {
}
}
impl<'ast, T: Field> TypedStatement<'ast, T> {
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
match self {
TypedStatement::For(variable, from, to, statements) => {
write!(f, "{}", "\t".repeat(depth))?;
writeln!(f, "for {} in {}..{} do", variable, from, to)?;
for s in statements {
s.fmt_indented(f, depth + 1)?;
writeln!(f, "")?;
}
writeln!(f, "{}endfor", "\t".repeat(depth))
}
s => write!(f, "{}{}", "\t".repeat(depth), s),
}
}
}
impl<'ast, T: Field> fmt::Display for TypedStatement<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {

View file

@ -0,0 +1,2 @@
def foo() -> (field):
return 1

View file

@ -0,0 +1,4 @@
from "./dep/foo.zok" import foo as bar
def foo() -> (field):
return 2 + bar()

View file

@ -0,0 +1,15 @@
{
"entry_point": "./tests/tests/import/import.zok",
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["3"]
}
}
}
]
}

View file

@ -0,0 +1,4 @@
from "./dep/foo.zok" import foo
def main() -> (field):
return foo()

View file

@ -0,0 +1,25 @@
{
"entry_point": "./tests/tests/nested_loop.zok",
"tests": [
{
"input": {
"values": ["1", "2", "3", "4"]
},
"output": {
"Ok": {
"values": ["4838400", "10", "25"]
}
}
},
{
"input": {
"values": ["0", "1", "2", "3"]
},
"output": {
"Ok": {
"values": ["0", "10", "25"]
}
}
}
]
}

View file

@ -0,0 +1,27 @@
def main(field[4] values) -> (field, field, field):
field res0 = 1
field res1 = 0
field counter = 0
for field i in 0..4 do
for field j in i..4 do
counter = counter + 1
res0 = res0 * (values[i] + values[j])
endfor
endfor
for field i in 0..counter do
res1 = res1 + 1
endfor
field res2 = 0
field i = 0
for field i in i..5 do
i = 5
for field i in 0..i do
res2 = res2 + 1
endfor
endfor
return res0, res1, res2

View file

@ -6,20 +6,28 @@ use std::path::{Component, PathBuf};
const ZOKRATES_HOME: &str = &"ZOKRATES_HOME";
type CurrentLocation = String;
type ImportLocation<'a> = String;
type CurrentLocation = PathBuf;
type ImportLocation<'a> = PathBuf;
type SourceCode = String;
pub fn resolve<'a>(
current_location: CurrentLocation,
import_location: ImportLocation<'a>,
) -> Result<(SourceCode, CurrentLocation), io::Error> {
println!(
"get file {} {}",
current_location.display(),
import_location.display()
);
let source = Path::new(&import_location);
// paths starting with `./` or `../` are interpreted relative to the current file
// other paths `abc/def` are interpreted relative to $ZOKRATES_HOME
let base = match source.components().next() {
Some(Component::CurDir) | Some(Component::ParentDir) => PathBuf::from(current_location),
Some(Component::CurDir) | Some(Component::ParentDir) => {
PathBuf::from(current_location).parent().unwrap().into()
}
_ => PathBuf::from(
std::env::var(ZOKRATES_HOME).expect("$ZOKRATES_HOME is not set, please set it"),
),
@ -33,16 +41,9 @@ pub fn resolve<'a>(
return Err(io::Error::new(io::ErrorKind::Other, "Not a file"));
}
let next_location = generate_next_location(&path_owned)?;
let source = read_to_string(path_owned)?;
let source = read_to_string(&path_owned)?;
Ok((source, next_location))
}
fn generate_next_location<'a>(path: &'a PathBuf) -> Result<String, io::Error> {
path.parent()
.ok_or(io::Error::new(io::ErrorKind::Other, "Invalid path"))
.map(|v| v.to_path_buf().into_os_string().into_string().unwrap())
Ok((source, path_owned))
}
#[cfg(test)]

View file

@ -77,17 +77,7 @@ pub fn test_inner(test_path: &str) {
let code = std::fs::read_to_string(&t.entry_point).unwrap();
let artifacts = compile(
code,
t.entry_point
.parent()
.unwrap()
.to_str()
.unwrap()
.to_string(),
Some(&resolve),
)
.unwrap();
let artifacts = compile(code, t.entry_point.clone(), Some(&resolve)).unwrap();
let bin = artifacts.prog();