1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00
This commit is contained in:
dark64 2023-01-25 14:03:39 +01:00
parent 553da16837
commit e12a4b46d3
7 changed files with 56 additions and 9 deletions

View file

@ -10,6 +10,6 @@ pub use self::embed::FlatEmbed;
pub use self::error::RuntimeError;
pub use self::metadata::SourceMetadata;
pub use self::parameter::Parameter;
pub use self::solvers::Solver;
pub use self::solvers::{Solver, ZirSolver};
pub use self::variable::Variable;
pub use format_string::FormatString;

View file

@ -2,6 +2,22 @@ use crate::zir::ZirFunction;
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub enum ZirSolver<'ast, T> {
#[serde(borrow)]
Function(ZirFunction<'ast, T>),
Indexed(usize, usize),
}
impl<'ast, T> fmt::Display for ZirSolver<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ZirSolver::Function(_) => write!(f, "Zir(..)"),
ZirSolver::Indexed(index, n) => write!(f, "Zir@{}({})", index, n),
}
}
}
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub enum Solver<'ast, T> {
ConditionEq,
@ -13,13 +29,22 @@ pub enum Solver<'ast, T> {
ShaCh,
EuclideanDiv,
#[serde(borrow)]
Zir(ZirFunction<'ast, T>),
Zir(ZirSolver<'ast, T>),
#[cfg(feature = "bellman")]
Sha256Round,
#[cfg(feature = "ark")]
SnarkVerifyBls12377(usize),
}
impl<'ast, T> Solver<'ast, T> {
pub fn zir_function(function: ZirFunction<'ast, T>) -> Self {
Solver::Zir(ZirSolver::Function(function))
}
pub fn zir_indexed(index: usize, argument_count: usize) -> Self {
Solver::Zir(ZirSolver::Indexed(index, argument_count))
}
}
impl<'ast, T> fmt::Display for Solver<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
@ -31,7 +56,7 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> {
Solver::ShaAndXorAndXorAnd => write!(f, "ShaAndXorAndXorAnd"),
Solver::ShaCh => write!(f, "ShaCh"),
Solver::EuclideanDiv => write!(f, "EuclideanDiv"),
Solver::Zir(_) => write!(f, "Zir(..)"),
Solver::Zir(s) => write!(f, "{}", s),
#[cfg(feature = "bellman")]
Solver::Sha256Round => write!(f, "Sha256Round"),
#[cfg(feature = "ark")]
@ -51,7 +76,10 @@ impl<'ast, T> Solver<'ast, T> {
Solver::ShaAndXorAndXorAnd => (3, 1),
Solver::ShaCh => (3, 1),
Solver::EuclideanDiv => (2, 2),
Solver::Zir(f) => (f.arguments.len(), 1),
Solver::Zir(s) => match s {
ZirSolver::Function(f) => (f.arguments.len(), 1),
ZirSolver::Indexed(_, n) => (*n, 1),
},
#[cfg(feature = "bellman")]
Solver::Sha256Round => (768, 26935),
#[cfg(feature = "ark")]

View file

@ -14,6 +14,7 @@ impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'a
.statements
.into_iter()
.flat_map(|s| Cleaner::default().fold_statement(s)),
solvers: self.solvers,
}
}
}

View file

@ -50,6 +50,7 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
.flat_map(|s| f.fold_statement(s))
.collect(),
return_count: p.return_count,
solvers: p.solvers,
}
}

View file

@ -2,6 +2,8 @@ use crate::flat::{FlatDirective, FlatExpression, FlatProgIterator, FlatStatement
use crate::ir::{Directive, LinComb, ProgIterator, QuadComb, Statement};
use zokrates_field::Field;
use super::SolverMap;
impl<T: Field> QuadComb<T> {
fn from_flat_expression<U: Into<FlatExpression<T>>>(flat_expression: U) -> QuadComb<T> {
let flat_expression = flat_expression.into();
@ -24,6 +26,7 @@ pub fn from_flat<'ast, T: Field, I: IntoIterator<Item = FlatStatement<'ast, T>>>
statements: flat_prog_iterator.statements.into_iter().map(Into::into),
arguments: flat_prog_iterator.arguments,
return_count: flat_prog_iterator.return_count,
solvers: SolverMap::default(),
}
}

View file

@ -2,7 +2,7 @@ use crate::common::FormatString;
use crate::typed::ConcreteType;
use derivative::Derivative;
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
use std::collections::{BTreeSet, HashMap};
use std::fmt;
use std::hash::Hash;
use zokrates_field::Field;
@ -22,8 +22,8 @@ pub use self::expression::{CanonicalLinComb, LinComb};
pub use self::serialize::ProgEnum;
pub use crate::common::Parameter;
pub use crate::common::RuntimeError;
pub use crate::common::Solver;
pub use crate::common::Variable;
pub use crate::common::{Solver, ZirSolver};
pub use self::witness::Witness;
@ -124,20 +124,29 @@ impl<'ast, T: Field> fmt::Display for Statement<'ast, T> {
}
pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec<Statement<'ast, T>>>;
pub type SolverMap<'ast, T> = HashMap<u64, Solver<'ast, T>>;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
pub struct ProgIterator<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> {
pub arguments: Vec<Parameter>,
pub return_count: usize,
pub statements: I,
#[serde(borrow)]
pub solvers: SolverMap<'ast, T>,
}
impl<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
pub fn new(arguments: Vec<Parameter>, statements: I, return_count: usize) -> Self {
pub fn new(
arguments: Vec<Parameter>,
statements: I,
return_count: usize,
solvers: SolverMap<'ast, T>,
) -> Self {
Self {
arguments,
return_count,
statements,
solvers,
}
}
@ -146,6 +155,7 @@ impl<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T,
statements: self.statements.into_iter().collect::<Vec<_>>(),
arguments: self.arguments,
return_count: self.return_count,
solvers: self.solvers,
}
}
@ -192,6 +202,7 @@ impl<'ast, T> Prog<'ast, T> {
statements: self.statements.into_iter(),
arguments: self.arguments,
return_count: self.return_count,
solvers: self.solvers,
}
}
}

View file

@ -24,7 +24,7 @@ use zokrates_ast::common::embed::*;
use zokrates_ast::common::FlatEmbed;
use zokrates_ast::common::{RuntimeError, Variable};
use zokrates_ast::flat::*;
use zokrates_ast::ir::Solver;
use zokrates_ast::ir::{Solver, ZirSolver};
use zokrates_ast::zir::types::{Type, UBitwidth};
use zokrates_ast::zir::{
BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter,
@ -2241,7 +2241,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.into_iter()
.map(|assignee| self.use_variable(&assignee))
.collect();
let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs);
let solver = Solver::Zir(ZirSolver::Function(function));
let directive = FlatDirective::new(outputs, solver, inputs);
statements_flattened.push_back(FlatStatement::Directive(directive));
}
ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => {