1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

proof of concept of iterator treatment starting at flattening

This commit is contained in:
schaeff 2021-10-25 16:08:45 +02:00
parent d924368038
commit dad17b79e0
21 changed files with 397 additions and 351 deletions

21
Cargo.lock generated
View file

@ -1074,6 +1074,12 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
[[package]]
name = "half"
version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.11.2" version = "0.11.2"
@ -1845,6 +1851,16 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde_cbor"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
dependencies = [
"half",
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.130" version = "1.0.130"
@ -2383,7 +2399,6 @@ name = "zokrates_cli"
version = "0.7.7" version = "0.7.7"
dependencies = [ dependencies = [
"assert_cli", "assert_cli",
"bincode",
"cfg-if 0.1.10", "cfg-if 0.1.10",
"clap", "clap",
"dirs", "dirs",
@ -2393,8 +2408,11 @@ dependencies = [
"lazy_static", "lazy_static",
"log", "log",
"regex 0.2.11", "regex 0.2.11",
"serde",
"serde_cbor",
"serde_json", "serde_json",
"tempdir", "tempdir",
"typed-arena",
"zokrates_abi", "zokrates_abi",
"zokrates_core", "zokrates_core",
"zokrates_field", "zokrates_field",
@ -2545,6 +2563,7 @@ dependencies = [
"serde", "serde",
"serde_derive", "serde_derive",
"serde_json", "serde_json",
"typed-arena",
"zokrates_abi", "zokrates_abi",
"zokrates_core", "zokrates_core",
"zokrates_field", "zokrates_field",

View file

@ -16,13 +16,15 @@ log = "0.4"
env_logger = "0.9.0" env_logger = "0.9.0"
cfg-if = "0.1" cfg-if = "0.1"
clap = "2.26.2" clap = "2.26.2"
bincode = "0.8.0" serde_cbor = "0.11.2"
regex = "0.2" regex = "0.2"
zokrates_field = { version = "0.4", path = "../zokrates_field", default-features = false } zokrates_field = { version = "0.4", path = "../zokrates_field", default-features = false }
zokrates_abi = { version = "0.1", path = "../zokrates_abi" } zokrates_abi = { version = "0.1", path = "../zokrates_abi" }
zokrates_core = { version = "0.6", path = "../zokrates_core", default-features = false } zokrates_core = { version = "0.6", path = "../zokrates_core", default-features = false }
typed-arena = "1.4.1"
zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"} zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"}
serde_json = "1.0" serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
dirs = "3.0.1" dirs = "3.0.1"
lazy_static = "1.4.0" lazy_static = "1.4.0"

View file

@ -6,10 +6,35 @@ use std::convert::TryFrom;
use std::fs::File; use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write}; use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use typed_arena::Arena;
use zokrates_core::compile::{compile, CompilationArtifacts, CompileConfig, CompileError}; use zokrates_core::compile::{compile, CompilationArtifacts, CompileConfig, CompileError};
use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field}; use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field};
use zokrates_fs_resolver::FileSystemResolver; use zokrates_fs_resolver::FileSystemResolver;
use serde::{Serialize, Serializer};
use std::cell::Cell;
pub fn write_as_cbor<I, P, W>(out: &mut W, groups: I) -> serde_cbor::Result<()>
where
I: IntoIterator<Item = P>,
P: Serialize,
W: Write,
{
struct Wrapper<T>(Cell<Option<T>>);
impl<I, P> Serialize for Wrapper<I>
where
I: IntoIterator<Item = P>,
P: Serialize,
{
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.collect_seq(self.0.take().unwrap())
}
}
serde_cbor::to_writer(out, &Wrapper(Cell::new(Some(groups))))
}
pub fn subcommand() -> App<'static, 'static> { pub fn subcommand() -> App<'static, 'static> {
SubCommand::with_name("compile") SubCommand::with_name("compile")
.about("Compiles into flattened conditions. Produces two files: human-readable '.ztf' file for debugging and binary file") .about("Compiles into flattened conditions. Produces two files: human-readable '.ztf' file for debugging and binary file")
@ -136,8 +161,10 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
log::debug!("Compile"); log::debug!("Compile");
let artifacts: CompilationArtifacts<T> = compile(source, path, Some(&resolver), &config) let arena = Arena::new();
.map_err(|e| {
let artifacts =
compile::<T, _>(source, path, Some(&resolver), config, &arena).map_err(|e| {
format!( format!(
"Compilation failed:\n\n{}", "Compilation failed:\n\n{}",
e.0.iter() e.0.iter()
@ -147,10 +174,11 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
) )
})?; })?;
let program_flattened = artifacts.prog(); let program_flattened = artifacts.prog;
let abi = artifacts.abi;
// number of constraints the flattened program will translate to. // number of constraints the flattened program will translate to.
let num_constraints = program_flattened.constraint_count(); //let num_constraints = program_flattened.constraint_count();
// serialize flattened program and write to binary file // serialize flattened program and write to binary file
log::debug!("Serialize program"); log::debug!("Serialize program");
@ -159,21 +187,22 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
let mut writer = BufWriter::new(bin_output_file); let mut writer = BufWriter::new(bin_output_file);
program_flattened.serialize(&mut writer); println!("START WRITING TO DISK");
//program_flattened.serialize(&mut writer);
write_as_cbor(&mut writer, program_flattened.statements).unwrap();
// serialize ABI spec and write to JSON file // serialize ABI spec and write to JSON file
log::debug!("Serialize ABI"); log::debug!("Serialize ABI");
let abi_spec_file = File::create(&abi_spec_path) let abi_spec_file = File::create(&abi_spec_path)
.map_err(|why| format!("Could not create {}: {}", abi_spec_path.display(), why))?; .map_err(|why| format!("Could not create {}: {}", abi_spec_path.display(), why))?;
let abi = artifacts.abi();
let mut writer = BufWriter::new(abi_spec_file); let mut writer = BufWriter::new(abi_spec_file);
to_writer_pretty(&mut writer, &abi).map_err(|_| "Unable to write data to file.".to_string())?; to_writer_pretty(&mut writer, &abi).map_err(|_| "Unable to write data to file.".to_string())?;
if sub_matches.is_present("verbose") { if sub_matches.is_present("verbose") {
// debugging output // debugging output
println!("Compiled program:\n{}", program_flattened); //println!("Compiled program:\n{}", program_flattened);
} }
println!("Compiled code written to '{}'", bin_output_path.display()); println!("Compiled code written to '{}'", bin_output_path.display());
@ -185,15 +214,15 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
.map_err(|why| format!("Could not create {}: {}", hr_output_path.display(), why))?; .map_err(|why| format!("Could not create {}: {}", hr_output_path.display(), why))?;
let mut hrofb = BufWriter::new(hr_output_file); let mut hrofb = BufWriter::new(hr_output_file);
writeln!(&mut hrofb, "{}", program_flattened) // writeln!(&mut hrofb, "{}", program_flattened)
.map_err(|_| "Unable to write data to file".to_string())?; // .map_err(|_| "Unable to write data to file".to_string())?;
hrofb // hrofb
.flush() // .flush()
.map_err(|_| "Unable to flush buffer".to_string())?; // .map_err(|_| "Unable to flush buffer".to_string())?;
println!("Human readable code to '{}'", hr_output_path.display()); println!("Human readable code to '{}'", hr_output_path.display());
} }
println!("Number of constraints: {}", num_constraints); //println!("Number of constraints: {}", num_constraints);
Ok(()) Ok(())
} }

View file

@ -106,7 +106,7 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
} }
false => ConcreteSignature::new() false => ConcreteSignature::new()
.inputs(vec![ConcreteType::FieldElement; ir_prog.arguments.len()]) .inputs(vec![ConcreteType::FieldElement; ir_prog.arguments.len()])
.outputs(vec![ConcreteType::FieldElement; ir_prog.returns.len()]), .outputs(vec![ConcreteType::FieldElement; ir_prog.return_count]),
}; };
use zokrates_abi::Inputs; use zokrates_abi::Inputs;

View file

@ -4,7 +4,7 @@
//! @author Thibaut Schaeffer <thibaut@schaeff.fr> //! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018 //! @date 2018
use crate::absy::{Module, OwnedModuleId, Program}; use crate::absy::{Module, OwnedModuleId, Program};
use crate::flatten::Flattener; use crate::flatten::FlattenerIterator;
use crate::imports::{self, Importer}; use crate::imports::{self, Importer};
use crate::ir; use crate::ir;
use crate::macros; use crate::macros;
@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
use std::io; use std::io;
use std::io::Write;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use typed_arena::Arena; use typed_arena::Arena;
use zokrates_common::Resolver; use zokrates_common::Resolver;
@ -25,14 +26,14 @@ use zokrates_field::Field;
use zokrates_pest_ast as pest; use zokrates_pest_ast as pest;
#[derive(Debug)] #[derive(Debug)]
pub struct CompilationArtifacts<T: Field> { pub struct CompilationArtifacts<I> {
prog: ir::Prog<T>, pub prog: ir::ProgIterator<I>,
abi: Abi, pub abi: Abi,
} }
impl<T: Field> CompilationArtifacts<T> { impl<I> CompilationArtifacts<I> {
pub fn prog(&self) -> &ir::Prog<T> { pub fn prog(self) -> ir::ProgIterator<I> {
&self.prog self.prog
} }
pub fn abi(&self) -> &Abi { pub fn abi(&self) -> &Abi {
@ -183,37 +184,46 @@ impl CompileConfig {
type FilePath = PathBuf; type FilePath = PathBuf;
pub fn compile<T: Field, E: Into<imports::Error>>( pub fn compile<'ast, T: Field, E: Into<imports::Error>>(
source: String, source: String,
location: FilePath, location: FilePath,
resolver: Option<&dyn Resolver<E>>, resolver: Option<&dyn Resolver<E>>,
config: &CompileConfig, config: CompileConfig,
) -> Result<CompilationArtifacts<T>, CompileErrors> { arena: &'ast Arena<String>,
let arena = Arena::new(); ) -> Result<CompilationArtifacts<impl Iterator<Item = ir::Statement<T>> + 'ast>, CompileErrors> {
let (typed_ast, abi): (crate::zir::ZirProgram<'_, T>, _) =
let (typed_ast, abi) = check_with_arena(source, location.clone(), resolver, config, &arena)?; check_with_arena(source, location.clone(), resolver, &config, arena)?;
// flatten input program // flatten input program
log::debug!("Flatten"); log::debug!("Flatten");
let program_flattened = Flattener::flatten(typed_ast, config); let program_flattened = FlattenerIterator::from_function_and_config(typed_ast.main, config);
// constant propagation after call resolution // // constant propagation after call resolution
log::debug!("Propagate flat program"); // log::debug!("Propagate flat program");
let program_flattened = program_flattened.propagate(); // let program_flattened = program_flattened.propagate();
// convert to ir // convert to ir
log::debug!("Convert to IR"); log::debug!("Convert to IR");
let ir_prog = ir::Prog::from(program_flattened); let ir_prog = ir::ProgIterator {
arguments: program_flattened
.arguments_flattened
.clone()
.into_iter()
.map(|a| a.into())
.collect(),
return_count: 0,
statements: ir::from_flat::from_flat(program_flattened),
};
// optimize // optimize
log::debug!("Optimise IR"); log::debug!("Optimise IR");
let optimized_ir_prog = ir_prog.optimize(); let optimized_ir_prog = ir_prog.optimize();
// analyse ir (check constraints) // analyse ir (check constraints)
log::debug!("Analyse IR"); // log::debug!("Analyse IR");
let optimized_ir_prog = optimized_ir_prog // let optimized_ir_prog = optimized_ir_prog
.analyse() // .analyse()
.map_err(|e| CompileErrorInner::from(e).in_file(location.as_path()))?; // .map_err(|e| CompileErrorInner::from(e).in_file(location.as_path()))?;
Ok(CompilationArtifacts { Ok(CompilationArtifacts {
prog: optimized_ir_prog, prog: optimized_ir_prog,

View file

@ -96,32 +96,24 @@ impl fmt::Display for RuntimeError {
} }
} }
#[derive(Clone, PartialEq)] pub type FlatProg<T> = FlatFunction<T>;
pub struct FlatProg<T: Field> {
/// FlatFunctions of the program
pub main: FlatFunction<T>,
}
impl<T: Field> fmt::Display for FlatProg<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.main)
}
}
impl<T: Field> fmt::Debug for FlatProg<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "flat_program(main: {}\t)", self.main)
}
}
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub struct FlatFunction<T: Field> { pub struct FlatFunction<T> {
/// Arguments of the function /// Arguments of the function
pub arguments: Vec<FlatParameter>, pub arguments: Vec<FlatParameter>,
/// Vector of statements that are executed when running the function /// Vector of statements that are executed when running the function
pub statements: Vec<FlatStatement<T>>, pub statements: Vec<FlatStatement<T>>,
} }
pub type FlatProgIterator<T> = FlatFunctionIterator<T>;
pub struct FlatFunctionIterator<I> {
/// Arguments of the function
pub arguments: Vec<FlatParameter>,
/// Vector of statements that are executed when running the function
pub statements: I,
}
impl<T: Field> fmt::Display for FlatFunction<T> { impl<T: Field> fmt::Display for FlatFunction<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( write!(
@ -167,7 +159,7 @@ impl<T: Field> fmt::Debug for FlatFunction<T> {
/// * r1cs - R1CS in standard JSON data format /// * r1cs - R1CS in standard JSON data format
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub enum FlatStatement<T: Field> { pub enum FlatStatement<T> {
Return(FlatExpressionList<T>), Return(FlatExpressionList<T>),
Condition(FlatExpression<T>, FlatExpression<T>, RuntimeError), Condition(FlatExpression<T>, FlatExpression<T>, RuntimeError),
Definition(FlatVariable, FlatExpression<T>), Definition(FlatVariable, FlatExpression<T>),
@ -239,7 +231,7 @@ impl<T: Field> FlatStatement<T> {
} }
#[derive(Clone, Hash, Debug, PartialEq, Eq)] #[derive(Clone, Hash, Debug, PartialEq, Eq)]
pub struct FlatDirective<T: Field> { pub struct FlatDirective<T> {
pub inputs: Vec<FlatExpression<T>>, pub inputs: Vec<FlatExpression<T>>,
pub outputs: Vec<FlatVariable>, pub outputs: Vec<FlatVariable>,
pub solver: Solver, pub solver: Solver,

View file

@ -16,17 +16,67 @@ use crate::flat_absy::{RuntimeError, *};
use crate::solvers::Solver; use crate::solvers::Solver;
use crate::zir::types::{Type, UBitwidth}; use crate::zir::types::{Type, UBitwidth};
use crate::zir::*; use crate::zir::*;
use std::collections::hash_map::Entry; use std::collections::{
use std::collections::HashMap; hash_map::{Entry, HashMap},
VecDeque,
};
use std::convert::TryFrom; use std::convert::TryFrom;
use zokrates_field::Field; use zokrates_field::Field;
type FlatStatements<T> = Vec<FlatStatement<T>>; type FlatStatements<T> = VecDeque<FlatStatement<T>>;
/// Flattens a function
///
/// # Arguments
/// * `funct` - `ZirFunction` that will be flattened
impl<'ast, T: Field> FlattenerIterator<'ast, T> {
pub fn from_function_and_config(funct: ZirFunction<'ast, T>, config: CompileConfig) -> Self {
let mut flattener = Flattener::new(config);
let mut statements_flattened = FlatStatements::new();
// push parameters
let arguments_flattened = funct
.arguments
.into_iter()
.map(|p| flattener.use_parameter(&p, &mut statements_flattened))
.collect();
FlattenerIterator {
statements: funct.statements.into(),
arguments_flattened,
statements_flattened,
flattener,
}
}
}
pub struct FlattenerIterator<'ast, T> {
pub statements: VecDeque<ZirStatement<'ast, T>>,
pub arguments_flattened: Vec<FlatParameter>,
pub statements_flattened: FlatStatements<T>,
pub flattener: Flattener<'ast, T>,
}
impl<'ast, T: Field> Iterator for FlattenerIterator<'ast, T> {
type Item = FlatStatement<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.statements_flattened.is_empty() {
match self.statements.pop_front() {
Some(s) => {
self.flattener
.flatten_statement(&mut self.statements_flattened, s);
}
None => {}
}
}
self.statements_flattened.pop_front()
}
}
/// Flattener, computes flattened program. /// Flattener, computes flattened program.
#[derive(Debug)] #[derive(Debug)]
pub struct Flattener<'ast, T: Field> { pub struct Flattener<'ast, T> {
config: &'ast CompileConfig, config: CompileConfig,
/// Index of the next introduced variable while processing the program. /// Index of the next introduced variable while processing the program.
next_var_idx: usize, next_var_idx: usize,
/// `FlatVariable`s corresponding to each `Identifier` /// `FlatVariable`s corresponding to each `Identifier`
@ -156,13 +206,8 @@ impl From<crate::zir::RuntimeError> for RuntimeError {
} }
impl<'ast, T: Field> Flattener<'ast, T> { impl<'ast, T: Field> Flattener<'ast, T> {
pub fn flatten(p: ZirProgram<'ast, T>, config: &CompileConfig) -> FlatProg<T> {
Flattener::new(config).flatten_program(p)
}
/// Returns a `Flattener` with fresh `layout`. /// Returns a `Flattener` with fresh `layout`.
fn new(config: CompileConfig) -> Flattener<'ast, T> {
fn new(config: &'ast CompileConfig) -> Flattener<'ast, T> {
Flattener { Flattener {
config, config,
next_var_idx: 0, next_var_idx: 0,
@ -182,7 +227,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatExpression::Identifier(id) => id, FlatExpression::Identifier(id) => id,
e => { e => {
let res = self.use_sym(); let res = self.use_sym();
statements_flattened.push(FlatStatement::Definition(res, e)); statements_flattened.push_back(FlatStatement::Definition(res, e));
res res
} }
} }
@ -241,7 +286,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
} }
// init size_unknown = true // init size_unknown = true
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
size_unknown[0], size_unknown[0],
FlatExpression::Number(T::from(1)), FlatExpression::Number(T::from(1)),
)); ));
@ -250,14 +295,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
for (i, b) in b.iter().enumerate() { for (i, b) in b.iter().enumerate() {
if *b { if *b {
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
is_not_smaller_run[i], is_not_smaller_run[i],
a[i].into(), a[i].into(),
)); ));
// don't need to update size_unknown in the last round // don't need to update size_unknown in the last round
if i < len - 1 { if i < len - 1 {
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
size_unknown[i + 1], size_unknown[i + 1],
FlatExpression::Mult( FlatExpression::Mult(
box size_unknown[i].into(), box size_unknown[i].into(),
@ -268,7 +313,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
} else { } else {
// don't need to update size_unknown in the last round // don't need to update size_unknown in the last round
if i < len - 1 { if i < len - 1 {
statements_flattened.push( statements_flattened.push_back(
// sizeUnknown is not changing in this case // sizeUnknown is not changing in this case
// We sill have to assign the old value to the variable of the current run // We sill have to assign the old value to the variable of the current run
// This trivial definition will later be removed by the optimiser // This trivial definition will later be removed by the optimiser
@ -285,7 +330,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let and_name = self.use_sym(); let and_name = self.use_sym();
let and = FlatExpression::Mult(box or_left.clone(), box or_right.clone()); let and = FlatExpression::Mult(box or_left.clone(), box or_right.clone());
statements_flattened.push(FlatStatement::Definition(and_name, and)); statements_flattened.push_back(FlatStatement::Definition(and_name, and));
let or = FlatExpression::Sub( let or = FlatExpression::Sub(
box FlatExpression::Add(box or_left, box or_right), box FlatExpression::Add(box or_left, box or_right),
box and_name.into(), box and_name.into(),
@ -310,7 +355,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise /// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise
fn eq_check( fn eq_check(
&mut self, &mut self,
statements_flattened: &mut Vec<FlatStatement<T>>, statements_flattened: &mut FlatStatements<T>,
left: FlatExpression<T>, left: FlatExpression<T>,
right: FlatExpression<T>, right: FlatExpression<T>,
) -> FlatExpression<T> { ) -> FlatExpression<T> {
@ -329,12 +374,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let name_y = self.use_sym(); let name_y = self.use_sym();
let name_m = self.use_sym(); let name_m = self.use_sym();
statements_flattened.push(FlatStatement::Directive(FlatDirective::new( statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
vec![name_y, name_m], vec![name_y, name_m],
Solver::ConditionEq, Solver::ConditionEq,
vec![x.clone()], vec![x.clone()],
))); )));
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(name_y), FlatExpression::Identifier(name_y),
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)), FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
RuntimeError::Equal, RuntimeError::Equal,
@ -345,7 +390,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
box FlatExpression::Identifier(name_y), box FlatExpression::Identifier(name_y),
); );
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Number(T::zero()), FlatExpression::Number(T::zero()),
FlatExpression::Mult(box res.clone(), box x), FlatExpression::Mult(box res.clone(), box x),
RuntimeError::Equal, RuntimeError::Equal,
@ -376,7 +421,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.fold(FlatExpression::from(T::zero()), |acc, e| { .fold(FlatExpression::from(T::zero()), |acc, e| {
FlatExpression::Add(box acc, box e) FlatExpression::Add(box acc, box e)
}); });
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Number(T::from(0)), FlatExpression::Number(T::from(0)),
FlatExpression::Sub(box conditions_sum, box T::from(conditions_count).into()), FlatExpression::Sub(box conditions_sum, box T::from(conditions_count).into()),
RuntimeError::Le, RuntimeError::Le,
@ -385,14 +430,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn make_conditional( fn make_conditional(
&mut self, &mut self,
statements: Vec<FlatStatement<T>>, statements: FlatStatements<T>,
condition: FlatExpression<T>, condition: FlatExpression<T>,
) -> Vec<FlatStatement<T>> { ) -> FlatStatements<T> {
statements statements
.into_iter() .into_iter()
.flat_map(|s| match s { .flat_map(|s| match s {
FlatStatement::Condition(left, right, message) => { FlatStatement::Condition(left, right, message) => {
let mut output = vec![]; let mut output = VecDeque::new();
// we transform (a == b) into (c => (a == b)) which is (!c || (a == b)) // we transform (a == b) into (c => (a == b)) which is (!c || (a == b))
@ -410,12 +455,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
assert!(x.is_linear() && y.is_linear()); assert!(x.is_linear() && y.is_linear());
let name_x_or_y = self.use_sym(); let name_x_or_y = self.use_sym();
output.push(FlatStatement::Directive(FlatDirective { output.push_back(FlatStatement::Directive(FlatDirective {
solver: Solver::Or, solver: Solver::Or,
outputs: vec![name_x_or_y], outputs: vec![name_x_or_y],
inputs: vec![x.clone(), y.clone()], inputs: vec![x.clone(), y.clone()],
})); }));
output.push(FlatStatement::Condition( output.push_back(FlatStatement::Condition(
FlatExpression::Add( FlatExpression::Add(
box x.clone(), box x.clone(),
box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()), box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()),
@ -423,7 +468,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatExpression::Mult(box x.clone(), box y.clone()), FlatExpression::Mult(box x.clone(), box y.clone()),
RuntimeError::BranchIsolation, RuntimeError::BranchIsolation,
)); ));
output.push(FlatStatement::Condition( output.push_back(FlatStatement::Condition(
name_x_or_y.into(), name_x_or_y.into(),
T::one().into(), T::one().into(),
message, message,
@ -431,7 +476,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
output output
} }
s => vec![s], s => VecDeque::from([s]),
}) })
.collect() .collect()
} }
@ -457,16 +502,16 @@ impl<'ast, T: Field> Flattener<'ast, T> {
self.flatten_boolean_expression(statements_flattened, condition.clone()); self.flatten_boolean_expression(statements_flattened, condition.clone());
let condition_id = self.use_sym(); let condition_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat)); statements_flattened.push_back(FlatStatement::Definition(condition_id, condition_flat));
self.condition_cache.insert(condition, condition_id); self.condition_cache.insert(condition, condition_id);
let (consequence, alternative) = if self.config.isolate_branches { let (consequence, alternative) = if self.config.isolate_branches {
let mut consequence_statements = vec![]; let mut consequence_statements = VecDeque::new();
let consequence = consequence.flatten(self, &mut consequence_statements); let consequence = consequence.flatten(self, &mut consequence_statements);
let mut alternative_statements = vec![]; let mut alternative_statements = VecDeque::new();
let alternative = alternative.flatten(self, &mut alternative_statements); let alternative = alternative.flatten(self, &mut alternative_statements);
@ -495,13 +540,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let alternative = alternative.flat(); let alternative = alternative.flat();
let consequence_id = self.use_sym(); let consequence_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(consequence_id, consequence)); statements_flattened.push_back(FlatStatement::Definition(consequence_id, consequence));
let alternative_id = self.use_sym(); let alternative_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(alternative_id, alternative)); statements_flattened.push_back(FlatStatement::Definition(alternative_id, alternative));
let term0_id = self.use_sym(); let term0_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
term0_id, term0_id,
FlatExpression::Mult( FlatExpression::Mult(
box condition_id.into(), box condition_id.into(),
@ -510,7 +555,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)); ));
let term1_id = self.use_sym(); let term1_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
term1_id, term1_id,
FlatExpression::Mult( FlatExpression::Mult(
box FlatExpression::Sub( box FlatExpression::Sub(
@ -522,7 +567,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
)); ));
let res = self.use_sym(); let res = self.use_sym();
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
res, res,
FlatExpression::Add( FlatExpression::Add(
box FlatExpression::from(term0_id), box FlatExpression::from(term0_id),
@ -582,7 +627,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let e_bits_be: Vec<FlatVariable> = (0..bit_width).map(|_| self.use_sym()).collect(); let e_bits_be: Vec<FlatVariable> = (0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits // add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new( statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
e_bits_be.clone(), e_bits_be.clone(),
Solver::bits(bit_width), Solver::bits(bit_width),
vec![e_id], vec![e_id],
@ -590,7 +635,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// bitness checks // bitness checks
for bit in e_bits_be.iter().take(bit_width) { for bit in e_bits_be.iter().take(bit_width) {
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(*bit), FlatExpression::Identifier(*bit),
FlatExpression::Mult( FlatExpression::Mult(
box FlatExpression::Identifier(*bit), box FlatExpression::Identifier(*bit),
@ -613,7 +658,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
); );
} }
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(e_id), FlatExpression::Identifier(e_id),
e_sum, e_sum,
RuntimeError::ConstantLtSum, RuntimeError::ConstantLtSum,
@ -704,7 +749,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(0..safe_width).map(|_| self.use_sym()).collect(); (0..safe_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits // add a directive to get the bits
statements_flattened.push(FlatStatement::Directive( statements_flattened.push_back(FlatStatement::Directive(
FlatDirective::new( FlatDirective::new(
lhs_bits_be.clone(), lhs_bits_be.clone(),
Solver::bits(safe_width), Solver::bits(safe_width),
@ -714,7 +759,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// bitness checks // bitness checks
for bit in lhs_bits_be.iter().take(safe_width) { for bit in lhs_bits_be.iter().take(safe_width) {
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(*bit), FlatExpression::Identifier(*bit),
FlatExpression::Mult( FlatExpression::Mult(
box FlatExpression::Identifier(*bit), box FlatExpression::Identifier(*bit),
@ -739,7 +784,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
); );
} }
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(lhs_id), FlatExpression::Identifier(lhs_id),
lhs_sum, lhs_sum,
RuntimeError::LtSum, RuntimeError::LtSum,
@ -756,7 +801,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(0..safe_width).map(|_| self.use_sym()).collect(); (0..safe_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits // add a directive to get the bits
statements_flattened.push(FlatStatement::Directive( statements_flattened.push_back(FlatStatement::Directive(
FlatDirective::new( FlatDirective::new(
rhs_bits_be.clone(), rhs_bits_be.clone(),
Solver::bits(safe_width), Solver::bits(safe_width),
@ -766,7 +811,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// bitness checks // bitness checks
for bit in rhs_bits_be.iter().take(safe_width) { for bit in rhs_bits_be.iter().take(safe_width) {
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(*bit), FlatExpression::Identifier(*bit),
FlatExpression::Mult( FlatExpression::Mult(
box FlatExpression::Identifier(*bit), box FlatExpression::Identifier(*bit),
@ -791,7 +836,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
); );
} }
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(rhs_id), FlatExpression::Identifier(rhs_id),
rhs_sum, rhs_sum,
RuntimeError::LtSum, RuntimeError::LtSum,
@ -815,15 +860,17 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(0..bit_width).map(|_| self.use_sym()).collect(); (0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits // add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new( statements_flattened.push_back(FlatStatement::Directive(
FlatDirective::new(
sub_bits_be.clone(), sub_bits_be.clone(),
Solver::bits(bit_width), Solver::bits(bit_width),
vec![subtraction_result.clone()], vec![subtraction_result.clone()],
))); ),
));
// bitness checks // bitness checks
for bit in sub_bits_be.iter().take(bit_width) { for bit in sub_bits_be.iter().take(bit_width) {
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(*bit), FlatExpression::Identifier(*bit),
FlatExpression::Mult( FlatExpression::Mult(
box FlatExpression::Identifier(*bit), box FlatExpression::Identifier(*bit),
@ -853,7 +900,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
); );
} }
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
subtraction_result, subtraction_result,
expr, expr,
RuntimeError::LtFinalSum, RuntimeError::LtFinalSum,
@ -885,7 +932,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let x_sub_y = FlatExpression::Sub(box x, box y); let x_sub_y = FlatExpression::Sub(box x, box y);
let name_x_mult_x = self.use_sym(); let name_x_mult_x = self.use_sym();
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
name_x_mult_x, name_x_mult_x,
FlatExpression::Mult(box x_sub_y.clone(), box x_sub_y), FlatExpression::Mult(box x_sub_y.clone(), box x_sub_y),
)); ));
@ -972,7 +1019,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(0..bit_width).map(|_| self.use_sym()).collect(); (0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits // add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new( statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
sub_bits_be.clone(), sub_bits_be.clone(),
Solver::bits(bit_width), Solver::bits(bit_width),
vec![subtraction_result.clone()], vec![subtraction_result.clone()],
@ -980,7 +1027,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// bitness checks // bitness checks
for bit in sub_bits_be.iter().take(bit_width) { for bit in sub_bits_be.iter().take(bit_width) {
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(*bit), FlatExpression::Identifier(*bit),
FlatExpression::Mult( FlatExpression::Mult(
box FlatExpression::Identifier(*bit), box FlatExpression::Identifier(*bit),
@ -1010,7 +1057,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
); );
} }
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
subtraction_result, subtraction_result,
expr, expr,
RuntimeError::LtFinalSum, RuntimeError::LtFinalSum,
@ -1042,12 +1089,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let y = self.flatten_boolean_expression(statements_flattened, rhs); let y = self.flatten_boolean_expression(statements_flattened, rhs);
assert!(x.is_linear() && y.is_linear()); assert!(x.is_linear() && y.is_linear());
let name_x_or_y = self.use_sym(); let name_x_or_y = self.use_sym();
statements_flattened.push(FlatStatement::Directive(FlatDirective { statements_flattened.push_back(FlatStatement::Directive(FlatDirective {
solver: Solver::Or, solver: Solver::Or,
outputs: vec![name_x_or_y], outputs: vec![name_x_or_y],
inputs: vec![x.clone(), y.clone()], inputs: vec![x.clone(), y.clone()],
})); }));
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Add( FlatExpression::Add(
box x.clone(), box x.clone(),
box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()), box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()),
@ -1063,7 +1110,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let name_x_and_y = self.use_sym(); let name_x_and_y = self.use_sym();
assert!(x.is_linear() && y.is_linear()); assert!(x.is_linear() && y.is_linear());
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
name_x_and_y, name_x_and_y,
FlatExpression::Mult(box x, box y), FlatExpression::Mult(box x, box y),
)); ));
@ -1373,14 +1420,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
left_flattened left_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
let d = if right_flattened.is_linear() { let d = if right_flattened.is_linear() {
right_flattened right_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, right_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
@ -1388,14 +1435,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let invd = self.use_sym(); let invd = self.use_sym();
// # invd = 1/d // # invd = 1/d
statements_flattened.push(FlatStatement::Directive(FlatDirective::new( statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
vec![invd], vec![invd],
Solver::Div, Solver::Div,
vec![FlatExpression::Number(T::one()), d.clone()], vec![FlatExpression::Number(T::one()), d.clone()],
))); )));
// assert(invd * d == 1) // assert(invd * d == 1)
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Number(T::one()), FlatExpression::Number(T::one()),
FlatExpression::Mult(box invd.into(), box d.clone()), FlatExpression::Mult(box invd.into(), box d.clone()),
RuntimeError::Inverse, RuntimeError::Inverse,
@ -1405,7 +1452,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let q = self.use_sym(); let q = self.use_sym();
let r = self.use_sym(); let r = self.use_sym();
statements_flattened.push(FlatStatement::Directive(FlatDirective { statements_flattened.push_back(FlatStatement::Directive(FlatDirective {
inputs: vec![n.clone(), d.clone()], inputs: vec![n.clone(), d.clone()],
outputs: vec![q, r], outputs: vec![q, r],
solver: Solver::EuclideanDiv, solver: Solver::EuclideanDiv,
@ -1439,7 +1486,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
); );
// q*d == n - r // q*d == n - r
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Sub(box n, box r.into()), FlatExpression::Sub(box n, box r.into()),
FlatExpression::Mult(box q.into(), box d), FlatExpression::Mult(box q.into(), box d),
RuntimeError::Euclidean, RuntimeError::Euclidean,
@ -1520,14 +1567,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
left_flattened left_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
let new_right = if right_flattened.is_linear() { let new_right = if right_flattened.is_linear() {
right_flattened right_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, right_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
@ -1550,14 +1597,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
left_flattened left_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
let new_right = if right_flattened.is_linear() { let new_right = if right_flattened.is_linear() {
right_flattened right_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, right_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
@ -1612,20 +1659,20 @@ impl<'ast, T: Field> Flattener<'ast, T> {
left_flattened left_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
let new_right = if right_flattened.is_linear() { let new_right = if right_flattened.is_linear() {
right_flattened right_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, right_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
let res = self.use_sym(); let res = self.use_sym();
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
res, res,
FlatExpression::Mult(box new_left, box new_right), FlatExpression::Mult(box new_left, box new_right),
)); ));
@ -1974,7 +2021,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
} }
Entry::Vacant(_) => { Entry::Vacant(_) => {
let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>(); let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>();
statements_flattened.push(FlatStatement::Directive(FlatDirective::new( statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
bits.clone(), bits.clone(),
Solver::Bits(from), Solver::Bits(from),
vec![e.field.clone().unwrap()], vec![e.field.clone().unwrap()],
@ -1996,7 +2043,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let sum = flat_expression_from_bits(bits.clone()); let sum = flat_expression_from_bits(bits.clone());
// sum check // sum check
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
e.field.clone().unwrap(), e.field.clone().unwrap(),
sum.clone(), sum.clone(),
RuntimeError::Sum, RuntimeError::Sum,
@ -2055,7 +2102,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
range_check = FlatExpression::Add(box range_check, box condition.clone()); range_check = FlatExpression::Add(box range_check, box condition.clone());
let conditional_element_id = self.use_sym(); let conditional_element_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
conditional_element_id, conditional_element_id,
FlatExpression::Mult(box condition, box element.flat()), FlatExpression::Mult(box condition, box element.flat()),
)); ));
@ -2065,7 +2112,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}, },
); );
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
range_check, range_check,
FlatExpression::Number(T::one()), FlatExpression::Number(T::one()),
RuntimeError::SelectRangeCheck, RuntimeError::SelectRangeCheck,
@ -2099,14 +2146,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
left_flattened left_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
let new_right = if right_flattened.is_linear() { let new_right = if right_flattened.is_linear() {
right_flattened right_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, right_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
FlatExpression::Add(box new_left, box new_right) FlatExpression::Add(box new_left, box new_right)
@ -2119,14 +2166,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
left_flattened left_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
let new_right = if right_flattened.is_linear() { let new_right = if right_flattened.is_linear() {
right_flattened right_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, right_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
@ -2139,14 +2186,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
left_flattened left_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
let new_right = if right_flattened.is_linear() { let new_right = if right_flattened.is_linear() {
right_flattened right_flattened
} else { } else {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, right_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
FlatExpression::Identifier(id) FlatExpression::Identifier(id)
}; };
FlatExpression::Mult(box new_left, box new_right) FlatExpression::Mult(box new_left, box new_right)
@ -2156,12 +2203,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let right_flattened = self.flatten_field_expression(statements_flattened, right); let right_flattened = self.flatten_field_expression(statements_flattened, right);
let new_left: FlatExpression<T> = { let new_left: FlatExpression<T> = {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, left_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
id.into() id.into()
}; };
let new_right: FlatExpression<T> = { let new_right: FlatExpression<T> = {
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(id, right_flattened)); statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
id.into() id.into()
}; };
@ -2169,28 +2216,28 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let inverse = self.use_sym(); let inverse = self.use_sym();
// # invb = 1/b // # invb = 1/b
statements_flattened.push(FlatStatement::Directive(FlatDirective::new( statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
vec![invb], vec![invb],
Solver::Div, Solver::Div,
vec![FlatExpression::Number(T::one()), new_right.clone()], vec![FlatExpression::Number(T::one()), new_right.clone()],
))); )));
// assert(invb * b == 1) // assert(invb * b == 1)
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Number(T::one()), FlatExpression::Number(T::one()),
FlatExpression::Mult(box invb.into(), box new_right.clone()), FlatExpression::Mult(box invb.into(), box new_right.clone()),
RuntimeError::Inverse, RuntimeError::Inverse,
)); ));
// # c = a/b // # c = a/b
statements_flattened.push(FlatStatement::Directive(FlatDirective::new( statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
vec![inverse], vec![inverse],
Solver::Div, Solver::Div,
vec![new_left.clone(), new_right.clone()], vec![new_left.clone(), new_right.clone()],
))); )));
// assert(c * b == a) // assert(c * b == a)
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
new_left, new_left,
FlatExpression::Mult(box new_right, box inverse.into()), FlatExpression::Mult(box new_right, box inverse.into()),
RuntimeError::Division, RuntimeError::Division,
@ -2239,7 +2286,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// introduce a new variable // introduce a new variable
let id = self.use_sym(); let id = self.use_sym();
// set it to the square of the previous one, stored in state // set it to the square of the previous one, stored in state
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
id, id,
FlatExpression::Mult( FlatExpression::Mult(
box previous.clone(), box previous.clone(),
@ -2262,7 +2309,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
true => { true => {
// update the result by introducing a new variable // update the result by introducing a new variable
let id = self.use_sym(); let id = self.use_sym();
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
id, id,
FlatExpression::Mult(box acc.clone(), box power), // set the new result to the current result times the current power FlatExpression::Mult(box acc.clone(), box power), // set the new result to the current result times the current power
)); ));
@ -2305,7 +2352,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.map(|x| x.get_field_unchecked()) .map(|x| x.get_field_unchecked())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
statements_flattened.push(FlatStatement::Return(FlatExpressionList { statements_flattened.push_back(FlatStatement::Return(FlatExpressionList {
expressions: flat_expressions, expressions: flat_expressions,
})); }));
} }
@ -2314,13 +2361,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
self.flatten_boolean_expression(statements_flattened, condition.clone()); self.flatten_boolean_expression(statements_flattened, condition.clone());
let condition_id = self.use_sym(); let condition_id = self.use_sym();
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat)); statements_flattened
.push_back(FlatStatement::Definition(condition_id, condition_flat));
self.condition_cache.insert(condition, condition_id); self.condition_cache.insert(condition, condition_id);
if self.config.isolate_branches { if self.config.isolate_branches {
let mut consequence_statements = vec![]; let mut consequence_statements = VecDeque::new();
let mut alternative_statements = vec![]; let mut alternative_statements = VecDeque::new();
consequence consequence
.into_iter() .into_iter()
@ -2367,7 +2415,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let var = self.use_variable(&assignee); let var = self.use_variable(&assignee);
// handle return of function call // handle return of function call
statements_flattened.push(FlatStatement::Definition(var, e)); statements_flattened.push_back(FlatStatement::Definition(var, e));
var var
} }
@ -2431,14 +2479,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let e = self.flatten_boolean_expression(statements_flattened, e); let e = self.flatten_boolean_expression(statements_flattened, e);
if e.is_linear() { if e.is_linear() {
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
e, e,
FlatExpression::Number(T::from(1)), FlatExpression::Number(T::from(1)),
error.into(), error.into(),
)); ));
} else { } else {
// swap so that left side is linear // swap so that left side is linear
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Number(T::from(1)), FlatExpression::Number(T::from(1)),
e, e,
error.into(), error.into(),
@ -2474,7 +2522,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
} }
e => { e => {
let id = self.use_variable(&v); let id = self.use_variable(&v);
statements_flattened.push(FlatStatement::Definition(id, e)); statements_flattened
.push_back(FlatStatement::Definition(id, e));
id id
} }
}) })
@ -2502,45 +2551,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
} }
} }
/// Flattens a function
///
/// # Arguments
/// * `funct` - `ZirFunction` that will be flattened
fn flatten_function(&mut self, funct: ZirFunction<'ast, T>) -> FlatFunction<T> {
self.layout = HashMap::new();
self.next_var_idx = 0;
let mut statements_flattened: FlatStatements<T> = FlatStatements::new();
// push parameters
let arguments_flattened = funct
.arguments
.into_iter()
.map(|p| self.use_parameter(&p, &mut statements_flattened))
.collect();
// flatten statements in functions and apply substitution
for stat in funct.statements {
self.flatten_statement(&mut statements_flattened, stat);
}
FlatFunction {
arguments: arguments_flattened,
statements: statements_flattened,
}
}
/// Flattens a program
///
/// # Arguments
///
/// * `prog` - `ZirProgram` that will be flattened.
fn flatten_program(&mut self, prog: ZirProgram<'ast, T>) -> FlatProg<T> {
FlatProg {
main: self.flatten_function(prog.main),
}
}
/// Flattens an equality assertion, enforcing it in the circuit. /// Flattens an equality assertion, enforcing it in the circuit.
/// ///
/// # Arguments /// # Arguments
@ -2571,7 +2581,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
), ),
), ),
}; };
statements_flattened.push(FlatStatement::Condition(lhs, rhs, error)); statements_flattened.push_back(FlatStatement::Condition(lhs, rhs, error));
} }
/// Identifies a non-linear expression by assigning it to a new identifier. /// Identifies a non-linear expression by assigning it to a new identifier.
@ -2589,7 +2599,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
true => e, true => e,
false => { false => {
let sym = self.use_sym(); let sym = self.use_sym();
statements_flattened.push(FlatStatement::Definition(sym, e)); statements_flattened.push_back(FlatStatement::Definition(sym, e));
FlatExpression::Identifier(sym) FlatExpression::Identifier(sym)
} }
} }
@ -2638,7 +2648,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
); );
} }
Type::Boolean => { Type::Boolean => {
statements_flattened.push(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
variable.into(), variable.into(),
FlatExpression::Mult(box variable.into(), box variable.into()), FlatExpression::Mult(box variable.into(), box variable.into()),
RuntimeError::ArgumentBitness, RuntimeError::ArgumentBitness,
@ -2649,7 +2659,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// we insert dummy condition statement for private field elements // we insert dummy condition statement for private field elements
// to avoid unconstrained variables // to avoid unconstrained variables
// translates to y == x * x // translates to y == x * x
statements_flattened.push(FlatStatement::Definition( statements_flattened.push_back(FlatStatement::Definition(
self.use_sym(), self.use_sym(),
FlatExpression::Mult(box variable.into(), box variable.into()), FlatExpression::Mult(box variable.into(), box variable.into()),
)); ));

View file

@ -46,7 +46,7 @@ pub fn fold_module<T: Field, F: Folder<T>>(f: &mut F, p: Prog<T>) -> Prog<T> {
.into_iter() .into_iter()
.flat_map(|s| f.fold_statement(s)) .flat_map(|s| f.fold_statement(s))
.collect(), .collect(),
returns: p.returns.into_iter().map(|v| f.fold_variable(v)).collect(), return_count: p.return_count,
} }
} }

View file

@ -1,5 +1,5 @@
use crate::flat_absy::{FlatDirective, FlatExpression, FlatProg, FlatStatement, FlatVariable}; use crate::flat_absy::{FlatDirective, FlatExpression, FlatStatement, FlatVariable};
use crate::ir::{Directive, LinComb, Prog, QuadComb, Statement}; use crate::ir::{Directive, LinComb, QuadComb, Statement};
use zokrates_field::Field; use zokrates_field::Field;
impl<T: Field> QuadComb<T> { impl<T: Field> QuadComb<T> {
@ -17,50 +17,25 @@ impl<T: Field> QuadComb<T> {
} }
} }
impl<T: Field> From<FlatProg<T>> for Prog<T> { pub fn from_flat<'ast, T: Field, I: Iterator<Item = FlatStatement<T>>>(
fn from(flat_prog: FlatProg<T>) -> Prog<T> { flat_prog_iterator: I,
// get the main function ) -> impl Iterator<Item = Statement<T>> {
let main = flat_prog.main; flat_prog_iterator.filter_map(|s| match s {
let return_expressions: Vec<FlatExpression<T>> = main
.statements
.iter()
.filter_map(|s| match s {
FlatStatement::Return(el) => Some(el.expressions.clone()),
_ => None,
})
.next()
.unwrap();
Prog {
arguments: main.arguments,
returns: return_expressions
.iter()
.enumerate()
.map(|(index, _)| FlatVariable::public(index))
.collect(),
statements: main
.statements
.into_iter()
.filter_map(|s| match s {
FlatStatement::Return(..) => None, FlatStatement::Return(..) => None,
s => Some(s.into()), s => Some(s.into()),
}) })
.chain( // .chain(
return_expressions // return_expressions
.into_iter() // .into_iter()
.enumerate() // .enumerate()
.map(|(index, expression)| { // .map(|(index, expression)| {
Statement::Constraint( // Statement::Constraint(
QuadComb::from_flat_expression(expression), // QuadComb::from_flat_expression(expression),
FlatVariable::public(index).into(), // FlatVariable::public(index).into(),
None, // None,
) // )
}), // }),
) // )
.collect(),
}
}
} }
impl<T: Field> From<FlatExpression<T>> for LinComb<T> { impl<T: Field> From<FlatExpression<T>> for LinComb<T> {

View file

@ -8,7 +8,7 @@ use zokrates_field::Field;
mod expression; mod expression;
pub mod folder; pub mod folder;
mod from_flat; pub mod from_flat;
mod interpreter; mod interpreter;
mod serialize; mod serialize;
pub mod smtlib2; pub mod smtlib2;
@ -78,10 +78,51 @@ impl<T: Field> fmt::Display for Statement<T> {
pub struct Prog<T> { pub struct Prog<T> {
pub statements: Vec<Statement<T>>, pub statements: Vec<Statement<T>>,
pub arguments: Vec<FlatParameter>, pub arguments: Vec<FlatParameter>,
pub returns: Vec<FlatVariable>, pub return_count: usize,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ProgIterator<I> {
pub statements: I,
pub arguments: Vec<FlatParameter>,
pub return_count: usize,
}
impl<T> From<Prog<T>> for ProgIterator<std::vec::IntoIter<Statement<T>>> {
fn from(p: Prog<T>) -> ProgIterator<std::vec::IntoIter<Statement<T>>> {
ProgIterator {
statements: p.statements.into_iter(),
arguments: p.arguments,
return_count: p.return_count,
}
}
}
impl<T, I: Iterator<Item = Statement<T>>> From<ProgIterator<I>> for Prog<T> {
fn from(p: ProgIterator<I>) -> Prog<T> {
Prog {
statements: p.statements.collect(),
arguments: p.arguments,
return_count: p.return_count,
}
}
}
impl<T> ProgIterator<T> {
pub fn returns(&self) -> Vec<FlatVariable> {
(0..self.return_count)
.map(|id| FlatVariable::public(id))
.collect()
}
} }
impl<T: Field> Prog<T> { impl<T: Field> Prog<T> {
pub fn returns(&self) -> Vec<FlatVariable> {
(0..self.return_count)
.map(|id| FlatVariable::public(id))
.collect()
}
pub fn constraint_count(&self) -> usize { pub fn constraint_count(&self) -> usize {
self.statements self.statements
.iter() .iter()
@ -113,14 +154,14 @@ impl<T: Field> fmt::Display for Prog<T> {
.map(|v| format!("{}", v)) .map(|v| format!("{}", v))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "), .join(", "),
self.returns.len(), self.return_count,
self.statements self.statements
.iter() .iter()
.map(|s| format!("\t{}", s)) .map(|s| format!("\t{}", s))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n"), .join("\n"),
self.returns (0..self.return_count)
.iter() .map(|i| FlatVariable::public(i))
.map(|e| format!("{}", e)) .map(|e| format!("{}", e))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", ") .join(", ")

View file

@ -49,9 +49,6 @@ pub fn visit_module<T: Field, F: Visitor<T>>(f: &mut F, p: &Prog<T>) {
for expr in p.statements.iter() { for expr in p.statements.iter() {
f.visit_statement(expr); f.visit_statement(expr);
} }
for expr in p.returns.iter() {
f.visit_variable(expr);
}
} }
pub fn visit_statement<T: Field, F: Visitor<T>>(f: &mut F, s: &Statement<T>) { pub fn visit_statement<T: Field, F: Visitor<T>>(f: &mut F, s: &Statement<T>) {

View file

@ -1,6 +1,7 @@
use crate::ir::{folder::Folder, LinComb}; use crate::ir::{folder::Folder, LinComb};
use zokrates_field::Field; use zokrates_field::Field;
#[derive(Default)]
pub struct Canonicalizer; pub struct Canonicalizer;
impl<T: Field> Folder<T> for Canonicalizer { impl<T: Field> Folder<T> for Canonicalizer {

View file

@ -17,26 +17,13 @@ use crate::solvers::Solver;
use std::collections::hash_map::{Entry, HashMap}; use std::collections::hash_map::{Entry, HashMap};
use zokrates_field::Field; use zokrates_field::Field;
#[derive(Debug)] #[derive(Debug, Default)]
pub struct DirectiveOptimizer<T: Field> { pub struct DirectiveOptimizer<T> {
calls: HashMap<(Solver, Vec<QuadComb<T>>), Vec<FlatVariable>>, calls: HashMap<(Solver, Vec<QuadComb<T>>), Vec<FlatVariable>>,
/// Map of renamings for reassigned variables while processing the program. /// Map of renamings for reassigned variables while processing the program.
substitution: HashMap<FlatVariable, FlatVariable>, substitution: HashMap<FlatVariable, FlatVariable>,
} }
impl<T: Field> DirectiveOptimizer<T> {
fn new() -> DirectiveOptimizer<T> {
DirectiveOptimizer {
calls: HashMap::new(),
substitution: HashMap::new(),
}
}
pub fn optimize(p: Prog<T>) -> Prog<T> {
DirectiveOptimizer::new().fold_module(p)
}
}
impl<T: Field> Folder<T> for DirectiveOptimizer<T> { impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
fn fold_module(&mut self, p: Prog<T>) -> Prog<T> { fn fold_module(&mut self, p: Prog<T>) -> Prog<T> {
// in order to correctly identify duplicates, we need to first canonicalize the statements // in order to correctly identify duplicates, we need to first canonicalize the statements

View file

@ -16,23 +16,11 @@ fn hash<T: Field>(s: &Statement<T>) -> Hash {
hasher.finish() hasher.finish()
} }
#[derive(Debug)] #[derive(Debug, Default)]
pub struct DuplicateOptimizer { pub struct DuplicateOptimizer {
seen: HashSet<Hash>, seen: HashSet<Hash>,
} }
impl DuplicateOptimizer {
fn new() -> Self {
DuplicateOptimizer {
seen: HashSet::new(),
}
}
pub fn optimize<T: Field>(p: Prog<T>) -> Prog<T> {
Self::new().fold_module(p)
}
}
impl<T: Field> Folder<T> for DuplicateOptimizer { impl<T: Field> Folder<T> for DuplicateOptimizer {
fn fold_module(&mut self, p: Prog<T>) -> Prog<T> { fn fold_module(&mut self, p: Prog<T>) -> Prog<T> {
// in order to correctly identify duplicates, we need to first canonicalize the statements // in order to correctly identify duplicates, we need to first canonicalize the statements

View file

@ -10,41 +10,59 @@ mod duplicate;
mod redefinition; mod redefinition;
mod tautology; mod tautology;
use self::canonicalizer::Canonicalizer;
use self::directive::DirectiveOptimizer; use self::directive::DirectiveOptimizer;
use self::duplicate::DuplicateOptimizer; use self::duplicate::DuplicateOptimizer;
use self::redefinition::RedefinitionOptimizer; use self::redefinition::RedefinitionOptimizer;
use self::tautology::TautologyOptimizer; use self::tautology::TautologyOptimizer;
use crate::ir::Prog; use crate::ir::{ProgIterator, Statement};
use zokrates_field::Field; use zokrates_field::Field;
impl<T: Field> Prog<T> { impl<T: Field, I: Iterator<Item = Statement<T>>> ProgIterator<I> {
pub fn optimize(self) -> Self { pub fn optimize(self) -> ProgIterator<impl Iterator<Item = Statement<T>>> {
// remove redefinitions // remove redefinitions
log::debug!("Constraints: {}", self.constraint_count()); log::debug!(
log::debug!("Optimizer: Remove redefinitions"); "Optimizer: Remove redefinitions and tautologies and directives and duplicates"
let r = RedefinitionOptimizer::optimize(self); );
log::debug!("Done");
// remove constraints that are always satisfied // define all optimizer steps
log::debug!("Constraints: {}", r.constraint_count()); let mut redefinition_optimizer = RedefinitionOptimizer::default();
log::debug!("Optimizer: Remove tautologies"); let mut tautologies_optimizer = TautologyOptimizer::default();
let r = TautologyOptimizer::optimize(r); let mut directive_optimizer = DirectiveOptimizer::default();
log::debug!("Done"); let mut canonicalizer = Canonicalizer::default();
let mut duplicate_optimizer = DuplicateOptimizer::default();
// deduplicate directives which take the same input // initialize the ones that need initializing
log::debug!("Constraints: {}", r.constraint_count()); redefinition_optimizer.ignore.extend(self.returns().clone());
log::debug!("Optimizer: Remove duplicate directive");
let r = DirectiveOptimizer::optimize(r);
log::debug!("Done");
// remove duplicate constraints use crate::ir::folder::Folder;
log::debug!("Constraints: {}", r.constraint_count());
log::debug!("Optimizer: Remove duplicate constraints");
let r = DuplicateOptimizer::optimize(r);
log::debug!("Done");
log::debug!("Constraints: {}", r.constraint_count()); let r = ProgIterator {
arguments: self
.arguments
.into_iter()
.map(|a| redefinition_optimizer.fold_argument(a))
.map(|a| {
<TautologyOptimizer as Folder<T>>::fold_argument(&mut tautologies_optimizer, a)
})
.map(|a| directive_optimizer.fold_argument(a))
.map(|a| {
<DuplicateOptimizer as Folder<T>>::fold_argument(&mut duplicate_optimizer, a)
})
.collect(),
statements: self
.statements
.into_iter()
.flat_map(move |s| redefinition_optimizer.fold_statement(s))
.flat_map(move |s| tautologies_optimizer.fold_statement(s))
.flat_map(move |s| canonicalizer.fold_statement(s))
.flat_map(move |s| directive_optimizer.fold_statement(s))
.flat_map(move |s| duplicate_optimizer.fold_statement(s)),
return_count: self.return_count,
};
log::debug!("Done");
r r
} }
} }

View file

@ -38,44 +38,30 @@
use crate::flat_absy::flat_variable::FlatVariable; use crate::flat_absy::flat_variable::FlatVariable;
use crate::flat_absy::FlatParameter; use crate::flat_absy::FlatParameter;
use crate::ir::folder::{fold_module, Folder}; use crate::ir::folder::Folder;
use crate::ir::LinComb; use crate::ir::LinComb;
use crate::ir::*; use crate::ir::*;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use zokrates_field::Field; use zokrates_field::Field;
#[derive(Debug)] #[derive(Debug)]
pub struct RedefinitionOptimizer<T: Field> { pub struct RedefinitionOptimizer<T> {
/// Map of renamings for reassigned variables while processing the program. /// Map of renamings for reassigned variables while processing the program.
substitution: HashMap<FlatVariable, CanonicalLinComb<T>>, substitution: HashMap<FlatVariable, CanonicalLinComb<T>>,
/// Set of variables that should not be substituted /// Set of variables that should not be substituted
ignore: HashSet<FlatVariable>, pub ignore: HashSet<FlatVariable>,
} }
impl<T: Field> RedefinitionOptimizer<T> { impl<T> Default for RedefinitionOptimizer<T> {
fn new() -> Self { fn default() -> Self {
RedefinitionOptimizer { RedefinitionOptimizer {
substitution: HashMap::new(), substitution: HashMap::new(),
ignore: HashSet::new(), ignore: vec![FlatVariable::one()].into_iter().collect(),
} }
} }
pub fn optimize(p: Prog<T>) -> Prog<T> {
RedefinitionOptimizer::new().fold_module(p)
}
} }
impl<T: Field> Folder<T> for RedefinitionOptimizer<T> { impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
fn fold_module(&mut self, p: Prog<T>) -> Prog<T> {
// to prevent the optimiser from replacing outputs, add them to the ignored set
self.ignore.extend(p.returns.iter().cloned());
// to prevent the optimiser from replacing ~one, add it to the ignored set
self.ignore.insert(FlatVariable::one());
fold_module(self, p)
}
fn fold_argument(&mut self, a: FlatParameter) -> FlatParameter { fn fold_argument(&mut self, a: FlatParameter) -> FlatParameter {
// to prevent the optimiser from replacing user input, add it to the ignored set // to prevent the optimiser from replacing user input, add it to the ignored set
self.ignore.insert(a.id); self.ignore.insert(a.id);

View file

@ -10,17 +10,8 @@ use crate::ir::folder::Folder;
use crate::ir::*; use crate::ir::*;
use zokrates_field::Field; use zokrates_field::Field;
pub struct TautologyOptimizer {} #[derive(Default)]
pub struct TautologyOptimizer;
impl TautologyOptimizer {
fn new() -> TautologyOptimizer {
TautologyOptimizer {}
}
pub fn optimize<T: Field>(p: Prog<T>) -> Prog<T> {
TautologyOptimizer::new().fold_module(p)
}
}
impl<T: Field> Folder<T> for TautologyOptimizer { impl<T: Field> Folder<T> for TautologyOptimizer {
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> { fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {

View file

@ -101,14 +101,6 @@ impl<T: Field> Propagate<T> for FlatFunction<T> {
} }
} }
impl<T: Field> FlatProg<T> {
pub fn propagate(self) -> FlatProg<T> {
let main = self.main.propagate();
FlatProg { main }
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -54,7 +54,8 @@ impl fmt::Debug for FieldParseError {
} }
pub trait Field: pub trait Field:
From<i32> 'static
+ From<i32>
+ From<u32> + From<u32>
+ From<usize> + From<usize>
+ From<u128> + From<u128>

View file

@ -12,5 +12,6 @@ zokrates_abi = { version = "0.1", path = "../zokrates_abi" }
serde = "1.0" serde = "1.0"
serde_derive = "1.0" serde_derive = "1.0"
serde_json = "1.0" serde_json = "1.0"
typed-arena = "1.4.1"
[lib] [lib]

View file

@ -149,13 +149,17 @@ fn compile_and_run<T: Field>(t: Tests) {
let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap(); let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap();
let resolver = FileSystemResolver::with_stdlib_root(stdlib.to_str().unwrap()); let resolver = FileSystemResolver::with_stdlib_root(stdlib.to_str().unwrap());
let artifacts = compile::<T, _>(code, entry_point.clone(), Some(&resolver), &config).unwrap(); let arena = typed_arena::Arena::new();
let bin = artifacts.prog(); let artifacts =
let abi = artifacts.abi(); compile::<T, _>(code, entry_point.clone(), Some(&resolver), config, &arena).unwrap();
let abi = artifacts.abi;
let bin = artifacts.prog;
if let Some(target_count) = t.max_constraint_count { if let Some(target_count) = t.max_constraint_count {
let count = bin.constraint_count(); unimplemented!("bin.constraint_count()");
let count = 42;
assert!( assert!(
count <= target_count, count <= target_count,
@ -169,6 +173,8 @@ fn compile_and_run<T: Field>(t: Tests) {
let interpreter = zokrates_core::ir::Interpreter::default(); let interpreter = zokrates_core::ir::Interpreter::default();
let bin = &bin.into();
for test in t.tests.into_iter() { for test in t.tests.into_iter() {
let with_abi = test.abi.unwrap_or(false); let with_abi = test.abi.unwrap_or(false);