diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 037f8477b3..c2ed5b8305 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -4,12 +4,13 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.codegen.GenerationUtils -import com.microsoft.azure.synapse.ml.fabric.OpenAITokenLibrary +import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting, OpenAITokenLibrary} import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails -import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasAPIVersion, - HasCognitiveServiceInput, HasServiceParams} import com.microsoft.azure.synapse.ml.param.ServiceParam +import com.microsoft.azure.synapse.ml.services._ +import org.apache.spark.ml.PipelineModel import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ import spray.json.DefaultJsonProtocol._ import scala.language.existentials @@ -256,10 +257,21 @@ trait HasOpenAICognitiveServiceInput extends HasCognitiveServiceInput { } else { providedCustomHeader } - } } -abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) { +abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) + with HasOpenAISharedParams with OpenAIFabricSetting { setDefault(timeout -> 360.0) + + private def usingDefaultOpenAIEndpoint(): Boolean = { + getUrl == FabricClient.MLWorkloadEndpointML + "/cognitive/openai/" + } + + override protected def getInternalTransformer(schema: StructType): PipelineModel = { + if (PlatformDetails.runningOnFabric() && usingDefaultOpenAIEndpoint) { + getModelStatus(getDeploymentName) + } + super.getInternalTransformer(schema) + } } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAIFabricSetting.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAIFabricSetting.scala new file mode 100644 index 0000000000..d5a1bc5900 --- /dev/null +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAIFabricSetting.scala @@ -0,0 +1,56 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.fabric + +import spray.json.{JsValue, JsString} + +trait OpenAIFabricSetting extends RESTUtils { + + private def getHeaders: Map[String, String] = { + Map( + "Authorization" -> s"Bearer ${TokenLibrary.getAccessToken}", + "Content-Type" -> "application/json" + ) + } + + def usagePost(url: String, body: String): JsValue = { + usagePost(url, body, getHeaders); + } + + def getModelStatus(modelName: String): Boolean = { + + val payload = + s"""["${modelName}"]""" + + val mlWorkloadEndpointML = FabricClient.MLWorkloadEndpointML + val url = mlWorkloadEndpointML + "cognitive/openai/tenantsetting" + val modelStatus = usagePost(url, payload).asJsObject.fields.get(modelName.toLowerCase).get + + // Allowed, Disallowed, DisallowedForCrossGeo, ModelNotFound, InvalidResult + val resultString: String = modelStatus match { + case JsString(value) => value + case _ => throw new RuntimeException("Unexpected result from type conversion " + + "when checking the fabric tenant settings API.") + } + + resultString match { + case "Disallowed" => throw new RuntimeException(s"Default OpenAI model ${modelName} is Disallowed, " + + s"please contact your admin if you want to use default fabric LLM model. " + + s"Or you can set your Azure OpenAI credentials.") + case "DisallowedForCrossGeo" => throw new RuntimeException(s"Default OpenAI model ${modelName} is Disallowed " + + s"for Cross Geo, please contact your admin if you want to use default fabric LLM model. " + + s"Or you can set your Azure OpenAI credentials." + + s"Refer to https://learn.microsoft.com/en-us/fabric/data-science/ai-services/ai-services-overview " + + s"for more detials") + case "ModelNotFound" => throw new RuntimeException(s"Default OpenAI model ${modelName} not found, " + + s"please check your deployment name. " + + s"Refer to https://learn.microsoft.com/en-us/fabric/data-science/ai-services/ai-services-overview " + + s"for the models available.") + case "InvalidResult" => throw new RuntimeException("Cannot get tenant admin setting status correctly") + case "Allowed" => true + case _ => throw new RuntimeException("Unexpected result from checking the Fabric tenant settings API.") + } + } + +}