proof of concept of iterator treatment starting at flattening
This commit is contained in:
parent
d924368038
commit
dad17b79e0
21 changed files with 397 additions and 351 deletions
21
Cargo.lock
generated
21
Cargo.lock
generated
|
@ -1074,6 +1074,12 @@ version = "0.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "1.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.11.2"
|
||||
|
@ -1845,6 +1851,16 @@ dependencies = [
|
|||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_cbor"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
|
||||
dependencies = [
|
||||
"half",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.130"
|
||||
|
@ -2383,7 +2399,6 @@ name = "zokrates_cli"
|
|||
version = "0.7.7"
|
||||
dependencies = [
|
||||
"assert_cli",
|
||||
"bincode",
|
||||
"cfg-if 0.1.10",
|
||||
"clap",
|
||||
"dirs",
|
||||
|
@ -2393,8 +2408,11 @@ dependencies = [
|
|||
"lazy_static",
|
||||
"log",
|
||||
"regex 0.2.11",
|
||||
"serde",
|
||||
"serde_cbor",
|
||||
"serde_json",
|
||||
"tempdir",
|
||||
"typed-arena",
|
||||
"zokrates_abi",
|
||||
"zokrates_core",
|
||||
"zokrates_field",
|
||||
|
@ -2545,6 +2563,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"typed-arena",
|
||||
"zokrates_abi",
|
||||
"zokrates_core",
|
||||
"zokrates_field",
|
||||
|
|
|
@ -16,13 +16,15 @@ log = "0.4"
|
|||
env_logger = "0.9.0"
|
||||
cfg-if = "0.1"
|
||||
clap = "2.26.2"
|
||||
bincode = "0.8.0"
|
||||
serde_cbor = "0.11.2"
|
||||
regex = "0.2"
|
||||
zokrates_field = { version = "0.4", path = "../zokrates_field", default-features = false }
|
||||
zokrates_abi = { version = "0.1", path = "../zokrates_abi" }
|
||||
zokrates_core = { version = "0.6", path = "../zokrates_core", default-features = false }
|
||||
typed-arena = "1.4.1"
|
||||
zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"}
|
||||
serde_json = "1.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
dirs = "3.0.1"
|
||||
lazy_static = "1.4.0"
|
||||
|
||||
|
|
|
@ -6,10 +6,35 @@ use std::convert::TryFrom;
|
|||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter, Read, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use typed_arena::Arena;
|
||||
use zokrates_core::compile::{compile, CompilationArtifacts, CompileConfig, CompileError};
|
||||
use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field};
|
||||
use zokrates_fs_resolver::FileSystemResolver;
|
||||
|
||||
use serde::{Serialize, Serializer};
|
||||
use std::cell::Cell;
|
||||
|
||||
pub fn write_as_cbor<I, P, W>(out: &mut W, groups: I) -> serde_cbor::Result<()>
|
||||
where
|
||||
I: IntoIterator<Item = P>,
|
||||
P: Serialize,
|
||||
W: Write,
|
||||
{
|
||||
struct Wrapper<T>(Cell<Option<T>>);
|
||||
|
||||
impl<I, P> Serialize for Wrapper<I>
|
||||
where
|
||||
I: IntoIterator<Item = P>,
|
||||
P: Serialize,
|
||||
{
|
||||
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
|
||||
s.collect_seq(self.0.take().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
serde_cbor::to_writer(out, &Wrapper(Cell::new(Some(groups))))
|
||||
}
|
||||
|
||||
pub fn subcommand() -> App<'static, 'static> {
|
||||
SubCommand::with_name("compile")
|
||||
.about("Compiles into flattened conditions. Produces two files: human-readable '.ztf' file for debugging and binary file")
|
||||
|
@ -136,8 +161,10 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
|
||||
log::debug!("Compile");
|
||||
|
||||
let artifacts: CompilationArtifacts<T> = compile(source, path, Some(&resolver), &config)
|
||||
.map_err(|e| {
|
||||
let arena = Arena::new();
|
||||
|
||||
let artifacts =
|
||||
compile::<T, _>(source, path, Some(&resolver), config, &arena).map_err(|e| {
|
||||
format!(
|
||||
"Compilation failed:\n\n{}",
|
||||
e.0.iter()
|
||||
|
@ -147,10 +174,11 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
)
|
||||
})?;
|
||||
|
||||
let program_flattened = artifacts.prog();
|
||||
let program_flattened = artifacts.prog;
|
||||
let abi = artifacts.abi;
|
||||
|
||||
// number of constraints the flattened program will translate to.
|
||||
let num_constraints = program_flattened.constraint_count();
|
||||
//let num_constraints = program_flattened.constraint_count();
|
||||
|
||||
// serialize flattened program and write to binary file
|
||||
log::debug!("Serialize program");
|
||||
|
@ -159,21 +187,22 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
|
||||
let mut writer = BufWriter::new(bin_output_file);
|
||||
|
||||
program_flattened.serialize(&mut writer);
|
||||
println!("START WRITING TO DISK");
|
||||
|
||||
//program_flattened.serialize(&mut writer);
|
||||
write_as_cbor(&mut writer, program_flattened.statements).unwrap();
|
||||
|
||||
// serialize ABI spec and write to JSON file
|
||||
log::debug!("Serialize ABI");
|
||||
let abi_spec_file = File::create(&abi_spec_path)
|
||||
.map_err(|why| format!("Could not create {}: {}", abi_spec_path.display(), why))?;
|
||||
|
||||
let abi = artifacts.abi();
|
||||
|
||||
let mut writer = BufWriter::new(abi_spec_file);
|
||||
to_writer_pretty(&mut writer, &abi).map_err(|_| "Unable to write data to file.".to_string())?;
|
||||
|
||||
if sub_matches.is_present("verbose") {
|
||||
// debugging output
|
||||
println!("Compiled program:\n{}", program_flattened);
|
||||
//println!("Compiled program:\n{}", program_flattened);
|
||||
}
|
||||
|
||||
println!("Compiled code written to '{}'", bin_output_path.display());
|
||||
|
@ -185,15 +214,15 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
.map_err(|why| format!("Could not create {}: {}", hr_output_path.display(), why))?;
|
||||
|
||||
let mut hrofb = BufWriter::new(hr_output_file);
|
||||
writeln!(&mut hrofb, "{}", program_flattened)
|
||||
.map_err(|_| "Unable to write data to file".to_string())?;
|
||||
hrofb
|
||||
.flush()
|
||||
.map_err(|_| "Unable to flush buffer".to_string())?;
|
||||
// writeln!(&mut hrofb, "{}", program_flattened)
|
||||
// .map_err(|_| "Unable to write data to file".to_string())?;
|
||||
// hrofb
|
||||
// .flush()
|
||||
// .map_err(|_| "Unable to flush buffer".to_string())?;
|
||||
|
||||
println!("Human readable code to '{}'", hr_output_path.display());
|
||||
}
|
||||
|
||||
println!("Number of constraints: {}", num_constraints);
|
||||
//println!("Number of constraints: {}", num_constraints);
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -106,7 +106,7 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
|
|||
}
|
||||
false => ConcreteSignature::new()
|
||||
.inputs(vec![ConcreteType::FieldElement; ir_prog.arguments.len()])
|
||||
.outputs(vec![ConcreteType::FieldElement; ir_prog.returns.len()]),
|
||||
.outputs(vec![ConcreteType::FieldElement; ir_prog.return_count]),
|
||||
};
|
||||
|
||||
use zokrates_abi::Inputs;
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
use crate::absy::{Module, OwnedModuleId, Program};
|
||||
use crate::flatten::Flattener;
|
||||
use crate::flatten::FlattenerIterator;
|
||||
use crate::imports::{self, Importer};
|
||||
use crate::ir;
|
||||
use crate::macros;
|
||||
|
@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
|
|||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use typed_arena::Arena;
|
||||
use zokrates_common::Resolver;
|
||||
|
@ -25,14 +26,14 @@ use zokrates_field::Field;
|
|||
use zokrates_pest_ast as pest;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CompilationArtifacts<T: Field> {
|
||||
prog: ir::Prog<T>,
|
||||
abi: Abi,
|
||||
pub struct CompilationArtifacts<I> {
|
||||
pub prog: ir::ProgIterator<I>,
|
||||
pub abi: Abi,
|
||||
}
|
||||
|
||||
impl<T: Field> CompilationArtifacts<T> {
|
||||
pub fn prog(&self) -> &ir::Prog<T> {
|
||||
&self.prog
|
||||
impl<I> CompilationArtifacts<I> {
|
||||
pub fn prog(self) -> ir::ProgIterator<I> {
|
||||
self.prog
|
||||
}
|
||||
|
||||
pub fn abi(&self) -> &Abi {
|
||||
|
@ -183,37 +184,46 @@ impl CompileConfig {
|
|||
|
||||
type FilePath = PathBuf;
|
||||
|
||||
pub fn compile<T: Field, E: Into<imports::Error>>(
|
||||
pub fn compile<'ast, T: Field, E: Into<imports::Error>>(
|
||||
source: String,
|
||||
location: FilePath,
|
||||
resolver: Option<&dyn Resolver<E>>,
|
||||
config: &CompileConfig,
|
||||
) -> Result<CompilationArtifacts<T>, CompileErrors> {
|
||||
let arena = Arena::new();
|
||||
|
||||
let (typed_ast, abi) = check_with_arena(source, location.clone(), resolver, config, &arena)?;
|
||||
config: CompileConfig,
|
||||
arena: &'ast Arena<String>,
|
||||
) -> Result<CompilationArtifacts<impl Iterator<Item = ir::Statement<T>> + 'ast>, CompileErrors> {
|
||||
let (typed_ast, abi): (crate::zir::ZirProgram<'_, T>, _) =
|
||||
check_with_arena(source, location.clone(), resolver, &config, arena)?;
|
||||
|
||||
// flatten input program
|
||||
log::debug!("Flatten");
|
||||
let program_flattened = Flattener::flatten(typed_ast, config);
|
||||
let program_flattened = FlattenerIterator::from_function_and_config(typed_ast.main, config);
|
||||
|
||||
// constant propagation after call resolution
|
||||
log::debug!("Propagate flat program");
|
||||
let program_flattened = program_flattened.propagate();
|
||||
// // constant propagation after call resolution
|
||||
// log::debug!("Propagate flat program");
|
||||
// let program_flattened = program_flattened.propagate();
|
||||
|
||||
// convert to ir
|
||||
log::debug!("Convert to IR");
|
||||
let ir_prog = ir::Prog::from(program_flattened);
|
||||
let ir_prog = ir::ProgIterator {
|
||||
arguments: program_flattened
|
||||
.arguments_flattened
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|a| a.into())
|
||||
.collect(),
|
||||
return_count: 0,
|
||||
statements: ir::from_flat::from_flat(program_flattened),
|
||||
};
|
||||
|
||||
// optimize
|
||||
log::debug!("Optimise IR");
|
||||
let optimized_ir_prog = ir_prog.optimize();
|
||||
|
||||
// analyse ir (check constraints)
|
||||
log::debug!("Analyse IR");
|
||||
let optimized_ir_prog = optimized_ir_prog
|
||||
.analyse()
|
||||
.map_err(|e| CompileErrorInner::from(e).in_file(location.as_path()))?;
|
||||
// log::debug!("Analyse IR");
|
||||
// let optimized_ir_prog = optimized_ir_prog
|
||||
// .analyse()
|
||||
// .map_err(|e| CompileErrorInner::from(e).in_file(location.as_path()))?;
|
||||
|
||||
Ok(CompilationArtifacts {
|
||||
prog: optimized_ir_prog,
|
||||
|
|
|
@ -96,32 +96,24 @@ impl fmt::Display for RuntimeError {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub struct FlatProg<T: Field> {
|
||||
/// FlatFunctions of the program
|
||||
pub main: FlatFunction<T>,
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for FlatProg<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.main)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Debug for FlatProg<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "flat_program(main: {}\t)", self.main)
|
||||
}
|
||||
}
|
||||
pub type FlatProg<T> = FlatFunction<T>;
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub struct FlatFunction<T: Field> {
|
||||
pub struct FlatFunction<T> {
|
||||
/// Arguments of the function
|
||||
pub arguments: Vec<FlatParameter>,
|
||||
/// Vector of statements that are executed when running the function
|
||||
pub statements: Vec<FlatStatement<T>>,
|
||||
}
|
||||
|
||||
pub type FlatProgIterator<T> = FlatFunctionIterator<T>;
|
||||
pub struct FlatFunctionIterator<I> {
|
||||
/// Arguments of the function
|
||||
pub arguments: Vec<FlatParameter>,
|
||||
/// Vector of statements that are executed when running the function
|
||||
pub statements: I,
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for FlatFunction<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
|
@ -167,7 +159,7 @@ impl<T: Field> fmt::Debug for FlatFunction<T> {
|
|||
/// * r1cs - R1CS in standard JSON data format
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub enum FlatStatement<T: Field> {
|
||||
pub enum FlatStatement<T> {
|
||||
Return(FlatExpressionList<T>),
|
||||
Condition(FlatExpression<T>, FlatExpression<T>, RuntimeError),
|
||||
Definition(FlatVariable, FlatExpression<T>),
|
||||
|
@ -239,7 +231,7 @@ impl<T: Field> FlatStatement<T> {
|
|||
}
|
||||
|
||||
#[derive(Clone, Hash, Debug, PartialEq, Eq)]
|
||||
pub struct FlatDirective<T: Field> {
|
||||
pub struct FlatDirective<T> {
|
||||
pub inputs: Vec<FlatExpression<T>>,
|
||||
pub outputs: Vec<FlatVariable>,
|
||||
pub solver: Solver,
|
||||
|
|
|
@ -16,17 +16,67 @@ use crate::flat_absy::{RuntimeError, *};
|
|||
use crate::solvers::Solver;
|
||||
use crate::zir::types::{Type, UBitwidth};
|
||||
use crate::zir::*;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{
|
||||
hash_map::{Entry, HashMap},
|
||||
VecDeque,
|
||||
};
|
||||
use std::convert::TryFrom;
|
||||
use zokrates_field::Field;
|
||||
|
||||
type FlatStatements<T> = Vec<FlatStatement<T>>;
|
||||
type FlatStatements<T> = VecDeque<FlatStatement<T>>;
|
||||
|
||||
/// Flattens a function
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `funct` - `ZirFunction` that will be flattened
|
||||
impl<'ast, T: Field> FlattenerIterator<'ast, T> {
|
||||
pub fn from_function_and_config(funct: ZirFunction<'ast, T>, config: CompileConfig) -> Self {
|
||||
let mut flattener = Flattener::new(config);
|
||||
let mut statements_flattened = FlatStatements::new();
|
||||
// push parameters
|
||||
let arguments_flattened = funct
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|p| flattener.use_parameter(&p, &mut statements_flattened))
|
||||
.collect();
|
||||
|
||||
FlattenerIterator {
|
||||
statements: funct.statements.into(),
|
||||
arguments_flattened,
|
||||
statements_flattened,
|
||||
flattener,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FlattenerIterator<'ast, T> {
|
||||
pub statements: VecDeque<ZirStatement<'ast, T>>,
|
||||
pub arguments_flattened: Vec<FlatParameter>,
|
||||
pub statements_flattened: FlatStatements<T>,
|
||||
pub flattener: Flattener<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Iterator for FlattenerIterator<'ast, T> {
|
||||
type Item = FlatStatement<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.statements_flattened.is_empty() {
|
||||
match self.statements.pop_front() {
|
||||
Some(s) => {
|
||||
self.flattener
|
||||
.flatten_statement(&mut self.statements_flattened, s);
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
self.statements_flattened.pop_front()
|
||||
}
|
||||
}
|
||||
|
||||
/// Flattener, computes flattened program.
|
||||
#[derive(Debug)]
|
||||
pub struct Flattener<'ast, T: Field> {
|
||||
config: &'ast CompileConfig,
|
||||
pub struct Flattener<'ast, T> {
|
||||
config: CompileConfig,
|
||||
/// Index of the next introduced variable while processing the program.
|
||||
next_var_idx: usize,
|
||||
/// `FlatVariable`s corresponding to each `Identifier`
|
||||
|
@ -156,13 +206,8 @@ impl From<crate::zir::RuntimeError> for RuntimeError {
|
|||
}
|
||||
|
||||
impl<'ast, T: Field> Flattener<'ast, T> {
|
||||
pub fn flatten(p: ZirProgram<'ast, T>, config: &CompileConfig) -> FlatProg<T> {
|
||||
Flattener::new(config).flatten_program(p)
|
||||
}
|
||||
|
||||
/// Returns a `Flattener` with fresh `layout`.
|
||||
|
||||
fn new(config: &'ast CompileConfig) -> Flattener<'ast, T> {
|
||||
fn new(config: CompileConfig) -> Flattener<'ast, T> {
|
||||
Flattener {
|
||||
config,
|
||||
next_var_idx: 0,
|
||||
|
@ -182,7 +227,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FlatExpression::Identifier(id) => id,
|
||||
e => {
|
||||
let res = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(res, e));
|
||||
statements_flattened.push_back(FlatStatement::Definition(res, e));
|
||||
res
|
||||
}
|
||||
}
|
||||
|
@ -241,7 +286,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
|
||||
// init size_unknown = true
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
size_unknown[0],
|
||||
FlatExpression::Number(T::from(1)),
|
||||
));
|
||||
|
@ -250,14 +295,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
for (i, b) in b.iter().enumerate() {
|
||||
if *b {
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
is_not_smaller_run[i],
|
||||
a[i].into(),
|
||||
));
|
||||
|
||||
// don't need to update size_unknown in the last round
|
||||
if i < len - 1 {
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
size_unknown[i + 1],
|
||||
FlatExpression::Mult(
|
||||
box size_unknown[i].into(),
|
||||
|
@ -268,7 +313,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
} else {
|
||||
// don't need to update size_unknown in the last round
|
||||
if i < len - 1 {
|
||||
statements_flattened.push(
|
||||
statements_flattened.push_back(
|
||||
// sizeUnknown is not changing in this case
|
||||
// We sill have to assign the old value to the variable of the current run
|
||||
// This trivial definition will later be removed by the optimiser
|
||||
|
@ -285,7 +330,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
let and_name = self.use_sym();
|
||||
let and = FlatExpression::Mult(box or_left.clone(), box or_right.clone());
|
||||
statements_flattened.push(FlatStatement::Definition(and_name, and));
|
||||
statements_flattened.push_back(FlatStatement::Definition(and_name, and));
|
||||
let or = FlatExpression::Sub(
|
||||
box FlatExpression::Add(box or_left, box or_right),
|
||||
box and_name.into(),
|
||||
|
@ -310,7 +355,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
/// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise
|
||||
fn eq_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
left: FlatExpression<T>,
|
||||
right: FlatExpression<T>,
|
||||
) -> FlatExpression<T> {
|
||||
|
@ -329,12 +374,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let name_y = self.use_sym();
|
||||
let name_m = self.use_sym();
|
||||
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![name_y, name_m],
|
||||
Solver::ConditionEq,
|
||||
vec![x.clone()],
|
||||
)));
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(name_y),
|
||||
FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)),
|
||||
RuntimeError::Equal,
|
||||
|
@ -345,7 +390,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
box FlatExpression::Identifier(name_y),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::zero()),
|
||||
FlatExpression::Mult(box res.clone(), box x),
|
||||
RuntimeError::Equal,
|
||||
|
@ -376,7 +421,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.fold(FlatExpression::from(T::zero()), |acc, e| {
|
||||
FlatExpression::Add(box acc, box e)
|
||||
});
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::from(0)),
|
||||
FlatExpression::Sub(box conditions_sum, box T::from(conditions_count).into()),
|
||||
RuntimeError::Le,
|
||||
|
@ -385,14 +430,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
fn make_conditional(
|
||||
&mut self,
|
||||
statements: Vec<FlatStatement<T>>,
|
||||
statements: FlatStatements<T>,
|
||||
condition: FlatExpression<T>,
|
||||
) -> Vec<FlatStatement<T>> {
|
||||
) -> FlatStatements<T> {
|
||||
statements
|
||||
.into_iter()
|
||||
.flat_map(|s| match s {
|
||||
FlatStatement::Condition(left, right, message) => {
|
||||
let mut output = vec![];
|
||||
let mut output = VecDeque::new();
|
||||
|
||||
// we transform (a == b) into (c => (a == b)) which is (!c || (a == b))
|
||||
|
||||
|
@ -410,12 +455,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
assert!(x.is_linear() && y.is_linear());
|
||||
let name_x_or_y = self.use_sym();
|
||||
output.push(FlatStatement::Directive(FlatDirective {
|
||||
output.push_back(FlatStatement::Directive(FlatDirective {
|
||||
solver: Solver::Or,
|
||||
outputs: vec![name_x_or_y],
|
||||
inputs: vec![x.clone(), y.clone()],
|
||||
}));
|
||||
output.push(FlatStatement::Condition(
|
||||
output.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Add(
|
||||
box x.clone(),
|
||||
box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()),
|
||||
|
@ -423,7 +468,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FlatExpression::Mult(box x.clone(), box y.clone()),
|
||||
RuntimeError::BranchIsolation,
|
||||
));
|
||||
output.push(FlatStatement::Condition(
|
||||
output.push_back(FlatStatement::Condition(
|
||||
name_x_or_y.into(),
|
||||
T::one().into(),
|
||||
message,
|
||||
|
@ -431,7 +476,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
output
|
||||
}
|
||||
s => vec![s],
|
||||
s => VecDeque::from([s]),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
@ -457,16 +502,16 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
self.flatten_boolean_expression(statements_flattened, condition.clone());
|
||||
|
||||
let condition_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat));
|
||||
statements_flattened.push_back(FlatStatement::Definition(condition_id, condition_flat));
|
||||
|
||||
self.condition_cache.insert(condition, condition_id);
|
||||
|
||||
let (consequence, alternative) = if self.config.isolate_branches {
|
||||
let mut consequence_statements = vec![];
|
||||
let mut consequence_statements = VecDeque::new();
|
||||
|
||||
let consequence = consequence.flatten(self, &mut consequence_statements);
|
||||
|
||||
let mut alternative_statements = vec![];
|
||||
let mut alternative_statements = VecDeque::new();
|
||||
|
||||
let alternative = alternative.flatten(self, &mut alternative_statements);
|
||||
|
||||
|
@ -495,13 +540,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let alternative = alternative.flat();
|
||||
|
||||
let consequence_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(consequence_id, consequence));
|
||||
statements_flattened.push_back(FlatStatement::Definition(consequence_id, consequence));
|
||||
|
||||
let alternative_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(alternative_id, alternative));
|
||||
statements_flattened.push_back(FlatStatement::Definition(alternative_id, alternative));
|
||||
|
||||
let term0_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
term0_id,
|
||||
FlatExpression::Mult(
|
||||
box condition_id.into(),
|
||||
|
@ -510,7 +555,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
));
|
||||
|
||||
let term1_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
term1_id,
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Sub(
|
||||
|
@ -522,7 +567,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
));
|
||||
|
||||
let res = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
res,
|
||||
FlatExpression::Add(
|
||||
box FlatExpression::from(term0_id),
|
||||
|
@ -582,7 +627,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let e_bits_be: Vec<FlatVariable> = (0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
e_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![e_id],
|
||||
|
@ -590,7 +635,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
// bitness checks
|
||||
for bit in e_bits_be.iter().take(bit_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
|
@ -613,7 +658,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(e_id),
|
||||
e_sum,
|
||||
RuntimeError::ConstantLtSum,
|
||||
|
@ -704,7 +749,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(
|
||||
statements_flattened.push_back(FlatStatement::Directive(
|
||||
FlatDirective::new(
|
||||
lhs_bits_be.clone(),
|
||||
Solver::bits(safe_width),
|
||||
|
@ -714,7 +759,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
// bitness checks
|
||||
for bit in lhs_bits_be.iter().take(safe_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
|
@ -739,7 +784,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_id),
|
||||
lhs_sum,
|
||||
RuntimeError::LtSum,
|
||||
|
@ -756,7 +801,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(
|
||||
statements_flattened.push_back(FlatStatement::Directive(
|
||||
FlatDirective::new(
|
||||
rhs_bits_be.clone(),
|
||||
Solver::bits(safe_width),
|
||||
|
@ -766,7 +811,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
// bitness checks
|
||||
for bit in rhs_bits_be.iter().take(safe_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
|
@ -791,7 +836,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(rhs_id),
|
||||
rhs_sum,
|
||||
RuntimeError::LtSum,
|
||||
|
@ -815,15 +860,17 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
statements_flattened.push_back(FlatStatement::Directive(
|
||||
FlatDirective::new(
|
||||
sub_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![subtraction_result.clone()],
|
||||
)));
|
||||
),
|
||||
));
|
||||
|
||||
// bitness checks
|
||||
for bit in sub_bits_be.iter().take(bit_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
|
@ -853,7 +900,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
subtraction_result,
|
||||
expr,
|
||||
RuntimeError::LtFinalSum,
|
||||
|
@ -885,7 +932,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let x_sub_y = FlatExpression::Sub(box x, box y);
|
||||
let name_x_mult_x = self.use_sym();
|
||||
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
name_x_mult_x,
|
||||
FlatExpression::Mult(box x_sub_y.clone(), box x_sub_y),
|
||||
));
|
||||
|
@ -972,7 +1019,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
sub_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![subtraction_result.clone()],
|
||||
|
@ -980,7 +1027,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
// bitness checks
|
||||
for bit in sub_bits_be.iter().take(bit_width) {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
|
@ -1010,7 +1057,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
subtraction_result,
|
||||
expr,
|
||||
RuntimeError::LtFinalSum,
|
||||
|
@ -1042,12 +1089,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let y = self.flatten_boolean_expression(statements_flattened, rhs);
|
||||
assert!(x.is_linear() && y.is_linear());
|
||||
let name_x_or_y = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective {
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective {
|
||||
solver: Solver::Or,
|
||||
outputs: vec![name_x_or_y],
|
||||
inputs: vec![x.clone(), y.clone()],
|
||||
}));
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Add(
|
||||
box x.clone(),
|
||||
box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()),
|
||||
|
@ -1063,7 +1110,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
let name_x_and_y = self.use_sym();
|
||||
assert!(x.is_linear() && y.is_linear());
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
name_x_and_y,
|
||||
FlatExpression::Mult(box x, box y),
|
||||
));
|
||||
|
@ -1373,14 +1420,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
left_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
let d = if right_flattened.is_linear() {
|
||||
right_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
|
||||
|
@ -1388,14 +1435,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let invd = self.use_sym();
|
||||
|
||||
// # invd = 1/d
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![invd],
|
||||
Solver::Div,
|
||||
vec![FlatExpression::Number(T::one()), d.clone()],
|
||||
)));
|
||||
|
||||
// assert(invd * d == 1)
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::one()),
|
||||
FlatExpression::Mult(box invd.into(), box d.clone()),
|
||||
RuntimeError::Inverse,
|
||||
|
@ -1405,7 +1452,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let q = self.use_sym();
|
||||
let r = self.use_sym();
|
||||
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective {
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective {
|
||||
inputs: vec![n.clone(), d.clone()],
|
||||
outputs: vec![q, r],
|
||||
solver: Solver::EuclideanDiv,
|
||||
|
@ -1439,7 +1486,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
|
||||
// q*d == n - r
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Sub(box n, box r.into()),
|
||||
FlatExpression::Mult(box q.into(), box d),
|
||||
RuntimeError::Euclidean,
|
||||
|
@ -1520,14 +1567,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
left_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
let new_right = if right_flattened.is_linear() {
|
||||
right_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
|
||||
|
@ -1550,14 +1597,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
left_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
let new_right = if right_flattened.is_linear() {
|
||||
right_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
|
||||
|
@ -1612,20 +1659,20 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
left_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
let new_right = if right_flattened.is_linear() {
|
||||
right_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
|
||||
let res = self.use_sym();
|
||||
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
res,
|
||||
FlatExpression::Mult(box new_left, box new_right),
|
||||
));
|
||||
|
@ -1974,7 +2021,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
Entry::Vacant(_) => {
|
||||
let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>();
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
bits.clone(),
|
||||
Solver::Bits(from),
|
||||
vec![e.field.clone().unwrap()],
|
||||
|
@ -1996,7 +2043,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let sum = flat_expression_from_bits(bits.clone());
|
||||
|
||||
// sum check
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
e.field.clone().unwrap(),
|
||||
sum.clone(),
|
||||
RuntimeError::Sum,
|
||||
|
@ -2055,7 +2102,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
range_check = FlatExpression::Add(box range_check, box condition.clone());
|
||||
|
||||
let conditional_element_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
conditional_element_id,
|
||||
FlatExpression::Mult(box condition, box element.flat()),
|
||||
));
|
||||
|
@ -2065,7 +2112,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
},
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
range_check,
|
||||
FlatExpression::Number(T::one()),
|
||||
RuntimeError::SelectRangeCheck,
|
||||
|
@ -2099,14 +2146,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
left_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
let new_right = if right_flattened.is_linear() {
|
||||
right_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
FlatExpression::Add(box new_left, box new_right)
|
||||
|
@ -2119,14 +2166,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
left_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
let new_right = if right_flattened.is_linear() {
|
||||
right_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
|
||||
|
@ -2139,14 +2186,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
left_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
let new_right = if right_flattened.is_linear() {
|
||||
right_flattened
|
||||
} else {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||
FlatExpression::Identifier(id)
|
||||
};
|
||||
FlatExpression::Mult(box new_left, box new_right)
|
||||
|
@ -2156,12 +2203,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let right_flattened = self.flatten_field_expression(statements_flattened, right);
|
||||
let new_left: FlatExpression<T> = {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, left_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, left_flattened));
|
||||
id.into()
|
||||
};
|
||||
let new_right: FlatExpression<T> = {
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(id, right_flattened));
|
||||
statements_flattened.push_back(FlatStatement::Definition(id, right_flattened));
|
||||
id.into()
|
||||
};
|
||||
|
||||
|
@ -2169,28 +2216,28 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let inverse = self.use_sym();
|
||||
|
||||
// # invb = 1/b
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![invb],
|
||||
Solver::Div,
|
||||
vec![FlatExpression::Number(T::one()), new_right.clone()],
|
||||
)));
|
||||
|
||||
// assert(invb * b == 1)
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::one()),
|
||||
FlatExpression::Mult(box invb.into(), box new_right.clone()),
|
||||
RuntimeError::Inverse,
|
||||
));
|
||||
|
||||
// # c = a/b
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![inverse],
|
||||
Solver::Div,
|
||||
vec![new_left.clone(), new_right.clone()],
|
||||
)));
|
||||
|
||||
// assert(c * b == a)
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
new_left,
|
||||
FlatExpression::Mult(box new_right, box inverse.into()),
|
||||
RuntimeError::Division,
|
||||
|
@ -2239,7 +2286,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// introduce a new variable
|
||||
let id = self.use_sym();
|
||||
// set it to the square of the previous one, stored in state
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
id,
|
||||
FlatExpression::Mult(
|
||||
box previous.clone(),
|
||||
|
@ -2262,7 +2309,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
true => {
|
||||
// update the result by introducing a new variable
|
||||
let id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
id,
|
||||
FlatExpression::Mult(box acc.clone(), box power), // set the new result to the current result times the current power
|
||||
));
|
||||
|
@ -2305,7 +2352,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.map(|x| x.get_field_unchecked())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
statements_flattened.push(FlatStatement::Return(FlatExpressionList {
|
||||
statements_flattened.push_back(FlatStatement::Return(FlatExpressionList {
|
||||
expressions: flat_expressions,
|
||||
}));
|
||||
}
|
||||
|
@ -2314,13 +2361,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
self.flatten_boolean_expression(statements_flattened, condition.clone());
|
||||
|
||||
let condition_id = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(condition_id, condition_flat));
|
||||
statements_flattened
|
||||
.push_back(FlatStatement::Definition(condition_id, condition_flat));
|
||||
|
||||
self.condition_cache.insert(condition, condition_id);
|
||||
|
||||
if self.config.isolate_branches {
|
||||
let mut consequence_statements = vec![];
|
||||
let mut alternative_statements = vec![];
|
||||
let mut consequence_statements = VecDeque::new();
|
||||
let mut alternative_statements = VecDeque::new();
|
||||
|
||||
consequence
|
||||
.into_iter()
|
||||
|
@ -2367,7 +2415,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let var = self.use_variable(&assignee);
|
||||
|
||||
// handle return of function call
|
||||
statements_flattened.push(FlatStatement::Definition(var, e));
|
||||
statements_flattened.push_back(FlatStatement::Definition(var, e));
|
||||
|
||||
var
|
||||
}
|
||||
|
@ -2431,14 +2479,14 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let e = self.flatten_boolean_expression(statements_flattened, e);
|
||||
|
||||
if e.is_linear() {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
e,
|
||||
FlatExpression::Number(T::from(1)),
|
||||
error.into(),
|
||||
));
|
||||
} else {
|
||||
// swap so that left side is linear
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::from(1)),
|
||||
e,
|
||||
error.into(),
|
||||
|
@ -2474,7 +2522,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
e => {
|
||||
let id = self.use_variable(&v);
|
||||
statements_flattened.push(FlatStatement::Definition(id, e));
|
||||
statements_flattened
|
||||
.push_back(FlatStatement::Definition(id, e));
|
||||
id
|
||||
}
|
||||
})
|
||||
|
@ -2502,45 +2551,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Flattens a function
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `funct` - `ZirFunction` that will be flattened
|
||||
fn flatten_function(&mut self, funct: ZirFunction<'ast, T>) -> FlatFunction<T> {
|
||||
self.layout = HashMap::new();
|
||||
|
||||
self.next_var_idx = 0;
|
||||
let mut statements_flattened: FlatStatements<T> = FlatStatements::new();
|
||||
|
||||
// push parameters
|
||||
let arguments_flattened = funct
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|p| self.use_parameter(&p, &mut statements_flattened))
|
||||
.collect();
|
||||
|
||||
// flatten statements in functions and apply substitution
|
||||
for stat in funct.statements {
|
||||
self.flatten_statement(&mut statements_flattened, stat);
|
||||
}
|
||||
|
||||
FlatFunction {
|
||||
arguments: arguments_flattened,
|
||||
statements: statements_flattened,
|
||||
}
|
||||
}
|
||||
|
||||
/// Flattens a program
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `prog` - `ZirProgram` that will be flattened.
|
||||
fn flatten_program(&mut self, prog: ZirProgram<'ast, T>) -> FlatProg<T> {
|
||||
FlatProg {
|
||||
main: self.flatten_function(prog.main),
|
||||
}
|
||||
}
|
||||
|
||||
/// Flattens an equality assertion, enforcing it in the circuit.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -2571,7 +2581,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
),
|
||||
),
|
||||
};
|
||||
statements_flattened.push(FlatStatement::Condition(lhs, rhs, error));
|
||||
statements_flattened.push_back(FlatStatement::Condition(lhs, rhs, error));
|
||||
}
|
||||
|
||||
/// Identifies a non-linear expression by assigning it to a new identifier.
|
||||
|
@ -2589,7 +2599,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
true => e,
|
||||
false => {
|
||||
let sym = self.use_sym();
|
||||
statements_flattened.push(FlatStatement::Definition(sym, e));
|
||||
statements_flattened.push_back(FlatStatement::Definition(sym, e));
|
||||
FlatExpression::Identifier(sym)
|
||||
}
|
||||
}
|
||||
|
@ -2638,7 +2648,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
}
|
||||
Type::Boolean => {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
variable.into(),
|
||||
FlatExpression::Mult(box variable.into(), box variable.into()),
|
||||
RuntimeError::ArgumentBitness,
|
||||
|
@ -2649,7 +2659,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// we insert dummy condition statement for private field elements
|
||||
// to avoid unconstrained variables
|
||||
// translates to y == x * x
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
self.use_sym(),
|
||||
FlatExpression::Mult(box variable.into(), box variable.into()),
|
||||
));
|
||||
|
|
|
@ -46,7 +46,7 @@ pub fn fold_module<T: Field, F: Folder<T>>(f: &mut F, p: Prog<T>) -> Prog<T> {
|
|||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
returns: p.returns.into_iter().map(|v| f.fold_variable(v)).collect(),
|
||||
return_count: p.return_count,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::flat_absy::{FlatDirective, FlatExpression, FlatProg, FlatStatement, FlatVariable};
|
||||
use crate::ir::{Directive, LinComb, Prog, QuadComb, Statement};
|
||||
use crate::flat_absy::{FlatDirective, FlatExpression, FlatStatement, FlatVariable};
|
||||
use crate::ir::{Directive, LinComb, QuadComb, Statement};
|
||||
use zokrates_field::Field;
|
||||
|
||||
impl<T: Field> QuadComb<T> {
|
||||
|
@ -17,50 +17,25 @@ impl<T: Field> QuadComb<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> From<FlatProg<T>> for Prog<T> {
|
||||
fn from(flat_prog: FlatProg<T>) -> Prog<T> {
|
||||
// get the main function
|
||||
let main = flat_prog.main;
|
||||
|
||||
let return_expressions: Vec<FlatExpression<T>> = main
|
||||
.statements
|
||||
.iter()
|
||||
.filter_map(|s| match s {
|
||||
FlatStatement::Return(el) => Some(el.expressions.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
Prog {
|
||||
arguments: main.arguments,
|
||||
returns: return_expressions
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, _)| FlatVariable::public(index))
|
||||
.collect(),
|
||||
statements: main
|
||||
.statements
|
||||
.into_iter()
|
||||
.filter_map(|s| match s {
|
||||
pub fn from_flat<'ast, T: Field, I: Iterator<Item = FlatStatement<T>>>(
|
||||
flat_prog_iterator: I,
|
||||
) -> impl Iterator<Item = Statement<T>> {
|
||||
flat_prog_iterator.filter_map(|s| match s {
|
||||
FlatStatement::Return(..) => None,
|
||||
s => Some(s.into()),
|
||||
})
|
||||
.chain(
|
||||
return_expressions
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, expression)| {
|
||||
Statement::Constraint(
|
||||
QuadComb::from_flat_expression(expression),
|
||||
FlatVariable::public(index).into(),
|
||||
None,
|
||||
)
|
||||
}),
|
||||
)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
// .chain(
|
||||
// return_expressions
|
||||
// .into_iter()
|
||||
// .enumerate()
|
||||
// .map(|(index, expression)| {
|
||||
// Statement::Constraint(
|
||||
// QuadComb::from_flat_expression(expression),
|
||||
// FlatVariable::public(index).into(),
|
||||
// None,
|
||||
// )
|
||||
// }),
|
||||
// )
|
||||
}
|
||||
|
||||
impl<T: Field> From<FlatExpression<T>> for LinComb<T> {
|
||||
|
|
|
@ -8,7 +8,7 @@ use zokrates_field::Field;
|
|||
|
||||
mod expression;
|
||||
pub mod folder;
|
||||
mod from_flat;
|
||||
pub mod from_flat;
|
||||
mod interpreter;
|
||||
mod serialize;
|
||||
pub mod smtlib2;
|
||||
|
@ -78,10 +78,51 @@ impl<T: Field> fmt::Display for Statement<T> {
|
|||
pub struct Prog<T> {
|
||||
pub statements: Vec<Statement<T>>,
|
||||
pub arguments: Vec<FlatParameter>,
|
||||
pub returns: Vec<FlatVariable>,
|
||||
pub return_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ProgIterator<I> {
|
||||
pub statements: I,
|
||||
pub arguments: Vec<FlatParameter>,
|
||||
pub return_count: usize,
|
||||
}
|
||||
|
||||
impl<T> From<Prog<T>> for ProgIterator<std::vec::IntoIter<Statement<T>>> {
|
||||
fn from(p: Prog<T>) -> ProgIterator<std::vec::IntoIter<Statement<T>>> {
|
||||
ProgIterator {
|
||||
statements: p.statements.into_iter(),
|
||||
arguments: p.arguments,
|
||||
return_count: p.return_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, I: Iterator<Item = Statement<T>>> From<ProgIterator<I>> for Prog<T> {
|
||||
fn from(p: ProgIterator<I>) -> Prog<T> {
|
||||
Prog {
|
||||
statements: p.statements.collect(),
|
||||
arguments: p.arguments,
|
||||
return_count: p.return_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ProgIterator<T> {
|
||||
pub fn returns(&self) -> Vec<FlatVariable> {
|
||||
(0..self.return_count)
|
||||
.map(|id| FlatVariable::public(id))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Prog<T> {
|
||||
pub fn returns(&self) -> Vec<FlatVariable> {
|
||||
(0..self.return_count)
|
||||
.map(|id| FlatVariable::public(id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn constraint_count(&self) -> usize {
|
||||
self.statements
|
||||
.iter()
|
||||
|
@ -113,14 +154,14 @@ impl<T: Field> fmt::Display for Prog<T> {
|
|||
.map(|v| format!("{}", v))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
self.returns.len(),
|
||||
self.return_count,
|
||||
self.statements
|
||||
.iter()
|
||||
.map(|s| format!("\t{}", s))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n"),
|
||||
self.returns
|
||||
.iter()
|
||||
(0..self.return_count)
|
||||
.map(|i| FlatVariable::public(i))
|
||||
.map(|e| format!("{}", e))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
|
|
|
@ -49,9 +49,6 @@ pub fn visit_module<T: Field, F: Visitor<T>>(f: &mut F, p: &Prog<T>) {
|
|||
for expr in p.statements.iter() {
|
||||
f.visit_statement(expr);
|
||||
}
|
||||
for expr in p.returns.iter() {
|
||||
f.visit_variable(expr);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn visit_statement<T: Field, F: Visitor<T>>(f: &mut F, s: &Statement<T>) {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::ir::{folder::Folder, LinComb};
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Canonicalizer;
|
||||
|
||||
impl<T: Field> Folder<T> for Canonicalizer {
|
||||
|
|
|
@ -17,26 +17,13 @@ use crate::solvers::Solver;
|
|||
use std::collections::hash_map::{Entry, HashMap};
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DirectiveOptimizer<T: Field> {
|
||||
#[derive(Debug, Default)]
|
||||
pub struct DirectiveOptimizer<T> {
|
||||
calls: HashMap<(Solver, Vec<QuadComb<T>>), Vec<FlatVariable>>,
|
||||
/// Map of renamings for reassigned variables while processing the program.
|
||||
substitution: HashMap<FlatVariable, FlatVariable>,
|
||||
}
|
||||
|
||||
impl<T: Field> DirectiveOptimizer<T> {
|
||||
fn new() -> DirectiveOptimizer<T> {
|
||||
DirectiveOptimizer {
|
||||
calls: HashMap::new(),
|
||||
substitution: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn optimize(p: Prog<T>) -> Prog<T> {
|
||||
DirectiveOptimizer::new().fold_module(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
|
||||
fn fold_module(&mut self, p: Prog<T>) -> Prog<T> {
|
||||
// in order to correctly identify duplicates, we need to first canonicalize the statements
|
||||
|
|
|
@ -16,23 +16,11 @@ fn hash<T: Field>(s: &Statement<T>) -> Hash {
|
|||
hasher.finish()
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Default)]
|
||||
pub struct DuplicateOptimizer {
|
||||
seen: HashSet<Hash>,
|
||||
}
|
||||
|
||||
impl DuplicateOptimizer {
|
||||
fn new() -> Self {
|
||||
DuplicateOptimizer {
|
||||
seen: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn optimize<T: Field>(p: Prog<T>) -> Prog<T> {
|
||||
Self::new().fold_module(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Folder<T> for DuplicateOptimizer {
|
||||
fn fold_module(&mut self, p: Prog<T>) -> Prog<T> {
|
||||
// in order to correctly identify duplicates, we need to first canonicalize the statements
|
||||
|
|
|
@ -10,41 +10,59 @@ mod duplicate;
|
|||
mod redefinition;
|
||||
mod tautology;
|
||||
|
||||
use self::canonicalizer::Canonicalizer;
|
||||
use self::directive::DirectiveOptimizer;
|
||||
use self::duplicate::DuplicateOptimizer;
|
||||
use self::redefinition::RedefinitionOptimizer;
|
||||
use self::tautology::TautologyOptimizer;
|
||||
|
||||
use crate::ir::Prog;
|
||||
use crate::ir::{ProgIterator, Statement};
|
||||
use zokrates_field::Field;
|
||||
|
||||
impl<T: Field> Prog<T> {
|
||||
pub fn optimize(self) -> Self {
|
||||
impl<T: Field, I: Iterator<Item = Statement<T>>> ProgIterator<I> {
|
||||
pub fn optimize(self) -> ProgIterator<impl Iterator<Item = Statement<T>>> {
|
||||
// remove redefinitions
|
||||
log::debug!("Constraints: {}", self.constraint_count());
|
||||
log::debug!("Optimizer: Remove redefinitions");
|
||||
let r = RedefinitionOptimizer::optimize(self);
|
||||
log::debug!("Done");
|
||||
log::debug!(
|
||||
"Optimizer: Remove redefinitions and tautologies and directives and duplicates"
|
||||
);
|
||||
|
||||
// remove constraints that are always satisfied
|
||||
log::debug!("Constraints: {}", r.constraint_count());
|
||||
log::debug!("Optimizer: Remove tautologies");
|
||||
let r = TautologyOptimizer::optimize(r);
|
||||
log::debug!("Done");
|
||||
// define all optimizer steps
|
||||
let mut redefinition_optimizer = RedefinitionOptimizer::default();
|
||||
let mut tautologies_optimizer = TautologyOptimizer::default();
|
||||
let mut directive_optimizer = DirectiveOptimizer::default();
|
||||
let mut canonicalizer = Canonicalizer::default();
|
||||
let mut duplicate_optimizer = DuplicateOptimizer::default();
|
||||
|
||||
// deduplicate directives which take the same input
|
||||
log::debug!("Constraints: {}", r.constraint_count());
|
||||
log::debug!("Optimizer: Remove duplicate directive");
|
||||
let r = DirectiveOptimizer::optimize(r);
|
||||
log::debug!("Done");
|
||||
// initialize the ones that need initializing
|
||||
redefinition_optimizer.ignore.extend(self.returns().clone());
|
||||
|
||||
// remove duplicate constraints
|
||||
log::debug!("Constraints: {}", r.constraint_count());
|
||||
log::debug!("Optimizer: Remove duplicate constraints");
|
||||
let r = DuplicateOptimizer::optimize(r);
|
||||
log::debug!("Done");
|
||||
use crate::ir::folder::Folder;
|
||||
|
||||
log::debug!("Constraints: {}", r.constraint_count());
|
||||
let r = ProgIterator {
|
||||
arguments: self
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|a| redefinition_optimizer.fold_argument(a))
|
||||
.map(|a| {
|
||||
<TautologyOptimizer as Folder<T>>::fold_argument(&mut tautologies_optimizer, a)
|
||||
})
|
||||
.map(|a| directive_optimizer.fold_argument(a))
|
||||
.map(|a| {
|
||||
<DuplicateOptimizer as Folder<T>>::fold_argument(&mut duplicate_optimizer, a)
|
||||
})
|
||||
.collect(),
|
||||
statements: self
|
||||
.statements
|
||||
.into_iter()
|
||||
.flat_map(move |s| redefinition_optimizer.fold_statement(s))
|
||||
.flat_map(move |s| tautologies_optimizer.fold_statement(s))
|
||||
.flat_map(move |s| canonicalizer.fold_statement(s))
|
||||
.flat_map(move |s| directive_optimizer.fold_statement(s))
|
||||
.flat_map(move |s| duplicate_optimizer.fold_statement(s)),
|
||||
return_count: self.return_count,
|
||||
};
|
||||
|
||||
log::debug!("Done");
|
||||
r
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,44 +38,30 @@
|
|||
|
||||
use crate::flat_absy::flat_variable::FlatVariable;
|
||||
use crate::flat_absy::FlatParameter;
|
||||
use crate::ir::folder::{fold_module, Folder};
|
||||
use crate::ir::folder::Folder;
|
||||
use crate::ir::LinComb;
|
||||
use crate::ir::*;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RedefinitionOptimizer<T: Field> {
|
||||
pub struct RedefinitionOptimizer<T> {
|
||||
/// Map of renamings for reassigned variables while processing the program.
|
||||
substitution: HashMap<FlatVariable, CanonicalLinComb<T>>,
|
||||
/// Set of variables that should not be substituted
|
||||
ignore: HashSet<FlatVariable>,
|
||||
pub ignore: HashSet<FlatVariable>,
|
||||
}
|
||||
|
||||
impl<T: Field> RedefinitionOptimizer<T> {
|
||||
fn new() -> Self {
|
||||
impl<T> Default for RedefinitionOptimizer<T> {
|
||||
fn default() -> Self {
|
||||
RedefinitionOptimizer {
|
||||
substitution: HashMap::new(),
|
||||
ignore: HashSet::new(),
|
||||
ignore: vec![FlatVariable::one()].into_iter().collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn optimize(p: Prog<T>) -> Prog<T> {
|
||||
RedefinitionOptimizer::new().fold_module(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
||||
fn fold_module(&mut self, p: Prog<T>) -> Prog<T> {
|
||||
// to prevent the optimiser from replacing outputs, add them to the ignored set
|
||||
self.ignore.extend(p.returns.iter().cloned());
|
||||
|
||||
// to prevent the optimiser from replacing ~one, add it to the ignored set
|
||||
self.ignore.insert(FlatVariable::one());
|
||||
|
||||
fold_module(self, p)
|
||||
}
|
||||
|
||||
fn fold_argument(&mut self, a: FlatParameter) -> FlatParameter {
|
||||
// to prevent the optimiser from replacing user input, add it to the ignored set
|
||||
self.ignore.insert(a.id);
|
||||
|
|
|
@ -10,17 +10,8 @@ use crate::ir::folder::Folder;
|
|||
use crate::ir::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub struct TautologyOptimizer {}
|
||||
|
||||
impl TautologyOptimizer {
|
||||
fn new() -> TautologyOptimizer {
|
||||
TautologyOptimizer {}
|
||||
}
|
||||
|
||||
pub fn optimize<T: Field>(p: Prog<T>) -> Prog<T> {
|
||||
TautologyOptimizer::new().fold_module(p)
|
||||
}
|
||||
}
|
||||
#[derive(Default)]
|
||||
pub struct TautologyOptimizer;
|
||||
|
||||
impl<T: Field> Folder<T> for TautologyOptimizer {
|
||||
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
|
||||
|
|
|
@ -101,14 +101,6 @@ impl<T: Field> Propagate<T> for FlatFunction<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Field> FlatProg<T> {
|
||||
pub fn propagate(self) -> FlatProg<T> {
|
||||
let main = self.main.propagate();
|
||||
|
||||
FlatProg { main }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -54,7 +54,8 @@ impl fmt::Debug for FieldParseError {
|
|||
}
|
||||
|
||||
pub trait Field:
|
||||
From<i32>
|
||||
'static
|
||||
+ From<i32>
|
||||
+ From<u32>
|
||||
+ From<usize>
|
||||
+ From<u128>
|
||||
|
|
|
@ -12,5 +12,6 @@ zokrates_abi = { version = "0.1", path = "../zokrates_abi" }
|
|||
serde = "1.0"
|
||||
serde_derive = "1.0"
|
||||
serde_json = "1.0"
|
||||
typed-arena = "1.4.1"
|
||||
|
||||
[lib]
|
||||
|
|
|
@ -149,13 +149,17 @@ fn compile_and_run<T: Field>(t: Tests) {
|
|||
let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap();
|
||||
let resolver = FileSystemResolver::with_stdlib_root(stdlib.to_str().unwrap());
|
||||
|
||||
let artifacts = compile::<T, _>(code, entry_point.clone(), Some(&resolver), &config).unwrap();
|
||||
let arena = typed_arena::Arena::new();
|
||||
|
||||
let bin = artifacts.prog();
|
||||
let abi = artifacts.abi();
|
||||
let artifacts =
|
||||
compile::<T, _>(code, entry_point.clone(), Some(&resolver), config, &arena).unwrap();
|
||||
|
||||
let abi = artifacts.abi;
|
||||
let bin = artifacts.prog;
|
||||
|
||||
if let Some(target_count) = t.max_constraint_count {
|
||||
let count = bin.constraint_count();
|
||||
unimplemented!("bin.constraint_count()");
|
||||
let count = 42;
|
||||
|
||||
assert!(
|
||||
count <= target_count,
|
||||
|
@ -169,6 +173,8 @@ fn compile_and_run<T: Field>(t: Tests) {
|
|||
|
||||
let interpreter = zokrates_core::ir::Interpreter::default();
|
||||
|
||||
let bin = &bin.into();
|
||||
|
||||
for test in t.tests.into_iter() {
|
||||
let with_abi = test.abi.unwrap_or(false);
|
||||
|
||||
|
|
Loading…
Reference in a new issue