Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Efficient element-wise comparisons #2529

Open
joshhansen opened this issue Nov 24, 2024 · 1 comment
Open

Efficient element-wise comparisons #2529

joshhansen opened this issue Nov 24, 2024 · 1 comment

Comments

@joshhansen
Copy link

Feature description

Allow elementise comparisons (eq, neq, gt, lt, gte, lte) where the element compared against need not be instantiated to a tensor of the same shape.

Feature motivation

Please correct me if I am missing something, but I have been unable to find a way to efficiently perform element-wise comparisons. For example, I want to check that all values are the range [0, 1]. The only way I see to do this is using Tensor::greater and friends, which require matching-shape tensors as input. So a full tensor of zeros, or ones, or whatever, to compare to. For my use case this is prohibitive.

(Optional) Suggest a Solution

Option 1: generalized Rust comparison traits

The dream would by a Pythonic use of operators as t1 < t2. Unfortunately, Rust's std::cmp traits do not allow the output type of the comparison to be specified, so this is likely not soon forthcoming.

Option 2: infer D on element-wise operations, then broadcast

One approach would be to let the D dimensionality of the comparison tensor be inferred, and specialize the implementation using standardized broadcast rules. As:

pub fn greater<const D2: usize>(self, other: Tensor<B, D2, K>) -> Tensor<B, D, Bool>

Option 3: single-element comparisons only

A perhaps-less disruptive stopgap would be to only implement for the most common case, that of comparing to a single element, which can always be broadcasted to a non-empty shape:

impl <B: Backend, const D: usize, K: TensorKind<B>> Tensor<B, D, K> {
  fn eq1(self, other: Tensor<B, 1, K>) -> Tensor<B, D, Bool> { ... }
  fn neq1(self, other: Tensor<B, 1, K>) -> Tensor<B, D, Bool> { ... }
  fn greater1(self, other: Tensor<B, 1, K>) -> Tensor<B, D, Bool> { ... }
  fn lower1(self, other: Tensor<B, 1, K>) -> Tensor<B, D, Bool> { ... }
  fn greater_equal1(self, other: Tensor<B, 1, K>) -> Tensor<B, D, Bool> { ... }
  fn lower_equal1(self, other: Tensor<B, 1, K>) -> Tensor<B, D, Bool> { ... }
}

Option 4: improve optimization

On the other hand, actually instantiating the full comparison array only to call any() on it could be seen as a failure of optimization. Consider my assert_in_zero_one:

fn assert_in_zero_one<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B> + Numeric<B>>(
    x: Tensor<B, D, K>,
) {
    {
        let zero = Tensor::zeros_like(&x);
        let negative = x.clone().lower(zero);
        assert!(!negative.any().into_scalar());
    }

    let one = Tensor::ones_like(&x);
    let over_one = x.greater(one);
    assert!(!over_one.any().into_scalar());
}

In both cases, the array is instantiated only to be compared, and than anyed. This implies a much more efficient algorithm, but which the Fusion and Jit components have not arrived at.

Automatically handling such cases would be wonderful, allowing the API to be used at a higher level of abstraction. Of course, that is also the downside.

@nathanielsimard
Copy link
Member

We have scalar comparison:

let tensor  = Tensor::random(..);
let mask  = tensor.greater_elem(0.5);

But also we support broadcasting, so:

let tensor = Tensor::random(..); // rank 6
let bools = Tensor::from_bool(..); // rank 1
let mask = tensor.greater(bools.unsqueeze());

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants