From e12a4b46d3c1d3a0dc95249af5b32a82f7763367 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 25 Jan 2023 14:03:39 +0100 Subject: [PATCH] wip --- zokrates_ast/src/common/mod.rs | 2 +- zokrates_ast/src/common/solvers.rs | 34 +++++++++++++++++++++++++++--- zokrates_ast/src/ir/clean.rs | 1 + zokrates_ast/src/ir/folder.rs | 1 + zokrates_ast/src/ir/from_flat.rs | 3 +++ zokrates_ast/src/ir/mod.rs | 17 ++++++++++++--- zokrates_codegen/src/lib.rs | 7 ++++-- 7 files changed, 56 insertions(+), 9 deletions(-) diff --git a/zokrates_ast/src/common/mod.rs b/zokrates_ast/src/common/mod.rs index 13d23bfd..e0337dec 100644 --- a/zokrates_ast/src/common/mod.rs +++ b/zokrates_ast/src/common/mod.rs @@ -10,6 +10,6 @@ pub use self::embed::FlatEmbed; pub use self::error::RuntimeError; pub use self::metadata::SourceMetadata; pub use self::parameter::Parameter; -pub use self::solvers::Solver; +pub use self::solvers::{Solver, ZirSolver}; pub use self::variable::Variable; pub use format_string::FormatString; diff --git a/zokrates_ast/src/common/solvers.rs b/zokrates_ast/src/common/solvers.rs index 9b4f5c90..42318342 100644 --- a/zokrates_ast/src/common/solvers.rs +++ b/zokrates_ast/src/common/solvers.rs @@ -2,6 +2,22 @@ use crate::zir::ZirFunction; use serde::{Deserialize, Serialize}; use std::fmt; +#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] +pub enum ZirSolver<'ast, T> { + #[serde(borrow)] + Function(ZirFunction<'ast, T>), + Indexed(usize, usize), +} + +impl<'ast, T> fmt::Display for ZirSolver<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ZirSolver::Function(_) => write!(f, "Zir(..)"), + ZirSolver::Indexed(index, n) => write!(f, "Zir@{}({})", index, n), + } + } +} + #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] pub enum Solver<'ast, T> { ConditionEq, @@ -13,13 +29,22 @@ pub enum Solver<'ast, T> { ShaCh, EuclideanDiv, #[serde(borrow)] - Zir(ZirFunction<'ast, T>), + Zir(ZirSolver<'ast, T>), #[cfg(feature = "bellman")] Sha256Round, #[cfg(feature = "ark")] SnarkVerifyBls12377(usize), } +impl<'ast, T> Solver<'ast, T> { + pub fn zir_function(function: ZirFunction<'ast, T>) -> Self { + Solver::Zir(ZirSolver::Function(function)) + } + pub fn zir_indexed(index: usize, argument_count: usize) -> Self { + Solver::Zir(ZirSolver::Indexed(index, argument_count)) + } +} + impl<'ast, T> fmt::Display for Solver<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -31,7 +56,7 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> { Solver::ShaAndXorAndXorAnd => write!(f, "ShaAndXorAndXorAnd"), Solver::ShaCh => write!(f, "ShaCh"), Solver::EuclideanDiv => write!(f, "EuclideanDiv"), - Solver::Zir(_) => write!(f, "Zir(..)"), + Solver::Zir(s) => write!(f, "{}", s), #[cfg(feature = "bellman")] Solver::Sha256Round => write!(f, "Sha256Round"), #[cfg(feature = "ark")] @@ -51,7 +76,10 @@ impl<'ast, T> Solver<'ast, T> { Solver::ShaAndXorAndXorAnd => (3, 1), Solver::ShaCh => (3, 1), Solver::EuclideanDiv => (2, 2), - Solver::Zir(f) => (f.arguments.len(), 1), + Solver::Zir(s) => match s { + ZirSolver::Function(f) => (f.arguments.len(), 1), + ZirSolver::Indexed(_, n) => (*n, 1), + }, #[cfg(feature = "bellman")] Solver::Sha256Round => (768, 26935), #[cfg(feature = "ark")] diff --git a/zokrates_ast/src/ir/clean.rs b/zokrates_ast/src/ir/clean.rs index b4fb8f44..998e3298 100644 --- a/zokrates_ast/src/ir/clean.rs +++ b/zokrates_ast/src/ir/clean.rs @@ -14,6 +14,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a .statements .into_iter() .flat_map(|s| Cleaner::default().fold_statement(s)), + solvers: self.solvers, } } } diff --git a/zokrates_ast/src/ir/folder.rs b/zokrates_ast/src/ir/folder.rs index 6e67c15d..99466add 100644 --- a/zokrates_ast/src/ir/folder.rs +++ b/zokrates_ast/src/ir/folder.rs @@ -50,6 +50,7 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( .flat_map(|s| f.fold_statement(s)) .collect(), return_count: p.return_count, + solvers: p.solvers, } } diff --git a/zokrates_ast/src/ir/from_flat.rs b/zokrates_ast/src/ir/from_flat.rs index fc961cd8..7932f74b 100644 --- a/zokrates_ast/src/ir/from_flat.rs +++ b/zokrates_ast/src/ir/from_flat.rs @@ -2,6 +2,8 @@ use crate::flat::{FlatDirective, FlatExpression, FlatProgIterator, FlatStatement use crate::ir::{Directive, LinComb, ProgIterator, QuadComb, Statement}; use zokrates_field::Field; +use super::SolverMap; + impl QuadComb { fn from_flat_expression>>(flat_expression: U) -> QuadComb { let flat_expression = flat_expression.into(); @@ -24,6 +26,7 @@ pub fn from_flat<'ast, T: Field, I: IntoIterator>> statements: flat_prog_iterator.statements.into_iter().map(Into::into), arguments: flat_prog_iterator.arguments, return_count: flat_prog_iterator.return_count, + solvers: SolverMap::default(), } } diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index 78b48f80..de73b683 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -2,7 +2,7 @@ use crate::common::FormatString; use crate::typed::ConcreteType; use derivative::Derivative; use serde::{Deserialize, Serialize}; -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashMap}; use std::fmt; use std::hash::Hash; use zokrates_field::Field; @@ -22,8 +22,8 @@ pub use self::expression::{CanonicalLinComb, LinComb}; pub use self::serialize::ProgEnum; pub use crate::common::Parameter; pub use crate::common::RuntimeError; -pub use crate::common::Solver; pub use crate::common::Variable; +pub use crate::common::{Solver, ZirSolver}; pub use self::witness::Witness; @@ -124,20 +124,29 @@ impl<'ast, T: Field> fmt::Display for Statement<'ast, T> { } pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec>>; +pub type SolverMap<'ast, T> = HashMap>; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct ProgIterator<'ast, T, I: IntoIterator>> { pub arguments: Vec, pub return_count: usize, pub statements: I, + #[serde(borrow)] + pub solvers: SolverMap<'ast, T>, } impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, I> { - pub fn new(arguments: Vec, statements: I, return_count: usize) -> Self { + pub fn new( + arguments: Vec, + statements: I, + return_count: usize, + solvers: SolverMap<'ast, T>, + ) -> Self { Self { arguments, return_count, statements, + solvers, } } @@ -146,6 +155,7 @@ impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, statements: self.statements.into_iter().collect::>(), arguments: self.arguments, return_count: self.return_count, + solvers: self.solvers, } } @@ -192,6 +202,7 @@ impl<'ast, T> Prog<'ast, T> { statements: self.statements.into_iter(), arguments: self.arguments, return_count: self.return_count, + solvers: self.solvers, } } } diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index e149a968..795fa00a 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -24,7 +24,7 @@ use zokrates_ast::common::embed::*; use zokrates_ast::common::FlatEmbed; use zokrates_ast::common::{RuntimeError, Variable}; use zokrates_ast::flat::*; -use zokrates_ast::ir::Solver; +use zokrates_ast::ir::{Solver, ZirSolver}; use zokrates_ast::zir::types::{Type, UBitwidth}; use zokrates_ast::zir::{ BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter, @@ -2241,7 +2241,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .map(|assignee| self.use_variable(&assignee)) .collect(); - let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); + + let solver = Solver::Zir(ZirSolver::Function(function)); + let directive = FlatDirective::new(outputs, solver, inputs); + statements_flattened.push_back(FlatStatement::Directive(directive)); } ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => {