1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00
This commit is contained in:
schaeff 2023-02-20 23:17:02 +01:00
parent 01450c741f
commit 1bb524f6a2
8 changed files with 199 additions and 470 deletions

View file

@ -44,25 +44,16 @@ impl fmt::Display for Error {
}
}
#[derive(Debug)]
pub struct Propagator<'ast, 'a, T> {
#[derive(Debug, Default)]
pub struct Propagator<'ast, T> {
// constants keeps track of constant expressions
// we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is
constants: &'a mut Constants<'ast, T>,
constants: Constants<'ast, T>,
}
impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> {
pub fn with_constants(constants: &'a mut Constants<'ast, T>) -> Self {
Propagator { constants }
}
impl<'ast, T: Field> Propagator<'ast, T> {
pub fn propagate(p: TypedProgram<'ast, T>) -> Result<TypedProgram<'ast, T>, Error> {
let mut constants = Constants::new();
Propagator {
constants: &mut constants,
}
.fold_program(p)
Propagator::default().fold_program(p)
}
// get a mutable reference to the constant corresponding to a given assignee if any, otherwise
@ -141,7 +132,7 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> {
}
}
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
type Error = Error;
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> Result<TypedProgram<'ast, T>, Error> {

View file

@ -65,7 +65,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame,
frame: _,
},
version,
},
@ -94,7 +94,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame,
frame: _,
},
version,
},
@ -124,7 +124,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame,
frame: _,
},
version,
},
@ -152,7 +152,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame,
frame: _,
},
version,
},
@ -182,7 +182,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame,
frame: _,
},
version,
},
@ -212,7 +212,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
id:
FrameIdentifier {
id: CoreIdentifier::Constant(c),
frame,
frame: _,
},
version,
},

View file

@ -5,9 +5,9 @@ use crate::reducer::{
};
use std::collections::{BTreeMap, HashSet};
use zokrates_ast::typed::{
result_folder::*, types::ConcreteGenericsAssignment, Constant, OwnedTypedModuleId, Typed,
TypedConstant, TypedConstantSymbol, TypedConstantSymbolDeclaration, TypedModuleId,
TypedProgram, TypedSymbolDeclaration, UExpression,
result_folder::*, Constant, OwnedTypedModuleId, Typed, TypedConstant, TypedConstantSymbol,
TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram, TypedSymbolDeclaration,
UExpression,
};
use zokrates_field::Field;

View file

@ -26,13 +26,10 @@
// - The body of the function is in SSA form
// - The return value(s) are assigned to internal variables
use crate::reducer::ShallowTransformer;
use crate::reducer::Versions;
use zokrates_ast::common::FlatEmbed;
use zokrates_ast::typed::types::{ConcreteGenericsAssignment, IntoType};
use zokrates_ast::typed::CoreIdentifier;
use zokrates_ast::typed::Identifier;
use zokrates_ast::typed::TypedAssignee;
use zokrates_ast::typed::UBitwidth;
use zokrates_ast::typed::{

View file

@ -18,30 +18,26 @@ mod shallow_ssa;
use self::inline::{inline_call, InlineError};
use std::collections::HashMap;
use zokrates_ast::typed::identifier::FrameIdentifier;
use zokrates_ast::typed::result_folder::*;
use zokrates_ast::typed::types::ConcreteGenericsAssignment;
use zokrates_ast::typed::types::GGenericsAssignment;
use zokrates_ast::typed::DeclarationParameter;
use zokrates_ast::typed::Folder;
use zokrates_ast::typed::Typed;
use zokrates_ast::typed::TypedAssemblyStatement;
use zokrates_ast::typed::TypedAssignee;
use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable};
use zokrates_ast::typed::{
ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall,
FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId,
TypedExpression, TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration,
TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner,
FunctionCallExpression, FunctionCallOrExpression, Id, OwnedTypedModuleId, TypedExpression,
TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram,
TypedStatement, UExpression, UExpressionInner,
};
use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable};
use zokrates_ast::zir::result_folder::fold_assembly_statement;
use zokrates_field::Field;
use self::constants_writer::ConstantsWriter;
use self::shallow_ssa::ShallowTransformer;
use crate::propagation::{Constants, Propagator};
use crate::propagation;
use crate::propagation::Propagator;
use std::fmt;
@ -55,7 +51,6 @@ pub type ConstantDefinitions<'ast, T> =
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct Versions<'ast> {
map: HashMap<usize, HashMap<CoreIdentifier<'ast>, usize>>,
frame: usize,
}
impl<'ast> Versions<'ast> {
@ -69,15 +64,15 @@ impl<'ast> Versions<'ast> {
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq)]
pub enum Error {
Incompatible(String),
GenericsInMain,
// TODO: give more details about what's blocking the progress
NoProgress,
LoopTooLarge(u128),
ConstantReduction(String, OwnedTypedModuleId),
NonConstant(String),
Type(String),
Propagation(propagation::Error),
}
impl fmt::Display for Error {
@ -89,39 +84,36 @@ impl fmt::Display for Error {
s
),
Error::GenericsInMain => write!(f, "Cannot generate code for generic function"),
Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"),
Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE),
Error::ConstantReduction(name, module) => write!(f, "Failed to reduce constant `{}` in module `{}` to a literal, try simplifying its declaration", name, module.display()),
Error::Type(message) => write!(f, "{}", message),
Error::NonConstant(s) => write!(f, "{}", s),
Error::Type(s) => write!(f, "{}", s),
Error::Propagation(e) => write!(f, "{}", e),
}
}
}
impl From<propagation::Error> for Error {
fn from(e: propagation::Error) -> Self {
Self::Propagation(e)
}
}
#[derive(Debug)]
struct Reducer<'ast, 'a, T> {
propagator: Propagator<'ast, 'a, T>,
statement_buffer: Vec<TypedStatement<'ast, T>>,
latest_frame: usize,
versions: &'a mut Versions<'ast>,
program: &'a TypedProgram<'ast, T>,
propagator: Propagator<'ast, T>,
ssa: ShallowTransformer<'ast>,
statement_buffer: Vec<TypedStatement<'ast, T>>,
}
impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
fn new(
program: &'a TypedProgram<'ast, T>,
versions: &'a mut Versions<'ast>,
constants: &'a mut Constants<'ast, T>,
) -> Self {
// println!("create reducer with");
// println!("{} versions", versions.len());
// println!("{} constants", constants.len());
fn new(program: &'a TypedProgram<'ast, T>) -> Self {
Reducer {
propagator: Propagator::with_constants(constants),
propagator: Propagator::default(),
ssa: ShallowTransformer::default(),
statement_buffer: vec![],
latest_frame: 0,
program,
versions,
}
}
}
@ -133,8 +125,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
&mut self,
p: DeclarationParameter<'ast, T>,
) -> Result<DeclarationParameter<'ast, T>, Self::Error> {
// this is only used on the entry point
let id = p.id.id.id.id.clone();
assert!(self.versions.insert_in_frame(id, 0, 0).is_none());
assert!(self.ssa.versions.insert_in_frame(id, 0, 0).is_none());
Ok(p)
}
@ -150,9 +143,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
.into_iter()
.map(|g| {
g.map(|g| {
let g =
ShallowTransformer::with_versions(self.versions).fold_uint_expression(g);
let g = self.propagator.fold_uint_expression(g).unwrap();
let g = self.ssa.fold_uint_expression(g);
let g = self.propagator.fold_uint_expression(g)?;
self.fold_uint_expression(g)
})
@ -164,43 +156,30 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
.arguments
.into_iter()
.map(|e| {
let e = ShallowTransformer::with_versions(self.versions).fold_expression(e);
let e = self.propagator.fold_expression(e).unwrap();
let e = self.ssa.fold_expression(e);
let e = self.propagator.fold_expression(e)?;
self.fold_expression(e)
})
.collect::<Result<_, _>>()?;
// back up the current frame
let frame_backup = self.versions.frame;
self.ssa.push_call_frame();
// create a new frame
self.latest_frame += 1;
// point the versions to this frame
self.versions.frame = self.latest_frame;
self.versions
.map
.insert(self.versions.frame, Default::default());
// println!("GENERICS {:?}", generics);
let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program);
let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, &self.program);
let res = match res {
Ok((input_variables, arguments, generics_bindings, statements, expression)) => {
let generics_bindings: Vec<_> = generics_bindings
let generics_bindings = generics_bindings
.into_iter()
.flat_map(|s| {
ShallowTransformer::with_versions(self.versions).fold_statement(s)
})
.flat_map(|s| self.propagator.fold_statement(s).unwrap())
.collect::<Vec<_>>()
.flat_map(|s| self.ssa.fold_statement(s))
.map(|s| self.propagator.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flat_map(|s| self.fold_statement(s).unwrap())
.collect();
// println!("{:#?}", propagator);
.flatten()
.map(|s| self.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
self.statement_buffer.extend(generics_bindings);
@ -208,43 +187,32 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
let input_bindings: Vec<_> = input_variables
.into_iter()
.zip(arguments)
.map(|(v, a)| {
TypedStatement::definition(
ShallowTransformer::with_versions(self.versions)
.fold_assignee(v.into()),
a,
)
})
.map(|(v, a)| TypedStatement::definition(self.ssa.fold_assignee(v.into()), a))
.collect();
let input_bindings: Vec<_> = input_bindings
let input_bindings = input_bindings
.into_iter()
.flat_map(|s| self.propagator.fold_statement(s).unwrap())
.collect();
.map(|s| self.propagator.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
self.statement_buffer.extend(input_bindings);
let statements: Vec<_> = statements
let statements = statements
.into_iter()
.flat_map(|s| self.fold_statement(s).unwrap())
.collect();
// println!("FRAME READY TO SSA {}", self.versions.frame);
let mut transformer = ShallowTransformer::with_versions(self.versions);
let propagator = &mut self.propagator;
.map(|s| self.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten();
self.statement_buffer.extend(statements);
// println!("call result {}", expression);
let expression = self.ssa.fold_expression(expression);
let expression = transformer.fold_expression(expression);
let expression = self.propagator.fold_expression(expression)?;
let expression = propagator.fold_expression(expression).unwrap();
let expression = self.fold_expression(expression).unwrap();
// println!("call result reduced {}", expression);
let expression = self.fold_expression(expression)?;
Ok(FunctionCallOrExpression::Expression(
E::from(expression).into_inner(),
@ -254,20 +222,14 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
"Call site `{}` incompatible with declaration `{}`",
conc, decl
))),
Err(InlineError::NonConstant(key, generics, arguments, _)) => Err(Error::NoProgress),
Err(InlineError::NonConstant(key, generics, arguments, _)) => {
Err(Error::NonConstant(format!(
"Generic parameters must be compile-time constants, found {}",
FunctionCallExpression::<_, E>::new(key, generics, arguments)
)))
}
Err(InlineError::Flat(embed, generics, arguments, output_type)) => {
let identifier =
Identifier::from(CoreIdentifier::Call(0).in_frame(self.versions.frame))
.version(
*self
.versions
.map
.entry(self.versions.frame)
.or_default()
.entry(CoreIdentifier::Call(0))
.and_modify(|e| *e += 1) // if it was already declared, we increment
.or_insert(0),
);
let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0));
let var = Variable::immutable(identifier.clone(), output_type);
@ -284,12 +246,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
}
};
// clean versions
self.versions.map.remove(&self.versions.frame);
// restore the original frame
// println!("RESTORING BACKUP {}", frame_backup);
self.versions.frame = frame_backup;
self.ssa.pop_call_frame();
res
}
@ -298,24 +255,44 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
&mut self,
b: BlockExpression<'ast, T, E>,
) -> Result<BlockExpression<'ast, T, E>, Self::Error> {
// // backup the statements and continue with a fresh state
// let statement_buffer = std::mem::take(&mut self.statement_buffer);
// backup the statements and continue with a fresh state
let statement_buffer = std::mem::take(&mut self.statement_buffer);
// let block = fold_block_expression(self, b)?;
let block = fold_block_expression(self, b)?;
// // put the original statements back and extract the statements created by visiting the block
// let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer);
// put the original statements back and extract the statements created by visiting the block
let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer);
// // return the visited block, augmented with the statements created while visiting it
// Ok(BlockExpression {
// statements: block
// .statements
// .into_iter()
// .chain(extra_statements)
// .collect(),
// ..block
// })
todo!()
// return the visited block, augmented with the statements created while visiting it
Ok(BlockExpression {
statements: block
.statements
.into_iter()
.chain(extra_statements)
.collect(),
..block
})
}
fn fold_assembly_statement(
&mut self,
s: TypedAssemblyStatement<'ast, T>,
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, Self::Error> {
Ok(match s {
TypedAssemblyStatement::Assignment(a, e) => {
vec![TypedAssemblyStatement::Assignment(
self.fold_assignee(a)?,
self.fold_expression(e)?,
)]
}
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
vec![TypedAssemblyStatement::Constraint(
self.fold_field_expression(lhs)?,
self.fold_field_expression(rhs)?,
metadata,
)]
}
})
}
fn fold_canonical_constant_identifier(
@ -336,163 +313,77 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
// then we reduce the rhs to remove the function calls
// only then we transform and propagate the assignee
let mut transformer = ShallowTransformer::with_versions(self.versions);
// println!("rhs {}", rhs);
let rhs = transformer.fold_definition_rhs(rhs);
// println!("ssa-ed {}", rhs);
let rhs = self.propagator.fold_definition_rhs(rhs).unwrap();
// println!("propagated {}", rhs);
let rhs = self.fold_definition_rhs(rhs).unwrap();
// println!("reduced {}", rhs);
let rhs = self.ssa.fold_definition_rhs(rhs);
let rhs = self.propagator.fold_definition_rhs(rhs)?;
let rhs = self.fold_definition_rhs(rhs)?;
// println!("ASSIGNEE {}", a);
// println!("{:?}", self.versions);
let a = ShallowTransformer::with_versions(self.versions).fold_assignee(a);
// println!("{:?}", self.versions);
// println!("definition before propagation {}", TypedStatement::Definition(a.clone(), rhs.clone()));
let a = self.ssa.fold_assignee(a);
self.propagator
.fold_statement(TypedStatement::Definition(a, rhs))
.unwrap()
// println!("final definition size: {}", s.len());
.fold_statement(TypedStatement::Definition(a, rhs))?
}
TypedStatement::For(v, from, to, statements) => {
let from =
ShallowTransformer::with_versions(self.versions).fold_uint_expression(from);
let from = self.propagator.fold_uint_expression(from).unwrap();
let from = self.fold_uint_expression(from).unwrap();
let to = ShallowTransformer::with_versions(self.versions).fold_uint_expression(to);
let to = self.propagator.fold_uint_expression(to).unwrap();
let to = self.fold_uint_expression(to).unwrap();
let from = self.ssa.fold_uint_expression(from);
let from = self.propagator.fold_uint_expression(from)?;
let from = self.fold_uint_expression(from)?;
let to = self.ssa.fold_uint_expression(to);
let to = self.propagator.fold_uint_expression(to)?;
let to = self.fold_uint_expression(to)?;
match (from.as_inner(), to.as_inner()) {
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => (*from..*to)
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from
..*to)
.flat_map(|index| {
std::iter::once(TypedStatement::definition(
v.clone().into(),
UExpression::from(index as u32).into(),
))
.chain(statements.clone())
.flat_map(|s| self.fold_statement(s).unwrap())
.map(|s| self.fold_statement(s))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>(),
_ => unimplemented!(),
}
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect()),
_ => Err(Error::NonConstant(format!(
"Expected loop bounds to be constant, found {}..{}",
from, to
))),
}?
}
TypedStatement::Assembly(_) => todo!(),
TypedStatement::Return(e) => {
let mut transformer = ShallowTransformer::with_versions(self.versions);
let e = transformer.fold_expression(e);
let e = self.propagator.fold_expression(e).unwrap();
vec![TypedStatement::Return(self.fold_expression(e).unwrap())]
let e = self.ssa.fold_expression(e);
let e = self.propagator.fold_expression(e)?;
vec![TypedStatement::Return(self.fold_expression(e)?)]
}
TypedStatement::Assertion(e, error) => {
let mut transformer = ShallowTransformer::with_versions(self.versions);
let e = transformer.fold_boolean_expression(e);
let e = self.propagator.fold_boolean_expression(e).unwrap();
let e = self.ssa.fold_boolean_expression(e);
let e = self.propagator.fold_boolean_expression(e)?;
vec![TypedStatement::Assertion(
self.fold_boolean_expression(e).unwrap(),
self.fold_boolean_expression(e)?,
error,
)]
}
s => {
let mut transformer = ShallowTransformer::with_versions(self.versions);
let propagator = &mut self.propagator;
transformer
.fold_statement(s)
.into_iter()
// .inspect(|s| println!("ssa {}\n", s))
.flat_map(|s| propagator.fold_statement(s).unwrap())
.collect::<Vec<_>>()
.into_iter()
.flat_map(|s| fold_statement(self, s).unwrap())
// .inspect(|s| println!("propagated {}\n", s))
.collect()
}
s => self
.ssa
.fold_statement(s)
.into_iter()
.map(|s| self.propagator.fold_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.map(|s| fold_statement(self, s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect(),
};
Ok(self
.statement_buffer
.drain(..)
.chain(res)
.collect::<Vec<_>>())
Ok(self.statement_buffer.drain(..).chain(res).collect())
}
// fn fold_statement(
// &mut self,
// s: TypedStatement<'ast, T>,
// ) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
// let mut transformer = ShallowTransformer::with_versions(self.versions);
// let propagator = &mut self.propagator;
// println!("FOLD_STATEMENT: {}", s);
// let s: Vec<_> = transformer
// .fold_statement(s)
// .into_iter()
// // .inspect(|s| println!("ssa {}\n", s))
// .flat_map(|s| propagator.fold_statement(s).unwrap())
// // .inspect(|s| println!("propagated {}\n", s))
// .collect();
// for s in &s {
// println!("OUTER: {}", s);
// }
// let res: Vec<_> = s
// .into_iter()
// .flat_map(|s| match s {
// TypedStatement::For(v, from, to, statements) => {
// match (from.as_inner(), to.as_inner()) {
// (UExpressionInner::Value(from), UExpressionInner::Value(to)) => (*from
// ..*to)
// .flat_map(|index| {
// std::iter::once(TypedStatement::definition(
// v.clone().into(),
// UExpression::from(index as u32).into(),
// ))
// .chain(statements.clone())
// .flat_map(|s| self.fold_statement(s).unwrap())
// .collect::<Vec<_>>()
// })
// .collect(),
// _ => unimplemented!(),
// }
// }
// s => {
// println!("UNROLL/INLINE STATEMENT {}", s);
// let s = fold_statement(self, s).unwrap();
// for s in &self.statement_buffer {
// println!("BUFFER {}", s);
// }
// for s in &s {
// println!("RESULT {}", s);
// }
// self.statement_buffer.drain(..).chain(s).collect::<Vec<_>>()
// }
// })
// .collect();
// for s in &res {
// // println!("DONE: {}", s);
// }
// Ok(res)
// }
fn fold_array_expression_inner(
&mut self,
array_ty: &ArrayType<'ast, T>,
@ -500,25 +391,24 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
match e {
ArrayExpressionInner::Slice(box array, box from, box to) => {
let array =
ShallowTransformer::with_versions(self.versions).fold_array_expression(array);
let array = self.propagator.fold_array_expression(array).unwrap();
let array = self.fold_array_expression(array).unwrap();
let from =
ShallowTransformer::with_versions(self.versions).fold_uint_expression(from);
let from = self.propagator.fold_uint_expression(from).unwrap();
let from = self.fold_uint_expression(from).unwrap();
let to = ShallowTransformer::with_versions(self.versions).fold_uint_expression(to);
let to = self.propagator.fold_uint_expression(to).unwrap();
let to = self.fold_uint_expression(to).unwrap();
let array = self.ssa.fold_array_expression(array);
let array = self.propagator.fold_array_expression(array)?;
let array = self.fold_array_expression(array)?;
let from = self.ssa.fold_uint_expression(from);
let from = self.propagator.fold_uint_expression(from)?;
let from = self.fold_uint_expression(from)?;
let to = self.ssa.fold_uint_expression(to);
let to = self.propagator.fold_uint_expression(to)?;
let to = self.fold_uint_expression(to)?;
match (from.as_inner(), to.as_inner()) {
(UExpressionInner::Value(..), UExpressionInner::Value(..)) => {
Ok(ArrayExpressionInner::Slice(box array, box from, box to))
}
_ => {
todo!("non constant slice bounds")
}
_ => Err(Error::NonConstant(format!(
"Slice bounds must be compile time constants, found {}",
ArrayExpressionInner::Slice(box array, box from, box to)
))),
}
}
_ => fold_array_expression_inner(self, array_ty, e),
@ -548,7 +438,7 @@ pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, E
match main_function.signature.generics.len() {
0 => {
let main_function = reduce_function(main_function, &p)?;
let main_function = Reducer::new(&p).fold_function(main_function)?;
Ok(TypedProgram {
main: p.main.clone(),
@ -574,135 +464,9 @@ fn reduce_function<'ast, T: Field>(
f: TypedFunction<'ast, T>,
program: &TypedProgram<'ast, T>,
) -> Result<TypedFunction<'ast, T>, Error> {
let mut versions = Versions::default();
let mut constants = Constants::default();
assert!(f.signature.generics.is_empty());
// let f = match ShallowTransformer::transform(f, &generics, &mut versions) {
// Output::Complete(f) => Ok(f),
// Output::Incomplete(new_f, new_for_loop_versions) => {
// let mut for_loop_versions = new_for_loop_versions;
// let mut f = Propagator::with_constants(&mut constants)
// .fold_function(new_f)
// .map_err(|e| Error::Incompatible(format!("{}", e)))?;
// let mut substitutions = Substitutions::default();
// let mut hash = None;
// let mut len = f.statements.len();
// // println!("{}", f);
// loop {
// let mut reducer = Reducer::new(
// program,
// &mut versions,
// &mut substitutions,
// for_loop_versions,
// &mut constants,
// );
// println!("reduce");
// let new_f = TypedFunction {
// statements: f
// .statements
// .into_iter()
// .map(|s| reducer.fold_statement(s))
// .collect::<Result<Vec<_>, _>>()?
// .into_iter()
// .flatten()
// .collect(),
// ..f
// };
// println!("done");
// // println!("after reduction {}", new_f);
// println!(
// "count {}, unrolled {} loops",
// new_f.statements.len(),
// reducer.credits
// );
// assert!(reducer.for_loop_versions.is_empty());
// match reducer.complete {
// true => {
// substitutions = substitutions.canonicalize();
// let new_f = Sub::new(&substitutions).fold_function(new_f);
// // println!("after last sub {}", new_f);
// // let new_f = Propagator::with_constants(&mut constants)
// // .fold_function(new_f)
// // .map_err(|e| Error::Incompatible(format!("{}", e)))?;
// // println!("after last prop {}", new_f);
// break Ok(new_f);
// }
// false => {
// for_loop_versions = reducer.for_loop_versions_after;
// println!("canonicalize");
// // substitutions = substitutions.canonicalize();
// // let new_f = Sub::new(&substitutions).fold_function(new_f);
// println!("done");
// // println!("after sub {}", new_f);
// println!("propagate");
// // f = Propagator::with_constants(&mut constants)
// // .fold_function(new_f)
// // .map_err(|e| Error::Incompatible(format!("{}", e)))?;
// println!("done");
// f = new_f;
// // println!("after prop {}", f);
// let new_len = f.statements.len();
// if new_len == len {
// let new_hash = Some(compute_hash(&f));
// if new_hash == hash {
// break Err(Error::NoProgress);
// } else {
// hash = new_hash;
// }
// } else {
// len = new_len;
// }
// }
// }
// }
// }
// }?;
// Propagator::with_constants(&mut constants)
// .fold_function(f)
// .map_err(|e| Error::Incompatible(format!("{}", e)))
Reducer::new(program, &mut versions, &mut constants).fold_function(f)
}
fn compute_hash<T: Field>(f: &TypedFunction<T>) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut s = DefaultHasher::new();
f.hash(&mut s);
s.finish()
Reducer::new(program).fold_function(f)
}
#[cfg(test)]

View file

@ -28,26 +28,24 @@
use zokrates_ast::typed::folder::*;
use zokrates_ast::typed::identifier::FrameIdentifier;
use zokrates_ast::typed::types::ConcreteGenericsAssignment;
use zokrates_ast::typed::types::Type;
use zokrates_ast::typed::*;
use zokrates_field::Field;
use super::Versions;
pub struct ShallowTransformer<'ast, 'a> {
#[derive(Debug, Default)]
pub struct ShallowTransformer<'ast> {
// version index for any variable name
pub versions: &'a mut Versions<'ast>,
pub versions: Versions<'ast>,
pub frames: Vec<usize>,
pub latest_frame: usize,
}
impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self {
ShallowTransformer { versions }
}
fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> {
let frame_versions = self.versions.map.entry(self.versions.frame).or_default();
impl<'ast> ShallowTransformer<'ast> {
pub fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> {
let frame_versions = self.versions.map.entry(self.frame()).or_default();
let version = frame_versions
.entry(c_id.clone())
@ -66,42 +64,21 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
}
}
pub fn transform<T: Field>(
f: TypedFunction<'ast, T>,
generics: &ConcreteGenericsAssignment<'ast>,
versions: &'a mut Versions<'ast>,
) -> TypedFunction<'ast, T> {
let mut unroller = ShallowTransformer::with_versions(versions);
unroller.fold_function(f, generics)
fn frame(&self) -> usize {
*self.frames.last().unwrap_or(&0)
}
fn fold_function<T: Field>(
&mut self,
f: TypedFunction<'ast, T>,
generics: &ConcreteGenericsAssignment<'ast>,
) -> TypedFunction<'ast, T> {
let mut f = f;
pub fn push_call_frame(&mut self) {
self.latest_frame += 1;
self.frames.push(self.latest_frame);
self.versions
.map
.insert(self.latest_frame, Default::default());
}
f.statements = generics
.0
.clone()
.into_iter()
.map(|(g, v)| {
TypedStatement::definition(
Variable::new(CoreIdentifier::from(g), Type::Uint(UBitwidth::B32), false)
.into(),
UExpression::from(v as u32).into(),
)
})
.chain(f.statements)
.collect();
for arg in &f.arguments {
let _ = self.issue_next_identifier(arg.id.id.id.id.clone());
}
fold_function(self, f)
pub fn pop_call_frame(&mut self) {
let frame = self.frames.pop().unwrap();
self.versions.map.remove(&frame);
}
pub fn fold_assignee<T: Field>(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
@ -115,7 +92,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
}
}
impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> {
fn fold_assembly_statement(
&mut self,
s: TypedAssemblyStatement<'ast, T>,
@ -154,14 +131,14 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
let version = self
.versions
.map
.get(&self.versions.frame)
.get(&self.frame())
.unwrap()
.get(&n.id.id)
.cloned()
.unwrap_or(0);
let id = FrameIdentifier {
frame: self.versions.frame,
frame: self.frame(),
..n.id
};

View file

@ -135,7 +135,7 @@ pub trait Folder<'ast, T: Field>: Sized {
id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)),
frame: 0,
},
id => n.id,
_id => n.id,
};
Identifier { id, ..n }

View file

@ -240,7 +240,7 @@ impl<'ast, T> From<u32> for UExpression<'ast, T> {
impl<'ast, T: Field> From<DeclarationConstant<'ast, T>> for UExpression<'ast, T> {
fn from(c: DeclarationConstant<'ast, T>) -> Self {
match c {
DeclarationConstant::Generic(g) => {
DeclarationConstant::Generic(_g) => {
// UExpression::identifier(FrameIdentifier::from(g).into()).annotate(UBitwidth::B32)
unreachable!()
}