diff --git a/Cargo.lock b/Cargo.lock index 871121e2..f7e7e752 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1074,6 +1074,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + [[package]] name = "hashbrown" version = "0.11.2" @@ -1845,6 +1851,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.130" @@ -2383,7 +2399,6 @@ name = "zokrates_cli" version = "0.7.7" dependencies = [ "assert_cli", - "bincode", "cfg-if 0.1.10", "clap", "dirs", @@ -2393,8 +2408,11 @@ dependencies = [ "lazy_static", "log", "regex 0.2.11", + "serde", + "serde_cbor", "serde_json", "tempdir", + "typed-arena", "zokrates_abi", "zokrates_core", "zokrates_field", @@ -2545,6 +2563,7 @@ dependencies = [ "serde", "serde_derive", "serde_json", + "typed-arena", "zokrates_abi", "zokrates_core", "zokrates_field", diff --git a/zokrates_cli/Cargo.toml b/zokrates_cli/Cargo.toml index 5c2b04cc..592aef14 100644 --- a/zokrates_cli/Cargo.toml +++ b/zokrates_cli/Cargo.toml @@ -16,13 +16,15 @@ log = "0.4" env_logger = "0.9.0" cfg-if = "0.1" clap = "2.26.2" -bincode = "0.8.0" +serde_cbor = "0.11.2" regex = "0.2" zokrates_field = { version = "0.4", path = "../zokrates_field", default-features = false } zokrates_abi = { version = "0.1", path = "../zokrates_abi" } zokrates_core = { version = "0.6", path = "../zokrates_core", default-features = false } +typed-arena = "1.4.1" zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"} serde_json = "1.0" +serde = { version = "1.0", features = ["derive"] } dirs = "3.0.1" lazy_static = "1.4.0" diff --git a/zokrates_cli/src/ops/compile.rs b/zokrates_cli/src/ops/compile.rs index cfe17781..8fd4071f 100644 --- a/zokrates_cli/src/ops/compile.rs +++ b/zokrates_cli/src/ops/compile.rs @@ -6,10 +6,35 @@ use std::convert::TryFrom; use std::fs::File; use std::io::{BufReader, BufWriter, Read, Write}; use std::path::{Path, PathBuf}; +use typed_arena::Arena; use zokrates_core::compile::{compile, CompilationArtifacts, CompileConfig, CompileError}; use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field}; use zokrates_fs_resolver::FileSystemResolver; +use serde::{Serialize, Serializer}; +use std::cell::Cell; + +pub fn write_as_cbor(out: &mut W, groups: I) -> serde_cbor::Result<()> +where + I: IntoIterator, + P: Serialize, + W: Write, +{ + struct Wrapper(Cell>); + + impl Serialize for Wrapper + where + I: IntoIterator, + P: Serialize, + { + fn serialize(&self, s: S) -> Result { + s.collect_seq(self.0.take().unwrap()) + } + } + + serde_cbor::to_writer(out, &Wrapper(Cell::new(Some(groups)))) +} + pub fn subcommand() -> App<'static, 'static> { SubCommand::with_name("compile") .about("Compiles into flattened conditions. Produces two files: human-readable '.ztf' file for debugging and binary file") @@ -136,8 +161,10 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { log::debug!("Compile"); - let artifacts: CompilationArtifacts = compile(source, path, Some(&resolver), &config) - .map_err(|e| { + let arena = Arena::new(); + + let artifacts = + compile::(source, path, Some(&resolver), config, &arena).map_err(|e| { format!( "Compilation failed:\n\n{}", e.0.iter() @@ -147,10 +174,11 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { ) })?; - let program_flattened = artifacts.prog(); + let program_flattened = artifacts.prog; + let abi = artifacts.abi; // number of constraints the flattened program will translate to. - let num_constraints = program_flattened.constraint_count(); + //let num_constraints = program_flattened.constraint_count(); // serialize flattened program and write to binary file log::debug!("Serialize program"); @@ -159,21 +187,22 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { let mut writer = BufWriter::new(bin_output_file); - program_flattened.serialize(&mut writer); + println!("START WRITING TO DISK"); + + //program_flattened.serialize(&mut writer); + write_as_cbor(&mut writer, program_flattened.statements).unwrap(); // serialize ABI spec and write to JSON file log::debug!("Serialize ABI"); let abi_spec_file = File::create(&abi_spec_path) .map_err(|why| format!("Could not create {}: {}", abi_spec_path.display(), why))?; - let abi = artifacts.abi(); - let mut writer = BufWriter::new(abi_spec_file); to_writer_pretty(&mut writer, &abi).map_err(|_| "Unable to write data to file.".to_string())?; if sub_matches.is_present("verbose") { // debugging output - println!("Compiled program:\n{}", program_flattened); + //println!("Compiled program:\n{}", program_flattened); } println!("Compiled code written to '{}'", bin_output_path.display()); @@ -185,15 +214,15 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { .map_err(|why| format!("Could not create {}: {}", hr_output_path.display(), why))?; let mut hrofb = BufWriter::new(hr_output_file); - writeln!(&mut hrofb, "{}", program_flattened) - .map_err(|_| "Unable to write data to file".to_string())?; - hrofb - .flush() - .map_err(|_| "Unable to flush buffer".to_string())?; + // writeln!(&mut hrofb, "{}", program_flattened) + // .map_err(|_| "Unable to write data to file".to_string())?; + // hrofb + // .flush() + // .map_err(|_| "Unable to flush buffer".to_string())?; println!("Human readable code to '{}'", hr_output_path.display()); } - println!("Number of constraints: {}", num_constraints); + //println!("Number of constraints: {}", num_constraints); Ok(()) } diff --git a/zokrates_cli/src/ops/compute_witness.rs b/zokrates_cli/src/ops/compute_witness.rs index 675ef77d..1f917021 100644 --- a/zokrates_cli/src/ops/compute_witness.rs +++ b/zokrates_cli/src/ops/compute_witness.rs @@ -106,7 +106,7 @@ fn cli_compute(ir_prog: ir::Prog, sub_matches: &ArgMatches) -> Resu } false => ConcreteSignature::new() .inputs(vec![ConcreteType::FieldElement; ir_prog.arguments.len()]) - .outputs(vec![ConcreteType::FieldElement; ir_prog.returns.len()]), + .outputs(vec![ConcreteType::FieldElement; ir_prog.return_count]), }; use zokrates_abi::Inputs; diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 86a9943e..f1fad4dc 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -4,7 +4,7 @@ //! @author Thibaut Schaeffer //! @date 2018 use crate::absy::{Module, OwnedModuleId, Program}; -use crate::flatten::Flattener; +use crate::flatten::FlattenerIterator; use crate::imports::{self, Importer}; use crate::ir; use crate::macros; @@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt; use std::io; +use std::io::Write; use std::path::{Path, PathBuf}; use typed_arena::Arena; use zokrates_common::Resolver; @@ -25,14 +26,14 @@ use zokrates_field::Field; use zokrates_pest_ast as pest; #[derive(Debug)] -pub struct CompilationArtifacts { - prog: ir::Prog, - abi: Abi, +pub struct CompilationArtifacts { + pub prog: ir::ProgIterator, + pub abi: Abi, } -impl CompilationArtifacts { - pub fn prog(&self) -> &ir::Prog { - &self.prog +impl CompilationArtifacts { + pub fn prog(self) -> ir::ProgIterator { + self.prog } pub fn abi(&self) -> &Abi { @@ -183,37 +184,46 @@ impl CompileConfig { type FilePath = PathBuf; -pub fn compile>( +pub fn compile<'ast, T: Field, E: Into>( source: String, location: FilePath, resolver: Option<&dyn Resolver>, - config: &CompileConfig, -) -> Result, CompileErrors> { - let arena = Arena::new(); - - let (typed_ast, abi) = check_with_arena(source, location.clone(), resolver, config, &arena)?; + config: CompileConfig, + arena: &'ast Arena, +) -> Result> + 'ast>, CompileErrors> { + let (typed_ast, abi): (crate::zir::ZirProgram<'_, T>, _) = + check_with_arena(source, location.clone(), resolver, &config, arena)?; // flatten input program log::debug!("Flatten"); - let program_flattened = Flattener::flatten(typed_ast, config); + let program_flattened = FlattenerIterator::from_function_and_config(typed_ast.main, config); - // constant propagation after call resolution - log::debug!("Propagate flat program"); - let program_flattened = program_flattened.propagate(); + // // constant propagation after call resolution + // log::debug!("Propagate flat program"); + // let program_flattened = program_flattened.propagate(); // convert to ir log::debug!("Convert to IR"); - let ir_prog = ir::Prog::from(program_flattened); + let ir_prog = ir::ProgIterator { + arguments: program_flattened + .arguments_flattened + .clone() + .into_iter() + .map(|a| a.into()) + .collect(), + return_count: 0, + statements: ir::from_flat::from_flat(program_flattened), + }; // optimize log::debug!("Optimise IR"); let optimized_ir_prog = ir_prog.optimize(); // analyse ir (check constraints) - log::debug!("Analyse IR"); - let optimized_ir_prog = optimized_ir_prog - .analyse() - .map_err(|e| CompileErrorInner::from(e).in_file(location.as_path()))?; + // log::debug!("Analyse IR"); + // let optimized_ir_prog = optimized_ir_prog + // .analyse() + // .map_err(|e| CompileErrorInner::from(e).in_file(location.as_path()))?; Ok(CompilationArtifacts { prog: optimized_ir_prog, diff --git a/zokrates_core/src/flat_absy/mod.rs b/zokrates_core/src/flat_absy/mod.rs index 4f73c126..457b2150 100644 --- a/zokrates_core/src/flat_absy/mod.rs +++ b/zokrates_core/src/flat_absy/mod.rs @@ -96,32 +96,24 @@ impl fmt::Display for RuntimeError { } } -#[derive(Clone, PartialEq)] -pub struct FlatProg { - /// FlatFunctions of the program - pub main: FlatFunction, -} - -impl fmt::Display for FlatProg { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.main) - } -} - -impl fmt::Debug for FlatProg { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "flat_program(main: {}\t)", self.main) - } -} +pub type FlatProg = FlatFunction; #[derive(Clone, PartialEq)] -pub struct FlatFunction { +pub struct FlatFunction { /// Arguments of the function pub arguments: Vec, /// Vector of statements that are executed when running the function pub statements: Vec>, } +pub type FlatProgIterator = FlatFunctionIterator; +pub struct FlatFunctionIterator { + /// Arguments of the function + pub arguments: Vec, + /// Vector of statements that are executed when running the function + pub statements: I, +} + impl fmt::Display for FlatFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( @@ -167,7 +159,7 @@ impl fmt::Debug for FlatFunction { /// * r1cs - R1CS in standard JSON data format #[derive(Clone, PartialEq)] -pub enum FlatStatement { +pub enum FlatStatement { Return(FlatExpressionList), Condition(FlatExpression, FlatExpression, RuntimeError), Definition(FlatVariable, FlatExpression), @@ -239,7 +231,7 @@ impl FlatStatement { } #[derive(Clone, Hash, Debug, PartialEq, Eq)] -pub struct FlatDirective { +pub struct FlatDirective { pub inputs: Vec>, pub outputs: Vec, pub solver: Solver, diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index ba73cc41..adf84e1c 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -16,17 +16,67 @@ use crate::flat_absy::{RuntimeError, *}; use crate::solvers::Solver; use crate::zir::types::{Type, UBitwidth}; use crate::zir::*; -use std::collections::hash_map::Entry; -use std::collections::HashMap; +use std::collections::{ + hash_map::{Entry, HashMap}, + VecDeque, +}; use std::convert::TryFrom; use zokrates_field::Field; -type FlatStatements = Vec>; +type FlatStatements = VecDeque>; + +/// Flattens a function +/// +/// # Arguments +/// * `funct` - `ZirFunction` that will be flattened +impl<'ast, T: Field> FlattenerIterator<'ast, T> { + pub fn from_function_and_config(funct: ZirFunction<'ast, T>, config: CompileConfig) -> Self { + let mut flattener = Flattener::new(config); + let mut statements_flattened = FlatStatements::new(); + // push parameters + let arguments_flattened = funct + .arguments + .into_iter() + .map(|p| flattener.use_parameter(&p, &mut statements_flattened)) + .collect(); + + FlattenerIterator { + statements: funct.statements.into(), + arguments_flattened, + statements_flattened, + flattener, + } + } +} + +pub struct FlattenerIterator<'ast, T> { + pub statements: VecDeque>, + pub arguments_flattened: Vec, + pub statements_flattened: FlatStatements, + pub flattener: Flattener<'ast, T>, +} + +impl<'ast, T: Field> Iterator for FlattenerIterator<'ast, T> { + type Item = FlatStatement; + + fn next(&mut self) -> Option { + if self.statements_flattened.is_empty() { + match self.statements.pop_front() { + Some(s) => { + self.flattener + .flatten_statement(&mut self.statements_flattened, s); + } + None => {} + } + } + self.statements_flattened.pop_front() + } +} /// Flattener, computes flattened program. #[derive(Debug)] -pub struct Flattener<'ast, T: Field> { - config: &'ast CompileConfig, +pub struct Flattener<'ast, T> { + config: CompileConfig, /// Index of the next introduced variable while processing the program. next_var_idx: usize, /// `FlatVariable`s corresponding to each `Identifier` @@ -156,13 +206,8 @@ impl From for RuntimeError { } impl<'ast, T: Field> Flattener<'ast, T> { - pub fn flatten(p: ZirProgram<'ast, T>, config: &CompileConfig) -> FlatProg { - Flattener::new(config).flatten_program(p) - } - /// Returns a `Flattener` with fresh `layout`. - - fn new(config: &'ast CompileConfig) -> Flattener<'ast, T> { + fn new(config: CompileConfig) -> Flattener<'ast, T> { Flattener { config, next_var_idx: 0, @@ -182,7 +227,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatExpression::Identifier(id) => id, e => { let res = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(res, e)); + statements_flattened.push_back(FlatStatement::Definition(res, e)); res } } @@ -241,7 +286,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { } // init size_unknown = true - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( size_unknown[0], FlatExpression::Number(T::from(1)), )); @@ -250,14 +295,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { for (i, b) in b.iter().enumerate() { if *b { - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( is_not_smaller_run[i], a[i].into(), )); // don't need to update size_unknown in the last round if i < len - 1 { - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( size_unknown[i + 1], FlatExpression::Mult( box size_unknown[i].into(), @@ -268,7 +313,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { } else { // don't need to update size_unknown in the last round if i < len - 1 { - statements_flattened.push( + statements_flattened.push_back( // sizeUnknown is not changing in this case // We sill have to assign the old value to the variable of the current run // This trivial definition will later be removed by the optimiser @@ -285,7 +330,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let and_name = self.use_sym(); let and = FlatExpression::Mult(box or_left.clone(), box or_right.clone()); - statements_flattened.push(FlatStatement::Definition(and_name, and)); + statements_flattened.push_back(FlatStatement::Definition(and_name, and)); let or = FlatExpression::Sub( box FlatExpression::Add(box or_left, box or_right), box and_name.into(), @@ -310,7 +355,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise fn eq_check( &mut self, - statements_flattened: &mut Vec>, + statements_flattened: &mut FlatStatements, left: FlatExpression, right: FlatExpression, ) -> FlatExpression { @@ -329,12 +374,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { let name_y = self.use_sym(); let name_m = self.use_sym(); - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( vec![name_y, name_m], Solver::ConditionEq, vec![x.clone()], ))); - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Identifier(name_y), FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)), RuntimeError::Equal, @@ -345,7 +390,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { box FlatExpression::Identifier(name_y), ); - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Number(T::zero()), FlatExpression::Mult(box res.clone(), box x), RuntimeError::Equal, @@ -376,7 +421,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { .fold(FlatExpression::from(T::zero()), |acc, e| { FlatExpression::Add(box acc, box e) }); - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Number(T::from(0)), FlatExpression::Sub(box conditions_sum, box T::from(conditions_count).into()), RuntimeError::Le, @@ -385,14 +430,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn make_conditional( &mut self, - statements: Vec>, + statements: FlatStatements, condition: FlatExpression, - ) -> Vec> { + ) -> FlatStatements { statements .into_iter() .flat_map(|s| match s { FlatStatement::Condition(left, right, message) => { - let mut output = vec![]; + let mut output = VecDeque::new(); // we transform (a == b) into (c => (a == b)) which is (!c || (a == b)) @@ -410,12 +455,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { assert!(x.is_linear() && y.is_linear()); let name_x_or_y = self.use_sym(); - output.push(FlatStatement::Directive(FlatDirective { + output.push_back(FlatStatement::Directive(FlatDirective { solver: Solver::Or, outputs: vec![name_x_or_y], inputs: vec![x.clone(), y.clone()], })); - output.push(FlatStatement::Condition( + output.push_back(FlatStatement::Condition( FlatExpression::Add( box x.clone(), box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()), @@ -423,7 +468,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatExpression::Mult(box x.clone(), box y.clone()), RuntimeError::BranchIsolation, )); - output.push(FlatStatement::Condition( + output.push_back(FlatStatement::Condition( name_x_or_y.into(), T::one().into(), message, @@ -431,7 +476,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { output } - s => vec![s], + s => VecDeque::from([s]), }) .collect() } @@ -457,16 +502,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { self.flatten_boolean_expression(statements_flattened, condition.clone()); let condition_id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat)); + statements_flattened.push_back(FlatStatement::Definition(condition_id, condition_flat)); self.condition_cache.insert(condition, condition_id); let (consequence, alternative) = if self.config.isolate_branches { - let mut consequence_statements = vec![]; + let mut consequence_statements = VecDeque::new(); let consequence = consequence.flatten(self, &mut consequence_statements); - let mut alternative_statements = vec![]; + let mut alternative_statements = VecDeque::new(); let alternative = alternative.flatten(self, &mut alternative_statements); @@ -495,13 +540,13 @@ impl<'ast, T: Field> Flattener<'ast, T> { let alternative = alternative.flat(); let consequence_id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(consequence_id, consequence)); + statements_flattened.push_back(FlatStatement::Definition(consequence_id, consequence)); let alternative_id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(alternative_id, alternative)); + statements_flattened.push_back(FlatStatement::Definition(alternative_id, alternative)); let term0_id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( term0_id, FlatExpression::Mult( box condition_id.into(), @@ -510,7 +555,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { )); let term1_id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( term1_id, FlatExpression::Mult( box FlatExpression::Sub( @@ -522,7 +567,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { )); let res = self.use_sym(); - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( res, FlatExpression::Add( box FlatExpression::from(term0_id), @@ -582,7 +627,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let e_bits_be: Vec = (0..bit_width).map(|_| self.use_sym()).collect(); // add a directive to get the bits - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( e_bits_be.clone(), Solver::bits(bit_width), vec![e_id], @@ -590,7 +635,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // bitness checks for bit in e_bits_be.iter().take(bit_width) { - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Identifier(*bit), FlatExpression::Mult( box FlatExpression::Identifier(*bit), @@ -613,7 +658,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); } - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Identifier(e_id), e_sum, RuntimeError::ConstantLtSum, @@ -704,7 +749,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { (0..safe_width).map(|_| self.use_sym()).collect(); // add a directive to get the bits - statements_flattened.push(FlatStatement::Directive( + statements_flattened.push_back(FlatStatement::Directive( FlatDirective::new( lhs_bits_be.clone(), Solver::bits(safe_width), @@ -714,7 +759,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // bitness checks for bit in lhs_bits_be.iter().take(safe_width) { - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Identifier(*bit), FlatExpression::Mult( box FlatExpression::Identifier(*bit), @@ -739,7 +784,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); } - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Identifier(lhs_id), lhs_sum, RuntimeError::LtSum, @@ -756,7 +801,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { (0..safe_width).map(|_| self.use_sym()).collect(); // add a directive to get the bits - statements_flattened.push(FlatStatement::Directive( + statements_flattened.push_back(FlatStatement::Directive( FlatDirective::new( rhs_bits_be.clone(), Solver::bits(safe_width), @@ -766,7 +811,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // bitness checks for bit in rhs_bits_be.iter().take(safe_width) { - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Identifier(*bit), FlatExpression::Mult( box FlatExpression::Identifier(*bit), @@ -791,7 +836,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); } - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Identifier(rhs_id), rhs_sum, RuntimeError::LtSum, @@ -815,15 +860,17 @@ impl<'ast, T: Field> Flattener<'ast, T> { (0..bit_width).map(|_| self.use_sym()).collect(); // add a directive to get the bits - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( - sub_bits_be.clone(), - Solver::bits(bit_width), - vec![subtraction_result.clone()], - ))); + statements_flattened.push_back(FlatStatement::Directive( + FlatDirective::new( + sub_bits_be.clone(), + Solver::bits(bit_width), + vec![subtraction_result.clone()], + ), + )); // bitness checks for bit in sub_bits_be.iter().take(bit_width) { - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Identifier(*bit), FlatExpression::Mult( box FlatExpression::Identifier(*bit), @@ -853,7 +900,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); } - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( subtraction_result, expr, RuntimeError::LtFinalSum, @@ -885,7 +932,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let x_sub_y = FlatExpression::Sub(box x, box y); let name_x_mult_x = self.use_sym(); - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( name_x_mult_x, FlatExpression::Mult(box x_sub_y.clone(), box x_sub_y), )); @@ -972,7 +1019,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { (0..bit_width).map(|_| self.use_sym()).collect(); // add a directive to get the bits - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( sub_bits_be.clone(), Solver::bits(bit_width), vec![subtraction_result.clone()], @@ -980,7 +1027,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // bitness checks for bit in sub_bits_be.iter().take(bit_width) { - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Identifier(*bit), FlatExpression::Mult( box FlatExpression::Identifier(*bit), @@ -1010,7 +1057,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); } - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( subtraction_result, expr, RuntimeError::LtFinalSum, @@ -1042,12 +1089,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { let y = self.flatten_boolean_expression(statements_flattened, rhs); assert!(x.is_linear() && y.is_linear()); let name_x_or_y = self.use_sym(); - statements_flattened.push(FlatStatement::Directive(FlatDirective { + statements_flattened.push_back(FlatStatement::Directive(FlatDirective { solver: Solver::Or, outputs: vec![name_x_or_y], inputs: vec![x.clone(), y.clone()], })); - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Add( box x.clone(), box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()), @@ -1063,7 +1110,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let name_x_and_y = self.use_sym(); assert!(x.is_linear() && y.is_linear()); - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( name_x_and_y, FlatExpression::Mult(box x, box y), )); @@ -1373,14 +1420,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { left_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, left_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); FlatExpression::Identifier(id) }; let d = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, right_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); FlatExpression::Identifier(id) }; @@ -1388,14 +1435,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { let invd = self.use_sym(); // # invd = 1/d - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( vec![invd], Solver::Div, vec![FlatExpression::Number(T::one()), d.clone()], ))); // assert(invd * d == 1) - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Number(T::one()), FlatExpression::Mult(box invd.into(), box d.clone()), RuntimeError::Inverse, @@ -1405,7 +1452,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let q = self.use_sym(); let r = self.use_sym(); - statements_flattened.push(FlatStatement::Directive(FlatDirective { + statements_flattened.push_back(FlatStatement::Directive(FlatDirective { inputs: vec![n.clone(), d.clone()], outputs: vec![q, r], solver: Solver::EuclideanDiv, @@ -1439,7 +1486,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); // q*d == n - r - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Sub(box n, box r.into()), FlatExpression::Mult(box q.into(), box d), RuntimeError::Euclidean, @@ -1520,14 +1567,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { left_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, left_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); FlatExpression::Identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, right_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); FlatExpression::Identifier(id) }; @@ -1550,14 +1597,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { left_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, left_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); FlatExpression::Identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, right_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); FlatExpression::Identifier(id) }; @@ -1612,20 +1659,20 @@ impl<'ast, T: Field> Flattener<'ast, T> { left_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, left_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); FlatExpression::Identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, right_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); FlatExpression::Identifier(id) }; let res = self.use_sym(); - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( res, FlatExpression::Mult(box new_left, box new_right), )); @@ -1974,7 +2021,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { } Entry::Vacant(_) => { let bits = (0..from).map(|_| self.use_sym()).collect::>(); - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( bits.clone(), Solver::Bits(from), vec![e.field.clone().unwrap()], @@ -1996,7 +2043,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let sum = flat_expression_from_bits(bits.clone()); // sum check - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( e.field.clone().unwrap(), sum.clone(), RuntimeError::Sum, @@ -2055,7 +2102,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { range_check = FlatExpression::Add(box range_check, box condition.clone()); let conditional_element_id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( conditional_element_id, FlatExpression::Mult(box condition, box element.flat()), )); @@ -2065,7 +2112,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { }, ); - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( range_check, FlatExpression::Number(T::one()), RuntimeError::SelectRangeCheck, @@ -2099,14 +2146,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { left_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, left_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); FlatExpression::Identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, right_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); FlatExpression::Identifier(id) }; FlatExpression::Add(box new_left, box new_right) @@ -2119,14 +2166,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { left_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, left_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); FlatExpression::Identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, right_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); FlatExpression::Identifier(id) }; @@ -2139,14 +2186,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { left_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, left_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); FlatExpression::Identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, right_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); FlatExpression::Identifier(id) }; FlatExpression::Mult(box new_left, box new_right) @@ -2156,12 +2203,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { let right_flattened = self.flatten_field_expression(statements_flattened, right); let new_left: FlatExpression = { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, left_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); id.into() }; let new_right: FlatExpression = { let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(id, right_flattened)); + statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); id.into() }; @@ -2169,28 +2216,28 @@ impl<'ast, T: Field> Flattener<'ast, T> { let inverse = self.use_sym(); // # invb = 1/b - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( vec![invb], Solver::Div, vec![FlatExpression::Number(T::one()), new_right.clone()], ))); // assert(invb * b == 1) - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Number(T::one()), FlatExpression::Mult(box invb.into(), box new_right.clone()), RuntimeError::Inverse, )); // # c = a/b - statements_flattened.push(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( vec![inverse], Solver::Div, vec![new_left.clone(), new_right.clone()], ))); // assert(c * b == a) - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( new_left, FlatExpression::Mult(box new_right, box inverse.into()), RuntimeError::Division, @@ -2239,7 +2286,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // introduce a new variable let id = self.use_sym(); // set it to the square of the previous one, stored in state - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( id, FlatExpression::Mult( box previous.clone(), @@ -2262,7 +2309,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { true => { // update the result by introducing a new variable let id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( id, FlatExpression::Mult(box acc.clone(), box power), // set the new result to the current result times the current power )); @@ -2305,7 +2352,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { .map(|x| x.get_field_unchecked()) .collect::>(); - statements_flattened.push(FlatStatement::Return(FlatExpressionList { + statements_flattened.push_back(FlatStatement::Return(FlatExpressionList { expressions: flat_expressions, })); } @@ -2314,13 +2361,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { self.flatten_boolean_expression(statements_flattened, condition.clone()); let condition_id = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat)); + statements_flattened + .push_back(FlatStatement::Definition(condition_id, condition_flat)); self.condition_cache.insert(condition, condition_id); if self.config.isolate_branches { - let mut consequence_statements = vec![]; - let mut alternative_statements = vec![]; + let mut consequence_statements = VecDeque::new(); + let mut alternative_statements = VecDeque::new(); consequence .into_iter() @@ -2367,7 +2415,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let var = self.use_variable(&assignee); // handle return of function call - statements_flattened.push(FlatStatement::Definition(var, e)); + statements_flattened.push_back(FlatStatement::Definition(var, e)); var } @@ -2431,14 +2479,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { let e = self.flatten_boolean_expression(statements_flattened, e); if e.is_linear() { - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( e, FlatExpression::Number(T::from(1)), error.into(), )); } else { // swap so that left side is linear - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( FlatExpression::Number(T::from(1)), e, error.into(), @@ -2474,7 +2522,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { } e => { let id = self.use_variable(&v); - statements_flattened.push(FlatStatement::Definition(id, e)); + statements_flattened + .push_back(FlatStatement::Definition(id, e)); id } }) @@ -2502,45 +2551,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { } } - /// Flattens a function - /// - /// # Arguments - /// * `funct` - `ZirFunction` that will be flattened - fn flatten_function(&mut self, funct: ZirFunction<'ast, T>) -> FlatFunction { - self.layout = HashMap::new(); - - self.next_var_idx = 0; - let mut statements_flattened: FlatStatements = FlatStatements::new(); - - // push parameters - let arguments_flattened = funct - .arguments - .into_iter() - .map(|p| self.use_parameter(&p, &mut statements_flattened)) - .collect(); - - // flatten statements in functions and apply substitution - for stat in funct.statements { - self.flatten_statement(&mut statements_flattened, stat); - } - - FlatFunction { - arguments: arguments_flattened, - statements: statements_flattened, - } - } - - /// Flattens a program - /// - /// # Arguments - /// - /// * `prog` - `ZirProgram` that will be flattened. - fn flatten_program(&mut self, prog: ZirProgram<'ast, T>) -> FlatProg { - FlatProg { - main: self.flatten_function(prog.main), - } - } - /// Flattens an equality assertion, enforcing it in the circuit. /// /// # Arguments @@ -2571,7 +2581,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ), ), }; - statements_flattened.push(FlatStatement::Condition(lhs, rhs, error)); + statements_flattened.push_back(FlatStatement::Condition(lhs, rhs, error)); } /// Identifies a non-linear expression by assigning it to a new identifier. @@ -2589,7 +2599,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { true => e, false => { let sym = self.use_sym(); - statements_flattened.push(FlatStatement::Definition(sym, e)); + statements_flattened.push_back(FlatStatement::Definition(sym, e)); FlatExpression::Identifier(sym) } } @@ -2638,7 +2648,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); } Type::Boolean => { - statements_flattened.push(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::Condition( variable.into(), FlatExpression::Mult(box variable.into(), box variable.into()), RuntimeError::ArgumentBitness, @@ -2649,7 +2659,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // we insert dummy condition statement for private field elements // to avoid unconstrained variables // translates to y == x * x - statements_flattened.push(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::Definition( self.use_sym(), FlatExpression::Mult(box variable.into(), box variable.into()), )); diff --git a/zokrates_core/src/ir/folder.rs b/zokrates_core/src/ir/folder.rs index 6abb7b86..b83fe2d7 100644 --- a/zokrates_core/src/ir/folder.rs +++ b/zokrates_core/src/ir/folder.rs @@ -46,7 +46,7 @@ pub fn fold_module>(f: &mut F, p: Prog) -> Prog { .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), - returns: p.returns.into_iter().map(|v| f.fold_variable(v)).collect(), + return_count: p.return_count, } } diff --git a/zokrates_core/src/ir/from_flat.rs b/zokrates_core/src/ir/from_flat.rs index 9d8aff4e..92a9c523 100644 --- a/zokrates_core/src/ir/from_flat.rs +++ b/zokrates_core/src/ir/from_flat.rs @@ -1,5 +1,5 @@ -use crate::flat_absy::{FlatDirective, FlatExpression, FlatProg, FlatStatement, FlatVariable}; -use crate::ir::{Directive, LinComb, Prog, QuadComb, Statement}; +use crate::flat_absy::{FlatDirective, FlatExpression, FlatStatement, FlatVariable}; +use crate::ir::{Directive, LinComb, QuadComb, Statement}; use zokrates_field::Field; impl QuadComb { @@ -17,50 +17,25 @@ impl QuadComb { } } -impl From> for Prog { - fn from(flat_prog: FlatProg) -> Prog { - // get the main function - let main = flat_prog.main; - - let return_expressions: Vec> = main - .statements - .iter() - .filter_map(|s| match s { - FlatStatement::Return(el) => Some(el.expressions.clone()), - _ => None, - }) - .next() - .unwrap(); - - Prog { - arguments: main.arguments, - returns: return_expressions - .iter() - .enumerate() - .map(|(index, _)| FlatVariable::public(index)) - .collect(), - statements: main - .statements - .into_iter() - .filter_map(|s| match s { - FlatStatement::Return(..) => None, - s => Some(s.into()), - }) - .chain( - return_expressions - .into_iter() - .enumerate() - .map(|(index, expression)| { - Statement::Constraint( - QuadComb::from_flat_expression(expression), - FlatVariable::public(index).into(), - None, - ) - }), - ) - .collect(), - } - } +pub fn from_flat<'ast, T: Field, I: Iterator>>( + flat_prog_iterator: I, +) -> impl Iterator> { + flat_prog_iterator.filter_map(|s| match s { + FlatStatement::Return(..) => None, + s => Some(s.into()), + }) + // .chain( + // return_expressions + // .into_iter() + // .enumerate() + // .map(|(index, expression)| { + // Statement::Constraint( + // QuadComb::from_flat_expression(expression), + // FlatVariable::public(index).into(), + // None, + // ) + // }), + // ) } impl From> for LinComb { diff --git a/zokrates_core/src/ir/mod.rs b/zokrates_core/src/ir/mod.rs index 623308a2..d9b44c89 100644 --- a/zokrates_core/src/ir/mod.rs +++ b/zokrates_core/src/ir/mod.rs @@ -8,7 +8,7 @@ use zokrates_field::Field; mod expression; pub mod folder; -mod from_flat; +pub mod from_flat; mod interpreter; mod serialize; pub mod smtlib2; @@ -78,10 +78,51 @@ impl fmt::Display for Statement { pub struct Prog { pub statements: Vec>, pub arguments: Vec, - pub returns: Vec, + pub return_count: usize, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ProgIterator { + pub statements: I, + pub arguments: Vec, + pub return_count: usize, +} + +impl From> for ProgIterator>> { + fn from(p: Prog) -> ProgIterator>> { + ProgIterator { + statements: p.statements.into_iter(), + arguments: p.arguments, + return_count: p.return_count, + } + } +} + +impl>> From> for Prog { + fn from(p: ProgIterator) -> Prog { + Prog { + statements: p.statements.collect(), + arguments: p.arguments, + return_count: p.return_count, + } + } +} + +impl ProgIterator { + pub fn returns(&self) -> Vec { + (0..self.return_count) + .map(|id| FlatVariable::public(id)) + .collect() + } } impl Prog { + pub fn returns(&self) -> Vec { + (0..self.return_count) + .map(|id| FlatVariable::public(id)) + .collect() + } + pub fn constraint_count(&self) -> usize { self.statements .iter() @@ -113,14 +154,14 @@ impl fmt::Display for Prog { .map(|v| format!("{}", v)) .collect::>() .join(", "), - self.returns.len(), + self.return_count, self.statements .iter() .map(|s| format!("\t{}", s)) .collect::>() .join("\n"), - self.returns - .iter() + (0..self.return_count) + .map(|i| FlatVariable::public(i)) .map(|e| format!("{}", e)) .collect::>() .join(", ") diff --git a/zokrates_core/src/ir/visitor.rs b/zokrates_core/src/ir/visitor.rs index 2a9cc028..236ade45 100644 --- a/zokrates_core/src/ir/visitor.rs +++ b/zokrates_core/src/ir/visitor.rs @@ -49,9 +49,6 @@ pub fn visit_module>(f: &mut F, p: &Prog) { for expr in p.statements.iter() { f.visit_statement(expr); } - for expr in p.returns.iter() { - f.visit_variable(expr); - } } pub fn visit_statement>(f: &mut F, s: &Statement) { diff --git a/zokrates_core/src/optimizer/canonicalizer.rs b/zokrates_core/src/optimizer/canonicalizer.rs index 69981b05..d5c6bae3 100644 --- a/zokrates_core/src/optimizer/canonicalizer.rs +++ b/zokrates_core/src/optimizer/canonicalizer.rs @@ -1,6 +1,7 @@ use crate::ir::{folder::Folder, LinComb}; use zokrates_field::Field; +#[derive(Default)] pub struct Canonicalizer; impl Folder for Canonicalizer { diff --git a/zokrates_core/src/optimizer/directive.rs b/zokrates_core/src/optimizer/directive.rs index 2b4e8a4b..06b9eb5a 100644 --- a/zokrates_core/src/optimizer/directive.rs +++ b/zokrates_core/src/optimizer/directive.rs @@ -17,26 +17,13 @@ use crate::solvers::Solver; use std::collections::hash_map::{Entry, HashMap}; use zokrates_field::Field; -#[derive(Debug)] -pub struct DirectiveOptimizer { +#[derive(Debug, Default)] +pub struct DirectiveOptimizer { calls: HashMap<(Solver, Vec>), Vec>, /// Map of renamings for reassigned variables while processing the program. substitution: HashMap, } -impl DirectiveOptimizer { - fn new() -> DirectiveOptimizer { - DirectiveOptimizer { - calls: HashMap::new(), - substitution: HashMap::new(), - } - } - - pub fn optimize(p: Prog) -> Prog { - DirectiveOptimizer::new().fold_module(p) - } -} - impl Folder for DirectiveOptimizer { fn fold_module(&mut self, p: Prog) -> Prog { // in order to correctly identify duplicates, we need to first canonicalize the statements diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index 493fe4e8..7cb95acc 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -16,23 +16,11 @@ fn hash(s: &Statement) -> Hash { hasher.finish() } -#[derive(Debug)] +#[derive(Debug, Default)] pub struct DuplicateOptimizer { seen: HashSet, } -impl DuplicateOptimizer { - fn new() -> Self { - DuplicateOptimizer { - seen: HashSet::new(), - } - } - - pub fn optimize(p: Prog) -> Prog { - Self::new().fold_module(p) - } -} - impl Folder for DuplicateOptimizer { fn fold_module(&mut self, p: Prog) -> Prog { // in order to correctly identify duplicates, we need to first canonicalize the statements diff --git a/zokrates_core/src/optimizer/mod.rs b/zokrates_core/src/optimizer/mod.rs index 3f2c064a..132c84f4 100644 --- a/zokrates_core/src/optimizer/mod.rs +++ b/zokrates_core/src/optimizer/mod.rs @@ -10,41 +10,59 @@ mod duplicate; mod redefinition; mod tautology; +use self::canonicalizer::Canonicalizer; use self::directive::DirectiveOptimizer; use self::duplicate::DuplicateOptimizer; use self::redefinition::RedefinitionOptimizer; use self::tautology::TautologyOptimizer; -use crate::ir::Prog; +use crate::ir::{ProgIterator, Statement}; use zokrates_field::Field; -impl Prog { - pub fn optimize(self) -> Self { +impl>> ProgIterator { + pub fn optimize(self) -> ProgIterator>> { // remove redefinitions - log::debug!("Constraints: {}", self.constraint_count()); - log::debug!("Optimizer: Remove redefinitions"); - let r = RedefinitionOptimizer::optimize(self); - log::debug!("Done"); + log::debug!( + "Optimizer: Remove redefinitions and tautologies and directives and duplicates" + ); - // remove constraints that are always satisfied - log::debug!("Constraints: {}", r.constraint_count()); - log::debug!("Optimizer: Remove tautologies"); - let r = TautologyOptimizer::optimize(r); - log::debug!("Done"); + // define all optimizer steps + let mut redefinition_optimizer = RedefinitionOptimizer::default(); + let mut tautologies_optimizer = TautologyOptimizer::default(); + let mut directive_optimizer = DirectiveOptimizer::default(); + let mut canonicalizer = Canonicalizer::default(); + let mut duplicate_optimizer = DuplicateOptimizer::default(); - // deduplicate directives which take the same input - log::debug!("Constraints: {}", r.constraint_count()); - log::debug!("Optimizer: Remove duplicate directive"); - let r = DirectiveOptimizer::optimize(r); - log::debug!("Done"); + // initialize the ones that need initializing + redefinition_optimizer.ignore.extend(self.returns().clone()); - // remove duplicate constraints - log::debug!("Constraints: {}", r.constraint_count()); - log::debug!("Optimizer: Remove duplicate constraints"); - let r = DuplicateOptimizer::optimize(r); - log::debug!("Done"); + use crate::ir::folder::Folder; - log::debug!("Constraints: {}", r.constraint_count()); + let r = ProgIterator { + arguments: self + .arguments + .into_iter() + .map(|a| redefinition_optimizer.fold_argument(a)) + .map(|a| { + >::fold_argument(&mut tautologies_optimizer, a) + }) + .map(|a| directive_optimizer.fold_argument(a)) + .map(|a| { + >::fold_argument(&mut duplicate_optimizer, a) + }) + .collect(), + statements: self + .statements + .into_iter() + .flat_map(move |s| redefinition_optimizer.fold_statement(s)) + .flat_map(move |s| tautologies_optimizer.fold_statement(s)) + .flat_map(move |s| canonicalizer.fold_statement(s)) + .flat_map(move |s| directive_optimizer.fold_statement(s)) + .flat_map(move |s| duplicate_optimizer.fold_statement(s)), + return_count: self.return_count, + }; + + log::debug!("Done"); r } } diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index 1080b5a5..f99d9fee 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -38,44 +38,30 @@ use crate::flat_absy::flat_variable::FlatVariable; use crate::flat_absy::FlatParameter; -use crate::ir::folder::{fold_module, Folder}; +use crate::ir::folder::Folder; use crate::ir::LinComb; use crate::ir::*; use std::collections::{HashMap, HashSet}; use zokrates_field::Field; #[derive(Debug)] -pub struct RedefinitionOptimizer { +pub struct RedefinitionOptimizer { /// Map of renamings for reassigned variables while processing the program. substitution: HashMap>, /// Set of variables that should not be substituted - ignore: HashSet, + pub ignore: HashSet, } -impl RedefinitionOptimizer { - fn new() -> Self { +impl Default for RedefinitionOptimizer { + fn default() -> Self { RedefinitionOptimizer { substitution: HashMap::new(), - ignore: HashSet::new(), + ignore: vec![FlatVariable::one()].into_iter().collect(), } } - - pub fn optimize(p: Prog) -> Prog { - RedefinitionOptimizer::new().fold_module(p) - } } impl Folder for RedefinitionOptimizer { - fn fold_module(&mut self, p: Prog) -> Prog { - // to prevent the optimiser from replacing outputs, add them to the ignored set - self.ignore.extend(p.returns.iter().cloned()); - - // to prevent the optimiser from replacing ~one, add it to the ignored set - self.ignore.insert(FlatVariable::one()); - - fold_module(self, p) - } - fn fold_argument(&mut self, a: FlatParameter) -> FlatParameter { // to prevent the optimiser from replacing user input, add it to the ignored set self.ignore.insert(a.id); diff --git a/zokrates_core/src/optimizer/tautology.rs b/zokrates_core/src/optimizer/tautology.rs index abee4ef7..e1146d3a 100644 --- a/zokrates_core/src/optimizer/tautology.rs +++ b/zokrates_core/src/optimizer/tautology.rs @@ -10,17 +10,8 @@ use crate::ir::folder::Folder; use crate::ir::*; use zokrates_field::Field; -pub struct TautologyOptimizer {} - -impl TautologyOptimizer { - fn new() -> TautologyOptimizer { - TautologyOptimizer {} - } - - pub fn optimize(p: Prog) -> Prog { - TautologyOptimizer::new().fold_module(p) - } -} +#[derive(Default)] +pub struct TautologyOptimizer; impl Folder for TautologyOptimizer { fn fold_statement(&mut self, s: Statement) -> Vec> { diff --git a/zokrates_core/src/static_analysis/flat_propagation.rs b/zokrates_core/src/static_analysis/flat_propagation.rs index 944bf54f..605a6d43 100644 --- a/zokrates_core/src/static_analysis/flat_propagation.rs +++ b/zokrates_core/src/static_analysis/flat_propagation.rs @@ -101,14 +101,6 @@ impl Propagate for FlatFunction { } } -impl FlatProg { - pub fn propagate(self) -> FlatProg { - let main = self.main.propagate(); - - FlatProg { main } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/zokrates_field/src/lib.rs b/zokrates_field/src/lib.rs index 030841a9..c1df081b 100644 --- a/zokrates_field/src/lib.rs +++ b/zokrates_field/src/lib.rs @@ -54,7 +54,8 @@ impl fmt::Debug for FieldParseError { } pub trait Field: - From + 'static + + From + From + From + From diff --git a/zokrates_test/Cargo.toml b/zokrates_test/Cargo.toml index 7444ff2a..ab0987c7 100644 --- a/zokrates_test/Cargo.toml +++ b/zokrates_test/Cargo.toml @@ -12,5 +12,6 @@ zokrates_abi = { version = "0.1", path = "../zokrates_abi" } serde = "1.0" serde_derive = "1.0" serde_json = "1.0" +typed-arena = "1.4.1" [lib] diff --git a/zokrates_test/src/lib.rs b/zokrates_test/src/lib.rs index f3605f35..6c6fec3d 100644 --- a/zokrates_test/src/lib.rs +++ b/zokrates_test/src/lib.rs @@ -149,13 +149,17 @@ fn compile_and_run(t: Tests) { let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap(); let resolver = FileSystemResolver::with_stdlib_root(stdlib.to_str().unwrap()); - let artifacts = compile::(code, entry_point.clone(), Some(&resolver), &config).unwrap(); + let arena = typed_arena::Arena::new(); - let bin = artifacts.prog(); - let abi = artifacts.abi(); + let artifacts = + compile::(code, entry_point.clone(), Some(&resolver), config, &arena).unwrap(); + + let abi = artifacts.abi; + let bin = artifacts.prog; if let Some(target_count) = t.max_constraint_count { - let count = bin.constraint_count(); + unimplemented!("bin.constraint_count()"); + let count = 42; assert!( count <= target_count, @@ -169,6 +173,8 @@ fn compile_and_run(t: Tests) { let interpreter = zokrates_core::ir::Interpreter::default(); + let bin = &bin.into(); + for test in t.tests.into_iter() { let with_abi = test.abi.unwrap_or(false);