From 1bb524f6a2f81ac56113b4557ed963dcd0203df8 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 20 Feb 2023 23:17:02 +0100 Subject: [PATCH] clean --- zokrates_analysis/src/propagation.rs | 21 +- .../src/reducer/constants_reader.rs | 12 +- .../src/reducer/constants_writer.rs | 6 +- zokrates_analysis/src/reducer/inline.rs | 5 +- zokrates_analysis/src/reducer/mod.rs | 550 +++++------------- zokrates_analysis/src/reducer/shallow_ssa.rs | 71 +-- zokrates_ast/src/typed/folder.rs | 2 +- zokrates_ast/src/typed/types.rs | 2 +- 8 files changed, 199 insertions(+), 470 deletions(-) diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index b7701dc0..8a2f32e0 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -44,25 +44,16 @@ impl fmt::Display for Error { } } -#[derive(Debug)] -pub struct Propagator<'ast, 'a, T> { +#[derive(Debug, Default)] +pub struct Propagator<'ast, T> { // constants keeps track of constant expressions // we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is - constants: &'a mut Constants<'ast, T>, + constants: Constants<'ast, T>, } -impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> { - pub fn with_constants(constants: &'a mut Constants<'ast, T>) -> Self { - Propagator { constants } - } - +impl<'ast, T: Field> Propagator<'ast, T> { pub fn propagate(p: TypedProgram<'ast, T>) -> Result, Error> { - let mut constants = Constants::new(); - - Propagator { - constants: &mut constants, - } - .fold_program(p) + Propagator::default().fold_program(p) } // get a mutable reference to the constant corresponding to a given assignee if any, otherwise @@ -141,7 +132,7 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> { } } -impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { +impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { type Error = Error; fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> Result, Error> { diff --git a/zokrates_analysis/src/reducer/constants_reader.rs b/zokrates_analysis/src/reducer/constants_reader.rs index f0ec252e..f991f77f 100644 --- a/zokrates_analysis/src/reducer/constants_reader.rs +++ b/zokrates_analysis/src/reducer/constants_reader.rs @@ -65,7 +65,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, @@ -94,7 +94,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, @@ -124,7 +124,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, @@ -152,7 +152,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, @@ -182,7 +182,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, @@ -212,7 +212,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, diff --git a/zokrates_analysis/src/reducer/constants_writer.rs b/zokrates_analysis/src/reducer/constants_writer.rs index d38daf15..50a6d25a 100644 --- a/zokrates_analysis/src/reducer/constants_writer.rs +++ b/zokrates_analysis/src/reducer/constants_writer.rs @@ -5,9 +5,9 @@ use crate::reducer::{ }; use std::collections::{BTreeMap, HashSet}; use zokrates_ast::typed::{ - result_folder::*, types::ConcreteGenericsAssignment, Constant, OwnedTypedModuleId, Typed, - TypedConstant, TypedConstantSymbol, TypedConstantSymbolDeclaration, TypedModuleId, - TypedProgram, TypedSymbolDeclaration, UExpression, + result_folder::*, Constant, OwnedTypedModuleId, Typed, TypedConstant, TypedConstantSymbol, + TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram, TypedSymbolDeclaration, + UExpression, }; use zokrates_field::Field; diff --git a/zokrates_analysis/src/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs index f1a7229e..7e2436cd 100644 --- a/zokrates_analysis/src/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -26,13 +26,10 @@ // - The body of the function is in SSA form // - The return value(s) are assigned to internal variables -use crate::reducer::ShallowTransformer; -use crate::reducer::Versions; - use zokrates_ast::common::FlatEmbed; use zokrates_ast::typed::types::{ConcreteGenericsAssignment, IntoType}; use zokrates_ast::typed::CoreIdentifier; -use zokrates_ast::typed::Identifier; + use zokrates_ast::typed::TypedAssignee; use zokrates_ast::typed::UBitwidth; use zokrates_ast::typed::{ diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index bd129890..51dcc3aa 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -18,30 +18,26 @@ mod shallow_ssa; use self::inline::{inline_call, InlineError}; use std::collections::HashMap; -use zokrates_ast::typed::identifier::FrameIdentifier; use zokrates_ast::typed::result_folder::*; -use zokrates_ast::typed::types::ConcreteGenericsAssignment; -use zokrates_ast::typed::types::GGenericsAssignment; use zokrates_ast::typed::DeclarationParameter; use zokrates_ast::typed::Folder; -use zokrates_ast::typed::Typed; +use zokrates_ast::typed::TypedAssemblyStatement; use zokrates_ast::typed::TypedAssignee; -use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; - use zokrates_ast::typed::{ ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall, - FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId, - TypedExpression, TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, - TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner, + FunctionCallExpression, FunctionCallOrExpression, Id, OwnedTypedModuleId, TypedExpression, + TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram, + TypedStatement, UExpression, UExpressionInner, }; +use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; -use zokrates_ast::zir::result_folder::fold_assembly_statement; use zokrates_field::Field; use self::constants_writer::ConstantsWriter; use self::shallow_ssa::ShallowTransformer; -use crate::propagation::{Constants, Propagator}; +use crate::propagation; +use crate::propagation::Propagator; use std::fmt; @@ -55,7 +51,6 @@ pub type ConstantDefinitions<'ast, T> = #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct Versions<'ast> { map: HashMap, usize>>, - frame: usize, } impl<'ast> Versions<'ast> { @@ -69,15 +64,15 @@ impl<'ast> Versions<'ast> { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq)] pub enum Error { Incompatible(String), GenericsInMain, - // TODO: give more details about what's blocking the progress - NoProgress, LoopTooLarge(u128), ConstantReduction(String, OwnedTypedModuleId), + NonConstant(String), Type(String), + Propagation(propagation::Error), } impl fmt::Display for Error { @@ -89,39 +84,36 @@ impl fmt::Display for Error { s ), Error::GenericsInMain => write!(f, "Cannot generate code for generic function"), - Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"), Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE), Error::ConstantReduction(name, module) => write!(f, "Failed to reduce constant `{}` in module `{}` to a literal, try simplifying its declaration", name, module.display()), - Error::Type(message) => write!(f, "{}", message), + Error::NonConstant(s) => write!(f, "{}", s), + Error::Type(s) => write!(f, "{}", s), + Error::Propagation(e) => write!(f, "{}", e), } } } +impl From for Error { + fn from(e: propagation::Error) -> Self { + Self::Propagation(e) + } +} + #[derive(Debug)] struct Reducer<'ast, 'a, T> { - propagator: Propagator<'ast, 'a, T>, - statement_buffer: Vec>, - latest_frame: usize, - versions: &'a mut Versions<'ast>, program: &'a TypedProgram<'ast, T>, + propagator: Propagator<'ast, T>, + ssa: ShallowTransformer<'ast>, + statement_buffer: Vec>, } impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> { - fn new( - program: &'a TypedProgram<'ast, T>, - versions: &'a mut Versions<'ast>, - constants: &'a mut Constants<'ast, T>, - ) -> Self { - // println!("create reducer with"); - // println!("{} versions", versions.len()); - // println!("{} constants", constants.len()); - + fn new(program: &'a TypedProgram<'ast, T>) -> Self { Reducer { - propagator: Propagator::with_constants(constants), + propagator: Propagator::default(), + ssa: ShallowTransformer::default(), statement_buffer: vec![], - latest_frame: 0, program, - versions, } } } @@ -133,8 +125,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { &mut self, p: DeclarationParameter<'ast, T>, ) -> Result, Self::Error> { + // this is only used on the entry point let id = p.id.id.id.id.clone(); - assert!(self.versions.insert_in_frame(id, 0, 0).is_none()); + assert!(self.ssa.versions.insert_in_frame(id, 0, 0).is_none()); Ok(p) } @@ -150,9 +143,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .into_iter() .map(|g| { g.map(|g| { - let g = - ShallowTransformer::with_versions(self.versions).fold_uint_expression(g); - let g = self.propagator.fold_uint_expression(g).unwrap(); + let g = self.ssa.fold_uint_expression(g); + let g = self.propagator.fold_uint_expression(g)?; self.fold_uint_expression(g) }) @@ -164,43 +156,30 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .arguments .into_iter() .map(|e| { - let e = ShallowTransformer::with_versions(self.versions).fold_expression(e); - let e = self.propagator.fold_expression(e).unwrap(); + let e = self.ssa.fold_expression(e); + let e = self.propagator.fold_expression(e)?; self.fold_expression(e) }) .collect::>()?; - // back up the current frame - let frame_backup = self.versions.frame; + self.ssa.push_call_frame(); - // create a new frame - self.latest_frame += 1; - - // point the versions to this frame - self.versions.frame = self.latest_frame; - self.versions - .map - .insert(self.versions.frame, Default::default()); - - // println!("GENERICS {:?}", generics); - - let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program); + let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, &self.program); let res = match res { Ok((input_variables, arguments, generics_bindings, statements, expression)) => { - let generics_bindings: Vec<_> = generics_bindings + let generics_bindings = generics_bindings .into_iter() - .flat_map(|s| { - ShallowTransformer::with_versions(self.versions).fold_statement(s) - }) - .flat_map(|s| self.propagator.fold_statement(s).unwrap()) - .collect::>() + .flat_map(|s| self.ssa.fold_statement(s)) + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? .into_iter() - .flat_map(|s| self.fold_statement(s).unwrap()) - .collect(); - - // println!("{:#?}", propagator); + .flatten() + .map(|s| self.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); self.statement_buffer.extend(generics_bindings); @@ -208,43 +187,32 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { let input_bindings: Vec<_> = input_variables .into_iter() .zip(arguments) - .map(|(v, a)| { - TypedStatement::definition( - ShallowTransformer::with_versions(self.versions) - .fold_assignee(v.into()), - a, - ) - }) + .map(|(v, a)| TypedStatement::definition(self.ssa.fold_assignee(v.into()), a)) .collect(); - let input_bindings: Vec<_> = input_bindings + let input_bindings = input_bindings .into_iter() - .flat_map(|s| self.propagator.fold_statement(s).unwrap()) - .collect(); + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); self.statement_buffer.extend(input_bindings); - let statements: Vec<_> = statements + let statements = statements .into_iter() - .flat_map(|s| self.fold_statement(s).unwrap()) - .collect(); - - // println!("FRAME READY TO SSA {}", self.versions.frame); - - let mut transformer = ShallowTransformer::with_versions(self.versions); - let propagator = &mut self.propagator; + .map(|s| self.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); self.statement_buffer.extend(statements); - // println!("call result {}", expression); + let expression = self.ssa.fold_expression(expression); - let expression = transformer.fold_expression(expression); + let expression = self.propagator.fold_expression(expression)?; - let expression = propagator.fold_expression(expression).unwrap(); - - let expression = self.fold_expression(expression).unwrap(); - - // println!("call result reduced {}", expression); + let expression = self.fold_expression(expression)?; Ok(FunctionCallOrExpression::Expression( E::from(expression).into_inner(), @@ -254,20 +222,14 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { "Call site `{}` incompatible with declaration `{}`", conc, decl ))), - Err(InlineError::NonConstant(key, generics, arguments, _)) => Err(Error::NoProgress), + Err(InlineError::NonConstant(key, generics, arguments, _)) => { + Err(Error::NonConstant(format!( + "Generic parameters must be compile-time constants, found {}", + FunctionCallExpression::<_, E>::new(key, generics, arguments) + ))) + } Err(InlineError::Flat(embed, generics, arguments, output_type)) => { - let identifier = - Identifier::from(CoreIdentifier::Call(0).in_frame(self.versions.frame)) - .version( - *self - .versions - .map - .entry(self.versions.frame) - .or_default() - .entry(CoreIdentifier::Call(0)) - .and_modify(|e| *e += 1) // if it was already declared, we increment - .or_insert(0), - ); + let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0)); let var = Variable::immutable(identifier.clone(), output_type); @@ -284,12 +246,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { } }; - // clean versions - self.versions.map.remove(&self.versions.frame); - - // restore the original frame - // println!("RESTORING BACKUP {}", frame_backup); - self.versions.frame = frame_backup; + self.ssa.pop_call_frame(); res } @@ -298,24 +255,44 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { &mut self, b: BlockExpression<'ast, T, E>, ) -> Result, Self::Error> { - // // backup the statements and continue with a fresh state - // let statement_buffer = std::mem::take(&mut self.statement_buffer); + // backup the statements and continue with a fresh state + let statement_buffer = std::mem::take(&mut self.statement_buffer); - // let block = fold_block_expression(self, b)?; + let block = fold_block_expression(self, b)?; - // // put the original statements back and extract the statements created by visiting the block - // let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer); + // put the original statements back and extract the statements created by visiting the block + let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer); - // // return the visited block, augmented with the statements created while visiting it - // Ok(BlockExpression { - // statements: block - // .statements - // .into_iter() - // .chain(extra_statements) - // .collect(), - // ..block - // }) - todo!() + // return the visited block, augmented with the statements created while visiting it + Ok(BlockExpression { + statements: block + .statements + .into_iter() + .chain(extra_statements) + .collect(), + ..block + }) + } + + fn fold_assembly_statement( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + Ok(match s { + TypedAssemblyStatement::Assignment(a, e) => { + vec![TypedAssemblyStatement::Assignment( + self.fold_assignee(a)?, + self.fold_expression(e)?, + )] + } + TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { + vec![TypedAssemblyStatement::Constraint( + self.fold_field_expression(lhs)?, + self.fold_field_expression(rhs)?, + metadata, + )] + } + }) } fn fold_canonical_constant_identifier( @@ -336,163 +313,77 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { // then we reduce the rhs to remove the function calls // only then we transform and propagate the assignee - let mut transformer = ShallowTransformer::with_versions(self.versions); - // println!("rhs {}", rhs); - let rhs = transformer.fold_definition_rhs(rhs); - // println!("ssa-ed {}", rhs); - let rhs = self.propagator.fold_definition_rhs(rhs).unwrap(); - // println!("propagated {}", rhs); - let rhs = self.fold_definition_rhs(rhs).unwrap(); - // println!("reduced {}", rhs); + let rhs = self.ssa.fold_definition_rhs(rhs); + let rhs = self.propagator.fold_definition_rhs(rhs)?; + let rhs = self.fold_definition_rhs(rhs)?; - // println!("ASSIGNEE {}", a); - - // println!("{:?}", self.versions); - - let a = ShallowTransformer::with_versions(self.versions).fold_assignee(a); - - // println!("{:?}", self.versions); - - // println!("definition before propagation {}", TypedStatement::Definition(a.clone(), rhs.clone())); + let a = self.ssa.fold_assignee(a); self.propagator - .fold_statement(TypedStatement::Definition(a, rhs)) - .unwrap() - - // println!("final definition size: {}", s.len()); + .fold_statement(TypedStatement::Definition(a, rhs))? } TypedStatement::For(v, from, to, statements) => { - let from = - ShallowTransformer::with_versions(self.versions).fold_uint_expression(from); - let from = self.propagator.fold_uint_expression(from).unwrap(); - let from = self.fold_uint_expression(from).unwrap(); - let to = ShallowTransformer::with_versions(self.versions).fold_uint_expression(to); - let to = self.propagator.fold_uint_expression(to).unwrap(); - let to = self.fold_uint_expression(to).unwrap(); + let from = self.ssa.fold_uint_expression(from); + let from = self.propagator.fold_uint_expression(from)?; + let from = self.fold_uint_expression(from)?; + let to = self.ssa.fold_uint_expression(to); + let to = self.propagator.fold_uint_expression(to)?; + let to = self.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { - (UExpressionInner::Value(from), UExpressionInner::Value(to)) => (*from..*to) + (UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from + ..*to) .flat_map(|index| { std::iter::once(TypedStatement::definition( v.clone().into(), UExpression::from(index as u32).into(), )) .chain(statements.clone()) - .flat_map(|s| self.fold_statement(s).unwrap()) + .map(|s| self.fold_statement(s)) .collect::>() }) - .collect::>(), - _ => unimplemented!(), - } + .collect::, _>>()? + .into_iter() + .flatten() + .collect()), + _ => Err(Error::NonConstant(format!( + "Expected loop bounds to be constant, found {}..{}", + from, to + ))), + }? } - TypedStatement::Assembly(_) => todo!(), TypedStatement::Return(e) => { - let mut transformer = ShallowTransformer::with_versions(self.versions); - - let e = transformer.fold_expression(e); - let e = self.propagator.fold_expression(e).unwrap(); - vec![TypedStatement::Return(self.fold_expression(e).unwrap())] + let e = self.ssa.fold_expression(e); + let e = self.propagator.fold_expression(e)?; + vec![TypedStatement::Return(self.fold_expression(e)?)] } TypedStatement::Assertion(e, error) => { - let mut transformer = ShallowTransformer::with_versions(self.versions); - - let e = transformer.fold_boolean_expression(e); - let e = self.propagator.fold_boolean_expression(e).unwrap(); + let e = self.ssa.fold_boolean_expression(e); + let e = self.propagator.fold_boolean_expression(e)?; vec![TypedStatement::Assertion( - self.fold_boolean_expression(e).unwrap(), + self.fold_boolean_expression(e)?, error, )] } - s => { - let mut transformer = ShallowTransformer::with_versions(self.versions); - let propagator = &mut self.propagator; - transformer - .fold_statement(s) - .into_iter() - // .inspect(|s| println!("ssa {}\n", s)) - .flat_map(|s| propagator.fold_statement(s).unwrap()) - .collect::>() - .into_iter() - .flat_map(|s| fold_statement(self, s).unwrap()) - // .inspect(|s| println!("propagated {}\n", s)) - .collect() - } + s => self + .ssa + .fold_statement(s) + .into_iter() + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .map(|s| fold_statement(self, s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), }; - Ok(self - .statement_buffer - .drain(..) - .chain(res) - .collect::>()) + Ok(self.statement_buffer.drain(..).chain(res).collect()) } - // fn fold_statement( - // &mut self, - // s: TypedStatement<'ast, T>, - // ) -> Result>, Self::Error> { - // let mut transformer = ShallowTransformer::with_versions(self.versions); - // let propagator = &mut self.propagator; - - // println!("FOLD_STATEMENT: {}", s); - - // let s: Vec<_> = transformer - // .fold_statement(s) - // .into_iter() - // // .inspect(|s| println!("ssa {}\n", s)) - // .flat_map(|s| propagator.fold_statement(s).unwrap()) - // // .inspect(|s| println!("propagated {}\n", s)) - // .collect(); - - // for s in &s { - // println!("OUTER: {}", s); - // } - - // let res: Vec<_> = s - // .into_iter() - // .flat_map(|s| match s { - // TypedStatement::For(v, from, to, statements) => { - // match (from.as_inner(), to.as_inner()) { - // (UExpressionInner::Value(from), UExpressionInner::Value(to)) => (*from - // ..*to) - // .flat_map(|index| { - // std::iter::once(TypedStatement::definition( - // v.clone().into(), - // UExpression::from(index as u32).into(), - // )) - // .chain(statements.clone()) - // .flat_map(|s| self.fold_statement(s).unwrap()) - // .collect::>() - // }) - // .collect(), - // _ => unimplemented!(), - // } - // } - // s => { - // println!("UNROLL/INLINE STATEMENT {}", s); - - // let s = fold_statement(self, s).unwrap(); - - // for s in &self.statement_buffer { - // println!("BUFFER {}", s); - // } - - // for s in &s { - // println!("RESULT {}", s); - // } - - // self.statement_buffer.drain(..).chain(s).collect::>() - // } - // }) - // .collect(); - - // for s in &res { - // // println!("DONE: {}", s); - // } - - // Ok(res) - // } - fn fold_array_expression_inner( &mut self, array_ty: &ArrayType<'ast, T>, @@ -500,25 +391,24 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ) -> Result, Self::Error> { match e { ArrayExpressionInner::Slice(box array, box from, box to) => { - let array = - ShallowTransformer::with_versions(self.versions).fold_array_expression(array); - let array = self.propagator.fold_array_expression(array).unwrap(); - let array = self.fold_array_expression(array).unwrap(); - let from = - ShallowTransformer::with_versions(self.versions).fold_uint_expression(from); - let from = self.propagator.fold_uint_expression(from).unwrap(); - let from = self.fold_uint_expression(from).unwrap(); - let to = ShallowTransformer::with_versions(self.versions).fold_uint_expression(to); - let to = self.propagator.fold_uint_expression(to).unwrap(); - let to = self.fold_uint_expression(to).unwrap(); + let array = self.ssa.fold_array_expression(array); + let array = self.propagator.fold_array_expression(array)?; + let array = self.fold_array_expression(array)?; + let from = self.ssa.fold_uint_expression(from); + let from = self.propagator.fold_uint_expression(from)?; + let from = self.fold_uint_expression(from)?; + let to = self.ssa.fold_uint_expression(to); + let to = self.propagator.fold_uint_expression(to)?; + let to = self.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { Ok(ArrayExpressionInner::Slice(box array, box from, box to)) } - _ => { - todo!("non constant slice bounds") - } + _ => Err(Error::NonConstant(format!( + "Slice bounds must be compile time constants, found {}", + ArrayExpressionInner::Slice(box array, box from, box to) + ))), } } _ => fold_array_expression_inner(self, array_ty, e), @@ -548,7 +438,7 @@ pub fn reduce_program(p: TypedProgram) -> Result, E match main_function.signature.generics.len() { 0 => { - let main_function = reduce_function(main_function, &p)?; + let main_function = Reducer::new(&p).fold_function(main_function)?; Ok(TypedProgram { main: p.main.clone(), @@ -574,135 +464,9 @@ fn reduce_function<'ast, T: Field>( f: TypedFunction<'ast, T>, program: &TypedProgram<'ast, T>, ) -> Result, Error> { - let mut versions = Versions::default(); - let mut constants = Constants::default(); - assert!(f.signature.generics.is_empty()); - // let f = match ShallowTransformer::transform(f, &generics, &mut versions) { - // Output::Complete(f) => Ok(f), - // Output::Incomplete(new_f, new_for_loop_versions) => { - // let mut for_loop_versions = new_for_loop_versions; - - // let mut f = Propagator::with_constants(&mut constants) - // .fold_function(new_f) - // .map_err(|e| Error::Incompatible(format!("{}", e)))?; - - // let mut substitutions = Substitutions::default(); - - // let mut hash = None; - - // let mut len = f.statements.len(); - - // // println!("{}", f); - - // loop { - // let mut reducer = Reducer::new( - // program, - // &mut versions, - // &mut substitutions, - // for_loop_versions, - // &mut constants, - // ); - - // println!("reduce"); - - // let new_f = TypedFunction { - // statements: f - // .statements - // .into_iter() - // .map(|s| reducer.fold_statement(s)) - // .collect::, _>>()? - // .into_iter() - // .flatten() - // .collect(), - // ..f - // }; - - // println!("done"); - - // // println!("after reduction {}", new_f); - - // println!( - // "count {}, unrolled {} loops", - // new_f.statements.len(), - // reducer.credits - // ); - - // assert!(reducer.for_loop_versions.is_empty()); - - // match reducer.complete { - // true => { - // substitutions = substitutions.canonicalize(); - - // let new_f = Sub::new(&substitutions).fold_function(new_f); - - // // println!("after last sub {}", new_f); - - // // let new_f = Propagator::with_constants(&mut constants) - // // .fold_function(new_f) - // // .map_err(|e| Error::Incompatible(format!("{}", e)))?; - - // // println!("after last prop {}", new_f); - - // break Ok(new_f); - // } - // false => { - // for_loop_versions = reducer.for_loop_versions_after; - - // println!("canonicalize"); - - // // substitutions = substitutions.canonicalize(); - - // // let new_f = Sub::new(&substitutions).fold_function(new_f); - - // println!("done"); - // // println!("after sub {}", new_f); - - // println!("propagate"); - - // // f = Propagator::with_constants(&mut constants) - // // .fold_function(new_f) - // // .map_err(|e| Error::Incompatible(format!("{}", e)))?; - - // println!("done"); - - // f = new_f; - - // // println!("after prop {}", f); - - // let new_len = f.statements.len(); - - // if new_len == len { - // let new_hash = Some(compute_hash(&f)); - - // if new_hash == hash { - // break Err(Error::NoProgress); - // } else { - // hash = new_hash; - // } - // } else { - // len = new_len; - // } - // } - // } - // } - // } - // }?; - - // Propagator::with_constants(&mut constants) - // .fold_function(f) - // .map_err(|e| Error::Incompatible(format!("{}", e))) - - Reducer::new(program, &mut versions, &mut constants).fold_function(f) -} - -fn compute_hash(f: &TypedFunction) -> u64 { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - let mut s = DefaultHasher::new(); - f.hash(&mut s); - s.finish() + Reducer::new(program).fold_function(f) } #[cfg(test)] diff --git a/zokrates_analysis/src/reducer/shallow_ssa.rs b/zokrates_analysis/src/reducer/shallow_ssa.rs index 5e744fe2..0d450950 100644 --- a/zokrates_analysis/src/reducer/shallow_ssa.rs +++ b/zokrates_analysis/src/reducer/shallow_ssa.rs @@ -28,26 +28,24 @@ use zokrates_ast::typed::folder::*; use zokrates_ast::typed::identifier::FrameIdentifier; -use zokrates_ast::typed::types::ConcreteGenericsAssignment; -use zokrates_ast::typed::types::Type; + use zokrates_ast::typed::*; use zokrates_field::Field; use super::Versions; -pub struct ShallowTransformer<'ast, 'a> { +#[derive(Debug, Default)] +pub struct ShallowTransformer<'ast> { // version index for any variable name - pub versions: &'a mut Versions<'ast>, + pub versions: Versions<'ast>, + pub frames: Vec, + pub latest_frame: usize, } -impl<'ast, 'a> ShallowTransformer<'ast, 'a> { - pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self { - ShallowTransformer { versions } - } - - fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> { - let frame_versions = self.versions.map.entry(self.versions.frame).or_default(); +impl<'ast> ShallowTransformer<'ast> { + pub fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> { + let frame_versions = self.versions.map.entry(self.frame()).or_default(); let version = frame_versions .entry(c_id.clone()) @@ -66,42 +64,21 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> { } } - pub fn transform( - f: TypedFunction<'ast, T>, - generics: &ConcreteGenericsAssignment<'ast>, - versions: &'a mut Versions<'ast>, - ) -> TypedFunction<'ast, T> { - let mut unroller = ShallowTransformer::with_versions(versions); - - unroller.fold_function(f, generics) + fn frame(&self) -> usize { + *self.frames.last().unwrap_or(&0) } - fn fold_function( - &mut self, - f: TypedFunction<'ast, T>, - generics: &ConcreteGenericsAssignment<'ast>, - ) -> TypedFunction<'ast, T> { - let mut f = f; + pub fn push_call_frame(&mut self) { + self.latest_frame += 1; + self.frames.push(self.latest_frame); + self.versions + .map + .insert(self.latest_frame, Default::default()); + } - f.statements = generics - .0 - .clone() - .into_iter() - .map(|(g, v)| { - TypedStatement::definition( - Variable::new(CoreIdentifier::from(g), Type::Uint(UBitwidth::B32), false) - .into(), - UExpression::from(v as u32).into(), - ) - }) - .chain(f.statements) - .collect(); - - for arg in &f.arguments { - let _ = self.issue_next_identifier(arg.id.id.id.id.clone()); - } - - fold_function(self, f) + pub fn pop_call_frame(&mut self) { + let frame = self.frames.pop().unwrap(); + self.versions.map.remove(&frame); } pub fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { @@ -115,7 +92,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> { } } -impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { +impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { fn fold_assembly_statement( &mut self, s: TypedAssemblyStatement<'ast, T>, @@ -154,14 +131,14 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { let version = self .versions .map - .get(&self.versions.frame) + .get(&self.frame()) .unwrap() .get(&n.id.id) .cloned() .unwrap_or(0); let id = FrameIdentifier { - frame: self.versions.frame, + frame: self.frame(), ..n.id }; diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index 70200ecd..1180874f 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -135,7 +135,7 @@ pub trait Folder<'ast, T: Field>: Sized { id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)), frame: 0, }, - id => n.id, + _id => n.id, }; Identifier { id, ..n } diff --git a/zokrates_ast/src/typed/types.rs b/zokrates_ast/src/typed/types.rs index 61f49eab..60d3792f 100644 --- a/zokrates_ast/src/typed/types.rs +++ b/zokrates_ast/src/typed/types.rs @@ -240,7 +240,7 @@ impl<'ast, T> From for UExpression<'ast, T> { impl<'ast, T: Field> From> for UExpression<'ast, T> { fn from(c: DeclarationConstant<'ast, T>) -> Self { match c { - DeclarationConstant::Generic(g) => { + DeclarationConstant::Generic(_g) => { // UExpression::identifier(FrameIdentifier::from(g).into()).annotate(UBitwidth::B32) unreachable!() }