Skip to content

Commit

Permalink
Merge pull request #409 from Xilinx/lmendesp.tosa-abs-to-linalg
Browse files Browse the repository at this point in the history
Fixed type conversion for tosa.abs when lowering to linalg
  • Loading branch information
mgehre-amd authored Dec 9, 2024
2 parents 37f5d68 + 918d719 commit a439f4c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
16 changes: 9 additions & 7 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(

if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementTy));
loc, rewriter.getZeroAttr(convertedElementTy));
auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
}
Expand Down Expand Up @@ -416,17 +416,19 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (intTy.isUnsignedInteger()) {
minRepresentable = 0;
if (intTy.getIntOrFloatBitWidth() <= 63) {
maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
.getZExtValue();
maxRepresentable =
(int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
.getZExtValue();
}
} else if(intTy.getIntOrFloatBitWidth() <= 64) {
} else if (intTy.getIntOrFloatBitWidth() <= 64) {
// Ensure that min & max fit into signed n-bit constants.
minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
.getSExtValue();
.getSExtValue();
maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
.getSExtValue();
.getSExtValue();
}
// Ensure that the bounds are representable as n-bit signed/unsigned integers.
// Ensure that the bounds are representable as n-bit signed/unsigned
// integers.
min = std::max(min, minRepresentable);
max = std::max(max, minRepresentable);
min = std::min(min, maxRepresentable);
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2010,3 +2010,12 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
%output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = true} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
}

// -----
// CHECK-LABEL: @test_abs_conversion
// CHECK: linalg.generic
// CHECK: arith.constant 0 : i64
func.func @test_abs_conversion(%arg0: tensor<9xui64> {func.orig_type = tensor<9xui64>, onnx.name = "in0"}) -> (tensor<9xui64> {func.orig_type = tensor<9xui64>, onnx.name = "out0"}) {
%0 = tosa.abs %arg0 : (tensor<9xui64>) -> tensor<9xui64>
return %0 : tensor<9xui64>
}

0 comments on commit a439f4c

Please # to comment.