1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

Merge pull request #1050 from Zokrates/fix-constant-reduction

Implement missing cases in constant resolution
This commit is contained in:
Thibaut Schaeffer 2021-11-23 11:22:23 +01:00 committed by GitHub
commit 2a63d8d67a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 78 additions and 6 deletions

View file

@ -0,0 +1 @@
Fix reduction of constants

View file

@ -84,7 +84,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> {
&mut self,
s: TypedSymbolDeclaration<'ast, T>,
) -> Result<TypedSymbolDeclaration<'ast, T>, Self::Error> {
// before we treat the symbol, propagate the constants into it, as may be using constants defined earlier in this module.
// before we treat the symbol, propagate the constants into it, as it may be using constants defined earlier in this module.
let s = self.update_symbol_declaration(s);
let s = fold_symbol_declaration(self, s)?;
@ -103,6 +103,12 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> {
TypedConstantSymbol::Here(c) => {
let c = self.fold_constant(c)?;
// if constants were used in the rhs, they are now defined in the map
// replace them in the expression
use crate::typed_absy::folder::Folder;
let c = ConstantsReader::with_constants(&self.constants).fold_constant(c);
use crate::typed_absy::{DeclarationSignature, TypedFunction, TypedStatement};
// wrap this expression in a function

View file

@ -508,7 +508,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
}
pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
// inline all constants and replace them in the program
// inline all constants and replace them in the program
let mut constants_writer = ConstantsWriter::with_program(p.clone());

View file

@ -122,7 +122,14 @@ pub trait Folder<'ast, T: Field>: Sized {
}
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
n
let id = match n.id {
CoreIdentifier::Constant(c) => {
CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c))
}
id => id,
};
Identifier { id, ..n }
}
fn fold_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> {
@ -1027,6 +1034,9 @@ pub fn fold_declaration_constant<'ast, T: Field, F: Folder<'ast, T>>(
) -> DeclarationConstant<'ast, T> {
match c {
DeclarationConstant::Expression(e) => DeclarationConstant::Expression(f.fold_expression(e)),
DeclarationConstant::Constant(c) => {
DeclarationConstant::Constant(f.fold_canonical_constant_identifier(c))
}
c => c,
}
}

View file

@ -150,7 +150,14 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
}
fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> {
Ok(n)
let id = match n.id {
CoreIdentifier::Constant(c) => {
CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)?)
}
id => id,
};
Ok(Identifier { id, ..n })
}
fn fold_variable(&mut self, v: Variable<'ast, T>) -> Result<Variable<'ast, T>, Self::Error> {
@ -1072,6 +1079,9 @@ pub fn fold_declaration_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
DeclarationConstant::Expression(e) => {
Ok(DeclarationConstant::Expression(f.fold_expression(e)?))
}
DeclarationConstant::Constant(c) => Ok(DeclarationConstant::Constant(
f.fold_canonical_constant_identifier(c)?,
)),
c => Ok(c),
}
}
@ -1238,11 +1248,14 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>(
p: TypedProgram<'ast, T>,
) -> Result<TypedProgram<'ast, T>, F::Error> {
Ok(TypedProgram {
main: f.fold_module_id(p.main)?,
modules: p
.modules
.into_iter()
.map(|(module_id, module)| f.fold_module(module).map(|m| (module_id, m)))
.map(|(module_id, module)| {
let module_id = f.fold_module_id(module_id)?;
f.fold_module(module).map(|m| (module_id, m))
})
.collect::<Result<_, _>>()?,
main: f.fold_module_id(p.main)?,
})
}

View file

@ -0,0 +1,4 @@
{
"entry_point": "./tests/tests/constants/issue_1038/a.zok",
"tests": []
}

View file

@ -0,0 +1,5 @@
from "./b" import SIZE_WORDS
def main(field[SIZE_WORDS] a):
assert(a == [0; SIZE_WORDS])
return

View file

@ -0,0 +1,2 @@
const u32 SIZE_BYTES = 136
const u32 SIZE_WORDS = SIZE_BYTES/8

View file

@ -0,0 +1,2 @@
const u32 SIZE_BYTES = 136
const u32 SIZE_WORDS = SIZE_BYTES/8

View file

@ -0,0 +1,4 @@
{
"entry_point": "./tests/tests/constants/issue_1038/reversed/b.zok",
"tests": []
}

View file

@ -0,0 +1,5 @@
from "./a" import SIZE_WORDS
def main(field[SIZE_WORDS] a):
assert(a == [0; SIZE_WORDS])
return

View file

@ -0,0 +1,4 @@
{
"entry_point": "./tests/tests/constants/issue_1047/a.zok",
"tests": []
}

View file

@ -0,0 +1,4 @@
from "./b" import B
def main():
return

View file

@ -0,0 +1,2 @@
const field A = 1
const field B = A + 1

View file

@ -0,0 +1,2 @@
const field A = 1
const field B = A + 1

View file

@ -0,0 +1,4 @@
{
"entry_point": "./tests/tests/constants/issue_1047/reversed/b.zok",
"tests": []
}

View file

@ -0,0 +1,4 @@
from "./a" import B
def main():
return