1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge pull request #1322 from Zokrates/fix-nested-assembly

Fix propagation in case of nested assembly
This commit is contained in:
Thibaut Schaeffer 2023-07-24 10:55:05 +02:00 committed by GitHub
commit f900437d36
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 132 additions and 13 deletions

View file

@ -0,0 +1 @@
Fix panic in case of nested assembly blocks

View file

@ -510,6 +510,38 @@ pub struct ArgumentFinder<'ast, T> {
} }
impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> { impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
fn fold_assembly_block(
&mut self,
s: zir::AssemblyBlockStatement<'ast, T>,
) -> Vec<zir::ZirStatement<'ast, T>> {
let mut statements: Vec<_> = s
.inner
.into_iter()
.rev()
.flat_map(|s| self.fold_assembly_statement(s))
.collect();
statements.reverse();
vec![zir::ZirStatement::Assembly(
zir::AssemblyBlockStatement::new(statements),
)]
}
fn fold_assembly_assignment(
&mut self,
s: zir::AssemblyAssignment<'ast, T>,
) -> Vec<zir::ZirAssemblyStatement<'ast, T>> {
let assignees: Vec<zir::ZirAssignee<'ast>> = s
.assignee
.into_iter()
.map(|v| self.fold_assignee(v))
.collect();
let expr = self.fold_function(s.expression);
for a in &assignees {
self.identifiers.remove(&a.id);
}
vec![zir::ZirAssemblyStatement::assignment(assignees, expr)]
}
fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec<zir::ZirStatement<'ast, T>> { fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec<zir::ZirStatement<'ast, T>> {
match s { match s {
zir::ZirStatement::Definition(s) => { zir::ZirStatement::Definition(s) => {

View file

@ -84,6 +84,27 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
}) })
} }
fn fold_assembly_block(
&mut self,
s: zokrates_ast::zir::AssemblyBlockStatement<'ast, T>,
) -> Result<Vec<ZirStatement<'ast, T>>, Self::Error> {
let block: Vec<_> = s
.inner
.into_iter()
.map(|s| self.fold_assembly_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect();
Ok(match block.is_empty() {
true => vec![],
false => vec![ZirStatement::Assembly(
zokrates_ast::zir::AssemblyBlockStatement::new(block),
)],
})
}
fn fold_assembly_assignment( fn fold_assembly_assignment(
&mut self, &mut self,
s: zokrates_ast::zir::AssemblyAssignment<'ast, T>, s: zokrates_ast::zir::AssemblyAssignment<'ast, T>,

View file

@ -9,11 +9,16 @@ pub struct ZirCanonicalizer<'ast> {
impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> { impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> {
fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> {
let id = match self.identifier_map.get(&p.id.id) {
Some(v) => Identifier::internal(*v),
None => {
let new_id = self.identifier_map.len(); let new_id = self.identifier_map.len();
self.identifier_map.insert(p.id.id.clone(), new_id); self.identifier_map.insert(p.id.id.clone(), new_id);
Identifier::internal(new_id)
}
};
Parameter { Parameter {
id: Variable::with_id_and_type(Identifier::internal(new_id), p.id.ty), id: Variable::with_id_and_type(id, p.id.ty),
..p ..p
} }
} }

View file

@ -63,8 +63,8 @@ pub struct ZirFunction<'ast, T> {
pub type IdentifierOrExpression<'ast, T, E> = pub type IdentifierOrExpression<'ast, T, E> =
expressions::IdentifierOrExpression<Identifier<'ast>, E, <E as Expr<'ast, T>>::Inner>; expressions::IdentifierOrExpression<Identifier<'ast>, E, <E as Expr<'ast, T>>::Inner>;
impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> { impl<'ast, T: fmt::Display> ZirFunction<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
writeln!( writeln!(
f, f,
"({}) -> ({}) {{", "({}) -> ({}) {{",
@ -82,11 +82,17 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> {
)?; )?;
for s in &self.statements { for s in &self.statements {
s.fmt_indented(f, 1)?; s.fmt_indented(f, depth + 1)?;
writeln!(f)?; writeln!(f)?;
} }
write!(f, "}}") write!(f, "{}}}", "\t".repeat(depth))
}
}
impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt_indented(f, 0)
} }
} }
@ -174,19 +180,30 @@ impl<'ast, T> WithSpan for ZirAssemblyStatement<'ast, T> {
} }
} }
impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> { impl<'ast, T: fmt::Display> ZirAssemblyStatement<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
write!(f, "{}", "\t".repeat(depth))?;
match *self { match *self {
ZirAssemblyStatement::Assignment(ref s) => { ZirAssemblyStatement::Assignment(ref s) => {
write!( write!(
f, f,
"{} <-- {};", "{} <-- (",
s.assignee s.assignee
.iter() .iter()
.map(|a| a.to_string()) .map(|a| a.to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "), .join(", ")
)?;
s.expression.fmt_indented(f, depth)?;
write!(
f,
")({});",
s.expression s.expression
.arguments
.iter()
.map(|a| a.id.id.to_string())
.collect::<Vec<_>>()
.join(", ")
) )
} }
ZirAssemblyStatement::Constraint(ref s) => { ZirAssemblyStatement::Constraint(ref s) => {
@ -196,6 +213,12 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> {
} }
} }
impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt_indented(f, 0)
}
}
pub type DefinitionStatement<'ast, T> = pub type DefinitionStatement<'ast, T> =
common::expressions::DefinitionStatement<ZirAssignee<'ast>, ZirExpression<'ast, T>>; common::expressions::DefinitionStatement<ZirAssignee<'ast>, ZirExpression<'ast, T>>;
pub type AssertionStatement<'ast, T> = pub type AssertionStatement<'ast, T> =
@ -462,7 +485,8 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
ZirStatement::Assembly(s) => { ZirStatement::Assembly(s) => {
writeln!(f, "asm {{")?; writeln!(f, "asm {{")?;
for s in &s.inner { for s in &s.inner {
writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?; s.fmt_indented(f, depth + 1)?;
writeln!(f)?;
} }
write!(f, "{}}}", "\t".repeat(depth)) write!(f, "{}}}", "\t".repeat(depth))
} }

View file

@ -0,0 +1,16 @@
{
"curves": ["Bn128"],
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": ["4"]
},
"output": {
"Ok": {
"value": "32"
}
}
}
]
}

View file

@ -0,0 +1,20 @@
def bar(field a) -> field {
field tmp = a * a;
field mut b = 0;
asm {
b <-- tmp * 2;
}
return b;
}
def foo(field a) -> field {
field mut b = 0;
asm {
b <-- bar(a);
}
return b;
}
def main(field a) -> field {
return foo(a);
}