diff --git a/include/mlc/core/list.h b/include/mlc/core/list.h index 88e1152..1b12c75 100644 --- a/include/mlc/core/list.h +++ b/include/mlc/core/list.h @@ -174,7 +174,8 @@ struct UList : public ObjectRef { self->erase(i); return ret; }) - .MemFn("_clear", &UListObj::clear); + .MemFn("_clear", &UListObj::clear) + .MemFn("__setitem__", &::mlc::core::ListBase::Accessor::SetItem); }; template struct ListObj : protected UListObj { diff --git a/include/mlc/core/list_base.h b/include/mlc/core/list_base.h index 76cb25e..a27a3be 100644 --- a/include/mlc/core/list_base.h +++ b/include/mlc/core/list_base.h @@ -44,6 +44,7 @@ struct ListBase : public MLCList { template struct Accessor { MLC_INLINE static void New(int32_t num_args, const AnyView *args, Any *ret); MLC_INLINE static Any At(TListObj *self, int64_t i); + MLC_INLINE static void SetItem(TListObj *self, int64_t i, Any value); }; }; @@ -180,6 +181,11 @@ MLC_INLINE Any ListBase::Accessor::At(TListObj *self, int64_t i) { return self->operator[](i); } +template // +MLC_INLINE void ListBase::Accessor::SetItem(TListObj *self, int64_t i, Any value) { + self->operator[](i) = std::move(value); +} + } // namespace core } // namespace mlc diff --git a/python/mlc/core/list.py b/python/mlc/core/list.py index ee04a64..505878b 100644 --- a/python/mlc/core/list.py +++ b/python/mlc/core/list.py @@ -43,6 +43,14 @@ def __getitem__(self, i: int | slice) -> T | Sequence[T]: else: raise TypeError(f"list indices must be integers or slices, not {type(i).__name__}") + def __setitem__(self, index: int, value: T) -> None: + length = len(self) + if not -length <= index < length: + raise IndexError(f"list assignment index out of range: {index}") + if index < 0: + index += length + List._C(b"__setitem__", self, index, value) + def __iter__(self) -> Iterator[T]: return iter(self[i] for i in range(len(self))) diff --git a/tests/python/test_core_list.py b/tests/python/test_core_list.py index 1bbe6ad..b04d775 100644 --- a/tests/python/test_core_list.py +++ b/tests/python/test_core_list.py @@ -163,3 +163,21 @@ def test_list_ne() -> None: b = List([1, 2, 3, 4]) assert a != b assert b != a + + +def test_list_setitem() -> None: + a = List([1, 4, 9, 16]) + a[1] = 8 + assert tuple(a) == (1, 8, 9, 16) + a[-2] = 12 + assert tuple(a) == (1, 8, 12, 16) + a[-4] = 2 + assert tuple(a) == (2, 8, 12, 16) + + with pytest.raises(IndexError) as e: + a[4] = 20 + assert str(e.value) == "list assignment index out of range: 4" + + with pytest.raises(IndexError) as e: + a[-5] = 20 + assert str(e.value) == "list assignment index out of range: -5"