diff --git a/fathom/src/core.rs b/fathom/src/core.rs index 64b3ddefc..54c584ca8 100644 --- a/fathom/src/core.rs +++ b/fathom/src/core.rs @@ -2,6 +2,8 @@ use std::fmt; +use scoped_arena::Scope; + use crate::env::{Index, Level}; use crate::source::Span; use crate::symbol::Symbol; @@ -23,9 +25,9 @@ pub enum Item<'arena> { /// The label that identifies this definition label: Symbol, /// The type of the defined expression - r#type: &'arena Term<'arena>, + r#type: Term<'arena>, /// The defined expression - expr: &'arena Term<'arena>, + expr: Term<'arena>, }, } @@ -144,13 +146,7 @@ pub enum Term<'arena> { /// Annotated expressions. Ann(Span, &'arena Term<'arena>, &'arena Term<'arena>), /// Let expressions. - Let( - Span, - Option, - &'arena Term<'arena>, - &'arena Term<'arena>, - &'arena Term<'arena>, - ), + Let(Span, &'arena LetDef<'arena>, &'arena Term<'arena>), /// The type of types. Universe(Span), @@ -205,6 +201,13 @@ pub enum Term<'arena> { ), } +#[derive(Debug, Clone)] +pub struct LetDef<'arena> { + pub name: Option, + pub r#type: Term<'arena>, + pub expr: Term<'arena>, +} + impl<'arena> Term<'arena> { /// Get the source span of the term. pub fn span(&self) -> Span { @@ -214,7 +217,7 @@ impl<'arena> Term<'arena> { | Term::MetaVar(span, _) | Term::InsertedMeta(span, _, _) | Term::Ann(span, _, _) - | Term::Let(span, _, _, _, _) + | Term::Let(span, ..) | Term::Universe(span) | Term::FunType(span, ..) | Term::FunLit(span, ..) @@ -233,7 +236,7 @@ impl<'arena> Term<'arena> { } /// Returns `true` if the term contains an occurrence of the local variable. - pub fn binds_local(&self, mut var: Index) -> bool { + pub fn binds_local(&self, var: Index) -> bool { match self { Term::LocalVar(_, v) => *v == var, Term::ItemVar(_, _) @@ -244,9 +247,9 @@ impl<'arena> Term<'arena> { | Term::ConstLit(_, _) => false, Term::Ann(_, expr, r#type) => expr.binds_local(var) || r#type.binds_local(var), - Term::Let(_, _, def_type, def_expr, body_expr) => { - def_type.binds_local(var) - || def_expr.binds_local(var) + Term::Let(_, def, body_expr) => { + def.r#type.binds_local(var) + || def.r#expr.binds_local(var) || body_expr.binds_local(var.prev()) } Term::FunType(.., param_type, body_type) => { @@ -259,11 +262,8 @@ impl<'arena> Term<'arena> { Term::RecordType(_, _, terms) | Term::RecordLit(_, _, terms) | Term::FormatRecord(_, _, terms) - | Term::FormatOverlap(_, _, terms) => terms.iter().any(|term| { - let result = term.binds_local(var); - var = var.prev(); - result - }), + | Term::FormatOverlap(_, _, terms) => Iterator::zip(var.iter_from(), terms.iter()) + .any(|(var, term)| term.binds_local(var)), Term::RecordProj(_, head_expr, _) => head_expr.binds_local(var), Term::ArrayLit(_, elem_exprs) => elem_exprs.iter().any(|term| term.binds_local(var)), Term::FormatCond(_, _, format, pred) => { @@ -280,6 +280,10 @@ impl<'arena> Term<'arena> { pub fn is_error(&self) -> bool { matches!(self, Term::Prim(_, Prim::ReportedError)) } + + pub fn error(span: impl Into) -> Term<'arena> { + Term::Prim(span.into(), Prim::ReportedError) + } } macro_rules! def_prims { @@ -672,6 +676,21 @@ impl Ord for Const { } } +impl Const { + /// Return the number of inhabitants of `self`. + /// `None` represents infinity + pub fn num_inhabitants(&self) -> Option { + match self { + Const::Bool(_) => Some(2), + Const::U8(_, _) | Const::S8(_) => Some(1 << 8), + Const::U16(_, _) | Const::S16(_) => Some(1 << 16), + Const::U32(_, _) | Const::S32(_) => Some(1 << 32), + Const::U64(_, _) | Const::S64(_) => Some(1 << 64), + Const::F32(_) | Const::F64(_) | Const::Pos(_) | Const::Ref(_) => None, + } + } +} + pub trait ToBeBytes { fn to_be_bytes(self) -> [u8; N]; } @@ -731,6 +750,332 @@ impl UIntStyle { } } +pub struct Builder<'arena> { + scope: &'arena Scope<'arena>, +} + +// Proxy type to allow many different types to be passed to +// `TermBuilder::fun_apps` for convenience +pub struct FunAppArg<'arena> { + span: Span, + plicity: Plicity, + term: Term<'arena>, +} + +impl<'arena> From<(Span, Plicity, Term<'arena>)> for FunAppArg<'arena> { + fn from((span, plicity, term): (Span, Plicity, Term<'arena>)) -> Self { + Self { + span, + plicity, + term, + } + } +} + +impl<'arena> From<(Plicity, Term<'arena>)> for FunAppArg<'arena> { + fn from((plicity, term): (Plicity, Term<'arena>)) -> Self { + Self { + span: Span::Empty, + plicity, + term, + } + } +} + +impl<'arena> From> for FunAppArg<'arena> { + fn from(term: Term<'arena>) -> Self { + Self { + span: Span::Empty, + plicity: Plicity::Explicit, + term, + } + } +} + +// Proxy type to allow many different types to be passed to +// `TermBuilder::fun_types` for convenience +pub struct FunTypeParam<'arena> { + span: Span, + plicity: Plicity, + name: Option, + r#type: Term<'arena>, +} + +impl<'arena> From<(Span, Plicity, Option, Term<'arena>)> for FunTypeParam<'arena> { + fn from((span, plicity, name, term): (Span, Plicity, Option, Term<'arena>)) -> Self { + Self { + span, + plicity, + name, + r#type: term, + } + } +} + +impl<'arena> From<(Plicity, Option, Term<'arena>)> for FunTypeParam<'arena> { + fn from((plicity, name, term): (Plicity, Option, Term<'arena>)) -> Self { + Self { + span: Span::Empty, + plicity, + name, + r#type: term, + } + } +} + +impl<'arena> From<(Option, Term<'arena>)> for FunTypeParam<'arena> { + fn from((name, term): (Option, Term<'arena>)) -> Self { + Self { + span: Span::Empty, + plicity: Plicity::Explicit, + name, + r#type: term, + } + } +} + +impl<'arena> From<(Plicity, Term<'arena>)> for FunTypeParam<'arena> { + fn from((plicity, term): (Plicity, Term<'arena>)) -> Self { + Self { + span: Span::Empty, + plicity, + name: None, + r#type: term, + } + } +} + +impl<'arena> From> for FunTypeParam<'arena> { + fn from(term: Term<'arena>) -> Self { + Self { + span: Span::Empty, + plicity: Plicity::Explicit, + name: None, + r#type: term, + } + } +} + +impl<'arena> Builder<'arena> { + pub fn new(scope: &'arena Scope<'arena>) -> Self { + Self { scope } + } + + pub fn ann( + &self, + span: impl Into, + expr: Term<'arena>, + r#type: Term<'arena>, + ) -> Term<'arena> { + Term::Ann( + span.into(), + self.scope.to_scope(expr), + self.scope.to_scope(r#type), + ) + } + + pub fn r#let( + &self, + span: impl Into, + def: LetDef<'arena>, + body: Term<'arena>, + ) -> Term<'arena> { + Term::Let( + span.into(), + self.scope.to_scope(def), + self.scope.to_scope(body), + ) + } + + pub fn fun_type( + &self, + span: impl Into, + plicity: Plicity, + name: impl Into>, + input: Term<'arena>, + output: Term<'arena>, + ) -> Term<'arena> { + Term::FunType( + span.into(), + plicity, + name.into(), + self.scope.to_scope(input), + self.scope.to_scope(output), + ) + } + + pub fn fun_types(&self, params: I, output: Term<'arena>) -> Term<'arena> + where + I: IntoIterator, + I::IntoIter: DoubleEndedIterator, + T: Into>, + { + params.into_iter().rev().fold(output, |output, param| { + let FunTypeParam { + span, + plicity, + name, + r#type: term, + } = param.into(); + self.fun_type(span, plicity, name, term, output) + }) + } + + pub fn arrow( + &self, + span: impl Into, + plicity: Plicity, + input: Term<'arena>, + output: Term<'arena>, + ) -> Term<'arena> { + self.fun_type(span, plicity, None, input, output) + } + + pub fn arrows(&self, params: I, output: Term<'arena>) -> Term<'arena> + where + I: IntoIterator)>, + I::IntoIter: DoubleEndedIterator, + { + params + .into_iter() + .fold(output, |output, (span, plicity, input)| { + self.arrow(span, plicity, input, output) + }) + } + + pub fn fun_lit( + &self, + span: impl Into, + plicity: Plicity, + name: impl Into>, + body: Term<'arena>, + ) -> Term<'arena> { + Term::FunLit(span.into(), plicity, name.into(), self.scope.to_scope(body)) + } + + pub fn fun_app( + &self, + span: impl Into, + plicity: Plicity, + fun: Term<'arena>, + arg: Term<'arena>, + ) -> Term<'arena> { + Term::FunApp( + span.into(), + plicity, + self.scope.to_scope(fun), + self.scope.to_scope(arg), + ) + } + + pub fn fun_apps( + &self, + fun: Term<'arena>, + args: impl IntoIterator>>, + ) -> Term<'arena> { + args.into_iter().fold(fun, |fun, arg| { + let FunAppArg { + span, + plicity, + term, + } = arg.into(); + self.fun_app(span, plicity, fun, term) + }) + } + + pub fn tuple_type(&self, span: impl Into, types: &'arena [Term<'arena>]) -> Term<'arena> { + let labels = Symbol::get_tuple_labels(0..types.len()); + let labels = self.scope.to_scope_from_iter(labels.iter().copied()); + Term::RecordType(span.into(), labels, types) + } + + pub fn record_proj( + &self, + span: impl Into, + head: Term<'arena>, + label: Symbol, + ) -> Term<'arena> { + Term::RecordProj(span.into(), self.scope.to_scope(head), label) + } + + pub fn record_projs( + &self, + head: Term<'arena>, + labels: impl IntoIterator, Symbol)>, + ) -> Term<'arena> { + labels.into_iter().fold(head, |head, (span, label)| { + self.record_proj(span, head, label) + }) + } + + pub fn format_cond( + &self, + span: impl Into, + name: Symbol, + format: Term<'arena>, + pred: Term<'arena>, + ) -> Term<'arena> { + Term::FormatCond( + span.into(), + name, + self.scope.to_scope(format), + self.scope.to_scope(pred), + ) + } + + pub fn const_match( + &self, + span: impl Into, + scrut: Term<'arena>, + branches: &'arena [(Const, Term<'arena>)], + default: impl Into, Term<'arena>)>>, + ) -> Term<'arena> { + Term::ConstMatch( + span.into(), + self.scope.to_scope(scrut), + branches, + default + .into() + .map(|(name, expr)| (name, self.scope.to_scope(expr) as &_)), + ) + } + + pub fn if_then_else( + &self, + span: impl Into, + cond: Term<'arena>, + then: Term<'arena>, + r#else: Term<'arena>, + ) -> Term<'arena> { + self.const_match( + span, + cond, + self.scope + .to_scope_from_iter([(Const::Bool(false), r#else), (Const::Bool(true), then)]), + None, + ) + } + + pub fn binop( + &self, + span: impl Into, + op_span: impl Into, + op: Prim, + lhs: Term<'arena>, + rhs: Term<'arena>, + ) -> Term<'arena> { + let args_span = Span::merge(&lhs.span(), &rhs.span()); + + self.fun_apps( + Term::Prim(op_span.into(), op), + [ + (span.into(), Plicity::Explicit, lhs), + (args_span, Plicity::Explicit, rhs), + ], + ) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/fathom/src/core/binary.rs b/fathom/src/core/binary.rs index 661f014e6..c64072692 100644 --- a/fathom/src/core/binary.rs +++ b/fathom/src/core/binary.rs @@ -425,15 +425,11 @@ impl<'arena, 'data> Context<'arena, 'data> { (Prim::FormatF32Le, []) => read_const(reader, span, read_f32le, Const::F32), (Prim::FormatF64Be, []) => read_const(reader, span, read_f64be, Const::F64), (Prim::FormatF64Le, []) => read_const(reader, span, read_f64le, Const::F64), - (Prim::FormatRepeatLen8, [FunApp(_, len), FunApp(_, format)]) => self.read_repeat_len(reader, span, len, format), - (Prim::FormatRepeatLen16, [FunApp(_, len), FunApp(_, format)]) => self.read_repeat_len(reader, span, len, format), - (Prim::FormatRepeatLen32, [FunApp(_, len), FunApp(_, format)]) => self.read_repeat_len(reader, span, len, format), - (Prim::FormatRepeatLen64, [FunApp(_, len), FunApp(_, format)]) => self.read_repeat_len(reader, span, len, format), + (Prim::FormatRepeatLen8 | Prim::FormatRepeatLen16 | Prim::FormatRepeatLen32 | Prim::FormatRepeatLen64, + [FunApp(_, len), FunApp(_, format)]) => self.read_repeat_len(reader, span, len, format), (Prim::FormatRepeatUntilEnd, [FunApp(_,format)]) => self.read_repeat_until_end(reader, format), - (Prim::FormatLimit8, [FunApp(_, limit), FunApp(_, format)]) => self.read_limit(reader, limit, format), - (Prim::FormatLimit16, [FunApp(_, limit), FunApp(_, format)]) => self.read_limit(reader, limit, format), - (Prim::FormatLimit32, [FunApp(_, limit), FunApp(_, format)]) => self.read_limit(reader, limit, format), - (Prim::FormatLimit64, [FunApp(_, limit), FunApp(_, format)]) => self.read_limit(reader, limit, format), + (Prim::FormatLimit8 | Prim::FormatLimit16 | Prim::FormatLimit32 | Prim::FormatLimit64, + [FunApp(_, limit), FunApp(_, format)]) => self.read_limit(reader, limit, format), (Prim::FormatLink, [FunApp(_, pos), FunApp(_, format)]) => self.read_link(span, pos, format), (Prim::FormatDeref, [FunApp(_, format), FunApp(_, r#ref)]) => self.read_deref(format, r#ref), (Prim::FormatStreamPos, []) => read_stream_pos(reader, span), @@ -642,9 +638,7 @@ fn read_s8(reader: &mut BufferReader<'_>) -> Result { /// Generates a function that reads a multi-byte primitive. macro_rules! read_multibyte_prim { ($read_multibyte_prim:ident, $from_bytes:ident, $T:ident) => { - fn $read_multibyte_prim<'data>( - reader: &mut BufferReader<'data>, - ) -> Result<$T, BufferError> { + fn $read_multibyte_prim(reader: &mut BufferReader) -> Result<$T, BufferError> { Ok($T::$from_bytes(*reader.read_byte_array()?)) } }; diff --git a/fathom/src/core/pretty.rs b/fathom/src/core/pretty.rs index 22035551b..e3ce4e86a 100644 --- a/fathom/src/core/pretty.rs +++ b/fathom/src/core/pretty.rs @@ -143,17 +143,17 @@ impl<'arena> Context { self.term_prec(Prec::Top, r#type), ]), ), - Term::Let(_, def_pattern, def_type, def_expr, body_expr) => self.paren( + Term::Let(_, def, body_expr) => self.paren( prec > Prec::Let, RcDoc::concat([ RcDoc::concat([ RcDoc::text("let"), RcDoc::space(), - self.ann_pattern(Prec::Top, *def_pattern, def_type), + self.ann_pattern(Prec::Top, def.name, &def.r#type), RcDoc::space(), RcDoc::text("="), RcDoc::softline(), - self.term_prec(Prec::Let, def_expr), + self.term_prec(Prec::Let, &def.expr), RcDoc::text(";"), ]) .group(), diff --git a/fathom/src/core/prim.rs b/fathom/src/core/prim.rs index e4e99808f..5f79e50aa 100644 --- a/fathom/src/core/prim.rs +++ b/fathom/src/core/prim.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use fxhash::FxHashMap; use scoped_arena::Scope; +use super::Builder; use crate::core::semantics::{ArcValue, Elim, ElimEnv, Head, Value}; use crate::core::{self, Const, Plicity, Prim, UIntStyle}; use crate::env::{self, SharedEnv, UniqueEnv}; @@ -51,89 +52,92 @@ impl<'arena> Env<'arena> { const ARRAY32_TYPE: Term<'_> = Term::Prim(Span::Empty, Array32Type); const ARRAY64_TYPE: Term<'_> = Term::Prim(Span::Empty, Array64Type); const POS_TYPE: Term<'_> = Term::Prim(Span::Empty, PosType); + const REF_TYPE: Term<'_> = Term::Prim(Span::Empty, RefType); + const OPTION_TYPE: Term<'_> = Term::Prim(Span::Empty, OptionType); let mut env = EnvBuilder::new(scope); + let builder = env.builder(); + + // comments force rustfmt not to mess with grouping + for prim in [ + VoidType, BoolType, PosType, FormatType, // + U8Type, U16Type, U32Type, U64Type, // + S8Type, S16Type, S32Type, S64Type, // + F32Type, F64Type, + ] { + env.define_prim(prim, &UNIVERSE); + } + + for prim in [OptionType, ArrayType] { + env.define_prim_fun(prim, [&UNIVERSE], &UNIVERSE); + } + + for (prim, arg) in [ + (Array8Type, &U8_TYPE), + (Array16Type, &U16_TYPE), + (Array32Type, &U32_TYPE), + (Array64Type, &U64_TYPE), + ] { + env.define_prim_fun(prim, [arg, &UNIVERSE], &UNIVERSE); + } - env.define_prim(VoidType, &UNIVERSE); - env.define_prim(BoolType, &UNIVERSE); - env.define_prim(U8Type, &UNIVERSE); - env.define_prim(U16Type, &UNIVERSE); - env.define_prim(U32Type, &UNIVERSE); - env.define_prim(U64Type, &UNIVERSE); - env.define_prim(S8Type, &UNIVERSE); - env.define_prim(S16Type, &UNIVERSE); - env.define_prim(S32Type, &UNIVERSE); - env.define_prim(S64Type, &UNIVERSE); - env.define_prim(F32Type, &UNIVERSE); - env.define_prim(F64Type, &UNIVERSE); - env.define_prim_fun(OptionType, [&UNIVERSE], &UNIVERSE); - env.define_prim_fun(ArrayType, [&UNIVERSE], &UNIVERSE); - env.define_prim_fun(Array8Type, [&U8_TYPE, &UNIVERSE], &UNIVERSE); - env.define_prim_fun(Array16Type, [&U16_TYPE, &UNIVERSE], &UNIVERSE); - env.define_prim_fun(Array32Type, [&U32_TYPE, &UNIVERSE], &UNIVERSE); - env.define_prim_fun(Array64Type, [&U64_TYPE, &UNIVERSE], &UNIVERSE); - env.define_prim(PosType, &UNIVERSE); env.define_prim_fun(RefType, [&FORMAT_TYPE], &UNIVERSE); - env.define_prim(FormatType, &UNIVERSE); - - env.define_prim(FormatU8, &FORMAT_TYPE); - env.define_prim(FormatU16Be, &FORMAT_TYPE); - env.define_prim(FormatU16Le, &FORMAT_TYPE); - env.define_prim(FormatU32Be, &FORMAT_TYPE); - env.define_prim(FormatU32Le, &FORMAT_TYPE); - env.define_prim(FormatU64Be, &FORMAT_TYPE); - env.define_prim(FormatU64Le, &FORMAT_TYPE); - env.define_prim(FormatS8, &FORMAT_TYPE); - env.define_prim(FormatS16Be, &FORMAT_TYPE); - env.define_prim(FormatS16Le, &FORMAT_TYPE); - env.define_prim(FormatS32Be, &FORMAT_TYPE); - env.define_prim(FormatS32Le, &FORMAT_TYPE); - env.define_prim(FormatS64Be, &FORMAT_TYPE); - env.define_prim(FormatS64Le, &FORMAT_TYPE); - env.define_prim(FormatF32Be, &FORMAT_TYPE); - env.define_prim(FormatF32Le, &FORMAT_TYPE); - env.define_prim(FormatF64Be, &FORMAT_TYPE); - env.define_prim(FormatF64Le, &FORMAT_TYPE); - env.define_prim_fun(FormatRepeatLen8, [&U8_TYPE, &FORMAT_TYPE], &FORMAT_TYPE); - env.define_prim_fun(FormatRepeatLen16, [&U16_TYPE, &FORMAT_TYPE], &FORMAT_TYPE); - env.define_prim_fun(FormatRepeatLen32, [&U32_TYPE, &FORMAT_TYPE], &FORMAT_TYPE); - env.define_prim_fun(FormatRepeatLen64, [&U64_TYPE, &FORMAT_TYPE], &FORMAT_TYPE); + + // rustfmt messes with grouping regardless of comments for some reason :( + for prim in [ + FormatU8, + FormatS8, + FormatU16Be, + FormatU16Le, + FormatU32Be, + FormatU32Le, + FormatU64Be, + FormatU64Le, + FormatS16Be, + FormatS16Le, + FormatS32Be, + FormatS32Le, + FormatS64Be, + FormatS64Le, + FormatF32Be, + FormatF32Le, + FormatF64Be, + FormatF64Le, + ] { + env.define_prim(prim, &FORMAT_TYPE); + } + + for (prim1, prim2, arg) in [ + (FormatRepeatLen8, FormatLimit8, &U8_TYPE), + (FormatRepeatLen16, FormatLimit16, &U16_TYPE), + (FormatRepeatLen32, FormatLimit32, &U32_TYPE), + (FormatRepeatLen64, FormatLimit64, &U64_TYPE), + ] { + env.define_prim_fun(prim1, [arg, &FORMAT_TYPE], &FORMAT_TYPE); + env.define_prim_fun(prim2, [arg, &FORMAT_TYPE], &FORMAT_TYPE); + } + env.define_prim_fun(FormatRepeatUntilEnd, [&FORMAT_TYPE], &FORMAT_TYPE); - env.define_prim_fun(FormatLimit8, [&U8_TYPE, &FORMAT_TYPE], &FORMAT_TYPE); - env.define_prim_fun(FormatLimit16, [&U16_TYPE, &FORMAT_TYPE], &FORMAT_TYPE); - env.define_prim_fun(FormatLimit32, [&U32_TYPE, &FORMAT_TYPE], &FORMAT_TYPE); - env.define_prim_fun(FormatLimit64, [&U64_TYPE, &FORMAT_TYPE], &FORMAT_TYPE); env.define_prim_fun(FormatLink, [&POS_TYPE, &FORMAT_TYPE], &FORMAT_TYPE); env.define_prim( FormatDeref, - &core::Term::FunType( - Span::Empty, - Plicity::Implicit, - env.name("f"), - &FORMAT_TYPE, - &Term::FunType( - Span::Empty, - Plicity::Explicit, - None, - &Term::FunApp( - Span::Empty, - Plicity::Explicit, - &Term::Prim(Span::Empty, RefType), - &VAR0, - ), - &FORMAT_TYPE, - ), + &builder.fun_types( + [ + (Plicity::Implicit, env.name("f"), FORMAT_TYPE), + (Plicity::Explicit, None, builder.fun_apps(REF_TYPE, [VAR0])), + ], + FORMAT_TYPE, ), ); env.define_prim(FormatStreamPos, &FORMAT_TYPE); env.define_prim( FormatSucceed, - &core::Term::FunType( - Span::Empty, - Plicity::Implicit, - env.name("A"), - &UNIVERSE, - &Term::FunType(Span::Empty, Plicity::Explicit, None, &VAR0, &FORMAT_TYPE), + &builder.fun_types( + [ + (Plicity::Implicit, env.name("A"), UNIVERSE), + (Plicity::Explicit, None, VAR0), + ], + FORMAT_TYPE, ), ); env.define_prim(FormatFail, &FORMAT_TYPE); @@ -141,23 +145,16 @@ impl<'arena> Env<'arena> { FormatUnwrap, // fun (@A : Type) -> Option A -> Format // fun (@A : Type) -> Option A@0 -> Format - &core::Term::FunType( - Span::Empty, - Plicity::Implicit, - env.name("A"), - &UNIVERSE, - &Term::FunType( - Span::Empty, - Plicity::Explicit, - None, - &Term::FunApp( - Span::Empty, + &builder.fun_types( + [ + (Plicity::Implicit, env.name("A"), UNIVERSE), + ( Plicity::Explicit, - &Term::Prim(Span::Empty, OptionType), - &VAR0, + None, + builder.fun_apps(OPTION_TYPE, [VAR0]), ), - &FORMAT_TYPE, - ), + ], + FORMAT_TYPE, ), ); env.define_prim_fun(FormatRepr, [&FORMAT_TYPE], &UNIVERSE); @@ -165,167 +162,148 @@ impl<'arena> Env<'arena> { // fun (@A : Type) -> Void -> A env.define_prim( Absurd, - &core::Term::FunType( - Span::Empty, - Plicity::Implicit, - env.name("A"), - &UNIVERSE, - &core::Term::FunType(Span::Empty, Plicity::Explicit, None, &VOID_TYPE, &VAR1), + &builder.fun_types( + [ + (Plicity::Implicit, env.name("A"), UNIVERSE), + (Plicity::Explicit, None, VOID_TYPE), + ], + VAR1, ), ); - env.define_prim_fun(BoolEq, [&BOOL_TYPE, &BOOL_TYPE], &BOOL_TYPE); - env.define_prim_fun(BoolNeq, [&BOOL_TYPE, &BOOL_TYPE], &BOOL_TYPE); env.define_prim_fun(BoolNot, [&BOOL_TYPE], &BOOL_TYPE); - env.define_prim_fun(BoolAnd, [&BOOL_TYPE, &BOOL_TYPE], &BOOL_TYPE); - env.define_prim_fun(BoolOr, [&BOOL_TYPE, &BOOL_TYPE], &BOOL_TYPE); - env.define_prim_fun(BoolXor, [&BOOL_TYPE, &BOOL_TYPE], &BOOL_TYPE); - - env.define_prim_fun(U8Eq, [&U8_TYPE, &U8_TYPE], &BOOL_TYPE); - env.define_prim_fun(U8Neq, [&U8_TYPE, &U8_TYPE], &BOOL_TYPE); - env.define_prim_fun(U8Lt, [&U8_TYPE, &U8_TYPE], &BOOL_TYPE); - env.define_prim_fun(U8Gt, [&U8_TYPE, &U8_TYPE], &BOOL_TYPE); - env.define_prim_fun(U8Lte, [&U8_TYPE, &U8_TYPE], &BOOL_TYPE); - env.define_prim_fun(U8Gte, [&U8_TYPE, &U8_TYPE], &BOOL_TYPE); - env.define_prim_fun(U8Add, [&U8_TYPE, &U8_TYPE], &U8_TYPE); - env.define_prim_fun(U8Sub, [&U8_TYPE, &U8_TYPE], &U8_TYPE); - env.define_prim_fun(U8Mul, [&U8_TYPE, &U8_TYPE], &U8_TYPE); - env.define_prim_fun(U8Div, [&U8_TYPE, &U8_TYPE], &U8_TYPE); - env.define_prim_fun(U8Not, [&U8_TYPE], &U8_TYPE); - env.define_prim_fun(U8Shl, [&U8_TYPE, &U8_TYPE], &U8_TYPE); - env.define_prim_fun(U8Shr, [&U8_TYPE, &U8_TYPE], &U8_TYPE); - env.define_prim_fun(U8And, [&U8_TYPE, &U8_TYPE], &U8_TYPE); - env.define_prim_fun(U8Or, [&U8_TYPE, &U8_TYPE], &U8_TYPE); - env.define_prim_fun(U8Xor, [&U8_TYPE, &U8_TYPE], &U8_TYPE); - - env.define_prim_fun(U16Eq, [&U16_TYPE, &U16_TYPE], &BOOL_TYPE); - env.define_prim_fun(U16Neq, [&U16_TYPE, &U16_TYPE], &BOOL_TYPE); - env.define_prim_fun(U16Lt, [&U16_TYPE, &U16_TYPE], &BOOL_TYPE); - env.define_prim_fun(U16Gt, [&U16_TYPE, &U16_TYPE], &BOOL_TYPE); - env.define_prim_fun(U16Lte, [&U16_TYPE, &U16_TYPE], &BOOL_TYPE); - env.define_prim_fun(U16Gte, [&U16_TYPE, &U16_TYPE], &BOOL_TYPE); - env.define_prim_fun(U16Add, [&U16_TYPE, &U16_TYPE], &U16_TYPE); - env.define_prim_fun(U16Sub, [&U16_TYPE, &U16_TYPE], &U16_TYPE); - env.define_prim_fun(U16Mul, [&U16_TYPE, &U16_TYPE], &U16_TYPE); - env.define_prim_fun(U16Div, [&U16_TYPE, &U16_TYPE], &U16_TYPE); - env.define_prim_fun(U16Not, [&U16_TYPE], &U16_TYPE); - env.define_prim_fun(U16Shl, [&U16_TYPE, &U8_TYPE], &U16_TYPE); - env.define_prim_fun(U16Shr, [&U16_TYPE, &U8_TYPE], &U16_TYPE); - env.define_prim_fun(U16And, [&U16_TYPE, &U16_TYPE], &U16_TYPE); - env.define_prim_fun(U16Or, [&U16_TYPE, &U16_TYPE], &U16_TYPE); - env.define_prim_fun(U16Xor, [&U16_TYPE, &U16_TYPE], &U16_TYPE); - - env.define_prim_fun(U32Eq, [&U32_TYPE, &U32_TYPE], &BOOL_TYPE); - env.define_prim_fun(U32Neq, [&U32_TYPE, &U32_TYPE], &BOOL_TYPE); - env.define_prim_fun(U32Lt, [&U32_TYPE, &U32_TYPE], &BOOL_TYPE); - env.define_prim_fun(U32Gt, [&U32_TYPE, &U32_TYPE], &BOOL_TYPE); - env.define_prim_fun(U32Lte, [&U32_TYPE, &U32_TYPE], &BOOL_TYPE); - env.define_prim_fun(U32Gte, [&U32_TYPE, &U32_TYPE], &BOOL_TYPE); - env.define_prim_fun(U32Add, [&U32_TYPE, &U32_TYPE], &U32_TYPE); - env.define_prim_fun(U32Sub, [&U32_TYPE, &U32_TYPE], &U32_TYPE); - env.define_prim_fun(U32Mul, [&U32_TYPE, &U32_TYPE], &U32_TYPE); - env.define_prim_fun(U32Div, [&U32_TYPE, &U32_TYPE], &U32_TYPE); - env.define_prim_fun(U32Not, [&U32_TYPE], &U32_TYPE); - env.define_prim_fun(U32Shl, [&U32_TYPE, &U8_TYPE], &U32_TYPE); - env.define_prim_fun(U32Shr, [&U32_TYPE, &U8_TYPE], &U32_TYPE); - env.define_prim_fun(U32And, [&U32_TYPE, &U32_TYPE], &U32_TYPE); - env.define_prim_fun(U32Or, [&U32_TYPE, &U32_TYPE], &U32_TYPE); - env.define_prim_fun(U32Xor, [&U32_TYPE, &U32_TYPE], &U32_TYPE); - - env.define_prim_fun(U64Eq, [&U64_TYPE, &U64_TYPE], &BOOL_TYPE); - env.define_prim_fun(U64Neq, [&U64_TYPE, &U64_TYPE], &BOOL_TYPE); - env.define_prim_fun(U64Lt, [&U64_TYPE, &U64_TYPE], &BOOL_TYPE); - env.define_prim_fun(U64Gt, [&U64_TYPE, &U64_TYPE], &BOOL_TYPE); - env.define_prim_fun(U64Lte, [&U64_TYPE, &U64_TYPE], &BOOL_TYPE); - env.define_prim_fun(U64Gte, [&U64_TYPE, &U64_TYPE], &BOOL_TYPE); - env.define_prim_fun(U64Add, [&U64_TYPE, &U64_TYPE], &U64_TYPE); - env.define_prim_fun(U64Sub, [&U64_TYPE, &U64_TYPE], &U64_TYPE); - env.define_prim_fun(U64Mul, [&U64_TYPE, &U64_TYPE], &U64_TYPE); - env.define_prim_fun(U64Div, [&U64_TYPE, &U64_TYPE], &U64_TYPE); - env.define_prim_fun(U64Not, [&U64_TYPE], &U64_TYPE); - env.define_prim_fun(U64Shl, [&U64_TYPE, &U8_TYPE], &U64_TYPE); - env.define_prim_fun(U64Shr, [&U64_TYPE, &U8_TYPE], &U64_TYPE); - env.define_prim_fun(U64And, [&U64_TYPE, &U64_TYPE], &U64_TYPE); - env.define_prim_fun(U64Or, [&U64_TYPE, &U64_TYPE], &U64_TYPE); - env.define_prim_fun(U64Xor, [&U64_TYPE, &U64_TYPE], &U64_TYPE); - - env.define_prim_fun(S8Eq, [&S8_TYPE, &S8_TYPE], &BOOL_TYPE); - env.define_prim_fun(S8Neq, [&S8_TYPE, &S8_TYPE], &BOOL_TYPE); - env.define_prim_fun(S8Lt, [&S8_TYPE, &S8_TYPE], &BOOL_TYPE); - env.define_prim_fun(S8Gt, [&S8_TYPE, &S8_TYPE], &BOOL_TYPE); - env.define_prim_fun(S8Lte, [&S8_TYPE, &S8_TYPE], &BOOL_TYPE); - env.define_prim_fun(S8Gte, [&S8_TYPE, &S8_TYPE], &BOOL_TYPE); - env.define_prim_fun(S8Neg, [&S8_TYPE], &S8_TYPE); - env.define_prim_fun(S8Add, [&S8_TYPE, &S8_TYPE], &S8_TYPE); - env.define_prim_fun(S8Sub, [&S8_TYPE, &S8_TYPE], &S8_TYPE); - env.define_prim_fun(S8Mul, [&S8_TYPE, &S8_TYPE], &S8_TYPE); - env.define_prim_fun(S8Div, [&S8_TYPE, &S8_TYPE], &S8_TYPE); - env.define_prim_fun(S8Abs, [&S8_TYPE], &S8_TYPE); - env.define_prim_fun(S8UAbs, [&S8_TYPE], &U8_TYPE); - - env.define_prim_fun(S16Eq, [&S16_TYPE, &S16_TYPE], &BOOL_TYPE); - env.define_prim_fun(S16Neq, [&S16_TYPE, &S16_TYPE], &BOOL_TYPE); - env.define_prim_fun(S16Lt, [&S16_TYPE, &S16_TYPE], &BOOL_TYPE); - env.define_prim_fun(S16Gt, [&S16_TYPE, &S16_TYPE], &BOOL_TYPE); - env.define_prim_fun(S16Lte, [&S16_TYPE, &S16_TYPE], &BOOL_TYPE); - env.define_prim_fun(S16Gte, [&S16_TYPE, &S16_TYPE], &BOOL_TYPE); - env.define_prim_fun(S16Neg, [&S16_TYPE], &S16_TYPE); - env.define_prim_fun(S16Add, [&S16_TYPE, &S16_TYPE], &S16_TYPE); - env.define_prim_fun(S16Sub, [&S16_TYPE, &S16_TYPE], &S16_TYPE); - env.define_prim_fun(S16Mul, [&S16_TYPE, &S16_TYPE], &S16_TYPE); - env.define_prim_fun(S16Div, [&S16_TYPE, &S16_TYPE], &S16_TYPE); - env.define_prim_fun(S16Abs, [&S16_TYPE], &S16_TYPE); - env.define_prim_fun(S16UAbs, [&S16_TYPE], &U16_TYPE); - - env.define_prim_fun(S32Eq, [&S32_TYPE, &S32_TYPE], &BOOL_TYPE); - env.define_prim_fun(S32Neq, [&S32_TYPE, &S32_TYPE], &BOOL_TYPE); - env.define_prim_fun(S32Lt, [&S32_TYPE, &S32_TYPE], &BOOL_TYPE); - env.define_prim_fun(S32Gt, [&S32_TYPE, &S32_TYPE], &BOOL_TYPE); - env.define_prim_fun(S32Lte, [&S32_TYPE, &S32_TYPE], &BOOL_TYPE); - env.define_prim_fun(S32Gte, [&S32_TYPE, &S32_TYPE], &BOOL_TYPE); - env.define_prim_fun(S32Neg, [&S32_TYPE], &S32_TYPE); - env.define_prim_fun(S32Add, [&S32_TYPE, &S32_TYPE], &S32_TYPE); - env.define_prim_fun(S32Sub, [&S32_TYPE, &S32_TYPE], &S32_TYPE); - env.define_prim_fun(S32Mul, [&S32_TYPE, &S32_TYPE], &S32_TYPE); - env.define_prim_fun(S32Div, [&S32_TYPE, &S32_TYPE], &S32_TYPE); - env.define_prim_fun(S32Abs, [&S32_TYPE], &S32_TYPE); - env.define_prim_fun(S32UAbs, [&S32_TYPE], &U32_TYPE); - - env.define_prim_fun(S64Eq, [&S64_TYPE, &S64_TYPE], &BOOL_TYPE); - env.define_prim_fun(S64Neq, [&S64_TYPE, &S64_TYPE], &BOOL_TYPE); - env.define_prim_fun(S64Lt, [&S64_TYPE, &S64_TYPE], &BOOL_TYPE); - env.define_prim_fun(S64Gt, [&S64_TYPE, &S64_TYPE], &BOOL_TYPE); - env.define_prim_fun(S64Lte, [&S64_TYPE, &S64_TYPE], &BOOL_TYPE); - env.define_prim_fun(S64Gte, [&S64_TYPE, &S64_TYPE], &BOOL_TYPE); - env.define_prim_fun(S64Neg, [&S64_TYPE], &S64_TYPE); - env.define_prim_fun(S64Add, [&S64_TYPE, &S64_TYPE], &S64_TYPE); - env.define_prim_fun(S64Sub, [&S64_TYPE, &S64_TYPE], &S64_TYPE); - env.define_prim_fun(S64Mul, [&S64_TYPE, &S64_TYPE], &S64_TYPE); - env.define_prim_fun(S64Div, [&S64_TYPE, &S64_TYPE], &S64_TYPE); - env.define_prim_fun(S64Abs, [&S64_TYPE], &S64_TYPE); - env.define_prim_fun(S64UAbs, [&S64_TYPE], &U64_TYPE); + for prim in [BoolEq, BoolNeq, BoolAnd, BoolOr, BoolXor] { + env.define_prim_fun(prim, [&BOOL_TYPE, &BOOL_TYPE], &BOOL_TYPE); + } + + struct UintPrims { + r#type: &'static Term<'static>, + not: Prim, + relops: [Prim; 6], + binops: [Prim; 7], + shifts: [Prim; 2], + } + + const U8_PRIMS: UintPrims = UintPrims { + r#type: &U8_TYPE, + not: U8Not, + relops: [U8Eq, U8Neq, U8Lt, U8Gt, U8Lte, U8Gte], + binops: [U8Add, U8Sub, U8Mul, U8Div, U8And, U8Or, U8Xor], + shifts: [U8Shl, U8Shr], + }; + + const U16_PRIMS: UintPrims = UintPrims { + r#type: &U16_TYPE, + not: U16Not, + relops: [U16Eq, U16Neq, U16Lt, U16Gt, U16Lte, U16Gte], + binops: [U16Add, U16Sub, U16Mul, U16Div, U16And, U16Or, U16Xor], + shifts: [U16Shl, U16Shr], + }; + + const U32_PRIMS: UintPrims = UintPrims { + r#type: &U32_TYPE, + not: U32Not, + relops: [U32Eq, U32Neq, U32Lt, U32Gt, U32Lte, U32Gte], + binops: [U32Add, U32Sub, U32Mul, U32Div, U32And, U32Or, U32Xor], + shifts: [U32Shl, U32Shr], + }; + + const U64_PRIMS: UintPrims = UintPrims { + r#type: &U64_TYPE, + not: U64Not, + relops: [U64Eq, U64Neq, U64Lt, U64Gt, U64Lte, U64Gte], + binops: [U64Add, U64Sub, U64Mul, U64Div, U64And, U64Or, U64Xor], + shifts: [U64Shl, U64Shr], + }; + + for schema in [U8_PRIMS, U16_PRIMS, U32_PRIMS, U64_PRIMS] { + let r#type = schema.r#type; + env.define_prim_fun(schema.not, [r#type], r#type); + for prim in schema.relops { + env.define_prim_fun(prim, [r#type, r#type], &BOOL_TYPE); + } + for prim in schema.binops { + env.define_prim_fun(prim, [r#type, r#type], r#type); + } + for prim in schema.shifts { + env.define_prim_fun(prim, [r#type, &U8_TYPE], r#type); + } + } + + struct SintPrims { + signed_type: &'static Term<'static>, + unsigned_type: &'static Term<'static>, + relops: [Prim; 6], + binops: [Prim; 4], + neg: Prim, + abs: Prim, + uabs: Prim, + } + + const S8_PRIMS: SintPrims = SintPrims { + signed_type: &S8_TYPE, + unsigned_type: &U8_TYPE, + relops: [S8Eq, S8Neq, S8Lt, S8Gt, S8Lte, S8Gte], + binops: [S8Add, S8Sub, S8Mul, S8Div], + neg: S8Neg, + abs: S8Abs, + uabs: S8UAbs, + }; + + const S16_PRIMS: SintPrims = SintPrims { + signed_type: &S16_TYPE, + unsigned_type: &U16_TYPE, + relops: [S16Eq, S16Neq, S16Lt, S16Gt, S16Lte, S16Gte], + binops: [S16Add, S16Sub, S16Mul, S16Div], + neg: S16Neg, + abs: S16Abs, + uabs: S16UAbs, + }; + + const S32_PRIMS: SintPrims = SintPrims { + signed_type: &S32_TYPE, + unsigned_type: &U32_TYPE, + relops: [S32Eq, S32Neq, S32Lt, S32Gt, S32Lte, S32Gte], + binops: [S32Add, S32Sub, S32Mul, S32Div], + neg: S32Neg, + abs: S32Abs, + uabs: S32UAbs, + }; + + const S64_PRIMS: SintPrims = SintPrims { + signed_type: &S64_TYPE, + unsigned_type: &U64_TYPE, + relops: [S64Eq, S64Neq, S64Lt, S64Gt, S64Lte, S64Gte], + binops: [S64Add, S64Sub, S64Mul, S64Div], + neg: S64Neg, + abs: S64Abs, + uabs: S64UAbs, + }; + + for schema in [S8_PRIMS, S16_PRIMS, S32_PRIMS, S64_PRIMS] { + let r#type = schema.signed_type; + + for prim in schema.relops { + env.define_prim_fun(prim, [r#type, r#type], &BOOL_TYPE); + } + for prim in schema.binops { + env.define_prim_fun(prim, [r#type, r#type], r#type); + } + env.define_prim_fun(schema.neg, [r#type], r#type); + env.define_prim_fun(schema.abs, [r#type], r#type); + env.define_prim_fun(schema.uabs, [r#type], schema.unsigned_type); + } env.define_prim( OptionSome, // fun (@A : Type) -> A -> Option A // fun (@A : Type) -> A@0 -> Option A@1 - &core::Term::FunType( - Span::Empty, - Plicity::Implicit, - env.name("A"), - &UNIVERSE, - &Term::FunType( - Span::Empty, - Plicity::Explicit, - None, - &VAR0, - &Term::FunApp( - Span::Empty, - Plicity::Explicit, - &Term::Prim(Span::Empty, OptionType), - &VAR1, - ), - ), + &builder.fun_types( + [ + (Plicity::Implicit, env.name("A"), UNIVERSE), + (Plicity::Explicit, None, VAR0), + ], + builder.fun_app(Span::Empty, Plicity::Explicit, OPTION_TYPE, VAR1), ), ); env.define_prim( @@ -337,163 +315,91 @@ impl<'arena> Env<'arena> { Plicity::Implicit, env.name("A"), &UNIVERSE, - &Term::FunApp( - Span::Empty, - Plicity::Explicit, - &Term::Prim(Span::Empty, OptionType), - &VAR0, - ), + &Term::FunApp(Span::Empty, Plicity::Explicit, &OPTION_TYPE, &VAR0), ), ); env.define_prim( OptionFold, // fun (@A : Type) (@B : Type) -> B -> (A -> B ) -> Option A -> B // fun (@A : Type) (@B : Type) -> B@0 -> (A@2 -> B@2) -> Option A@3 -> B@3 - scope.to_scope(core::Term::FunType( - Span::Empty, - Plicity::Implicit, - env.name("A"), - &UNIVERSE, - scope.to_scope(core::Term::FunType( - Span::Empty, - Plicity::Implicit, - env.name("B"), - &UNIVERSE, - scope.to_scope(core::Term::FunType( - Span::Empty, + &builder.fun_types( + [ + (Plicity::Implicit, env.name("A"), UNIVERSE), + (Plicity::Implicit, env.name("B"), UNIVERSE), + (Plicity::Explicit, None, VAR0), + (Plicity::Explicit, None, builder.fun_types([VAR2], VAR2)), + ( Plicity::Explicit, None, - &VAR0, // B@0 - scope.to_scope(core::Term::FunType( - Span::Empty, - Plicity::Explicit, - None, - // A@2 -> B@2 - &Term::FunType(Span::Empty, Plicity::Explicit, None, &VAR2, &VAR2), - scope.to_scope(core::Term::FunType( - Span::Empty, - Plicity::Explicit, - None, - // Option A@3 - &Term::FunApp( - Span::Empty, - Plicity::Explicit, - &Term::Prim(Span::Empty, OptionType), - &VAR3, - ), - &VAR3, // B@3 - )), - )), - )), - )), - )), + builder.fun_apps(OPTION_TYPE, [VAR3]), + ), + ], + VAR3, + ), ); // fun (@len : UN) (@A : Type) -> (A -> Bool) -> ArrayN len A -> Option A // fun (@len : UN) (@A : Type) -> (A@0 -> Bool) -> ArrayN len@2 A@1 -> Option // A@2 - let find_type = |index_type, array_type| { - scope.to_scope(core::Term::FunType( - Span::Empty, - Plicity::Implicit, - env.name("len"), - index_type, - scope.to_scope(core::Term::FunType( - Span::Empty, - Plicity::Implicit, - env.name("A"), - &UNIVERSE, - scope.to_scope(core::Term::FunType( - Span::Empty, + fn find_type<'arena>( + env: &EnvBuilder<'arena>, + index_type: Term<'arena>, + array_type: Term<'arena>, + ) -> &'arena Term<'arena> { + let builder = env.builder(); + env.scope.to_scope(builder.fun_types( + [ + (Plicity::Implicit, env.name("len"), index_type), + (Plicity::Implicit, env.name("A"), UNIVERSE), + ( Plicity::Explicit, None, - // (A@0 -> Bool) - &Term::FunType(Span::Empty, Plicity::Explicit, None, &VAR0, &BOOL_TYPE), - scope.to_scope(core::Term::FunType( - Span::Empty, - Plicity::Explicit, - None, - // ArrayN len@2 A@1 - scope.to_scope(Term::FunApp( - Span::Empty, - Plicity::Explicit, - scope.to_scope(Term::FunApp( - Span::Empty, - Plicity::Explicit, - array_type, - &VAR2, - )), - &VAR1, - )), - // Option A@2 - &Term::FunApp( - Span::Empty, - Plicity::Explicit, - &Term::Prim(Span::Empty, OptionType), - &VAR2, - ), - )), - )), - )), + builder.fun_types([VAR0], BOOL_TYPE), + ), + ( + Plicity::Explicit, + None, + builder.fun_apps(array_type, [VAR2, VAR1]), + ), + ], + Term::FunApp(Span::Empty, Plicity::Explicit, &OPTION_TYPE, &VAR2), )) - }; - let array8_find_type = find_type(&U8_TYPE, &ARRAY8_TYPE); - let array16_find_type = find_type(&U16_TYPE, &ARRAY16_TYPE); - let array32_find_type = find_type(&U32_TYPE, &ARRAY32_TYPE); - let array64_find_type = find_type(&U64_TYPE, &ARRAY64_TYPE); - env.define_prim(Array8Find, array8_find_type); - env.define_prim(Array16Find, array16_find_type); - env.define_prim(Array32Find, array32_find_type); - env.define_prim(Array64Find, array64_find_type); + } // fun (@len : UN) (@A : Type) (index : UN) -> ArrayN len A -> A // fun (@len : UN) (@A : Type) (index : UN) -> ArrayN len@2 A@1 -> A@2 - let array_index_type = |index_type, array_type| { - scope.to_scope(core::Term::FunType( - Span::Empty, - Plicity::Implicit, - env.name("len"), - index_type, - scope.to_scope(core::Term::FunType( - Span::Empty, - Plicity::Implicit, - env.name("A"), - &UNIVERSE, - scope.to_scope(core::Term::FunType( - Span::Empty, + fn array_index_type<'arena>( + env: &EnvBuilder<'arena>, + index_type: Term<'arena>, + array_type: Term<'arena>, + ) -> &'arena Term<'arena> { + let builder = env.builder(); + env.scope.to_scope(builder.fun_types( + [ + (Plicity::Implicit, env.name("len"), index_type.clone()), + (Plicity::Implicit, env.name("A"), UNIVERSE), + (Plicity::Explicit, env.name("index"), index_type), + ( Plicity::Explicit, - env.name("index"), - index_type, - scope.to_scope(core::Term::FunType( - Span::Empty, - Plicity::Explicit, - None, - // ArrayN len@2 A@1 - scope.to_scope(Term::FunApp( - Span::Empty, - Plicity::Explicit, - scope.to_scope(Term::FunApp( - Span::Empty, - Plicity::Explicit, - array_type, - &VAR2, - )), - &VAR1, - )), - &VAR2, // A@2 - )), - )), - )), + None, + builder.fun_apps(array_type, [VAR2, VAR1]), + ), + ], + VAR2, )) - }; - let array8_index_type = array_index_type(&U8_TYPE, &ARRAY8_TYPE); - let array16_index_type = array_index_type(&U16_TYPE, &ARRAY16_TYPE); - let array32_index_type = array_index_type(&U32_TYPE, &ARRAY32_TYPE); - let array64_index_type = array_index_type(&U64_TYPE, &ARRAY64_TYPE); - env.define_prim(Array8Index, array8_index_type); - env.define_prim(Array16Index, array16_index_type); - env.define_prim(Array32Index, array32_index_type); - env.define_prim(Array64Index, array64_index_type); + } + + for (prim1, prim2, index_type, array_type) in [ + (Array8Find, Array8Index, U8_TYPE, ARRAY8_TYPE), + (Array16Find, Array16Index, U16_TYPE, ARRAY16_TYPE), + (Array32Find, Array32Index, U32_TYPE, ARRAY32_TYPE), + (Array64Find, Array64Index, U64_TYPE, ARRAY64_TYPE), + ] { + let find_type = find_type(&env, index_type.clone(), array_type.clone()); + let array_index_type = array_index_type(&env, index_type, array_type); + env.define_prim(prim1, find_type); + env.define_prim(prim2, array_index_type); + } env.define_prim_fun(PosAddU8, [&POS_TYPE, &U8_TYPE], &POS_TYPE); env.define_prim_fun(PosAddU16, [&POS_TYPE, &U16_TYPE], &POS_TYPE); @@ -523,6 +429,10 @@ impl<'arena> EnvBuilder<'arena> { } } + fn builder(&self) -> Builder<'arena> { + Builder::new(self.scope) + } + fn name(&self, name: &'static str) -> Option { Some(Symbol::intern_static(name)) } diff --git a/fathom/src/core/semantics.rs b/fathom/src/core/semantics.rs index 5be7f9dbe..71201b363 100644 --- a/fathom/src/core/semantics.rs +++ b/fathom/src/core/semantics.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use scoped_arena::Scope; +use super::{Builder, LetDef}; use crate::alloc::SliceVec; use crate::core::{prim, Const, LocalInfo, Plicity, Prim, Term}; use crate::env::{EnvLen, Index, Level, SharedEnv, SliceEnv}; @@ -300,8 +301,8 @@ impl<'arena, 'env> EvalEnv<'arena, 'env> { self.apply_local_infos(head_expr, local_infos) } Term::Ann(span, expr, _) => Spanned::merge(*span, self.eval(expr)), - Term::Let(span, _, _, def_expr, body_expr) => { - let def_expr = self.eval(def_expr); + Term::Let(span, def, body_expr) => { + let def_expr = self.eval(&def.expr); self.local_exprs.push(def_expr); let body_expr = self.eval(body_expr); self.local_exprs.pop(); @@ -667,21 +668,17 @@ impl<'in_arena, 'env> QuoteEnv<'in_arena, 'env> { // NOTE: this copies more than is necessary when `'in_arena == 'out_arena`: // for example when copying label slices. + let builder = Builder::new(scope); let value = self.elim_env.force(value); let span = value.span(); match value.as_ref() { Value::Stuck(head, spine) => spine.iter().fold( self.quote_head(scope, span, head), |head_expr, elim| match elim { - Elim::FunApp(plicity, arg_expr) => Term::FunApp( - span, - *plicity, - scope.to_scope(head_expr), - scope.to_scope(self.quote(scope, arg_expr)), - ), - Elim::RecordProj(label) => { - Term::RecordProj(span, scope.to_scope(head_expr), *label) + Elim::FunApp(plicity, arg_expr) => { + builder.fun_app(span, *plicity, head_expr, self.quote(scope, arg_expr)) } + Elim::RecordProj(label) => builder.record_proj(span, head_expr, *label), Elim::ConstMatch(branches) => { let mut branches = branches.clone(); let mut pattern_branches = SliceVec::new(scope, branches.num_patterns()); @@ -699,9 +696,9 @@ impl<'in_arena, 'env> QuoteEnv<'in_arena, 'env> { } }; - Term::ConstMatch( + builder.const_match( span, - scope.to_scope(head_expr), + head_expr, pattern_branches.into(), default_branch .map(|(name, expr)| (name, self.quote_closure(scope, &expr))), @@ -712,20 +709,19 @@ impl<'in_arena, 'env> QuoteEnv<'in_arena, 'env> { Value::Universe => Term::Universe(span), - Value::FunType(plicity, param_name, param_type, body_type) => Term::FunType( + Value::FunType(plicity, param_name, param_type, body_type) => builder.fun_type( span, *plicity, *param_name, - scope.to_scope(self.quote(scope, param_type)), + self.quote(scope, param_type), self.quote_closure(scope, body_type), ), - Value::FunLit(plicity, param_name, body_expr) => Term::FunLit( + Value::FunLit(plicity, param_name, body_expr) => builder.fun_lit( span, *plicity, *param_name, self.quote_closure(scope, body_expr), ), - Value::RecordType(labels, types) => Term::RecordType( span, scope.to_scope_from_iter(labels.iter().copied()), @@ -746,10 +742,10 @@ impl<'in_arena, 'env> QuoteEnv<'in_arena, 'env> { scope.to_scope_from_iter(labels.iter().copied()), self.quote_telescope(scope, formats), ), - Value::FormatCond(label, format, cond) => Term::FormatCond( + Value::FormatCond(label, format, cond) => builder.format_cond( span, *label, - scope.to_scope(self.quote(scope, format)), + self.quote(scope, format), self.quote_closure(scope, cond), ), Value::FormatOverlap(labels, formats) => Term::FormatOverlap( @@ -792,7 +788,7 @@ impl<'in_arena, 'env> QuoteEnv<'in_arena, 'env> { &mut self, scope: &'out_arena Scope<'out_arena>, closure: &Closure<'in_arena>, - ) -> &'out_arena Term<'out_arena> { + ) -> Term<'out_arena> { let var = Arc::new(Value::local_var(self.local_exprs.next_level())); let value = self.elim_env.apply_closure(closure, Spanned::empty(var)); @@ -800,7 +796,7 @@ impl<'in_arena, 'env> QuoteEnv<'in_arena, 'env> { let term = self.quote(scope, &value); self.pop_local(); - scope.to_scope(term) + term } /// Quote a [telescope][Telescope] back into a slice of [terms][Term]. @@ -848,6 +844,8 @@ impl<'arena, 'env> EvalEnv<'arena, 'env> { scope: &'out_arena Scope<'out_arena>, term: &Term<'arena>, ) -> Term<'out_arena> { + let builder = Builder::new(scope); + match term { Term::ItemVar(span, var) => Term::ItemVar(*span, *var), Term::LocalVar(span, var) => Term::LocalVar(*span, *var), @@ -871,29 +869,31 @@ impl<'arena, 'env> EvalEnv<'arena, 'env> { Term::InsertedMeta(*span, *var, infos) } }, - Term::Ann(span, expr, r#type) => Term::Ann( + Term::Ann(span, expr, r#type) => builder.ann( *span, - scope.to_scope(self.unfold_metas(scope, expr)), - scope.to_scope(self.unfold_metas(scope, r#type)), + self.unfold_metas(scope, expr), + self.unfold_metas(scope, r#type), ), - Term::Let(span, def_name, def_type, def_expr, body_expr) => Term::Let( + Term::Let(span, def, body_expr) => builder.r#let( *span, - *def_name, - scope.to_scope(self.unfold_metas(scope, def_type)), - scope.to_scope(self.unfold_metas(scope, def_expr)), + LetDef { + name: def.name, + r#type: (self.unfold_metas(scope, &def.r#type)), + expr: (self.unfold_metas(scope, &def.expr)), + }, self.unfold_bound_metas(scope, body_expr), ), Term::Universe(span) => Term::Universe(*span), - Term::FunType(span, plicity, param_name, param_type, body_type) => Term::FunType( + Term::FunType(span, plicity, param_name, param_type, body_type) => builder.fun_type( *span, *plicity, *param_name, - scope.to_scope(self.unfold_metas(scope, param_type)), + self.unfold_metas(scope, param_type), self.unfold_bound_metas(scope, body_type), ), - Term::FunLit(span, plicity, param_name, body_expr) => Term::FunLit( + Term::FunLit(span, plicity, param_name, body_expr) => builder.fun_lit( *span, *plicity, *param_name, @@ -921,10 +921,10 @@ impl<'arena, 'env> EvalEnv<'arena, 'env> { scope.to_scope_from_iter(labels.iter().copied()), self.unfold_telescope_metas(scope, formats), ), - Term::FormatCond(span, name, format, pred) => Term::FormatCond( + Term::FormatCond(span, name, format, pred) => builder.format_cond( *span, *name, - scope.to_scope(self.unfold_metas(scope, format)), + self.unfold_metas(scope, format), self.unfold_bound_metas(scope, pred), ), Term::FormatOverlap(span, labels, formats) => Term::FormatOverlap( @@ -945,6 +945,8 @@ impl<'arena, 'env> EvalEnv<'arena, 'env> { scope: &'out_arena Scope<'out_arena>, term: &Term<'arena>, ) -> TermOrValue<'arena, 'out_arena> { + let builder = Builder::new(scope); + // Recurse to find the head of an elimination, checking if it's a // metavariable. If so, check if it has a solution, and then apply // eliminations to the solution in turn on our way back out. @@ -971,11 +973,11 @@ impl<'arena, 'env> EvalEnv<'arena, 'env> { Term::FunApp(span, plicity, head_expr, arg_expr) => { match self.unfold_meta_var_spines(scope, head_expr) { - TermOrValue::Term(head_expr) => TermOrValue::Term(Term::FunApp( + TermOrValue::Term(head_expr) => TermOrValue::Term(builder.fun_app( *span, *plicity, - scope.to_scope(head_expr), - scope.to_scope(self.unfold_metas(scope, arg_expr)), + head_expr, + self.unfold_metas(scope, arg_expr), )), TermOrValue::Value(head_expr) => { let arg_expr = self.eval(arg_expr); @@ -985,11 +987,9 @@ impl<'arena, 'env> EvalEnv<'arena, 'env> { } Term::RecordProj(span, head_expr, label) => { match self.unfold_meta_var_spines(scope, head_expr) { - TermOrValue::Term(head_expr) => TermOrValue::Term(Term::RecordProj( - *span, - scope.to_scope(head_expr), - *label, - )), + TermOrValue::Term(head_expr) => { + TermOrValue::Term(builder.record_proj(*span, head_expr, *label)) + } TermOrValue::Value(head_expr) => { TermOrValue::Value(self.elim_env.record_proj(head_expr, *label)) } @@ -997,16 +997,19 @@ impl<'arena, 'env> EvalEnv<'arena, 'env> { } Term::ConstMatch(span, head_expr, branches, default_branch) => { match self.unfold_meta_var_spines(scope, head_expr) { - TermOrValue::Term(head_expr) => TermOrValue::Term(Term::ConstMatch( - *span, - scope.to_scope(head_expr), - scope.to_scope_from_iter( - (branches.iter()) - .map(|(r#const, expr)| (*r#const, self.unfold_metas(scope, expr))), + TermOrValue::Term(head_expr) => TermOrValue::Term( + builder.const_match( + *span, + head_expr, + scope.to_scope_from_iter( + branches.iter().map(|(r#const, expr)| { + (*r#const, self.unfold_metas(scope, expr)) + }), + ), + default_branch + .map(|(name, expr)| (name, self.unfold_bound_metas(scope, expr))), ), - default_branch - .map(|(name, expr)| (name, self.unfold_bound_metas(scope, expr))), - )), + ), TermOrValue::Value(head_expr) => { let branches = Branches::new(self.local_exprs.clone(), branches, *default_branch); @@ -1023,14 +1026,14 @@ impl<'arena, 'env> EvalEnv<'arena, 'env> { &mut self, scope: &'out_arena Scope<'out_arena>, term: &Term<'arena>, - ) -> &'out_arena Term<'out_arena> { + ) -> Term<'out_arena> { let var = Arc::new(Value::local_var(self.local_exprs.len().next_level())); self.local_exprs.push(Spanned::empty(var)); let term = self.unfold_metas(scope, term); self.local_exprs.pop(); - scope.to_scope(term) + term } fn unfold_telescope_metas<'out_arena>( diff --git a/fathom/src/env.rs b/fathom/src/env.rs index e4775e146..8d425b48b 100644 --- a/fathom/src/env.rs +++ b/fathom/src/env.rs @@ -54,6 +54,16 @@ impl Index { pub const fn prev(self) -> Index { Index(self.0 + 1) } + + /// An iterator over indices, listed from the most recently bound. + pub fn iter() -> impl Iterator { + (0..).map(Self) + } + + /// An iterator over indices, listed from `self` + pub fn iter_from(self) -> impl Iterator { + (self.0..).map(Self) + } } impl fmt::Debug for Index { @@ -70,11 +80,6 @@ impl fmt::Display for Index { } } -/// An iterator over indices, listed from the most recently bound. -pub fn indices() -> impl Iterator { - (0..).map(Index) -} - /// A de Bruijn level, which represents a variable by counting the number of /// binders between the binder that introduced the variable and the start of the /// environment. For example: @@ -101,6 +106,16 @@ impl Level { pub const fn next(self) -> Level { Level(self.0 + 1) } + + /// An iterator over levels, listed from the least recently bound. + pub fn iter() -> impl Iterator { + (0..).map(Self) + } + + /// An iterator over levels, listed from `self` + pub fn iter_from(self) -> impl Iterator { + (self.0..).map(Self) + } } impl fmt::Debug for Level { @@ -117,11 +132,6 @@ impl fmt::Display for Level { } } -/// An iterator over levels, listed from the least recently bound. -pub fn levels() -> impl Iterator { - (0..).map(Level) -} - /// The length of an environment. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct EnvLen(RawVar); @@ -270,11 +280,12 @@ impl SliceEnv { impl SliceEnv { pub fn elem_level(&self, entry: &Entry) -> Option { - Iterator::zip(levels(), self.iter()).find_map(|(var, e)| (entry == e).then_some(var)) + Iterator::zip(Level::iter(), self.iter()).find_map(|(var, e)| (entry == e).then_some(var)) } pub fn elem_index(&self, entry: &Entry) -> Option { - Iterator::zip(indices(), self.iter().rev()).find_map(|(var, e)| (entry == e).then_some(var)) + Iterator::zip(Index::iter(), self.iter().rev()) + .find_map(|(var, e)| (entry == e).then_some(var)) } } diff --git a/fathom/src/source.rs b/fathom/src/source.rs index 60f0e1286..ae035c549 100644 --- a/fathom/src/source.rs +++ b/fathom/src/source.rs @@ -162,6 +162,12 @@ impl fmt::Debug for ByteRange { } } +impl From<(BytePos, BytePos)> for ByteRange { + fn from((start, end): (BytePos, BytePos)) -> Self { + Self { start, end } + } +} + impl ByteRange { pub fn new(start: BytePos, end: BytePos) -> ByteRange { ByteRange { start, end } diff --git a/fathom/src/surface.rs b/fathom/src/surface.rs index 354a86c21..f0e6cdf69 100644 --- a/fathom/src/surface.rs +++ b/fathom/src/surface.rs @@ -1,6 +1,7 @@ //! Surface language. use std::fmt; +use std::ops::Deref; use codespan_reporting::diagnostic::{Diagnostic, Label}; use lalrpop_util::lalrpop_mod; @@ -68,9 +69,9 @@ pub struct ItemDef<'arena, Range> { /// Parameter patterns params: &'arena [Param<'arena, Range>], /// An optional type annotation for the defined expression - r#type: Option<&'arena Term<'arena, Range>>, + r#type: Option>, /// The defined expression - expr: &'arena Term<'arena, Range>, + expr: Term<'arena, Range>, } /// Surface patterns. @@ -93,7 +94,7 @@ pub enum Pattern { /// Boolean literal patterns BooleanLiteral(Range, bool), // TODO: Record literal patterns - // RecordLiteral(Range, &'arena [((Range, StringId), Pattern<'arena, Range>)]), + // RecordLiteral(Range, &'arena [((Range, Symbol), Pattern<'arena, Range>)]), } #[derive(Debug, Clone, Copy)] @@ -198,9 +199,7 @@ pub enum Term<'arena, Range> { /// Let expressions. Let( Range, - Pattern, - Option<&'arena Term<'arena, Range>>, - &'arena Term<'arena, Range>, + &'arena LetDef<'arena, Range>, &'arena Term<'arena, Range>, ), /// If expressions @@ -299,7 +298,7 @@ impl<'arena, Range: Clone> Term<'arena, Range> { | Term::Hole(range, _) | Term::Placeholder(range) | Term::Ann(range, _, _) - | Term::Let(range, _, _, _, _) + | Term::Let(range, ..) | Term::If(range, _, _, _) | Term::Match(range, _, _) | Term::Universe(range) @@ -324,6 +323,88 @@ impl<'arena, Range: Clone> Term<'arena, Range> { } } +impl<'arena, Range> Term<'arena, Range> { + /// Apply `f` to all the child terms of self. + /// Useful for implementing functions that need to behave differently on a + /// few specific node types, but otherwise just recurse through the tree + /// normally + pub fn walk(&self, mut f: impl FnMut(&Self)) { + let mut recur2 = |term1, term2| { + f(term1); + f(term2) + }; + + match self { + Term::Name(_, _) + | Term::Hole(_, _) + | Term::Placeholder(_) + | Term::Universe(_) + | Term::StringLiteral(_, _) + | Term::NumberLiteral(_, _) + | Term::BooleanLiteral(_, _) + | Term::ReportedError(_) => {} + + Term::Paren(_, term) => f(term), + Term::Ann(_, term, r#type) => recur2(term, r#type), + Term::Let(_, def, body) => { + if let Some(r#type) = def.r#type.as_ref() { + f(r#type); + } + f(&def.expr); + f(body); + } + Term::If(_, cond, then, r#else) => { + f(cond); + f(then); + f(r#else) + } + Term::Match(_, scrut, branches) => { + f(scrut); + branches.iter().for_each(|(_, term)| f(term)); + } + Term::Arrow(_, _, input, output) => recur2(input, output), + Term::FunType(_, params, body) | Term::FunLiteral(_, params, body) => { + params.iter().for_each(|param| { + if let Some(r#type) = param.r#type.as_ref() { + f(r#type) + } + }); + f(body) + } + Term::App(_, head, args) => { + f(head); + args.iter().for_each(|arg| f(&arg.term)); + } + Term::RecordType(_, fields) => fields.iter().for_each(|field| f(&field.r#type)), + Term::RecordLiteral(_, field) => field.iter().for_each(|field| { + if let Some(term) = field.expr.as_ref() { + f(term) + } + }), + Term::Proj(_, head, _) => f(head), + Term::Tuple(_, terms) | Term::ArrayLiteral(_, terms) => terms.iter().for_each(f), + Term::FormatRecord(_, fields) | Term::FormatOverlap(_, fields) => { + fields.iter().for_each(|field| match field { + FormatField::Format { format, pred, .. } => { + f(format); + if let Some(pred) = pred { + f(pred) + } + } + FormatField::Computed { r#type, expr, .. } => { + if let Some(r#type) = r#type { + f(r#type); + } + f(expr); + } + }) + } + Term::FormatCond(_, _, format, pred) => recur2(format, pred), + Term::BinOp(_, lhs, _, rhs) => recur2(lhs, rhs), + } + } +} + impl<'arena> Term<'arena, FileRange> { /// Parse a term from the `source` string, interning strings to the /// supplied `interner` and allocating nodes to the `arena`. @@ -347,6 +428,13 @@ impl<'arena> Term<'arena, FileRange> { } } +#[derive(Debug, Clone)] +pub struct LetDef<'arena, Range> { + pub pattern: Pattern, + pub r#type: Option>, + pub expr: Term<'arena, Range>, +} + #[derive(Debug, Clone)] pub struct Param<'arena, Range> { pub plicity: Plicity, @@ -513,6 +601,215 @@ fn format_expected(expected: &[impl std::fmt::Display]) -> Option { }) } +pub struct Builder<'arena> { + scope: &'arena Scope<'arena>, +} + +impl<'arena> Deref for Builder<'arena> { + type Target = Scope<'arena>; + + fn deref(&self) -> &Self::Target { + self.scope + } +} + +impl<'arena> Builder<'arena> { + pub fn new(scope: &'arena Scope<'arena>) -> Self { + Self { scope } + } + + pub fn into_scope(self) -> &'arena Scope<'arena> { + self.scope + } + + pub fn paren( + &self, + range: impl Into, + term: Term<'arena, Range>, + ) -> Term<'arena, Range> { + Term::Paren(range.into(), self.scope.to_scope(term)) + } + + pub fn ann( + &self, + range: impl Into, + term: Term<'arena, Range>, + r#type: Term<'arena, Range>, + ) -> Term<'arena, Range> { + Term::Ann( + range.into(), + self.scope.to_scope(term), + self.scope.to_scope(r#type), + ) + } + + pub fn r#let( + &self, + range: impl Into, + def: LetDef<'arena, Range>, + body: Term<'arena, Range>, + ) -> Term<'arena, Range> { + Term::Let( + range.into(), + self.scope.to_scope(def), + self.scope.to_scope(body), + ) + } + + pub fn if_then_else( + &self, + range: impl Into, + cond: Term<'arena, Range>, + then: Term<'arena, Range>, + r#else: Term<'arena, Range>, + ) -> Term<'arena, Range> { + Term::If( + range.into(), + self.scope.to_scope(cond), + self.scope.to_scope(then), + self.scope.to_scope(r#else), + ) + } + + pub fn r#match( + &self, + range: impl Into, + scrut: Term<'arena, Range>, + branches: &'arena [(Pattern, Term<'arena, Range>)], + ) -> Term<'arena, Range> { + Term::Match(range.into(), self.scope.to_scope(scrut), branches) + } + + pub fn arrow( + &self, + range: impl Into, + plicity: Plicity, + r#type: Term<'arena, Range>, + body: Term<'arena, Range>, + ) -> Term<'arena, Range> { + Term::Arrow( + range.into(), + plicity, + self.scope.to_scope(r#type), + self.scope.to_scope(body), + ) + } + + pub fn fun_type( + &self, + range: impl Into, + params: &'arena [Param<'arena, Range>], + body: Term<'arena, Range>, + ) -> Term<'arena, Range> { + Term::FunType(range.into(), params, self.scope.to_scope(body)) + } + + pub fn fun_lit( + &self, + range: impl Into, + params: &'arena [Param<'arena, Range>], + body: Term<'arena, Range>, + ) -> Term<'arena, Range> { + Term::FunLiteral(range.into(), params, self.scope.to_scope(body)) + } + + pub fn record_type( + &self, + range: impl Into, + fields: &'arena [TypeField<'arena, Range>], + ) -> Term<'arena, Range> { + Term::RecordType(range.into(), fields) + } + + pub fn record_lit( + &self, + range: impl Into, + fields: &'arena [ExprField<'arena, Range>], + ) -> Term<'arena, Range> { + Term::RecordLiteral(range.into(), fields) + } + + pub fn tuple( + &self, + range: impl Into, + terms: &'arena [Term<'arena, Range>], + ) -> Term<'arena, Range> { + Term::Tuple(range.into(), terms) + } + + pub fn fun_app( + &self, + range: impl Into, + fun: Term<'arena, Range>, + args: &'arena [Arg<'arena, Range>], + ) -> Term<'arena, Range> { + Term::App(range.into(), self.scope.to_scope(fun), args) + } + + pub fn record_proj( + &self, + range: impl Into, + head: Term<'arena, Range>, + labels: &'arena [(Range, Symbol)], + ) -> Term<'arena, Range> { + Term::Proj(range.into(), self.scope.to_scope(head), labels) + } + + pub fn array_lit( + &self, + range: impl Into, + terms: &'arena [Term<'arena, Range>], + ) -> Term<'arena, Range> { + Term::ArrayLiteral(range.into(), terms) + } + + pub fn format_record( + &self, + range: impl Into, + fields: &'arena [FormatField<'arena, Range>], + ) -> Term<'arena, Range> { + Term::FormatRecord(range.into(), fields) + } + + pub fn format_overlap( + &self, + range: impl Into, + fields: &'arena [FormatField<'arena, Range>], + ) -> Term<'arena, Range> { + Term::FormatOverlap(range.into(), fields) + } + + pub fn format_cond( + &self, + range: impl Into, + label: (Range, Symbol), + format: Term<'arena, Range>, + pred: Term<'arena, Range>, + ) -> Term<'arena, Range> { + Term::FormatCond( + range.into(), + label, + self.scope.to_scope(format), + self.scope.to_scope(pred), + ) + } + + pub fn binop( + &self, + range: impl Into, + lhs: Term<'arena, Range>, + op: BinOp, + rhs: Term<'arena, Range>, + ) -> Term<'arena, Range> { + Term::BinOp( + range.into(), + self.scope.to_scope(lhs), + op, + self.scope.to_scope(rhs), + ) + } +} + #[cfg(test)] mod tests { use super::*; @@ -528,7 +825,7 @@ mod tests { #[cfg(target_pointer_width = "64")] fn term_size() { assert_eq!(std::mem::size_of::>(), 32); - assert_eq!(std::mem::size_of::>(), 48); + assert_eq!(std::mem::size_of::>(), 40); } #[test] diff --git a/fathom/src/surface/distillation.rs b/fathom/src/surface/distillation.rs index d8669ee0d..7b9a1b1d9 100644 --- a/fathom/src/surface/distillation.rs +++ b/fathom/src/surface/distillation.rs @@ -2,6 +2,7 @@ use scoped_arena::Scope; +use super::{Builder, LetDef}; use crate::alloc::SliceVec; use crate::core; use crate::core::{Const, Plicity, UIntStyle}; @@ -62,6 +63,10 @@ impl<'arena, 'env> Context<'arena, 'env> { } } + fn builder(&self) -> Builder<'arena> { + Builder::new(self.scope) + } + fn is_bound(&self, name: Symbol) -> bool { (self.local_names.iter()).any(|local_name| *local_name == Some(name)) || self.item_names.iter().any(|item_name| *item_name == name) @@ -86,8 +91,18 @@ impl<'arena, 'env> Context<'arena, 'env> { name } - fn pop_local(&mut self) { + fn with_scope(&mut self, mut f: impl FnMut(&mut Self) -> T) -> T { + let len = self.local_len(); + let res = f(self); + self.local_names.truncate(len); + res + } + + fn with_local(&mut self, name: Option, mut f: impl FnMut(&mut Self) -> T) -> T { + self.local_names.push(name); + let res = f(self); self.local_names.pop(); + res } fn truncate_local(&mut self, len: EnvLen) { @@ -142,8 +157,8 @@ impl<'arena, 'env> Context<'arena, 'env> { r#type, expr, } => { - let r#type = scope.to_scope(self.check_prec(Prec::Top, r#type)); - let expr = scope.to_scope(self.check_prec(Prec::Let, expr)); + let r#type = self.check_prec(Prec::Top, r#type); + let expr = self.check_prec(Prec::Let, expr); Item::Def(ItemDef { range: (), @@ -214,10 +229,7 @@ impl<'arena, 'env> Context<'arena, 'env> { let expr = self.check_number_literal(number); let r#type = self.synth_prim(prim_type); - self.paren( - prec > Prec::Top, - Term::Ann((), self.scope.to_scope(expr), self.scope.to_scope(r#type)), - ) + self.paren(prec > Prec::Top, self.builder().ann((), expr, r#type)) } fn synth_number_literal_styled, const N: usize>( @@ -230,10 +242,7 @@ impl<'arena, 'env> Context<'arena, 'env> { let expr = self.check_number_literal_styled(number, style); let r#type = self.synth_prim(prim_type); - self.paren( - prec > Prec::Top, - Term::Ann((), self.scope.to_scope(expr), self.scope.to_scope(r#type)), - ) + self.paren(prec > Prec::Top, self.builder().ann((), expr, r#type)) } fn check_dependent_tuple( @@ -242,15 +251,15 @@ impl<'arena, 'env> Context<'arena, 'env> { exprs: &[core::Term<'_>], ) -> Term<'arena, ()> { self.local_names.reserve(labels.len()); - let initial_local_len = self.local_len(); - let exprs = (self.scope).to_scope_from_iter( - Iterator::zip(labels.iter(), exprs.iter()).map(|(label, expr)| { - let expr = self.check_prec(Prec::Top, expr); - self.push_local(Some(*label)); - expr - }), - ); - self.truncate_local(initial_local_len); + let exprs = self.with_scope(|this| { + (this.scope).to_scope_from_iter(Iterator::zip(labels.iter(), exprs.iter()).map( + |(label, expr)| { + let expr = this.check_prec(Prec::Top, expr); + this.push_local(Some(*label)); + expr + }, + )) + }); Term::Tuple((), exprs) } @@ -315,13 +324,15 @@ impl<'arena, 'env> Context<'arena, 'env> { /// Wrap a term in parens. fn paren(&self, wrap: bool, term: Term<'arena, ()>) -> Term<'arena, ()> { if wrap { - Term::Paren((), self.scope.to_scope(term)) + self.builder().paren((), term) } else { term } } fn term_prec(&mut self, mode: Mode, prec: Prec, term: &core::Term<'_>) -> Term<'arena, ()> { + let builder = self.builder(); + match (term, mode) { (core::Term::ItemVar(_, var), _) => match self.get_item_name(*var) { Some(name) => Term::Name((), name), @@ -349,7 +360,7 @@ impl<'arena, 'env> Context<'arena, 'env> { let head_expr = self.scope.to_scope(head_expr); let mut args = SliceVec::new(self.scope, num_params); - for (var, info) in Iterator::zip(env::levels(), local_infos.iter()) { + for (var, info) in Iterator::zip(env::Level::iter(), local_infos.iter()) { match info { core::LocalInfo::Def => {} core::LocalInfo::Param => { @@ -373,101 +384,88 @@ impl<'arena, 'env> Context<'arena, 'env> { let expr = self.check_prec(Prec::Let, expr); let r#type = self.check_prec(Prec::Top, r#type); - self.paren( - prec > Prec::Top, - Term::Ann((), self.scope.to_scope(expr), self.scope.to_scope(r#type)), - ) + self.paren(prec > Prec::Top, builder.ann((), expr, r#type)) } - (core::Term::Let(_, name, r#type, expr, body), _) => { - let r#type = self.term_prec(mode, Prec::Top, r#type); - let expr = self.term_prec(mode, Prec::Let, expr); - let name = self.freshen_name(*name, body); - let name = self.push_local(name); + (core::Term::Let(_, def, body), _) => { + let r#type = self.term_prec(mode, Prec::Top, &def.r#type); + let expr = self.term_prec(mode, Prec::Let, &def.expr); + let name = self.freshen_name(def.name, body); let pattern = name_to_pattern(name); - let body = self.term_prec(mode, Prec::Top, body); - self.pop_local(); + let body = self.with_local(name, |this| this.term_prec(mode, Prec::Top, body)); self.paren( prec > Prec::Let, - Term::Let( + builder.r#let( (), - pattern, - Some(self.scope.to_scope(r#type)), - self.scope.to_scope(expr), - self.scope.to_scope(body), + LetDef { + pattern, + r#type: Some(r#type), + expr, + }, + body, ), ) } (core::Term::Universe(_), _) => Term::Universe(()), (core::Term::FunType(..), _) => { - let initial_local_len = self.local_len(); - let mut params = Vec::new(); let mut body_type = term; - let body_type = loop { - match body_type { - // Use an explicit parameter if it is referenced in the body - core::Term::FunType(_, plicity, param_name, param_type, next_body_type) - if next_body_type.binds_local(Index::last()) => - { - let param_type = self.check_prec(Prec::Top, param_type); - let param_name = self.freshen_name(*param_name, next_body_type); - let param_name = self.push_local(param_name); - params.push(Param { - plicity: *plicity, - pattern: name_to_pattern(param_name), - r#type: Some(param_type), - }); - body_type = next_body_type; - } - // Use arrow sugar if the parameter is not referenced in the body type. - core::Term::FunType(_, plicity, _, param_type, body_type) => { - let param_type = self.check_prec(Prec::App, param_type); - - self.push_local(None); - let body_type = self.check_prec(Prec::Fun, body_type); - self.pop_local(); - - break Term::Arrow( - (), - *plicity, - self.scope.to_scope(param_type), - self.scope.to_scope(body_type), - ); + let body_type = self.with_scope(|this| { + loop { + match body_type { + // Use an explicit parameter if it is referenced in the body + core::Term::FunType( + _, + plicity, + param_name, + param_type, + next_body_type, + ) if next_body_type.binds_local(Index::last()) => { + let param_type = this.check_prec(Prec::Top, param_type); + let param_name = this.freshen_name(*param_name, next_body_type); + let param_name = this.push_local(param_name); + params.push(Param { + plicity: *plicity, + pattern: name_to_pattern(param_name), + r#type: Some(param_type), + }); + body_type = next_body_type; + } + // Use arrow sugar if the parameter is not referenced in the body type. + core::Term::FunType(_, plicity, _, param_type, body_type) => { + let param_type = this.check_prec(Prec::App, param_type); + let body_type = this + .with_local(None, |this| this.check_prec(Prec::Fun, body_type)); + break builder.arrow((), *plicity, param_type, body_type); + } + body_type => break this.check_prec(Prec::Fun, body_type), } - body_type => break self.check_prec(Prec::Fun, body_type), } - }; - - self.truncate_local(initial_local_len); + }); self.paren( prec > Prec::Fun, if params.is_empty() { body_type } else { - Term::FunType( - (), - self.scope.to_scope_from_iter(params), - self.scope.to_scope(body_type), - ) + builder.fun_type((), self.scope.to_scope_from_iter(params), body_type) }, ) } (core::Term::FunLit(..), _) => { - let initial_local_len = self.local_len(); let mut params = Vec::new(); let mut body_expr = term; - while let core::Term::FunLit(_, plicity, param_name, next_body_expr) = body_expr { - let param_name = self.freshen_name(*param_name, next_body_expr); - params.push((*plicity, self.push_local(param_name))); - body_expr = next_body_expr; - } - - let body_expr = self.term_prec(mode, Prec::Let, body_expr); - self.truncate_local(initial_local_len); + let body_expr = self.with_scope(|this| { + while let core::Term::FunLit(_, plicity, param_name, next_body_expr) = body_expr + { + let param_name = this.freshen_name(*param_name, next_body_expr); + params.push((*plicity, this.push_local(param_name))); + body_expr = next_body_expr; + } + this.term_prec(mode, Prec::Let, body_expr) + }); let params = params.into_iter().map(|(plicity, name)| Param { plicity, @@ -477,11 +475,7 @@ impl<'arena, 'env> Context<'arena, 'env> { self.paren( prec > Prec::Fun, - Term::FunLiteral( - (), - self.scope.to_scope_from_iter(params), - self.scope.to_scope(body_expr), - ), + builder.fun_lit((), self.scope.to_scope_from_iter(params), body_expr), ) } (core::Term::FunApp(..), _) => { @@ -490,9 +484,9 @@ impl<'arena, 'env> Context<'arena, 'env> { // ((op lhs) rhs) if let core::Term::FunApp(.., core::Term::FunApp(.., core::Term::Prim(_, prim), lhs), rhs,) = term { if let Some(op) = prim_to_bin_op(prim) { - let lhs = self.scope.to_scope(self.synth_prec(op.lhs_prec(), lhs)); - let rhs = self.scope.to_scope(self.synth_prec(op.rhs_prec(), rhs)); - return self.paren(prec > op.precedence(), Term::BinOp((), lhs, op, rhs)); + let lhs = self.synth_prec(op.lhs_prec(), lhs); + let rhs = self.synth_prec(op.rhs_prec(), rhs); + return self.paren(prec > op.precedence(), builder.binop((), lhs, op, rhs)); } }; @@ -519,24 +513,25 @@ impl<'arena, 'env> Context<'arena, 'env> { (core::Term::RecordType(_, labels, types), _) if is_tuple_type(labels, types) => { let tuple = self.check_dependent_tuple(labels, types); match mode { - Mode::Synth => Term::Ann((), self.scope.to_scope(tuple), &Term::Universe(())), + Mode::Synth => builder.ann((), tuple, Term::Universe(())), Mode::Check => tuple, } } (core::Term::RecordType(_, labels, types), _) => { self.local_names.reserve(labels.len()); - let initial_local_len = self.local_len(); - let type_fields = (self.scope).to_scope_from_iter( - Iterator::zip(labels.iter(), types.iter()).map(|(label, r#type)| { - let r#type = self.check_prec(Prec::Top, r#type); - self.push_local(Some(*label)); - TypeField { - label: ((), *label), - r#type, - } - }), - ); - self.truncate_local(initial_local_len); + + let type_fields = self.with_scope(|this| { + (this.scope).to_scope_from_iter(Iterator::zip(labels.iter(), types.iter()).map( + |(label, r#type)| { + let r#type = this.check_prec(Prec::Top, r#type); + this.push_local(Some(*label)); + TypeField { + label: ((), *label), + r#type, + } + }, + )) + }); Term::RecordType((), type_fields) } @@ -575,9 +570,9 @@ impl<'arena, 'env> Context<'arena, 'env> { } let head_expr = self.synth_prec(Prec::Atomic, head_expr); - Term::Proj( + builder.record_proj( (), - self.scope.to_scope(head_expr), + head_expr, self.scope.to_scope_from_iter(labels.into_iter().rev()), ) } @@ -598,15 +593,8 @@ impl<'arena, 'env> Context<'arena, 'env> { } (core::Term::FormatCond(_, label, format, cond), _) => { let format = self.check_prec(Prec::Top, format); - self.push_local(Some(*label)); - let cond = self.check_prec(Prec::Top, cond); - self.pop_local(); - Term::FormatCond( - (), - ((), *label), - self.scope.to_scope(format), - self.scope.to_scope(cond), - ) + let cond = self.with_local(Some(*label), |this| this.check_prec(Prec::Top, cond)); + builder.format_cond((), ((), *label), format, cond) } (core::Term::FormatOverlap(_, labels, formats), _) => { Term::FormatOverlap((), self.synth_format_fields(labels, formats)) @@ -659,12 +647,7 @@ impl<'arena, 'env> Context<'arena, 'env> { let else_expr = self.term_prec(mode, Prec::Let, else_expr); return self.paren( prec > Prec::Let, - Term::If( - (), - self.scope.to_scope(cond_expr), - self.scope.to_scope(then_expr), - self.scope.to_scope(else_expr), - ), + builder.if_then_else((), cond_expr, then_expr, else_expr), ); } @@ -683,13 +666,13 @@ impl<'arena, 'env> Context<'arena, 'env> { if let Some((name, expr)) = default_expr { let name = self.freshen_name(*name, expr); - let name = self.push_local(name); - let expr = self.term_prec(mode, Prec::Top, expr); - branches.push((name_to_pattern(name), expr)); - self.pop_local(); + self.with_local(name, |this| { + let expr = this.term_prec(mode, Prec::Top, expr); + branches.push((name_to_pattern(name), expr)); + }); } - Term::Match((), self.scope.to_scope(head_expr), branches.into()) + builder.r#match((), head_expr, branches.into()) } } } @@ -740,7 +723,7 @@ fn is_tuple_type(labels: &[Symbol], types: &[core::Term<'_>]) -> bool { // For each type in the telescope, ensure that the subsequent types in // the telescope do not depend on the current field. && (1..=types.len()).all(|index| { - Iterator::zip(types[index..].iter(), env::indices()) + Iterator::zip(types[index..].iter(), env::Index::iter()) .all(|(expr, var)| !expr.binds_local(var)) }) } diff --git a/fathom/src/surface/elaboration.rs b/fathom/src/surface/elaboration.rs index cf4a19466..29507a120 100644 --- a/fathom/src/surface/elaboration.rs +++ b/fathom/src/surface/elaboration.rs @@ -33,7 +33,7 @@ use crate::files::FileId; use crate::source::{BytePos, ByteRange, FileRange, Span, Spanned}; use crate::surface::elaboration::reporting::Message; use crate::surface::{ - distillation, pretty, BinOp, ExprField, FormatField, Item, Module, Param, Pattern, Term, + distillation, pretty, BinOp, FormatField, Item, LetDef, Module, Param, Pattern, Term, }; use crate::symbol::Symbol; @@ -329,6 +329,40 @@ impl<'arena> Context<'arena> { Some((local_var, local_type)) } + /// Run `f`, potentially modifying the local environment, then restore the + /// local environment to its previous state. + fn with_scope(&mut self, mut f: impl FnMut(&mut Self) -> T) -> T { + let initial_len = self.local_env.len(); + let result = f(self); + self.local_env.truncate(initial_len); + result + } + + fn with_def( + &mut self, + name: impl Into>, + expr: ArcValue<'arena>, + r#type: ArcValue<'arena>, + mut f: impl FnMut(&mut Self) -> T, + ) -> T { + self.local_env.push_def(name.into(), expr, r#type); + let result = f(self); + self.local_env.pop(); + result + } + + fn with_param( + &mut self, + name: impl Into>, + r#type: ArcValue<'arena>, + mut f: impl FnMut(&mut Self) -> T, + ) -> T { + self.local_env.push_param(name.into(), r#type); + let result = f(self); + self.local_env.pop(); + result + } + /// Push an unsolved term onto the context, to be updated later during /// unification. fn push_unsolved_term( @@ -383,6 +417,10 @@ impl<'arena> Context<'arena> { } } + pub fn builder(&self) -> core::Builder<'arena> { + core::Builder::new(self.scope) + } + pub fn eval_env(&mut self) -> semantics::EvalEnv<'arena, '_> { semantics::ElimEnv::new(&self.item_env.exprs, &self.meta_env.exprs) .eval_env(&mut self.local_env.exprs) @@ -588,12 +626,11 @@ impl<'arena> Context<'arena> { (Value::Stuck(Head::Prim(Prim::FormatType), elims), Value::Universe) if elims.is_empty() => { - core::Term::FunApp( + self.builder().fun_app( span, Plicity::Explicit, - self.scope - .to_scope(core::Term::Prim(span, core::Prim::FormatRepr)), - self.scope.to_scope(expr), + core::Term::Prim(span, core::Prim::FormatRepr), + expr, ) } @@ -616,7 +653,7 @@ impl<'arena> Context<'arena> { expected: self.pretty_value(&to), error, }); - core::Term::Prim(span, Prim::ReportedError) + core::Term::error(span) } }, } @@ -636,8 +673,12 @@ impl<'arena> Context<'arena> { for item in elab_order.iter().copied().map(|i| &surface_module.items[i]) { match item { Item::Def(item) => { - let (expr, r#type) = - self.synth_fun_lit(item.range, item.params, item.expr, item.r#type); + let (expr, r#type) = self.synth_fun_lit( + item.range, + item.params, + &item.expr, + item.r#type.as_ref(), + ); let expr_value = self.eval_env().eval(&expr); let type_value = self.eval_env().eval(&r#type); @@ -646,8 +687,8 @@ impl<'arena> Context<'arena> { items.push(core::Item::Def { label: item.label.1, - r#type: self.scope.to_scope(r#type), - expr: self.scope.to_scope(expr), + r#type, + expr, }); } Item::ReportedError(_) => {} @@ -662,13 +703,13 @@ impl<'arena> Context<'arena> { expr, } => { // TODO: Unfold unsolved metas to reported errors - let r#type = self.eval_env().unfold_metas(scope, r#type); - let expr = self.eval_env().unfold_metas(scope, expr); + let r#type = self.eval_env().unfold_metas(scope, &r#type); + let expr = self.eval_env().unfold_metas(scope, &expr); core::Item::Def { label, - r#type: scope.to_scope(r#type), - expr: scope.to_scope(expr), + r#type, + expr, } } })); @@ -889,74 +930,50 @@ impl<'arena> Context<'arena> { } } - /// Push a local definition onto the context. - /// The supplied `pattern` is expected to be irrefutable. - fn push_local_def( - &mut self, - pattern: CheckedPattern, - expr: ArcValue<'arena>, - r#type: ArcValue<'arena>, - ) -> Option { - let name = match pattern { - CheckedPattern::Binder(_, name) => Some(name), - CheckedPattern::Placeholder(_) => None, - // FIXME: generate failing parameter expressions? - CheckedPattern::ConstLit(range, _) => { - self.push_message(Message::RefutablePattern { - pattern_range: range, - }); - None - } - CheckedPattern::ReportedError(_) => None, - }; - - self.local_env.push_def(name, expr, r#type); - - name - } - - /// Push a local parameter onto the context. - /// The supplied `pattern` is expected to be irrefutable. - fn push_local_param( - &mut self, - pattern: CheckedPattern, - r#type: ArcValue<'arena>, - ) -> (Option, ArcValue<'arena>) { - let name = match pattern { - CheckedPattern::Binder(_, name) => Some(name), - CheckedPattern::Placeholder(_) => None, - // FIXME: generate failing parameter expressions? - CheckedPattern::ConstLit(range, _) => { - self.push_message(Message::RefutablePattern { - pattern_range: range, - }); - None - } - CheckedPattern::ReportedError(_) => None, - }; - - let expr = self.local_env.push_param(name, r#type); - - (name, expr) + /// Report an error if `pattern` is refutable + fn check_pattern_refutability(&mut self, pattern: &CheckedPattern) { + if let CheckedPattern::ConstLit(range, _) = pattern { + self.push_message(Message::RefutablePattern { + pattern_range: *range, + }); + } } /// Elaborate a list of parameters, pushing them onto the context. fn synth_and_push_params( &mut self, + mut range: FileRange, params: &[Param], - ) -> Vec<(ByteRange, Plicity, Option, core::Term<'arena>)> { + ) -> Vec<(Span, Plicity, Option, core::Term<'arena>)> { self.local_env.reserve(params.len()); Vec::from_iter(params.iter().map(|param| { - let range = param.pattern.range(); + let old_range = range; + range = self.file_range(ByteRange::merge(param.pattern.range(), range.byte_range())); + let (pattern, r#type, type_value) = self.synth_ann_pattern(¶m.pattern, param.r#type.as_ref()); - let (name, _) = self.push_local_param(pattern, type_value); + self.check_pattern_refutability(&pattern); - (range, param.plicity, name, r#type) + let name = pattern.name(); + self.local_env.push_param(name, type_value); + (old_range.into(), param.plicity, name, r#type) })) } + fn synth_let_def( + &mut self, + def: &LetDef<'_, ByteRange>, + ) -> (core::LetDef<'arena>, ArcValue<'arena>) { + let (pattern, r#type, type_value) = + self.synth_ann_pattern(&def.pattern, def.r#type.as_ref()); + let name = pattern.name(); + self.check_pattern_refutability(&pattern); + + let expr = self.check(&def.expr, &type_value); + (core::LetDef { name, r#type, expr }, type_value) + } + /// Check that a surface term conforms to the given type. /// /// Returns the elaborated term in the core language. @@ -970,39 +987,23 @@ impl<'arena> Context<'arena> { match (surface_term, expected_type.as_ref()) { (Term::Paren(_, term), _) => self.check(term, &expected_type), - (Term::Let(_, def_pattern, def_type, def_expr, body_expr), _) => { - let (def_pattern, def_type, def_type_value) = - self.synth_ann_pattern(def_pattern, *def_type); - let def_expr = self.check(def_expr, &def_type_value); - let def_expr_value = self.eval_env().eval(&def_expr); - - let def_name = self.push_local_def(def_pattern, def_expr_value, def_type_value); // TODO: split on constants - let body_expr = self.check(body_expr, &expected_type); - self.local_env.pop(); - - core::Term::Let( - file_range.into(), - def_name, - self.scope.to_scope(def_type), - self.scope.to_scope(def_expr), - self.scope.to_scope(body_expr), - ) + (Term::Let(_, def, body_expr), _) => { + let (def, type_value) = self.synth_let_def(def); + let expr_value = self.eval_env().eval(&def.expr); + + let body_expr = self.with_def(def.name, expr_value, type_value, |this| { + this.check(body_expr, &expected_type) + }); + + self.builder().r#let(file_range, def, body_expr) } (Term::If(_, cond_expr, then_expr, else_expr), _) => { let cond_expr = self.check(cond_expr, &self.bool_type.clone()); let then_expr = self.check(then_expr, &expected_type); let else_expr = self.check(else_expr, &expected_type); - core::Term::ConstMatch( - file_range.into(), - self.scope.to_scope(cond_expr), - // NOTE: in lexicographic order: in Rust, `false < true` - self.scope.to_scope_from_iter([ - (Const::Bool(false), else_expr), - (Const::Bool(true), then_expr), - ]), - None, - ) + self.builder() + .if_then_else(file_range, cond_expr, then_expr, else_expr) } (Term::Match(range, scrutinee_expr, equations), _) => { self.check_match(*range, scrutinee_expr, equations, &expected_type) @@ -1017,20 +1018,13 @@ impl<'arena> Context<'arena> { let (synth_term, synth_type) = self.synth_and_insert_implicit_apps(surface_term); self.coerce(surface_range, synth_term, &synth_type, &expected_type) } - (Term::RecordLiteral(_, expr_fields), Value::RecordType(labels, types)) => { + (Term::RecordLiteral(range, expr_fields), Value::RecordType(labels, types)) => { // TODO: improve handling of duplicate labels - if expr_fields.len() != labels.len() - || Iterator::zip(expr_fields.iter(), labels.iter()) - .any(|(expr_field, type_label)| expr_field.label.1 != *type_label) + if self + .check_record_fields(*range, expr_fields, |field| field.label, labels) + .is_err() { - self.push_message(Message::MismatchedFieldLabels { - range: file_range, - expr_labels: (expr_fields.iter()) - .map(|ExprField { label, .. }| (self.file_range(label.0), label.1)) - .collect(), - type_labels: labels.to_vec(), - }); - return core::Term::Prim(file_range.into(), Prim::ReportedError); + return core::Term::error(file_range); } let mut types = types.clone(); @@ -1054,20 +1048,17 @@ impl<'arena> Context<'arena> { let labels = Symbol::get_tuple_labels(0..elem_exprs.len()); let labels = self.scope.to_scope_from_iter(labels.iter().copied()); - let initial_local_len = self.local_env.len(); - let universe = &self.universe.clone(); - let types = self.scope.to_scope_from_iter( - Iterator::zip(labels.iter(), elem_exprs.iter()).map(|(label, elem_expr)| { - let r#type = self.check(elem_expr, universe); - let type_value = self.eval_env().eval(&r#type); - self.local_env.push_param(Some(*label), type_value); - r#type - }), - ); - - self.local_env.truncate(initial_local_len); - - core::Term::RecordType(file_range.into(), labels, types) + self.with_scope(|this| { + let universe = &this.universe.clone(); + let types = + (this.scope).to_scope_from_iter(elem_exprs.iter().map(|elem_expr| { + let r#type = this.check(elem_expr, universe); + let type_value = this.eval_env().eval(&r#type); + this.local_env.push_param(None, type_value); + r#type + })); + core::Term::RecordType(file_range.into(), labels, types) + }) } (Term::Tuple(_, elem_exprs), Value::Stuck(Head::Prim(Prim::FormatType), args)) if args.is_empty() => @@ -1076,50 +1067,25 @@ impl<'arena> Context<'arena> { let labels = Symbol::get_tuple_labels(0..elem_exprs.len()); let labels = self.scope.to_scope_from_iter(labels.iter().copied()); - let initial_local_len = self.local_env.len(); - let format_type = self.format_type.clone(); - let formats = self.scope.to_scope_from_iter( - Iterator::zip(labels.iter(), elem_exprs.iter()).map(|(label, elem_expr)| { - let format = self.check(elem_expr, &format_type); - let format_value = self.eval_env().eval(&format); - let r#type = self.elim_env().format_repr(&format_value); - self.local_env.push_param(Some(*label), r#type); - format - }), - ); - - self.local_env.truncate(initial_local_len); - - core::Term::FormatRecord(file_range.into(), labels, formats) - } - (Term::Tuple(_, elem_exprs), Value::RecordType(labels, types)) => { - if elem_exprs.len() != labels.len() { - let mut expr_labels = Vec::with_capacity(elem_exprs.len()); - let mut elem_exprs = elem_exprs.iter().enumerate().peekable(); - let mut label_iter = labels.iter(); - - // use the label names from the expected type - while let Some(((_, elem_expr), label)) = - Option::zip(elem_exprs.peek(), label_iter.next()) - { - expr_labels.push((self.file_range(elem_expr.range()), *label)); - elem_exprs.next(); - } - - // use numeric labels for excess elems - for (index, elem_expr) in elem_exprs { - expr_labels.push(( - self.file_range(elem_expr.range()), - Symbol::get_tuple_label(index), - )); - } - - self.push_message(Message::MismatchedFieldLabels { - range: file_range, - expr_labels, - type_labels: labels.to_vec(), - }); - return core::Term::Prim(file_range.into(), Prim::ReportedError); + self.with_scope(|this| { + let format_type = this.format_type.clone(); + let formats = + (this.scope).to_scope_from_iter(elem_exprs.iter().map(|elem_expr| { + let format = this.check(elem_expr, &format_type); + let format_value = this.eval_env().eval(&format); + let r#type = this.elim_env().format_repr(&format_value); + this.local_env.push_param(None, r#type); + format + })); + core::Term::FormatRecord(file_range.into(), labels, formats) + }) + } + (Term::Tuple(range, elem_exprs), Value::RecordType(labels, types)) => { + if self + .check_tuple_fields(*range, elem_exprs, Term::range, labels) + .is_err() + { + return core::Term::error(file_range); } let mut types = types.clone(); @@ -1141,27 +1107,20 @@ impl<'arena> Context<'arena> { let (len_value, elem_type) = match expected_type.match_prim_spine() { Some((Prim::ArrayType, [App(_, elem_type)])) => (None, elem_type), - Some((Prim::Array8Type, [App(_, len), App(_, elem_type)])) => { - (Some(len), elem_type) - } - Some((Prim::Array16Type, [App(_, len), App(_, elem_type)])) => { - (Some(len), elem_type) - } - Some((Prim::Array32Type, [App(_, len), App(_, elem_type)])) => { - (Some(len), elem_type) - } - Some((Prim::Array64Type, [App(_, len), App(_, elem_type)])) => { - (Some(len), elem_type) - } - Some((Prim::ReportedError, _)) => { - return core::Term::Prim(file_range.into(), Prim::ReportedError) - } + Some(( + Prim::Array8Type + | Prim::Array16Type + | Prim::Array32Type + | Prim::Array64Type, + [App(_, len), App(_, elem_type)], + )) => (Some(len), elem_type), + Some((Prim::ReportedError, _)) => return core::Term::error(file_range), _ => { self.push_message(Message::ArrayLiteralNotSupported { range: file_range, expected_type: self.pretty_value(&expected_type), }); - return core::Term::Prim(file_range.into(), Prim::ReportedError); + return core::Term::error(file_range); } }; @@ -1172,7 +1131,7 @@ impl<'arena> Context<'arena> { Some(Value::ConstLit(Const::U32(len, _))) => Some(*len as u64), Some(Value::ConstLit(Const::U64(len, _))) => Some(*len), Some(Value::Stuck(Head::Prim(Prim::ReportedError), _)) => { - return core::Term::Prim(file_range.into(), Prim::ReportedError); + return core::Term::error(file_range) } _ => None, }; @@ -1197,7 +1156,7 @@ impl<'arena> Context<'arena> { expected_len: self.pretty_value(len_value.unwrap()), }); - core::Term::Prim(file_range.into(), Prim::ReportedError) + return core::Term::error(file_range); } } } @@ -1223,7 +1182,7 @@ impl<'arena> Context<'arena> { match constant { Some(constant) => core::Term::ConstLit(file_range.into(), constant), - None => core::Term::Prim(file_range.into(), Prim::ReportedError), + None => core::Term::error(file_range), } } (Term::NumberLiteral(range, lit), _) => { @@ -1244,19 +1203,19 @@ impl<'arena> Context<'arena> { range: file_range, expected_type: self.pretty_value(&expected_type), }); - return core::Term::Prim(file_range.into(), Prim::ReportedError); + return core::Term::error(file_range); } }; match constant { Some(constant) => core::Term::ConstLit(file_range.into(), constant), - None => core::Term::Prim(file_range.into(), Prim::ReportedError), + None => core::Term::error(file_range), } } (Term::BinOp(range, lhs, op, rhs), _) => { self.check_bin_op(*range, lhs, *op, rhs, &expected_type) } - (Term::ReportedError(_), _) => core::Term::Prim(file_range.into(), Prim::ReportedError), + (Term::ReportedError(_), _) => core::Term::error(file_range), (_, _) => { let surface_range = surface_term.range(); let (synth_term, synth_type) = self.synth(surface_term); @@ -1281,12 +1240,9 @@ impl<'arena> Context<'arena> { let arg_term = self.push_unsolved_term(source, param_type.clone()); let arg_value = self.eval_env().eval(&arg_term); - term = core::Term::FunApp( - file_range.into(), - Plicity::Implicit, - self.scope.to_scope(term), - self.scope.to_scope(arg_term), - ); + term = self + .builder() + .fun_app(file_range, Plicity::Implicit, term, arg_term); r#type = self.elim_env().apply_closure(body_type, arg_value); } (term, r#type) @@ -1364,32 +1320,18 @@ impl<'arena> Context<'arena> { let type_value = self.eval_env().eval(&r#type); let expr = self.check(expr, &type_value); - let ann_expr = core::Term::Ann( - file_range.into(), - self.scope.to_scope(expr), - self.scope.to_scope(r#type), - ); - + let ann_expr = self.builder().ann(file_range, expr, r#type); (ann_expr, type_value) } - Term::Let(_, def_pattern, def_type, def_expr, body_expr) => { - let (def_pattern, def_type, def_type_value) = - self.synth_ann_pattern(def_pattern, *def_type); - let def_expr = self.check(def_expr, &def_type_value); - let def_expr_value = self.eval_env().eval(&def_expr); - - let def_name = self.push_local_def(def_pattern, def_expr_value, def_type_value); - let (body_expr, body_type) = self.synth(body_expr); - self.local_env.pop(); - - let let_expr = core::Term::Let( - file_range.into(), - def_name, - self.scope.to_scope(def_type), - self.scope.to_scope(def_expr), - self.scope.to_scope(body_expr), - ); + Term::Let(_, def, body_expr) => { + let (def, type_value) = self.synth_let_def(def); + let expr_value = self.eval_env().eval(&def.expr); + + let (body, body_type) = self.with_def(def.name, expr_value, r#type_value, |this| { + this.synth(body_expr) + }); + let let_expr = self.builder().r#let(file_range, def, body); (let_expr, body_type) } Term::If(_, cond_expr, then_expr, else_expr) => { @@ -1397,16 +1339,9 @@ impl<'arena> Context<'arena> { let (then_expr, r#type) = self.synth(then_expr); let else_expr = self.check(else_expr, &r#type); - let match_expr = core::Term::ConstMatch( - file_range.into(), - self.scope.to_scope(cond_expr), - // NOTE: in lexicographic order: in Rust, `false < true` - self.scope.to_scope_from_iter([ - (Const::Bool(false), else_expr), - (Const::Bool(true), then_expr), - ]), - None, - ); + let match_expr = self + .builder() + .if_then_else(file_range, cond_expr, then_expr, else_expr); (match_expr, r#type) } @@ -1427,46 +1362,29 @@ impl<'arena> Context<'arena> { let param_type = self.check(param_type, &universe); let param_type_value = self.eval_env().eval(¶m_type); - self.local_env.push_param(None, param_type_value); - let body_type = self.check(body_type, &universe); - self.local_env.pop(); + let body_type = self.with_param(None, param_type_value, |this| { + this.check(body_type, &universe) + }); - let fun_type = core::Term::FunType( - file_range.into(), - *plicity, - None, - self.scope.to_scope(param_type), - self.scope.to_scope(body_type), - ); + let fun_type = self + .builder() + .arrow(file_range, *plicity, param_type, body_type); - (fun_type, self.universe.clone()) + (fun_type, universe) } - Term::FunType(range, params, body_type) => { - let initial_local_len = self.local_env.len(); - - let params = self.synth_and_push_params(params); - let mut fun_type = self.check(body_type, &self.universe.clone()); - self.local_env.truncate(initial_local_len); + Term::FunType(_, params, body_type) => { + let universe = self.universe.clone(); - // Construct the function type from the parameters in reverse - for (i, (param_range, plicity, name, r#type)) in - params.into_iter().enumerate().rev() - { - let range = match i { - 0 => *range, // Use the range of the full function type - _ => ByteRange::merge(param_range, body_type.range()), - }; + let (params, fun_type) = self.with_scope(|this| { + let params = this.synth_and_push_params(file_range, params); + let fun_type = this.check(body_type, &universe); + (params, fun_type) + }); - fun_type = core::Term::FunType( - self.file_range(range).into(), - plicity, - name, - self.scope.to_scope(r#type), - self.scope.to_scope(fun_type), - ); - } + // Construct the function type from the parameters + let fun_type = self.builder().fun_types(params, fun_type); - (fun_type, self.universe.clone()) + (fun_type, universe) } Term::FunLiteral(range, params, body_expr) => { let (expr, r#type) = self.synth_fun_lit(*range, params, body_expr, None); @@ -1527,36 +1445,34 @@ impl<'arena> Context<'arena> { let arg_expr = self.check(&arg.term, param_type); let arg_expr_value = self.eval_env().eval(&arg_expr); - head_expr = core::Term::FunApp( - self.file_range(head_range).into(), + head_expr = self.builder().fun_app( + self.file_range(head_range), arg.plicity, - self.scope.to_scope(head_expr), - self.scope.to_scope(arg_expr), + head_expr, + arg_expr, ); head_type = self.elim_env().apply_closure(body_type, arg_expr_value); } (head_expr, head_type) } - Term::RecordType(range, type_fields) => { - let universe = self.universe.clone(); - let initial_local_len = self.local_env.len(); + Term::RecordType(range, type_fields) => self.with_scope(|this| { let (labels, type_fields) = - self.report_duplicate_labels(*range, type_fields, |f| f.label); - let mut types = SliceVec::new(self.scope, labels.len()); + this.report_duplicate_labels(*range, type_fields, |f| f.label); + + let universe = this.universe.clone(); + let mut types = SliceVec::new(this.scope, labels.len()); for type_field in type_fields { - let r#type = self.check(&type_field.r#type, &universe); - let type_value = self.eval_env().eval(&r#type); - self.local_env + let r#type = this.check(&type_field.r#type, &universe); + let type_value = this.eval_env().eval(&r#type); + this.local_env .push_param(Some(type_field.label.1), type_value); types.push(r#type); } - self.local_env.truncate(initial_local_len); let record_type = core::Term::RecordType(file_range.into(), labels, types.into()); - (record_type, universe) - } + }), Term::RecordLiteral(range, expr_fields) => { let (labels, expr_fields) = self.report_duplicate_labels(*range, expr_fields, |f| f.label); @@ -1582,8 +1498,8 @@ impl<'arena> Context<'arena> { let labels = Symbol::get_tuple_labels(0..elem_exprs.len()); let labels = self.scope.to_scope_from_iter(labels.iter().copied()); - let mut exprs = SliceVec::new(self.scope, labels.len()); - let mut types = SliceVec::new(self.scope, labels.len()); + let mut exprs = SliceVec::new(self.scope, elem_exprs.len()); + let mut types = SliceVec::new(self.scope, elem_exprs.len()); for elem_exprs in elem_exprs.iter() { let (expr, r#type) = self.synth(elem_exprs); @@ -1591,9 +1507,10 @@ impl<'arena> Context<'arena> { exprs.push(expr); } - let types = Telescope::new(self.local_env.exprs.clone(), types.into()); let term = core::Term::RecordLit(file_range.into(), labels, exprs.into()); - let r#type = Spanned::empty(Arc::new(Value::RecordType(labels, types))); + let r#type = core::Term::RecordType(Span::Empty, labels, types.into()); + let r#type = self.eval_env().eval(&r#type); + (term, r#type) } Term::Proj(range, head_expr, labels) => { @@ -1618,10 +1535,9 @@ impl<'arena> Context<'arena> { if *proj_label == label { // The field was found. Update the head expression // and continue elaborating the next projection. - head_expr = core::Term::RecordProj( - self.file_range(ByteRange::merge(head_range, *label_range)) - .into(), - self.scope.to_scope(head_expr), + head_expr = self.builder().record_proj( + self.file_range(ByteRange::merge(head_range, *label_range)), + head_expr, *proj_label, ); head_type = r#type; @@ -1641,9 +1557,8 @@ impl<'arena> Context<'arena> { // There's been an error when elaborating the head of // the projection, so avoid trying to elaborate any // further to prevent cascading type errors. - (core::Term::Prim(_, Prim::ReportedError), _) - | (_, Value::Stuck(Head::Prim(Prim::ReportedError), _)) => { - return self.synth_reported_error(*range); + (expr, r#type) if expr.is_error() || r#type.is_error() => { + return self.synth_reported_error(*range) } // The head expression was not a record type. // Fallthrough with an error. @@ -1687,21 +1602,17 @@ impl<'arena> Context<'arena> { } Term::FormatCond(_, (_, name), format, pred) => { let format_type = self.format_type.clone(); + let bool_type = self.bool_type.clone(); let format = self.check(format, &format_type); let format_value = self.eval_env().eval(&format); let repr_type = self.elim_env().format_repr(&format_value); - self.local_env.push_param(Some(*name), repr_type); - let bool_type = self.bool_type.clone(); - let pred_expr = self.check(pred, &bool_type); - self.local_env.pop(); + let pred_expr = + self.with_param(*name, repr_type, |this| this.check(pred, &bool_type)); - let cond_format = core::Term::FormatCond( - file_range.into(), - *name, - self.scope.to_scope(format), - self.scope.to_scope(pred_expr), - ); + let cond_format = self + .builder() + .format_cond(file_range, *name, format, pred_expr); (cond_format, format_type) } @@ -1737,18 +1648,20 @@ impl<'arena> Context<'arena> { param.r#type.as_ref(), param_type, ); - let (name, arg_expr) = self.push_local_param(pattern, param_type.clone()); + self.check_pattern_refutability(&pattern); + let name = pattern.name(); + let arg_expr = self.local_env.push_param(name, param_type.clone()); let body_type = self.elim_env().apply_closure(next_body_type, arg_expr); let body_expr = self.check_fun_lit(range, next_params, body_expr, &body_type); self.local_env.pop(); - core::Term::FunLit( - self.file_range(range).into(), + self.builder().fun_lit( + self.file_range(range), param.plicity, name, - self.scope.to_scope(body_expr), + body_expr, ) } // If an implicit function is expected, try to generalize the @@ -1760,11 +1673,11 @@ impl<'arena> Context<'arena> { let body_type = self.elim_env().apply_closure(next_body_type, arg_expr); let body_expr = self.check_fun_lit(range, params, body_expr, &body_type); self.local_env.pop(); - core::Term::FunLit( - file_range.into(), + self.builder().fun_lit( + file_range, Plicity::Implicit, *param_name, - self.scope.to_scope(body_expr), + body_expr, ) } // Attempt to elaborate the the body of the function in synthesis @@ -1776,7 +1689,7 @@ impl<'arena> Context<'arena> { self.coerce(range, expr, &type_value, expected_type) } Value::Stuck(Head::Prim(Prim::ReportedError), _) => { - core::Term::Prim(file_range.into(), Prim::ReportedError) + core::Term::error(file_range) } _ => { self.push_message(Message::UnexpectedParameter { @@ -1785,7 +1698,7 @@ impl<'arena> Context<'arena> { // TODO: For improved error recovery, bind the rest of // the parameters, and check the body of the function // literal using the expected body type. - core::Term::Prim(file_range.into(), Prim::ReportedError) + core::Term::error(file_range) } } } @@ -1800,45 +1713,32 @@ impl<'arena> Context<'arena> { body_expr: &Term<'_, ByteRange>, body_type: Option<&Term<'_, ByteRange>>, ) -> (core::Term<'arena>, core::Term<'arena>) { + let file_range = self.file_range(range); self.local_env.reserve(params.len()); - let initial_local_len = self.local_env.len(); - - let params = self.synth_and_push_params(params); - - let (mut fun_lit, mut fun_type) = match body_type { - Some(body_type) => { - let body_type = self.check(body_type, &self.universe.clone()); - let body_type_value = self.eval_env().eval(&body_type); - (self.check(body_expr, &body_type_value), body_type) - } - None => { - let (body_expr, body_type) = self.synth(body_expr); - (body_expr, self.quote_env().quote(self.scope, &body_type)) - } - }; - self.local_env.truncate(initial_local_len); + let (params, mut fun_lit, mut fun_type) = self.with_scope(|this| { + let params = this.synth_and_push_params(file_range, params); - // Construct the function literal and type from the parameters in reverse - for (i, (param_range, plicity, name, r#type)) in params.into_iter().enumerate().rev() { - let range = match i { - 0 => range, // Use the range of the full function literal - _ => ByteRange::merge(param_range, body_expr.range()), + let (fun_lit, fun_type) = match body_type { + Some(body_type) => { + let body_type = this.check(body_type, &this.universe.clone()); + let body_type_value = this.eval_env().eval(&body_type); + (this.check(body_expr, &body_type_value), body_type) + } + None => { + let (body_expr, body_type) = this.synth(body_expr); + (body_expr, this.quote_env().quote(this.scope, &body_type)) + } }; + (params, fun_lit, fun_type) + }); - fun_lit = core::Term::FunLit( - self.file_range(range).into(), - plicity, - name, - self.scope.to_scope(fun_lit), - ); - fun_type = core::Term::FunType( - Span::Empty, - plicity, - name, - self.scope.to_scope(r#type), - self.scope.to_scope(fun_type), - ); + // Construct the function literal and type from the parameters in reverse + for (param_range, plicity, name, r#type) in params.into_iter().rev() { + fun_lit = self.builder().fun_lit(param_range, plicity, name, fun_lit); + fun_type = self + .builder() + .fun_type(Span::Empty, plicity, name, r#type, fun_type); } (fun_lit, fun_type) @@ -1983,18 +1883,12 @@ impl<'arena> Context<'arena> { } }; - let fun_head = core::Term::Prim(self.file_range(op.range()).into(), fun); - let fun_app = core::Term::FunApp( - self.file_range(range).into(), - Plicity::Explicit, - self.scope.to_scope(core::Term::FunApp( - Span::merge(&lhs_expr.span(), &rhs_expr.span()), - Plicity::Explicit, - self.scope.to_scope(fun_head), - self.scope.to_scope(lhs_expr), - )), - self.scope.to_scope(rhs_expr), - ); + let term_span = self.file_range(range); + let op_span = self.file_range(op.range()); + + let fun_app = self + .builder() + .binop(term_span, op_span, fun, lhs_expr, rhs_expr); // TODO: Maybe it would be good to reuse lhs_type here if body_type is the same ( @@ -2075,23 +1969,16 @@ impl<'arena> Context<'arena> { let lhs_expr = self.check(lhs, &expected_type); let rhs_expr = self.check(rhs, &expected_type); - let fun_head = core::Term::Prim(self.file_range(op.range()).into(), fun); - core::Term::FunApp( - self.file_range(range).into(), - Plicity::Explicit, - self.scope.to_scope(core::Term::FunApp( - Span::merge(&lhs_expr.span(), &rhs_expr.span()), - Plicity::Explicit, - self.scope.to_scope(fun_head), - self.scope.to_scope(lhs_expr), - )), - self.scope.to_scope(rhs_expr), - ) + let term_span = self.file_range(range); + let op_span = self.file_range(op.range()); + + self.builder() + .binop(term_span, op_span, fun, lhs_expr, rhs_expr) } fn synth_reported_error(&mut self, range: ByteRange) -> (core::Term<'arena>, ArcValue<'arena>) { let file_range = self.file_range(range); - let expr = core::Term::Prim(file_range.into(), Prim::ReportedError); + let expr = core::Term::error(file_range); let r#type = self.push_unsolved_type(MetaSource::ReportedErrorType(file_range)); (expr, r#type) } @@ -2135,12 +2022,10 @@ impl<'arena> Context<'arena> { let cond_expr = self.check(pred, &self.bool_type.clone()); let field_span = Span::merge(&label_range.into(), &cond_expr.span()); - formats.push(core::Term::FormatCond( - field_span, - *label, - self.scope.to_scope(format), - self.scope.to_scope(cond_expr), - )); + let format = self + .builder() + .format_cond(field_span, *label, format, cond_expr); + formats.push(format); } } } @@ -2164,19 +2049,13 @@ impl<'arena> Context<'arena> { }; let field_span = Span::merge(&label_range.into(), &expr.span()); - let format = core::Term::FunApp( - field_span, - Plicity::Explicit, - self.scope.to_scope(core::Term::FunApp( - field_span, - Plicity::Explicit, - self.scope - .to_scope(core::Term::Prim(field_span, Prim::FormatSucceed)), - self.scope.to_scope(r#type), - )), - self.scope.to_scope(expr), + let format = self.builder().fun_apps( + core::Term::Prim(field_span, Prim::FormatSucceed), + [ + (field_span, Plicity::Explicit, r#type), + (field_span, Plicity::Explicit, expr), + ], ); - // Assume that `Repr ${type_value} ${expr} = ${type_value}` self.local_env.push_param(Some(*label), type_value); formats.push(format); @@ -2189,6 +2068,76 @@ impl<'arena> Context<'arena> { (labels, formats.into()) } + fn check_tuple_fields( + &mut self, + range: ByteRange, + fields: &[F], + get_range: fn(&F) -> ByteRange, + expected_labels: &[Symbol], + ) -> Result<(), ()> { + if fields.len() == expected_labels.len() { + return Ok(()); + } + + let mut found_labels = Vec::with_capacity(fields.len()); + let mut fields_iter = fields.iter().enumerate().peekable(); + let mut expected_labels_iter = expected_labels.iter(); + + // use the label names from the expected labels + while let Some(((_, field), label)) = + Option::zip(fields_iter.peek(), expected_labels_iter.next()) + { + found_labels.push((self.file_range(get_range(field)), *label)); + fields_iter.next(); + } + + // use numeric labels for excess fields + for (index, field) in fields_iter { + found_labels.push(( + self.file_range(get_range(field)), + Symbol::get_tuple_label(index), + )); + } + + self.push_message(Message::MismatchedFieldLabels { + range: self.file_range(range), + found_labels, + expected_labels: expected_labels.to_vec(), + }); + Err(()) + } + + fn check_record_fields( + &mut self, + range: ByteRange, + fields: &[F], + get_label: impl Fn(&F) -> (ByteRange, Symbol), + labels: &'arena [Symbol], + ) -> Result<(), ()> { + if fields.len() == labels.len() + && fields + .iter() + .zip(labels.iter()) + .all(|(field, type_label)| get_label(field).1 == *type_label) + { + return Ok(()); + } + + // TODO: improve handling of duplicate labels + self.push_message(Message::MismatchedFieldLabels { + range: self.file_range(range), + found_labels: fields + .iter() + .map(|field| { + let (range, label) = get_label(field); + (self.file_range(range), label) + }) + .collect(), + expected_labels: labels.to_vec(), + }); + Err(()) + } + /// Elaborate a match expression in checking mode fn check_match( &mut self, @@ -2242,18 +2191,20 @@ impl<'arena> Context<'arena> { let def_type_value = match_info.scrutinee.r#type.clone(); let def_type = self.quote_env().quote(self.scope, &def_type_value); - self.local_env.push_def(def_name, def_expr, def_type_value); - let body_expr = self.check(body_expr, &match_info.expected_type); - self.local_env.pop(); + let body_expr = self.with_def(def_name, def_expr, def_type_value, |this| { + this.check(body_expr, &match_info.expected_type) + }); self.elab_match_unreachable(match_info, equations); - core::Term::Let( + self.builder().r#let( Span::merge(&range.into(), &body_expr.span()), - def_name, - self.scope.to_scope(def_type), - match_info.scrutinee.expr, - self.scope.to_scope(body_expr), + core::LetDef { + name: def_name, + r#type: def_type, + expr: match_info.scrutinee.expr.clone(), + }, + body_expr, ) } // Placeholder patterns just elaborate to the body @@ -2281,7 +2232,7 @@ impl<'arena> Context<'arena> { CheckedPattern::ReportedError(range) => { self.check(body_expr, &match_info.expected_type); self.elab_match_unreachable(match_info, equations); - core::Term::Prim(range.into(), Prim::ReportedError) + core::Term::error(range) } } } @@ -2311,24 +2262,18 @@ impl<'arena> Context<'arena> { let mut branches = vec![(r#const, body_expr)]; // Elaborate a run of constant patterns. - 'patterns: while let Some((pattern, body_expr)) = equations.next() { + while let Some((pattern, body_expr)) = equations.next() { // Update the range up to the end of the next body expression full_span = Span::merge(&full_span, &self.file_range(body_expr.range()).into()); - // Default expression, defined if we arrive at a default case - let default_branch; - - match self.check_pattern(pattern, &match_info.scrutinee.r#type) { - // Accumulate constant pattern. Search for it in the accumulated - // branches and insert it in order. + let pattern = self.check_pattern(pattern, &match_info.scrutinee.r#type); + match pattern { CheckedPattern::ConstLit(range, r#const) => { let body_expr = self.check(body_expr, &match_info.expected_type); // Find insertion index of the branch - let insertion_index = branches.binary_search_by(|(probe_const, _)| { - Const::partial_cmp(probe_const, &r#const) - .expect("attempt to compare non-ordered value") - }); + let insertion_index = branches + .binary_search_by(|(probe_const, _)| Const::cmp(probe_const, &r#const)); match insertion_index { Ok(_) => self.push_message(Message::UnreachablePattern { range }), @@ -2339,66 +2284,56 @@ impl<'arena> Context<'arena> { } } - // No default case yet, continue looking for constant patterns. - continue 'patterns; - } - - // Time to elaborate the default pattern. The default case of - // `core::Term::ConstMatch` binds a variable, so both - // the named and placeholder patterns should bind this. - CheckedPattern::Binder(range, name) => { - self.check_match_reachable(is_reachable, range); - - // TODO: If we know this is an exhaustive match, bind the - // scrutinee to a let binding with the elaborated body, and - // add it to the branches. This will simplify the - // distillation of if expressions. - (self.local_env).push_param(Some(name), match_info.scrutinee.r#type.clone()); - let default_expr = self.check(body_expr, &match_info.expected_type); - default_branch = (Some(name), self.scope.to_scope(default_expr) as &_); - self.local_env.pop(); + if let Some(n) = r#const.num_inhabitants() { + if branches.len() as u128 >= n { + // The match is exhaustive. + // No need to elaborate the rest of the patterns + self.elab_match_unreachable(match_info, equations); + + return core::Term::ConstMatch( + full_span, + match_info.scrutinee.expr, + self.scope.to_scope_from_iter(branches.into_iter()), + None, + ); + } + } } - CheckedPattern::Placeholder(range) => { - self.check_match_reachable(is_reachable, range); + CheckedPattern::Binder(_, _) + | CheckedPattern::Placeholder(_) + | CheckedPattern::ReportedError(_) => { + let name = pattern.name(); + let range = pattern.range(); - (self.local_env).push_param(None, match_info.scrutinee.r#type.clone()); - let default_expr = self.check(body_expr, &match_info.expected_type); - default_branch = (None, self.scope.to_scope(default_expr) as &_); - self.local_env.pop(); - } - CheckedPattern::ReportedError(range) => { - (self.local_env).push_param(None, match_info.scrutinee.r#type.clone()); - let default_expr = core::Term::Prim(range.into(), Prim::ReportedError); - default_branch = (None, self.scope.to_scope(default_expr) as &_); - self.local_env.pop(); - } - }; + if !pattern.is_err() { + self.check_match_reachable(is_reachable, range); + self.elab_match_unreachable(match_info, equations); + } - // A default pattern was found, check any unreachable patterns. - self.elab_match_unreachable(match_info, equations); + let default_expr = + self.with_param(name, match_info.scrutinee.r#type.clone(), |this| { + this.check(body_expr, &match_info.expected_type) + }); - return core::Term::ConstMatch( - full_span, - match_info.scrutinee.expr, - self.scope.to_scope_from_iter(branches.into_iter()), - Some(default_branch), - ); + return core::Term::ConstMatch( + full_span, + match_info.scrutinee.expr, + self.scope.to_scope_from_iter(branches.into_iter()), + Some((name, self.scope.to_scope(default_expr))), + ); + } + } } // Finished all the constant patterns without encountering a default - // case. This should have been an exhaustive match, so check to see if - // all the cases were covered. - let default_expr = match match_info.scrutinee.r#type.match_prim_spine() { - // No need for a default case if all the values were covered - Some((Prim::BoolType, [])) if branches.len() >= 2 => None, - _ => Some(self.elab_match_absurd(is_reachable, match_info)), - }; + // case or an exhaustive match + let default_expr = self.elab_match_absurd(is_reachable, match_info); core::Term::ConstMatch( full_span, match_info.scrutinee.expr, self.scope.to_scope_from_iter(branches.into_iter()), - default_expr.map(|expr| (None, self.scope.to_scope(expr) as &_)), + Some((None, self.scope.to_scope(default_expr))), ) } @@ -2426,10 +2361,7 @@ impl<'arena> Context<'arena> { scrutinee_expr_range: self.file_range(match_info.scrutinee.range), }); } - core::Term::Prim( - self.file_range(match_info.range).into(), - Prim::ReportedError, - ) + core::Term::error(self.file_range(match_info.range)) } } @@ -2465,6 +2397,27 @@ enum CheckedPattern { /// Error sentinel ReportedError(FileRange), } +impl CheckedPattern { + fn name(&self) -> Option { + match self { + CheckedPattern::Binder(_, name) => Some(*name), + _ => None, + } + } + + fn range(&self) -> FileRange { + match self { + CheckedPattern::Binder(range, ..) + | CheckedPattern::Placeholder(range, ..) + | CheckedPattern::ConstLit(range, ..) + | CheckedPattern::ReportedError(range, ..) => *range, + } + } + + fn is_err(&self) -> bool { + matches!(self, Self::ReportedError(..)) + } +} /// Scrutinee of a match expression struct Scrutinee<'arena> { diff --git a/fathom/src/surface/elaboration/order.rs b/fathom/src/surface/elaboration/order.rs index c3a47cfc5..d34165353 100644 --- a/fathom/src/surface/elaboration/order.rs +++ b/fathom/src/surface/elaboration/order.rs @@ -150,10 +150,10 @@ fn item_dependencies( Item::Def(item) => { let initial_locals_names_len = local_names.len(); push_param_deps(item.params, item_names, local_names, &mut deps); - if let Some(r#type) = item.r#type { + if let Some(r#type) = item.r#type.as_ref() { term_deps(r#type, item_names, local_names, &mut deps); } - term_deps(item.expr, item_names, local_names, &mut deps); + term_deps(&item.expr, item_names, local_names, &mut deps); local_names.truncate(initial_locals_names_len); } Item::ReportedError(_) => {} @@ -168,7 +168,6 @@ fn term_deps( deps: &mut Vec, ) { match term { - Term::Paren(_, term) => term_deps(term, item_names, local_names, deps), Term::Name(_, name) => { if local_names.iter().rev().any(|local| name == local) { // local binding, do nothing @@ -178,23 +177,15 @@ fn term_deps( deps.push(*name); } } - Term::Ann(_, expr, r#type) => { - term_deps(expr, item_names, local_names, deps); - term_deps(r#type, item_names, local_names, deps); - } - Term::Let(_, pattern, r#type, def_expr, body_expr) => { - push_pattern(pattern, local_names); - if let Some(r#type) = r#type { + Term::Let(_, def, body_expr) => { + let initial_locals_names_len = local_names.len(); + if let Some(r#type) = def.r#type.as_ref() { term_deps(r#type, item_names, local_names, deps); } - term_deps(def_expr, item_names, local_names, deps); + term_deps(&def.expr, item_names, local_names, deps); + push_pattern(&def.pattern, local_names); term_deps(body_expr, item_names, local_names, deps); - pop_pattern(pattern, local_names); - } - Term::If(_, cond_expr, then_expr, else_expr) => { - term_deps(cond_expr, item_names, local_names, deps); - term_deps(then_expr, item_names, local_names, deps); - term_deps(else_expr, item_names, local_names, deps); + local_names.truncate(initial_locals_names_len); } Term::Match(_, scrutinee, equations) => { let initial_locals_names_len = local_names.len(); @@ -205,28 +196,12 @@ fn term_deps( } local_names.truncate(initial_locals_names_len); } - Term::Arrow(.., param_type, body_type) => { - term_deps(param_type, item_names, local_names, deps); - term_deps(body_type, item_names, local_names, deps); - } - Term::FunType(_, patterns, body_type) => { + Term::FunType(_, params, body) | Term::FunLiteral(_, params, body) => { let initial_locals_names_len = local_names.len(); - push_param_deps(patterns, item_names, local_names, deps); - term_deps(body_type, item_names, local_names, deps); + push_param_deps(params, item_names, local_names, deps); + term_deps(body, item_names, local_names, deps); local_names.truncate(initial_locals_names_len); } - Term::FunLiteral(_, patterns, body_type) => { - let initial_locals_names_len = local_names.len(); - push_param_deps(patterns, item_names, local_names, deps); - term_deps(body_type, item_names, local_names, deps); - local_names.truncate(initial_locals_names_len); - } - Term::App(_, head_expr, args) => { - term_deps(head_expr, item_names, local_names, deps); - for arg in *args { - term_deps(&arg.term, item_names, local_names, deps); - } - } Term::RecordType(_, type_fields) => { let initial_locals_names_len = local_names.len(); for type_field in *type_fields { @@ -235,50 +210,43 @@ fn term_deps( } local_names.truncate(initial_locals_names_len); } - Term::RecordLiteral(_, expr_fields) => { + Term::FormatRecord(_, fields) | Term::FormatOverlap(_, fields) => { let initial_locals_names_len = local_names.len(); - for expr_field in *expr_fields { - if let Some(expr) = expr_field.expr.as_ref() { - term_deps(expr, item_names, local_names, deps); + for field in fields.iter() { + match field { + FormatField::Format { + label: (_, name), + format, + pred, + } => { + term_deps(format, item_names, local_names, deps); + if let Some(pred) = pred { + term_deps(pred, item_names, local_names, deps); + } + local_names.push(*name) + } + FormatField::Computed { + label: (_, name), + r#type, + expr, + } => { + if let Some(r#type) = r#type { + term_deps(r#type, item_names, local_names, deps); + } + term_deps(expr, item_names, local_names, deps); + local_names.push(*name) + } } - local_names.push(expr_field.label.1); } local_names.truncate(initial_locals_names_len); } - Term::Tuple(_, terms) => terms - .iter() - .for_each(|term| term_deps(term, item_names, local_names, deps)), - Term::Proj(_, head_expr, _) => { - term_deps(head_expr, item_names, local_names, deps); - } - Term::ArrayLiteral(_, terms) => { - for term in *terms { - term_deps(term, item_names, local_names, deps); - } - } - Term::FormatRecord(_, fields) => { - field_deps(fields, item_names, local_names, deps); - } - Term::FormatOverlap(_, format_fields) => { - field_deps(format_fields, item_names, local_names, deps); - } Term::FormatCond(_, (_, name), format, cond) => { local_names.push(*name); term_deps(format, item_names, local_names, deps); term_deps(cond, item_names, local_names, deps); local_names.pop(); } - Term::BinOp(_, lhs, _, rhs) => { - term_deps(lhs, item_names, local_names, deps); - term_deps(rhs, item_names, local_names, deps); - } - Term::Hole(_, _) - | Term::Placeholder(_) - | Term::Universe(_) - | Term::StringLiteral(_, _) - | Term::NumberLiteral(_, _) - | Term::BooleanLiteral(_, _) - | Term::ReportedError(_) => {} + _ => term.walk(|term| term_deps(term, item_names, local_names, deps)), } } @@ -296,39 +264,6 @@ fn push_param_deps( } } -fn field_deps( - fields: &[FormatField], - item_names: &FxHashMap, - local_names: &mut Vec, - deps: &mut Vec, -) { - let initial_locals_names_len = local_names.len(); - for field in fields { - match field { - FormatField::Format { - label: (_, label), - format, - .. - } => { - term_deps(format, item_names, local_names, deps); - local_names.push(*label) - } - FormatField::Computed { - label: (_, label), - r#type, - expr, - } => { - if let Some(r#type) = r#type { - term_deps(r#type, item_names, local_names, deps); - } - term_deps(expr, item_names, local_names, deps); - local_names.push(*label) - } - } - } - local_names.truncate(initial_locals_names_len); -} - fn push_pattern(pattern: &Pattern, local_names: &mut Vec) { match pattern { Pattern::Name(_, name) => local_names.push(*name), @@ -338,15 +273,3 @@ fn push_pattern(pattern: &Pattern, local_names: &mut Vec) { Pattern::BooleanLiteral(_, _) => {} } } - -fn pop_pattern(pattern: &Pattern, local_names: &mut Vec) { - match pattern { - Pattern::Name(_, _) => { - local_names.pop(); - } - Pattern::Placeholder(_) => {} - Pattern::StringLiteral(_, _) => {} - Pattern::NumberLiteral(_, _) => {} - Pattern::BooleanLiteral(_, _) => {} - } -} diff --git a/fathom/src/surface/elaboration/reporting.rs b/fathom/src/surface/elaboration/reporting.rs index 8982d02be..ed04c46ab 100644 --- a/fathom/src/surface/elaboration/reporting.rs +++ b/fathom/src/surface/elaboration/reporting.rs @@ -51,8 +51,8 @@ pub enum Message { }, MismatchedFieldLabels { range: FileRange, - expr_labels: Vec<(FileRange, Symbol)>, - type_labels: Vec, + found_labels: Vec<(FileRange, Symbol)>, + expected_labels: Vec, // TODO: add expected type // expected_type: Doc<_>, }, @@ -223,16 +223,16 @@ impl Message { })), Message::MismatchedFieldLabels { range, - expr_labels, - type_labels, + found_labels, + expected_labels, } => { - let mut diagnostic_labels = Vec::with_capacity(expr_labels.len()); + let mut diagnostic_labels = Vec::with_capacity(found_labels.len()); { - let mut type_labels = type_labels.iter().peekable(); + let mut expected_labels = expected_labels.iter().peekable(); - 'expr_labels: for (range, expr_label) in expr_labels.iter() { + 'expr_labels: for (range, expr_label) in found_labels.iter() { 'type_labels: loop { - match type_labels.next() { + match expected_labels.next() { None => { diagnostic_labels.push(primary_label(range).with_message( format!("unexpected field `{}`", expr_label.resolve()), @@ -252,11 +252,11 @@ impl Message { } } - if type_labels.peek().is_some() { + if expected_labels.peek().is_some() { diagnostic_labels.push(primary_label(range).with_message(format!( "missing fields {}", - type_labels - .map(|label|label.resolve()) + expected_labels + .map(|label| label.resolve()) .format_with(", ", |label, f| f(&format_args!("`{label}`"))), ))); } else { @@ -265,10 +265,10 @@ impl Message { } } - let found_labels = (expr_labels.iter()) + let found_labels = (found_labels.iter()) .map(|(_, label)| label.resolve()) .format_with(", ", |label, f| f(&format_args!("`{label}`"))); - let expected_labels = (type_labels.iter()) + let expected_labels = (expected_labels.iter()) .map(|label| label.resolve()) .format_with(", ", |label, f| f(&format_args!("`{label}`"))); diff --git a/fathom/src/surface/elaboration/unification.rs b/fathom/src/surface/elaboration/unification.rs index 852c4e69a..542122261 100644 --- a/fathom/src/surface/elaboration/unification.rs +++ b/fathom/src/surface/elaboration/unification.rs @@ -23,7 +23,7 @@ use crate::alloc::SliceVec; use crate::core::semantics::{ self, ArcValue, Branches, Closure, Elim, Head, SplitBranches, Telescope, Value, }; -use crate::core::{Prim, Term}; +use crate::core::{Builder, Prim, Term}; use crate::env::{EnvLen, Index, Level, SharedEnv, SliceEnv, UniqueEnv}; use crate::source::Spanned; use crate::surface::Plicity; @@ -192,6 +192,38 @@ impl<'arena, 'env> Context<'arena, 'env> { semantics::ElimEnv::new(self.item_exprs, self.meta_exprs) } + fn builder(&self) -> Builder<'arena> { + Builder::new(self.scope) + } + + fn with_local_scope(&mut self, mut f: impl FnMut(&mut Self) -> T) -> T { + let len = self.local_exprs; + let res = f(self); + self.local_exprs.truncate(len); + res + } + + fn with_renaming_scope(&mut self, mut f: impl FnMut(&mut Self) -> T) -> T { + let len = self.renaming.len(); + let res = f(self); + self.renaming.truncate(len); + res + } + + fn with_local(&mut self, mut f: impl FnMut(&mut Self) -> T) -> T { + self.local_exprs.push(); + let res = f(self); + self.local_exprs.pop(); + res + } + + fn with_renaming_local(&mut self, mut f: impl FnMut(&mut Self) -> T) -> T { + self.renaming.push_local(); + let res = f(self); + self.renaming.pop_local(); + res + } + /// Unify two values, updating the solution environment if necessary. pub fn unify( &mut self, @@ -251,16 +283,14 @@ impl<'arena, 'env> Context<'arena, 'env> { self.unify_fun_lit(*plicity, body_expr, &value0) } - (Value::RecordType(labels0, types0), Value::RecordType(labels1, types1)) => { - if labels0 != labels1 { - return Err(Error::Mismatch); - } + (Value::RecordType(labels0, types0), Value::RecordType(labels1, types1)) + if labels0 == labels1 => + { self.unify_telescopes(types0, types1) } - (Value::RecordLit(labels0, exprs0), Value::RecordLit(labels1, exprs1)) => { - if labels0 != labels1 { - return Err(Error::Mismatch); - } + (Value::RecordLit(labels0, exprs0), Value::RecordLit(labels1, exprs1)) + if labels0 == labels1 => + { for (expr0, expr1) in Iterator::zip(exprs0.iter(), exprs1.iter()) { self.unify(expr0, expr1)?; } @@ -269,7 +299,9 @@ impl<'arena, 'env> Context<'arena, 'env> { (Value::RecordLit(labels, exprs), _) => self.unify_record_lit(labels, exprs, &value1), (_, Value::RecordLit(labels, exprs)) => self.unify_record_lit(labels, exprs, &value0), - (Value::ArrayLit(elem_exprs0), Value::ArrayLit(elem_exprs1)) => { + (Value::ArrayLit(elem_exprs0), Value::ArrayLit(elem_exprs1)) + if elem_exprs0.len() == elem_exprs1.len() => + { for (elem_expr0, elem_expr1) in Iterator::zip(elem_exprs0.iter(), elem_exprs1.iter()) { @@ -278,20 +310,16 @@ impl<'arena, 'env> Context<'arena, 'env> { Ok(()) } - (Value::FormatRecord(labels0, formats0), Value::FormatRecord(labels1, formats1)) => { - if labels0 != labels1 { - return Err(Error::Mismatch); - } + (Value::FormatRecord(labels0, formats0), Value::FormatRecord(labels1, formats1)) + if labels0 == labels1 => + { self.unify_telescopes(formats0, formats1) } ( Value::FormatCond(label0, format0, cond0), Value::FormatCond(label1, format1, cond1), - ) => { - if label0 != label1 { - return Err(Error::Mismatch); - } + ) if label0 == label1 => { self.unify(format0, format1)?; self.unify_closures(cond0, cond1) } @@ -347,11 +375,7 @@ impl<'arena, 'env> Context<'arena, 'env> { let value0 = self.elim_env().apply_closure(closure0, var.clone()); let value1 = self.elim_env().apply_closure(closure1, var); - self.local_exprs.push(); - let result = self.unify(&value0, &value1); - self.local_exprs.pop(); - - result + self.with_local(|this| this.unify(&value0, &value1)) } /// Unify two [telescopes][Telescope]. @@ -364,27 +388,22 @@ impl<'arena, 'env> Context<'arena, 'env> { return Err(Error::Mismatch); } - let initial_local_len = self.local_exprs; - let mut telescope0 = telescope0.clone(); - let mut telescope1 = telescope1.clone(); - - while let Some(((value0, next_telescope0), (value1, next_telescope1))) = Option::zip( - self.elim_env().split_telescope(telescope0), - self.elim_env().split_telescope(telescope1), - ) { - if let Err(error) = self.unify(&value0, &value1) { - self.local_exprs.truncate(initial_local_len); - return Err(error); - } - - let var = Spanned::empty(Arc::new(Value::local_var(self.local_exprs.next_level()))); - telescope0 = next_telescope0(var.clone()); - telescope1 = next_telescope1(var); - self.local_exprs.push(); - } + self.with_local_scope(|this| { + let mut telescope0 = telescope0.clone(); + let mut telescope1 = telescope1.clone(); - self.local_exprs.truncate(initial_local_len); - Ok(()) + while let Some(((value0, next_telescope0), (value1, next_telescope1))) = Option::zip( + this.elim_env().split_telescope(telescope0), + this.elim_env().split_telescope(telescope1), + ) { + this.unify(&value0, &value1)?; + let var = Spanned::empty(Arc::new(Value::local_var(this.local_exprs.next_level()))); + telescope0 = next_telescope0(var.clone()); + telescope1 = next_telescope1(var); + this.local_exprs.push(); + } + Ok(()) + }) } /// Unify two [constant branches][Branches]. @@ -437,11 +456,7 @@ impl<'arena, 'env> Context<'arena, 'env> { let value = self.elim_env().fun_app(plicity, value.clone(), var.clone()); let body_expr = self.elim_env().apply_closure(body_expr, var); - self.local_exprs.push(); - let result = self.unify(&body_expr, &value); - self.local_exprs.pop(); - - result + self.with_local(|this| this.unify(&body_expr, &value)) } /// Unify a record literal with a value, using eta-conversion. @@ -519,9 +534,7 @@ impl<'arena, 'env> Context<'arena, 'env> { /// correspond to the given `spine`. fn fun_intros(&self, spine: &[Elim<'arena>], term: Term<'arena>) -> Term<'arena> { spine.iter().fold(term, |term, elim| match elim { - Elim::FunApp(plicity, _) => { - Term::FunLit(term.span(), *plicity, None, self.scope.to_scope(term)) - } + Elim::FunApp(plicity, _) => self.builder().fun_lit(term.span(), *plicity, None, term), Elim::RecordProj(_) | Elim::ConstMatch(_) => { unreachable!("should have been caught by `init_renaming`") } @@ -559,14 +572,14 @@ impl<'arena, 'env> Context<'arena, 'env> { spine.iter().try_fold(head_expr, |head_expr, elim| { Ok(match elim { - Elim::FunApp(plicity, arg_expr) => Term::FunApp( + Elim::FunApp(plicity, arg_expr) => self.builder().fun_app( span, *plicity, - self.scope.to_scope(head_expr), - self.scope.to_scope(self.rename(meta_var, arg_expr)?), + head_expr, + self.rename(meta_var, arg_expr)?, ), Elim::RecordProj(label) => { - Term::RecordProj(span, self.scope.to_scope(head_expr), *label) + self.builder().record_proj(span, head_expr, *label) } Elim::ConstMatch(branches) => { let mut branches = branches.clone(); @@ -608,23 +621,16 @@ impl<'arena, 'env> Context<'arena, 'env> { let param_type = self.rename(meta_var, param_type)?; let body_type = self.rename_closure(meta_var, body_type)?; - Ok(Term::FunType( - span, - *plicity, - *param_name, - self.scope.to_scope(param_type), - self.scope.to_scope(body_type), - )) + Ok(self + .builder() + .fun_type(span, *plicity, *param_name, param_type, body_type)) } Value::FunLit(plicity, param_name, body_expr) => { let body_expr = self.rename_closure(meta_var, body_expr)?; - Ok(Term::FunLit( - span, - *plicity, - *param_name, - self.scope.to_scope(body_expr), - )) + Ok(self + .builder() + .fun_lit(span, *plicity, *param_name, body_expr)) } Value::RecordType(labels, types) => { @@ -658,12 +664,7 @@ impl<'arena, 'env> Context<'arena, 'env> { Value::FormatCond(label, format, cond) => { let format = self.rename(meta_var, format)?; let cond = self.rename_closure(meta_var, cond)?; - Ok(Term::FormatCond( - span, - *label, - self.scope.to_scope(format), - self.scope.to_scope(cond), - )) + Ok(self.builder().format_cond(span, *label, format, cond)) } Value::FormatOverlap(labels, formats) => { let formats = self.rename_telescope(meta_var, formats)?; @@ -684,11 +685,7 @@ impl<'arena, 'env> Context<'arena, 'env> { let source_var = self.renaming.next_local_var(); let value = self.elim_env().apply_closure(closure, source_var); - self.renaming.push_local(); - let term = self.rename(meta_var, &value); - self.renaming.pop_local(); - - term + self.with_renaming_local(|this| this.rename(meta_var, &value)) } /// Rename a telescope back into a [`Term`]. @@ -697,27 +694,22 @@ impl<'arena, 'env> Context<'arena, 'env> { meta_var: Level, telescope: &Telescope<'arena>, ) -> Result<&'arena [Term<'arena>], RenameError> { - let initial_renaming_len = self.renaming.len(); - let mut telescope = telescope.clone(); - let mut terms = SliceVec::new(self.scope, telescope.len()); - - while let Some((value, next_telescope)) = self.elim_env().split_telescope(telescope) { - match self.rename(meta_var, &value) { - Ok(term) => { - terms.push(term); - let source_var = self.renaming.next_local_var(); - telescope = next_telescope(source_var); - self.renaming.push_local(); - } - Err(error) => { - self.renaming.truncate(initial_renaming_len); - return Err(error); - } + self.with_renaming_scope(|this| { + let mut telescope = telescope.clone(); + let mut terms = SliceVec::new(this.scope, telescope.len()); + + while let Some((value, next_telescope)) = + this.elim_env().split_telescope(telescope.clone()) + { + let term = this.rename(meta_var, &value)?; + terms.push(term); + let source_var = this.renaming.next_local_var(); + telescope = next_telescope(source_var); + this.renaming.push_local(); } - } - self.renaming.truncate(initial_renaming_len); - Ok(terms.into()) + Ok(terms.into()) + }) } } diff --git a/fathom/src/surface/grammar.lalrpop b/fathom/src/surface/grammar.lalrpop index d9d0f9c23..1e0973bf8 100644 --- a/fathom/src/surface/grammar.lalrpop +++ b/fathom/src/surface/grammar.lalrpop @@ -3,7 +3,8 @@ use scoped_arena::Scope; use crate::source::{ByteRange, BytePos}; use crate::surface::{ Arg, BinOp, ExprField, FormatField, Item, ItemDef, Module, ParseMessage, - Pattern, Param, Plicity, Term, TypeField, + Pattern, Param, Plicity, Term, TypeField, LetDef, + Builder, }; use crate::surface::lexer::{Error as LexerError, Token}; use crate::symbol::Symbol; @@ -69,19 +70,17 @@ extern { } pub Module: Module<'arena, ByteRange> = { - => Module { - items: scope.to_scope_from_iter(items.into_iter()), - }, + > => Module { items }, }; Item: Item<'arena, ByteRange> = { - "def" )?> "=" ";" => { + "def" > )?> "=" ";" => { Item::Def(ItemDef { range: ByteRange::new(start, end), label, - params: scope.to_scope_from_iter(params), - r#type: r#type.map(|r#type| scope.to_scope(r#type) as &_), - expr: scope.to_scope(expr), + params, + r#type, + expr, }) }, => { @@ -102,53 +101,34 @@ Pattern: Pattern = { pub Term: Term<'arena, ByteRange> = { LetTerm, ":" => { - Term::Ann( - ByteRange::new(start, end), - scope.to_scope(expr), - scope.to_scope(r#type), - ) + Builder::new(scope).ann((start, end), expr, r#type) }, }; LetTerm: Term<'arena, ByteRange> = { FunTerm, - "let" )?> "=" ";" => { - Term::Let( - ByteRange::new(start, end), - def_pattern, - def_type.map(|def_type| scope.to_scope(def_type) as &_), - scope.to_scope(def_expr), - scope.to_scope(body_expr), - ) + "let" ";" => { + Builder::new(scope).r#let((start, end), def, body_expr) }, "if" "then" "else" => { - Term::If(ByteRange::new(start, end), scope.to_scope(cond_expr), scope.to_scope(then_expr), scope.to_scope(else_expr)) + Builder::new(scope).if_then_else((start, end), cond_expr, then_expr, else_expr) }, }; +LetDef: LetDef<'arena, ByteRange> = { + )?> "=" => LetDef {pattern, r#type, expr}, +} + FunTerm: Term<'arena, ByteRange> = { EqExpr, "->" => { - Term::Arrow( - ByteRange::new(start, end), - plicity, - scope.to_scope(param_type), - scope.to_scope(body_type), - ) + Builder::new(scope).arrow((start, end), plicity, param_type, body_type) }, - "fun" "->" => { - Term::FunType( - ByteRange::new(start, end), - scope.to_scope_from_iter(params), - scope.to_scope(output_type), - ) + "fun" > "->" => { + Builder::new(scope).fun_type((start, end), params, output_type) }, - "fun" "=>" => { - Term::FunLiteral( - ByteRange::new(start, end), - scope.to_scope_from_iter(params), - scope.to_scope(output_type), - ) + "fun" > "=>" => { + Builder::new(scope).fun_lit((start, end), params, output_expr) }, }; @@ -180,35 +160,27 @@ MulExpr: Term<'arena, ByteRange> = { AppTerm: Term<'arena, ByteRange> = { ProjTerm, - => { - Term::App( - ByteRange::new(start, end), - scope.to_scope(head_expr), - scope.to_scope_from_iter(args), - ) + > => { + Builder::new(scope).fun_app((start, end), head_expr, args) }, }; ProjTerm: Term<'arena, ByteRange> = { AtomicTerm, - )+> => { - Term::Proj( - ByteRange::new(start, end), - scope.to_scope(head_expr), - scope.to_scope_from_iter(labels), - ) + )>> => { + Builder::new(scope).record_proj((start, end), head_expr, labels) }, }; AtomicTerm: Term<'arena, ByteRange> = { - "(" ")" => Term::Paren(ByteRange::new(start, end), scope.to_scope(term)), + "(" ")" => Builder::new(scope).paren((start, end), term), > => Term::Tuple(ByteRange::new(start, end), terms), => Term::Name(ByteRange::new(start, end), name), "_" => Term::Placeholder(ByteRange::new(start, end)), => Term::Hole(ByteRange::new(start, end), name), "match" "{" "=>" ), ",">> "}" => { - Term::Match(ByteRange::new(start, end), scope.to_scope(scrutinee), equations) + Builder::new(scope).r#match((start, end), scrutinee, equations) }, "Type" => Term::Universe(ByteRange::new(start, end)), => Term::StringLiteral(ByteRange::new(start, end), string), @@ -226,7 +198,7 @@ AtomicTerm: Term<'arena, ByteRange> = { Term::FormatRecord(ByteRange::new(start, end), fields) }, "{" "<-" "|" "}" => { - Term::FormatCond(ByteRange::new(start, end), name, scope.to_scope(format), scope.to_scope(cond)) + Builder::new(scope).format_cond((start, end), name, format, cond) }, "overlap" "{" > "}" => { Term::FormatOverlap(ByteRange::new(start, end), fields) @@ -259,12 +231,7 @@ ExprField: ExprField<'arena, ByteRange> = { BinExpr: Term<'arena, ByteRange> = { => { - Term::BinOp( - ByteRange::new(start, end), - scope.to_scope(lhs), - op, - scope.to_scope(rhs), - ) + Builder::new(scope).binop((start, end), lhs, op, rhs) }, }; @@ -311,6 +278,18 @@ RangedName: (ByteRange, Symbol) = { => (ByteRange::new(start, end), name), }; +List: &'arena [Elem] = { + => { + scope.to_scope_from_iter(elems) + } +} + +List1: &'arena [Elem] = { + => { + scope.to_scope_from_iter(elems) + } +} + Seq: &'arena [Elem] = { Sep)*> => { scope.to_scope_from_iter(elems.into_iter().chain(last)) diff --git a/fathom/src/surface/pretty.rs b/fathom/src/surface/pretty.rs index ebd95a65f..cd6b33173 100644 --- a/fathom/src/surface/pretty.rs +++ b/fathom/src/surface/pretty.rs @@ -42,7 +42,7 @@ impl<'arena> Context<'arena> { .concat([ self.text("def"), self.space(), - match item.r#type { + match item.r#type.as_ref() { None => self.concat([ self.ident(item.label.1), self.params(item.params), @@ -63,7 +63,7 @@ impl<'arena> Context<'arena> { self.space(), self.text("="), self.softline(), - self.term(item.expr), + self.term(&item.expr), self.text(";"), ]) .group(), @@ -148,15 +148,15 @@ impl<'arena> Context<'arena> { self.softline(), self.term(r#type), ]), - Term::Let(_, def_pattern, def_type, def_expr, body_expr) => self.concat([ + Term::Let(_, def, body_expr) => self.concat([ self.concat([ self.text("let"), self.space(), - self.ann_pattern(def_pattern, *def_type), + self.ann_pattern(&def.pattern, def.r#type.as_ref()), self.space(), self.text("="), self.softline(), - self.term(def_expr), + self.term(&def.expr), self.text(";"), ]) .group(),