Skip to content

Commit

Permalink
feat(core): Support List.__setitem__ (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
potatomashed authored Feb 26, 2025
1 parent 4f414eb commit 1513c93
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 1 deletion.
3 changes: 2 additions & 1 deletion include/mlc/core/list.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ struct UList : public ObjectRef {
self->erase(i);
return ret;
})
.MemFn("_clear", &UListObj::clear);
.MemFn("_clear", &UListObj::clear)
.MemFn("__setitem__", &::mlc::core::ListBase::Accessor<UListObj>::SetItem);
};

template <typename T> struct ListObj : protected UListObj {
Expand Down
6 changes: 6 additions & 0 deletions include/mlc/core/list_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ struct ListBase : public MLCList {
template <typename TListObj> struct Accessor {
MLC_INLINE static void New(int32_t num_args, const AnyView *args, Any *ret);
MLC_INLINE static Any At(TListObj *self, int64_t i);
MLC_INLINE static void SetItem(TListObj *self, int64_t i, Any value);
};
};

Expand Down Expand Up @@ -180,6 +181,11 @@ MLC_INLINE Any ListBase::Accessor<TListObj>::At(TListObj *self, int64_t i) {
return self->operator[](i);
}

template <typename TListObj> //
MLC_INLINE void ListBase::Accessor<TListObj>::SetItem(TListObj *self, int64_t i, Any value) {
self->operator[](i) = std::move(value);
}

} // namespace core
} // namespace mlc

Expand Down
8 changes: 8 additions & 0 deletions python/mlc/core/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def __getitem__(self, i: int | slice) -> T | Sequence[T]:
else:
raise TypeError(f"list indices must be integers or slices, not {type(i).__name__}")

def __setitem__(self, index: int, value: T) -> None:
length = len(self)
if not -length <= index < length:
raise IndexError(f"list assignment index out of range: {index}")
if index < 0:
index += length
List._C(b"__setitem__", self, index, value)

def __iter__(self) -> Iterator[T]:
return iter(self[i] for i in range(len(self)))

Expand Down
18 changes: 18 additions & 0 deletions tests/python/test_core_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,21 @@ def test_list_ne() -> None:
b = List([1, 2, 3, 4])
assert a != b
assert b != a


def test_list_setitem() -> None:
a = List([1, 4, 9, 16])
a[1] = 8
assert tuple(a) == (1, 8, 9, 16)
a[-2] = 12
assert tuple(a) == (1, 8, 12, 16)
a[-4] = 2
assert tuple(a) == (2, 8, 12, 16)

with pytest.raises(IndexError) as e:
a[4] = 20
assert str(e.value) == "list assignment index out of range: 4"

with pytest.raises(IndexError) as e:
a[-5] = 20
assert str(e.value) == "list assignment index out of range: -5"

0 comments on commit 1513c93

Please # to comment.