optimize zir solvers by indexing
This commit is contained in:
parent
e12a4b46d3
commit
4988911183
15 changed files with 235 additions and 119 deletions
|
@ -1,4 +1,4 @@
|
|||
use std::collections::HashMap;
|
||||
use std::collections::BTreeMap;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::marker::PhantomData;
|
||||
use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth};
|
||||
|
@ -481,7 +481,7 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
// }
|
||||
#[derive(Default)]
|
||||
pub struct ArgumentFinder<'ast, T> {
|
||||
pub identifiers: HashMap<zir::Identifier<'ast>, zir::Type>,
|
||||
pub identifiers: BTreeMap<zir::Identifier<'ast>, zir::Type>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,6 @@ pub use self::embed::FlatEmbed;
|
|||
pub use self::error::RuntimeError;
|
||||
pub use self::metadata::SourceMetadata;
|
||||
pub use self::parameter::Parameter;
|
||||
pub use self::solvers::{Solver, ZirSolver};
|
||||
pub use self::solvers::Solver;
|
||||
pub use self::variable::Variable;
|
||||
pub use format_string::FormatString;
|
||||
|
|
|
@ -2,22 +2,6 @@ use crate::zir::ZirFunction;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
|
||||
pub enum ZirSolver<'ast, T> {
|
||||
#[serde(borrow)]
|
||||
Function(ZirFunction<'ast, T>),
|
||||
Indexed(usize, usize),
|
||||
}
|
||||
|
||||
impl<'ast, T> fmt::Display for ZirSolver<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
ZirSolver::Function(_) => write!(f, "Zir(..)"),
|
||||
ZirSolver::Indexed(index, n) => write!(f, "Zir@{}({})", index, n),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
|
||||
pub enum Solver<'ast, T> {
|
||||
ConditionEq,
|
||||
|
@ -29,22 +13,14 @@ pub enum Solver<'ast, T> {
|
|||
ShaCh,
|
||||
EuclideanDiv,
|
||||
#[serde(borrow)]
|
||||
Zir(ZirSolver<'ast, T>),
|
||||
Zir(ZirFunction<'ast, T>),
|
||||
IndexedCall(usize, usize),
|
||||
#[cfg(feature = "bellman")]
|
||||
Sha256Round,
|
||||
#[cfg(feature = "ark")]
|
||||
SnarkVerifyBls12377(usize),
|
||||
}
|
||||
|
||||
impl<'ast, T> Solver<'ast, T> {
|
||||
pub fn zir_function(function: ZirFunction<'ast, T>) -> Self {
|
||||
Solver::Zir(ZirSolver::Function(function))
|
||||
}
|
||||
pub fn zir_indexed(index: usize, argument_count: usize) -> Self {
|
||||
Solver::Zir(ZirSolver::Indexed(index, argument_count))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> fmt::Display for Solver<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
|
@ -56,7 +32,8 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> {
|
|||
Solver::ShaAndXorAndXorAnd => write!(f, "ShaAndXorAndXorAnd"),
|
||||
Solver::ShaCh => write!(f, "ShaCh"),
|
||||
Solver::EuclideanDiv => write!(f, "EuclideanDiv"),
|
||||
Solver::Zir(s) => write!(f, "{}", s),
|
||||
Solver::Zir(_) => write!(f, "Zir(..)"),
|
||||
Solver::IndexedCall(index, argc) => write!(f, "IndexedCall@{}({})", index, argc),
|
||||
#[cfg(feature = "bellman")]
|
||||
Solver::Sha256Round => write!(f, "Sha256Round"),
|
||||
#[cfg(feature = "ark")]
|
||||
|
@ -76,10 +53,8 @@ impl<'ast, T> Solver<'ast, T> {
|
|||
Solver::ShaAndXorAndXorAnd => (3, 1),
|
||||
Solver::ShaCh => (3, 1),
|
||||
Solver::EuclideanDiv => (2, 2),
|
||||
Solver::Zir(s) => match s {
|
||||
ZirSolver::Function(f) => (f.arguments.len(), 1),
|
||||
ZirSolver::Indexed(_, n) => (*n, 1),
|
||||
},
|
||||
Solver::Zir(f) => (f.arguments.len(), 1),
|
||||
Solver::IndexedCall(_, n) => (*n, 1),
|
||||
#[cfg(feature = "bellman")]
|
||||
Solver::Sha256Round => (768, 26935),
|
||||
#[cfg(feature = "ark")]
|
||||
|
|
|
@ -2,8 +2,6 @@ use crate::flat::{FlatDirective, FlatExpression, FlatProgIterator, FlatStatement
|
|||
use crate::ir::{Directive, LinComb, ProgIterator, QuadComb, Statement};
|
||||
use zokrates_field::Field;
|
||||
|
||||
use super::SolverMap;
|
||||
|
||||
impl<T: Field> QuadComb<T> {
|
||||
fn from_flat_expression<U: Into<FlatExpression<T>>>(flat_expression: U) -> QuadComb<T> {
|
||||
let flat_expression = flat_expression.into();
|
||||
|
@ -26,7 +24,7 @@ pub fn from_flat<'ast, T: Field, I: IntoIterator<Item = FlatStatement<'ast, T>>>
|
|||
statements: flat_prog_iterator.statements.into_iter().map(Into::into),
|
||||
arguments: flat_prog_iterator.arguments,
|
||||
return_count: flat_prog_iterator.return_count,
|
||||
solvers: SolverMap::default(),
|
||||
solvers: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::common::FormatString;
|
|||
use crate::typed::ConcreteType;
|
||||
use derivative::Derivative;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::collections::BTreeSet;
|
||||
use std::fmt;
|
||||
use std::hash::Hash;
|
||||
use zokrates_field::Field;
|
||||
|
@ -14,6 +14,7 @@ pub mod folder;
|
|||
pub mod from_flat;
|
||||
mod serialize;
|
||||
pub mod smtlib2;
|
||||
mod solver_indexer;
|
||||
pub mod visitor;
|
||||
mod witness;
|
||||
|
||||
|
@ -22,8 +23,8 @@ pub use self::expression::{CanonicalLinComb, LinComb};
|
|||
pub use self::serialize::ProgEnum;
|
||||
pub use crate::common::Parameter;
|
||||
pub use crate::common::RuntimeError;
|
||||
pub use crate::common::Solver;
|
||||
pub use crate::common::Variable;
|
||||
pub use crate::common::{Solver, ZirSolver};
|
||||
|
||||
pub use self::witness::Witness;
|
||||
|
||||
|
@ -124,7 +125,6 @@ impl<'ast, T: Field> fmt::Display for Statement<'ast, T> {
|
|||
}
|
||||
|
||||
pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec<Statement<'ast, T>>>;
|
||||
pub type SolverMap<'ast, T> = HashMap<u64, Solver<'ast, T>>;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct ProgIterator<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> {
|
||||
|
@ -132,7 +132,7 @@ pub struct ProgIterator<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> {
|
|||
pub return_count: usize,
|
||||
pub statements: I,
|
||||
#[serde(borrow)]
|
||||
pub solvers: SolverMap<'ast, T>,
|
||||
pub solvers: Vec<Solver<'ast, T>>,
|
||||
}
|
||||
|
||||
impl<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
|
||||
|
@ -140,7 +140,7 @@ impl<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T,
|
|||
arguments: Vec<Parameter>,
|
||||
statements: I,
|
||||
return_count: usize,
|
||||
solvers: SolverMap<'ast, T>,
|
||||
solvers: Vec<Solver<'ast, T>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
arguments,
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
use crate::ir::check::UnconstrainedVariableDetector;
|
||||
use crate::{
|
||||
ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer},
|
||||
Solver,
|
||||
};
|
||||
|
||||
use super::{ProgIterator, Statement};
|
||||
use serde::Deserialize;
|
||||
use serde_cbor::{self, StreamDeserializer};
|
||||
use std::io::{Read, Write};
|
||||
use std::io::{Read, Seek, Write};
|
||||
use zokrates_field::*;
|
||||
|
||||
type DynamicError = Box<dyn std::error::Error>;
|
||||
|
||||
const ZOKRATES_MAGIC: &[u8; 4] = &[0x5a, 0x4f, 0x4b, 0];
|
||||
const ZOKRATES_VERSION_2: &[u8; 4] = &[0, 0, 0, 2];
|
||||
const ZOKRATES_VERSION_3: &[u8; 4] = &[0, 0, 0, 3];
|
||||
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
pub enum ProgEnum<
|
||||
|
@ -61,17 +65,21 @@ impl<
|
|||
impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
|
||||
/// serialize a program iterator, returning the number of constraints serialized
|
||||
/// Note that we only return constraints, not other statements such as directives
|
||||
pub fn serialize<W: Write>(self, mut w: W) -> Result<usize, DynamicError> {
|
||||
pub fn serialize<W: Write + Seek>(self, mut w: W) -> Result<usize, DynamicError> {
|
||||
use super::folder::Folder;
|
||||
|
||||
w.write_all(ZOKRATES_MAGIC)?;
|
||||
w.write_all(ZOKRATES_VERSION_2)?;
|
||||
w.write_all(ZOKRATES_VERSION_3)?;
|
||||
w.write_all(&T::id())?;
|
||||
|
||||
let solver_list_ptr_offset = w.stream_position()?;
|
||||
w.write_all(&[0u8; std::mem::size_of::<u64>()])?; // reserve 8 bytes
|
||||
|
||||
serde_cbor::to_writer(&mut w, &self.arguments)?;
|
||||
serde_cbor::to_writer(&mut w, &self.return_count)?;
|
||||
|
||||
let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self);
|
||||
let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default();
|
||||
|
||||
let statements = self.statements.into_iter();
|
||||
|
||||
|
@ -80,12 +88,22 @@ impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'a
|
|||
if matches!(s, Statement::Constraint(..)) {
|
||||
count += 1;
|
||||
}
|
||||
let s = unconstrained_variable_detector.fold_statement(s);
|
||||
let s: Vec<Statement<T>> = solver_indexer
|
||||
.fold_statement(s)
|
||||
.into_iter()
|
||||
.flat_map(|s| unconstrained_variable_detector.fold_statement(s))
|
||||
.collect();
|
||||
for s in s {
|
||||
serde_cbor::to_writer(&mut w, &s)?;
|
||||
}
|
||||
}
|
||||
|
||||
let solver_list_offset = w.stream_position()?;
|
||||
serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?;
|
||||
|
||||
w.seek(std::io::SeekFrom::Start(solver_list_ptr_offset))?;
|
||||
w.write_all(&solver_list_offset.to_le_bytes())?;
|
||||
|
||||
unconstrained_variable_detector
|
||||
.finalize()
|
||||
.map(|_| count)
|
||||
|
@ -103,11 +121,11 @@ impl<'de, R: serde_cbor::de::Read<'de>, T: serde::Deserialize<'de>> Iterator
|
|||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<T> {
|
||||
self.s.next().transpose().unwrap()
|
||||
self.s.next().and_then(|v| v.ok())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, R: Read>
|
||||
impl<'de, R: Read + Seek>
|
||||
ProgEnum<
|
||||
'de,
|
||||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bls12_381Field>>,
|
||||
|
@ -128,104 +146,108 @@ impl<'de, R: Read>
|
|||
r.read_exact(&mut version)
|
||||
.map_err(|_| String::from("Cannot read version"))?;
|
||||
|
||||
if &version == ZOKRATES_VERSION_2 {
|
||||
if &version == ZOKRATES_VERSION_3 {
|
||||
// Check the curve identifier, deserializing accordingly
|
||||
let mut curve = [0; 4];
|
||||
r.read_exact(&mut curve)
|
||||
.map_err(|_| String::from("Cannot read curve identifier"))?;
|
||||
|
||||
use serde::de::Deserializer;
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r);
|
||||
let mut buffer = [0u8; std::mem::size_of::<u64>()];
|
||||
r.read_exact(&mut buffer)
|
||||
.map_err(|_| String::from("Cannot read solver list pointer"))?;
|
||||
|
||||
struct ArgumentsVisitor;
|
||||
let solver_list_offset = u64::from_le_bytes(buffer);
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for ArgumentsVisitor {
|
||||
type Value = Vec<super::Parameter>;
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("seq of flat param")
|
||||
}
|
||||
let (arguments, return_count) = {
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
let mut res = vec![];
|
||||
while let Some(e) = seq.next_element().unwrap() {
|
||||
res.push(e);
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
let arguments: Vec<super::Parameter> = Vec::deserialize(&mut p)
|
||||
.map_err(|_| String::from("Cannot read parameters"))?;
|
||||
|
||||
let arguments = p.deserialize_seq(ArgumentsVisitor).unwrap();
|
||||
let return_count = usize::deserialize(&mut p)
|
||||
.map_err(|_| String::from("Cannot read return count"))?;
|
||||
|
||||
struct ReturnCountVisitor;
|
||||
(arguments, return_count)
|
||||
};
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for ReturnCountVisitor {
|
||||
type Value = usize;
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("usize")
|
||||
}
|
||||
|
||||
fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(v as usize)
|
||||
}
|
||||
|
||||
fn visit_u8<E>(self, v: u8) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(v as usize)
|
||||
}
|
||||
|
||||
fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(v as usize)
|
||||
}
|
||||
}
|
||||
|
||||
let return_count = p.deserialize_u32(ReturnCountVisitor).unwrap();
|
||||
let statement_offset = r.stream_position().unwrap();
|
||||
r.seek(std::io::SeekFrom::Start(solver_list_offset))
|
||||
.unwrap();
|
||||
|
||||
match curve {
|
||||
m if m == Bls12_381Field::id() => {
|
||||
let s = p.into_iter::<Statement<Bls12_381Field>>();
|
||||
let solvers: Vec<Solver<'de, Bls12_381Field>> = {
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
Vec::deserialize(&mut p)
|
||||
.map_err(|_| String::from("Cannot read solver map"))?
|
||||
};
|
||||
|
||||
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
|
||||
|
||||
let p = serde_cbor::Deserializer::from_reader(r);
|
||||
let s = p.into_iter::<Statement<Bls12_381Field>>();
|
||||
Ok(ProgEnum::Bls12_381Program(ProgIterator::new(
|
||||
arguments,
|
||||
UnwrappedStreamDeserializer { s },
|
||||
return_count,
|
||||
solvers,
|
||||
)))
|
||||
}
|
||||
m if m == Bn128Field::id() => {
|
||||
let solvers: Vec<Solver<'de, Bn128Field>> = {
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
Vec::deserialize(&mut p)
|
||||
.map_err(|_| String::from("Cannot read solver map"))?
|
||||
};
|
||||
|
||||
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
|
||||
|
||||
let p = serde_cbor::Deserializer::from_reader(r);
|
||||
let s = p.into_iter::<Statement<Bn128Field>>();
|
||||
|
||||
Ok(ProgEnum::Bn128Program(ProgIterator::new(
|
||||
arguments,
|
||||
UnwrappedStreamDeserializer { s },
|
||||
return_count,
|
||||
solvers,
|
||||
)))
|
||||
}
|
||||
m if m == Bls12_377Field::id() => {
|
||||
let solvers: Vec<Solver<'de, Bls12_377Field>> = {
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
Vec::deserialize(&mut p)
|
||||
.map_err(|_| String::from("Cannot read solver map"))?
|
||||
};
|
||||
|
||||
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
|
||||
|
||||
let p = serde_cbor::Deserializer::from_reader(r);
|
||||
let s = p.into_iter::<Statement<Bls12_377Field>>();
|
||||
|
||||
Ok(ProgEnum::Bls12_377Program(ProgIterator::new(
|
||||
arguments,
|
||||
UnwrappedStreamDeserializer { s },
|
||||
return_count,
|
||||
solvers,
|
||||
)))
|
||||
}
|
||||
m if m == Bw6_761Field::id() => {
|
||||
let solvers: Vec<Solver<'de, Bw6_761Field>> = {
|
||||
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||
Vec::deserialize(&mut p)
|
||||
.map_err(|_| String::from("Cannot read solver map"))?
|
||||
};
|
||||
|
||||
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
|
||||
|
||||
let p = serde_cbor::Deserializer::from_reader(r);
|
||||
let s = p.into_iter::<Statement<Bw6_761Field>>();
|
||||
|
||||
Ok(ProgEnum::Bw6_761Program(ProgIterator::new(
|
||||
arguments,
|
||||
UnwrappedStreamDeserializer { s },
|
||||
return_count,
|
||||
solvers,
|
||||
)))
|
||||
}
|
||||
_ => Err(String::from("Unknown curve identifier")),
|
||||
|
|
50
zokrates_ast/src/ir/solver_indexer.rs
Normal file
50
zokrates_ast/src/ir/solver_indexer.rs
Normal file
|
@ -0,0 +1,50 @@
|
|||
use crate::ir::folder::Folder;
|
||||
use crate::ir::Directive;
|
||||
use crate::ir::Solver;
|
||||
use crate::zir::ZirFunction;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use zokrates_field::Field;
|
||||
|
||||
type Hash = u64;
|
||||
|
||||
fn hash<T: Field>(f: &ZirFunction<T>) -> Hash {
|
||||
use std::hash::Hash;
|
||||
use std::hash::Hasher;
|
||||
let mut hasher = DefaultHasher::new();
|
||||
f.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SolverIndexer<'ast, T> {
|
||||
pub solvers: Vec<Solver<'ast, T>>,
|
||||
pub index_map: HashMap<u64, usize>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> {
|
||||
fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> {
|
||||
match d.solver {
|
||||
Solver::Zir(f) => {
|
||||
let argc = f.arguments.len();
|
||||
let h = hash(&f);
|
||||
let index = match self.index_map.entry(h) {
|
||||
Entry::Occupied(v) => *v.get(),
|
||||
Entry::Vacant(entry) => {
|
||||
let index = self.solvers.len();
|
||||
entry.insert(index);
|
||||
self.solvers.push(Solver::Zir(f));
|
||||
index
|
||||
}
|
||||
};
|
||||
Directive {
|
||||
inputs: d.inputs,
|
||||
outputs: d.outputs,
|
||||
solver: Solver::IndexedCall(index, argc),
|
||||
}
|
||||
}
|
||||
_ => d,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,13 +4,14 @@ use std::fmt;
|
|||
|
||||
use crate::typed::Identifier as CoreIdentifier;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum Identifier<'ast> {
|
||||
#[serde(borrow)]
|
||||
Source(SourceIdentifier<'ast>),
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum SourceIdentifier<'ast> {
|
||||
#[serde(borrow)]
|
||||
Basic(CoreIdentifier<'ast>),
|
||||
|
@ -34,10 +35,17 @@ impl<'ast> fmt::Display for Identifier<'ast> {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Identifier::Source(s) => write!(f, "{}", s),
|
||||
Identifier::Internal(s) => write!(f, "{}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<String> for Identifier<'ast> {
|
||||
fn from(id: String) -> Identifier<'ast> {
|
||||
Identifier::Internal(id)
|
||||
}
|
||||
}
|
||||
|
||||
// this is only used in tests but somehow cfg(test) does not work
|
||||
impl<'ast> From<&'ast str> for Identifier<'ast> {
|
||||
fn from(id: &'ast str) -> Identifier<'ast> {
|
||||
|
|
|
@ -4,6 +4,7 @@ mod identifier;
|
|||
pub mod lqc;
|
||||
mod parameter;
|
||||
pub mod result_folder;
|
||||
pub mod substitution;
|
||||
pub mod types;
|
||||
mod uint;
|
||||
mod variable;
|
||||
|
|
26
zokrates_ast/src/zir/substitution.rs
Normal file
26
zokrates_ast/src/zir/substitution.rs
Normal file
|
@ -0,0 +1,26 @@
|
|||
use super::{Folder, Identifier};
|
||||
use std::{collections::HashMap, marker::PhantomData};
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub struct ZirSubstitutor<'a, 'ast, T> {
|
||||
substitution: &'a HashMap<Identifier<'ast>, Identifier<'ast>>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'a, 'ast, T: Field> ZirSubstitutor<'a, 'ast, T> {
|
||||
pub fn new(substitution: &'a HashMap<Identifier<'ast>, Identifier<'ast>>) -> Self {
|
||||
ZirSubstitutor {
|
||||
substitution,
|
||||
_phantom: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'a, 'ast, T> {
|
||||
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
|
||||
match self.substitution.get(&n) {
|
||||
Some(v) => v.clone(),
|
||||
None => n,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,8 +11,8 @@ mod utils;
|
|||
|
||||
use self::utils::flat_expression_from_bits;
|
||||
use zokrates_ast::zir::{
|
||||
ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement,
|
||||
ZirExpressionList,
|
||||
substitution::ZirSubstitutor, ConditionalExpression, Folder, SelectExpression, ShouldReduce,
|
||||
UMetadata, ZirAssemblyStatement, ZirExpressionList,
|
||||
};
|
||||
use zokrates_interpreter::Interpreter;
|
||||
|
||||
|
@ -24,7 +24,7 @@ use zokrates_ast::common::embed::*;
|
|||
use zokrates_ast::common::FlatEmbed;
|
||||
use zokrates_ast::common::{RuntimeError, Variable};
|
||||
use zokrates_ast::flat::*;
|
||||
use zokrates_ast::ir::{Solver, ZirSolver};
|
||||
use zokrates_ast::ir::Solver;
|
||||
use zokrates_ast::zir::types::{Type, UBitwidth};
|
||||
use zokrates_ast::zir::{
|
||||
BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter,
|
||||
|
@ -1885,7 +1885,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
// constants do not require directives
|
||||
if let Some(FlatExpression::Number(ref x)) = e.field {
|
||||
let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()])
|
||||
let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()], &[])
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(FlatExpression::Number)
|
||||
|
@ -2237,12 +2237,34 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.cloned()
|
||||
.map(|p| self.layout.get(&p.id.id).cloned().unwrap().into())
|
||||
.collect();
|
||||
|
||||
let outputs: Vec<Variable> = assignees
|
||||
.into_iter()
|
||||
.map(|assignee| self.use_variable(&assignee))
|
||||
.collect();
|
||||
|
||||
let solver = Solver::Zir(ZirSolver::Function(function));
|
||||
let mut substitution_map = HashMap::default();
|
||||
for (index, p) in function.arguments.iter().enumerate() {
|
||||
let new_id = format!("i{}", index).into();
|
||||
substitution_map.insert(p.id.id.clone(), new_id);
|
||||
}
|
||||
|
||||
let mut substitutor = ZirSubstitutor::new(&substitution_map);
|
||||
let function = ZirFunction {
|
||||
arguments: function
|
||||
.arguments
|
||||
.into_iter()
|
||||
.map(|p| substitutor.fold_parameter(p))
|
||||
.collect(),
|
||||
statements: function
|
||||
.statements
|
||||
.into_iter()
|
||||
.flat_map(|s| substitutor.fold_statement(s))
|
||||
.collect(),
|
||||
signature: function.signature,
|
||||
};
|
||||
|
||||
let solver = Solver::Zir(function);
|
||||
let directive = FlatDirective::new(outputs, solver, inputs);
|
||||
|
||||
statements_flattened.push_back(FlatStatement::Directive(directive));
|
||||
|
|
|
@ -39,14 +39,18 @@ impl<'ast, T: Field> Folder<'ast, T> for DuplicateOptimizer {
|
|||
}
|
||||
|
||||
fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
|
||||
let hashed = hash(&s);
|
||||
let result = match self.seen.get(&hashed) {
|
||||
Some(_) => vec![],
|
||||
None => vec![s],
|
||||
};
|
||||
|
||||
self.seen.insert(hashed);
|
||||
result
|
||||
match s {
|
||||
Statement::Block(s) => s.into_iter().flat_map(|s| self.fold_statement(s)).collect(),
|
||||
s => {
|
||||
let hashed = hash(&s);
|
||||
let result = match self.seen.get(&hashed) {
|
||||
Some(_) => vec![],
|
||||
None => vec![s],
|
||||
};
|
||||
self.seen.insert(hashed);
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ pub fn optimize<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>>(
|
|||
.flat_map(move |s| directive_optimizer.fold_statement(s))
|
||||
.flat_map(move |s| duplicate_optimizer.fold_statement(s)),
|
||||
return_count: p.return_count,
|
||||
solvers: p.solvers,
|
||||
};
|
||||
|
||||
log::debug!("Done");
|
||||
|
|
|
@ -146,7 +146,7 @@ impl<T: Field> RedefinitionOptimizer<T> {
|
|||
// unwrap inputs to their constant value
|
||||
let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
|
||||
// run the solver
|
||||
let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap();
|
||||
let outputs = Interpreter::execute_solver(&d.solver, &inputs, &[]).unwrap();
|
||||
assert_eq!(outputs.len(), d.outputs.len());
|
||||
|
||||
// insert the results in the substitution
|
||||
|
|
|
@ -83,7 +83,7 @@ impl Interpreter {
|
|||
inputs.pop().unwrap(),
|
||||
))
|
||||
}
|
||||
_ => Self::execute_solver(&d.solver, &inputs),
|
||||
_ => Self::execute_solver(&d.solver, &inputs, &program.solvers),
|
||||
}
|
||||
.map_err(Error::Solver)?;
|
||||
|
||||
|
@ -164,7 +164,15 @@ impl Interpreter {
|
|||
pub fn execute_solver<'ast, T: Field>(
|
||||
solver: &Solver<'ast, T>,
|
||||
inputs: &[T],
|
||||
solvers: &[Solver<'ast, T>],
|
||||
) -> Result<Vec<T>, String> {
|
||||
let solver = match solver {
|
||||
Solver::IndexedCall(index, _) => solvers
|
||||
.get(*index)
|
||||
.ok_or_else(|| format!("Could not resolve solver at index {}", index))?,
|
||||
s => s,
|
||||
};
|
||||
|
||||
let (expected_input_count, expected_output_count) = solver.get_signature();
|
||||
assert_eq!(inputs.len(), expected_input_count);
|
||||
|
||||
|
@ -334,6 +342,7 @@ impl Interpreter {
|
|||
&inputs[*n + 8usize..],
|
||||
)
|
||||
}
|
||||
_ => unreachable!("unexpected solver"),
|
||||
};
|
||||
|
||||
assert_eq!(res.len(), expected_output_count);
|
||||
|
|
Loading…
Reference in a new issue