diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala index ecc7228406..d357765789 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala @@ -3,33 +3,32 @@ package com.microsoft.azure.synapse.ml.services.language -import com.microsoft.azure.synapse.ml.services._ -import com.microsoft.azure.synapse.ml.services.text.{TADocument, TextAnalyticsAutoBatch} -import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} +import com.microsoft.azure.synapse.ml.logging.{ FeatureNames, SynapseMLLogging } import com.microsoft.azure.synapse.ml.param.ServiceParam -import com.microsoft.azure.synapse.ml.stages.{FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer} -import org.apache.http.entity.{AbstractHttpEntity, StringEntity} +import com.microsoft.azure.synapse.ml.services._ +import com.microsoft.azure.synapse.ml.services.text.{ TADocument, TextAnalyticsAutoBatch } +import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer } +import org.apache.http.entity.{ AbstractHttpEntity, StringEntity } import org.apache.spark.injections.UDFUtils -import org.apache.spark.ml.param.{Param, ParamValidators} +import org.apache.spark.ml.{ ComplexParamsReadable, NamespaceInjections, PipelineModel } +import org.apache.spark.ml.param.{ Param, ParamValidators } import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel} import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.types.{ArrayType, DataType, StructType} -import spray.json.DefaultJsonProtocol._ +import org.apache.spark.sql.types.{ ArrayType, DataType, StructType } import spray.json._ +import spray.json.DefaultJsonProtocol._ -trait AnalyzeTextTaskParameters extends HasServiceParams { +trait HasAnalyzeTextServiceBaseParams extends HasServiceParams { val modelVersion = new ServiceParam[String]( this, name = "modelVersion", "Version of the model") + def getModelVersion: String = getScalarParam(modelVersion) def setModelVersion(v: String): this.type = setScalarParam(modelVersion, v) - + def getModelVersionCol: String = getVectorParam(modelVersion) def setModelVersionCol(v: String): this.type = setVectorParam(modelVersion, v) - def getModelVersion: String = getScalarParam(modelVersion) - def getModelVersionCol: String = getVectorParam(modelVersion) val loggingOptOut = new ServiceParam[Boolean]( this, "loggingOptOut", "loggingOptOut for task" @@ -44,13 +43,15 @@ trait AnalyzeTextTaskParameters extends HasServiceParams { def getLoggingOptOutCol: String = getVectorParam(loggingOptOut) val stringIndexType = new ServiceParam[String](this, "stringIndexType", - "Specifies the method used to interpret string offsets. " + - "Defaults to Text Elements (Graphemes) according to Unicode v8.0.0. " + - "For additional information see https://aka.ms/text-analytics-offsets", - isValid = { - case Left(s) => Set("TextElements_v8", "UnicodeCodePoint", "Utf16CodeUnit")(s) - case _ => true - }) + "Specifies the method used to interpret string offsets. " + + "Defaults to Text Elements(Graphemes) according to Unicode v8.0.0." + + "For more information see https://aka.ms/text-analytics-offsets", + isValid = { + case Left(s) => Set("TextElements_v8", + "UnicodeCodePoint", + "Utf16CodeUnit")(s) + case _ => true + }) def setStringIndexType(v: String): this.type = setScalarParam(stringIndexType, v) @@ -60,6 +61,36 @@ trait AnalyzeTextTaskParameters extends HasServiceParams { def getStringIndexTypeCol: String = getVectorParam(stringIndexType) + val showStats = new ServiceParam[Boolean]( + this, name = "showStats", "Whether to include detailed statistics in the response", + isURLParam = true) + + def setShowStats(v: Boolean): this.type = setScalarParam(showStats, v) + + def getShowStats: Boolean = getScalarParam(showStats) + + // We don't support setKindCol here because output schemas for different kind are different + val kind = new Param[String]( + this, "kind", "Enumeration of supported Text Analysis tasks", + isValid = validKinds.contains(_) + ) + + protected def validKinds: Set[String] + + def setKind(v: String): this.type = set(kind, v) + + def getKind: String = $(kind) + + setDefault( + showStats -> Left(false), + modelVersion -> Left("latest"), + loggingOptOut -> Left(false), + stringIndexType -> Left("TextElements_v8") + ) +} + + +trait AnalyzeTextTaskParameters extends HasAnalyzeTextServiceBaseParams { val opinionMining = new ServiceParam[Boolean]( this, name = "opinionMining", "opinionMining option for SentimentAnalysisTask") @@ -98,9 +129,6 @@ trait AnalyzeTextTaskParameters extends HasServiceParams { def getPiiCategoriesCol: String = getVectorParam(piiCategories) setDefault( - modelVersion -> Left("latest"), - loggingOptOut -> Left(false), - stringIndexType -> Left("TextElements_v8"), opinionMining -> Left(false), domain -> Left("none") ) @@ -131,33 +159,21 @@ class AnalyzeText(override val uid: String) extends CognitiveServicesBase(uid) def this() = this(Identifiable.randomUID("AnalyzeText")) - val showStats = new ServiceParam[Boolean]( - this, name = "showStats", "Whether to include detailed statistics in the response", - isURLParam = true) - - def setShowStats(v: Boolean): this.type = setScalarParam(showStats, v) - - def getShowStats: Boolean = getScalarParam(showStats) + override protected def validKinds: Set[String] = Set("EntityLinking", + "EntityRecognition", + "KeyPhraseExtraction", + "LanguageDetection", + "PiiEntityRecognition", + "SentimentAnalysis") setDefault( - apiVersion -> Left("2022-05-01"), - showStats -> Left(false) + apiVersion -> Left("2022-05-01") ) override def urlPath: String = "/language/:analyze-text" override private[ml] def internalServiceType: String = "textanalytics" - // We don't support setKindCol here because output schemas for different kind are different - val kind = new Param[String]( - this, "kind", "Enumeration of supported Text Analysis tasks", - isValid = ParamValidators.inArray(Array("EntityLinking", "EntityRecognition", "KeyPhraseExtraction", - "LanguageDetection", "PiiEntityRecognition", "SentimentAnalysis")) - ) - - def setKind(v: String): this.type = set(kind, v) - - def getKind: String = $(kind) override protected def shouldSkip(row: Row): Boolean = if (emptyParamData(row, text)) { true diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala new file mode 100644 index 0000000000..6c27ab2c6c --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala @@ -0,0 +1,594 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.language + +import com.microsoft.azure.synapse.ml.core.schema.SparkBindings +import spray.json. RootJsonFormat + +// scalastyle:off number.of.types +case class DocumentWarning(code: String, + message: String, + targetRef: Option[String]) + +object DocumentWarning extends SparkBindings[DocumentWarning] + + +case class SummaryContext(offset: Int, + length: Int) + +object SummaryContext extends SparkBindings[SummaryContext] + +//------------------------------------------------------------------------------------------------------ +// Extractive Summarization +//------------------------------------------------------------------------------------------------------ +case class ExtractiveSummarizationTaskParameters(loggingOptOut: Boolean, + modelVersion: String, + sentenceCount: Option[Int], + sortBy: Option[String], + stringIndexType: String) + +case class ExtractiveSummarizationLROTask(parameters: ExtractiveSummarizationTaskParameters, + taskName: Option[String], + kind: String) + +case class ExtractiveSummarizationJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[ExtractiveSummarizationLROTask]) + +case class ExtractedSummarySentence(text: String, + rankScore: Double, + offset: Int, + length: Int) + +case class ExtractedSummaryDocumentResult(id: String, + warnings: Seq[DocumentWarning], + statistics: Option[RequestStatistics], + sentences: Seq[ExtractedSummarySentence]) + +case class ExtractiveSummarizationResult(errors: Seq[ATError], + statistics: Option[RequestStatistics], + modelVersion: String, + documents: Seq[ExtractedSummaryDocumentResult]) + +case class ExtractiveSummarizationLROResult(results: ExtractiveSummarizationResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class ExtractiveSummarizationTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[ExtractiveSummarizationLROResult]]) + +case class ExtractiveSummarizationJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: ExtractiveSummarizationTaskResult, + statistics: Option[RequestStatistics]) + +object ExtractiveSummarizationJobState extends SparkBindings[ExtractiveSummarizationJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Abstractive Summarization +//------------------------------------------------------------------------------------------------------ +object SummaryLength extends Enumeration { + type SummaryLength = Value + val Short, Medium, Long = Value +} + +case class AbstractiveSummarizationTaskParameters(loggingOptOut: Boolean, + modelVersion: String, + sentenceCount: Option[Int], + stringIndexType: String, + summaryLength: Option[String]) + +case class AbstractiveSummarizationLROTask(parameters: AbstractiveSummarizationTaskParameters, + taskName: Option[String], + kind: String) + +case class AbstractiveSummarizationJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[AbstractiveSummarizationLROTask]) + +case class AbstractiveSummary(text: String, + contexts: Option[Seq[SummaryContext]]) + +case class AbstractiveSummaryDocumentResult(id: String, + warnings: Seq[DocumentWarning], + statistics: Option[RequestStatistics], + summaries: Seq[AbstractiveSummary]) + +object AbstractiveSummaryDocumentResult extends SparkBindings[AbstractiveSummaryDocumentResult] + +case class AbstractiveSummarizationResult(errors: Seq[ATError], + statistics: Option[RequestStatistics], + modelVersion: String, + documents: Seq[AbstractiveSummaryDocumentResult]) + +case class AbstractiveSummarizationLROResult(results: AbstractiveSummarizationResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class AbstractiveSummarizationTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[AbstractiveSummarizationLROResult]]) + +case class AbstractiveSummarizationJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: AbstractiveSummarizationTaskResult, + statistics: Option[RequestStatistics]) + +object AbstractiveSummarizationJobState extends SparkBindings[AbstractiveSummarizationJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// HealthCare +//------------------------------------------------------------------------------------------------------ +case class HealthcareTaskParameters(loggingOptOut: Boolean, + modelVersion: String, + stringIndexType: String) + +case class HealthcareLROTask(parameters: HealthcareTaskParameters, + taskName: Option[String], + kind: String) + +case class HealthcareJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[HealthcareLROTask]) + +case class HealthcareAssertion(conditionality: Option[String], + certainty: Option[String], + association: Option[String], + temporality: Option[String]) + +case class HealthcareEntitiesDocumentResult(id: String, + warnings: Seq[DocumentWarning], + statistics: Option[RequestStatistics], + entities: Seq[HealthcareEntity], + relations: Seq[HealthcareRelation], + fhirBundle: Option[String]) + +case class HealthcareEntity(text: String, + category: String, + subcategory: Option[String], + offset: Int, + length: Int, + confidenceScore: Double, + assertion: Option[HealthcareAssertion], + name: Option[String], + links: Option[Seq[HealthcareEntityLink]]) + +case class HealthcareEntityLink(dataSource: String, + id: String) + +case class HealthcareLROResult(results: HealthcareResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class HealthcareRelation(relationType: String, + entities: Seq[HealthcareRelationEntity], + confidenceScore: Option[Double]) + +case class HealthcareRelationEntity(ref: String, + role: String) + +case class HealthcareResult(errors: Seq[DocumentError], + statistics: Option[RequestStatistics], + modelVersion: String, + documents: Seq[HealthcareEntitiesDocumentResult]) + +case class HealthcareTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[HealthcareLROResult]]) + +case class HealthcareJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: HealthcareTaskResult, + statistics: Option[RequestStatistics]) + +object HealthcareJobState extends SparkBindings[HealthcareJobState] + +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Sentiment Analysis +//------------------------------------------------------------------------------------------------------ +case class SentimentAnalysisLROTask(parameters: SentimentAnalysisTaskParameters, + taskName: Option[String], + kind: String) + +case class SentimentAnalysisJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[SentimentAnalysisLROTask]) + +case class SentimentAnalysisLROResult(results: SentimentResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class SentimentAnalysisTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[SentimentAnalysisLROResult]]) + +case class SentimentAnalysisJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: SentimentAnalysisTaskResult, + statistics: Option[RequestStatistics]) + +object SentimentAnalysisJobState extends SparkBindings[SentimentAnalysisJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Key Phrase Extraction +//------------------------------------------------------------------------------------------------------ +case class KeyPhraseExtractionLROTask(parameters: KPnLDTaskParameters, + taskName: Option[String], + kind: String) + +case class KeyPhraseExtractionJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[KeyPhraseExtractionLROTask]) + +case class KeyPhraseExtractionLROResult(results: KeyPhraseExtractionResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class KeyPhraseExtractionTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[KeyPhraseExtractionLROResult]]) + +case class KeyPhraseExtractionJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: KeyPhraseExtractionTaskResult, + statistics: Option[RequestStatistics]) + +object KeyPhraseExtractionJobState extends SparkBindings[KeyPhraseExtractionJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// PII Entity Recognition +//------------------------------------------------------------------------------------------------------ +object PiiDomain extends Enumeration { + type PiiDomain = Value + val None, Phi = Value +} + +case class PiiEntityRecognitionLROTask(parameters: PiiTaskParameters, + taskName: Option[String], + kind: String) + +case class PiiEntityRecognitionJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[PiiEntityRecognitionLROTask]) + +case class PiiEntityRecognitionLROResult(results: PIIResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class PiiEntityRecognitionTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[PiiEntityRecognitionLROResult]]) + +case class PiiEntityRecognitionJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: PiiEntityRecognitionTaskResult, + statistics: Option[RequestStatistics]) + +object PiiEntityRecognitionJobState extends SparkBindings[PiiEntityRecognitionJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Entity Linking +//------------------------------------------------------------------------------------------------------ +case class EntityLinkingLROTask(parameters: EntityTaskParameters, + taskName: Option[String], + kind: String) + + +case class EntityLinkingJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[EntityLinkingLROTask]) + + +case class EntityLinkingLROResult(results: EntityLinkingResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class EntityLinkingTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[EntityLinkingLROResult]]) + +case class EntityLinkingJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: EntityLinkingTaskResult, + statistics: Option[RequestStatistics]) + +object EntityLinkingJobState extends SparkBindings[EntityLinkingJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Entity Recognition +//------------------------------------------------------------------------------------------------------ + +case class EntityRecognitionTaskParameters(loggingOptOut: Boolean, + modelVersion: String, + stringIndexType: String, + inclusionList: Option[Seq[String]], + exclusionList: Option[Seq[String]], + overlapPolicy: Option[EntityOverlapPolicy], + inferenceOptions: Option[EntityInferenceOptions]) + +case class EntityOverlapPolicy(policyKind: String) + +case class EntityInferenceOptions(excludeNormalizedValues: Boolean) + +case class EntityRecognitionLROTask(parameters: EntityRecognitionTaskParameters, + taskName: Option[String], + kind: String) + +case class EntityRecognitionJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[EntityRecognitionLROTask]) + +case class EntityRecognitionLROResult(results: EntityRecognitionResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class EntityRecognitionTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[EntityRecognitionLROResult]]) + +case class EntityRecognitionJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: EntityRecognitionTaskResult, + statistics: Option[RequestStatistics]) + +object EntityRecognitionJobState extends SparkBindings[EntityRecognitionJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Custom Entity Recognoition +//------------------------------------------------------------------------------------------------------ +case class CustomEntitiesTaskParameters(loggingOptOut: Boolean, + stringIndexType: String, + deploymentName: String, + projectName: String) + +case class CustomEntityRecognitionLROTask(parameters: CustomEntitiesTaskParameters, + taskName: Option[String], + kind: String) + +case class CustomEntitiesJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[CustomEntityRecognitionLROTask]) +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Custom Label Classification +//------------------------------------------------------------------------------------------------------ +case class CustomLabelTaskParameters(loggingOptOut: Boolean, + deploymentName: String, + projectName: String) + +case class CustomLabelLROTask(parameters: CustomLabelTaskParameters, + taskName: Option[String], + kind: String) + +case class CustomLabelJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[CustomLabelLROTask]) + +case class ClassificationDocumentResult(id: String, + warnings: Seq[DocumentWarning], + statistics: Option[RequestStatistics], + classifications: Seq[ClassificationResult]) + +//object ClassificationDocumentResult extends SparkBindings[ClassificationDocumentResult] + +case class ClassificationResult(category: String, + confidenceScore: Double) + +object ClassificationResult extends SparkBindings[ClassificationResult] + +case class CustomLabelResult(errors: Seq[DocumentError], + statistics: Option[RequestStatistics], + modelVersion: String, + documents: Seq[ClassificationDocumentResult]) + +case class CustomLabelLROResult(results: CustomLabelResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class CustomLabelTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[CustomLabelLROResult]]) + +case class CustomLabelJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: CustomLabelTaskResult, + statistics: Option[RequestStatistics]) + +object CustomLabelJobState extends SparkBindings[CustomLabelJobState] +//------------------------------------------------------------------------------------------------------ + + +object ATLROJSONFormat { + + import spray.json.DefaultJsonProtocol._ + import ATJSONFormat._ + + implicit val DocumentWarningFormat: RootJsonFormat[DocumentWarning] = + jsonFormat3(DocumentWarning.apply) + + implicit val ExtractiveSummarizationTaskParametersF: RootJsonFormat[ExtractiveSummarizationTaskParameters] = + jsonFormat5(ExtractiveSummarizationTaskParameters.apply) + + implicit val ExtractiveSummarizationLROTaskF: RootJsonFormat[ExtractiveSummarizationLROTask] = + jsonFormat3(ExtractiveSummarizationLROTask.apply) + + implicit val ExtractiveSummarizationJobsInputF: RootJsonFormat[ExtractiveSummarizationJobsInput] = + jsonFormat3(ExtractiveSummarizationJobsInput.apply) + + implicit val AbstractiveSummarizationTaskParametersF: RootJsonFormat[AbstractiveSummarizationTaskParameters] = + jsonFormat5(AbstractiveSummarizationTaskParameters.apply) + + implicit val AbstractiveSummarizationLROTaskF: RootJsonFormat[AbstractiveSummarizationLROTask] = + jsonFormat3(AbstractiveSummarizationLROTask.apply) + + implicit val AbstractiveSummarizationJobsInputF: RootJsonFormat[AbstractiveSummarizationJobsInput] = + jsonFormat3(AbstractiveSummarizationJobsInput.apply) + + implicit val HealthcareTaskParametersF: RootJsonFormat[HealthcareTaskParameters] = + jsonFormat3(HealthcareTaskParameters.apply) + + implicit val HealthcareLROTaskF: RootJsonFormat[HealthcareLROTask] = + jsonFormat3(HealthcareLROTask.apply) + + implicit val HealthcareJobsInputF: RootJsonFormat[HealthcareJobsInput] = + jsonFormat3(HealthcareJobsInput.apply) + + implicit val SentimentAnalysisLROTaskF: RootJsonFormat[SentimentAnalysisLROTask] = + jsonFormat3(SentimentAnalysisLROTask.apply) + + implicit val SentimentAnalysisJobsInputF: RootJsonFormat[SentimentAnalysisJobsInput] = + jsonFormat3(SentimentAnalysisJobsInput.apply) + + implicit val KeyPhraseExtractionLROTaskF: RootJsonFormat[KeyPhraseExtractionLROTask] = + jsonFormat3(KeyPhraseExtractionLROTask.apply) + + implicit val KeyPhraseExtractionJobsInputF: RootJsonFormat[KeyPhraseExtractionJobsInput] = + jsonFormat3(KeyPhraseExtractionJobsInput.apply) + + implicit val PiiEntityRecognitionLROTaskF: RootJsonFormat[PiiEntityRecognitionLROTask] = + jsonFormat3(PiiEntityRecognitionLROTask.apply) + + implicit val PiiEntityRecognitionJobsInputF: RootJsonFormat[PiiEntityRecognitionJobsInput] = + jsonFormat3(PiiEntityRecognitionJobsInput.apply) + + implicit val EntityLinkingLROTaskF: RootJsonFormat[EntityLinkingLROTask] = + jsonFormat3(EntityLinkingLROTask.apply) + + implicit val EntityLinkingJobsInputF: RootJsonFormat[EntityLinkingJobsInput] = + jsonFormat3(EntityLinkingJobsInput.apply) + + implicit val EntityOverlapPolicyF: RootJsonFormat[EntityOverlapPolicy] = + jsonFormat1(EntityOverlapPolicy.apply) + + implicit val EntityInferenceOptionsF: RootJsonFormat[EntityInferenceOptions] = + jsonFormat1(EntityInferenceOptions.apply) + + implicit val EntityRecognitionTaskParametersF: RootJsonFormat[EntityRecognitionTaskParameters] = + jsonFormat7(EntityRecognitionTaskParameters.apply) + + implicit val EntityRecognitionLROTaskF: RootJsonFormat[EntityRecognitionLROTask] = + jsonFormat3(EntityRecognitionLROTask.apply) + + implicit val EntityRecognitionJobsInputF: RootJsonFormat[EntityRecognitionJobsInput] = + jsonFormat3(EntityRecognitionJobsInput.apply) + + implicit val CustomEntitiesTaskParametersF: RootJsonFormat[CustomEntitiesTaskParameters] = + jsonFormat4(CustomEntitiesTaskParameters.apply) + + implicit val CustomEntityRecognitionLROTaskF: RootJsonFormat[CustomEntityRecognitionLROTask] = + jsonFormat3(CustomEntityRecognitionLROTask.apply) + + implicit val CustomEntitiesJobsInputF: RootJsonFormat[CustomEntitiesJobsInput] = + jsonFormat3(CustomEntitiesJobsInput.apply) + + implicit val CustomSingleLabelTaskParametersF: RootJsonFormat[CustomLabelTaskParameters] = + jsonFormat3(CustomLabelTaskParameters.apply) + + implicit val CustomSingleLabelLROTaskF: RootJsonFormat[CustomLabelLROTask] = + jsonFormat3(CustomLabelLROTask.apply) + + implicit val CustomSingleLabelJobsInputF: RootJsonFormat[CustomLabelJobsInput] = + jsonFormat3(CustomLabelJobsInput.apply) +} diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala new file mode 100644 index 0000000000..7f33107e7b --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala @@ -0,0 +1,638 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.language + +import com.microsoft.azure.synapse.ml.io.http.{ EntityData, HTTPResponseData } +import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging +import com.microsoft.azure.synapse.ml.param.ServiceParam +import com.microsoft.azure.synapse.ml.services.HasServiceParams +import com.microsoft.azure.synapse.ml.services.language.ATLROJSONFormat._ +import com.microsoft.azure.synapse.ml.services.language.PiiDomain.PiiDomain +import com.microsoft.azure.synapse.ml.services.language.SummaryLength.SummaryLength +import com.microsoft.azure.synapse.ml.services.vision.BasicAsyncReply +import org.apache.commons.io.IOUtils +import org.apache.http.impl.client.CloseableHttpClient +import org.apache.spark.ml.param.ParamValidators +import org.apache.spark.sql.Row +import spray.json.DefaultJsonProtocol._ +import spray.json.enrichAny + +import java.net.URI + +object AnalysisTaskKind extends Enumeration { + type AnalysisTaskKind = Value + val SentimentAnalysis, + EntityRecognition, + PiiEntityRecognition, + KeyPhraseExtraction, + EntityLinking, + Healthcare, + CustomEntityRecognition, + CustomSingleLabelClassification, + CustomMultiLabelClassification, + ExtractiveSummarization, + AbstractiveSummarization = Value + + def getKindFromString(kind: String): AnalysisTaskKind = { + AnalysisTaskKind.values.find(_.toString == kind).getOrElse( + throw new IllegalArgumentException(s"Invalid kind: $kind") + ) + } +} + +private[language] trait HasSummarizationBaseParameter extends HasServiceParams { + val sentenceCount = new ServiceParam[Int]( + this, + name = "sentenceCount", + doc = "Specifies the number of sentences in the extracted summary.", + isValid = { case Left(value) => value >= 1 case Right(_) => true } + ) + + def getSentenceCount: Int = getScalarParam(sentenceCount) + + def setSentenceCount(value: Int): this.type = setScalarParam(sentenceCount, value) + + def getSentenceCountCol: String = getVectorParam(sentenceCount) + + def setSentenceCountCol(value: String): this.type = setVectorParam(sentenceCount, value) +} + +/** + * This trait is used to handle the extractive summarization request. It provides the necessary + * parameters to create the request and the method to create the request. There are two + * parameters for extractive summarization: sentenceCount and sortBy. Both of them are optional. + * If the user does not provide any value for sentenceCount, the service will return the default + * number of sentences in the summary. If the user does not provide any value for sortBy, the + * service will return the summary in the order of the sentences in the input text. The possible values + * for sortBy are "Rank" and "Offset". If the user provides an invalid value for sortBy, the service + * will return an error. SentenceCount is an integer value, and it should be greater than 0. This parameter + * specifies the number of sentences in the extracted summary. If the user provides an invalid value for + * sentenceCount, the service will return an error. For more details about the parameters, please refer to + * the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/summarization/overview]] + */ +private[language] trait HandleExtractiveSummarization extends HasServiceParams + with HasSummarizationBaseParameter { + val sortBy = new ServiceParam[String]( + this, + name = "sortBy", + doc = "Specifies how to sort the extracted summaries. This can be either 'Rank' or 'Offset'.", + isValid = { + case Left(value) => ParamValidators.inArray(Array("Rank", "Offset"))(value) + case Right(_) => true + }) + + def getSortBy: String = getScalarParam(sortBy) + + def setSortBy(value: String): this.type = setScalarParam(sortBy, value) + + def getSortByCol: String = getVectorParam(sortBy) + + def setSortByCol(value: String): this.type = setVectorParam(sortBy, value) + + private[language] def createExtractiveSummarizationRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = ExtractiveSummarizationLROTask( + parameters = ExtractiveSummarizationTaskParameters( + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + sentenceCount = getValueOpt(row, sentenceCount), + sortBy = getValueOpt(row, sortBy), + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.ExtractiveSummarization.toString + ) + + ExtractiveSummarizationJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the abstractive summarization request. It provides the necessary + * parameters to create the request and the method to create the request. There are two + * parameters for abstractive summarization: sentenceCount and summaryLength. Both of them are optional. + * It is recommended to use summaryLength over sentenceCount. Service may ignore sentenceCount parameter. + * SummaryLength is a string value, and it should be one of "short", "medium", or "long". This parameter + * controls the approximate length of the output summaries. If the user provides an invalid value for + * summaryLength, the service will return an error. For more details about the parameters, please refer to + * the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/summarization/overview]] + */ +private[language] trait HandleAbstractiveSummarization extends HasServiceParams with HasSummarizationBaseParameter { + val summaryLength = new ServiceParam[String]( + this, + name = "summaryLength", + doc = "(NOTE: Recommended to use summaryLength over sentenceCount) Controls the" + + " approximate length of the output summaries.", + isValid = { + case Left(value) => ParamValidators.inArray(Array("short", "medium", "long"))(value) + case Right(_) => true + } + + ) + + def getSummaryLength: String = getScalarParam(summaryLength) + + def setSummaryLength(value: String): this.type = setScalarParam(summaryLength, value) + + def setSummaryLength(value: SummaryLength): this.type = setScalarParam(summaryLength, value.toString.toLowerCase) + + def getSummaryLengthCol: String = getVectorParam(summaryLength) + + def setSummaryLengthCol(value: String): this.type = setVectorParam(summaryLength, value) + + private[language] def createAbstractiveSummarizationRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val paramerter = AbstractiveSummarizationLROTask( + parameters = AbstractiveSummarizationTaskParameters( + sentenceCount = getValueOpt(row, sentenceCount), + summaryLength = getValueOpt(row, summaryLength), + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + stringIndexType = stringIndexType), + taskName = None, + kind = AnalysisTaskKind.AbstractiveSummarization.toString + ) + AbstractiveSummarizationJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(paramerter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the healthcare text analytics request. It provides the necessary parameters + * to create the request and the method to create the request. There are three parameters for healthcare text + * analytics: modelVersion, stringIndexType, and loggingOptOut. All of them are optional. For more details about + * the parameters, please refer to the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/text-analytics-for-health/overview]] + */ +private[language] trait HandleHealthcareTextAnalystics extends HasServiceParams { + private[language] def createHealthcareTextAnalyticsRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = HealthcareLROTask( + parameters = HealthcareTaskParameters( + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.Healthcare.toString + ) + HealthcareJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the text analytics request. It provides the necessary parameters to create + * the request and the method to create the request. There are three parameters for text analytics: modelVersion, + * stringIndexType, and loggingOptOut. All of them are optional. For more details about the parameters, please refer + * to the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/sentiment-opinion-mining/overview]] + */ +private[language] trait HandleSentimentAnalysis extends HasServiceParams { + val opinionMining = new ServiceParam[Boolean]( + this, + name = "opinionMining", + doc = "Whether to use opinion mining in the request or not." + ) + + def getOpinionMining: Boolean = getScalarParam(opinionMining) + + def setOpinionMining(value: Boolean): this.type = setScalarParam(opinionMining, value) + + def getOpinionMiningCol: String = getVectorParam(opinionMining) + + def setOpinionMiningCol(value: String): this.type = setVectorParam(opinionMining, value) + + setDefault( + opinionMining -> Left(false) + ) + + private[language] def createSentimentAnalysisRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = SentimentAnalysisLROTask( + parameters = SentimentAnalysisTaskParameters( + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + opinionMining = getValue(row, opinionMining), + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.SentimentAnalysis.toString + ) + SentimentAnalysisJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the key phrase extraction request. It provides the necessary parameters to create + * the request and the method to create the request. There are two parameters for key phrase extraction: modelVersion + * and loggingOptOut. Both of them are optional. For more details about the parameters, + * please refer to the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/key-phrase-extraction/overview]] + */ +private[language] trait HandleKeyPhraseExtraction extends HasServiceParams { + private[language] def createKeyPhraseExtractionRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + // This parameter is not used and only exists for compatibility + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = KeyPhraseExtractionLROTask( + parameters = KPnLDTaskParameters( + loggingOptOut = loggingOptOut, + modelVersion = modelVersion + ), + taskName = None, + kind = AnalysisTaskKind.KeyPhraseExtraction.toString + ) + KeyPhraseExtractionJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +private[language] trait HandleEntityLinking extends HasServiceParams { + private[language] def createEntityLinkingRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = EntityLinkingLROTask( + parameters = EntityTaskParameters( + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.EntityLinking.toString + ) + EntityLinkingJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the PII entity recognition request. It provides the necessary parameters to create + * the request and the method to create the request. There are three parameters for PII entity recognition: domain, + * piiCategories, and loggingOptOut. All of them are optional. For more details about the parameters, please refer to + * the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/personally-identifiable-information/overview]] + */ +private[language] trait HandlePiiEntityRecognition extends HasServiceParams { + val domain = new ServiceParam[String]( + this, + name = "domain", + doc = "The domain of the PII entity recognition request.", + isValid = { + case Left(value) => PiiDomain.values.map(_.toString.toLowerCase).contains(value) + case Right(_) => true + } + ) + + def getDomain: String = getScalarParam(domain) + + def setDomain(value: String): this.type = setScalarParam(domain, value) + + def setDomain(value: PiiDomain): this.type = setScalarParam(domain, value.toString.toLowerCase) + + def getDomainCol: String = getVectorParam(domain) + + def setDomainCol(value: String): this.type = setVectorParam(domain, value) + + val piiCategories = new ServiceParam[Seq[String]](this, "piiCategories", + "describes the PII categories to return") + + def setPiiCategories(v: Seq[String]): this.type = setScalarParam(piiCategories, v) + + def getPiiCategories: Seq[String] = getScalarParam(piiCategories) + + def setPiiCategoriesCol(v: String): this.type = setVectorParam(piiCategories, v) + + def getPiiCategoriesCol: String = getVectorParam(piiCategories) + + setDefault( + domain -> Left("none") + ) + + private[language] def createPiiEntityRecognitionRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = PiiEntityRecognitionLROTask( + parameters = PiiTaskParameters( + domain = getValue(row, domain), + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + piiCategories = getValueOpt(row, piiCategories), + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.PiiEntityRecognition.toString + ) + PiiEntityRecognitionJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the entity recognition request. It provides the necessary parameters to create + * the request and the method to create the request. There are five parameters for entity recognition: inclusionList, + * exclusionList, overlapPolicy, excludeNormalizedValues, and loggingOptOut. All of them are optional. For more details + * about the parameters, please refer to the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/named-entity-recognition/overview]] + */ +private[language] trait HandleEntityRecognition extends HasServiceParams { + val inclusionList = new ServiceParam[Seq[String]]( + this, + name = "inclusionList", + doc = "(Optional) request parameter that limits the output to the requested entity" + + " types included in this list. We will apply inclusionList before" + + " exclusionList" + ) + + def getInclusionList: Seq[String] = getScalarParam(inclusionList) + + def setInclusionList(value: Seq[String]): this.type = setScalarParam(inclusionList, value) + + def getInclusionListCol: String = getVectorParam(inclusionList) + + def setInclusionListCol(value: String): this.type = setVectorParam(inclusionList, value) + + val exclusionList = new ServiceParam[Seq[String]]( + this, + name = "exclusionList", + doc = "(Optional) request parameter that filters out any entities that are" + + " included the excludeList. When a user specifies an excludeList, they cannot" + + " get a prediction returned with an entity in that list. We will apply" + + " inclusionList before exclusionList" + ) + + def getExclusionList: Seq[String] = getScalarParam(exclusionList) + + def setExclusionList(value: Seq[String]): this.type = setScalarParam(exclusionList, value) + + def getExclusionListCol: String = getVectorParam(exclusionList) + + def setExclusionListCol(value: String): this.type = setVectorParam(exclusionList, value) + + val overlapPolicy = new ServiceParam[String]( + this, + name = "overlapPolicy", + doc = "(Optional) describes the type of overlap policy to apply to the ner output.", + isValid = { + case Left(value) => value == "matchLongest" || value == "allowOverlap" + case Right(_) => true + }) + + def getOverlapPolicy: String = getScalarParam(overlapPolicy) + + def setOverlapPolicy(value: String): this.type = setScalarParam(overlapPolicy, value) + + def getOverlapPolicyCol: String = getVectorParam(overlapPolicy) + + def setOverlapPolicyCol(value: String): this.type = setVectorParam(overlapPolicy, value) + + val excludeNormalizedValues = new ServiceParam[Boolean]( + this, + name = "excludeNormalizedValues", + doc = "(Optional) request parameter that allows the user to provide settings for" + + " running the inference. If set to true, the service will exclude normalized" + ) + + def getExcludeNormalizedValues: Boolean = getScalarParam(excludeNormalizedValues) + + def setExcludeNormalizedValues(value: Boolean): this.type = setScalarParam(excludeNormalizedValues, value) + + def getExcludeNormalizedValuesCol: String = getVectorParam(excludeNormalizedValues) + + def setexcludeNormalizedValuesCol(value: String): this.type = setVectorParam(excludeNormalizedValues, value) + + private[language] def createEntityRecognitionRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val serviceOverlapPolicy: Option[EntityOverlapPolicy] = getValueOpt(row, overlapPolicy) match { + case Some(policy) => Some(EntityOverlapPolicy(policy)) + case None => None + } + + val inferenceOptions: Option[EntityInferenceOptions] = getValueOpt(row, excludeNormalizedValues) match { + case Some(value) => Some(EntityInferenceOptions(value)) + case None => None + } + val taskParameter = EntityRecognitionLROTask( + parameters = EntityRecognitionTaskParameters( + exclusionList = getValueOpt(row, exclusionList), + inclusionList = getValueOpt(row, inclusionList), + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + overlapPolicy = serviceOverlapPolicy, + stringIndexType = stringIndexType, + inferenceOptions = inferenceOptions + ), + taskName = None, + kind = AnalysisTaskKind.EntityRecognition.toString + ) + EntityRecognitionJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +private[language] trait HasCustomLanguageModelParam extends HasServiceParams { + val projectName = new ServiceParam[String]( + this, + name = "projectName", + doc = "This field indicates the project name for the model. This is a required field" + ) + + def getProjectName: String = getScalarParam(projectName) + + def setProjectName(value: String): this.type = setScalarParam(projectName, value) + + def getProjectNameCol: String = getVectorParam(projectName) + + def setProjectNameCol(value: String): this.type = setVectorParam(projectName, value) + + val deploymentName = new ServiceParam[String]( + this, + name = "deploymentName", + doc = "This field indicates the deployment name for the model. This is a required field." + ) + + def getDeploymentName: String = getScalarParam(deploymentName) + + def setDeploymentName(value: String): this.type = setScalarParam(deploymentName, value) + + def getDeploymentNameCol: String = getVectorParam(deploymentName) + + def setDeploymentNameCol(value: String): this.type = setVectorParam(deploymentName, value) +} + +private[language] trait HandleCustomEntityRecognition extends HasServiceParams + with HasCustomLanguageModelParam { + + private[language] def createCustomEntityRecognitionRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + // This paremeter is not used and only exists for compatibility + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = CustomEntityRecognitionLROTask( + parameters = CustomEntitiesTaskParameters( + loggingOptOut = loggingOptOut, + projectName = getValue(row, projectName), + deploymentName = getValue(row, deploymentName), + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.CustomEntityRecognition.toString) + CustomEntitiesJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * Trait `ModifiableAsyncReply` extends `BasicAsyncReply` and provides a mechanism to modify the HTTP response + * received from an asynchronous service call. This trait is designed to be mixed into classes that require + * custom handling of the response data. + * + * The primary purpose of this trait is to allow modification of the response before it is processed further. + * This is particularly useful in scenarios where the response needs to be transformed or certain fields need + * to be renamed to comply with specific requirements or constraints. + * + * In this implementation, the `queryForResult` method is overridden and marked as `final` to prevent further + * overriding. This ensures that the response modification logic is consistently applied across all subclasses. + * + * @note This trait is designed to be used with the `SynapseMLLogging` trait for consistent logging. + */ +trait ModifiableAsyncReply extends BasicAsyncReply { + self: SynapseMLLogging => + + protected def modifyResponse(response: Option[HTTPResponseData]): Option[HTTPResponseData] = response + + /** + * Queries for the result of an asynchronous service call and applies the response modification logic. + */ + override final protected def queryForResult(key: Option[String], + client: CloseableHttpClient, + location: URI): Option[HTTPResponseData] = { + val originalResponse = super.queryForResult(key, client, location) + logDebug(s"Original response: $originalResponse") + modifyResponse(originalResponse) + } +} + + +/** + * Trait `HandleCustomLabelClassification` extends `HasServiceParams` and `HasCustomLanguageModelParam` to handle + * custom label classification tasks. This trait provides the necessary methods to create requests for custom + * multi-label classification and to modify the response to comply with specific requirements. + * + * The primary purpose of this trait is to address the limitation in Spark where fields named "class" cannot be + * directly bound. To work around this limitation, the response is modified to rename the "class" field to + * "classifications". + * + * This trait is designed to be mixed into classes that require custom label classification functionality and + * response modification logic. + * + * @note This trait is designed to be used with the `ModifiableAsyncReply` and `SynapseMLLogging` traits for + * consistent response handling and logging. + */ +private[language] trait HandleCustomLabelClassification extends HasServiceParams + with HasCustomLanguageModelParam { + self: ModifiableAsyncReply + with SynapseMLLogging => + + private def isCustomLabelClassification: Boolean = { + val kind = getKind + kind == AnalysisTaskKind.CustomSingleLabelClassification.toString || + kind == AnalysisTaskKind.CustomMultiLabelClassification.toString + } + + /** + * Modifies the entity in the HTTP response to rename the "class" field to "classifications". + * + * @param response The original HTTP response. + * @return The modified HTTP response with the "class" field renamed to "classifications". + */ + private def modifyEntity(response: HTTPResponseData): HTTPResponseData = { + val modifiedEntity = response.entity.flatMap { entity => + val strEntity = IOUtils.toString(entity.content, "UTF-8") + val modifiedEntity = strEntity.replace("\"class\":", "\"classifications\":") + logDebug(s"Original entity: $strEntity\t Modified entity: $modifiedEntity") + Some(new EntityData( + content = modifiedEntity.getBytes, + contentEncoding = entity.contentEncoding, + contentLength = Some(strEntity.length), + contentType = entity.contentType, + isChunked = entity.isChunked, + isRepeatable = entity.isRepeatable, + isStreaming = entity.isStreaming + )) + } + new HTTPResponseData(response.headers, modifiedEntity, response.statusLine, response.locale) + } + + /** + * Modifies the HTTP response if the task kind is custom label classification. + */ + override def modifyResponse(response: Option[HTTPResponseData]): Option[HTTPResponseData] = { + if (!isCustomLabelClassification) { + logDebug(s"Kind is not CustomSingleLabelClassification or CustomMultiLabelClassification. Kind: $getKind") + response + } else { + response.map(modifyEntity) + } + } + + + def getKind: String + + private[language] def createCustomMultiLabelRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + // This paremeter is not used and only exists for compatibility + modelVersion: String, + // This paremeter is not used and only exists for compatibility + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = CustomLabelLROTask( + parameters = CustomLabelTaskParameters( + loggingOptOut = loggingOptOut, + projectName = getValue(row, projectName), + deploymentName = getValue(row, deploymentName) + ), + taskName = None, + kind = getKind + ) + CustomLabelJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala new file mode 100644 index 0000000000..cc5d0f1ed4 --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala @@ -0,0 +1,243 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.language + +import com.microsoft.azure.synapse.ml.logging.{ FeatureNames, SynapseMLLogging } +import com.microsoft.azure.synapse.ml.services.{ CognitiveServicesBaseNoHandler, HasAPIVersion, + HasCognitiveServiceInput, HasInternalJsonOutputParser, HasSetLocation } +import com.microsoft.azure.synapse.ml.services.text.{ TADocument, TextAnalyticsAutoBatch } +import com.microsoft.azure.synapse.ml.services.vision.BasicAsyncReply +import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer } +import org.apache.http.entity.{ AbstractHttpEntity, StringEntity } +import org.apache.spark.injections.UDFUtils +import org.apache.spark.ml.{ ComplexParamsReadable, NamespaceInjections, PipelineModel } +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.types.{ ArrayType, DataType, StructType } +import org.apache.spark.sql.Row +import org.apache.spark.sql.expressions.UserDefinedFunction + +import java.net.URI + +object AnalyzeTextLongRunningOperations extends ComplexParamsReadable[AnalyzeTextLongRunningOperations] + with Serializable + +/** + *
+ * This transformer is used to analyze text using the Azure AI Language service. It uses AI service asynchronously. + * For more details please visit + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/concepts/use-asynchronously]] + * For each row, it submits a job to the service and polls the service until the job is complete. Delay between + * polling requests is controlled by the [[pollingDelay]] parameter, which is set to 1000 milliseconds by default. + * Number of polling attempts is controlled by the [[maxPollingRetries]] parameter, which is set to 1000 by default. + *
+ *+ * This transformer will use the field specified as TextCol to submit the text to the service. The response from the + * service will be stored in the field specified as OutputCol. The response will be a struct with the + * following fields: + *
+ * This transformer support only single task per row. The task to be performed is specified by the [[kind]] parameter. + * The supported tasks are: + *