1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

Merge pull request #975 from Zokrates/allow-calls-in-constants

Allow calls in constants
This commit is contained in:
Thibaut Schaeffer 2021-10-01 18:01:52 +03:00 committed by GitHub
commit 5f2d65124b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 2317 additions and 1597 deletions

View file

@ -0,0 +1 @@
Allow calls in constant definitions

View file

@ -0,0 +1,11 @@
def myFct<N, N2>(u64[N] ignored) -> u64[N2]:
assert(2*N == N2)
return [0; N2]
const u32 N = 3
const u32 N2 = 2*N
def main(u64[N] arg) -> bool:
u64[N2] someVariable = myFct(arg)
return true

View file

@ -0,0 +1,10 @@
from "./call_in_const_aux.zok" import A, foo, F
def bar(field[A] x) -> field[A]:
return x
const field[A] Y = [...bar(foo::<A>(F))[..A - 1], 1]
def main(field[A] X):
assert(X == Y)
return

View file

@ -0,0 +1,9 @@
const field F = 10
const u32 A = 10
const u32 B = A
def foo<N>(field X) -> field[N]:
return [X; N]
def main():
return

View file

@ -0,0 +1,7 @@
def yes() -> bool:
return true
const bool TRUE = yes()
def main():
return

View file

@ -0,0 +1,14 @@
// this should not compile, as A == B
const u32 A = 1
const u32 B = 1
def foo(field[A] a) -> bool:
return true
def foo(field[B] a) -> bool:
return true
def main():
assert(foo([1]))
return

View file

@ -0,0 +1,14 @@
// this should actually compile, as A != B
const u32 A = 2
const u32 B = 1
def foo(field[A] a) -> bool:
return true
def foo(field[B] a) -> bool:
return true
def main():
assert(foo([1]))
return

View file

@ -0,0 +1,6 @@
from "EMBED" import bit_array_le
const bool CONST = bit_array_le([true], [true])
def main() -> bool:
return CONST

View file

@ -0,0 +1,14 @@
def constant() -> u32:
u32 res = 0
u32 x = 3
for u32 y in 0..x do
res = res + 1
endfor
return res
const u32 CONSTANT = 1 + constant()
const u32 OTHER_CONSTANT = 42
def main(field[CONSTANT] a) -> u32:
return CONSTANT + OTHER_CONSTANT

View file

@ -0,0 +1,15 @@
struct SomeStruct<N> {
u64[N] f
}
def myFct<N, N2, N3>(SomeStruct<N> ignored) -> u32[N2]:
assert(2*N == N2)
return [N3; N2]
const u32 N = 3
const u32 N2 = 2*N
def main(SomeStruct<N> arg) -> u32:
u32[N2] someVariable = myFct::<_, _, 42>(arg)
return someVariable[0]

View file

@ -9,6 +9,16 @@ pub struct Node<T> {
pub value: T,
}
impl<T> Node<T> {
pub fn mock(e: T) -> Self {
Self {
start: Position::mock(),
end: Position::mock(),
value: e,
}
}
}
impl<T: fmt::Display> Node<T> {
pub fn pos(&self) -> (Position, Position) {
(self.start, self.end)
@ -67,8 +77,7 @@ pub trait NodeValue: fmt::Display + fmt::Debug + Sized + PartialEq {
impl<V: NodeValue> From<V> for Node<V> {
fn from(v: V) -> Node<V> {
let mock_position = Position { col: 42, line: 42 };
Node::new(mock_position, mock_position, v)
Node::new(Position::mock(), Position::mock(), v)
}
}

View file

@ -249,6 +249,8 @@ fn check_with_arena<'ast, T: Field, E: Into<imports::Error>>(
let typed_ast = Checker::check(compiled)
.map_err(|errors| CompileErrors(errors.into_iter().map(CompileError::from).collect()))?;
log::trace!("\n{}", typed_ast);
let main_module = typed_ast.main.clone();
log::debug!("Run static analysis");

View file

@ -1,3 +1,7 @@
use crate::absy::{
types::{UnresolvedSignature, UnresolvedType},
ConstantGenericNode, Expression,
};
use crate::flat_absy::{
FlatDirective, FlatExpression, FlatExpressionList, FlatFunction, FlatParameter, FlatStatement,
FlatVariable, RuntimeError,
@ -26,7 +30,7 @@ cfg_if::cfg_if! {
/// A low level function that contains non-deterministic introduction of variables. It is carried out as is until
/// the flattening step when it can be inlined.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)]
pub enum FlatEmbed {
BitArrayLe,
Unpack,
@ -45,7 +49,131 @@ pub enum FlatEmbed {
}
impl FlatEmbed {
pub fn signature(&self) -> DeclarationSignature<'static> {
pub fn signature(&self) -> UnresolvedSignature {
match self {
FlatEmbed::BitArrayLe => UnresolvedSignature::new()
.generics(vec![ConstantGenericNode::mock("N")])
.inputs(vec![
UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::Identifier("N").into(),
)
.into(),
UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::Identifier("N").into(),
)
.into(),
])
.outputs(vec![UnresolvedType::Boolean.into()]),
FlatEmbed::Unpack => UnresolvedSignature::new()
.generics(vec!["N".into()])
.inputs(vec![UnresolvedType::FieldElement.into()])
.outputs(vec![UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::Identifier("N").into(),
)
.into()]),
FlatEmbed::U8ToBits => UnresolvedSignature::new()
.inputs(vec![UnresolvedType::Uint(8).into()])
.outputs(vec![UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::U32Constant(8).into(),
)
.into()]),
FlatEmbed::U16ToBits => UnresolvedSignature::new()
.inputs(vec![UnresolvedType::Uint(16).into()])
.outputs(vec![UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::U32Constant(16).into(),
)
.into()]),
FlatEmbed::U32ToBits => UnresolvedSignature::new()
.inputs(vec![UnresolvedType::Uint(32).into()])
.outputs(vec![UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::U32Constant(32).into(),
)
.into()]),
FlatEmbed::U64ToBits => UnresolvedSignature::new()
.inputs(vec![UnresolvedType::Uint(64).into()])
.outputs(vec![UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::U32Constant(64).into(),
)
.into()]),
FlatEmbed::U8FromBits => UnresolvedSignature::new()
.outputs(vec![UnresolvedType::Uint(8).into()])
.inputs(vec![UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::U32Constant(8).into(),
)
.into()]),
FlatEmbed::U16FromBits => UnresolvedSignature::new()
.outputs(vec![UnresolvedType::Uint(16).into()])
.inputs(vec![UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::U32Constant(16).into(),
)
.into()]),
FlatEmbed::U32FromBits => UnresolvedSignature::new()
.outputs(vec![UnresolvedType::Uint(32).into()])
.inputs(vec![UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::U32Constant(32).into(),
)
.into()]),
FlatEmbed::U64FromBits => UnresolvedSignature::new()
.outputs(vec![UnresolvedType::Uint(64).into()])
.inputs(vec![UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::U32Constant(64).into(),
)
.into()]),
#[cfg(feature = "bellman")]
FlatEmbed::Sha256Round => UnresolvedSignature::new()
.inputs(vec![
UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::U32Constant(512).into(),
)
.into(),
UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::U32Constant(256).into(),
)
.into(),
])
.outputs(vec![UnresolvedType::array(
UnresolvedType::Boolean.into(),
Expression::U32Constant(256).into(),
)
.into()]),
#[cfg(feature = "ark")]
FlatEmbed::SnarkVerifyBls12377 => UnresolvedSignature::new()
.generics(vec!["N".into(), "V".into()])
.inputs(vec![
UnresolvedType::array(
UnresolvedType::FieldElement.into(),
Expression::Identifier("N").into(),
)
.into(), // inputs
UnresolvedType::array(
UnresolvedType::FieldElement.into(),
Expression::U32Constant(8).into(),
)
.into(), // proof
UnresolvedType::array(
UnresolvedType::FieldElement.into(),
Expression::Identifier("V").into(),
)
.into(), // 18 + (2 * n) // vk
])
.outputs(vec![UnresolvedType::Boolean.into()]),
}
}
pub fn typed_signature<T>(&self) -> DeclarationSignature<'static, T> {
match self {
FlatEmbed::BitArrayLe => DeclarationSignature::new()
.generics(vec![Some(DeclarationConstant::Generic(
@ -177,15 +305,13 @@ impl FlatEmbed {
}
}
pub fn generics<'ast>(&self, assignment: &ConcreteGenericsAssignment<'ast>) -> Vec<u32> {
let gen = self
.signature()
.generics
.into_iter()
.map(|c| match c.unwrap() {
pub fn generics<'ast, T>(&self, assignment: &ConcreteGenericsAssignment<'ast>) -> Vec<u32> {
let gen = self.typed_signature().generics.into_iter().map(
|c: Option<DeclarationConstant<'ast, T>>| match c.unwrap() {
DeclarationConstant::Generic(g) => g,
_ => unreachable!(),
});
},
);
assert_eq!(gen.len(), assignment.0.len());
gen.map(|g| *assignment.0.get(&g).unwrap() as u32).collect()

View file

@ -15,7 +15,6 @@ impl Position {
}
}
#[cfg(test)]
pub fn mock() -> Self {
Position { line: 42, col: 42 }
}

File diff suppressed because it is too large Load diff

View file

@ -1,973 +0,0 @@
use crate::static_analysis::Propagator;
use crate::typed_absy::result_folder::*;
use crate::typed_absy::types::DeclarationConstant;
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryInto;
use std::fmt;
use zokrates_field::Field;
type ProgramConstants<'ast, T> =
HashMap<OwnedTypedModuleId, HashMap<Identifier<'ast>, TypedExpression<'ast, T>>>;
#[derive(Debug, PartialEq)]
pub enum Error {
Type(String),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::Type(s) => write!(f, "{}", s),
}
}
}
pub struct ConstantInliner<'ast, T> {
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
constants: ProgramConstants<'ast, T>,
}
impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
pub fn new(
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
constants: ProgramConstants<'ast, T>,
) -> Self {
ConstantInliner {
modules,
location,
constants,
}
}
pub fn inline(p: TypedProgram<'ast, T>) -> Result<TypedProgram<'ast, T>, Error> {
let constants = ProgramConstants::new();
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants);
inliner.fold_program(p)
}
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
let prev = self.location.clone();
self.location = location;
self.constants.entry(self.location.clone()).or_default();
prev
}
fn treated(&self, id: &TypedModuleId) -> bool {
self.constants.contains_key(id)
}
fn get_constant(
&self,
id: &CanonicalConstantIdentifier<'ast>,
) -> Option<TypedExpression<'ast, T>> {
self.constants
.get(&id.module)
.and_then(|constants| constants.get(&id.id.into()))
.cloned()
}
fn get_constant_for_identifier(
&self,
id: &Identifier<'ast>,
) -> Option<TypedExpression<'ast, T>> {
self.constants
.get(&self.location)
.and_then(|constants| constants.get(&id))
.cloned()
}
}
impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> {
type Error = Error;
fn fold_program(
&mut self,
p: TypedProgram<'ast, T>,
) -> Result<TypedProgram<'ast, T>, Self::Error> {
self.fold_module_id(p.main.clone())?;
Ok(TypedProgram {
modules: std::mem::take(&mut self.modules),
..p
})
}
fn fold_module_id(
&mut self,
id: OwnedTypedModuleId,
) -> Result<OwnedTypedModuleId, Self::Error> {
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
if !self.treated(&id) {
let current_m_id = self.change_location(id.clone());
let m = self.modules.remove(&id).unwrap();
let m = self.fold_module(m)?;
self.modules.insert(id.clone(), m);
self.change_location(current_m_id);
}
Ok(id)
}
fn fold_module(
&mut self,
m: TypedModule<'ast, T>,
) -> Result<TypedModule<'ast, T>, Self::Error> {
Ok(TypedModule {
constants: m
.constants
.into_iter()
.map(|(id, tc)| {
let id = self.fold_canonical_constant_identifier(id)?;
let constant = match tc {
TypedConstantSymbol::There(imported_id) => {
// visit the imported symbol. This triggers visiting the corresponding module if needed
let imported_id = self.fold_canonical_constant_identifier(imported_id)?;
// after that, the constant must have been defined defined in the global map. It is already reduced
// to a literal, so running propagation isn't required
self.get_constant(&imported_id).unwrap()
}
TypedConstantSymbol::Here(c) => {
let non_propagated_constant = fold_constant(self, c)?.expression;
// folding the constant above only reduces it to an expression containing only literals, not to a single literal.
// propagating with an empty map of constants reduces it to a single literal
Propagator::with_constants(&mut HashMap::default())
.fold_expression(non_propagated_constant)
.unwrap()
}
};
if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>(*id.ty.clone()).unwrap() == constant.get_type() {
// add to the constant map. The value added is always a single litteral
self.constants
.get_mut(&self.location)
.unwrap()
.insert(id.id.into(), constant.clone());
Ok((
id,
TypedConstantSymbol::Here(TypedConstant {
expression: constant,
}),
))
} else {
Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant.get_type(), id.id, id.ty)))
}
})
.collect::<Result<Vec<_>, _>>()?,
functions: m
.functions
.into_iter()
.map::<Result<_, Self::Error>, _>(|(key, fun)| {
Ok((
self.fold_declaration_function_key(key)?,
self.fold_function_symbol(fun)?,
))
})
.collect::<Result<Vec<_>, _>>()
.into_iter()
.flatten()
.collect(),
})
}
fn fold_declaration_constant(
&mut self,
c: DeclarationConstant<'ast>,
) -> Result<DeclarationConstant<'ast>, Self::Error> {
match c {
// replace constants by their concrete value in declaration types
DeclarationConstant::Constant(id) => {
let id = CanonicalConstantIdentifier {
module: self.fold_module_id(id.module)?,
..id
};
Ok(DeclarationConstant::Concrete(match self.get_constant(&id).unwrap() {
TypedExpression::Uint(UExpression {
inner: UExpressionInner::Value(v),
..
}) => v as u32,
_ => unreachable!("all constants found in declaration types should be reduceable to u32 literals"),
}))
}
c => Ok(c),
}
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match e {
FieldElementExpression::Identifier(ref id) => {
match self.get_constant_for_identifier(id) {
Some(c) => Ok(c.try_into().unwrap()),
None => fold_field_expression(self, e),
}
}
e => fold_field_expression(self, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
match e {
BooleanExpression::Identifier(ref id) => match self.get_constant_for_identifier(id) {
Some(c) => Ok(c.try_into().unwrap()),
None => fold_boolean_expression(self, e),
},
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression_inner(
&mut self,
size: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Self::Error> {
match e {
UExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) {
Some(c) => {
let e: UExpression<'ast, T> = c.try_into().unwrap();
Ok(e.into_inner())
}
None => fold_uint_expression_inner(self, size, e),
},
e => fold_uint_expression_inner(self, size, e),
}
}
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
match e {
ArrayExpressionInner::Identifier(ref id) => {
match self.get_constant_for_identifier(id) {
Some(c) => {
let e: ArrayExpression<'ast, T> = c.try_into().unwrap();
Ok(e.into_inner())
}
None => fold_array_expression_inner(self, ty, e),
}
}
e => fold_array_expression_inner(self, ty, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> Result<StructExpressionInner<'ast, T>, Self::Error> {
match e {
StructExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id)
{
Some(c) => {
let e: StructExpression<'ast, T> = c.try_into().unwrap();
Ok(e.into_inner())
}
None => fold_struct_expression_inner(self, ty, e),
},
e => fold_struct_expression_inner(self, ty, e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::typed_absy::types::DeclarationSignature;
use crate::typed_absy::{
DeclarationArrayType, DeclarationFunctionKey, DeclarationType, FieldElementExpression,
GType, Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol,
TypedStatement,
};
use zokrates_field::Bn128Field;
#[test]
fn inline_const_field() {
// const field a = 1
//
// def main() -> field:
// return a
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(const_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(
const_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
FieldElementExpression::Number(Bn128Field::from(1)),
))),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Number(Bn128Field::from(1)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, Ok(expected_program))
}
#[test]
fn inline_const_boolean() {
// const bool a = true
//
// def main() -> bool:
// return a
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier(
Identifier::from(const_id),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
};
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into(), DeclarationType::Boolean),
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Boolean(
BooleanExpression::Value(true),
))),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
BooleanExpression::Value(true).into()
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, Ok(expected_program))
}
#[test]
fn inline_const_uint() {
// const u32 a = 0x00000001
//
// def main() -> u32:
// return a
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier(
Identifier::from(const_id),
)
.annotate(UBitwidth::B32)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
};
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(
const_id,
"main".into(),
DeclarationType::Uint(UBitwidth::B32),
),
TypedConstantSymbol::Here(TypedConstant::new(
UExpressionInner::Value(1u128)
.annotate(UBitwidth::B32)
.into(),
)),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![UExpressionInner::Value(1u128)
.annotate(UBitwidth::B32)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, Ok(expected_program))
}
#[test]
fn inline_const_field_array() {
// const field[2] a = [2, 2]
//
// def main() -> field:
// return a[0] + a[1]
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
FieldElementExpression::select(
ArrayExpressionInner::Identifier(Identifier::from(const_id))
.annotate(GType::FieldElement, 2usize),
UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
)
.into(),
FieldElementExpression::select(
ArrayExpressionInner::Identifier(Identifier::from(const_id))
.annotate(GType::FieldElement, 2usize),
UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
)
.into(),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(
const_id,
"main".into(),
DeclarationType::Array(DeclarationArrayType::new(
DeclarationType::FieldElement,
2u32,
)),
),
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Array(
ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
))),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
FieldElementExpression::select(
ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
)
.into(),
FieldElementExpression::select(
ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
)
.into(),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, Ok(expected_program))
}
#[test]
fn inline_nested_const_field() {
// const field a = 1
// const field b = a + 1
//
// def main() -> field:
// return b
let const_a_id = "a";
let const_b_id = "b";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(const_b_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: vec![
(
CanonicalConstantIdentifier::new(
const_a_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
)),
)),
),
(
CanonicalConstantIdentifier::new(
const_b_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from(
const_a_id,
)),
box FieldElementExpression::Number(Bn128Field::from(1)),
)),
)),
),
]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants: vec![
(
CanonicalConstantIdentifier::new(
const_a_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
)),
)),
),
(
CanonicalConstantIdentifier::new(
const_b_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(2),
)),
)),
),
]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, Ok(expected_program))
}
#[test]
fn inline_imported_constant() {
// ---------------------
// module `foo`
// --------------------
// const field FOO = 42
//
// def main():
// return
//
// ---------------------
// module `main`
// ---------------------
// from "foo" import FOO
//
// def main() -> field:
// return FOO
let foo_const_id = "FOO";
let foo_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main")
.signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![],
signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
}),
)]
.into_iter()
.collect(),
constants: vec![(
CanonicalConstantIdentifier::new(
foo_const_id,
"foo".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
FieldElementExpression::Number(Bn128Field::from(42)),
))),
)]
.into_iter()
.collect(),
};
let main_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(foo_const_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)]
.into_iter()
.collect(),
constants: vec![(
CanonicalConstantIdentifier::new(
foo_const_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::There(CanonicalConstantIdentifier::new(
foo_const_id,
"foo".into(),
DeclarationType::FieldElement,
)),
)]
.into_iter()
.collect(),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), main_module),
("foo".into(), foo_module.clone()),
]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Number(Bn128Field::from(42)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)]
.into_iter()
.collect(),
constants: vec![(
CanonicalConstantIdentifier::new(
foo_const_id,
"main".into(),
DeclarationType::FieldElement,
),
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
FieldElementExpression::Number(Bn128Field::from(42)),
))),
)]
.into_iter()
.collect(),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), expected_main_module),
("foo".into(), foo_module),
]
.into_iter()
.collect(),
};
assert_eq!(program, Ok(expected_program))
}
}

View file

@ -0,0 +1,844 @@
// Static analysis step to replace all imported constants with the imported value
// This does *not* reduce constants to their literal value
// This step cannot fail as the imports were checked during semantics
use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use std::collections::HashMap;
use zokrates_field::Field;
// a map of the canonical constants in this program. with all imported constants reduced to their canonical value
type ProgramConstants<'ast, T> =
HashMap<OwnedTypedModuleId, HashMap<ConstantIdentifier<'ast>, TypedConstant<'ast, T>>>;
pub struct ConstantResolver<'ast, T> {
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
constants: ProgramConstants<'ast, T>,
}
impl<'ast, 'a, T: Field> ConstantResolver<'ast, T> {
pub fn new(
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
constants: ProgramConstants<'ast, T>,
) -> Self {
ConstantResolver {
modules,
location,
constants,
}
}
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
let constants = ProgramConstants::new();
let mut inliner = ConstantResolver::new(p.modules.clone(), p.main.clone(), constants);
inliner.fold_program(p)
}
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
let prev = self.location.clone();
self.location = location;
self.constants.entry(self.location.clone()).or_default();
prev
}
fn treated(&self, id: &TypedModuleId) -> bool {
self.constants.contains_key(id)
}
fn get_constant(
&self,
id: &CanonicalConstantIdentifier<'ast>,
) -> Option<TypedConstant<'ast, T>> {
self.constants
.get(&id.module)
.and_then(|constants| constants.get(&id.id))
.cloned()
}
}
impl<'ast, T: Field> Folder<'ast, T> for ConstantResolver<'ast, T> {
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
self.fold_module_id(p.main.clone());
TypedProgram {
modules: std::mem::take(&mut self.modules),
..p
}
}
fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId {
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
if !self.treated(&id) {
let current_m_id = self.change_location(id.clone());
let m = self.modules.remove(&id).unwrap();
let m = self.fold_module(m);
self.modules.insert(id.clone(), m);
self.change_location(current_m_id);
}
id
}
fn fold_constant_symbol_declaration(
&mut self,
c: TypedConstantSymbolDeclaration<'ast, T>,
) -> TypedConstantSymbolDeclaration<'ast, T> {
let id = self.fold_canonical_constant_identifier(c.id);
let constant = match c.symbol {
TypedConstantSymbol::There(imported_id) => {
// visit the imported symbol. This triggers visiting the corresponding module if needed
let imported_id = self.fold_canonical_constant_identifier(imported_id);
// after that, the constant must have been defined in the global map
self.get_constant(&imported_id).unwrap()
}
TypedConstantSymbol::Here(c) => fold_constant(self, c),
};
self.constants
.get_mut(&self.location)
.unwrap()
.insert(id.id, constant.clone());
TypedConstantSymbolDeclaration {
id,
symbol: TypedConstantSymbol::Here(constant),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::typed_absy::types::DeclarationSignature;
use crate::typed_absy::{
DeclarationArrayType, DeclarationFunctionKey, DeclarationType, FieldElementExpression,
GType, Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol,
TypedStatement,
};
use zokrates_field::Bn128Field;
#[test]
fn inline_const_field() {
// in the absence of imports, a module is left unchanged
// const field a = 1
//
// def main() -> field:
// return a
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(const_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
symbols: vec![
TypedConstantSymbolDeclaration::new(
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
)),
DeclarationType::FieldElement,
)),
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)
.into(),
],
},
)]
.into_iter()
.collect(),
};
let expected_program = program.clone();
let program = ConstantResolver::inline(program);
assert_eq!(program, expected_program)
}
#[test]
fn no_op_const_boolean() {
// in the absence of imports, a module is left unchanged
// const bool a = true
//
// def main() -> bool:
// return main.zok/a
let const_id = CanonicalConstantIdentifier::new("a", "main".into());
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier(
Identifier::from(const_id.clone()),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
symbols: vec![
TypedConstantSymbolDeclaration::new(
const_id,
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::Boolean(BooleanExpression::Value(true)),
DeclarationType::Boolean,
)),
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
),
TypedFunctionSymbol::Here(main),
)
.into(),
],
},
)]
.into_iter()
.collect(),
};
let expected_program = program.clone();
let program = ConstantResolver::inline(program);
assert_eq!(program, expected_program)
}
#[test]
fn no_op_const_uint() {
// in the absence of imports, a module is left unchanged
// const u32 a = 0x00000001
//
// def main() -> u32:
// return a
let const_id = CanonicalConstantIdentifier::new("a", "main".into());
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier(
Identifier::from(const_id.clone()),
)
.annotate(UBitwidth::B32)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
symbols: vec![
TypedConstantSymbolDeclaration::new(
const_id,
TypedConstantSymbol::Here(TypedConstant::new(
UExpressionInner::Value(1u128)
.annotate(UBitwidth::B32)
.into(),
DeclarationType::Uint(UBitwidth::B32),
)),
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
),
TypedFunctionSymbol::Here(main),
)
.into(),
],
},
)]
.into_iter()
.collect(),
};
let expected_program = program.clone();
let program = ConstantResolver::inline(program);
assert_eq!(program, expected_program)
}
#[test]
fn no_op_const_field_array() {
// in the absence of imports, a module is left unchanged
// const field[2] a = [2, 2]
//
// def main() -> field:
// return a[0] + a[1]
let const_id = CanonicalConstantIdentifier::new("a", "main".into());
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
FieldElementExpression::select(
ArrayExpressionInner::Identifier(Identifier::from(const_id.clone()))
.annotate(GType::FieldElement, 2usize),
UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
)
.into(),
FieldElementExpression::select(
ArrayExpressionInner::Identifier(Identifier::from(const_id.clone()))
.annotate(GType::FieldElement, 2usize),
UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
)
.into(),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
symbols: vec![
TypedConstantSymbolDeclaration::new(
const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::Array(
ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2))
.into(),
FieldElementExpression::Number(Bn128Field::from(2))
.into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
),
DeclarationType::Array(DeclarationArrayType::new(
DeclarationType::FieldElement,
2u32,
)),
)),
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)
.into(),
],
},
)]
.into_iter()
.collect(),
};
let expected_program = program.clone();
let program = ConstantResolver::inline(program);
assert_eq!(program, expected_program)
}
#[test]
fn no_op_nested_const_field() {
// const field a = 1
// const field b = a + 1
//
// def main() -> field:
// return b
let const_a_id = CanonicalConstantIdentifier::new("a", "main".into());
let const_b_id = CanonicalConstantIdentifier::new("a", "main".into());
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(const_b_id.clone())).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
symbols: vec![
TypedConstantSymbolDeclaration::new(
const_a_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
)),
DeclarationType::FieldElement,
)),
)
.into(),
TypedConstantSymbolDeclaration::new(
const_b_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from(
const_a_id.clone(),
)),
box FieldElementExpression::Number(Bn128Field::from(1)),
)),
DeclarationType::FieldElement,
)),
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)
.into(),
],
},
)]
.into_iter()
.collect(),
};
let expected_program = program.clone();
let program = ConstantResolver::inline(program);
assert_eq!(program, expected_program)
}
#[test]
fn inline_imported_constant() {
// ---------------------
// module `foo`
// --------------------
// const field FOO = 42
// const field BAR = FOO
//
// def main():
// return
//
// ---------------------
// module `main`
// ---------------------
// from "foo" import BAR
//
// def main() -> field:
// return FOO
// Should be resolved to
// ---------------------
// module `foo`
// --------------------
// const field BAR = ./foo.zok/FOO
//
// def main():
// return
//
// ---------------------
// module `main`
// ---------------------
// const field FOO = 42
//
// def main() -> field:
// return FOO
let foo_const_id = CanonicalConstantIdentifier::new("FOO", "foo".into());
let bar_const_id = CanonicalConstantIdentifier::new("BAR", "foo".into());
let foo_module = TypedModule {
symbols: vec![
TypedConstantSymbolDeclaration::new(
foo_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(42),
)),
DeclarationType::FieldElement,
)),
)
.into(),
TypedConstantSymbolDeclaration::new(
bar_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Identifier(
foo_const_id.clone().into(),
)),
DeclarationType::FieldElement,
)),
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("foo", "main")
.signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![],
signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
}),
)
.into(),
],
};
let main_const_id = CanonicalConstantIdentifier::new("FOO", "main".into());
let main_module = TypedModule {
symbols: vec![
TypedConstantSymbolDeclaration::new(
main_const_id.clone(),
TypedConstantSymbol::There(bar_const_id),
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(
main_const_id.clone(),
))
.into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)
.into(),
],
};
let program = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), main_module),
("foo".into(), foo_module.clone()),
]
.into_iter()
.collect(),
};
let program = ConstantResolver::inline(program);
let expected_main_module = TypedModule {
symbols: vec![
TypedConstantSymbolDeclaration::new(
main_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Identifier(
foo_const_id.into(),
)),
DeclarationType::FieldElement,
)),
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(
main_const_id.clone(),
))
.into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)
.into(),
],
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), expected_main_module),
("foo".into(), foo_module),
]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
#[test]
fn inline_imported_constant_with_generics() {
// ---------------------
// module `foo`
// --------------------
// const field FOO = 2
// const field[FOO] BAR = [1; FOO]
//
// def main():
// return
//
// ---------------------
// module `main`
// ---------------------
// from "foo" import FOO
// from "foo" import BAR
// const field[FOO] BAZ = BAR
//
// def main() -> field:
// return FOO
// Should be resolved to
// ---------------------
// module `foo`
// --------------------
// const field FOO = 2
// const field[FOO] BAR = [1; FOO]
//
// def main():
// return
//
// ---------------------
// module `main`
// ---------------------
// const FOO = 2
// const BAR = [1; ./foo.zok/FOO]
// const field[FOO] BAZ = BAR
//
// def main() -> field:
// return FOO
let foo_const_id = CanonicalConstantIdentifier::new("FOO", "foo".into());
let bar_const_id = CanonicalConstantIdentifier::new("BAR", "foo".into());
let foo_module = TypedModule {
symbols: vec![
TypedConstantSymbolDeclaration::new(
foo_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(2),
)),
DeclarationType::FieldElement,
)),
)
.into(),
TypedConstantSymbolDeclaration::new(
bar_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::Array(
ArrayExpressionInner::Repeat(
box FieldElementExpression::Number(Bn128Field::from(1)).into(),
box UExpression::from(foo_const_id.clone()),
)
.annotate(Type::FieldElement, foo_const_id.clone()),
),
DeclarationType::Array(DeclarationArrayType::new(
DeclarationType::FieldElement,
DeclarationConstant::Constant(foo_const_id.clone()),
)),
)),
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("foo", "main")
.signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![],
signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
}),
)
.into(),
],
};
let main_foo_const_id = CanonicalConstantIdentifier::new("FOO", "main".into());
let main_bar_const_id = CanonicalConstantIdentifier::new("BAR", "main".into());
let main_baz_const_id = CanonicalConstantIdentifier::new("BAZ", "main".into());
let main_module = TypedModule {
symbols: vec![
TypedConstantSymbolDeclaration::new(
main_foo_const_id.clone(),
TypedConstantSymbol::There(foo_const_id.clone()),
)
.into(),
TypedConstantSymbolDeclaration::new(
main_bar_const_id.clone(),
TypedConstantSymbol::There(bar_const_id),
)
.into(),
TypedConstantSymbolDeclaration::new(
main_baz_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::Array(
ArrayExpressionInner::Identifier(main_bar_const_id.clone().into())
.annotate(Type::FieldElement, main_foo_const_id.clone()),
),
DeclarationType::Array(DeclarationArrayType::new(
DeclarationType::FieldElement,
DeclarationConstant::Constant(foo_const_id.clone()),
)),
)),
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(
main_foo_const_id.clone(),
))
.into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)
.into(),
],
};
let program = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), main_module),
("foo".into(), foo_module.clone()),
]
.into_iter()
.collect(),
};
let program = ConstantResolver::inline(program);
let expected_main_module = TypedModule {
symbols: vec![
TypedConstantSymbolDeclaration::new(
main_foo_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
FieldElementExpression::Number(Bn128Field::from(2)).into(),
DeclarationType::FieldElement,
)),
)
.into(),
TypedConstantSymbolDeclaration::new(
main_bar_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::Array(
ArrayExpressionInner::Repeat(
box FieldElementExpression::Number(Bn128Field::from(1)).into(),
box UExpression::from(foo_const_id.clone()),
)
.annotate(Type::FieldElement, foo_const_id.clone()),
),
DeclarationType::Array(DeclarationArrayType::new(
DeclarationType::FieldElement,
DeclarationConstant::Constant(foo_const_id.clone()),
)),
)),
)
.into(),
TypedConstantSymbolDeclaration::new(
main_baz_const_id.clone(),
TypedConstantSymbol::Here(TypedConstant::new(
TypedExpression::Array(
ArrayExpressionInner::Identifier(main_bar_const_id.into())
.annotate(Type::FieldElement, main_foo_const_id.clone()),
),
DeclarationType::Array(DeclarationArrayType::new(
DeclarationType::FieldElement,
DeclarationConstant::Constant(foo_const_id.clone()),
)),
)),
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(
main_foo_const_id.clone(),
))
.into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)
.into(),
],
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), expected_main_module),
("foo".into(), foo_module),
]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
}

View file

@ -126,7 +126,7 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_declaration_parameter(
&mut self,
p: typed_absy::DeclarationParameter<'ast>,
p: typed_absy::DeclarationParameter<'ast, T>,
) -> Vec<zir::Parameter<'ast>> {
let private = p.private;
self.fold_variable(crate::typed_absy::variable::try_from_g_variable(p.id).unwrap())
@ -1093,7 +1093,7 @@ fn fold_function<'ast, T: Field>(
statements: main_statements_buffer,
signature: typed_absy::types::ConcreteSignature::try_from(
crate::typed_absy::types::try_from_g_signature::<
crate::typed_absy::types::DeclarationConstant<'ast>,
crate::typed_absy::types::DeclarationConstant<'ast, T>,
crate::typed_absy::UExpression<'ast, T>,
>(fun.signature)
.unwrap(),
@ -1139,14 +1139,12 @@ fn fold_program<'ast, T: Field>(
let main_module = p.modules.remove(&p.main).unwrap();
let main_function = main_module
.functions
.into_iter()
.find(|(key, _)| key.id == "main")
.into_functions_iter()
.find(|d| d.key.id == "main")
.unwrap()
.1;
.symbol;
let main_function = match main_function {
typed_absy::TypedFunctionSymbol::Here(f) => f,
typed_absy::TypedFunctionSymbol::Here(main) => main,
_ => unreachable!(),
};

View file

@ -6,7 +6,7 @@
mod branch_isolator;
mod constant_argument_checker;
mod constant_inliner;
mod constant_resolver;
mod flat_propagation;
mod flatten_complex_types;
mod out_of_bounds;
@ -28,7 +28,7 @@ use self::unconstrained_vars::UnconstrainedVariableDetector;
use self::variable_write_remover::VariableWriteRemover;
use crate::compile::CompileConfig;
use crate::ir::Prog;
use crate::static_analysis::constant_inliner::ConstantInliner;
use crate::static_analysis::constant_resolver::ConstantResolver;
use crate::static_analysis::zir_propagation::ZirPropagator;
use crate::typed_absy::{abi::Abi, TypedProgram};
use crate::zir::ZirProgram;
@ -48,17 +48,10 @@ pub enum Error {
Propagation(self::propagation::Error),
ZirPropagation(self::zir_propagation::Error),
NonConstantArgument(self::constant_argument_checker::Error),
ConstantInliner(self::constant_inliner::Error),
UnconstrainedVariable(self::unconstrained_vars::Error),
OutOfBounds(self::out_of_bounds::Error),
}
impl From<constant_inliner::Error> for Error {
fn from(e: self::constant_inliner::Error) -> Self {
Error::ConstantInliner(e)
}
}
impl From<reducer::Error> for Error {
fn from(e: self::reducer::Error) -> Self {
Error::Reducer(e)
@ -102,7 +95,6 @@ impl fmt::Display for Error {
Error::Propagation(e) => write!(f, "{}", e),
Error::ZirPropagation(e) => write!(f, "{}", e),
Error::NonConstantArgument(e) => write!(f, "{}", e),
Error::ConstantInliner(e) => write!(f, "{}", e),
Error::UnconstrainedVariable(e) => write!(f, "{}", e),
Error::OutOfBounds(e) => write!(f, "{}", e),
}
@ -113,7 +105,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
// inline user-defined constants
log::debug!("Static analyser: Inline constants");
let r = ConstantInliner::inline(self).map_err(Error::from)?;
let r = ConstantResolver::inline(self);
log::trace!("\n{}", r);
// isolate branches

View file

@ -16,7 +16,7 @@ use std::convert::{TryFrom, TryInto};
use std::fmt;
use zokrates_field::Field;
type Constants<'ast, T> = HashMap<Identifier<'ast>, TypedExpression<'ast, T>>;
pub type Constants<'ast, T> = HashMap<Identifier<'ast>, TypedExpression<'ast, T>>;
#[derive(Debug, PartialEq)]
pub enum Error {
@ -45,6 +45,7 @@ impl fmt::Display for Error {
}
}
#[derive(Debug)]
pub struct Propagator<'ast, 'a, T: Field> {
// constants keeps track of constant expressions
// we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is
@ -149,21 +150,17 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
})
}
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> Result<TypedModule<'ast, T>, Error> {
Ok(TypedModule {
functions: m
.functions
.into_iter()
.map(|(key, fun)| {
if key.id == "main" {
self.fold_function_symbol(fun).map(|f| (key, f))
} else {
Ok((key, fun))
}
})
.collect::<Result<_, _>>()?,
..m
})
fn fold_function_symbol_declaration(
&mut self,
s: TypedFunctionSymbolDeclaration<'ast, T>,
) -> Result<TypedFunctionSymbolDeclaration<'ast, T>, Error> {
if s.key.id == "main" {
let key = s.key;
self.fold_function_symbol(s.symbol)
.map(|f| TypedFunctionSymbolDeclaration { key, symbol: f })
} else {
Ok(s)
}
}
fn fold_function(

View file

@ -0,0 +1,169 @@
// given a (partial) map of values for program constants, replace where applicable constants by their value
use crate::static_analysis::reducer::ConstantDefinitions;
use crate::typed_absy::{
folder::*, ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier,
DeclarationConstant, Expr, FieldElementExpression, Identifier, StructExpression,
StructExpressionInner, StructType, TypedProgram, TypedSymbolDeclaration, UBitwidth,
UExpression, UExpressionInner,
};
use zokrates_field::Field;
use std::convert::{TryFrom, TryInto};
pub struct ConstantsReader<'a, 'ast, T> {
constants: &'a ConstantDefinitions<'ast, T>,
}
impl<'a, 'ast, T: Field> ConstantsReader<'a, 'ast, T> {
pub fn with_constants(constants: &'a ConstantDefinitions<'ast, T>) -> Self {
Self { constants }
}
pub fn read_into_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
self.fold_program(p)
}
pub fn read_into_symbol_declaration(
&mut self,
d: TypedSymbolDeclaration<'ast, T>,
) -> TypedSymbolDeclaration<'ast, T> {
self.fold_symbol_declaration(d)
}
}
impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
}) => {
assert_eq!(version, 0);
match self.constants.get(&c).cloned() {
Some(v) => v.try_into().unwrap(),
None => FieldElementExpression::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
}),
}
}
e => fold_field_expression(self, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
}) => {
assert_eq!(version, 0);
match self.constants.get(&c).cloned() {
Some(v) => v.try_into().unwrap(),
None => BooleanExpression::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
}),
}
}
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression_inner(
&mut self,
ty: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
}) => {
assert_eq!(version, 0);
match self.constants.get(&c).cloned() {
Some(v) => UExpression::try_from(v).unwrap().into_inner(),
None => UExpressionInner::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
}),
}
}
e => fold_uint_expression_inner(self, ty, e),
}
}
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
}) => {
assert_eq!(version, 0);
match self.constants.get(&c).cloned() {
Some(v) => ArrayExpression::try_from(v).unwrap().into_inner(),
None => ArrayExpressionInner::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
}),
}
}
e => fold_array_expression_inner(self, ty, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
}) => {
assert_eq!(version, 0);
match self.constants.get(&c).cloned() {
Some(v) => StructExpression::try_from(v).unwrap().into_inner(),
None => StructExpressionInner::Identifier(Identifier {
id: CoreIdentifier::Constant(c),
version,
}),
}
}
e => fold_struct_expression_inner(self, ty, e),
}
}
fn fold_declaration_constant(
&mut self,
c: DeclarationConstant<'ast, T>,
) -> DeclarationConstant<'ast, T> {
match c {
DeclarationConstant::Constant(c) => {
let c = self.fold_canonical_constant_identifier(c);
match self.constants.get(&c).cloned() {
Some(e) => match UExpression::try_from(e).unwrap().into_inner() {
UExpressionInner::Value(v) => DeclarationConstant::Concrete(v as u32),
_ => unreachable!(),
},
None => DeclarationConstant::Constant(c),
}
}
c => fold_declaration_constant(self, c),
}
}
}

View file

@ -0,0 +1,163 @@
// A folder to inline all constant definitions down to a single literal and register them in the state for later use.
use crate::static_analysis::reducer::{
constants_reader::ConstantsReader, reduce_function, ConstantDefinitions, Error,
};
use crate::typed_absy::{
result_folder::*, types::ConcreteGenericsAssignment, OwnedTypedModuleId, TypedConstant,
TypedConstantSymbol, TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram,
TypedSymbolDeclaration, UExpression,
};
use std::collections::{BTreeMap, HashSet};
use zokrates_field::Field;
pub struct ConstantsWriter<'ast, T> {
treated: HashSet<OwnedTypedModuleId>,
constants: ConstantDefinitions<'ast, T>,
location: OwnedTypedModuleId,
program: TypedProgram<'ast, T>,
}
impl<'ast, T: Field> ConstantsWriter<'ast, T> {
pub fn with_program(program: TypedProgram<'ast, T>) -> Self {
ConstantsWriter {
constants: ConstantDefinitions::default(),
location: program.main.clone(),
treated: HashSet::default(),
program,
}
}
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
let prev = self.location.clone();
self.location = location;
self.treated.insert(self.location.clone());
prev
}
fn treated(&self, id: &TypedModuleId) -> bool {
self.treated.contains(id)
}
fn update_program(&mut self) {
let mut p = TypedProgram {
main: "".into(),
modules: BTreeMap::default(),
};
std::mem::swap(&mut self.program, &mut p);
self.program = ConstantsReader::with_constants(&self.constants).read_into_program(p);
}
fn update_symbol_declaration(
&self,
d: TypedSymbolDeclaration<'ast, T>,
) -> TypedSymbolDeclaration<'ast, T> {
ConstantsReader::with_constants(&self.constants).read_into_symbol_declaration(d)
}
}
impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> {
type Error = Error;
fn fold_module_id(
&mut self,
id: OwnedTypedModuleId,
) -> Result<OwnedTypedModuleId, Self::Error> {
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
if !self.treated(&id) {
let current_m_id = self.change_location(id.clone());
// I did not find a way to achieve this without cloning the module. Assuming we do not clone:
// to fold the module, we need to consume it, so it gets removed from the modules
// but to inline the calls while folding the module, all modules must be present
// therefore we clone...
// this does not lead to a module being folded more than once, as the first time
// we change location to this module, it's added to the `treated` set
let m = self.program.modules.get(&id).cloned().unwrap();
let m = self.fold_module(m)?;
self.program.modules.insert(id.clone(), m);
self.change_location(current_m_id);
}
Ok(id)
}
fn fold_symbol_declaration(
&mut self,
s: TypedSymbolDeclaration<'ast, T>,
) -> Result<TypedSymbolDeclaration<'ast, T>, Self::Error> {
// before we treat the symbol, propagate the constants into it, as may be using constants defined earlier in this module.
let s = self.update_symbol_declaration(s);
let s = fold_symbol_declaration(self, s)?;
// after we treat the symbol, propagate again, as treating this symbol may have triggered checking another module, resolving new constants which this symbol may be using.
Ok(self.update_symbol_declaration(s))
}
fn fold_constant_symbol_declaration(
&mut self,
d: TypedConstantSymbolDeclaration<'ast, T>,
) -> Result<TypedConstantSymbolDeclaration<'ast, T>, Self::Error> {
let id = self.fold_canonical_constant_identifier(d.id)?;
match d.symbol {
TypedConstantSymbol::Here(c) => {
let c = self.fold_constant(c)?;
use crate::typed_absy::{DeclarationSignature, TypedFunction, TypedStatement};
// wrap this expression in a function
let wrapper = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![c.expression])],
signature: DeclarationSignature::new().outputs(vec![c.ty.clone()]),
};
let mut inlined_wrapper = reduce_function(
wrapper,
ConcreteGenericsAssignment::default(),
&self.program,
)?;
if let TypedStatement::Return(mut expressions) =
inlined_wrapper.statements.pop().unwrap()
{
assert_eq!(expressions.len(), 1);
let constant_expression = expressions.pop().unwrap();
use crate::typed_absy::Constant;
if !constant_expression.is_constant() {
return Err(Error::ConstantReduction(id.id.to_string(), id.module));
};
use crate::typed_absy::Typed;
if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>(
c.ty.clone(),
)
.unwrap()
== constant_expression.get_type()
{
// add to the constant map
self.constants
.insert(id.clone(), constant_expression.clone());
// after we reduced a constant, propagate it through the whole program
self.update_program();
Ok(TypedConstantSymbolDeclaration {
id,
symbol: TypedConstantSymbol::Here(TypedConstant {
expression: constant_expression,
ty: c.ty,
}),
})
} else {
Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant_expression.get_type(), id, c.ty)))
}
} else {
Err(Error::ConstantReduction(id.id.to_string(), id.module))
}
}
_ => unreachable!("all constants should be local"),
}
}
}

View file

@ -35,13 +35,13 @@ use crate::typed_absy::Identifier;
use crate::typed_absy::TypedAssignee;
use crate::typed_absy::{
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Expr,
Signature, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, Types,
UExpression, UExpressionInner, Variable,
Signature, TypedExpression, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedProgram,
TypedStatement, Types, UExpression, UExpressionInner, Variable,
};
use zokrates_field::Field;
pub enum InlineError<'ast, T> {
Generic(DeclarationFunctionKey<'ast>, ConcreteFunctionKey<'ast>),
Generic(DeclarationFunctionKey<'ast, T>, ConcreteFunctionKey<'ast>),
Flat(
FlatEmbed,
Vec<u32>,
@ -49,7 +49,7 @@ pub enum InlineError<'ast, T> {
Types<'ast, T>,
),
NonConstant(
DeclarationFunctionKey<'ast>,
DeclarationFunctionKey<'ast, T>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
Types<'ast, T>,
@ -57,20 +57,20 @@ pub enum InlineError<'ast, T> {
}
fn get_canonical_function<'ast, T: Field>(
function_key: DeclarationFunctionKey<'ast>,
function_key: DeclarationFunctionKey<'ast, T>,
program: &TypedProgram<'ast, T>,
) -> (DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>) {
match program
) -> TypedFunctionSymbolDeclaration<'ast, T> {
let s = program
.modules
.get(&function_key.module)
.unwrap()
.functions
.iter()
.find(|(key, _)| function_key == **key)
.unwrap()
{
(_, TypedFunctionSymbol::There(key)) => get_canonical_function(key.clone(), &program),
(key, s) => (key.clone(), s.clone()),
.functions_iter()
.find(|d| d.key == function_key)
.unwrap();
match &s.symbol {
TypedFunctionSymbol::There(key) => get_canonical_function(key.clone(), &program),
_ => s.clone(),
}
}
@ -80,7 +80,7 @@ type InlineResult<'ast, T> = Result<
>;
pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
k: DeclarationFunctionKey<'ast>,
k: DeclarationFunctionKey<'ast, T>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output: &E::Ty,
@ -134,7 +134,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
}
};
let (decl_key, symbol) = get_canonical_function(k.clone(), program);
let decl = get_canonical_function(k.clone(), program);
// get an assignment of generics for this call site
let assignment: ConcreteGenericsAssignment<'ast> = k
@ -144,18 +144,18 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
InlineError::Generic(
k.clone(),
ConcreteFunctionKey {
module: decl_key.module.clone(),
id: decl_key.id,
module: decl.key.module.clone(),
id: decl.key.id,
signature: inferred_signature.clone(),
},
)
})?;
let f = match symbol {
let f = match decl.symbol {
TypedFunctionSymbol::Here(f) => Ok(f),
TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat(
e,
e.generics(&assignment),
e.generics::<T>(&assignment),
arguments.clone(),
output_types,
)),
@ -169,7 +169,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)),
};
let call_log = TypedStatement::PushCallLog(decl_key.clone(), assignment.clone());
let call_log = TypedStatement::PushCallLog(decl.key.clone(), assignment.clone());
let input_bindings: Vec<TypedStatement<'ast, T>> = ssa_f
.arguments

View file

@ -11,6 +11,8 @@
// - unroll loops
// - inline function calls. This includes applying shallow-ssa on the target function
mod constants_reader;
mod constants_writer;
mod inline;
mod shallow_ssa;
@ -18,26 +20,33 @@ use self::inline::{inline_call, InlineError};
use crate::typed_absy::result_folder::*;
use crate::typed_absy::types::ConcreteGenericsAssignment;
use crate::typed_absy::types::GGenericsAssignment;
use crate::typed_absy::CanonicalConstantIdentifier;
use crate::typed_absy::Folder;
use std::collections::HashMap;
use crate::typed_absy::{
ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall,
FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, TypedExpression,
TypedExpressionList, TypedExpressionListInner, TypedFunction, TypedFunctionSymbol, TypedModule,
TypedProgram, TypedStatement, UExpression, UExpressionInner, Variable,
FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId,
TypedExpression, TypedExpressionList, TypedExpressionListInner, TypedFunction,
TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram, TypedStatement,
UExpression, UExpressionInner, Variable,
};
use zokrates_field::Field;
use self::constants_writer::ConstantsWriter;
use self::shallow_ssa::ShallowTransformer;
use crate::static_analysis::Propagator;
use crate::static_analysis::propagation::{Constants, Propagator};
use std::fmt;
const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20);
// A map to register the canonical value of all constants. The values must be literals.
pub type ConstantDefinitions<'ast, T> =
HashMap<CanonicalConstantIdentifier<'ast>, TypedExpression<'ast, T>>;
// An SSA version map, giving access to the latest version number for each identifier
pub type Versions<'ast> = HashMap<CoreIdentifier<'ast>, usize>;
@ -55,6 +64,8 @@ pub enum Error {
// TODO: give more details about what's blocking the progress
NoProgress,
LoopTooLarge(u128),
ConstantReduction(String, OwnedTypedModuleId),
Type(String),
}
impl fmt::Display for Error {
@ -68,6 +79,8 @@ impl fmt::Display for Error {
Error::GenericsInMain => write!(f, "Cannot generate code for generic function"),
Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"),
Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE),
Error::ConstantReduction(name, module) => write!(f, "Failed to reduce constant `{}` in module `{}` to a literal, try simplifying its declaration", name, module.display()),
Error::Type(message) => write!(f, "{}", message),
}
}
}
@ -159,6 +172,7 @@ fn register<'ast>(
}
}
#[derive(Debug)]
struct Reducer<'ast, 'a, T> {
statement_buffer: Vec<TypedStatement<'ast, T>>,
for_loop_versions: Vec<Versions<'ast>>,
@ -304,6 +318,13 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
})
}
fn fold_canonical_constant_identifier(
&mut self,
_: CanonicalConstantIdentifier<'ast>,
) -> Result<CanonicalConstantIdentifier<'ast>, Self::Error> {
unreachable!("canonical constant identifiers should not be folded, they should be inlined")
}
fn fold_statement(
&mut self,
s: TypedStatement<'ast, T>,
@ -487,15 +508,21 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
}
pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
// inline all constants and replace them in the program
let mut constants_writer = ConstantsWriter::with_program(p.clone());
let p = constants_writer.fold_program(p)?;
// inline starting from main
let main_module = p.modules.get(&p.main).unwrap().clone();
let (main_key, main_function) = main_module
.functions
.iter()
.find(|(k, _)| k.id == "main")
let decl = main_module
.functions_iter()
.find(|d| d.key.id == "main")
.unwrap();
let main_function = match main_function {
let main_function = match &decl.symbol {
TypedFunctionSymbol::Here(f) => f.clone(),
_ => unreachable!(),
};
@ -509,13 +536,11 @@ pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, E
modules: vec![(
p.main.clone(),
TypedModule {
functions: vec![(
main_key.clone(),
symbols: vec![TypedFunctionSymbolDeclaration::new(
decl.key.clone(),
TypedFunctionSymbol::Here(main_function),
)]
.into_iter()
.collect(),
constants: Default::default(),
)
.into()],
},
)]
.into_iter()
@ -533,7 +558,9 @@ fn reduce_function<'ast, T: Field>(
) -> Result<TypedFunction<'ast, T>, Error> {
let mut versions = Versions::default();
match ShallowTransformer::transform(f, &generics, &mut versions) {
let mut constants = Constants::default();
let f = match ShallowTransformer::transform(f, &generics, &mut versions) {
Output::Complete(f) => Ok(f),
Output::Incomplete(new_f, new_for_loop_versions) => {
let mut for_loop_versions = new_for_loop_versions;
@ -542,8 +569,6 @@ fn reduce_function<'ast, T: Field>(
let mut substitutions = Substitutions::default();
let mut constants: HashMap<Identifier<'ast>, TypedExpression<'ast, T>> = HashMap::new();
let mut hash = None;
loop {
@ -600,7 +625,11 @@ fn reduce_function<'ast, T: Field>(
}
}
}
}
}?;
Propagator::with_constants(&mut constants)
.fold_function(f)
.map_err(|e| Error::Incompatible(format!("{}", e)))
}
fn compute_hash<T: Field>(f: &TypedFunction<T>) -> u64 {
@ -710,27 +739,26 @@ mod tests {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![
(
symbols: vec![
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(foo),
),
(
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
),
]
.into_iter()
.collect(),
constants: Default::default(),
)
.into(),
],
},
)]
.into_iter()
@ -786,17 +814,15 @@ mod tests {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
symbols: vec![TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants: Default::default(),
)
.into()],
},
)]
.into_iter()
@ -913,24 +939,23 @@ mod tests {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![
(
symbols: vec![
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
TypedFunctionSymbol::Here(foo),
),
(
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
),
]
.into_iter()
.collect(),
constants: Default::default(),
)
.into(),
],
},
)]
.into_iter()
@ -1005,17 +1030,15 @@ mod tests {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
symbols: vec![TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants: Default::default(),
)
.into()],
},
)]
.into_iter()
@ -1141,24 +1164,23 @@ mod tests {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![
(
symbols: vec![
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
TypedFunctionSymbol::Here(foo),
),
(
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
),
]
.into_iter()
.collect(),
constants: Default::default(),
)
.into(),
],
},
)]
.into_iter()
@ -1233,17 +1255,15 @@ mod tests {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
symbols: vec![TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants: Default::default(),
)
.into()],
},
)]
.into_iter()
@ -1399,25 +1419,25 @@ mod tests {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![
(
symbols: vec![
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "bar")
.signature(bar_signature.clone()),
TypedFunctionSymbol::Here(bar),
),
(
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
TypedFunctionSymbol::Here(foo),
),
(
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main"),
TypedFunctionSymbol::Here(main),
),
]
.into_iter()
.collect(),
constants: Default::default(),
)
.into(),
],
},
)]
.into_iter()
@ -1459,14 +1479,12 @@ mod tests {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
symbols: vec![TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main")
.signature(DeclarationSignature::new()),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants: Default::default(),
)
.into()],
},
)]
.into_iter()
@ -1540,22 +1558,21 @@ mod tests {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![
(
symbols: vec![
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
TypedFunctionSymbol::Here(foo),
),
(
)
.into(),
TypedFunctionSymbolDeclaration::new(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
),
TypedFunctionSymbol::Here(main),
),
]
.into_iter()
.collect(),
constants: Default::default(),
)
.into(),
],
},
)]
.into_iter()

View file

@ -35,15 +35,15 @@ mod tests {
};
use crate::typed_absy::{
parameter::DeclarationParameter, variable::DeclarationVariable, ConcreteType,
TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram,
TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule,
TypedProgram,
};
use std::collections::HashMap;
use std::collections::BTreeMap;
use zokrates_field::Bn128Field;
#[test]
fn generate_abi_from_typed_ast() {
let mut functions = HashMap::new();
functions.insert(
let symbols = vec![TypedFunctionSymbolDeclaration::new(
ConcreteFunctionKey::with_location("main", "main").into(),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![
@ -62,16 +62,11 @@ mod tests {
.outputs(vec![ConcreteType::FieldElement])
.into(),
}),
);
)
.into()];
let mut modules = HashMap::new();
modules.insert(
"main".into(),
TypedModule {
functions,
constants: Default::default(),
},
);
let mut modules = BTreeMap::new();
modules.insert("main".into(), TypedModule { symbols });
let typed_ast: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),

View file

@ -47,6 +47,27 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_module(self, m)
}
fn fold_symbol_declaration(
&mut self,
s: TypedSymbolDeclaration<'ast, T>,
) -> TypedSymbolDeclaration<'ast, T> {
fold_symbol_declaration(self, s)
}
fn fold_function_symbol_declaration(
&mut self,
s: TypedFunctionSymbolDeclaration<'ast, T>,
) -> TypedFunctionSymbolDeclaration<'ast, T> {
fold_function_symbol_declaration(self, s)
}
fn fold_constant_symbol_declaration(
&mut self,
s: TypedConstantSymbolDeclaration<'ast, T>,
) -> TypedConstantSymbolDeclaration<'ast, T> {
fold_constant_symbol_declaration(self, s)
}
fn fold_constant(&mut self, c: TypedConstant<'ast, T>) -> TypedConstant<'ast, T> {
fold_constant(self, c)
}
@ -67,8 +88,8 @@ pub trait Folder<'ast, T: Field>: Sized {
fn fold_declaration_function_key(
&mut self,
key: DeclarationFunctionKey<'ast>,
) -> DeclarationFunctionKey<'ast> {
key: DeclarationFunctionKey<'ast, T>,
) -> DeclarationFunctionKey<'ast, T> {
fold_declaration_function_key(self, key)
}
@ -76,18 +97,24 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_function(self, f)
}
fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> {
fn fold_signature(
&mut self,
s: DeclarationSignature<'ast, T>,
) -> DeclarationSignature<'ast, T> {
fold_signature(self, s)
}
fn fold_declaration_constant(
&mut self,
c: DeclarationConstant<'ast>,
) -> DeclarationConstant<'ast> {
c: DeclarationConstant<'ast, T>,
) -> DeclarationConstant<'ast, T> {
fold_declaration_constant(self, c)
}
fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> {
fn fold_parameter(
&mut self,
p: DeclarationParameter<'ast, T>,
) -> DeclarationParameter<'ast, T> {
DeclarationParameter {
id: self.fold_declaration_variable(p.id),
..p
@ -107,8 +134,8 @@ pub trait Folder<'ast, T: Field>: Sized {
fn fold_declaration_variable(
&mut self,
v: DeclarationVariable<'ast>,
) -> DeclarationVariable<'ast> {
v: DeclarationVariable<'ast, T>,
) -> DeclarationVariable<'ast, T> {
DeclarationVariable {
id: self.fold_name(v.id),
_type: self.fold_declaration_type(v._type),
@ -155,7 +182,7 @@ pub trait Folder<'ast, T: Field>: Sized {
}
}
fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> {
fn fold_declaration_type(&mut self, t: DeclarationType<'ast, T>) -> DeclarationType<'ast, T> {
use self::GType::*;
match t {
@ -167,8 +194,8 @@ pub trait Folder<'ast, T: Field>: Sized {
fn fold_declaration_array_type(
&mut self,
t: DeclarationArrayType<'ast>,
) -> DeclarationArrayType<'ast> {
t: DeclarationArrayType<'ast, T>,
) -> DeclarationArrayType<'ast, T> {
DeclarationArrayType {
ty: box self.fold_declaration_type(*t.ty),
size: self.fold_declaration_constant(t.size),
@ -177,8 +204,8 @@ pub trait Folder<'ast, T: Field>: Sized {
fn fold_declaration_struct_type(
&mut self,
t: DeclarationStructType<'ast>,
) -> DeclarationStructType<'ast> {
t: DeclarationStructType<'ast, T>,
) -> DeclarationStructType<'ast, T> {
DeclarationStructType {
generics: t
.generics
@ -232,7 +259,6 @@ pub trait Folder<'ast, T: Field>: Sized {
CanonicalConstantIdentifier {
module: self.fold_module_id(i.module),
id: i.id,
ty: box self.fold_declaration_type(*i.ty),
}
}
@ -378,29 +404,48 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
m: TypedModule<'ast, T>,
) -> TypedModule<'ast, T> {
TypedModule {
constants: m
.constants
symbols: m
.symbols
.into_iter()
.map(|(id, tc)| {
(
f.fold_canonical_constant_identifier(id),
f.fold_constant_symbol(tc),
)
})
.collect(),
functions: m
.functions
.into_iter()
.map(|(key, fun)| {
(
f.fold_declaration_function_key(key),
f.fold_function_symbol(fun),
)
})
.map(|s| f.fold_symbol_declaration(s))
.collect(),
}
}
pub fn fold_symbol_declaration<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
d: TypedSymbolDeclaration<'ast, T>,
) -> TypedSymbolDeclaration<'ast, T> {
match d {
TypedSymbolDeclaration::Function(d) => {
TypedSymbolDeclaration::Function(f.fold_function_symbol_declaration(d))
}
TypedSymbolDeclaration::Constant(d) => {
TypedSymbolDeclaration::Constant(f.fold_constant_symbol_declaration(d))
}
}
}
pub fn fold_function_symbol_declaration<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
d: TypedFunctionSymbolDeclaration<'ast, T>,
) -> TypedFunctionSymbolDeclaration<'ast, T> {
TypedFunctionSymbolDeclaration {
key: f.fold_declaration_function_key(d.key),
symbol: f.fold_function_symbol(d.symbol),
}
}
pub fn fold_constant_symbol_declaration<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
d: TypedConstantSymbolDeclaration<'ast, T>,
) -> TypedConstantSymbolDeclaration<'ast, T> {
TypedConstantSymbolDeclaration {
id: f.fold_canonical_constant_identifier(d.id),
symbol: f.fold_constant_symbol(d.symbol),
}
}
pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: TypedStatement<'ast, T>,
@ -902,8 +947,8 @@ pub fn fold_block_expression<'ast, T: Field, E: Fold<'ast, T>, F: Folder<'ast, T
pub fn fold_declaration_function_key<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
key: DeclarationFunctionKey<'ast>,
) -> DeclarationFunctionKey<'ast> {
key: DeclarationFunctionKey<'ast, T>,
) -> DeclarationFunctionKey<'ast, T> {
DeclarationFunctionKey {
module: f.fold_module_id(key.module),
signature: f.fold_signature(key.signature),
@ -955,8 +1000,8 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: DeclarationSignature<'ast>,
) -> DeclarationSignature<'ast> {
s: DeclarationSignature<'ast, T>,
) -> DeclarationSignature<'ast, T> {
DeclarationSignature {
generics: s.generics,
inputs: s
@ -972,11 +1017,14 @@ fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>(
}
}
fn fold_declaration_constant<'ast, T: Field, F: Folder<'ast, T>>(
_: &mut F,
c: DeclarationConstant<'ast>,
) -> DeclarationConstant<'ast> {
c
pub fn fold_declaration_constant<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
c: DeclarationConstant<'ast, T>,
) -> DeclarationConstant<'ast, T> {
match c {
DeclarationConstant::Expression(e) => DeclarationConstant::Expression(f.fold_expression(e)),
c => c,
}
}
pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>(
@ -1056,6 +1104,7 @@ pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>(
) -> TypedConstant<'ast, T> {
TypedConstant {
expression: f.fold_expression(c.expression),
ty: f.fold_declaration_type(c.ty),
}
}

View file

@ -1,10 +1,12 @@
use crate::typed_absy::CanonicalConstantIdentifier;
use std::convert::TryInto;
use std::fmt;
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)]
pub enum CoreIdentifier<'ast> {
Source(&'ast str),
Call(usize),
Constant(CanonicalConstantIdentifier<'ast>),
}
impl<'ast> fmt::Display for CoreIdentifier<'ast> {
@ -12,6 +14,7 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> {
match self {
CoreIdentifier::Source(s) => write!(f, "{}", s),
CoreIdentifier::Call(i) => write!(f, "#CALL_RETURN_AT_INDEX_{}", i),
CoreIdentifier::Constant(c) => write!(f, "{}/{}", c.module.display(), c.id),
}
}
}
@ -22,8 +25,14 @@ impl<'ast> From<&'ast str> for CoreIdentifier<'ast> {
}
}
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for CoreIdentifier<'ast> {
fn from(s: CanonicalConstantIdentifier<'ast>) -> CoreIdentifier<'ast> {
CoreIdentifier::Constant(s)
}
}
/// A identifier for a variable
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)]
pub struct Identifier<'ast> {
/// the id of the variable
pub id: CoreIdentifier<'ast>,
@ -52,6 +61,12 @@ impl<'ast> fmt::Display for Identifier<'ast> {
}
}
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for Identifier<'ast> {
fn from(id: CanonicalConstantIdentifier<'ast>) -> Identifier<'ast> {
Identifier::from(CoreIdentifier::Constant(id))
}
}
impl<'ast> From<&'ast str> for Identifier<'ast> {
fn from(id: &'ast str) -> Identifier<'ast> {
Identifier::from(CoreIdentifier::Source(id))

View file

@ -40,7 +40,7 @@ trait IntegerInference: Sized {
}
impl<'ast, T> IntegerInference for Type<'ast, T> {
type Pattern = DeclarationType<'ast>;
type Pattern = DeclarationType<'ast, T>;
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
match (self, other) {
@ -72,7 +72,7 @@ impl<'ast, T> IntegerInference for Type<'ast, T> {
}
impl<'ast, T> IntegerInference for ArrayType<'ast, T> {
type Pattern = DeclarationArrayType<'ast>;
type Pattern = DeclarationArrayType<'ast, T>;
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
let s0 = self.size;
@ -88,7 +88,7 @@ impl<'ast, T> IntegerInference for ArrayType<'ast, T> {
}
impl<'ast, T> IntegerInference for StructType<'ast, T> {
type Pattern = DeclarationStructType<'ast>;
type Pattern = DeclarationStructType<'ast, T>;
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
Ok(DeclarationStructType {
@ -228,7 +228,7 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
pub enum IntExpression<'ast, T> {
Value(BigUint),
Pos(Box<IntExpression<'ast, T>>),
@ -424,7 +424,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
v,
&DeclarationArrayType::new(
DeclarationType::FieldElement,
DeclarationConstant::Concrete(0),
DeclarationConstant::from(0u32),
),
)
.map_err(|(e, _)| match e {
@ -542,7 +542,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
v,
&DeclarationArrayType::new(
DeclarationType::Uint(*bitwidth),
DeclarationConstant::Concrete(0),
DeclarationConstant::from(0u32),
),
)
.map_err(|(e, _)| match e {

View file

@ -20,9 +20,9 @@ pub use self::identifier::CoreIdentifier;
pub use self::parameter::{DeclarationParameter, GParameter};
pub use self::types::{
CanonicalConstantIdentifier, ConcreteFunctionKey, ConcreteSignature, ConcreteType,
ConstantIdentifier, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature,
DeclarationStructType, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
IntoTypes, Signature, StructType, Type, Types, UBitwidth,
ConstantIdentifier, DeclarationArrayType, DeclarationConstant, DeclarationFunctionKey,
DeclarationSignature, DeclarationStructType, DeclarationType, GArrayType, GStructType, GType,
GenericIdentifier, IntoTypes, Signature, StructType, Type, Types, UBitwidth,
};
use crate::typed_absy::types::ConcreteGenericsAssignment;
@ -35,7 +35,7 @@ pub use crate::typed_absy::uint::{bitwidth, UExpression, UExpressionInner, UMeta
use crate::embed::FlatEmbed;
use std::collections::HashMap;
use std::collections::BTreeMap;
use std::convert::{TryFrom, TryInto};
use std::fmt;
@ -54,14 +54,14 @@ pub type OwnedTypedModuleId = PathBuf;
pub type TypedModuleId = Path;
/// A collection of `TypedModule`s
pub type TypedModules<'ast, T> = HashMap<OwnedTypedModuleId, TypedModule<'ast, T>>;
pub type TypedModules<'ast, T> = BTreeMap<OwnedTypedModuleId, TypedModule<'ast, T>>;
/// A collection of `TypedFunctionSymbol`s
/// # Remarks
/// * It is the role of the semantic checker to make sure there are no duplicates for a given `FunctionKey`
/// in a given `TypedModule`, hence the use of a HashMap
/// in a given `TypedModule`, hence the use of a BTreeMap
pub type TypedFunctionSymbols<'ast, T> =
HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>;
BTreeMap<DeclarationFunctionKey<'ast, T>, TypedFunctionSymbol<'ast, T>>;
#[derive(Clone, Debug, PartialEq)]
pub enum TypedConstantSymbol<'ast, T> {
@ -91,12 +91,11 @@ impl<'ast, T> TypedProgram<'ast, T> {
impl<'ast, T: Field> TypedProgram<'ast, T> {
pub fn abi(&self) -> Abi {
let main = self.modules[&self.main]
.functions
.iter()
.find(|(id, _)| id.id == "main")
let main = &self.modules[&self.main]
.functions_iter()
.find(|s| s.key.id == "main")
.unwrap()
.1;
.symbol;
let main = match main {
TypedFunctionSymbol::Here(main) => main,
_ => unreachable!(),
@ -109,7 +108,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
.map(|p| {
types::ConcreteType::try_from(
crate::typed_absy::types::try_from_g_type::<
crate::typed_absy::types::DeclarationConstant<'ast>,
DeclarationConstant<'ast, T>,
UExpression<'ast, T>,
>(p.id._type.clone())
.unwrap(),
@ -129,7 +128,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
.map(|ty| {
types::ConcreteType::try_from(
crate::typed_absy::types::try_from_g_type::<
crate::typed_absy::types::DeclarationConstant<'ast>,
DeclarationConstant<'ast, T>,
UExpression<'ast, T>,
>(ty.clone())
.unwrap(),
@ -163,19 +162,90 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedProgram<'ast, T> {
}
}
#[derive(PartialEq, Debug, Clone)]
pub struct TypedFunctionSymbolDeclaration<'ast, T> {
pub key: DeclarationFunctionKey<'ast, T>,
pub symbol: TypedFunctionSymbol<'ast, T>,
}
impl<'ast, T> TypedFunctionSymbolDeclaration<'ast, T> {
pub fn new(key: DeclarationFunctionKey<'ast, T>, symbol: TypedFunctionSymbol<'ast, T>) -> Self {
TypedFunctionSymbolDeclaration { key, symbol }
}
}
#[derive(PartialEq, Debug, Clone)]
pub struct TypedConstantSymbolDeclaration<'ast, T> {
pub id: CanonicalConstantIdentifier<'ast>,
pub symbol: TypedConstantSymbol<'ast, T>,
}
impl<'ast, T> TypedConstantSymbolDeclaration<'ast, T> {
pub fn new(
id: CanonicalConstantIdentifier<'ast>,
symbol: TypedConstantSymbol<'ast, T>,
) -> Self {
TypedConstantSymbolDeclaration { id, symbol }
}
}
#[derive(PartialEq, Debug, Clone)]
pub enum TypedSymbolDeclaration<'ast, T> {
Function(TypedFunctionSymbolDeclaration<'ast, T>),
Constant(TypedConstantSymbolDeclaration<'ast, T>),
}
impl<'ast, T> From<TypedFunctionSymbolDeclaration<'ast, T>> for TypedSymbolDeclaration<'ast, T> {
fn from(d: TypedFunctionSymbolDeclaration<'ast, T>) -> Self {
Self::Function(d)
}
}
impl<'ast, T> From<TypedConstantSymbolDeclaration<'ast, T>> for TypedSymbolDeclaration<'ast, T> {
fn from(d: TypedConstantSymbolDeclaration<'ast, T>) -> Self {
Self::Constant(d)
}
}
impl<'ast, T: fmt::Display> fmt::Display for TypedSymbolDeclaration<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
TypedSymbolDeclaration::Function(fun) => write!(f, "{}", fun),
TypedSymbolDeclaration::Constant(c) => write!(f, "{}", c),
}
}
}
pub type TypedSymbolDeclarations<'ast, T> = Vec<TypedSymbolDeclaration<'ast, T>>;
/// A typed module as a collection of functions. Types have been resolved during semantic checking.
#[derive(PartialEq, Debug, Clone)]
pub struct TypedModule<'ast, T> {
/// Functions of the module
pub functions: TypedFunctionSymbols<'ast, T>,
/// Constants defined in module
pub constants: TypedConstantSymbols<'ast, T>,
pub symbols: TypedSymbolDeclarations<'ast, T>,
}
impl<'ast, T> TypedModule<'ast, T> {
pub fn functions_iter(&self) -> impl Iterator<Item = &TypedFunctionSymbolDeclaration<'ast, T>> {
self.symbols.iter().filter_map(|s| match s {
TypedSymbolDeclaration::Function(d) => Some(d),
_ => None,
})
}
pub fn into_functions_iter(
self,
) -> impl Iterator<Item = TypedFunctionSymbolDeclaration<'ast, T>> {
self.symbols.into_iter().filter_map(|s| match s {
TypedSymbolDeclaration::Function(d) => Some(d),
_ => None,
})
}
}
#[derive(Clone, PartialEq, Debug)]
pub enum TypedFunctionSymbol<'ast, T> {
Here(TypedFunction<'ast, T>),
There(DeclarationFunctionKey<'ast>),
There(DeclarationFunctionKey<'ast, T>),
Flat(FlatEmbed),
}
@ -183,17 +253,61 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> {
pub fn signature<'a>(
&'a self,
modules: &'a TypedModules<'ast, T>,
) -> DeclarationSignature<'ast> {
) -> DeclarationSignature<'ast, T> {
match self {
TypedFunctionSymbol::Here(f) => f.signature.clone(),
TypedFunctionSymbol::There(key) => modules
.get(&key.module)
.unwrap()
.functions
.get(key)
.functions_iter()
.find(|d| d.key == *key)
.unwrap()
.symbol
.signature(&modules),
TypedFunctionSymbol::Flat(flat_fun) => flat_fun.signature(),
TypedFunctionSymbol::Flat(flat_fun) => flat_fun.typed_signature(),
}
}
}
impl<'ast, T: fmt::Display> fmt::Display for TypedConstantSymbolDeclaration<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.symbol {
TypedConstantSymbol::Here(ref tc) => {
write!(f, "const {} {} = {}", tc.ty, self.id, tc.expression)
}
TypedConstantSymbol::There(ref imported_id) => {
write!(
f,
"from \"{}\" import {} as {}",
imported_id.module.display(),
imported_id.id,
self.id
)
}
}
}
}
impl<'ast, T: fmt::Display> fmt::Display for TypedFunctionSymbolDeclaration<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.symbol {
TypedFunctionSymbol::Here(ref function) => write!(f, "def {}{}", self.key.id, function),
TypedFunctionSymbol::There(ref fun_key) => write!(
f,
"from \"{}\" import {} as {} // with signature {}",
fun_key.module.display(),
fun_key.id,
self.key.id,
self.key.signature
),
TypedFunctionSymbol::Flat(ref flat_fun) => {
write!(
f,
"def {}{}:\n\t// hidden",
self.key.id,
flat_fun.typed_signature::<T>()
)
}
}
}
}
@ -201,34 +315,9 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = self
.constants
.symbols
.iter()
.map(|(id, symbol)| match symbol {
TypedConstantSymbol::Here(ref tc) => {
format!("const {} {} = {}", id.ty, id.id, tc)
}
TypedConstantSymbol::There(ref imported_id) => {
format!(
"from \"{}\" import {} as {}",
imported_id.module.display(),
imported_id.id,
id.id
)
}
})
.chain(self.functions.iter().map(|(key, symbol)| match symbol {
TypedFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function),
TypedFunctionSymbol::There(ref fun_key) => format!(
"from \"{}\" import {} as {} // with signature {}",
fun_key.module.display(),
fun_key.id,
key.id,
key.signature
),
TypedFunctionSymbol::Flat(ref flat_fun) => {
format!("def {}{}:\n\t// hidden", key.id, flat_fun.signature())
}
}))
.map(|s| format!("{}", s))
.collect::<Vec<_>>();
write!(f, "{}", res.join("\n"))
@ -239,11 +328,11 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
#[derive(Clone, PartialEq, Debug, Hash)]
pub struct TypedFunction<'ast, T> {
/// Arguments of the function
pub arguments: Vec<DeclarationParameter<'ast>>,
pub arguments: Vec<DeclarationParameter<'ast, T>>,
/// Vector of statements that are executed when running the function
pub statements: Vec<TypedStatement<'ast, T>>,
/// function signature
pub signature: DeclarationSignature<'ast>,
pub signature: DeclarationSignature<'ast, T>,
}
impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
@ -312,11 +401,12 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Debug)]
pub struct TypedConstant<'ast, T> {
pub expression: TypedExpression<'ast, T>,
pub ty: DeclarationType<'ast, T>,
}
impl<'ast, T> TypedConstant<'ast, T> {
pub fn new(expression: TypedExpression<'ast, T>) -> Self {
TypedConstant { expression }
pub fn new(expression: TypedExpression<'ast, T>, ty: DeclarationType<'ast, T>) -> Self {
TypedConstant { expression, ty }
}
}
@ -333,14 +423,14 @@ impl<'ast, T: Field> Typed<'ast, T> for TypedConstant<'ast, T> {
}
/// Something we can assign to.
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum TypedAssignee<'ast, T> {
Identifier(Variable<'ast, T>),
Select(Box<TypedAssignee<'ast, T>>, Box<UExpression<'ast, T>>),
Member(Box<TypedAssignee<'ast, T>>, MemberId),
}
#[derive(Clone, PartialEq, Hash, Eq, Debug)]
#[derive(Clone, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
pub struct TypedSpread<'ast, T> {
pub array: ArrayExpression<'ast, T>,
}
@ -357,7 +447,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedSpread<'ast, T> {
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum TypedExpressionOrSpread<'ast, T> {
Expression(TypedExpression<'ast, T>),
Spread(TypedSpread<'ast, T>),
@ -487,7 +577,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedAssignee<'ast, T> {
/// A statement in a `TypedFunction`
#[allow(clippy::large_enum_variant)]
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum TypedStatement<'ast, T> {
Return(Vec<TypedExpression<'ast, T>>),
Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
@ -502,7 +592,7 @@ pub enum TypedStatement<'ast, T> {
MultipleDefinition(Vec<TypedAssignee<'ast, T>>, TypedExpressionList<'ast, T>),
// Aux
PushCallLog(
DeclarationFunctionKey<'ast>,
DeclarationFunctionKey<'ast, T>,
ConcreteGenericsAssignment<'ast>,
),
PopCallLog,
@ -575,7 +665,7 @@ pub trait Typed<'ast, T> {
/// A typed expression
#[allow(clippy::large_enum_variant)]
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum TypedExpression<'ast, T> {
Boolean(BooleanExpression<'ast, T>),
FieldElement(FieldElementExpression<'ast, T>),
@ -714,7 +804,7 @@ pub trait MultiTyped<'ast, T> {
fn get_types(&self) -> &Vec<Type<'ast, T>>;
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct TypedExpressionList<'ast, T> {
pub inner: TypedExpressionListInner<'ast, T>,
@ -727,7 +817,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionList<'ast, T> {
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum TypedExpressionListInner<'ast, T> {
FunctionCall(FunctionCallExpression<'ast, T, TypedExpressionList<'ast, T>>),
EmbedCall(FlatEmbed, Vec<u32>, Vec<TypedExpression<'ast, T>>),
@ -744,7 +834,7 @@ impl<'ast, T> TypedExpressionListInner<'ast, T> {
TypedExpressionList { inner: self, types }
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct BlockExpression<'ast, T, E> {
pub statements: Vec<TypedStatement<'ast, T>>,
pub value: Box<E>,
@ -759,7 +849,7 @@ impl<'ast, T, E> BlockExpression<'ast, T, E> {
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct MemberExpression<'ast, T, E> {
pub struc: Box<StructExpression<'ast, T>>,
pub id: MemberId,
@ -782,7 +872,7 @@ impl<'ast, T: fmt::Display, E> fmt::Display for MemberExpression<'ast, T, E> {
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct SelectExpression<'ast, T, E> {
pub array: Box<ArrayExpression<'ast, T>>,
pub index: Box<UExpression<'ast, T>>,
@ -805,7 +895,7 @@ impl<'ast, T: fmt::Display, E> fmt::Display for SelectExpression<'ast, T, E> {
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct IfElseExpression<'ast, T, E> {
pub condition: Box<BooleanExpression<'ast, T>>,
pub consequence: Box<E>,
@ -832,9 +922,9 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for IfElseExpression<'
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct FunctionCallExpression<'ast, T, E> {
pub function_key: DeclarationFunctionKey<'ast>,
pub function_key: DeclarationFunctionKey<'ast, T>,
pub generics: Vec<Option<UExpression<'ast, T>>>,
pub arguments: Vec<TypedExpression<'ast, T>>,
ty: PhantomData<E>,
@ -842,7 +932,7 @@ pub struct FunctionCallExpression<'ast, T, E> {
impl<'ast, T, E> FunctionCallExpression<'ast, T, E> {
pub fn new(
function_key: DeclarationFunctionKey<'ast>,
function_key: DeclarationFunctionKey<'ast, T>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self {
@ -885,7 +975,7 @@ impl<'ast, T: fmt::Display, E> fmt::Display for FunctionCallExpression<'ast, T,
}
/// An expression of type `field`
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum FieldElementExpression<'ast, T> {
Block(BlockExpression<'ast, T, Self>),
Number(T),
@ -962,7 +1052,7 @@ impl<'ast, T> From<T> for FieldElementExpression<'ast, T> {
}
/// An expression of type `bool`
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum BooleanExpression<'ast, T> {
Block(BlockExpression<'ast, T, Self>),
Identifier(Identifier<'ast>),
@ -1027,13 +1117,13 @@ impl<'ast, T> From<bool> for BooleanExpression<'ast, T> {
/// * Contrary to basic types which are represented as enums, we wrap an enum `ArrayExpressionInner` in a struct in order to keep track of the type (content and size)
/// of the array. Only using an enum would require generics, which would propagate up to TypedExpression which we want to keep simple, hence this "runtime"
/// type checking
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct ArrayExpression<'ast, T> {
pub ty: Box<ArrayType<'ast, T>>,
pub inner: ArrayExpressionInner<'ast, T>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
pub struct ArrayValue<'ast, T>(pub Vec<TypedExpressionOrSpread<'ast, T>>);
impl<'ast, T> From<Vec<TypedExpressionOrSpread<'ast, T>>> for ArrayValue<'ast, T> {
@ -1111,7 +1201,7 @@ impl<'ast, T> std::iter::FromIterator<TypedExpressionOrSpread<'ast, T>> for Arra
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum ArrayExpressionInner<'ast, T> {
Block(BlockExpression<'ast, T, ArrayExpression<'ast, T>>),
Identifier(Identifier<'ast>),
@ -1151,7 +1241,7 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> {
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct StructExpression<'ast, T> {
ty: StructType<'ast, T>,
inner: StructExpressionInner<'ast, T>,
@ -1175,7 +1265,7 @@ impl<'ast, T> StructExpression<'ast, T> {
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub enum StructExpressionInner<'ast, T> {
Block(BlockExpression<'ast, T, StructExpression<'ast, T>>),
Identifier(Identifier<'ast>),
@ -1897,7 +1987,7 @@ impl<'ast, T: Field> Id<'ast, T> for TypedExpressionList<'ast, T> {
pub trait FunctionCall<'ast, T>: Expr<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
key: DeclarationFunctionKey<'ast, T>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self::Inner;
@ -1905,7 +1995,7 @@ pub trait FunctionCall<'ast, T>: Expr<'ast, T> {
impl<'ast, T: Field> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
key: DeclarationFunctionKey<'ast, T>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self::Inner {
@ -1915,7 +2005,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> {
impl<'ast, T: Field> FunctionCall<'ast, T> for BooleanExpression<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
key: DeclarationFunctionKey<'ast, T>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self::Inner {
@ -1925,7 +2015,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for BooleanExpression<'ast, T> {
impl<'ast, T: Field> FunctionCall<'ast, T> for UExpression<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
key: DeclarationFunctionKey<'ast, T>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self::Inner {
@ -1935,7 +2025,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for UExpression<'ast, T> {
impl<'ast, T: Field> FunctionCall<'ast, T> for ArrayExpression<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
key: DeclarationFunctionKey<'ast, T>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self::Inner {
@ -1945,7 +2035,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for ArrayExpression<'ast, T> {
impl<'ast, T: Field> FunctionCall<'ast, T> for StructExpression<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
key: DeclarationFunctionKey<'ast, T>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self::Inner {
@ -1955,7 +2045,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for StructExpression<'ast, T> {
impl<'ast, T: Field> FunctionCall<'ast, T> for TypedExpressionList<'ast, T> {
fn function_call(
key: DeclarationFunctionKey<'ast>,
key: DeclarationFunctionKey<'ast, T>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
) -> Self::Inner {

View file

@ -18,12 +18,12 @@ impl<'ast, S> From<GVariable<'ast, S>> for GParameter<'ast, S> {
}
}
pub type DeclarationParameter<'ast> = GParameter<'ast, DeclarationConstant<'ast>>;
pub type DeclarationParameter<'ast, T> = GParameter<'ast, DeclarationConstant<'ast, T>>;
impl<'ast, S: fmt::Display + Clone> fmt::Display for GParameter<'ast, S> {
impl<'ast, S: fmt::Display> fmt::Display for GParameter<'ast, S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let visibility = if self.private { "private " } else { "" };
write!(f, "{}{} {}", visibility, self.id.get_type(), self.id.id)
write!(f, "{}{} {}", visibility, self.id._type, self.id.id)
}
}

View file

@ -55,6 +55,27 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_module(self, m)
}
fn fold_symbol_declaration(
&mut self,
s: TypedSymbolDeclaration<'ast, T>,
) -> Result<TypedSymbolDeclaration<'ast, T>, Self::Error> {
fold_symbol_declaration(self, s)
}
fn fold_function_symbol_declaration(
&mut self,
s: TypedFunctionSymbolDeclaration<'ast, T>,
) -> Result<TypedFunctionSymbolDeclaration<'ast, T>, Self::Error> {
fold_function_symbol_declaration(self, s)
}
fn fold_constant_symbol_declaration(
&mut self,
s: TypedConstantSymbolDeclaration<'ast, T>,
) -> Result<TypedConstantSymbolDeclaration<'ast, T>, Self::Error> {
fold_constant_symbol_declaration(self, s)
}
fn fold_constant(
&mut self,
c: TypedConstant<'ast, T>,
@ -78,8 +99,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fn fold_declaration_function_key(
&mut self,
key: DeclarationFunctionKey<'ast>,
) -> Result<DeclarationFunctionKey<'ast>, Self::Error> {
key: DeclarationFunctionKey<'ast, T>,
) -> Result<DeclarationFunctionKey<'ast, T>, Self::Error> {
fold_declaration_function_key(self, key)
}
@ -92,22 +113,22 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fn fold_signature(
&mut self,
s: DeclarationSignature<'ast>,
) -> Result<DeclarationSignature<'ast>, Self::Error> {
s: DeclarationSignature<'ast, T>,
) -> Result<DeclarationSignature<'ast, T>, Self::Error> {
fold_signature(self, s)
}
fn fold_declaration_constant(
&mut self,
c: DeclarationConstant<'ast>,
) -> Result<DeclarationConstant<'ast>, Self::Error> {
c: DeclarationConstant<'ast, T>,
) -> Result<DeclarationConstant<'ast, T>, Self::Error> {
fold_declaration_constant(self, c)
}
fn fold_parameter(
&mut self,
p: DeclarationParameter<'ast>,
) -> Result<DeclarationParameter<'ast>, Self::Error> {
p: DeclarationParameter<'ast, T>,
) -> Result<DeclarationParameter<'ast, T>, Self::Error> {
Ok(DeclarationParameter {
id: self.fold_declaration_variable(p.id)?,
..p
@ -121,7 +142,6 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
Ok(CanonicalConstantIdentifier {
module: self.fold_module_id(i.module)?,
id: i.id,
ty: box self.fold_declaration_type(*i.ty)?,
})
}
@ -142,8 +162,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fn fold_declaration_variable(
&mut self,
v: DeclarationVariable<'ast>,
) -> Result<DeclarationVariable<'ast>, Self::Error> {
v: DeclarationVariable<'ast, T>,
) -> Result<DeclarationVariable<'ast, T>, Self::Error> {
Ok(DeclarationVariable {
id: self.fold_name(v.id)?,
_type: self.fold_declaration_type(v._type)?,
@ -246,8 +266,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fn fold_declaration_type(
&mut self,
t: DeclarationType<'ast>,
) -> Result<DeclarationType<'ast>, Self::Error> {
t: DeclarationType<'ast, T>,
) -> Result<DeclarationType<'ast, T>, Self::Error> {
use self::GType::*;
match t {
@ -259,8 +279,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fn fold_declaration_array_type(
&mut self,
t: DeclarationArrayType<'ast>,
) -> Result<DeclarationArrayType<'ast>, Self::Error> {
t: DeclarationArrayType<'ast, T>,
) -> Result<DeclarationArrayType<'ast, T>, Self::Error> {
Ok(DeclarationArrayType {
ty: box self.fold_declaration_type(*t.ty)?,
size: self.fold_declaration_constant(t.size)?,
@ -269,8 +289,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fn fold_declaration_struct_type(
&mut self,
t: DeclarationStructType<'ast>,
) -> Result<DeclarationStructType<'ast>, Self::Error> {
t: DeclarationStructType<'ast, T>,
) -> Result<DeclarationStructType<'ast, T>, Self::Error> {
Ok(DeclarationStructType {
generics: t
.generics
@ -977,8 +997,8 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
key: DeclarationFunctionKey<'ast>,
) -> Result<DeclarationFunctionKey<'ast>, F::Error> {
key: DeclarationFunctionKey<'ast, T>,
) -> Result<DeclarationFunctionKey<'ast, T>, F::Error> {
Ok(DeclarationFunctionKey {
module: f.fold_module_id(key.module)?,
signature: f.fold_signature(key.signature)?,
@ -1008,12 +1028,16 @@ pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>(
})
}
fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>(
pub fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
s: DeclarationSignature<'ast>,
) -> Result<DeclarationSignature<'ast>, F::Error> {
s: DeclarationSignature<'ast, T>,
) -> Result<DeclarationSignature<'ast, T>, F::Error> {
Ok(DeclarationSignature {
generics: s.generics,
generics: s
.generics
.into_iter()
.map(|g| g.map(|g| f.fold_declaration_constant(g)).transpose())
.collect::<Result<_, _>>()?,
inputs: s
.inputs
.into_iter()
@ -1027,11 +1051,16 @@ fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>(
})
}
fn fold_declaration_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
_: &mut F,
c: DeclarationConstant<'ast>,
) -> Result<DeclarationConstant<'ast>, F::Error> {
Ok(c)
pub fn fold_declaration_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
c: DeclarationConstant<'ast, T>,
) -> Result<DeclarationConstant<'ast, T>, F::Error> {
match c {
DeclarationConstant::Expression(e) => {
Ok(DeclarationConstant::Expression(f.fold_expression(e)?))
}
c => Ok(c),
}
}
pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
@ -1115,6 +1144,7 @@ pub fn fold_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
) -> Result<TypedConstant<'ast, T>, F::Error> {
Ok(TypedConstant {
expression: f.fold_expression(c.expression)?,
ty: f.fold_declaration_type(c.ty)?,
})
}
@ -1143,20 +1173,49 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
}
}
pub fn fold_symbol_declaration<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
d: TypedSymbolDeclaration<'ast, T>,
) -> Result<TypedSymbolDeclaration<'ast, T>, F::Error> {
Ok(match d {
TypedSymbolDeclaration::Function(d) => {
TypedSymbolDeclaration::Function(f.fold_function_symbol_declaration(d)?)
}
TypedSymbolDeclaration::Constant(d) => {
TypedSymbolDeclaration::Constant(f.fold_constant_symbol_declaration(d)?)
}
})
}
pub fn fold_function_symbol_declaration<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
d: TypedFunctionSymbolDeclaration<'ast, T>,
) -> Result<TypedFunctionSymbolDeclaration<'ast, T>, F::Error> {
Ok(TypedFunctionSymbolDeclaration {
key: f.fold_declaration_function_key(d.key)?,
symbol: f.fold_function_symbol(d.symbol)?,
})
}
pub fn fold_constant_symbol_declaration<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
d: TypedConstantSymbolDeclaration<'ast, T>,
) -> Result<TypedConstantSymbolDeclaration<'ast, T>, F::Error> {
Ok(TypedConstantSymbolDeclaration {
id: f.fold_canonical_constant_identifier(d.id)?,
symbol: f.fold_constant_symbol(d.symbol)?,
})
}
pub fn fold_module<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
m: TypedModule<'ast, T>,
) -> Result<TypedModule<'ast, T>, F::Error> {
Ok(TypedModule {
constants: m
.constants
symbols: m
.symbols
.into_iter()
.map(|(key, tc)| f.fold_constant_symbol(tc).map(|tc| (key, tc)))
.collect::<Result<_, _>>()?,
functions: m
.functions
.into_iter()
.map(|(key, fun)| f.fold_function_symbol(fun).map(|f| (key, f)))
.map(|s| f.fold_symbol_declaration(s))
.collect::<Result<_, _>>()?,
})
}

View file

@ -1,4 +1,6 @@
use crate::typed_absy::{Identifier, OwnedTypedModuleId, UExpression, UExpressionInner};
use crate::typed_absy::{
CoreIdentifier, Identifier, OwnedTypedModuleId, TypedExpression, UExpression, UExpressionInner,
};
use crate::typed_absy::{TryFrom, TryInto};
use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
use std::collections::BTreeMap;
@ -46,7 +48,7 @@ impl<'ast, T> IntoTypes<'ast, T> for Types<'ast, T> {
}
}
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
pub struct Types<'ast, T> {
pub inner: Vec<Type<'ast, T>>,
}
@ -107,69 +109,77 @@ pub type ConstantIdentifier<'ast> = &'ast str;
pub struct CanonicalConstantIdentifier<'ast> {
pub module: OwnedTypedModuleId,
pub id: ConstantIdentifier<'ast>,
pub ty: Box<DeclarationType<'ast>>,
}
impl<'ast> fmt::Display for CanonicalConstantIdentifier<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}/{}", self.module.display(), self.id)
}
}
impl<'ast> CanonicalConstantIdentifier<'ast> {
pub fn new(
id: ConstantIdentifier<'ast>,
module: OwnedTypedModuleId,
ty: DeclarationType<'ast>,
) -> Self {
CanonicalConstantIdentifier {
module,
id,
ty: box ty,
}
pub fn new(id: ConstantIdentifier<'ast>, module: OwnedTypedModuleId) -> Self {
CanonicalConstantIdentifier { module, id }
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum DeclarationConstant<'ast> {
pub enum DeclarationConstant<'ast, T> {
Generic(GenericIdentifier<'ast>),
Concrete(u32),
Constant(CanonicalConstantIdentifier<'ast>),
Expression(TypedExpression<'ast, T>),
}
impl<'ast, T> PartialEq<UExpression<'ast, T>> for DeclarationConstant<'ast> {
impl<'ast, T: PartialEq> PartialEq<UExpression<'ast, T>> for DeclarationConstant<'ast, T> {
fn eq(&self, other: &UExpression<'ast, T>) -> bool {
match (self, other.as_inner()) {
(DeclarationConstant::Concrete(c), UExpressionInner::Value(v)) => *c == *v as u32,
match (self, other) {
(
DeclarationConstant::Concrete(c),
UExpression {
bitwidth: UBitwidth::B32,
inner: UExpressionInner::Value(v),
..
},
) => *c == *v as u32,
(DeclarationConstant::Expression(TypedExpression::Uint(e0)), e1) => e0 == e1,
(DeclarationConstant::Expression(..), _) => false, // type error
_ => true,
}
}
}
impl<'ast, T> PartialEq<DeclarationConstant<'ast>> for UExpression<'ast, T> {
fn eq(&self, other: &DeclarationConstant<'ast>) -> bool {
impl<'ast, T: PartialEq> PartialEq<DeclarationConstant<'ast, T>> for UExpression<'ast, T> {
fn eq(&self, other: &DeclarationConstant<'ast, T>) -> bool {
other.eq(self)
}
}
impl<'ast> From<u32> for DeclarationConstant<'ast> {
impl<'ast, T> From<u32> for DeclarationConstant<'ast, T> {
fn from(e: u32) -> Self {
DeclarationConstant::Concrete(e)
}
}
impl<'ast> From<usize> for DeclarationConstant<'ast> {
impl<'ast, T> From<usize> for DeclarationConstant<'ast, T> {
fn from(e: usize) -> Self {
DeclarationConstant::Concrete(e as u32)
}
}
impl<'ast> From<GenericIdentifier<'ast>> for DeclarationConstant<'ast> {
impl<'ast, T> From<GenericIdentifier<'ast>> for DeclarationConstant<'ast, T> {
fn from(e: GenericIdentifier<'ast>) -> Self {
DeclarationConstant::Generic(e)
}
}
impl<'ast> fmt::Display for DeclarationConstant<'ast> {
impl<'ast, T: fmt::Display> fmt::Display for DeclarationConstant<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
DeclarationConstant::Generic(i) => write!(f, "{}", i),
DeclarationConstant::Concrete(v) => write!(f, "{}", v),
DeclarationConstant::Constant(v) => write!(f, "{}/{}", v.module.display(), v.id),
DeclarationConstant::Expression(e) => write!(f, "{}", e),
}
}
}
@ -180,8 +190,8 @@ impl<'ast, T> From<usize> for UExpression<'ast, T> {
}
}
impl<'ast, T> From<DeclarationConstant<'ast>> for UExpression<'ast, T> {
fn from(c: DeclarationConstant<'ast>) -> Self {
impl<'ast, T> From<DeclarationConstant<'ast, T>> for UExpression<'ast, T> {
fn from(c: DeclarationConstant<'ast, T>) -> Self {
match c {
DeclarationConstant::Generic(i) => {
UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32)
@ -190,8 +200,10 @@ impl<'ast, T> From<DeclarationConstant<'ast>> for UExpression<'ast, T> {
UExpressionInner::Value(v as u128).annotate(UBitwidth::B32)
}
DeclarationConstant::Constant(v) => {
UExpressionInner::Identifier(Identifier::from(v.id)).annotate(UBitwidth::B32)
UExpressionInner::Identifier(CoreIdentifier::from(v).into())
.annotate(UBitwidth::B32)
}
DeclarationConstant::Expression(e) => e.try_into().unwrap(),
}
}
}
@ -209,7 +221,7 @@ impl<'ast, T> TryInto<usize> for UExpression<'ast, T> {
}
}
impl<'ast> TryInto<usize> for DeclarationConstant<'ast> {
impl<'ast, T> TryInto<usize> for DeclarationConstant<'ast, T> {
type Error = SpecializationError;
fn try_into(self) -> Result<usize, Self::Error> {
@ -230,7 +242,7 @@ pub struct GStructMember<S> {
pub ty: Box<GType<S>>,
}
pub type DeclarationStructMember<'ast> = GStructMember<DeclarationConstant<'ast>>;
pub type DeclarationStructMember<'ast, T> = GStructMember<DeclarationConstant<'ast, T>>;
pub type ConcreteStructMember = GStructMember<usize>;
pub type StructMember<'ast, T> = GStructMember<UExpression<'ast, T>>;
@ -270,7 +282,7 @@ pub struct GArrayType<S> {
pub ty: Box<GType<S>>,
}
pub type DeclarationArrayType<'ast> = GArrayType<DeclarationConstant<'ast>>;
pub type DeclarationArrayType<'ast, T> = GArrayType<DeclarationConstant<'ast, T>>;
pub type ConcreteArrayType = GArrayType<usize>;
pub type ArrayType<'ast, T> = GArrayType<UExpression<'ast, T>>;
@ -336,7 +348,7 @@ pub struct StructLocation {
pub name: String,
}
impl<'ast> From<ConcreteArrayType> for DeclarationArrayType<'ast> {
impl<'ast, T> From<ConcreteArrayType> for DeclarationArrayType<'ast, T> {
fn from(t: ConcreteArrayType) -> Self {
try_from_g_array_type(t).unwrap()
}
@ -352,7 +364,7 @@ pub struct GStructType<S> {
pub members: Vec<GStructMember<S>>,
}
pub type DeclarationStructType<'ast> = GStructType<DeclarationConstant<'ast>>;
pub type DeclarationStructType<'ast, T> = GStructType<DeclarationConstant<'ast, T>>;
pub type ConcreteStructType = GStructType<usize>;
pub type StructType<'ast, T> = GStructType<UExpression<'ast, T>>;
@ -416,7 +428,7 @@ impl<'ast, T> From<ConcreteStructType> for StructType<'ast, T> {
}
}
impl<'ast> From<ConcreteStructType> for DeclarationStructType<'ast> {
impl<'ast, T> From<ConcreteStructType> for DeclarationStructType<'ast, T> {
fn from(t: ConcreteStructType) -> Self {
try_from_g_struct_type(t).unwrap()
}
@ -609,7 +621,7 @@ impl<'de, S: Deserialize<'de>> Deserialize<'de> for GType<S> {
}
}
pub type DeclarationType<'ast> = GType<DeclarationConstant<'ast>>;
pub type DeclarationType<'ast, T> = GType<DeclarationConstant<'ast, T>>;
pub type ConcreteType = GType<usize>;
pub type Type<'ast, T> = GType<UExpression<'ast, T>>;
@ -652,7 +664,7 @@ impl<'ast, T> From<ConcreteType> for Type<'ast, T> {
}
}
impl<'ast> From<ConcreteType> for DeclarationType<'ast> {
impl<'ast, T> From<ConcreteType> for DeclarationType<'ast, T> {
fn from(t: ConcreteType) -> Self {
try_from_g_type(t).unwrap()
}
@ -738,7 +750,7 @@ impl<S> GType<S> {
}
impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
pub fn can_be_specialized_to(&self, other: &DeclarationType) -> bool {
pub fn can_be_specialized_to(&self, other: &DeclarationType<'ast, T>) -> bool {
use self::GType::*;
if other == self {
@ -811,14 +823,14 @@ impl ConcreteType {
pub type FunctionIdentifier<'ast> = &'ast str;
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
#[derive(PartialEq, Eq, Hash, Debug, Clone, PartialOrd, Ord)]
pub struct GFunctionKey<'ast, S> {
pub module: OwnedTypedModuleId,
pub id: FunctionIdentifier<'ast>,
pub signature: GSignature<S>,
}
pub type DeclarationFunctionKey<'ast> = GFunctionKey<'ast, DeclarationConstant<'ast>>;
pub type DeclarationFunctionKey<'ast, T> = GFunctionKey<'ast, DeclarationConstant<'ast, T>>;
pub type ConcreteFunctionKey<'ast> = GFunctionKey<'ast, usize>;
pub type FunctionKey<'ast, T> = GFunctionKey<'ast, UExpression<'ast, T>>;
@ -828,7 +840,7 @@ impl<'ast, S: fmt::Display> fmt::Display for GFunctionKey<'ast, S> {
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
pub struct GGenericsAssignment<'ast, S>(pub BTreeMap<GenericIdentifier<'ast>, S>);
pub type ConcreteGenericsAssignment<'ast> = GGenericsAssignment<'ast, usize>;
@ -854,8 +866,8 @@ impl<'ast, S: fmt::Display> fmt::Display for GGenericsAssignment<'ast, S> {
}
}
impl<'ast> PartialEq<DeclarationFunctionKey<'ast>> for ConcreteFunctionKey<'ast> {
fn eq(&self, other: &DeclarationFunctionKey<'ast>) -> bool {
impl<'ast, T> PartialEq<DeclarationFunctionKey<'ast, T>> for ConcreteFunctionKey<'ast> {
fn eq(&self, other: &DeclarationFunctionKey<'ast, T>) -> bool {
self.module == other.module && self.id == other.id && self.signature == other.signature
}
}
@ -884,7 +896,7 @@ impl<'ast, T> From<ConcreteFunctionKey<'ast>> for FunctionKey<'ast, T> {
}
}
impl<'ast> From<ConcreteFunctionKey<'ast>> for DeclarationFunctionKey<'ast> {
impl<'ast, T> From<ConcreteFunctionKey<'ast>> for DeclarationFunctionKey<'ast, T> {
fn from(k: ConcreteFunctionKey<'ast>) -> Self {
try_from_g_function_key(k).unwrap()
}
@ -931,8 +943,8 @@ impl<'ast> ConcreteFunctionKey<'ast> {
use std::collections::btree_map::Entry;
pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
decl_ty: &DeclarationType<'ast>,
pub fn check_type<'ast, T, S: Clone + PartialEq + PartialEq<usize>>(
decl_ty: &DeclarationType<'ast, T>,
ty: &GType<S>,
constants: &mut GGenericsAssignment<'ast, S>,
) -> bool {
@ -953,9 +965,9 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
}
},
DeclarationConstant::Concrete(s0) => s1 == *s0 as usize,
// in the case of a constant, we do not know the value yet, so we optimistically assume it's correct
// in the other cases, we do not know the value yet, so we optimistically assume it's correct
// if it does not match, it will be caught during inlining
DeclarationConstant::Constant(..) => true,
DeclarationConstant::Constant(..) | DeclarationConstant::Expression(..) => true,
}
}
(DeclarationType::FieldElement, GType::FieldElement)
@ -963,6 +975,11 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
(DeclarationType::Uint(b0), GType::Uint(b1)) => b0 == b1,
(DeclarationType::Struct(s0), GType::Struct(s1)) => {
s0.canonical_location == s1.canonical_location
&& s0
.members
.iter()
.zip(s1.members.iter())
.all(|(m0, m1)| check_type(&*m0.ty, &*m1.ty, constants))
}
_ => false,
}
@ -970,16 +987,12 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
impl<'ast, T> From<CanonicalConstantIdentifier<'ast>> for UExpression<'ast, T> {
fn from(c: CanonicalConstantIdentifier<'ast>) -> Self {
let bitwidth = match *c.ty {
DeclarationType::Uint(bitwidth) => bitwidth,
_ => unreachable!(),
};
UExpressionInner::Identifier(Identifier::from(c.id)).annotate(bitwidth)
UExpressionInner::Identifier(Identifier::from(CoreIdentifier::Constant(c)))
.annotate(UBitwidth::B32)
}
}
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for DeclarationConstant<'ast> {
impl<'ast, T> From<CanonicalConstantIdentifier<'ast>> for DeclarationConstant<'ast, T> {
fn from(c: CanonicalConstantIdentifier<'ast>) -> Self {
DeclarationConstant::Constant(c)
}
@ -987,21 +1000,21 @@ impl<'ast> From<CanonicalConstantIdentifier<'ast>> for DeclarationConstant<'ast>
pub fn specialize_declaration_type<
'ast,
T,
S: Clone + PartialEq + From<u32> + fmt::Debug + From<CanonicalConstantIdentifier<'ast>>,
>(
decl_ty: DeclarationType<'ast>,
decl_ty: DeclarationType<'ast, T>,
generics: &GGenericsAssignment<'ast, S>,
) -> Result<GType<S>, GenericIdentifier<'ast>> {
Ok(match decl_ty {
DeclarationType::Int => unreachable!(),
DeclarationType::Array(t0) => {
// let s1 = t1.size.clone();
let ty = box specialize_declaration_type(*t0.ty, &generics)?;
let size = match t0.size {
DeclarationConstant::Generic(s) => generics.0.get(&s).cloned().ok_or(s),
DeclarationConstant::Concrete(s) => Ok(s.into()),
DeclarationConstant::Constant(c) => Ok(c.into()),
DeclarationConstant::Expression(..) => unreachable!("the semantic checker should not yield this DeclarationConstant variant")
}?;
GType::Array(GArrayType { size, ty })
@ -1028,11 +1041,8 @@ pub fn specialize_declaration_type<
generics.0.get(&s).cloned().ok_or(s).map(Some)
}
DeclarationConstant::Concrete(s) => Ok(Some(s.into())),
DeclarationConstant::Constant(..) => {
unreachable!(
"identifiers should have been removed in constant inlining"
)
}
DeclarationConstant::Constant(c) => Ok(Some(c.into())),
DeclarationConstant::Expression(..) => unreachable!("the semantic checker should not yield this DeclarationConstant variant"),
},
_ => Ok(None),
})
@ -1096,12 +1106,12 @@ pub mod signature {
}
}
pub type DeclarationSignature<'ast> = GSignature<DeclarationConstant<'ast>>;
pub type DeclarationSignature<'ast, T> = GSignature<DeclarationConstant<'ast, T>>;
pub type ConcreteSignature = GSignature<usize>;
pub type Signature<'ast, T> = GSignature<UExpression<'ast, T>>;
impl<'ast> PartialEq<DeclarationSignature<'ast>> for ConcreteSignature {
fn eq(&self, other: &DeclarationSignature<'ast>) -> bool {
impl<'ast, T> PartialEq<DeclarationSignature<'ast, T>> for ConcreteSignature {
fn eq(&self, other: &DeclarationSignature<'ast, T>) -> bool {
// we keep track of the value of constants in a map, as a given constant can only have one value
let mut constants = ConcreteGenericsAssignment::default();
@ -1110,11 +1120,11 @@ pub mod signature {
.iter()
.chain(other.outputs.iter())
.zip(self.inputs.iter().chain(self.outputs.iter()))
.all(|(decl_ty, ty)| check_type::<usize>(decl_ty, ty, &mut constants))
.all(|(decl_ty, ty)| check_type::<T, usize>(decl_ty, ty, &mut constants))
}
}
impl<'ast> DeclarationSignature<'ast> {
impl<'ast, T: Clone + PartialEq + fmt::Debug> DeclarationSignature<'ast, T> {
pub fn specialize(
&self,
values: Vec<Option<u32>>,
@ -1155,7 +1165,7 @@ pub mod signature {
}
}
pub fn get_output_types<T: Clone + PartialEq + fmt::Debug>(
pub fn get_output_types(
&self,
generics: Vec<Option<UExpression<'ast, T>>>,
inputs: Vec<Type<'ast, T>>,
@ -1234,7 +1244,7 @@ pub mod signature {
}
}
impl<'ast> From<ConcreteSignature> for DeclarationSignature<'ast> {
impl<'ast, T> From<ConcreteSignature> for DeclarationSignature<'ast, T> {
fn from(s: ConcreteSignature) -> Self {
try_from_g_signature(s).unwrap()
}
@ -1349,6 +1359,7 @@ pub mod signature {
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::Bn128Field;
#[test]
fn signature() {
@ -1365,7 +1376,7 @@ pub mod signature {
// <P>(field[P])
// <Q>(field[Q])
let generic1 = DeclarationSignature::new()
let generic1 = DeclarationSignature::<Bn128Field>::new()
.generics(vec![Some(
GenericIdentifier {
name: "P",

View file

@ -133,13 +133,13 @@ impl<'ast, T: Field> From<&'ast str> for UExpressionInner<'ast, T> {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct UMetadata {
pub bitwidth: Option<Bitwidth>,
pub should_reduce: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct UExpression<'ast, T> {
pub bitwidth: UBitwidth,
pub metadata: Option<UMetadata>,
@ -173,7 +173,7 @@ impl<'ast, T> PartialEq<usize> for UExpression<'ast, T> {
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
pub enum UExpressionInner<'ast, T> {
Block(BlockExpression<'ast, T, UExpression<'ast, T>>),
Identifier(Identifier<'ast>),

View file

@ -5,13 +5,13 @@ use crate::typed_absy::UExpression;
use crate::typed_absy::{TryFrom, TryInto};
use std::fmt;
#[derive(Clone, PartialEq, Hash, Eq)]
#[derive(Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
pub struct GVariable<'ast, S> {
pub id: Identifier<'ast>,
pub _type: GType<S>,
}
pub type DeclarationVariable<'ast> = GVariable<'ast, DeclarationConstant<'ast>>;
pub type DeclarationVariable<'ast, T> = GVariable<'ast, DeclarationConstant<'ast, T>>;
pub type ConcreteVariable<'ast> = GVariable<'ast, usize>;
pub type Variable<'ast, T> = GVariable<'ast, UExpression<'ast, T>>;