From 9d9e15437b351f394a2d2826b570d841d63c39df Mon Sep 17 00:00:00 2001 From: minmingzhu Date: Tue, 10 Sep 2024 10:40:39 +0800 Subject: [PATCH] update --- .../apache/spark/ml/regression/spark313/LinearRegression.scala | 3 ++- .../apache/spark/ml/regression/spark333/LinearRegression.scala | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala index 84737560c..043b59f77 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala @@ -462,7 +462,8 @@ class LinearRegression @Since("1.3") (@Since("1.3.0") override val uid: String) dataset.count() } - val paramSupported = ($(regParam) == 0) || ($(regParam) != 0 && $(elasticNetParam) == 0) + val paramSupported = ($(regParam) == 0 || ($(regParam) != 0 && $(elasticNetParam) == 0) + && (!isDefined(weightCol) || getWeightCol.isEmpty)) val sparkContext = dataset.sparkSession.sparkContext val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice) val isPlatformSupported = Utils.checkClusterPlatformCompatibility( diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala index ef71ee0d8..8f2078d2f 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala @@ -460,7 +460,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String dataset.persist(StorageLevel.MEMORY_AND_DISK) dataset.count() } - val paramSupported = ($(regParam) == 0) || ($(regParam) != 0 && $(elasticNetParam) == 0) + val paramSupported = ($(regParam) == 0 || ($(regParam) != 0 && $(elasticNetParam) == 0) + && (!isDefined(weightCol) || getWeightCol.isEmpty)) val sparkContext = dataset.sparkSession.sparkContext val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice) val isPlatformSupported = Utils.checkClusterPlatformCompatibility(