Merge pull request #975 from Zokrates/allow-calls-in-constants
Allow calls in constants
This commit is contained in:
commit
5f2d65124b
34 changed files with 2317 additions and 1597 deletions
1
changelogs/unreleased/975-schaeff
Normal file
1
changelogs/unreleased/975-schaeff
Normal file
|
@ -0,0 +1 @@
|
|||
Allow calls in constant definitions
|
11
zokrates_cli/examples/array_generic_inference.zok
Normal file
11
zokrates_cli/examples/array_generic_inference.zok
Normal file
|
@ -0,0 +1,11 @@
|
|||
def myFct<N, N2>(u64[N] ignored) -> u64[N2]:
|
||||
assert(2*N == N2)
|
||||
return [0; N2]
|
||||
|
||||
const u32 N = 3
|
||||
|
||||
const u32 N2 = 2*N
|
||||
|
||||
def main(u64[N] arg) -> bool:
|
||||
u64[N2] someVariable = myFct(arg)
|
||||
return true
|
10
zokrates_cli/examples/call_in_const.zok
Normal file
10
zokrates_cli/examples/call_in_const.zok
Normal file
|
@ -0,0 +1,10 @@
|
|||
from "./call_in_const_aux.zok" import A, foo, F
|
||||
|
||||
def bar(field[A] x) -> field[A]:
|
||||
return x
|
||||
|
||||
const field[A] Y = [...bar(foo::<A>(F))[..A - 1], 1]
|
||||
|
||||
def main(field[A] X):
|
||||
assert(X == Y)
|
||||
return
|
9
zokrates_cli/examples/call_in_const_aux.zok
Normal file
9
zokrates_cli/examples/call_in_const_aux.zok
Normal file
|
@ -0,0 +1,9 @@
|
|||
const field F = 10
|
||||
const u32 A = 10
|
||||
const u32 B = A
|
||||
|
||||
def foo<N>(field X) -> field[N]:
|
||||
return [X; N]
|
||||
|
||||
def main():
|
||||
return
|
7
zokrates_cli/examples/call_in_constant.zok
Normal file
7
zokrates_cli/examples/call_in_constant.zok
Normal file
|
@ -0,0 +1,7 @@
|
|||
def yes() -> bool:
|
||||
return true
|
||||
|
||||
const bool TRUE = yes()
|
||||
|
||||
def main():
|
||||
return
|
|
@ -0,0 +1,14 @@
|
|||
// this should not compile, as A == B
|
||||
|
||||
const u32 A = 1
|
||||
const u32 B = 1
|
||||
|
||||
def foo(field[A] a) -> bool:
|
||||
return true
|
||||
|
||||
def foo(field[B] a) -> bool:
|
||||
return true
|
||||
|
||||
def main():
|
||||
assert(foo([1]))
|
||||
return
|
|
@ -0,0 +1,14 @@
|
|||
// this should actually compile, as A != B
|
||||
|
||||
const u32 A = 2
|
||||
const u32 B = 1
|
||||
|
||||
def foo(field[A] a) -> bool:
|
||||
return true
|
||||
|
||||
def foo(field[B] a) -> bool:
|
||||
return true
|
||||
|
||||
def main():
|
||||
assert(foo([1]))
|
||||
return
|
|
@ -0,0 +1,6 @@
|
|||
from "EMBED" import bit_array_le
|
||||
|
||||
const bool CONST = bit_array_le([true], [true])
|
||||
|
||||
def main() -> bool:
|
||||
return CONST
|
14
zokrates_cli/examples/complex_call_in_constant.zok
Normal file
14
zokrates_cli/examples/complex_call_in_constant.zok
Normal file
|
@ -0,0 +1,14 @@
|
|||
def constant() -> u32:
|
||||
u32 res = 0
|
||||
u32 x = 3
|
||||
for u32 y in 0..x do
|
||||
res = res + 1
|
||||
endfor
|
||||
return res
|
||||
|
||||
const u32 CONSTANT = 1 + constant()
|
||||
|
||||
const u32 OTHER_CONSTANT = 42
|
||||
|
||||
def main(field[CONSTANT] a) -> u32:
|
||||
return CONSTANT + OTHER_CONSTANT
|
15
zokrates_cli/examples/struct_generic_inference.zok
Normal file
15
zokrates_cli/examples/struct_generic_inference.zok
Normal file
|
@ -0,0 +1,15 @@
|
|||
struct SomeStruct<N> {
|
||||
u64[N] f
|
||||
}
|
||||
|
||||
def myFct<N, N2, N3>(SomeStruct<N> ignored) -> u32[N2]:
|
||||
assert(2*N == N2)
|
||||
return [N3; N2]
|
||||
|
||||
const u32 N = 3
|
||||
|
||||
const u32 N2 = 2*N
|
||||
|
||||
def main(SomeStruct<N> arg) -> u32:
|
||||
u32[N2] someVariable = myFct::<_, _, 42>(arg)
|
||||
return someVariable[0]
|
|
@ -9,6 +9,16 @@ pub struct Node<T> {
|
|||
pub value: T,
|
||||
}
|
||||
|
||||
impl<T> Node<T> {
|
||||
pub fn mock(e: T) -> Self {
|
||||
Self {
|
||||
start: Position::mock(),
|
||||
end: Position::mock(),
|
||||
value: e,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: fmt::Display> Node<T> {
|
||||
pub fn pos(&self) -> (Position, Position) {
|
||||
(self.start, self.end)
|
||||
|
@ -67,8 +77,7 @@ pub trait NodeValue: fmt::Display + fmt::Debug + Sized + PartialEq {
|
|||
|
||||
impl<V: NodeValue> From<V> for Node<V> {
|
||||
fn from(v: V) -> Node<V> {
|
||||
let mock_position = Position { col: 42, line: 42 };
|
||||
Node::new(mock_position, mock_position, v)
|
||||
Node::new(Position::mock(), Position::mock(), v)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -249,6 +249,8 @@ fn check_with_arena<'ast, T: Field, E: Into<imports::Error>>(
|
|||
let typed_ast = Checker::check(compiled)
|
||||
.map_err(|errors| CompileErrors(errors.into_iter().map(CompileError::from).collect()))?;
|
||||
|
||||
log::trace!("\n{}", typed_ast);
|
||||
|
||||
let main_module = typed_ast.main.clone();
|
||||
|
||||
log::debug!("Run static analysis");
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
use crate::absy::{
|
||||
types::{UnresolvedSignature, UnresolvedType},
|
||||
ConstantGenericNode, Expression,
|
||||
};
|
||||
use crate::flat_absy::{
|
||||
FlatDirective, FlatExpression, FlatExpressionList, FlatFunction, FlatParameter, FlatStatement,
|
||||
FlatVariable, RuntimeError,
|
||||
|
@ -26,7 +30,7 @@ cfg_if::cfg_if! {
|
|||
|
||||
/// A low level function that contains non-deterministic introduction of variables. It is carried out as is until
|
||||
/// the flattening step when it can be inlined.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)]
|
||||
pub enum FlatEmbed {
|
||||
BitArrayLe,
|
||||
Unpack,
|
||||
|
@ -45,7 +49,131 @@ pub enum FlatEmbed {
|
|||
}
|
||||
|
||||
impl FlatEmbed {
|
||||
pub fn signature(&self) -> DeclarationSignature<'static> {
|
||||
pub fn signature(&self) -> UnresolvedSignature {
|
||||
match self {
|
||||
FlatEmbed::BitArrayLe => UnresolvedSignature::new()
|
||||
.generics(vec![ConstantGenericNode::mock("N")])
|
||||
.inputs(vec![
|
||||
UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::Identifier("N").into(),
|
||||
)
|
||||
.into(),
|
||||
UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::Identifier("N").into(),
|
||||
)
|
||||
.into(),
|
||||
])
|
||||
.outputs(vec![UnresolvedType::Boolean.into()]),
|
||||
FlatEmbed::Unpack => UnresolvedSignature::new()
|
||||
.generics(vec!["N".into()])
|
||||
.inputs(vec![UnresolvedType::FieldElement.into()])
|
||||
.outputs(vec![UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::Identifier("N").into(),
|
||||
)
|
||||
.into()]),
|
||||
FlatEmbed::U8ToBits => UnresolvedSignature::new()
|
||||
.inputs(vec![UnresolvedType::Uint(8).into()])
|
||||
.outputs(vec![UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::U32Constant(8).into(),
|
||||
)
|
||||
.into()]),
|
||||
FlatEmbed::U16ToBits => UnresolvedSignature::new()
|
||||
.inputs(vec![UnresolvedType::Uint(16).into()])
|
||||
.outputs(vec![UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::U32Constant(16).into(),
|
||||
)
|
||||
.into()]),
|
||||
FlatEmbed::U32ToBits => UnresolvedSignature::new()
|
||||
.inputs(vec![UnresolvedType::Uint(32).into()])
|
||||
.outputs(vec![UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::U32Constant(32).into(),
|
||||
)
|
||||
.into()]),
|
||||
FlatEmbed::U64ToBits => UnresolvedSignature::new()
|
||||
.inputs(vec![UnresolvedType::Uint(64).into()])
|
||||
.outputs(vec![UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::U32Constant(64).into(),
|
||||
)
|
||||
.into()]),
|
||||
FlatEmbed::U8FromBits => UnresolvedSignature::new()
|
||||
.outputs(vec![UnresolvedType::Uint(8).into()])
|
||||
.inputs(vec![UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::U32Constant(8).into(),
|
||||
)
|
||||
.into()]),
|
||||
FlatEmbed::U16FromBits => UnresolvedSignature::new()
|
||||
.outputs(vec![UnresolvedType::Uint(16).into()])
|
||||
.inputs(vec![UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::U32Constant(16).into(),
|
||||
)
|
||||
.into()]),
|
||||
FlatEmbed::U32FromBits => UnresolvedSignature::new()
|
||||
.outputs(vec![UnresolvedType::Uint(32).into()])
|
||||
.inputs(vec![UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::U32Constant(32).into(),
|
||||
)
|
||||
.into()]),
|
||||
FlatEmbed::U64FromBits => UnresolvedSignature::new()
|
||||
.outputs(vec![UnresolvedType::Uint(64).into()])
|
||||
.inputs(vec![UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::U32Constant(64).into(),
|
||||
)
|
||||
.into()]),
|
||||
#[cfg(feature = "bellman")]
|
||||
FlatEmbed::Sha256Round => UnresolvedSignature::new()
|
||||
.inputs(vec![
|
||||
UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::U32Constant(512).into(),
|
||||
)
|
||||
.into(),
|
||||
UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::U32Constant(256).into(),
|
||||
)
|
||||
.into(),
|
||||
])
|
||||
.outputs(vec![UnresolvedType::array(
|
||||
UnresolvedType::Boolean.into(),
|
||||
Expression::U32Constant(256).into(),
|
||||
)
|
||||
.into()]),
|
||||
#[cfg(feature = "ark")]
|
||||
FlatEmbed::SnarkVerifyBls12377 => UnresolvedSignature::new()
|
||||
.generics(vec!["N".into(), "V".into()])
|
||||
.inputs(vec![
|
||||
UnresolvedType::array(
|
||||
UnresolvedType::FieldElement.into(),
|
||||
Expression::Identifier("N").into(),
|
||||
)
|
||||
.into(), // inputs
|
||||
UnresolvedType::array(
|
||||
UnresolvedType::FieldElement.into(),
|
||||
Expression::U32Constant(8).into(),
|
||||
)
|
||||
.into(), // proof
|
||||
UnresolvedType::array(
|
||||
UnresolvedType::FieldElement.into(),
|
||||
Expression::Identifier("V").into(),
|
||||
)
|
||||
.into(), // 18 + (2 * n) // vk
|
||||
])
|
||||
.outputs(vec![UnresolvedType::Boolean.into()]),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn typed_signature<T>(&self) -> DeclarationSignature<'static, T> {
|
||||
match self {
|
||||
FlatEmbed::BitArrayLe => DeclarationSignature::new()
|
||||
.generics(vec![Some(DeclarationConstant::Generic(
|
||||
|
@ -177,15 +305,13 @@ impl FlatEmbed {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn generics<'ast>(&self, assignment: &ConcreteGenericsAssignment<'ast>) -> Vec<u32> {
|
||||
let gen = self
|
||||
.signature()
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|c| match c.unwrap() {
|
||||
pub fn generics<'ast, T>(&self, assignment: &ConcreteGenericsAssignment<'ast>) -> Vec<u32> {
|
||||
let gen = self.typed_signature().generics.into_iter().map(
|
||||
|c: Option<DeclarationConstant<'ast, T>>| match c.unwrap() {
|
||||
DeclarationConstant::Generic(g) => g,
|
||||
_ => unreachable!(),
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
assert_eq!(gen.len(), assignment.0.len());
|
||||
gen.map(|g| *assignment.0.get(&g).unwrap() as u32).collect()
|
||||
|
|
|
@ -15,7 +15,6 @@ impl Position {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn mock() -> Self {
|
||||
Position { line: 42, col: 42 }
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,973 +0,0 @@
|
|||
use crate::static_analysis::Propagator;
|
||||
use crate::typed_absy::result_folder::*;
|
||||
use crate::typed_absy::types::DeclarationConstant;
|
||||
use crate::typed_absy::*;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
use std::fmt;
|
||||
use zokrates_field::Field;
|
||||
|
||||
type ProgramConstants<'ast, T> =
|
||||
HashMap<OwnedTypedModuleId, HashMap<Identifier<'ast>, TypedExpression<'ast, T>>>;
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Error {
|
||||
Type(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Error::Type(s) => write!(f, "{}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub struct ConstantInliner<'ast, T> {
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
constants: ProgramConstants<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
||||
pub fn new(
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
constants: ProgramConstants<'ast, T>,
|
||||
) -> Self {
|
||||
ConstantInliner {
|
||||
modules,
|
||||
location,
|
||||
constants,
|
||||
}
|
||||
}
|
||||
pub fn inline(p: TypedProgram<'ast, T>) -> Result<TypedProgram<'ast, T>, Error> {
|
||||
let constants = ProgramConstants::new();
|
||||
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants);
|
||||
inliner.fold_program(p)
|
||||
}
|
||||
|
||||
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
|
||||
let prev = self.location.clone();
|
||||
self.location = location;
|
||||
self.constants.entry(self.location.clone()).or_default();
|
||||
prev
|
||||
}
|
||||
|
||||
fn treated(&self, id: &TypedModuleId) -> bool {
|
||||
self.constants.contains_key(id)
|
||||
}
|
||||
|
||||
fn get_constant(
|
||||
&self,
|
||||
id: &CanonicalConstantIdentifier<'ast>,
|
||||
) -> Option<TypedExpression<'ast, T>> {
|
||||
self.constants
|
||||
.get(&id.module)
|
||||
.and_then(|constants| constants.get(&id.id.into()))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
fn get_constant_for_identifier(
|
||||
&self,
|
||||
id: &Identifier<'ast>,
|
||||
) -> Option<TypedExpression<'ast, T>> {
|
||||
self.constants
|
||||
.get(&self.location)
|
||||
.and_then(|constants| constants.get(&id))
|
||||
.cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_program(
|
||||
&mut self,
|
||||
p: TypedProgram<'ast, T>,
|
||||
) -> Result<TypedProgram<'ast, T>, Self::Error> {
|
||||
self.fold_module_id(p.main.clone())?;
|
||||
|
||||
Ok(TypedProgram {
|
||||
modules: std::mem::take(&mut self.modules),
|
||||
..p
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_module_id(
|
||||
&mut self,
|
||||
id: OwnedTypedModuleId,
|
||||
) -> Result<OwnedTypedModuleId, Self::Error> {
|
||||
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
|
||||
if !self.treated(&id) {
|
||||
let current_m_id = self.change_location(id.clone());
|
||||
let m = self.modules.remove(&id).unwrap();
|
||||
let m = self.fold_module(m)?;
|
||||
self.modules.insert(id.clone(), m);
|
||||
self.change_location(current_m_id);
|
||||
}
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
fn fold_module(
|
||||
&mut self,
|
||||
m: TypedModule<'ast, T>,
|
||||
) -> Result<TypedModule<'ast, T>, Self::Error> {
|
||||
Ok(TypedModule {
|
||||
constants: m
|
||||
.constants
|
||||
.into_iter()
|
||||
.map(|(id, tc)| {
|
||||
|
||||
let id = self.fold_canonical_constant_identifier(id)?;
|
||||
|
||||
let constant = match tc {
|
||||
TypedConstantSymbol::There(imported_id) => {
|
||||
// visit the imported symbol. This triggers visiting the corresponding module if needed
|
||||
let imported_id = self.fold_canonical_constant_identifier(imported_id)?;
|
||||
// after that, the constant must have been defined defined in the global map. It is already reduced
|
||||
// to a literal, so running propagation isn't required
|
||||
self.get_constant(&imported_id).unwrap()
|
||||
}
|
||||
TypedConstantSymbol::Here(c) => {
|
||||
let non_propagated_constant = fold_constant(self, c)?.expression;
|
||||
// folding the constant above only reduces it to an expression containing only literals, not to a single literal.
|
||||
// propagating with an empty map of constants reduces it to a single literal
|
||||
Propagator::with_constants(&mut HashMap::default())
|
||||
.fold_expression(non_propagated_constant)
|
||||
.unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>(*id.ty.clone()).unwrap() == constant.get_type() {
|
||||
// add to the constant map. The value added is always a single litteral
|
||||
self.constants
|
||||
.get_mut(&self.location)
|
||||
.unwrap()
|
||||
.insert(id.id.into(), constant.clone());
|
||||
|
||||
Ok((
|
||||
id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
expression: constant,
|
||||
}),
|
||||
))
|
||||
} else {
|
||||
Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant.get_type(), id.id, id.ty)))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
functions: m
|
||||
.functions
|
||||
.into_iter()
|
||||
.map::<Result<_, Self::Error>, _>(|(key, fun)| {
|
||||
Ok((
|
||||
self.fold_declaration_function_key(key)?,
|
||||
self.fold_function_symbol(fun)?,
|
||||
))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_declaration_constant(
|
||||
&mut self,
|
||||
c: DeclarationConstant<'ast>,
|
||||
) -> Result<DeclarationConstant<'ast>, Self::Error> {
|
||||
match c {
|
||||
// replace constants by their concrete value in declaration types
|
||||
DeclarationConstant::Constant(id) => {
|
||||
let id = CanonicalConstantIdentifier {
|
||||
module: self.fold_module_id(id.module)?,
|
||||
..id
|
||||
};
|
||||
|
||||
Ok(DeclarationConstant::Concrete(match self.get_constant(&id).unwrap() {
|
||||
TypedExpression::Uint(UExpression {
|
||||
inner: UExpressionInner::Value(v),
|
||||
..
|
||||
}) => v as u32,
|
||||
_ => unreachable!("all constants found in declaration types should be reduceable to u32 literals"),
|
||||
}))
|
||||
}
|
||||
c => Ok(c),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
&mut self,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
FieldElementExpression::Identifier(ref id) => {
|
||||
match self.get_constant_for_identifier(id) {
|
||||
Some(c) => Ok(c.try_into().unwrap()),
|
||||
None => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_boolean_expression(
|
||||
&mut self,
|
||||
e: BooleanExpression<'ast, T>,
|
||||
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
BooleanExpression::Identifier(ref id) => match self.get_constant_for_identifier(id) {
|
||||
Some(c) => Ok(c.try_into().unwrap()),
|
||||
None => fold_boolean_expression(self, e),
|
||||
},
|
||||
e => fold_boolean_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
size: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
UExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) {
|
||||
Some(c) => {
|
||||
let e: UExpression<'ast, T> = c.try_into().unwrap();
|
||||
Ok(e.into_inner())
|
||||
}
|
||||
None => fold_uint_expression_inner(self, size, e),
|
||||
},
|
||||
e => fold_uint_expression_inner(self, size, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_array_expression_inner(
|
||||
&mut self,
|
||||
ty: &ArrayType<'ast, T>,
|
||||
e: ArrayExpressionInner<'ast, T>,
|
||||
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
ArrayExpressionInner::Identifier(ref id) => {
|
||||
match self.get_constant_for_identifier(id) {
|
||||
Some(c) => {
|
||||
let e: ArrayExpression<'ast, T> = c.try_into().unwrap();
|
||||
Ok(e.into_inner())
|
||||
}
|
||||
None => fold_array_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
e => fold_array_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_struct_expression_inner(
|
||||
&mut self,
|
||||
ty: &StructType<'ast, T>,
|
||||
e: StructExpressionInner<'ast, T>,
|
||||
) -> Result<StructExpressionInner<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
StructExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id)
|
||||
{
|
||||
Some(c) => {
|
||||
let e: StructExpression<'ast, T> = c.try_into().unwrap();
|
||||
Ok(e.into_inner())
|
||||
}
|
||||
None => fold_struct_expression_inner(self, ty, e),
|
||||
},
|
||||
e => fold_struct_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::typed_absy::types::DeclarationSignature;
|
||||
use crate::typed_absy::{
|
||||
DeclarationArrayType, DeclarationFunctionKey, DeclarationType, FieldElementExpression,
|
||||
GType, Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol,
|
||||
TypedStatement,
|
||||
};
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
#[test]
|
||||
fn inline_const_field() {
|
||||
// const field a = 1
|
||||
//
|
||||
// def main() -> field:
|
||||
// return a
|
||||
|
||||
let const_id = "a";
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(Identifier::from(const_id)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: constants.clone(),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let program = ConstantInliner::inline(program);
|
||||
|
||||
let expected_main = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(1)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(expected_main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants,
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
assert_eq!(program, Ok(expected_program))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inline_const_boolean() {
|
||||
// const bool a = true
|
||||
//
|
||||
// def main() -> bool:
|
||||
// return a
|
||||
|
||||
let const_id = "a";
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier(
|
||||
Identifier::from(const_id),
|
||||
)
|
||||
.into()])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Boolean]),
|
||||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into(), DeclarationType::Boolean),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Boolean(
|
||||
BooleanExpression::Value(true),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Boolean]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: constants.clone(),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let program = ConstantInliner::inline(program);
|
||||
|
||||
let expected_main = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
BooleanExpression::Value(true).into()
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Boolean]),
|
||||
};
|
||||
|
||||
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Boolean]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(expected_main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants,
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
assert_eq!(program, Ok(expected_program))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inline_const_uint() {
|
||||
// const u32 a = 0x00000001
|
||||
//
|
||||
// def main() -> u32:
|
||||
// return a
|
||||
|
||||
let const_id = "a";
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier(
|
||||
Identifier::from(const_id),
|
||||
)
|
||||
.annotate(UBitwidth::B32)
|
||||
.into()])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
|
||||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_id,
|
||||
"main".into(),
|
||||
DeclarationType::Uint(UBitwidth::B32),
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
UExpressionInner::Value(1u128)
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
)),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: constants.clone(),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let program = ConstantInliner::inline(program);
|
||||
|
||||
let expected_main = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![UExpressionInner::Value(1u128)
|
||||
.annotate(UBitwidth::B32)
|
||||
.into()])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
|
||||
};
|
||||
|
||||
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(expected_main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants,
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
assert_eq!(program, Ok(expected_program))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inline_const_field_array() {
|
||||
// const field[2] a = [2, 2]
|
||||
//
|
||||
// def main() -> field:
|
||||
// return a[0] + a[1]
|
||||
|
||||
let const_id = "a";
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
|
||||
FieldElementExpression::select(
|
||||
ArrayExpressionInner::Identifier(Identifier::from(const_id))
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::select(
|
||||
ArrayExpressionInner::Identifier(Identifier::from(const_id))
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.into()])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_id,
|
||||
"main".into(),
|
||||
DeclarationType::Array(DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
2u32,
|
||||
)),
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Array(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
]
|
||||
.into(),
|
||||
)
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: constants.clone(),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let program = ConstantInliner::inline(program);
|
||||
|
||||
let expected_main = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
|
||||
FieldElementExpression::select(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
]
|
||||
.into(),
|
||||
)
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::select(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
]
|
||||
.into(),
|
||||
)
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.into()])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(expected_main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants,
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
assert_eq!(program, Ok(expected_program))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inline_nested_const_field() {
|
||||
// const field a = 1
|
||||
// const field b = a + 1
|
||||
//
|
||||
// def main() -> field:
|
||||
// return b
|
||||
|
||||
let const_a_id = "a";
|
||||
let const_b_id = "b";
|
||||
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(Identifier::from(const_b_id)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: vec![
|
||||
(
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_a_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(1),
|
||||
)),
|
||||
)),
|
||||
),
|
||||
(
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_b_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::FieldElement(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from(
|
||||
const_a_id,
|
||||
)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
)),
|
||||
)),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let program = ConstantInliner::inline(program);
|
||||
|
||||
let expected_main = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(expected_main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: vec![
|
||||
(
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_a_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(1),
|
||||
)),
|
||||
)),
|
||||
),
|
||||
(
|
||||
CanonicalConstantIdentifier::new(
|
||||
const_b_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(2),
|
||||
)),
|
||||
)),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
assert_eq!(program, Ok(expected_program))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inline_imported_constant() {
|
||||
// ---------------------
|
||||
// module `foo`
|
||||
// --------------------
|
||||
// const field FOO = 42
|
||||
//
|
||||
// def main():
|
||||
// return
|
||||
//
|
||||
// ---------------------
|
||||
// module `main`
|
||||
// ---------------------
|
||||
// from "foo" import FOO
|
||||
//
|
||||
// def main() -> field:
|
||||
// return FOO
|
||||
|
||||
let foo_const_id = "FOO";
|
||||
let foo_module = TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main")
|
||||
.signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![],
|
||||
signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: vec![(
|
||||
CanonicalConstantIdentifier::new(
|
||||
foo_const_id,
|
||||
"foo".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(42)),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let main_module = TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(Identifier::from(foo_const_id)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: vec![(
|
||||
CanonicalConstantIdentifier::new(
|
||||
foo_const_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::There(CanonicalConstantIdentifier::new(
|
||||
foo_const_id,
|
||||
"foo".into(),
|
||||
DeclarationType::FieldElement,
|
||||
)),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![
|
||||
("main".into(), main_module),
|
||||
("foo".into(), foo_module.clone()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let program = ConstantInliner::inline(program);
|
||||
let expected_main_module = TypedModule {
|
||||
functions: vec![(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(42)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: vec![(
|
||||
CanonicalConstantIdentifier::new(
|
||||
foo_const_id,
|
||||
"main".into(),
|
||||
DeclarationType::FieldElement,
|
||||
),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(Bn128Field::from(42)),
|
||||
))),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![
|
||||
("main".into(), expected_main_module),
|
||||
("foo".into(), foo_module),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
assert_eq!(program, Ok(expected_program))
|
||||
}
|
||||
}
|
844
zokrates_core/src/static_analysis/constant_resolver.rs
Normal file
844
zokrates_core/src/static_analysis/constant_resolver.rs
Normal file
|
@ -0,0 +1,844 @@
|
|||
// Static analysis step to replace all imported constants with the imported value
|
||||
// This does *not* reduce constants to their literal value
|
||||
// This step cannot fail as the imports were checked during semantics
|
||||
|
||||
use crate::typed_absy::folder::*;
|
||||
use crate::typed_absy::*;
|
||||
use std::collections::HashMap;
|
||||
use zokrates_field::Field;
|
||||
|
||||
// a map of the canonical constants in this program. with all imported constants reduced to their canonical value
|
||||
type ProgramConstants<'ast, T> =
|
||||
HashMap<OwnedTypedModuleId, HashMap<ConstantIdentifier<'ast>, TypedConstant<'ast, T>>>;
|
||||
|
||||
pub struct ConstantResolver<'ast, T> {
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
constants: ProgramConstants<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> ConstantResolver<'ast, T> {
|
||||
pub fn new(
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
constants: ProgramConstants<'ast, T>,
|
||||
) -> Self {
|
||||
ConstantResolver {
|
||||
modules,
|
||||
location,
|
||||
constants,
|
||||
}
|
||||
}
|
||||
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
let constants = ProgramConstants::new();
|
||||
let mut inliner = ConstantResolver::new(p.modules.clone(), p.main.clone(), constants);
|
||||
inliner.fold_program(p)
|
||||
}
|
||||
|
||||
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
|
||||
let prev = self.location.clone();
|
||||
self.location = location;
|
||||
self.constants.entry(self.location.clone()).or_default();
|
||||
prev
|
||||
}
|
||||
|
||||
fn treated(&self, id: &TypedModuleId) -> bool {
|
||||
self.constants.contains_key(id)
|
||||
}
|
||||
|
||||
fn get_constant(
|
||||
&self,
|
||||
id: &CanonicalConstantIdentifier<'ast>,
|
||||
) -> Option<TypedConstant<'ast, T>> {
|
||||
self.constants
|
||||
.get(&id.module)
|
||||
.and_then(|constants| constants.get(&id.id))
|
||||
.cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ConstantResolver<'ast, T> {
|
||||
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
self.fold_module_id(p.main.clone());
|
||||
|
||||
TypedProgram {
|
||||
modules: std::mem::take(&mut self.modules),
|
||||
..p
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId {
|
||||
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
|
||||
if !self.treated(&id) {
|
||||
let current_m_id = self.change_location(id.clone());
|
||||
let m = self.modules.remove(&id).unwrap();
|
||||
let m = self.fold_module(m);
|
||||
self.modules.insert(id.clone(), m);
|
||||
self.change_location(current_m_id);
|
||||
}
|
||||
id
|
||||
}
|
||||
|
||||
fn fold_constant_symbol_declaration(
|
||||
&mut self,
|
||||
c: TypedConstantSymbolDeclaration<'ast, T>,
|
||||
) -> TypedConstantSymbolDeclaration<'ast, T> {
|
||||
let id = self.fold_canonical_constant_identifier(c.id);
|
||||
|
||||
let constant = match c.symbol {
|
||||
TypedConstantSymbol::There(imported_id) => {
|
||||
// visit the imported symbol. This triggers visiting the corresponding module if needed
|
||||
let imported_id = self.fold_canonical_constant_identifier(imported_id);
|
||||
// after that, the constant must have been defined in the global map
|
||||
self.get_constant(&imported_id).unwrap()
|
||||
}
|
||||
TypedConstantSymbol::Here(c) => fold_constant(self, c),
|
||||
};
|
||||
self.constants
|
||||
.get_mut(&self.location)
|
||||
.unwrap()
|
||||
.insert(id.id, constant.clone());
|
||||
|
||||
TypedConstantSymbolDeclaration {
|
||||
id,
|
||||
symbol: TypedConstantSymbol::Here(constant),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::typed_absy::types::DeclarationSignature;
|
||||
use crate::typed_absy::{
|
||||
DeclarationArrayType, DeclarationFunctionKey, DeclarationType, FieldElementExpression,
|
||||
GType, Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol,
|
||||
TypedStatement,
|
||||
};
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
#[test]
|
||||
fn inline_const_field() {
|
||||
// in the absence of imports, a module is left unchanged
|
||||
|
||||
// const field a = 1
|
||||
//
|
||||
// def main() -> field:
|
||||
// return a
|
||||
|
||||
let const_id = "a";
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(Identifier::from(const_id)).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
symbols: vec![
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(1),
|
||||
)),
|
||||
DeclarationType::FieldElement,
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let expected_program = program.clone();
|
||||
|
||||
let program = ConstantResolver::inline(program);
|
||||
|
||||
assert_eq!(program, expected_program)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_op_const_boolean() {
|
||||
// in the absence of imports, a module is left unchanged
|
||||
|
||||
// const bool a = true
|
||||
//
|
||||
// def main() -> bool:
|
||||
// return main.zok/a
|
||||
|
||||
let const_id = CanonicalConstantIdentifier::new("a", "main".into());
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier(
|
||||
Identifier::from(const_id.clone()),
|
||||
)
|
||||
.into()])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Boolean]),
|
||||
};
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
symbols: vec![
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
const_id,
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::Boolean(BooleanExpression::Value(true)),
|
||||
DeclarationType::Boolean,
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Boolean]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let expected_program = program.clone();
|
||||
|
||||
let program = ConstantResolver::inline(program);
|
||||
|
||||
assert_eq!(program, expected_program)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_op_const_uint() {
|
||||
// in the absence of imports, a module is left unchanged
|
||||
|
||||
// const u32 a = 0x00000001
|
||||
//
|
||||
// def main() -> u32:
|
||||
// return a
|
||||
|
||||
let const_id = CanonicalConstantIdentifier::new("a", "main".into());
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier(
|
||||
Identifier::from(const_id.clone()),
|
||||
)
|
||||
.annotate(UBitwidth::B32)
|
||||
.into()])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
|
||||
};
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
symbols: vec![
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
const_id,
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
UExpressionInner::Value(1u128)
|
||||
.annotate(UBitwidth::B32)
|
||||
.into(),
|
||||
DeclarationType::Uint(UBitwidth::B32),
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let expected_program = program.clone();
|
||||
|
||||
let program = ConstantResolver::inline(program);
|
||||
|
||||
assert_eq!(program, expected_program)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_op_const_field_array() {
|
||||
// in the absence of imports, a module is left unchanged
|
||||
|
||||
// const field[2] a = [2, 2]
|
||||
//
|
||||
// def main() -> field:
|
||||
// return a[0] + a[1]
|
||||
|
||||
let const_id = CanonicalConstantIdentifier::new("a", "main".into());
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
|
||||
FieldElementExpression::select(
|
||||
ArrayExpressionInner::Identifier(Identifier::from(const_id.clone()))
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into(),
|
||||
FieldElementExpression::select(
|
||||
ArrayExpressionInner::Identifier(Identifier::from(const_id.clone()))
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.into()])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
symbols: vec![
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
const_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::Array(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![
|
||||
FieldElementExpression::Number(Bn128Field::from(2))
|
||||
.into(),
|
||||
FieldElementExpression::Number(Bn128Field::from(2))
|
||||
.into(),
|
||||
]
|
||||
.into(),
|
||||
)
|
||||
.annotate(GType::FieldElement, 2usize),
|
||||
),
|
||||
DeclarationType::Array(DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
2u32,
|
||||
)),
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let expected_program = program.clone();
|
||||
|
||||
let program = ConstantResolver::inline(program);
|
||||
|
||||
assert_eq!(program, expected_program)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_op_nested_const_field() {
|
||||
// const field a = 1
|
||||
// const field b = a + 1
|
||||
//
|
||||
// def main() -> field:
|
||||
// return b
|
||||
|
||||
let const_a_id = CanonicalConstantIdentifier::new("a", "main".into());
|
||||
let const_b_id = CanonicalConstantIdentifier::new("a", "main".into());
|
||||
|
||||
let main: TypedFunction<Bn128Field> = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(Identifier::from(const_b_id.clone())).into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
};
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
symbols: vec![
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
const_a_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(1),
|
||||
)),
|
||||
DeclarationType::FieldElement,
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
const_b_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::FieldElement(FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from(
|
||||
const_a_id.clone(),
|
||||
)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)),
|
||||
)),
|
||||
DeclarationType::FieldElement,
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let expected_program = program.clone();
|
||||
|
||||
let program = ConstantResolver::inline(program);
|
||||
|
||||
assert_eq!(program, expected_program)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inline_imported_constant() {
|
||||
// ---------------------
|
||||
// module `foo`
|
||||
// --------------------
|
||||
// const field FOO = 42
|
||||
// const field BAR = FOO
|
||||
//
|
||||
// def main():
|
||||
// return
|
||||
//
|
||||
// ---------------------
|
||||
// module `main`
|
||||
// ---------------------
|
||||
// from "foo" import BAR
|
||||
//
|
||||
// def main() -> field:
|
||||
// return FOO
|
||||
|
||||
// Should be resolved to
|
||||
|
||||
// ---------------------
|
||||
// module `foo`
|
||||
// --------------------
|
||||
// const field BAR = ./foo.zok/FOO
|
||||
//
|
||||
// def main():
|
||||
// return
|
||||
//
|
||||
// ---------------------
|
||||
// module `main`
|
||||
// ---------------------
|
||||
// const field FOO = 42
|
||||
//
|
||||
// def main() -> field:
|
||||
// return FOO
|
||||
|
||||
let foo_const_id = CanonicalConstantIdentifier::new("FOO", "foo".into());
|
||||
let bar_const_id = CanonicalConstantIdentifier::new("BAR", "foo".into());
|
||||
let foo_module = TypedModule {
|
||||
symbols: vec![
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
foo_const_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(42),
|
||||
)),
|
||||
DeclarationType::FieldElement,
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
bar_const_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::FieldElement(FieldElementExpression::Identifier(
|
||||
foo_const_id.clone().into(),
|
||||
)),
|
||||
DeclarationType::FieldElement,
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("foo", "main")
|
||||
.signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![],
|
||||
signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
|
||||
}),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
};
|
||||
|
||||
let main_const_id = CanonicalConstantIdentifier::new("FOO", "main".into());
|
||||
let main_module = TypedModule {
|
||||
symbols: vec![
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
main_const_id.clone(),
|
||||
TypedConstantSymbol::There(bar_const_id),
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(Identifier::from(
|
||||
main_const_id.clone(),
|
||||
))
|
||||
.into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
}),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
};
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![
|
||||
("main".into(), main_module),
|
||||
("foo".into(), foo_module.clone()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let program = ConstantResolver::inline(program);
|
||||
let expected_main_module = TypedModule {
|
||||
symbols: vec![
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
main_const_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::FieldElement(FieldElementExpression::Identifier(
|
||||
foo_const_id.into(),
|
||||
)),
|
||||
DeclarationType::FieldElement,
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(Identifier::from(
|
||||
main_const_id.clone(),
|
||||
))
|
||||
.into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
}),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
};
|
||||
|
||||
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![
|
||||
("main".into(), expected_main_module),
|
||||
("foo".into(), foo_module),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
assert_eq!(program, expected_program)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inline_imported_constant_with_generics() {
|
||||
// ---------------------
|
||||
// module `foo`
|
||||
// --------------------
|
||||
// const field FOO = 2
|
||||
// const field[FOO] BAR = [1; FOO]
|
||||
//
|
||||
// def main():
|
||||
// return
|
||||
//
|
||||
// ---------------------
|
||||
// module `main`
|
||||
// ---------------------
|
||||
// from "foo" import FOO
|
||||
// from "foo" import BAR
|
||||
// const field[FOO] BAZ = BAR
|
||||
//
|
||||
// def main() -> field:
|
||||
// return FOO
|
||||
|
||||
// Should be resolved to
|
||||
|
||||
// ---------------------
|
||||
// module `foo`
|
||||
// --------------------
|
||||
// const field FOO = 2
|
||||
// const field[FOO] BAR = [1; FOO]
|
||||
//
|
||||
// def main():
|
||||
// return
|
||||
//
|
||||
// ---------------------
|
||||
// module `main`
|
||||
// ---------------------
|
||||
// const FOO = 2
|
||||
// const BAR = [1; ./foo.zok/FOO]
|
||||
// const field[FOO] BAZ = BAR
|
||||
//
|
||||
// def main() -> field:
|
||||
// return FOO
|
||||
|
||||
let foo_const_id = CanonicalConstantIdentifier::new("FOO", "foo".into());
|
||||
let bar_const_id = CanonicalConstantIdentifier::new("BAR", "foo".into());
|
||||
let foo_module = TypedModule {
|
||||
symbols: vec![
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
foo_const_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
Bn128Field::from(2),
|
||||
)),
|
||||
DeclarationType::FieldElement,
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
bar_const_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::Array(
|
||||
ArrayExpressionInner::Repeat(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)).into(),
|
||||
box UExpression::from(foo_const_id.clone()),
|
||||
)
|
||||
.annotate(Type::FieldElement, foo_const_id.clone()),
|
||||
),
|
||||
DeclarationType::Array(DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
DeclarationConstant::Constant(foo_const_id.clone()),
|
||||
)),
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("foo", "main")
|
||||
.signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![],
|
||||
signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
|
||||
}),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
};
|
||||
|
||||
let main_foo_const_id = CanonicalConstantIdentifier::new("FOO", "main".into());
|
||||
let main_bar_const_id = CanonicalConstantIdentifier::new("BAR", "main".into());
|
||||
let main_baz_const_id = CanonicalConstantIdentifier::new("BAZ", "main".into());
|
||||
|
||||
let main_module = TypedModule {
|
||||
symbols: vec![
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
main_foo_const_id.clone(),
|
||||
TypedConstantSymbol::There(foo_const_id.clone()),
|
||||
)
|
||||
.into(),
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
main_bar_const_id.clone(),
|
||||
TypedConstantSymbol::There(bar_const_id),
|
||||
)
|
||||
.into(),
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
main_baz_const_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::Array(
|
||||
ArrayExpressionInner::Identifier(main_bar_const_id.clone().into())
|
||||
.annotate(Type::FieldElement, main_foo_const_id.clone()),
|
||||
),
|
||||
DeclarationType::Array(DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
DeclarationConstant::Constant(foo_const_id.clone()),
|
||||
)),
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(Identifier::from(
|
||||
main_foo_const_id.clone(),
|
||||
))
|
||||
.into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
}),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
};
|
||||
|
||||
let program = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![
|
||||
("main".into(), main_module),
|
||||
("foo".into(), foo_module.clone()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let program = ConstantResolver::inline(program);
|
||||
let expected_main_module = TypedModule {
|
||||
symbols: vec![
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
main_foo_const_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
FieldElementExpression::Number(Bn128Field::from(2)).into(),
|
||||
DeclarationType::FieldElement,
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
main_bar_const_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::Array(
|
||||
ArrayExpressionInner::Repeat(
|
||||
box FieldElementExpression::Number(Bn128Field::from(1)).into(),
|
||||
box UExpression::from(foo_const_id.clone()),
|
||||
)
|
||||
.annotate(Type::FieldElement, foo_const_id.clone()),
|
||||
),
|
||||
DeclarationType::Array(DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
DeclarationConstant::Constant(foo_const_id.clone()),
|
||||
)),
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedConstantSymbolDeclaration::new(
|
||||
main_baz_const_id.clone(),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
TypedExpression::Array(
|
||||
ArrayExpressionInner::Identifier(main_bar_const_id.into())
|
||||
.annotate(Type::FieldElement, main_foo_const_id.clone()),
|
||||
),
|
||||
DeclarationType::Array(DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
DeclarationConstant::Constant(foo_const_id.clone()),
|
||||
)),
|
||||
)),
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
FieldElementExpression::Identifier(Identifier::from(
|
||||
main_foo_const_id.clone(),
|
||||
))
|
||||
.into(),
|
||||
])],
|
||||
signature: DeclarationSignature::new()
|
||||
.inputs(vec![])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
}),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
};
|
||||
|
||||
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
|
||||
main: "main".into(),
|
||||
modules: vec![
|
||||
("main".into(), expected_main_module),
|
||||
("foo".into(), foo_module),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
assert_eq!(program, expected_program)
|
||||
}
|
||||
}
|
|
@ -126,7 +126,7 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
|
||||
fn fold_declaration_parameter(
|
||||
&mut self,
|
||||
p: typed_absy::DeclarationParameter<'ast>,
|
||||
p: typed_absy::DeclarationParameter<'ast, T>,
|
||||
) -> Vec<zir::Parameter<'ast>> {
|
||||
let private = p.private;
|
||||
self.fold_variable(crate::typed_absy::variable::try_from_g_variable(p.id).unwrap())
|
||||
|
@ -1093,7 +1093,7 @@ fn fold_function<'ast, T: Field>(
|
|||
statements: main_statements_buffer,
|
||||
signature: typed_absy::types::ConcreteSignature::try_from(
|
||||
crate::typed_absy::types::try_from_g_signature::<
|
||||
crate::typed_absy::types::DeclarationConstant<'ast>,
|
||||
crate::typed_absy::types::DeclarationConstant<'ast, T>,
|
||||
crate::typed_absy::UExpression<'ast, T>,
|
||||
>(fun.signature)
|
||||
.unwrap(),
|
||||
|
@ -1139,14 +1139,12 @@ fn fold_program<'ast, T: Field>(
|
|||
let main_module = p.modules.remove(&p.main).unwrap();
|
||||
|
||||
let main_function = main_module
|
||||
.functions
|
||||
.into_iter()
|
||||
.find(|(key, _)| key.id == "main")
|
||||
.into_functions_iter()
|
||||
.find(|d| d.key.id == "main")
|
||||
.unwrap()
|
||||
.1;
|
||||
|
||||
.symbol;
|
||||
let main_function = match main_function {
|
||||
typed_absy::TypedFunctionSymbol::Here(f) => f,
|
||||
typed_absy::TypedFunctionSymbol::Here(main) => main,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
mod branch_isolator;
|
||||
mod constant_argument_checker;
|
||||
mod constant_inliner;
|
||||
mod constant_resolver;
|
||||
mod flat_propagation;
|
||||
mod flatten_complex_types;
|
||||
mod out_of_bounds;
|
||||
|
@ -28,7 +28,7 @@ use self::unconstrained_vars::UnconstrainedVariableDetector;
|
|||
use self::variable_write_remover::VariableWriteRemover;
|
||||
use crate::compile::CompileConfig;
|
||||
use crate::ir::Prog;
|
||||
use crate::static_analysis::constant_inliner::ConstantInliner;
|
||||
use crate::static_analysis::constant_resolver::ConstantResolver;
|
||||
use crate::static_analysis::zir_propagation::ZirPropagator;
|
||||
use crate::typed_absy::{abi::Abi, TypedProgram};
|
||||
use crate::zir::ZirProgram;
|
||||
|
@ -48,17 +48,10 @@ pub enum Error {
|
|||
Propagation(self::propagation::Error),
|
||||
ZirPropagation(self::zir_propagation::Error),
|
||||
NonConstantArgument(self::constant_argument_checker::Error),
|
||||
ConstantInliner(self::constant_inliner::Error),
|
||||
UnconstrainedVariable(self::unconstrained_vars::Error),
|
||||
OutOfBounds(self::out_of_bounds::Error),
|
||||
}
|
||||
|
||||
impl From<constant_inliner::Error> for Error {
|
||||
fn from(e: self::constant_inliner::Error) -> Self {
|
||||
Error::ConstantInliner(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reducer::Error> for Error {
|
||||
fn from(e: self::reducer::Error) -> Self {
|
||||
Error::Reducer(e)
|
||||
|
@ -102,7 +95,6 @@ impl fmt::Display for Error {
|
|||
Error::Propagation(e) => write!(f, "{}", e),
|
||||
Error::ZirPropagation(e) => write!(f, "{}", e),
|
||||
Error::NonConstantArgument(e) => write!(f, "{}", e),
|
||||
Error::ConstantInliner(e) => write!(f, "{}", e),
|
||||
Error::UnconstrainedVariable(e) => write!(f, "{}", e),
|
||||
Error::OutOfBounds(e) => write!(f, "{}", e),
|
||||
}
|
||||
|
@ -113,7 +105,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
|
||||
// inline user-defined constants
|
||||
log::debug!("Static analyser: Inline constants");
|
||||
let r = ConstantInliner::inline(self).map_err(Error::from)?;
|
||||
let r = ConstantResolver::inline(self);
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
// isolate branches
|
||||
|
|
|
@ -16,7 +16,7 @@ use std::convert::{TryFrom, TryInto};
|
|||
use std::fmt;
|
||||
use zokrates_field::Field;
|
||||
|
||||
type Constants<'ast, T> = HashMap<Identifier<'ast>, TypedExpression<'ast, T>>;
|
||||
pub type Constants<'ast, T> = HashMap<Identifier<'ast>, TypedExpression<'ast, T>>;
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Error {
|
||||
|
@ -45,6 +45,7 @@ impl fmt::Display for Error {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Propagator<'ast, 'a, T: Field> {
|
||||
// constants keeps track of constant expressions
|
||||
// we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is
|
||||
|
@ -149,21 +150,17 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
})
|
||||
}
|
||||
|
||||
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> Result<TypedModule<'ast, T>, Error> {
|
||||
Ok(TypedModule {
|
||||
functions: m
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|(key, fun)| {
|
||||
if key.id == "main" {
|
||||
self.fold_function_symbol(fun).map(|f| (key, f))
|
||||
} else {
|
||||
Ok((key, fun))
|
||||
}
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
..m
|
||||
})
|
||||
fn fold_function_symbol_declaration(
|
||||
&mut self,
|
||||
s: TypedFunctionSymbolDeclaration<'ast, T>,
|
||||
) -> Result<TypedFunctionSymbolDeclaration<'ast, T>, Error> {
|
||||
if s.key.id == "main" {
|
||||
let key = s.key;
|
||||
self.fold_function_symbol(s.symbol)
|
||||
.map(|f| TypedFunctionSymbolDeclaration { key, symbol: f })
|
||||
} else {
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_function(
|
||||
|
|
169
zokrates_core/src/static_analysis/reducer/constants_reader.rs
Normal file
169
zokrates_core/src/static_analysis/reducer/constants_reader.rs
Normal file
|
@ -0,0 +1,169 @@
|
|||
// given a (partial) map of values for program constants, replace where applicable constants by their value
|
||||
|
||||
use crate::static_analysis::reducer::ConstantDefinitions;
|
||||
use crate::typed_absy::{
|
||||
folder::*, ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier,
|
||||
DeclarationConstant, Expr, FieldElementExpression, Identifier, StructExpression,
|
||||
StructExpressionInner, StructType, TypedProgram, TypedSymbolDeclaration, UBitwidth,
|
||||
UExpression, UExpressionInner,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
||||
pub struct ConstantsReader<'a, 'ast, T> {
|
||||
constants: &'a ConstantDefinitions<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'a, 'ast, T: Field> ConstantsReader<'a, 'ast, T> {
|
||||
pub fn with_constants(constants: &'a ConstantDefinitions<'ast, T>) -> Self {
|
||||
Self { constants }
|
||||
}
|
||||
|
||||
pub fn read_into_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
self.fold_program(p)
|
||||
}
|
||||
|
||||
pub fn read_into_symbol_declaration(
|
||||
&mut self,
|
||||
d: TypedSymbolDeclaration<'ast, T>,
|
||||
) -> TypedSymbolDeclaration<'ast, T> {
|
||||
self.fold_symbol_declaration(d)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> {
|
||||
fn fold_field_expression(
|
||||
&mut self,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
) -> FieldElementExpression<'ast, T> {
|
||||
match e {
|
||||
FieldElementExpression::Identifier(Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
version,
|
||||
}) => {
|
||||
assert_eq!(version, 0);
|
||||
match self.constants.get(&c).cloned() {
|
||||
Some(v) => v.try_into().unwrap(),
|
||||
None => FieldElementExpression::Identifier(Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
version,
|
||||
}),
|
||||
}
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_boolean_expression(
|
||||
&mut self,
|
||||
e: BooleanExpression<'ast, T>,
|
||||
) -> BooleanExpression<'ast, T> {
|
||||
match e {
|
||||
BooleanExpression::Identifier(Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
version,
|
||||
}) => {
|
||||
assert_eq!(version, 0);
|
||||
match self.constants.get(&c).cloned() {
|
||||
Some(v) => v.try_into().unwrap(),
|
||||
None => BooleanExpression::Identifier(Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
version,
|
||||
}),
|
||||
}
|
||||
}
|
||||
e => fold_boolean_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
ty: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> UExpressionInner<'ast, T> {
|
||||
match e {
|
||||
UExpressionInner::Identifier(Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
version,
|
||||
}) => {
|
||||
assert_eq!(version, 0);
|
||||
|
||||
match self.constants.get(&c).cloned() {
|
||||
Some(v) => UExpression::try_from(v).unwrap().into_inner(),
|
||||
None => UExpressionInner::Identifier(Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
version,
|
||||
}),
|
||||
}
|
||||
}
|
||||
e => fold_uint_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_array_expression_inner(
|
||||
&mut self,
|
||||
ty: &ArrayType<'ast, T>,
|
||||
e: ArrayExpressionInner<'ast, T>,
|
||||
) -> ArrayExpressionInner<'ast, T> {
|
||||
match e {
|
||||
ArrayExpressionInner::Identifier(Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
version,
|
||||
}) => {
|
||||
assert_eq!(version, 0);
|
||||
match self.constants.get(&c).cloned() {
|
||||
Some(v) => ArrayExpression::try_from(v).unwrap().into_inner(),
|
||||
None => ArrayExpressionInner::Identifier(Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
version,
|
||||
}),
|
||||
}
|
||||
}
|
||||
e => fold_array_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_struct_expression_inner(
|
||||
&mut self,
|
||||
ty: &StructType<'ast, T>,
|
||||
e: StructExpressionInner<'ast, T>,
|
||||
) -> StructExpressionInner<'ast, T> {
|
||||
match e {
|
||||
StructExpressionInner::Identifier(Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
version,
|
||||
}) => {
|
||||
assert_eq!(version, 0);
|
||||
match self.constants.get(&c).cloned() {
|
||||
Some(v) => StructExpression::try_from(v).unwrap().into_inner(),
|
||||
None => StructExpressionInner::Identifier(Identifier {
|
||||
id: CoreIdentifier::Constant(c),
|
||||
version,
|
||||
}),
|
||||
}
|
||||
}
|
||||
e => fold_struct_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_declaration_constant(
|
||||
&mut self,
|
||||
c: DeclarationConstant<'ast, T>,
|
||||
) -> DeclarationConstant<'ast, T> {
|
||||
match c {
|
||||
DeclarationConstant::Constant(c) => {
|
||||
let c = self.fold_canonical_constant_identifier(c);
|
||||
|
||||
match self.constants.get(&c).cloned() {
|
||||
Some(e) => match UExpression::try_from(e).unwrap().into_inner() {
|
||||
UExpressionInner::Value(v) => DeclarationConstant::Concrete(v as u32),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
None => DeclarationConstant::Constant(c),
|
||||
}
|
||||
}
|
||||
c => fold_declaration_constant(self, c),
|
||||
}
|
||||
}
|
||||
}
|
163
zokrates_core/src/static_analysis/reducer/constants_writer.rs
Normal file
163
zokrates_core/src/static_analysis/reducer/constants_writer.rs
Normal file
|
@ -0,0 +1,163 @@
|
|||
// A folder to inline all constant definitions down to a single literal and register them in the state for later use.
|
||||
|
||||
use crate::static_analysis::reducer::{
|
||||
constants_reader::ConstantsReader, reduce_function, ConstantDefinitions, Error,
|
||||
};
|
||||
use crate::typed_absy::{
|
||||
result_folder::*, types::ConcreteGenericsAssignment, OwnedTypedModuleId, TypedConstant,
|
||||
TypedConstantSymbol, TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram,
|
||||
TypedSymbolDeclaration, UExpression,
|
||||
};
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub struct ConstantsWriter<'ast, T> {
|
||||
treated: HashSet<OwnedTypedModuleId>,
|
||||
constants: ConstantDefinitions<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
program: TypedProgram<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ConstantsWriter<'ast, T> {
|
||||
pub fn with_program(program: TypedProgram<'ast, T>) -> Self {
|
||||
ConstantsWriter {
|
||||
constants: ConstantDefinitions::default(),
|
||||
location: program.main.clone(),
|
||||
treated: HashSet::default(),
|
||||
program,
|
||||
}
|
||||
}
|
||||
|
||||
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
|
||||
let prev = self.location.clone();
|
||||
self.location = location;
|
||||
self.treated.insert(self.location.clone());
|
||||
prev
|
||||
}
|
||||
|
||||
fn treated(&self, id: &TypedModuleId) -> bool {
|
||||
self.treated.contains(id)
|
||||
}
|
||||
|
||||
fn update_program(&mut self) {
|
||||
let mut p = TypedProgram {
|
||||
main: "".into(),
|
||||
modules: BTreeMap::default(),
|
||||
};
|
||||
std::mem::swap(&mut self.program, &mut p);
|
||||
self.program = ConstantsReader::with_constants(&self.constants).read_into_program(p);
|
||||
}
|
||||
|
||||
fn update_symbol_declaration(
|
||||
&self,
|
||||
d: TypedSymbolDeclaration<'ast, T>,
|
||||
) -> TypedSymbolDeclaration<'ast, T> {
|
||||
ConstantsReader::with_constants(&self.constants).read_into_symbol_declaration(d)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_module_id(
|
||||
&mut self,
|
||||
id: OwnedTypedModuleId,
|
||||
) -> Result<OwnedTypedModuleId, Self::Error> {
|
||||
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
|
||||
if !self.treated(&id) {
|
||||
let current_m_id = self.change_location(id.clone());
|
||||
// I did not find a way to achieve this without cloning the module. Assuming we do not clone:
|
||||
// to fold the module, we need to consume it, so it gets removed from the modules
|
||||
// but to inline the calls while folding the module, all modules must be present
|
||||
// therefore we clone...
|
||||
// this does not lead to a module being folded more than once, as the first time
|
||||
// we change location to this module, it's added to the `treated` set
|
||||
let m = self.program.modules.get(&id).cloned().unwrap();
|
||||
let m = self.fold_module(m)?;
|
||||
self.program.modules.insert(id.clone(), m);
|
||||
self.change_location(current_m_id);
|
||||
}
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
fn fold_symbol_declaration(
|
||||
&mut self,
|
||||
s: TypedSymbolDeclaration<'ast, T>,
|
||||
) -> Result<TypedSymbolDeclaration<'ast, T>, Self::Error> {
|
||||
// before we treat the symbol, propagate the constants into it, as may be using constants defined earlier in this module.
|
||||
let s = self.update_symbol_declaration(s);
|
||||
|
||||
let s = fold_symbol_declaration(self, s)?;
|
||||
|
||||
// after we treat the symbol, propagate again, as treating this symbol may have triggered checking another module, resolving new constants which this symbol may be using.
|
||||
Ok(self.update_symbol_declaration(s))
|
||||
}
|
||||
|
||||
fn fold_constant_symbol_declaration(
|
||||
&mut self,
|
||||
d: TypedConstantSymbolDeclaration<'ast, T>,
|
||||
) -> Result<TypedConstantSymbolDeclaration<'ast, T>, Self::Error> {
|
||||
let id = self.fold_canonical_constant_identifier(d.id)?;
|
||||
|
||||
match d.symbol {
|
||||
TypedConstantSymbol::Here(c) => {
|
||||
let c = self.fold_constant(c)?;
|
||||
|
||||
use crate::typed_absy::{DeclarationSignature, TypedFunction, TypedStatement};
|
||||
|
||||
// wrap this expression in a function
|
||||
let wrapper = TypedFunction {
|
||||
arguments: vec![],
|
||||
statements: vec![TypedStatement::Return(vec![c.expression])],
|
||||
signature: DeclarationSignature::new().outputs(vec![c.ty.clone()]),
|
||||
};
|
||||
|
||||
let mut inlined_wrapper = reduce_function(
|
||||
wrapper,
|
||||
ConcreteGenericsAssignment::default(),
|
||||
&self.program,
|
||||
)?;
|
||||
|
||||
if let TypedStatement::Return(mut expressions) =
|
||||
inlined_wrapper.statements.pop().unwrap()
|
||||
{
|
||||
assert_eq!(expressions.len(), 1);
|
||||
let constant_expression = expressions.pop().unwrap();
|
||||
|
||||
use crate::typed_absy::Constant;
|
||||
if !constant_expression.is_constant() {
|
||||
return Err(Error::ConstantReduction(id.id.to_string(), id.module));
|
||||
};
|
||||
|
||||
use crate::typed_absy::Typed;
|
||||
if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>(
|
||||
c.ty.clone(),
|
||||
)
|
||||
.unwrap()
|
||||
== constant_expression.get_type()
|
||||
{
|
||||
// add to the constant map
|
||||
self.constants
|
||||
.insert(id.clone(), constant_expression.clone());
|
||||
|
||||
// after we reduced a constant, propagate it through the whole program
|
||||
self.update_program();
|
||||
|
||||
Ok(TypedConstantSymbolDeclaration {
|
||||
id,
|
||||
symbol: TypedConstantSymbol::Here(TypedConstant {
|
||||
expression: constant_expression,
|
||||
ty: c.ty,
|
||||
}),
|
||||
})
|
||||
} else {
|
||||
Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant_expression.get_type(), id, c.ty)))
|
||||
}
|
||||
} else {
|
||||
Err(Error::ConstantReduction(id.id.to_string(), id.module))
|
||||
}
|
||||
}
|
||||
_ => unreachable!("all constants should be local"),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -35,13 +35,13 @@ use crate::typed_absy::Identifier;
|
|||
use crate::typed_absy::TypedAssignee;
|
||||
use crate::typed_absy::{
|
||||
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Expr,
|
||||
Signature, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, Types,
|
||||
UExpression, UExpressionInner, Variable,
|
||||
Signature, TypedExpression, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedProgram,
|
||||
TypedStatement, Types, UExpression, UExpressionInner, Variable,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub enum InlineError<'ast, T> {
|
||||
Generic(DeclarationFunctionKey<'ast>, ConcreteFunctionKey<'ast>),
|
||||
Generic(DeclarationFunctionKey<'ast, T>, ConcreteFunctionKey<'ast>),
|
||||
Flat(
|
||||
FlatEmbed,
|
||||
Vec<u32>,
|
||||
|
@ -49,7 +49,7 @@ pub enum InlineError<'ast, T> {
|
|||
Types<'ast, T>,
|
||||
),
|
||||
NonConstant(
|
||||
DeclarationFunctionKey<'ast>,
|
||||
DeclarationFunctionKey<'ast, T>,
|
||||
Vec<Option<UExpression<'ast, T>>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Types<'ast, T>,
|
||||
|
@ -57,20 +57,20 @@ pub enum InlineError<'ast, T> {
|
|||
}
|
||||
|
||||
fn get_canonical_function<'ast, T: Field>(
|
||||
function_key: DeclarationFunctionKey<'ast>,
|
||||
function_key: DeclarationFunctionKey<'ast, T>,
|
||||
program: &TypedProgram<'ast, T>,
|
||||
) -> (DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>) {
|
||||
match program
|
||||
) -> TypedFunctionSymbolDeclaration<'ast, T> {
|
||||
let s = program
|
||||
.modules
|
||||
.get(&function_key.module)
|
||||
.unwrap()
|
||||
.functions
|
||||
.iter()
|
||||
.find(|(key, _)| function_key == **key)
|
||||
.unwrap()
|
||||
{
|
||||
(_, TypedFunctionSymbol::There(key)) => get_canonical_function(key.clone(), &program),
|
||||
(key, s) => (key.clone(), s.clone()),
|
||||
.functions_iter()
|
||||
.find(|d| d.key == function_key)
|
||||
.unwrap();
|
||||
|
||||
match &s.symbol {
|
||||
TypedFunctionSymbol::There(key) => get_canonical_function(key.clone(), &program),
|
||||
_ => s.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -80,7 +80,7 @@ type InlineResult<'ast, T> = Result<
|
|||
>;
|
||||
|
||||
pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
||||
k: DeclarationFunctionKey<'ast>,
|
||||
k: DeclarationFunctionKey<'ast, T>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output: &E::Ty,
|
||||
|
@ -134,7 +134,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
}
|
||||
};
|
||||
|
||||
let (decl_key, symbol) = get_canonical_function(k.clone(), program);
|
||||
let decl = get_canonical_function(k.clone(), program);
|
||||
|
||||
// get an assignment of generics for this call site
|
||||
let assignment: ConcreteGenericsAssignment<'ast> = k
|
||||
|
@ -144,18 +144,18 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
InlineError::Generic(
|
||||
k.clone(),
|
||||
ConcreteFunctionKey {
|
||||
module: decl_key.module.clone(),
|
||||
id: decl_key.id,
|
||||
module: decl.key.module.clone(),
|
||||
id: decl.key.id,
|
||||
signature: inferred_signature.clone(),
|
||||
},
|
||||
)
|
||||
})?;
|
||||
|
||||
let f = match symbol {
|
||||
let f = match decl.symbol {
|
||||
TypedFunctionSymbol::Here(f) => Ok(f),
|
||||
TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat(
|
||||
e,
|
||||
e.generics(&assignment),
|
||||
e.generics::<T>(&assignment),
|
||||
arguments.clone(),
|
||||
output_types,
|
||||
)),
|
||||
|
@ -169,7 +169,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
|
|||
Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)),
|
||||
};
|
||||
|
||||
let call_log = TypedStatement::PushCallLog(decl_key.clone(), assignment.clone());
|
||||
let call_log = TypedStatement::PushCallLog(decl.key.clone(), assignment.clone());
|
||||
|
||||
let input_bindings: Vec<TypedStatement<'ast, T>> = ssa_f
|
||||
.arguments
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
// - unroll loops
|
||||
// - inline function calls. This includes applying shallow-ssa on the target function
|
||||
|
||||
mod constants_reader;
|
||||
mod constants_writer;
|
||||
mod inline;
|
||||
mod shallow_ssa;
|
||||
|
||||
|
@ -18,26 +20,33 @@ use self::inline::{inline_call, InlineError};
|
|||
use crate::typed_absy::result_folder::*;
|
||||
use crate::typed_absy::types::ConcreteGenericsAssignment;
|
||||
use crate::typed_absy::types::GGenericsAssignment;
|
||||
use crate::typed_absy::CanonicalConstantIdentifier;
|
||||
use crate::typed_absy::Folder;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::typed_absy::{
|
||||
ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall,
|
||||
FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, TypedExpression,
|
||||
TypedExpressionList, TypedExpressionListInner, TypedFunction, TypedFunctionSymbol, TypedModule,
|
||||
TypedProgram, TypedStatement, UExpression, UExpressionInner, Variable,
|
||||
FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId,
|
||||
TypedExpression, TypedExpressionList, TypedExpressionListInner, TypedFunction,
|
||||
TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram, TypedStatement,
|
||||
UExpression, UExpressionInner, Variable,
|
||||
};
|
||||
|
||||
use zokrates_field::Field;
|
||||
|
||||
use self::constants_writer::ConstantsWriter;
|
||||
use self::shallow_ssa::ShallowTransformer;
|
||||
|
||||
use crate::static_analysis::Propagator;
|
||||
use crate::static_analysis::propagation::{Constants, Propagator};
|
||||
|
||||
use std::fmt;
|
||||
|
||||
const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20);
|
||||
|
||||
// A map to register the canonical value of all constants. The values must be literals.
|
||||
pub type ConstantDefinitions<'ast, T> =
|
||||
HashMap<CanonicalConstantIdentifier<'ast>, TypedExpression<'ast, T>>;
|
||||
|
||||
// An SSA version map, giving access to the latest version number for each identifier
|
||||
pub type Versions<'ast> = HashMap<CoreIdentifier<'ast>, usize>;
|
||||
|
||||
|
@ -55,6 +64,8 @@ pub enum Error {
|
|||
// TODO: give more details about what's blocking the progress
|
||||
NoProgress,
|
||||
LoopTooLarge(u128),
|
||||
ConstantReduction(String, OwnedTypedModuleId),
|
||||
Type(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
|
@ -68,6 +79,8 @@ impl fmt::Display for Error {
|
|||
Error::GenericsInMain => write!(f, "Cannot generate code for generic function"),
|
||||
Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"),
|
||||
Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE),
|
||||
Error::ConstantReduction(name, module) => write!(f, "Failed to reduce constant `{}` in module `{}` to a literal, try simplifying its declaration", name, module.display()),
|
||||
Error::Type(message) => write!(f, "{}", message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -159,6 +172,7 @@ fn register<'ast>(
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Reducer<'ast, 'a, T> {
|
||||
statement_buffer: Vec<TypedStatement<'ast, T>>,
|
||||
for_loop_versions: Vec<Versions<'ast>>,
|
||||
|
@ -304,6 +318,13 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
})
|
||||
}
|
||||
|
||||
fn fold_canonical_constant_identifier(
|
||||
&mut self,
|
||||
_: CanonicalConstantIdentifier<'ast>,
|
||||
) -> Result<CanonicalConstantIdentifier<'ast>, Self::Error> {
|
||||
unreachable!("canonical constant identifiers should not be folded, they should be inlined")
|
||||
}
|
||||
|
||||
fn fold_statement(
|
||||
&mut self,
|
||||
s: TypedStatement<'ast, T>,
|
||||
|
@ -487,15 +508,21 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
}
|
||||
|
||||
pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
|
||||
// inline all constants and replace them in the program
|
||||
|
||||
let mut constants_writer = ConstantsWriter::with_program(p.clone());
|
||||
|
||||
let p = constants_writer.fold_program(p)?;
|
||||
|
||||
// inline starting from main
|
||||
let main_module = p.modules.get(&p.main).unwrap().clone();
|
||||
|
||||
let (main_key, main_function) = main_module
|
||||
.functions
|
||||
.iter()
|
||||
.find(|(k, _)| k.id == "main")
|
||||
let decl = main_module
|
||||
.functions_iter()
|
||||
.find(|d| d.key.id == "main")
|
||||
.unwrap();
|
||||
|
||||
let main_function = match main_function {
|
||||
let main_function = match &decl.symbol {
|
||||
TypedFunctionSymbol::Here(f) => f.clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
@ -509,13 +536,11 @@ pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, E
|
|||
modules: vec![(
|
||||
p.main.clone(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
main_key.clone(),
|
||||
symbols: vec![TypedFunctionSymbolDeclaration::new(
|
||||
decl.key.clone(),
|
||||
TypedFunctionSymbol::Here(main_function),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Default::default(),
|
||||
)
|
||||
.into()],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -533,7 +558,9 @@ fn reduce_function<'ast, T: Field>(
|
|||
) -> Result<TypedFunction<'ast, T>, Error> {
|
||||
let mut versions = Versions::default();
|
||||
|
||||
match ShallowTransformer::transform(f, &generics, &mut versions) {
|
||||
let mut constants = Constants::default();
|
||||
|
||||
let f = match ShallowTransformer::transform(f, &generics, &mut versions) {
|
||||
Output::Complete(f) => Ok(f),
|
||||
Output::Incomplete(new_f, new_for_loop_versions) => {
|
||||
let mut for_loop_versions = new_for_loop_versions;
|
||||
|
@ -542,8 +569,6 @@ fn reduce_function<'ast, T: Field>(
|
|||
|
||||
let mut substitutions = Substitutions::default();
|
||||
|
||||
let mut constants: HashMap<Identifier<'ast>, TypedExpression<'ast, T>> = HashMap::new();
|
||||
|
||||
let mut hash = None;
|
||||
|
||||
loop {
|
||||
|
@ -600,7 +625,11 @@ fn reduce_function<'ast, T: Field>(
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}?;
|
||||
|
||||
Propagator::with_constants(&mut constants)
|
||||
.fold_function(f)
|
||||
.map_err(|e| Error::Incompatible(format!("{}", e)))
|
||||
}
|
||||
|
||||
fn compute_hash<T: Field>(f: &TypedFunction<T>) -> u64 {
|
||||
|
@ -710,27 +739,26 @@ mod tests {
|
|||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![
|
||||
(
|
||||
symbols: vec![
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "foo").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(foo),
|
||||
),
|
||||
(
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Default::default(),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -786,17 +814,15 @@ mod tests {
|
|||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
symbols: vec![TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(expected_main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Default::default(),
|
||||
)
|
||||
.into()],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -913,24 +939,23 @@ mod tests {
|
|||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![
|
||||
(
|
||||
symbols: vec![
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
TypedFunctionSymbol::Here(foo),
|
||||
),
|
||||
(
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Default::default(),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -1005,17 +1030,15 @@ mod tests {
|
|||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
symbols: vec![TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(expected_main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Default::default(),
|
||||
)
|
||||
.into()],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -1141,24 +1164,23 @@ mod tests {
|
|||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![
|
||||
(
|
||||
symbols: vec![
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
TypedFunctionSymbol::Here(foo),
|
||||
),
|
||||
(
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Default::default(),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -1233,17 +1255,15 @@ mod tests {
|
|||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
symbols: vec![TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(expected_main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Default::default(),
|
||||
)
|
||||
.into()],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -1399,25 +1419,25 @@ mod tests {
|
|||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![
|
||||
(
|
||||
symbols: vec![
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "bar")
|
||||
.signature(bar_signature.clone()),
|
||||
TypedFunctionSymbol::Here(bar),
|
||||
),
|
||||
(
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
TypedFunctionSymbol::Here(foo),
|
||||
),
|
||||
(
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main"),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Default::default(),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -1459,14 +1479,12 @@ mod tests {
|
|||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
symbols: vec![TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main")
|
||||
.signature(DeclarationSignature::new()),
|
||||
TypedFunctionSymbol::Here(expected_main),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Default::default(),
|
||||
)
|
||||
.into()],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
@ -1540,22 +1558,21 @@ mod tests {
|
|||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![
|
||||
(
|
||||
symbols: vec![
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "foo")
|
||||
.signature(foo_signature.clone()),
|
||||
TypedFunctionSymbol::Here(foo),
|
||||
),
|
||||
(
|
||||
)
|
||||
.into(),
|
||||
TypedFunctionSymbolDeclaration::new(
|
||||
DeclarationFunctionKey::with_location("main", "main").signature(
|
||||
DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
|
||||
),
|
||||
TypedFunctionSymbol::Here(main),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
constants: Default::default(),
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
|
|
|
@ -35,15 +35,15 @@ mod tests {
|
|||
};
|
||||
use crate::typed_absy::{
|
||||
parameter::DeclarationParameter, variable::DeclarationVariable, ConcreteType,
|
||||
TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram,
|
||||
TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule,
|
||||
TypedProgram,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::BTreeMap;
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
#[test]
|
||||
fn generate_abi_from_typed_ast() {
|
||||
let mut functions = HashMap::new();
|
||||
functions.insert(
|
||||
let symbols = vec![TypedFunctionSymbolDeclaration::new(
|
||||
ConcreteFunctionKey::with_location("main", "main").into(),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![
|
||||
|
@ -62,16 +62,11 @@ mod tests {
|
|||
.outputs(vec![ConcreteType::FieldElement])
|
||||
.into(),
|
||||
}),
|
||||
);
|
||||
)
|
||||
.into()];
|
||||
|
||||
let mut modules = HashMap::new();
|
||||
modules.insert(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions,
|
||||
constants: Default::default(),
|
||||
},
|
||||
);
|
||||
let mut modules = BTreeMap::new();
|
||||
modules.insert("main".into(), TypedModule { symbols });
|
||||
|
||||
let typed_ast: TypedProgram<Bn128Field> = TypedProgram {
|
||||
main: "main".into(),
|
||||
|
|
|
@ -47,6 +47,27 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
fold_module(self, m)
|
||||
}
|
||||
|
||||
fn fold_symbol_declaration(
|
||||
&mut self,
|
||||
s: TypedSymbolDeclaration<'ast, T>,
|
||||
) -> TypedSymbolDeclaration<'ast, T> {
|
||||
fold_symbol_declaration(self, s)
|
||||
}
|
||||
|
||||
fn fold_function_symbol_declaration(
|
||||
&mut self,
|
||||
s: TypedFunctionSymbolDeclaration<'ast, T>,
|
||||
) -> TypedFunctionSymbolDeclaration<'ast, T> {
|
||||
fold_function_symbol_declaration(self, s)
|
||||
}
|
||||
|
||||
fn fold_constant_symbol_declaration(
|
||||
&mut self,
|
||||
s: TypedConstantSymbolDeclaration<'ast, T>,
|
||||
) -> TypedConstantSymbolDeclaration<'ast, T> {
|
||||
fold_constant_symbol_declaration(self, s)
|
||||
}
|
||||
|
||||
fn fold_constant(&mut self, c: TypedConstant<'ast, T>) -> TypedConstant<'ast, T> {
|
||||
fold_constant(self, c)
|
||||
}
|
||||
|
@ -67,8 +88,8 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_declaration_function_key(
|
||||
&mut self,
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
) -> DeclarationFunctionKey<'ast> {
|
||||
key: DeclarationFunctionKey<'ast, T>,
|
||||
) -> DeclarationFunctionKey<'ast, T> {
|
||||
fold_declaration_function_key(self, key)
|
||||
}
|
||||
|
||||
|
@ -76,18 +97,24 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
fold_function(self, f)
|
||||
}
|
||||
|
||||
fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> {
|
||||
fn fold_signature(
|
||||
&mut self,
|
||||
s: DeclarationSignature<'ast, T>,
|
||||
) -> DeclarationSignature<'ast, T> {
|
||||
fold_signature(self, s)
|
||||
}
|
||||
|
||||
fn fold_declaration_constant(
|
||||
&mut self,
|
||||
c: DeclarationConstant<'ast>,
|
||||
) -> DeclarationConstant<'ast> {
|
||||
c: DeclarationConstant<'ast, T>,
|
||||
) -> DeclarationConstant<'ast, T> {
|
||||
fold_declaration_constant(self, c)
|
||||
}
|
||||
|
||||
fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> {
|
||||
fn fold_parameter(
|
||||
&mut self,
|
||||
p: DeclarationParameter<'ast, T>,
|
||||
) -> DeclarationParameter<'ast, T> {
|
||||
DeclarationParameter {
|
||||
id: self.fold_declaration_variable(p.id),
|
||||
..p
|
||||
|
@ -107,8 +134,8 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_declaration_variable(
|
||||
&mut self,
|
||||
v: DeclarationVariable<'ast>,
|
||||
) -> DeclarationVariable<'ast> {
|
||||
v: DeclarationVariable<'ast, T>,
|
||||
) -> DeclarationVariable<'ast, T> {
|
||||
DeclarationVariable {
|
||||
id: self.fold_name(v.id),
|
||||
_type: self.fold_declaration_type(v._type),
|
||||
|
@ -155,7 +182,7 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> {
|
||||
fn fold_declaration_type(&mut self, t: DeclarationType<'ast, T>) -> DeclarationType<'ast, T> {
|
||||
use self::GType::*;
|
||||
|
||||
match t {
|
||||
|
@ -167,8 +194,8 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_declaration_array_type(
|
||||
&mut self,
|
||||
t: DeclarationArrayType<'ast>,
|
||||
) -> DeclarationArrayType<'ast> {
|
||||
t: DeclarationArrayType<'ast, T>,
|
||||
) -> DeclarationArrayType<'ast, T> {
|
||||
DeclarationArrayType {
|
||||
ty: box self.fold_declaration_type(*t.ty),
|
||||
size: self.fold_declaration_constant(t.size),
|
||||
|
@ -177,8 +204,8 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_declaration_struct_type(
|
||||
&mut self,
|
||||
t: DeclarationStructType<'ast>,
|
||||
) -> DeclarationStructType<'ast> {
|
||||
t: DeclarationStructType<'ast, T>,
|
||||
) -> DeclarationStructType<'ast, T> {
|
||||
DeclarationStructType {
|
||||
generics: t
|
||||
.generics
|
||||
|
@ -232,7 +259,6 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
CanonicalConstantIdentifier {
|
||||
module: self.fold_module_id(i.module),
|
||||
id: i.id,
|
||||
ty: box self.fold_declaration_type(*i.ty),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -378,29 +404,48 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
m: TypedModule<'ast, T>,
|
||||
) -> TypedModule<'ast, T> {
|
||||
TypedModule {
|
||||
constants: m
|
||||
.constants
|
||||
symbols: m
|
||||
.symbols
|
||||
.into_iter()
|
||||
.map(|(id, tc)| {
|
||||
(
|
||||
f.fold_canonical_constant_identifier(id),
|
||||
f.fold_constant_symbol(tc),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
functions: m
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|(key, fun)| {
|
||||
(
|
||||
f.fold_declaration_function_key(key),
|
||||
f.fold_function_symbol(fun),
|
||||
)
|
||||
})
|
||||
.map(|s| f.fold_symbol_declaration(s))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_symbol_declaration<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
d: TypedSymbolDeclaration<'ast, T>,
|
||||
) -> TypedSymbolDeclaration<'ast, T> {
|
||||
match d {
|
||||
TypedSymbolDeclaration::Function(d) => {
|
||||
TypedSymbolDeclaration::Function(f.fold_function_symbol_declaration(d))
|
||||
}
|
||||
TypedSymbolDeclaration::Constant(d) => {
|
||||
TypedSymbolDeclaration::Constant(f.fold_constant_symbol_declaration(d))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_function_symbol_declaration<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
d: TypedFunctionSymbolDeclaration<'ast, T>,
|
||||
) -> TypedFunctionSymbolDeclaration<'ast, T> {
|
||||
TypedFunctionSymbolDeclaration {
|
||||
key: f.fold_declaration_function_key(d.key),
|
||||
symbol: f.fold_function_symbol(d.symbol),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_constant_symbol_declaration<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
d: TypedConstantSymbolDeclaration<'ast, T>,
|
||||
) -> TypedConstantSymbolDeclaration<'ast, T> {
|
||||
TypedConstantSymbolDeclaration {
|
||||
id: f.fold_canonical_constant_identifier(d.id),
|
||||
symbol: f.fold_constant_symbol(d.symbol),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: TypedStatement<'ast, T>,
|
||||
|
@ -902,8 +947,8 @@ pub fn fold_block_expression<'ast, T: Field, E: Fold<'ast, T>, F: Folder<'ast, T
|
|||
|
||||
pub fn fold_declaration_function_key<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
) -> DeclarationFunctionKey<'ast> {
|
||||
key: DeclarationFunctionKey<'ast, T>,
|
||||
) -> DeclarationFunctionKey<'ast, T> {
|
||||
DeclarationFunctionKey {
|
||||
module: f.fold_module_id(key.module),
|
||||
signature: f.fold_signature(key.signature),
|
||||
|
@ -955,8 +1000,8 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
|
||||
fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: DeclarationSignature<'ast>,
|
||||
) -> DeclarationSignature<'ast> {
|
||||
s: DeclarationSignature<'ast, T>,
|
||||
) -> DeclarationSignature<'ast, T> {
|
||||
DeclarationSignature {
|
||||
generics: s.generics,
|
||||
inputs: s
|
||||
|
@ -972,11 +1017,14 @@ fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_declaration_constant<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
_: &mut F,
|
||||
c: DeclarationConstant<'ast>,
|
||||
) -> DeclarationConstant<'ast> {
|
||||
c
|
||||
pub fn fold_declaration_constant<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
c: DeclarationConstant<'ast, T>,
|
||||
) -> DeclarationConstant<'ast, T> {
|
||||
match c {
|
||||
DeclarationConstant::Expression(e) => DeclarationConstant::Expression(f.fold_expression(e)),
|
||||
c => c,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
|
@ -1056,6 +1104,7 @@ pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
) -> TypedConstant<'ast, T> {
|
||||
TypedConstant {
|
||||
expression: f.fold_expression(c.expression),
|
||||
ty: f.fold_declaration_type(c.ty),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use crate::typed_absy::CanonicalConstantIdentifier;
|
||||
use std::convert::TryInto;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum CoreIdentifier<'ast> {
|
||||
Source(&'ast str),
|
||||
Call(usize),
|
||||
Constant(CanonicalConstantIdentifier<'ast>),
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for CoreIdentifier<'ast> {
|
||||
|
@ -12,6 +14,7 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> {
|
|||
match self {
|
||||
CoreIdentifier::Source(s) => write!(f, "{}", s),
|
||||
CoreIdentifier::Call(i) => write!(f, "#CALL_RETURN_AT_INDEX_{}", i),
|
||||
CoreIdentifier::Constant(c) => write!(f, "{}/{}", c.module.display(), c.id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -22,8 +25,14 @@ impl<'ast> From<&'ast str> for CoreIdentifier<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for CoreIdentifier<'ast> {
|
||||
fn from(s: CanonicalConstantIdentifier<'ast>) -> CoreIdentifier<'ast> {
|
||||
CoreIdentifier::Constant(s)
|
||||
}
|
||||
}
|
||||
|
||||
/// A identifier for a variable
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct Identifier<'ast> {
|
||||
/// the id of the variable
|
||||
pub id: CoreIdentifier<'ast>,
|
||||
|
@ -52,6 +61,12 @@ impl<'ast> fmt::Display for Identifier<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for Identifier<'ast> {
|
||||
fn from(id: CanonicalConstantIdentifier<'ast>) -> Identifier<'ast> {
|
||||
Identifier::from(CoreIdentifier::Constant(id))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<&'ast str> for Identifier<'ast> {
|
||||
fn from(id: &'ast str) -> Identifier<'ast> {
|
||||
Identifier::from(CoreIdentifier::Source(id))
|
||||
|
|
|
@ -40,7 +40,7 @@ trait IntegerInference: Sized {
|
|||
}
|
||||
|
||||
impl<'ast, T> IntegerInference for Type<'ast, T> {
|
||||
type Pattern = DeclarationType<'ast>;
|
||||
type Pattern = DeclarationType<'ast, T>;
|
||||
|
||||
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
|
||||
match (self, other) {
|
||||
|
@ -72,7 +72,7 @@ impl<'ast, T> IntegerInference for Type<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, T> IntegerInference for ArrayType<'ast, T> {
|
||||
type Pattern = DeclarationArrayType<'ast>;
|
||||
type Pattern = DeclarationArrayType<'ast, T>;
|
||||
|
||||
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
|
||||
let s0 = self.size;
|
||||
|
@ -88,7 +88,7 @@ impl<'ast, T> IntegerInference for ArrayType<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, T> IntegerInference for StructType<'ast, T> {
|
||||
type Pattern = DeclarationStructType<'ast>;
|
||||
type Pattern = DeclarationStructType<'ast, T>;
|
||||
|
||||
fn get_common_pattern(self, other: Self) -> Result<Self::Pattern, (Self, Self)> {
|
||||
Ok(DeclarationStructType {
|
||||
|
@ -228,7 +228,7 @@ impl<'ast, T: Field> TypedExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
|
||||
pub enum IntExpression<'ast, T> {
|
||||
Value(BigUint),
|
||||
Pos(Box<IntExpression<'ast, T>>),
|
||||
|
@ -424,7 +424,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
|||
v,
|
||||
&DeclarationArrayType::new(
|
||||
DeclarationType::FieldElement,
|
||||
DeclarationConstant::Concrete(0),
|
||||
DeclarationConstant::from(0u32),
|
||||
),
|
||||
)
|
||||
.map_err(|(e, _)| match e {
|
||||
|
@ -542,7 +542,7 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
v,
|
||||
&DeclarationArrayType::new(
|
||||
DeclarationType::Uint(*bitwidth),
|
||||
DeclarationConstant::Concrete(0),
|
||||
DeclarationConstant::from(0u32),
|
||||
),
|
||||
)
|
||||
.map_err(|(e, _)| match e {
|
||||
|
|
|
@ -20,9 +20,9 @@ pub use self::identifier::CoreIdentifier;
|
|||
pub use self::parameter::{DeclarationParameter, GParameter};
|
||||
pub use self::types::{
|
||||
CanonicalConstantIdentifier, ConcreteFunctionKey, ConcreteSignature, ConcreteType,
|
||||
ConstantIdentifier, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature,
|
||||
DeclarationStructType, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
|
||||
IntoTypes, Signature, StructType, Type, Types, UBitwidth,
|
||||
ConstantIdentifier, DeclarationArrayType, DeclarationConstant, DeclarationFunctionKey,
|
||||
DeclarationSignature, DeclarationStructType, DeclarationType, GArrayType, GStructType, GType,
|
||||
GenericIdentifier, IntoTypes, Signature, StructType, Type, Types, UBitwidth,
|
||||
};
|
||||
use crate::typed_absy::types::ConcreteGenericsAssignment;
|
||||
|
||||
|
@ -35,7 +35,7 @@ pub use crate::typed_absy::uint::{bitwidth, UExpression, UExpressionInner, UMeta
|
|||
|
||||
use crate::embed::FlatEmbed;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::BTreeMap;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::fmt;
|
||||
|
||||
|
@ -54,14 +54,14 @@ pub type OwnedTypedModuleId = PathBuf;
|
|||
pub type TypedModuleId = Path;
|
||||
|
||||
/// A collection of `TypedModule`s
|
||||
pub type TypedModules<'ast, T> = HashMap<OwnedTypedModuleId, TypedModule<'ast, T>>;
|
||||
pub type TypedModules<'ast, T> = BTreeMap<OwnedTypedModuleId, TypedModule<'ast, T>>;
|
||||
|
||||
/// A collection of `TypedFunctionSymbol`s
|
||||
/// # Remarks
|
||||
/// * It is the role of the semantic checker to make sure there are no duplicates for a given `FunctionKey`
|
||||
/// in a given `TypedModule`, hence the use of a HashMap
|
||||
/// in a given `TypedModule`, hence the use of a BTreeMap
|
||||
pub type TypedFunctionSymbols<'ast, T> =
|
||||
HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>;
|
||||
BTreeMap<DeclarationFunctionKey<'ast, T>, TypedFunctionSymbol<'ast, T>>;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum TypedConstantSymbol<'ast, T> {
|
||||
|
@ -91,12 +91,11 @@ impl<'ast, T> TypedProgram<'ast, T> {
|
|||
|
||||
impl<'ast, T: Field> TypedProgram<'ast, T> {
|
||||
pub fn abi(&self) -> Abi {
|
||||
let main = self.modules[&self.main]
|
||||
.functions
|
||||
.iter()
|
||||
.find(|(id, _)| id.id == "main")
|
||||
let main = &self.modules[&self.main]
|
||||
.functions_iter()
|
||||
.find(|s| s.key.id == "main")
|
||||
.unwrap()
|
||||
.1;
|
||||
.symbol;
|
||||
let main = match main {
|
||||
TypedFunctionSymbol::Here(main) => main,
|
||||
_ => unreachable!(),
|
||||
|
@ -109,7 +108,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
.map(|p| {
|
||||
types::ConcreteType::try_from(
|
||||
crate::typed_absy::types::try_from_g_type::<
|
||||
crate::typed_absy::types::DeclarationConstant<'ast>,
|
||||
DeclarationConstant<'ast, T>,
|
||||
UExpression<'ast, T>,
|
||||
>(p.id._type.clone())
|
||||
.unwrap(),
|
||||
|
@ -129,7 +128,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
.map(|ty| {
|
||||
types::ConcreteType::try_from(
|
||||
crate::typed_absy::types::try_from_g_type::<
|
||||
crate::typed_absy::types::DeclarationConstant<'ast>,
|
||||
DeclarationConstant<'ast, T>,
|
||||
UExpression<'ast, T>,
|
||||
>(ty.clone())
|
||||
.unwrap(),
|
||||
|
@ -163,19 +162,90 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedProgram<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, Clone)]
|
||||
pub struct TypedFunctionSymbolDeclaration<'ast, T> {
|
||||
pub key: DeclarationFunctionKey<'ast, T>,
|
||||
pub symbol: TypedFunctionSymbol<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T> TypedFunctionSymbolDeclaration<'ast, T> {
|
||||
pub fn new(key: DeclarationFunctionKey<'ast, T>, symbol: TypedFunctionSymbol<'ast, T>) -> Self {
|
||||
TypedFunctionSymbolDeclaration { key, symbol }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, Clone)]
|
||||
pub struct TypedConstantSymbolDeclaration<'ast, T> {
|
||||
pub id: CanonicalConstantIdentifier<'ast>,
|
||||
pub symbol: TypedConstantSymbol<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T> TypedConstantSymbolDeclaration<'ast, T> {
|
||||
pub fn new(
|
||||
id: CanonicalConstantIdentifier<'ast>,
|
||||
symbol: TypedConstantSymbol<'ast, T>,
|
||||
) -> Self {
|
||||
TypedConstantSymbolDeclaration { id, symbol }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, Clone)]
|
||||
pub enum TypedSymbolDeclaration<'ast, T> {
|
||||
Function(TypedFunctionSymbolDeclaration<'ast, T>),
|
||||
Constant(TypedConstantSymbolDeclaration<'ast, T>),
|
||||
}
|
||||
|
||||
impl<'ast, T> From<TypedFunctionSymbolDeclaration<'ast, T>> for TypedSymbolDeclaration<'ast, T> {
|
||||
fn from(d: TypedFunctionSymbolDeclaration<'ast, T>) -> Self {
|
||||
Self::Function(d)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<TypedConstantSymbolDeclaration<'ast, T>> for TypedSymbolDeclaration<'ast, T> {
|
||||
fn from(d: TypedConstantSymbolDeclaration<'ast, T>) -> Self {
|
||||
Self::Constant(d)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedSymbolDeclaration<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
TypedSymbolDeclaration::Function(fun) => write!(f, "{}", fun),
|
||||
TypedSymbolDeclaration::Constant(c) => write!(f, "{}", c),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type TypedSymbolDeclarations<'ast, T> = Vec<TypedSymbolDeclaration<'ast, T>>;
|
||||
|
||||
/// A typed module as a collection of functions. Types have been resolved during semantic checking.
|
||||
#[derive(PartialEq, Debug, Clone)]
|
||||
pub struct TypedModule<'ast, T> {
|
||||
/// Functions of the module
|
||||
pub functions: TypedFunctionSymbols<'ast, T>,
|
||||
/// Constants defined in module
|
||||
pub constants: TypedConstantSymbols<'ast, T>,
|
||||
pub symbols: TypedSymbolDeclarations<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T> TypedModule<'ast, T> {
|
||||
pub fn functions_iter(&self) -> impl Iterator<Item = &TypedFunctionSymbolDeclaration<'ast, T>> {
|
||||
self.symbols.iter().filter_map(|s| match s {
|
||||
TypedSymbolDeclaration::Function(d) => Some(d),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_functions_iter(
|
||||
self,
|
||||
) -> impl Iterator<Item = TypedFunctionSymbolDeclaration<'ast, T>> {
|
||||
self.symbols.into_iter().filter_map(|s| match s {
|
||||
TypedSymbolDeclaration::Function(d) => Some(d),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum TypedFunctionSymbol<'ast, T> {
|
||||
Here(TypedFunction<'ast, T>),
|
||||
There(DeclarationFunctionKey<'ast>),
|
||||
There(DeclarationFunctionKey<'ast, T>),
|
||||
Flat(FlatEmbed),
|
||||
}
|
||||
|
||||
|
@ -183,17 +253,61 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> {
|
|||
pub fn signature<'a>(
|
||||
&'a self,
|
||||
modules: &'a TypedModules<'ast, T>,
|
||||
) -> DeclarationSignature<'ast> {
|
||||
) -> DeclarationSignature<'ast, T> {
|
||||
match self {
|
||||
TypedFunctionSymbol::Here(f) => f.signature.clone(),
|
||||
TypedFunctionSymbol::There(key) => modules
|
||||
.get(&key.module)
|
||||
.unwrap()
|
||||
.functions
|
||||
.get(key)
|
||||
.functions_iter()
|
||||
.find(|d| d.key == *key)
|
||||
.unwrap()
|
||||
.symbol
|
||||
.signature(&modules),
|
||||
TypedFunctionSymbol::Flat(flat_fun) => flat_fun.signature(),
|
||||
TypedFunctionSymbol::Flat(flat_fun) => flat_fun.typed_signature(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedConstantSymbolDeclaration<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self.symbol {
|
||||
TypedConstantSymbol::Here(ref tc) => {
|
||||
write!(f, "const {} {} = {}", tc.ty, self.id, tc.expression)
|
||||
}
|
||||
TypedConstantSymbol::There(ref imported_id) => {
|
||||
write!(
|
||||
f,
|
||||
"from \"{}\" import {} as {}",
|
||||
imported_id.module.display(),
|
||||
imported_id.id,
|
||||
self.id
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedFunctionSymbolDeclaration<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self.symbol {
|
||||
TypedFunctionSymbol::Here(ref function) => write!(f, "def {}{}", self.key.id, function),
|
||||
TypedFunctionSymbol::There(ref fun_key) => write!(
|
||||
f,
|
||||
"from \"{}\" import {} as {} // with signature {}",
|
||||
fun_key.module.display(),
|
||||
fun_key.id,
|
||||
self.key.id,
|
||||
self.key.signature
|
||||
),
|
||||
TypedFunctionSymbol::Flat(ref flat_fun) => {
|
||||
write!(
|
||||
f,
|
||||
"def {}{}:\n\t// hidden",
|
||||
self.key.id,
|
||||
flat_fun.typed_signature::<T>()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -201,34 +315,9 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> {
|
|||
impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let res = self
|
||||
.constants
|
||||
.symbols
|
||||
.iter()
|
||||
.map(|(id, symbol)| match symbol {
|
||||
TypedConstantSymbol::Here(ref tc) => {
|
||||
format!("const {} {} = {}", id.ty, id.id, tc)
|
||||
}
|
||||
TypedConstantSymbol::There(ref imported_id) => {
|
||||
format!(
|
||||
"from \"{}\" import {} as {}",
|
||||
imported_id.module.display(),
|
||||
imported_id.id,
|
||||
id.id
|
||||
)
|
||||
}
|
||||
})
|
||||
.chain(self.functions.iter().map(|(key, symbol)| match symbol {
|
||||
TypedFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function),
|
||||
TypedFunctionSymbol::There(ref fun_key) => format!(
|
||||
"from \"{}\" import {} as {} // with signature {}",
|
||||
fun_key.module.display(),
|
||||
fun_key.id,
|
||||
key.id,
|
||||
key.signature
|
||||
),
|
||||
TypedFunctionSymbol::Flat(ref flat_fun) => {
|
||||
format!("def {}{}:\n\t// hidden", key.id, flat_fun.signature())
|
||||
}
|
||||
}))
|
||||
.map(|s| format!("{}", s))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
write!(f, "{}", res.join("\n"))
|
||||
|
@ -239,11 +328,11 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
|
|||
#[derive(Clone, PartialEq, Debug, Hash)]
|
||||
pub struct TypedFunction<'ast, T> {
|
||||
/// Arguments of the function
|
||||
pub arguments: Vec<DeclarationParameter<'ast>>,
|
||||
pub arguments: Vec<DeclarationParameter<'ast, T>>,
|
||||
/// Vector of statements that are executed when running the function
|
||||
pub statements: Vec<TypedStatement<'ast, T>>,
|
||||
/// function signature
|
||||
pub signature: DeclarationSignature<'ast>,
|
||||
pub signature: DeclarationSignature<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
|
||||
|
@ -312,11 +401,12 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
|
|||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct TypedConstant<'ast, T> {
|
||||
pub expression: TypedExpression<'ast, T>,
|
||||
pub ty: DeclarationType<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, T> TypedConstant<'ast, T> {
|
||||
pub fn new(expression: TypedExpression<'ast, T>) -> Self {
|
||||
TypedConstant { expression }
|
||||
pub fn new(expression: TypedExpression<'ast, T>, ty: DeclarationType<'ast, T>) -> Self {
|
||||
TypedConstant { expression, ty }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -333,14 +423,14 @@ impl<'ast, T: Field> Typed<'ast, T> for TypedConstant<'ast, T> {
|
|||
}
|
||||
|
||||
/// Something we can assign to.
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum TypedAssignee<'ast, T> {
|
||||
Identifier(Variable<'ast, T>),
|
||||
Select(Box<TypedAssignee<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Member(Box<TypedAssignee<'ast, T>>, MemberId),
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Debug)]
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
|
||||
pub struct TypedSpread<'ast, T> {
|
||||
pub array: ArrayExpression<'ast, T>,
|
||||
}
|
||||
|
@ -357,7 +447,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedSpread<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum TypedExpressionOrSpread<'ast, T> {
|
||||
Expression(TypedExpression<'ast, T>),
|
||||
Spread(TypedSpread<'ast, T>),
|
||||
|
@ -487,7 +577,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedAssignee<'ast, T> {
|
|||
|
||||
/// A statement in a `TypedFunction`
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum TypedStatement<'ast, T> {
|
||||
Return(Vec<TypedExpression<'ast, T>>),
|
||||
Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>),
|
||||
|
@ -502,7 +592,7 @@ pub enum TypedStatement<'ast, T> {
|
|||
MultipleDefinition(Vec<TypedAssignee<'ast, T>>, TypedExpressionList<'ast, T>),
|
||||
// Aux
|
||||
PushCallLog(
|
||||
DeclarationFunctionKey<'ast>,
|
||||
DeclarationFunctionKey<'ast, T>,
|
||||
ConcreteGenericsAssignment<'ast>,
|
||||
),
|
||||
PopCallLog,
|
||||
|
@ -575,7 +665,7 @@ pub trait Typed<'ast, T> {
|
|||
|
||||
/// A typed expression
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum TypedExpression<'ast, T> {
|
||||
Boolean(BooleanExpression<'ast, T>),
|
||||
FieldElement(FieldElementExpression<'ast, T>),
|
||||
|
@ -714,7 +804,7 @@ pub trait MultiTyped<'ast, T> {
|
|||
fn get_types(&self) -> &Vec<Type<'ast, T>>;
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
|
||||
pub struct TypedExpressionList<'ast, T> {
|
||||
pub inner: TypedExpressionListInner<'ast, T>,
|
||||
|
@ -727,7 +817,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionList<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum TypedExpressionListInner<'ast, T> {
|
||||
FunctionCall(FunctionCallExpression<'ast, T, TypedExpressionList<'ast, T>>),
|
||||
EmbedCall(FlatEmbed, Vec<u32>, Vec<TypedExpression<'ast, T>>),
|
||||
|
@ -744,7 +834,7 @@ impl<'ast, T> TypedExpressionListInner<'ast, T> {
|
|||
TypedExpressionList { inner: self, types }
|
||||
}
|
||||
}
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct BlockExpression<'ast, T, E> {
|
||||
pub statements: Vec<TypedStatement<'ast, T>>,
|
||||
pub value: Box<E>,
|
||||
|
@ -759,7 +849,7 @@ impl<'ast, T, E> BlockExpression<'ast, T, E> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct MemberExpression<'ast, T, E> {
|
||||
pub struc: Box<StructExpression<'ast, T>>,
|
||||
pub id: MemberId,
|
||||
|
@ -782,7 +872,7 @@ impl<'ast, T: fmt::Display, E> fmt::Display for MemberExpression<'ast, T, E> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct SelectExpression<'ast, T, E> {
|
||||
pub array: Box<ArrayExpression<'ast, T>>,
|
||||
pub index: Box<UExpression<'ast, T>>,
|
||||
|
@ -805,7 +895,7 @@ impl<'ast, T: fmt::Display, E> fmt::Display for SelectExpression<'ast, T, E> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct IfElseExpression<'ast, T, E> {
|
||||
pub condition: Box<BooleanExpression<'ast, T>>,
|
||||
pub consequence: Box<E>,
|
||||
|
@ -832,9 +922,9 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for IfElseExpression<'
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct FunctionCallExpression<'ast, T, E> {
|
||||
pub function_key: DeclarationFunctionKey<'ast>,
|
||||
pub function_key: DeclarationFunctionKey<'ast, T>,
|
||||
pub generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
pub arguments: Vec<TypedExpression<'ast, T>>,
|
||||
ty: PhantomData<E>,
|
||||
|
@ -842,7 +932,7 @@ pub struct FunctionCallExpression<'ast, T, E> {
|
|||
|
||||
impl<'ast, T, E> FunctionCallExpression<'ast, T, E> {
|
||||
pub fn new(
|
||||
function_key: DeclarationFunctionKey<'ast>,
|
||||
function_key: DeclarationFunctionKey<'ast, T>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self {
|
||||
|
@ -885,7 +975,7 @@ impl<'ast, T: fmt::Display, E> fmt::Display for FunctionCallExpression<'ast, T,
|
|||
}
|
||||
|
||||
/// An expression of type `field`
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum FieldElementExpression<'ast, T> {
|
||||
Block(BlockExpression<'ast, T, Self>),
|
||||
Number(T),
|
||||
|
@ -962,7 +1052,7 @@ impl<'ast, T> From<T> for FieldElementExpression<'ast, T> {
|
|||
}
|
||||
|
||||
/// An expression of type `bool`
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum BooleanExpression<'ast, T> {
|
||||
Block(BlockExpression<'ast, T, Self>),
|
||||
Identifier(Identifier<'ast>),
|
||||
|
@ -1027,13 +1117,13 @@ impl<'ast, T> From<bool> for BooleanExpression<'ast, T> {
|
|||
/// * Contrary to basic types which are represented as enums, we wrap an enum `ArrayExpressionInner` in a struct in order to keep track of the type (content and size)
|
||||
/// of the array. Only using an enum would require generics, which would propagate up to TypedExpression which we want to keep simple, hence this "runtime"
|
||||
/// type checking
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct ArrayExpression<'ast, T> {
|
||||
pub ty: Box<ArrayType<'ast, T>>,
|
||||
pub inner: ArrayExpressionInner<'ast, T>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
|
||||
pub struct ArrayValue<'ast, T>(pub Vec<TypedExpressionOrSpread<'ast, T>>);
|
||||
|
||||
impl<'ast, T> From<Vec<TypedExpressionOrSpread<'ast, T>>> for ArrayValue<'ast, T> {
|
||||
|
@ -1111,7 +1201,7 @@ impl<'ast, T> std::iter::FromIterator<TypedExpressionOrSpread<'ast, T>> for Arra
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum ArrayExpressionInner<'ast, T> {
|
||||
Block(BlockExpression<'ast, T, ArrayExpression<'ast, T>>),
|
||||
Identifier(Identifier<'ast>),
|
||||
|
@ -1151,7 +1241,7 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct StructExpression<'ast, T> {
|
||||
ty: StructType<'ast, T>,
|
||||
inner: StructExpressionInner<'ast, T>,
|
||||
|
@ -1175,7 +1265,7 @@ impl<'ast, T> StructExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub enum StructExpressionInner<'ast, T> {
|
||||
Block(BlockExpression<'ast, T, StructExpression<'ast, T>>),
|
||||
Identifier(Identifier<'ast>),
|
||||
|
@ -1897,7 +1987,7 @@ impl<'ast, T: Field> Id<'ast, T> for TypedExpressionList<'ast, T> {
|
|||
|
||||
pub trait FunctionCall<'ast, T>: Expr<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
key: DeclarationFunctionKey<'ast, T>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self::Inner;
|
||||
|
@ -1905,7 +1995,7 @@ pub trait FunctionCall<'ast, T>: Expr<'ast, T> {
|
|||
|
||||
impl<'ast, T: Field> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
key: DeclarationFunctionKey<'ast, T>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self::Inner {
|
||||
|
@ -1915,7 +2005,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> {
|
|||
|
||||
impl<'ast, T: Field> FunctionCall<'ast, T> for BooleanExpression<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
key: DeclarationFunctionKey<'ast, T>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self::Inner {
|
||||
|
@ -1925,7 +2015,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for BooleanExpression<'ast, T> {
|
|||
|
||||
impl<'ast, T: Field> FunctionCall<'ast, T> for UExpression<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
key: DeclarationFunctionKey<'ast, T>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self::Inner {
|
||||
|
@ -1935,7 +2025,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for UExpression<'ast, T> {
|
|||
|
||||
impl<'ast, T: Field> FunctionCall<'ast, T> for ArrayExpression<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
key: DeclarationFunctionKey<'ast, T>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self::Inner {
|
||||
|
@ -1945,7 +2035,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for ArrayExpression<'ast, T> {
|
|||
|
||||
impl<'ast, T: Field> FunctionCall<'ast, T> for StructExpression<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
key: DeclarationFunctionKey<'ast, T>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self::Inner {
|
||||
|
@ -1955,7 +2045,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for StructExpression<'ast, T> {
|
|||
|
||||
impl<'ast, T: Field> FunctionCall<'ast, T> for TypedExpressionList<'ast, T> {
|
||||
fn function_call(
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
key: DeclarationFunctionKey<'ast, T>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
) -> Self::Inner {
|
||||
|
|
|
@ -18,12 +18,12 @@ impl<'ast, S> From<GVariable<'ast, S>> for GParameter<'ast, S> {
|
|||
}
|
||||
}
|
||||
|
||||
pub type DeclarationParameter<'ast> = GParameter<'ast, DeclarationConstant<'ast>>;
|
||||
pub type DeclarationParameter<'ast, T> = GParameter<'ast, DeclarationConstant<'ast, T>>;
|
||||
|
||||
impl<'ast, S: fmt::Display + Clone> fmt::Display for GParameter<'ast, S> {
|
||||
impl<'ast, S: fmt::Display> fmt::Display for GParameter<'ast, S> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let visibility = if self.private { "private " } else { "" };
|
||||
write!(f, "{}{} {}", visibility, self.id.get_type(), self.id.id)
|
||||
write!(f, "{}{} {}", visibility, self.id._type, self.id.id)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -55,6 +55,27 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
fold_module(self, m)
|
||||
}
|
||||
|
||||
fn fold_symbol_declaration(
|
||||
&mut self,
|
||||
s: TypedSymbolDeclaration<'ast, T>,
|
||||
) -> Result<TypedSymbolDeclaration<'ast, T>, Self::Error> {
|
||||
fold_symbol_declaration(self, s)
|
||||
}
|
||||
|
||||
fn fold_function_symbol_declaration(
|
||||
&mut self,
|
||||
s: TypedFunctionSymbolDeclaration<'ast, T>,
|
||||
) -> Result<TypedFunctionSymbolDeclaration<'ast, T>, Self::Error> {
|
||||
fold_function_symbol_declaration(self, s)
|
||||
}
|
||||
|
||||
fn fold_constant_symbol_declaration(
|
||||
&mut self,
|
||||
s: TypedConstantSymbolDeclaration<'ast, T>,
|
||||
) -> Result<TypedConstantSymbolDeclaration<'ast, T>, Self::Error> {
|
||||
fold_constant_symbol_declaration(self, s)
|
||||
}
|
||||
|
||||
fn fold_constant(
|
||||
&mut self,
|
||||
c: TypedConstant<'ast, T>,
|
||||
|
@ -78,8 +99,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_declaration_function_key(
|
||||
&mut self,
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
) -> Result<DeclarationFunctionKey<'ast>, Self::Error> {
|
||||
key: DeclarationFunctionKey<'ast, T>,
|
||||
) -> Result<DeclarationFunctionKey<'ast, T>, Self::Error> {
|
||||
fold_declaration_function_key(self, key)
|
||||
}
|
||||
|
||||
|
@ -92,22 +113,22 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_signature(
|
||||
&mut self,
|
||||
s: DeclarationSignature<'ast>,
|
||||
) -> Result<DeclarationSignature<'ast>, Self::Error> {
|
||||
s: DeclarationSignature<'ast, T>,
|
||||
) -> Result<DeclarationSignature<'ast, T>, Self::Error> {
|
||||
fold_signature(self, s)
|
||||
}
|
||||
|
||||
fn fold_declaration_constant(
|
||||
&mut self,
|
||||
c: DeclarationConstant<'ast>,
|
||||
) -> Result<DeclarationConstant<'ast>, Self::Error> {
|
||||
c: DeclarationConstant<'ast, T>,
|
||||
) -> Result<DeclarationConstant<'ast, T>, Self::Error> {
|
||||
fold_declaration_constant(self, c)
|
||||
}
|
||||
|
||||
fn fold_parameter(
|
||||
&mut self,
|
||||
p: DeclarationParameter<'ast>,
|
||||
) -> Result<DeclarationParameter<'ast>, Self::Error> {
|
||||
p: DeclarationParameter<'ast, T>,
|
||||
) -> Result<DeclarationParameter<'ast, T>, Self::Error> {
|
||||
Ok(DeclarationParameter {
|
||||
id: self.fold_declaration_variable(p.id)?,
|
||||
..p
|
||||
|
@ -121,7 +142,6 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
Ok(CanonicalConstantIdentifier {
|
||||
module: self.fold_module_id(i.module)?,
|
||||
id: i.id,
|
||||
ty: box self.fold_declaration_type(*i.ty)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -142,8 +162,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_declaration_variable(
|
||||
&mut self,
|
||||
v: DeclarationVariable<'ast>,
|
||||
) -> Result<DeclarationVariable<'ast>, Self::Error> {
|
||||
v: DeclarationVariable<'ast, T>,
|
||||
) -> Result<DeclarationVariable<'ast, T>, Self::Error> {
|
||||
Ok(DeclarationVariable {
|
||||
id: self.fold_name(v.id)?,
|
||||
_type: self.fold_declaration_type(v._type)?,
|
||||
|
@ -246,8 +266,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_declaration_type(
|
||||
&mut self,
|
||||
t: DeclarationType<'ast>,
|
||||
) -> Result<DeclarationType<'ast>, Self::Error> {
|
||||
t: DeclarationType<'ast, T>,
|
||||
) -> Result<DeclarationType<'ast, T>, Self::Error> {
|
||||
use self::GType::*;
|
||||
|
||||
match t {
|
||||
|
@ -259,8 +279,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_declaration_array_type(
|
||||
&mut self,
|
||||
t: DeclarationArrayType<'ast>,
|
||||
) -> Result<DeclarationArrayType<'ast>, Self::Error> {
|
||||
t: DeclarationArrayType<'ast, T>,
|
||||
) -> Result<DeclarationArrayType<'ast, T>, Self::Error> {
|
||||
Ok(DeclarationArrayType {
|
||||
ty: box self.fold_declaration_type(*t.ty)?,
|
||||
size: self.fold_declaration_constant(t.size)?,
|
||||
|
@ -269,8 +289,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
|
||||
fn fold_declaration_struct_type(
|
||||
&mut self,
|
||||
t: DeclarationStructType<'ast>,
|
||||
) -> Result<DeclarationStructType<'ast>, Self::Error> {
|
||||
t: DeclarationStructType<'ast, T>,
|
||||
) -> Result<DeclarationStructType<'ast, T>, Self::Error> {
|
||||
Ok(DeclarationStructType {
|
||||
generics: t
|
||||
.generics
|
||||
|
@ -977,8 +997,8 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
|
||||
pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
key: DeclarationFunctionKey<'ast>,
|
||||
) -> Result<DeclarationFunctionKey<'ast>, F::Error> {
|
||||
key: DeclarationFunctionKey<'ast, T>,
|
||||
) -> Result<DeclarationFunctionKey<'ast, T>, F::Error> {
|
||||
Ok(DeclarationFunctionKey {
|
||||
module: f.fold_module_id(key.module)?,
|
||||
signature: f.fold_signature(key.signature)?,
|
||||
|
@ -1008,12 +1028,16 @@ pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
})
|
||||
}
|
||||
|
||||
fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
pub fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
s: DeclarationSignature<'ast>,
|
||||
) -> Result<DeclarationSignature<'ast>, F::Error> {
|
||||
s: DeclarationSignature<'ast, T>,
|
||||
) -> Result<DeclarationSignature<'ast, T>, F::Error> {
|
||||
Ok(DeclarationSignature {
|
||||
generics: s.generics,
|
||||
generics: s
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|g| g.map(|g| f.fold_declaration_constant(g)).transpose())
|
||||
.collect::<Result<_, _>>()?,
|
||||
inputs: s
|
||||
.inputs
|
||||
.into_iter()
|
||||
|
@ -1027,11 +1051,16 @@ fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
})
|
||||
}
|
||||
|
||||
fn fold_declaration_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
_: &mut F,
|
||||
c: DeclarationConstant<'ast>,
|
||||
) -> Result<DeclarationConstant<'ast>, F::Error> {
|
||||
Ok(c)
|
||||
pub fn fold_declaration_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
c: DeclarationConstant<'ast, T>,
|
||||
) -> Result<DeclarationConstant<'ast, T>, F::Error> {
|
||||
match c {
|
||||
DeclarationConstant::Expression(e) => {
|
||||
Ok(DeclarationConstant::Expression(f.fold_expression(e)?))
|
||||
}
|
||||
c => Ok(c),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
|
@ -1115,6 +1144,7 @@ pub fn fold_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
) -> Result<TypedConstant<'ast, T>, F::Error> {
|
||||
Ok(TypedConstant {
|
||||
expression: f.fold_expression(c.expression)?,
|
||||
ty: f.fold_declaration_type(c.ty)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1143,20 +1173,49 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fold_symbol_declaration<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
d: TypedSymbolDeclaration<'ast, T>,
|
||||
) -> Result<TypedSymbolDeclaration<'ast, T>, F::Error> {
|
||||
Ok(match d {
|
||||
TypedSymbolDeclaration::Function(d) => {
|
||||
TypedSymbolDeclaration::Function(f.fold_function_symbol_declaration(d)?)
|
||||
}
|
||||
TypedSymbolDeclaration::Constant(d) => {
|
||||
TypedSymbolDeclaration::Constant(f.fold_constant_symbol_declaration(d)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_function_symbol_declaration<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
d: TypedFunctionSymbolDeclaration<'ast, T>,
|
||||
) -> Result<TypedFunctionSymbolDeclaration<'ast, T>, F::Error> {
|
||||
Ok(TypedFunctionSymbolDeclaration {
|
||||
key: f.fold_declaration_function_key(d.key)?,
|
||||
symbol: f.fold_function_symbol(d.symbol)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_constant_symbol_declaration<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
d: TypedConstantSymbolDeclaration<'ast, T>,
|
||||
) -> Result<TypedConstantSymbolDeclaration<'ast, T>, F::Error> {
|
||||
Ok(TypedConstantSymbolDeclaration {
|
||||
id: f.fold_canonical_constant_identifier(d.id)?,
|
||||
symbol: f.fold_constant_symbol(d.symbol)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fold_module<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
m: TypedModule<'ast, T>,
|
||||
) -> Result<TypedModule<'ast, T>, F::Error> {
|
||||
Ok(TypedModule {
|
||||
constants: m
|
||||
.constants
|
||||
symbols: m
|
||||
.symbols
|
||||
.into_iter()
|
||||
.map(|(key, tc)| f.fold_constant_symbol(tc).map(|tc| (key, tc)))
|
||||
.collect::<Result<_, _>>()?,
|
||||
functions: m
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|(key, fun)| f.fold_function_symbol(fun).map(|f| (key, f)))
|
||||
.map(|s| f.fold_symbol_declaration(s))
|
||||
.collect::<Result<_, _>>()?,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use crate::typed_absy::{Identifier, OwnedTypedModuleId, UExpression, UExpressionInner};
|
||||
use crate::typed_absy::{
|
||||
CoreIdentifier, Identifier, OwnedTypedModuleId, TypedExpression, UExpression, UExpressionInner,
|
||||
};
|
||||
use crate::typed_absy::{TryFrom, TryInto};
|
||||
use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::collections::BTreeMap;
|
||||
|
@ -46,7 +48,7 @@ impl<'ast, T> IntoTypes<'ast, T> for Types<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct Types<'ast, T> {
|
||||
pub inner: Vec<Type<'ast, T>>,
|
||||
}
|
||||
|
@ -107,69 +109,77 @@ pub type ConstantIdentifier<'ast> = &'ast str;
|
|||
pub struct CanonicalConstantIdentifier<'ast> {
|
||||
pub module: OwnedTypedModuleId,
|
||||
pub id: ConstantIdentifier<'ast>,
|
||||
pub ty: Box<DeclarationType<'ast>>,
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for CanonicalConstantIdentifier<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}/{}", self.module.display(), self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> CanonicalConstantIdentifier<'ast> {
|
||||
pub fn new(
|
||||
id: ConstantIdentifier<'ast>,
|
||||
module: OwnedTypedModuleId,
|
||||
ty: DeclarationType<'ast>,
|
||||
) -> Self {
|
||||
CanonicalConstantIdentifier {
|
||||
module,
|
||||
id,
|
||||
ty: box ty,
|
||||
}
|
||||
pub fn new(id: ConstantIdentifier<'ast>, module: OwnedTypedModuleId) -> Self {
|
||||
CanonicalConstantIdentifier { module, id }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub enum DeclarationConstant<'ast> {
|
||||
pub enum DeclarationConstant<'ast, T> {
|
||||
Generic(GenericIdentifier<'ast>),
|
||||
Concrete(u32),
|
||||
Constant(CanonicalConstantIdentifier<'ast>),
|
||||
Expression(TypedExpression<'ast, T>),
|
||||
}
|
||||
|
||||
impl<'ast, T> PartialEq<UExpression<'ast, T>> for DeclarationConstant<'ast> {
|
||||
impl<'ast, T: PartialEq> PartialEq<UExpression<'ast, T>> for DeclarationConstant<'ast, T> {
|
||||
fn eq(&self, other: &UExpression<'ast, T>) -> bool {
|
||||
match (self, other.as_inner()) {
|
||||
(DeclarationConstant::Concrete(c), UExpressionInner::Value(v)) => *c == *v as u32,
|
||||
match (self, other) {
|
||||
(
|
||||
DeclarationConstant::Concrete(c),
|
||||
UExpression {
|
||||
bitwidth: UBitwidth::B32,
|
||||
inner: UExpressionInner::Value(v),
|
||||
..
|
||||
},
|
||||
) => *c == *v as u32,
|
||||
(DeclarationConstant::Expression(TypedExpression::Uint(e0)), e1) => e0 == e1,
|
||||
(DeclarationConstant::Expression(..), _) => false, // type error
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> PartialEq<DeclarationConstant<'ast>> for UExpression<'ast, T> {
|
||||
fn eq(&self, other: &DeclarationConstant<'ast>) -> bool {
|
||||
impl<'ast, T: PartialEq> PartialEq<DeclarationConstant<'ast, T>> for UExpression<'ast, T> {
|
||||
fn eq(&self, other: &DeclarationConstant<'ast, T>) -> bool {
|
||||
other.eq(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<u32> for DeclarationConstant<'ast> {
|
||||
impl<'ast, T> From<u32> for DeclarationConstant<'ast, T> {
|
||||
fn from(e: u32) -> Self {
|
||||
DeclarationConstant::Concrete(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<usize> for DeclarationConstant<'ast> {
|
||||
impl<'ast, T> From<usize> for DeclarationConstant<'ast, T> {
|
||||
fn from(e: usize) -> Self {
|
||||
DeclarationConstant::Concrete(e as u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<GenericIdentifier<'ast>> for DeclarationConstant<'ast> {
|
||||
impl<'ast, T> From<GenericIdentifier<'ast>> for DeclarationConstant<'ast, T> {
|
||||
fn from(e: GenericIdentifier<'ast>) -> Self {
|
||||
DeclarationConstant::Generic(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for DeclarationConstant<'ast> {
|
||||
impl<'ast, T: fmt::Display> fmt::Display for DeclarationConstant<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
DeclarationConstant::Generic(i) => write!(f, "{}", i),
|
||||
DeclarationConstant::Concrete(v) => write!(f, "{}", v),
|
||||
DeclarationConstant::Constant(v) => write!(f, "{}/{}", v.module.display(), v.id),
|
||||
DeclarationConstant::Expression(e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -180,8 +190,8 @@ impl<'ast, T> From<usize> for UExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<DeclarationConstant<'ast>> for UExpression<'ast, T> {
|
||||
fn from(c: DeclarationConstant<'ast>) -> Self {
|
||||
impl<'ast, T> From<DeclarationConstant<'ast, T>> for UExpression<'ast, T> {
|
||||
fn from(c: DeclarationConstant<'ast, T>) -> Self {
|
||||
match c {
|
||||
DeclarationConstant::Generic(i) => {
|
||||
UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32)
|
||||
|
@ -190,8 +200,10 @@ impl<'ast, T> From<DeclarationConstant<'ast>> for UExpression<'ast, T> {
|
|||
UExpressionInner::Value(v as u128).annotate(UBitwidth::B32)
|
||||
}
|
||||
DeclarationConstant::Constant(v) => {
|
||||
UExpressionInner::Identifier(Identifier::from(v.id)).annotate(UBitwidth::B32)
|
||||
UExpressionInner::Identifier(CoreIdentifier::from(v).into())
|
||||
.annotate(UBitwidth::B32)
|
||||
}
|
||||
DeclarationConstant::Expression(e) => e.try_into().unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -209,7 +221,7 @@ impl<'ast, T> TryInto<usize> for UExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> TryInto<usize> for DeclarationConstant<'ast> {
|
||||
impl<'ast, T> TryInto<usize> for DeclarationConstant<'ast, T> {
|
||||
type Error = SpecializationError;
|
||||
|
||||
fn try_into(self) -> Result<usize, Self::Error> {
|
||||
|
@ -230,7 +242,7 @@ pub struct GStructMember<S> {
|
|||
pub ty: Box<GType<S>>,
|
||||
}
|
||||
|
||||
pub type DeclarationStructMember<'ast> = GStructMember<DeclarationConstant<'ast>>;
|
||||
pub type DeclarationStructMember<'ast, T> = GStructMember<DeclarationConstant<'ast, T>>;
|
||||
pub type ConcreteStructMember = GStructMember<usize>;
|
||||
pub type StructMember<'ast, T> = GStructMember<UExpression<'ast, T>>;
|
||||
|
||||
|
@ -270,7 +282,7 @@ pub struct GArrayType<S> {
|
|||
pub ty: Box<GType<S>>,
|
||||
}
|
||||
|
||||
pub type DeclarationArrayType<'ast> = GArrayType<DeclarationConstant<'ast>>;
|
||||
pub type DeclarationArrayType<'ast, T> = GArrayType<DeclarationConstant<'ast, T>>;
|
||||
pub type ConcreteArrayType = GArrayType<usize>;
|
||||
pub type ArrayType<'ast, T> = GArrayType<UExpression<'ast, T>>;
|
||||
|
||||
|
@ -336,7 +348,7 @@ pub struct StructLocation {
|
|||
pub name: String,
|
||||
}
|
||||
|
||||
impl<'ast> From<ConcreteArrayType> for DeclarationArrayType<'ast> {
|
||||
impl<'ast, T> From<ConcreteArrayType> for DeclarationArrayType<'ast, T> {
|
||||
fn from(t: ConcreteArrayType) -> Self {
|
||||
try_from_g_array_type(t).unwrap()
|
||||
}
|
||||
|
@ -352,7 +364,7 @@ pub struct GStructType<S> {
|
|||
pub members: Vec<GStructMember<S>>,
|
||||
}
|
||||
|
||||
pub type DeclarationStructType<'ast> = GStructType<DeclarationConstant<'ast>>;
|
||||
pub type DeclarationStructType<'ast, T> = GStructType<DeclarationConstant<'ast, T>>;
|
||||
pub type ConcreteStructType = GStructType<usize>;
|
||||
pub type StructType<'ast, T> = GStructType<UExpression<'ast, T>>;
|
||||
|
||||
|
@ -416,7 +428,7 @@ impl<'ast, T> From<ConcreteStructType> for StructType<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<ConcreteStructType> for DeclarationStructType<'ast> {
|
||||
impl<'ast, T> From<ConcreteStructType> for DeclarationStructType<'ast, T> {
|
||||
fn from(t: ConcreteStructType) -> Self {
|
||||
try_from_g_struct_type(t).unwrap()
|
||||
}
|
||||
|
@ -609,7 +621,7 @@ impl<'de, S: Deserialize<'de>> Deserialize<'de> for GType<S> {
|
|||
}
|
||||
}
|
||||
|
||||
pub type DeclarationType<'ast> = GType<DeclarationConstant<'ast>>;
|
||||
pub type DeclarationType<'ast, T> = GType<DeclarationConstant<'ast, T>>;
|
||||
pub type ConcreteType = GType<usize>;
|
||||
pub type Type<'ast, T> = GType<UExpression<'ast, T>>;
|
||||
|
||||
|
@ -652,7 +664,7 @@ impl<'ast, T> From<ConcreteType> for Type<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<ConcreteType> for DeclarationType<'ast> {
|
||||
impl<'ast, T> From<ConcreteType> for DeclarationType<'ast, T> {
|
||||
fn from(t: ConcreteType) -> Self {
|
||||
try_from_g_type(t).unwrap()
|
||||
}
|
||||
|
@ -738,7 +750,7 @@ impl<S> GType<S> {
|
|||
}
|
||||
|
||||
impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
|
||||
pub fn can_be_specialized_to(&self, other: &DeclarationType) -> bool {
|
||||
pub fn can_be_specialized_to(&self, other: &DeclarationType<'ast, T>) -> bool {
|
||||
use self::GType::*;
|
||||
|
||||
if other == self {
|
||||
|
@ -811,14 +823,14 @@ impl ConcreteType {
|
|||
|
||||
pub type FunctionIdentifier<'ast> = &'ast str;
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
|
||||
#[derive(PartialEq, Eq, Hash, Debug, Clone, PartialOrd, Ord)]
|
||||
pub struct GFunctionKey<'ast, S> {
|
||||
pub module: OwnedTypedModuleId,
|
||||
pub id: FunctionIdentifier<'ast>,
|
||||
pub signature: GSignature<S>,
|
||||
}
|
||||
|
||||
pub type DeclarationFunctionKey<'ast> = GFunctionKey<'ast, DeclarationConstant<'ast>>;
|
||||
pub type DeclarationFunctionKey<'ast, T> = GFunctionKey<'ast, DeclarationConstant<'ast, T>>;
|
||||
pub type ConcreteFunctionKey<'ast> = GFunctionKey<'ast, usize>;
|
||||
pub type FunctionKey<'ast, T> = GFunctionKey<'ast, UExpression<'ast, T>>;
|
||||
|
||||
|
@ -828,7 +840,7 @@ impl<'ast, S: fmt::Display> fmt::Display for GFunctionKey<'ast, S> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
|
||||
pub struct GGenericsAssignment<'ast, S>(pub BTreeMap<GenericIdentifier<'ast>, S>);
|
||||
|
||||
pub type ConcreteGenericsAssignment<'ast> = GGenericsAssignment<'ast, usize>;
|
||||
|
@ -854,8 +866,8 @@ impl<'ast, S: fmt::Display> fmt::Display for GGenericsAssignment<'ast, S> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> PartialEq<DeclarationFunctionKey<'ast>> for ConcreteFunctionKey<'ast> {
|
||||
fn eq(&self, other: &DeclarationFunctionKey<'ast>) -> bool {
|
||||
impl<'ast, T> PartialEq<DeclarationFunctionKey<'ast, T>> for ConcreteFunctionKey<'ast> {
|
||||
fn eq(&self, other: &DeclarationFunctionKey<'ast, T>) -> bool {
|
||||
self.module == other.module && self.id == other.id && self.signature == other.signature
|
||||
}
|
||||
}
|
||||
|
@ -884,7 +896,7 @@ impl<'ast, T> From<ConcreteFunctionKey<'ast>> for FunctionKey<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<ConcreteFunctionKey<'ast>> for DeclarationFunctionKey<'ast> {
|
||||
impl<'ast, T> From<ConcreteFunctionKey<'ast>> for DeclarationFunctionKey<'ast, T> {
|
||||
fn from(k: ConcreteFunctionKey<'ast>) -> Self {
|
||||
try_from_g_function_key(k).unwrap()
|
||||
}
|
||||
|
@ -931,8 +943,8 @@ impl<'ast> ConcreteFunctionKey<'ast> {
|
|||
|
||||
use std::collections::btree_map::Entry;
|
||||
|
||||
pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
|
||||
decl_ty: &DeclarationType<'ast>,
|
||||
pub fn check_type<'ast, T, S: Clone + PartialEq + PartialEq<usize>>(
|
||||
decl_ty: &DeclarationType<'ast, T>,
|
||||
ty: >ype<S>,
|
||||
constants: &mut GGenericsAssignment<'ast, S>,
|
||||
) -> bool {
|
||||
|
@ -953,9 +965,9 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
|
|||
}
|
||||
},
|
||||
DeclarationConstant::Concrete(s0) => s1 == *s0 as usize,
|
||||
// in the case of a constant, we do not know the value yet, so we optimistically assume it's correct
|
||||
// in the other cases, we do not know the value yet, so we optimistically assume it's correct
|
||||
// if it does not match, it will be caught during inlining
|
||||
DeclarationConstant::Constant(..) => true,
|
||||
DeclarationConstant::Constant(..) | DeclarationConstant::Expression(..) => true,
|
||||
}
|
||||
}
|
||||
(DeclarationType::FieldElement, GType::FieldElement)
|
||||
|
@ -963,6 +975,11 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
|
|||
(DeclarationType::Uint(b0), GType::Uint(b1)) => b0 == b1,
|
||||
(DeclarationType::Struct(s0), GType::Struct(s1)) => {
|
||||
s0.canonical_location == s1.canonical_location
|
||||
&& s0
|
||||
.members
|
||||
.iter()
|
||||
.zip(s1.members.iter())
|
||||
.all(|(m0, m1)| check_type(&*m0.ty, &*m1.ty, constants))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
|
@ -970,16 +987,12 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq<usize>>(
|
|||
|
||||
impl<'ast, T> From<CanonicalConstantIdentifier<'ast>> for UExpression<'ast, T> {
|
||||
fn from(c: CanonicalConstantIdentifier<'ast>) -> Self {
|
||||
let bitwidth = match *c.ty {
|
||||
DeclarationType::Uint(bitwidth) => bitwidth,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
UExpressionInner::Identifier(Identifier::from(c.id)).annotate(bitwidth)
|
||||
UExpressionInner::Identifier(Identifier::from(CoreIdentifier::Constant(c)))
|
||||
.annotate(UBitwidth::B32)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<CanonicalConstantIdentifier<'ast>> for DeclarationConstant<'ast> {
|
||||
impl<'ast, T> From<CanonicalConstantIdentifier<'ast>> for DeclarationConstant<'ast, T> {
|
||||
fn from(c: CanonicalConstantIdentifier<'ast>) -> Self {
|
||||
DeclarationConstant::Constant(c)
|
||||
}
|
||||
|
@ -987,21 +1000,21 @@ impl<'ast> From<CanonicalConstantIdentifier<'ast>> for DeclarationConstant<'ast>
|
|||
|
||||
pub fn specialize_declaration_type<
|
||||
'ast,
|
||||
T,
|
||||
S: Clone + PartialEq + From<u32> + fmt::Debug + From<CanonicalConstantIdentifier<'ast>>,
|
||||
>(
|
||||
decl_ty: DeclarationType<'ast>,
|
||||
decl_ty: DeclarationType<'ast, T>,
|
||||
generics: &GGenericsAssignment<'ast, S>,
|
||||
) -> Result<GType<S>, GenericIdentifier<'ast>> {
|
||||
Ok(match decl_ty {
|
||||
DeclarationType::Int => unreachable!(),
|
||||
DeclarationType::Array(t0) => {
|
||||
// let s1 = t1.size.clone();
|
||||
|
||||
let ty = box specialize_declaration_type(*t0.ty, &generics)?;
|
||||
let size = match t0.size {
|
||||
DeclarationConstant::Generic(s) => generics.0.get(&s).cloned().ok_or(s),
|
||||
DeclarationConstant::Concrete(s) => Ok(s.into()),
|
||||
DeclarationConstant::Constant(c) => Ok(c.into()),
|
||||
DeclarationConstant::Expression(..) => unreachable!("the semantic checker should not yield this DeclarationConstant variant")
|
||||
}?;
|
||||
|
||||
GType::Array(GArrayType { size, ty })
|
||||
|
@ -1028,11 +1041,8 @@ pub fn specialize_declaration_type<
|
|||
generics.0.get(&s).cloned().ok_or(s).map(Some)
|
||||
}
|
||||
DeclarationConstant::Concrete(s) => Ok(Some(s.into())),
|
||||
DeclarationConstant::Constant(..) => {
|
||||
unreachable!(
|
||||
"identifiers should have been removed in constant inlining"
|
||||
)
|
||||
}
|
||||
DeclarationConstant::Constant(c) => Ok(Some(c.into())),
|
||||
DeclarationConstant::Expression(..) => unreachable!("the semantic checker should not yield this DeclarationConstant variant"),
|
||||
},
|
||||
_ => Ok(None),
|
||||
})
|
||||
|
@ -1096,12 +1106,12 @@ pub mod signature {
|
|||
}
|
||||
}
|
||||
|
||||
pub type DeclarationSignature<'ast> = GSignature<DeclarationConstant<'ast>>;
|
||||
pub type DeclarationSignature<'ast, T> = GSignature<DeclarationConstant<'ast, T>>;
|
||||
pub type ConcreteSignature = GSignature<usize>;
|
||||
pub type Signature<'ast, T> = GSignature<UExpression<'ast, T>>;
|
||||
|
||||
impl<'ast> PartialEq<DeclarationSignature<'ast>> for ConcreteSignature {
|
||||
fn eq(&self, other: &DeclarationSignature<'ast>) -> bool {
|
||||
impl<'ast, T> PartialEq<DeclarationSignature<'ast, T>> for ConcreteSignature {
|
||||
fn eq(&self, other: &DeclarationSignature<'ast, T>) -> bool {
|
||||
// we keep track of the value of constants in a map, as a given constant can only have one value
|
||||
let mut constants = ConcreteGenericsAssignment::default();
|
||||
|
||||
|
@ -1110,11 +1120,11 @@ pub mod signature {
|
|||
.iter()
|
||||
.chain(other.outputs.iter())
|
||||
.zip(self.inputs.iter().chain(self.outputs.iter()))
|
||||
.all(|(decl_ty, ty)| check_type::<usize>(decl_ty, ty, &mut constants))
|
||||
.all(|(decl_ty, ty)| check_type::<T, usize>(decl_ty, ty, &mut constants))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> DeclarationSignature<'ast> {
|
||||
impl<'ast, T: Clone + PartialEq + fmt::Debug> DeclarationSignature<'ast, T> {
|
||||
pub fn specialize(
|
||||
&self,
|
||||
values: Vec<Option<u32>>,
|
||||
|
@ -1155,7 +1165,7 @@ pub mod signature {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn get_output_types<T: Clone + PartialEq + fmt::Debug>(
|
||||
pub fn get_output_types(
|
||||
&self,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
inputs: Vec<Type<'ast, T>>,
|
||||
|
@ -1234,7 +1244,7 @@ pub mod signature {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<ConcreteSignature> for DeclarationSignature<'ast> {
|
||||
impl<'ast, T> From<ConcreteSignature> for DeclarationSignature<'ast, T> {
|
||||
fn from(s: ConcreteSignature) -> Self {
|
||||
try_from_g_signature(s).unwrap()
|
||||
}
|
||||
|
@ -1349,6 +1359,7 @@ pub mod signature {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
#[test]
|
||||
fn signature() {
|
||||
|
@ -1365,7 +1376,7 @@ pub mod signature {
|
|||
// <P>(field[P])
|
||||
// <Q>(field[Q])
|
||||
|
||||
let generic1 = DeclarationSignature::new()
|
||||
let generic1 = DeclarationSignature::<Bn128Field>::new()
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier {
|
||||
name: "P",
|
||||
|
|
|
@ -133,13 +133,13 @@ impl<'ast, T: Field> From<&'ast str> for UExpressionInner<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub struct UMetadata {
|
||||
pub bitwidth: Option<Bitwidth>,
|
||||
pub should_reduce: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub struct UExpression<'ast, T> {
|
||||
pub bitwidth: UBitwidth,
|
||||
pub metadata: Option<UMetadata>,
|
||||
|
@ -173,7 +173,7 @@ impl<'ast, T> PartialEq<usize> for UExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
|
||||
pub enum UExpressionInner<'ast, T> {
|
||||
Block(BlockExpression<'ast, T, UExpression<'ast, T>>),
|
||||
Identifier(Identifier<'ast>),
|
||||
|
|
|
@ -5,13 +5,13 @@ use crate::typed_absy::UExpression;
|
|||
use crate::typed_absy::{TryFrom, TryInto};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, PartialEq, Hash, Eq)]
|
||||
#[derive(Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct GVariable<'ast, S> {
|
||||
pub id: Identifier<'ast>,
|
||||
pub _type: GType<S>,
|
||||
}
|
||||
|
||||
pub type DeclarationVariable<'ast> = GVariable<'ast, DeclarationConstant<'ast>>;
|
||||
pub type DeclarationVariable<'ast, T> = GVariable<'ast, DeclarationConstant<'ast, T>>;
|
||||
pub type ConcreteVariable<'ast> = GVariable<'ast, usize>;
|
||||
pub type Variable<'ast, T> = GVariable<'ast, UExpression<'ast, T>>;
|
||||
|
||||
|
|
Loading…
Reference in a new issue