Skip to content

[DO NOT CLOSE] Library TODOs and call for contributions #116

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
1 of 5 tasks
matteobettini opened this issue Jun 26, 2024 · 1 comment
Open
1 of 5 tasks

[DO NOT CLOSE] Library TODOs and call for contributions #116

matteobettini opened this issue Jun 26, 2024 · 1 comment

Comments

@matteobettini
Copy link
Member

matteobettini commented Jun 26, 2024

Hello people!

In this issue I will list the things I would really like to have in VMAS and will tick them off as they are implemented!

These were previously in the README TODOs

It is also a really good place to find something you would like to contribute.

Features

Sensors

  • Implement 1D camera sensor
  • Implement 2D birds eye view camera sensor
@matteobettini matteobettini pinned this issue Jun 26, 2024
@19991006
Copy link

19991006 commented Apr 2, 2025

Hi,

Thanks for your vectorized MARL environment! It helps a lot!

I have noticed that you mentioned a bird eye 2D sensor is needed, I have created a custom grid sensor which may meets your requirements. Here is the code.

Since I'm no professional developer (sad), the code may need further debug.

# -*- coding: utf-8 -*-
# @Time : 2025/3/17 10:41
# @Author : Dong Shaoqian
# @File : sensors.py
# @Software: PyCharm
import torch
from torch import Tensor

from vmas.simulator.core import World, Entity, Agent, Landmark
from vmas.simulator.utils import Color, X, Y
from vmas.simulator.sensors import Sensor
from vmas.simulator.rendering import Geom

from typing import List, Union, Tuple

from envs.simulator.utils import entity_size


def is_ally(agent: Agent, entity: Entity):
    if (isinstance(entity, Agent)
            and not (agent.adversary ^ entity.adversary)):
        return True
    else:
        return False


def is_enemy(agent: Agent, entity: Entity):
    if (isinstance(entity, Agent)
            and (agent.adversary ^ entity.adversary)):
        return True
    else:
        return False


def is_obstacle(_, entity: Entity):
    if isinstance(entity, Landmark) and 'obstacle' in entity.name:
        return True
    else:
        return False


def is_target(_, entity: Entity):
    if isinstance(entity, Landmark) and 'target' in entity.name:
        return True
    else:
        return False


def out_of_sensor_range(resolution, coordinate):
    if (coordinate < resolution / 2
            or coordinate >= resolution / 2):
        return False
    else:
        return True


class GridSensor(Sensor):
    def __init__(
            self,
            world: World,
            resolution: int = 28,
            grid_size: float = 0.02,
            render_color: Union[Color, Tuple[float, float, float]] = (0.7, 0.7, 0.7),
            render: bool = False,
    ):
        super().__init__(world=world)
        self._grid_size = grid_size * 0.5 * (world.x_semidim + world.y_semidim)
        self._world = world
        self._render = render
        self._render_color = render_color
        self._resolution = resolution
        self._n_grids = resolution ** 2

        self.check_resolution()

    @property
    def world(self):
        return self._world

    @property
    def grid_size(self):
        return self._grid_size

    @property
    def n_grids(self):
        return self._n_grids

    @property
    def resolution(self):
        return self._resolution

    def check_resolution(self):
        for entity in self.world.entities:
            radius = entity_size(entity)
            assert self.grid_size * 1.414 <= radius, \
                f'{entity.name} are too small to be detected'

    def measure_entities(self, entity_filter):
        grids = []
        for entity in self.world.entities:
            if self.agent is entity or not entity_filter(self.agent, entity=entity):
                continue
            rel_pos = entity.state.pos - self.agent.state.pos  # (batch, 2)
            rel_pos_x = rel_pos[:, X].reshape(-1, 1, 1)
            rel_pos_y = rel_pos[:, Y].reshape(-1, 1, 1)  # (batch, )

            center_range = self.grid_size * (self.resolution / 2 - 0.5)
            mesh = torch.arange(
                -center_range,
                center_range + self.grid_size / 2,
                self.grid_size,
                device=self.world.device
            )
            x, y = torch.meshgrid(mesh, mesh, indexing='ij')  # (res, res)
            x = x.expand(
                self.world.batch_dim, self.resolution, self.resolution
            )
            y = y.expand(
                self.world.batch_dim, self.resolution, self.resolution
            )  # (batch, res, res)

            dist = torch.sqrt(
                (x - rel_pos_x) ** 2
                + (y - rel_pos_y) ** 2
            )  # (batch, res, res)

            radius = entity_size(entity)
            grid = (dist <= radius).to(torch.int32)  # (batch, res, res)
            grids.append(grid)

        if len(grids) == 0:
            return torch.zeros((self.world.batch_dim, self.resolution, self.resolution),
                               dtype=torch.float32,
                               device=self.world.device)
        else:
            grids = torch.stack(grids, dim=-1)  # (batch, res, res, num_entities)
            grids = grids.sum(-1).to(torch.bool)  # (batch, res, res)

            return grids.to(torch.float32)

    def measure(self) -> Tensor:
        allies = self.measure_entities(is_ally)
        enemies = self.measure_entities(is_enemy)
        obstacles = self.measure_entities(is_obstacle)
        target = self.measure_entities(is_target)

        return torch.stack(
            [
                allies,
                enemies,
                obstacles,
                target
            ],
            dim=1
        )  # (batch, channels=4, res, res)

    def render(self, env_index: int = 0) -> "List[Geom]":
        return []

    def to(self, device: torch.device):
        pass

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants