fix propagation in case of nested assembly
This commit is contained in:
parent
350430f66c
commit
6ea437c83c
5 changed files with 98 additions and 4 deletions
|
@ -510,6 +510,38 @@ pub struct ArgumentFinder<'ast, T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
|
impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
|
||||||
|
fn fold_assembly_block(
|
||||||
|
&mut self,
|
||||||
|
s: zir::AssemblyBlockStatement<'ast, T>,
|
||||||
|
) -> Vec<zir::ZirStatement<'ast, T>> {
|
||||||
|
let mut statements: Vec<_> = s
|
||||||
|
.inner
|
||||||
|
.into_iter()
|
||||||
|
.rev()
|
||||||
|
.flat_map(|s| self.fold_assembly_statement(s))
|
||||||
|
.collect();
|
||||||
|
statements.reverse();
|
||||||
|
vec![zir::ZirStatement::Assembly(
|
||||||
|
zir::AssemblyBlockStatement::new(statements),
|
||||||
|
)]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fold_assembly_assignment(
|
||||||
|
&mut self,
|
||||||
|
s: zir::AssemblyAssignment<'ast, T>,
|
||||||
|
) -> Vec<zir::ZirAssemblyStatement<'ast, T>> {
|
||||||
|
let assignees: Vec<zir::ZirAssignee<'ast>> = s
|
||||||
|
.assignee
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| self.fold_assignee(v))
|
||||||
|
.collect();
|
||||||
|
let expr = self.fold_function(s.expression);
|
||||||
|
for a in &assignees {
|
||||||
|
self.identifiers.remove(&a.id);
|
||||||
|
}
|
||||||
|
vec![zir::ZirAssemblyStatement::assignment(assignees, expr)]
|
||||||
|
}
|
||||||
|
|
||||||
fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec<zir::ZirStatement<'ast, T>> {
|
fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec<zir::ZirStatement<'ast, T>> {
|
||||||
match s {
|
match s {
|
||||||
zir::ZirStatement::Definition(s) => {
|
zir::ZirStatement::Definition(s) => {
|
||||||
|
|
|
@ -84,6 +84,27 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fold_assembly_block(
|
||||||
|
&mut self,
|
||||||
|
s: zokrates_ast::zir::AssemblyBlockStatement<'ast, T>,
|
||||||
|
) -> Result<Vec<ZirStatement<'ast, T>>, Self::Error> {
|
||||||
|
let block: Vec<_> = s
|
||||||
|
.inner
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| self.fold_assembly_statement(s))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(match block.is_empty() {
|
||||||
|
true => vec![],
|
||||||
|
false => vec![ZirStatement::Assembly(
|
||||||
|
zokrates_ast::zir::AssemblyBlockStatement::new(block),
|
||||||
|
)],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn fold_assembly_assignment(
|
fn fold_assembly_assignment(
|
||||||
&mut self,
|
&mut self,
|
||||||
s: zokrates_ast::zir::AssemblyAssignment<'ast, T>,
|
s: zokrates_ast::zir::AssemblyAssignment<'ast, T>,
|
||||||
|
|
|
@ -9,11 +9,16 @@ pub struct ZirCanonicalizer<'ast> {
|
||||||
|
|
||||||
impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> {
|
impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> {
|
||||||
fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> {
|
fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> {
|
||||||
let new_id = self.identifier_map.len();
|
let id = match self.identifier_map.get(&p.id.id) {
|
||||||
self.identifier_map.insert(p.id.id.clone(), new_id);
|
Some(v) => Identifier::internal(*v),
|
||||||
|
None => {
|
||||||
|
let new_id = self.identifier_map.len();
|
||||||
|
self.identifier_map.insert(p.id.id.clone(), new_id);
|
||||||
|
Identifier::internal(new_id)
|
||||||
|
}
|
||||||
|
};
|
||||||
Parameter {
|
Parameter {
|
||||||
id: Variable::with_id_and_type(Identifier::internal(new_id), p.id.ty),
|
id: Variable::with_id_and_type(id, p.id.ty),
|
||||||
..p
|
..p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
{
|
||||||
|
"curves": ["Bn128"],
|
||||||
|
"max_constraint_count": 1,
|
||||||
|
"tests": [
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"values": ["4"]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"Ok": {
|
||||||
|
"value": "32"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
def bar(field a) -> field {
|
||||||
|
field tmp = a * a;
|
||||||
|
field mut b = 0;
|
||||||
|
asm {
|
||||||
|
b <-- tmp * 2;
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
def foo(field a) -> field {
|
||||||
|
field mut b = 0;
|
||||||
|
asm {
|
||||||
|
b <-- bar(a);
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
def main(field a) -> field {
|
||||||
|
return foo(a);
|
||||||
|
}
|
Loading…
Reference in a new issue