Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix type hints after adding Branca type checking #2060

Merged
merged 13 commits into from
Dec 29, 2024
2 changes: 1 addition & 1 deletion folium/elements.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to rename these methods, since they really do something different than branca.Element.render. See the discussion I started on the signature redefinition.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand all the typing issues fully, but in general it looks good. I don't mind going ahead with this as is to fix the immediate issues. We can discuss my proposal to rename render from MacroElement and its children at a later moment. Whatever you think is best for now.

Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ class JSCSSMixin(Element):
default_js: List[Tuple[str, str]] = []
default_css: List[Tuple[str, str]] = []

def render(self, **kwargs) -> None:
def render(self, **kwargs):
figure = self.get_root()
assert isinstance(
figure, Figure
78 changes: 50 additions & 28 deletions folium/features.py
Original file line number Diff line number Diff line change
@@ -12,14 +12,24 @@
import numpy as np
import requests
from branca.colormap import ColorMap, LinearColormap, StepColormap
from branca.element import Element, Figure, Html, IFrame, JavascriptLink, MacroElement
from branca.element import (
Div,
Element,
Figure,
Html,
IFrame,
JavascriptLink,
MacroElement,
)
from branca.utilities import color_brewer

from folium.elements import JSCSSMixin
from folium.folium import Map
from folium.map import FeatureGroup, Icon, Layer, Marker, Popup, Tooltip
from folium.template import Template
from folium.utilities import (
TypeBoundsReturn,
TypeContainer,
TypeJsonValue,
TypeLine,
TypePathOptions,
@@ -165,7 +175,7 @@ def __init__(
self.top = _parse_size(top)
self.position = position

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
super().render(**kwargs)

@@ -284,9 +294,15 @@ def __init__(
self.top = _parse_size(top)
self.position = position

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
self._parent.html.add_child(
parent = self._parent
if not isinstance(parent, (Figure, Div, Popup)):
raise TypeError(
"VegaLite elements can only be added to a Figure, Div, or Popup"
)

parent.html.add_child(
Element(
Template(
"""
@@ -331,7 +347,7 @@ def render(self, **kwargs) -> None:
embed_vegalite = embed_mapping.get(
self.vegalite_major_version, self._embed_vegalite_v2
)
embed_vegalite(figure)
embed_vegalite(figure=figure, parent=parent)

@property
def vegalite_major_version(self) -> Optional[int]:
@@ -342,8 +358,8 @@ def vegalite_major_version(self) -> Optional[int]:

return int(schema.split("/")[-1].split(".")[0].lstrip("v"))

def _embed_vegalite_v5(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v5(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm//vega@5"), name="vega"
@@ -356,8 +372,8 @@ def _embed_vegalite_v5(self, figure: Figure) -> None:
name="vega-embed",
)

def _embed_vegalite_v4(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v4(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm//vega@5"), name="vega"
@@ -370,8 +386,8 @@ def _embed_vegalite_v4(self, figure: Figure) -> None:
name="vega-embed",
)

def _embed_vegalite_v3(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v3(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm/vega@4"), name="vega"
@@ -384,8 +400,8 @@ def _embed_vegalite_v3(self, figure: Figure) -> None:
name="vega-embed",
)

def _embed_vegalite_v2(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v2(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm/vega@3"), name="vega"
@@ -398,8 +414,8 @@ def _embed_vegalite_v2(self, figure: Figure) -> None:
name="vega-embed",
)

def _vega_embed(self) -> None:
self._parent.script.add_child(
def _vega_embed(self, parent: TypeContainer) -> None:
parent.script.add_child(
Element(
Template(
"""
@@ -412,8 +428,8 @@ def _vega_embed(self) -> None:
name=self.get_name(),
)

def _embed_vegalite_v1(self, figure: Figure) -> None:
self._parent.script.add_child(
def _embed_vegalite_v1(self, figure: Figure, parent: TypeContainer) -> None:
parent.script.add_child(
Element(
Template(
"""
@@ -436,19 +452,19 @@ def _embed_vegalite_v1(self, figure: Figure) -> None:
figure.header.add_child(
JavascriptLink("https://cdnjs.cloudflare.com/ajax/libs/vega/2.6.5/vega.js"),
name="vega",
) # noqa
)
figure.header.add_child(
JavascriptLink(
"https://cdnjs.cloudflare.com/ajax/libs/vega-lite/1.3.1/vega-lite.js"
),
name="vega-lite",
) # noqa
)
figure.header.add_child(
JavascriptLink(
"https://cdnjs.cloudflare.com/ajax/libs/vega-embed/2.2.0/vega-embed.js"
),
name="vega-embed",
) # noqa
)


class GeoJson(Layer):
@@ -820,7 +836,7 @@ def _get_self_bounds(self) -> List[List[Optional[float]]]:
"""
return get_bounds(self.data, lonlat=True)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
self.parent_map = get_obj_in_upper_tree(self, Map)
# Need at least one feature, otherwise style mapping fails
if (self.style or self.highlight) and self.data["features"]:
@@ -1041,12 +1057,12 @@ def recursive_get(data, keys):
self.style_function(feature)
) # noqa

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
self.style_data()
super().render(**kwargs)

def get_bounds(self) -> List[List[float]]:
def get_bounds(self) -> TypeBoundsReturn:
"""
Computes the bounds of the object itself (not including it's children)
in the form [[lat_min, lon_min], [lat_max, lon_max]]
@@ -1146,6 +1162,7 @@ def __init__(

def warn_for_geometry_collections(self) -> None:
"""Checks for GeoJson GeometryCollection features to warn user about incompatibility."""
assert isinstance(self._parent, GeoJson)
geom_collections = [
feature.get("properties") if feature.get("properties") is not None else key
for key, feature in enumerate(self._parent.data["features"])
@@ -1160,7 +1177,7 @@ def warn_for_geometry_collections(self) -> None:
UserWarning,
)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
figure = self.get_root()
if isinstance(self._parent, GeoJson):
@@ -1565,7 +1582,7 @@ def __init__(
color_range = color_brewer(fill_color, n=nb_bins)
self.color_scale = StepColormap(
color_range,
index=bin_edges,
index=list(bin_edges),
vmin=bins_min,
vmax=bins_max,
caption=legend_name,
@@ -1625,7 +1642,7 @@ def highlight_function(x):
return {"weight": line_weight + 2, "fillOpacity": fill_opacity + 0.2}

if topojson:
self.geojson = TopoJson(
self.geojson: Union[TopoJson, GeoJson] = TopoJson(
geo_data,
topojson,
style_function=style_function,
@@ -1657,7 +1674,7 @@ def _get_by_key(cls, obj: Union[dict, list], key: str) -> Union[float, str, None
else:
return value

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Render the GeoJson/TopoJson and color scale objects."""
if self.color_scale:
# ColorMap needs Map as its parent
@@ -1963,8 +1980,13 @@ def __init__(
vmin=min(colors),
vmax=max(colors),
).to_step(nb_steps)
else:
elif isinstance(colormap, StepColormap):
cm = colormap
else:
raise TypeError(
f"Unexpected type for argument `colormap`: {type(colormap)}"
)

out: Dict[str, List[List[List[float]]]] = {}
for (lat1, lng1), (lat2, lng2), color in zip(coords[:-1], coords[1:], colors):
out.setdefault(cm(color), []).append([[lat1, lng1], [lat2, lng2]])
2 changes: 1 addition & 1 deletion folium/folium.py
Original file line number Diff line number Diff line change
@@ -377,7 +377,7 @@ def _repr_png_(self) -> Optional[bytes]:
return None
return self._to_png()

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
figure = self.get_root()
assert isinstance(
13 changes: 7 additions & 6 deletions folium/map.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@

import warnings
from collections import OrderedDict
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Optional, Sequence, Union, cast

from branca.element import Element, Figure, Html, MacroElement

@@ -14,6 +14,7 @@
from folium.utilities import (
JsCode,
TypeBounds,
TypeBoundsReturn,
TypeJsonValue,
escape_backticks,
parse_options,
@@ -221,7 +222,7 @@ def reset(self) -> None:
self.base_layers = OrderedDict()
self.overlays = OrderedDict()

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
self.reset()
for item in self._parent._children.values():
@@ -396,15 +397,15 @@ def __init__(
tooltip if isinstance(tooltip, Tooltip) else Tooltip(str(tooltip))
)

def _get_self_bounds(self) -> List[List[float]]:
def _get_self_bounds(self) -> TypeBoundsReturn:
"""Computes the bounds of the object itself.

Because a marker has only single coordinates, we repeat them.
"""
assert self.location is not None
return [self.location, self.location]
return cast(TypeBoundsReturn, [self.location, self.location])

def render(self) -> None:
def render(self):
if self.location is None:
raise ValueError(
f"{self._name} location must be assigned when added directly to map."
@@ -492,7 +493,7 @@ def __init__(
**kwargs,
)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
for name, child in self._children.items():
child.render(**kwargs)
2 changes: 1 addition & 1 deletion folium/plugins/overlapping_marker_spiderfier.py
Original file line number Diff line number Diff line change
@@ -92,7 +92,7 @@ def add_to(
) -> Element:
self._parent = parent
self.markers = self._get_all_markers(parent)
super().add_to(parent, name=name, index=index)
return super().add_to(parent, name=name, index=index)

def _get_all_markers(self, element: Element) -> list:
markers = []
16 changes: 9 additions & 7 deletions folium/raster_layers.py
Original file line number Diff line number Diff line change
@@ -12,9 +12,11 @@
from folium.template import Template
from folium.utilities import (
TypeBounds,
TypeBoundsReturn,
TypeJsonValue,
image_to_url,
mercator_transform,
normalize_bounds_type,
parse_options,
remove_empty,
)
@@ -246,7 +248,7 @@ class ImageOverlay(Layer):
* If string, it will be written directly in the output file.
* If file, it's content will be converted as embedded in the output file.
* If array-like, it will be converted to PNG base64 string and embedded in the output.
bounds: list
bounds: list/tuple of list/tuple of float
Image bounds on the map in the form
[[lat_min, lon_min], [lat_max, lon_max]]
opacity: float, default Leaflet's default (1.0)
@@ -319,7 +321,7 @@ def __init__(

self.url = image_to_url(image, origin=origin, colormap=colormap)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
super().render()

figure = self.get_root()
@@ -344,13 +346,13 @@ def render(self, **kwargs) -> None:
Element(pixelated), name="leaflet-image-layer"
) # noqa

def _get_self_bounds(self) -> TypeBounds:
def _get_self_bounds(self) -> TypeBoundsReturn:
"""
Computes the bounds of the object itself (not including it's children)
in the form [[lat_min, lon_min], [lat_max, lon_max]].

"""
return self.bounds
return normalize_bounds_type(self.bounds)


class VideoOverlay(Layer):
@@ -361,7 +363,7 @@ class VideoOverlay(Layer):
----------
video_url: str
URL of the video
bounds: list
bounds: list/tuple of list/tuple of float
Video bounds on the map in the form
[[lat_min, lon_min], [lat_max, lon_max]]
autoplay: bool, default True
@@ -411,10 +413,10 @@ def __init__(
self.bounds = bounds
self.options = remove_empty(autoplay=autoplay, loop=loop, **kwargs)

def _get_self_bounds(self) -> TypeBounds:
def _get_self_bounds(self) -> TypeBoundsReturn:
"""
Computes the bounds of the object itself (not including it's children)
in the form [[lat_min, lon_min], [lat_max, lon_max]]

"""
return self.bounds
return normalize_bounds_type(self.bounds)
Loading
Loading