1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00
This commit is contained in:
dark64 2022-10-05 15:27:45 +02:00
parent 90fe8494b0
commit c869649745
57 changed files with 1112 additions and 303 deletions

View file

@ -16,8 +16,8 @@ use zokrates_proof_systems::Scheme;
use zokrates_proof_systems::{Backend, NonUniversalBackend, Proof, SetupKeypair}; use zokrates_proof_systems::{Backend, NonUniversalBackend, Proof, SetupKeypair};
impl<T: Field + ArkFieldExtensions> NonUniversalBackend<T, GM17> for Ark { impl<T: Field + ArkFieldExtensions> NonUniversalBackend<T, GM17> for Ark {
fn setup<I: IntoIterator<Item = Statement<T>>>( fn setup<'a, I: IntoIterator<Item = Statement<'a, T>>>(
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
) -> SetupKeypair<T, GM17> { ) -> SetupKeypair<T, GM17> {
let computation = Computation::without_witness(program); let computation = Computation::without_witness(program);
@ -41,8 +41,8 @@ impl<T: Field + ArkFieldExtensions> NonUniversalBackend<T, GM17> for Ark {
} }
impl<T: Field + ArkFieldExtensions> Backend<T, GM17> for Ark { impl<T: Field + ArkFieldExtensions> Backend<T, GM17> for Ark {
fn generate_proof<I: IntoIterator<Item = Statement<T>>>( fn generate_proof<'a, I: IntoIterator<Item = Statement<'a, T>>>(
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
witness: Witness<T>, witness: Witness<T>,
proving_key: Vec<u8>, proving_key: Vec<u8>,
) -> Proof<T, GM17> { ) -> Proof<T, GM17> {

View file

@ -19,8 +19,8 @@ use zokrates_proof_systems::Scheme;
const G16_WARNING: &str = "WARNING: You are using the G16 scheme which is subject to malleability. See zokrates.github.io/toolbox/proving_schemes.html#g16-malleability for implications."; const G16_WARNING: &str = "WARNING: You are using the G16 scheme which is subject to malleability. See zokrates.github.io/toolbox/proving_schemes.html#g16-malleability for implications.";
impl<T: Field + ArkFieldExtensions> Backend<T, G16> for Ark { impl<T: Field + ArkFieldExtensions> Backend<T, G16> for Ark {
fn generate_proof<I: IntoIterator<Item = Statement<T>>>( fn generate_proof<'a, I: IntoIterator<Item = Statement<'a, T>>>(
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
witness: Witness<T>, witness: Witness<T>,
proving_key: Vec<u8>, proving_key: Vec<u8>,
) -> Proof<T, G16> { ) -> Proof<T, G16> {
@ -86,8 +86,8 @@ impl<T: Field + ArkFieldExtensions> Backend<T, G16> for Ark {
} }
impl<T: Field + ArkFieldExtensions> NonUniversalBackend<T, G16> for Ark { impl<T: Field + ArkFieldExtensions> NonUniversalBackend<T, G16> for Ark {
fn setup<I: IntoIterator<Item = Statement<T>>>( fn setup<'a, I: IntoIterator<Item = Statement<'a, T>>>(
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
) -> SetupKeypair<T, G16> { ) -> SetupKeypair<T, G16> {
println!("{}", G16_WARNING); println!("{}", G16_WARNING);

View file

@ -17,20 +17,20 @@ pub use self::parse::*;
pub struct Ark; pub struct Ark;
#[derive(Clone)] #[derive(Clone)]
pub struct Computation<T, I: IntoIterator<Item = Statement<T>>> { pub struct Computation<'a, T, I: IntoIterator<Item = Statement<'a, T>>> {
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
witness: Option<Witness<T>>, witness: Option<Witness<T>>,
} }
impl<T, I: IntoIterator<Item = Statement<T>>> Computation<T, I> { impl<'a, T, I: IntoIterator<Item = Statement<'a, T>>> Computation<'a, T, I> {
pub fn with_witness(program: ProgIterator<T, I>, witness: Witness<T>) -> Self { pub fn with_witness(program: ProgIterator<'a, T, I>, witness: Witness<T>) -> Self {
Computation { Computation {
program, program,
witness: Some(witness), witness: Some(witness),
} }
} }
pub fn without_witness(program: ProgIterator<T, I>) -> Self { pub fn without_witness(program: ProgIterator<'a, T, I>) -> Self {
Computation { Computation {
program, program,
witness: None, witness: None,
@ -72,9 +72,9 @@ fn ark_combination<T: Field + ArkFieldExtensions>(
.fold(LinearCombination::zero(), |acc, e| acc + e) .fold(LinearCombination::zero(), |acc, e| acc + e)
} }
impl<T: Field + ArkFieldExtensions, I: IntoIterator<Item = Statement<T>>> impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator<Item = Statement<'a, T>>>
ConstraintSynthesizer<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr> ConstraintSynthesizer<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>
for Computation<T, I> for Computation<'a, T, I>
{ {
fn generate_constraints( fn generate_constraints(
self, self,
@ -143,7 +143,9 @@ impl<T: Field + ArkFieldExtensions, I: IntoIterator<Item = Statement<T>>>
} }
} }
impl<T: Field + ArkFieldExtensions, I: IntoIterator<Item = Statement<T>>> Computation<T, I> { impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator<Item = Statement<'a, T>>>
Computation<'a, T, I>
{
pub fn public_inputs_values(&self) -> Vec<<T::ArkEngine as PairingEngine>::Fr> { pub fn public_inputs_values(&self) -> Vec<<T::ArkEngine as PairingEngine>::Fr> {
self.program self.program
.public_inputs_values(self.witness.as_ref().unwrap()) .public_inputs_values(self.witness.as_ref().unwrap())

View file

@ -134,9 +134,9 @@ impl<T: Field + ArkFieldExtensions> UniversalBackend<T, marlin::Marlin> for Ark
res res
} }
fn setup<I: IntoIterator<Item = Statement<T>>>( fn setup<'a, I: IntoIterator<Item = Statement<'a, T>>>(
srs: Vec<u8>, srs: Vec<u8>,
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
) -> Result<SetupKeypair<T, marlin::Marlin>, String> { ) -> Result<SetupKeypair<T, marlin::Marlin>, String> {
let program = program.collect(); let program = program.collect();
@ -210,8 +210,8 @@ impl<T: Field + ArkFieldExtensions> UniversalBackend<T, marlin::Marlin> for Ark
} }
impl<T: Field + ArkFieldExtensions> Backend<T, marlin::Marlin> for Ark { impl<T: Field + ArkFieldExtensions> Backend<T, marlin::Marlin> for Ark {
fn generate_proof<I: IntoIterator<Item = Statement<T>>>( fn generate_proof<'a, I: IntoIterator<Item = Statement<'a, T>>>(
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
witness: Witness<T>, witness: Witness<T>,
proving_key: Vec<u8>, proving_key: Vec<u8>,
) -> Proof<T, marlin::Marlin> { ) -> Proof<T, marlin::Marlin> {

View file

@ -9,6 +9,7 @@ use crate::untyped::{
types::{UnresolvedSignature, UnresolvedType}, types::{UnresolvedSignature, UnresolvedType},
ConstantGenericNode, Expression, ConstantGenericNode, Expression,
}; };
use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use zokrates_field::Field; use zokrates_field::Field;
@ -28,7 +29,7 @@ cfg_if::cfg_if! {
/// A low level function that contains non-deterministic introduction of variables. It is carried out as is until /// A low level function that contains non-deterministic introduction of variables. It is carried out as is until
/// the flattening step when it can be inlined. /// the flattening step when it can be inlined.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Serialize, Deserialize)]
pub enum FlatEmbed { pub enum FlatEmbed {
BitArrayLe, BitArrayLe,
Unpack, Unpack,
@ -317,8 +318,8 @@ impl FlatEmbed {
/// - constraint system variables /// - constraint system variables
/// - arguments /// - arguments
#[cfg(feature = "bellman")] #[cfg(feature = "bellman")]
pub fn sha256_round<T: Field>( pub fn sha256_round<'ast, T: Field>(
) -> FlatFunctionIterator<T, impl IntoIterator<Item = FlatStatement<T>>> { ) -> FlatFunctionIterator<'ast, T, impl IntoIterator<Item = FlatStatement<'ast, T>>> {
use zokrates_field::Bn128Field; use zokrates_field::Bn128Field;
assert_eq!(T::id(), Bn128Field::id()); assert_eq!(T::id(), Bn128Field::id());
@ -420,9 +421,9 @@ pub fn sha256_round<T: Field>(
} }
#[cfg(feature = "ark")] #[cfg(feature = "ark")]
pub fn snark_verify_bls12_377<T: Field>( pub fn snark_verify_bls12_377<'ast, T: Field>(
n: usize, n: usize,
) -> FlatFunctionIterator<T, impl IntoIterator<Item = FlatStatement<T>>> { ) -> FlatFunctionIterator<'ast, T, impl IntoIterator<Item = FlatStatement<'ast, T>>> {
use zokrates_field::Bw6_761Field; use zokrates_field::Bw6_761Field;
assert_eq!(T::id(), Bw6_761Field::id()); assert_eq!(T::id(), Bw6_761Field::id());
@ -546,9 +547,9 @@ fn use_variable(
/// # Remarks /// # Remarks
/// * the return value of the `FlatFunction` is not deterministic if `bit_width >= T::get_required_bits()` /// * the return value of the `FlatFunction` is not deterministic if `bit_width >= T::get_required_bits()`
/// as some elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)` /// as some elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
pub fn unpack_to_bitwidth<T: Field>( pub fn unpack_to_bitwidth<'ast, T: Field>(
bit_width: usize, bit_width: usize,
) -> FlatFunctionIterator<T, impl IntoIterator<Item = FlatStatement<T>>> { ) -> FlatFunctionIterator<'ast, T, impl IntoIterator<Item = FlatStatement<'ast, T>>> {
let mut counter = 0; let mut counter = 0;
let mut layout = HashMap::new(); let mut layout = HashMap::new();

View file

@ -3,6 +3,7 @@ use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub enum RuntimeError { pub enum RuntimeError {
UnsatisfiedConstraint,
BellmanConstraint, BellmanConstraint,
BellmanOneBinding, BellmanOneBinding,
BellmanInputBinding, BellmanInputBinding,
@ -63,6 +64,7 @@ impl fmt::Display for RuntimeError {
use RuntimeError::*; use RuntimeError::*;
let msg = match self { let msg = match self {
UnsatisfiedConstraint => "Constraint is unsatisfied",
BellmanConstraint => "Bellman constraint is unsatisfied", BellmanConstraint => "Bellman constraint is unsatisfied",
BellmanOneBinding => "Bellman ~one binding is unsatisfied", BellmanOneBinding => "Bellman ~one binding is unsatisfied",
BellmanInputBinding => "Bellman input binding is unsatisfied", BellmanInputBinding => "Bellman input binding is unsatisfied",

View file

@ -1,8 +1,9 @@
use crate::zir::ZirFunction;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub enum Solver { pub enum Solver<'ast, T> {
ConditionEq, ConditionEq,
Bits(usize), Bits(usize),
Div, Div,
@ -11,19 +12,24 @@ pub enum Solver {
ShaAndXorAndXorAnd, ShaAndXorAndXorAnd,
ShaCh, ShaCh,
EuclideanDiv, EuclideanDiv,
#[serde(borrow)]
Zir(ZirFunction<'ast, T>),
#[cfg(feature = "bellman")] #[cfg(feature = "bellman")]
Sha256Round, Sha256Round,
#[cfg(feature = "ark")] #[cfg(feature = "ark")]
SnarkVerifyBls12377(usize), SnarkVerifyBls12377(usize),
} }
impl fmt::Display for Solver { impl<'ast, T: fmt::Debug + fmt::Display> fmt::Display for Solver<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self) match self {
Solver::Zir(func) => write!(f, "Zir({})", func),
_ => write!(f, "{:?}", self)
}
} }
} }
impl Solver { impl<'ast, T> Solver<'ast, T> {
pub fn get_signature(&self) -> (usize, usize) { pub fn get_signature(&self) -> (usize, usize) {
match self { match self {
Solver::ConditionEq => (1, 2), Solver::ConditionEq => (1, 2),
@ -34,6 +40,7 @@ impl Solver {
Solver::ShaAndXorAndXorAnd => (3, 1), Solver::ShaAndXorAndXorAnd => (3, 1),
Solver::ShaCh => (3, 1), Solver::ShaCh => (3, 1),
Solver::EuclideanDiv => (2, 2), Solver::EuclideanDiv => (2, 2),
Solver::Zir(f) => (f.arguments.len(), 1),
#[cfg(feature = "bellman")] #[cfg(feature = "bellman")]
Solver::Sha256Round => (768, 26935), Solver::Sha256Round => (768, 26935),
#[cfg(feature = "ark")] #[cfg(feature = "ark")]
@ -42,7 +49,7 @@ impl Solver {
} }
} }
impl Solver { impl<'ast, T> Solver<'ast, T> {
pub fn bits(width: usize) -> Self { pub fn bits(width: usize) -> Self {
Solver::Bits(width) Solver::Bits(width)
} }

View file

@ -4,8 +4,8 @@ use super::*;
use crate::common::Variable; use crate::common::Variable;
use zokrates_field::Field; use zokrates_field::Field;
pub trait Folder<T: Field>: Sized { pub trait Folder<'ast, T: Field>: Sized {
fn fold_program(&mut self, p: FlatProg<T>) -> FlatProg<T> { fn fold_program(&mut self, p: FlatProg<'ast, T>) -> FlatProg<'ast, T> {
fold_program(self, p) fold_program(self, p)
} }
@ -17,7 +17,7 @@ pub trait Folder<T: Field>: Sized {
fold_variable(self, v) fold_variable(self, v)
} }
fn fold_statement(&mut self, s: FlatStatement<T>) -> Vec<FlatStatement<T>> { fn fold_statement(&mut self, s: FlatStatement<'ast, T>) -> Vec<FlatStatement<'ast, T>> {
fold_statement(self, s) fold_statement(self, s)
} }
@ -25,12 +25,15 @@ pub trait Folder<T: Field>: Sized {
fold_expression(self, e) fold_expression(self, e)
} }
fn fold_directive(&mut self, d: FlatDirective<T>) -> FlatDirective<T> { fn fold_directive(&mut self, d: FlatDirective<'ast, T>) -> FlatDirective<'ast, T> {
fold_directive(self, d) fold_directive(self, d)
} }
} }
pub fn fold_program<T: Field, F: Folder<T>>(f: &mut F, p: FlatProg<T>) -> FlatProg<T> { pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
p: FlatProg<'ast, T>,
) -> FlatProg<'ast, T> {
FlatProg { FlatProg {
arguments: p arguments: p
.arguments .arguments
@ -46,10 +49,10 @@ pub fn fold_program<T: Field, F: Folder<T>>(f: &mut F, p: FlatProg<T>) -> FlatPr
} }
} }
pub fn fold_statement<T: Field, F: Folder<T>>( pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F, f: &mut F,
s: FlatStatement<T>, s: FlatStatement<'ast, T>,
) -> Vec<FlatStatement<T>> { ) -> Vec<FlatStatement<'ast, T>> {
match s { match s {
FlatStatement::Condition(left, right, error) => vec![FlatStatement::Condition( FlatStatement::Condition(left, right, error) => vec![FlatStatement::Condition(
f.fold_expression(left), f.fold_expression(left),
@ -70,7 +73,7 @@ pub fn fold_statement<T: Field, F: Folder<T>>(
} }
} }
pub fn fold_expression<T: Field, F: Folder<T>>( pub fn fold_expression<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F, f: &mut F,
e: FlatExpression<T>, e: FlatExpression<T>,
) -> FlatExpression<T> { ) -> FlatExpression<T> {
@ -89,7 +92,10 @@ pub fn fold_expression<T: Field, F: Folder<T>>(
} }
} }
pub fn fold_directive<T: Field, F: Folder<T>>(f: &mut F, ds: FlatDirective<T>) -> FlatDirective<T> { pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
ds: FlatDirective<'ast, T>,
) -> FlatDirective<'ast, T> {
FlatDirective { FlatDirective {
inputs: ds inputs: ds
.inputs .inputs
@ -101,13 +107,13 @@ pub fn fold_directive<T: Field, F: Folder<T>>(f: &mut F, ds: FlatDirective<T>) -
} }
} }
pub fn fold_argument<T: Field, F: Folder<T>>(f: &mut F, a: Parameter) -> Parameter { pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter {
Parameter { Parameter {
id: f.fold_variable(a.id), id: f.fold_variable(a.id),
private: a.private, private: a.private,
} }
} }
pub fn fold_variable<T: Field, F: Folder<T>>(_f: &mut F, v: Variable) -> Variable { pub fn fold_variable<'ast, T: Field, F: Folder<'ast, T>>(_f: &mut F, v: Variable) -> Variable {
v v
} }

View file

@ -24,14 +24,14 @@ use std::collections::HashMap;
use std::fmt; use std::fmt;
use zokrates_field::Field; use zokrates_field::Field;
pub type FlatProg<T> = FlatFunction<T>; pub type FlatProg<'ast, T> = FlatFunction<'ast, T>;
pub type FlatFunction<T> = FlatFunctionIterator<T, Vec<FlatStatement<T>>>; pub type FlatFunction<'ast, T> = FlatFunctionIterator<'ast, T, Vec<FlatStatement<'ast, T>>>;
pub type FlatProgIterator<T, I> = FlatFunctionIterator<T, I>; pub type FlatProgIterator<'ast, T, I> = FlatFunctionIterator<'ast, T, I>;
#[derive(Clone, PartialEq, Eq, Debug)] #[derive(Clone, PartialEq, Eq, Debug)]
pub struct FlatFunctionIterator<T, I: IntoIterator<Item = FlatStatement<T>>> { pub struct FlatFunctionIterator<'ast, T, I: IntoIterator<Item = FlatStatement<'ast, T>>> {
/// Arguments of the function /// Arguments of the function
pub arguments: Vec<Parameter>, pub arguments: Vec<Parameter>,
/// Vector of statements that are executed when running the function /// Vector of statements that are executed when running the function
@ -40,8 +40,8 @@ pub struct FlatFunctionIterator<T, I: IntoIterator<Item = FlatStatement<T>>> {
pub return_count: usize, pub return_count: usize,
} }
impl<T, I: IntoIterator<Item = FlatStatement<T>>> FlatFunctionIterator<T, I> { impl<'ast, T, I: IntoIterator<Item = FlatStatement<'ast, T>>> FlatFunctionIterator<'ast, T, I> {
pub fn collect(self) -> FlatFunction<T> { pub fn collect(self) -> FlatFunction<'ast, T> {
FlatFunction { FlatFunction {
statements: self.statements.into_iter().collect(), statements: self.statements.into_iter().collect(),
arguments: self.arguments, arguments: self.arguments,
@ -50,7 +50,7 @@ impl<T, I: IntoIterator<Item = FlatStatement<T>>> FlatFunctionIterator<T, I> {
} }
} }
impl<T: Field> fmt::Display for FlatFunction<T> { impl<'ast, T: Field> fmt::Display for FlatFunction<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( write!(
f, f,
@ -81,14 +81,14 @@ impl<T: Field> fmt::Display for FlatFunction<T> {
/// * r1cs - R1CS in standard JSON data format /// * r1cs - R1CS in standard JSON data format
#[derive(Clone, PartialEq, Eq, Debug)] #[derive(Clone, PartialEq, Eq, Debug)]
pub enum FlatStatement<T> { pub enum FlatStatement<'ast, T> {
Condition(FlatExpression<T>, FlatExpression<T>, RuntimeError), Condition(FlatExpression<T>, FlatExpression<T>, RuntimeError),
Definition(Variable, FlatExpression<T>), Definition(Variable, FlatExpression<T>),
Directive(FlatDirective<T>), Directive(FlatDirective<'ast, T>),
Log(FormatString, Vec<(ConcreteType, Vec<FlatExpression<T>>)>), Log(FormatString, Vec<(ConcreteType, Vec<FlatExpression<T>>)>),
} }
impl<T: Field> fmt::Display for FlatStatement<T> { impl<'ast, T: Field> fmt::Display for FlatStatement<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
FlatStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs), FlatStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
@ -116,10 +116,10 @@ impl<T: Field> fmt::Display for FlatStatement<T> {
} }
} }
impl<T: Field> FlatStatement<T> { impl<'ast, T: Field> FlatStatement<'ast, T> {
pub fn apply_substitution( pub fn apply_substitution(
self, self,
substitution: &HashMap<Variable, Variable>, substitution: &'ast HashMap<Variable, Variable>,
) -> FlatStatement<T> { ) -> FlatStatement<T> {
match self { match self {
FlatStatement::Definition(id, x) => FlatStatement::Definition( FlatStatement::Definition(id, x) => FlatStatement::Definition(
@ -167,16 +167,16 @@ impl<T: Field> FlatStatement<T> {
} }
#[derive(Clone, Hash, Debug, PartialEq, Eq)] #[derive(Clone, Hash, Debug, PartialEq, Eq)]
pub struct FlatDirective<T> { pub struct FlatDirective<'ast, T> {
pub inputs: Vec<FlatExpression<T>>, pub inputs: Vec<FlatExpression<T>>,
pub outputs: Vec<Variable>, pub outputs: Vec<Variable>,
pub solver: Solver, pub solver: Solver<'ast, T>,
} }
impl<T> FlatDirective<T> { impl<'ast, T> FlatDirective<'ast, T> {
pub fn new<E: Into<FlatExpression<T>>>( pub fn new<E: Into<FlatExpression<T>>>(
outputs: Vec<Variable>, outputs: Vec<Variable>,
solver: Solver, solver: Solver<'ast, T>,
inputs: Vec<E>, inputs: Vec<E>,
) -> Self { ) -> Self {
let (in_len, out_len) = solver.get_signature(); let (in_len, out_len) = solver.get_signature();
@ -190,7 +190,7 @@ impl<T> FlatDirective<T> {
} }
} }
impl<T: Field> fmt::Display for FlatDirective<T> { impl<'ast, T: Field> fmt::Display for FlatDirective<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( write!(
f, f,

View file

@ -13,7 +13,9 @@ pub struct UnconstrainedVariableDetector {
} }
impl UnconstrainedVariableDetector { impl UnconstrainedVariableDetector {
pub fn new<T: Field, I: IntoIterator<Item = Statement<T>>>(p: &ProgIterator<T, I>) -> Self { pub fn new<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>>(
p: &ProgIterator<'ast, T, I>,
) -> Self {
UnconstrainedVariableDetector { UnconstrainedVariableDetector {
variables: p variables: p
.arguments .arguments
@ -32,7 +34,7 @@ impl UnconstrainedVariableDetector {
} }
} }
impl<T: Field> Folder<T> for UnconstrainedVariableDetector { impl<'ast, T: Field> Folder<'ast, T> for UnconstrainedVariableDetector {
fn fold_argument(&mut self, p: Parameter) -> Parameter { fn fold_argument(&mut self, p: Parameter) -> Parameter {
p p
} }
@ -40,7 +42,7 @@ impl<T: Field> Folder<T> for UnconstrainedVariableDetector {
self.variables.remove(&v); self.variables.remove(&v);
v v
} }
fn fold_directive(&mut self, d: Directive<T>) -> Directive<T> { fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> {
self.variables.extend(d.outputs.iter()); self.variables.extend(d.outputs.iter());
d d
} }

View file

@ -4,8 +4,8 @@ use super::*;
use crate::common::Variable; use crate::common::Variable;
use zokrates_field::Field; use zokrates_field::Field;
pub trait Folder<T: Field>: Sized { pub trait Folder<'ast, T: Field>: Sized {
fn fold_program(&mut self, p: Prog<T>) -> Prog<T> { fn fold_program(&mut self, p: Prog<'ast, T>) -> Prog<'ast, T> {
fold_program(self, p) fold_program(self, p)
} }
@ -17,7 +17,7 @@ pub trait Folder<T: Field>: Sized {
fold_variable(self, v) fold_variable(self, v)
} }
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> { fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
fold_statement(self, s) fold_statement(self, s)
} }
@ -29,12 +29,15 @@ pub trait Folder<T: Field>: Sized {
fold_quadratic_combination(self, es) fold_quadratic_combination(self, es)
} }
fn fold_directive(&mut self, d: Directive<T>) -> Directive<T> { fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> {
fold_directive(self, d) fold_directive(self, d)
} }
} }
pub fn fold_program<T: Field, F: Folder<T>>(f: &mut F, p: Prog<T>) -> Prog<T> { pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
p: Prog<'ast, T>,
) -> Prog<'ast, T> {
Prog { Prog {
arguments: p arguments: p
.arguments .arguments
@ -50,7 +53,10 @@ pub fn fold_program<T: Field, F: Folder<T>>(f: &mut F, p: Prog<T>) -> Prog<T> {
} }
} }
pub fn fold_statement<T: Field, F: Folder<T>>(f: &mut F, s: Statement<T>) -> Vec<Statement<T>> { pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: Statement<'ast, T>,
) -> Vec<Statement<'ast, T>> {
match s { match s {
Statement::Constraint(quad, lin, message) => vec![Statement::Constraint( Statement::Constraint(quad, lin, message) => vec![Statement::Constraint(
f.fold_quadratic_combination(quad), f.fold_quadratic_combination(quad),
@ -74,7 +80,10 @@ pub fn fold_statement<T: Field, F: Folder<T>>(f: &mut F, s: Statement<T>) -> Vec
} }
} }
pub fn fold_linear_combination<T: Field, F: Folder<T>>(f: &mut F, e: LinComb<T>) -> LinComb<T> { pub fn fold_linear_combination<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
e: LinComb<T>,
) -> LinComb<T> {
LinComb( LinComb(
e.0.into_iter() e.0.into_iter()
.map(|(variable, coefficient)| (f.fold_variable(variable), coefficient)) .map(|(variable, coefficient)| (f.fold_variable(variable), coefficient))
@ -82,7 +91,7 @@ pub fn fold_linear_combination<T: Field, F: Folder<T>>(f: &mut F, e: LinComb<T>)
) )
} }
pub fn fold_quadratic_combination<T: Field, F: Folder<T>>( pub fn fold_quadratic_combination<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F, f: &mut F,
e: QuadComb<T>, e: QuadComb<T>,
) -> QuadComb<T> { ) -> QuadComb<T> {
@ -92,7 +101,10 @@ pub fn fold_quadratic_combination<T: Field, F: Folder<T>>(
} }
} }
pub fn fold_directive<T: Field, F: Folder<T>>(f: &mut F, ds: Directive<T>) -> Directive<T> { pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
ds: Directive<'ast, T>,
) -> Directive<'ast, T> {
Directive { Directive {
inputs: ds inputs: ds
.inputs .inputs
@ -104,13 +116,13 @@ pub fn fold_directive<T: Field, F: Folder<T>>(f: &mut F, ds: Directive<T>) -> Di
} }
} }
pub fn fold_argument<T: Field, F: Folder<T>>(f: &mut F, a: Parameter) -> Parameter { pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter {
Parameter { Parameter {
id: f.fold_variable(a.id), id: f.fold_variable(a.id),
private: a.private, private: a.private,
} }
} }
pub fn fold_variable<T: Field, F: Folder<T>>(_f: &mut F, v: Variable) -> Variable { pub fn fold_variable<'ast, T: Field, F: Folder<'ast, T>>(_f: &mut F, v: Variable) -> Variable {
v v
} }

View file

@ -17,9 +17,9 @@ impl<T: Field> QuadComb<T> {
} }
} }
pub fn from_flat<T: Field, I: IntoIterator<Item = FlatStatement<T>>>( pub fn from_flat<'ast, T: Field, I: IntoIterator<Item = FlatStatement<'ast, T>>>(
flat_prog_iterator: FlatProgIterator<T, I>, flat_prog_iterator: FlatProgIterator<'ast, T, I>,
) -> ProgIterator<T, impl IntoIterator<Item = Statement<T>>> { ) -> ProgIterator<T, impl IntoIterator<Item = Statement<'ast, T>>> {
ProgIterator { ProgIterator {
statements: flat_prog_iterator.statements.into_iter().map(Into::into), statements: flat_prog_iterator.statements.into_iter().map(Into::into),
arguments: flat_prog_iterator.arguments, arguments: flat_prog_iterator.arguments,
@ -52,8 +52,8 @@ impl<T: Field> From<FlatExpression<T>> for LinComb<T> {
} }
} }
impl<T: Field> From<FlatStatement<T>> for Statement<T> { impl<'ast, T: Field> From<FlatStatement<'ast, T>> for Statement<'ast, T> {
fn from(flat_statement: FlatStatement<T>) -> Statement<T> { fn from(flat_statement: FlatStatement<'ast, T>) -> Statement<'ast, T> {
match flat_statement { match flat_statement {
FlatStatement::Condition(linear, quadratic, message) => match quadratic { FlatStatement::Condition(linear, quadratic, message) => match quadratic {
FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint( FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint(
@ -83,8 +83,8 @@ impl<T: Field> From<FlatStatement<T>> for Statement<T> {
} }
} }
impl<T: Field> From<FlatDirective<T>> for Directive<T> { impl<'ast, T: Field> From<FlatDirective<'ast, T>> for Directive<'ast, T> {
fn from(ds: FlatDirective<T>) -> Directive<T> { fn from(ds: FlatDirective<'ast, T>) -> Directive<T> {
Directive { Directive {
inputs: ds inputs: ds
.inputs .inputs

View file

@ -26,15 +26,16 @@ pub use crate::common::Variable;
pub use self::witness::Witness; pub use self::witness::Witness;
#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)] #[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
pub enum Statement<T> { pub enum Statement<'ast, T> {
Constraint(QuadComb<T>, LinComb<T>, Option<RuntimeError>), Constraint(QuadComb<T>, LinComb<T>, Option<RuntimeError>),
Directive(Directive<T>), #[serde(borrow)]
Directive(Directive<'ast, T>),
Log(FormatString, Vec<(ConcreteType, Vec<LinComb<T>>)>), Log(FormatString, Vec<(ConcreteType, Vec<LinComb<T>>)>),
} }
pub type PublicInputs = BTreeSet<Variable>; pub type PublicInputs = BTreeSet<Variable>;
impl<T: Field> Statement<T> { impl<'ast, T: Field> Statement<'ast, T> {
pub fn definition<U: Into<QuadComb<T>>>(v: Variable, e: U) -> Self { pub fn definition<U: Into<QuadComb<T>>>(v: Variable, e: U) -> Self {
Statement::Constraint(e.into(), v.into(), None) Statement::Constraint(e.into(), v.into(), None)
} }
@ -45,13 +46,14 @@ impl<T: Field> Statement<T> {
} }
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] #[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub struct Directive<T> { pub struct Directive<'ast, T> {
pub inputs: Vec<QuadComb<T>>, pub inputs: Vec<QuadComb<T>>,
pub outputs: Vec<Variable>, pub outputs: Vec<Variable>,
pub solver: Solver, #[serde(borrow)]
pub solver: Solver<'ast, T>,
} }
impl<T: Field> fmt::Display for Directive<T> { impl<'ast, T: Field> fmt::Display for Directive<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( write!(
f, f,
@ -71,7 +73,7 @@ impl<T: Field> fmt::Display for Directive<T> {
} }
} }
impl<T: Field> fmt::Display for Statement<T> { impl<'ast, T: Field> fmt::Display for Statement<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
Statement::Constraint(ref quad, ref lin, _) => write!(f, "{} == {}", quad, lin), Statement::Constraint(ref quad, ref lin, _) => write!(f, "{} == {}", quad, lin),
@ -96,16 +98,16 @@ impl<T: Field> fmt::Display for Statement<T> {
} }
} }
pub type Prog<T> = ProgIterator<T, Vec<Statement<T>>>; pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec<Statement<'ast, T>>>;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
pub struct ProgIterator<T, I: IntoIterator<Item = Statement<T>>> { pub struct ProgIterator<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> {
pub arguments: Vec<Parameter>, pub arguments: Vec<Parameter>,
pub return_count: usize, pub return_count: usize,
pub statements: I, pub statements: I,
} }
impl<T, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> { impl<'ast, T, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
pub fn new(arguments: Vec<Parameter>, statements: I, return_count: usize) -> Self { pub fn new(arguments: Vec<Parameter>, statements: I, return_count: usize) -> Self {
Self { Self {
arguments, arguments,
@ -114,7 +116,7 @@ impl<T, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> {
} }
} }
pub fn collect(self) -> ProgIterator<T, Vec<Statement<T>>> { pub fn collect(self) -> ProgIterator<'ast, T, Vec<Statement<'ast, T>>> {
ProgIterator { ProgIterator {
statements: self.statements.into_iter().collect::<Vec<_>>(), statements: self.statements.into_iter().collect::<Vec<_>>(),
arguments: self.arguments, arguments: self.arguments,
@ -139,7 +141,7 @@ impl<T, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> {
} }
} }
impl<T: Field, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> { impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
pub fn public_inputs_values(&self, witness: &Witness<T>) -> Vec<T> { pub fn public_inputs_values(&self, witness: &Witness<T>) -> Vec<T> {
self.arguments self.arguments
.iter() .iter()
@ -150,7 +152,7 @@ impl<T: Field, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> {
} }
} }
impl<T> Prog<T> { impl<'ast, T> Prog<'ast, T> {
pub fn constraint_count(&self) -> usize { pub fn constraint_count(&self) -> usize {
self.statements self.statements
.iter() .iter()
@ -158,7 +160,9 @@ impl<T> Prog<T> {
.count() .count()
} }
pub fn into_prog_iter(self) -> ProgIterator<T, impl IntoIterator<Item = Statement<T>>> { pub fn into_prog_iter(
self,
) -> ProgIterator<'ast, T, impl IntoIterator<Item = Statement<'ast, T>>> {
ProgIterator { ProgIterator {
statements: self.statements.into_iter(), statements: self.statements.into_iter(),
arguments: self.arguments, arguments: self.arguments,
@ -167,7 +171,7 @@ impl<T> Prog<T> {
} }
} }
impl<T: Field> fmt::Display for Prog<T> { impl<'ast, T: Field> fmt::Display for Prog<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let returns = (0..self.return_count) let returns = (0..self.return_count)
.map(Variable::public) .map(Variable::public)

View file

@ -12,32 +12,35 @@ const ZOKRATES_VERSION_2: &[u8; 4] = &[0, 0, 0, 2];
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug)]
pub enum ProgEnum< pub enum ProgEnum<
Bls12_381I: IntoIterator<Item = Statement<Bls12_381Field>>, 'ast,
Bn128I: IntoIterator<Item = Statement<Bn128Field>>, Bls12_381I: IntoIterator<Item = Statement<'ast, Bls12_381Field>>,
Bls12_377I: IntoIterator<Item = Statement<Bls12_377Field>>, Bn128I: IntoIterator<Item = Statement<'ast, Bn128Field>>,
Bw6_761I: IntoIterator<Item = Statement<Bw6_761Field>>, Bls12_377I: IntoIterator<Item = Statement<'ast, Bls12_377Field>>,
Bw6_761I: IntoIterator<Item = Statement<'ast, Bw6_761Field>>,
> { > {
Bls12_381Program(ProgIterator<Bls12_381Field, Bls12_381I>), Bls12_381Program(ProgIterator<'ast, Bls12_381Field, Bls12_381I>),
Bn128Program(ProgIterator<Bn128Field, Bn128I>), Bn128Program(ProgIterator<'ast, Bn128Field, Bn128I>),
Bls12_377Program(ProgIterator<Bls12_377Field, Bls12_377I>), Bls12_377Program(ProgIterator<'ast, Bls12_377Field, Bls12_377I>),
Bw6_761Program(ProgIterator<Bw6_761Field, Bw6_761I>), Bw6_761Program(ProgIterator<'ast, Bw6_761Field, Bw6_761I>),
} }
type MemoryProgEnum = ProgEnum< type MemoryProgEnum<'ast> = ProgEnum<
Vec<Statement<Bls12_381Field>>, 'ast,
Vec<Statement<Bn128Field>>, Vec<Statement<'ast, Bls12_381Field>>,
Vec<Statement<Bls12_377Field>>, Vec<Statement<'ast, Bn128Field>>,
Vec<Statement<Bw6_761Field>>, Vec<Statement<'ast, Bls12_377Field>>,
Vec<Statement<'ast, Bw6_761Field>>,
>; >;
impl< impl<
Bls12_381I: IntoIterator<Item = Statement<Bls12_381Field>>, 'ast,
Bn128I: IntoIterator<Item = Statement<Bn128Field>>, Bls12_381I: IntoIterator<Item = Statement<'ast, Bls12_381Field>>,
Bls12_377I: IntoIterator<Item = Statement<Bls12_377Field>>, Bn128I: IntoIterator<Item = Statement<'ast, Bn128Field>>,
Bw6_761I: IntoIterator<Item = Statement<Bw6_761Field>>, Bls12_377I: IntoIterator<Item = Statement<'ast, Bls12_377Field>>,
> ProgEnum<Bls12_381I, Bn128I, Bls12_377I, Bw6_761I> Bw6_761I: IntoIterator<Item = Statement<'ast, Bw6_761Field>>,
> ProgEnum<'ast, Bls12_381I, Bn128I, Bls12_377I, Bw6_761I>
{ {
pub fn collect(self) -> MemoryProgEnum { pub fn collect(self) -> MemoryProgEnum<'ast> {
match self { match self {
ProgEnum::Bls12_381Program(p) => ProgEnum::Bls12_381Program(p.collect()), ProgEnum::Bls12_381Program(p) => ProgEnum::Bls12_381Program(p.collect()),
ProgEnum::Bn128Program(p) => ProgEnum::Bn128Program(p.collect()), ProgEnum::Bn128Program(p) => ProgEnum::Bn128Program(p.collect()),
@ -55,7 +58,7 @@ impl<
} }
} }
impl<T: Field, I: IntoIterator<Item = Statement<T>>> ProgIterator<T, I> { impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'ast, T, I> {
/// serialize a program iterator, returning the number of constraints serialized /// serialize a program iterator, returning the number of constraints serialized
/// Note that we only return constraints, not other statements such as directives /// Note that we only return constraints, not other statements such as directives
pub fn serialize<W: Write>(self, mut w: W) -> Result<usize, DynamicError> { pub fn serialize<W: Write>(self, mut w: W) -> Result<usize, DynamicError> {
@ -106,10 +109,11 @@ impl<'de, R: serde_cbor::de::Read<'de>, T: serde::Deserialize<'de>> Iterator
impl<'de, R: Read> impl<'de, R: Read>
ProgEnum< ProgEnum<
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<Bls12_381Field>>, 'de,
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<Bn128Field>>, UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bls12_381Field>>,
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<Bls12_377Field>>, UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bn128Field>>,
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<Bw6_761Field>>, UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bls12_377Field>>,
UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead<R>, Statement<'de, Bw6_761Field>>,
> >
{ {
pub fn deserialize(mut r: R) -> Result<Self, String> { pub fn deserialize(mut r: R) -> Result<Self, String> {

View file

@ -12,9 +12,9 @@ pub trait SMTLib2 {
fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result; fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result;
} }
pub struct SMTLib2Display<'a, T>(pub &'a Prog<T>); pub struct SMTLib2Display<'a, 'ast, T>(pub &'a Prog<'ast, T>);
impl<T: Field> fmt::Display for SMTLib2Display<'_, T> { impl<'ast, T: Field> fmt::Display for SMTLib2Display<'_, 'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.to_smtlib2(f) self.0.to_smtlib2(f)
} }
@ -30,7 +30,7 @@ impl<T: Field> Visitor<T> for VariableCollector {
} }
} }
impl<T: Field> SMTLib2 for Prog<T> { impl<'ast, T: Field> SMTLib2 for Prog<'ast, T> {
fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut collector = VariableCollector { let mut collector = VariableCollector {
variables: BTreeSet::<Variable>::new(), variables: BTreeSet::<Variable>::new(),
@ -75,7 +75,7 @@ fn format_prefix_op_smtlib2<T: SMTLib2, Ts: SMTLib2>(
write!(f, ")") write!(f, ")")
} }
impl<T: Field> SMTLib2 for Statement<T> { impl<'ast, T: Field> SMTLib2 for Statement<'ast, T> {
fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
Statement::Constraint(ref quad, ref lin, _) => { Statement::Constraint(ref quad, ref lin, _) => {
@ -91,7 +91,7 @@ impl<T: Field> SMTLib2 for Statement<T> {
} }
} }
impl<T: Field> SMTLib2 for Directive<T> { impl<'ast, T: Field> SMTLib2 for Directive<'ast, T> {
fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "") write!(f, "")
} }

View file

@ -260,6 +260,13 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_assignee(self, a) fold_assignee(self, a)
} }
fn fold_assembly_statement(
&mut self,
s: TypedAssemblyStatement<'ast, T>,
) -> TypedAssemblyStatement<'ast, T> {
fold_assembly_statement(self, s)
}
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> { fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
fold_statement(self, s) fold_statement(self, s)
} }
@ -505,6 +512,21 @@ pub fn fold_definition_rhs<'ast, T: Field, F: Folder<'ast, T>>(
} }
} }
pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: TypedAssemblyStatement<'ast, T>,
) -> TypedAssemblyStatement<'ast, T> {
match s {
TypedAssemblyStatement::Assignment(a, e) => {
TypedAssemblyStatement::Assignment(f.fold_assignee(a), f.fold_field_expression(e))
}
TypedAssemblyStatement::Constraint(lhs, rhs) => TypedAssemblyStatement::Constraint(
f.fold_field_expression(lhs),
f.fold_field_expression(rhs),
),
}
}
pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F, f: &mut F,
s: TypedStatement<'ast, T>, s: TypedStatement<'ast, T>,
@ -529,6 +551,12 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
TypedStatement::Log(s, e) => { TypedStatement::Log(s, e) => {
TypedStatement::Log(s, e.into_iter().map(|e| f.fold_expression(e)).collect()) TypedStatement::Log(s, e.into_iter().map(|e| f.fold_expression(e)).collect())
} }
TypedStatement::Assembly(statements) => TypedStatement::Assembly(
statements
.into_iter()
.map(|s| f.fold_assembly_statement(s))
.collect(),
),
s => s, s => s,
}; };
vec![res] vec![res]

View file

@ -1,10 +1,12 @@
use crate::typed::CanonicalConstantIdentifier; use crate::typed::CanonicalConstantIdentifier;
use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
pub type SourceIdentifier<'ast> = &'ast str; pub type SourceIdentifier<'ast> = std::borrow::Cow<'ast, str>;
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)] #[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum CoreIdentifier<'ast> { pub enum CoreIdentifier<'ast> {
#[serde(borrow)]
Source(ShadowedIdentifier<'ast>), Source(ShadowedIdentifier<'ast>),
Call(usize), Call(usize),
Constant(CanonicalConstantIdentifier<'ast>), Constant(CanonicalConstantIdentifier<'ast>),
@ -29,16 +31,18 @@ impl<'ast> From<CanonicalConstantIdentifier<'ast>> for CoreIdentifier<'ast> {
} }
/// A identifier for a variable /// A identifier for a variable
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)] #[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Identifier<'ast> { pub struct Identifier<'ast> {
/// the id of the variable /// the id of the variable
#[serde(borrow)]
pub id: CoreIdentifier<'ast>, pub id: CoreIdentifier<'ast>,
/// the version of the variable, used after SSA transformation /// the version of the variable, used after SSA transformation
pub version: usize, pub version: usize,
} }
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)] #[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ShadowedIdentifier<'ast> { pub struct ShadowedIdentifier<'ast> {
#[serde(borrow)]
pub id: SourceIdentifier<'ast>, pub id: SourceIdentifier<'ast>,
pub shadow: usize, pub shadow: usize,
} }
@ -97,7 +101,7 @@ impl<'ast> Identifier<'ast> {
// these two From implementations are only used in tests but somehow cfg(test) doesn't work // these two From implementations are only used in tests but somehow cfg(test) doesn't work
impl<'ast> From<&'ast str> for CoreIdentifier<'ast> { impl<'ast> From<&'ast str> for CoreIdentifier<'ast> {
fn from(s: &str) -> CoreIdentifier { fn from(s: &str) -> CoreIdentifier {
CoreIdentifier::Source(ShadowedIdentifier::shadow(s, 0)) CoreIdentifier::Source(ShadowedIdentifier::shadow(std::borrow::Cow::Borrowed(s), 0))
} }
} }

View file

@ -675,6 +675,28 @@ impl<'ast, T: fmt::Display> fmt::Display for DefinitionRhs<'ast, T> {
} }
} }
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum TypedAssemblyStatement<'ast, T> {
Assignment(TypedAssignee<'ast, T>, FieldElementExpression<'ast, T>),
Constraint(
FieldElementExpression<'ast, T>,
FieldElementExpression<'ast, T>,
),
}
impl<'ast, T: fmt::Display> fmt::Display for TypedAssemblyStatement<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
TypedAssemblyStatement::Assignment(ref lhs, ref rhs) => {
write!(f, "{} <-- {}", lhs, rhs)
}
TypedAssemblyStatement::Constraint(ref lhs, ref rhs) => {
write!(f, "{} === {}", lhs, rhs)
}
}
}
}
/// A statement in a `TypedFunction` /// A statement in a `TypedFunction`
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
@ -695,6 +717,7 @@ pub enum TypedStatement<'ast, T> {
ConcreteGenericsAssignment<'ast>, ConcreteGenericsAssignment<'ast>,
), ),
PopCallLog, PopCallLog,
Assembly(Vec<TypedAssemblyStatement<'ast, T>>),
} }
impl<'ast, T> TypedStatement<'ast, T> { impl<'ast, T> TypedStatement<'ast, T> {
@ -719,6 +742,14 @@ impl<'ast, T: fmt::Display> TypedStatement<'ast, T> {
} }
write!(f, "{}}}", "\t".repeat(depth)) write!(f, "{}}}", "\t".repeat(depth))
} }
TypedStatement::Assembly(statements) => {
write!(f, "{}", "\t".repeat(depth))?;
writeln!(f, "asm {{")?;
for s in statements {
writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?;
}
write!(f, "{}}}", "\t".repeat(depth))
}
s => write!(f, "{}{}", "\t".repeat(depth), s), s => write!(f, "{}{}", "\t".repeat(depth), s),
} }
} }
@ -766,6 +797,13 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> {
generics, generics,
), ),
TypedStatement::PopCallLog => write!(f, "// POP CALL",), TypedStatement::PopCallLog => write!(f, "// POP CALL",),
TypedStatement::Assembly(ref statements) => {
writeln!(f, "asm {{")?;
for s in statements {
writeln!(f, "\t\t{}", s)?;
}
write!(f, "\t}}")
}
} }
} }
} }
@ -1173,6 +1211,73 @@ pub enum FieldElementExpression<'ast, T> {
Select(SelectExpression<'ast, T, Self>), Select(SelectExpression<'ast, T, Self>),
Element(ElementExpression<'ast, T, Self>), Element(ElementExpression<'ast, T, Self>),
} }
impl<'ast, T: Clone> From<TypedAssignee<'ast, T>> for TupleExpression<'ast, T> {
fn from(assignee: TypedAssignee<'ast, T>) -> Self {
match assignee {
TypedAssignee::Identifier(v) => {
let inner = TupleExpressionInner::Identifier(v.id);
match v._type {
GType::Tuple(tuple_ty) => inner.annotate(tuple_ty),
_ => unreachable!(),
}
}
TypedAssignee::Select(box a, box index) => TupleExpression::select(a.into(), index),
TypedAssignee::Member(box a, id) => TupleExpression::member(a.into(), id),
TypedAssignee::Element(box a, index) => TupleExpression::element(a.into(), index),
}
}
}
impl<'ast, T: Clone> From<TypedAssignee<'ast, T>> for StructExpression<'ast, T> {
fn from(assignee: TypedAssignee<'ast, T>) -> Self {
match assignee {
TypedAssignee::Identifier(v) => {
let inner = StructExpressionInner::Identifier(v.id);
match v._type {
GType::Struct(struct_ty) => inner.annotate(struct_ty),
_ => unreachable!(),
}
}
TypedAssignee::Select(box a, box index) => StructExpression::select(a.into(), index),
TypedAssignee::Member(box a, id) => StructExpression::member(a.into(), id),
TypedAssignee::Element(box a, index) => StructExpression::element(a.into(), index),
}
}
}
impl<'ast, T: Clone> From<TypedAssignee<'ast, T>> for ArrayExpression<'ast, T> {
fn from(assignee: TypedAssignee<'ast, T>) -> Self {
match assignee {
TypedAssignee::Identifier(v) => {
let inner = ArrayExpressionInner::Identifier(v.id);
match v._type {
GType::Array(array_ty) => inner.annotate(*array_ty.ty, *array_ty.size),
_ => unreachable!(),
}
}
TypedAssignee::Select(box a, box index) => ArrayExpression::select(a.into(), index),
TypedAssignee::Member(box a, id) => ArrayExpression::member(a.into(), id),
TypedAssignee::Element(box a, index) => ArrayExpression::element(a.into(), index),
}
}
}
impl<'ast, T: Clone> From<TypedAssignee<'ast, T>> for FieldElementExpression<'ast, T> {
fn from(assignee: TypedAssignee<'ast, T>) -> Self {
match assignee {
TypedAssignee::Identifier(v) => FieldElementExpression::Identifier(v.id),
TypedAssignee::Element(box a, index) => {
FieldElementExpression::element(a.into(), index)
}
TypedAssignee::Member(box a, id) => FieldElementExpression::member(a.into(), id),
TypedAssignee::Select(box a, box index) => {
FieldElementExpression::select(a.into(), index)
}
}
}
}
impl<'ast, T> Add for FieldElementExpression<'ast, T> { impl<'ast, T> Add for FieldElementExpression<'ast, T> {
type Output = Self; type Output = Self;
@ -1209,6 +1314,9 @@ impl<'ast, T> FieldElementExpression<'ast, T> {
pub fn pow(self, other: UExpression<'ast, T>) -> Self { pub fn pow(self, other: UExpression<'ast, T>) -> Self {
FieldElementExpression::Pow(box self, box other) FieldElementExpression::Pow(box self, box other)
} }
pub fn is_quadratic(&self) -> bool {
true // TODO: implement
}
} }
impl<'ast, T> From<T> for FieldElementExpression<'ast, T> { impl<'ast, T> From<T> for FieldElementExpression<'ast, T> {

View file

@ -378,6 +378,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_assignee(self, a) fold_assignee(self, a)
} }
fn fold_assembly_statement(
&mut self,
s: TypedAssemblyStatement<'ast, T>,
) -> Result<TypedAssemblyStatement<'ast, T>, Self::Error> {
fold_assembly_statement(self, s)
}
fn fold_statement( fn fold_statement(
&mut self, &mut self,
s: TypedStatement<'ast, T>, s: TypedStatement<'ast, T>,
@ -508,6 +515,21 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
} }
} }
pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
s: TypedAssemblyStatement<'ast, T>,
) -> Result<TypedAssemblyStatement<'ast, T>, F::Error> {
Ok(match s {
TypedAssemblyStatement::Assignment(a, e) => {
TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, f.fold_field_expression(e)?)
}
TypedAssemblyStatement::Constraint(lhs, rhs) => TypedAssemblyStatement::Constraint(
f.fold_field_expression(lhs)?,
f.fold_field_expression(rhs)?,
),
})
}
pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F, f: &mut F,
s: TypedStatement<'ast, T>, s: TypedStatement<'ast, T>,
@ -538,6 +560,12 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
.map(|e| f.fold_expression(e)) .map(|e| f.fold_expression(e))
.collect::<Result<Vec<_>, _>>()?, .collect::<Result<Vec<_>, _>>()?,
), ),
TypedStatement::Assembly(statements) => TypedStatement::Assembly(
statements
.into_iter()
.map(|s| f.fold_assembly_statement(s))
.collect::<Result<Vec<_>, _>>()?,
),
s => s, s => s,
}; };
Ok(vec![res]) Ok(vec![res])

View file

@ -51,7 +51,7 @@ pub struct GenericIdentifier<'ast> {
impl<'ast> From<GenericIdentifier<'ast>> for CoreIdentifier<'ast> { impl<'ast> From<GenericIdentifier<'ast>> for CoreIdentifier<'ast> {
fn from(g: GenericIdentifier<'ast>) -> CoreIdentifier<'ast> { fn from(g: GenericIdentifier<'ast>) -> CoreIdentifier<'ast> {
// generic identifiers are always declared in the function scope, which is shadow 0 // generic identifiers are always declared in the function scope, which is shadow 0
CoreIdentifier::Source(ShadowedIdentifier::shadow(g.name(), 0)) CoreIdentifier::Source(ShadowedIdentifier::shadow(std::borrow::Cow::Borrowed(g.name()), 0))
} }
} }
@ -119,9 +119,10 @@ pub struct SpecializationError;
pub type ConstantIdentifier<'ast> = &'ast str; pub type ConstantIdentifier<'ast> = &'ast str;
#[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)] #[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct CanonicalConstantIdentifier<'ast> { pub struct CanonicalConstantIdentifier<'ast> {
pub module: OwnedTypedModuleId, pub module: OwnedTypedModuleId,
#[serde(borrow)]
pub id: ConstantIdentifier<'ast>, pub id: ConstantIdentifier<'ast>,
} }

View file

@ -280,6 +280,7 @@ impl<'ast> From<pest::Statement<'ast>> for untyped::StatementNode<'ast> {
pest::Statement::Assertion(s) => untyped::StatementNode::from(s), pest::Statement::Assertion(s) => untyped::StatementNode::from(s),
pest::Statement::Return(s) => untyped::StatementNode::from(s), pest::Statement::Return(s) => untyped::StatementNode::from(s),
pest::Statement::Log(s) => untyped::StatementNode::from(s), pest::Statement::Log(s) => untyped::StatementNode::from(s),
pest::Statement::Assembly(s) => untyped::StatementNode::from(s),
} }
} }
} }
@ -343,6 +344,32 @@ impl<'ast> From<pest::IterationStatement<'ast>> for untyped::StatementNode<'ast>
} }
} }
impl<'ast> From<pest::AssemblyStatement<'ast>> for untyped::StatementNode<'ast> {
fn from(statement: pest::AssemblyStatement<'ast>) -> untyped::StatementNode<'ast> {
use crate::untyped::NodeValue;
let statements = statement
.inner
.into_iter()
.map(|s| match s {
pest::AssemblyStatementInner::Assignment(a) => {
untyped::AssemblyStatement::Assignment(
a.assignee.into(),
a.expression.into(),
matches!(a.operator, pest::AssignmentOperator::AssignConstrain),
)
.span(a.span)
}
pest::AssemblyStatementInner::Constraint(c) => {
untyped::AssemblyStatement::Constraint(c.lhs.into(), c.rhs.into()).span(c.span)
}
})
.collect();
untyped::Statement::Assembly(statements).span(statement.span)
}
}
impl<'ast> From<pest::Expression<'ast>> for untyped::ExpressionNode<'ast> { impl<'ast> From<pest::Expression<'ast>> for untyped::ExpressionNode<'ast> {
fn from(expression: pest::Expression<'ast>) -> untyped::ExpressionNode<'ast> { fn from(expression: pest::Expression<'ast>) -> untyped::ExpressionNode<'ast> {
match expression { match expression {

View file

@ -382,6 +382,33 @@ impl<'ast> fmt::Display for Assignee<'ast> {
} }
} }
#[derive(Debug, Clone, PartialEq)]
pub enum AssemblyStatement<'ast> {
Assignment(AssigneeNode<'ast>, ExpressionNode<'ast>, bool),
Constraint(ExpressionNode<'ast>, ExpressionNode<'ast>),
}
pub type AssemblyStatementNode<'ast> = Node<AssemblyStatement<'ast>>;
impl<'ast> fmt::Display for AssemblyStatement<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
AssemblyStatement::Assignment(ref lhs, ref rhs, ref constrained) => {
write!(
f,
"{} <{} {}",
lhs,
if *constrained { "==" } else { "--" },
rhs
)
}
AssemblyStatement::Constraint(ref lhs, ref rhs) => {
write!(f, "{} === {}", lhs, rhs)
}
}
}
}
/// A statement in a `Function` /// A statement in a `Function`
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -397,6 +424,7 @@ pub enum Statement<'ast> {
Vec<StatementNode<'ast>>, Vec<StatementNode<'ast>>,
), ),
Log(&'ast str, Vec<ExpressionNode<'ast>>), Log(&'ast str, Vec<ExpressionNode<'ast>>),
Assembly(Vec<AssemblyStatementNode<'ast>>),
} }
pub type StatementNode<'ast> = Node<Statement<'ast>>; pub type StatementNode<'ast> = Node<Statement<'ast>>;
@ -431,7 +459,7 @@ impl<'ast> fmt::Display for Statement<'ast> {
} }
Statement::Log(ref l, ref expressions) => write!( Statement::Log(ref l, ref expressions) => write!(
f, f,
"log({}, {})", "log({}, {});",
l, l,
expressions expressions
.iter() .iter()
@ -439,6 +467,13 @@ impl<'ast> fmt::Display for Statement<'ast> {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", ") .join(", ")
), ),
Statement::Assembly(ref statements) => {
writeln!(f, "asm {{")?;
for s in statements {
writeln!(f, "\t\t{};", s)?;
}
write!(f, "\t}}")
}
} }
} }
} }

View file

@ -84,6 +84,7 @@ use super::*;
impl<'ast> NodeValue for Expression<'ast> {} impl<'ast> NodeValue for Expression<'ast> {}
impl<'ast> NodeValue for Assignee<'ast> {} impl<'ast> NodeValue for Assignee<'ast> {}
impl<'ast> NodeValue for Statement<'ast> {} impl<'ast> NodeValue for Statement<'ast> {}
impl<'ast> NodeValue for AssemblyStatement<'ast> {}
impl<'ast> NodeValue for SymbolDeclaration<'ast> {} impl<'ast> NodeValue for SymbolDeclaration<'ast> {}
impl<'ast> NodeValue for UnresolvedType<'ast> {} impl<'ast> NodeValue for UnresolvedType<'ast> {}
impl<'ast> NodeValue for StructDefinition<'ast> {} impl<'ast> NodeValue for StructDefinition<'ast> {}

View file

@ -56,6 +56,13 @@ pub trait Folder<'ast, T: Field>: Sized {
self.fold_variable(a) self.fold_variable(a)
} }
fn fold_assembly_statement(
&mut self,
s: ZirAssemblyStatement<'ast, T>,
) -> ZirAssemblyStatement<'ast, T> {
fold_assembly_statement(self, s)
}
fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec<ZirStatement<'ast, T>> { fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec<ZirStatement<'ast, T>> {
fold_statement(self, s) fold_statement(self, s)
} }
@ -127,6 +134,24 @@ pub trait Folder<'ast, T: Field>: Sized {
} }
} }
pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: ZirAssemblyStatement<'ast, T>,
) -> ZirAssemblyStatement<'ast, T> {
match s {
ZirAssemblyStatement::Assignment(assignees, function) => {
let assignees = assignees.into_iter().map(|a| f.fold_assignee(a)).collect();
let function = f.fold_function(function);
ZirAssemblyStatement::Assignment(assignees, function)
}
ZirAssemblyStatement::Constraint(lhs, rhs) => {
let lhs = f.fold_field_expression(lhs);
let rhs = f.fold_field_expression(rhs);
ZirAssemblyStatement::Constraint(lhs, rhs)
}
}
}
pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F, f: &mut F,
s: ZirStatement<'ast, T>, s: ZirStatement<'ast, T>,
@ -165,6 +190,12 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
.map(|(t, e)| (t, e.into_iter().map(|e| f.fold_expression(e)).collect())) .map(|(t, e)| (t, e.into_iter().map(|e| f.fold_expression(e)).collect()))
.collect(), .collect(),
), ),
ZirStatement::Assembly(statements) => ZirStatement::Assembly(
statements
.into_iter()
.map(|s| f.fold_assembly_statement(s))
.collect(),
),
}; };
vec![res] vec![res]
} }

View file

@ -1,15 +1,18 @@
use crate::zir::types::MemberId; use crate::zir::types::MemberId;
use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
use crate::typed::Identifier as CoreIdentifier; use crate::typed::Identifier as CoreIdentifier;
#[derive(Debug, PartialEq, Clone, Hash, Eq)] #[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)]
pub enum Identifier<'ast> { pub enum Identifier<'ast> {
#[serde(borrow)]
Source(SourceIdentifier<'ast>), Source(SourceIdentifier<'ast>),
} }
#[derive(Debug, PartialEq, Clone, Hash, Eq)] #[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)]
pub enum SourceIdentifier<'ast> { pub enum SourceIdentifier<'ast> {
#[serde(borrow)]
Basic(CoreIdentifier<'ast>), Basic(CoreIdentifier<'ast>),
Select(Box<SourceIdentifier<'ast>>, u32), Select(Box<SourceIdentifier<'ast>>, u32),
Member(Box<SourceIdentifier<'ast>>, MemberId), Member(Box<SourceIdentifier<'ast>>, MemberId),

View file

@ -21,6 +21,7 @@ use zokrates_field::Field;
pub use self::folder::Folder; pub use self::folder::Folder;
pub use self::identifier::{Identifier, SourceIdentifier}; pub use self::identifier::{Identifier, SourceIdentifier};
use serde::{Deserialize, Serialize};
/// A typed program as a collection of modules, one of them being the main /// A typed program as a collection of modules, one of them being the main
#[derive(PartialEq, Eq, Debug, Clone)] #[derive(PartialEq, Eq, Debug, Clone)]
@ -34,11 +35,13 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirProgram<'ast, T> {
} }
} }
/// A typed function /// A typed function
#[derive(Clone, PartialEq, Eq)] #[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
pub struct ZirFunction<'ast, T> { pub struct ZirFunction<'ast, T> {
/// Arguments of the function /// Arguments of the function
#[serde(borrow)]
pub arguments: Vec<Parameter<'ast>>, pub arguments: Vec<Parameter<'ast>>,
/// Vector of statements that are executed when running the function /// Vector of statements that are executed when running the function
#[serde(borrow)]
pub statements: Vec<ZirStatement<'ast, T>>, pub statements: Vec<ZirStatement<'ast, T>>,
/// function signature /// function signature
pub signature: Signature, pub signature: Signature,
@ -88,7 +91,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ZirFunction<'ast, T> {
pub type ZirAssignee<'ast> = Variable<'ast>; pub type ZirAssignee<'ast> = Variable<'ast>;
#[derive(Debug, Clone, PartialEq, Hash, Eq)] #[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
pub enum RuntimeError { pub enum RuntimeError {
SourceAssertion(String), SourceAssertion(String),
SelectRangeCheck, SelectRangeCheck,
@ -113,8 +116,70 @@ impl RuntimeError {
} }
} }
// #[derive(Clone, PartialEq, Hash, Eq, Debug)]
// pub struct ZirBlock<'ast, T> {
// pub statements: Vec<ZirStatement<'ast, T>>,
// pub value: FieldElementExpression<'ast, T>,
// }
//
// impl<'ast, T> ZirBlock<'ast, T> {
// pub fn new(
// statements: Vec<ZirStatement<'ast, T>>,
// value: FieldElementExpression<'ast, T>,
// ) -> Self {
// Self { statements, value }
// }
// }
// impl<'ast, T: fmt::Display> fmt::Display for ZirBlock<'ast, T> {
// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// write!(
// f,
// "{{\n{}\n}}",
// self.statements
// .iter()
// .map(|s| s.to_string())
// .chain(std::iter::once(self.value.to_string()))
// .collect::<Vec<_>>()
// .join("\n")
// )
// }
// }
#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)]
pub enum ZirAssemblyStatement<'ast, T> {
Assignment(
#[serde(borrow)] Vec<ZirAssignee<'ast>>,
ZirFunction<'ast, T>,
),
Constraint(
FieldElementExpression<'ast, T>,
FieldElementExpression<'ast, T>,
),
}
impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ZirAssemblyStatement::Assignment(ref lhs, ref rhs) => {
write!(
f,
"{} <-- {}",
lhs.iter()
.map(|a| a.id.to_string())
.collect::<Vec<String>>()
.join(", "),
rhs
)
}
ZirAssemblyStatement::Constraint(ref lhs, ref rhs) => {
write!(f, "{} === {}", lhs, rhs)
}
}
}
}
/// A statement in a `ZirFunction` /// A statement in a `ZirFunction`
#[derive(Clone, PartialEq, Hash, Eq, Debug)] #[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)]
pub enum ZirStatement<'ast, T> { pub enum ZirStatement<'ast, T> {
Return(Vec<ZirExpression<'ast, T>>), Return(Vec<ZirExpression<'ast, T>>),
Definition(ZirAssignee<'ast>, ZirExpression<'ast, T>), Definition(ZirAssignee<'ast>, ZirExpression<'ast, T>),
@ -129,6 +194,7 @@ pub enum ZirStatement<'ast, T> {
FormatString, FormatString,
Vec<(ConcreteType, Vec<ZirExpression<'ast, T>>)>, Vec<(ConcreteType, Vec<ZirExpression<'ast, T>>)>,
), ),
Assembly(#[serde(borrow)] Vec<ZirAssemblyStatement<'ast, T>>),
} }
impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> {
@ -142,15 +208,19 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
write!(f, "{}", "\t".repeat(depth))?; write!(f, "{}", "\t".repeat(depth))?;
match self { match self {
ZirStatement::Return(ref exprs) => { ZirStatement::Return(ref exprs) => {
write!( write!(f, "return")?;
f, if exprs.len() > 0 {
"return {};", write!(
exprs f,
.iter() " {}",
.map(|e| e.to_string()) exprs
.collect::<Vec<_>>() .iter()
.join(", ") .map(|e| e.to_string())
) .collect::<Vec<_>>()
.join(", ")
)?;
}
write!(f, ";")
} }
ZirStatement::Definition(ref lhs, ref rhs) => { ZirStatement::Definition(ref lhs, ref rhs) => {
write!(f, "{} = {};", lhs, rhs) write!(f, "{} = {};", lhs, rhs)
@ -166,7 +236,7 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
s.fmt_indented(f, depth + 1)?; s.fmt_indented(f, depth + 1)?;
writeln!(f)?; writeln!(f)?;
} }
write!(f, "{}}};", "\t".repeat(depth)) write!(f, "{}}}", "\t".repeat(depth))
} }
ZirStatement::Assertion(ref e, ref error) => { ZirStatement::Assertion(ref e, ref error) => {
write!(f, "assert({}", e)?; write!(f, "assert({}", e)?;
@ -200,6 +270,13 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", ") .join(", ")
), ),
ZirStatement::Assembly(statements) => {
writeln!(f, "asm {{")?;
for s in statements {
writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?;
}
write!(f, "{}}}", "\t".repeat(depth))
}
} }
} }
} }
@ -208,8 +285,9 @@ pub trait Typed {
fn get_type(&self) -> Type; fn get_type(&self) -> Type;
} }
#[derive(Debug, Clone, PartialEq, Hash, Eq)] #[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
pub struct ConditionalExpression<'ast, T, E> { pub struct ConditionalExpression<'ast, T, E> {
#[serde(borrow)]
pub condition: Box<BooleanExpression<'ast, T>>, pub condition: Box<BooleanExpression<'ast, T>>,
pub consequence: Box<E>, pub consequence: Box<E>,
pub alternative: Box<E>, pub alternative: Box<E>,
@ -235,9 +313,10 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpress
} }
} }
#[derive(Debug, Clone, PartialEq, Hash, Eq)] #[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
pub struct SelectExpression<'ast, T, E> { pub struct SelectExpression<'ast, T, E> {
pub array: Vec<E>, pub array: Vec<E>,
#[serde(borrow)]
pub index: Box<UExpression<'ast, T>>, pub index: Box<UExpression<'ast, T>>,
} }
@ -266,11 +345,11 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for SelectExpression<'
} }
/// A typed expression /// A typed expression
#[derive(Clone, PartialEq, Hash, Eq)] #[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
pub enum ZirExpression<'ast, T> { pub enum ZirExpression<'ast, T> {
Boolean(BooleanExpression<'ast, T>), Boolean(BooleanExpression<'ast, T>),
FieldElement(FieldElementExpression<'ast, T>), FieldElement(FieldElementExpression<'ast, T>),
Uint(UExpression<'ast, T>), Uint(#[serde(borrow)] UExpression<'ast, T>),
} }
impl<'ast, T: Field> From<BooleanExpression<'ast, T>> for ZirExpression<'ast, T> { impl<'ast, T: Field> From<BooleanExpression<'ast, T>> for ZirExpression<'ast, T> {
@ -343,15 +422,20 @@ pub trait MultiTyped {
fn get_types(&self) -> &Vec<Type>; fn get_types(&self) -> &Vec<Type>;
} }
#[derive(Clone, PartialEq, Hash, Eq)] #[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
pub enum ZirExpressionList<'ast, T> { pub enum ZirExpressionList<'ast, T> {
EmbedCall(FlatEmbed, Vec<u32>, Vec<ZirExpression<'ast, T>>), EmbedCall(
FlatEmbed,
Vec<u32>,
#[serde(borrow)] Vec<ZirExpression<'ast, T>>,
),
} }
/// An expression of type `field` /// An expression of type `field`
#[derive(Clone, PartialEq, Hash, Eq, Debug)] #[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)]
pub enum FieldElementExpression<'ast, T> { pub enum FieldElementExpression<'ast, T> {
Number(T), Number(T),
#[serde(borrow)]
Identifier(Identifier<'ast>), Identifier(Identifier<'ast>),
Select(SelectExpression<'ast, T, Self>), Select(SelectExpression<'ast, T, Self>),
Add( Add(
@ -372,15 +456,16 @@ pub enum FieldElementExpression<'ast, T> {
), ),
Pow( Pow(
Box<FieldElementExpression<'ast, T>>, Box<FieldElementExpression<'ast, T>>,
Box<UExpression<'ast, T>>, #[serde(borrow)] Box<UExpression<'ast, T>>,
), ),
Conditional(ConditionalExpression<'ast, T, FieldElementExpression<'ast, T>>), Conditional(ConditionalExpression<'ast, T, FieldElementExpression<'ast, T>>),
} }
/// An expression of type `bool` /// An expression of type `bool`
#[derive(Clone, PartialEq, Hash, Eq, Debug)] #[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)]
pub enum BooleanExpression<'ast, T> { pub enum BooleanExpression<'ast, T> {
Value(bool), Value(bool),
#[serde(borrow)]
Identifier(Identifier<'ast>), Identifier(Identifier<'ast>),
Select(SelectExpression<'ast, T, Self>), Select(SelectExpression<'ast, T, Self>),
FieldLt( FieldLt(

View file

@ -1,8 +1,10 @@
use crate::zir::Variable; use crate::zir::Variable;
use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
#[derive(Clone, PartialEq, Eq)] #[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
pub struct Parameter<'ast> { pub struct Parameter<'ast> {
#[serde(borrow)]
pub id: Variable<'ast>, pub id: Variable<'ast>,
pub private: bool, pub private: bool,
} }

View file

@ -61,6 +61,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
self.fold_variable(a) self.fold_variable(a)
} }
fn fold_assembly_statement(
&mut self,
s: ZirAssemblyStatement<'ast, T>,
) -> Result<ZirAssemblyStatement<'ast, T>, Self::Error> {
fold_assembly_statement(self, s)
}
fn fold_statement( fn fold_statement(
&mut self, &mut self,
s: ZirStatement<'ast, T>, s: ZirStatement<'ast, T>,
@ -144,6 +151,26 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_uint_expression_inner(self, bitwidth, e) fold_uint_expression_inner(self, bitwidth, e)
} }
} }
pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
s: ZirAssemblyStatement<'ast, T>,
) -> Result<ZirAssemblyStatement<'ast, T>, F::Error> {
Ok(match s {
ZirAssemblyStatement::Assignment(assignees, function) => {
let assignees = assignees
.into_iter()
.map(|a| f.fold_assignee(a))
.collect::<Result<Vec<_>, _>>()?;
let function = f.fold_function(function)?;
ZirAssemblyStatement::Assignment(assignees, function)
}
ZirAssemblyStatement::Constraint(lhs, rhs) => {
let lhs = f.fold_field_expression(lhs)?;
let rhs = f.fold_field_expression(rhs)?;
ZirAssemblyStatement::Constraint(lhs, rhs)
}
})
}
pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F, f: &mut F,
@ -199,6 +226,13 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
ZirStatement::Log(l, e) ZirStatement::Log(l, e)
} }
ZirStatement::Assembly(statements) => {
let statements = statements
.into_iter()
.map(|s| f.fold_assembly_statement(s))
.collect::<Result<Vec<_>, _>>()?;
ZirStatement::Assembly(statements)
}
}; };
Ok(vec![res]) Ok(vec![res])
} }

View file

@ -1,5 +1,6 @@
use crate::zir::identifier::Identifier; use crate::zir::identifier::Identifier;
use crate::zir::types::UBitwidth; use crate::zir::types::UBitwidth;
use serde::{Deserialize, Serialize};
use zokrates_field::Field; use zokrates_field::Field;
use super::{ConditionalExpression, SelectExpression}; use super::{ConditionalExpression, SelectExpression};
@ -91,7 +92,7 @@ impl<'ast, T> From<u32> for UExpression<'ast, T> {
} }
} }
#[derive(Debug, PartialEq, Eq, Clone, Hash)] #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)]
pub enum ShouldReduce { pub enum ShouldReduce {
Unknown, Unknown,
True, True,
@ -135,7 +136,7 @@ impl ShouldReduce {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct UMetadata<T> { pub struct UMetadata<T> {
pub max: T, pub max: T,
pub should_reduce: ShouldReduce, pub should_reduce: ShouldReduce,
@ -162,16 +163,18 @@ impl<T: Field> UMetadata<T> {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct UExpression<'ast, T> { pub struct UExpression<'ast, T> {
pub bitwidth: UBitwidth, pub bitwidth: UBitwidth,
pub metadata: Option<UMetadata<T>>, pub metadata: Option<UMetadata<T>>,
#[serde(borrow)]
pub inner: UExpressionInner<'ast, T>, pub inner: UExpressionInner<'ast, T>,
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum UExpressionInner<'ast, T> { pub enum UExpressionInner<'ast, T> {
Value(u128), Value(u128),
#[serde(borrow)]
Identifier(Identifier<'ast>), Identifier(Identifier<'ast>),
Select(SelectExpression<'ast, T, UExpression<'ast, T>>), Select(SelectExpression<'ast, T, UExpression<'ast, T>>),
Add(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>), Add(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),

View file

@ -1,9 +1,11 @@
use crate::zir::types::{Type, UBitwidth}; use crate::zir::types::{Type, UBitwidth};
use crate::zir::Identifier; use crate::zir::Identifier;
use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
#[derive(Clone, PartialEq, Hash, Eq)] #[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
pub struct Variable<'ast> { pub struct Variable<'ast> {
#[serde(borrow)]
pub id: Identifier<'ast>, pub id: Identifier<'ast>,
pub _type: Type, pub _type: Type,
} }

View file

@ -21,8 +21,8 @@ use zokrates_proof_systems::Scheme;
const G16_WARNING: &str = "WARNING: You are using the G16 scheme which is subject to malleability. See zokrates.github.io/toolbox/proving_schemes.html#g16-malleability for implications."; const G16_WARNING: &str = "WARNING: You are using the G16 scheme which is subject to malleability. See zokrates.github.io/toolbox/proving_schemes.html#g16-malleability for implications.";
impl<T: Field + BellmanFieldExtensions> Backend<T, G16> for Bellman { impl<T: Field + BellmanFieldExtensions> Backend<T, G16> for Bellman {
fn generate_proof<I: IntoIterator<Item = Statement<T>>>( fn generate_proof<'a, I: IntoIterator<Item = Statement<'a, T>>>(
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
witness: Witness<T>, witness: Witness<T>,
proving_key: Vec<u8>, proving_key: Vec<u8>,
) -> Proof<T, G16> { ) -> Proof<T, G16> {
@ -84,8 +84,8 @@ impl<T: Field + BellmanFieldExtensions> Backend<T, G16> for Bellman {
} }
impl<T: Field + BellmanFieldExtensions> NonUniversalBackend<T, G16> for Bellman { impl<T: Field + BellmanFieldExtensions> NonUniversalBackend<T, G16> for Bellman {
fn setup<I: IntoIterator<Item = Statement<T>>>( fn setup<'a, I: IntoIterator<Item = Statement<'a, T>>>(
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
) -> SetupKeypair<T, G16> { ) -> SetupKeypair<T, G16> {
println!("{}", G16_WARNING); println!("{}", G16_WARNING);
@ -99,8 +99,8 @@ impl<T: Field + BellmanFieldExtensions> NonUniversalBackend<T, G16> for Bellman
} }
impl<T: Field + BellmanFieldExtensions> MpcBackend<T, G16> for Bellman { impl<T: Field + BellmanFieldExtensions> MpcBackend<T, G16> for Bellman {
fn initialize<R: Read, W: Write, I: IntoIterator<Item = Statement<T>>>( fn initialize<'a, R: Read, W: Write, I: IntoIterator<Item = Statement<'a, T>>>(
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
phase1_radix: &mut R, phase1_radix: &mut R,
output: &mut W, output: &mut W,
) -> Result<(), String> { ) -> Result<(), String> {
@ -124,9 +124,9 @@ impl<T: Field + BellmanFieldExtensions> MpcBackend<T, G16> for Bellman {
Ok(hash) Ok(hash)
} }
fn verify<P: Read, R: Read, I: IntoIterator<Item = Statement<T>>>( fn verify<'a, P: Read, R: Read, I: IntoIterator<Item = Statement<'a, T>>>(
params: &mut P, params: &mut P,
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
phase1_radix: &mut R, phase1_radix: &mut R,
) -> Result<Vec<[u8; 64]>, String> { ) -> Result<Vec<[u8; 64]>, String> {
let params = let params =

View file

@ -22,20 +22,20 @@ pub use self::parse::*;
pub struct Bellman; pub struct Bellman;
#[derive(Clone)] #[derive(Clone)]
pub struct Computation<T, I: IntoIterator<Item = Statement<T>>> { pub struct Computation<'a, T, I: IntoIterator<Item = Statement<'a, T>>> {
program: ProgIterator<T, I>, program: ProgIterator<'a, T, I>,
witness: Option<Witness<T>>, witness: Option<Witness<T>>,
} }
impl<T: Field, I: IntoIterator<Item = Statement<T>>> Computation<T, I> { impl<'a, T: Field, I: IntoIterator<Item = Statement<'a, T>>> Computation<'a, T, I> {
pub fn with_witness(program: ProgIterator<T, I>, witness: Witness<T>) -> Self { pub fn with_witness(program: ProgIterator<'a, T, I>, witness: Witness<T>) -> Self {
Computation { Computation {
program, program,
witness: Some(witness), witness: Some(witness),
} }
} }
pub fn without_witness(program: ProgIterator<T, I>) -> Self { pub fn without_witness(program: ProgIterator<'a, T, I>) -> Self {
Computation { Computation {
program, program,
witness: None, witness: None,
@ -83,8 +83,8 @@ fn bellman_combination<T: BellmanFieldExtensions, CS: ConstraintSystem<T::Bellma
.fold(LinearCombination::zero(), |acc, e| acc + e) .fold(LinearCombination::zero(), |acc, e| acc + e)
} }
impl<T: BellmanFieldExtensions + Field, I: IntoIterator<Item = Statement<T>>> impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator<Item = Statement<'a, T>>>
Circuit<T::BellmanEngine> for Computation<T, I> Circuit<T::BellmanEngine> for Computation<'a, T, I>
{ {
fn synthesize<CS: ConstraintSystem<T::BellmanEngine>>( fn synthesize<CS: ConstraintSystem<T::BellmanEngine>>(
self, self,
@ -148,7 +148,9 @@ impl<T: BellmanFieldExtensions + Field, I: IntoIterator<Item = Statement<T>>>
} }
} }
impl<T: BellmanFieldExtensions + Field, I: IntoIterator<Item = Statement<T>>> Computation<T, I> { impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator<Item = Statement<'a, T>>>
Computation<'a, T, I>
{
fn get_random_seed(&self) -> Result<[u32; 8], getrandom::Error> { fn get_random_seed(&self) -> Result<[u32; 8], getrandom::Error> {
let mut seed = [0u8; 32]; let mut seed = [0u8; 32];
getrandom::getrandom(&mut seed)?; getrandom::getrandom(&mut seed)?;

View file

@ -85,8 +85,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
} }
} }
fn cli_compute<T: Field, I: Iterator<Item = ir::Statement<T>>>( fn cli_compute<'a, T: Field, I: Iterator<Item = ir::Statement<'a, T>>>(
ir_prog: ir::ProgIterator<T, I>, ir_prog: ir::ProgIterator<'a, T, I>,
sub_matches: &ArgMatches, sub_matches: &ArgMatches,
) -> Result<(), String> { ) -> Result<(), String> {
println!("Computing witness..."); println!("Computing witness...");

View file

@ -136,12 +136,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
} }
fn cli_generate_proof< fn cli_generate_proof<
'a,
T: Field, T: Field,
I: Iterator<Item = ir::Statement<T>>, I: Iterator<Item = ir::Statement<'a, T>>,
S: Scheme<T>, S: Scheme<T>,
B: Backend<T, S>, B: Backend<T, S>,
>( >(
program: ir::ProgIterator<T, I>, program: ir::ProgIterator<'a, T, I>,
sub_matches: &ArgMatches, sub_matches: &ArgMatches,
) -> Result<(), String> { ) -> Result<(), String> {
println!("Generating proof..."); println!("Generating proof...");

View file

@ -47,8 +47,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
} }
} }
fn cli_smtlib2<T: Field, I: Iterator<Item = ir::Statement<T>>>( fn cli_smtlib2<'a, T: Field, I: Iterator<Item = ir::Statement<'a, T>>>(
ir_prog: ir::ProgIterator<T, I>, ir_prog: ir::ProgIterator<'a, T, I>,
sub_matches: &ArgMatches, sub_matches: &ArgMatches,
) -> Result<(), String> { ) -> Result<(), String> {
println!("Generating SMTLib2..."); println!("Generating SMTLib2...");

View file

@ -43,8 +43,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
} }
} }
fn cli_inspect<T: Field, I: Iterator<Item = ir::Statement<T>>>( fn cli_inspect<'a, T: Field, I: Iterator<Item = ir::Statement<'a, T>>>(
ir_prog: ir::ProgIterator<T, I>, ir_prog: ir::ProgIterator<'a, T, I>,
sub_matches: &ArgMatches, sub_matches: &ArgMatches,
) -> Result<(), String> { ) -> Result<(), String> {
let ir_prog: ir::Prog<T> = ir_prog.collect(); let ir_prog: ir::Prog<T> = ir_prog.collect();

View file

@ -58,12 +58,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
} }
fn cli_mpc_init< fn cli_mpc_init<
'a,
T: Field + BellmanFieldExtensions, T: Field + BellmanFieldExtensions,
I: Iterator<Item = ir::Statement<T>>, I: Iterator<Item = ir::Statement<'a, T>>,
S: MpcScheme<T>, S: MpcScheme<T>,
B: MpcBackend<T, S>, B: MpcBackend<T, S>,
>( >(
program: ir::ProgIterator<T, I>, program: ir::ProgIterator<'a, T, I>,
sub_matches: &ArgMatches, sub_matches: &ArgMatches,
) -> Result<(), String> { ) -> Result<(), String> {
println!("Initializing MPC..."); println!("Initializing MPC...");

View file

@ -58,12 +58,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
} }
fn cli_mpc_verify< fn cli_mpc_verify<
'a,
T: Field + BellmanFieldExtensions, T: Field + BellmanFieldExtensions,
I: Iterator<Item = ir::Statement<T>>, I: Iterator<Item = ir::Statement<'a, T>>,
S: MpcScheme<T>, S: MpcScheme<T>,
B: MpcBackend<T, S>, B: MpcBackend<T, S>,
>( >(
program: ir::ProgIterator<T, I>, program: ir::ProgIterator<'a, T, I>,
sub_matches: &ArgMatches, sub_matches: &ArgMatches,
) -> Result<(), String> { ) -> Result<(), String> {
println!("Verifying contributions..."); println!("Verifying contributions...");

View file

@ -167,12 +167,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
} }
fn cli_setup_non_universal< fn cli_setup_non_universal<
'a,
T: Field, T: Field,
I: Iterator<Item = ir::Statement<T>>, I: Iterator<Item = ir::Statement<'a, T>>,
S: NonUniversalScheme<T>, S: NonUniversalScheme<T>,
B: NonUniversalBackend<T, S>, B: NonUniversalBackend<T, S>,
>( >(
program: ir::ProgIterator<T, I>, program: ir::ProgIterator<'a, T, I>,
sub_matches: &ArgMatches, sub_matches: &ArgMatches,
) -> Result<(), String> { ) -> Result<(), String> {
println!("Performing setup..."); println!("Performing setup...");
@ -211,12 +212,13 @@ fn cli_setup_non_universal<
} }
fn cli_setup_universal< fn cli_setup_universal<
'a,
T: Field, T: Field,
I: Iterator<Item = ir::Statement<T>>, I: Iterator<Item = ir::Statement<'a, T>>,
S: UniversalScheme<T>, S: UniversalScheme<T>,
B: UniversalBackend<T, S>, B: UniversalBackend<T, S>,
>( >(
program: ir::ProgIterator<T, I>, program: ir::ProgIterator<'a, T, I>,
srs: Vec<u8>, srs: Vec<u8>,
sub_matches: &ArgMatches, sub_matches: &ArgMatches,
) -> Result<(), String> { ) -> Result<(), String> {

View file

@ -25,13 +25,13 @@ use zokrates_field::Field;
use zokrates_pest_ast as pest; use zokrates_pest_ast as pest;
#[derive(Debug)] #[derive(Debug)]
pub struct CompilationArtifacts<T, I: IntoIterator<Item = ir::Statement<T>>> { pub struct CompilationArtifacts<'ast, T, I: IntoIterator<Item = ir::Statement<'ast, T>>> {
prog: ir::ProgIterator<T, I>, prog: ir::ProgIterator<'ast, T, I>,
abi: Abi, abi: Abi,
} }
impl<T, I: IntoIterator<Item = ir::Statement<T>>> CompilationArtifacts<T, I> { impl<'ast, T, I: IntoIterator<Item = ir::Statement<'ast, T>>> CompilationArtifacts<'ast, T, I> {
pub fn prog(self) -> ir::ProgIterator<T, I> { pub fn prog(self) -> ir::ProgIterator<'ast, T, I> {
self.prog self.prog
} }
@ -39,11 +39,11 @@ impl<T, I: IntoIterator<Item = ir::Statement<T>>> CompilationArtifacts<T, I> {
&self.abi &self.abi
} }
pub fn into_inner(self) -> (ir::ProgIterator<T, I>, Abi) { pub fn into_inner(self) -> (ir::ProgIterator<'ast, T, I>, Abi) {
(self.prog, self.abi) (self.prog, self.abi)
} }
pub fn collect(self) -> CompilationArtifacts<T, Vec<ir::Statement<T>>> { pub fn collect(self) -> CompilationArtifacts<'ast, T, Vec<ir::Statement<'ast, T>>> {
CompilationArtifacts { CompilationArtifacts {
prog: self.prog.collect(), prog: self.prog.collect(),
abi: self.abi, abi: self.abi,
@ -201,8 +201,10 @@ pub fn compile<'ast, T: Field, E: Into<imports::Error>>(
resolver: Option<&dyn Resolver<E>>, resolver: Option<&dyn Resolver<E>>,
config: CompileConfig, config: CompileConfig,
arena: &'ast Arena<String>, arena: &'ast Arena<String>,
) -> Result<CompilationArtifacts<T, impl IntoIterator<Item = ir::Statement<T>> + 'ast>, CompileErrors> ) -> Result<
{ CompilationArtifacts<'ast, T, impl IntoIterator<Item = ir::Statement<'ast, T>> + 'ast>,
CompileErrors,
> {
let (typed_ast, abi): (zokrates_ast::zir::ZirProgram<'_, T>, _) = let (typed_ast, abi): (zokrates_ast::zir::ZirProgram<'_, T>, _) =
check_with_arena(source, location, resolver, &config, arena)?; check_with_arena(source, location, resolver, &config, arena)?;

View file

@ -9,7 +9,8 @@ mod utils;
use self::utils::flat_expression_from_bits; use self::utils::flat_expression_from_bits;
use zokrates_ast::zir::{ use zokrates_ast::zir::{
ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirExpressionList, ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement,
ZirExpressionList,
}; };
use zokrates_interpreter::Interpreter; use zokrates_interpreter::Interpreter;
@ -32,7 +33,7 @@ use zokrates_ast::zir::{
}; };
use zokrates_field::Field; use zokrates_field::Field;
type FlatStatements<T> = VecDeque<FlatStatement<T>>; type FlatStatements<'ast, T> = VecDeque<FlatStatement<'ast, T>>;
/// Flattens a function /// Flattens a function
/// ///
@ -64,14 +65,14 @@ pub fn from_function_and_config<T: Field>(
pub struct FlattenerIteratorInner<'ast, T> { pub struct FlattenerIteratorInner<'ast, T> {
pub statements: VecDeque<ZirStatement<'ast, T>>, pub statements: VecDeque<ZirStatement<'ast, T>>,
pub statements_flattened: FlatStatements<T>, pub statements_flattened: FlatStatements<'ast, T>,
pub flattener: Flattener<'ast, T>, pub flattener: Flattener<'ast, T>,
} }
pub type FlattenerIterator<'ast, T> = FlatProgIterator<T, FlattenerIteratorInner<'ast, T>>; pub type FlattenerIterator<'ast, T> = FlatProgIterator<'ast, T, FlattenerIteratorInner<'ast, T>>;
impl<'ast, T: Field> Iterator for FlattenerIteratorInner<'ast, T> { impl<'ast, T: Field> Iterator for FlattenerIteratorInner<'ast, T> {
type Item = FlatStatement<T>; type Item = FlatStatement<'ast, T>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
while self.statements_flattened.is_empty() { while self.statements_flattened.is_empty() {
@ -127,7 +128,7 @@ trait Flatten<'ast, T: Field>:
fn flatten( fn flatten(
self, self,
flattener: &mut Flattener<'ast, T>, flattener: &mut Flattener<'ast, T>,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
) -> Self::Output; ) -> Self::Output;
} }
@ -137,7 +138,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> {
fn flatten( fn flatten(
self, self,
flattener: &mut Flattener<'ast, T>, flattener: &mut Flattener<'ast, T>,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
) -> Self::Output { ) -> Self::Output {
flattener.flatten_field_expression(statements_flattened, self) flattener.flatten_field_expression(statements_flattened, self)
} }
@ -149,7 +150,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for UExpression<'ast, T> {
fn flatten( fn flatten(
self, self,
flattener: &mut Flattener<'ast, T>, flattener: &mut Flattener<'ast, T>,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
) -> Self::Output { ) -> Self::Output {
flattener.flatten_uint_expression(statements_flattened, self) flattener.flatten_uint_expression(statements_flattened, self)
} }
@ -161,7 +162,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> {
fn flatten( fn flatten(
self, self,
flattener: &mut Flattener<'ast, T>, flattener: &mut Flattener<'ast, T>,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
) -> Self::Output { ) -> Self::Output {
flattener.flatten_boolean_expression(statements_flattened, self) flattener.flatten_boolean_expression(statements_flattened, self)
} }
@ -227,7 +228,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn define( fn define(
&mut self, &mut self,
e: FlatExpression<T>, e: FlatExpression<T>,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
) -> Variable { ) -> Variable {
match e { match e {
FlatExpression::Identifier(id) => id, FlatExpression::Identifier(id) => id,
@ -276,7 +277,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
#[must_use] #[must_use]
fn constant_le_check( fn constant_le_check(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
a: &[FlatExpression<T>], a: &[FlatExpression<T>],
b: &[bool], b: &[bool],
) -> Vec<FlatExpression<T>> { ) -> Vec<FlatExpression<T>> {
@ -381,7 +382,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise /// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise
fn eq_check( fn eq_check(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
left: FlatExpression<T>, left: FlatExpression<T>,
right: FlatExpression<T>, right: FlatExpression<T>,
) -> FlatExpression<T> { ) -> FlatExpression<T> {
@ -434,7 +435,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * `b` - the big-endian bit decomposition of the upper bound of the range /// * `b` - the big-endian bit decomposition of the upper bound of the range
fn enforce_constant_le_check_bits( fn enforce_constant_le_check_bits(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
a: &[FlatExpression<T>], a: &[FlatExpression<T>],
c: &[bool], c: &[bool],
error: RuntimeError, error: RuntimeError,
@ -464,7 +465,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * `c` - the constant upper bound of the range /// * `c` - the constant upper bound of the range
fn enforce_constant_le_check( fn enforce_constant_le_check(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
e: FlatExpression<T>, e: FlatExpression<T>,
c: T, c: T,
error: RuntimeError, error: RuntimeError,
@ -500,7 +501,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * `c` - the constant upper bound of the range /// * `c` - the constant upper bound of the range
fn enforce_constant_lt_check( fn enforce_constant_lt_check(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
e: FlatExpression<T>, e: FlatExpression<T>,
c: T, c: T,
error: RuntimeError, error: RuntimeError,
@ -519,9 +520,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn make_conditional( fn make_conditional(
&mut self, &mut self,
statements: FlatStatements<T>, statements: FlatStatements<'ast, T>,
condition: FlatExpression<T>, condition: FlatExpression<T>,
) -> FlatStatements<T> { ) -> FlatStatements<'ast, T> {
statements statements
.into_iter() .into_iter()
.flat_map(|s| match s { .flat_map(|s| match s {
@ -582,7 +583,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * U is the type of the expression /// * U is the type of the expression
fn flatten_conditional_expression<U: Flatten<'ast, T>>( fn flatten_conditional_expression<U: Flatten<'ast, T>>(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
e: ConditionalExpression<'ast, T, U>, e: ConditionalExpression<'ast, T, U>,
) -> FlatUExpression<T> { ) -> FlatUExpression<T> {
let condition = *e.condition; let condition = *e.condition;
@ -680,7 +681,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * a `FlatExpression` which evaluates to `1` if `0 <= e < c`, and to `0` otherwise /// * a `FlatExpression` which evaluates to `1` if `0 <= e < c`, and to `0` otherwise
fn constant_lt_check( fn constant_lt_check(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
e: FlatExpression<T>, e: FlatExpression<T>,
c: T, c: T,
) -> FlatExpression<T> { ) -> FlatExpression<T> {
@ -704,7 +705,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * a `FlatExpression` which evaluates to `1` if `0 <= e <= c`, and to `0` otherwise /// * a `FlatExpression` which evaluates to `1` if `0 <= e <= c`, and to `0` otherwise
fn constant_field_le_check( fn constant_field_le_check(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
e: FlatExpression<T>, e: FlatExpression<T>,
c: T, c: T,
) -> FlatExpression<T> { ) -> FlatExpression<T> {
@ -745,7 +746,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
#[must_use] #[must_use]
fn le_check( fn le_check(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
lhs_flattened: FlatExpression<T>, lhs_flattened: FlatExpression<T>,
rhs_flattened: FlatExpression<T>, rhs_flattened: FlatExpression<T>,
bit_width: usize, bit_width: usize,
@ -768,7 +769,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
#[must_use] #[must_use]
fn lt_check( fn lt_check(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
lhs_flattened: FlatExpression<T>, lhs_flattened: FlatExpression<T>,
rhs_flattened: FlatExpression<T>, rhs_flattened: FlatExpression<T>,
bit_width: usize, bit_width: usize,
@ -827,7 +828,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * in order to preserve composability. /// * in order to preserve composability.
fn flatten_boolean_expression( fn flatten_boolean_expression(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
expression: BooleanExpression<'ast, T>, expression: BooleanExpression<'ast, T>,
) -> FlatExpression<T> { ) -> FlatExpression<T> {
match expression { match expression {
@ -1033,7 +1034,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * `param_expressions` - Arguments of this call /// * `param_expressions` - Arguments of this call
fn flatten_embed_call( fn flatten_embed_call(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
embed: FlatEmbed, embed: FlatEmbed,
generics: Vec<u32>, generics: Vec<u32>,
param_expressions: Vec<ZirExpression<'ast, T>>, param_expressions: Vec<ZirExpression<'ast, T>>,
@ -1134,9 +1135,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn flatten_embed_call_aux( fn flatten_embed_call_aux(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
params: Vec<FlatUExpression<T>>, params: Vec<FlatUExpression<T>>,
funct: FlatFunctionIterator<T, impl IntoIterator<Item = FlatStatement<T>>>, funct: FlatFunctionIterator<'ast, T, impl IntoIterator<Item = FlatStatement<'ast, T>>>,
) -> Vec<FlatUExpression<T>> { ) -> Vec<FlatUExpression<T>> {
let mut replacement_map = HashMap::new(); let mut replacement_map = HashMap::new();
@ -1219,7 +1220,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * `expr` - `ZirExpression` that will be flattened. /// * `expr` - `ZirExpression` that will be flattened.
fn flatten_expression( fn flatten_expression(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
expr: ZirExpression<'ast, T>, expr: ZirExpression<'ast, T>,
) -> FlatUExpression<T> { ) -> FlatUExpression<T> {
match expr { match expr {
@ -1235,7 +1236,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn default_xor( fn default_xor(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
left: UExpression<'ast, T>, left: UExpression<'ast, T>,
right: UExpression<'ast, T>, right: UExpression<'ast, T>,
) -> FlatUExpression<T> { ) -> FlatUExpression<T> {
@ -1296,7 +1297,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn euclidean_division( fn euclidean_division(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
target_bitwidth: UBitwidth, target_bitwidth: UBitwidth,
left: UExpression<'ast, T>, left: UExpression<'ast, T>,
right: UExpression<'ast, T>, right: UExpression<'ast, T>,
@ -1382,7 +1383,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * `expr` - `UExpression` that will be flattened. /// * `expr` - `UExpression` that will be flattened.
fn flatten_uint_expression( fn flatten_uint_expression(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
expr: UExpression<'ast, T>, expr: UExpression<'ast, T>,
) -> FlatUExpression<T> { ) -> FlatUExpression<T> {
// the bitwidth for this type of uint (8, 16 or 32) // the bitwidth for this type of uint (8, 16 or 32)
@ -1875,7 +1876,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
e: &FlatUExpression<T>, e: &FlatUExpression<T>,
from: usize, from: usize,
to: usize, to: usize,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
error: RuntimeError, error: RuntimeError,
) -> Vec<FlatExpression<T>> { ) -> Vec<FlatExpression<T>> {
assert!(from <= T::get_required_bits()); assert!(from <= T::get_required_bits());
@ -1969,7 +1970,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn flatten_select_expression<U: Flatten<'ast, T>>( fn flatten_select_expression<U: Flatten<'ast, T>>(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
e: SelectExpression<'ast, T, U>, e: SelectExpression<'ast, T, U>,
) -> FlatUExpression<T> { ) -> FlatUExpression<T> {
let array = e.array; let array = e.array;
@ -2033,7 +2034,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * `expr` - `FieldElementExpression` that will be flattened. /// * `expr` - `FieldElementExpression` that will be flattened.
fn flatten_field_expression( fn flatten_field_expression(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
expr: FieldElementExpression<'ast, T>, expr: FieldElementExpression<'ast, T>,
) -> FlatExpression<T> { ) -> FlatExpression<T> {
match expr { match expr {
@ -2221,6 +2222,35 @@ impl<'ast, T: Field> Flattener<'ast, T> {
} }
} }
fn flatten_assembly_statement(
&mut self,
statements_flattened: &mut FlatStatements<'ast, T>,
stat: ZirAssemblyStatement<'ast, T>,
) {
match stat {
ZirAssemblyStatement::Assignment(assignees, function) => {
let outputs: Vec<Variable> = assignees
.iter()
.map(|a| self.use_variable(a)) /*self.layout.get(&a.id).cloned().unwrap()*/
.collect();
let inputs: Vec<FlatExpression<T>> = function
.arguments
.iter()
.cloned()
.map(|p| self.layout.get(&p.id.id).cloned().unwrap().into())
.collect();
let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs);
statements_flattened.push_back(FlatStatement::Directive(directive));
}
ZirAssemblyStatement::Constraint(lhs, rhs) => {
let lhs = self.flatten_field_expression(statements_flattened, lhs);
let rhs = self.flatten_field_expression(statements_flattened, rhs);
self.flatten_equality_assertion(statements_flattened, lhs, rhs, RuntimeError::UnsatisfiedConstraint)
}
}
}
/// Flattens a statement /// Flattens a statement
/// ///
/// # Arguments /// # Arguments
@ -2229,10 +2259,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// * `stat` - `ZirStatement` that will be flattened. /// * `stat` - `ZirStatement` that will be flattened.
fn flatten_statement( fn flatten_statement(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
stat: ZirStatement<'ast, T>, stat: ZirStatement<'ast, T>,
) { ) {
match stat { match stat {
ZirStatement::Assembly(statements) => {
for s in statements {
self.flatten_assembly_statement(statements_flattened, s);
}
}
ZirStatement::Return(exprs) => { ZirStatement::Return(exprs) => {
#[allow(clippy::needless_collect)] #[allow(clippy::needless_collect)]
// clippy suggests to not collect here, but `statements_flattened` is borrowed in the iterator, // clippy suggests to not collect here, but `statements_flattened` is borrowed in the iterator,
@ -2633,12 +2668,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `statements_flattened` - `FlatStatements<T>` Vector where new flattened statements can be added. /// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added.
/// * `lhs` - `FlatExpression<T>` Left-hand side of the equality expression. /// * `lhs` - `FlatExpression<T>` Left-hand side of the equality expression.
/// * `rhs` - `FlatExpression<T>` Right-hand side of the equality expression. /// * `rhs` - `FlatExpression<T>` Right-hand side of the equality expression.
fn flatten_equality_assertion( fn flatten_equality_assertion(
&mut self, &mut self,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
lhs: FlatExpression<T>, lhs: FlatExpression<T>,
rhs: FlatExpression<T>, rhs: FlatExpression<T>,
error: RuntimeError, error: RuntimeError,
@ -2667,11 +2702,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
/// # Arguments /// # Arguments
/// ///
/// * `e` - `FlatExpression<T>` Expression to be assigned to an identifier. /// * `e` - `FlatExpression<T>` Expression to be assigned to an identifier.
/// * `statements_flattened` - `FlatStatements<T>` Vector where new flattened statements can be added. /// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added.
fn identify_expression( fn identify_expression(
&mut self, &mut self,
e: FlatExpression<T>, e: FlatExpression<T>,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
) -> FlatExpression<T> { ) -> FlatExpression<T> {
match e.is_linear() { match e.is_linear() {
true => e, true => e,
@ -2710,7 +2745,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn use_parameter( fn use_parameter(
&mut self, &mut self,
parameter: &ZirParameter<'ast>, parameter: &ZirParameter<'ast>,
statements_flattened: &mut FlatStatements<T>, statements_flattened: &mut FlatStatements<'ast, T>,
) -> Parameter { ) -> Parameter {
let variable = self.use_variable(&parameter.id); let variable = self.use_variable(&parameter.id);

View file

@ -4,7 +4,7 @@ use zokrates_field::Field;
#[derive(Default)] #[derive(Default)]
pub struct Canonicalizer; pub struct Canonicalizer;
impl<T: Field> Folder<T> for Canonicalizer { impl<'ast, T: Field> Folder<'ast, T> for Canonicalizer {
fn fold_linear_combination(&mut self, l: LinComb<T>) -> LinComb<T> { fn fold_linear_combination(&mut self, l: LinComb<T>) -> LinComb<T> {
l.into_canonical().into() l.into_canonical().into()
} }

View file

@ -15,18 +15,18 @@ use zokrates_ast::ir::*;
use zokrates_field::Field; use zokrates_field::Field;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct DirectiveOptimizer<T> { pub struct DirectiveOptimizer<'ast, T> {
calls: HashMap<(Solver, Vec<QuadComb<T>>), Vec<Variable>>, calls: HashMap<(Solver<'ast, T>, Vec<QuadComb<T>>), Vec<Variable>>,
/// Map of renamings for reassigned variables while processing the program. /// Map of renamings for reassigned variables while processing the program.
substitution: HashMap<Variable, Variable>, substitution: HashMap<Variable, Variable>,
} }
impl<T: Field> Folder<T> for DirectiveOptimizer<T> { impl<'ast, T: Field> Folder<'ast, T> for DirectiveOptimizer<'ast, T> {
fn fold_variable(&mut self, v: Variable) -> Variable { fn fold_variable(&mut self, v: Variable) -> Variable {
*self.substitution.get(&v).unwrap_or(&v) *self.substitution.get(&v).unwrap_or(&v)
} }
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> { fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
match s { match s {
Statement::Directive(d) => { Statement::Directive(d) => {
let d = self.fold_directive(d); let d = self.fold_directive(d);

View file

@ -21,8 +21,8 @@ pub struct DuplicateOptimizer {
seen: HashSet<Hash>, seen: HashSet<Hash>,
} }
impl<T: Field> Folder<T> for DuplicateOptimizer { impl<'ast, T: Field> Folder<'ast, T> for DuplicateOptimizer {
fn fold_program(&mut self, p: Prog<T>) -> Prog<T> { fn fold_program(&mut self, p: Prog<'ast, T>) -> Prog<'ast, T> {
// in order to correctly identify duplicates, we need to first canonicalize the statements // in order to correctly identify duplicates, we need to first canonicalize the statements
let mut canonicalizer = Canonicalizer; let mut canonicalizer = Canonicalizer;
@ -38,7 +38,7 @@ impl<T: Field> Folder<T> for DuplicateOptimizer {
fold_program(self, p) fold_program(self, p)
} }
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> { fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
let hashed = hash(&s); let hashed = hash(&s);
let result = match self.seen.get(&hashed) { let result = match self.seen.get(&hashed) {
Some(_) => vec![], Some(_) => vec![],

View file

@ -19,9 +19,9 @@ use self::tautology::TautologyOptimizer;
use zokrates_ast::ir::{ProgIterator, Statement}; use zokrates_ast::ir::{ProgIterator, Statement};
use zokrates_field::Field; use zokrates_field::Field;
pub fn optimize<T: Field, I: IntoIterator<Item = Statement<T>>>( pub fn optimize<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>>(
p: ProgIterator<T, I>, p: ProgIterator<'ast, T, I>,
) -> ProgIterator<T, impl IntoIterator<Item = Statement<T>>> { ) -> ProgIterator<'ast, T, impl IntoIterator<Item = Statement<'ast, T>>> {
// remove redefinitions // remove redefinitions
log::debug!("Optimizer: Remove redefinitions and tautologies and directives and duplicates"); log::debug!("Optimizer: Remove redefinitions and tautologies and directives and duplicates");

View file

@ -53,7 +53,9 @@ pub struct RedefinitionOptimizer<T> {
} }
impl<T> RedefinitionOptimizer<T> { impl<T> RedefinitionOptimizer<T> {
pub fn init<I: IntoIterator<Item = Statement<T>>>(p: &ProgIterator<T, I>) -> Self { pub fn init<'ast, I: IntoIterator<Item = Statement<'ast, T>>>(
p: &ProgIterator<'ast, T, I>,
) -> Self {
RedefinitionOptimizer { RedefinitionOptimizer {
substitution: HashMap::new(), substitution: HashMap::new(),
ignore: vec![Variable::one()] ignore: vec![Variable::one()]
@ -66,8 +68,8 @@ impl<T> RedefinitionOptimizer<T> {
} }
} }
impl<T: Field> Folder<T> for RedefinitionOptimizer<T> { impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer<T> {
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> { fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
match s { match s {
Statement::Constraint(quad, lin, message) => { Statement::Constraint(quad, lin, message) => {
let quad = self.fold_quadratic_combination(quad); let quad = self.fold_quadratic_combination(quad);

View file

@ -13,8 +13,8 @@ use zokrates_field::Field;
#[derive(Default)] #[derive(Default)]
pub struct TautologyOptimizer; pub struct TautologyOptimizer;
impl<T: Field> Folder<T> for TautologyOptimizer { impl<'ast, T: Field> Folder<'ast, T> for TautologyOptimizer {
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> { fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec<Statement<'ast, T>> {
match s { match s {
Statement::Constraint(quad, lin, message) => match quad.try_linear() { Statement::Constraint(quad, lin, message) => match quad.try_linear() {
Ok(l) => { Ok(l) => {

View file

@ -699,7 +699,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
is_mutable: false, is_mutable: false,
}; };
assert_eq!(self.scope.level, 0); assert_eq!(self.scope.level, 0);
assert!(!self.scope.insert(id, info)); assert!(!self.scope.insert(id.into(), info));
assert!(state assert!(state
.constants .constants
.entry(module_id.to_path_buf()) .entry(module_id.to_path_buf())
@ -895,7 +895,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
is_mutable: false, is_mutable: false,
}; };
assert_eq!(self.scope.level, 0); assert_eq!(self.scope.level, 0);
assert!(!self.scope.insert(id, info)); assert!(!self.scope.insert(id.into(), info));
state state
.constants .constants
@ -1130,12 +1130,12 @@ impl<'ast, T: Field> Checker<'ast, T> {
// for declaration signatures, generics cannot be ignored // for declaration signatures, generics cannot be ignored
generics.0.insert( generics.0.insert(
generic.clone(), generic.clone(),
UExpressionInner::Identifier(self.id_in_this_scope(generic.name()).into()) UExpressionInner::Identifier(self.id_in_this_scope(generic.name().into()).into())
.annotate(UBitwidth::B32), .annotate(UBitwidth::B32),
); );
//we don't have to check for conflicts here, because this was done when checking the signature //we don't have to check for conflicts here, because this was done when checking the signature
self.insert_into_scope(generic.name(), Type::Uint(UBitwidth::B32), false); self.insert_into_scope(generic.name().into(), Type::Uint(UBitwidth::B32), false);
} }
for (arg, decl_ty) in funct.arguments.into_iter().zip(s.inputs.iter()) { for (arg, decl_ty) in funct.arguments.into_iter().zip(s.inputs.iter()) {
@ -1144,7 +1144,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let arg = arg.value; let arg = arg.value;
let decl_v = DeclarationVariable::new( let decl_v = DeclarationVariable::new(
self.id_in_this_scope(arg.id.value.id), self.id_in_this_scope(arg.id.value.id.into()),
decl_ty.clone(), decl_ty.clone(),
arg.id.value.is_mutable, arg.id.value.is_mutable,
); );
@ -1161,7 +1161,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
ty, ty,
is_mutable, is_mutable,
}; };
match self.scope.insert(id, info) { match self.scope.insert(id.into(), info) {
false => {} false => {}
true => { true => {
errors.push(ErrorInner { errors.push(ErrorInner {
@ -1651,10 +1651,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
.map_err(|e| vec![e])?; .map_err(|e| vec![e])?;
// insert into the scope and ignore whether shadowing happened // insert into the scope and ignore whether shadowing happened
self.insert_into_scope(v.value.id, ty.clone(), v.value.is_mutable); self.insert_into_scope(v.value.id.into(), ty.clone(), v.value.is_mutable);
Ok(Variable::new( Ok(Variable::new(
self.id_in_this_scope(v.value.id), self.id_in_this_scope(v.value.id.into()),
ty, ty,
v.value.is_mutable, v.value.is_mutable,
)) ))
@ -1770,6 +1770,87 @@ impl<'ast, T: Field> Checker<'ast, T> {
} }
} }
fn check_assembly_statement(
&mut self,
stat: AssemblyStatementNode<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast, T>,
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, ErrorInner> {
let pos = stat.pos();
match stat.value {
AssemblyStatement::Assignment(assignee, expression, constrained) => {
let assignee = self.check_assignee(assignee, module_id, types)?;
let checked_e = self.check_expression(expression, module_id, types)?;
let e = match checked_e {
TypedExpression::FieldElement(e) => Ok(e),
TypedExpression::Int(e) => Ok(FieldElementExpression::try_from_int(e).unwrap()), // todo: handle properly
_ => Err(ErrorInner {
pos: Some(pos),
message: "Only field element expressions are allowed in the assembly"
.to_string(),
}),
}?;
let e = FieldElementExpression::block(vec![], e);
match constrained {
true => {
if !e.is_quadratic() {
return Err(ErrorInner {
pos: Some(pos),
message: "Non quadratic constraints are not allowed".to_string(),
});
}
match assignee.get_type() {
Type::FieldElement => Ok(vec![
TypedAssemblyStatement::Assignment(assignee.clone(), e.clone()),
TypedAssemblyStatement::Constraint(assignee.into(), e),
]),
_ => Err(ErrorInner {
pos: Some(pos),
message: "Assignee must be of type `field`".to_string(),
}),
}
}
false => Ok(vec![TypedAssemblyStatement::Assignment(assignee, e)]),
}
}
AssemblyStatement::Constraint(lhs, rhs) => {
let lhs = self.check_expression(lhs, module_id, types)?;
let rhs = self.check_expression(rhs, module_id, types)?;
match (lhs, rhs) {
(TypedExpression::FieldElement(lhs), TypedExpression::FieldElement(rhs)) => {
Ok(vec![TypedAssemblyStatement::Constraint(lhs, rhs)])
}
(TypedExpression::FieldElement(lhs), TypedExpression::Int(rhs)) => {
Ok(vec![TypedAssemblyStatement::Constraint(
lhs,
FieldElementExpression::try_from_int(rhs).unwrap(),
)])
}
(TypedExpression::Int(lhs), TypedExpression::FieldElement(rhs)) => {
Ok(vec![TypedAssemblyStatement::Constraint(
FieldElementExpression::try_from_int(lhs).unwrap(),
rhs,
)])
}
(TypedExpression::Int(lhs), TypedExpression::Int(rhs)) => {
Ok(vec![TypedAssemblyStatement::Constraint(
FieldElementExpression::try_from_int(lhs).unwrap(),
FieldElementExpression::try_from_int(rhs).unwrap(),
)])
}
_ => Err(ErrorInner {
pos: Some(pos),
message: "Only field element expressions are allowed in the assembly"
.to_string(),
}),
}
}
}
}
fn check_statement( fn check_statement(
&mut self, &mut self,
stat: StatementNode<'ast>, stat: StatementNode<'ast>,
@ -1779,6 +1860,18 @@ impl<'ast, T: Field> Checker<'ast, T> {
let pos = stat.pos(); let pos = stat.pos();
match stat.value { match stat.value {
Statement::Assembly(statements) => {
let mut checked_statements = vec![];
for s in statements {
checked_statements.push(
self.check_assembly_statement(s, module_id, types)
.map_err(|e| vec![e])?,
);
}
Ok(TypedStatement::Assembly(
checked_statements.into_iter().flatten().collect(),
))
}
Statement::Log(l, expressions) => { Statement::Log(l, expressions) => {
let l = FormatString::from(l); let l = FormatString::from(l);
@ -1901,10 +1994,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
.map_err(|e| vec![e])?; .map_err(|e| vec![e])?;
// insert the lhs into the scope and ignore whether shadowing happened // insert the lhs into the scope and ignore whether shadowing happened
self.insert_into_scope(var.value.id, var_ty.clone(), var.value.is_mutable); self.insert_into_scope(var.value.id.into(), var_ty.clone(), var.value.is_mutable);
let var = Variable::new( let var = Variable::new(
self.id_in_this_scope(var.value.id), self.id_in_this_scope(var.value.id.into()),
var_ty.clone(), var_ty.clone(),
var.value.is_mutable, var.value.is_mutable,
); );
@ -2037,7 +2130,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let pos = assignee.pos(); let pos = assignee.pos();
// check that the assignee is declared // check that the assignee is declared
match assignee.value { match assignee.value {
Assignee::Identifier(variable_name) => match self.scope.get(&variable_name) { Assignee::Identifier(variable_name) => match self.scope.get(&variable_name.into()) {
Some(info) => match info.is_mutable { Some(info) => match info.is_mutable {
false => Err(ErrorInner { false => Err(ErrorInner {
pos: Some(assignee.pos()), pos: Some(assignee.pos()),
@ -2346,7 +2439,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
Expression::BooleanConstant(b) => Ok(BooleanExpression::Value(b).into()), Expression::BooleanConstant(b) => Ok(BooleanExpression::Value(b).into()),
Expression::Identifier(name) => { Expression::Identifier(name) => {
// check that `id` is defined in the scope // check that `id` is defined in the scope
match self.scope.get(&name) { match self.scope.get(&name.into()) {
Some(info) => { Some(info) => {
let id = info.id; let id = info.id;
match info.ty.clone() { match info.ty.clone() {
@ -3615,11 +3708,11 @@ impl<'ast, T: Field> Checker<'ast, T> {
is_mutable: bool, is_mutable: bool,
) -> bool { ) -> bool {
let info = IdentifierInfo { let info = IdentifierInfo {
id: self.id_in_this_scope(id), id: self.id_in_this_scope(id.clone()),
ty, ty,
is_mutable, is_mutable,
}; };
self.scope.insert(id, info) self.scope.insert(id.into(), info)
} }
fn find_functions( fn find_functions(

View file

@ -14,8 +14,8 @@ struct Propagator<T> {
constants: HashMap<Variable, T>, constants: HashMap<Variable, T>,
} }
impl<T: Field> Folder<T> for Propagator<T> { impl<'ast, T: Field> Folder<'ast, T> for Propagator<T> {
fn fold_statement(&mut self, s: FlatStatement<T>) -> Vec<FlatStatement<T>> { fn fold_statement(&mut self, s: FlatStatement<'ast, T>) -> Vec<FlatStatement<'ast, T>> {
match s { match s {
FlatStatement::Definition(var, expr) => match self.fold_expression(expr) { FlatStatement::Definition(var, expr) => match self.fold_expression(expr) {
FlatExpression::Number(n) => { FlatExpression::Number(n) => {

View file

@ -1,7 +1,8 @@
use std::collections::HashSet;
use std::marker::PhantomData; use std::marker::PhantomData;
use zokrates_ast::typed::types::UBitwidth; use zokrates_ast::typed::types::UBitwidth;
use zokrates_ast::typed::{self, Expr, Typed}; use zokrates_ast::typed::{self, Expr, Typed};
use zokrates_ast::zir::{self, Select}; use zokrates_ast::zir::{self, Folder, Select};
use zokrates_field::Field; use zokrates_field::Field;
use std::convert::{TryFrom, TryInto}; use std::convert::{TryFrom, TryInto};
@ -224,6 +225,14 @@ impl<'ast, T: Field> Flattener<T> {
} }
} }
fn fold_assembly_statement(
&mut self,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
s: typed::TypedAssemblyStatement<'ast, T>,
) -> zir::ZirAssemblyStatement<'ast, T> {
fold_assembly_statement(self, statements_buffer, s)
}
fn fold_statement( fn fold_statement(
&mut self, &mut self,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>, statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
@ -393,12 +402,102 @@ impl<'ast, T: Field> Flattener<T> {
} }
} }
#[derive(Default)]
pub struct ArgumentFinder<'ast, T> {
pub identifiers: HashSet<zir::Identifier<'ast>>,
_phantom: PhantomData<T>,
}
impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
fn fold_name(&mut self, n: zir::Identifier<'ast>) -> zir::Identifier<'ast> {
self.identifiers.insert(n.clone());
n
}
fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec<zir::ZirStatement<'ast, T>> {
match s {
zir::ZirStatement::Definition(assignee, expr) => {
let assignee = self.fold_assignee(assignee);
let expr = self.fold_expression(expr);
self.identifiers.remove(&assignee.id);
vec![zir::ZirStatement::Definition(assignee, expr)]
}
zir::ZirStatement::MultipleDefinition(assignees, list) => {
let assignees: Vec<zir::ZirAssignee<'ast>> = assignees
.into_iter()
.map(|v| self.fold_assignee(v))
.collect();
let list = self.fold_expression_list(list);
for a in &assignees {
self.identifiers.remove(&a.id);
}
vec![zir::ZirStatement::MultipleDefinition(assignees, list)]
}
s => zir::folder::fold_statement(self, s),
}
}
}
fn fold_assembly_statement<'ast, T: Field>(
f: &mut Flattener<T>,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
s: typed::TypedAssemblyStatement<'ast, T>,
) -> zir::ZirAssemblyStatement<'ast, T> {
match s {
typed::TypedAssemblyStatement::Assignment(a, e) => {
let mut statements_buffer: Vec<zir::ZirStatement<'ast, T>> = vec![];
let a = f.fold_assignee(a);
let e = f.fold_field_expression(&mut statements_buffer, e);
statements_buffer.push(zir::ZirStatement::Return(vec![
zir::ZirExpression::FieldElement(e),
]));
let mut finder = ArgumentFinder::default();
let mut statements_buffer: Vec<zir::ZirStatement<'ast, T>> = statements_buffer
.into_iter()
.rev()
.map(|s| finder.fold_statement(s))
.flatten()
.collect();
statements_buffer.reverse();
let function = zir::ZirFunction {
signature: zir::types::Signature::default()
.inputs(vec![zir::Type::FieldElement; finder.identifiers.len()])
.outputs(a.iter().map(|a| a.get_type()).collect()),
arguments: finder
.identifiers
.into_iter()
.map(|id| zir::Parameter {
id: zir::Variable::field_element(id),
private: false,
})
.collect(),
statements: statements_buffer,
};
zir::ZirAssemblyStatement::Assignment(a, function)
}
typed::TypedAssemblyStatement::Constraint(lhs, rhs) => {
let lhs = f.fold_field_expression(statements_buffer, lhs);
let rhs = f.fold_field_expression(statements_buffer, rhs);
zir::ZirAssemblyStatement::Constraint(lhs, rhs)
}
}
}
fn fold_statement<'ast, T: Field>( fn fold_statement<'ast, T: Field>(
f: &mut Flattener<T>, f: &mut Flattener<T>,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>, statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
s: typed::TypedStatement<'ast, T>, s: typed::TypedStatement<'ast, T>,
) { ) {
let res = match s { let res = match s {
typed::TypedStatement::Assembly(statements) => {
let statements = statements
.into_iter()
.map(|s| f.fold_assembly_statement(statements_buffer, s))
.collect();
vec![zir::ZirStatement::Assembly(statements)]
}
typed::TypedStatement::Return(expression) => vec![zir::ZirStatement::Return( typed::TypedStatement::Return(expression) => vec![zir::ZirStatement::Return(
f.fold_expression(statements_buffer, expression), f.fold_expression(statements_buffer, expression),
)], )],

View file

@ -220,6 +220,38 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
s: TypedStatement<'ast, T>, s: TypedStatement<'ast, T>,
) -> Result<Vec<TypedStatement<'ast, T>>, Error> { ) -> Result<Vec<TypedStatement<'ast, T>>, Error> {
match s { match s {
TypedStatement::Assembly(statements) => {
let mut assembly_statement_buffer = vec![];
let mut statement_buffer = vec![];
for s in statements {
match self.fold_assembly_statement(s)? {
TypedAssemblyStatement::Assignment(assignee, expr) => {
// invalidate the cache
let v = self
.try_get_constant_mut(&assignee)
.map(|(v, _)| v)
.unwrap_or_else(|v| v);
match self.constants.remove(&v.id) {
Some(c) => {
statement_buffer.push(TypedStatement::Definition(
v.clone().into(),
c.into(),
));
}
None => {}
}
assembly_statement_buffer
.push(TypedAssemblyStatement::Assignment(assignee, expr));
}
s => assembly_statement_buffer.push(s),
}
}
statement_buffer.push(TypedStatement::Assembly(assembly_statement_buffer));
Ok(statement_buffer)
}
// propagation to the defined variable if rhs is a constant // propagation to the defined variable if rhs is a constant
TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => { TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => {
let assignee = self.fold_assignee(assignee)?; let assignee = self.fold_assignee(assignee)?;

View file

@ -3,7 +3,7 @@ use std::fmt;
use zokrates_ast::zir::types::UBitwidth; use zokrates_ast::zir::types::UBitwidth;
use zokrates_ast::zir::{ use zokrates_ast::zir::{
result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr,
SelectExpression, SelectOrExpression, SelectExpression, SelectOrExpression, ZirAssemblyStatement,
}; };
use zokrates_ast::zir::{ use zokrates_ast::zir::{
BooleanExpression, FieldElementExpression, Identifier, RuntimeError, UExpression, BooleanExpression, FieldElementExpression, Identifier, RuntimeError, UExpression,
@ -42,6 +42,9 @@ pub struct ZirPropagator<'ast, T> {
} }
impl<'ast, T: Field> ZirPropagator<'ast, T> { impl<'ast, T: Field> ZirPropagator<'ast, T> {
pub fn with_constants(constants: Constants<'ast, T>) -> Self {
Self { constants }
}
pub fn propagate(p: ZirProgram<T>) -> Result<ZirProgram<T>, Error> { pub fn propagate(p: ZirProgram<T>) -> Result<ZirProgram<T>, Error> {
ZirPropagator::default().fold_program(p) ZirPropagator::default().fold_program(p)
} }
@ -50,6 +53,24 @@ impl<'ast, T: Field> ZirPropagator<'ast, T> {
impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
type Error = Error; type Error = Error;
fn fold_assembly_statement(
&mut self,
s: ZirAssemblyStatement<'ast, T>,
) -> Result<ZirAssemblyStatement<'ast, T>, Self::Error> {
match s {
ZirAssemblyStatement::Assignment(assignees, function) => {
for a in &assignees {
self.constants.remove(&a.id);
}
Ok(ZirAssemblyStatement::Assignment(
assignees,
self.fold_function(function)?,
))
}
s => fold_assembly_statement(self, s),
}
}
fn fold_statement( fn fold_statement(
&mut self, &mut self,
s: ZirStatement<'ast, T>, s: ZirStatement<'ast, T>,

View file

@ -24,21 +24,22 @@ impl Interpreter {
} }
impl Interpreter { impl Interpreter {
pub fn execute<T: Field, I: IntoIterator<Item = Statement<T>>>( pub fn execute<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>>(
&self, &self,
program: ProgIterator<T, I>, program: ProgIterator<'ast, T, I>,
inputs: &[T], inputs: &[T],
) -> ExecutionResult<T> { ) -> ExecutionResult<T> {
self.execute_with_log_stream(program, inputs, &mut std::io::sink()) self.execute_with_log_stream(program, inputs, &mut std::io::sink())
} }
pub fn execute_with_log_stream< pub fn execute_with_log_stream<
'ast,
W: std::io::Write, W: std::io::Write,
T: Field, T: Field,
I: IntoIterator<Item = Statement<T>>, I: IntoIterator<Item = Statement<'ast, T>>,
>( >(
&self, &self,
program: ProgIterator<T, I>, program: ProgIterator<'ast, T, I>,
inputs: &[T], inputs: &[T],
log_stream: &mut W, log_stream: &mut W,
) -> ExecutionResult<T> { ) -> ExecutionResult<T> {
@ -142,9 +143,9 @@ impl Interpreter {
.collect() .collect()
} }
fn check_inputs<T: Field, I: IntoIterator<Item = Statement<T>>, U>( fn check_inputs<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>, U>(
&self, &self,
program: &ProgIterator<T, I>, program: &ProgIterator<'ast, T, I>,
inputs: &[U], inputs: &[U],
) -> Result<(), Error> { ) -> Result<(), Error> {
if program.arguments.len() == inputs.len() { if program.arguments.len() == inputs.len() {
@ -157,11 +158,18 @@ impl Interpreter {
} }
} }
pub fn execute_solver<T: Field>(solver: &Solver, inputs: &[T]) -> Result<Vec<T>, String> { pub fn execute_solver<'ast, T: Field>(
solver: &Solver<'ast, T>,
inputs: &[T],
) -> Result<Vec<T>, String> {
let (expected_input_count, expected_output_count) = solver.get_signature(); let (expected_input_count, expected_output_count) = solver.get_signature();
assert_eq!(inputs.len(), expected_input_count); assert_eq!(inputs.len(), expected_input_count);
let res = match solver { let res = match solver {
Solver::Zir(func) => {
// TODO: implement evaluation of the function
vec![inputs[1].checked_div(&inputs[0]).unwrap()]
}
Solver::ConditionEq => match inputs[0].is_zero() { Solver::ConditionEq => match inputs[0].is_zero() {
true => vec![T::zero(), T::one()], true => vec![T::zero(), T::one()],
false => vec![ false => vec![

View file

@ -349,13 +349,14 @@ mod internal {
} }
pub fn setup_universal< pub fn setup_universal<
'a,
T: Field, T: Field,
I: IntoIterator<Item = ir::Statement<T>>, I: IntoIterator<Item = ir::Statement<'a, T>>,
S: UniversalScheme<T> + Serialize, S: UniversalScheme<T> + Serialize,
B: UniversalBackend<T, S>, B: UniversalBackend<T, S>,
>( >(
srs: &[u8], srs: &[u8],
program: ir::ProgIterator<T, I>, program: ir::ProgIterator<'a, T, I>,
) -> Result<JsValue, JsValue> { ) -> Result<JsValue, JsValue> {
let keypair = B::setup(srs.to_vec(), program).map_err(|e| JsValue::from_str(&e))?; let keypair = B::setup(srs.to_vec(), program).map_err(|e| JsValue::from_str(&e))?;
Ok(JsValue::from_serde(&TaggedKeypair::<T, S>::new(keypair)).unwrap()) Ok(JsValue::from_serde(&TaggedKeypair::<T, S>::new(keypair)).unwrap())

View file

@ -52,7 +52,7 @@ _mut = {"mut"}
// Statements // Statements
statement = { (iteration_statement // does not require semicolon statement = { (iteration_statement | asm_statement // does not require semicolon
| ((log_statement | ((log_statement
|return_statement |return_statement
| definition_statement | definition_statement
@ -66,6 +66,15 @@ return_statement = { "return" ~ expression? }
definition_statement = { typed_identifier_or_assignee ~ "=" ~ expression } definition_statement = { typed_identifier_or_assignee ~ "=" ~ expression }
assertion_statement = {"assert" ~ "(" ~ expression ~ ("," ~ quoted_string)? ~ ")"} assertion_statement = {"assert" ~ "(" ~ expression ~ ("," ~ quoted_string)? ~ ")"}
op_asm_assign = @{"<--"}
op_asm_assign_constrain = @{"<=="}
asm_assignment = { assignee ~ (op_asm_assign | op_asm_assign_constrain) ~ expression }
asm_constraint = { expression ~ "===" ~ expression }
asm_statement_inner = { (asm_assignment | asm_constraint) ~ semicolon ~ NEWLINE* }
asm_statement = { "asm" ~ "{" ~ NEWLINE* ~ asm_statement_inner* ~ NEWLINE* ~ "}" }
typed_identifier_or_assignee = { typed_identifier | assignee } typed_identifier_or_assignee = { typed_identifier | assignee }
// Expressions // Expressions

View file

@ -8,11 +8,12 @@ use zokrates_parser::Rule;
extern crate lazy_static; extern crate lazy_static;
pub use ast::{ pub use ast::{
Access, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, Access, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType, AssemblyStatement,
Assignee, AssigneeAccess, BasicOrStructOrTupleType, BasicType, BinaryExpression, AssemblyStatementInner, AssertionStatement, Assignee, AssigneeAccess, AssignmentOperator,
BinaryOperator, CallAccess, ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, BasicOrStructOrTupleType, BasicType, BinaryExpression, BinaryOperator, CallAccess,
DecimalNumber, DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber,
File, FromExpression, FunctionDefinition, HexLiteralExpression, HexNumberExpression, DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, File,
FromExpression, FunctionDefinition, HexLiteralExpression, HexNumberExpression,
IdentifierExpression, IdentifierOrDecimal, IfElseExpression, ImportDirective, ImportSymbol, IdentifierExpression, IdentifierOrDecimal, IfElseExpression, ImportDirective, ImportSymbol,
InlineArrayExpression, InlineStructExpression, InlineStructMember, InlineTupleExpression, InlineArrayExpression, InlineStructExpression, InlineStructMember, InlineTupleExpression,
IterationStatement, LiteralExpression, LogStatement, Parameter, PostfixExpression, Range, IterationStatement, LiteralExpression, LogStatement, Parameter, PostfixExpression, Range,
@ -366,6 +367,7 @@ mod ast {
Assertion(AssertionStatement<'ast>), Assertion(AssertionStatement<'ast>),
Iteration(IterationStatement<'ast>), Iteration(IterationStatement<'ast>),
Log(LogStatement<'ast>), Log(LogStatement<'ast>),
Assembly(AssemblyStatement<'ast>),
} }
#[derive(Debug, FromPest, PartialEq, Clone)] #[derive(Debug, FromPest, PartialEq, Clone)]
@ -423,6 +425,72 @@ mod ast {
pub span: Span<'ast>, pub span: Span<'ast>,
} }
// #[derive(Debug, FromPest, PartialEq, Eq, Clone)]
// #[pest_ast(rule(Rule::op_asm_assign))]
// pub struct AssemblyAssignOperator;
//
// #[derive(Debug, FromPest, PartialEq, Eq, Clone)]
// #[pest_ast(rule(Rule::op_asm_assign_constrain))]
// pub struct AssemblyAssignConstrainOperator;
//
// #[derive(Debug, FromPest, PartialEq, Eq, Clone)]
// #[pest_ast(rule(Rule::op_asm_constrain))]
// pub struct AssemblyConstrainOperator;
#[derive(Debug, PartialEq, Clone)]
pub enum AssignmentOperator {
Assign,
AssignConstrain,
}
impl<'ast> FromPest<'ast> for AssignmentOperator {
type Rule = Rule;
type FatalError = Void;
fn from_pest(pest: &mut Pairs<'ast, Rule>) -> Result<Self, ConversionError<Void>> {
let pair = pest.next().ok_or(::from_pest::ConversionError::NoMatch)?;
match pair.as_rule() {
Rule::op_asm_assign => Ok(AssignmentOperator::Assign),
Rule::op_asm_assign_constrain => Ok(AssignmentOperator::AssignConstrain),
_ => Err(ConversionError::NoMatch),
}
}
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::asm_assignment))]
pub struct AssemblyAssignment<'ast> {
pub assignee: Assignee<'ast>,
pub operator: AssignmentOperator,
pub expression: Expression<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::asm_constraint))]
pub struct AssemblyConstraint<'ast> {
pub lhs: Expression<'ast>,
pub rhs: Expression<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::asm_statement_inner))]
pub enum AssemblyStatementInner<'ast> {
Assignment(AssemblyAssignment<'ast>),
Constraint(AssemblyConstraint<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::asm_statement))]
pub struct AssemblyStatement<'ast> {
pub inner: Vec<AssemblyStatementInner<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, Clone)]
pub enum BinaryOperator { pub enum BinaryOperator {
BitXor, BitXor,

View file

@ -96,8 +96,8 @@ impl ToString for G2AffineFq2 {
} }
pub trait Backend<T: Field, S: Scheme<T>> { pub trait Backend<T: Field, S: Scheme<T>> {
fn generate_proof<I: IntoIterator<Item = ir::Statement<T>>>( fn generate_proof<'a, I: IntoIterator<Item = ir::Statement<'a, T>>>(
program: ir::ProgIterator<T, I>, program: ir::ProgIterator<'a, T, I>,
witness: ir::Witness<T>, witness: ir::Witness<T>,
proving_key: Vec<u8>, proving_key: Vec<u8>,
) -> Proof<T, S>; ) -> Proof<T, S>;
@ -105,36 +105,36 @@ pub trait Backend<T: Field, S: Scheme<T>> {
fn verify(vk: S::VerificationKey, proof: Proof<T, S>) -> bool; fn verify(vk: S::VerificationKey, proof: Proof<T, S>) -> bool;
} }
pub trait NonUniversalBackend<T: Field, S: NonUniversalScheme<T>>: Backend<T, S> { pub trait NonUniversalBackend<T: Field, S: NonUniversalScheme<T>>: Backend<T, S> {
fn setup<I: IntoIterator<Item = ir::Statement<T>>>( fn setup<'a, I: IntoIterator<Item = ir::Statement<'a, T>>>(
program: ir::ProgIterator<T, I>, program: ir::ProgIterator<'a, T, I>,
) -> SetupKeypair<T, S>; ) -> SetupKeypair<T, S>;
} }
pub trait UniversalBackend<T: Field, S: UniversalScheme<T>>: Backend<T, S> { pub trait UniversalBackend<T: Field, S: UniversalScheme<T>>: Backend<T, S> {
fn universal_setup(size: u32) -> Vec<u8>; fn universal_setup(size: u32) -> Vec<u8>;
fn setup<I: IntoIterator<Item = ir::Statement<T>>>( fn setup<'a, I: IntoIterator<Item = ir::Statement<'a, T>>>(
srs: Vec<u8>, srs: Vec<u8>,
program: ir::ProgIterator<T, I>, program: ir::ProgIterator<'a, T, I>,
) -> Result<SetupKeypair<T, S>, String>; ) -> Result<SetupKeypair<T, S>, String>;
} }
pub trait MpcBackend<T: Field, S: Scheme<T>> { pub trait MpcBackend<T: Field, S: Scheme<T>> {
fn initialize<R: Read, W: Write, I: IntoIterator<Item = ir::Statement<T>>>( fn initialize<'a, R: Read, W: Write, I: IntoIterator<Item = ir::Statement<'a, T>>>(
program: ir::ProgIterator<T, I>, program: ir::ProgIterator<'a, T, I>,
phase1_radix: &mut R, phase1_radix: &mut R,
output: &mut W, output: &mut W,
) -> Result<(), String>; ) -> Result<(), String>;
fn contribute<R: Read, W: Write, G: Rng>( fn contribute<'a, R: Read, W: Write, G: Rng>(
params: &mut R, params: &mut R,
rng: &mut G, rng: &mut G,
output: &mut W, output: &mut W,
) -> Result<[u8; 64], String>; ) -> Result<[u8; 64], String>;
fn verify<P: Read, R: Read, I: IntoIterator<Item = ir::Statement<T>>>( fn verify<'a, P: Read, R: Read, I: IntoIterator<Item = ir::Statement<'a, T>>>(
params: &mut P, params: &mut P,
program: ir::ProgIterator<T, I>, program: ir::ProgIterator<'a, T, I>,
phase1_radix: &mut R, phase1_radix: &mut R,
) -> Result<Vec<[u8; 64]>, String>; ) -> Result<Vec<[u8; 64]>, String>;