1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

reduce propagation memory usage

This commit is contained in:
schaeff 2023-04-20 14:40:34 +02:00
parent e46ad845d0
commit b397e8a993
5 changed files with 117 additions and 58 deletions

View file

@ -0,0 +1 @@
Reduce memory usage of compilation

View file

@ -32,6 +32,7 @@ pub enum Error {
AssertionFailed(RuntimeError),
InvalidValue(String),
OutOfBounds(u128, u128),
VariableLength(String),
}
impl fmt::Display for Error {
@ -45,6 +46,7 @@ impl fmt::Display for Error {
"Out of bounds index ({} >= {}) found during static analysis",
index, size
),
Error::VariableLength(message) => write!(f, "{}", message),
}
}
}
@ -61,6 +63,10 @@ impl<'ast, T: Field> Propagator<'ast, T> {
Propagator::default().fold_program(p)
}
pub fn clear_call_frame(&mut self, frame: usize) {
self.constants.retain(|id, _| id.id.frame != frame);
}
// get a mutable reference to the constant corresponding to a given assignee if any, otherwise
// return the identifier at the root of this assignee
fn try_get_constant_mut<'b>(
@ -1153,59 +1159,93 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
_: &E::Ty,
e: SelectExpression<'ast, T, E>,
) -> Result<SelectOrExpression<'ast, T, E>, Self::Error> {
let array = self.fold_array_expression(*e.array)?;
let index = self.fold_uint_expression(*e.index)?;
let array = *e.array;
let inner_type = array.inner_type().clone();
let size = array.size();
let ty = self.fold_array_type(*array.ty)?;
let size = match ty.size.as_inner() {
UExpressionInner::Value(v) => Ok(v),
_ => unreachable!("array size was checked when folding array type"),
}?;
match size.into_inner() {
UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner()) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
Ok(SelectOrExpression::Expression(
v.expression_at::<E>(n.value as usize).unwrap().into_inner(),
))
} else {
Err(Error::OutOfBounds(n.value, size.value))
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id.id) {
Some(a) => match a {
match (array.inner, index.into_inner()) {
// special case if the array is an identifier: check the cache and only clone the element, not the whole array
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id.id) {
Some(v) => {
// get the constant array. it was guaranteed to be a value when it was inserted
let v = match v {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
Ok(SelectOrExpression::Expression(
v.expression_at::<E>(n.value as usize)
.unwrap()
.into_inner(),
))
}
_ => unreachable!("should be an array value"),
ArrayExpressionInner::Value(v) => v,
_ => unreachable!(),
},
_ => unreachable!("should be an array expression"),
},
None => Ok(SelectOrExpression::Expression(
E::select(
ArrayExpressionInner::Identifier(id)
.annotate(ArrayType::new(inner_type, size.value as u32)),
UExpressionInner::Value(n).annotate(UBitwidth::B32),
)
.into_inner(),
)),
_ => unreachable!(),
};
// sanity check that the value does not contain spreads
assert!(v
.value
.iter()
.all(|e| matches!(e, TypedExpressionOrSpread::Expression(_))));
if n.value < size.value {
Ok(SelectOrExpression::Expression(
// clone only the element
match v.value[n.value as usize].clone() {
TypedExpressionOrSpread::Expression(e) => {
E::try_from(e).unwrap().into_inner()
}
_ => unreachable!(),
},
))
} else {
Err(Error::OutOfBounds(n.value, size.value))
}
}
_ => Ok(SelectOrExpression::Select(SelectExpression::new(
ArrayExpressionInner::Identifier(id).annotate(ty),
UExpressionInner::Value(n).annotate(UBitwidth::B32),
))),
}
(a, i) => Ok(SelectOrExpression::Select(SelectExpression::new(
a.annotate(ArrayType::new(inner_type, size.value as u32)),
i.annotate(UBitwidth::B32),
))),
},
_ => Ok(SelectOrExpression::Select(SelectExpression::new(
array, index,
))),
}
(array, index) => {
let array = self.fold_array_expression_inner(&ty, array)?;
match (array, index) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n.value < size.value {
Ok(SelectOrExpression::Expression(
v.expression_at::<E>(n.value as usize).into_inner(),
))
} else {
Err(Error::OutOfBounds(n.value, size.value))
}
}
(a, i) => Ok(SelectOrExpression::Select(SelectExpression::new(
a.annotate(ty),
i.annotate(UBitwidth::B32),
))),
}
}
}
}
fn fold_array_type(
&mut self,
t: ArrayType<'ast, T>,
) -> Result<ArrayType<'ast, T>, Self::Error> {
let size = self.fold_uint_expression(*t.size)?;
if !size.is_constant() {
return Err(Error::VariableLength(format!(
"Array length should be fixed, found {}",
size
)));
}
Ok(ArrayType::new(self.fold_type(*t.ty)?, size))
}
fn fold_array_expression_cases(
&mut self,
ty: &ArrayType<'ast, T>,

View file

@ -106,6 +106,15 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
program,
}
}
fn push_call_frame(&mut self) {
self.ssa.push_call_frame();
}
fn pop_call_frame(&mut self) {
self.propagator.clear_call_frame(self.ssa.latest_frame);
self.ssa.pop_call_frame();
}
}
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
@ -159,7 +168,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
})
.collect::<Result<_, _>>()?;
self.ssa.push_call_frame();
self.push_call_frame();
let res = inline_call::<_, E>(&e.function_key, &generics, &arguments, ty, self.program);
@ -236,7 +245,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
}
};
self.ssa.pop_call_frame();
self.pop_call_frame();
res
}

View file

@ -1731,28 +1731,31 @@ impl<'ast, T> std::iter::FromIterator<TypedExpressionOrSpread<'ast, T>>
impl<'ast, T: Field> ArrayValueExpression<'ast, T> {
fn expression_at_aux<
'a,
U: Select<'ast, T> + From<TypedExpression<'ast, T>> + Into<TypedExpression<'ast, T>>,
>(
v: TypedExpressionOrSpread<'ast, T>,
) -> Vec<Option<U>> {
) -> Vec<TypedExpression<'ast, T>> {
match v {
TypedExpressionOrSpread::Expression(e) => vec![Some(e.into())],
TypedExpressionOrSpread::Expression(e) => vec![e],
TypedExpressionOrSpread::Spread(s) => match s.array.size().into_inner() {
UExpressionInner::Value(size) => {
let array_ty = s.array.ty().clone();
match s.array.into_inner() {
ArrayExpressionInner::Value(v) => {
v.into_iter().flat_map(Self::expression_at_aux).collect()
}
ArrayExpressionInner::Value(v) => v
.value
.into_iter()
.flat_map(Self::expression_at_aux::<U>)
.collect(),
a => (0..size.value)
.map(|i| {
Some(U::select(a.clone().annotate(array_ty.clone()), i as u32))
U::select(a.clone().annotate(array_ty.clone()), i as u32).into()
})
.collect(),
}
}
_ => vec![None],
_ => unreachable!(),
},
}
}
@ -1760,14 +1763,16 @@ impl<'ast, T: Field> ArrayValueExpression<'ast, T> {
pub fn expression_at<
U: Select<'ast, T> + From<TypedExpression<'ast, T>> + Into<TypedExpression<'ast, T>>,
>(
&self,
self,
index: usize,
) -> Option<U> {
self.iter()
.flat_map(|v| Self::expression_at_aux(v.clone()))
.take_while(|e| e.is_some())
.map(|e| e.unwrap())
) -> U {
self.into_iter()
.flat_map(|v| Self::expression_at_aux::<U>(v))
.nth(index)
.unwrap()
.clone()
.try_into()
.unwrap()
}
}

View file

@ -0,0 +1,4 @@
def main(u32 N) {
field[N] a = [0; N];
return;
}