diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala index 57837ad276..703fc7f471 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala @@ -38,7 +38,7 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase( s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions" } - override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { + override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { r => lazy val optionalParams: Map[String, Any] = getOptionalParams(r) val messages = r.getAs[Seq[Row]](getMessagesCol) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala index 304addaf66..953138bc36 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala @@ -37,7 +37,7 @@ class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid) s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/completions" } - override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { + override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = { r => lazy val optionalParams: Map[String, Any] = getOptionalParams(r) getValueOpt(r, prompt) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index b17b5c59c1..47f2de6204 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -9,12 +9,13 @@ import com.microsoft.azure.synapse.ml.core.spark.Functions import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL} import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} import com.microsoft.azure.synapse.ml.param.StringStringMapParam +import org.apache.http.entity.AbstractHttpEntity import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.{Column, DataFrame, Dataset, functions => F, types => T} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T} import scala.collection.JavaConverters._ @@ -25,6 +26,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer with HasErrorCol with HasOutputCol with HasURL with HasCustomCogServiceDomain with ConcurrencyParams with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader + with HasOpenAICognitiveServiceInput with ComplexParamsWritable with SynapseMLLogging { logClass(FeatureNames.AiServices.OpenAI) @@ -174,6 +176,16 @@ class OpenAIPrompt(override val uid: String) extends Transformer completion } + override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { + r => + openAICompletion match { + case chatCompletion: OpenAIChatCompletion => + chatCompletion.prepareEntity(r) + case completion: OpenAICompletion => + completion.prepareEntity(r) + } + } + private def getParser: OutputParser = { val opts = getPostProcessingOptions diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 6282067b0d..773cbe7547 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -112,6 +112,33 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .foreach(r => assert(r.get(0) != null)) } + ignore("Custom EndPoint") { + lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "") + lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "") + lazy val customHeadersValues: Map[String, String] = Map("X-ModelType" -> "gpt-4-turbo-chat-completions") + + lazy val customPromptGpt4: OpenAIPrompt = new OpenAIPrompt() + .setCustomUrlRoot(customRootUrlValue) + .setOutputCol("outParsed") + .setTemperature(0) + + if (accessToken.isEmpty) { + customPromptGpt4.setSubscriptionKey(openAIAPIKey) + .setDeploymentName(deploymentNameGpt4) + .setCustomServiceName(openAIServiceName) + } else { + customPromptGpt4.setAADToken(accessToken) + .setCustomHeaders(customHeadersValues) + } + + customPromptGpt4.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setPostProcessing("csv") + .transform(df) + .select("outParsed") + .collect() + .count(r => Option(r.getSeq[String](0)).isDefined) + } + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq) }