Skip to content

Latest commit

 

History

History
219 lines (175 loc) · 11.1 KB

20200301_api_design_for_quantile.md

File metadata and controls

219 lines (175 loc) · 11.1 KB

paddle.Tensor.quantile 设计文档

API名称 paddle.Tensor.quantile
提交作者 陈明
提交时间 2022-03-01
版本号 V1.0
依赖飞桨版本 v2.2.0
文件名 20200301_design_for_quantile.md

一、概述

1、相关背景

为了提升飞桨API丰富度,支持科学计算领域API,Paddle需要扩充APIpaddle.quantile以及paddle.Tensor.quantile

2、功能目标

增加APIpaddle.quantile以及paddle.Tensor.quantile,实现对一个张量沿指定维度计算q分位数的功能。

3、意义

飞桨支持计算分位数

二、飞桨现状

目前paddle缺少相关功能实现。

API方面,已有类似功能的API,paddle.median, 在Paddle中是一个由多个其他API组合成的API,没有实现自己的OP,其主要实现逻辑为:

  1. 如未指定维度,则通过paddle.flatten展平处理
  2. 通过paddle.topk得到原Tensor中较大的半部分(k取对应维度size / 2 + 1)。
  3. 若size是奇数,则能直接取到一个元素,直接通过paddle.slice切分出第size/2个元素即可;若为偶数,通过paddle.slice分别切分出第size/2-1size/2个元素,交错相加并取均值得到结果。

但在实际实现时,不能完全直接复用上述方案,理由如下:

  1. paddle.topk未支持一次计算多个k值,如果仍然采用topk取对应indice的元素,在q值为多个时需要执行多次topk
  2. paddle.topk当前GPU/CPU对NaN值的处理未统一;
  3. paddle.slice只支持取一次索引,仍然无法一次处理取多个索引的情况。

三、业内方案调研

Pytorch

Pytorch中有APItorch.quantile(input, q, dim=None, keepdim=False, *, out=None) -> Tensor,以及对应的torch.Tensor.quantile(q, dim=None, keepdim=False) -> Tensor.在pytorch中,介绍为:

Computes the q-th quantiles of each row of the input tensor along the dimension dim.

To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location of the quantile in the sorted input. If the quantile lies between two data points a < b with indices i and j in the sorted order, result is computed using linear interpolation as follows:

a + (b - a) * fraction, where fraction is the fractional part of the computed quantile index.

If q is a 1D tensor, the first dimension of the output represents the quantiles and has size equal to the size of q, the remaining dimensions are what remains from the reduction.

实现方法

在实现方法上, Pytorch是通过c++ API组合实现的,代码位置。 其中核心代码为,根据对NaN处理方式的不同,同时支持了pytorch.quantilepytorch.nanquantile两个API:

// Convert q in [0, 1] to ranks in [0, reduction_size)
  Tensor ranks;
  if (ignore_nan) {
    // For nanquantile, compute ranks based on number of non-nan values.
    // If all values are nan, set rank to 0 so the quantile computed is nan.
    ranks = q * (sorted.isnan().logical_not_().sum(-1, true) - 1);
    ranks.masked_fill_(ranks < 0, 0);
  } else {
    // For quantile, compute ranks based on reduction size. If there is nan
    // set rank to last index so the quantile computed will be nan.
    int64_t last_index = sorted.size(-1) - 1;
    std::vector<Tensor> tl =
        at::broadcast_tensors({q * last_index, sorted.isnan().any(-1, true)});
    ranks = at::masked_fill(tl[0], tl[1], last_index);
  }

  // adjust ranks based on the interpolation mode
  if (interpolation == QUANTILE_INTERPOLATION_MODE::LOWER) {
    ranks.floor_();
  } else if (interpolation == QUANTILE_INTERPOLATION_MODE::HIGHER) {
    ranks.ceil_();
  } else if (interpolation == QUANTILE_INTERPOLATION_MODE::NEAREST) {
    ranks.round_();
  }

  Tensor ranks_below = ranks.toType(kLong);
  Tensor values_below = sorted.gather(-1, ranks_below);

  // Actual interpolation is only needed for the liner and midpoint modes
  if (interpolation == QUANTILE_INTERPOLATION_MODE::LINEAR ||
      interpolation == QUANTILE_INTERPOLATION_MODE::MIDPOINT) {
    // calculate weights for linear and midpoint
    Tensor weights = interpolation == QUANTILE_INTERPOLATION_MODE::MIDPOINT
        ? at::full_like(ranks, 0.5)
        : ranks - ranks_below;

    // Interpolate to compute quantiles and store in values_below
    Tensor ranks_above = ranks.ceil_().toType(kLong);
    Tensor values_above = sorted.gather(-1, ranks_above);
    values_below.lerp_(values_above, weights);

整体逻辑为:

  • 将Tensor的对应维度调整到最后,并进行排序处理
  • broadcast_tensors的方式,将q [0,1]映射到[0, num-1],并检查NaN值。
  • masked_fill将对应为NaN的部分直接赋值到最后的index(对应值也是NaN)
  • 采用lerp处理blowabove的之间的插值, 包含NaN的结果也是NaN.

Numpy

实现方法

以现有numpy python API组合实现,代码位置. 其中核心代码为:

    ap = np.moveaxis(ap, axis, 0)
    del axis

    if np.issubdtype(indices.dtype, np.integer):
        # take the points along axis

        if np.issubdtype(a.dtype, np.inexact):
            # may contain nan, which would sort to the end
            ap.partition(concatenate((indices.ravel(), [-1])), axis=0)
            n = np.isnan(ap[-1])
        else:
            # cannot contain nan
            ap.partition(indices.ravel(), axis=0)
            n = np.array(False, dtype=bool)

        r = take(ap, indices, axis=0, out=out)

    else:
        # weight the points above and below the indices

        indices_below = not_scalar(floor(indices)).astype(intp)
        indices_above = not_scalar(indices_below + 1)
        indices_above[indices_above > Nx - 1] = Nx - 1

        if np.issubdtype(a.dtype, np.inexact):
            # may contain nan, which would sort to the end
            ap.partition(concatenate((
                indices_below.ravel(), indices_above.ravel(), [-1]
            )), axis=0)
            n = np.isnan(ap[-1])
        else:
            # cannot contain nan
            ap.partition(concatenate((
                indices_below.ravel(), indices_above.ravel()
            )), axis=0)
            n = np.array(False, dtype=bool)

        weights_shape = indices.shape + (1,) * (ap.ndim - 1)
        weights_above = not_scalar(indices - indices_below).reshape(weights_shape)

        x_below = take(ap, indices_below, axis=0)
        x_above = take(ap, indices_above, axis=0)

        r = _lerp(x_below, x_above, weights_above, out=out)

    # if any slice contained a nan, then all results on that slice are also nan
    if np.any(n):
        if r.ndim == 0 and out is None:
            # can't write to a scalar
            r = a.dtype.type(np.nan)
        else:
            r[..., n] = a.dtype.type(np.nan)

    return r

整体逻辑为:

  • 若未指定维度,则flatten展平处理。使用np.moveaxis将指定的维度放到0处理;
  • q [0,1]根据shape放缩到 indice [0, nums-1]
  • 如果indice是整数,表示分位数是该Tensor的元素,后续直接按indice取元素即可;如果仍是小数,则找到其相邻位置的两个元素,后续需要用np.lerp插值计算得到对应元素。
  • 对输入Tensor,当indice为整数时,直接通过np.partition将其按每个indice分为两部分(即快速排序算法中的partition部分,不完整执行排序过程以降低时间复杂度),indice位置就是q分位数;当size为偶数时则将两端的indice_belowindice_above都做partition操作,并取出两端的对应结果,并利用np.lerp计算插值结果。
  • NaN的处理:对存在NaN的情况,使用np.isnan确定标志位,标志位对应的位置输出值为NaN.
  • Numpy支持多个维度处理,以tuple形式作为输入。此时的分位数计算是将指定的多个维度合并后计算得到的。

四、对比分析

  • 使用场景与功能:在维度支持上,Pytorch只支持一维,而Numpy支持多维,这里对齐Numpy的实现逻辑,同时支持一维和多维场景。
  • 实现对比:由于pytorch.gatherpaddle.gather实际在秩大于1时的表现不一致;在出现多个q值时,pytorch可直接通过处理后的indice进行多维索引,paddle则需要分别索引再组合到一起。因此这里不再使用paddle.gather索引,改使用paddle.take_along_axisAPI进行索引。

五、方案设计

命名与参数设计

API设计为paddle.quantile(x, q, axis=None, keepdim=False, name=None)paddle.Tensor.quantile(q, axis=None, keepdim=False, name=None) 命名与参数顺序为:形参名input->xdim->axis, 与paddle其他API保持一致性,不影响实际功能使用。 参数类型中,axis支持int1-D Tensor输入,以同时支持一维和多维的场景。

底层OP设计

使用已有API组合实现,不再单独设计OP。

API实现方案

主要按下列步骤进行组合实现,实现位置为paddle/tensor/stat.pymean,median等方法放在一起:

  1. 使用paddle.sort得到排序后的tensor.
  2. q:[0, 1]映射到indice:[0, numel_of_dim-1];并对indice分别做paddle.floorpaddle.ceil求得需要计算的两端元素位置;
  3. 使用paddle.take_along_axis取出对应axisindice的两端元素;
  4. paddle.lerp计算两端元素的加权插值,作为结果。
  5. 根据keepdim参数,确定是否需要对应调整结果shape。
  • NaN的处理,对原tensor采用paddle.isnan检查NaN值,包含NaN的,在步骤4所对应位置的元素置NaN

六、测试和验收的考量

测试考虑的case如下:

  • 数值准确性:和numpy结果的数值的一致性, paddle.quantile,paddle.Tensor.quantilenp.quantile结果是否一致;
  • 数值准确性:输入含NaN结果的正确性;
  • 入参测试:参数q为int和1-D Tensor时输出的正确性;
  • 入参测试:参数axis为int 和1-D Tensor时输出的正确性;
  • 入参测试:keepdim参数的正确性;
  • 入参测试:未输入维度时的输出正确性;
  • 数据类型:输入Tensorxdtypefloat32float64时的结果正确性;
  • 运行设备:在CPU/GPU设备上执行时的结果正确性;
  • 运行模式:动态图、静态图下执行时的结果正确性;
  • 错误检查:q值不在[0,1]时能正确抛出错误;为tensor时维度大于1时正确抛出错误;
  • 错误检查:axis所指维度在当前Tensor中不合法时能正确抛出错误。

七、可行性分析及规划排期

方案主要依赖现有paddle api组合而成,且依赖的paddle.lerp已于前期合入,paddle.take_along_axis将于近期合入。工期上可以满足在当前版本周期内开发完成。

八、影响面

为独立新增API,对其他模块没有影响

名词解释

附件及参考资料