proof of concept of iterator treatment starting at flattening
This commit is contained in:
parent
d924368038
commit
dad17b79e0
21 changed files with 397 additions and 351 deletions
21
Cargo.lock
generated
21
Cargo.lock
generated
|
@ -1074,6 +1074,12 @@ version = "0.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "half"
|
||||||
|
version = "1.8.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashbrown"
|
name = "hashbrown"
|
||||||
version = "0.11.2"
|
version = "0.11.2"
|
||||||
|
@ -1845,6 +1851,16 @@ dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_cbor"
|
||||||
|
version = "0.11.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
|
||||||
|
dependencies = [
|
||||||
|
"half",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.130"
|
version = "1.0.130"
|
||||||
|
@ -2383,7 +2399,6 @@ name = "zokrates_cli"
|
||||||
version = "0.7.7"
|
version = "0.7.7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"assert_cli",
|
"assert_cli",
|
||||||
"bincode",
|
|
||||||
"cfg-if 0.1.10",
|
"cfg-if 0.1.10",
|
||||||
"clap",
|
"clap",
|
||||||
"dirs",
|
"dirs",
|
||||||
|
@ -2393,8 +2408,11 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"log",
|
"log",
|
||||||
"regex 0.2.11",
|
"regex 0.2.11",
|
||||||
|
"serde",
|
||||||
|
"serde_cbor",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tempdir",
|
"tempdir",
|
||||||
|
"typed-arena",
|
||||||
"zokrates_abi",
|
"zokrates_abi",
|
||||||
"zokrates_core",
|
"zokrates_core",
|
||||||
"zokrates_field",
|
"zokrates_field",
|
||||||
|
@ -2545,6 +2563,7 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"typed-arena",
|
||||||
"zokrates_abi",
|
"zokrates_abi",
|
||||||
"zokrates_core",
|
"zokrates_core",
|
||||||
"zokrates_field",
|
"zokrates_field",
|
||||||
|
|
|
@ -16,13 +16,15 @@ log = "0.4"
|
||||||
env_logger = "0.9.0"
|
env_logger = "0.9.0"
|
||||||
cfg-if = "0.1"
|
cfg-if = "0.1"
|
||||||
clap = "2.26.2"
|
clap = "2.26.2"
|
||||||
bincode = "0.8.0"
|
serde_cbor = "0.11.2"
|
||||||
regex = "0.2"
|
regex = "0.2"
|
||||||
zokrates_field = { version = "0.4", path = "../zokrates_field", default-features = false }
|
zokrates_field = { version = "0.4", path = "../zokrates_field", default-features = false }
|
||||||
zokrates_abi = { version = "0.1", path = "../zokrates_abi" }
|
zokrates_abi = { version = "0.1", path = "../zokrates_abi" }
|
||||||
zokrates_core = { version = "0.6", path = "../zokrates_core", default-features = false }
|
zokrates_core = { version = "0.6", path = "../zokrates_core", default-features = false }
|
||||||
|
typed-arena = "1.4.1"
|
||||||
zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"}
|
zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"}
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
dirs = "3.0.1"
|
dirs = "3.0.1"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,35 @@ use std::convert::TryFrom;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{BufReader, BufWriter, Read, Write};
|
use std::io::{BufReader, BufWriter, Read, Write};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
use typed_arena::Arena;
|
||||||
use zokrates_core::compile::{compile, CompilationArtifacts, CompileConfig, CompileError};
|
use zokrates_core::compile::{compile, CompilationArtifacts, CompileConfig, CompileError};
|
||||||
use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field};
|
use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field};
|
||||||
use zokrates_fs_resolver::FileSystemResolver;
|
use zokrates_fs_resolver::FileSystemResolver;
|
||||||
|
|
||||||
|
use serde::{Serialize, Serializer};
|
||||||
|
use std::cell::Cell;
|
||||||
|
|
||||||
|
pub fn write_as_cbor<I, P, W>(out: &mut W, groups: I) -> serde_cbor::Result<()>
|
||||||
|
where
|
||||||
|
I: IntoIterator<Item = P>,
|
||||||
|
P: Serialize,
|
||||||
|
W: Write,
|
||||||
|
{
|
||||||
|
struct Wrapper<T>(Cell<Option<T>>);
|
||||||
|
|
||||||
|
impl<I, P> Serialize for Wrapper<I>
|
||||||
|
where
|
||||||
|
I: IntoIterator<Item = P>,
|
||||||
|
P: Serialize,
|
||||||
|
{
|
||||||
|
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
|
||||||
|
s.collect_seq(self.0.take().unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
serde_cbor::to_writer(out, &Wrapper(Cell::new(Some(groups))))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn subcommand() -> App<'static, 'static> {
|
pub fn subcommand() -> App<'static, 'static> {
|
||||||
SubCommand::with_name("compile")
|
SubCommand::with_name("compile")
|
||||||
.about("Compiles into flattened conditions. Produces two files: human-readable '.ztf' file for debugging and binary file")
|
.about("Compiles into flattened conditions. Produces two files: human-readable '.ztf' file for debugging and binary file")
|
||||||
|
@ -136,8 +161,10 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
|
||||||
|
|
||||||
log::debug!("Compile");
|
log::debug!("Compile");
|
||||||
|
|
||||||
let artifacts: CompilationArtifacts<T> = compile(source, path, Some(&resolver), &config)
|
let arena = Arena::new();
|
||||||
.map_err(|e| {
|
|
||||||
|
let artifacts =
|
||||||
|
compile::<T, _>(source, path, Some(&resolver), config, &arena).map_err(|e| {
|
||||||
format!(
|
format!(
|
||||||
"Compilation failed:\n\n{}",
|
"Compilation failed:\n\n{}",
|
||||||
e.0.iter()
|
e.0.iter()
|
||||||
|
@ -147,10 +174,11 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let program_flattened = artifacts.prog();
|
let program_flattened = artifacts.prog;
|
||||||
|
let abi = artifacts.abi;
|
||||||
|
|
||||||
// number of constraints the flattened program will translate to.
|
// number of constraints the flattened program will translate to.
|
||||||
let num_constraints = program_flattened.constraint_count();
|
//let num_constraints = program_flattened.constraint_count();
|
||||||
|
|
||||||
// serialize flattened program and write to binary file
|
// serialize flattened program and write to binary file
|
||||||
log::debug!("Serialize program");
|
log::debug!("Serialize program");
|
||||||
|
@ -159,21 +187,22 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
|
||||||
|
|
||||||
let mut writer = BufWriter::new(bin_output_file);
|
let mut writer = BufWriter::new(bin_output_file);
|
||||||
|
|
||||||
program_flattened.serialize(&mut writer);
|
println!("START WRITING TO DISK");
|
||||||
|
|
||||||
|
//program_flattened.serialize(&mut writer);
|
||||||
|
write_as_cbor(&mut writer, program_flattened.statements).unwrap();
|
||||||
|
|
||||||
// serialize ABI spec and write to JSON file
|
// serialize ABI spec and write to JSON file
|
||||||
log::debug!("Serialize ABI");
|
log::debug!("Serialize ABI");
|
||||||
let abi_spec_file = File::create(&abi_spec_path)
|
let abi_spec_file = File::create(&abi_spec_path)
|
||||||
.map_err(|why| format!("Could not create {}: {}", abi_spec_path.display(), why))?;
|
.map_err(|why| format!("Could not create {}: {}", abi_spec_path.display(), why))?;
|
||||||
|
|
||||||
let abi = artifacts.abi();
|
|
||||||
|
|
||||||
let mut writer = BufWriter::new(abi_spec_file);
|
let mut writer = BufWriter::new(abi_spec_file);
|
||||||
to_writer_pretty(&mut writer, &abi).map_err(|_| "Unable to write data to file.".to_string())?;
|
to_writer_pretty(&mut writer, &abi).map_err(|_| "Unable to write data to file.".to_string())?;
|
||||||
|
|
||||||
if sub_matches.is_present("verbose") {
|
if sub_matches.is_present("verbose") {
|
||||||
// debugging output
|
// debugging output
|
||||||
println!("Compiled program:\n{}", program_flattened);
|
//println!("Compiled program:\n{}", program_flattened);
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Compiled code written to '{}'", bin_output_path.display());
|
println!("Compiled code written to '{}'", bin_output_path.display());
|
||||||
|
@ -185,15 +214,15 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
|
||||||
.map_err(|why| format!("Could not create {}: {}", hr_output_path.display(), why))?;
|
.map_err(|why| format!("Could not create {}: {}", hr_output_path.display(), why))?;
|
||||||
|
|
||||||
let mut hrofb = BufWriter::new(hr_output_file);
|
let mut hrofb = BufWriter::new(hr_output_file);
|
||||||
writeln!(&mut hrofb, "{}", program_flattened)
|
// writeln!(&mut hrofb, "{}", program_flattened)
|
||||||
.map_err(|_| "Unable to write data to file".to_string())?;
|
// .map_err(|_| "Unable to write data to file".to_string())?;
|
||||||
hrofb
|
// hrofb
|
||||||
.flush()
|
// .flush()
|
||||||
.map_err(|_| "Unable to flush buffer".to_string())?;
|
// .map_err(|_| "Unable to flush buffer".to_string())?;
|
||||||
|
|
||||||
println!("Human readable code to '{}'", hr_output_path.display());
|
println!("Human readable code to '{}'", hr_output_path.display());
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Number of constraints: {}", num_constraints);
|
//println!("Number of constraints: {}", num_constraints);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,7 +106,7 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
|
||||||
}
|
}
|
||||||
false => ConcreteSignature::new()
|
false => ConcreteSignature::new()
|
||||||
.inputs(vec![ConcreteType::FieldElement; ir_prog.arguments.len()])
|
.inputs(vec![ConcreteType::FieldElement; ir_prog.arguments.len()])
|
||||||
.outputs(vec![ConcreteType::FieldElement; ir_prog.returns.len()]),
|
.outputs(vec![ConcreteType::FieldElement; ir_prog.return_count]),
|
||||||
};
|
};
|
||||||
|
|
||||||
use zokrates_abi::Inputs;
|
use zokrates_abi::Inputs;
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||||
//! @date 2018
|
//! @date 2018
|
||||||
use crate::absy::{Module, OwnedModuleId, Program};
|
use crate::absy::{Module, OwnedModuleId, Program};
|
||||||
use crate::flatten::Flattener;
|
use crate::flatten::FlattenerIterator;
|
||||||
use crate::imports::{self, Importer};
|
use crate::imports::{self, Importer};
|
||||||
use crate::ir;
|
use crate::ir;
|
||||||
use crate::macros;
|
use crate::macros;
|
||||||
|
@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::io;
|
use std::io;
|
||||||
|
use std::io::Write;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use typed_arena::Arena;
|
use typed_arena::Arena;
|
||||||
use zokrates_common::Resolver;
|
use zokrates_common::Resolver;
|
||||||
|
@ -25,14 +26,14 @@ use zokrates_field::Field;
|
||||||
use zokrates_pest_ast as pest;
|
use zokrates_pest_ast as pest;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct CompilationArtifacts<T: Field> {
|
pub struct CompilationArtifacts<I> {
|
||||||
prog: ir::Prog<T>,
|
pub prog: ir::ProgIterator<I>,
|
||||||
abi: Abi,
|
pub abi: Abi,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Field> CompilationArtifacts<T> {
|
impl<I> CompilationArtifacts<I> {
|
||||||
pub fn prog(&self) -> &ir::Prog<T> {
|
pub fn prog(self) -> ir::ProgIterator<I> {
|
||||||
&self.prog
|
self.prog
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn abi(&self) -> &Abi {
|
pub fn abi(&self) -> &Abi {
|
||||||
|
@ -183,37 +184,46 @@ impl CompileConfig {
|
||||||
|
|
||||||
type FilePath = PathBuf;
|
type FilePath = PathBuf;
|
||||||
|
|
||||||
pub fn compile<T: Field, E: Into<imports::Error>>(
|
pub fn compile<'ast, T: Field, E: Into<imports::Error>>(
|
||||||
source: String,
|
source: String,
|
||||||
location: FilePath,
|
location: FilePath,
|
||||||
resolver: Option<&dyn Resolver<E>>,
|
resolver: Option<&dyn Resolver<E>>,
|
||||||
config: &CompileConfig,
|
config: CompileConfig,
|
||||||
) -> Result<CompilationArtifacts<T>, CompileErrors> {
|
arena: &'ast Arena<String>,
|
||||||
let arena = Arena::new();
|
) -> Result<CompilationArtifacts<impl Iterator<Item = ir::Statement<T>> + 'ast>, CompileErrors> {
|
||||||
|
let (typed_ast, abi): (crate::zir::ZirProgram<'_, T>, _) =
|
||||||
let (typed_ast, abi) = check_with_arena(source, location.clone(), resolver, config, &arena)?;
|
check_with_arena(source, location.clone(), resolver, &config, arena)?;
|
||||||
|
|
||||||
// flatten input program
|
// flatten input program
|
||||||
log::debug!("Flatten");
|
log::debug!("Flatten");
|
||||||
let program_flattened = Flattener::flatten(typed_ast, config);
|
let program_flattened = FlattenerIterator::from_function_and_config(typed_ast.main, config);
|
||||||
|
|
||||||
// constant propagation after call resolution
|
// // constant propagation after call resolution
|
||||||
log::debug!("Propagate flat program");
|
// log::debug!("Propagate flat program");
|
||||||
let program_flattened = program_flattened.propagate();
|
// let program_flattened = program_flattened.propagate();
|
||||||
|
|
||||||
// convert to ir
|
// convert to ir
|
||||||
log::debug!("Convert to IR");
|
log::debug!("Convert to IR");
|
||||||
let ir_prog = ir::Prog::from(program_flattened);
|
let ir_prog = ir::ProgIterator {
|
||||||
|
arguments: program_flattened
|
||||||
|
.arguments_flattened
|
||||||
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.map(|a| a.into())
|
||||||
|
.collect(),
|
||||||
|
return_count: 0,
|
||||||
|
statements: ir::from_flat::from_flat(program_flattened),
|
||||||
|
};
|
||||||
|
|
||||||
// optimize
|
// optimize
|
||||||
log::debug!("Optimise IR");
|
log::debug!("Optimise IR");
|
||||||
let optimized_ir_prog = ir_prog.optimize();
|
let optimized_ir_prog = ir_prog.optimize();
|
||||||
|
|
||||||
// analyse ir (check constraints)
|
// analyse ir (check constraints)
|
||||||
log::debug!("Analyse IR");
|
// log::debug!("Analyse IR");
|
||||||
let optimized_ir_prog = optimized_ir_prog
|
// let optimized_ir_prog = optimized_ir_prog
|
||||||
.analyse()
|
// .analyse()
|
||||||
.map_err(|e| CompileErrorInner::from(e).in_file(location.as_path()))?;
|
// .map_err(|e| CompileErrorInner::from(e).in_file(location.as_path()))?;
|
||||||
|
|
||||||
Ok(CompilationArtifacts {
|
Ok(CompilationArtifacts {
|
||||||
prog: optimized_ir_prog,
|
prog: optimized_ir_prog,
|
||||||
|
|
|
@ -96,32 +96,24 @@ impl fmt::Display for RuntimeError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, PartialEq)]
|
pub type FlatProg<T> = FlatFunction<T>;
|
||||||
pub struct FlatProg<T: Field> {
|
|
||||||
/// FlatFunctions of the program
|
|
||||||
pub main: FlatFunction<T>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Field> fmt::Display for FlatProg<T> {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
write!(f, "{}", self.main)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Field> fmt::Debug for FlatProg<T> {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
write!(f, "flat_program(main: {}\t)", self.main)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, PartialEq)]
|
#[derive(Clone, PartialEq)]
|
||||||
pub struct FlatFunction<T: Field> {
|
pub struct FlatFunction<T> {
|
||||||
/// Arguments of the function
|
/// Arguments of the function
|
||||||
pub arguments: Vec<FlatParameter>,
|
pub arguments: Vec<FlatParameter>,
|
||||||
/// Vector of statements that are executed when running the function
|
/// Vector of statements that are executed when running the function
|
||||||
pub statements: Vec<FlatStatement<T>>,
|
pub statements: Vec<FlatStatement<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub type FlatProgIterator<T> = FlatFunctionIterator<T>;
|
||||||
|
pub struct FlatFunctionIterator<I> {
|
||||||
|
/// Arguments of the function
|
||||||
|
pub arguments: Vec<FlatParameter>,
|
||||||
|
/// Vector of statements that are executed when running the function
|
||||||
|
pub statements: I,
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: Field> fmt::Display for FlatFunction<T> {
|
impl<T: Field> fmt::Display for FlatFunction<T> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
write!(
|
write!(
|
||||||
|
@ -167,7 +159,7 @@ impl<T: Field> fmt::Debug for FlatFunction<T> {
|
||||||
/// * r1cs - R1CS in standard JSON data format
|
/// * r1cs - R1CS in standard JSON data format
|
||||||
|
|
||||||
#[derive(Clone, PartialEq)]
|
#[derive(Clone, PartialEq)]
|
||||||
pub enum FlatStatement<T: Field> {
|
pub enum FlatStatement<T> {
|
||||||
Return(FlatExpressionList<T>),
|
Return(FlatExpressionList<T>),
|
||||||
Condition(FlatExpression<T>, FlatExpression<T>, RuntimeError),
|
Condition(FlatExpression<T>, FlatExpression<T>, RuntimeError),
|
||||||
Definition(FlatVariable, FlatExpression<T>),
|
Definition(FlatVariable, FlatExpression<T>),
|
||||||
|
@ -239,7 +231,7 @@ impl<T: Field> FlatStatement<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Hash, Debug, PartialEq, Eq)]
|
#[derive(Clone, Hash, Debug, PartialEq, Eq)]
|
||||||
pub struct FlatDirective<T: Field> {
|
pub struct FlatDirective<T> {
|
||||||
pub inputs: Vec<FlatExpression<T>>,
|
pub inputs: Vec<FlatExpression<T>>,
|
||||||
pub outputs: Vec<FlatVariable>,
|
pub outputs: Vec<FlatVariable>,
|
||||||
pub solver: Solver,
|
pub solver: Solver,
|
||||||
|
|
|
@ -16,17 +16,67 @@ use crate::flat_absy::{RuntimeError, *};
|
||||||
use crate::solvers::Solver;
|
use crate::solvers::Solver;
|
||||||
use crate::zir::types::{Type, UBitwidth};
|
use crate::zir::types::{Type, UBitwidth};
|
||||||
use crate::zir::*;
|
use crate::zir::*;
|
||||||
use std::collections::hash_map::Entry;
|
use std::collections::{
|
||||||
use std::collections::HashMap;
|
hash_map::{Entry, HashMap},
|
||||||
|
VecDeque,
|
||||||
|
};
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
use zokrates_field::Field;
|
use zokrates_field::Field;
|
||||||
|
|
||||||
type FlatStatements<T> = Vec<FlatStatement<T>>;
|
type FlatStatements<T> = VecDeque<FlatStatement<T>>;
|
||||||
|
|
||||||
|
/// Flattens a function
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `funct` - `ZirFunction` that will be flattened
|
||||||
|
impl<'ast, T: Field> FlattenerIterator<'ast, T> {
|
||||||
|
pub fn from_function_and_config(funct: ZirFunction<'ast, T>, config: CompileConfig) -> Self {
|
||||||
|
let mut flattener = Flattener::new(config);
|
||||||
|
let mut statements_flattened = FlatStatements::new();
|
||||||
|
// push parameters
|
||||||
|
let arguments_flattened = funct
|
||||||
|
.arguments
|
||||||
|
.into_iter()
|
||||||
|
.map(|p| flattener.use_parameter(&p, &mut statements_flattened))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
FlattenerIterator {
|
||||||
|
statements: funct.statements.into(),
|
||||||
|
arguments_flattened,
|
||||||
|
statements_flattened,
|
||||||
|
flattener,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FlattenerIterator<'ast, T> {
|
||||||
|
pub statements: VecDeque<ZirStatement<'ast, T>>,
|
||||||
|
pub arguments_flattened: Vec<FlatParameter>,
|
||||||
|
pub statements_flattened: FlatStatements<T>,
|
||||||
|
pub flattener: Flattener<'ast, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast, T: Field> Iterator for FlattenerIterator<'ast, T> {
|
||||||
|
type Item = FlatStatement<T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.statements_flattened.is_empty() {
|
||||||
|
match self.statements.pop_front() {
|
||||||
|
Some(s) => {
|
||||||
|
self.flattener
|
||||||
|
.flatten_statement(&mut self.statements_flattened, s);
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.statements_flattened.pop_front()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Flattener, computes flattened program.
|
/// Flattener, computes flattened program.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Flattener<'ast, T: Field> {
|
pub struct Flattener<'ast, T> {
|
||||||
config: &'ast CompileConfig,
|
config: CompileConfig,
|
||||||
/// Index of the next introduced variable while processing the program.
|
/// Index of the next introduced variable while processing the program.
|
||||||
next_var_idx: usize,
|
next_var_idx: usize,
|
||||||
/// `FlatVariable`s corresponding to each `Identifier`
|
/// `FlatVariable`s corresponding to each `Identifier`
|
||||||
|
@ -156,13 +206,8 @@ impl From<crate::zir::RuntimeError> for RuntimeError {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T: Field> Flattener<'ast, T> {
|
impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
pub fn flatten(p: ZirProgram<'ast, T>, config: &CompileConfig) -> FlatProg<T> {
|
|
||||||
Flattener::new(config).flatten_program(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a `Flattener` with fresh `layout`.
|
/// Returns a `Flattener` with fresh `layout`.
|
||||||
|
fn new(config: CompileConfig) -> Flattener<'ast, T> {
|
||||||
fn new(config: &'ast CompileConfig) -> Flattener<'ast, T> {
|
|
||||||
Flattener {
|
Flattener {
|
||||||
config,
|
config,
|
||||||
next_var_idx: 0,
|
next_var_idx: 0,
|
||||||
|
@ -182,7 +227,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
FlatExpression::Identifier(id) => id,
|
FlatExpression::Identifier(id) => id,
|
||||||
e => {
|
e => {
|
||||||
let res = self.use_sym();
|
let res = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(res, e));
|
statements_flattened.push_back(FlatStatement::Definition(res, e));
|
||||||
res
|
res
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -241,7 +286,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// init size_unknown = true
|
// init size_unknown = true
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
size_unknown[0],
|
size_unknown[0],
|
||||||
FlatExpression::Number(T::from(1)),
|
FlatExpression::Number(T::from(1)),
|
||||||
));
|
));
|
||||||
|
@ -250,14 +295,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
for (i, b) in b.iter().enumerate() {
|
for (i, b) in b.iter().enumerate() {
|
||||||
if *b {
|
if *b {
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
is_not_smaller_run[i],
|
is_not_smaller_run[i],
|
||||||
a[i].into(),
|
a[i].into(),
|
||||||
));
|
));
|
||||||
|
|
||||||
// don't need to update size_unknown in the last round
|
// don't need to update size_unknown in the last round
|
||||||
if i < len - 1 {
|
if i < len - 1 {
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
size_unknown[i + 1],
|
size_unknown[i + 1],
|
||||||
FlatExpression::Mult(
|
FlatExpression::Mult(
|
||||||
box size_unknown[i].into(),
|
box size_unknown[i].into(),
|
||||||
|
@ -268,7 +313,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
} else {
|
} else {
|
||||||
// don't need to update size_unknown in the last round
|
// don't need to update size_unknown in the last round
|
||||||
if i < len - 1 {
|
if i < len - 1 {
|
||||||
statements_flattened.push(
|
statements_flattened.push_back(
|
||||||
// sizeUnknown is not changing in this case
|
// sizeUnknown is not changing in this case
|
||||||
// We sill have to assign the old value to the variable of the current run
|
// We sill have to assign the old value to the variable of the current run
|
||||||
// This trivial definition will later be removed by the optimiser
|
// This trivial definition will later be removed by the optimiser
|
||||||
|
@ -285,7 +330,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
let and_name = self.use_sym();
|
let and_name = self.use_sym();
|
||||||
let and = FlatExpression::Mult(box or_left.clone(), box or_right.clone());
|
let and = FlatExpression::Mult(box or_left.clone(), box or_right.clone());
|
||||||
statements_flattened.push(FlatStatement::Definition(and_name, and));
|
statements_flattened.push_back(FlatStatement::Definition(and_name, and));
|
||||||
let or = FlatExpression::Sub(
|
let or = FlatExpression::Sub(
|
||||||
box FlatExpression::Add(box or_left, box or_right),
|
box FlatExpression::Add(box or_left, box or_right),
|
||||||
box and_name.into(),
|
box and_name.into(),
|
||||||
|
@ -310,7 +355,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
/// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise
|
/// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise
|
||||||
fn eq_check(
|
fn eq_check(
|
||||||
&mut self,
|
&mut self,
|
||||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
statements_flattened: &mut FlatStatements<T>,
|
||||||
left: FlatExpression<T>,
|
left: FlatExpression<T>,
|
||||||
right: FlatExpression<T>,
|
right: FlatExpression<T>,
|
||||||
) -> FlatExpression<T> {
|
) -> FlatExpression<T> {
|
||||||
|
@ -329,12 +374,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let name_y = self.use_sym();
|
let name_y = self.use_sym();
|
||||||
let name_m = self.use_sym();
|
let name_m = self.use_sym();
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||||
vec![name_y, name_m],
|
vec![name_y, name_m],
|
||||||
Solver::ConditionEq,
|
Solver::ConditionEq,
|
||||||
vec![x.clone()],
|
vec![x.clone()],
|
||||||
)));
|
)));
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Identifier(name_y),
|
FlatExpression::Identifier(name_y),
|
||||||
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
|
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
|
||||||
RuntimeError::Equal,
|
RuntimeError::Equal,
|
||||||
|
@ -345,7 +390,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
box FlatExpression::Identifier(name_y),
|
box FlatExpression::Identifier(name_y),
|
||||||
);
|
);
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Number(T::zero()),
|
FlatExpression::Number(T::zero()),
|
||||||
FlatExpression::Mult(box res.clone(), box x),
|
FlatExpression::Mult(box res.clone(), box x),
|
||||||
RuntimeError::Equal,
|
RuntimeError::Equal,
|
||||||
|
@ -376,7 +421,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
.fold(FlatExpression::from(T::zero()), |acc, e| {
|
.fold(FlatExpression::from(T::zero()), |acc, e| {
|
||||||
FlatExpression::Add(box acc, box e)
|
FlatExpression::Add(box acc, box e)
|
||||||
});
|
});
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Number(T::from(0)),
|
FlatExpression::Number(T::from(0)),
|
||||||
FlatExpression::Sub(box conditions_sum, box T::from(conditions_count).into()),
|
FlatExpression::Sub(box conditions_sum, box T::from(conditions_count).into()),
|
||||||
RuntimeError::Le,
|
RuntimeError::Le,
|
||||||
|
@ -385,14 +430,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
fn make_conditional(
|
fn make_conditional(
|
||||||
&mut self,
|
&mut self,
|
||||||
statements: Vec<FlatStatement<T>>,
|
statements: FlatStatements<T>,
|
||||||
condition: FlatExpression<T>,
|
condition: FlatExpression<T>,
|
||||||
) -> Vec<FlatStatement<T>> {
|
) -> FlatStatements<T> {
|
||||||
statements
|
statements
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flat_map(|s| match s {
|
.flat_map(|s| match s {
|
||||||
FlatStatement::Condition(left, right, message) => {
|
FlatStatement::Condition(left, right, message) => {
|
||||||
let mut output = vec![];
|
let mut output = VecDeque::new();
|
||||||
|
|
||||||
// we transform (a == b) into (c => (a == b)) which is (!c || (a == b))
|
// we transform (a == b) into (c => (a == b)) which is (!c || (a == b))
|
||||||
|
|
||||||
|
@ -410,12 +455,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
assert!(x.is_linear() && y.is_linear());
|
assert!(x.is_linear() && y.is_linear());
|
||||||
let name_x_or_y = self.use_sym();
|
let name_x_or_y = self.use_sym();
|
||||||
output.push(FlatStatement::Directive(FlatDirective {
|
output.push_back(FlatStatement::Directive(FlatDirective {
|
||||||
solver: Solver::Or,
|
solver: Solver::Or,
|
||||||
outputs: vec![name_x_or_y],
|
outputs: vec![name_x_or_y],
|
||||||
inputs: vec![x.clone(), y.clone()],
|
inputs: vec![x.clone(), y.clone()],
|
||||||
}));
|
}));
|
||||||
output.push(FlatStatement::Condition(
|
output.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Add(
|
FlatExpression::Add(
|
||||||
box x.clone(),
|
box x.clone(),
|
||||||
box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()),
|
box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()),
|
||||||
|
@ -423,7 +468,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
FlatExpression::Mult(box x.clone(), box y.clone()),
|
FlatExpression::Mult(box x.clone(), box y.clone()),
|
||||||
RuntimeError::BranchIsolation,
|
RuntimeError::BranchIsolation,
|
||||||
));
|
));
|
||||||
output.push(FlatStatement::Condition(
|
output.push_back(FlatStatement::Condition(
|
||||||
name_x_or_y.into(),
|
name_x_or_y.into(),
|
||||||
T::one().into(),
|
T::one().into(),
|
||||||
message,
|
message,
|
||||||
|
@ -431,7 +476,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
s => vec![s],
|
s => VecDeque::from([s]),
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
@ -457,16 +502,16 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
self.flatten_boolean_expression(statements_flattened, condition.clone());
|
self.flatten_boolean_expression(statements_flattened, condition.clone());
|
||||||
|
|
||||||
let condition_id = self.use_sym();
|
let condition_id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat));
|
statements_flattened.push_back(FlatStatement::Definition(condition_id, condition_flat));
|
||||||
|
|
||||||
self.condition_cache.insert(condition, condition_id);
|
self.condition_cache.insert(condition, condition_id);
|
||||||
|
|
||||||
let (consequence, alternative) = if self.config.isolate_branches {
|
let (consequence, alternative) = if self.config.isolate_branches {
|
||||||
let mut consequence_statements = vec![];
|
let mut consequence_statements = VecDeque::new();
|
||||||
|
|
||||||
let consequence = consequence.flatten(self, &mut consequence_statements);
|
let consequence = consequence.flatten(self, &mut consequence_statements);
|
||||||
|
|
||||||
let mut alternative_statements = vec![];
|
let mut alternative_statements = VecDeque::new();
|
||||||
|
|
||||||
let alternative = alternative.flatten(self, &mut alternative_statements);
|
let alternative = alternative.flatten(self, &mut alternative_statements);
|
||||||
|
|
||||||
|
@ -495,13 +540,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let alternative = alternative.flat();
|
let alternative = alternative.flat();
|
||||||
|
|
||||||
let consequence_id = self.use_sym();
|
let consequence_id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(consequence_id, consequence));
|
statements_flattened.push_back(FlatStatement::Definition(consequence_id, consequence));
|
||||||
|
|
||||||
let alternative_id = self.use_sym();
|
let alternative_id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(alternative_id, alternative));
|
statements_flattened.push_back(FlatStatement::Definition(alternative_id, alternative));
|
||||||
|
|
||||||
let term0_id = self.use_sym();
|
let term0_id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
term0_id,
|
term0_id,
|
||||||
FlatExpression::Mult(
|
FlatExpression::Mult(
|
||||||
box condition_id.into(),
|
box condition_id.into(),
|
||||||
|
@ -510,7 +555,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
));
|
));
|
||||||
|
|
||||||
let term1_id = self.use_sym();
|
let term1_id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
term1_id,
|
term1_id,
|
||||||
FlatExpression::Mult(
|
FlatExpression::Mult(
|
||||||
box FlatExpression::Sub(
|
box FlatExpression::Sub(
|
||||||
|
@ -522,7 +567,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
));
|
));
|
||||||
|
|
||||||
let res = self.use_sym();
|
let res = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
res,
|
res,
|
||||||
FlatExpression::Add(
|
FlatExpression::Add(
|
||||||
box FlatExpression::from(term0_id),
|
box FlatExpression::from(term0_id),
|
||||||
|
@ -582,7 +627,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let e_bits_be: Vec<FlatVariable> = (0..bit_width).map(|_| self.use_sym()).collect();
|
let e_bits_be: Vec<FlatVariable> = (0..bit_width).map(|_| self.use_sym()).collect();
|
||||||
|
|
||||||
// add a directive to get the bits
|
// add a directive to get the bits
|
||||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||||
e_bits_be.clone(),
|
e_bits_be.clone(),
|
||||||
Solver::bits(bit_width),
|
Solver::bits(bit_width),
|
||||||
vec![e_id],
|
vec![e_id],
|
||||||
|
@ -590,7 +635,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
// bitness checks
|
// bitness checks
|
||||||
for bit in e_bits_be.iter().take(bit_width) {
|
for bit in e_bits_be.iter().take(bit_width) {
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Identifier(*bit),
|
FlatExpression::Identifier(*bit),
|
||||||
FlatExpression::Mult(
|
FlatExpression::Mult(
|
||||||
box FlatExpression::Identifier(*bit),
|
box FlatExpression::Identifier(*bit),
|
||||||
|
@ -613,7 +658,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Identifier(e_id),
|
FlatExpression::Identifier(e_id),
|
||||||
e_sum,
|
e_sum,
|
||||||
RuntimeError::ConstantLtSum,
|
RuntimeError::ConstantLtSum,
|
||||||
|
@ -704,7 +749,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||||
|
|
||||||
// add a directive to get the bits
|
// add a directive to get the bits
|
||||||
statements_flattened.push(FlatStatement::Directive(
|
statements_flattened.push_back(FlatStatement::Directive(
|
||||||
FlatDirective::new(
|
FlatDirective::new(
|
||||||
lhs_bits_be.clone(),
|
lhs_bits_be.clone(),
|
||||||
Solver::bits(safe_width),
|
Solver::bits(safe_width),
|
||||||
|
@ -714,7 +759,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
// bitness checks
|
// bitness checks
|
||||||
for bit in lhs_bits_be.iter().take(safe_width) {
|
for bit in lhs_bits_be.iter().take(safe_width) {
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Identifier(*bit),
|
FlatExpression::Identifier(*bit),
|
||||||
FlatExpression::Mult(
|
FlatExpression::Mult(
|
||||||
box FlatExpression::Identifier(*bit),
|
box FlatExpression::Identifier(*bit),
|
||||||
|
@ -739,7 +784,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Identifier(lhs_id),
|
FlatExpression::Identifier(lhs_id),
|
||||||
lhs_sum,
|
lhs_sum,
|
||||||
RuntimeError::LtSum,
|
RuntimeError::LtSum,
|
||||||
|
@ -756,7 +801,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||||
|
|
||||||
// add a directive to get the bits
|
// add a directive to get the bits
|
||||||
statements_flattened.push(FlatStatement::Directive(
|
statements_flattened.push_back(FlatStatement::Directive(
|
||||||
FlatDirective::new(
|
FlatDirective::new(
|
||||||
rhs_bits_be.clone(),
|
rhs_bits_be.clone(),
|
||||||
Solver::bits(safe_width),
|
Solver::bits(safe_width),
|
||||||
|
@ -766,7 +811,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
// bitness checks
|
// bitness checks
|
||||||
for bit in rhs_bits_be.iter().take(safe_width) {
|
for bit in rhs_bits_be.iter().take(safe_width) {
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Identifier(*bit),
|
FlatExpression::Identifier(*bit),
|
||||||
FlatExpression::Mult(
|
FlatExpression::Mult(
|
||||||
box FlatExpression::Identifier(*bit),
|
box FlatExpression::Identifier(*bit),
|
||||||
|
@ -791,7 +836,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Identifier(rhs_id),
|
FlatExpression::Identifier(rhs_id),
|
||||||
rhs_sum,
|
rhs_sum,
|
||||||
RuntimeError::LtSum,
|
RuntimeError::LtSum,
|
||||||
|
@ -815,15 +860,17 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
(0..bit_width).map(|_| self.use_sym()).collect();
|
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||||
|
|
||||||
// add a directive to get the bits
|
// add a directive to get the bits
|
||||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
statements_flattened.push_back(FlatStatement::Directive(
|
||||||
|
FlatDirective::new(
|
||||||
sub_bits_be.clone(),
|
sub_bits_be.clone(),
|
||||||
Solver::bits(bit_width),
|
Solver::bits(bit_width),
|
||||||
vec![subtraction_result.clone()],
|
vec![subtraction_result.clone()],
|
||||||
)));
|
),
|
||||||
|
));
|
||||||
|
|
||||||
// bitness checks
|
// bitness checks
|
||||||
for bit in sub_bits_be.iter().take(bit_width) {
|
for bit in sub_bits_be.iter().take(bit_width) {
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Identifier(*bit),
|
FlatExpression::Identifier(*bit),
|
||||||
FlatExpression::Mult(
|
FlatExpression::Mult(
|
||||||
box FlatExpression::Identifier(*bit),
|
box FlatExpression::Identifier(*bit),
|
||||||
|
@ -853,7 +900,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
subtraction_result,
|
subtraction_result,
|
||||||
expr,
|
expr,
|
||||||
RuntimeError::LtFinalSum,
|
RuntimeError::LtFinalSum,
|
||||||
|
@ -885,7 +932,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let x_sub_y = FlatExpression::Sub(box x, box y);
|
let x_sub_y = FlatExpression::Sub(box x, box y);
|
||||||
let name_x_mult_x = self.use_sym();
|
let name_x_mult_x = self.use_sym();
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
name_x_mult_x,
|
name_x_mult_x,
|
||||||
FlatExpression::Mult(box x_sub_y.clone(), box x_sub_y),
|
FlatExpression::Mult(box x_sub_y.clone(), box x_sub_y),
|
||||||
));
|
));
|
||||||
|
@ -972,7 +1019,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
(0..bit_width).map(|_| self.use_sym()).collect();
|
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||||
|
|
||||||
// add a directive to get the bits
|
// add a directive to get the bits
|
||||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||||
sub_bits_be.clone(),
|
sub_bits_be.clone(),
|
||||||
Solver::bits(bit_width),
|
Solver::bits(bit_width),
|
||||||
vec![subtraction_result.clone()],
|
vec![subtraction_result.clone()],
|
||||||
|
@ -980,7 +1027,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
// bitness checks
|
// bitness checks
|
||||||
for bit in sub_bits_be.iter().take(bit_width) {
|
for bit in sub_bits_be.iter().take(bit_width) {
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Identifier(*bit),
|
FlatExpression::Identifier(*bit),
|
||||||
FlatExpression::Mult(
|
FlatExpression::Mult(
|
||||||
box FlatExpression::Identifier(*bit),
|
box FlatExpression::Identifier(*bit),
|
||||||
|
@ -1010,7 +1057,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
subtraction_result,
|
subtraction_result,
|
||||||
expr,
|
expr,
|
||||||
RuntimeError::LtFinalSum,
|
RuntimeError::LtFinalSum,
|
||||||
|
@ -1042,12 +1089,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let y = self.flatten_boolean_expression(statements_flattened, rhs);
|
let y = self.flatten_boolean_expression(statements_flattened, rhs);
|
||||||
assert!(x.is_linear() && y.is_linear());
|
assert!(x.is_linear() && y.is_linear());
|
||||||
let name_x_or_y = self.use_sym();
|
let name_x_or_y = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Directive(FlatDirective {
|
statements_flattened.push_back(FlatStatement::Directive(FlatDirective {
|
||||||
solver: Solver::Or,
|
solver: Solver::Or,
|
||||||
outputs: vec![name_x_or_y],
|
outputs: vec![name_x_or_y],
|
||||||
inputs: vec![x.clone(), y.clone()],
|
inputs: vec![x.clone(), y.clone()],
|
||||||
}));
|
}));
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Add(
|
FlatExpression::Add(
|
||||||
box x.clone(),
|
box x.clone(),
|
||||||
box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()),
|
box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()),
|
||||||
|
@ -1063,7 +1110,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
let name_x_and_y = self.use_sym();
|
let name_x_and_y = self.use_sym();
|
||||||
assert!(x.is_linear() && y.is_linear());
|
assert!(x.is_linear() && y.is_linear());
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
name_x_and_y,
|
name_x_and_y,
|
||||||
FlatExpression::Mult(box x, box y),
|
FlatExpression::Mult(box x, box y),
|
||||||
));
|
));
|
||||||
|
@ -1373,14 +1420,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
left_flattened
|
left_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
let d = if right_flattened.is_linear() {
|
let d = if right_flattened.is_linear() {
|
||||||
right_flattened
|
right_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1388,14 +1435,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let invd = self.use_sym();
|
let invd = self.use_sym();
|
||||||
|
|
||||||
// # invd = 1/d
|
// # invd = 1/d
|
||||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||||
vec![invd],
|
vec![invd],
|
||||||
Solver::Div,
|
Solver::Div,
|
||||||
vec![FlatExpression::Number(T::one()), d.clone()],
|
vec![FlatExpression::Number(T::one()), d.clone()],
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// assert(invd * d == 1)
|
// assert(invd * d == 1)
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Number(T::one()),
|
FlatExpression::Number(T::one()),
|
||||||
FlatExpression::Mult(box invd.into(), box d.clone()),
|
FlatExpression::Mult(box invd.into(), box d.clone()),
|
||||||
RuntimeError::Inverse,
|
RuntimeError::Inverse,
|
||||||
|
@ -1405,7 +1452,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let q = self.use_sym();
|
let q = self.use_sym();
|
||||||
let r = self.use_sym();
|
let r = self.use_sym();
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Directive(FlatDirective {
|
statements_flattened.push_back(FlatStatement::Directive(FlatDirective {
|
||||||
inputs: vec![n.clone(), d.clone()],
|
inputs: vec![n.clone(), d.clone()],
|
||||||
outputs: vec![q, r],
|
outputs: vec![q, r],
|
||||||
solver: Solver::EuclideanDiv,
|
solver: Solver::EuclideanDiv,
|
||||||
|
@ -1439,7 +1486,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
);
|
);
|
||||||
|
|
||||||
// q*d == n - r
|
// q*d == n - r
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Sub(box n, box r.into()),
|
FlatExpression::Sub(box n, box r.into()),
|
||||||
FlatExpression::Mult(box q.into(), box d),
|
FlatExpression::Mult(box q.into(), box d),
|
||||||
RuntimeError::Euclidean,
|
RuntimeError::Euclidean,
|
||||||
|
@ -1520,14 +1567,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
left_flattened
|
left_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
let new_right = if right_flattened.is_linear() {
|
let new_right = if right_flattened.is_linear() {
|
||||||
right_flattened
|
right_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1550,14 +1597,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
left_flattened
|
left_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
let new_right = if right_flattened.is_linear() {
|
let new_right = if right_flattened.is_linear() {
|
||||||
right_flattened
|
right_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1612,20 +1659,20 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
left_flattened
|
left_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
let new_right = if right_flattened.is_linear() {
|
let new_right = if right_flattened.is_linear() {
|
||||||
right_flattened
|
right_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
|
|
||||||
let res = self.use_sym();
|
let res = self.use_sym();
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
res,
|
res,
|
||||||
FlatExpression::Mult(box new_left, box new_right),
|
FlatExpression::Mult(box new_left, box new_right),
|
||||||
));
|
));
|
||||||
|
@ -1974,7 +2021,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
}
|
}
|
||||||
Entry::Vacant(_) => {
|
Entry::Vacant(_) => {
|
||||||
let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>();
|
let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>();
|
||||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||||
bits.clone(),
|
bits.clone(),
|
||||||
Solver::Bits(from),
|
Solver::Bits(from),
|
||||||
vec![e.field.clone().unwrap()],
|
vec![e.field.clone().unwrap()],
|
||||||
|
@ -1996,7 +2043,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let sum = flat_expression_from_bits(bits.clone());
|
let sum = flat_expression_from_bits(bits.clone());
|
||||||
|
|
||||||
// sum check
|
// sum check
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
e.field.clone().unwrap(),
|
e.field.clone().unwrap(),
|
||||||
sum.clone(),
|
sum.clone(),
|
||||||
RuntimeError::Sum,
|
RuntimeError::Sum,
|
||||||
|
@ -2055,7 +2102,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
range_check = FlatExpression::Add(box range_check, box condition.clone());
|
range_check = FlatExpression::Add(box range_check, box condition.clone());
|
||||||
|
|
||||||
let conditional_element_id = self.use_sym();
|
let conditional_element_id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
conditional_element_id,
|
conditional_element_id,
|
||||||
FlatExpression::Mult(box condition, box element.flat()),
|
FlatExpression::Mult(box condition, box element.flat()),
|
||||||
));
|
));
|
||||||
|
@ -2065,7 +2112,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
range_check,
|
range_check,
|
||||||
FlatExpression::Number(T::one()),
|
FlatExpression::Number(T::one()),
|
||||||
RuntimeError::SelectRangeCheck,
|
RuntimeError::SelectRangeCheck,
|
||||||
|
@ -2099,14 +2146,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
left_flattened
|
left_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
let new_right = if right_flattened.is_linear() {
|
let new_right = if right_flattened.is_linear() {
|
||||||
right_flattened
|
right_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
FlatExpression::Add(box new_left, box new_right)
|
FlatExpression::Add(box new_left, box new_right)
|
||||||
|
@ -2119,14 +2166,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
left_flattened
|
left_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
let new_right = if right_flattened.is_linear() {
|
let new_right = if right_flattened.is_linear() {
|
||||||
right_flattened
|
right_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2139,14 +2186,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
left_flattened
|
left_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
let new_right = if right_flattened.is_linear() {
|
let new_right = if right_flattened.is_linear() {
|
||||||
right_flattened
|
right_flattened
|
||||||
} else {
|
} else {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||||
FlatExpression::Identifier(id)
|
FlatExpression::Identifier(id)
|
||||||
};
|
};
|
||||||
FlatExpression::Mult(box new_left, box new_right)
|
FlatExpression::Mult(box new_left, box new_right)
|
||||||
|
@ -2156,12 +2203,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let right_flattened = self.flatten_field_expression(statements_flattened, right);
|
let right_flattened = self.flatten_field_expression(statements_flattened, right);
|
||||||
let new_left: FlatExpression<T> = {
|
let new_left: FlatExpression<T> = {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||||
id.into()
|
id.into()
|
||||||
};
|
};
|
||||||
let new_right: FlatExpression<T> = {
|
let new_right: FlatExpression<T> = {
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||||
id.into()
|
id.into()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2169,28 +2216,28 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let inverse = self.use_sym();
|
let inverse = self.use_sym();
|
||||||
|
|
||||||
// # invb = 1/b
|
// # invb = 1/b
|
||||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||||
vec![invb],
|
vec![invb],
|
||||||
Solver::Div,
|
Solver::Div,
|
||||||
vec![FlatExpression::Number(T::one()), new_right.clone()],
|
vec![FlatExpression::Number(T::one()), new_right.clone()],
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// assert(invb * b == 1)
|
// assert(invb * b == 1)
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Number(T::one()),
|
FlatExpression::Number(T::one()),
|
||||||
FlatExpression::Mult(box invb.into(), box new_right.clone()),
|
FlatExpression::Mult(box invb.into(), box new_right.clone()),
|
||||||
RuntimeError::Inverse,
|
RuntimeError::Inverse,
|
||||||
));
|
));
|
||||||
|
|
||||||
// # c = a/b
|
// # c = a/b
|
||||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||||
vec![inverse],
|
vec![inverse],
|
||||||
Solver::Div,
|
Solver::Div,
|
||||||
vec![new_left.clone(), new_right.clone()],
|
vec![new_left.clone(), new_right.clone()],
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// assert(c * b == a)
|
// assert(c * b == a)
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
new_left,
|
new_left,
|
||||||
FlatExpression::Mult(box new_right, box inverse.into()),
|
FlatExpression::Mult(box new_right, box inverse.into()),
|
||||||
RuntimeError::Division,
|
RuntimeError::Division,
|
||||||
|
@ -2239,7 +2286,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
// introduce a new variable
|
// introduce a new variable
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
// set it to the square of the previous one, stored in state
|
// set it to the square of the previous one, stored in state
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
id,
|
id,
|
||||||
FlatExpression::Mult(
|
FlatExpression::Mult(
|
||||||
box previous.clone(),
|
box previous.clone(),
|
||||||
|
@ -2262,7 +2309,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
true => {
|
true => {
|
||||||
// update the result by introducing a new variable
|
// update the result by introducing a new variable
|
||||||
let id = self.use_sym();
|
let id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
id,
|
id,
|
||||||
FlatExpression::Mult(box acc.clone(), box power), // set the new result to the current result times the current power
|
FlatExpression::Mult(box acc.clone(), box power), // set the new result to the current result times the current power
|
||||||
));
|
));
|
||||||
|
@ -2305,7 +2352,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
.map(|x| x.get_field_unchecked())
|
.map(|x| x.get_field_unchecked())
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
statements_flattened.push(FlatStatement::Return(FlatExpressionList {
|
statements_flattened.push_back(FlatStatement::Return(FlatExpressionList {
|
||||||
expressions: flat_expressions,
|
expressions: flat_expressions,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
@ -2314,13 +2361,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
self.flatten_boolean_expression(statements_flattened, condition.clone());
|
self.flatten_boolean_expression(statements_flattened, condition.clone());
|
||||||
|
|
||||||
let condition_id = self.use_sym();
|
let condition_id = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat));
|
statements_flattened
|
||||||
|
.push_back(FlatStatement::Definition(condition_id, condition_flat));
|
||||||
|
|
||||||
self.condition_cache.insert(condition, condition_id);
|
self.condition_cache.insert(condition, condition_id);
|
||||||
|
|
||||||
if self.config.isolate_branches {
|
if self.config.isolate_branches {
|
||||||
let mut consequence_statements = vec![];
|
let mut consequence_statements = VecDeque::new();
|
||||||
let mut alternative_statements = vec![];
|
let mut alternative_statements = VecDeque::new();
|
||||||
|
|
||||||
consequence
|
consequence
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -2367,7 +2415,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let var = self.use_variable(&assignee);
|
let var = self.use_variable(&assignee);
|
||||||
|
|
||||||
// handle return of function call
|
// handle return of function call
|
||||||
statements_flattened.push(FlatStatement::Definition(var, e));
|
statements_flattened.push_back(FlatStatement::Definition(var, e));
|
||||||
|
|
||||||
var
|
var
|
||||||
}
|
}
|
||||||
|
@ -2431,14 +2479,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
let e = self.flatten_boolean_expression(statements_flattened, e);
|
let e = self.flatten_boolean_expression(statements_flattened, e);
|
||||||
|
|
||||||
if e.is_linear() {
|
if e.is_linear() {
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
e,
|
e,
|
||||||
FlatExpression::Number(T::from(1)),
|
FlatExpression::Number(T::from(1)),
|
||||||
error.into(),
|
error.into(),
|
||||||
));
|
));
|
||||||
} else {
|
} else {
|
||||||
// swap so that left side is linear
|
// swap so that left side is linear
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
FlatExpression::Number(T::from(1)),
|
FlatExpression::Number(T::from(1)),
|
||||||
e,
|
e,
|
||||||
error.into(),
|
error.into(),
|
||||||
|
@ -2474,7 +2522,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
}
|
}
|
||||||
e => {
|
e => {
|
||||||
let id = self.use_variable(&v);
|
let id = self.use_variable(&v);
|
||||||
statements_flattened.push(FlatStatement::Definition(id, e));
|
statements_flattened
|
||||||
|
.push_back(FlatStatement::Definition(id, e));
|
||||||
id
|
id
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -2502,45 +2551,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Flattens a function
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * `funct` - `ZirFunction` that will be flattened
|
|
||||||
fn flatten_function(&mut self, funct: ZirFunction<'ast, T>) -> FlatFunction<T> {
|
|
||||||
self.layout = HashMap::new();
|
|
||||||
|
|
||||||
self.next_var_idx = 0;
|
|
||||||
let mut statements_flattened: FlatStatements<T> = FlatStatements::new();
|
|
||||||
|
|
||||||
// push parameters
|
|
||||||
let arguments_flattened = funct
|
|
||||||
.arguments
|
|
||||||
.into_iter()
|
|
||||||
.map(|p| self.use_parameter(&p, &mut statements_flattened))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// flatten statements in functions and apply substitution
|
|
||||||
for stat in funct.statements {
|
|
||||||
self.flatten_statement(&mut statements_flattened, stat);
|
|
||||||
}
|
|
||||||
|
|
||||||
FlatFunction {
|
|
||||||
arguments: arguments_flattened,
|
|
||||||
statements: statements_flattened,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Flattens a program
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `prog` - `ZirProgram` that will be flattened.
|
|
||||||
fn flatten_program(&mut self, prog: ZirProgram<'ast, T>) -> FlatProg<T> {
|
|
||||||
FlatProg {
|
|
||||||
main: self.flatten_function(prog.main),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Flattens an equality assertion, enforcing it in the circuit.
|
/// Flattens an equality assertion, enforcing it in the circuit.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
@ -2571,7 +2581,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
statements_flattened.push(FlatStatement::Condition(lhs, rhs, error));
|
statements_flattened.push_back(FlatStatement::Condition(lhs, rhs, error));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Identifies a non-linear expression by assigning it to a new identifier.
|
/// Identifies a non-linear expression by assigning it to a new identifier.
|
||||||
|
@ -2589,7 +2599,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
true => e,
|
true => e,
|
||||||
false => {
|
false => {
|
||||||
let sym = self.use_sym();
|
let sym = self.use_sym();
|
||||||
statements_flattened.push(FlatStatement::Definition(sym, e));
|
statements_flattened.push_back(FlatStatement::Definition(sym, e));
|
||||||
FlatExpression::Identifier(sym)
|
FlatExpression::Identifier(sym)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2638,7 +2648,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Type::Boolean => {
|
Type::Boolean => {
|
||||||
statements_flattened.push(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
variable.into(),
|
variable.into(),
|
||||||
FlatExpression::Mult(box variable.into(), box variable.into()),
|
FlatExpression::Mult(box variable.into(), box variable.into()),
|
||||||
RuntimeError::ArgumentBitness,
|
RuntimeError::ArgumentBitness,
|
||||||
|
@ -2649,7 +2659,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
// we insert dummy condition statement for private field elements
|
// we insert dummy condition statement for private field elements
|
||||||
// to avoid unconstrained variables
|
// to avoid unconstrained variables
|
||||||
// translates to y == x * x
|
// translates to y == x * x
|
||||||
statements_flattened.push(FlatStatement::Definition(
|
statements_flattened.push_back(FlatStatement::Definition(
|
||||||
self.use_sym(),
|
self.use_sym(),
|
||||||
FlatExpression::Mult(box variable.into(), box variable.into()),
|
FlatExpression::Mult(box variable.into(), box variable.into()),
|
||||||
));
|
));
|
||||||
|
|
|
@ -46,7 +46,7 @@ pub fn fold_module<T: Field, F: Folder<T>>(f: &mut F, p: Prog<T>) -> Prog<T> {
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flat_map(|s| f.fold_statement(s))
|
.flat_map(|s| f.fold_statement(s))
|
||||||
.collect(),
|
.collect(),
|
||||||
returns: p.returns.into_iter().map(|v| f.fold_variable(v)).collect(),
|
return_count: p.return_count,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::flat_absy::{FlatDirective, FlatExpression, FlatProg, FlatStatement, FlatVariable};
|
use crate::flat_absy::{FlatDirective, FlatExpression, FlatStatement, FlatVariable};
|
||||||
use crate::ir::{Directive, LinComb, Prog, QuadComb, Statement};
|
use crate::ir::{Directive, LinComb, QuadComb, Statement};
|
||||||
use zokrates_field::Field;
|
use zokrates_field::Field;
|
||||||
|
|
||||||
impl<T: Field> QuadComb<T> {
|
impl<T: Field> QuadComb<T> {
|
||||||
|
@ -17,50 +17,25 @@ impl<T: Field> QuadComb<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Field> From<FlatProg<T>> for Prog<T> {
|
pub fn from_flat<'ast, T: Field, I: Iterator<Item = FlatStatement<T>>>(
|
||||||
fn from(flat_prog: FlatProg<T>) -> Prog<T> {
|
flat_prog_iterator: I,
|
||||||
// get the main function
|
) -> impl Iterator<Item = Statement<T>> {
|
||||||
let main = flat_prog.main;
|
flat_prog_iterator.filter_map(|s| match s {
|
||||||
|
|
||||||
let return_expressions: Vec<FlatExpression<T>> = main
|
|
||||||
.statements
|
|
||||||
.iter()
|
|
||||||
.filter_map(|s| match s {
|
|
||||||
FlatStatement::Return(el) => Some(el.expressions.clone()),
|
|
||||||
_ => None,
|
|
||||||
})
|
|
||||||
.next()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Prog {
|
|
||||||
arguments: main.arguments,
|
|
||||||
returns: return_expressions
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(index, _)| FlatVariable::public(index))
|
|
||||||
.collect(),
|
|
||||||
statements: main
|
|
||||||
.statements
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|s| match s {
|
|
||||||
FlatStatement::Return(..) => None,
|
FlatStatement::Return(..) => None,
|
||||||
s => Some(s.into()),
|
s => Some(s.into()),
|
||||||
})
|
})
|
||||||
.chain(
|
// .chain(
|
||||||
return_expressions
|
// return_expressions
|
||||||
.into_iter()
|
// .into_iter()
|
||||||
.enumerate()
|
// .enumerate()
|
||||||
.map(|(index, expression)| {
|
// .map(|(index, expression)| {
|
||||||
Statement::Constraint(
|
// Statement::Constraint(
|
||||||
QuadComb::from_flat_expression(expression),
|
// QuadComb::from_flat_expression(expression),
|
||||||
FlatVariable::public(index).into(),
|
// FlatVariable::public(index).into(),
|
||||||
None,
|
// None,
|
||||||
)
|
// )
|
||||||
}),
|
// }),
|
||||||
)
|
// )
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Field> From<FlatExpression<T>> for LinComb<T> {
|
impl<T: Field> From<FlatExpression<T>> for LinComb<T> {
|
||||||
|
|
|
@ -8,7 +8,7 @@ use zokrates_field::Field;
|
||||||
|
|
||||||
mod expression;
|
mod expression;
|
||||||
pub mod folder;
|
pub mod folder;
|
||||||
mod from_flat;
|
pub mod from_flat;
|
||||||
mod interpreter;
|
mod interpreter;
|
||||||
mod serialize;
|
mod serialize;
|
||||||
pub mod smtlib2;
|
pub mod smtlib2;
|
||||||
|
@ -78,10 +78,51 @@ impl<T: Field> fmt::Display for Statement<T> {
|
||||||
pub struct Prog<T> {
|
pub struct Prog<T> {
|
||||||
pub statements: Vec<Statement<T>>,
|
pub statements: Vec<Statement<T>>,
|
||||||
pub arguments: Vec<FlatParameter>,
|
pub arguments: Vec<FlatParameter>,
|
||||||
pub returns: Vec<FlatVariable>,
|
pub return_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct ProgIterator<I> {
|
||||||
|
pub statements: I,
|
||||||
|
pub arguments: Vec<FlatParameter>,
|
||||||
|
pub return_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<Prog<T>> for ProgIterator<std::vec::IntoIter<Statement<T>>> {
|
||||||
|
fn from(p: Prog<T>) -> ProgIterator<std::vec::IntoIter<Statement<T>>> {
|
||||||
|
ProgIterator {
|
||||||
|
statements: p.statements.into_iter(),
|
||||||
|
arguments: p.arguments,
|
||||||
|
return_count: p.return_count,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, I: Iterator<Item = Statement<T>>> From<ProgIterator<I>> for Prog<T> {
|
||||||
|
fn from(p: ProgIterator<I>) -> Prog<T> {
|
||||||
|
Prog {
|
||||||
|
statements: p.statements.collect(),
|
||||||
|
arguments: p.arguments,
|
||||||
|
return_count: p.return_count,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> ProgIterator<T> {
|
||||||
|
pub fn returns(&self) -> Vec<FlatVariable> {
|
||||||
|
(0..self.return_count)
|
||||||
|
.map(|id| FlatVariable::public(id))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Field> Prog<T> {
|
impl<T: Field> Prog<T> {
|
||||||
|
pub fn returns(&self) -> Vec<FlatVariable> {
|
||||||
|
(0..self.return_count)
|
||||||
|
.map(|id| FlatVariable::public(id))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn constraint_count(&self) -> usize {
|
pub fn constraint_count(&self) -> usize {
|
||||||
self.statements
|
self.statements
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -113,14 +154,14 @@ impl<T: Field> fmt::Display for Prog<T> {
|
||||||
.map(|v| format!("{}", v))
|
.map(|v| format!("{}", v))
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(", "),
|
.join(", "),
|
||||||
self.returns.len(),
|
self.return_count,
|
||||||
self.statements
|
self.statements
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| format!("\t{}", s))
|
.map(|s| format!("\t{}", s))
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join("\n"),
|
.join("\n"),
|
||||||
self.returns
|
(0..self.return_count)
|
||||||
.iter()
|
.map(|i| FlatVariable::public(i))
|
||||||
.map(|e| format!("{}", e))
|
.map(|e| format!("{}", e))
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(", ")
|
.join(", ")
|
||||||
|
|
|
@ -49,9 +49,6 @@ pub fn visit_module<T: Field, F: Visitor<T>>(f: &mut F, p: &Prog<T>) {
|
||||||
for expr in p.statements.iter() {
|
for expr in p.statements.iter() {
|
||||||
f.visit_statement(expr);
|
f.visit_statement(expr);
|
||||||
}
|
}
|
||||||
for expr in p.returns.iter() {
|
|
||||||
f.visit_variable(expr);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn visit_statement<T: Field, F: Visitor<T>>(f: &mut F, s: &Statement<T>) {
|
pub fn visit_statement<T: Field, F: Visitor<T>>(f: &mut F, s: &Statement<T>) {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::ir::{folder::Folder, LinComb};
|
use crate::ir::{folder::Folder, LinComb};
|
||||||
use zokrates_field::Field;
|
use zokrates_field::Field;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
pub struct Canonicalizer;
|
pub struct Canonicalizer;
|
||||||
|
|
||||||
impl<T: Field> Folder<T> for Canonicalizer {
|
impl<T: Field> Folder<T> for Canonicalizer {
|
||||||
|
|
|
@ -17,26 +17,13 @@ use crate::solvers::Solver;
|
||||||
use std::collections::hash_map::{Entry, HashMap};
|
use std::collections::hash_map::{Entry, HashMap};
|
||||||
use zokrates_field::Field;
|
use zokrates_field::Field;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Default)]
|
||||||
pub struct DirectiveOptimizer<T: Field> {
|
pub struct DirectiveOptimizer<T> {
|
||||||
calls: HashMap<(Solver, Vec<QuadComb<T>>), Vec<FlatVariable>>,
|
calls: HashMap<(Solver, Vec<QuadComb<T>>), Vec<FlatVariable>>,
|
||||||
/// Map of renamings for reassigned variables while processing the program.
|
/// Map of renamings for reassigned variables while processing the program.
|
||||||
substitution: HashMap<FlatVariable, FlatVariable>,
|
substitution: HashMap<FlatVariable, FlatVariable>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Field> DirectiveOptimizer<T> {
|
|
||||||
fn new() -> DirectiveOptimizer<T> {
|
|
||||||
DirectiveOptimizer {
|
|
||||||
calls: HashMap::new(),
|
|
||||||
substitution: HashMap::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn optimize(p: Prog<T>) -> Prog<T> {
|
|
||||||
DirectiveOptimizer::new().fold_module(p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
|
impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
|
||||||
fn fold_module(&mut self, p: Prog<T>) -> Prog<T> {
|
fn fold_module(&mut self, p: Prog<T>) -> Prog<T> {
|
||||||
// in order to correctly identify duplicates, we need to first canonicalize the statements
|
// in order to correctly identify duplicates, we need to first canonicalize the statements
|
||||||
|
|
|
@ -16,23 +16,11 @@ fn hash<T: Field>(s: &Statement<T>) -> Hash {
|
||||||
hasher.finish()
|
hasher.finish()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Default)]
|
||||||
pub struct DuplicateOptimizer {
|
pub struct DuplicateOptimizer {
|
||||||
seen: HashSet<Hash>,
|
seen: HashSet<Hash>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DuplicateOptimizer {
|
|
||||||
fn new() -> Self {
|
|
||||||
DuplicateOptimizer {
|
|
||||||
seen: HashSet::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn optimize<T: Field>(p: Prog<T>) -> Prog<T> {
|
|
||||||
Self::new().fold_module(p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Field> Folder<T> for DuplicateOptimizer {
|
impl<T: Field> Folder<T> for DuplicateOptimizer {
|
||||||
fn fold_module(&mut self, p: Prog<T>) -> Prog<T> {
|
fn fold_module(&mut self, p: Prog<T>) -> Prog<T> {
|
||||||
// in order to correctly identify duplicates, we need to first canonicalize the statements
|
// in order to correctly identify duplicates, we need to first canonicalize the statements
|
||||||
|
|
|
@ -10,41 +10,59 @@ mod duplicate;
|
||||||
mod redefinition;
|
mod redefinition;
|
||||||
mod tautology;
|
mod tautology;
|
||||||
|
|
||||||
|
use self::canonicalizer::Canonicalizer;
|
||||||
use self::directive::DirectiveOptimizer;
|
use self::directive::DirectiveOptimizer;
|
||||||
use self::duplicate::DuplicateOptimizer;
|
use self::duplicate::DuplicateOptimizer;
|
||||||
use self::redefinition::RedefinitionOptimizer;
|
use self::redefinition::RedefinitionOptimizer;
|
||||||
use self::tautology::TautologyOptimizer;
|
use self::tautology::TautologyOptimizer;
|
||||||
|
|
||||||
use crate::ir::Prog;
|
use crate::ir::{ProgIterator, Statement};
|
||||||
use zokrates_field::Field;
|
use zokrates_field::Field;
|
||||||
|
|
||||||
impl<T: Field> Prog<T> {
|
impl<T: Field, I: Iterator<Item = Statement<T>>> ProgIterator<I> {
|
||||||
pub fn optimize(self) -> Self {
|
pub fn optimize(self) -> ProgIterator<impl Iterator<Item = Statement<T>>> {
|
||||||
// remove redefinitions
|
// remove redefinitions
|
||||||
log::debug!("Constraints: {}", self.constraint_count());
|
log::debug!(
|
||||||
log::debug!("Optimizer: Remove redefinitions");
|
"Optimizer: Remove redefinitions and tautologies and directives and duplicates"
|
||||||
let r = RedefinitionOptimizer::optimize(self);
|
);
|
||||||
log::debug!("Done");
|
|
||||||
|
|
||||||
// remove constraints that are always satisfied
|
// define all optimizer steps
|
||||||
log::debug!("Constraints: {}", r.constraint_count());
|
let mut redefinition_optimizer = RedefinitionOptimizer::default();
|
||||||
log::debug!("Optimizer: Remove tautologies");
|
let mut tautologies_optimizer = TautologyOptimizer::default();
|
||||||
let r = TautologyOptimizer::optimize(r);
|
let mut directive_optimizer = DirectiveOptimizer::default();
|
||||||
log::debug!("Done");
|
let mut canonicalizer = Canonicalizer::default();
|
||||||
|
let mut duplicate_optimizer = DuplicateOptimizer::default();
|
||||||
|
|
||||||
// deduplicate directives which take the same input
|
// initialize the ones that need initializing
|
||||||
log::debug!("Constraints: {}", r.constraint_count());
|
redefinition_optimizer.ignore.extend(self.returns().clone());
|
||||||
log::debug!("Optimizer: Remove duplicate directive");
|
|
||||||
let r = DirectiveOptimizer::optimize(r);
|
|
||||||
log::debug!("Done");
|
|
||||||
|
|
||||||
// remove duplicate constraints
|
use crate::ir::folder::Folder;
|
||||||
log::debug!("Constraints: {}", r.constraint_count());
|
|
||||||
log::debug!("Optimizer: Remove duplicate constraints");
|
|
||||||
let r = DuplicateOptimizer::optimize(r);
|
|
||||||
log::debug!("Done");
|
|
||||||
|
|
||||||
log::debug!("Constraints: {}", r.constraint_count());
|
let r = ProgIterator {
|
||||||
|
arguments: self
|
||||||
|
.arguments
|
||||||
|
.into_iter()
|
||||||
|
.map(|a| redefinition_optimizer.fold_argument(a))
|
||||||
|
.map(|a| {
|
||||||
|
<TautologyOptimizer as Folder<T>>::fold_argument(&mut tautologies_optimizer, a)
|
||||||
|
})
|
||||||
|
.map(|a| directive_optimizer.fold_argument(a))
|
||||||
|
.map(|a| {
|
||||||
|
<DuplicateOptimizer as Folder<T>>::fold_argument(&mut duplicate_optimizer, a)
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
statements: self
|
||||||
|
.statements
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(move |s| redefinition_optimizer.fold_statement(s))
|
||||||
|
.flat_map(move |s| tautologies_optimizer.fold_statement(s))
|
||||||
|
.flat_map(move |s| canonicalizer.fold_statement(s))
|
||||||
|
.flat_map(move |s| directive_optimizer.fold_statement(s))
|
||||||
|
.flat_map(move |s| duplicate_optimizer.fold_statement(s)),
|
||||||
|
return_count: self.return_count,
|
||||||
|
};
|
||||||
|
|
||||||
|
log::debug!("Done");
|
||||||
r
|
r
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,44 +38,30 @@
|
||||||
|
|
||||||
use crate::flat_absy::flat_variable::FlatVariable;
|
use crate::flat_absy::flat_variable::FlatVariable;
|
||||||
use crate::flat_absy::FlatParameter;
|
use crate::flat_absy::FlatParameter;
|
||||||
use crate::ir::folder::{fold_module, Folder};
|
use crate::ir::folder::Folder;
|
||||||
use crate::ir::LinComb;
|
use crate::ir::LinComb;
|
||||||
use crate::ir::*;
|
use crate::ir::*;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use zokrates_field::Field;
|
use zokrates_field::Field;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct RedefinitionOptimizer<T: Field> {
|
pub struct RedefinitionOptimizer<T> {
|
||||||
/// Map of renamings for reassigned variables while processing the program.
|
/// Map of renamings for reassigned variables while processing the program.
|
||||||
substitution: HashMap<FlatVariable, CanonicalLinComb<T>>,
|
substitution: HashMap<FlatVariable, CanonicalLinComb<T>>,
|
||||||
/// Set of variables that should not be substituted
|
/// Set of variables that should not be substituted
|
||||||
ignore: HashSet<FlatVariable>,
|
pub ignore: HashSet<FlatVariable>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Field> RedefinitionOptimizer<T> {
|
impl<T> Default for RedefinitionOptimizer<T> {
|
||||||
fn new() -> Self {
|
fn default() -> Self {
|
||||||
RedefinitionOptimizer {
|
RedefinitionOptimizer {
|
||||||
substitution: HashMap::new(),
|
substitution: HashMap::new(),
|
||||||
ignore: HashSet::new(),
|
ignore: vec![FlatVariable::one()].into_iter().collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn optimize(p: Prog<T>) -> Prog<T> {
|
|
||||||
RedefinitionOptimizer::new().fold_module(p)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
||||||
fn fold_module(&mut self, p: Prog<T>) -> Prog<T> {
|
|
||||||
// to prevent the optimiser from replacing outputs, add them to the ignored set
|
|
||||||
self.ignore.extend(p.returns.iter().cloned());
|
|
||||||
|
|
||||||
// to prevent the optimiser from replacing ~one, add it to the ignored set
|
|
||||||
self.ignore.insert(FlatVariable::one());
|
|
||||||
|
|
||||||
fold_module(self, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn fold_argument(&mut self, a: FlatParameter) -> FlatParameter {
|
fn fold_argument(&mut self, a: FlatParameter) -> FlatParameter {
|
||||||
// to prevent the optimiser from replacing user input, add it to the ignored set
|
// to prevent the optimiser from replacing user input, add it to the ignored set
|
||||||
self.ignore.insert(a.id);
|
self.ignore.insert(a.id);
|
||||||
|
|
|
@ -10,17 +10,8 @@ use crate::ir::folder::Folder;
|
||||||
use crate::ir::*;
|
use crate::ir::*;
|
||||||
use zokrates_field::Field;
|
use zokrates_field::Field;
|
||||||
|
|
||||||
pub struct TautologyOptimizer {}
|
#[derive(Default)]
|
||||||
|
pub struct TautologyOptimizer;
|
||||||
impl TautologyOptimizer {
|
|
||||||
fn new() -> TautologyOptimizer {
|
|
||||||
TautologyOptimizer {}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn optimize<T: Field>(p: Prog<T>) -> Prog<T> {
|
|
||||||
TautologyOptimizer::new().fold_module(p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Field> Folder<T> for TautologyOptimizer {
|
impl<T: Field> Folder<T> for TautologyOptimizer {
|
||||||
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
|
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
|
||||||
|
|
|
@ -101,14 +101,6 @@ impl<T: Field> Propagate<T> for FlatFunction<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Field> FlatProg<T> {
|
|
||||||
pub fn propagate(self) -> FlatProg<T> {
|
|
||||||
let main = self.main.propagate();
|
|
||||||
|
|
||||||
FlatProg { main }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
@ -54,7 +54,8 @@ impl fmt::Debug for FieldParseError {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Field:
|
pub trait Field:
|
||||||
From<i32>
|
'static
|
||||||
|
+ From<i32>
|
||||||
+ From<u32>
|
+ From<u32>
|
||||||
+ From<usize>
|
+ From<usize>
|
||||||
+ From<u128>
|
+ From<u128>
|
||||||
|
|
|
@ -12,5 +12,6 @@ zokrates_abi = { version = "0.1", path = "../zokrates_abi" }
|
||||||
serde = "1.0"
|
serde = "1.0"
|
||||||
serde_derive = "1.0"
|
serde_derive = "1.0"
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
|
typed-arena = "1.4.1"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
|
|
|
@ -149,13 +149,17 @@ fn compile_and_run<T: Field>(t: Tests) {
|
||||||
let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap();
|
let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap();
|
||||||
let resolver = FileSystemResolver::with_stdlib_root(stdlib.to_str().unwrap());
|
let resolver = FileSystemResolver::with_stdlib_root(stdlib.to_str().unwrap());
|
||||||
|
|
||||||
let artifacts = compile::<T, _>(code, entry_point.clone(), Some(&resolver), &config).unwrap();
|
let arena = typed_arena::Arena::new();
|
||||||
|
|
||||||
let bin = artifacts.prog();
|
let artifacts =
|
||||||
let abi = artifacts.abi();
|
compile::<T, _>(code, entry_point.clone(), Some(&resolver), config, &arena).unwrap();
|
||||||
|
|
||||||
|
let abi = artifacts.abi;
|
||||||
|
let bin = artifacts.prog;
|
||||||
|
|
||||||
if let Some(target_count) = t.max_constraint_count {
|
if let Some(target_count) = t.max_constraint_count {
|
||||||
let count = bin.constraint_count();
|
unimplemented!("bin.constraint_count()");
|
||||||
|
let count = 42;
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
count <= target_count,
|
count <= target_count,
|
||||||
|
@ -169,6 +173,8 @@ fn compile_and_run<T: Field>(t: Tests) {
|
||||||
|
|
||||||
let interpreter = zokrates_core::ir::Interpreter::default();
|
let interpreter = zokrates_core::ir::Interpreter::default();
|
||||||
|
|
||||||
|
let bin = &bin.into();
|
||||||
|
|
||||||
for test in t.tests.into_iter() {
|
for test in t.tests.into_iter() {
|
||||||
let with_abi = test.abi.unwrap_or(false);
|
let with_abi = test.abi.unwrap_or(false);
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue