Skip to content

Commit

Permalink
Fixed type conversion for tosa.abs when lowering to linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
lmendesp-amd committed Dec 9, 2024
1 parent 37f5d68 commit 1e53fb5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
23 changes: 13 additions & 10 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
}

static Value createLinalgBodyCalculationForElementwiseOp(
Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
ConversionPatternRewriter &rewriter) {
Operation *op, const TypeConverter &converter, ValueRange args,
ArrayRef<Type> resultTypes, ConversionPatternRewriter &rewriter) {
Location loc = op->getLoc();
auto elementTy =
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
Expand All @@ -61,7 +61,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(

if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementTy));
loc, rewriter.getZeroAttr(converter.convertType(elementTy)));
auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
}
Expand Down Expand Up @@ -416,17 +416,19 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (intTy.isUnsignedInteger()) {
minRepresentable = 0;
if (intTy.getIntOrFloatBitWidth() <= 63) {
maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
.getZExtValue();
maxRepresentable =
(int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
.getZExtValue();
}
} else if(intTy.getIntOrFloatBitWidth() <= 64) {
} else if (intTy.getIntOrFloatBitWidth() <= 64) {
// Ensure that min & max fit into signed n-bit constants.
minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
.getSExtValue();
.getSExtValue();
maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
.getSExtValue();
.getSExtValue();
}
// Ensure that the bounds are representable as n-bit signed/unsigned integers.
// Ensure that the bounds are representable as n-bit signed/unsigned
// integers.
min = std::max(min, minRepresentable);
max = std::max(max, minRepresentable);
min = std::min(min, maxRepresentable);
Expand Down Expand Up @@ -946,7 +948,8 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
Value opResult = createLinalgBodyCalculationForElementwiseOp(
operation, blockArgs.take_front(operation->getNumOperands()),
operation, converter,
blockArgs.take_front(operation->getNumOperands()),
{resultType.getElementType()}, rewriter);
if (!opResult) {
encounteredError = true;
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2010,3 +2010,12 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
%output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = true} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
}

// -----
// CHECK-LABEL: @test_abs_conversion
// CHECK: linalg.generic
// CHECK: arith.constant 0 : i64
func.func @test_abs_conversion(%arg0: tensor<9xui64> {func.orig_type = tensor<9xui64>, onnx.name = "in0"}) -> (tensor<9xui64> {func.orig_type = tensor<9xui64>, onnx.name = "out0"}) {
%0 = tosa.abs %arg0 : (tensor<9xui64>) -> tensor<9xui64>
return %0 : tensor<9xui64>
}

0 comments on commit 1e53fb5

Please # to comment.