1
0
Fork 0
mirror of synced 2025-09-23 20:28:36 +00:00

support constants in declaration types

This commit is contained in:
dark64 2021-05-13 15:09:06 +02:00
parent 72312537dd
commit 5c528535f2
13 changed files with 310 additions and 150 deletions

View file

@ -0,0 +1,4 @@
const field SIZE = 2
def main(field[SIZE] n):
return

View file

@ -0,0 +1,7 @@
const u32 N = 42
def foo<N>(field[N] a) -> bool:
return true
def main():
return

View file

@ -353,11 +353,11 @@ impl<'ast, T: Field> Checker<'ast, T> {
id: &'ast str,
c: ConstantDefinitionNode<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
) -> Result<TypedConstant<'ast, T>, ErrorInner> {
let pos = c.pos();
let ty = self.check_type(c.value.ty.clone(), module_id, &types)?;
let checked_expr = self.check_expression(c.value.expression.clone(), module_id, types)?;
let ty = self.check_type(c.value.ty.clone(), module_id, state)?;
let checked_expr = self.check_expression(c.value.expression.clone(), module_id, state)?;
match ty {
Type::FieldElement => {
@ -397,7 +397,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
id: String,
s: StructDefinitionNode<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
) -> Result<DeclarationType<'ast>, Vec<ErrorInner>> {
let pos = s.pos();
let s = s.value;
@ -409,7 +409,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
for field in s.fields {
let member_id = field.value.id.to_string();
match self
.check_declaration_type(field.value.ty, module_id, &types, &HashMap::new())
.check_declaration_type(field.value.ty, module_id, state, &HashMap::new())
.map(|t| (member_id, t))
{
Ok(f) => match fields_set.insert(f.0.clone()) {
@ -460,7 +460,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
declaration_id.to_string(),
t.clone(),
module_id,
&state.types,
state,
) {
Ok(ty) => {
match symbol_unifier.insert_type(declaration_id) {
@ -492,7 +492,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Symbol::Here(SymbolDefinition::Constant(c)) => {
match self.check_constant_definition(declaration_id, c, module_id, &state.types) {
match self.check_constant_definition(declaration_id, c, module_id, state) {
Ok(c) => {
match symbol_unifier.insert_constant(declaration_id) {
false => errors.push(
@ -527,7 +527,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Symbol::Here(SymbolDefinition::Function(f)) => {
match self.check_function(f, module_id, &state.types) {
match self.check_function(f, module_id, state) {
Ok(funct) => {
match symbol_unifier
.insert_function(declaration_id, funct.signature.clone())
@ -831,7 +831,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
&mut self,
funct_node: FunctionNode<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
) -> Result<TypedFunction<'ast, T>, Vec<ErrorInner>> {
assert!(self.return_types.is_none());
@ -849,7 +849,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let mut statements_checked = vec![];
match self.check_signature(funct.signature, module_id, types) {
match self.check_signature(funct.signature, module_id, state) {
Ok(s) => {
// define variables for the constants
for generic in &s.generics {
@ -905,7 +905,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
found_return = true;
}
match self.check_statement(stat, module_id, types) {
match self.check_statement(stat, module_id, state) {
Ok(statement) => {
if let TypedStatement::Return(e) = &statement {
match e.iter().map(|e| e.get_type()).collect::<Vec<_>>()
@ -971,7 +971,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
&mut self,
signature: UnresolvedSignature<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
) -> Result<DeclarationSignature<'ast>, Vec<ErrorInner>> {
let mut errors = vec![];
let mut inputs = vec![];
@ -981,6 +981,20 @@ impl<'ast, T: Field> Checker<'ast, T> {
let mut generics_map = HashMap::new();
for (index, g) in signature.generics.iter().enumerate() {
if let Some((key, _)) = state
.constants
.get(module_id)
.unwrap()
.get_key_value(g.value)
{
errors.push(ErrorInner {
pos: Some(g.pos()),
message: format!(
"Generic parameter {} conflicts with constant symbol {}",
g.value, key
),
});
} else {
match generics_map.insert(g.value, index).is_none() {
true => {
generics.push(Some(Constant::Generic(GenericIdentifier {
@ -996,9 +1010,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
}
}
for t in signature.inputs {
match self.check_declaration_type(t, module_id, types, &generics_map) {
match self.check_declaration_type(t, module_id, state, &generics_map) {
Ok(t) => {
inputs.push(t);
}
@ -1009,7 +1024,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
for t in signature.outputs {
match self.check_declaration_type(t, module_id, types, &generics_map) {
match self.check_declaration_type(t, module_id, state, &generics_map) {
Ok(t) => {
outputs.push(t);
}
@ -1036,7 +1051,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
&mut self,
ty: UnresolvedTypeNode<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
) -> Result<Type<'ast, T>, ErrorInner> {
let pos = ty.pos();
let ty = ty.value;
@ -1046,7 +1061,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
UnresolvedType::Boolean => Ok(Type::Boolean),
UnresolvedType::Uint(bitwidth) => Ok(Type::uint(bitwidth)),
UnresolvedType::Array(t, size) => {
let size = self.check_expression(size, module_id, types)?;
let size = self.check_expression(size, module_id, state)?;
let ty = size.get_type();
@ -1079,11 +1094,12 @@ impl<'ast, T: Field> Checker<'ast, T> {
}?;
Ok(Type::Array(ArrayType::new(
self.check_type(*t, module_id, types)?,
self.check_type(*t, module_id, state)?,
size,
)))
}
UnresolvedType::User(id) => types
UnresolvedType::User(id) => state
.types
.get(module_id)
.unwrap()
.get(&id)
@ -1099,6 +1115,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
fn check_generic_expression(
&mut self,
expr: ExpressionNode<'ast>,
constants_map: &HashMap<ConstantIdentifier<'ast>, Type<'ast, T>>,
generics_map: &HashMap<Identifier<'ast>, usize>,
) -> Result<Constant<'ast>, ErrorInner> {
let pos = expr.pos();
@ -1121,10 +1138,21 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Identifier(name) => {
// check that this generic parameter is defined
match generics_map.get(&name) {
Some(index) => Ok(Constant::Generic(GenericIdentifier {name, index: *index})),
None => Err(ErrorInner {
match (constants_map.get(name), generics_map.get(&name)) {
(Some(c), None) => {
match c {
Type::Uint(bitwidth) => Ok(Constant::Identifier(name, bitwidth.to_usize())),
_ => Err(ErrorInner {
pos: Some(pos),
message: format!(
"Expected array dimension to be a u32 constant or an identifier, found {} of type {}",
name, c
),
})
}
}
(None, Some(index)) => Ok(Constant::Generic(GenericIdentifier { name, index: *index })),
_ => Err(ErrorInner {
pos: Some(pos),
message: format!("Undeclared generic parameter in function definition: `{}` isn\'t declared as a generic constant", name)
})
@ -1144,7 +1172,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
&mut self,
ty: UnresolvedTypeNode<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
generics_map: &HashMap<Identifier<'ast>, usize>,
) -> Result<DeclarationType<'ast>, ErrorInner> {
let pos = ty.pos();
@ -1155,15 +1183,19 @@ impl<'ast, T: Field> Checker<'ast, T> {
UnresolvedType::Boolean => Ok(DeclarationType::Boolean),
UnresolvedType::Uint(bitwidth) => Ok(DeclarationType::uint(bitwidth)),
UnresolvedType::Array(t, size) => {
let checked_size = self.check_generic_expression(size.clone(), &generics_map)?;
let checked_size = self.check_generic_expression(
size.clone(),
state.constants.get(module_id).unwrap(),
generics_map,
)?;
Ok(DeclarationType::Array(DeclarationArrayType::new(
self.check_declaration_type(*t, module_id, types, generics_map)?,
self.check_declaration_type(*t, module_id, state, generics_map)?,
checked_size,
)))
}
UnresolvedType::User(id) => {
types
UnresolvedType::User(id) => state
.types
.get(module_id)
.unwrap()
.get(&id)
@ -1171,8 +1203,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.ok_or_else(|| ErrorInner {
pos: Some(pos),
message: format!("Undefined type {}", id),
})
}
}),
}
}
@ -1180,11 +1211,11 @@ impl<'ast, T: Field> Checker<'ast, T> {
&mut self,
v: crate::absy::VariableNode<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
) -> Result<Variable<'ast, T>, Vec<ErrorInner>> {
Ok(Variable::with_id_and_type(
v.value.id,
self.check_type(v.value._type, module_id, types)
self.check_type(v.value._type, module_id, state)
.map_err(|e| vec![e])?,
))
}
@ -1196,17 +1227,17 @@ impl<'ast, T: Field> Checker<'ast, T> {
statements: Vec<StatementNode<'ast>>,
pos: (Position, Position),
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
) -> Result<TypedStatement<'ast, T>, Vec<ErrorInner>> {
self.check_for_var(&var).map_err(|e| vec![e])?;
let var = self.check_variable(var, module_id, types).unwrap();
let var = self.check_variable(var, module_id, state).unwrap();
let from = self
.check_expression(range.0, module_id, &types)
.check_expression(range.0, module_id, state)
.map_err(|e| vec![e])?;
let to = self
.check_expression(range.1, module_id, &types)
.check_expression(range.1, module_id, state)
.map_err(|e| vec![e])?;
let from = match from {
@ -1274,7 +1305,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let mut checked_statements = vec![];
for stat in statements {
let checked_stat = self.check_statement(stat, module_id, types)?;
let checked_stat = self.check_statement(stat, module_id, state)?;
checked_statements.push(checked_stat);
}
@ -1285,7 +1316,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
&mut self,
stat: StatementNode<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
) -> Result<TypedStatement<'ast, T>, Vec<ErrorInner>> {
let pos = stat.pos();
@ -1299,7 +1330,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
for e in e.value.expressions.into_iter() {
let e_checked = self
.check_expression(e, module_id, &types)
.check_expression(e, module_id, state)
.map_err(|e| vec![e])?;
expression_list_checked.push(e_checked);
}
@ -1367,7 +1398,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
Ok(res)
}
Statement::Declaration(var) => {
let var = self.check_variable(var, module_id, types)?;
let var = self.check_variable(var, module_id, state)?;
match self.insert_into_scope(var.clone()) {
true => Ok(TypedStatement::Declaration(var)),
false => Err(ErrorInner {
@ -1386,12 +1417,12 @@ impl<'ast, T: Field> Checker<'ast, T> {
// check the expression to be assigned
let checked_expr = self
.check_expression(expr, module_id, &types)
.check_expression(expr, module_id, state)
.map_err(|e| vec![e])?;
// check that the assignee is declared and is well formed
let var = self
.check_assignee(assignee, module_id, &types)
.check_assignee(assignee, module_id, state)
.map_err(|e| vec![e])?;
let var_type = var.get_type();
@ -1430,7 +1461,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
Statement::Assertion(e) => {
let e = self
.check_expression(e, module_id, &types)
.check_expression(e, module_id, state)
.map_err(|e| vec![e])?;
match e {
@ -1449,7 +1480,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
Statement::For(var, from, to, statements) => {
self.enter_scope();
let res = self.check_for_loop(var, (from, to), statements, pos, module_id, types);
let res = self.check_for_loop(var, (from, to), statements, pos, module_id, state);
self.exit_scope();
@ -1465,7 +1496,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
generics.into_iter().map(|g|
g.map(|g| {
let pos = g.pos();
self.check_expression(g, module_id, &types).and_then(|g| {
self.check_expression(g, module_id, state).and_then(|g| {
UExpression::try_from_typed(g, UBitwidth::B32).map_err(
|e| ErrorInner {
pos: Some(pos),
@ -1484,7 +1515,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
).transpose().map_err(|e| vec![e])?;
// check lhs assignees are defined
let (assignees, errors): (Vec<_>, Vec<_>) = assignees.into_iter().map(|a| self.check_assignee(a, module_id, types)).partition(|r| r.is_ok());
let (assignees, errors): (Vec<_>, Vec<_>) = assignees.into_iter().map(|a| self.check_assignee(a, module_id, state)).partition(|r| r.is_ok());
if !errors.is_empty() {
return Err(errors.into_iter().map(|e| e.unwrap_err()).collect());
@ -1497,7 +1528,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
// find argument types
let mut arguments_checked = vec![];
for arg in arguments {
let arg_checked = self.check_expression(arg, module_id, &types).map_err(|e| vec![e])?;
let arg_checked = self.check_expression(arg, module_id, state).map_err(|e| vec![e])?;
arguments_checked.push(arg_checked);
}
@ -1545,7 +1576,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
&mut self,
assignee: AssigneeNode<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
) -> Result<TypedAssignee<'ast, T>, ErrorInner> {
let pos = assignee.pos();
// check that the assignee is declared
@ -1567,14 +1598,14 @@ impl<'ast, T: Field> Checker<'ast, T> {
}),
},
Assignee::Select(box assignee, box index) => {
let checked_assignee = self.check_assignee(assignee, module_id, &types)?;
let checked_assignee = self.check_assignee(assignee, module_id, state)?;
let ty = checked_assignee.get_type();
match ty {
Type::Array(..) => {
let checked_index = match index {
RangeOrExpression::Expression(e) => {
self.check_expression(e, module_id, &types)?
self.check_expression(e, module_id, state)?
}
r => unimplemented!(
"Using slices in assignments is not supported yet, found {}",
@ -1609,7 +1640,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Assignee::Member(box assignee, box member) => {
let checked_assignee = self.check_assignee(assignee, module_id, &types)?;
let checked_assignee = self.check_assignee(assignee, module_id, state)?;
let ty = checked_assignee.get_type();
match &ty {
@ -1646,14 +1677,14 @@ impl<'ast, T: Field> Checker<'ast, T> {
&mut self,
spread_or_expression: SpreadOrExpression<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
) -> Result<TypedExpressionOrSpread<'ast, T>, ErrorInner> {
match spread_or_expression {
SpreadOrExpression::Spread(s) => {
let pos = s.pos();
let checked_expression =
self.check_expression(s.value.expression, module_id, &types)?;
self.check_expression(s.value.expression, module_id, state)?;
match checked_expression {
TypedExpression::Array(a) => Ok(TypedExpressionOrSpread::Spread(a.into())),
@ -1666,9 +1697,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
}),
}
}
SpreadOrExpression::Expression(e) => self
.check_expression(e, module_id, &types)
.map(|r| r.into()),
SpreadOrExpression::Expression(e) => {
self.check_expression(e, module_id, state).map(|r| r.into())
}
}
}
@ -1676,7 +1707,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
&mut self,
expr: ExpressionNode<'ast>,
module_id: &ModuleId,
types: &TypeMap<'ast>,
state: &State<'ast, T>,
) -> Result<TypedExpression<'ast, T>, ErrorInner> {
let pos = expr.pos();
@ -1711,8 +1742,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Add(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
use self::TypedExpression::*;
@ -1746,8 +1777,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Sub(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
use self::TypedExpression::*;
@ -1777,8 +1808,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Mult(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
use self::TypedExpression::*;
@ -1812,8 +1843,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Div(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
use self::TypedExpression::*;
@ -1847,8 +1878,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Rem(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
let (e1_checked, e2_checked) = TypedExpression::align_without_integers(
e1_checked, e2_checked,
@ -1876,8 +1907,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Pow(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
let e1_checked = match FieldElementExpression::try_from_typed(e1_checked) {
Ok(e) => e.into(),
@ -1904,7 +1935,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Neg(box e) => {
let e = self.check_expression(e, module_id, &types)?;
let e = self.check_expression(e, module_id, state)?;
match e {
TypedExpression::Int(e) => Ok(IntExpression::Neg(box e).into()),
@ -1923,7 +1954,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Pos(box e) => {
let e = self.check_expression(e, module_id, &types)?;
let e = self.check_expression(e, module_id, state)?;
match e {
TypedExpression::Int(e) => Ok(IntExpression::Pos(box e).into()),
@ -1942,9 +1973,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::IfElse(box condition, box consequence, box alternative) => {
let condition_checked = self.check_expression(condition, module_id, &types)?;
let consequence_checked = self.check_expression(consequence, module_id, &types)?;
let alternative_checked = self.check_expression(alternative, module_id, &types)?;
let condition_checked = self.check_expression(condition, module_id, state)?;
let consequence_checked = self.check_expression(consequence, module_id, state)?;
let alternative_checked = self.check_expression(alternative, module_id, state)?;
let (consequence_checked, alternative_checked) =
TypedExpression::align_without_integers(
@ -2020,7 +2051,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
.map(|g| {
g.map(|g| {
let pos = g.pos();
self.check_expression(g, module_id, &types).and_then(|g| {
self.check_expression(g, module_id, state).and_then(|g| {
UExpression::try_from_typed(g, UBitwidth::B32).map_err(
|e| ErrorInner {
pos: Some(pos),
@ -2042,7 +2073,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
// check the arguments
let mut arguments_checked = vec![];
for arg in arguments {
let arg_checked = self.check_expression(arg, module_id, &types)?;
let arg_checked = self.check_expression(arg, module_id, state)?;
arguments_checked.push(arg_checked);
}
@ -2168,8 +2199,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Lt(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
let (e1_checked, e2_checked) = TypedExpression::align_without_integers(
e1_checked, e2_checked,
@ -2218,8 +2249,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Le(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
let (e1_checked, e2_checked) = TypedExpression::align_without_integers(
e1_checked, e2_checked,
@ -2268,8 +2299,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Eq(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
let (e1_checked, e2_checked) = TypedExpression::align_without_integers(
e1_checked, e2_checked,
@ -2318,8 +2349,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Ge(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
let (e1_checked, e2_checked) = TypedExpression::align_without_integers(
e1_checked, e2_checked,
@ -2368,8 +2399,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Gt(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
let (e1_checked, e2_checked) = TypedExpression::align_without_integers(
e1_checked, e2_checked,
@ -2418,7 +2449,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Select(box array, box index) => {
let array = self.check_expression(array, module_id, &types)?;
let array = self.check_expression(array, module_id, state)?;
match index {
RangeOrExpression::Range(r) => {
@ -2432,13 +2463,13 @@ impl<'ast, T: Field> Checker<'ast, T> {
let from = r
.value
.from
.map(|e| self.check_expression(e, module_id, &types))
.map(|e| self.check_expression(e, module_id, state))
.unwrap_or_else(|| Ok(UExpression::from(0u32).into()))?;
let to = r
.value
.to
.map(|e| self.check_expression(e, module_id, &types))
.map(|e| self.check_expression(e, module_id, state))
.unwrap_or_else(|| Ok(array_size.clone().into()))?;
let from = UExpression::try_from_typed(from, UBitwidth::B32).map_err(|e| ErrorInner {
@ -2478,7 +2509,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
RangeOrExpression::Expression(index) => {
let index = self.check_expression(index, module_id, &types)?;
let index = self.check_expression(index, module_id, state)?;
let index =
UExpression::try_from_typed(index, UBitwidth::B32).map_err(|e| {
@ -2519,7 +2550,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Member(box e, box id) => {
let e = self.check_expression(e, module_id, &types)?;
let e = self.check_expression(e, module_id, state)?;
match e {
TypedExpression::Struct(s) => {
@ -2575,7 +2606,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
// check each expression, getting its type
let mut expressions_or_spreads_checked = vec![];
for e in expressions_or_spreads {
let e_checked = self.check_spread_or_expression(e, module_id, &types)?;
let e_checked = self.check_spread_or_expression(e, module_id, state)?;
expressions_or_spreads_checked.push(e_checked);
}
@ -2642,10 +2673,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
)
}
Expression::ArrayInitializer(box e, box count) => {
let e = self.check_expression(e, module_id, &types)?;
let e = self.check_expression(e, module_id, state)?;
let ty = e.get_type();
let count = self.check_expression(count, module_id, &types)?;
let count = self.check_expression(count, module_id, state)?;
let count =
UExpression::try_from_typed(count, UBitwidth::B32).map_err(|e| ErrorInner {
@ -2665,7 +2696,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let ty = self.check_type(
UnresolvedType::User(id.clone()).at(42, 42, 42),
module_id,
&types,
state,
)?;
let struct_type = match ty {
Type::Struct(struct_type) => struct_type,
@ -2705,7 +2736,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
match inline_members_map.remove(member.id.as_str()) {
Some(value) => {
let expression_checked =
self.check_expression(value, module_id, &types)?;
self.check_expression(value, module_id, state)?;
let expression_checked = TypedExpression::align_to_type(
expression_checked,
@ -2750,8 +2781,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
.into())
}
Expression::And(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
let (e1_checked, e2_checked) = TypedExpression::align_without_integers(
e1_checked, e2_checked,
@ -2784,8 +2815,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Or(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
match (e1_checked, e2_checked) {
(TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => {
Ok(BooleanExpression::Or(box e1, box e2).into())
@ -2801,8 +2832,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::LeftShift(box e1, box e2) => {
let e1 = self.check_expression(e1, module_id, &types)?;
let e2 = self.check_expression(e2, module_id, &types)?;
let e1 = self.check_expression(e1, module_id, state)?;
let e2 = self.check_expression(e2, module_id, state)?;
let e2 =
UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner {
@ -2828,8 +2859,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::RightShift(box e1, box e2) => {
let e1 = self.check_expression(e1, module_id, &types)?;
let e2 = self.check_expression(e2, module_id, &types)?;
let e1 = self.check_expression(e1, module_id, state)?;
let e2 = self.check_expression(e2, module_id, state)?;
let e2 =
UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner {
@ -2857,8 +2888,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::BitOr(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
let (e1_checked, e2_checked) = TypedExpression::align_without_integers(
e1_checked, e2_checked,
@ -2889,8 +2920,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::BitAnd(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
let (e1_checked, e2_checked) = TypedExpression::align_without_integers(
e1_checked, e2_checked,
@ -2921,8 +2952,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::BitXor(box e1, box e2) => {
let e1_checked = self.check_expression(e1, module_id, &types)?;
let e2_checked = self.check_expression(e2, module_id, &types)?;
let e1_checked = self.check_expression(e1, module_id, state)?;
let e2_checked = self.check_expression(e2, module_id, state)?;
let (e1_checked, e2_checked) = TypedExpression::align_without_integers(
e1_checked, e2_checked,
@ -2953,7 +2984,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
}
}
Expression::Not(box e) => {
let e_checked = self.check_expression(e, module_id, &types)?;
let e_checked = self.check_expression(e, module_id, state)?;
match e_checked {
TypedExpression::Int(e) => Ok(IntExpression::Not(box e).into()),
TypedExpression::Boolean(e) => Ok(BooleanExpression::Not(box e).into()),

View file

@ -1,20 +1,26 @@
use crate::static_analysis::propagation::Propagator;
use crate::typed_absy::folder::*;
use crate::typed_absy::result_folder::ResultFolder;
use crate::typed_absy::types::{Constant, DeclarationStructType, GStructMember};
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryInto;
use zokrates_field::Field;
pub struct ConstantInliner<'ast, T: Field> {
pub struct ConstantInliner<'ast, 'a, T: Field> {
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
propagator: Propagator<'ast, 'a, T>,
}
impl<'ast, T: Field> ConstantInliner<'ast, T> {
pub fn new(modules: TypedModules<'ast, T>, location: OwnedTypedModuleId) -> Self {
ConstantInliner { modules, location }
}
impl<'ast, 'a, T: Field> ConstantInliner<'ast, 'a, T> {
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone());
let mut constants = HashMap::new();
let mut inliner = ConstantInliner {
modules: p.modules.clone(),
location: p.main.clone(),
propagator: Propagator::with_constants(&mut constants),
};
inliner.fold_program(p)
}
@ -51,12 +57,18 @@ impl<'ast, T: Field> ConstantInliner<'ast, T> {
let _ = self.change_location(location);
symbol
}
TypedConstantSymbol::Here(tc) => self.fold_constant(tc),
TypedConstantSymbol::Here(tc) => {
let tc: TypedConstant<T> = self.fold_constant(tc);
TypedConstant {
expression: self.propagator.fold_expression(tc.expression).unwrap(),
..tc
}
}
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> {
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
TypedProgram {
modules: p
@ -71,6 +83,50 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
}
}
fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> {
DeclarationSignature {
generics: s.generics,
inputs: s
.inputs
.into_iter()
.map(|ty| self.fold_declaration_type(ty))
.collect(),
outputs: s
.outputs
.into_iter()
.map(|ty| self.fold_declaration_type(ty))
.collect(),
}
}
fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> {
match t {
DeclarationType::Array(ref array_ty) => match array_ty.size {
Constant::Identifier(name, _) => {
let tc = self.get_constant(&name.into()).unwrap();
let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap();
match expression.inner {
UExpressionInner::Value(v) => DeclarationType::array((
*array_ty.ty.clone(),
Constant::Concrete(v as u32),
)),
_ => unreachable!("expected u32 value"),
}
}
_ => t,
},
DeclarationType::Struct(struct_ty) => DeclarationType::struc(DeclarationStructType {
members: struct_ty
.members
.into_iter()
.map(|m| GStructMember::new(m.id, self.fold_declaration_type(*m.ty)))
.collect(),
..struct_ty
}),
_ => t,
}
}
fn fold_constant_symbol(
&mut self,
s: TypedConstantSymbol<'ast, T>,

View file

@ -35,6 +35,10 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_function(self, f)
}
fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> {
s
}
fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> {
DeclarationParameter {
id: self.fold_declaration_variable(p.id),
@ -668,7 +672,7 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>(
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
..fun
signature: f.fold_signature(fun.signature),
}
}

View file

@ -290,8 +290,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Debug)]
pub struct TypedConstant<'ast, T> {
ty: Type<'ast, T>,
expression: TypedExpression<'ast, T>,
pub ty: Type<'ast, T>,
pub expression: TypedExpression<'ast, T>,
}
impl<'ast, T> TypedConstant<'ast, T> {

View file

@ -49,6 +49,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_function(self, f)
}
fn fold_signature(
&mut self,
s: DeclarationSignature<'ast>,
) -> Result<DeclarationSignature<'ast>, Self::Error> {
Ok(s)
}
fn fold_parameter(
&mut self,
p: DeclarationParameter<'ast>,
@ -741,7 +748,7 @@ pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>(
.into_iter()
.flatten()
.collect(),
..fun
signature: f.fold_signature(fun.signature)?,
})
}

View file

@ -1,4 +1,4 @@
use crate::typed_absy::{OwnedTypedModuleId, UExpression, UExpressionInner};
use crate::typed_absy::{Identifier, OwnedTypedModuleId, UExpression, UExpressionInner};
use crate::typed_absy::{TryFrom, TryInto};
use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
use std::collections::BTreeMap;
@ -54,6 +54,7 @@ pub struct SpecializationError;
pub enum Constant<'ast> {
Generic(GenericIdentifier<'ast>),
Concrete(u32),
Identifier(&'ast str, usize),
}
impl<'ast> From<u32> for Constant<'ast> {
@ -79,6 +80,7 @@ impl<'ast> fmt::Display for Constant<'ast> {
match self {
Constant::Generic(i) => write!(f, "{}", i),
Constant::Concrete(v) => write!(f, "{}", v),
Constant::Identifier(v, _) => write!(f, "{}", v),
}
}
}
@ -96,6 +98,9 @@ impl<'ast, T> From<Constant<'ast>> for UExpression<'ast, T> {
UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32)
}
Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32),
Constant::Identifier(v, size) => {
UExpressionInner::Identifier(Identifier::from(v)).annotate(UBitwidth::from(size))
}
}
}
}
@ -920,6 +925,7 @@ pub mod signature {
}
},
Constant::Concrete(s0) => s1 == *s0 as usize,
Constant::Identifier(_, s0) => s1 == *s0,
}
}
(DeclarationType::FieldElement, GType::FieldElement)
@ -945,6 +951,7 @@ pub mod signature {
let size = match t0.size {
Constant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
Constant::Concrete(s) => Ok(s.into()),
Constant::Identifier(_, s) => Ok((s as u32).into()),
}?;
GType::Array(GArrayType { size, ty })

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/array_size.zok",
"max_constraint_count": 2,
"tests": [
{
"input": {
"values": ["42", "42"]
},
"output": {
"Ok": {
"values": ["42", "42"]
}
}
}
]
}

View file

@ -0,0 +1,5 @@
const u32 SIZE = 2
def main(field[SIZE] a) -> field[SIZE]:
field[SIZE] b = a
return b

View file

@ -0,0 +1,16 @@
{
"entry_point": "./tests/tests/constants/propagate.zok",
"max_constraint_count": 4,
"tests": [
{
"input": {
"values": []
},
"output": {
"Ok": {
"values": ["42", "42", "42", "42"]
}
}
}
]
}

View file

@ -0,0 +1,5 @@
const u32 TWO = 2
const u32 FOUR = TWO * TWO
def main() -> field[FOUR]:
return [42; FOUR]

View file

@ -1,9 +1,11 @@
struct Foo {
field a
const u32 A_SIZE = 2
struct State {
field[A_SIZE] a
field b
}
const Foo FOO = Foo { a: 2, b: 2 }
const State STATE = State { a: [1, 1], b: 2 }
def main() -> field:
return FOO.a + FOO.b
return STATE.a[0] + STATE.a[1] + STATE.b