switch to paths, add test
This commit is contained in:
commit
43075af3d6
24 changed files with 723 additions and 188 deletions
|
@ -42,9 +42,9 @@ jobs:
|
|||
- restore_cache:
|
||||
keys:
|
||||
- v4-cargo-cache-{{ arch }}-{{ checksum "Cargo.lock" }}
|
||||
- run:
|
||||
name: Check format
|
||||
command: rustup component add rustfmt; cargo fmt --all -- --check
|
||||
# - run:
|
||||
# name: Check format
|
||||
# command: rustup component add rustfmt; cargo fmt --all -- --check
|
||||
- run:
|
||||
name: Install libsnark prerequisites
|
||||
command: ./scripts/install_libsnark_prerequisites.sh
|
||||
|
|
2
PROJECT/dep/dep/foo.zok
Normal file
2
PROJECT/dep/dep/foo.zok
Normal file
|
@ -0,0 +1,2 @@
|
|||
def foo() -> (field):
|
||||
return 1
|
4
PROJECT/dep/foo.zok
Normal file
4
PROJECT/dep/foo.zok
Normal file
|
@ -0,0 +1,4 @@
|
|||
from "./dep/foo.zok" import foo as bar
|
||||
|
||||
def foo() -> (field):
|
||||
return 2 + bar()
|
4
PROJECT/main.zok
Normal file
4
PROJECT/main.zok
Normal file
|
@ -0,0 +1,4 @@
|
|||
from "./dep/foo.zok" import foo
|
||||
|
||||
def main() -> (field):
|
||||
return foo()
|
|
@ -271,8 +271,6 @@ fn cli() -> Result<(), String> {
|
|||
|
||||
let path = PathBuf::from(sub_matches.value_of("input").unwrap());
|
||||
|
||||
let location = path.to_path_buf().into_os_string().into_string().unwrap();
|
||||
|
||||
let light = sub_matches.occurrences_of("light") > 0;
|
||||
|
||||
let bin_output_path = Path::new(sub_matches.value_of("output").unwrap());
|
||||
|
@ -290,7 +288,7 @@ fn cli() -> Result<(), String> {
|
|||
.map_err(|why| format!("couldn't open input file {}: {}", path.display(), why))?;
|
||||
|
||||
let artifacts: CompilationArtifacts<FieldPrime> =
|
||||
compile(source, location, Some(&fs_resolve))
|
||||
compile(source, path, Some(&fs_resolve))
|
||||
.map_err(|e| format!("Compilation failed:\n\n {}", e))?;
|
||||
|
||||
let program_flattened = artifacts.prog();
|
||||
|
|
|
@ -25,13 +25,13 @@ impl<'ast> From<pest::ImportDirective<'ast>> for absy::ImportNode<'ast> {
|
|||
|
||||
match import {
|
||||
pest::ImportDirective::Main(import) => {
|
||||
imports::Import::new(None, import.source.span.as_str())
|
||||
imports::Import::new(None, std::path::Path::new(import.source.span.as_str()))
|
||||
.alias(import.alias.map(|a| a.span.as_str()))
|
||||
.span(import.span)
|
||||
}
|
||||
pest::ImportDirective::From(import) => imports::Import::new(
|
||||
Some(import.symbol.span.as_str()),
|
||||
import.source.span.as_str(),
|
||||
std::path::Path::new(import.source.span.as_str()),
|
||||
)
|
||||
.alias(import.alias.map(|a| a.span.as_str()))
|
||||
.span(import.span),
|
||||
|
@ -283,16 +283,6 @@ impl<'ast, T: Field> From<pest::IterationStatement<'ast>> for absy::StatementNod
|
|||
.flat_map(|s| statements_from_statement(s))
|
||||
.collect();
|
||||
|
||||
let from = match from.value {
|
||||
absy::Expression::FieldConstant(n) => n,
|
||||
e => unimplemented!("For loop bounds should be constants, found {}", e),
|
||||
};
|
||||
|
||||
let to = match to.value {
|
||||
absy::Expression::FieldConstant(n) => n,
|
||||
e => unimplemented!("For loop bounds should be constants, found {}", e),
|
||||
};
|
||||
|
||||
let var = absy::Variable::new(index, ty).span(statement.index.span);
|
||||
|
||||
absy::Statement::For(var, from, to, statements).span(statement.span)
|
||||
|
|
|
@ -16,6 +16,7 @@ pub use crate::absy::parameter::{Parameter, ParameterNode};
|
|||
use crate::absy::types::{FunctionIdentifier, UnresolvedSignature, UnresolvedType, UserTypeId};
|
||||
pub use crate::absy::variable::{Variable, VariableNode};
|
||||
use embed::FlatEmbed;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::imports::ImportNode;
|
||||
use std::fmt;
|
||||
|
@ -27,7 +28,7 @@ use std::collections::HashMap;
|
|||
pub type Identifier<'ast> = &'ast str;
|
||||
|
||||
/// The identifier of a `Module`, typically a path or uri
|
||||
pub type ModuleId = String;
|
||||
pub type ModuleId = PathBuf;
|
||||
|
||||
/// A collection of `Module`s
|
||||
pub type Modules<'ast, T> = HashMap<ModuleId, Module<'ast, T>>;
|
||||
|
@ -160,7 +161,12 @@ impl<'ast> SymbolImport<'ast> {
|
|||
|
||||
impl<'ast> fmt::Display for SymbolImport<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{} from {}", self.symbol_id, self.module_id)
|
||||
write!(
|
||||
f,
|
||||
"{} from {}",
|
||||
self.symbol_id,
|
||||
self.module_id.display().to_string()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -282,7 +288,12 @@ pub enum Statement<'ast, T: Field> {
|
|||
Declaration(VariableNode<'ast>),
|
||||
Definition(AssigneeNode<'ast, T>, ExpressionNode<'ast, T>),
|
||||
Condition(ExpressionNode<'ast, T>, ExpressionNode<'ast, T>),
|
||||
For(VariableNode<'ast>, T, T, Vec<StatementNode<'ast, T>>),
|
||||
For(
|
||||
VariableNode<'ast>,
|
||||
ExpressionNode<'ast, T>,
|
||||
ExpressionNode<'ast, T>,
|
||||
Vec<StatementNode<'ast, T>>,
|
||||
),
|
||||
MultipleDefinition(Vec<AssigneeNode<'ast, T>>, ExpressionNode<'ast, T>),
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ use static_analysis::Analyse;
|
|||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use typed_absy::abi::Abi;
|
||||
use typed_arena::Arena;
|
||||
use zokrates_field::field::Field;
|
||||
|
@ -35,7 +36,7 @@ impl<T: Field> CompilationArtifacts<T> {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CompileErrors(Vec<CompileError>);
|
||||
pub struct CompileErrors(pub Vec<CompileError>);
|
||||
|
||||
impl From<CompileError> for CompileErrors {
|
||||
fn from(e: CompileError) -> CompileErrors {
|
||||
|
@ -66,27 +67,27 @@ pub enum CompileErrorInner {
|
|||
}
|
||||
|
||||
impl CompileErrorInner {
|
||||
pub fn with_context(self, context: &String) -> CompileError {
|
||||
pub fn in_file(self, context: &PathBuf) -> CompileError {
|
||||
CompileError {
|
||||
value: self,
|
||||
context: context.clone(),
|
||||
file: context.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CompileError {
|
||||
context: String,
|
||||
file: PathBuf,
|
||||
value: CompileErrorInner,
|
||||
}
|
||||
|
||||
impl CompileErrors {
|
||||
pub fn with_context(self, context: String) -> Self {
|
||||
pub fn with_context(self, file: PathBuf) -> Self {
|
||||
CompileErrors(
|
||||
self.0
|
||||
.into_iter()
|
||||
.map(|e| CompileError {
|
||||
context: context.clone(),
|
||||
file: file.clone(),
|
||||
..e
|
||||
})
|
||||
.collect(),
|
||||
|
@ -96,7 +97,17 @@ impl CompileErrors {
|
|||
|
||||
impl fmt::Display for CompileError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}:{}", self.context, self.value)
|
||||
write!(
|
||||
f,
|
||||
"{}:{}",
|
||||
self.file
|
||||
.canonicalize()
|
||||
.unwrap()
|
||||
.strip_prefix(std::env::current_dir().unwrap())
|
||||
.unwrap()
|
||||
.display(),
|
||||
self.value
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -122,7 +133,7 @@ impl From<semantics::Error> for CompileError {
|
|||
fn from(error: semantics::Error) -> Self {
|
||||
CompileError {
|
||||
value: CompileErrorInner::SemanticError(error.inner),
|
||||
context: error.module_id
|
||||
file: error.module_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -139,14 +150,15 @@ impl fmt::Display for CompileErrorInner {
|
|||
}
|
||||
}
|
||||
|
||||
pub type Resolve<'a, E> = &'a dyn Fn(String, String) -> Result<(String, String), E>;
|
||||
pub type Resolve<'a, E> = &'a dyn Fn(PathBuf, PathBuf) -> Result<(String, PathBuf), E>;
|
||||
|
||||
type FilePath = PathBuf;
|
||||
|
||||
pub fn compile<T: Field, E: Into<imports::Error>>(
|
||||
source: String,
|
||||
location: String,
|
||||
location: FilePath,
|
||||
resolve_option: Option<Resolve<E>>,
|
||||
) -> Result<CompilationArtifacts<T>, CompileErrors> {
|
||||
|
||||
let arena = Arena::new();
|
||||
|
||||
let source = arena.alloc(source);
|
||||
|
@ -154,12 +166,7 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
|
|||
|
||||
// check semantics
|
||||
let typed_ast = Checker::check(compiled).map_err(|errors| {
|
||||
CompileErrors(
|
||||
errors
|
||||
.into_iter()
|
||||
.map(|e| CompileError::from(e))
|
||||
.collect(),
|
||||
)
|
||||
CompileErrors(errors.into_iter().map(|e| CompileError::from(e)).collect())
|
||||
})?;
|
||||
|
||||
let abi = typed_ast.abi();
|
||||
|
@ -187,7 +194,7 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
|
|||
|
||||
pub fn compile_program<'ast, T: Field, E: Into<imports::Error>>(
|
||||
source: &'ast str,
|
||||
location: String,
|
||||
location: FilePath,
|
||||
resolve_option: Option<Resolve<E>>,
|
||||
arena: &'ast Arena<String>,
|
||||
) -> Result<Program<'ast, T>, CompileErrors> {
|
||||
|
@ -211,13 +218,13 @@ pub fn compile_program<'ast, T: Field, E: Into<imports::Error>>(
|
|||
|
||||
pub fn compile_module<'ast, T: Field, E: Into<imports::Error>>(
|
||||
source: &'ast str,
|
||||
location: String,
|
||||
location: FilePath,
|
||||
resolve_option: Option<Resolve<E>>,
|
||||
modules: &mut HashMap<ModuleId, Module<'ast, T>>,
|
||||
arena: &'ast Arena<String>,
|
||||
) -> Result<Module<'ast, T>, CompileErrors> {
|
||||
let ast = pest::generate_ast(&source)
|
||||
.map_err(|e| CompileErrors::from(CompileErrorInner::from(e).with_context(&location)))?;
|
||||
.map_err(|e| CompileErrors::from(CompileErrorInner::from(e).in_file(&location)))?;
|
||||
let module_without_imports: Module<T> = Module::from(ast);
|
||||
|
||||
Importer::new().apply_imports(
|
||||
|
|
|
@ -11,7 +11,7 @@ use zokrates_field::field::Field;
|
|||
|
||||
/// A low level function that contains non-deterministic introduction of variables. It is carried out as is until
|
||||
/// the flattening step when it can be inlined.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq, Hash)]
|
||||
pub enum FlatEmbed {
|
||||
Sha256Round,
|
||||
Unpack,
|
||||
|
|
|
@ -12,6 +12,7 @@ use crate::parser::Position;
|
|||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use typed_arena::Arena;
|
||||
use zokrates_field::field::Field;
|
||||
|
@ -54,9 +55,11 @@ impl From<io::Error> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
type ImportPath<'ast> = &'ast Path;
|
||||
|
||||
#[derive(PartialEq, Clone)]
|
||||
pub struct Import<'ast> {
|
||||
source: Identifier<'ast>,
|
||||
source: ImportPath<'ast>,
|
||||
symbol: Option<Identifier<'ast>>,
|
||||
alias: Option<Identifier<'ast>>,
|
||||
}
|
||||
|
@ -64,7 +67,7 @@ pub struct Import<'ast> {
|
|||
pub type ImportNode<'ast> = Node<Import<'ast>>;
|
||||
|
||||
impl<'ast> Import<'ast> {
|
||||
pub fn new(symbol: Option<Identifier<'ast>>, source: Identifier<'ast>) -> Import<'ast> {
|
||||
pub fn new(symbol: Option<Identifier<'ast>>, source: ImportPath<'ast>) -> Import<'ast> {
|
||||
Import {
|
||||
symbol,
|
||||
source,
|
||||
|
@ -78,7 +81,7 @@ impl<'ast> Import<'ast> {
|
|||
|
||||
pub fn new_with_alias(
|
||||
symbol: Option<Identifier<'ast>>,
|
||||
source: Identifier<'ast>,
|
||||
source: ImportPath<'ast>,
|
||||
alias: Identifier<'ast>,
|
||||
) -> Import<'ast> {
|
||||
Import {
|
||||
|
@ -93,7 +96,7 @@ impl<'ast> Import<'ast> {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn get_source(&self) -> &Identifier<'ast> {
|
||||
pub fn get_source(&self) -> &ImportPath<'ast> {
|
||||
&self.source
|
||||
}
|
||||
}
|
||||
|
@ -101,8 +104,8 @@ impl<'ast> Import<'ast> {
|
|||
impl<'ast> fmt::Display for Import<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self.alias {
|
||||
Some(ref alias) => write!(f, "import {} as {}", self.source, alias),
|
||||
None => write!(f, "import {}", self.source),
|
||||
Some(ref alias) => write!(f, "import {} as {}", self.source.display(), alias),
|
||||
None => write!(f, "import {}", self.source.display()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -110,8 +113,13 @@ impl<'ast> fmt::Display for Import<'ast> {
|
|||
impl<'ast> fmt::Debug for Import<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self.alias {
|
||||
Some(ref alias) => write!(f, "import(source: {}, alias: {})", self.source, alias),
|
||||
None => write!(f, "import(source: {})", self.source),
|
||||
Some(ref alias) => write!(
|
||||
f,
|
||||
"import(source: {}, alias: {})",
|
||||
self.source.display(),
|
||||
alias
|
||||
),
|
||||
None => write!(f, "import(source: {})", self.source.display()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -126,7 +134,7 @@ impl Importer {
|
|||
pub fn apply_imports<'ast, T: Field, E: Into<Error>>(
|
||||
&self,
|
||||
destination: Module<'ast, T>,
|
||||
location: String,
|
||||
location: PathBuf,
|
||||
resolve_option: Option<Resolve<E>>,
|
||||
modules: &mut HashMap<ModuleId, Module<'ast, T>>,
|
||||
arena: &'ast Arena<String>,
|
||||
|
@ -139,7 +147,7 @@ impl Importer {
|
|||
let alias = import.alias;
|
||||
// handle the case of special bellman and packing imports
|
||||
if import.source.starts_with("EMBED") {
|
||||
match import.source.as_ref() {
|
||||
match import.source.to_str().unwrap() {
|
||||
"EMBED/sha256round" => {
|
||||
let alias = alias.unwrap_or("sha256round");
|
||||
|
||||
|
@ -166,16 +174,16 @@ impl Importer {
|
|||
return Err(CompileErrorInner::ImportError(
|
||||
Error::new(format!("Embed {} not found. Options are \"EMBED/sha256round\", \"EMBED/unpack\"", s)).with_pos(Some(pos)),
|
||||
)
|
||||
.with_context(&location)
|
||||
.in_file(&location)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// to resolve imports, we need a resolver
|
||||
let folder = std::path::PathBuf::from(location.clone()).parent().unwrap().to_path_buf().into_os_string().into_string().unwrap();
|
||||
let folder = location.clone().parent().unwrap();
|
||||
match resolve_option {
|
||||
Some(resolve) => match resolve(folder, import.source.to_string()) {
|
||||
Ok((source, location)) => {
|
||||
Some(resolve) => match resolve(location.clone(), import.source.to_path_buf()) {
|
||||
Ok((source, new_location)) => {
|
||||
let source = arena.alloc(source);
|
||||
|
||||
// generate an alias from the imported path if none was given explicitely
|
||||
|
@ -185,19 +193,23 @@ impl Importer {
|
|||
.ok_or(CompileErrors::from(
|
||||
CompileErrorInner::ImportError(Error::new(format!(
|
||||
"Could not determine alias for import {}",
|
||||
import.source
|
||||
import.source.display()
|
||||
)))
|
||||
.with_context(&location),
|
||||
.in_file(&location),
|
||||
))?
|
||||
.to_str()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let compiled =
|
||||
compile_module(source, location, resolve_option, modules, &arena)
|
||||
.map_err(|e| e.with_context(import.source.to_string()))?;
|
||||
let compiled = compile_module(
|
||||
source,
|
||||
new_location.clone(),
|
||||
resolve_option,
|
||||
modules,
|
||||
&arena,
|
||||
)?;
|
||||
|
||||
modules.insert(import.source.to_string(), compiled);
|
||||
assert!(modules.insert(new_location.clone(), compiled).is_none());
|
||||
|
||||
symbols.push(
|
||||
SymbolDeclaration {
|
||||
|
@ -205,7 +217,7 @@ impl Importer {
|
|||
symbol: Symbol::There(
|
||||
SymbolImport::with_id_in_module(
|
||||
import.symbol.unwrap_or("main"),
|
||||
import.source.clone(),
|
||||
new_location.display().to_string(),
|
||||
)
|
||||
.start_end(pos.0, pos.1),
|
||||
),
|
||||
|
@ -217,7 +229,7 @@ impl Importer {
|
|||
return Err(CompileErrorInner::ImportError(
|
||||
err.into().with_pos(Some(pos)),
|
||||
)
|
||||
.with_context(&location)
|
||||
.in_file(&location)
|
||||
.into());
|
||||
}
|
||||
},
|
||||
|
@ -225,7 +237,7 @@ impl Importer {
|
|||
return Err(CompileErrorInner::from(Error::new(
|
||||
"Can't resolve import without a resolver",
|
||||
))
|
||||
.with_context(&location)
|
||||
.in_file(&location)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ use crate::typed_absy::*;
|
|||
use crate::typed_absy::{Parameter, Variable};
|
||||
use std::collections::{hash_map::Entry, BTreeSet, HashMap, HashSet};
|
||||
use std::fmt;
|
||||
use std::path::PathBuf;
|
||||
use zokrates_field::field::Field;
|
||||
|
||||
use crate::parser::Position;
|
||||
|
@ -29,14 +30,14 @@ pub struct ErrorInner {
|
|||
#[derive(PartialEq, Debug)]
|
||||
pub struct Error {
|
||||
pub inner: ErrorInner,
|
||||
pub module_id: ModuleId
|
||||
pub module_id: PathBuf,
|
||||
}
|
||||
|
||||
impl ErrorInner {
|
||||
fn with_module_id(self, id: &ModuleId) -> Error {
|
||||
fn in_file(self, id: &ModuleId) -> Error {
|
||||
Error {
|
||||
inner: self,
|
||||
module_id: id.clone()
|
||||
module_id: id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -254,11 +255,21 @@ impl<'ast> Checker<'ast> {
|
|||
|
||||
let main_id = program.main.clone();
|
||||
|
||||
Checker::check_single_main(state.typed_modules.get(&program.main).unwrap())
|
||||
.map_err(|inner| vec![Error { inner, module_id: main_id }])?;
|
||||
Checker::check_single_main(
|
||||
state
|
||||
.typed_modules
|
||||
.get(&program.main.display().to_string())
|
||||
.unwrap(),
|
||||
)
|
||||
.map_err(|inner| {
|
||||
vec![Error {
|
||||
inner,
|
||||
module_id: main_id,
|
||||
}]
|
||||
})?;
|
||||
|
||||
Ok(TypedProgram {
|
||||
main: program.main,
|
||||
main: program.main.display().to_string(),
|
||||
modules: state.typed_modules,
|
||||
})
|
||||
}
|
||||
|
@ -325,13 +336,16 @@ impl<'ast> Checker<'ast> {
|
|||
match self.check_struct_type_declaration(t.clone(), module_id, &state.types) {
|
||||
Ok(ty) => {
|
||||
match symbol_unifier.insert_type(declaration.id) {
|
||||
false => errors.push(ErrorInner {
|
||||
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"{} conflicts with another symbol",
|
||||
declaration.id,
|
||||
)}.with_module_id(module_id)),
|
||||
false => errors.push(
|
||||
ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"{} conflicts with another symbol",
|
||||
declaration.id,
|
||||
),
|
||||
}
|
||||
.in_file(module_id),
|
||||
),
|
||||
true => {}
|
||||
};
|
||||
state
|
||||
|
@ -340,17 +354,25 @@ impl<'ast> Checker<'ast> {
|
|||
.or_default()
|
||||
.insert(declaration.id.to_string(), ty);
|
||||
}
|
||||
Err(e) => errors.extend(e.into_iter().map(|inner| Error {inner, module_id: module_id.clone() })),
|
||||
Err(e) => errors.extend(e.into_iter().map(|inner| Error {
|
||||
inner,
|
||||
module_id: module_id.clone(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
Symbol::HereFunction(f) => match self.check_function(f, module_id, &state.types) {
|
||||
Ok(funct) => {
|
||||
match symbol_unifier.insert_function(declaration.id, funct.signature.clone()) {
|
||||
false => errors.push(
|
||||
ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("{} conflicts with another symbol", declaration.id,)
|
||||
}.with_module_id(module_id)),
|
||||
ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"{} conflicts with another symbol",
|
||||
declaration.id,
|
||||
),
|
||||
}
|
||||
.in_file(module_id),
|
||||
),
|
||||
true => {}
|
||||
};
|
||||
|
||||
|
@ -365,7 +387,7 @@ impl<'ast> Checker<'ast> {
|
|||
);
|
||||
}
|
||||
Err(e) => {
|
||||
errors.extend(e.into_iter().map(|inner| inner.with_module_id(module_id)));
|
||||
errors.extend(e.into_iter().map(|inner| inner.in_file(module_id)));
|
||||
}
|
||||
},
|
||||
Symbol::There(import) => {
|
||||
|
@ -377,7 +399,7 @@ impl<'ast> Checker<'ast> {
|
|||
// find candidates in the checked module
|
||||
let function_candidates: Vec<_> = state
|
||||
.typed_modules
|
||||
.get(&import.module_id)
|
||||
.get(&import.module_id.display().to_string())
|
||||
.unwrap()
|
||||
.functions
|
||||
.iter()
|
||||
|
@ -424,9 +446,9 @@ impl<'ast> Checker<'ast> {
|
|||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Could not find symbol {} in module {}",
|
||||
import.symbol_id, import.module_id,
|
||||
import.symbol_id, import.module_id.display(),
|
||||
),
|
||||
}.with_module_id(module_id));
|
||||
}.in_file(module_id));
|
||||
}
|
||||
(_, Some(_)) => unreachable!("collision in module we're importing from should have been caught when checking it"),
|
||||
_ => {
|
||||
|
@ -440,7 +462,7 @@ impl<'ast> Checker<'ast> {
|
|||
"{} conflicts with another symbol",
|
||||
declaration.id,
|
||||
),
|
||||
}.with_module_id(module_id));
|
||||
}.in_file(module_id));
|
||||
},
|
||||
true => {}
|
||||
};
|
||||
|
@ -450,7 +472,7 @@ impl<'ast> Checker<'ast> {
|
|||
candidate.clone().id(declaration.id),
|
||||
TypedFunctionSymbol::There(
|
||||
candidate,
|
||||
import.module_id.clone(),
|
||||
import.module_id.clone().display().to_string(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
@ -465,10 +487,16 @@ impl<'ast> Checker<'ast> {
|
|||
Symbol::Flat(funct) => {
|
||||
match symbol_unifier.insert_function(declaration.id, funct.signature::<T>()) {
|
||||
false => {
|
||||
errors.push(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("{} conflicts with another symbol", declaration.id,),
|
||||
}.with_module_id(module_id));
|
||||
errors.push(
|
||||
ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"{} conflicts with another symbol",
|
||||
declaration.id,
|
||||
),
|
||||
}
|
||||
.in_file(module_id),
|
||||
);
|
||||
}
|
||||
true => {}
|
||||
};
|
||||
|
@ -548,7 +576,7 @@ impl<'ast> Checker<'ast> {
|
|||
// there should be no checked module at that key just yet, if there is we have a colision or we checked something twice
|
||||
assert!(state
|
||||
.typed_modules
|
||||
.insert(module_id.clone(), typed_module)
|
||||
.insert(module_id.clone().display().to_string(), typed_module)
|
||||
.is_none());
|
||||
}
|
||||
None => {}
|
||||
|
@ -865,6 +893,37 @@ impl<'ast> Checker<'ast> {
|
|||
|
||||
let var = self.check_variable(var, module_id, types).unwrap();
|
||||
|
||||
let from = self
|
||||
.check_expression(from, module_id, &types)
|
||||
.map_err(|e| vec![e])?;
|
||||
let to = self
|
||||
.check_expression(to, module_id, &types)
|
||||
.map_err(|e| vec![e])?;
|
||||
|
||||
let from = match from {
|
||||
TypedExpression::FieldElement(e) => Ok(e),
|
||||
e => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected lower loop bound to be of type field, found {}",
|
||||
e.get_type()
|
||||
),
|
||||
}),
|
||||
}
|
||||
.map_err(|e| vec![e])?;
|
||||
|
||||
let to = match to {
|
||||
TypedExpression::FieldElement(e) => Ok(e),
|
||||
e => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Expected higher loop bound to be of type field, found {}",
|
||||
e.get_type()
|
||||
),
|
||||
}),
|
||||
}
|
||||
.map_err(|e| vec![e])?;
|
||||
|
||||
self.insert_into_scope(var.clone());
|
||||
|
||||
let mut checked_statements = vec![];
|
||||
|
@ -2663,8 +2722,8 @@ mod tests {
|
|||
let foo_statements = vec![
|
||||
Statement::For(
|
||||
absy::Variable::new("i", UnresolvedType::FieldElement.mock()).mock(),
|
||||
FieldPrime::from(0),
|
||||
FieldPrime::from(10),
|
||||
Expression::FieldConstant(FieldPrime::from(0)).mock(),
|
||||
Expression::FieldConstant(FieldPrime::from(10)).mock(),
|
||||
vec![],
|
||||
)
|
||||
.mock(),
|
||||
|
@ -2721,8 +2780,8 @@ mod tests {
|
|||
|
||||
let foo_statements = vec![Statement::For(
|
||||
absy::Variable::new("i", UnresolvedType::FieldElement.mock()).mock(),
|
||||
FieldPrime::from(0),
|
||||
FieldPrime::from(10),
|
||||
Expression::FieldConstant(FieldPrime::from(0)).mock(),
|
||||
Expression::FieldConstant(FieldPrime::from(10)).mock(),
|
||||
for_statements,
|
||||
)
|
||||
.mock()];
|
||||
|
@ -2737,8 +2796,8 @@ mod tests {
|
|||
|
||||
let foo_statements_checked = vec![TypedStatement::For(
|
||||
typed_absy::Variable::field_element("i".into()),
|
||||
FieldPrime::from(0),
|
||||
FieldPrime::from(10),
|
||||
FieldElementExpression::Number(FieldPrime::from(0)),
|
||||
FieldElementExpression::Number(FieldPrime::from(10)),
|
||||
for_statements_checked,
|
||||
)];
|
||||
|
||||
|
@ -3301,14 +3360,15 @@ mod tests {
|
|||
&module_id,
|
||||
&types,
|
||||
);
|
||||
let s2_checked: Result<TypedStatement<FieldPrime>, Vec<ErrorInner>> = checker.check_statement(
|
||||
Statement::Declaration(
|
||||
absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(),
|
||||
)
|
||||
.mock(),
|
||||
&module_id,
|
||||
&types,
|
||||
);
|
||||
let s2_checked: Result<TypedStatement<FieldPrime>, Vec<ErrorInner>> = checker
|
||||
.check_statement(
|
||||
Statement::Declaration(
|
||||
absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(),
|
||||
)
|
||||
.mock(),
|
||||
&module_id,
|
||||
&types,
|
||||
);
|
||||
assert_eq!(
|
||||
s2_checked,
|
||||
Err(vec![ErrorInner {
|
||||
|
@ -3337,12 +3397,15 @@ mod tests {
|
|||
&module_id,
|
||||
&types,
|
||||
);
|
||||
let s2_checked: Result<TypedStatement<FieldPrime>, Vec<ErrorInner>> = checker.check_statement(
|
||||
Statement::Declaration(absy::Variable::new("a", UnresolvedType::Boolean.mock()).mock())
|
||||
let s2_checked: Result<TypedStatement<FieldPrime>, Vec<ErrorInner>> = checker
|
||||
.check_statement(
|
||||
Statement::Declaration(
|
||||
absy::Variable::new("a", UnresolvedType::Boolean.mock()).mock(),
|
||||
)
|
||||
.mock(),
|
||||
&module_id,
|
||||
&types,
|
||||
);
|
||||
&module_id,
|
||||
&types,
|
||||
);
|
||||
assert_eq!(
|
||||
s2_checked,
|
||||
Err(vec![ErrorInner {
|
||||
|
|
|
@ -7,13 +7,14 @@
|
|||
mod constrain_inputs;
|
||||
mod flat_propagation;
|
||||
mod inline;
|
||||
mod propagate_unroll;
|
||||
mod propagation;
|
||||
mod unroll;
|
||||
|
||||
use self::constrain_inputs::InputConstrainer;
|
||||
use self::inline::Inliner;
|
||||
use self::propagate_unroll::PropagatedUnroller;
|
||||
use self::propagation::Propagator;
|
||||
use self::unroll::Unroller;
|
||||
use crate::flat_absy::FlatProg;
|
||||
use crate::typed_absy::TypedProgram;
|
||||
use zokrates_field::field::Field;
|
||||
|
@ -24,14 +25,15 @@ pub trait Analyse {
|
|||
|
||||
impl<'ast, T: Field> Analyse for TypedProgram<'ast, T> {
|
||||
fn analyse(self) -> Self {
|
||||
// unroll
|
||||
let r = Unroller::unroll(self);
|
||||
// propagated unrolling
|
||||
let r = PropagatedUnroller::unroll(self).unwrap_or_else(|e| panic!(e));
|
||||
// inline
|
||||
let r = Inliner::inline(r);
|
||||
// propagate
|
||||
let r = Propagator::propagate(r);
|
||||
// constrain inputs
|
||||
let r = InputConstrainer::constrain(r);
|
||||
|
||||
r
|
||||
}
|
||||
}
|
||||
|
|
235
zokrates_core/src/static_analysis/propagate_unroll.rs
Normal file
235
zokrates_core/src/static_analysis/propagate_unroll.rs
Normal file
|
@ -0,0 +1,235 @@
|
|||
//! Module containing iterative unrolling in order to unroll nested loops with variable bounds
|
||||
//!
|
||||
//! For example:
|
||||
//! ```zokrates
|
||||
//! for field i in 0..5 do
|
||||
//! for field j in i..5 do
|
||||
//! //
|
||||
//! endfor
|
||||
//! endfor
|
||||
//! ```
|
||||
//!
|
||||
//! We can unroll the outer loop, but to unroll the inner one we need to propagate the value of `i` to the lower bound of the loop
|
||||
//!
|
||||
//! This module does exactly that:
|
||||
//! - unroll the outter loop, detecting that it cannot unroll the inner one as the lower `i` bound isn't constant
|
||||
//! - apply constant propagation to the program, *not visiting statements of loops whose bounds are not constant yet*
|
||||
//! - unroll again, this time the 5 inner loops all have constant bounds
|
||||
//!
|
||||
//! In the case that a loop bound cannot be reduced to a constant, we detect it by noticing that the unroll does
|
||||
//! not make progress anymore.
|
||||
|
||||
use static_analysis::propagation::Propagator;
|
||||
use static_analysis::unroll::{Output, Unroller};
|
||||
use typed_absy::TypedProgram;
|
||||
use zokrates_field::field::Field;
|
||||
|
||||
pub struct PropagatedUnroller;
|
||||
|
||||
impl PropagatedUnroller {
|
||||
pub fn unroll<'ast, T: Field>(
|
||||
p: TypedProgram<'ast, T>,
|
||||
) -> Result<TypedProgram<'ast, T>, &'static str> {
|
||||
let mut blocked_at = None;
|
||||
|
||||
// unroll a first time, retrieving whether the unroll is complete
|
||||
let mut unrolled = Unroller::unroll(p);
|
||||
|
||||
loop {
|
||||
// conditions to exit the loop
|
||||
unrolled = match unrolled {
|
||||
Output::Complete(p) => return Ok(p),
|
||||
Output::Incomplete(next, index) => {
|
||||
if Some(index) == blocked_at {
|
||||
return Err("Loop unrolling failed. This happened because a loop bound is not constant");
|
||||
} else {
|
||||
// update the index where we blocked
|
||||
blocked_at = Some(index);
|
||||
|
||||
// propagate
|
||||
let propagated = Propagator::propagate_verbose(next);
|
||||
|
||||
// unroll
|
||||
Unroller::unroll(propagated)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use typed_absy::types::{FunctionKey, Signature};
|
||||
use typed_absy::*;
|
||||
use zokrates_field::field::FieldPrime;
|
||||
|
||||
#[test]
|
||||
fn detect_non_constant_bound() {
|
||||
let loops = vec![TypedStatement::For(
|
||||
Variable::field_element("i".into()),
|
||||
FieldElementExpression::Identifier("i".into()),
|
||||
FieldElementExpression::Number(FieldPrime::from(2)),
|
||||
vec![],
|
||||
)];
|
||||
|
||||
let statements = loops;
|
||||
|
||||
let p = TypedProgram {
|
||||
modules: vec![(
|
||||
"main".to_string(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
FunctionKey::with_id("main"),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
signature: Signature::new(),
|
||||
statements,
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
main: "main".to_string(),
|
||||
};
|
||||
|
||||
assert!(PropagatedUnroller::unroll(p).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn for_loop() {
|
||||
// for field i in 0..2
|
||||
// for field j in i..2
|
||||
// field foo = i + j
|
||||
|
||||
// should be unrolled to
|
||||
// i_0 = 0
|
||||
// j_0 = 0
|
||||
// foo_0 = i_0 + j_0
|
||||
// j_1 = 1
|
||||
// foo_1 = i_0 + j_1
|
||||
// i_1 = 1
|
||||
// j_2 = 1
|
||||
// foo_2 = i_1 + j_1
|
||||
|
||||
let s = TypedStatement::For(
|
||||
Variable::field_element("i".into()),
|
||||
FieldElementExpression::Number(FieldPrime::from(0)),
|
||||
FieldElementExpression::Number(FieldPrime::from(2)),
|
||||
vec![TypedStatement::For(
|
||||
Variable::field_element("j".into()),
|
||||
FieldElementExpression::Identifier("i".into()),
|
||||
FieldElementExpression::Number(FieldPrime::from(2)),
|
||||
vec![
|
||||
TypedStatement::Declaration(Variable::field_element("foo".into())),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("foo".into())),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier("i".into()),
|
||||
box FieldElementExpression::Identifier("j".into()),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
)],
|
||||
);
|
||||
|
||||
let expected_statements = vec![
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("i").version(0),
|
||||
)),
|
||||
FieldElementExpression::Number(FieldPrime::from(0)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("j").version(0),
|
||||
)),
|
||||
FieldElementExpression::Number(FieldPrime::from(0)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("foo").version(0),
|
||||
)),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from("i").version(0)),
|
||||
box FieldElementExpression::Identifier(Identifier::from("j").version(0)),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("j").version(1),
|
||||
)),
|
||||
FieldElementExpression::Number(FieldPrime::from(1)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("foo").version(1),
|
||||
)),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from("i").version(0)),
|
||||
box FieldElementExpression::Identifier(Identifier::from("j").version(1)),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("i").version(1),
|
||||
)),
|
||||
FieldElementExpression::Number(FieldPrime::from(1)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("j").version(2),
|
||||
)),
|
||||
FieldElementExpression::Number(FieldPrime::from(1)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("foo").version(2),
|
||||
)),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from("i").version(1)),
|
||||
box FieldElementExpression::Identifier(Identifier::from("j").version(2)),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
];
|
||||
|
||||
let p = TypedProgram {
|
||||
modules: vec![(
|
||||
"main".to_string(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
FunctionKey::with_id("main"),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
signature: Signature::new(),
|
||||
statements: vec![s],
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
main: "main".to_string(),
|
||||
};
|
||||
|
||||
let statements = match PropagatedUnroller::unroll(p).unwrap().modules["main"].functions
|
||||
[&FunctionKey::with_id("main")]
|
||||
.clone()
|
||||
{
|
||||
TypedFunctionSymbol::Here(f) => f.statements,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
assert_eq!(statements, expected_statements);
|
||||
}
|
||||
}
|
|
@ -1,5 +1,11 @@
|
|||
//! Module containing constant propagation for the typed AST
|
||||
//!
|
||||
//! On top of the usual behavior of removing statements which assign a constant to a variable (as the variable can simply be
|
||||
//! substituted for the constant whenever used), we provide a `verbose` mode which does not remove such statements. This is done
|
||||
//! as for partial passes which do not visit the whole program, the variables being defined may be be used in parts of the program
|
||||
//! that are not visited. Keeping the statements is semantically equivalent and enables rebuilding the set of constants at the
|
||||
//! next pass.
|
||||
//!
|
||||
//! @file propagation.rs
|
||||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
|
@ -12,19 +18,36 @@ use typed_absy::types::{StructMember, Type};
|
|||
use zokrates_field::field::Field;
|
||||
|
||||
pub struct Propagator<'ast, T: Field> {
|
||||
// constants keeps track of constant expressions
|
||||
// we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is
|
||||
constants: HashMap<TypedAssignee<'ast, T>, TypedExpression<'ast, T>>,
|
||||
// the verbose mode doesn't remove statements which assign constants to variables
|
||||
// it's required when using propagation in combination with unrolling
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Propagator<'ast, T> {
|
||||
fn verbose() -> Self {
|
||||
Propagator {
|
||||
constants: HashMap::new(),
|
||||
verbose: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn new() -> Self {
|
||||
Propagator {
|
||||
constants: HashMap::new(),
|
||||
verbose: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn propagate(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
Propagator::new().fold_program(p)
|
||||
}
|
||||
|
||||
pub fn propagate_verbose(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
Propagator::verbose().fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
fn is_constant<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> bool {
|
||||
|
@ -63,8 +86,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
let expr = self.fold_expression(expr);
|
||||
|
||||
if is_constant(&expr) {
|
||||
self.constants.insert(TypedAssignee::Identifier(var), expr);
|
||||
None
|
||||
self.constants
|
||||
.insert(TypedAssignee::Identifier(var.clone()), expr.clone());
|
||||
match self.verbose {
|
||||
true => Some(TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(var),
|
||||
expr,
|
||||
)),
|
||||
false => None,
|
||||
}
|
||||
} else {
|
||||
Some(TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(var),
|
||||
|
@ -86,9 +116,17 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
self.fold_expression(e2),
|
||||
))
|
||||
}
|
||||
// we unrolled for loops in the previous step
|
||||
TypedStatement::For(..) => {
|
||||
unreachable!("for loop is unexpected, it should have been unrolled")
|
||||
// only loops with variable bounds are expected here
|
||||
// we stop propagation here as constants maybe be modified inside the loop body
|
||||
// which we do not visit
|
||||
TypedStatement::For(v, from, to, statements) => {
|
||||
let from = self.fold_field_expression(from);
|
||||
let to = self.fold_field_expression(to);
|
||||
|
||||
// invalidate the constants map as any constant could be modified inside the loop body, which we don't visit
|
||||
self.constants.clear();
|
||||
|
||||
Some(TypedStatement::For(v, from, to, statements))
|
||||
}
|
||||
TypedStatement::MultipleDefinition(variables, expression_list) => {
|
||||
let expression_list = self.fold_expression_list(expression_list);
|
||||
|
@ -98,6 +136,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
))
|
||||
}
|
||||
};
|
||||
|
||||
// In verbose mode, we always return a statement
|
||||
assert!(res.is_some() || !self.verbose);
|
||||
|
||||
match res {
|
||||
Some(v) => vec![v],
|
||||
None => vec![],
|
||||
|
|
|
@ -11,19 +11,30 @@ use std::collections::HashMap;
|
|||
use std::collections::HashSet;
|
||||
use zokrates_field::field::Field;
|
||||
|
||||
pub enum Output<'ast, T: Field> {
|
||||
Complete(TypedProgram<'ast, T>),
|
||||
Incomplete(TypedProgram<'ast, T>, usize),
|
||||
}
|
||||
|
||||
pub struct Unroller<'ast> {
|
||||
substitution: HashMap<Identifier<'ast>, usize>,
|
||||
// version index for any variable name
|
||||
substitution: HashMap<&'ast str, usize>,
|
||||
// whether all statements could be unrolled so far. Loops with variable bounds cannot.
|
||||
complete: bool,
|
||||
statement_count: usize,
|
||||
}
|
||||
|
||||
impl<'ast> Unroller<'ast> {
|
||||
fn new() -> Self {
|
||||
Unroller {
|
||||
substitution: HashMap::new(),
|
||||
complete: true,
|
||||
statement_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn issue_next_ssa_variable(&mut self, v: Variable<'ast>) -> Variable<'ast> {
|
||||
let res = match self.substitution.get(&v.id) {
|
||||
let res = match self.substitution.get(&v.id.id) {
|
||||
Some(i) => Variable {
|
||||
id: Identifier {
|
||||
id: v.id.id,
|
||||
|
@ -34,15 +45,22 @@ impl<'ast> Unroller<'ast> {
|
|||
},
|
||||
None => Variable { ..v.clone() },
|
||||
};
|
||||
|
||||
self.substitution
|
||||
.entry(v.id)
|
||||
.entry(v.id.id)
|
||||
.and_modify(|e| *e += 1)
|
||||
.or_insert(0);
|
||||
res
|
||||
}
|
||||
|
||||
pub fn unroll<T: Field>(p: TypedProgram<T>) -> TypedProgram<T> {
|
||||
Unroller::new().fold_program(p)
|
||||
pub fn unroll<T: Field>(p: TypedProgram<T>) -> Output<T> {
|
||||
let mut unroller = Unroller::new();
|
||||
let p = unroller.fold_program(p);
|
||||
|
||||
match unroller.complete {
|
||||
true => Output::Complete(p),
|
||||
false => Output::Incomplete(p, unroller.statement_count),
|
||||
}
|
||||
}
|
||||
|
||||
fn choose_many<T: Field>(
|
||||
|
@ -322,6 +340,7 @@ fn linear<'ast, T: Field>(a: TypedAssignee<'ast, T>) -> (Variable, Vec<Access<'a
|
|||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
self.statement_count += 1;
|
||||
match s {
|
||||
TypedStatement::Declaration(_) => vec![],
|
||||
TypedStatement::Definition(assignee, expr) => {
|
||||
|
@ -349,6 +368,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
|
|||
};
|
||||
|
||||
let base = self.fold_expression(base);
|
||||
|
||||
let indices = indices
|
||||
.into_iter()
|
||||
.map(|a| match a {
|
||||
|
@ -378,34 +398,45 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
|
|||
vec![TypedStatement::MultipleDefinition(variables, exprs)]
|
||||
}
|
||||
TypedStatement::For(v, from, to, stats) => {
|
||||
let mut values: Vec<T> = vec![];
|
||||
let mut current = from;
|
||||
while current < to {
|
||||
values.push(current.clone());
|
||||
current = T::one() + ¤t;
|
||||
let from = self.fold_field_expression(from);
|
||||
let to = self.fold_field_expression(to);
|
||||
|
||||
match (from, to) {
|
||||
(FieldElementExpression::Number(from), FieldElementExpression::Number(to)) => {
|
||||
let mut values: Vec<T> = vec![];
|
||||
let mut current = from;
|
||||
while current < to {
|
||||
values.push(current.clone());
|
||||
current = T::one() + ¤t;
|
||||
}
|
||||
|
||||
let res = values
|
||||
.into_iter()
|
||||
.map(|index| {
|
||||
vec![
|
||||
vec![
|
||||
TypedStatement::Declaration(v.clone()),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(v.clone()),
|
||||
FieldElementExpression::Number(index).into(),
|
||||
),
|
||||
],
|
||||
stats.clone(),
|
||||
]
|
||||
.into_iter()
|
||||
.flat_map(|x| x)
|
||||
})
|
||||
.flat_map(|x| x)
|
||||
.flat_map(|x| self.fold_statement(x))
|
||||
.collect();
|
||||
|
||||
res
|
||||
}
|
||||
(from, to) => {
|
||||
self.complete = false;
|
||||
vec![TypedStatement::For(v, from, to, stats)]
|
||||
}
|
||||
}
|
||||
|
||||
let res = values
|
||||
.into_iter()
|
||||
.map(|index| {
|
||||
vec![
|
||||
vec![
|
||||
TypedStatement::Declaration(v.clone()),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(v.clone()),
|
||||
FieldElementExpression::Number(index).into(),
|
||||
),
|
||||
],
|
||||
stats.clone(),
|
||||
]
|
||||
.into_iter()
|
||||
.flat_map(|x| x)
|
||||
})
|
||||
.flat_map(|x| x)
|
||||
.flat_map(|x| self.fold_statement(x))
|
||||
.collect();
|
||||
|
||||
res
|
||||
}
|
||||
s => fold_statement(self, s),
|
||||
}
|
||||
|
@ -414,7 +445,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
|
|||
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
|
||||
self.substitution = HashMap::new();
|
||||
for arg in &f.arguments {
|
||||
self.substitution.insert(arg.id.id.clone(), 0);
|
||||
self.substitution.insert(arg.id.id.id.clone(), 0);
|
||||
}
|
||||
|
||||
fold_function(self, f)
|
||||
|
@ -422,7 +453,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> {
|
|||
|
||||
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
|
||||
Identifier {
|
||||
version: self.substitution.get(&n).unwrap_or(&0).clone(),
|
||||
version: self.substitution.get(&n.id).unwrap_or(&0).clone(),
|
||||
..n
|
||||
}
|
||||
}
|
||||
|
@ -693,8 +724,8 @@ mod tests {
|
|||
|
||||
let s = TypedStatement::For(
|
||||
Variable::field_element("i".into()),
|
||||
FieldPrime::from(2),
|
||||
FieldPrime::from(5),
|
||||
FieldElementExpression::Number(FieldPrime::from(2)),
|
||||
FieldElementExpression::Number(FieldPrime::from(5)),
|
||||
vec![
|
||||
TypedStatement::Declaration(Variable::field_element("foo".into())),
|
||||
TypedStatement::Definition(
|
||||
|
@ -748,6 +779,46 @@ mod tests {
|
|||
assert_eq!(u.fold_statement(s), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn idempotence() {
|
||||
// an already unrolled program should not be modified by unrolling again
|
||||
|
||||
// a = 5
|
||||
// a_1 = 6
|
||||
// a_2 = 7
|
||||
|
||||
// should be turned into
|
||||
// a = 5
|
||||
// a_1 = 6
|
||||
// a_2 = 7
|
||||
|
||||
let mut u = Unroller::new();
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("a").version(0),
|
||||
)),
|
||||
FieldElementExpression::Number(FieldPrime::from(5)).into(),
|
||||
);
|
||||
assert_eq!(u.fold_statement(s.clone()), vec![s]);
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("a").version(1),
|
||||
)),
|
||||
FieldElementExpression::Number(FieldPrime::from(6)).into(),
|
||||
);
|
||||
assert_eq!(u.fold_statement(s.clone()), vec![s]);
|
||||
|
||||
let s = TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("a").version(2),
|
||||
)),
|
||||
FieldElementExpression::Number(FieldPrime::from(7)).into(),
|
||||
);
|
||||
assert_eq!(u.fold_statement(s.clone()), vec![s]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn definition() {
|
||||
// field a
|
||||
|
|
|
@ -229,7 +229,7 @@ impl<'ast, T: Field> fmt::Display for TypedFunction<'ast, T> {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"({}) -> ({}):\n{}",
|
||||
"({}) -> ({}):",
|
||||
self.arguments
|
||||
.iter()
|
||||
.map(|x| format!("{}", x))
|
||||
|
@ -241,12 +241,16 @@ impl<'ast, T: Field> fmt::Display for TypedFunction<'ast, T> {
|
|||
.map(|x| format!("{}", x))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
self.statements
|
||||
.iter()
|
||||
.map(|x| format!("\t{}", x))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
)
|
||||
)?;
|
||||
|
||||
writeln!(f, "")?;
|
||||
|
||||
for s in &self.statements {
|
||||
s.fmt_indented(f, 1)?;
|
||||
writeln!(f, "")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -326,7 +330,12 @@ pub enum TypedStatement<'ast, T: Field> {
|
|||
Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
|
||||
Declaration(Variable<'ast>),
|
||||
Condition(TypedExpression<'ast, T>, TypedExpression<'ast, T>),
|
||||
For(Variable<'ast>, T, T, Vec<TypedStatement<'ast, T>>),
|
||||
For(
|
||||
Variable<'ast>,
|
||||
FieldElementExpression<'ast, T>,
|
||||
FieldElementExpression<'ast, T>,
|
||||
Vec<TypedStatement<'ast, T>>,
|
||||
),
|
||||
MultipleDefinition(Vec<Variable<'ast>>, TypedExpressionList<'ast, T>),
|
||||
}
|
||||
|
||||
|
@ -364,6 +373,23 @@ impl<'ast, T: Field> fmt::Debug for TypedStatement<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> TypedStatement<'ast, T> {
|
||||
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
|
||||
match self {
|
||||
TypedStatement::For(variable, from, to, statements) => {
|
||||
write!(f, "{}", "\t".repeat(depth))?;
|
||||
writeln!(f, "for {} in {}..{} do", variable, from, to)?;
|
||||
for s in statements {
|
||||
s.fmt_indented(f, depth + 1)?;
|
||||
writeln!(f, "")?;
|
||||
}
|
||||
writeln!(f, "{}endfor", "\t".repeat(depth))
|
||||
}
|
||||
s => write!(f, "{}{}", "\t".repeat(depth), s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> fmt::Display for TypedStatement<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
|
|
2
zokrates_core_test/tests/tests/import/dep/dep/foo.zok
Normal file
2
zokrates_core_test/tests/tests/import/dep/dep/foo.zok
Normal file
|
@ -0,0 +1,2 @@
|
|||
def foo() -> (field):
|
||||
return 1
|
4
zokrates_core_test/tests/tests/import/dep/foo.zok
Normal file
4
zokrates_core_test/tests/tests/import/dep/foo.zok
Normal file
|
@ -0,0 +1,4 @@
|
|||
from "./dep/foo.zok" import foo as bar
|
||||
|
||||
def foo() -> (field):
|
||||
return 2 + bar()
|
15
zokrates_core_test/tests/tests/import/import.json
Normal file
15
zokrates_core_test/tests/tests/import/import.json
Normal file
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/import/import.zok",
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": []
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["3"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
4
zokrates_core_test/tests/tests/import/import.zok
Normal file
4
zokrates_core_test/tests/tests/import/import.zok
Normal file
|
@ -0,0 +1,4 @@
|
|||
from "./dep/foo.zok" import foo
|
||||
|
||||
def main() -> (field):
|
||||
return foo()
|
25
zokrates_core_test/tests/tests/nested_loop.json
Normal file
25
zokrates_core_test/tests/tests/nested_loop.json
Normal file
|
@ -0,0 +1,25 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/nested_loop.zok",
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["1", "2", "3", "4"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["4838400", "10", "25"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["0", "1", "2", "3"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"values": ["0", "10", "25"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
27
zokrates_core_test/tests/tests/nested_loop.zok
Normal file
27
zokrates_core_test/tests/tests/nested_loop.zok
Normal file
|
@ -0,0 +1,27 @@
|
|||
def main(field[4] values) -> (field, field, field):
|
||||
field res0 = 1
|
||||
field res1 = 0
|
||||
|
||||
field counter = 0
|
||||
|
||||
for field i in 0..4 do
|
||||
for field j in i..4 do
|
||||
counter = counter + 1
|
||||
res0 = res0 * (values[i] + values[j])
|
||||
endfor
|
||||
endfor
|
||||
|
||||
for field i in 0..counter do
|
||||
res1 = res1 + 1
|
||||
endfor
|
||||
|
||||
field res2 = 0
|
||||
field i = 0
|
||||
for field i in i..5 do
|
||||
i = 5
|
||||
for field i in 0..i do
|
||||
res2 = res2 + 1
|
||||
endfor
|
||||
endfor
|
||||
|
||||
return res0, res1, res2
|
|
@ -6,20 +6,28 @@ use std::path::{Component, PathBuf};
|
|||
|
||||
const ZOKRATES_HOME: &str = &"ZOKRATES_HOME";
|
||||
|
||||
type CurrentLocation = String;
|
||||
type ImportLocation<'a> = String;
|
||||
type CurrentLocation = PathBuf;
|
||||
type ImportLocation<'a> = PathBuf;
|
||||
type SourceCode = String;
|
||||
|
||||
pub fn resolve<'a>(
|
||||
current_location: CurrentLocation,
|
||||
import_location: ImportLocation<'a>,
|
||||
) -> Result<(SourceCode, CurrentLocation), io::Error> {
|
||||
println!(
|
||||
"get file {} {}",
|
||||
current_location.display(),
|
||||
import_location.display()
|
||||
);
|
||||
|
||||
let source = Path::new(&import_location);
|
||||
|
||||
// paths starting with `./` or `../` are interpreted relative to the current file
|
||||
// other paths `abc/def` are interpreted relative to $ZOKRATES_HOME
|
||||
let base = match source.components().next() {
|
||||
Some(Component::CurDir) | Some(Component::ParentDir) => PathBuf::from(current_location),
|
||||
Some(Component::CurDir) | Some(Component::ParentDir) => {
|
||||
PathBuf::from(current_location).parent().unwrap().into()
|
||||
}
|
||||
_ => PathBuf::from(
|
||||
std::env::var(ZOKRATES_HOME).expect("$ZOKRATES_HOME is not set, please set it"),
|
||||
),
|
||||
|
@ -33,16 +41,9 @@ pub fn resolve<'a>(
|
|||
return Err(io::Error::new(io::ErrorKind::Other, "Not a file"));
|
||||
}
|
||||
|
||||
let next_location = generate_next_location(&path_owned)?;
|
||||
let source = read_to_string(path_owned)?;
|
||||
let source = read_to_string(&path_owned)?;
|
||||
|
||||
Ok((source, next_location))
|
||||
}
|
||||
|
||||
fn generate_next_location<'a>(path: &'a PathBuf) -> Result<String, io::Error> {
|
||||
path.parent()
|
||||
.ok_or(io::Error::new(io::ErrorKind::Other, "Invalid path"))
|
||||
.map(|v| v.to_path_buf().into_os_string().into_string().unwrap())
|
||||
Ok((source, path_owned))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -77,17 +77,7 @@ pub fn test_inner(test_path: &str) {
|
|||
|
||||
let code = std::fs::read_to_string(&t.entry_point).unwrap();
|
||||
|
||||
let artifacts = compile(
|
||||
code,
|
||||
t.entry_point
|
||||
.parent()
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
Some(&resolve),
|
||||
)
|
||||
.unwrap();
|
||||
let artifacts = compile(code, t.entry_point.clone(), Some(&resolve)).unwrap();
|
||||
|
||||
let bin = artifacts.prog();
|
||||
|
||||
|
|
Loading…
Reference in a new issue