Skip to content

Commit

Permalink
HIVE-26539: Prevent unsafe deserialization in PartitionExpressionForM…
Browse files Browse the repository at this point in the history
…etastore (apache#3605)
  • Loading branch information
dengzhhu653 authored and yeahyung committed Jul 20, 2023
1 parent 4d0dc5f commit a25d61a
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ private static class KryoWithHooks extends Kryo implements Configurable {
private Hook globalHook;
// this should be set on-the-fly after borrowing this instance and needs to be reset on release
private Configuration configuration;
// default false, should be reset on release
private boolean isExprNodeFirst = false;
// total classes we have met during (de)serialization, should be reset on release
private long classCounter = 0;

@SuppressWarnings({"unchecked", "rawtypes"})
private static final class SerializerWithHook extends com.esotericsoftware.kryo.Serializer {
Expand Down Expand Up @@ -228,6 +232,32 @@ public void setConf(Configuration conf) {
public Configuration getConf() {
return configuration;
}

@Override
public com.esotericsoftware.kryo.Registration getRegistration(Class type) {
// If PartitionExpressionForMetastore performs deserialization at remote HMS,
// the first class encountered during deserialization must be an ExprNodeDesc,
// throw exception to avoid potential security problem if it is not.
if (isExprNodeFirst && classCounter == 0) {
if (!ExprNodeDesc.class.isAssignableFrom(type)) {
throw new UnsupportedOperationException(
"The object to be deserialized must be an ExprNodeDesc, but encountered: " + type);
}
}
classCounter++;
return super.getRegistration(type);
}

public void setExprNodeFirst(boolean isPartFilter) {
this.isExprNodeFirst = isPartFilter;
}

// reset the fields on release
public void restore() {
setConf(null);
isExprNodeFirst = false;
classCounter = 0;
}
}

private static final Object FAKE_REFERENCE = new Object();
Expand Down Expand Up @@ -294,7 +324,7 @@ public static Kryo borrowKryo(Configuration configuration) {
*/
public static void releaseKryo(Kryo kryo) {
if (kryo != null){
((KryoWithHooks) kryo).setConf(null);
((KryoWithHooks) kryo).restore();
}
kryoPool.free(kryo);
}
Expand Down Expand Up @@ -830,10 +860,13 @@ public static byte[] serializeObjectWithTypeInformation(Serializable object) {
/**
* Deserializes expression from Kryo.
* @param bytes Bytes containing the expression.
* @param isPartFilter ture if it is a partition filter
* @return Expression; null if deserialization succeeded, but the result type is incorrect.
*/
public static <T> T deserializeObjectWithTypeInformation(byte[] bytes) {
Kryo kryo = borrowKryo();
public static <T> T deserializeObjectWithTypeInformation(byte[] bytes,
boolean isPartFilter) {
KryoWithHooks kryo = (KryoWithHooks) borrowKryo();
kryo.setExprNodeFirst(isPartFilter);
try (Input inp = new Input(new ByteArrayInputStream(bytes))) {
return (T) kryo.readClassAndObject(inp);
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.slf4j.Logger;
Expand Down Expand Up @@ -108,7 +107,7 @@ public boolean filterPartitionsByExpr(List<FieldSchema> partColumns,
private ExprNodeDesc deserializeExpr(byte[] exprBytes) throws MetaException {
ExprNodeDesc expr = null;
try {
expr = SerializationUtilities.deserializeObjectWithTypeInformation(exprBytes);
expr = SerializationUtilities.deserializeObjectWithTypeInformation(exprBytes, true);
} catch (Exception ex) {
LOG.error("Failed to deserialize the expression", ex);
throw new MetaException(ex.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -32,10 +33,16 @@
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.io.orc.OrcInputFormat;
import org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.apache.hadoop.hive.ql.plan.VectorPartitionDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNull;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.junit.Assert;
import org.junit.Test;

Expand Down Expand Up @@ -127,6 +134,25 @@ public void testSkippingAppliesToAllPartitions() throws Exception {
assertPartitionDescPropertyPresence(mapWork, "/warehouse/test_table/p=1", "serialization.ddl", false);
}

@Test
public void testUnsupportedDeserialization() throws Exception {
ArrayList<Long> invalidExpr = new ArrayList<>();
invalidExpr.add(1L);
byte[] buf = SerializationUtilities.serializeObjectWithTypeInformation(invalidExpr);
try {
SerializationUtilities.deserializeObjectWithTypeInformation(buf, true);
Assert.fail("Should throw exception as the input is not a valid filter");
} catch (UnsupportedOperationException e) {
// ignore
}

ExprNodeDesc validExpr = ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPNull(),
Arrays.asList(new ExprNodeColumnDesc(new ColumnInfo("_c0", TypeInfoFactory.stringTypeInfo, "a", false))));
buf = SerializationUtilities.serializeObjectWithTypeInformation(validExpr);
ExprNodeDesc desc = SerializationUtilities.deserializeObjectWithTypeInformation(buf, true);
Assert.assertTrue(ExprNodeDescUtils.isSame(validExpr, desc));
}

private MapWork doSerDeser(Configuration configuration) throws Exception, IOException {
MapWork mapWork = mockMapWorkWithSomePartitionDescProperties();
ByteArrayOutputStream baos = new ByteArrayOutputStream();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ private void initialize() {
isInitialized = pm != null;
if (isInitialized) {
dbType = determineDatabaseProduct();
expressionProxy = createExpressionProxy(conf);
expressionProxy = PartFilterExprUtil.createExpressionProxy(conf);
if (MetastoreConf.getBoolVar(getConf(), ConfVars.TRY_DIRECT_SQL)) {
String schema = PersistenceManagerProvider.getProperty("javax.jdo.mapping.Schema");
schema = org.apache.commons.lang3.StringUtils.defaultIfBlank(schema, null);
Expand All @@ -447,25 +447,6 @@ private static String getProductName(PersistenceManager pm) {
}
}

/**
* Creates the proxy used to evaluate expressions. This is here to prevent circular
* dependency - ql -&gt; metastore client &lt;-&gt metastore server -&gt ql. If server and
* client are split, this can be removed.
* @param conf Configuration.
* @return The partition expression proxy.
*/
private static PartitionExpressionProxy createExpressionProxy(Configuration conf) {
String className = MetastoreConf.getVar(conf, ConfVars.EXPRESSION_PROXY_CLASS);
try {
Class<? extends PartitionExpressionProxy> clazz =
JavaUtils.getClass(className, PartitionExpressionProxy.class);
return JavaUtils.newInstance(clazz, new Class<?>[0], new Object[0]);
} catch (MetaException e) {
LOG.error("Error loading PartitionExpressionProxy", e);
throw new RuntimeException("Error loading PartitionExpressionProxy: " + e.getMessage());
}
}

/**
* Configure SSL encryption to the database store.
*
Expand Down

0 comments on commit a25d61a

Please # to comment.