From f6e1a219861db8009ab1af76412dc9328a4dab9a Mon Sep 17 00:00:00 2001 From: beauxq Date: Thu, 28 Apr 2022 09:03:44 -0700 Subject: [PATCH] typing, mostly in AutoWorld.py includes a bugfix (that was found by static type checking) in `get_filler_item_name` --- BaseClasses.py | 2 +- Utils.py | 22 +++++++++------ docs/api.md | 10 +++---- worlds/AutoWorld.py | 69 +++++++++++++++++++++++++-------------------- 4 files changed, 57 insertions(+), 46 deletions(-) diff --git a/BaseClasses.py b/BaseClasses.py index e5d92b95b68f..e3b8aaf7d107 100644 --- a/BaseClasses.py +++ b/BaseClasses.py @@ -36,7 +36,7 @@ class MultiWorld(): plando_texts: List[Dict[str, str]] plando_items: List[List[Dict[str, Any]]] plando_connections: List - worlds: Dict[int, Any] + worlds: Dict[int, auto_world] groups: Dict[int, Group] itempool: List[Item] is_race: bool = False diff --git a/Utils.py b/Utils.py index 2068d48baaaa..a764dea95dbf 100644 --- a/Utils.py +++ b/Utils.py @@ -36,41 +36,45 @@ class Version(typing.NamedTuple): from yaml import Loader -def int16_as_bytes(value): +def int16_as_bytes(value: int) -> typing.List[int]: value = value & 0xFFFF return [value & 0xFF, (value >> 8) & 0xFF] -def int32_as_bytes(value): +def int32_as_bytes(value: int) -> typing.List[int]: value = value & 0xFFFFFFFF return [value & 0xFF, (value >> 8) & 0xFF, (value >> 16) & 0xFF, (value >> 24) & 0xFF] -def pc_to_snes(value): +def pc_to_snes(value: int) -> int: return ((value << 1) & 0x7F0000) | (value & 0x7FFF) | 0x8000 -def snes_to_pc(value): +def snes_to_pc(value: int) -> int: return ((value & 0x7F0000) >> 1) | (value & 0x7FFF) -def cache_argsless(function): +RetType = typing.TypeVar("RetType") + + +def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[], RetType]: if function.__code__.co_argcount: raise Exception("Can only cache 0 argument functions with this cache.") - result = sentinel = object() + sentinel = object() + result: typing.Union[object, RetType] = sentinel - def _wrap(): + def _wrap() -> RetType: nonlocal result if result is sentinel: result = function() - return result + return typing.cast(RetType, result) return _wrap def is_frozen() -> bool: - return getattr(sys, 'frozen', False) + return typing.cast(bool, getattr(sys, 'frozen', False)) def local_path(*path: str) -> str: diff --git a/docs/api.md b/docs/api.md index cb1b1e4bddc1..f81c29b3db4f 100644 --- a/docs/api.md +++ b/docs/api.md @@ -430,7 +430,7 @@ In addition, the following methods can be implemented and attributes can be set #### generate_early ```python -def generate_early(self): +def generate_early(self) -> None: # read player settings to world instance self.final_boss_hp = self.world.final_boss_hp[self.player].value ``` @@ -456,7 +456,7 @@ def create_event(self, event: str): #### create_items ```python -def create_items(self): +def create_items(self) -> None: # Add items to the Multiworld. # If there are two of the same item, the item has to be twice in the pool. # Which items are added to the pool may depend on player settings, @@ -483,7 +483,7 @@ def create_items(self): #### create_regions ```python -def create_regions(self): +def create_regions(self) -> None: # Add regions to the multiworld. "Menu" is the required starting point. # Arguments to Region() are name, type, human_readable_name, player, world r = Region("Menu", None, "Menu", self.player, self.world) @@ -518,7 +518,7 @@ def create_regions(self): #### generate_basic ```python -def generate_basic(self): +def generate_basic(self) -> None: # place "Victory" at "Final Boss" and set collection as win condition self.world.get_location("Final Boss", self.player)\ .place_locked_item(self.create_event("Victory")) @@ -539,7 +539,7 @@ def generate_basic(self): from ..generic.Rules import add_rule, set_rule, forbid_item from Items import get_item_type -def set_rules(self): +def set_rules(self) -> None: # For some worlds this step can be omitted if either a Logic mixin # (see below) is used, it's easier to apply the rules from data during # location generation or everything is in generate_basic diff --git a/worlds/AutoWorld.py b/worlds/AutoWorld.py index b2df8486d09d..ee7f618d7a41 100644 --- a/worlds/AutoWorld.py +++ b/worlds/AutoWorld.py @@ -1,16 +1,16 @@ from __future__ import annotations import logging -from typing import Dict, Set, Tuple, List, Optional, TextIO, Any, Callable, Union +from typing import Dict, FrozenSet, Set, Tuple, List, Optional, TextIO, Any, Callable, Union from BaseClasses import MultiWorld, Item, CollectionState, Location from Options import Option class AutoWorldRegister(type): - world_types: Dict[str, World] = {} + world_types: Dict[str, AutoWorldRegister] = {} - def __new__(cls, name: str, bases, dct: Dict[str, Any]): + def __new__(cls, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoWorldRegister: if "web" in dct: assert isinstance(dct["web"], WebWorld), "WebWorld has to be instantiated." # filter out any events @@ -34,7 +34,8 @@ def __new__(cls, name: str, bases, dct: Dict[str, Any]): if "required_client_version" in dct and bases: for base in bases: if "required_client_version" in base.__dict__: - dct["required_client_version"] = max(dct["required_client_version"], base.required_client_version) + dct["required_client_version"] = max(dct["required_client_version"], + base.__dict__["required_client_version"]) # construct class new_class = super().__new__(cls, name, bases, dct) @@ -44,9 +45,9 @@ def __new__(cls, name: str, bases, dct: Dict[str, Any]): class AutoLogicRegister(type): - def __new__(cls, name, bases, dct): + def __new__(cls, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoLogicRegister: new_class = super().__new__(cls, name, bases, dct) - function: Callable + function: Callable[..., Any] for item_name, function in dct.items(): if item_name == "copy_mixin": CollectionState.additional_copy_functions.append(function) @@ -59,13 +60,13 @@ def __new__(cls, name, bases, dct): return new_class -def call_single(world: MultiWorld, method_name: str, player: int, *args): +def call_single(world: MultiWorld, method_name: str, player: int, *args: Any) -> Any: method = getattr(world.worlds[player], method_name) return method(*args) -def call_all(world: MultiWorld, method_name: str, *args): - world_types = set() +def call_all(world: MultiWorld, method_name: str, *args: Any) -> None: + world_types: Set[AutoWorldRegister] = set() for player in world.player_ids: world_types.add(world.worlds[player].__class__) call_single(world, method_name, player, *args) @@ -76,7 +77,7 @@ def call_all(world: MultiWorld, method_name: str, *args): stage_callable(world, *args) -def call_stage(world: MultiWorld, method_name: str, *args): +def call_stage(world: MultiWorld, method_name: str, *args: Any) -> None: world_types = {world.worlds[player].__class__ for player in world.player_ids} for world_type in world_types: stage_callable = getattr(world_type, f"stage_{method_name}", None) @@ -101,10 +102,12 @@ class World(metaclass=AutoWorldRegister): """A World object encompasses a game's Items, Locations, Rules and additional data or functionality required. A Game should have its own subclass of World in which it defines the required data structures.""" - options: Dict[str, type(Option)] = {} # link your Options mapping + options: Dict[str, Option[Any]] = {} # link your Options mapping game: str # name the game topology_present: bool = False # indicate if world type has any meaningful layout/pathing - all_item_and_group_names: Set[str] = frozenset() # gets automatically populated with all item and item group names + + # gets automatically populated with all item and item group names + all_item_and_group_names: FrozenSet[str] = frozenset() # map names to their IDs item_name_to_id: Dict[str, int] = {} @@ -126,7 +129,7 @@ class World(metaclass=AutoWorldRegister): # update this if the resulting multidata breaks forward-compatibility of the server required_server_version: Tuple[int, int, int] = (0, 2, 4) - hint_blacklist: Set[str] = frozenset() # any names that should not be hintable + hint_blacklist: FrozenSet[str] = frozenset() # any names that should not be hintable # NOTE: remote_items and remote_start_inventory are now available in the network protocol for the client to set. # These values will be removed. @@ -168,61 +171,65 @@ def __init__(self, world: MultiWorld, player: int): # can also be implemented as a classmethod and called "stage_", # in that case the MultiWorld object is passed as an argument and it gets called once for the entire multiworld. # An example of this can be found in alttp as stage_pre_fill - def generate_early(self): + def generate_early(self) -> None: pass - def create_regions(self): + def create_regions(self) -> None: pass - def create_items(self): + def create_items(self) -> None: pass - def set_rules(self): + def set_rules(self) -> None: pass - def generate_basic(self): + def generate_basic(self) -> None: pass - def pre_fill(self): + def pre_fill(self) -> None: """Optional method that is supposed to be used for special fill stages. This is run *after* plando.""" pass @classmethod - def fill_hook(cls, progitempool: List[Item], nonexcludeditempool: List[Item], - localrestitempool: Dict[int, List[Item]], nonlocalrestitempool: Dict[int, List[Item]], - restitempool: List[Item], fill_locations: List[Location]): + def fill_hook(cls, + progitempool: List[Item], + nonexcludeditempool: List[Item], + localrestitempool: Dict[int, List[Item]], + nonlocalrestitempool: Dict[int, List[Item]], + restitempool: List[Item], + fill_locations: List[Location]) -> None: """Special method that gets called as part of distribute_items_restrictive (main fill). This gets called once per present world type.""" pass - def post_fill(self): + def post_fill(self) -> None: """Optional Method that is called after regular fill. Can be used to do adjustments before output generation.""" - def generate_output(self, output_directory: str): + def generate_output(self, output_directory: str) -> None: """This method gets called from a threadpool, do not use world.random here. If you need any last-second randomization, use MultiWorld.slot_seeds[slot] instead.""" pass - def fill_slot_data(self) -> dict: + def fill_slot_data(self) -> Dict[str, Any]: # json of WebHostLib.models.Slot """Fill in the slot_data field in the Connected network package.""" return {} - def modify_multidata(self, multidata: dict): + def modify_multidata(self, multidata: Dict[str, Any]) -> None: # TODO: TypedDict for multidata? """For deeper modification of server multidata.""" pass # Spoiler writing is optional, these may not get called. - def write_spoiler_header(self, spoiler_handle: TextIO): + def write_spoiler_header(self, spoiler_handle: TextIO) -> None: """Write to the spoiler header. If individual it's right at the end of that player's options, if as stage it's right under the common header before per-player options.""" pass - def write_spoiler(self, spoiler_handle: TextIO): + def write_spoiler(self, spoiler_handle: TextIO) -> None: """Write to the spoiler "middle", this is after the per-player options and before locations, meant for useful or interesting info.""" pass - def write_spoiler_end(self, spoiler_handle: TextIO): + def write_spoiler_end(self, spoiler_handle: TextIO) -> None: """Write to the end of the spoiler""" pass @@ -236,7 +243,7 @@ def create_item(self, name: str) -> Item: def get_filler_item_name(self) -> str: """Called when the item pool needs to be filled with additional items to match location count.""" logging.warning(f"World {self} is generating a filler item without custom filler pool.") - return self.world.random.choice(self.item_name_to_id) + return self.world.random.choice(tuple(self.item_name_to_id.keys())) # decent place to implement progressive items, in most cases can stay as-is def collect_item(self, state: CollectionState, item: Item, remove: bool = False) -> Optional[str]: @@ -247,6 +254,7 @@ def collect_item(self, state: CollectionState, item: Item, remove: bool = False) :param remove: indicate if this is meant to remove from state instead of adding.""" if item.advancement: return item.name + return None # called to create all_state, return Items that are created during pre_fill def get_pre_fill_items(self) -> List[Item]: @@ -277,4 +285,3 @@ def create_filler(self) -> Item: # please use a prefix as all of them get clobbered together class LogicMixin(metaclass=AutoLogicRegister): pass -