1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00
This commit is contained in:
schaeff 2021-06-16 16:30:46 +02:00
parent 59023ca677
commit 6b4a80d891
3 changed files with 66 additions and 67 deletions

View file

@ -1,2 +0,0 @@
def main(field[3][2] a, u32 index) -> field[2]:
return a[index]

View file

@ -7,20 +7,20 @@ use std::collections::HashMap;
use std::convert::TryInto;
use zokrates_field::Field;
type ModuleConstants<'ast, T> =
type ProgramConstants<'ast, T> =
HashMap<OwnedTypedModuleId, HashMap<Identifier<'ast>, TypedExpression<'ast, T>>>;
pub struct ConstantInliner<'ast, T> {
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
constants: ModuleConstants<'ast, T>,
constants: ProgramConstants<'ast, T>,
}
impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
pub fn new(
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
constants: ModuleConstants<'ast, T>,
constants: ProgramConstants<'ast, T>,
) -> Self {
ConstantInliner {
modules,
@ -29,7 +29,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
}
}
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
let constants = HashMap::new();
let constants = ProgramConstants::new();
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants);
inliner.fold_program(p)
}
@ -37,6 +37,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
let prev = self.location.clone();
self.location = location;
self.constants.entry(self.location.clone()).or_default();
prev
}
@ -44,18 +45,24 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
self.constants.contains_key(id)
}
fn get_constant(&mut self, id: &Identifier) -> Option<TypedExpression<'ast, T>> {
assert_eq!(id.version, 0);
match id.id {
CoreIdentifier::Call(..) => {
unreachable!("calls identifiers are only available after call inlining")
}
CoreIdentifier::Source(id) => self
.constants
.get(&self.location)
.and_then(|constants| constants.get(&id.into()))
.cloned(),
}
fn get_constant(
&self,
id: &CanonicalConstantIdentifier<'ast>,
) -> Option<TypedExpression<'ast, T>> {
self.constants
.get(&id.module)
.and_then(|constants| constants.get(&id.id.into()))
.cloned()
}
fn get_constant_for_identifier(
&self,
id: &Identifier<'ast>,
) -> Option<TypedExpression<'ast, T>> {
self.constants
.get(&self.location)
.and_then(|constants| constants.get(&id))
.cloned()
}
}
@ -64,48 +71,42 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
if !self.treated(&id) {
let current_m_id = self.change_location(id.clone());
self.constants.entry(self.location.clone()).or_default();
let m = self.fold_module(self.modules.get(&id).unwrap().clone());
let m = self.modules.remove(&id).unwrap();
let m = self.fold_module(m);
self.modules.insert(id.clone(), m);
self.change_location(current_m_id);
}
id
}
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
// initialise a constant map for this module
self.constants.entry(self.location.clone()).or_default();
TypedModule {
constants: m
.constants
.into_iter()
.map(|(id, tc)| {
let id = self.fold_canonical_constant_identifier(id);
let constant = match tc {
TypedConstantSymbol::There(imported_id) => {
// visit the imported symbol. This triggers visiting the corresponding module if needed
let imported_id = self.fold_canonical_constant_identifier(imported_id);
self.constants
.get(&imported_id.module)
.unwrap()
.get(&imported_id.id.into())
.cloned()
// after that, the constant must have been defined defined in the global map. It is already reduced
// to a literal, so running propagation isn't required
self.get_constant(&imported_id).unwrap()
}
TypedConstantSymbol::Here(c) => {
let non_propagated_constant = fold_constant(self, c).expression;
// folding the constant above only reduces it to an expression containing only literals, not to a single literal.
// propagating with an empty map of constants reduces it to a single literal
Propagator::with_constants(&mut HashMap::default())
.fold_expression(non_propagated_constant)
.unwrap()
}
TypedConstantSymbol::Here(c) => fold_constant(self, c).expression,
};
let constant =
Propagator::with_constants(self.constants.get_mut(&self.location).unwrap())
.fold_expression(constant)
.unwrap();
// add to the constant map. The value added is always a single litteral
self.constants
.entry(self.location.clone())
.or_default()
.get_mut(&self.location)
.unwrap()
.insert(id.id.into(), constant.clone());
(
@ -136,22 +137,15 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
) -> DeclarationConstant<'ast> {
match c {
// replace constants by their concrete value in declaration types
DeclarationConstant::Constant(id) => DeclarationConstant::Concrete(
match self
.constants
.get(&id.module)
.unwrap()
.get(&id.id.into())
.cloned()
.unwrap()
{
DeclarationConstant::Constant(id) => {
DeclarationConstant::Concrete(match self.get_constant(&id).unwrap() {
TypedExpression::Uint(UExpression {
inner: UExpressionInner::Value(v),
..
}) => v as u32,
_ => unreachable!("all constants should be reduceable to u32 literals"),
},
),
})
}
c => c,
}
}
@ -161,10 +155,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => c.try_into().unwrap(),
None => fold_field_expression(self, e),
},
FieldElementExpression::Identifier(ref id) => {
match self.get_constant_for_identifier(id) {
Some(c) => c.try_into().unwrap(),
None => fold_field_expression(self, e),
}
}
e => fold_field_expression(self, e),
}
}
@ -174,8 +170,8 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => self.fold_expression(c).try_into().unwrap(),
BooleanExpression::Identifier(ref id) => match self.get_constant_for_identifier(id) {
Some(c) => c.try_into().unwrap(),
None => fold_boolean_expression(self, e),
},
e => fold_boolean_expression(self, e),
@ -188,9 +184,9 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Identifier(ref id) => match self.get_constant(id) {
UExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) {
Some(c) => {
let e: UExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
let e: UExpression<'ast, T> = c.try_into().unwrap();
e.into_inner()
}
None => fold_uint_expression_inner(self, size, e),
@ -205,13 +201,15 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) {
Some(c) => {
let e: ArrayExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
e.into_inner()
ArrayExpressionInner::Identifier(ref id) => {
match self.get_constant_for_identifier(id) {
Some(c) => {
let e: ArrayExpression<'ast, T> = c.try_into().unwrap();
e.into_inner()
}
None => fold_array_expression_inner(self, ty, e),
}
None => fold_array_expression_inner(self, ty, e),
},
}
e => fold_array_expression_inner(self, ty, e),
}
}
@ -222,9 +220,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Identifier(ref id) => match self.get_constant(id) {
StructExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id)
{
Some(c) => {
let e: StructExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
let e: StructExpression<'ast, T> = c.try_into().unwrap();
e.into_inner()
}
None => fold_struct_expression_inner(self, ty, e),

View file

@ -298,6 +298,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Debug)]
pub struct TypedConstant<'ast, T> {
// the type is already stored in the TypedExpression, but we want to avoid awkward trait bounds in `fmt::Display`
pub ty: Type<'ast, T>,
pub expression: TypedExpression<'ast, T>,
}
@ -310,6 +311,7 @@ impl<'ast, T> TypedConstant<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// using `self.expression.get_type()` would be better here but ends up requiring stronger trait bounds
write!(f, "const {}({})", self.ty, self.expression)
}
}