From c8696497454b036366433ed54e0a0aa2e2eb1915 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 5 Oct 2022 15:27:45 +0200 Subject: [PATCH] wip --- zokrates_ark/src/gm17.rs | 8 +- zokrates_ark/src/groth16.rs | 8 +- zokrates_ark/src/lib.rs | 18 +-- zokrates_ark/src/marlin.rs | 8 +- zokrates_ast/src/common/embed.rs | 15 +- zokrates_ast/src/common/error.rs | 2 + zokrates_ast/src/common/solvers.rs | 17 ++- zokrates_ast/src/flat/folder.rs | 30 ++-- zokrates_ast/src/flat/mod.rs | 34 ++--- zokrates_ast/src/ir/check.rs | 8 +- zokrates_ast/src/ir/folder.rs | 34 +++-- zokrates_ast/src/ir/from_flat.rs | 14 +- zokrates_ast/src/ir/mod.rs | 34 +++-- zokrates_ast/src/ir/serialize.rs | 52 +++---- zokrates_ast/src/ir/smtlib2.rs | 10 +- zokrates_ast/src/typed/folder.rs | 28 ++++ zokrates_ast/src/typed/identifier.rs | 14 +- zokrates_ast/src/typed/mod.rs | 108 +++++++++++++++ zokrates_ast/src/typed/result_folder.rs | 28 ++++ zokrates_ast/src/typed/types.rs | 5 +- zokrates_ast/src/untyped/from_ast.rs | 27 ++++ zokrates_ast/src/untyped/mod.rs | 37 ++++- zokrates_ast/src/untyped/node.rs | 1 + zokrates_ast/src/zir/folder.rs | 31 +++++ zokrates_ast/src/zir/identifier.rs | 7 +- zokrates_ast/src/zir/mod.rs | 129 +++++++++++++++--- zokrates_ast/src/zir/parameter.rs | 4 +- zokrates_ast/src/zir/result_folder.rs | 34 +++++ zokrates_ast/src/zir/uint.rs | 11 +- zokrates_ast/src/zir/variable.rs | 4 +- zokrates_bellman/src/groth16.rs | 16 +-- zokrates_bellman/src/lib.rs | 18 +-- zokrates_cli/src/ops/compute_witness.rs | 4 +- zokrates_cli/src/ops/generate_proof.rs | 5 +- zokrates_cli/src/ops/generate_smtlib2.rs | 4 +- zokrates_cli/src/ops/inspect.rs | 4 +- zokrates_cli/src/ops/mpc/init.rs | 5 +- zokrates_cli/src/ops/mpc/verify.rs | 5 +- zokrates_cli/src/ops/setup.rs | 10 +- zokrates_core/src/compile.rs | 18 +-- zokrates_core/src/flatten/mod.rs | 113 +++++++++------ zokrates_core/src/optimizer/canonicalizer.rs | 2 +- zokrates_core/src/optimizer/directive.rs | 8 +- zokrates_core/src/optimizer/duplicate.rs | 6 +- zokrates_core/src/optimizer/mod.rs | 6 +- zokrates_core/src/optimizer/redefinition.rs | 8 +- zokrates_core/src/optimizer/tautology.rs | 4 +- zokrates_core/src/semantics.rs | 121 ++++++++++++++-- .../src/static_analysis/flat_propagation.rs | 4 +- .../static_analysis/flatten_complex_types.rs | 101 +++++++++++++- .../src/static_analysis/propagation.rs | 32 +++++ .../src/static_analysis/zir_propagation.rs | 23 +++- zokrates_interpreter/src/lib.rs | 22 ++- zokrates_js/src/lib.rs | 5 +- zokrates_parser/src/zokrates.pest | 11 +- zokrates_pest_ast/src/lib.rs | 78 ++++++++++- zokrates_proof_systems/src/lib.rs | 22 +-- 57 files changed, 1112 insertions(+), 303 deletions(-) diff --git a/zokrates_ark/src/gm17.rs b/zokrates_ark/src/gm17.rs index 209a1abd..0e119eee 100644 --- a/zokrates_ark/src/gm17.rs +++ b/zokrates_ark/src/gm17.rs @@ -16,8 +16,8 @@ use zokrates_proof_systems::Scheme; use zokrates_proof_systems::{Backend, NonUniversalBackend, Proof, SetupKeypair}; impl NonUniversalBackend for Ark { - fn setup>>( - program: ProgIterator, + fn setup<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, ) -> SetupKeypair { let computation = Computation::without_witness(program); @@ -41,8 +41,8 @@ impl NonUniversalBackend for Ark { } impl Backend for Ark { - fn generate_proof>>( - program: ProgIterator, + fn generate_proof<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, witness: Witness, proving_key: Vec, ) -> Proof { diff --git a/zokrates_ark/src/groth16.rs b/zokrates_ark/src/groth16.rs index 617d34ed..0de18845 100644 --- a/zokrates_ark/src/groth16.rs +++ b/zokrates_ark/src/groth16.rs @@ -19,8 +19,8 @@ use zokrates_proof_systems::Scheme; const G16_WARNING: &str = "WARNING: You are using the G16 scheme which is subject to malleability. See zokrates.github.io/toolbox/proving_schemes.html#g16-malleability for implications."; impl Backend for Ark { - fn generate_proof>>( - program: ProgIterator, + fn generate_proof<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, witness: Witness, proving_key: Vec, ) -> Proof { @@ -86,8 +86,8 @@ impl Backend for Ark { } impl NonUniversalBackend for Ark { - fn setup>>( - program: ProgIterator, + fn setup<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, ) -> SetupKeypair { println!("{}", G16_WARNING); diff --git a/zokrates_ark/src/lib.rs b/zokrates_ark/src/lib.rs index f5c3b320..425be3a8 100644 --- a/zokrates_ark/src/lib.rs +++ b/zokrates_ark/src/lib.rs @@ -17,20 +17,20 @@ pub use self::parse::*; pub struct Ark; #[derive(Clone)] -pub struct Computation>> { - program: ProgIterator, +pub struct Computation<'a, T, I: IntoIterator>> { + program: ProgIterator<'a, T, I>, witness: Option>, } -impl>> Computation { - pub fn with_witness(program: ProgIterator, witness: Witness) -> Self { +impl<'a, T, I: IntoIterator>> Computation<'a, T, I> { + pub fn with_witness(program: ProgIterator<'a, T, I>, witness: Witness) -> Self { Computation { program, witness: Some(witness), } } - pub fn without_witness(program: ProgIterator) -> Self { + pub fn without_witness(program: ProgIterator<'a, T, I>) -> Self { Computation { program, witness: None, @@ -72,9 +72,9 @@ fn ark_combination( .fold(LinearCombination::zero(), |acc, e| acc + e) } -impl>> +impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator>> ConstraintSynthesizer<<::ArkEngine as PairingEngine>::Fr> - for Computation + for Computation<'a, T, I> { fn generate_constraints( self, @@ -143,7 +143,9 @@ impl>> } } -impl>> Computation { +impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator>> + Computation<'a, T, I> +{ pub fn public_inputs_values(&self) -> Vec<::Fr> { self.program .public_inputs_values(self.witness.as_ref().unwrap()) diff --git a/zokrates_ark/src/marlin.rs b/zokrates_ark/src/marlin.rs index cc85f6a2..24204a6c 100644 --- a/zokrates_ark/src/marlin.rs +++ b/zokrates_ark/src/marlin.rs @@ -134,9 +134,9 @@ impl UniversalBackend for Ark res } - fn setup>>( + fn setup<'a, I: IntoIterator>>( srs: Vec, - program: ProgIterator, + program: ProgIterator<'a, T, I>, ) -> Result, String> { let program = program.collect(); @@ -210,8 +210,8 @@ impl UniversalBackend for Ark } impl Backend for Ark { - fn generate_proof>>( - program: ProgIterator, + fn generate_proof<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, witness: Witness, proving_key: Vec, ) -> Proof { diff --git a/zokrates_ast/src/common/embed.rs b/zokrates_ast/src/common/embed.rs index 4133c5c8..0e8361d9 100644 --- a/zokrates_ast/src/common/embed.rs +++ b/zokrates_ast/src/common/embed.rs @@ -9,6 +9,7 @@ use crate::untyped::{ types::{UnresolvedSignature, UnresolvedType}, ConstantGenericNode, Expression, }; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use zokrates_field::Field; @@ -28,7 +29,7 @@ cfg_if::cfg_if! { /// A low level function that contains non-deterministic introduction of variables. It is carried out as is until /// the flattening step when it can be inlined. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Serialize, Deserialize)] pub enum FlatEmbed { BitArrayLe, Unpack, @@ -317,8 +318,8 @@ impl FlatEmbed { /// - constraint system variables /// - arguments #[cfg(feature = "bellman")] -pub fn sha256_round( -) -> FlatFunctionIterator>> { +pub fn sha256_round<'ast, T: Field>( +) -> FlatFunctionIterator<'ast, T, impl IntoIterator>> { use zokrates_field::Bn128Field; assert_eq!(T::id(), Bn128Field::id()); @@ -420,9 +421,9 @@ pub fn sha256_round( } #[cfg(feature = "ark")] -pub fn snark_verify_bls12_377( +pub fn snark_verify_bls12_377<'ast, T: Field>( n: usize, -) -> FlatFunctionIterator>> { +) -> FlatFunctionIterator<'ast, T, impl IntoIterator>> { use zokrates_field::Bw6_761Field; assert_eq!(T::id(), Bw6_761Field::id()); @@ -546,9 +547,9 @@ fn use_variable( /// # Remarks /// * the return value of the `FlatFunction` is not deterministic if `bit_width >= T::get_required_bits()` /// as some elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)` -pub fn unpack_to_bitwidth( +pub fn unpack_to_bitwidth<'ast, T: Field>( bit_width: usize, -) -> FlatFunctionIterator>> { +) -> FlatFunctionIterator<'ast, T, impl IntoIterator>> { let mut counter = 0; let mut layout = HashMap::new(); diff --git a/zokrates_ast/src/common/error.rs b/zokrates_ast/src/common/error.rs index 45ef0422..b5ea381b 100644 --- a/zokrates_ast/src/common/error.rs +++ b/zokrates_ast/src/common/error.rs @@ -3,6 +3,7 @@ use std::fmt; #[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)] pub enum RuntimeError { + UnsatisfiedConstraint, BellmanConstraint, BellmanOneBinding, BellmanInputBinding, @@ -63,6 +64,7 @@ impl fmt::Display for RuntimeError { use RuntimeError::*; let msg = match self { + UnsatisfiedConstraint => "Constraint is unsatisfied", BellmanConstraint => "Bellman constraint is unsatisfied", BellmanOneBinding => "Bellman ~one binding is unsatisfied", BellmanInputBinding => "Bellman input binding is unsatisfied", diff --git a/zokrates_ast/src/common/solvers.rs b/zokrates_ast/src/common/solvers.rs index d8387f26..c3e4fe77 100644 --- a/zokrates_ast/src/common/solvers.rs +++ b/zokrates_ast/src/common/solvers.rs @@ -1,8 +1,9 @@ +use crate::zir::ZirFunction; use serde::{Deserialize, Serialize}; use std::fmt; #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] -pub enum Solver { +pub enum Solver<'ast, T> { ConditionEq, Bits(usize), Div, @@ -11,19 +12,24 @@ pub enum Solver { ShaAndXorAndXorAnd, ShaCh, EuclideanDiv, + #[serde(borrow)] + Zir(ZirFunction<'ast, T>), #[cfg(feature = "bellman")] Sha256Round, #[cfg(feature = "ark")] SnarkVerifyBls12377(usize), } -impl fmt::Display for Solver { +impl<'ast, T: fmt::Debug + fmt::Display> fmt::Display for Solver<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:?}", self) + match self { + Solver::Zir(func) => write!(f, "Zir({})", func), + _ => write!(f, "{:?}", self) + } } } -impl Solver { +impl<'ast, T> Solver<'ast, T> { pub fn get_signature(&self) -> (usize, usize) { match self { Solver::ConditionEq => (1, 2), @@ -34,6 +40,7 @@ impl Solver { Solver::ShaAndXorAndXorAnd => (3, 1), Solver::ShaCh => (3, 1), Solver::EuclideanDiv => (2, 2), + Solver::Zir(f) => (f.arguments.len(), 1), #[cfg(feature = "bellman")] Solver::Sha256Round => (768, 26935), #[cfg(feature = "ark")] @@ -42,7 +49,7 @@ impl Solver { } } -impl Solver { +impl<'ast, T> Solver<'ast, T> { pub fn bits(width: usize) -> Self { Solver::Bits(width) } diff --git a/zokrates_ast/src/flat/folder.rs b/zokrates_ast/src/flat/folder.rs index 4c9baeb0..5f8b4e39 100644 --- a/zokrates_ast/src/flat/folder.rs +++ b/zokrates_ast/src/flat/folder.rs @@ -4,8 +4,8 @@ use super::*; use crate::common::Variable; use zokrates_field::Field; -pub trait Folder: Sized { - fn fold_program(&mut self, p: FlatProg) -> FlatProg { +pub trait Folder<'ast, T: Field>: Sized { + fn fold_program(&mut self, p: FlatProg<'ast, T>) -> FlatProg<'ast, T> { fold_program(self, p) } @@ -17,7 +17,7 @@ pub trait Folder: Sized { fold_variable(self, v) } - fn fold_statement(&mut self, s: FlatStatement) -> Vec> { + fn fold_statement(&mut self, s: FlatStatement<'ast, T>) -> Vec> { fold_statement(self, s) } @@ -25,12 +25,15 @@ pub trait Folder: Sized { fold_expression(self, e) } - fn fold_directive(&mut self, d: FlatDirective) -> FlatDirective { + fn fold_directive(&mut self, d: FlatDirective<'ast, T>) -> FlatDirective<'ast, T> { fold_directive(self, d) } } -pub fn fold_program>(f: &mut F, p: FlatProg) -> FlatProg { +pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + p: FlatProg<'ast, T>, +) -> FlatProg<'ast, T> { FlatProg { arguments: p .arguments @@ -46,10 +49,10 @@ pub fn fold_program>(f: &mut F, p: FlatProg) -> FlatPr } } -pub fn fold_statement>( +pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - s: FlatStatement, -) -> Vec> { + s: FlatStatement<'ast, T>, +) -> Vec> { match s { FlatStatement::Condition(left, right, error) => vec![FlatStatement::Condition( f.fold_expression(left), @@ -70,7 +73,7 @@ pub fn fold_statement>( } } -pub fn fold_expression>( +pub fn fold_expression<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: FlatExpression, ) -> FlatExpression { @@ -89,7 +92,10 @@ pub fn fold_expression>( } } -pub fn fold_directive>(f: &mut F, ds: FlatDirective) -> FlatDirective { +pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ds: FlatDirective<'ast, T>, +) -> FlatDirective<'ast, T> { FlatDirective { inputs: ds .inputs @@ -101,13 +107,13 @@ pub fn fold_directive>(f: &mut F, ds: FlatDirective) - } } -pub fn fold_argument>(f: &mut F, a: Parameter) -> Parameter { +pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter { Parameter { id: f.fold_variable(a.id), private: a.private, } } -pub fn fold_variable>(_f: &mut F, v: Variable) -> Variable { +pub fn fold_variable<'ast, T: Field, F: Folder<'ast, T>>(_f: &mut F, v: Variable) -> Variable { v } diff --git a/zokrates_ast/src/flat/mod.rs b/zokrates_ast/src/flat/mod.rs index 903fff2c..cf38e553 100644 --- a/zokrates_ast/src/flat/mod.rs +++ b/zokrates_ast/src/flat/mod.rs @@ -24,14 +24,14 @@ use std::collections::HashMap; use std::fmt; use zokrates_field::Field; -pub type FlatProg = FlatFunction; +pub type FlatProg<'ast, T> = FlatFunction<'ast, T>; -pub type FlatFunction = FlatFunctionIterator>>; +pub type FlatFunction<'ast, T> = FlatFunctionIterator<'ast, T, Vec>>; -pub type FlatProgIterator = FlatFunctionIterator; +pub type FlatProgIterator<'ast, T, I> = FlatFunctionIterator<'ast, T, I>; #[derive(Clone, PartialEq, Eq, Debug)] -pub struct FlatFunctionIterator>> { +pub struct FlatFunctionIterator<'ast, T, I: IntoIterator>> { /// Arguments of the function pub arguments: Vec, /// Vector of statements that are executed when running the function @@ -40,8 +40,8 @@ pub struct FlatFunctionIterator>> { pub return_count: usize, } -impl>> FlatFunctionIterator { - pub fn collect(self) -> FlatFunction { +impl<'ast, T, I: IntoIterator>> FlatFunctionIterator<'ast, T, I> { + pub fn collect(self) -> FlatFunction<'ast, T> { FlatFunction { statements: self.statements.into_iter().collect(), arguments: self.arguments, @@ -50,7 +50,7 @@ impl>> FlatFunctionIterator { } } -impl fmt::Display for FlatFunction { +impl<'ast, T: Field> fmt::Display for FlatFunction<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -81,14 +81,14 @@ impl fmt::Display for FlatFunction { /// * r1cs - R1CS in standard JSON data format #[derive(Clone, PartialEq, Eq, Debug)] -pub enum FlatStatement { +pub enum FlatStatement<'ast, T> { Condition(FlatExpression, FlatExpression, RuntimeError), Definition(Variable, FlatExpression), - Directive(FlatDirective), + Directive(FlatDirective<'ast, T>), Log(FormatString, Vec<(ConcreteType, Vec>)>), } -impl fmt::Display for FlatStatement { +impl<'ast, T: Field> fmt::Display for FlatStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { FlatStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs), @@ -116,10 +116,10 @@ impl fmt::Display for FlatStatement { } } -impl FlatStatement { +impl<'ast, T: Field> FlatStatement<'ast, T> { pub fn apply_substitution( self, - substitution: &HashMap, + substitution: &'ast HashMap, ) -> FlatStatement { match self { FlatStatement::Definition(id, x) => FlatStatement::Definition( @@ -167,16 +167,16 @@ impl FlatStatement { } #[derive(Clone, Hash, Debug, PartialEq, Eq)] -pub struct FlatDirective { +pub struct FlatDirective<'ast, T> { pub inputs: Vec>, pub outputs: Vec, - pub solver: Solver, + pub solver: Solver<'ast, T>, } -impl FlatDirective { +impl<'ast, T> FlatDirective<'ast, T> { pub fn new>>( outputs: Vec, - solver: Solver, + solver: Solver<'ast, T>, inputs: Vec, ) -> Self { let (in_len, out_len) = solver.get_signature(); @@ -190,7 +190,7 @@ impl FlatDirective { } } -impl fmt::Display for FlatDirective { +impl<'ast, T: Field> fmt::Display for FlatDirective<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, diff --git a/zokrates_ast/src/ir/check.rs b/zokrates_ast/src/ir/check.rs index 11c5fd84..41cac7b0 100644 --- a/zokrates_ast/src/ir/check.rs +++ b/zokrates_ast/src/ir/check.rs @@ -13,7 +13,9 @@ pub struct UnconstrainedVariableDetector { } impl UnconstrainedVariableDetector { - pub fn new>>(p: &ProgIterator) -> Self { + pub fn new<'ast, T: Field, I: IntoIterator>>( + p: &ProgIterator<'ast, T, I>, + ) -> Self { UnconstrainedVariableDetector { variables: p .arguments @@ -32,7 +34,7 @@ impl UnconstrainedVariableDetector { } } -impl Folder for UnconstrainedVariableDetector { +impl<'ast, T: Field> Folder<'ast, T> for UnconstrainedVariableDetector { fn fold_argument(&mut self, p: Parameter) -> Parameter { p } @@ -40,7 +42,7 @@ impl Folder for UnconstrainedVariableDetector { self.variables.remove(&v); v } - fn fold_directive(&mut self, d: Directive) -> Directive { + fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { self.variables.extend(d.outputs.iter()); d } diff --git a/zokrates_ast/src/ir/folder.rs b/zokrates_ast/src/ir/folder.rs index 22245252..6d1c20b9 100644 --- a/zokrates_ast/src/ir/folder.rs +++ b/zokrates_ast/src/ir/folder.rs @@ -4,8 +4,8 @@ use super::*; use crate::common::Variable; use zokrates_field::Field; -pub trait Folder: Sized { - fn fold_program(&mut self, p: Prog) -> Prog { +pub trait Folder<'ast, T: Field>: Sized { + fn fold_program(&mut self, p: Prog<'ast, T>) -> Prog<'ast, T> { fold_program(self, p) } @@ -17,7 +17,7 @@ pub trait Folder: Sized { fold_variable(self, v) } - fn fold_statement(&mut self, s: Statement) -> Vec> { + fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { fold_statement(self, s) } @@ -29,12 +29,15 @@ pub trait Folder: Sized { fold_quadratic_combination(self, es) } - fn fold_directive(&mut self, d: Directive) -> Directive { + fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { fold_directive(self, d) } } -pub fn fold_program>(f: &mut F, p: Prog) -> Prog { +pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + p: Prog<'ast, T>, +) -> Prog<'ast, T> { Prog { arguments: p .arguments @@ -50,7 +53,10 @@ pub fn fold_program>(f: &mut F, p: Prog) -> Prog { } } -pub fn fold_statement>(f: &mut F, s: Statement) -> Vec> { +pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: Statement<'ast, T>, +) -> Vec> { match s { Statement::Constraint(quad, lin, message) => vec![Statement::Constraint( f.fold_quadratic_combination(quad), @@ -74,7 +80,10 @@ pub fn fold_statement>(f: &mut F, s: Statement) -> Vec } } -pub fn fold_linear_combination>(f: &mut F, e: LinComb) -> LinComb { +pub fn fold_linear_combination<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: LinComb, +) -> LinComb { LinComb( e.0.into_iter() .map(|(variable, coefficient)| (f.fold_variable(variable), coefficient)) @@ -82,7 +91,7 @@ pub fn fold_linear_combination>(f: &mut F, e: LinComb) ) } -pub fn fold_quadratic_combination>( +pub fn fold_quadratic_combination<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: QuadComb, ) -> QuadComb { @@ -92,7 +101,10 @@ pub fn fold_quadratic_combination>( } } -pub fn fold_directive>(f: &mut F, ds: Directive) -> Directive { +pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ds: Directive<'ast, T>, +) -> Directive<'ast, T> { Directive { inputs: ds .inputs @@ -104,13 +116,13 @@ pub fn fold_directive>(f: &mut F, ds: Directive) -> Di } } -pub fn fold_argument>(f: &mut F, a: Parameter) -> Parameter { +pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter { Parameter { id: f.fold_variable(a.id), private: a.private, } } -pub fn fold_variable>(_f: &mut F, v: Variable) -> Variable { +pub fn fold_variable<'ast, T: Field, F: Folder<'ast, T>>(_f: &mut F, v: Variable) -> Variable { v } diff --git a/zokrates_ast/src/ir/from_flat.rs b/zokrates_ast/src/ir/from_flat.rs index e5e3cb73..35acad34 100644 --- a/zokrates_ast/src/ir/from_flat.rs +++ b/zokrates_ast/src/ir/from_flat.rs @@ -17,9 +17,9 @@ impl QuadComb { } } -pub fn from_flat>>( - flat_prog_iterator: FlatProgIterator, -) -> ProgIterator>> { +pub fn from_flat<'ast, T: Field, I: IntoIterator>>( + flat_prog_iterator: FlatProgIterator<'ast, T, I>, +) -> ProgIterator>> { ProgIterator { statements: flat_prog_iterator.statements.into_iter().map(Into::into), arguments: flat_prog_iterator.arguments, @@ -52,8 +52,8 @@ impl From> for LinComb { } } -impl From> for Statement { - fn from(flat_statement: FlatStatement) -> Statement { +impl<'ast, T: Field> From> for Statement<'ast, T> { + fn from(flat_statement: FlatStatement<'ast, T>) -> Statement<'ast, T> { match flat_statement { FlatStatement::Condition(linear, quadratic, message) => match quadratic { FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint( @@ -83,8 +83,8 @@ impl From> for Statement { } } -impl From> for Directive { - fn from(ds: FlatDirective) -> Directive { +impl<'ast, T: Field> From> for Directive<'ast, T> { + fn from(ds: FlatDirective<'ast, T>) -> Directive { Directive { inputs: ds .inputs diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index bd068b89..9e852bfc 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -26,15 +26,16 @@ pub use crate::common::Variable; pub use self::witness::Witness; #[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)] -pub enum Statement { +pub enum Statement<'ast, T> { Constraint(QuadComb, LinComb, Option), - Directive(Directive), + #[serde(borrow)] + Directive(Directive<'ast, T>), Log(FormatString, Vec<(ConcreteType, Vec>)>), } pub type PublicInputs = BTreeSet; -impl Statement { +impl<'ast, T: Field> Statement<'ast, T> { pub fn definition>>(v: Variable, e: U) -> Self { Statement::Constraint(e.into(), v.into(), None) } @@ -45,13 +46,14 @@ impl Statement { } #[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] -pub struct Directive { +pub struct Directive<'ast, T> { pub inputs: Vec>, pub outputs: Vec, - pub solver: Solver, + #[serde(borrow)] + pub solver: Solver<'ast, T>, } -impl fmt::Display for Directive { +impl<'ast, T: Field> fmt::Display for Directive<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -71,7 +73,7 @@ impl fmt::Display for Directive { } } -impl fmt::Display for Statement { +impl<'ast, T: Field> fmt::Display for Statement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Statement::Constraint(ref quad, ref lin, _) => write!(f, "{} == {}", quad, lin), @@ -96,16 +98,16 @@ impl fmt::Display for Statement { } } -pub type Prog = ProgIterator>>; +pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec>>; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] -pub struct ProgIterator>> { +pub struct ProgIterator<'ast, T, I: IntoIterator>> { pub arguments: Vec, pub return_count: usize, pub statements: I, } -impl>> ProgIterator { +impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, I> { pub fn new(arguments: Vec, statements: I, return_count: usize) -> Self { Self { arguments, @@ -114,7 +116,7 @@ impl>> ProgIterator { } } - pub fn collect(self) -> ProgIterator>> { + pub fn collect(self) -> ProgIterator<'ast, T, Vec>> { ProgIterator { statements: self.statements.into_iter().collect::>(), arguments: self.arguments, @@ -139,7 +141,7 @@ impl>> ProgIterator { } } -impl>> ProgIterator { +impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { pub fn public_inputs_values(&self, witness: &Witness) -> Vec { self.arguments .iter() @@ -150,7 +152,7 @@ impl>> ProgIterator { } } -impl Prog { +impl<'ast, T> Prog<'ast, T> { pub fn constraint_count(&self) -> usize { self.statements .iter() @@ -158,7 +160,9 @@ impl Prog { .count() } - pub fn into_prog_iter(self) -> ProgIterator>> { + pub fn into_prog_iter( + self, + ) -> ProgIterator<'ast, T, impl IntoIterator>> { ProgIterator { statements: self.statements.into_iter(), arguments: self.arguments, @@ -167,7 +171,7 @@ impl Prog { } } -impl fmt::Display for Prog { +impl<'ast, T: Field> fmt::Display for Prog<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let returns = (0..self.return_count) .map(Variable::public) diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 39c74636..09d00390 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -12,32 +12,35 @@ const ZOKRATES_VERSION_2: &[u8; 4] = &[0, 0, 0, 2]; #[derive(PartialEq, Eq, Debug)] pub enum ProgEnum< - Bls12_381I: IntoIterator>, - Bn128I: IntoIterator>, - Bls12_377I: IntoIterator>, - Bw6_761I: IntoIterator>, + 'ast, + Bls12_381I: IntoIterator>, + Bn128I: IntoIterator>, + Bls12_377I: IntoIterator>, + Bw6_761I: IntoIterator>, > { - Bls12_381Program(ProgIterator), - Bn128Program(ProgIterator), - Bls12_377Program(ProgIterator), - Bw6_761Program(ProgIterator), + Bls12_381Program(ProgIterator<'ast, Bls12_381Field, Bls12_381I>), + Bn128Program(ProgIterator<'ast, Bn128Field, Bn128I>), + Bls12_377Program(ProgIterator<'ast, Bls12_377Field, Bls12_377I>), + Bw6_761Program(ProgIterator<'ast, Bw6_761Field, Bw6_761I>), } -type MemoryProgEnum = ProgEnum< - Vec>, - Vec>, - Vec>, - Vec>, +type MemoryProgEnum<'ast> = ProgEnum< + 'ast, + Vec>, + Vec>, + Vec>, + Vec>, >; impl< - Bls12_381I: IntoIterator>, - Bn128I: IntoIterator>, - Bls12_377I: IntoIterator>, - Bw6_761I: IntoIterator>, - > ProgEnum + 'ast, + Bls12_381I: IntoIterator>, + Bn128I: IntoIterator>, + Bls12_377I: IntoIterator>, + Bw6_761I: IntoIterator>, + > ProgEnum<'ast, Bls12_381I, Bn128I, Bls12_377I, Bw6_761I> { - pub fn collect(self) -> MemoryProgEnum { + pub fn collect(self) -> MemoryProgEnum<'ast> { match self { ProgEnum::Bls12_381Program(p) => ProgEnum::Bls12_381Program(p.collect()), ProgEnum::Bn128Program(p) => ProgEnum::Bn128Program(p.collect()), @@ -55,7 +58,7 @@ impl< } } -impl>> ProgIterator { +impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { /// serialize a program iterator, returning the number of constraints serialized /// Note that we only return constraints, not other statements such as directives pub fn serialize(self, mut w: W) -> Result { @@ -106,10 +109,11 @@ impl<'de, R: serde_cbor::de::Read<'de>, T: serde::Deserialize<'de>> Iterator impl<'de, R: Read> ProgEnum< - UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement>, - UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement>, - UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement>, - UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement>, + 'de, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bls12_381Field>>, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bn128Field>>, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bls12_377Field>>, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bw6_761Field>>, > { pub fn deserialize(mut r: R) -> Result { diff --git a/zokrates_ast/src/ir/smtlib2.rs b/zokrates_ast/src/ir/smtlib2.rs index 8bdd04d3..1a80c874 100644 --- a/zokrates_ast/src/ir/smtlib2.rs +++ b/zokrates_ast/src/ir/smtlib2.rs @@ -12,9 +12,9 @@ pub trait SMTLib2 { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result; } -pub struct SMTLib2Display<'a, T>(pub &'a Prog); +pub struct SMTLib2Display<'a, 'ast, T>(pub &'a Prog<'ast, T>); -impl fmt::Display for SMTLib2Display<'_, T> { +impl<'ast, T: Field> fmt::Display for SMTLib2Display<'_, 'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.0.to_smtlib2(f) } @@ -30,7 +30,7 @@ impl Visitor for VariableCollector { } } -impl SMTLib2 for Prog { +impl<'ast, T: Field> SMTLib2 for Prog<'ast, T> { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut collector = VariableCollector { variables: BTreeSet::::new(), @@ -75,7 +75,7 @@ fn format_prefix_op_smtlib2( write!(f, ")") } -impl SMTLib2 for Statement { +impl<'ast, T: Field> SMTLib2 for Statement<'ast, T> { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Statement::Constraint(ref quad, ref lin, _) => { @@ -91,7 +91,7 @@ impl SMTLib2 for Statement { } } -impl SMTLib2 for Directive { +impl<'ast, T: Field> SMTLib2 for Directive<'ast, T> { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "") } diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index e1e6d971..c0f6daf6 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -260,6 +260,13 @@ pub trait Folder<'ast, T: Field>: Sized { fold_assignee(self, a) } + fn fold_assembly_statement( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> TypedAssemblyStatement<'ast, T> { + fold_assembly_statement(self, s) + } + fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { fold_statement(self, s) } @@ -505,6 +512,21 @@ pub fn fold_definition_rhs<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: TypedAssemblyStatement<'ast, T>, +) -> TypedAssemblyStatement<'ast, T> { + match s { + TypedAssemblyStatement::Assignment(a, e) => { + TypedAssemblyStatement::Assignment(f.fold_assignee(a), f.fold_field_expression(e)) + } + TypedAssemblyStatement::Constraint(lhs, rhs) => TypedAssemblyStatement::Constraint( + f.fold_field_expression(lhs), + f.fold_field_expression(rhs), + ), + } +} + pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: TypedStatement<'ast, T>, @@ -529,6 +551,12 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( TypedStatement::Log(s, e) => { TypedStatement::Log(s, e.into_iter().map(|e| f.fold_expression(e)).collect()) } + TypedStatement::Assembly(statements) => TypedStatement::Assembly( + statements + .into_iter() + .map(|s| f.fold_assembly_statement(s)) + .collect(), + ), s => s, }; vec![res] diff --git a/zokrates_ast/src/typed/identifier.rs b/zokrates_ast/src/typed/identifier.rs index abcd2f40..772eb2bf 100644 --- a/zokrates_ast/src/typed/identifier.rs +++ b/zokrates_ast/src/typed/identifier.rs @@ -1,10 +1,12 @@ use crate::typed::CanonicalConstantIdentifier; +use serde::{Deserialize, Serialize}; use std::fmt; -pub type SourceIdentifier<'ast> = &'ast str; +pub type SourceIdentifier<'ast> = std::borrow::Cow<'ast, str>; -#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum CoreIdentifier<'ast> { + #[serde(borrow)] Source(ShadowedIdentifier<'ast>), Call(usize), Constant(CanonicalConstantIdentifier<'ast>), @@ -29,16 +31,18 @@ impl<'ast> From> for CoreIdentifier<'ast> { } /// A identifier for a variable -#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct Identifier<'ast> { /// the id of the variable + #[serde(borrow)] pub id: CoreIdentifier<'ast>, /// the version of the variable, used after SSA transformation pub version: usize, } -#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct ShadowedIdentifier<'ast> { + #[serde(borrow)] pub id: SourceIdentifier<'ast>, pub shadow: usize, } @@ -97,7 +101,7 @@ impl<'ast> Identifier<'ast> { // these two From implementations are only used in tests but somehow cfg(test) doesn't work impl<'ast> From<&'ast str> for CoreIdentifier<'ast> { fn from(s: &str) -> CoreIdentifier { - CoreIdentifier::Source(ShadowedIdentifier::shadow(s, 0)) + CoreIdentifier::Source(ShadowedIdentifier::shadow(std::borrow::Cow::Borrowed(s), 0)) } } diff --git a/zokrates_ast/src/typed/mod.rs b/zokrates_ast/src/typed/mod.rs index a43f59bd..c8c8c1c6 100644 --- a/zokrates_ast/src/typed/mod.rs +++ b/zokrates_ast/src/typed/mod.rs @@ -675,6 +675,28 @@ impl<'ast, T: fmt::Display> fmt::Display for DefinitionRhs<'ast, T> { } } +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub enum TypedAssemblyStatement<'ast, T> { + Assignment(TypedAssignee<'ast, T>, FieldElementExpression<'ast, T>), + Constraint( + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + ), +} + +impl<'ast, T: fmt::Display> fmt::Display for TypedAssemblyStatement<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + TypedAssemblyStatement::Assignment(ref lhs, ref rhs) => { + write!(f, "{} <-- {}", lhs, rhs) + } + TypedAssemblyStatement::Constraint(ref lhs, ref rhs) => { + write!(f, "{} === {}", lhs, rhs) + } + } + } +} + /// A statement in a `TypedFunction` #[allow(clippy::large_enum_variant)] #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] @@ -695,6 +717,7 @@ pub enum TypedStatement<'ast, T> { ConcreteGenericsAssignment<'ast>, ), PopCallLog, + Assembly(Vec>), } impl<'ast, T> TypedStatement<'ast, T> { @@ -719,6 +742,14 @@ impl<'ast, T: fmt::Display> TypedStatement<'ast, T> { } write!(f, "{}}}", "\t".repeat(depth)) } + TypedStatement::Assembly(statements) => { + write!(f, "{}", "\t".repeat(depth))?; + writeln!(f, "asm {{")?; + for s in statements { + writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?; + } + write!(f, "{}}}", "\t".repeat(depth)) + } s => write!(f, "{}{}", "\t".repeat(depth), s), } } @@ -766,6 +797,13 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { generics, ), TypedStatement::PopCallLog => write!(f, "// POP CALL",), + TypedStatement::Assembly(ref statements) => { + writeln!(f, "asm {{")?; + for s in statements { + writeln!(f, "\t\t{}", s)?; + } + write!(f, "\t}}") + } } } } @@ -1173,6 +1211,73 @@ pub enum FieldElementExpression<'ast, T> { Select(SelectExpression<'ast, T, Self>), Element(ElementExpression<'ast, T, Self>), } + +impl<'ast, T: Clone> From> for TupleExpression<'ast, T> { + fn from(assignee: TypedAssignee<'ast, T>) -> Self { + match assignee { + TypedAssignee::Identifier(v) => { + let inner = TupleExpressionInner::Identifier(v.id); + match v._type { + GType::Tuple(tuple_ty) => inner.annotate(tuple_ty), + _ => unreachable!(), + } + } + TypedAssignee::Select(box a, box index) => TupleExpression::select(a.into(), index), + TypedAssignee::Member(box a, id) => TupleExpression::member(a.into(), id), + TypedAssignee::Element(box a, index) => TupleExpression::element(a.into(), index), + } + } +} + +impl<'ast, T: Clone> From> for StructExpression<'ast, T> { + fn from(assignee: TypedAssignee<'ast, T>) -> Self { + match assignee { + TypedAssignee::Identifier(v) => { + let inner = StructExpressionInner::Identifier(v.id); + match v._type { + GType::Struct(struct_ty) => inner.annotate(struct_ty), + _ => unreachable!(), + } + } + TypedAssignee::Select(box a, box index) => StructExpression::select(a.into(), index), + TypedAssignee::Member(box a, id) => StructExpression::member(a.into(), id), + TypedAssignee::Element(box a, index) => StructExpression::element(a.into(), index), + } + } +} + +impl<'ast, T: Clone> From> for ArrayExpression<'ast, T> { + fn from(assignee: TypedAssignee<'ast, T>) -> Self { + match assignee { + TypedAssignee::Identifier(v) => { + let inner = ArrayExpressionInner::Identifier(v.id); + match v._type { + GType::Array(array_ty) => inner.annotate(*array_ty.ty, *array_ty.size), + _ => unreachable!(), + } + } + TypedAssignee::Select(box a, box index) => ArrayExpression::select(a.into(), index), + TypedAssignee::Member(box a, id) => ArrayExpression::member(a.into(), id), + TypedAssignee::Element(box a, index) => ArrayExpression::element(a.into(), index), + } + } +} + +impl<'ast, T: Clone> From> for FieldElementExpression<'ast, T> { + fn from(assignee: TypedAssignee<'ast, T>) -> Self { + match assignee { + TypedAssignee::Identifier(v) => FieldElementExpression::Identifier(v.id), + TypedAssignee::Element(box a, index) => { + FieldElementExpression::element(a.into(), index) + } + TypedAssignee::Member(box a, id) => FieldElementExpression::member(a.into(), id), + TypedAssignee::Select(box a, box index) => { + FieldElementExpression::select(a.into(), index) + } + } + } +} + impl<'ast, T> Add for FieldElementExpression<'ast, T> { type Output = Self; @@ -1209,6 +1314,9 @@ impl<'ast, T> FieldElementExpression<'ast, T> { pub fn pow(self, other: UExpression<'ast, T>) -> Self { FieldElementExpression::Pow(box self, box other) } + pub fn is_quadratic(&self) -> bool { + true // TODO: implement + } } impl<'ast, T> From for FieldElementExpression<'ast, T> { diff --git a/zokrates_ast/src/typed/result_folder.rs b/zokrates_ast/src/typed/result_folder.rs index fe6cd157..7732bc39 100644 --- a/zokrates_ast/src/typed/result_folder.rs +++ b/zokrates_ast/src/typed/result_folder.rs @@ -378,6 +378,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_assignee(self, a) } + fn fold_assembly_statement( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Result, Self::Error> { + fold_assembly_statement(self, s) + } + fn fold_statement( &mut self, s: TypedStatement<'ast, T>, @@ -508,6 +515,21 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } } +pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedAssemblyStatement<'ast, T>, +) -> Result, F::Error> { + Ok(match s { + TypedAssemblyStatement::Assignment(a, e) => { + TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, f.fold_field_expression(e)?) + } + TypedAssemblyStatement::Constraint(lhs, rhs) => TypedAssemblyStatement::Constraint( + f.fold_field_expression(lhs)?, + f.fold_field_expression(rhs)?, + ), + }) +} + pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: TypedStatement<'ast, T>, @@ -538,6 +560,12 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( .map(|e| f.fold_expression(e)) .collect::, _>>()?, ), + TypedStatement::Assembly(statements) => TypedStatement::Assembly( + statements + .into_iter() + .map(|s| f.fold_assembly_statement(s)) + .collect::, _>>()?, + ), s => s, }; Ok(vec![res]) diff --git a/zokrates_ast/src/typed/types.rs b/zokrates_ast/src/typed/types.rs index 6d29da5c..8bd863f9 100644 --- a/zokrates_ast/src/typed/types.rs +++ b/zokrates_ast/src/typed/types.rs @@ -51,7 +51,7 @@ pub struct GenericIdentifier<'ast> { impl<'ast> From> for CoreIdentifier<'ast> { fn from(g: GenericIdentifier<'ast>) -> CoreIdentifier<'ast> { // generic identifiers are always declared in the function scope, which is shadow 0 - CoreIdentifier::Source(ShadowedIdentifier::shadow(g.name(), 0)) + CoreIdentifier::Source(ShadowedIdentifier::shadow(std::borrow::Cow::Borrowed(g.name()), 0)) } } @@ -119,9 +119,10 @@ pub struct SpecializationError; pub type ConstantIdentifier<'ast> = &'ast str; -#[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)] +#[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord, Serialize, Deserialize)] pub struct CanonicalConstantIdentifier<'ast> { pub module: OwnedTypedModuleId, + #[serde(borrow)] pub id: ConstantIdentifier<'ast>, } diff --git a/zokrates_ast/src/untyped/from_ast.rs b/zokrates_ast/src/untyped/from_ast.rs index 349a03dc..b37af590 100644 --- a/zokrates_ast/src/untyped/from_ast.rs +++ b/zokrates_ast/src/untyped/from_ast.rs @@ -280,6 +280,7 @@ impl<'ast> From> for untyped::StatementNode<'ast> { pest::Statement::Assertion(s) => untyped::StatementNode::from(s), pest::Statement::Return(s) => untyped::StatementNode::from(s), pest::Statement::Log(s) => untyped::StatementNode::from(s), + pest::Statement::Assembly(s) => untyped::StatementNode::from(s), } } } @@ -343,6 +344,32 @@ impl<'ast> From> for untyped::StatementNode<'ast> } } +impl<'ast> From> for untyped::StatementNode<'ast> { + fn from(statement: pest::AssemblyStatement<'ast>) -> untyped::StatementNode<'ast> { + use crate::untyped::NodeValue; + + let statements = statement + .inner + .into_iter() + .map(|s| match s { + pest::AssemblyStatementInner::Assignment(a) => { + untyped::AssemblyStatement::Assignment( + a.assignee.into(), + a.expression.into(), + matches!(a.operator, pest::AssignmentOperator::AssignConstrain), + ) + .span(a.span) + } + pest::AssemblyStatementInner::Constraint(c) => { + untyped::AssemblyStatement::Constraint(c.lhs.into(), c.rhs.into()).span(c.span) + } + }) + .collect(); + + untyped::Statement::Assembly(statements).span(statement.span) + } +} + impl<'ast> From> for untyped::ExpressionNode<'ast> { fn from(expression: pest::Expression<'ast>) -> untyped::ExpressionNode<'ast> { match expression { diff --git a/zokrates_ast/src/untyped/mod.rs b/zokrates_ast/src/untyped/mod.rs index 8fabc2ec..07541ede 100644 --- a/zokrates_ast/src/untyped/mod.rs +++ b/zokrates_ast/src/untyped/mod.rs @@ -382,6 +382,33 @@ impl<'ast> fmt::Display for Assignee<'ast> { } } +#[derive(Debug, Clone, PartialEq)] +pub enum AssemblyStatement<'ast> { + Assignment(AssigneeNode<'ast>, ExpressionNode<'ast>, bool), + Constraint(ExpressionNode<'ast>, ExpressionNode<'ast>), +} + +pub type AssemblyStatementNode<'ast> = Node>; + +impl<'ast> fmt::Display for AssemblyStatement<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + AssemblyStatement::Assignment(ref lhs, ref rhs, ref constrained) => { + write!( + f, + "{} <{} {}", + lhs, + if *constrained { "==" } else { "--" }, + rhs + ) + } + AssemblyStatement::Constraint(ref lhs, ref rhs) => { + write!(f, "{} === {}", lhs, rhs) + } + } + } +} + /// A statement in a `Function` #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq)] @@ -397,6 +424,7 @@ pub enum Statement<'ast> { Vec>, ), Log(&'ast str, Vec>), + Assembly(Vec>), } pub type StatementNode<'ast> = Node>; @@ -431,7 +459,7 @@ impl<'ast> fmt::Display for Statement<'ast> { } Statement::Log(ref l, ref expressions) => write!( f, - "log({}, {})", + "log({}, {});", l, expressions .iter() @@ -439,6 +467,13 @@ impl<'ast> fmt::Display for Statement<'ast> { .collect::>() .join(", ") ), + Statement::Assembly(ref statements) => { + writeln!(f, "asm {{")?; + for s in statements { + writeln!(f, "\t\t{};", s)?; + } + write!(f, "\t}}") + } } } } diff --git a/zokrates_ast/src/untyped/node.rs b/zokrates_ast/src/untyped/node.rs index 44cda49f..62ef299d 100644 --- a/zokrates_ast/src/untyped/node.rs +++ b/zokrates_ast/src/untyped/node.rs @@ -84,6 +84,7 @@ use super::*; impl<'ast> NodeValue for Expression<'ast> {} impl<'ast> NodeValue for Assignee<'ast> {} impl<'ast> NodeValue for Statement<'ast> {} +impl<'ast> NodeValue for AssemblyStatement<'ast> {} impl<'ast> NodeValue for SymbolDeclaration<'ast> {} impl<'ast> NodeValue for UnresolvedType<'ast> {} impl<'ast> NodeValue for StructDefinition<'ast> {} diff --git a/zokrates_ast/src/zir/folder.rs b/zokrates_ast/src/zir/folder.rs index f934ed3c..6523e2f8 100644 --- a/zokrates_ast/src/zir/folder.rs +++ b/zokrates_ast/src/zir/folder.rs @@ -56,6 +56,13 @@ pub trait Folder<'ast, T: Field>: Sized { self.fold_variable(a) } + fn fold_assembly_statement( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> ZirAssemblyStatement<'ast, T> { + fold_assembly_statement(self, s) + } + fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { fold_statement(self, s) } @@ -127,6 +134,24 @@ pub trait Folder<'ast, T: Field>: Sized { } } +pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ZirAssemblyStatement<'ast, T>, +) -> ZirAssemblyStatement<'ast, T> { + match s { + ZirAssemblyStatement::Assignment(assignees, function) => { + let assignees = assignees.into_iter().map(|a| f.fold_assignee(a)).collect(); + let function = f.fold_function(function); + ZirAssemblyStatement::Assignment(assignees, function) + } + ZirAssemblyStatement::Constraint(lhs, rhs) => { + let lhs = f.fold_field_expression(lhs); + let rhs = f.fold_field_expression(rhs); + ZirAssemblyStatement::Constraint(lhs, rhs) + } + } +} + pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: ZirStatement<'ast, T>, @@ -165,6 +190,12 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( .map(|(t, e)| (t, e.into_iter().map(|e| f.fold_expression(e)).collect())) .collect(), ), + ZirStatement::Assembly(statements) => ZirStatement::Assembly( + statements + .into_iter() + .map(|s| f.fold_assembly_statement(s)) + .collect(), + ), }; vec![res] } diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index aae839d3..249b2630 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -1,15 +1,18 @@ use crate::zir::types::MemberId; +use serde::{Deserialize, Serialize}; use std::fmt; use crate::typed::Identifier as CoreIdentifier; -#[derive(Debug, PartialEq, Clone, Hash, Eq)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] pub enum Identifier<'ast> { + #[serde(borrow)] Source(SourceIdentifier<'ast>), } -#[derive(Debug, PartialEq, Clone, Hash, Eq)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] pub enum SourceIdentifier<'ast> { + #[serde(borrow)] Basic(CoreIdentifier<'ast>), Select(Box>, u32), Member(Box>, MemberId), diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index d2cbf633..c1752b52 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -21,6 +21,7 @@ use zokrates_field::Field; pub use self::folder::Folder; pub use self::identifier::{Identifier, SourceIdentifier}; +use serde::{Deserialize, Serialize}; /// A typed program as a collection of modules, one of them being the main #[derive(PartialEq, Eq, Debug, Clone)] @@ -34,11 +35,13 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirProgram<'ast, T> { } } /// A typed function -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub struct ZirFunction<'ast, T> { /// Arguments of the function + #[serde(borrow)] pub arguments: Vec>, /// Vector of statements that are executed when running the function + #[serde(borrow)] pub statements: Vec>, /// function signature pub signature: Signature, @@ -88,7 +91,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ZirFunction<'ast, T> { pub type ZirAssignee<'ast> = Variable<'ast>; -#[derive(Debug, Clone, PartialEq, Hash, Eq)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub enum RuntimeError { SourceAssertion(String), SelectRangeCheck, @@ -113,8 +116,70 @@ impl RuntimeError { } } +// #[derive(Clone, PartialEq, Hash, Eq, Debug)] +// pub struct ZirBlock<'ast, T> { +// pub statements: Vec>, +// pub value: FieldElementExpression<'ast, T>, +// } +// +// impl<'ast, T> ZirBlock<'ast, T> { +// pub fn new( +// statements: Vec>, +// value: FieldElementExpression<'ast, T>, +// ) -> Self { +// Self { statements, value } +// } +// } +// impl<'ast, T: fmt::Display> fmt::Display for ZirBlock<'ast, T> { +// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +// write!( +// f, +// "{{\n{}\n}}", +// self.statements +// .iter() +// .map(|s| s.to_string()) +// .chain(std::iter::once(self.value.to_string())) +// .collect::>() +// .join("\n") +// ) +// } +// } + +#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] +pub enum ZirAssemblyStatement<'ast, T> { + Assignment( + #[serde(borrow)] Vec>, + ZirFunction<'ast, T>, + ), + Constraint( + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + ), +} + +impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ZirAssemblyStatement::Assignment(ref lhs, ref rhs) => { + write!( + f, + "{} <-- {}", + lhs.iter() + .map(|a| a.id.to_string()) + .collect::>() + .join(", "), + rhs + ) + } + ZirAssemblyStatement::Constraint(ref lhs, ref rhs) => { + write!(f, "{} === {}", lhs, rhs) + } + } + } +} + /// A statement in a `ZirFunction` -#[derive(Clone, PartialEq, Hash, Eq, Debug)] +#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] pub enum ZirStatement<'ast, T> { Return(Vec>), Definition(ZirAssignee<'ast>, ZirExpression<'ast, T>), @@ -129,6 +194,7 @@ pub enum ZirStatement<'ast, T> { FormatString, Vec<(ConcreteType, Vec>)>, ), + Assembly(#[serde(borrow)] Vec>), } impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> { @@ -142,15 +208,19 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { write!(f, "{}", "\t".repeat(depth))?; match self { ZirStatement::Return(ref exprs) => { - write!( - f, - "return {};", - exprs - .iter() - .map(|e| e.to_string()) - .collect::>() - .join(", ") - ) + write!(f, "return")?; + if exprs.len() > 0 { + write!( + f, + " {}", + exprs + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", ") + )?; + } + write!(f, ";") } ZirStatement::Definition(ref lhs, ref rhs) => { write!(f, "{} = {};", lhs, rhs) @@ -166,7 +236,7 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { s.fmt_indented(f, depth + 1)?; writeln!(f)?; } - write!(f, "{}}};", "\t".repeat(depth)) + write!(f, "{}}}", "\t".repeat(depth)) } ZirStatement::Assertion(ref e, ref error) => { write!(f, "assert({}", e)?; @@ -200,6 +270,13 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { .collect::>() .join(", ") ), + ZirStatement::Assembly(statements) => { + writeln!(f, "asm {{")?; + for s in statements { + writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?; + } + write!(f, "{}}}", "\t".repeat(depth)) + } } } } @@ -208,8 +285,9 @@ pub trait Typed { fn get_type(&self) -> Type; } -#[derive(Debug, Clone, PartialEq, Hash, Eq)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub struct ConditionalExpression<'ast, T, E> { + #[serde(borrow)] pub condition: Box>, pub consequence: Box, pub alternative: Box, @@ -235,9 +313,10 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpress } } -#[derive(Debug, Clone, PartialEq, Hash, Eq)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub struct SelectExpression<'ast, T, E> { pub array: Vec, + #[serde(borrow)] pub index: Box>, } @@ -266,11 +345,11 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for SelectExpression<' } /// A typed expression -#[derive(Clone, PartialEq, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub enum ZirExpression<'ast, T> { Boolean(BooleanExpression<'ast, T>), FieldElement(FieldElementExpression<'ast, T>), - Uint(UExpression<'ast, T>), + Uint(#[serde(borrow)] UExpression<'ast, T>), } impl<'ast, T: Field> From> for ZirExpression<'ast, T> { @@ -343,15 +422,20 @@ pub trait MultiTyped { fn get_types(&self) -> &Vec; } -#[derive(Clone, PartialEq, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub enum ZirExpressionList<'ast, T> { - EmbedCall(FlatEmbed, Vec, Vec>), + EmbedCall( + FlatEmbed, + Vec, + #[serde(borrow)] Vec>, + ), } /// An expression of type `field` -#[derive(Clone, PartialEq, Hash, Eq, Debug)] +#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] pub enum FieldElementExpression<'ast, T> { Number(T), + #[serde(borrow)] Identifier(Identifier<'ast>), Select(SelectExpression<'ast, T, Self>), Add( @@ -372,15 +456,16 @@ pub enum FieldElementExpression<'ast, T> { ), Pow( Box>, - Box>, + #[serde(borrow)] Box>, ), Conditional(ConditionalExpression<'ast, T, FieldElementExpression<'ast, T>>), } /// An expression of type `bool` -#[derive(Clone, PartialEq, Hash, Eq, Debug)] +#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] pub enum BooleanExpression<'ast, T> { Value(bool), + #[serde(borrow)] Identifier(Identifier<'ast>), Select(SelectExpression<'ast, T, Self>), FieldLt( diff --git a/zokrates_ast/src/zir/parameter.rs b/zokrates_ast/src/zir/parameter.rs index 08d26c93..203a291a 100644 --- a/zokrates_ast/src/zir/parameter.rs +++ b/zokrates_ast/src/zir/parameter.rs @@ -1,8 +1,10 @@ use crate::zir::Variable; +use serde::{Deserialize, Serialize}; use std::fmt; -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub struct Parameter<'ast> { + #[serde(borrow)] pub id: Variable<'ast>, pub private: bool, } diff --git a/zokrates_ast/src/zir/result_folder.rs b/zokrates_ast/src/zir/result_folder.rs index 6b172ff8..16f78f01 100644 --- a/zokrates_ast/src/zir/result_folder.rs +++ b/zokrates_ast/src/zir/result_folder.rs @@ -61,6 +61,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { self.fold_variable(a) } + fn fold_assembly_statement( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> Result, Self::Error> { + fold_assembly_statement(self, s) + } + fn fold_statement( &mut self, s: ZirStatement<'ast, T>, @@ -144,6 +151,26 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_uint_expression_inner(self, bitwidth, e) } } +pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ZirAssemblyStatement<'ast, T>, +) -> Result, F::Error> { + Ok(match s { + ZirAssemblyStatement::Assignment(assignees, function) => { + let assignees = assignees + .into_iter() + .map(|a| f.fold_assignee(a)) + .collect::, _>>()?; + let function = f.fold_function(function)?; + ZirAssemblyStatement::Assignment(assignees, function) + } + ZirAssemblyStatement::Constraint(lhs, rhs) => { + let lhs = f.fold_field_expression(lhs)?; + let rhs = f.fold_field_expression(rhs)?; + ZirAssemblyStatement::Constraint(lhs, rhs) + } + }) +} pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, @@ -199,6 +226,13 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( ZirStatement::Log(l, e) } + ZirStatement::Assembly(statements) => { + let statements = statements + .into_iter() + .map(|s| f.fold_assembly_statement(s)) + .collect::, _>>()?; + ZirStatement::Assembly(statements) + } }; Ok(vec![res]) } diff --git a/zokrates_ast/src/zir/uint.rs b/zokrates_ast/src/zir/uint.rs index 5d09b428..12619b5c 100644 --- a/zokrates_ast/src/zir/uint.rs +++ b/zokrates_ast/src/zir/uint.rs @@ -1,5 +1,6 @@ use crate::zir::identifier::Identifier; use crate::zir::types::UBitwidth; +use serde::{Deserialize, Serialize}; use zokrates_field::Field; use super::{ConditionalExpression, SelectExpression}; @@ -91,7 +92,7 @@ impl<'ast, T> From for UExpression<'ast, T> { } } -#[derive(Debug, PartialEq, Eq, Clone, Hash)] +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub enum ShouldReduce { Unknown, True, @@ -135,7 +136,7 @@ impl ShouldReduce { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct UMetadata { pub max: T, pub should_reduce: ShouldReduce, @@ -162,16 +163,18 @@ impl UMetadata { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct UExpression<'ast, T> { pub bitwidth: UBitwidth, pub metadata: Option>, + #[serde(borrow)] pub inner: UExpressionInner<'ast, T>, } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum UExpressionInner<'ast, T> { Value(u128), + #[serde(borrow)] Identifier(Identifier<'ast>), Select(SelectExpression<'ast, T, UExpression<'ast, T>>), Add(Box>, Box>), diff --git a/zokrates_ast/src/zir/variable.rs b/zokrates_ast/src/zir/variable.rs index 91fda8c7..14d32972 100644 --- a/zokrates_ast/src/zir/variable.rs +++ b/zokrates_ast/src/zir/variable.rs @@ -1,9 +1,11 @@ use crate::zir::types::{Type, UBitwidth}; use crate::zir::Identifier; +use serde::{Deserialize, Serialize}; use std::fmt; -#[derive(Clone, PartialEq, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub struct Variable<'ast> { + #[serde(borrow)] pub id: Identifier<'ast>, pub _type: Type, } diff --git a/zokrates_bellman/src/groth16.rs b/zokrates_bellman/src/groth16.rs index c918ccf5..79d699ee 100644 --- a/zokrates_bellman/src/groth16.rs +++ b/zokrates_bellman/src/groth16.rs @@ -21,8 +21,8 @@ use zokrates_proof_systems::Scheme; const G16_WARNING: &str = "WARNING: You are using the G16 scheme which is subject to malleability. See zokrates.github.io/toolbox/proving_schemes.html#g16-malleability for implications."; impl Backend for Bellman { - fn generate_proof>>( - program: ProgIterator, + fn generate_proof<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, witness: Witness, proving_key: Vec, ) -> Proof { @@ -84,8 +84,8 @@ impl Backend for Bellman { } impl NonUniversalBackend for Bellman { - fn setup>>( - program: ProgIterator, + fn setup<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, ) -> SetupKeypair { println!("{}", G16_WARNING); @@ -99,8 +99,8 @@ impl NonUniversalBackend for Bellman } impl MpcBackend for Bellman { - fn initialize>>( - program: ProgIterator, + fn initialize<'a, R: Read, W: Write, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, phase1_radix: &mut R, output: &mut W, ) -> Result<(), String> { @@ -124,9 +124,9 @@ impl MpcBackend for Bellman { Ok(hash) } - fn verify>>( + fn verify<'a, P: Read, R: Read, I: IntoIterator>>( params: &mut P, - program: ProgIterator, + program: ProgIterator<'a, T, I>, phase1_radix: &mut R, ) -> Result, String> { let params = diff --git a/zokrates_bellman/src/lib.rs b/zokrates_bellman/src/lib.rs index 4bf39624..26bcf392 100644 --- a/zokrates_bellman/src/lib.rs +++ b/zokrates_bellman/src/lib.rs @@ -22,20 +22,20 @@ pub use self::parse::*; pub struct Bellman; #[derive(Clone)] -pub struct Computation>> { - program: ProgIterator, +pub struct Computation<'a, T, I: IntoIterator>> { + program: ProgIterator<'a, T, I>, witness: Option>, } -impl>> Computation { - pub fn with_witness(program: ProgIterator, witness: Witness) -> Self { +impl<'a, T: Field, I: IntoIterator>> Computation<'a, T, I> { + pub fn with_witness(program: ProgIterator<'a, T, I>, witness: Witness) -> Self { Computation { program, witness: Some(witness), } } - pub fn without_witness(program: ProgIterator) -> Self { + pub fn without_witness(program: ProgIterator<'a, T, I>) -> Self { Computation { program, witness: None, @@ -83,8 +83,8 @@ fn bellman_combination>> - Circuit for Computation +impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator>> + Circuit for Computation<'a, T, I> { fn synthesize>( self, @@ -148,7 +148,9 @@ impl>> } } -impl>> Computation { +impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator>> + Computation<'a, T, I> +{ fn get_random_seed(&self) -> Result<[u32; 8], getrandom::Error> { let mut seed = [0u8; 32]; getrandom::getrandom(&mut seed)?; diff --git a/zokrates_cli/src/ops/compute_witness.rs b/zokrates_cli/src/ops/compute_witness.rs index 865f3c9b..ab2a959b 100644 --- a/zokrates_cli/src/ops/compute_witness.rs +++ b/zokrates_cli/src/ops/compute_witness.rs @@ -85,8 +85,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } } -fn cli_compute>>( - ir_prog: ir::ProgIterator, +fn cli_compute<'a, T: Field, I: Iterator>>( + ir_prog: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Computing witness..."); diff --git a/zokrates_cli/src/ops/generate_proof.rs b/zokrates_cli/src/ops/generate_proof.rs index 2a62042a..319cf561 100644 --- a/zokrates_cli/src/ops/generate_proof.rs +++ b/zokrates_cli/src/ops/generate_proof.rs @@ -136,12 +136,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } fn cli_generate_proof< + 'a, T: Field, - I: Iterator>, + I: Iterator>, S: Scheme, B: Backend, >( - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Generating proof..."); diff --git a/zokrates_cli/src/ops/generate_smtlib2.rs b/zokrates_cli/src/ops/generate_smtlib2.rs index b1bf6f6a..ac58f8e5 100644 --- a/zokrates_cli/src/ops/generate_smtlib2.rs +++ b/zokrates_cli/src/ops/generate_smtlib2.rs @@ -47,8 +47,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } } -fn cli_smtlib2>>( - ir_prog: ir::ProgIterator, +fn cli_smtlib2<'a, T: Field, I: Iterator>>( + ir_prog: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Generating SMTLib2..."); diff --git a/zokrates_cli/src/ops/inspect.rs b/zokrates_cli/src/ops/inspect.rs index 523d664a..338c01c4 100644 --- a/zokrates_cli/src/ops/inspect.rs +++ b/zokrates_cli/src/ops/inspect.rs @@ -43,8 +43,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } } -fn cli_inspect>>( - ir_prog: ir::ProgIterator, +fn cli_inspect<'a, T: Field, I: Iterator>>( + ir_prog: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { let ir_prog: ir::Prog = ir_prog.collect(); diff --git a/zokrates_cli/src/ops/mpc/init.rs b/zokrates_cli/src/ops/mpc/init.rs index 71136366..eb7ba16e 100644 --- a/zokrates_cli/src/ops/mpc/init.rs +++ b/zokrates_cli/src/ops/mpc/init.rs @@ -58,12 +58,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } fn cli_mpc_init< + 'a, T: Field + BellmanFieldExtensions, - I: Iterator>, + I: Iterator>, S: MpcScheme, B: MpcBackend, >( - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Initializing MPC..."); diff --git a/zokrates_cli/src/ops/mpc/verify.rs b/zokrates_cli/src/ops/mpc/verify.rs index 6beca03d..fa014bd0 100644 --- a/zokrates_cli/src/ops/mpc/verify.rs +++ b/zokrates_cli/src/ops/mpc/verify.rs @@ -58,12 +58,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } fn cli_mpc_verify< + 'a, T: Field + BellmanFieldExtensions, - I: Iterator>, + I: Iterator>, S: MpcScheme, B: MpcBackend, >( - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Verifying contributions..."); diff --git a/zokrates_cli/src/ops/setup.rs b/zokrates_cli/src/ops/setup.rs index e9e8a216..0fc56953 100644 --- a/zokrates_cli/src/ops/setup.rs +++ b/zokrates_cli/src/ops/setup.rs @@ -167,12 +167,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } fn cli_setup_non_universal< + 'a, T: Field, - I: Iterator>, + I: Iterator>, S: NonUniversalScheme, B: NonUniversalBackend, >( - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Performing setup..."); @@ -211,12 +212,13 @@ fn cli_setup_non_universal< } fn cli_setup_universal< + 'a, T: Field, - I: Iterator>, + I: Iterator>, S: UniversalScheme, B: UniversalBackend, >( - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, srs: Vec, sub_matches: &ArgMatches, ) -> Result<(), String> { diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 68f35560..a50b57b3 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -25,13 +25,13 @@ use zokrates_field::Field; use zokrates_pest_ast as pest; #[derive(Debug)] -pub struct CompilationArtifacts>> { - prog: ir::ProgIterator, +pub struct CompilationArtifacts<'ast, T, I: IntoIterator>> { + prog: ir::ProgIterator<'ast, T, I>, abi: Abi, } -impl>> CompilationArtifacts { - pub fn prog(self) -> ir::ProgIterator { +impl<'ast, T, I: IntoIterator>> CompilationArtifacts<'ast, T, I> { + pub fn prog(self) -> ir::ProgIterator<'ast, T, I> { self.prog } @@ -39,11 +39,11 @@ impl>> CompilationArtifacts { &self.abi } - pub fn into_inner(self) -> (ir::ProgIterator, Abi) { + pub fn into_inner(self) -> (ir::ProgIterator<'ast, T, I>, Abi) { (self.prog, self.abi) } - pub fn collect(self) -> CompilationArtifacts>> { + pub fn collect(self) -> CompilationArtifacts<'ast, T, Vec>> { CompilationArtifacts { prog: self.prog.collect(), abi: self.abi, @@ -201,8 +201,10 @@ pub fn compile<'ast, T: Field, E: Into>( resolver: Option<&dyn Resolver>, config: CompileConfig, arena: &'ast Arena, -) -> Result> + 'ast>, CompileErrors> -{ +) -> Result< + CompilationArtifacts<'ast, T, impl IntoIterator> + 'ast>, + CompileErrors, +> { let (typed_ast, abi): (zokrates_ast::zir::ZirProgram<'_, T>, _) = check_with_arena(source, location, resolver, &config, arena)?; diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index cd2e455c..8168aee6 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -9,7 +9,8 @@ mod utils; use self::utils::flat_expression_from_bits; use zokrates_ast::zir::{ - ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirExpressionList, + ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement, + ZirExpressionList, }; use zokrates_interpreter::Interpreter; @@ -32,7 +33,7 @@ use zokrates_ast::zir::{ }; use zokrates_field::Field; -type FlatStatements = VecDeque>; +type FlatStatements<'ast, T> = VecDeque>; /// Flattens a function /// @@ -64,14 +65,14 @@ pub fn from_function_and_config( pub struct FlattenerIteratorInner<'ast, T> { pub statements: VecDeque>, - pub statements_flattened: FlatStatements, + pub statements_flattened: FlatStatements<'ast, T>, pub flattener: Flattener<'ast, T>, } -pub type FlattenerIterator<'ast, T> = FlatProgIterator>; +pub type FlattenerIterator<'ast, T> = FlatProgIterator<'ast, T, FlattenerIteratorInner<'ast, T>>; impl<'ast, T: Field> Iterator for FlattenerIteratorInner<'ast, T> { - type Item = FlatStatement; + type Item = FlatStatement<'ast, T>; fn next(&mut self) -> Option { while self.statements_flattened.is_empty() { @@ -127,7 +128,7 @@ trait Flatten<'ast, T: Field>: fn flatten( self, flattener: &mut Flattener<'ast, T>, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output; } @@ -137,7 +138,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> { fn flatten( self, flattener: &mut Flattener<'ast, T>, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output { flattener.flatten_field_expression(statements_flattened, self) } @@ -149,7 +150,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for UExpression<'ast, T> { fn flatten( self, flattener: &mut Flattener<'ast, T>, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output { flattener.flatten_uint_expression(statements_flattened, self) } @@ -161,7 +162,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> { fn flatten( self, flattener: &mut Flattener<'ast, T>, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output { flattener.flatten_boolean_expression(statements_flattened, self) } @@ -227,7 +228,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn define( &mut self, e: FlatExpression, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Variable { match e { FlatExpression::Identifier(id) => id, @@ -276,7 +277,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { #[must_use] fn constant_le_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, a: &[FlatExpression], b: &[bool], ) -> Vec> { @@ -381,7 +382,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise fn eq_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, left: FlatExpression, right: FlatExpression, ) -> FlatExpression { @@ -434,7 +435,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `b` - the big-endian bit decomposition of the upper bound of the range fn enforce_constant_le_check_bits( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, a: &[FlatExpression], c: &[bool], error: RuntimeError, @@ -464,7 +465,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `c` - the constant upper bound of the range fn enforce_constant_le_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, error: RuntimeError, @@ -500,7 +501,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `c` - the constant upper bound of the range fn enforce_constant_lt_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, error: RuntimeError, @@ -519,9 +520,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn make_conditional( &mut self, - statements: FlatStatements, + statements: FlatStatements<'ast, T>, condition: FlatExpression, - ) -> FlatStatements { + ) -> FlatStatements<'ast, T> { statements .into_iter() .flat_map(|s| match s { @@ -582,7 +583,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * U is the type of the expression fn flatten_conditional_expression>( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: ConditionalExpression<'ast, T, U>, ) -> FlatUExpression { let condition = *e.condition; @@ -680,7 +681,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * a `FlatExpression` which evaluates to `1` if `0 <= e < c`, and to `0` otherwise fn constant_lt_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, ) -> FlatExpression { @@ -704,7 +705,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * a `FlatExpression` which evaluates to `1` if `0 <= e <= c`, and to `0` otherwise fn constant_field_le_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, ) -> FlatExpression { @@ -745,7 +746,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { #[must_use] fn le_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, lhs_flattened: FlatExpression, rhs_flattened: FlatExpression, bit_width: usize, @@ -768,7 +769,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { #[must_use] fn lt_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, lhs_flattened: FlatExpression, rhs_flattened: FlatExpression, bit_width: usize, @@ -827,7 +828,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * in order to preserve composability. fn flatten_boolean_expression( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, expression: BooleanExpression<'ast, T>, ) -> FlatExpression { match expression { @@ -1033,7 +1034,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `param_expressions` - Arguments of this call fn flatten_embed_call( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, embed: FlatEmbed, generics: Vec, param_expressions: Vec>, @@ -1134,9 +1135,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn flatten_embed_call_aux( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, params: Vec>, - funct: FlatFunctionIterator>>, + funct: FlatFunctionIterator<'ast, T, impl IntoIterator>>, ) -> Vec> { let mut replacement_map = HashMap::new(); @@ -1219,7 +1220,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `expr` - `ZirExpression` that will be flattened. fn flatten_expression( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, expr: ZirExpression<'ast, T>, ) -> FlatUExpression { match expr { @@ -1235,7 +1236,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn default_xor( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, left: UExpression<'ast, T>, right: UExpression<'ast, T>, ) -> FlatUExpression { @@ -1296,7 +1297,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn euclidean_division( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, target_bitwidth: UBitwidth, left: UExpression<'ast, T>, right: UExpression<'ast, T>, @@ -1382,7 +1383,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `expr` - `UExpression` that will be flattened. fn flatten_uint_expression( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, expr: UExpression<'ast, T>, ) -> FlatUExpression { // the bitwidth for this type of uint (8, 16 or 32) @@ -1875,7 +1876,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { e: &FlatUExpression, from: usize, to: usize, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, error: RuntimeError, ) -> Vec> { assert!(from <= T::get_required_bits()); @@ -1969,7 +1970,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn flatten_select_expression>( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: SelectExpression<'ast, T, U>, ) -> FlatUExpression { let array = e.array; @@ -2033,7 +2034,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `expr` - `FieldElementExpression` that will be flattened. fn flatten_field_expression( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, expr: FieldElementExpression<'ast, T>, ) -> FlatExpression { match expr { @@ -2221,6 +2222,35 @@ impl<'ast, T: Field> Flattener<'ast, T> { } } + fn flatten_assembly_statement( + &mut self, + statements_flattened: &mut FlatStatements<'ast, T>, + stat: ZirAssemblyStatement<'ast, T>, + ) { + match stat { + ZirAssemblyStatement::Assignment(assignees, function) => { + let outputs: Vec = assignees + .iter() + .map(|a| self.use_variable(a)) /*self.layout.get(&a.id).cloned().unwrap()*/ + .collect(); + let inputs: Vec> = function + .arguments + .iter() + .cloned() + .map(|p| self.layout.get(&p.id.id).cloned().unwrap().into()) + .collect(); + let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); + statements_flattened.push_back(FlatStatement::Directive(directive)); + } + ZirAssemblyStatement::Constraint(lhs, rhs) => { + let lhs = self.flatten_field_expression(statements_flattened, lhs); + let rhs = self.flatten_field_expression(statements_flattened, rhs); + + self.flatten_equality_assertion(statements_flattened, lhs, rhs, RuntimeError::UnsatisfiedConstraint) + } + } + } + /// Flattens a statement /// /// # Arguments @@ -2229,10 +2259,15 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `stat` - `ZirStatement` that will be flattened. fn flatten_statement( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, stat: ZirStatement<'ast, T>, ) { match stat { + ZirStatement::Assembly(statements) => { + for s in statements { + self.flatten_assembly_statement(statements_flattened, s); + } + } ZirStatement::Return(exprs) => { #[allow(clippy::needless_collect)] // clippy suggests to not collect here, but `statements_flattened` is borrowed in the iterator, @@ -2633,12 +2668,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// /// # Arguments /// - /// * `statements_flattened` - `FlatStatements` Vector where new flattened statements can be added. + /// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added. /// * `lhs` - `FlatExpression` Left-hand side of the equality expression. /// * `rhs` - `FlatExpression` Right-hand side of the equality expression. fn flatten_equality_assertion( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, lhs: FlatExpression, rhs: FlatExpression, error: RuntimeError, @@ -2667,11 +2702,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// # Arguments /// /// * `e` - `FlatExpression` Expression to be assigned to an identifier. - /// * `statements_flattened` - `FlatStatements` Vector where new flattened statements can be added. + /// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added. fn identify_expression( &mut self, e: FlatExpression, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> FlatExpression { match e.is_linear() { true => e, @@ -2710,7 +2745,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn use_parameter( &mut self, parameter: &ZirParameter<'ast>, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Parameter { let variable = self.use_variable(¶meter.id); diff --git a/zokrates_core/src/optimizer/canonicalizer.rs b/zokrates_core/src/optimizer/canonicalizer.rs index 4a65bc85..57810aed 100644 --- a/zokrates_core/src/optimizer/canonicalizer.rs +++ b/zokrates_core/src/optimizer/canonicalizer.rs @@ -4,7 +4,7 @@ use zokrates_field::Field; #[derive(Default)] pub struct Canonicalizer; -impl Folder for Canonicalizer { +impl<'ast, T: Field> Folder<'ast, T> for Canonicalizer { fn fold_linear_combination(&mut self, l: LinComb) -> LinComb { l.into_canonical().into() } diff --git a/zokrates_core/src/optimizer/directive.rs b/zokrates_core/src/optimizer/directive.rs index afabc87b..b216bef6 100644 --- a/zokrates_core/src/optimizer/directive.rs +++ b/zokrates_core/src/optimizer/directive.rs @@ -15,18 +15,18 @@ use zokrates_ast::ir::*; use zokrates_field::Field; #[derive(Debug, Default)] -pub struct DirectiveOptimizer { - calls: HashMap<(Solver, Vec>), Vec>, +pub struct DirectiveOptimizer<'ast, T> { + calls: HashMap<(Solver<'ast, T>, Vec>), Vec>, /// Map of renamings for reassigned variables while processing the program. substitution: HashMap, } -impl Folder for DirectiveOptimizer { +impl<'ast, T: Field> Folder<'ast, T> for DirectiveOptimizer<'ast, T> { fn fold_variable(&mut self, v: Variable) -> Variable { *self.substitution.get(&v).unwrap_or(&v) } - fn fold_statement(&mut self, s: Statement) -> Vec> { + fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { match s { Statement::Directive(d) => { let d = self.fold_directive(d); diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index 68c95409..664cfc2d 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -21,8 +21,8 @@ pub struct DuplicateOptimizer { seen: HashSet, } -impl Folder for DuplicateOptimizer { - fn fold_program(&mut self, p: Prog) -> Prog { +impl<'ast, T: Field> Folder<'ast, T> for DuplicateOptimizer { + fn fold_program(&mut self, p: Prog<'ast, T>) -> Prog<'ast, T> { // in order to correctly identify duplicates, we need to first canonicalize the statements let mut canonicalizer = Canonicalizer; @@ -38,7 +38,7 @@ impl Folder for DuplicateOptimizer { fold_program(self, p) } - fn fold_statement(&mut self, s: Statement) -> Vec> { + fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { let hashed = hash(&s); let result = match self.seen.get(&hashed) { Some(_) => vec![], diff --git a/zokrates_core/src/optimizer/mod.rs b/zokrates_core/src/optimizer/mod.rs index cecee2e3..1f94740a 100644 --- a/zokrates_core/src/optimizer/mod.rs +++ b/zokrates_core/src/optimizer/mod.rs @@ -19,9 +19,9 @@ use self::tautology::TautologyOptimizer; use zokrates_ast::ir::{ProgIterator, Statement}; use zokrates_field::Field; -pub fn optimize>>( - p: ProgIterator, -) -> ProgIterator>> { +pub fn optimize<'ast, T: Field, I: IntoIterator>>( + p: ProgIterator<'ast, T, I>, +) -> ProgIterator<'ast, T, impl IntoIterator>> { // remove redefinitions log::debug!("Optimizer: Remove redefinitions and tautologies and directives and duplicates"); diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index e853be11..8f0e2447 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -53,7 +53,9 @@ pub struct RedefinitionOptimizer { } impl RedefinitionOptimizer { - pub fn init>>(p: &ProgIterator) -> Self { + pub fn init<'ast, I: IntoIterator>>( + p: &ProgIterator<'ast, T, I>, + ) -> Self { RedefinitionOptimizer { substitution: HashMap::new(), ignore: vec![Variable::one()] @@ -66,8 +68,8 @@ impl RedefinitionOptimizer { } } -impl Folder for RedefinitionOptimizer { - fn fold_statement(&mut self, s: Statement) -> Vec> { +impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer { + fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { match s { Statement::Constraint(quad, lin, message) => { let quad = self.fold_quadratic_combination(quad); diff --git a/zokrates_core/src/optimizer/tautology.rs b/zokrates_core/src/optimizer/tautology.rs index 4a9ce847..855efa11 100644 --- a/zokrates_core/src/optimizer/tautology.rs +++ b/zokrates_core/src/optimizer/tautology.rs @@ -13,8 +13,8 @@ use zokrates_field::Field; #[derive(Default)] pub struct TautologyOptimizer; -impl Folder for TautologyOptimizer { - fn fold_statement(&mut self, s: Statement) -> Vec> { +impl<'ast, T: Field> Folder<'ast, T> for TautologyOptimizer { + fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { match s { Statement::Constraint(quad, lin, message) => match quad.try_linear() { Ok(l) => { diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 8e27f27f..baba3ab2 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -699,7 +699,7 @@ impl<'ast, T: Field> Checker<'ast, T> { is_mutable: false, }; assert_eq!(self.scope.level, 0); - assert!(!self.scope.insert(id, info)); + assert!(!self.scope.insert(id.into(), info)); assert!(state .constants .entry(module_id.to_path_buf()) @@ -895,7 +895,7 @@ impl<'ast, T: Field> Checker<'ast, T> { is_mutable: false, }; assert_eq!(self.scope.level, 0); - assert!(!self.scope.insert(id, info)); + assert!(!self.scope.insert(id.into(), info)); state .constants @@ -1130,12 +1130,12 @@ impl<'ast, T: Field> Checker<'ast, T> { // for declaration signatures, generics cannot be ignored generics.0.insert( generic.clone(), - UExpressionInner::Identifier(self.id_in_this_scope(generic.name()).into()) + UExpressionInner::Identifier(self.id_in_this_scope(generic.name().into()).into()) .annotate(UBitwidth::B32), ); //we don't have to check for conflicts here, because this was done when checking the signature - self.insert_into_scope(generic.name(), Type::Uint(UBitwidth::B32), false); + self.insert_into_scope(generic.name().into(), Type::Uint(UBitwidth::B32), false); } for (arg, decl_ty) in funct.arguments.into_iter().zip(s.inputs.iter()) { @@ -1144,7 +1144,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let arg = arg.value; let decl_v = DeclarationVariable::new( - self.id_in_this_scope(arg.id.value.id), + self.id_in_this_scope(arg.id.value.id.into()), decl_ty.clone(), arg.id.value.is_mutable, ); @@ -1161,7 +1161,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ty, is_mutable, }; - match self.scope.insert(id, info) { + match self.scope.insert(id.into(), info) { false => {} true => { errors.push(ErrorInner { @@ -1651,10 +1651,10 @@ impl<'ast, T: Field> Checker<'ast, T> { .map_err(|e| vec![e])?; // insert into the scope and ignore whether shadowing happened - self.insert_into_scope(v.value.id, ty.clone(), v.value.is_mutable); + self.insert_into_scope(v.value.id.into(), ty.clone(), v.value.is_mutable); Ok(Variable::new( - self.id_in_this_scope(v.value.id), + self.id_in_this_scope(v.value.id.into()), ty, v.value.is_mutable, )) @@ -1770,6 +1770,87 @@ impl<'ast, T: Field> Checker<'ast, T> { } } + fn check_assembly_statement( + &mut self, + stat: AssemblyStatementNode<'ast>, + module_id: &ModuleId, + types: &TypeMap<'ast, T>, + ) -> Result>, ErrorInner> { + let pos = stat.pos(); + + match stat.value { + AssemblyStatement::Assignment(assignee, expression, constrained) => { + let assignee = self.check_assignee(assignee, module_id, types)?; + let checked_e = self.check_expression(expression, module_id, types)?; + + let e = match checked_e { + TypedExpression::FieldElement(e) => Ok(e), + TypedExpression::Int(e) => Ok(FieldElementExpression::try_from_int(e).unwrap()), // todo: handle properly + _ => Err(ErrorInner { + pos: Some(pos), + message: "Only field element expressions are allowed in the assembly" + .to_string(), + }), + }?; + + let e = FieldElementExpression::block(vec![], e); + match constrained { + true => { + if !e.is_quadratic() { + return Err(ErrorInner { + pos: Some(pos), + message: "Non quadratic constraints are not allowed".to_string(), + }); + } + match assignee.get_type() { + Type::FieldElement => Ok(vec![ + TypedAssemblyStatement::Assignment(assignee.clone(), e.clone()), + TypedAssemblyStatement::Constraint(assignee.into(), e), + ]), + _ => Err(ErrorInner { + pos: Some(pos), + message: "Assignee must be of type `field`".to_string(), + }), + } + } + false => Ok(vec![TypedAssemblyStatement::Assignment(assignee, e)]), + } + } + AssemblyStatement::Constraint(lhs, rhs) => { + let lhs = self.check_expression(lhs, module_id, types)?; + let rhs = self.check_expression(rhs, module_id, types)?; + match (lhs, rhs) { + (TypedExpression::FieldElement(lhs), TypedExpression::FieldElement(rhs)) => { + Ok(vec![TypedAssemblyStatement::Constraint(lhs, rhs)]) + } + (TypedExpression::FieldElement(lhs), TypedExpression::Int(rhs)) => { + Ok(vec![TypedAssemblyStatement::Constraint( + lhs, + FieldElementExpression::try_from_int(rhs).unwrap(), + )]) + } + (TypedExpression::Int(lhs), TypedExpression::FieldElement(rhs)) => { + Ok(vec![TypedAssemblyStatement::Constraint( + FieldElementExpression::try_from_int(lhs).unwrap(), + rhs, + )]) + } + (TypedExpression::Int(lhs), TypedExpression::Int(rhs)) => { + Ok(vec![TypedAssemblyStatement::Constraint( + FieldElementExpression::try_from_int(lhs).unwrap(), + FieldElementExpression::try_from_int(rhs).unwrap(), + )]) + } + _ => Err(ErrorInner { + pos: Some(pos), + message: "Only field element expressions are allowed in the assembly" + .to_string(), + }), + } + } + } + } + fn check_statement( &mut self, stat: StatementNode<'ast>, @@ -1779,6 +1860,18 @@ impl<'ast, T: Field> Checker<'ast, T> { let pos = stat.pos(); match stat.value { + Statement::Assembly(statements) => { + let mut checked_statements = vec![]; + for s in statements { + checked_statements.push( + self.check_assembly_statement(s, module_id, types) + .map_err(|e| vec![e])?, + ); + } + Ok(TypedStatement::Assembly( + checked_statements.into_iter().flatten().collect(), + )) + } Statement::Log(l, expressions) => { let l = FormatString::from(l); @@ -1901,10 +1994,10 @@ impl<'ast, T: Field> Checker<'ast, T> { .map_err(|e| vec![e])?; // insert the lhs into the scope and ignore whether shadowing happened - self.insert_into_scope(var.value.id, var_ty.clone(), var.value.is_mutable); + self.insert_into_scope(var.value.id.into(), var_ty.clone(), var.value.is_mutable); let var = Variable::new( - self.id_in_this_scope(var.value.id), + self.id_in_this_scope(var.value.id.into()), var_ty.clone(), var.value.is_mutable, ); @@ -2037,7 +2130,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let pos = assignee.pos(); // check that the assignee is declared match assignee.value { - Assignee::Identifier(variable_name) => match self.scope.get(&variable_name) { + Assignee::Identifier(variable_name) => match self.scope.get(&variable_name.into()) { Some(info) => match info.is_mutable { false => Err(ErrorInner { pos: Some(assignee.pos()), @@ -2346,7 +2439,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Expression::BooleanConstant(b) => Ok(BooleanExpression::Value(b).into()), Expression::Identifier(name) => { // check that `id` is defined in the scope - match self.scope.get(&name) { + match self.scope.get(&name.into()) { Some(info) => { let id = info.id; match info.ty.clone() { @@ -3615,11 +3708,11 @@ impl<'ast, T: Field> Checker<'ast, T> { is_mutable: bool, ) -> bool { let info = IdentifierInfo { - id: self.id_in_this_scope(id), + id: self.id_in_this_scope(id.clone()), ty, is_mutable, }; - self.scope.insert(id, info) + self.scope.insert(id.into(), info) } fn find_functions( diff --git a/zokrates_core/src/static_analysis/flat_propagation.rs b/zokrates_core/src/static_analysis/flat_propagation.rs index f69b9313..155d803c 100644 --- a/zokrates_core/src/static_analysis/flat_propagation.rs +++ b/zokrates_core/src/static_analysis/flat_propagation.rs @@ -14,8 +14,8 @@ struct Propagator { constants: HashMap, } -impl Folder for Propagator { - fn fold_statement(&mut self, s: FlatStatement) -> Vec> { +impl<'ast, T: Field> Folder<'ast, T> for Propagator { + fn fold_statement(&mut self, s: FlatStatement<'ast, T>) -> Vec> { match s { FlatStatement::Definition(var, expr) => match self.fold_expression(expr) { FlatExpression::Number(n) => { diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index c4683e4e..2964ff06 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -1,7 +1,8 @@ +use std::collections::HashSet; use std::marker::PhantomData; use zokrates_ast::typed::types::UBitwidth; use zokrates_ast::typed::{self, Expr, Typed}; -use zokrates_ast::zir::{self, Select}; +use zokrates_ast::zir::{self, Folder, Select}; use zokrates_field::Field; use std::convert::{TryFrom, TryInto}; @@ -224,6 +225,14 @@ impl<'ast, T: Field> Flattener { } } + fn fold_assembly_statement( + &mut self, + statements_buffer: &mut Vec>, + s: typed::TypedAssemblyStatement<'ast, T>, + ) -> zir::ZirAssemblyStatement<'ast, T> { + fold_assembly_statement(self, statements_buffer, s) + } + fn fold_statement( &mut self, statements_buffer: &mut Vec>, @@ -393,12 +402,102 @@ impl<'ast, T: Field> Flattener { } } +#[derive(Default)] +pub struct ArgumentFinder<'ast, T> { + pub identifiers: HashSet>, + _phantom: PhantomData, +} + +impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> { + fn fold_name(&mut self, n: zir::Identifier<'ast>) -> zir::Identifier<'ast> { + self.identifiers.insert(n.clone()); + n + } + fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec> { + match s { + zir::ZirStatement::Definition(assignee, expr) => { + let assignee = self.fold_assignee(assignee); + let expr = self.fold_expression(expr); + self.identifiers.remove(&assignee.id); + vec![zir::ZirStatement::Definition(assignee, expr)] + } + zir::ZirStatement::MultipleDefinition(assignees, list) => { + let assignees: Vec> = assignees + .into_iter() + .map(|v| self.fold_assignee(v)) + .collect(); + let list = self.fold_expression_list(list); + for a in &assignees { + self.identifiers.remove(&a.id); + } + vec![zir::ZirStatement::MultipleDefinition(assignees, list)] + } + s => zir::folder::fold_statement(self, s), + } + } +} + +fn fold_assembly_statement<'ast, T: Field>( + f: &mut Flattener, + statements_buffer: &mut Vec>, + s: typed::TypedAssemblyStatement<'ast, T>, +) -> zir::ZirAssemblyStatement<'ast, T> { + match s { + typed::TypedAssemblyStatement::Assignment(a, e) => { + let mut statements_buffer: Vec> = vec![]; + let a = f.fold_assignee(a); + let e = f.fold_field_expression(&mut statements_buffer, e); + statements_buffer.push(zir::ZirStatement::Return(vec![ + zir::ZirExpression::FieldElement(e), + ])); + + let mut finder = ArgumentFinder::default(); + let mut statements_buffer: Vec> = statements_buffer + .into_iter() + .rev() + .map(|s| finder.fold_statement(s)) + .flatten() + .collect(); + statements_buffer.reverse(); + + let function = zir::ZirFunction { + signature: zir::types::Signature::default() + .inputs(vec![zir::Type::FieldElement; finder.identifiers.len()]) + .outputs(a.iter().map(|a| a.get_type()).collect()), + arguments: finder + .identifiers + .into_iter() + .map(|id| zir::Parameter { + id: zir::Variable::field_element(id), + private: false, + }) + .collect(), + statements: statements_buffer, + }; + + zir::ZirAssemblyStatement::Assignment(a, function) + } + typed::TypedAssemblyStatement::Constraint(lhs, rhs) => { + let lhs = f.fold_field_expression(statements_buffer, lhs); + let rhs = f.fold_field_expression(statements_buffer, rhs); + zir::ZirAssemblyStatement::Constraint(lhs, rhs) + } + } +} + fn fold_statement<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, s: typed::TypedStatement<'ast, T>, ) { let res = match s { + typed::TypedStatement::Assembly(statements) => { + let statements = statements + .into_iter() + .map(|s| f.fold_assembly_statement(statements_buffer, s)) + .collect(); + vec![zir::ZirStatement::Assembly(statements)] + } typed::TypedStatement::Return(expression) => vec![zir::ZirStatement::Return( f.fold_expression(statements_buffer, expression), )], diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 311eec21..f2b31ea9 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -220,6 +220,38 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { s: TypedStatement<'ast, T>, ) -> Result>, Error> { match s { + TypedStatement::Assembly(statements) => { + let mut assembly_statement_buffer = vec![]; + let mut statement_buffer = vec![]; + + for s in statements { + match self.fold_assembly_statement(s)? { + TypedAssemblyStatement::Assignment(assignee, expr) => { + // invalidate the cache + let v = self + .try_get_constant_mut(&assignee) + .map(|(v, _)| v) + .unwrap_or_else(|v| v); + + match self.constants.remove(&v.id) { + Some(c) => { + statement_buffer.push(TypedStatement::Definition( + v.clone().into(), + c.into(), + )); + } + None => {} + } + assembly_statement_buffer + .push(TypedAssemblyStatement::Assignment(assignee, expr)); + } + s => assembly_statement_buffer.push(s), + } + } + + statement_buffer.push(TypedStatement::Assembly(assembly_statement_buffer)); + Ok(statement_buffer) + } // propagation to the defined variable if rhs is a constant TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => { let assignee = self.fold_assignee(assignee)?; diff --git a/zokrates_core/src/static_analysis/zir_propagation.rs b/zokrates_core/src/static_analysis/zir_propagation.rs index 72bbdf1f..d573f8a0 100644 --- a/zokrates_core/src/static_analysis/zir_propagation.rs +++ b/zokrates_core/src/static_analysis/zir_propagation.rs @@ -3,7 +3,7 @@ use std::fmt; use zokrates_ast::zir::types::UBitwidth; use zokrates_ast::zir::{ result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, - SelectExpression, SelectOrExpression, + SelectExpression, SelectOrExpression, ZirAssemblyStatement, }; use zokrates_ast::zir::{ BooleanExpression, FieldElementExpression, Identifier, RuntimeError, UExpression, @@ -42,6 +42,9 @@ pub struct ZirPropagator<'ast, T> { } impl<'ast, T: Field> ZirPropagator<'ast, T> { + pub fn with_constants(constants: Constants<'ast, T>) -> Self { + Self { constants } + } pub fn propagate(p: ZirProgram) -> Result, Error> { ZirPropagator::default().fold_program(p) } @@ -50,6 +53,24 @@ impl<'ast, T: Field> ZirPropagator<'ast, T> { impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { type Error = Error; + fn fold_assembly_statement( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> Result, Self::Error> { + match s { + ZirAssemblyStatement::Assignment(assignees, function) => { + for a in &assignees { + self.constants.remove(&a.id); + } + Ok(ZirAssemblyStatement::Assignment( + assignees, + self.fold_function(function)?, + )) + } + s => fold_assembly_statement(self, s), + } + } + fn fold_statement( &mut self, s: ZirStatement<'ast, T>, diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 3776d84f..f24cad3b 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -24,21 +24,22 @@ impl Interpreter { } impl Interpreter { - pub fn execute>>( + pub fn execute<'ast, T: Field, I: IntoIterator>>( &self, - program: ProgIterator, + program: ProgIterator<'ast, T, I>, inputs: &[T], ) -> ExecutionResult { self.execute_with_log_stream(program, inputs, &mut std::io::sink()) } pub fn execute_with_log_stream< + 'ast, W: std::io::Write, T: Field, - I: IntoIterator>, + I: IntoIterator>, >( &self, - program: ProgIterator, + program: ProgIterator<'ast, T, I>, inputs: &[T], log_stream: &mut W, ) -> ExecutionResult { @@ -142,9 +143,9 @@ impl Interpreter { .collect() } - fn check_inputs>, U>( + fn check_inputs<'ast, T: Field, I: IntoIterator>, U>( &self, - program: &ProgIterator, + program: &ProgIterator<'ast, T, I>, inputs: &[U], ) -> Result<(), Error> { if program.arguments.len() == inputs.len() { @@ -157,11 +158,18 @@ impl Interpreter { } } - pub fn execute_solver(solver: &Solver, inputs: &[T]) -> Result, String> { + pub fn execute_solver<'ast, T: Field>( + solver: &Solver<'ast, T>, + inputs: &[T], + ) -> Result, String> { let (expected_input_count, expected_output_count) = solver.get_signature(); assert_eq!(inputs.len(), expected_input_count); let res = match solver { + Solver::Zir(func) => { + // TODO: implement evaluation of the function + vec![inputs[1].checked_div(&inputs[0]).unwrap()] + } Solver::ConditionEq => match inputs[0].is_zero() { true => vec![T::zero(), T::one()], false => vec![ diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index cf24252e..e2c6715d 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -349,13 +349,14 @@ mod internal { } pub fn setup_universal< + 'a, T: Field, - I: IntoIterator>, + I: IntoIterator>, S: UniversalScheme + Serialize, B: UniversalBackend, >( srs: &[u8], - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, ) -> Result { let keypair = B::setup(srs.to_vec(), program).map_err(|e| JsValue::from_str(&e))?; Ok(JsValue::from_serde(&TaggedKeypair::::new(keypair)).unwrap()) diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index 67ab2ec2..d0319d13 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -52,7 +52,7 @@ _mut = {"mut"} // Statements -statement = { (iteration_statement // does not require semicolon +statement = { (iteration_statement | asm_statement // does not require semicolon | ((log_statement |return_statement | definition_statement @@ -66,6 +66,15 @@ return_statement = { "return" ~ expression? } definition_statement = { typed_identifier_or_assignee ~ "=" ~ expression } assertion_statement = {"assert" ~ "(" ~ expression ~ ("," ~ quoted_string)? ~ ")"} +op_asm_assign = @{"<--"} +op_asm_assign_constrain = @{"<=="} + +asm_assignment = { assignee ~ (op_asm_assign | op_asm_assign_constrain) ~ expression } +asm_constraint = { expression ~ "===" ~ expression } + +asm_statement_inner = { (asm_assignment | asm_constraint) ~ semicolon ~ NEWLINE* } +asm_statement = { "asm" ~ "{" ~ NEWLINE* ~ asm_statement_inner* ~ NEWLINE* ~ "}" } + typed_identifier_or_assignee = { typed_identifier | assignee } // Expressions diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 0564e13b..7e241886 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -8,11 +8,12 @@ use zokrates_parser::Rule; extern crate lazy_static; pub use ast::{ - Access, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, - Assignee, AssigneeAccess, BasicOrStructOrTupleType, BasicType, BinaryExpression, - BinaryOperator, CallAccess, ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, - DecimalNumber, DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, - File, FromExpression, FunctionDefinition, HexLiteralExpression, HexNumberExpression, + Access, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType, AssemblyStatement, + AssemblyStatementInner, AssertionStatement, Assignee, AssigneeAccess, AssignmentOperator, + BasicOrStructOrTupleType, BasicType, BinaryExpression, BinaryOperator, CallAccess, + ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber, + DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, File, + FromExpression, FunctionDefinition, HexLiteralExpression, HexNumberExpression, IdentifierExpression, IdentifierOrDecimal, IfElseExpression, ImportDirective, ImportSymbol, InlineArrayExpression, InlineStructExpression, InlineStructMember, InlineTupleExpression, IterationStatement, LiteralExpression, LogStatement, Parameter, PostfixExpression, Range, @@ -366,6 +367,7 @@ mod ast { Assertion(AssertionStatement<'ast>), Iteration(IterationStatement<'ast>), Log(LogStatement<'ast>), + Assembly(AssemblyStatement<'ast>), } #[derive(Debug, FromPest, PartialEq, Clone)] @@ -423,6 +425,72 @@ mod ast { pub span: Span<'ast>, } + // #[derive(Debug, FromPest, PartialEq, Eq, Clone)] + // #[pest_ast(rule(Rule::op_asm_assign))] + // pub struct AssemblyAssignOperator; + // + // #[derive(Debug, FromPest, PartialEq, Eq, Clone)] + // #[pest_ast(rule(Rule::op_asm_assign_constrain))] + // pub struct AssemblyAssignConstrainOperator; + // + // #[derive(Debug, FromPest, PartialEq, Eq, Clone)] + // #[pest_ast(rule(Rule::op_asm_constrain))] + // pub struct AssemblyConstrainOperator; + + #[derive(Debug, PartialEq, Clone)] + pub enum AssignmentOperator { + Assign, + AssignConstrain, + } + + impl<'ast> FromPest<'ast> for AssignmentOperator { + type Rule = Rule; + type FatalError = Void; + + fn from_pest(pest: &mut Pairs<'ast, Rule>) -> Result> { + let pair = pest.next().ok_or(::from_pest::ConversionError::NoMatch)?; + match pair.as_rule() { + Rule::op_asm_assign => Ok(AssignmentOperator::Assign), + Rule::op_asm_assign_constrain => Ok(AssignmentOperator::AssignConstrain), + _ => Err(ConversionError::NoMatch), + } + } + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::asm_assignment))] + pub struct AssemblyAssignment<'ast> { + pub assignee: Assignee<'ast>, + pub operator: AssignmentOperator, + pub expression: Expression<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::asm_constraint))] + pub struct AssemblyConstraint<'ast> { + pub lhs: Expression<'ast>, + pub rhs: Expression<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::asm_statement_inner))] + pub enum AssemblyStatementInner<'ast> { + Assignment(AssemblyAssignment<'ast>), + Constraint(AssemblyConstraint<'ast>), + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::asm_statement))] + pub struct AssemblyStatement<'ast> { + pub inner: Vec>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, PartialEq, Eq, Clone)] pub enum BinaryOperator { BitXor, diff --git a/zokrates_proof_systems/src/lib.rs b/zokrates_proof_systems/src/lib.rs index 231fbeee..4184f6c9 100644 --- a/zokrates_proof_systems/src/lib.rs +++ b/zokrates_proof_systems/src/lib.rs @@ -96,8 +96,8 @@ impl ToString for G2AffineFq2 { } pub trait Backend> { - fn generate_proof>>( - program: ir::ProgIterator, + fn generate_proof<'a, I: IntoIterator>>( + program: ir::ProgIterator<'a, T, I>, witness: ir::Witness, proving_key: Vec, ) -> Proof; @@ -105,36 +105,36 @@ pub trait Backend> { fn verify(vk: S::VerificationKey, proof: Proof) -> bool; } pub trait NonUniversalBackend>: Backend { - fn setup>>( - program: ir::ProgIterator, + fn setup<'a, I: IntoIterator>>( + program: ir::ProgIterator<'a, T, I>, ) -> SetupKeypair; } pub trait UniversalBackend>: Backend { fn universal_setup(size: u32) -> Vec; - fn setup>>( + fn setup<'a, I: IntoIterator>>( srs: Vec, - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, ) -> Result, String>; } pub trait MpcBackend> { - fn initialize>>( - program: ir::ProgIterator, + fn initialize<'a, R: Read, W: Write, I: IntoIterator>>( + program: ir::ProgIterator<'a, T, I>, phase1_radix: &mut R, output: &mut W, ) -> Result<(), String>; - fn contribute( + fn contribute<'a, R: Read, W: Write, G: Rng>( params: &mut R, rng: &mut G, output: &mut W, ) -> Result<[u8; 64], String>; - fn verify>>( + fn verify<'a, P: Read, R: Read, I: IntoIterator>>( params: &mut P, - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, phase1_radix: &mut R, ) -> Result, String>;