optimize zir solvers by indexing
This commit is contained in:
parent
e12a4b46d3
commit
4988911183
15 changed files with 235 additions and 119 deletions
|
@ -1,4 +1,4 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::BTreeMap;
|
||||||
use std::convert::{TryFrom, TryInto};
|
use std::convert::{TryFrom, TryInto};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth};
|
use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth};
|
||||||
|
@ -481,7 +481,7 @@ impl<'ast, T: Field> Flattener<T> {
|
||||||
// }
|
// }
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct ArgumentFinder<'ast, T> {
|
pub struct ArgumentFinder<'ast, T> {
|
||||||
pub identifiers: HashMap<zir::Identifier<'ast>, zir::Type>,
|
pub identifiers: BTreeMap<zir::Identifier<'ast>, zir::Type>,
|
||||||
_phantom: PhantomData<T>,
|
_phantom: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,6 @@ pub use self::embed::FlatEmbed;
|
||||||
pub use self::error::RuntimeError;
|
pub use self::error::RuntimeError;
|
||||||
pub use self::metadata::SourceMetadata;
|
pub use self::metadata::SourceMetadata;
|
||||||
pub use self::parameter::Parameter;
|
pub use self::parameter::Parameter;
|
||||||
pub use self::solvers::{Solver, ZirSolver};
|
pub use self::solvers::Solver;
|
||||||
pub use self::variable::Variable;
|
pub use self::variable::Variable;
|
||||||
pub use format_string::FormatString;
|
pub use format_string::FormatString;
|
||||||
|
|
|
@ -2,22 +2,6 @@ use crate::zir::ZirFunction;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
|
|
||||||
pub enum ZirSolver<'ast, T> {
|
|
||||||
#[serde(borrow)]
|
|
||||||
Function(ZirFunction<'ast, T>),
|
|
||||||
Indexed(usize, usize),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ast, T> fmt::Display for ZirSolver<'ast, T> {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
match self {
|
|
||||||
ZirSolver::Function(_) => write!(f, "Zir(..)"),
|
|
||||||
ZirSolver::Indexed(index, n) => write!(f, "Zir@{}({})", index, n),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
|
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
|
||||||
pub enum Solver<'ast, T> {
|
pub enum Solver<'ast, T> {
|
||||||
ConditionEq,
|
ConditionEq,
|
||||||
|
@ -29,22 +13,14 @@ pub enum Solver<'ast, T> {
|
||||||
ShaCh,
|
ShaCh,
|
||||||
EuclideanDiv,
|
EuclideanDiv,
|
||||||
#[serde(borrow)]
|
#[serde(borrow)]
|
||||||
Zir(ZirSolver<'ast, T>),
|
Zir(ZirFunction<'ast, T>),
|
||||||
|
IndexedCall(usize, usize),
|
||||||
#[cfg(feature = "bellman")]
|
#[cfg(feature = "bellman")]
|
||||||
Sha256Round,
|
Sha256Round,
|
||||||
#[cfg(feature = "ark")]
|
#[cfg(feature = "ark")]
|
||||||
SnarkVerifyBls12377(usize),
|
SnarkVerifyBls12377(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T> Solver<'ast, T> {
|
|
||||||
pub fn zir_function(function: ZirFunction<'ast, T>) -> Self {
|
|
||||||
Solver::Zir(ZirSolver::Function(function))
|
|
||||||
}
|
|
||||||
pub fn zir_indexed(index: usize, argument_count: usize) -> Self {
|
|
||||||
Solver::Zir(ZirSolver::Indexed(index, argument_count))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ast, T> fmt::Display for Solver<'ast, T> {
|
impl<'ast, T> fmt::Display for Solver<'ast, T> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
|
@ -56,7 +32,8 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> {
|
||||||
Solver::ShaAndXorAndXorAnd => write!(f, "ShaAndXorAndXorAnd"),
|
Solver::ShaAndXorAndXorAnd => write!(f, "ShaAndXorAndXorAnd"),
|
||||||
Solver::ShaCh => write!(f, "ShaCh"),
|
Solver::ShaCh => write!(f, "ShaCh"),
|
||||||
Solver::EuclideanDiv => write!(f, "EuclideanDiv"),
|
Solver::EuclideanDiv => write!(f, "EuclideanDiv"),
|
||||||
Solver::Zir(s) => write!(f, "{}", s),
|
Solver::Zir(_) => write!(f, "Zir(..)"),
|
||||||
|
Solver::IndexedCall(index, argc) => write!(f, "IndexedCall@{}({})", index, argc),
|
||||||
#[cfg(feature = "bellman")]
|
#[cfg(feature = "bellman")]
|
||||||
Solver::Sha256Round => write!(f, "Sha256Round"),
|
Solver::Sha256Round => write!(f, "Sha256Round"),
|
||||||
#[cfg(feature = "ark")]
|
#[cfg(feature = "ark")]
|
||||||
|
@ -76,10 +53,8 @@ impl<'ast, T> Solver<'ast, T> {
|
||||||
Solver::ShaAndXorAndXorAnd => (3, 1),
|
Solver::ShaAndXorAndXorAnd => (3, 1),
|
||||||
Solver::ShaCh => (3, 1),
|
Solver::ShaCh => (3, 1),
|
||||||
Solver::EuclideanDiv => (2, 2),
|
Solver::EuclideanDiv => (2, 2),
|
||||||
Solver::Zir(s) => match s {
|
Solver::Zir(f) => (f.arguments.len(), 1),
|
||||||
ZirSolver::Function(f) => (f.arguments.len(), 1),
|
Solver::IndexedCall(_, n) => (*n, 1),
|
||||||
ZirSolver::Indexed(_, n) => (*n, 1),
|
|
||||||
},
|
|
||||||
#[cfg(feature = "bellman")]
|
#[cfg(feature = "bellman")]
|
||||||
Solver::Sha256Round => (768, 26935),
|
Solver::Sha256Round => (768, 26935),
|
||||||
#[cfg(feature = "ark")]
|
#[cfg(feature = "ark")]
|
||||||
|
|
|
@ -2,8 +2,6 @@ use crate::flat::{FlatDirective, FlatExpression, FlatProgIterator, FlatStatement
|
||||||
use crate::ir::{Directive, LinComb, ProgIterator, QuadComb, Statement};
|
use crate::ir::{Directive, LinComb, ProgIterator, QuadComb, Statement};
|
||||||
use zokrates_field::Field;
|
use zokrates_field::Field;
|
||||||
|
|
||||||
use super::SolverMap;
|
|
||||||
|
|
||||||
impl<T: Field> QuadComb<T> {
|
impl<T: Field> QuadComb<T> {
|
||||||
fn from_flat_expression<U: Into<FlatExpression<T>>>(flat_expression: U) -> QuadComb<T> {
|
fn from_flat_expression<U: Into<FlatExpression<T>>>(flat_expression: U) -> QuadComb<T> {
|
||||||
let flat_expression = flat_expression.into();
|
let flat_expression = flat_expression.into();
|
||||||
|
@ -26,7 +24,7 @@ pub fn from_flat<'ast, T: Field, I: IntoIterator<Item = FlatStatement<'ast, T>>>
|
||||||
statements: flat_prog_iterator.statements.into_iter().map(Into::into),
|
statements: flat_prog_iterator.statements.into_iter().map(Into::into),
|
||||||
arguments: flat_prog_iterator.arguments,
|
arguments: flat_prog_iterator.arguments,
|
||||||
return_count: flat_prog_iterator.return_count,
|
return_count: flat_prog_iterator.return_count,
|
||||||
solvers: SolverMap::default(),
|
solvers: vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ use crate::common::FormatString;
|
||||||
use crate::typed::ConcreteType;
|
use crate::typed::ConcreteType;
|
||||||
use derivative::Derivative;
|
use derivative::Derivative;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::{BTreeSet, HashMap};
|
use std::collections::BTreeSet;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use zokrates_field::Field;
|
use zokrates_field::Field;
|
||||||
|
@ -14,6 +14,7 @@ pub mod folder;
|
||||||
pub mod from_flat;
|
pub mod from_flat;
|
||||||
mod serialize;
|
mod serialize;
|
||||||
pub mod smtlib2;
|
pub mod smtlib2;
|
||||||
|
mod solver_indexer;
|
||||||
pub mod visitor;
|
pub mod visitor;
|
||||||
mod witness;
|
mod witness;
|
||||||
|
|
||||||
|
@ -22,8 +23,8 @@ pub use self::expression::{CanonicalLinComb, LinComb};
|
||||||
pub use self::serialize::ProgEnum;
|
pub use self::serialize::ProgEnum;
|
||||||
pub use crate::common::Parameter;
|
pub use crate::common::Parameter;
|
||||||
pub use crate::common::RuntimeError;
|
pub use crate::common::RuntimeError;
|
||||||
|
pub use crate::common::Solver;
|
||||||
pub use crate::common::Variable;
|
pub use crate::common::Variable;
|
||||||
pub use crate::common::{Solver, ZirSolver};
|
|
||||||
|
|
||||||
pub use self::witness::Witness;
|
pub use self::witness::Witness;
|
||||||
|
|
||||||
|
@ -124,7 +125,6 @@ impl<'ast, T: Field> fmt::Display for Statement<'ast, T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec<Statement<'ast, T>>>;
|
pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec<Statement<'ast, T>>>;
|
||||||
pub type SolverMap<'ast, T> = HashMap<u64, Solver<'ast, T>>;
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
|
||||||
pub struct ProgIterator<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> {
|
pub struct ProgIterator<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> {
|
||||||
|
@ -132,7 +132,7 @@ pub struct ProgIterator<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> {
|
||||||
pub return_count: usize,
|
pub return_count: usize,
|
||||||
pub statements: I,
|
pub statements: I,
|
||||||
#[serde(borrow)]
|
#[serde(borrow)]
|
||||||
pub solvers: SolverMap<'ast, T>,
|
pub solvers: Vec<Solver<'ast, T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
|
impl<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
|
||||||
|
@ -140,7 +140,7 @@ impl<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T,
|
||||||
arguments: Vec<Parameter>,
|
arguments: Vec<Parameter>,
|
||||||
statements: I,
|
statements: I,
|
||||||
return_count: usize,
|
return_count: usize,
|
||||||
solvers: SolverMap<'ast, T>,
|
solvers: Vec<Solver<'ast, T>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
arguments,
|
arguments,
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
use crate::ir::check::UnconstrainedVariableDetector;
|
use crate::{
|
||||||
|
ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer},
|
||||||
|
Solver,
|
||||||
|
};
|
||||||
|
|
||||||
use super::{ProgIterator, Statement};
|
use super::{ProgIterator, Statement};
|
||||||
|
use serde::Deserialize;
|
||||||
use serde_cbor::{self, StreamDeserializer};
|
use serde_cbor::{self, StreamDeserializer};
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Seek, Write};
|
||||||
use zokrates_field::*;
|
use zokrates_field::*;
|
||||||
|
|
||||||
type DynamicError = Box<dyn std::error::Error>;
|
type DynamicError = Box<dyn std::error::Error>;
|
||||||
|
|
||||||
const ZOKRATES_MAGIC: &[u8; 4] = &[0x5a, 0x4f, 0x4b, 0];
|
const ZOKRATES_MAGIC: &[u8; 4] = &[0x5a, 0x4f, 0x4b, 0];
|
||||||
const ZOKRATES_VERSION_2: &[u8; 4] = &[0, 0, 0, 2];
|
const ZOKRATES_VERSION_3: &[u8; 4] = &[0, 0, 0, 3];
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Debug)]
|
#[derive(PartialEq, Eq, Debug)]
|
||||||
pub enum ProgEnum<
|
pub enum ProgEnum<
|
||||||
|
@ -61,17 +65,21 @@ impl<
|
||||||
impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
|
impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
|
||||||
/// serialize a program iterator, returning the number of constraints serialized
|
/// serialize a program iterator, returning the number of constraints serialized
|
||||||
/// Note that we only return constraints, not other statements such as directives
|
/// Note that we only return constraints, not other statements such as directives
|
||||||
pub fn serialize<W: Write>(self, mut w: W) -> Result<usize, DynamicError> {
|
pub fn serialize<W: Write + Seek>(self, mut w: W) -> Result<usize, DynamicError> {
|
||||||
use super::folder::Folder;
|
use super::folder::Folder;
|
||||||
|
|
||||||
w.write_all(ZOKRATES_MAGIC)?;
|
w.write_all(ZOKRATES_MAGIC)?;
|
||||||
w.write_all(ZOKRATES_VERSION_2)?;
|
w.write_all(ZOKRATES_VERSION_3)?;
|
||||||
w.write_all(&T::id())?;
|
w.write_all(&T::id())?;
|
||||||
|
|
||||||
|
let solver_list_ptr_offset = w.stream_position()?;
|
||||||
|
w.write_all(&[0u8; std::mem::size_of::<u64>()])?; // reserve 8 bytes
|
||||||
|
|
||||||
serde_cbor::to_writer(&mut w, &self.arguments)?;
|
serde_cbor::to_writer(&mut w, &self.arguments)?;
|
||||||
serde_cbor::to_writer(&mut w, &self.return_count)?;
|
serde_cbor::to_writer(&mut w, &self.return_count)?;
|
||||||
|
|
||||||
let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self);
|
let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self);
|
||||||
|
let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default();
|
||||||
|
|
||||||
let statements = self.statements.into_iter();
|
let statements = self.statements.into_iter();
|
||||||
|
|
||||||
|
@ -80,12 +88,22 @@ impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'a
|
||||||
if matches!(s, Statement::Constraint(..)) {
|
if matches!(s, Statement::Constraint(..)) {
|
||||||
count += 1;
|
count += 1;
|
||||||
}
|
}
|
||||||
let s = unconstrained_variable_detector.fold_statement(s);
|
let s: Vec<Statement<T>> = solver_indexer
|
||||||
|
.fold_statement(s)
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|s| unconstrained_variable_detector.fold_statement(s))
|
||||||
|
.collect();
|
||||||
for s in s {
|
for s in s {
|
||||||
serde_cbor::to_writer(&mut w, &s)?;
|
serde_cbor::to_writer(&mut w, &s)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let solver_list_offset = w.stream_position()?;
|
||||||
|
serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?;
|
||||||
|
|
||||||
|
w.seek(std::io::SeekFrom::Start(solver_list_ptr_offset))?;
|
||||||
|
w.write_all(&solver_list_offset.to_le_bytes())?;
|
||||||
|
|
||||||
unconstrained_variable_detector
|
unconstrained_variable_detector
|
||||||
.finalize()
|
.finalize()
|
||||||
.map(|_| count)
|
.map(|_| count)
|
||||||
|
@ -103,11 +121,11 @@ impl<'de, R: serde_cbor::de::Read<'de>, T: serde::Deserialize<'de>> Iterator
|
||||||
type Item = T;
|
type Item = T;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<T> {
|
fn next(&mut self) -> Option<T> {
|
||||||
self.s.next().transpose().unwrap()
|
self.s.next().and_then(|v| v.ok())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'de, R: Read>
|
impl<'de, R: Read + Seek>
|
||||||
ProgEnum<
|
ProgEnum<
|
||||||
'de,
|
'de,
|
||||||
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bls12_381Field>>,
|
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bls12_381Field>>,
|
||||||
|
@ -128,104 +146,108 @@ impl<'de, R: Read>
|
||||||
r.read_exact(&mut version)
|
r.read_exact(&mut version)
|
||||||
.map_err(|_| String::from("Cannot read version"))?;
|
.map_err(|_| String::from("Cannot read version"))?;
|
||||||
|
|
||||||
if &version == ZOKRATES_VERSION_2 {
|
if &version == ZOKRATES_VERSION_3 {
|
||||||
// Check the curve identifier, deserializing accordingly
|
// Check the curve identifier, deserializing accordingly
|
||||||
let mut curve = [0; 4];
|
let mut curve = [0; 4];
|
||||||
r.read_exact(&mut curve)
|
r.read_exact(&mut curve)
|
||||||
.map_err(|_| String::from("Cannot read curve identifier"))?;
|
.map_err(|_| String::from("Cannot read curve identifier"))?;
|
||||||
|
|
||||||
use serde::de::Deserializer;
|
let mut buffer = [0u8; std::mem::size_of::<u64>()];
|
||||||
let mut p = serde_cbor::Deserializer::from_reader(r);
|
r.read_exact(&mut buffer)
|
||||||
|
.map_err(|_| String::from("Cannot read solver list pointer"))?;
|
||||||
|
|
||||||
struct ArgumentsVisitor;
|
let solver_list_offset = u64::from_le_bytes(buffer);
|
||||||
|
|
||||||
impl<'de> serde::de::Visitor<'de> for ArgumentsVisitor {
|
let (arguments, return_count) = {
|
||||||
type Value = Vec<super::Parameter>;
|
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
||||||
formatter.write_str("seq of flat param")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
let arguments: Vec<super::Parameter> = Vec::deserialize(&mut p)
|
||||||
where
|
.map_err(|_| String::from("Cannot read parameters"))?;
|
||||||
A: serde::de::SeqAccess<'de>,
|
|
||||||
{
|
|
||||||
let mut res = vec![];
|
|
||||||
while let Some(e) = seq.next_element().unwrap() {
|
|
||||||
res.push(e);
|
|
||||||
}
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let arguments = p.deserialize_seq(ArgumentsVisitor).unwrap();
|
let return_count = usize::deserialize(&mut p)
|
||||||
|
.map_err(|_| String::from("Cannot read return count"))?;
|
||||||
|
|
||||||
struct ReturnCountVisitor;
|
(arguments, return_count)
|
||||||
|
};
|
||||||
|
|
||||||
impl<'de> serde::de::Visitor<'de> for ReturnCountVisitor {
|
let statement_offset = r.stream_position().unwrap();
|
||||||
type Value = usize;
|
r.seek(std::io::SeekFrom::Start(solver_list_offset))
|
||||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
.unwrap();
|
||||||
formatter.write_str("usize")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: serde::de::Error,
|
|
||||||
{
|
|
||||||
Ok(v as usize)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_u8<E>(self, v: u8) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: serde::de::Error,
|
|
||||||
{
|
|
||||||
Ok(v as usize)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: serde::de::Error,
|
|
||||||
{
|
|
||||||
Ok(v as usize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let return_count = p.deserialize_u32(ReturnCountVisitor).unwrap();
|
|
||||||
|
|
||||||
match curve {
|
match curve {
|
||||||
m if m == Bls12_381Field::id() => {
|
m if m == Bls12_381Field::id() => {
|
||||||
let s = p.into_iter::<Statement<Bls12_381Field>>();
|
let solvers: Vec<Solver<'de, Bls12_381Field>> = {
|
||||||
|
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||||
|
Vec::deserialize(&mut p)
|
||||||
|
.map_err(|_| String::from("Cannot read solver map"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
|
||||||
|
|
||||||
|
let p = serde_cbor::Deserializer::from_reader(r);
|
||||||
|
let s = p.into_iter::<Statement<Bls12_381Field>>();
|
||||||
Ok(ProgEnum::Bls12_381Program(ProgIterator::new(
|
Ok(ProgEnum::Bls12_381Program(ProgIterator::new(
|
||||||
arguments,
|
arguments,
|
||||||
UnwrappedStreamDeserializer { s },
|
UnwrappedStreamDeserializer { s },
|
||||||
return_count,
|
return_count,
|
||||||
|
solvers,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
m if m == Bn128Field::id() => {
|
m if m == Bn128Field::id() => {
|
||||||
|
let solvers: Vec<Solver<'de, Bn128Field>> = {
|
||||||
|
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||||
|
Vec::deserialize(&mut p)
|
||||||
|
.map_err(|_| String::from("Cannot read solver map"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
|
||||||
|
|
||||||
|
let p = serde_cbor::Deserializer::from_reader(r);
|
||||||
let s = p.into_iter::<Statement<Bn128Field>>();
|
let s = p.into_iter::<Statement<Bn128Field>>();
|
||||||
|
|
||||||
Ok(ProgEnum::Bn128Program(ProgIterator::new(
|
Ok(ProgEnum::Bn128Program(ProgIterator::new(
|
||||||
arguments,
|
arguments,
|
||||||
UnwrappedStreamDeserializer { s },
|
UnwrappedStreamDeserializer { s },
|
||||||
return_count,
|
return_count,
|
||||||
|
solvers,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
m if m == Bls12_377Field::id() => {
|
m if m == Bls12_377Field::id() => {
|
||||||
|
let solvers: Vec<Solver<'de, Bls12_377Field>> = {
|
||||||
|
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||||
|
Vec::deserialize(&mut p)
|
||||||
|
.map_err(|_| String::from("Cannot read solver map"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
|
||||||
|
|
||||||
|
let p = serde_cbor::Deserializer::from_reader(r);
|
||||||
let s = p.into_iter::<Statement<Bls12_377Field>>();
|
let s = p.into_iter::<Statement<Bls12_377Field>>();
|
||||||
|
|
||||||
Ok(ProgEnum::Bls12_377Program(ProgIterator::new(
|
Ok(ProgEnum::Bls12_377Program(ProgIterator::new(
|
||||||
arguments,
|
arguments,
|
||||||
UnwrappedStreamDeserializer { s },
|
UnwrappedStreamDeserializer { s },
|
||||||
return_count,
|
return_count,
|
||||||
|
solvers,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
m if m == Bw6_761Field::id() => {
|
m if m == Bw6_761Field::id() => {
|
||||||
|
let solvers: Vec<Solver<'de, Bw6_761Field>> = {
|
||||||
|
let mut p = serde_cbor::Deserializer::from_reader(r.by_ref());
|
||||||
|
Vec::deserialize(&mut p)
|
||||||
|
.map_err(|_| String::from("Cannot read solver map"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap();
|
||||||
|
|
||||||
|
let p = serde_cbor::Deserializer::from_reader(r);
|
||||||
let s = p.into_iter::<Statement<Bw6_761Field>>();
|
let s = p.into_iter::<Statement<Bw6_761Field>>();
|
||||||
|
|
||||||
Ok(ProgEnum::Bw6_761Program(ProgIterator::new(
|
Ok(ProgEnum::Bw6_761Program(ProgIterator::new(
|
||||||
arguments,
|
arguments,
|
||||||
UnwrappedStreamDeserializer { s },
|
UnwrappedStreamDeserializer { s },
|
||||||
return_count,
|
return_count,
|
||||||
|
solvers,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
_ => Err(String::from("Unknown curve identifier")),
|
_ => Err(String::from("Unknown curve identifier")),
|
||||||
|
|
50
zokrates_ast/src/ir/solver_indexer.rs
Normal file
50
zokrates_ast/src/ir/solver_indexer.rs
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
use crate::ir::folder::Folder;
|
||||||
|
use crate::ir::Directive;
|
||||||
|
use crate::ir::Solver;
|
||||||
|
use crate::zir::ZirFunction;
|
||||||
|
use std::collections::hash_map::DefaultHasher;
|
||||||
|
use std::collections::hash_map::Entry;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use zokrates_field::Field;
|
||||||
|
|
||||||
|
type Hash = u64;
|
||||||
|
|
||||||
|
fn hash<T: Field>(f: &ZirFunction<T>) -> Hash {
|
||||||
|
use std::hash::Hash;
|
||||||
|
use std::hash::Hasher;
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
f.hash(&mut hasher);
|
||||||
|
hasher.finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct SolverIndexer<'ast, T> {
|
||||||
|
pub solvers: Vec<Solver<'ast, T>>,
|
||||||
|
pub index_map: HashMap<u64, usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> {
|
||||||
|
fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> {
|
||||||
|
match d.solver {
|
||||||
|
Solver::Zir(f) => {
|
||||||
|
let argc = f.arguments.len();
|
||||||
|
let h = hash(&f);
|
||||||
|
let index = match self.index_map.entry(h) {
|
||||||
|
Entry::Occupied(v) => *v.get(),
|
||||||
|
Entry::Vacant(entry) => {
|
||||||
|
let index = self.solvers.len();
|
||||||
|
entry.insert(index);
|
||||||
|
self.solvers.push(Solver::Zir(f));
|
||||||
|
index
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Directive {
|
||||||
|
inputs: d.inputs,
|
||||||
|
outputs: d.outputs,
|
||||||
|
solver: Solver::IndexedCall(index, argc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => d,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,13 +4,14 @@ use std::fmt;
|
||||||
|
|
||||||
use crate::typed::Identifier as CoreIdentifier;
|
use crate::typed::Identifier as CoreIdentifier;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||||
pub enum Identifier<'ast> {
|
pub enum Identifier<'ast> {
|
||||||
#[serde(borrow)]
|
#[serde(borrow)]
|
||||||
Source(SourceIdentifier<'ast>),
|
Source(SourceIdentifier<'ast>),
|
||||||
|
Internal(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||||
pub enum SourceIdentifier<'ast> {
|
pub enum SourceIdentifier<'ast> {
|
||||||
#[serde(borrow)]
|
#[serde(borrow)]
|
||||||
Basic(CoreIdentifier<'ast>),
|
Basic(CoreIdentifier<'ast>),
|
||||||
|
@ -34,10 +35,17 @@ impl<'ast> fmt::Display for Identifier<'ast> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
Identifier::Source(s) => write!(f, "{}", s),
|
Identifier::Source(s) => write!(f, "{}", s),
|
||||||
|
Identifier::Internal(s) => write!(f, "{}", s),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ast> From<String> for Identifier<'ast> {
|
||||||
|
fn from(id: String) -> Identifier<'ast> {
|
||||||
|
Identifier::Internal(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// this is only used in tests but somehow cfg(test) does not work
|
// this is only used in tests but somehow cfg(test) does not work
|
||||||
impl<'ast> From<&'ast str> for Identifier<'ast> {
|
impl<'ast> From<&'ast str> for Identifier<'ast> {
|
||||||
fn from(id: &'ast str) -> Identifier<'ast> {
|
fn from(id: &'ast str) -> Identifier<'ast> {
|
||||||
|
|
|
@ -4,6 +4,7 @@ mod identifier;
|
||||||
pub mod lqc;
|
pub mod lqc;
|
||||||
mod parameter;
|
mod parameter;
|
||||||
pub mod result_folder;
|
pub mod result_folder;
|
||||||
|
pub mod substitution;
|
||||||
pub mod types;
|
pub mod types;
|
||||||
mod uint;
|
mod uint;
|
||||||
mod variable;
|
mod variable;
|
||||||
|
|
26
zokrates_ast/src/zir/substitution.rs
Normal file
26
zokrates_ast/src/zir/substitution.rs
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
use super::{Folder, Identifier};
|
||||||
|
use std::{collections::HashMap, marker::PhantomData};
|
||||||
|
use zokrates_field::Field;
|
||||||
|
|
||||||
|
pub struct ZirSubstitutor<'a, 'ast, T> {
|
||||||
|
substitution: &'a HashMap<Identifier<'ast>, Identifier<'ast>>,
|
||||||
|
_phantom: PhantomData<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'ast, T: Field> ZirSubstitutor<'a, 'ast, T> {
|
||||||
|
pub fn new(substitution: &'a HashMap<Identifier<'ast>, Identifier<'ast>>) -> Self {
|
||||||
|
ZirSubstitutor {
|
||||||
|
substitution,
|
||||||
|
_phantom: PhantomData::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'a, 'ast, T> {
|
||||||
|
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
|
||||||
|
match self.substitution.get(&n) {
|
||||||
|
Some(v) => v.clone(),
|
||||||
|
None => n,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,8 +11,8 @@ mod utils;
|
||||||
|
|
||||||
use self::utils::flat_expression_from_bits;
|
use self::utils::flat_expression_from_bits;
|
||||||
use zokrates_ast::zir::{
|
use zokrates_ast::zir::{
|
||||||
ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement,
|
substitution::ZirSubstitutor, ConditionalExpression, Folder, SelectExpression, ShouldReduce,
|
||||||
ZirExpressionList,
|
UMetadata, ZirAssemblyStatement, ZirExpressionList,
|
||||||
};
|
};
|
||||||
use zokrates_interpreter::Interpreter;
|
use zokrates_interpreter::Interpreter;
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ use zokrates_ast::common::embed::*;
|
||||||
use zokrates_ast::common::FlatEmbed;
|
use zokrates_ast::common::FlatEmbed;
|
||||||
use zokrates_ast::common::{RuntimeError, Variable};
|
use zokrates_ast::common::{RuntimeError, Variable};
|
||||||
use zokrates_ast::flat::*;
|
use zokrates_ast::flat::*;
|
||||||
use zokrates_ast::ir::{Solver, ZirSolver};
|
use zokrates_ast::ir::Solver;
|
||||||
use zokrates_ast::zir::types::{Type, UBitwidth};
|
use zokrates_ast::zir::types::{Type, UBitwidth};
|
||||||
use zokrates_ast::zir::{
|
use zokrates_ast::zir::{
|
||||||
BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter,
|
BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter,
|
||||||
|
@ -1885,7 +1885,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
// constants do not require directives
|
// constants do not require directives
|
||||||
if let Some(FlatExpression::Number(ref x)) = e.field {
|
if let Some(FlatExpression::Number(ref x)) = e.field {
|
||||||
let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()])
|
let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()], &[])
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(FlatExpression::Number)
|
.map(FlatExpression::Number)
|
||||||
|
@ -2237,12 +2237,34 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
.cloned()
|
.cloned()
|
||||||
.map(|p| self.layout.get(&p.id.id).cloned().unwrap().into())
|
.map(|p| self.layout.get(&p.id.id).cloned().unwrap().into())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let outputs: Vec<Variable> = assignees
|
let outputs: Vec<Variable> = assignees
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|assignee| self.use_variable(&assignee))
|
.map(|assignee| self.use_variable(&assignee))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let solver = Solver::Zir(ZirSolver::Function(function));
|
let mut substitution_map = HashMap::default();
|
||||||
|
for (index, p) in function.arguments.iter().enumerate() {
|
||||||
|
let new_id = format!("i{}", index).into();
|
||||||
|
substitution_map.insert(p.id.id.clone(), new_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut substitutor = ZirSubstitutor::new(&substitution_map);
|
||||||
|
let function = ZirFunction {
|
||||||
|
arguments: function
|
||||||
|
.arguments
|
||||||
|
.into_iter()
|
||||||
|
.map(|p| substitutor.fold_parameter(p))
|
||||||
|
.collect(),
|
||||||
|
statements: function
|
||||||
|
.statements
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|s| substitutor.fold_statement(s))
|
||||||
|
.collect(),
|
||||||
|
signature: function.signature,
|
||||||
|
};
|
||||||
|
|
||||||
|
let solver = Solver::Zir(function);
|
||||||
let directive = FlatDirective::new(outputs, solver, inputs);
|
let directive = FlatDirective::new(outputs, solver, inputs);
|
||||||
|
|
||||||
statements_flattened.push_back(FlatStatement::Directive(directive));
|
statements_flattened.push_back(FlatStatement::Directive(directive));
|
||||||
|
|
|
@ -39,14 +39,18 @@ impl<'ast, T: Field> Folder<'ast, T> for DuplicateOptimizer {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
|
fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
|
||||||
let hashed = hash(&s);
|
match s {
|
||||||
let result = match self.seen.get(&hashed) {
|
Statement::Block(s) => s.into_iter().flat_map(|s| self.fold_statement(s)).collect(),
|
||||||
Some(_) => vec![],
|
s => {
|
||||||
None => vec![s],
|
let hashed = hash(&s);
|
||||||
};
|
let result = match self.seen.get(&hashed) {
|
||||||
|
Some(_) => vec![],
|
||||||
self.seen.insert(hashed);
|
None => vec![s],
|
||||||
result
|
};
|
||||||
|
self.seen.insert(hashed);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,7 @@ pub fn optimize<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>>(
|
||||||
.flat_map(move |s| directive_optimizer.fold_statement(s))
|
.flat_map(move |s| directive_optimizer.fold_statement(s))
|
||||||
.flat_map(move |s| duplicate_optimizer.fold_statement(s)),
|
.flat_map(move |s| duplicate_optimizer.fold_statement(s)),
|
||||||
return_count: p.return_count,
|
return_count: p.return_count,
|
||||||
|
solvers: p.solvers,
|
||||||
};
|
};
|
||||||
|
|
||||||
log::debug!("Done");
|
log::debug!("Done");
|
||||||
|
|
|
@ -146,7 +146,7 @@ impl<T: Field> RedefinitionOptimizer<T> {
|
||||||
// unwrap inputs to their constant value
|
// unwrap inputs to their constant value
|
||||||
let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
|
let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
|
||||||
// run the solver
|
// run the solver
|
||||||
let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap();
|
let outputs = Interpreter::execute_solver(&d.solver, &inputs, &[]).unwrap();
|
||||||
assert_eq!(outputs.len(), d.outputs.len());
|
assert_eq!(outputs.len(), d.outputs.len());
|
||||||
|
|
||||||
// insert the results in the substitution
|
// insert the results in the substitution
|
||||||
|
|
|
@ -83,7 +83,7 @@ impl Interpreter {
|
||||||
inputs.pop().unwrap(),
|
inputs.pop().unwrap(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
_ => Self::execute_solver(&d.solver, &inputs),
|
_ => Self::execute_solver(&d.solver, &inputs, &program.solvers),
|
||||||
}
|
}
|
||||||
.map_err(Error::Solver)?;
|
.map_err(Error::Solver)?;
|
||||||
|
|
||||||
|
@ -164,7 +164,15 @@ impl Interpreter {
|
||||||
pub fn execute_solver<'ast, T: Field>(
|
pub fn execute_solver<'ast, T: Field>(
|
||||||
solver: &Solver<'ast, T>,
|
solver: &Solver<'ast, T>,
|
||||||
inputs: &[T],
|
inputs: &[T],
|
||||||
|
solvers: &[Solver<'ast, T>],
|
||||||
) -> Result<Vec<T>, String> {
|
) -> Result<Vec<T>, String> {
|
||||||
|
let solver = match solver {
|
||||||
|
Solver::IndexedCall(index, _) => solvers
|
||||||
|
.get(*index)
|
||||||
|
.ok_or_else(|| format!("Could not resolve solver at index {}", index))?,
|
||||||
|
s => s,
|
||||||
|
};
|
||||||
|
|
||||||
let (expected_input_count, expected_output_count) = solver.get_signature();
|
let (expected_input_count, expected_output_count) = solver.get_signature();
|
||||||
assert_eq!(inputs.len(), expected_input_count);
|
assert_eq!(inputs.len(), expected_input_count);
|
||||||
|
|
||||||
|
@ -334,6 +342,7 @@ impl Interpreter {
|
||||||
&inputs[*n + 8usize..],
|
&inputs[*n + 8usize..],
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
_ => unreachable!("unexpected solver"),
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(res.len(), expected_output_count);
|
assert_eq!(res.len(), expected_output_count);
|
||||||
|
|
Loading…
Reference in a new issue