Skip to content

Commit

Permalink
merge: more progress on classes (#1129)
Browse files Browse the repository at this point in the history
  • Loading branch information
ice1000 committed Jul 12, 2024
2 parents e8cf3f8 + 8f52aab commit 5cecead
Show file tree
Hide file tree
Showing 27 changed files with 452 additions and 100 deletions.
54 changes: 48 additions & 6 deletions base/src/main/java/org/aya/tyck/ExprTycker.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.jetbrains.annotations.Nullable;

import java.util.Comparator;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;

Expand Down Expand Up @@ -134,6 +135,11 @@ case PiTerm(var dom, var cod) -> {
};
}

/**
* @param type expected type
* @param result wellTyped + actual type from synthesize
* @param expr original expr, used for error reporting
*/
private @NotNull Jdg inheritFallbackUnify(@NotNull Term type, @NotNull Jdg result, @NotNull WithPos<Expr> expr) {
type = whnf(type);
var resultType = result.type();
Expand All @@ -155,6 +161,22 @@ case PiTerm(var dom, var cod) -> {
return new Jdg.Default(new LamTerm(closure), eq);
}
}
// Try coercive subtyping between classes
if (type instanceof ClassCall clazz) {
// Try coercive subtyping for `SomeClass (foo := 114514)` into `SomeClass`
resultType = whnf(resultType);
if (resultType instanceof ClassCall resultClazz) {
// TODO: check whether resultClazz <: clazz
if (true) {
// No need to coerce
if (clazz.args().size() == resultClazz.args().size()) return result;
var forget = resultClazz.args().drop(clazz.args().size());
return new Jdg.Default(new ClassCastTerm(clazz.ref(), result.wellTyped(), clazz.args(), forget), type);
} else {
return makeErrorResult(type, result);
}
}
}
if (unifyTyReported(type, resultType, expr)) return result;
return makeErrorResult(type, result);
}
Expand Down Expand Up @@ -315,6 +337,19 @@ yield subscoped(param.ref(), wellParam, () ->
var type = new DataCall(def, 0, ImmutableSeq.of(elementTy));
yield new Jdg.Default(new ListTerm(results, match.recog(), type), type);
}
case Expr.New(var classCall) -> {
var wellTyped = synthesize(classCall);
if (!(wellTyped.wellTyped() instanceof ClassCall call)) {
yield fail(expr.data(), new ClassError.NotClassCall(classCall));
}

// check whether the call is fully applied
if (call.args().size() != call.ref().members().size()) {
yield fail(expr.data(), new ClassError.NotFullyApplied(classCall));
}

yield new Jdg.Default(new NewTerm(call), call);
}
case Expr.Unresolved _ -> Panic.unreachable();
default -> fail(expr.data(), new NoRuleError(expr, null));
};
Expand Down Expand Up @@ -342,9 +377,9 @@ yield subscoped(param.ref(), wellParam, () ->
case LocalVar ref when localLet.contains(ref) -> generateApplication(args, localLet.get(ref)).lift(lift);
case LocalVar lVar -> generateApplication(args,
new Jdg.Default(new FreeTerm(lVar), localCtx().get(lVar))).lift(lift);
case CompiledVar(var content) -> new AppTycker<>(state, sourcePos, args.size(), lift, (params, k) ->
case CompiledVar(var content) -> new AppTycker<>(this, sourcePos, args.size(), lift, (params, k) ->
computeArgs(sourcePos, args, params, k)).checkCompiledApplication(content);
case DefVar<?, ?> defVar -> new AppTycker<>(state, sourcePos, args.size(), lift, (params, k) ->
case DefVar<?, ?> defVar -> new AppTycker<>(this, sourcePos, args.size(), lift, (params, k) ->
computeArgs(sourcePos, args, params, k)).checkDefApplication(defVar);
default -> throw new UnsupportedOperationException("TODO");
};
Expand All @@ -359,10 +394,11 @@ case CompiledVar(var content) -> new AppTycker<>(state, sourcePos, args.size(),

private Jdg computeArgs(
@NotNull SourcePos pos, @NotNull ImmutableSeq<Expr.NamedArg> args,
@NotNull AbstractTele params, @NotNull Function<Term[], Jdg> k
@NotNull AbstractTele params, @NotNull BiFunction<Term[], Term, Jdg> k
) throws NotPi {
int argIx = 0, paramIx = 0;
var result = new Term[params.telescopeSize()];
Term firstType = null;
while (argIx < args.size() && paramIx < params.telescopeSize()) {
var arg = args.get(argIx);
var param = params.telescopeRich(paramIx, result);
Expand All @@ -373,33 +409,39 @@ private Jdg computeArgs(
break;
} else if (arg.name() == null) {
// here, arg.explicit() == true and param.explicit() == false
if (paramIx == 0) firstType = param.type();
result[paramIx++] = insertImplicit(param, arg.sourcePos());
continue;
}
}
if (arg.name() != null && !param.nameEq(arg.name())) {
if (paramIx == 0) firstType = param.type();
result[paramIx++] = insertImplicit(param, arg.sourcePos());
continue;
}
result[paramIx++] = inherit(arg.arg(), param.type()).wellTyped();
var what = inherit(arg.arg(), param.type());
if (paramIx == 0) firstType = param.type();
result[paramIx++] = what.wellTyped();
argIx++;
}
// Trailing implicits
while (paramIx < params.telescopeSize()) {
if (params.telescopeLicit(paramIx)) break;
var param = params.telescopeRich(paramIx, result);
if (paramIx == 0) firstType = param.type();
result[paramIx++] = insertImplicit(param, pos);
}
var extraParams = MutableStack.<Pair<LocalVar, Term>>create();
if (argIx < args.size()) {
return generateApplication(args.drop(argIx), k.apply(result));
return generateApplication(args.drop(argIx), k.apply(result, firstType));
} else while (paramIx < params.telescopeSize()) {
var param = params.telescopeRich(paramIx, result);
var atarashiVar = LocalVar.generate(param.name());
extraParams.push(new Pair<>(atarashiVar, param.type()));
if (paramIx == 0) firstType = param.type();
result[paramIx++] = new FreeTerm(atarashiVar);
}
var generated = k.apply(result);
var generated = k.apply(result, firstType);
while (extraParams.isNotEmpty()) {
var pair = extraParams.pop();
generated = new Jdg.Default(
Expand Down
2 changes: 1 addition & 1 deletion base/src/main/java/org/aya/tyck/StmtTycker.java
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ private void checkMember(@NotNull ClassMember member, @NotNull ExprTycker tycker
new Param("self", classCall, false),
classRef.concrete.sourcePos()
);
new MemberDef(classRef, member.ref, signature.params(), signature.result());
new MemberDef(classRef, member.ref, classRef.concrete.members.indexOf(member), signature.params(), signature.result());
member.ref.signature = signature;
}

Expand Down
28 changes: 28 additions & 0 deletions base/src/main/java/org/aya/tyck/error/ClassError.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2020-2024 Tesla (Yinsen) Zhang.
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.tyck.error;

import org.aya.pretty.doc.Doc;
import org.aya.syntax.concrete.Expr;
import org.aya.util.error.SourcePos;
import org.aya.util.error.WithPos;
import org.aya.util.prettier.PrettierOptions;
import org.jetbrains.annotations.NotNull;

public interface ClassError extends TyckError {
@NotNull WithPos<Expr> problemExpr();

@Override default @NotNull SourcePos sourcePos() { return problemExpr().sourcePos(); }

record NotClassCall(@Override @NotNull WithPos<Expr> problemExpr) implements ClassError {
@Override public @NotNull Doc describe(@NotNull PrettierOptions options) {
return Doc.sep(Doc.english("Unable to new a non-class type:"), Doc.code(problemExpr.data().toDoc(options)));
}
}

record NotFullyApplied(@Override @NotNull WithPos<Expr> problemExpr) implements ClassError {
@Override public @NotNull Doc describe(@NotNull PrettierOptions options) {
return Doc.sep(Doc.english("Unable to new an incomplete class type:"), Doc.code(problemExpr.data().toDoc(options)));
}
}
}
70 changes: 48 additions & 22 deletions base/src/main/java/org/aya/tyck/tycker/AppTycker.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,39 @@
import kala.collection.Seq;
import kala.collection.SeqView;
import kala.collection.immutable.ImmutableArray;
import kala.collection.immutable.ImmutableSeq;
import kala.function.CheckedBiFunction;
import org.aya.generic.stmt.Shaped;
import org.aya.syntax.compile.JitCon;
import org.aya.syntax.compile.JitData;
import org.aya.syntax.compile.JitFn;
import org.aya.syntax.compile.JitPrim;
import org.aya.syntax.concrete.stmt.decl.*;
import org.aya.syntax.core.Closure;
import org.aya.syntax.core.def.*;
import org.aya.syntax.core.repr.AyaShape;
import org.aya.syntax.core.term.FreeTerm;
import org.aya.syntax.core.term.Term;
import org.aya.syntax.core.term.*;
import org.aya.syntax.core.term.call.*;
import org.aya.syntax.ref.DefVar;
import org.aya.syntax.ref.LocalVar;
import org.aya.syntax.telescope.AbstractTele;
import org.aya.tyck.Jdg;
import org.aya.tyck.TyckState;
import org.aya.unify.Synthesizer;
import org.aya.util.error.Panic;
import org.aya.util.error.SourcePos;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.function.Function;
import java.util.function.BiFunction;

public record AppTycker<Ex extends Exception>(
@NotNull TyckState state, @NotNull SourcePos pos,
int argsCount, int lift, @NotNull Factory<Ex> makeArgs
) {
@Override @NotNull TyckState state,
@NotNull AbstractTycker tycker,
@NotNull SourcePos pos,
int argsCount, int lift,
@NotNull Factory<Ex> makeArgs
) implements Stateful {
/**
* <pre>
* Signature (0th param) --------> Argument Parser (this interface)
Expand All @@ -45,7 +50,14 @@ public record AppTycker<Ex extends Exception>(
*/
@FunctionalInterface
public interface Factory<Ex extends Exception> extends
CheckedBiFunction<AbstractTele, Function<Term[], Jdg>, Jdg, Ex> {
CheckedBiFunction<AbstractTele, BiFunction<Term[], Term, Jdg>, Jdg, Ex> {
}

public AppTycker(
@NotNull AbstractTycker tycker, @NotNull SourcePos pos,
int argsCount, int lift, @NotNull Factory<Ex> makeArgs
) {
this(tycker.state, tycker, pos, argsCount, lift, makeArgs);
}

public @NotNull Jdg checkCompiledApplication(@NotNull AbstractTele def) throws Ex {
Expand Down Expand Up @@ -93,7 +105,7 @@ public interface Factory<Ex extends Exception> extends
// ownerTele + selfTele
var fullSignature = conVar.signature().lift(lift);

return makeArgs.applyChecked(fullSignature, args -> {
return makeArgs.applyChecked(fullSignature, (args, _) -> {
var realArgs = ImmutableArray.from(args);
var ownerArgs = realArgs.take(conVar.ownerTeleSize());
var conArgs = realArgs.drop(conVar.ownerTeleSize());
Expand All @@ -109,14 +121,14 @@ public interface Factory<Ex extends Exception> extends
}
private @NotNull Jdg checkPrimCall(@NotNull PrimDefLike primVar) throws Ex {
var signature = primVar.signature().lift(lift);
return makeArgs.applyChecked(signature, args -> new Jdg.Default(
return makeArgs.applyChecked(signature, (args, _) -> new Jdg.Default(
state.primFactory.unfold(new PrimCall(primVar, 0, ImmutableArray.from(args)), state),
signature.result(args)
));
}
private @NotNull Jdg checkDataCall(@NotNull DataDefLike data) throws Ex {
var signature = data.signature().lift(lift);
return makeArgs.applyChecked(signature, args -> new Jdg.Default(
return makeArgs.applyChecked(signature, (args, _) -> new Jdg.Default(
new DataCall(data, 0, ImmutableArray.from(args)),
signature.result(args)
));
Expand All @@ -125,7 +137,7 @@ public interface Factory<Ex extends Exception> extends
@NotNull FnDefLike fnDef, @Nullable Shaped.Applicable<FnDefLike> operator
) throws Ex {
var signature = fnDef.signature().lift(lift);
return makeArgs.applyChecked(signature, args -> {
return makeArgs.applyChecked(signature, (args, _) -> {
var argsSeq = ImmutableArray.from(args);
var result = signature.result(args);
if (operator != null) {
Expand All @@ -137,10 +149,10 @@ public interface Factory<Ex extends Exception> extends
}

private @NotNull Jdg checkClassCall(@NotNull ClassDefLike clazz) throws Ex {
var appliedParams = ofClassMembers(clazz, argsCount).lift(lift);
var self = LocalVar.generate("self");
var appliedParams = ofClassMembers(self, clazz, argsCount).lift(lift);
state.classThis.push(self);
var result = makeArgs.applyChecked(appliedParams, args -> new Jdg.Default(
var result = makeArgs.applyChecked(appliedParams, (args, _) -> new Jdg.Default(
new ClassCall(clazz, 0, ImmutableArray.from(args).map(x -> x.bind(self))),
appliedParams.result(args)
));
Expand All @@ -150,23 +162,29 @@ public interface Factory<Ex extends Exception> extends

private @NotNull Jdg checkProjCall(@NotNull MemberDefLike member) throws Ex {
var signature = member.signature().lift(lift);
return makeArgs.applyChecked(signature, args -> {
return makeArgs.applyChecked(signature, (args, fstTy) -> {
assert args.length >= 1;
var ofTy = whnf(fstTy);
if (!(ofTy instanceof ClassCall classTy)) throw new UnsupportedOperationException("report"); // TODO
var fieldArgs = ImmutableArray.fill(args.length - 1, i -> args[i + 1]);
return new Jdg.Default(
new MemberCall(args[0], member, 0, fieldArgs),
MemberCall.make(classTy, args[0], member, 0, fieldArgs),
signature.result(args)
);
});
}

static @NotNull AbstractTele ofClassMembers(@NotNull ClassDefLike def, int memberCount) {
private @NotNull AbstractTele ofClassMembers(@NotNull LocalVar self, @NotNull ClassDefLike def, int memberCount) {
var synthesizer = new Synthesizer(tycker);
return switch (def) {
case ClassDef.Delegate delegate -> new TakeMembers(delegate.core(), memberCount);
case ClassDef.Delegate delegate -> new TakeMembers(self, delegate.core(), memberCount, synthesizer);
};
}

record TakeMembers(@NotNull ClassDef clazz, @Override int telescopeSize) implements AbstractTele {
record TakeMembers(
@NotNull LocalVar self, @NotNull ClassDef clazz,
@Override int telescopeSize, @NotNull Synthesizer synthesizer
) implements AbstractTele {
@Override public boolean telescopeLicit(int i) { return true; }
@Override public @NotNull String telescopeName(int i) {
assert i < telescopeSize;
Expand All @@ -175,18 +193,26 @@ record TakeMembers(@NotNull ClassDef clazz, @Override int telescopeSize) impleme

// class Foo
// | foo : A
// | + : A -> A -> A
// | infix + : A -> A -> A
// | bar : Fn (x : Foo A) -> (x.foo) self.+ (self.foo)
// instantiate these! ^ ^
@Override public @NotNull Term telescope(int i, Seq<Term> teleArgs) {
// teleArgs are former members
assert i < telescopeSize;
var member = clazz.members().get(i);
return TyckDef.defSignature(member.ref()).makePi(Seq.of(new FreeTerm(clazz.ref().concrete.self)));
return TyckDef.defSignature(member.ref()).inst(ImmutableSeq.of(new NewTerm(
new ClassCall(new ClassDef.Delegate(clazz.ref()), 0,
ImmutableSeq.fill(clazz.members().size(), idx -> Closure.mkConst(idx < i ? teleArgs.get(idx) : ErrorTerm.DUMMY))
)
))).makePi(Seq.empty());
}

@Override public @NotNull Term result(Seq<Term> teleArgs) {
// Use SigmaTerm::lub
throw new UnsupportedOperationException("TODO");
return clazz.members().view()
.drop(telescopeSize)
.map(member -> TyckDef.defSignature(member.ref()).inst(ImmutableSeq.of(new FreeTerm(self))).makePi(Seq.empty()))
.map(ty -> (SortTerm) synthesizer.synth(ty))
.foldLeft(SortTerm.Type0, SigmaTerm::lub);
}
@Override public @NotNull SeqView<String> namesView() {
return clazz.members().sliceView(0, telescopeSize).map(i -> i.ref().name());
Expand Down
2 changes: 2 additions & 0 deletions base/src/main/java/org/aya/unify/Synthesizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ case MetaCall(var ref, var args) when ref.req() instanceof MetaVar.OfType(var ty
case MetaLitTerm mlt -> mlt.type();
case StringTerm str -> state().primFactory.getCall(PrimDef.ID.STRING);
case ClassCall classCall -> throw new UnsupportedOperationException("TODO");
case NewTerm newTerm -> newTerm.inner();
case ClassCastTerm castTerm -> new ClassCall(castTerm.ref(), 0, castTerm.remember());
};
}

Expand Down
Loading

0 comments on commit 5cecead

Please # to comment.