diff --git a/core/src/main/java/org/bouncycastle/pqc/crypto/xmss/XMSSUtil.java b/core/src/main/java/org/bouncycastle/pqc/crypto/xmss/XMSSUtil.java index ea8fa64aa3..79c63ac80a 100644 --- a/core/src/main/java/org/bouncycastle/pqc/crypto/xmss/XMSSUtil.java +++ b/core/src/main/java/org/bouncycastle/pqc/crypto/xmss/XMSSUtil.java @@ -8,6 +8,8 @@ import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.ObjectStreamClass; +import java.util.HashSet; +import java.util.Set; import org.bouncycastle.crypto.Digest; import org.bouncycastle.util.Arrays; @@ -382,6 +384,24 @@ public static boolean isNewAuthenticationPathNeeded(long globalIndex, int xmssHe private static class CheckingStream extends ObjectInputStream { + private static final Set components = new HashSet<>(); + + static + { + components.add("java.util.TreeMap"); + components.add("java.lang.Integer"); + components.add("java.lang.Number"); + components.add("org.bouncycastle.pqc.crypto.xmss.BDS"); + components.add("java.util.ArrayList"); + components.add("org.bouncycastle.pqc.crypto.xmss.XMSSNode"); + components.add("[B"); + components.add("java.util.LinkedList"); + components.add("java.util.Stack"); + components.add("java.util.Vector"); + components.add("[Ljava.lang.Object;"); + components.add("org.bouncycastle.pqc.crypto.xmss.BDSTreeHash"); + } + private final Class mainClass; private boolean found = false; @@ -409,6 +429,14 @@ protected Class resolveClass(ObjectStreamClass desc) found = true; } } + else + { + if (!components.contains(desc.getName())) + { + throw new InvalidClassException( + "unexpected class: ", desc.getName()); + } + } return super.resolveClass(desc); } }