clean
This commit is contained in:
parent
01450c741f
commit
1bb524f6a2
8 changed files with 199 additions and 470 deletions
|
@ -44,25 +44,16 @@ impl fmt::Display for Error {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Propagator<'ast, 'a, T> {
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Propagator<'ast, T> {
|
||||
// constants keeps track of constant expressions
|
||||
// we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is
|
||||
constants: &'a mut Constants<'ast, T>,
|
||||
constants: Constants<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> {
|
||||
pub fn with_constants(constants: &'a mut Constants<'ast, T>) -> Self {
|
||||
Propagator { constants }
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Propagator<'ast, T> {
|
||||
pub fn propagate(p: TypedProgram<'ast, T>) -> Result<TypedProgram<'ast, T>, Error> {
|
||||
let mut constants = Constants::new();
|
||||
|
||||
Propagator {
|
||||
constants: &mut constants,
|
||||
}
|
||||
.fold_program(p)
|
||||
Propagator::default().fold_program(p)
|
||||
}
|
||||
|
||||
// get a mutable reference to the constant corresponding to a given assignee if any, otherwise
|
||||
|
@ -141,7 +132,7 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> Result<TypedProgram<'ast, T>, Error> {
|
||||
|
|
|
@ -65,7 +65,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame,
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
|
@ -94,7 +94,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame,
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
|
@ -124,7 +124,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame,
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
|
@ -152,7 +152,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame,
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
|
@ -182,7 +182,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame,
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
|
@ -212,7 +212,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
|||
id:
|
||||
FrameIdentifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
frame,
|
||||
frame: _,
|
||||
},
|
||||
version,
|
||||
},
|
||||
|
|
|
@ -5,9 +5,9 @@ use crate::reducer::{
|
|||
};
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use zokrates_ast::typed::{
|
||||
result_folder::*, types::ConcreteGenericsAssignment, Constant, OwnedTypedModuleId, Typed,
|
||||
TypedConstant, TypedConstantSymbol, TypedConstantSymbolDeclaration, TypedModuleId,
|
||||
TypedProgram, TypedSymbolDeclaration, UExpression,
|
||||
result_folder::*, Constant, OwnedTypedModuleId, Typed, TypedConstant, TypedConstantSymbol,
|
||||
TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram, TypedSymbolDeclaration,
|
||||
UExpression,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
|
|
@ -26,13 +26,10 @@
|
|||
// - The body of the function is in SSA form
|
||||
// - The return value(s) are assigned to internal variables
|
||||
|
||||
use crate::reducer::ShallowTransformer;
|
||||
use crate::reducer::Versions;
|
||||
|
||||
use zokrates_ast::common::FlatEmbed;
|
||||
use zokrates_ast::typed::types::{ConcreteGenericsAssignment, IntoType};
|
||||
use zokrates_ast::typed::CoreIdentifier;
|
||||
use zokrates_ast::typed::Identifier;
|
||||
|
||||
use zokrates_ast::typed::TypedAssignee;
|
||||
use zokrates_ast::typed::UBitwidth;
|
||||
use zokrates_ast::typed::{
|
||||
|
|
|
@ -18,30 +18,26 @@ mod shallow_ssa;
|
|||
|
||||
use self::inline::{inline_call, InlineError};
|
||||
use std::collections::HashMap;
|
||||
use zokrates_ast::typed::identifier::FrameIdentifier;
|
||||
use zokrates_ast::typed::result_folder::*;
|
||||
use zokrates_ast::typed::types::ConcreteGenericsAssignment;
|
||||
use zokrates_ast::typed::types::GGenericsAssignment;
|
||||
use zokrates_ast::typed::DeclarationParameter;
|
||||
use zokrates_ast::typed::Folder;
|
||||
use zokrates_ast::typed::Typed;
|
||||
use zokrates_ast::typed::TypedAssemblyStatement;
|
||||
use zokrates_ast::typed::TypedAssignee;
|
||||
use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable};
|
||||
|
||||
use zokrates_ast::typed::{
|
||||
ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall,
|
||||
FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId,
|
||||
TypedExpression, TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration,
|
||||
TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner,
|
||||
FunctionCallExpression, FunctionCallOrExpression, Id, OwnedTypedModuleId, TypedExpression,
|
||||
TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram,
|
||||
TypedStatement, UExpression, UExpressionInner,
|
||||
};
|
||||
use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable};
|
||||
|
||||
use zokrates_ast::zir::result_folder::fold_assembly_statement;
|
||||
use zokrates_field::Field;
|
||||
|
||||
use self::constants_writer::ConstantsWriter;
|
||||
use self::shallow_ssa::ShallowTransformer;
|
||||
|
||||
use crate::propagation::{Constants, Propagator};
|
||||
use crate::propagation;
|
||||
use crate::propagation::Propagator;
|
||||
|
||||
use std::fmt;
|
||||
|
||||
|
@ -55,7 +51,6 @@ pub type ConstantDefinitions<'ast, T> =
|
|||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct Versions<'ast> {
|
||||
map: HashMap<usize, HashMap<CoreIdentifier<'ast>, usize>>,
|
||||
frame: usize,
|
||||
}
|
||||
|
||||
impl<'ast> Versions<'ast> {
|
||||
|
@ -69,15 +64,15 @@ impl<'ast> Versions<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
Incompatible(String),
|
||||
GenericsInMain,
|
||||
// TODO: give more details about what's blocking the progress
|
||||
NoProgress,
|
||||
LoopTooLarge(u128),
|
||||
ConstantReduction(String, OwnedTypedModuleId),
|
||||
NonConstant(String),
|
||||
Type(String),
|
||||
Propagation(propagation::Error),
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
|
@ -89,39 +84,36 @@ impl fmt::Display for Error {
|
|||
s
|
||||
),
|
||||
Error::GenericsInMain => write!(f, "Cannot generate code for generic function"),
|
||||
Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"),
|
||||
Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE),
|
||||
Error::ConstantReduction(name, module) => write!(f, "Failed to reduce constant `{}` in module `{}` to a literal, try simplifying its declaration", name, module.display()),
|
||||
Error::Type(message) => write!(f, "{}", message),
|
||||
Error::NonConstant(s) => write!(f, "{}", s),
|
||||
Error::Type(s) => write!(f, "{}", s),
|
||||
Error::Propagation(e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<propagation::Error> for Error {
|
||||
fn from(e: propagation::Error) -> Self {
|
||||
Self::Propagation(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Reducer<'ast, 'a, T> {
|
||||
propagator: Propagator<'ast, 'a, T>,
|
||||
statement_buffer: Vec<TypedStatement<'ast, T>>,
|
||||
latest_frame: usize,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
program: &'a TypedProgram<'ast, T>,
|
||||
propagator: Propagator<'ast, T>,
|
||||
ssa: ShallowTransformer<'ast>,
|
||||
statement_buffer: Vec<TypedStatement<'ast, T>>,
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
||||
fn new(
|
||||
program: &'a TypedProgram<'ast, T>,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
constants: &'a mut Constants<'ast, T>,
|
||||
) -> Self {
|
||||
// println!("create reducer with");
|
||||
// println!("{} versions", versions.len());
|
||||
// println!("{} constants", constants.len());
|
||||
|
||||
fn new(program: &'a TypedProgram<'ast, T>) -> Self {
|
||||
Reducer {
|
||||
propagator: Propagator::with_constants(constants),
|
||||
propagator: Propagator::default(),
|
||||
ssa: ShallowTransformer::default(),
|
||||
statement_buffer: vec![],
|
||||
latest_frame: 0,
|
||||
program,
|
||||
versions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -133,8 +125,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
&mut self,
|
||||
p: DeclarationParameter<'ast, T>,
|
||||
) -> Result<DeclarationParameter<'ast, T>, Self::Error> {
|
||||
// this is only used on the entry point
|
||||
let id = p.id.id.id.id.clone();
|
||||
assert!(self.versions.insert_in_frame(id, 0, 0).is_none());
|
||||
assert!(self.ssa.versions.insert_in_frame(id, 0, 0).is_none());
|
||||
Ok(p)
|
||||
}
|
||||
|
||||
|
@ -150,9 +143,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
.into_iter()
|
||||
.map(|g| {
|
||||
g.map(|g| {
|
||||
let g =
|
||||
ShallowTransformer::with_versions(self.versions).fold_uint_expression(g);
|
||||
let g = self.propagator.fold_uint_expression(g).unwrap();
|
||||
let g = self.ssa.fold_uint_expression(g);
|
||||
let g = self.propagator.fold_uint_expression(g)?;
|
||||
|
||||
self.fold_uint_expression(g)
|
||||
})
|
||||
|
@ -164,43 +156,30 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
.arguments
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
let e = ShallowTransformer::with_versions(self.versions).fold_expression(e);
|
||||
let e = self.propagator.fold_expression(e).unwrap();
|
||||
let e = self.ssa.fold_expression(e);
|
||||
let e = self.propagator.fold_expression(e)?;
|
||||
|
||||
self.fold_expression(e)
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
// back up the current frame
|
||||
let frame_backup = self.versions.frame;
|
||||
self.ssa.push_call_frame();
|
||||
|
||||
// create a new frame
|
||||
self.latest_frame += 1;
|
||||
|
||||
// point the versions to this frame
|
||||
self.versions.frame = self.latest_frame;
|
||||
self.versions
|
||||
.map
|
||||
.insert(self.versions.frame, Default::default());
|
||||
|
||||
// println!("GENERICS {:?}", generics);
|
||||
|
||||
let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program);
|
||||
let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, &self.program);
|
||||
|
||||
let res = match res {
|
||||
Ok((input_variables, arguments, generics_bindings, statements, expression)) => {
|
||||
let generics_bindings: Vec<_> = generics_bindings
|
||||
let generics_bindings = generics_bindings
|
||||
.into_iter()
|
||||
.flat_map(|s| {
|
||||
ShallowTransformer::with_versions(self.versions).fold_statement(s)
|
||||
})
|
||||
.flat_map(|s| self.propagator.fold_statement(s).unwrap())
|
||||
.collect::<Vec<_>>()
|
||||
.flat_map(|s| self.ssa.fold_statement(s))
|
||||
.map(|s| self.propagator.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flat_map(|s| self.fold_statement(s).unwrap())
|
||||
.collect();
|
||||
|
||||
// println!("{:#?}", propagator);
|
||||
.flatten()
|
||||
.map(|s| self.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
self.statement_buffer.extend(generics_bindings);
|
||||
|
||||
|
@ -208,43 +187,32 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
let input_bindings: Vec<_> = input_variables
|
||||
.into_iter()
|
||||
.zip(arguments)
|
||||
.map(|(v, a)| {
|
||||
TypedStatement::definition(
|
||||
ShallowTransformer::with_versions(self.versions)
|
||||
.fold_assignee(v.into()),
|
||||
a,
|
||||
)
|
||||
})
|
||||
.map(|(v, a)| TypedStatement::definition(self.ssa.fold_assignee(v.into()), a))
|
||||
.collect();
|
||||
|
||||
let input_bindings: Vec<_> = input_bindings
|
||||
let input_bindings = input_bindings
|
||||
.into_iter()
|
||||
.flat_map(|s| self.propagator.fold_statement(s).unwrap())
|
||||
.collect();
|
||||
.map(|s| self.propagator.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
self.statement_buffer.extend(input_bindings);
|
||||
|
||||
let statements: Vec<_> = statements
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.flat_map(|s| self.fold_statement(s).unwrap())
|
||||
.collect();
|
||||
|
||||
// println!("FRAME READY TO SSA {}", self.versions.frame);
|
||||
|
||||
let mut transformer = ShallowTransformer::with_versions(self.versions);
|
||||
let propagator = &mut self.propagator;
|
||||
.map(|s| self.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
self.statement_buffer.extend(statements);
|
||||
|
||||
// println!("call result {}", expression);
|
||||
let expression = self.ssa.fold_expression(expression);
|
||||
|
||||
let expression = transformer.fold_expression(expression);
|
||||
let expression = self.propagator.fold_expression(expression)?;
|
||||
|
||||
let expression = propagator.fold_expression(expression).unwrap();
|
||||
|
||||
let expression = self.fold_expression(expression).unwrap();
|
||||
|
||||
// println!("call result reduced {}", expression);
|
||||
let expression = self.fold_expression(expression)?;
|
||||
|
||||
Ok(FunctionCallOrExpression::Expression(
|
||||
E::from(expression).into_inner(),
|
||||
|
@ -254,20 +222,14 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
"Call site `{}` incompatible with declaration `{}`",
|
||||
conc, decl
|
||||
))),
|
||||
Err(InlineError::NonConstant(key, generics, arguments, _)) => Err(Error::NoProgress),
|
||||
Err(InlineError::NonConstant(key, generics, arguments, _)) => {
|
||||
Err(Error::NonConstant(format!(
|
||||
"Generic parameters must be compile-time constants, found {}",
|
||||
FunctionCallExpression::<_, E>::new(key, generics, arguments)
|
||||
)))
|
||||
}
|
||||
Err(InlineError::Flat(embed, generics, arguments, output_type)) => {
|
||||
let identifier =
|
||||
Identifier::from(CoreIdentifier::Call(0).in_frame(self.versions.frame))
|
||||
.version(
|
||||
*self
|
||||
.versions
|
||||
.map
|
||||
.entry(self.versions.frame)
|
||||
.or_default()
|
||||
.entry(CoreIdentifier::Call(0))
|
||||
.and_modify(|e| *e += 1) // if it was already declared, we increment
|
||||
.or_insert(0),
|
||||
);
|
||||
let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0));
|
||||
|
||||
let var = Variable::immutable(identifier.clone(), output_type);
|
||||
|
||||
|
@ -284,12 +246,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
}
|
||||
};
|
||||
|
||||
// clean versions
|
||||
self.versions.map.remove(&self.versions.frame);
|
||||
|
||||
// restore the original frame
|
||||
// println!("RESTORING BACKUP {}", frame_backup);
|
||||
self.versions.frame = frame_backup;
|
||||
self.ssa.pop_call_frame();
|
||||
|
||||
res
|
||||
}
|
||||
|
@ -298,24 +255,44 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
&mut self,
|
||||
b: BlockExpression<'ast, T, E>,
|
||||
) -> Result<BlockExpression<'ast, T, E>, Self::Error> {
|
||||
// // backup the statements and continue with a fresh state
|
||||
// let statement_buffer = std::mem::take(&mut self.statement_buffer);
|
||||
// backup the statements and continue with a fresh state
|
||||
let statement_buffer = std::mem::take(&mut self.statement_buffer);
|
||||
|
||||
// let block = fold_block_expression(self, b)?;
|
||||
let block = fold_block_expression(self, b)?;
|
||||
|
||||
// // put the original statements back and extract the statements created by visiting the block
|
||||
// let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer);
|
||||
// put the original statements back and extract the statements created by visiting the block
|
||||
let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer);
|
||||
|
||||
// // return the visited block, augmented with the statements created while visiting it
|
||||
// Ok(BlockExpression {
|
||||
// statements: block
|
||||
// .statements
|
||||
// .into_iter()
|
||||
// .chain(extra_statements)
|
||||
// .collect(),
|
||||
// ..block
|
||||
// })
|
||||
todo!()
|
||||
// return the visited block, augmented with the statements created while visiting it
|
||||
Ok(BlockExpression {
|
||||
statements: block
|
||||
.statements
|
||||
.into_iter()
|
||||
.chain(extra_statements)
|
||||
.collect(),
|
||||
..block
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, Self::Error> {
|
||||
Ok(match s {
|
||||
TypedAssemblyStatement::Assignment(a, e) => {
|
||||
vec![TypedAssemblyStatement::Assignment(
|
||||
self.fold_assignee(a)?,
|
||||
self.fold_expression(e)?,
|
||||
)]
|
||||
}
|
||||
TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => {
|
||||
vec![TypedAssemblyStatement::Constraint(
|
||||
self.fold_field_expression(lhs)?,
|
||||
self.fold_field_expression(rhs)?,
|
||||
metadata,
|
||||
)]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_canonical_constant_identifier(
|
||||
|
@ -336,163 +313,77 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
// then we reduce the rhs to remove the function calls
|
||||
// only then we transform and propagate the assignee
|
||||
|
||||
let mut transformer = ShallowTransformer::with_versions(self.versions);
|
||||
// println!("rhs {}", rhs);
|
||||
let rhs = transformer.fold_definition_rhs(rhs);
|
||||
// println!("ssa-ed {}", rhs);
|
||||
let rhs = self.propagator.fold_definition_rhs(rhs).unwrap();
|
||||
// println!("propagated {}", rhs);
|
||||
let rhs = self.fold_definition_rhs(rhs).unwrap();
|
||||
// println!("reduced {}", rhs);
|
||||
let rhs = self.ssa.fold_definition_rhs(rhs);
|
||||
let rhs = self.propagator.fold_definition_rhs(rhs)?;
|
||||
let rhs = self.fold_definition_rhs(rhs)?;
|
||||
|
||||
// println!("ASSIGNEE {}", a);
|
||||
|
||||
// println!("{:?}", self.versions);
|
||||
|
||||
let a = ShallowTransformer::with_versions(self.versions).fold_assignee(a);
|
||||
|
||||
// println!("{:?}", self.versions);
|
||||
|
||||
// println!("definition before propagation {}", TypedStatement::Definition(a.clone(), rhs.clone()));
|
||||
let a = self.ssa.fold_assignee(a);
|
||||
|
||||
self.propagator
|
||||
.fold_statement(TypedStatement::Definition(a, rhs))
|
||||
.unwrap()
|
||||
|
||||
// println!("final definition size: {}", s.len());
|
||||
.fold_statement(TypedStatement::Definition(a, rhs))?
|
||||
}
|
||||
TypedStatement::For(v, from, to, statements) => {
|
||||
let from =
|
||||
ShallowTransformer::with_versions(self.versions).fold_uint_expression(from);
|
||||
let from = self.propagator.fold_uint_expression(from).unwrap();
|
||||
let from = self.fold_uint_expression(from).unwrap();
|
||||
let to = ShallowTransformer::with_versions(self.versions).fold_uint_expression(to);
|
||||
let to = self.propagator.fold_uint_expression(to).unwrap();
|
||||
let to = self.fold_uint_expression(to).unwrap();
|
||||
let from = self.ssa.fold_uint_expression(from);
|
||||
let from = self.propagator.fold_uint_expression(from)?;
|
||||
let from = self.fold_uint_expression(from)?;
|
||||
let to = self.ssa.fold_uint_expression(to);
|
||||
let to = self.propagator.fold_uint_expression(to)?;
|
||||
let to = self.fold_uint_expression(to)?;
|
||||
|
||||
match (from.as_inner(), to.as_inner()) {
|
||||
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => (*from..*to)
|
||||
(UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from
|
||||
..*to)
|
||||
.flat_map(|index| {
|
||||
std::iter::once(TypedStatement::definition(
|
||||
v.clone().into(),
|
||||
UExpression::from(index as u32).into(),
|
||||
))
|
||||
.chain(statements.clone())
|
||||
.flat_map(|s| self.fold_statement(s).unwrap())
|
||||
.map(|s| self.fold_statement(s))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect()),
|
||||
_ => Err(Error::NonConstant(format!(
|
||||
"Expected loop bounds to be constant, found {}..{}",
|
||||
from, to
|
||||
))),
|
||||
}?
|
||||
}
|
||||
TypedStatement::Assembly(_) => todo!(),
|
||||
TypedStatement::Return(e) => {
|
||||
let mut transformer = ShallowTransformer::with_versions(self.versions);
|
||||
|
||||
let e = transformer.fold_expression(e);
|
||||
let e = self.propagator.fold_expression(e).unwrap();
|
||||
vec![TypedStatement::Return(self.fold_expression(e).unwrap())]
|
||||
let e = self.ssa.fold_expression(e);
|
||||
let e = self.propagator.fold_expression(e)?;
|
||||
vec![TypedStatement::Return(self.fold_expression(e)?)]
|
||||
}
|
||||
TypedStatement::Assertion(e, error) => {
|
||||
let mut transformer = ShallowTransformer::with_versions(self.versions);
|
||||
|
||||
let e = transformer.fold_boolean_expression(e);
|
||||
let e = self.propagator.fold_boolean_expression(e).unwrap();
|
||||
let e = self.ssa.fold_boolean_expression(e);
|
||||
let e = self.propagator.fold_boolean_expression(e)?;
|
||||
|
||||
vec![TypedStatement::Assertion(
|
||||
self.fold_boolean_expression(e).unwrap(),
|
||||
self.fold_boolean_expression(e)?,
|
||||
error,
|
||||
)]
|
||||
}
|
||||
s => {
|
||||
let mut transformer = ShallowTransformer::with_versions(self.versions);
|
||||
let propagator = &mut self.propagator;
|
||||
transformer
|
||||
.fold_statement(s)
|
||||
.into_iter()
|
||||
// .inspect(|s| println!("ssa {}\n", s))
|
||||
.flat_map(|s| propagator.fold_statement(s).unwrap())
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.flat_map(|s| fold_statement(self, s).unwrap())
|
||||
// .inspect(|s| println!("propagated {}\n", s))
|
||||
.collect()
|
||||
}
|
||||
s => self
|
||||
.ssa
|
||||
.fold_statement(s)
|
||||
.into_iter()
|
||||
.map(|s| self.propagator.fold_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|s| fold_statement(self, s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
Ok(self
|
||||
.statement_buffer
|
||||
.drain(..)
|
||||
.chain(res)
|
||||
.collect::<Vec<_>>())
|
||||
Ok(self.statement_buffer.drain(..).chain(res).collect())
|
||||
}
|
||||
|
||||
// fn fold_statement(
|
||||
// &mut self,
|
||||
// s: TypedStatement<'ast, T>,
|
||||
// ) -> Result<Vec<TypedStatement<'ast, T>>, Self::Error> {
|
||||
// let mut transformer = ShallowTransformer::with_versions(self.versions);
|
||||
// let propagator = &mut self.propagator;
|
||||
|
||||
// println!("FOLD_STATEMENT: {}", s);
|
||||
|
||||
// let s: Vec<_> = transformer
|
||||
// .fold_statement(s)
|
||||
// .into_iter()
|
||||
// // .inspect(|s| println!("ssa {}\n", s))
|
||||
// .flat_map(|s| propagator.fold_statement(s).unwrap())
|
||||
// // .inspect(|s| println!("propagated {}\n", s))
|
||||
// .collect();
|
||||
|
||||
// for s in &s {
|
||||
// println!("OUTER: {}", s);
|
||||
// }
|
||||
|
||||
// let res: Vec<_> = s
|
||||
// .into_iter()
|
||||
// .flat_map(|s| match s {
|
||||
// TypedStatement::For(v, from, to, statements) => {
|
||||
// match (from.as_inner(), to.as_inner()) {
|
||||
// (UExpressionInner::Value(from), UExpressionInner::Value(to)) => (*from
|
||||
// ..*to)
|
||||
// .flat_map(|index| {
|
||||
// std::iter::once(TypedStatement::definition(
|
||||
// v.clone().into(),
|
||||
// UExpression::from(index as u32).into(),
|
||||
// ))
|
||||
// .chain(statements.clone())
|
||||
// .flat_map(|s| self.fold_statement(s).unwrap())
|
||||
// .collect::<Vec<_>>()
|
||||
// })
|
||||
// .collect(),
|
||||
// _ => unimplemented!(),
|
||||
// }
|
||||
// }
|
||||
// s => {
|
||||
// println!("UNROLL/INLINE STATEMENT {}", s);
|
||||
|
||||
// let s = fold_statement(self, s).unwrap();
|
||||
|
||||
// for s in &self.statement_buffer {
|
||||
// println!("BUFFER {}", s);
|
||||
// }
|
||||
|
||||
// for s in &s {
|
||||
// println!("RESULT {}", s);
|
||||
// }
|
||||
|
||||
// self.statement_buffer.drain(..).chain(s).collect::<Vec<_>>()
|
||||
// }
|
||||
// })
|
||||
// .collect();
|
||||
|
||||
// for s in &res {
|
||||
// // println!("DONE: {}", s);
|
||||
// }
|
||||
|
||||
// Ok(res)
|
||||
// }
|
||||
|
||||
fn fold_array_expression_inner(
|
||||
&mut self,
|
||||
array_ty: &ArrayType<'ast, T>,
|
||||
|
@ -500,25 +391,24 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
ArrayExpressionInner::Slice(box array, box from, box to) => {
|
||||
let array =
|
||||
ShallowTransformer::with_versions(self.versions).fold_array_expression(array);
|
||||
let array = self.propagator.fold_array_expression(array).unwrap();
|
||||
let array = self.fold_array_expression(array).unwrap();
|
||||
let from =
|
||||
ShallowTransformer::with_versions(self.versions).fold_uint_expression(from);
|
||||
let from = self.propagator.fold_uint_expression(from).unwrap();
|
||||
let from = self.fold_uint_expression(from).unwrap();
|
||||
let to = ShallowTransformer::with_versions(self.versions).fold_uint_expression(to);
|
||||
let to = self.propagator.fold_uint_expression(to).unwrap();
|
||||
let to = self.fold_uint_expression(to).unwrap();
|
||||
let array = self.ssa.fold_array_expression(array);
|
||||
let array = self.propagator.fold_array_expression(array)?;
|
||||
let array = self.fold_array_expression(array)?;
|
||||
let from = self.ssa.fold_uint_expression(from);
|
||||
let from = self.propagator.fold_uint_expression(from)?;
|
||||
let from = self.fold_uint_expression(from)?;
|
||||
let to = self.ssa.fold_uint_expression(to);
|
||||
let to = self.propagator.fold_uint_expression(to)?;
|
||||
let to = self.fold_uint_expression(to)?;
|
||||
|
||||
match (from.as_inner(), to.as_inner()) {
|
||||
(UExpressionInner::Value(..), UExpressionInner::Value(..)) => {
|
||||
Ok(ArrayExpressionInner::Slice(box array, box from, box to))
|
||||
}
|
||||
_ => {
|
||||
todo!("non constant slice bounds")
|
||||
}
|
||||
_ => Err(Error::NonConstant(format!(
|
||||
"Slice bounds must be compile time constants, found {}",
|
||||
ArrayExpressionInner::Slice(box array, box from, box to)
|
||||
))),
|
||||
}
|
||||
}
|
||||
_ => fold_array_expression_inner(self, array_ty, e),
|
||||
|
@ -548,7 +438,7 @@ pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, E
|
|||
|
||||
match main_function.signature.generics.len() {
|
||||
0 => {
|
||||
let main_function = reduce_function(main_function, &p)?;
|
||||
let main_function = Reducer::new(&p).fold_function(main_function)?;
|
||||
|
||||
Ok(TypedProgram {
|
||||
main: p.main.clone(),
|
||||
|
@ -574,135 +464,9 @@ fn reduce_function<'ast, T: Field>(
|
|||
f: TypedFunction<'ast, T>,
|
||||
program: &TypedProgram<'ast, T>,
|
||||
) -> Result<TypedFunction<'ast, T>, Error> {
|
||||
let mut versions = Versions::default();
|
||||
let mut constants = Constants::default();
|
||||
|
||||
assert!(f.signature.generics.is_empty());
|
||||
|
||||
// let f = match ShallowTransformer::transform(f, &generics, &mut versions) {
|
||||
// Output::Complete(f) => Ok(f),
|
||||
// Output::Incomplete(new_f, new_for_loop_versions) => {
|
||||
// let mut for_loop_versions = new_for_loop_versions;
|
||||
|
||||
// let mut f = Propagator::with_constants(&mut constants)
|
||||
// .fold_function(new_f)
|
||||
// .map_err(|e| Error::Incompatible(format!("{}", e)))?;
|
||||
|
||||
// let mut substitutions = Substitutions::default();
|
||||
|
||||
// let mut hash = None;
|
||||
|
||||
// let mut len = f.statements.len();
|
||||
|
||||
// // println!("{}", f);
|
||||
|
||||
// loop {
|
||||
// let mut reducer = Reducer::new(
|
||||
// program,
|
||||
// &mut versions,
|
||||
// &mut substitutions,
|
||||
// for_loop_versions,
|
||||
// &mut constants,
|
||||
// );
|
||||
|
||||
// println!("reduce");
|
||||
|
||||
// let new_f = TypedFunction {
|
||||
// statements: f
|
||||
// .statements
|
||||
// .into_iter()
|
||||
// .map(|s| reducer.fold_statement(s))
|
||||
// .collect::<Result<Vec<_>, _>>()?
|
||||
// .into_iter()
|
||||
// .flatten()
|
||||
// .collect(),
|
||||
// ..f
|
||||
// };
|
||||
|
||||
// println!("done");
|
||||
|
||||
// // println!("after reduction {}", new_f);
|
||||
|
||||
// println!(
|
||||
// "count {}, unrolled {} loops",
|
||||
// new_f.statements.len(),
|
||||
// reducer.credits
|
||||
// );
|
||||
|
||||
// assert!(reducer.for_loop_versions.is_empty());
|
||||
|
||||
// match reducer.complete {
|
||||
// true => {
|
||||
// substitutions = substitutions.canonicalize();
|
||||
|
||||
// let new_f = Sub::new(&substitutions).fold_function(new_f);
|
||||
|
||||
// // println!("after last sub {}", new_f);
|
||||
|
||||
// // let new_f = Propagator::with_constants(&mut constants)
|
||||
// // .fold_function(new_f)
|
||||
// // .map_err(|e| Error::Incompatible(format!("{}", e)))?;
|
||||
|
||||
// // println!("after last prop {}", new_f);
|
||||
|
||||
// break Ok(new_f);
|
||||
// }
|
||||
// false => {
|
||||
// for_loop_versions = reducer.for_loop_versions_after;
|
||||
|
||||
// println!("canonicalize");
|
||||
|
||||
// // substitutions = substitutions.canonicalize();
|
||||
|
||||
// // let new_f = Sub::new(&substitutions).fold_function(new_f);
|
||||
|
||||
// println!("done");
|
||||
// // println!("after sub {}", new_f);
|
||||
|
||||
// println!("propagate");
|
||||
|
||||
// // f = Propagator::with_constants(&mut constants)
|
||||
// // .fold_function(new_f)
|
||||
// // .map_err(|e| Error::Incompatible(format!("{}", e)))?;
|
||||
|
||||
// println!("done");
|
||||
|
||||
// f = new_f;
|
||||
|
||||
// // println!("after prop {}", f);
|
||||
|
||||
// let new_len = f.statements.len();
|
||||
|
||||
// if new_len == len {
|
||||
// let new_hash = Some(compute_hash(&f));
|
||||
|
||||
// if new_hash == hash {
|
||||
// break Err(Error::NoProgress);
|
||||
// } else {
|
||||
// hash = new_hash;
|
||||
// }
|
||||
// } else {
|
||||
// len = new_len;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }?;
|
||||
|
||||
// Propagator::with_constants(&mut constants)
|
||||
// .fold_function(f)
|
||||
// .map_err(|e| Error::Incompatible(format!("{}", e)))
|
||||
|
||||
Reducer::new(program, &mut versions, &mut constants).fold_function(f)
|
||||
}
|
||||
|
||||
fn compute_hash<T: Field>(f: &TypedFunction<T>) -> u64 {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut s = DefaultHasher::new();
|
||||
f.hash(&mut s);
|
||||
s.finish()
|
||||
Reducer::new(program).fold_function(f)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -28,26 +28,24 @@
|
|||
|
||||
use zokrates_ast::typed::folder::*;
|
||||
use zokrates_ast::typed::identifier::FrameIdentifier;
|
||||
use zokrates_ast::typed::types::ConcreteGenericsAssignment;
|
||||
use zokrates_ast::typed::types::Type;
|
||||
|
||||
use zokrates_ast::typed::*;
|
||||
|
||||
use zokrates_field::Field;
|
||||
|
||||
use super::Versions;
|
||||
|
||||
pub struct ShallowTransformer<'ast, 'a> {
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ShallowTransformer<'ast> {
|
||||
// version index for any variable name
|
||||
pub versions: &'a mut Versions<'ast>,
|
||||
pub versions: Versions<'ast>,
|
||||
pub frames: Vec<usize>,
|
||||
pub latest_frame: usize,
|
||||
}
|
||||
|
||||
impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
|
||||
pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self {
|
||||
ShallowTransformer { versions }
|
||||
}
|
||||
|
||||
fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> {
|
||||
let frame_versions = self.versions.map.entry(self.versions.frame).or_default();
|
||||
impl<'ast> ShallowTransformer<'ast> {
|
||||
pub fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> {
|
||||
let frame_versions = self.versions.map.entry(self.frame()).or_default();
|
||||
|
||||
let version = frame_versions
|
||||
.entry(c_id.clone())
|
||||
|
@ -66,42 +64,21 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn transform<T: Field>(
|
||||
f: TypedFunction<'ast, T>,
|
||||
generics: &ConcreteGenericsAssignment<'ast>,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
) -> TypedFunction<'ast, T> {
|
||||
let mut unroller = ShallowTransformer::with_versions(versions);
|
||||
|
||||
unroller.fold_function(f, generics)
|
||||
fn frame(&self) -> usize {
|
||||
*self.frames.last().unwrap_or(&0)
|
||||
}
|
||||
|
||||
fn fold_function<T: Field>(
|
||||
&mut self,
|
||||
f: TypedFunction<'ast, T>,
|
||||
generics: &ConcreteGenericsAssignment<'ast>,
|
||||
) -> TypedFunction<'ast, T> {
|
||||
let mut f = f;
|
||||
pub fn push_call_frame(&mut self) {
|
||||
self.latest_frame += 1;
|
||||
self.frames.push(self.latest_frame);
|
||||
self.versions
|
||||
.map
|
||||
.insert(self.latest_frame, Default::default());
|
||||
}
|
||||
|
||||
f.statements = generics
|
||||
.0
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(g, v)| {
|
||||
TypedStatement::definition(
|
||||
Variable::new(CoreIdentifier::from(g), Type::Uint(UBitwidth::B32), false)
|
||||
.into(),
|
||||
UExpression::from(v as u32).into(),
|
||||
)
|
||||
})
|
||||
.chain(f.statements)
|
||||
.collect();
|
||||
|
||||
for arg in &f.arguments {
|
||||
let _ = self.issue_next_identifier(arg.id.id.id.id.clone());
|
||||
}
|
||||
|
||||
fold_function(self, f)
|
||||
pub fn pop_call_frame(&mut self) {
|
||||
let frame = self.frames.pop().unwrap();
|
||||
self.versions.map.remove(&frame);
|
||||
}
|
||||
|
||||
pub fn fold_assignee<T: Field>(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
|
||||
|
@ -115,7 +92,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> {
|
||||
fn fold_assembly_statement(
|
||||
&mut self,
|
||||
s: TypedAssemblyStatement<'ast, T>,
|
||||
|
@ -154,14 +131,14 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
|
|||
let version = self
|
||||
.versions
|
||||
.map
|
||||
.get(&self.versions.frame)
|
||||
.get(&self.frame())
|
||||
.unwrap()
|
||||
.get(&n.id.id)
|
||||
.cloned()
|
||||
.unwrap_or(0);
|
||||
|
||||
let id = FrameIdentifier {
|
||||
frame: self.versions.frame,
|
||||
frame: self.frame(),
|
||||
..n.id
|
||||
};
|
||||
|
||||
|
|
|
@ -135,7 +135,7 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)),
|
||||
frame: 0,
|
||||
},
|
||||
id => n.id,
|
||||
_id => n.id,
|
||||
};
|
||||
|
||||
Identifier { id, ..n }
|
||||
|
|
|
@ -240,7 +240,7 @@ impl<'ast, T> From<u32> for UExpression<'ast, T> {
|
|||
impl<'ast, T: Field> From<DeclarationConstant<'ast, T>> for UExpression<'ast, T> {
|
||||
fn from(c: DeclarationConstant<'ast, T>) -> Self {
|
||||
match c {
|
||||
DeclarationConstant::Generic(g) => {
|
||||
DeclarationConstant::Generic(_g) => {
|
||||
// UExpression::identifier(FrameIdentifier::from(g).into()).annotate(UBitwidth::B32)
|
||||
unreachable!()
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue