Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: Adding Custom Url Endpoints and Headers #2232

Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ trait HasCustomAuthHeader extends HasServiceParams {
}
}

trait HasCustomHeader extends HasServiceParams {
// scalastyle:off field.name
val CustomHeader = new ServiceParam[Map[String, String]](
this, "CustomHeader", "List of Custom Header Key-Value Tuples."
)
// scalastyle:on field.name

def setCustomHeader(v: Map[String, String]): this.type = {
setScalarParam(CustomHeader, v)
}

def getCustomHeader: Map[String, String] = getScalarParam(CustomHeader)
}

trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
def setCustomServiceName(v: String): this.type = {
setUrl(s"https://$v.cognitiveservices.azure.com/" + urlPath.stripPrefix("/"))
Expand Down Expand Up @@ -256,7 +270,15 @@ object URLEncodingUtils {
}

trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
with SynapseMLLogging {
with HasCustomHeader with SynapseMLLogging {

val customUrlRoot: Param[String] = new Param[String](
this, "customUrlRoot", "The custom URL root for the service. " +
"This will not append OpenAI specific model path completions (i.e. /chat/completions) to the URL.")

def getCustomUrlRoot: String = $(customUrlRoot)

def setCustomUrlRoot(v: String): this.type = set(customUrlRoot, v)

protected def paramNameToPayloadName(p: Param[_]): String = p match {
case p: ServiceParam[_] => p.payloadName
Expand All @@ -281,7 +303,11 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
} else {
""
}
prepareUrlRoot(row) + appended
if (get(customUrlRoot).nonEmpty) {
$(customUrlRoot)
} else {
prepareUrlRoot(row) + appended
}
}
}

Expand All @@ -296,20 +322,25 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
protected def contentType: Row => String = { _ => "application/json" }

protected def getCustomAuthHeader(row: Row): Option[String] = {
val providedCustomHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomHeader .isEmpty && PlatformDetails.runningOnFabric()) {
val providedCustomAuthHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomAuthHeader .isEmpty && PlatformDetails.runningOnFabric()) {
logInfo("Using Default AAD Token On Fabric")
Option(TokenLibrary.getAuthHeader)
} else {
providedCustomHeader
providedCustomAuthHeader
}
}

protected def getCustomHeader(row: Row): Option[Map[String, String]] = {
getValueOpt(row, CustomHeader)
}

protected def addHeaders(req: HttpRequestBase,
subscriptionKey: Option[String],
aadToken: Option[String],
contentType: String = "",
customAuthHeader: Option[String] = None): Unit = {
customAuthHeader: Option[String] = None,
customHeader: Option[Map[String, String]] = None): Unit = {

if (subscriptionKey.nonEmpty) {
req.setHeader(subscriptionKeyHeaderName, subscriptionKey.get)
Expand All @@ -326,6 +357,13 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
req.setHeader("x-ms-workload-resource-moniker", UUID.randomUUID().toString)
})
}
if (customHeader.nonEmpty) {
customHeader.foreach(m => {
m.foreach {
case (headerName, headerValue) => req.setHeader(headerName, headerValue)
}
})
}
if (contentType != "") req.setHeader("Content-Type", contentType)
}

Expand All @@ -342,7 +380,8 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row))
getCustomAuthHeader(row),
getCustomHeader(row))

req match {
case er: HttpEntityEnclosingRequestBase =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ trait HasPromptInputs extends HasServiceParams {
trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

val deploymentName = new ServiceParam[String](
this, "deploymentName", "The name of the deployment", isRequired = true)
this, "deploymentName", "The name of the deployment", isRequired = false)

def getDeploymentName: String = getScalarParam(deploymentName)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,29 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion]
assert(Option(results.apply(2).getAs[Row]("out")).isEmpty)
}

test("Custom EndPoint") {
lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "")
lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "")

val customEndpointCompletion = new OpenAIChatCompletion()
.setCustomUrlRoot(customRootUrlValue)
.setOutputCol("out")
.setMessagesCol("messages")
.setTemperature(0)

if (accessToken.isEmpty) {
customEndpointCompletion.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentNameGpt4)
.setCustomServiceName(openAIServiceName)
} else {
customEndpointCompletion.setAADToken(accessToken)
.setCustomHeader(Map("X-ModelType" -> "gpt-4-turbo-chat-completions",
"X-ScenarioGUID" -> "7687c733-45b0-425b-82b3-05eb4eb70247"))
}

testCompletion(customEndpointCompletion, goodDf)
}

def testCompletion(completion: OpenAIChatCompletion, df: DataFrame, requiredLength: Int = 10): Unit = {
val fromRow = ChatCompletionResponse.makeFromRowConverter
completion.transform(df).collect().foreach(r =>
Expand Down
Loading