diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala index ecc7228406..d357765789 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala @@ -3,33 +3,32 @@ package com.microsoft.azure.synapse.ml.services.language -import com.microsoft.azure.synapse.ml.services._ -import com.microsoft.azure.synapse.ml.services.text.{TADocument, TextAnalyticsAutoBatch} -import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} +import com.microsoft.azure.synapse.ml.logging.{ FeatureNames, SynapseMLLogging } import com.microsoft.azure.synapse.ml.param.ServiceParam -import com.microsoft.azure.synapse.ml.stages.{FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer} -import org.apache.http.entity.{AbstractHttpEntity, StringEntity} +import com.microsoft.azure.synapse.ml.services._ +import com.microsoft.azure.synapse.ml.services.text.{ TADocument, TextAnalyticsAutoBatch } +import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer } +import org.apache.http.entity.{ AbstractHttpEntity, StringEntity } import org.apache.spark.injections.UDFUtils -import org.apache.spark.ml.param.{Param, ParamValidators} +import org.apache.spark.ml.{ ComplexParamsReadable, NamespaceInjections, PipelineModel } +import org.apache.spark.ml.param.{ Param, ParamValidators } import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel} import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.types.{ArrayType, DataType, StructType} -import spray.json.DefaultJsonProtocol._ +import org.apache.spark.sql.types.{ ArrayType, DataType, StructType } import spray.json._ +import spray.json.DefaultJsonProtocol._ -trait AnalyzeTextTaskParameters extends HasServiceParams { +trait HasAnalyzeTextServiceBaseParams extends HasServiceParams { val modelVersion = new ServiceParam[String]( this, name = "modelVersion", "Version of the model") + def getModelVersion: String = getScalarParam(modelVersion) def setModelVersion(v: String): this.type = setScalarParam(modelVersion, v) - + def getModelVersionCol: String = getVectorParam(modelVersion) def setModelVersionCol(v: String): this.type = setVectorParam(modelVersion, v) - def getModelVersion: String = getScalarParam(modelVersion) - def getModelVersionCol: String = getVectorParam(modelVersion) val loggingOptOut = new ServiceParam[Boolean]( this, "loggingOptOut", "loggingOptOut for task" @@ -44,13 +43,15 @@ trait AnalyzeTextTaskParameters extends HasServiceParams { def getLoggingOptOutCol: String = getVectorParam(loggingOptOut) val stringIndexType = new ServiceParam[String](this, "stringIndexType", - "Specifies the method used to interpret string offsets. " + - "Defaults to Text Elements (Graphemes) according to Unicode v8.0.0. " + - "For additional information see https://aka.ms/text-analytics-offsets", - isValid = { - case Left(s) => Set("TextElements_v8", "UnicodeCodePoint", "Utf16CodeUnit")(s) - case _ => true - }) + "Specifies the method used to interpret string offsets. " + + "Defaults to Text Elements(Graphemes) according to Unicode v8.0.0." + + "For more information see https://aka.ms/text-analytics-offsets", + isValid = { + case Left(s) => Set("TextElements_v8", + "UnicodeCodePoint", + "Utf16CodeUnit")(s) + case _ => true + }) def setStringIndexType(v: String): this.type = setScalarParam(stringIndexType, v) @@ -60,6 +61,36 @@ trait AnalyzeTextTaskParameters extends HasServiceParams { def getStringIndexTypeCol: String = getVectorParam(stringIndexType) + val showStats = new ServiceParam[Boolean]( + this, name = "showStats", "Whether to include detailed statistics in the response", + isURLParam = true) + + def setShowStats(v: Boolean): this.type = setScalarParam(showStats, v) + + def getShowStats: Boolean = getScalarParam(showStats) + + // We don't support setKindCol here because output schemas for different kind are different + val kind = new Param[String]( + this, "kind", "Enumeration of supported Text Analysis tasks", + isValid = validKinds.contains(_) + ) + + protected def validKinds: Set[String] + + def setKind(v: String): this.type = set(kind, v) + + def getKind: String = $(kind) + + setDefault( + showStats -> Left(false), + modelVersion -> Left("latest"), + loggingOptOut -> Left(false), + stringIndexType -> Left("TextElements_v8") + ) +} + + +trait AnalyzeTextTaskParameters extends HasAnalyzeTextServiceBaseParams { val opinionMining = new ServiceParam[Boolean]( this, name = "opinionMining", "opinionMining option for SentimentAnalysisTask") @@ -98,9 +129,6 @@ trait AnalyzeTextTaskParameters extends HasServiceParams { def getPiiCategoriesCol: String = getVectorParam(piiCategories) setDefault( - modelVersion -> Left("latest"), - loggingOptOut -> Left(false), - stringIndexType -> Left("TextElements_v8"), opinionMining -> Left(false), domain -> Left("none") ) @@ -131,33 +159,21 @@ class AnalyzeText(override val uid: String) extends CognitiveServicesBase(uid) def this() = this(Identifiable.randomUID("AnalyzeText")) - val showStats = new ServiceParam[Boolean]( - this, name = "showStats", "Whether to include detailed statistics in the response", - isURLParam = true) - - def setShowStats(v: Boolean): this.type = setScalarParam(showStats, v) - - def getShowStats: Boolean = getScalarParam(showStats) + override protected def validKinds: Set[String] = Set("EntityLinking", + "EntityRecognition", + "KeyPhraseExtraction", + "LanguageDetection", + "PiiEntityRecognition", + "SentimentAnalysis") setDefault( - apiVersion -> Left("2022-05-01"), - showStats -> Left(false) + apiVersion -> Left("2022-05-01") ) override def urlPath: String = "/language/:analyze-text" override private[ml] def internalServiceType: String = "textanalytics" - // We don't support setKindCol here because output schemas for different kind are different - val kind = new Param[String]( - this, "kind", "Enumeration of supported Text Analysis tasks", - isValid = ParamValidators.inArray(Array("EntityLinking", "EntityRecognition", "KeyPhraseExtraction", - "LanguageDetection", "PiiEntityRecognition", "SentimentAnalysis")) - ) - - def setKind(v: String): this.type = set(kind, v) - - def getKind: String = $(kind) override protected def shouldSkip(row: Row): Boolean = if (emptyParamData(row, text)) { true diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala new file mode 100644 index 0000000000..6c27ab2c6c --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala @@ -0,0 +1,594 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.language + +import com.microsoft.azure.synapse.ml.core.schema.SparkBindings +import spray.json. RootJsonFormat + +// scalastyle:off number.of.types +case class DocumentWarning(code: String, + message: String, + targetRef: Option[String]) + +object DocumentWarning extends SparkBindings[DocumentWarning] + + +case class SummaryContext(offset: Int, + length: Int) + +object SummaryContext extends SparkBindings[SummaryContext] + +//------------------------------------------------------------------------------------------------------ +// Extractive Summarization +//------------------------------------------------------------------------------------------------------ +case class ExtractiveSummarizationTaskParameters(loggingOptOut: Boolean, + modelVersion: String, + sentenceCount: Option[Int], + sortBy: Option[String], + stringIndexType: String) + +case class ExtractiveSummarizationLROTask(parameters: ExtractiveSummarizationTaskParameters, + taskName: Option[String], + kind: String) + +case class ExtractiveSummarizationJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[ExtractiveSummarizationLROTask]) + +case class ExtractedSummarySentence(text: String, + rankScore: Double, + offset: Int, + length: Int) + +case class ExtractedSummaryDocumentResult(id: String, + warnings: Seq[DocumentWarning], + statistics: Option[RequestStatistics], + sentences: Seq[ExtractedSummarySentence]) + +case class ExtractiveSummarizationResult(errors: Seq[ATError], + statistics: Option[RequestStatistics], + modelVersion: String, + documents: Seq[ExtractedSummaryDocumentResult]) + +case class ExtractiveSummarizationLROResult(results: ExtractiveSummarizationResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class ExtractiveSummarizationTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[ExtractiveSummarizationLROResult]]) + +case class ExtractiveSummarizationJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: ExtractiveSummarizationTaskResult, + statistics: Option[RequestStatistics]) + +object ExtractiveSummarizationJobState extends SparkBindings[ExtractiveSummarizationJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Abstractive Summarization +//------------------------------------------------------------------------------------------------------ +object SummaryLength extends Enumeration { + type SummaryLength = Value + val Short, Medium, Long = Value +} + +case class AbstractiveSummarizationTaskParameters(loggingOptOut: Boolean, + modelVersion: String, + sentenceCount: Option[Int], + stringIndexType: String, + summaryLength: Option[String]) + +case class AbstractiveSummarizationLROTask(parameters: AbstractiveSummarizationTaskParameters, + taskName: Option[String], + kind: String) + +case class AbstractiveSummarizationJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[AbstractiveSummarizationLROTask]) + +case class AbstractiveSummary(text: String, + contexts: Option[Seq[SummaryContext]]) + +case class AbstractiveSummaryDocumentResult(id: String, + warnings: Seq[DocumentWarning], + statistics: Option[RequestStatistics], + summaries: Seq[AbstractiveSummary]) + +object AbstractiveSummaryDocumentResult extends SparkBindings[AbstractiveSummaryDocumentResult] + +case class AbstractiveSummarizationResult(errors: Seq[ATError], + statistics: Option[RequestStatistics], + modelVersion: String, + documents: Seq[AbstractiveSummaryDocumentResult]) + +case class AbstractiveSummarizationLROResult(results: AbstractiveSummarizationResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class AbstractiveSummarizationTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[AbstractiveSummarizationLROResult]]) + +case class AbstractiveSummarizationJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: AbstractiveSummarizationTaskResult, + statistics: Option[RequestStatistics]) + +object AbstractiveSummarizationJobState extends SparkBindings[AbstractiveSummarizationJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// HealthCare +//------------------------------------------------------------------------------------------------------ +case class HealthcareTaskParameters(loggingOptOut: Boolean, + modelVersion: String, + stringIndexType: String) + +case class HealthcareLROTask(parameters: HealthcareTaskParameters, + taskName: Option[String], + kind: String) + +case class HealthcareJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[HealthcareLROTask]) + +case class HealthcareAssertion(conditionality: Option[String], + certainty: Option[String], + association: Option[String], + temporality: Option[String]) + +case class HealthcareEntitiesDocumentResult(id: String, + warnings: Seq[DocumentWarning], + statistics: Option[RequestStatistics], + entities: Seq[HealthcareEntity], + relations: Seq[HealthcareRelation], + fhirBundle: Option[String]) + +case class HealthcareEntity(text: String, + category: String, + subcategory: Option[String], + offset: Int, + length: Int, + confidenceScore: Double, + assertion: Option[HealthcareAssertion], + name: Option[String], + links: Option[Seq[HealthcareEntityLink]]) + +case class HealthcareEntityLink(dataSource: String, + id: String) + +case class HealthcareLROResult(results: HealthcareResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class HealthcareRelation(relationType: String, + entities: Seq[HealthcareRelationEntity], + confidenceScore: Option[Double]) + +case class HealthcareRelationEntity(ref: String, + role: String) + +case class HealthcareResult(errors: Seq[DocumentError], + statistics: Option[RequestStatistics], + modelVersion: String, + documents: Seq[HealthcareEntitiesDocumentResult]) + +case class HealthcareTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[HealthcareLROResult]]) + +case class HealthcareJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: HealthcareTaskResult, + statistics: Option[RequestStatistics]) + +object HealthcareJobState extends SparkBindings[HealthcareJobState] + +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Sentiment Analysis +//------------------------------------------------------------------------------------------------------ +case class SentimentAnalysisLROTask(parameters: SentimentAnalysisTaskParameters, + taskName: Option[String], + kind: String) + +case class SentimentAnalysisJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[SentimentAnalysisLROTask]) + +case class SentimentAnalysisLROResult(results: SentimentResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class SentimentAnalysisTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[SentimentAnalysisLROResult]]) + +case class SentimentAnalysisJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: SentimentAnalysisTaskResult, + statistics: Option[RequestStatistics]) + +object SentimentAnalysisJobState extends SparkBindings[SentimentAnalysisJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Key Phrase Extraction +//------------------------------------------------------------------------------------------------------ +case class KeyPhraseExtractionLROTask(parameters: KPnLDTaskParameters, + taskName: Option[String], + kind: String) + +case class KeyPhraseExtractionJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[KeyPhraseExtractionLROTask]) + +case class KeyPhraseExtractionLROResult(results: KeyPhraseExtractionResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class KeyPhraseExtractionTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[KeyPhraseExtractionLROResult]]) + +case class KeyPhraseExtractionJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: KeyPhraseExtractionTaskResult, + statistics: Option[RequestStatistics]) + +object KeyPhraseExtractionJobState extends SparkBindings[KeyPhraseExtractionJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// PII Entity Recognition +//------------------------------------------------------------------------------------------------------ +object PiiDomain extends Enumeration { + type PiiDomain = Value + val None, Phi = Value +} + +case class PiiEntityRecognitionLROTask(parameters: PiiTaskParameters, + taskName: Option[String], + kind: String) + +case class PiiEntityRecognitionJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[PiiEntityRecognitionLROTask]) + +case class PiiEntityRecognitionLROResult(results: PIIResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class PiiEntityRecognitionTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[PiiEntityRecognitionLROResult]]) + +case class PiiEntityRecognitionJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: PiiEntityRecognitionTaskResult, + statistics: Option[RequestStatistics]) + +object PiiEntityRecognitionJobState extends SparkBindings[PiiEntityRecognitionJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Entity Linking +//------------------------------------------------------------------------------------------------------ +case class EntityLinkingLROTask(parameters: EntityTaskParameters, + taskName: Option[String], + kind: String) + + +case class EntityLinkingJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[EntityLinkingLROTask]) + + +case class EntityLinkingLROResult(results: EntityLinkingResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class EntityLinkingTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[EntityLinkingLROResult]]) + +case class EntityLinkingJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: EntityLinkingTaskResult, + statistics: Option[RequestStatistics]) + +object EntityLinkingJobState extends SparkBindings[EntityLinkingJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Entity Recognition +//------------------------------------------------------------------------------------------------------ + +case class EntityRecognitionTaskParameters(loggingOptOut: Boolean, + modelVersion: String, + stringIndexType: String, + inclusionList: Option[Seq[String]], + exclusionList: Option[Seq[String]], + overlapPolicy: Option[EntityOverlapPolicy], + inferenceOptions: Option[EntityInferenceOptions]) + +case class EntityOverlapPolicy(policyKind: String) + +case class EntityInferenceOptions(excludeNormalizedValues: Boolean) + +case class EntityRecognitionLROTask(parameters: EntityRecognitionTaskParameters, + taskName: Option[String], + kind: String) + +case class EntityRecognitionJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[EntityRecognitionLROTask]) + +case class EntityRecognitionLROResult(results: EntityRecognitionResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class EntityRecognitionTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[EntityRecognitionLROResult]]) + +case class EntityRecognitionJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: EntityRecognitionTaskResult, + statistics: Option[RequestStatistics]) + +object EntityRecognitionJobState extends SparkBindings[EntityRecognitionJobState] +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Custom Entity Recognoition +//------------------------------------------------------------------------------------------------------ +case class CustomEntitiesTaskParameters(loggingOptOut: Boolean, + stringIndexType: String, + deploymentName: String, + projectName: String) + +case class CustomEntityRecognitionLROTask(parameters: CustomEntitiesTaskParameters, + taskName: Option[String], + kind: String) + +case class CustomEntitiesJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[CustomEntityRecognitionLROTask]) +//------------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------------ +// Custom Label Classification +//------------------------------------------------------------------------------------------------------ +case class CustomLabelTaskParameters(loggingOptOut: Boolean, + deploymentName: String, + projectName: String) + +case class CustomLabelLROTask(parameters: CustomLabelTaskParameters, + taskName: Option[String], + kind: String) + +case class CustomLabelJobsInput(displayName: Option[String], + analysisInput: MultiLanguageAnalysisInput, + tasks: Seq[CustomLabelLROTask]) + +case class ClassificationDocumentResult(id: String, + warnings: Seq[DocumentWarning], + statistics: Option[RequestStatistics], + classifications: Seq[ClassificationResult]) + +//object ClassificationDocumentResult extends SparkBindings[ClassificationDocumentResult] + +case class ClassificationResult(category: String, + confidenceScore: Double) + +object ClassificationResult extends SparkBindings[ClassificationResult] + +case class CustomLabelResult(errors: Seq[DocumentError], + statistics: Option[RequestStatistics], + modelVersion: String, + documents: Seq[ClassificationDocumentResult]) + +case class CustomLabelLROResult(results: CustomLabelResult, + lastUpdateDateTime: String, + status: String, + taskName: Option[String], + kind: String) + +case class CustomLabelTaskResult(completed: Int, + failed: Int, + inProgress: Int, + total: Int, + items: Option[Seq[CustomLabelLROResult]]) + +case class CustomLabelJobState(displayName: Option[String], + createdDateTime: String, + expirationDateTime: Option[String], + jobId: String, + lastUpdatedDateTime: String, + status: String, + errors: Option[Seq[String]], + nextLink: Option[String], + tasks: CustomLabelTaskResult, + statistics: Option[RequestStatistics]) + +object CustomLabelJobState extends SparkBindings[CustomLabelJobState] +//------------------------------------------------------------------------------------------------------ + + +object ATLROJSONFormat { + + import spray.json.DefaultJsonProtocol._ + import ATJSONFormat._ + + implicit val DocumentWarningFormat: RootJsonFormat[DocumentWarning] = + jsonFormat3(DocumentWarning.apply) + + implicit val ExtractiveSummarizationTaskParametersF: RootJsonFormat[ExtractiveSummarizationTaskParameters] = + jsonFormat5(ExtractiveSummarizationTaskParameters.apply) + + implicit val ExtractiveSummarizationLROTaskF: RootJsonFormat[ExtractiveSummarizationLROTask] = + jsonFormat3(ExtractiveSummarizationLROTask.apply) + + implicit val ExtractiveSummarizationJobsInputF: RootJsonFormat[ExtractiveSummarizationJobsInput] = + jsonFormat3(ExtractiveSummarizationJobsInput.apply) + + implicit val AbstractiveSummarizationTaskParametersF: RootJsonFormat[AbstractiveSummarizationTaskParameters] = + jsonFormat5(AbstractiveSummarizationTaskParameters.apply) + + implicit val AbstractiveSummarizationLROTaskF: RootJsonFormat[AbstractiveSummarizationLROTask] = + jsonFormat3(AbstractiveSummarizationLROTask.apply) + + implicit val AbstractiveSummarizationJobsInputF: RootJsonFormat[AbstractiveSummarizationJobsInput] = + jsonFormat3(AbstractiveSummarizationJobsInput.apply) + + implicit val HealthcareTaskParametersF: RootJsonFormat[HealthcareTaskParameters] = + jsonFormat3(HealthcareTaskParameters.apply) + + implicit val HealthcareLROTaskF: RootJsonFormat[HealthcareLROTask] = + jsonFormat3(HealthcareLROTask.apply) + + implicit val HealthcareJobsInputF: RootJsonFormat[HealthcareJobsInput] = + jsonFormat3(HealthcareJobsInput.apply) + + implicit val SentimentAnalysisLROTaskF: RootJsonFormat[SentimentAnalysisLROTask] = + jsonFormat3(SentimentAnalysisLROTask.apply) + + implicit val SentimentAnalysisJobsInputF: RootJsonFormat[SentimentAnalysisJobsInput] = + jsonFormat3(SentimentAnalysisJobsInput.apply) + + implicit val KeyPhraseExtractionLROTaskF: RootJsonFormat[KeyPhraseExtractionLROTask] = + jsonFormat3(KeyPhraseExtractionLROTask.apply) + + implicit val KeyPhraseExtractionJobsInputF: RootJsonFormat[KeyPhraseExtractionJobsInput] = + jsonFormat3(KeyPhraseExtractionJobsInput.apply) + + implicit val PiiEntityRecognitionLROTaskF: RootJsonFormat[PiiEntityRecognitionLROTask] = + jsonFormat3(PiiEntityRecognitionLROTask.apply) + + implicit val PiiEntityRecognitionJobsInputF: RootJsonFormat[PiiEntityRecognitionJobsInput] = + jsonFormat3(PiiEntityRecognitionJobsInput.apply) + + implicit val EntityLinkingLROTaskF: RootJsonFormat[EntityLinkingLROTask] = + jsonFormat3(EntityLinkingLROTask.apply) + + implicit val EntityLinkingJobsInputF: RootJsonFormat[EntityLinkingJobsInput] = + jsonFormat3(EntityLinkingJobsInput.apply) + + implicit val EntityOverlapPolicyF: RootJsonFormat[EntityOverlapPolicy] = + jsonFormat1(EntityOverlapPolicy.apply) + + implicit val EntityInferenceOptionsF: RootJsonFormat[EntityInferenceOptions] = + jsonFormat1(EntityInferenceOptions.apply) + + implicit val EntityRecognitionTaskParametersF: RootJsonFormat[EntityRecognitionTaskParameters] = + jsonFormat7(EntityRecognitionTaskParameters.apply) + + implicit val EntityRecognitionLROTaskF: RootJsonFormat[EntityRecognitionLROTask] = + jsonFormat3(EntityRecognitionLROTask.apply) + + implicit val EntityRecognitionJobsInputF: RootJsonFormat[EntityRecognitionJobsInput] = + jsonFormat3(EntityRecognitionJobsInput.apply) + + implicit val CustomEntitiesTaskParametersF: RootJsonFormat[CustomEntitiesTaskParameters] = + jsonFormat4(CustomEntitiesTaskParameters.apply) + + implicit val CustomEntityRecognitionLROTaskF: RootJsonFormat[CustomEntityRecognitionLROTask] = + jsonFormat3(CustomEntityRecognitionLROTask.apply) + + implicit val CustomEntitiesJobsInputF: RootJsonFormat[CustomEntitiesJobsInput] = + jsonFormat3(CustomEntitiesJobsInput.apply) + + implicit val CustomSingleLabelTaskParametersF: RootJsonFormat[CustomLabelTaskParameters] = + jsonFormat3(CustomLabelTaskParameters.apply) + + implicit val CustomSingleLabelLROTaskF: RootJsonFormat[CustomLabelLROTask] = + jsonFormat3(CustomLabelLROTask.apply) + + implicit val CustomSingleLabelJobsInputF: RootJsonFormat[CustomLabelJobsInput] = + jsonFormat3(CustomLabelJobsInput.apply) +} diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala new file mode 100644 index 0000000000..7f33107e7b --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala @@ -0,0 +1,638 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.language + +import com.microsoft.azure.synapse.ml.io.http.{ EntityData, HTTPResponseData } +import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging +import com.microsoft.azure.synapse.ml.param.ServiceParam +import com.microsoft.azure.synapse.ml.services.HasServiceParams +import com.microsoft.azure.synapse.ml.services.language.ATLROJSONFormat._ +import com.microsoft.azure.synapse.ml.services.language.PiiDomain.PiiDomain +import com.microsoft.azure.synapse.ml.services.language.SummaryLength.SummaryLength +import com.microsoft.azure.synapse.ml.services.vision.BasicAsyncReply +import org.apache.commons.io.IOUtils +import org.apache.http.impl.client.CloseableHttpClient +import org.apache.spark.ml.param.ParamValidators +import org.apache.spark.sql.Row +import spray.json.DefaultJsonProtocol._ +import spray.json.enrichAny + +import java.net.URI + +object AnalysisTaskKind extends Enumeration { + type AnalysisTaskKind = Value + val SentimentAnalysis, + EntityRecognition, + PiiEntityRecognition, + KeyPhraseExtraction, + EntityLinking, + Healthcare, + CustomEntityRecognition, + CustomSingleLabelClassification, + CustomMultiLabelClassification, + ExtractiveSummarization, + AbstractiveSummarization = Value + + def getKindFromString(kind: String): AnalysisTaskKind = { + AnalysisTaskKind.values.find(_.toString == kind).getOrElse( + throw new IllegalArgumentException(s"Invalid kind: $kind") + ) + } +} + +private[language] trait HasSummarizationBaseParameter extends HasServiceParams { + val sentenceCount = new ServiceParam[Int]( + this, + name = "sentenceCount", + doc = "Specifies the number of sentences in the extracted summary.", + isValid = { case Left(value) => value >= 1 case Right(_) => true } + ) + + def getSentenceCount: Int = getScalarParam(sentenceCount) + + def setSentenceCount(value: Int): this.type = setScalarParam(sentenceCount, value) + + def getSentenceCountCol: String = getVectorParam(sentenceCount) + + def setSentenceCountCol(value: String): this.type = setVectorParam(sentenceCount, value) +} + +/** + * This trait is used to handle the extractive summarization request. It provides the necessary + * parameters to create the request and the method to create the request. There are two + * parameters for extractive summarization: sentenceCount and sortBy. Both of them are optional. + * If the user does not provide any value for sentenceCount, the service will return the default + * number of sentences in the summary. If the user does not provide any value for sortBy, the + * service will return the summary in the order of the sentences in the input text. The possible values + * for sortBy are "Rank" and "Offset". If the user provides an invalid value for sortBy, the service + * will return an error. SentenceCount is an integer value, and it should be greater than 0. This parameter + * specifies the number of sentences in the extracted summary. If the user provides an invalid value for + * sentenceCount, the service will return an error. For more details about the parameters, please refer to + * the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/summarization/overview]] + */ +private[language] trait HandleExtractiveSummarization extends HasServiceParams + with HasSummarizationBaseParameter { + val sortBy = new ServiceParam[String]( + this, + name = "sortBy", + doc = "Specifies how to sort the extracted summaries. This can be either 'Rank' or 'Offset'.", + isValid = { + case Left(value) => ParamValidators.inArray(Array("Rank", "Offset"))(value) + case Right(_) => true + }) + + def getSortBy: String = getScalarParam(sortBy) + + def setSortBy(value: String): this.type = setScalarParam(sortBy, value) + + def getSortByCol: String = getVectorParam(sortBy) + + def setSortByCol(value: String): this.type = setVectorParam(sortBy, value) + + private[language] def createExtractiveSummarizationRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = ExtractiveSummarizationLROTask( + parameters = ExtractiveSummarizationTaskParameters( + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + sentenceCount = getValueOpt(row, sentenceCount), + sortBy = getValueOpt(row, sortBy), + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.ExtractiveSummarization.toString + ) + + ExtractiveSummarizationJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the abstractive summarization request. It provides the necessary + * parameters to create the request and the method to create the request. There are two + * parameters for abstractive summarization: sentenceCount and summaryLength. Both of them are optional. + * It is recommended to use summaryLength over sentenceCount. Service may ignore sentenceCount parameter. + * SummaryLength is a string value, and it should be one of "short", "medium", or "long". This parameter + * controls the approximate length of the output summaries. If the user provides an invalid value for + * summaryLength, the service will return an error. For more details about the parameters, please refer to + * the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/summarization/overview]] + */ +private[language] trait HandleAbstractiveSummarization extends HasServiceParams with HasSummarizationBaseParameter { + val summaryLength = new ServiceParam[String]( + this, + name = "summaryLength", + doc = "(NOTE: Recommended to use summaryLength over sentenceCount) Controls the" + + " approximate length of the output summaries.", + isValid = { + case Left(value) => ParamValidators.inArray(Array("short", "medium", "long"))(value) + case Right(_) => true + } + + ) + + def getSummaryLength: String = getScalarParam(summaryLength) + + def setSummaryLength(value: String): this.type = setScalarParam(summaryLength, value) + + def setSummaryLength(value: SummaryLength): this.type = setScalarParam(summaryLength, value.toString.toLowerCase) + + def getSummaryLengthCol: String = getVectorParam(summaryLength) + + def setSummaryLengthCol(value: String): this.type = setVectorParam(summaryLength, value) + + private[language] def createAbstractiveSummarizationRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val paramerter = AbstractiveSummarizationLROTask( + parameters = AbstractiveSummarizationTaskParameters( + sentenceCount = getValueOpt(row, sentenceCount), + summaryLength = getValueOpt(row, summaryLength), + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + stringIndexType = stringIndexType), + taskName = None, + kind = AnalysisTaskKind.AbstractiveSummarization.toString + ) + AbstractiveSummarizationJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(paramerter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the healthcare text analytics request. It provides the necessary parameters + * to create the request and the method to create the request. There are three parameters for healthcare text + * analytics: modelVersion, stringIndexType, and loggingOptOut. All of them are optional. For more details about + * the parameters, please refer to the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/text-analytics-for-health/overview]] + */ +private[language] trait HandleHealthcareTextAnalystics extends HasServiceParams { + private[language] def createHealthcareTextAnalyticsRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = HealthcareLROTask( + parameters = HealthcareTaskParameters( + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.Healthcare.toString + ) + HealthcareJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the text analytics request. It provides the necessary parameters to create + * the request and the method to create the request. There are three parameters for text analytics: modelVersion, + * stringIndexType, and loggingOptOut. All of them are optional. For more details about the parameters, please refer + * to the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/sentiment-opinion-mining/overview]] + */ +private[language] trait HandleSentimentAnalysis extends HasServiceParams { + val opinionMining = new ServiceParam[Boolean]( + this, + name = "opinionMining", + doc = "Whether to use opinion mining in the request or not." + ) + + def getOpinionMining: Boolean = getScalarParam(opinionMining) + + def setOpinionMining(value: Boolean): this.type = setScalarParam(opinionMining, value) + + def getOpinionMiningCol: String = getVectorParam(opinionMining) + + def setOpinionMiningCol(value: String): this.type = setVectorParam(opinionMining, value) + + setDefault( + opinionMining -> Left(false) + ) + + private[language] def createSentimentAnalysisRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = SentimentAnalysisLROTask( + parameters = SentimentAnalysisTaskParameters( + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + opinionMining = getValue(row, opinionMining), + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.SentimentAnalysis.toString + ) + SentimentAnalysisJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the key phrase extraction request. It provides the necessary parameters to create + * the request and the method to create the request. There are two parameters for key phrase extraction: modelVersion + * and loggingOptOut. Both of them are optional. For more details about the parameters, + * please refer to the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/key-phrase-extraction/overview]] + */ +private[language] trait HandleKeyPhraseExtraction extends HasServiceParams { + private[language] def createKeyPhraseExtractionRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + // This parameter is not used and only exists for compatibility + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = KeyPhraseExtractionLROTask( + parameters = KPnLDTaskParameters( + loggingOptOut = loggingOptOut, + modelVersion = modelVersion + ), + taskName = None, + kind = AnalysisTaskKind.KeyPhraseExtraction.toString + ) + KeyPhraseExtractionJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +private[language] trait HandleEntityLinking extends HasServiceParams { + private[language] def createEntityLinkingRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = EntityLinkingLROTask( + parameters = EntityTaskParameters( + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.EntityLinking.toString + ) + EntityLinkingJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the PII entity recognition request. It provides the necessary parameters to create + * the request and the method to create the request. There are three parameters for PII entity recognition: domain, + * piiCategories, and loggingOptOut. All of them are optional. For more details about the parameters, please refer to + * the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/personally-identifiable-information/overview]] + */ +private[language] trait HandlePiiEntityRecognition extends HasServiceParams { + val domain = new ServiceParam[String]( + this, + name = "domain", + doc = "The domain of the PII entity recognition request.", + isValid = { + case Left(value) => PiiDomain.values.map(_.toString.toLowerCase).contains(value) + case Right(_) => true + } + ) + + def getDomain: String = getScalarParam(domain) + + def setDomain(value: String): this.type = setScalarParam(domain, value) + + def setDomain(value: PiiDomain): this.type = setScalarParam(domain, value.toString.toLowerCase) + + def getDomainCol: String = getVectorParam(domain) + + def setDomainCol(value: String): this.type = setVectorParam(domain, value) + + val piiCategories = new ServiceParam[Seq[String]](this, "piiCategories", + "describes the PII categories to return") + + def setPiiCategories(v: Seq[String]): this.type = setScalarParam(piiCategories, v) + + def getPiiCategories: Seq[String] = getScalarParam(piiCategories) + + def setPiiCategoriesCol(v: String): this.type = setVectorParam(piiCategories, v) + + def getPiiCategoriesCol: String = getVectorParam(piiCategories) + + setDefault( + domain -> Left("none") + ) + + private[language] def createPiiEntityRecognitionRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = PiiEntityRecognitionLROTask( + parameters = PiiTaskParameters( + domain = getValue(row, domain), + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + piiCategories = getValueOpt(row, piiCategories), + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.PiiEntityRecognition.toString + ) + PiiEntityRecognitionJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * This trait is used to handle the entity recognition request. It provides the necessary parameters to create + * the request and the method to create the request. There are five parameters for entity recognition: inclusionList, + * exclusionList, overlapPolicy, excludeNormalizedValues, and loggingOptOut. All of them are optional. For more details + * about the parameters, please refer to the documentation. + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/named-entity-recognition/overview]] + */ +private[language] trait HandleEntityRecognition extends HasServiceParams { + val inclusionList = new ServiceParam[Seq[String]]( + this, + name = "inclusionList", + doc = "(Optional) request parameter that limits the output to the requested entity" + + " types included in this list. We will apply inclusionList before" + + " exclusionList" + ) + + def getInclusionList: Seq[String] = getScalarParam(inclusionList) + + def setInclusionList(value: Seq[String]): this.type = setScalarParam(inclusionList, value) + + def getInclusionListCol: String = getVectorParam(inclusionList) + + def setInclusionListCol(value: String): this.type = setVectorParam(inclusionList, value) + + val exclusionList = new ServiceParam[Seq[String]]( + this, + name = "exclusionList", + doc = "(Optional) request parameter that filters out any entities that are" + + " included the excludeList. When a user specifies an excludeList, they cannot" + + " get a prediction returned with an entity in that list. We will apply" + + " inclusionList before exclusionList" + ) + + def getExclusionList: Seq[String] = getScalarParam(exclusionList) + + def setExclusionList(value: Seq[String]): this.type = setScalarParam(exclusionList, value) + + def getExclusionListCol: String = getVectorParam(exclusionList) + + def setExclusionListCol(value: String): this.type = setVectorParam(exclusionList, value) + + val overlapPolicy = new ServiceParam[String]( + this, + name = "overlapPolicy", + doc = "(Optional) describes the type of overlap policy to apply to the ner output.", + isValid = { + case Left(value) => value == "matchLongest" || value == "allowOverlap" + case Right(_) => true + }) + + def getOverlapPolicy: String = getScalarParam(overlapPolicy) + + def setOverlapPolicy(value: String): this.type = setScalarParam(overlapPolicy, value) + + def getOverlapPolicyCol: String = getVectorParam(overlapPolicy) + + def setOverlapPolicyCol(value: String): this.type = setVectorParam(overlapPolicy, value) + + val excludeNormalizedValues = new ServiceParam[Boolean]( + this, + name = "excludeNormalizedValues", + doc = "(Optional) request parameter that allows the user to provide settings for" + + " running the inference. If set to true, the service will exclude normalized" + ) + + def getExcludeNormalizedValues: Boolean = getScalarParam(excludeNormalizedValues) + + def setExcludeNormalizedValues(value: Boolean): this.type = setScalarParam(excludeNormalizedValues, value) + + def getExcludeNormalizedValuesCol: String = getVectorParam(excludeNormalizedValues) + + def setexcludeNormalizedValuesCol(value: String): this.type = setVectorParam(excludeNormalizedValues, value) + + private[language] def createEntityRecognitionRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val serviceOverlapPolicy: Option[EntityOverlapPolicy] = getValueOpt(row, overlapPolicy) match { + case Some(policy) => Some(EntityOverlapPolicy(policy)) + case None => None + } + + val inferenceOptions: Option[EntityInferenceOptions] = getValueOpt(row, excludeNormalizedValues) match { + case Some(value) => Some(EntityInferenceOptions(value)) + case None => None + } + val taskParameter = EntityRecognitionLROTask( + parameters = EntityRecognitionTaskParameters( + exclusionList = getValueOpt(row, exclusionList), + inclusionList = getValueOpt(row, inclusionList), + loggingOptOut = loggingOptOut, + modelVersion = modelVersion, + overlapPolicy = serviceOverlapPolicy, + stringIndexType = stringIndexType, + inferenceOptions = inferenceOptions + ), + taskName = None, + kind = AnalysisTaskKind.EntityRecognition.toString + ) + EntityRecognitionJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +private[language] trait HasCustomLanguageModelParam extends HasServiceParams { + val projectName = new ServiceParam[String]( + this, + name = "projectName", + doc = "This field indicates the project name for the model. This is a required field" + ) + + def getProjectName: String = getScalarParam(projectName) + + def setProjectName(value: String): this.type = setScalarParam(projectName, value) + + def getProjectNameCol: String = getVectorParam(projectName) + + def setProjectNameCol(value: String): this.type = setVectorParam(projectName, value) + + val deploymentName = new ServiceParam[String]( + this, + name = "deploymentName", + doc = "This field indicates the deployment name for the model. This is a required field." + ) + + def getDeploymentName: String = getScalarParam(deploymentName) + + def setDeploymentName(value: String): this.type = setScalarParam(deploymentName, value) + + def getDeploymentNameCol: String = getVectorParam(deploymentName) + + def setDeploymentNameCol(value: String): this.type = setVectorParam(deploymentName, value) +} + +private[language] trait HandleCustomEntityRecognition extends HasServiceParams + with HasCustomLanguageModelParam { + + private[language] def createCustomEntityRecognitionRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + // This paremeter is not used and only exists for compatibility + modelVersion: String, + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = CustomEntityRecognitionLROTask( + parameters = CustomEntitiesTaskParameters( + loggingOptOut = loggingOptOut, + projectName = getValue(row, projectName), + deploymentName = getValue(row, deploymentName), + stringIndexType = stringIndexType + ), + taskName = None, + kind = AnalysisTaskKind.CustomEntityRecognition.toString) + CustomEntitiesJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} + +/** + * Trait `ModifiableAsyncReply` extends `BasicAsyncReply` and provides a mechanism to modify the HTTP response + * received from an asynchronous service call. This trait is designed to be mixed into classes that require + * custom handling of the response data. + * + * The primary purpose of this trait is to allow modification of the response before it is processed further. + * This is particularly useful in scenarios where the response needs to be transformed or certain fields need + * to be renamed to comply with specific requirements or constraints. + * + * In this implementation, the `queryForResult` method is overridden and marked as `final` to prevent further + * overriding. This ensures that the response modification logic is consistently applied across all subclasses. + * + * @note This trait is designed to be used with the `SynapseMLLogging` trait for consistent logging. + */ +trait ModifiableAsyncReply extends BasicAsyncReply { + self: SynapseMLLogging => + + protected def modifyResponse(response: Option[HTTPResponseData]): Option[HTTPResponseData] = response + + /** + * Queries for the result of an asynchronous service call and applies the response modification logic. + */ + override final protected def queryForResult(key: Option[String], + client: CloseableHttpClient, + location: URI): Option[HTTPResponseData] = { + val originalResponse = super.queryForResult(key, client, location) + logDebug(s"Original response: $originalResponse") + modifyResponse(originalResponse) + } +} + + +/** + * Trait `HandleCustomLabelClassification` extends `HasServiceParams` and `HasCustomLanguageModelParam` to handle + * custom label classification tasks. This trait provides the necessary methods to create requests for custom + * multi-label classification and to modify the response to comply with specific requirements. + * + * The primary purpose of this trait is to address the limitation in Spark where fields named "class" cannot be + * directly bound. To work around this limitation, the response is modified to rename the "class" field to + * "classifications". + * + * This trait is designed to be mixed into classes that require custom label classification functionality and + * response modification logic. + * + * @note This trait is designed to be used with the `ModifiableAsyncReply` and `SynapseMLLogging` traits for + * consistent response handling and logging. + */ +private[language] trait HandleCustomLabelClassification extends HasServiceParams + with HasCustomLanguageModelParam { + self: ModifiableAsyncReply + with SynapseMLLogging => + + private def isCustomLabelClassification: Boolean = { + val kind = getKind + kind == AnalysisTaskKind.CustomSingleLabelClassification.toString || + kind == AnalysisTaskKind.CustomMultiLabelClassification.toString + } + + /** + * Modifies the entity in the HTTP response to rename the "class" field to "classifications". + * + * @param response The original HTTP response. + * @return The modified HTTP response with the "class" field renamed to "classifications". + */ + private def modifyEntity(response: HTTPResponseData): HTTPResponseData = { + val modifiedEntity = response.entity.flatMap { entity => + val strEntity = IOUtils.toString(entity.content, "UTF-8") + val modifiedEntity = strEntity.replace("\"class\":", "\"classifications\":") + logDebug(s"Original entity: $strEntity\t Modified entity: $modifiedEntity") + Some(new EntityData( + content = modifiedEntity.getBytes, + contentEncoding = entity.contentEncoding, + contentLength = Some(strEntity.length), + contentType = entity.contentType, + isChunked = entity.isChunked, + isRepeatable = entity.isRepeatable, + isStreaming = entity.isStreaming + )) + } + new HTTPResponseData(response.headers, modifiedEntity, response.statusLine, response.locale) + } + + /** + * Modifies the HTTP response if the task kind is custom label classification. + */ + override def modifyResponse(response: Option[HTTPResponseData]): Option[HTTPResponseData] = { + if (!isCustomLabelClassification) { + logDebug(s"Kind is not CustomSingleLabelClassification or CustomMultiLabelClassification. Kind: $getKind") + response + } else { + response.map(modifyEntity) + } + } + + + def getKind: String + + private[language] def createCustomMultiLabelRequest(row: Row, + analysisInput: MultiLanguageAnalysisInput, + // This paremeter is not used and only exists for compatibility + modelVersion: String, + // This paremeter is not used and only exists for compatibility + stringIndexType: String, + loggingOptOut: Boolean): String = { + val taskParameter = CustomLabelLROTask( + parameters = CustomLabelTaskParameters( + loggingOptOut = loggingOptOut, + projectName = getValue(row, projectName), + deploymentName = getValue(row, deploymentName) + ), + taskName = None, + kind = getKind + ) + CustomLabelJobsInput(displayName = None, + analysisInput = analysisInput, + tasks = Seq(taskParameter)).toJson.compactPrint + } +} diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala new file mode 100644 index 0000000000..cc5d0f1ed4 --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala @@ -0,0 +1,243 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.language + +import com.microsoft.azure.synapse.ml.logging.{ FeatureNames, SynapseMLLogging } +import com.microsoft.azure.synapse.ml.services.{ CognitiveServicesBaseNoHandler, HasAPIVersion, + HasCognitiveServiceInput, HasInternalJsonOutputParser, HasSetLocation } +import com.microsoft.azure.synapse.ml.services.text.{ TADocument, TextAnalyticsAutoBatch } +import com.microsoft.azure.synapse.ml.services.vision.BasicAsyncReply +import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer } +import org.apache.http.entity.{ AbstractHttpEntity, StringEntity } +import org.apache.spark.injections.UDFUtils +import org.apache.spark.ml.{ ComplexParamsReadable, NamespaceInjections, PipelineModel } +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.types.{ ArrayType, DataType, StructType } +import org.apache.spark.sql.Row +import org.apache.spark.sql.expressions.UserDefinedFunction + +import java.net.URI + +object AnalyzeTextLongRunningOperations extends ComplexParamsReadable[AnalyzeTextLongRunningOperations] + with Serializable + +/** + *

+ * This transformer is used to analyze text using the Azure AI Language service. It uses AI service asynchronously. + * For more details please visit + * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/concepts/use-asynchronously]] + * For each row, it submits a job to the service and polls the service until the job is complete. Delay between + * polling requests is controlled by the [[pollingDelay]] parameter, which is set to 1000 milliseconds by default. + * Number of polling attempts is controlled by the [[maxPollingRetries]] parameter, which is set to 1000 by default. + *

+ *

+ * This transformer will use the field specified as TextCol to submit the text to the service. The response from the + * service will be stored in the field specified as OutputCol. The response will be a struct with the + * following fields: + *

+ *

+ *

+ * This transformer support only single task per row. The task to be performed is specified by the [[kind]] parameter. + * The supported tasks are: + *

+ * Each task has its own set of parameters that can be set to control the behavior of the service and response + * schema. + *

+ */ +class AnalyzeTextLongRunningOperations(override val uid: String) extends CognitiveServicesBaseNoHandler(uid) + with HasAPIVersion + with ModifiableAsyncReply + with HasCognitiveServiceInput + with HasInternalJsonOutputParser + with HasSetLocation + with TextAnalyticsAutoBatch + with SynapseMLLogging + with HasAnalyzeTextServiceBaseParams + with HasBatchSize + with HandleExtractiveSummarization + with HandleAbstractiveSummarization + with HandleHealthcareTextAnalystics + with HandleSentimentAnalysis + with HandleKeyPhraseExtraction + with HandlePiiEntityRecognition + with HandleEntityLinking + with HandleEntityRecognition + with HandleCustomEntityRecognition + with HandleCustomLabelClassification { + logClass(FeatureNames.AiServices.Language) + + def this() = this(Identifiable.randomUID("AnalyzeTextLongRunningOperations")) + + override private[ml] def internalServiceType: String = "textanalytics" + + override def urlPath: String = "/language/analyze-text/jobs" + + override protected def validKinds: Set[String] = responseDataTypeSchemaMap.keySet.map(_.toString) + + setDefault( + apiVersion -> Left("2023-04-01"), + showStats -> Left(false), + batchSize -> 10, + pollingDelay -> 1000 + ) + + def setKind(value: AnalysisTaskKind.AnalysisTaskKind): this.type = set(kind, value.toString) + + override protected def shouldSkip(row: Row): Boolean = emptyParamData(row, text) || super.shouldSkip(row) + + /** + * Modifies the polling URI to include the showStats parameter if enabled. + */ + override protected def modifyPollingURI(originalURI: URI): URI = { + if (getShowStats) { + new URI(s"${ originalURI.toString }&showStats=true") + } else { + originalURI + } + } + + // This method is made package private for testing purposes + private[language] val responseDataTypeSchemaMap: Map[AnalysisTaskKind.AnalysisTaskKind, StructType] = Map( + AnalysisTaskKind.ExtractiveSummarization -> ExtractiveSummarizationJobState.schema, + AnalysisTaskKind.AbstractiveSummarization -> AbstractiveSummarizationJobState.schema, + AnalysisTaskKind.Healthcare -> HealthcareJobState.schema, + AnalysisTaskKind.SentimentAnalysis -> SentimentAnalysisJobState.schema, + AnalysisTaskKind.KeyPhraseExtraction -> KeyPhraseExtractionJobState.schema, + AnalysisTaskKind.PiiEntityRecognition -> PiiEntityRecognitionJobState.schema, + AnalysisTaskKind.EntityLinking -> EntityLinkingJobState.schema, + AnalysisTaskKind.EntityRecognition -> EntityRecognitionJobState.schema, + AnalysisTaskKind.CustomEntityRecognition -> EntityRecognitionJobState.schema, + AnalysisTaskKind.CustomSingleLabelClassification -> CustomLabelJobState.schema, + AnalysisTaskKind.CustomMultiLabelClassification -> CustomLabelJobState.schema + ) + + override protected def responseDataType: DataType = { + val taskKind = AnalysisTaskKind.getKindFromString(getKind) + responseDataTypeSchemaMap(taskKind) + } + + // This method is made package private for testing purposes + private[language] val requestCreatorMap: Map[AnalysisTaskKind.AnalysisTaskKind, + (Row, MultiLanguageAnalysisInput, String, String, Boolean) => String] = Map( + AnalysisTaskKind.ExtractiveSummarization -> createExtractiveSummarizationRequest, + AnalysisTaskKind.AbstractiveSummarization -> createAbstractiveSummarizationRequest, + AnalysisTaskKind.Healthcare -> createHealthcareTextAnalyticsRequest, + AnalysisTaskKind.SentimentAnalysis -> createSentimentAnalysisRequest, + AnalysisTaskKind.KeyPhraseExtraction -> createKeyPhraseExtractionRequest, + AnalysisTaskKind.PiiEntityRecognition -> createPiiEntityRecognitionRequest, + AnalysisTaskKind.EntityLinking -> createEntityLinkingRequest, + AnalysisTaskKind.EntityRecognition -> createEntityRecognitionRequest, + AnalysisTaskKind.CustomEntityRecognition -> createCustomEntityRecognitionRequest, + AnalysisTaskKind.CustomSingleLabelClassification -> createCustomMultiLabelRequest, + AnalysisTaskKind.CustomMultiLabelClassification -> createCustomMultiLabelRequest + ) + + // This method is made package private for testing purposes + override protected[language] def prepareEntity: Row => Option[AbstractHttpEntity] = row => { + val analysisInput = createMultiLanguageAnalysisInput(row) + val taskKind = AnalysisTaskKind.getKindFromString(getKind) + val requestString = requestCreatorMap(taskKind)(row, + analysisInput, + getValue(row, modelVersion), + getValue(row, stringIndexType), + getValue(row, loggingOptOut)) + Some(new StringEntity(requestString, "UTF-8")) + } + + protected def postprocessResponse(responseOpt: Row): Option[Seq[Row]] = { + Option(responseOpt).map { response => + val tasks = response.getAs[Row]("tasks") + val items = tasks.getAs[Seq[Row]]("items") + items.flatMap(item => { + val results = item.getAs[Row]("results") + val stats = results.getAs[Row]("statistics") + val docs = results.getAs[Seq[Row]]("documents").map( + doc => (doc.getAs[String]("id"), doc)).toMap + val errors = results.getAs[Seq[Row]]("errors").map( + error => (error.getAs[String]("id"), error)).toMap + val modelVersion = results.getAs[String]("modelVersion") + (0 until (docs.size + errors.size)).map { i => + Row.fromSeq(Seq( + stats, + docs.get(i.toString), + errors.get(i.toString), + modelVersion + )) + } + }) + } + } + + protected def postprocessResponseUdf: UserDefinedFunction = { + val responseType = responseDataType.asInstanceOf[StructType] + val results = responseType("tasks").dataType.asInstanceOf[StructType]("items") + .dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]("results") + .dataType.asInstanceOf[StructType] + val outputType = ArrayType( + new StructType() + .add("statistics", results("statistics").dataType) + .add("documents", results("documents").dataType.asInstanceOf[ArrayType].elementType) + .add("errors", results("errors").dataType.asInstanceOf[ArrayType].elementType) + .add("modelVersion", results("modelVersion").dataType) + ) + UDFUtils.oldUdf(postprocessResponse _, outputType) + } + + override protected def getInternalTransformer(schema: StructType): PipelineModel = { + val batcher = if (shouldAutoBatch(schema)) { + Some(new FixedMiniBatchTransformer().setBatchSize(getBatchSize)) + } else { + None + } + val newSchema = batcher.map(_.transformSchema(schema)).getOrElse(schema) + + val pipe = super.getInternalTransformer(newSchema) + + val postprocess = new UDFTransformer() + .setInputCol(getOutputCol) + .setOutputCol(getOutputCol) + .setUDF(postprocessResponseUdf) + + val flatten = if (shouldAutoBatch(schema)) { + Some(new FlattenBatch()) + } else { + None + } + + NamespaceInjections.pipelineModel( + Array(batcher, Some(pipe), Some(postprocess), flatten).flatten + ) + } + + private def createMultiLanguageAnalysisInput(row: Row): MultiLanguageAnalysisInput = { + val validText = getValue(row, text) + val langs = getValueOpt(row, language).getOrElse(Seq.fill(validText.length)("")) + val validLanguages = (if (langs.length == 1) { + Seq.fill(validText.length)(langs.head) + } else { + langs + }).map(lang => Option(lang).getOrElse("")) + assert(validLanguages.length == validText.length) + MultiLanguageAnalysisInput(validText.zipWithIndex.map { case (t, i) => + TADocument(Some(validLanguages(i)), i.toString, Option(t).getOrElse("")) + }) + } +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala new file mode 100644 index 0000000000..03a6afa948 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala @@ -0,0 +1,700 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.language + +import com.microsoft.azure.synapse.ml.core.test.fuzzing.{ TestObject, TransformerFuzzing } +import com.microsoft.azure.synapse.ml.services.text.{ SentimentAssessment, TextEndpoint } +import com.microsoft.azure.synapse.ml.Secrets +import org.apache.spark.ml.util.MLReadable +import org.apache.spark.sql.{ DataFrame, Row } +import org.apache.spark.sql.functions.{ col, flatten, map } +import org.scalactic.{ Equality, TolerantNumerics } +import org.scalatest.funsuite.AnyFunSuiteLike + +trait LanguageServiceEndpoint { + lazy val languageApiKey: String = sys.env.getOrElse("CUSTOM_LANGUAGE_KEY", Secrets.LanguageApiKey) + lazy val languageApiLocation: String = sys.env.getOrElse("LANGUAGE_API_LOCATION", "eastus") +} + +class AnalyzeTextLROSuite extends AnyFunSuiteLike { + test("Validate that response map and creator handle same kinds") { + val transformer = new AnalyzeTextLongRunningOperations() + val responseKinds = transformer.responseDataTypeSchemaMap.keySet + val creatorKinds = transformer.requestCreatorMap.keySet + assert(responseKinds == creatorKinds) + } +} + +class ExtractiveSummarizationSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint { + + import spark.implicits._ + + implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3) + + + private val df = Seq( + Seq( + """At Microsoft, we have been on a quest to advance AI beyond existing techniques, by taking a more holistic, + |human-centric approach to learning and understanding. As Chief Technology Officer of Azure AI services, + |I have been working with a team of amazing scientists and engineers to turn this quest into a reality. + |In my role, I enjoy a unique perspective in viewing the relationship among three attributes of human + |cognition: monolingual text (X), audio or visual sensory signals, (Y) and multilingual (Z). At the + |intersection of all three, there’s magic—what we call XYZ-code as illustrated in Figure 1—a joint + |representation to create more powerful AI that can speak, hear, see, and understand humans better. We + |believe XYZ-code enables us to fulfill our long-term vision: cross-domain transfer learning, spanning + |modalities and languages. The goal is to have pretrained models that can jointly learn representations to + |support a broad range of downstream AI tasks, much in the way humans do today. Over the past five years, we + |have achieved human performance on benchmarks in conversational speech recognition, machine translation, + |conversational question answering, machine reading comprehension, and image captioning. These five + |breakthroughs provided us with strong signals toward our more ambitious aspiration to produce a leap in AI + |capabilities, achieving multi-sensory and multilingual learning that is closer in line with how humans learn + | and understand. I believe the joint XYZ-code is a foundational component of this aspiration, if grounded + | with external knowledge sources in the downstream AI tasks""".stripMargin, + "", + """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nam ultricies interdum eros, vehicula dignissim + |odio dignissim id. Nam sagittis lacinia enim at fringilla. Nunc imperdiet porta ex. Vestibulum quis nisl + |feugiat, dapibus nulla nec, dictum lorem. Vivamus ut urna a ante cursus egestas. In vulputate facilisis + |nunc, vitae aliquam neque faucibus a. Fusce et venenatis nisi. Duis eleifend condimentum neque. Mauris eu + |pulvinar sapien. Nam at nibh sem. Integer sapien ex, viverra vel interdum non, volutpat sed tellus. Aenean + | nec maximus nibh. Maecenas sagittis turpis vel nibh condimentum vulputate. Pellentesque viverra + | ullamcorper urna vitae rutrum. Nunc fermentum sem vitae commodo efficitur.""".stripMargin + ) + ).toDF("text") + + + test("Basic usage") { + val model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setTextCol("text") + .setLanguage("en") + .setKind(AnalysisTaskKind.ExtractiveSummarization) + .setOutputCol("response") + .setErrorCol("error") + val responses = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("modelVersion", col("response.modelVersion")) + .withColumn("errors", col("response.errors")) + .withColumn("statistics", col("response.statistics")) + .collect() + assert(responses.length == 1) + val response = responses.head + val documents = response.getAs[Seq[Row]]("documents") + val errors = response.getAs[Seq[Row]]("errors") + assert(documents.length == errors.length) + assert(documents.length == 3) + val sentences = documents.head.getAs[Seq[Row]]("sentences") + assert(sentences.nonEmpty) + sentences.foreach { sentence => + assert(sentence.getAs[String]("text").nonEmpty) + assert(sentence.getAs[Double]("rankScore") > 0.0) + assert(sentence.getAs[Int]("offset") >= 0) + assert(sentence.getAs[Int]("length") > 0) + } + } + + + test("show-stats and sentence-count") { + val sentenceCount = 10 + val model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setTextCol("text") + .setLanguage("en") + .setKind(AnalysisTaskKind.ExtractiveSummarization) + .setOutputCol("response") + .setErrorCol("error") + .setShowStats(true) + .setSentenceCount(sentenceCount) + val responses = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("modelVersion", col("response.modelVersion")) + .withColumn("errors", col("response.errors")) + .withColumn("statistics", col("response.statistics")) + .collect() + assert(responses.length == 1) + val response = responses.head + val stats = response.getAs[Seq[Row]]("statistics") + assert(stats.length == 3) + stats.foreach { stat => + assert(stat.getAs[Int]("documentsCount") == 3) + assert(stat.getAs[Int]("validDocumentsCount") == 2) + assert(stat.getAs[Int]("erroneousDocumentsCount") == 1) + assert(stat.getAs[Int]("transactionsCount") == 3) + } + + val documents = response.getAs[Seq[Row]]("documents") + for (document <- documents) { + if (document != null) { + val sentences = document.getAs[Seq[Row]]("sentences") + assert(sentences.length == sentenceCount) + sentences.foreach { sentence => + assert(sentence.getAs[String]("text").nonEmpty) + assert(sentence.getAs[Double]("rankScore") > 0.0) + assert(sentence.getAs[Int]("offset") >= 0) + assert(sentence.getAs[Int]("length") > 0) + } + } + } + } + + override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] = + Seq(new TestObject[AnalyzeTextLongRunningOperations](new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setTextCol("text") + .setLanguage("en") + .setKind("ExtractiveSummarization") + .setOutputCol("response"), + df)) + + override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations +} + + +class AbstractiveSummarizationSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint { + + import spark.implicits._ + + implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3) + + + private val df = Seq( + Seq( + """At Microsoft, we have been on a quest to advance AI beyond existing techniques, by taking a more holistic, + |human-centric approach to learning and understanding. As Chief Technology Officer of Azure AI services, + |I have been working with a team of amazing scientists and engineers to turn this quest into a reality. + |In my role, I enjoy a unique perspective in viewing the relationship among three attributes of human + |cognition: monolingual text (X), audio or visual sensory signals, (Y) and multilingual (Z). At the + |intersection of all three, there’s magic—what we call XYZ-code as illustrated in Figure 1—a joint + |representation to create more powerful AI that can speak, hear, see, and understand humans better. We + |believe XYZ-code enables us to fulfill our long-term vision: cross-domain transfer learning, spanning + |modalities and languages. The goal is to have pretrained models that can jointly learn representations to + |support a broad range of downstream AI tasks, much in the way humans do today. Over the past five years, we + |have achieved human performance on benchmarks in conversational speech recognition, machine translation, + |conversational question answering, machine reading comprehension, and image captioning. These five + |breakthroughs provided us with strong signals toward our more ambitious aspiration to produce a leap in AI + |capabilities, achieving multi-sensory and multilingual learning that is closer in line with how humans learn + | and understand. I believe the joint XYZ-code is a foundational component of this aspiration, if grounded + | with external knowledge sources in the downstream AI tasks""".stripMargin + ) + ).toDF("text") + + + test("Basic usage") { + val model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setTextCol("text") + .setLanguage("en") + .setKind("AbstractiveSummarization") + .setOutputCol("response") + .setErrorCol("error") + .setPollingDelay(5 * 1000) + .setMaxPollingRetries(30) + val responses = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("modelVersion", col("response.modelVersion")) + .withColumn("errors", col("response.errors")) + .withColumn("statistics", col("response.statistics")) + .collect() + assert(responses.length == 1) + val response = responses.head + val documents = response.getAs[Seq[Row]]("documents") + val errors = response.getAs[Seq[Row]]("errors") + assert(documents.length == errors.length) + assert(documents.length == 1) + val summaries = documents.head.getAs[Seq[Row]]("summaries") + assert(summaries.nonEmpty) + } + + + test("show-stats and summary-length") { + val model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setTextCol("text") + .setLanguage("en") + .setKind(AnalysisTaskKind.AbstractiveSummarization) + .setOutputCol("response") + .setErrorCol("error") + .setShowStats(true) + .setSummaryLength(SummaryLength.Short) + .setPollingDelay(5 * 1000) + .setMaxPollingRetries(30) + val responses = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("modelVersion", col("response.modelVersion")) + .withColumn("errors", col("response.errors")) + .withColumn("statistics", col("response.statistics")) + .collect() + assert(responses.length == 1) + val response = responses.head + val stat = response.getAs[Seq[Row]]("statistics").head + assert(stat.getAs[Int]("documentsCount") == 1) + assert(stat.getAs[Int]("validDocumentsCount") == 1) + assert(stat.getAs[Int]("erroneousDocumentsCount") == 0) + + + val document = response.getAs[Seq[Row]]("documents").head + val summaries = document.getAs[Seq[Row]]("summaries") + assert(summaries.length == 1) + val summary = summaries.head.getAs[String]("text") + assert(summary.length <= 750) + } + + override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] = + Seq(new TestObject[AnalyzeTextLongRunningOperations](new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setTextCol("text") + .setLanguage("en") + .setKind("AbstractiveSummarization") + .setPollingDelay(5 * 1000) + .setMaxPollingRetries(30) + .setSummaryLength(SummaryLength.Short) + .setOutputCol("response"), + Seq("Microsoft Azure AI Data Fabric").toDF("text"))) + + override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations +} + +class HealthcareSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint { + + + import spark.implicits._ + + implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3) + + private val df = Seq( + "The doctor prescried 200mg Ibuprofen." + ).toDF("text") + + test("Basic usage") { + val model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setTextCol("text") + .setLanguage("en") + .setKind("Healthcare") + .setOutputCol("response") + .setShowStats(true) + .setErrorCol("error") + val responses = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("modelVersion", col("response.modelVersion")) + .withColumn("errors", col("response.errors")) + .withColumn("statistics", col("response.statistics")) + .collect() + assert(responses.length == 1) + } + + override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] = + Seq(new TestObject[AnalyzeTextLongRunningOperations](new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setTextCol("text") + .setLanguage("en") + .setKind("Healthcare") + .setOutputCol("response"), + df)) + + override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations +} + +class SentimentAnalysisLROSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint { + + import spark.implicits._ + + implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3) + + def df: DataFrame = Seq( + "Great atmosphere. Close to plenty of restaurants, hotels, and transit! Staff are friendly and helpful.", + "What a sad story!" + ).toDF("text") + + def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setTextCol("text") + .setKind(AnalysisTaskKind.SentimentAnalysis) + .setOutputCol("response") + .setErrorCol("error") + + test("Basic usage") { + val result = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("sentiment", col("documents.sentiment")) + .collect() + assert(result.head.getAs[String]("sentiment") == "positive") + assert(result(1).getAs[String]("sentiment") == "negative") + } + + test("api-version 2022-10-01-preview") { + val result = model.setApiVersion("2022-10-01-preview").transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("sentiment", col("documents.sentiment")) + .collect() + assert(result.head.getAs[String]("sentiment") == "positive") + assert(result(1).getAs[String]("sentiment") == "negative") + } + + test("Show stats") { + val result = model.setShowStats(true).transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("sentiment", col("documents.sentiment")) + .withColumn("validDocumentsCount", col("response.statistics.validDocumentsCount")) + .collect() + assert(result.head.getAs[String]("sentiment") == "positive") + assert(result(1).getAs[String]("sentiment") == "negative") + assert(result.head.getAs[Int]("validDocumentsCount") == 1) + } + + test("Opinion Mining") { + val result = model.setOpinionMining(true).transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("sentiment", col("documents.sentiment")) + .withColumn("assessments", flatten(col("documents.sentences.assessments"))) + .collect() + assert(result.head.getAs[String]("sentiment") == "positive") + assert(result(1).getAs[String]("sentiment") == "negative") + val fromRow = SentimentAssessment.makeFromRowConverter + assert(result.head.getAs[Seq[Row]]("assessments").map(fromRow).head.sentiment == "positive") + } + + override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] = + Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df)) + + override def reader: MLReadable[_] = AnalyzeText +} + + +class KeyPhraseLROSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint { + + import spark.implicits._ + + implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3) + + def df: DataFrame = Seq( + ("en", "Microsoft was founded by Bill Gates and Paul Allen."), + ("en", "Text Analytics is one of the Azure Cognitive Services."), + ("en", "My cat might need to see a veterinarian.") + ).toDF("language", "text") + + def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setLanguageCol("language") + .setTextCol("text") + .setKind("KeyPhraseExtraction") + .setOutputCol("response") + .setErrorCol("error") + + test("Basic usage") { + val result = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("keyPhrases", col("documents.keyPhrases")) + val keyPhrases = result.collect()(1).getAs[Seq[String]]("keyPhrases") + assert(keyPhrases.contains("Azure Cognitive Services")) + assert(keyPhrases.contains("Text Analytics")) + } + + test("api-version 2022-10-01-preview") { + val result = model.setApiVersion("2022-10-01-preview").transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("keyPhrases", col("documents.keyPhrases")) + val keyPhrases = result.collect()(1).getAs[Seq[String]]("keyPhrases") + assert(keyPhrases.contains("Azure Cognitive Services")) + assert(keyPhrases.contains("Text Analytics")) + } + + test("Show stats") { + val result = model.setShowStats(true).transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("keyPhrases", col("documents.keyPhrases")) + .withColumn("validDocumentsCount", col("response.statistics.validDocumentsCount")) + val keyPhrases = result.collect()(1).getAs[Seq[String]]("keyPhrases") + assert(keyPhrases.contains("Azure Cognitive Services")) + assert(keyPhrases.contains("Text Analytics")) + assert(result.head.getAs[Int]("validDocumentsCount") == 1) + } + + override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] = + Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df)) + + override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations +} + + +class AnalyzeTextPIILORSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint { + + import spark.implicits._ + + implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3) + + def df: DataFrame = Seq( + "My SSN is 859-98-0987", + "Your ABA number - 111000025 - is the first 9 digits in the lower left hand corner of your personal check.", + "Is 998.214.865-68 your Brazilian CPF number?" + ).toDF("text") + + def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setTextCol("text") + .setKind("PiiEntityRecognition") + .setOutputCol("response") + .setErrorCol("error") + + test("Basic usage") { + val result = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("redactedText", col("documents.redactedText")) + .withColumn("entities", col("documents.entities.text")) + .collect() + val entities = result.head.getAs[Seq[String]]("entities") + assert(entities.contains("859-98-0987")) + val redactedText = result(1).getAs[String]("redactedText") + assert(!redactedText.contains("111000025")) + } + + test("api-version 2022-10-01-preview") { + val result = model.setApiVersion("2022-10-01-preview").transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("redactedText", col("documents.redactedText")) + .withColumn("entities", col("documents.entities.text")) + .collect() + val entities = result.head.getAs[Seq[String]]("entities") + assert(entities.contains("859-98-0987")) + val redactedText = result(1).getAs[String]("redactedText") + assert(!redactedText.contains("111000025")) + } + + test("Show stats") { + val result = model.setShowStats(true).transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("redactedText", col("documents.redactedText")) + .withColumn("entities", col("documents.entities.text")) + .withColumn("validDocumentsCount", col("response.statistics.validDocumentsCount")) + .collect() + val entities = result.head.getAs[Seq[String]]("entities") + assert(entities.contains("859-98-0987")) + val redactedText = result(1).getAs[String]("redactedText") + assert(!redactedText.contains("111000025")) + assert(result.head.getAs[Int]("validDocumentsCount") == 1) + } + + override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] = + Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df)) + + override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations +} + + +class EntityLinkingLROSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint { + + import spark.implicits._ + + implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3) + + def df: DataFrame = Seq( + ("en", "Microsoft was founded by Bill Gates and Paul Allen."), + ("en", "Pike place market is my favorite Seattle attraction.") + ).toDF("language", "text") + + def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setLanguageCol("language") + .setTextCol("text") + .setKind("EntityLinking") + .setOutputCol("response") + .setErrorCol("error") + + test("Basic usage") { + val result = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("entityNames", map(col("documents.id"), col("documents.entities.name"))) + val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0") + assert(entities.contains("Microsoft")) + assert(entities.contains("Bill Gates")) + } + + test("api-version 2022-10-01-preview") { + val result = model.setApiVersion("2022-10-01-preview").transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("entityNames", map(col("documents.id"), col("documents.entities.name"))) + val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0") + assert(entities.contains("Microsoft")) + assert(entities.contains("Bill Gates")) + } + + test("Show stats") { + val result = model.setShowStats(true).transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("entityNames", map(col("documents.id"), col("documents.entities.name"))) + .withColumn("validDocumentsCount", col("response.statistics.validDocumentsCount")) + val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0") + assert(entities.contains("Microsoft")) + assert(entities.contains("Bill Gates")) + assert(result.head.getAs[Int]("validDocumentsCount") == 1) + } + + override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] = + Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df)) + + override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations +} + + +class EntityRecognitionLROSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint { + + import spark.implicits._ + + implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3) + + def df: DataFrame = Seq( + ("en", "Microsoft was founded by Bill Gates and Paul Allen."), + ("en", "Pike place market is my favorite Seattle attraction.") + ).toDF("language", "text") + + def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setLanguageCol("language") + .setTextCol("text") + .setKind("EntityRecognition") + .setOutputCol("response") + .setErrorCol("error") + + test("Basic usage") { + val result = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("entityNames", map(col("documents.id"), col("documents.entities.text"))) + val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0") + assert(entities.contains("Microsoft")) + assert(entities.contains("Bill Gates")) + } + + test("api-version 2022-10-01-preview") { + val result = model.setApiVersion("2022-10-01-preview").transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("entityNames", map(col("documents.id"), col("documents.entities.text"))) + val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0") + assert(entities.contains("Microsoft")) + assert(entities.contains("Bill Gates")) + } + + test("Show stats") { + val result = model.setShowStats(true).transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("entityNames", map(col("documents.id"), col("documents.entities.text"))) + .withColumn("validDocumentsCount", col("response.statistics.validDocumentsCount")) + val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0") + assert(entities.contains("Microsoft")) + assert(entities.contains("Bill Gates")) + assert(result.head.getAs[Int]("validDocumentsCount") == 1) + } + + override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] = + Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df)) + + override def reader: MLReadable[_] = AnalyzeText +} + +class CustomEntityRecognitionSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] + with LanguageServiceEndpoint { + + import spark.implicits._ + + implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3) + + def df: DataFrame = + Seq("Maria Sullivan with a mailing address of 334 Shinn Avenue, City of Wampum, State of Pennsylvania") + .toDF("text") + + def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(languageApiKey) + .setLocation(languageApiLocation) + .setLanguage("en") + .setTextCol("text") + .setKind(AnalysisTaskKind.CustomEntityRecognition) + .setOutputCol("response") + .setErrorCol("error") + .setDeploymentName("custom-ner-unitest-deployment") + .setProjectName("for-unit-test") + + test("Basic usage") { + val result = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("entities", col("documents.entities")) + .collect() + val entities = result.head.getAs[Seq[Row]]("entities") + assert(entities.length == 4) + val resultMap: Map[String, String] = entities.map { entity => + entity.getAs[String]("text") -> entity.getAs[String]("category") + }.toMap + assert(resultMap("Maria Sullivan") == "BorrowerName") + assert(resultMap("334 Shinn Avenue") == "BorrowerAddress") + assert(resultMap("Wampum") == "BorrowerCity") + assert(resultMap("Pennsylvania") == "BorrowerState") + } + + + override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] = + Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df)) + + override def reader: MLReadable[_] = AnalyzeText +} + + +class MultiLableClassificationSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] + with LanguageServiceEndpoint { + + import spark.implicits._ + + implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3) + + def df: DataFrame = { + // description of movie Finding Nemo + Seq("In the depths of the ocean, a father's worst nightmare comes to life. A grieving and determined father, " + + "must overcome his fears and navigate, the treacherous waters to find his missing son. The journey is " + + "fraught with relentless predators, dark secrets, and the haunting realization that the ocean is a vast, " + + "unforgiving abyss. Will a Father's unwavering resolve be enough to reunite him with his son, or will " + + "the shadows of the deep consume them both? Dive into the darkness and discover the lengths a parent will " + + "go to for their child.") + .toDF("text") + } + + def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations() + .setSubscriptionKey(languageApiKey) + .setLocation(languageApiLocation) + .setLanguage("en") + .setTextCol("text") + .setKind(AnalysisTaskKind.CustomMultiLabelClassification) + .setOutputCol("response") + .setErrorCol("error") + .setDeploymentName("multi-class-movie-dep") + .setProjectName("for-unit-test-muti-class") + + test("Basic usage") { + val result = model.transform(df) + .withColumn("documents", col("response.documents")) + .withColumn("classifications", col("documents.classifications")) + .collect() + val classifications = result.head.getAs[Seq[Row]]("classifications") + assert(classifications.nonEmpty) + assert(classifications.head.getAs[String]("category").nonEmpty) + assert(classifications.head.getAs[Double]("confidenceScore") > 0.0) + } + + + override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] = + Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df)) + + override def reader: MLReadable[_] = AnalyzeText +} + + + diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala index 17eed8a668..ae45753e8b 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala @@ -74,4 +74,5 @@ object Secrets { lazy val Platform: String = getSecret("synapse-platform") lazy val AadResource: String = getSecret("synapse-internal-aad-resource") + lazy val LanguageApiKey: String = getSecret("language-api-key") }