From 36412dc8f0fd061798a701ca632766dbe6f069c8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Rapha=C3=ABl=20Nussbaumer?= <rafnuss@gmail.com>
Date: Sun, 14 Jan 2024 19:46:27 +0100
Subject: [PATCH] Move computation of wind to edge_add_wind

---
 R/edge_add_wind.R  | 316 ++++++++++++++++++++++++-------------------
 R/graph_add_wind.R | 324 +++------------------------------------------
 2 files changed, 200 insertions(+), 440 deletions(-)

diff --git a/R/edge_add_wind.R b/R/edge_add_wind.R
index dc38ae1b..9e462990 100644
--- a/R/edge_add_wind.R
+++ b/R/edge_add_wind.R
@@ -1,43 +1,53 @@
 #' Retrieve ERA5 variable along edge
 #'
 #' @description
-#' Reads the NetCDF files downloaded and interpolate the average windspeed experienced by the
-#' bird on each possible edge, as well as the corresponding airspeed.
+#' Reads the NetCDF files and extract the variable requested along each flight defined by the edges.
+#'
+#' - Time: linear interpolation using the resolution requested with `rounding_interval`
+#' - Space: nearet neighbor interpolation by default or bi-linear with `pracma::interp2` if
+#' `interp_spatial_linear=TRUE`
+#' - Pressure/altitude: linear interpolation using the exact `pressure` values
 #'
 #' @param tag_graph either a `tag` or a `graph` GeoPressureR object.
-#' @param edge_s a matrix of the index of the source node of the edge
-#' @param edge_t a matrix of the index of the target node of the edge
+#' @param edge_s a index of the source node of the edge. Either a vector with 3D index or a matrix
+#' of 3 columns, one for each dimension.
+#' @param edge_t a index of the target node of the edge. Either a vector with 3D index or a matrix
+#' of 3 columns, one for each dimension.
 #' @param pressure pressure measurement of the associated `tag` data used to estimate the pressure
 #' level (i.e., altitude) of the bird during the flights. This data.frame needs to contain `date` as
 #' POSIXt and `value` in hPa.
-#' @param variable variable the the
-#' @param return_average_var logical
-#' @param interp_spatial_linear logical
+#' @inheritParams tag_download_wind
+#' @param return_average_var logical to return the variable for each timestep or average for the
+#' entire flight.
+#' @param interp_spatial_linear logical to interpolate the variable linearly over space, if `FALSE`
+#' takes the nearest neighbour. ERA5 native resolution is 0.25°
+#' @param rounding_interval temporal resolution on which to query the variable (min). Default is to
+#' macth ERA5 native resolution (1hr).
 #' @param quiet logical to hide messages about the progress
 #'
-#' @return a `graph` object with windspeed and airspeed as `ws` and `as` respectively.
+#' @return ...
 #'
-#' @family graph
-#' @family movement
-#' @references{ Nussbaumer, Raphaël, Mathieu Gravey, Martins Briedis, Felix Liechti, and Daniel
-#' Sheldon. 2023. Reconstructing bird trajectories from pressure and wind data using a highly
-#' optimized hidden Markov model. *Methods in Ecology and Evolution*, 14, 1118–1129
-#' <https://doi.org/10.1111/2041-210X.14082>.}
 #' @seealso [GeoPressureManual](
 #' https://raphaelnussbaumer.com/GeoPressureManual/trajectory-with-wind.html)
 #' @export
-edge_add_wind <- function(tag_graph,
-                     edge_s,
-                     edge_t,
-                     pressure = NULL,
-                     variable = c("u", "v"),
-                     rounding_interval = 60,
-                     interp_spatial_linear = FALSE,
-                     return_average_variable = FALSE,
-                     file = \(stap_id) glue::glue("./data/wind/{tag_graph$param$id}/{tag_graph$param$id}_{stap_id}.nc"),
-                     quiet = FALSE) {
+edge_add_wind <- function(
+    tag_graph,
+    edge_s,
+    edge_t,
+    pressure = NULL,
+    variable = c("u", "v"),
+    rounding_interval = 60,
+    interp_spatial_linear = FALSE,
+    return_average_variable = FALSE,
+    file = \(stap_id) {
+      glue::glue("./data/wind/{tag_graph$param$id}/{tag_graph$param$id}_{stap_id}.nc")
+    },
+    quiet = FALSE) {
+  if (is.null(pressure) && inherits(tag_graph, "tag")) {
+    pressure <- tag_graph$pressure
+  }
 
-  assertthat::assert_that(inherits(tag_graph, "tag") | inherits(tag_graph, "graph"))
+  edge_add_wind_check(tag_graph, pressure, file)
 
   # Compute lat-lon coordinate of the grid
   g <- map_expand(tag_graph$param$extent, tag_graph$param$scale)
@@ -45,41 +55,37 @@ edge_add_wind <- function(tag_graph,
   # Compute flight from stap
   flight <- stap2flight(tag_graph$stap, format = "list")
 
-  # Check pressure
-  if (is.null(pressure) & inherits(tag_graph, "tag")){
-    pressure <- tag_graph$pressure
+  # Check edges
+  if (!is.matrix(edge_s)) {
+    edge_s <- arrayInd(edge_s, c(g$dim, nrow(tag_graph$stap)))
+  }
+  if (!is.matrix(edge_t)) {
+    edge_t <- arrayInd(edge_t, c(g$dim, nrow(tag_graph$stap)))
   }
-  assertthat::assert_that(is.data.frame(pressure))
-  assertthat::assert_that(assertthat::has_name(pressure, c("date", "value")))
-  assertthat::assert_that(assertthat::is.time(pressure$date))
-  assertthat::assert_that(is.numeric(pressure$value))
-
-  # Check file
-  assertthat::assert_that(is.function(file))
-  edge_add_wind_check(file, flight, pressure, g)
 
-  # Check edges
   assertthat::assert_that(assertthat::are_equal(dim(edge_s), dim(edge_t)))
-  assertthat::assert_that(assertthat::are_equal(edge_t[,1], as.integer(edge_t[,1])))
-  assertthat::assert_that(assertthat::are_equal(edge_t[,2], as.integer(edge_t[,2])))
-  assertthat::assert_that(assertthat::are_equal(edge_t[,3], as.integer(edge_t[,3])))
-  assertthat::assert_that(assertthat::are_equal(edge_s[,1], as.integer(edge_s[,1])))
-  assertthat::assert_that(assertthat::are_equal(edge_s[,2], as.integer(edge_s[,2])))
-  assertthat::assert_that(assertthat::are_equal(edge_s[,3], as.integer(edge_s[,3])))
-  assertthat::assert_that(all(edge_t[,3]>1 & edge_t[,3]<=max(tag_graph$stap$stap_id)))
-  assertthat::assert_that(all(edge_s[,3]>=1 & edge_s[,3]<max(tag_graph$stap$stap_id)))
-  assertthat::assert_that(all(edge_t[,1]>=1 & edge_t[,1]<=g$dim[1]))
-  assertthat::assert_that(all(edge_t[,2]>=1 & edge_t[,2]<=g$dim[2]))
-  assertthat::assert_that(all(edge_s[,1]>=1 & edge_s[,1]<=g$dim[1]))
-  assertthat::assert_that(all(edge_s[,2]>=1 & edge_s[,2]<=g$dim[2]))
+  assertthat::assert_that(assertthat::are_equal(dim(edge_s)[2], 3))
+  assertthat::assert_that(assertthat::are_equal(dim(edge_t)[2], 3))
+  assertthat::assert_that(assertthat::are_equal(edge_t[, 1], as.integer(edge_t[, 1])))
+  assertthat::assert_that(assertthat::are_equal(edge_t[, 2], as.integer(edge_t[, 2])))
+  assertthat::assert_that(assertthat::are_equal(edge_t[, 3], as.integer(edge_t[, 3])))
+  assertthat::assert_that(assertthat::are_equal(edge_s[, 1], as.integer(edge_s[, 1])))
+  assertthat::assert_that(assertthat::are_equal(edge_s[, 2], as.integer(edge_s[, 2])))
+  assertthat::assert_that(assertthat::are_equal(edge_s[, 3], as.integer(edge_s[, 3])))
+  assertthat::assert_that(all(edge_t[, 3] > 1 & edge_t[, 3] <= max(tag_graph$stap$stap_id)))
+  assertthat::assert_that(all(edge_s[, 3] >= 1 & edge_s[, 3] < max(tag_graph$stap$stap_id)))
+  assertthat::assert_that(all(edge_t[, 1] >= 1 & edge_t[, 1] <= g$dim[1]))
+  assertthat::assert_that(all(edge_t[, 2] >= 1 & edge_t[, 2] <= g$dim[2]))
+  assertthat::assert_that(all(edge_s[, 1] >= 1 & edge_s[, 1] <= g$dim[1]))
+  assertthat::assert_that(all(edge_s[, 2] >= 1 & edge_s[, 2] <= g$dim[2]))
 
   # Prepare the matrix of speed to return
-  if (return_average_variable){
-    VAR <- matrix(NA, nrow = length(edge_s), ncol = length(variable))
+  if (return_average_variable) {
+    var <- matrix(NA, nrow = nrow(edge_s), ncol = length(variable))
   } else {
-    VAR <- list()
-    for (var_i in seq_len(length(variable))){
-      VAR[[var_i]] <- vector("list", length(flight))
+    var <- list()
+    for (var_i in seq_len(length(variable))) {
+      var[[var_i]] <- vector("list", length(flight))
     }
   }
 
@@ -119,15 +125,15 @@ edge_add_wind <- function(tag_graph,
     ratio_stap <- as.matrix(c(0, cumsum(fl_s$duration) / sum(fl_s$duration)))
 
 
-    if (return_average_variable){
+    if (return_average_variable) {
       # Prepare the u- and v- windspeed for each flight (row) and edge (col)
-      VAR_stap <- list()
-      for (var_i in seq_len(length(variable))){
-        VAR_stap[[var_i]] <- matrix(NA, nrow = length(fl_s$stap_s), ncol = length(st_id))
+      var_stap <- list()
+      for (var_i in seq_len(length(variable))) {
+        var_stap[[var_i]] <- matrix(NA, nrow = length(fl_s$stap_s), ncol = length(st_id))
       }
     } else {
-      for (var_i in seq_len(length(variable))){
-        VAR[[var_i]][[i_stap]] <- vector("list", nrow(fl_s))
+      for (var_i in seq_len(length(variable))) {
+        var[[var_i]][[i_stap]] <- vector("list", nrow(fl_s))
       }
     }
 
@@ -141,7 +147,7 @@ edge_add_wind <- function(tag_graph,
 
       # Read data from netCDF file and convert the time of data to posixt
       time <- as.POSIXct(ncdf4::ncvar_get(nc, "time") * 60 * 60,
-                         origin = "1900-01-01", tz = "UTC"
+        origin = "1900-01-01", tz = "UTC"
       )
       pres <- ncdf4::ncvar_get(nc, "level")
       lat <- ncdf4::ncvar_get(nc, "latitude")
@@ -163,8 +169,10 @@ edge_add_wind <- function(tag_graph,
       # start and end time of the flight. Thus, we first round the start end end time.
 
       # Round down to the lower n-minute interval
-      t_s <- as.POSIXct(trunc(as.numeric(fl_s$start[i_fl]) / (60 * rounding_interval)) * (60 * rounding_interval), origin = "1970-01-01", tz = "UTC")
-      t_e <- as.POSIXct(ceiling(as.numeric(fl_s$end[i_fl]) / (60 * rounding_interval)) * (60 * rounding_interval), origin = "1970-01-01", tz = "UTC")
+      t_s <- as.POSIXct(trunc(as.numeric(fl_s$start[i_fl]) / (60 * rounding_interval)) *
+        (60 * rounding_interval), origin = "1970-01-01", tz = "UTC")
+      t_e <- as.POSIXct(ceiling(as.numeric(fl_s$end[i_fl]) / (60 * rounding_interval)) *
+        (60 * rounding_interval), origin = "1970-01-01", tz = "UTC")
       t_q <- seq(from = t_s, to = t_e, by = 60 * rounding_interval)
 
       # We assume that the bird is moving with a constant groundspeed between `flight$start` and
@@ -181,21 +189,23 @@ edge_add_wind <- function(tag_graph,
       lat_int <- lat_s + w2 * replicate(length(w), dlat_se)
       lon_int <- lon_s + w2 * replicate(length(w), dlon_se)
 
-      if (return_average_variable){
-        # As we are interesting in the average windspeed experienced during the entire flight, we need
-        # to find the weights of each 1hr interval extracted from ERA5. We can estimate these weight
-        # assuming a linear integration of the time (trapezoidal rule) or a step integration (Riemann
-        # sum)
+      if (TRUE) { # we use w for both return_average_variable TRUE and FALSE
+        # As we are interesting in the average windspeed experienced during the entire flight, we
+        # need to find the weights of each 1hr interval extracted from ERA5. We can estimate these
+        # weight assuming a linear integration of the time (trapezoidal rule) or a step integration
+        # (Riemann sum)
 
         # Linear integration
         w <- numeric(length(t_q))
         assertthat::assert_that(length(w) > 1)
 
-        alpha <- 1 - as.numeric(difftime(fl_s$start[i_fl], t_q[1], units = "mins"))/rounding_interval
+        alpha <- 1 - as.numeric(difftime(fl_s$start[i_fl], t_q[1], units = "mins")) /
+          rounding_interval
         assertthat::assert_that(alpha >= 0 & alpha <= 1)
         w[c(1, 2)] <- w[c(1, 2)] + c(alpha, 1 - alpha) * alpha
 
-        alpha <- 1 - as.numeric(difftime(utils::tail(t_q, 1), fl_s$end[i_fl], units = "mins"))/rounding_interval
+        alpha <- 1 - as.numeric(difftime(utils::tail(t_q, 1), fl_s$end[i_fl], units = "mins")) /
+          rounding_interval
         assertthat::assert_that(alpha >= 0 & alpha <= 1)
         w[length(w) - c(1, 0)] <- w[length(w) - c(1, 0)] + c(1 - alpha, alpha) * alpha
 
@@ -214,21 +224,20 @@ edge_add_wind <- function(tag_graph,
         # w <- difftime(pmin(pmax(t_q+60*60/2,fl_s$start[i_fl]),fl_s$end[i_fl]),
         #               pmin(pmax(t_q-60*60/2,fl_s$start[i_fl]),fl_s$end[i_fl]),
         #               units = "hours")
-
       }
 
       # Prepare the interpolated variable for each flight
-      VAR_fl <- list()
-      for (var_i in seq_len(length(variable))){
-        VAR_fl[[var_i]] <- matrix(NA, nrow = length(t_q), ncol = length(st_id))
+      var_fl <- list()
+      for (var_i in seq_len(length(variable))) {
+        var_fl[[var_i]] <- matrix(NA, nrow = length(t_q), ncol = length(st_id))
       }
 
       p_q <- numeric(length(t_q))
 
       # Find the index of lat and lon
-      if (!interp_spatial_linear){
-        lat_int_ind <- matrix(match(as.vector(round(lat_int*4)/4), lat), nrow=nrow(lat_int))
-        lon_int_ind <- matrix(match(as.vector(round(lon_int*4)/4), lon), nrow=nrow(lon_int))
+      if (!interp_spatial_linear) {
+        lat_int_ind <- matrix(match(as.vector(round(lat_int * 4) / 4), lat), nrow = nrow(lat_int))
+        lon_int_ind <- matrix(match(as.vector(round(lon_int * 4) / 4), lon), nrow = nrow(lon_int))
       }
 
       # Loop through the 1hr interval
@@ -248,48 +257,50 @@ edge_add_wind <- function(tag_graph,
         n_time <- ifelse(id_time == length(time) | time[id_time] == t_q[i_time], 1, 2)
 
         # Find the index of lat and longitude necessary
-        id_lon <- which(lon >= (min(lon_int[, i_time]) - dlon) & (max(lon_int[, i_time]) + dlon) >= lon)
-        id_lat <- which(lat >= (min(lat_int[, i_time]) - dlat) & (max(lat_int[, i_time]) + dlat) >= lat)
+        id_lon <- which(lon >= (min(lon_int[, i_time]) - dlon) &
+          (max(lon_int[, i_time]) + dlon) >= lon)
+        id_lat <- which(lat >= (min(lat_int[, i_time]) - dlat) &
+          (max(lat_int[, i_time]) + dlat) >= lat)
 
         # get the two maps of u- and v-
-        VAR_nc <- list()
-        for (var_i in seq_len(length(variable))){
-          VAR_nc[[var_i]] <- ncdf4::ncvar_get(nc, variable[var_i],
-                                              start = c(id_lon[1], id_lat[1], id_pres, id_time),
-                                              count = c(length(id_lon), length(id_lat), n_pres, n_time),
-                                              collapse_degen = FALSE
+        var_nc <- list()
+        for (var_i in seq_len(length(variable))) {
+          var_nc[[var_i]] <- ncdf4::ncvar_get(nc, variable[var_i],
+            start = c(id_lon[1], id_lat[1], id_pres, id_time),
+            count = c(length(id_lon), length(id_lat), n_pres, n_time),
+            collapse_degen = FALSE
           )
         }
 
         # Interpolate linearly along time
         if (n_time == 2) {
           w_time <- as.numeric(difftime(t_q[i_time], time[id_time], units = "hours")) /
-            as.numeric(difftime(time[id_time+1], time[id_time], units = "hours"))
-          for (var_i in seq_len(length(variable))){
-            VAR_nc[[var_i]] <- VAR_nc[[var_i]][, , , 1] +
-              w_time * (VAR_nc[[var_i]][, , , 2] - VAR_nc[[var_i]][, , , 1])
+            as.numeric(difftime(time[id_time + 1], time[id_time], units = "hours"))
+          for (var_i in seq_len(length(variable))) {
+            var_nc[[var_i]] <- var_nc[[var_i]][, , , 1] +
+              w_time * (var_nc[[var_i]][, , , 2] - var_nc[[var_i]][, , , 1])
           }
         } else {
-          for (var_i in seq_len(length(variable))){
-            VAR_nc[[var_i]] <- VAR_nc[[var_i]][, , , 1]
+          for (var_i in seq_len(length(variable))) {
+            var_nc[[var_i]] <- var_nc[[var_i]][, , , 1]
           }
         }
 
         # Interpolate linearly along altitude/pressure.
         if (n_pres == 2) {
-          w_pres <- (p_q[i_time] - pres[id_pres]) / (pres[id_pres+1] - pres[id_pres])
-          for (var_i in seq_len(length(variable))){
-            VAR_nc[[var_i]] <- VAR_nc[[var_i]][, , 1] +
-              w_pres * (VAR_nc[[var_i]][, , 2] - VAR_nc[[var_i]][, , 1])
+          w_pres <- (p_q[i_time] - pres[id_pres]) / (pres[id_pres + 1] - pres[id_pres])
+          for (var_i in seq_len(length(variable))) {
+            var_nc[[var_i]] <- var_nc[[var_i]][, , 1] +
+              w_pres * (var_nc[[var_i]][, , 2] - var_nc[[var_i]][, , 1])
           }
         }
 
-        if (interp_spatial_linear){
+        if (interp_spatial_linear) {
           # Interpolation the u- and v- component at the interpolated position at the current time
           # step.
           # Because lat_int and lon_int are so big, we round their value and only interpolate on the
-          # unique value that are needed. Then, we give the interpolated value back to all the lat_int
-          # lon_int dimension
+          # unique value that are needed. Then, we give the interpolated value back to all the
+          # lat_int lon_int dimension
           # Convert the coordinate to 1d to have a more efficient unique.
           ll_int_1d <- (round(lat_int[, i_time], 1) + 90) * 10 * 10000 +
             (round(lon_int[, i_time], 1) + 180) * 10 + 1
@@ -304,53 +315,62 @@ edge_add_wind <- function(tag_graph,
 
           id_uniq <- match(ll_int_1d, ll_int_1d_uniq)
 
-          for (var_i in seq_len(length(variable))){
+          for (var_i in seq_len(length(variable))) {
             tmp <- pracma::interp2(rev(lat[id_lat]), lon[id_lon],
-                                   VAR_nc[[var_i]][, rev(seq_len(ncol(VAR_nc[[var_i]])))],
-                                   lat_int_uniq, lon_int_uniq,
-                                   method = "linear"
+              var_nc[[var_i]][, rev(seq_len(ncol(var_nc[[var_i]])))],
+              lat_int_uniq, lon_int_uniq,
+              method = "linear"
             )
             assertthat::assert_that(all(!is.na(tmp)))
-            VAR_fl[[var_i]][i_time, ] <- tmp[id_uniq]
+            var_fl[[var_i]][i_time, ] <- tmp[id_uniq]
           }
         } else {
           # Take the closest value
-          for (var_i in seq_len(length(variable))){
-            ind <- (lon_int_ind[,i_time] - id_lon[1]) * ncol(VAR_nc[[var_i]]) + (lat_int_ind[,i_time] - id_lat[1] + 1)
-            VAR_fl[[var_i]][i_time, ] <- VAR_nc[[var_i]][ind]
+          for (var_i in seq_len(length(variable))) {
+            # Compute the index of lat, lon in the spatial extent extracted for var_nc
+            lon_int_ind_off <- lon_int_ind[, i_time] - id_lon[1] + 1
+            lat_int_ind_off <- lat_int_ind[, i_time] - id_lat[1] + 1
+
+            # compute the 2d index
+            ind <- (lat_int_ind_off - 1) * nrow(var_nc[[var_i]]) + lon_int_ind_off
+
+            # Extract variable
+            var_fl[[var_i]][i_time, ] <- var_nc[[var_i]][ind]
           }
         }
       }
 
-      if (return_average_variable){
+      if (return_average_variable) {
         # Compute the average wind component of the flight accounting for the weighting scheme
-        for (var_i in seq_len(length(variable))){
-          VAR_stap[[var_i]][i_fl, ] <- colSums(VAR_fl[[var_i]] * w)
+        for (var_i in seq_len(length(variable))) {
+          var_stap[[var_i]][i_fl, ] <- colSums(var_fl[[var_i]] * w)
         }
       } else {
-        VAR[[var_i]][[i_stap]][[i_fl]] <- data.frame(
-          edge_id = rep(st_id, each = length(t_q)),
-          val = as.vector(VAR_fl[[var_i]]),
-          pressure = rep(p_q, length(st_id)),
-          time = rep(t_q, length(st_id)),
-          w = rep(w, length(st_id))
-        )
-        VAR[[var_i]][[i_stap]][[i_fl]]$var <- variable[var_i]
-        if (interp_spatial_linear){
-          VAR[[var_i]][[i_stap]][[i_fl]]$lat <- as.vector(round(t(lat_int), 1))
-          VAR[[var_i]][[i_stap]][[i_fl]]$lon <- as.vector(round(t(lon_int), 1))
-        } else {
-          VAR[[var_i]][[i_stap]][[i_fl]]$lat <- as.vector(lat[t(lat_int_ind)])
-          VAR[[var_i]][[i_stap]][[i_fl]]$lon <- as.vector(lon[t(lon_int_ind)])
+        for (var_i in seq_len(length(variable))) {
+          var[[var_i]][[i_stap]][[i_fl]] <- data.frame(
+            edge_id = rep(st_id, each = length(t_q)),
+            val = as.vector(var_fl[[var_i]]),
+            pressure = rep(p_q, length(st_id)),
+            time = rep(t_q, length(st_id)),
+            w = rep(w, length(st_id))
+          )
+          var[[var_i]][[i_stap]][[i_fl]]$var <- variable[var_i]
+          if (interp_spatial_linear) {
+            var[[var_i]][[i_stap]][[i_fl]]$lat <- as.vector(round(t(lat_int), 1))
+            var[[var_i]][[i_stap]][[i_fl]]$lon <- as.vector(round(t(lon_int), 1))
+          } else {
+            var[[var_i]][[i_stap]][[i_fl]]$lat <- as.vector(lat[t(lat_int_ind)])
+            var[[var_i]][[i_stap]][[i_fl]]$lon <- as.vector(lon[t(lon_int_ind)])
+          }
         }
       }
     }
 
-    if (return_average_variable){
+    if (return_average_variable) {
       # Compute the average over all the flight of the transition accounting for the duration of the
       # flight.
-      for (var_i in seq_len(length(variable))){
-        VAR[st_id, var_i] <- colSums(VAR_stap[[var_i]] * fl_s$duration / sum(fl_s$duration))
+      for (var_i in seq_len(length(variable))) {
+        var[st_id, var_i] <- colSums(var_stap[[var_i]] * fl_s$duration / sum(fl_s$duration))
       }
     }
 
@@ -359,16 +379,41 @@ edge_add_wind <- function(tag_graph,
     }
   }
 
-  if (!return_average_variable){
-    VAR <- do.call(rbind, unlist(unlist(VAR, recursive = FALSE), recursive = FALSE))
+  if (!return_average_variable) {
+    var <- do.call(rbind, unlist(unlist(var, recursive = FALSE), recursive = FALSE))
   }
 
-  return(VAR)
+  return(var)
 }
 
 
 #' @noRd
-edge_add_wind_check <- function (file, flight, pressure, g){
+edge_add_wind_check <- function(
+    tag_graph,
+    pressure = NULL,
+    file = \(stap_id) {
+      glue::glue("./data/wind/{tag_graph$param$id}/{tag_graph$param$id}_{stap_id}.nc")
+    }) {
+  assertthat::assert_that(inherits(tag_graph, "tag") | inherits(tag_graph, "graph"))
+
+  # Compute lat-lon coordinate of the grid
+  g <- map_expand(tag_graph$param$extent, tag_graph$param$scale)
+
+  # Compute flight from stap
+  flight <- stap2flight(tag_graph$stap, format = "list")
+
+  # Check pressure
+  if (is.null(pressure) && inherits(tag_graph, "tag")) {
+    pressure <- tag_graph$pressure
+  }
+  assertthat::assert_that(is.data.frame(pressure))
+  assertthat::assert_that(assertthat::has_name(pressure, c("date", "value")))
+  assertthat::assert_that(assertthat::is.time(pressure$date))
+  assertthat::assert_that(is.numeric(pressure$value))
+
+  # Check file
+  assertthat::assert_that(is.function(file))
+
   # Check that all the files of wind_speed exist and match the data request
   for (i_stap in seq_len(length(flight))) {
     fl_s <- flight[[i_stap]]
@@ -380,6 +425,10 @@ edge_add_wind_check <- function (file, flight, pressure, g){
       }
       nc <- ncdf4::nc_open(file(i_s))
 
+      # Check that the variables u and v are present
+      ncdf4::ncvar_get(nc, "v")
+      ncdf4::ncvar_get(nc, "u")
+
       time <- as.POSIXct(ncdf4::ncvar_get(nc, "time") * 60 * 60, origin = "1900-01-01", tz = "UTC")
       t_s <- as.POSIXct(format(fl_s$start[i_fl], "%Y-%m-%d %H:00:00"), tz = "UTC")
       t_e <- as.POSIXct(format(fl_s$end[i_fl] + 60 * 60, "%Y-%m-%d %H:00:00"), tz = "UTC")
@@ -392,11 +441,10 @@ edge_add_wind_check <- function (file, flight, pressure, g){
       }
 
       pres <- ncdf4::ncvar_get(nc, "level")
-      t_q <- seq(from = t_s, to = t_e, by = 60 * 60)
       pres_value <- pressure$value[pressure$date > t_s & pressure$date < t_e]
       if (length(pres_value) == 0 ||
-          !(min(pres) <= min(pres_value) &&
-            max(pres) >= min(1000, max(pres_value)))) {
+        !(min(pres) <= min(pres_value) &&
+          max(pres) >= min(1000, max(pres_value)))) {
         cli::cli_abort(c(
           x = "Time between graph data and the wind file ({.file {file(i_s)}}) are not matching.",
           "!" = "You might have modified your stationary periods without updating your wind file? ",
@@ -408,7 +456,7 @@ edge_add_wind_check <- function (file, flight, pressure, g){
       lat <- ncdf4::ncvar_get(nc, "latitude")
       lon <- ncdf4::ncvar_get(nc, "longitude")
       if (min(g$lat) < min(lat) || max(g$lat) > max(lat) ||
-          min(g$lon) < min(lon) || max(g$lon) > max(lon)) {
+        min(g$lon) < min(lon) || max(g$lon) > max(lon)) {
         cli::cli_abort(c(x = "Spatial extend not matching for {.file {file(i_s)}}"))
       }
 
diff --git a/R/graph_add_wind.R b/R/graph_add_wind.R
index 7588ea42..707ca390 100644
--- a/R/graph_add_wind.R
+++ b/R/graph_add_wind.R
@@ -13,12 +13,8 @@
 #' illustration on how to use it.
 #'
 #' @param graph a GeoPressureR graph object.
-#' @param pressure pressure measurement of the associated `tag` data used to estimate the pressure
-#' level (i.e., altitude) of the bird during the flights. This data.frame needs to contain `date` as
-#' POSIXt and `value` in hPa.
 #' @param thr_as threshold of airspeed (km/h).
-#' @inheritParams tag_download_wind
-#' @param quiet logical to hide messages about the progress
+#' @inheritParams edge_add_wind
 #'
 #' @return a `graph` object with windspeed and airspeed as `ws` and `as` respectively.
 #'
@@ -34,324 +30,40 @@
 graph_add_wind <- function(
     graph,
     pressure,
-    thr_as = Inf,
+    rounding_interval = 60,
+    interp_spatial_linear = FALSE,
     file = \(stap_id) glue::glue("./data/wind/{graph$param$id}/{graph$param$id}_{stap_id}.nc"),
+    thr_as = Inf,
     quiet = FALSE) {
   graph_assert(graph, "full")
-  assertthat::assert_that(is.data.frame(pressure))
-  assertthat::assert_that(assertthat::has_name(pressure, c("date", "value")))
-  assertthat::assert_that(assertthat::is.time(pressure$date))
-  assertthat::assert_that(is.numeric(pressure$value))
-  assertthat::assert_that(is.function(file))
   assertthat::assert_that(is.numeric(thr_as))
   assertthat::assert_that(length(thr_as) == 1)
   assertthat::assert_that(thr_as >= 0)
 
-  # Compute flight from stap
-  flight <- stap2flight(graph$stap, format = "list")
-
-  # Compute lat-lon coordinate of the grid
-  g <- map_expand(graph$param$extent, graph$param$scale)
-
   # Check that all the files of wind_speed exist and match the data request
-  for (i1 in seq_len(graph$sz[3] - 1)) {
-    fl_s <- flight[[i1]]
-    for (i2 in seq_len(length(fl_s$stap_s))) {
-      i_s <- fl_s$stap_s[i2]
-
-      if (!file.exists(file(i_s))) {
-        cli::cli_abort(c(x = "No wind file {.file {file(i_s)}}"))
-      }
-      nc <- ncdf4::nc_open(file(i_s))
-
-      time <- as.POSIXct(ncdf4::ncvar_get(nc, "time") * 60 * 60, origin = "1900-01-01", tz = "UTC")
-      t_s <- as.POSIXct(format(fl_s$start[i2], "%Y-%m-%d %H:00:00"), tz = "UTC")
-      t_e <- as.POSIXct(format(fl_s$end[i2] + 60 * 60, "%Y-%m-%d %H:00:00"), tz = "UTC")
-      if (!(min(time) <= t_e && max(time) >= t_s)) {
-        cli::cli_abort(c(
-          x = "Time between graph data and the wind file ({.file {file(i_s)}}) are not matching.",
-          "!" = "You might have modified your stationary periods without updating your wind file? ",
-          ">" = "If so, run {.run tag_download_wind(tag)}"
-        ))
-      }
-
-      pres <- ncdf4::ncvar_get(nc, "level")
-      t_q <- seq(from = t_s, to = t_e, by = 60 * 60)
-      pres_value <- pressure$value[pressure$date > t_s & pressure$date < t_e]
-      if (length(pres_value) == 0 ||
-        !(min(pres) <= min(pres_value) &&
-          max(pres) >= min(1000, max(pres_value)))) {
-        cli::cli_abort(c(
-          x = "Time between graph data and the wind file ({.file {file(i_s)}}) are not matching.",
-          "!" = "You might have modified your stationary periods without updating your wind file? ",
-          ">" = "If so, run {.run tag_download_wind(tag)}"
-        ))
-      }
-
-      # Check if spatial extend match
-      lat <- ncdf4::ncvar_get(nc, "latitude")
-      lon <- ncdf4::ncvar_get(nc, "longitude")
-      if (min(g$lat) < min(lat) || max(g$lat) > max(lat) ||
-        min(g$lon) < min(lon) || max(g$lon) > max(lon)) {
-        cli::cli_abort(c(x = "Spatial extend not matching for {.file {file(i_s)}}"))
-      }
-
-      # Check if flight duration is
-      if (fl_s$start[i2] >= fl_s$end[i2]) {
-        cli::cli_abort(c(
-          x = "Flight starting on stap {fl_s$stap_s[i2]} has a start time equal or greater than \\
-                         the end time. Please review your labelling file."
-        ))
-      }
-    }
-  }
-
-  if (!quiet) {
-    cli::cli_progress_step("Extract edge information")
-  }
-
-  # Extract the index in lat, lon, stap from the source and target of all edges
-  s <- arrayInd(graph$s, graph$sz)
-  t <- arrayInd(graph$t, graph$sz)
-
-  # Prepare the matrix of speed to return
-  uv <- matrix(NA, nrow = length(graph$s), ncol = 2)
-
-  # Start progress bar
-  if (!quiet) {
-    i1 <- 0
-    cli::cli_progress_bar(
-      "Compute wind speed for edges of stationary period:",
-      format = "{cli::pb_name} {i1}/{graph$sz[3] - 1} {cli::pb_bar} {cli::pb_percent} | \\
-      {cli::pb_eta_str} [{cli::pb_elapsed}]",
-      format_done = "Compute wind speed for edges of stationary periods [{cli::pb_elapsed}]",
-      clear = FALSE,
-      total = sum(table(s[, 3]))
-    )
-  }
-
-  # Loop through the stationary period kept in the graph
-  for (i1 in seq_len(graph$sz[3] - 1)) {
-    # Extract the flight information from the current stap to the next one considered in the graph.
-    # It can be the next, or if some stap are skipped at construction, it can contains multiples
-    # flights
-    fl_s <- flight[[i1]]
-
-    # Extract the duration of each flights.
-    fl_s_dur <- stap2duration(fl_s, units = "hours")
-
-    # Determine the id of edges of the graph corresponding to this/these flight(s).
-    st_id <- which(s[, 3] == i1)
-
-    # We are assuming that the bird flight as a straight line between the source and the target node
-    # of each edge. If multiple flights happen during this transition, we assume that the bird flew
-    # with a constant groundspeed during each flight, thus considering its stopover position to be
-    # spread according to the flight duration. This does not account for habitat, so that it would
-    # assume a bird can stop over water. While we could improve this part of the code to assume
-    # constant airspeed rather than groundspeed, we suggest to create the graph considering all
-    # stopovers.
-    ratio_stap <- as.matrix(c(0, cumsum(fl_s_dur) / sum(fl_s_dur)))
-
-    # Prepare the u- and v- windspeed for each flight (row) and edge (col)
-    u_stap <- matrix(NA, nrow = length(fl_s$stap_s), ncol = length(st_id))
-    v_stap <- matrix(NA, nrow = length(fl_s$stap_s), ncol = length(st_id))
-
-    # Loop through each flight of the transition
-    for (i2 in seq_len(length(fl_s$stap_s))) {
-      # Find the stationary period ID from this specific flight (source)
-      i_s <- fl_s$stap_s[i2]
-
-      # Read the netCDF file
-      nc <- ncdf4::nc_open(file(i_s))
-
-      # Read data from netCDF file and convert the time of data to posixt
-      time <- as.POSIXct(ncdf4::ncvar_get(nc, "time") * 60 * 60,
-        origin = "1900-01-01", tz = "UTC"
-      )
-      pres <- ncdf4::ncvar_get(nc, "level")
-      lat <- ncdf4::ncvar_get(nc, "latitude")
-      lon <- ncdf4::ncvar_get(nc, "longitude")
-
-      # Find the start and end latitude and longitude of each edge
-      lat_s <- g$lat[s[st_id, 1]] +
-        ratio_stap[i2] * (g$lat[t[st_id, 1]] - g$lat[s[st_id, 1]])
-      lon_s <- g$lon[s[st_id, 2]] +
-        ratio_stap[i2] * (g$lon[t[st_id, 2]] - g$lon[s[st_id, 2]])
-      lat_e <- g$lat[s[st_id, 1]] +
-        ratio_stap[i2 + 1] * (g$lat[t[st_id, 1]] - g$lat[s[st_id, 1]])
-      lon_e <- g$lon[s[st_id, 2]] +
-        ratio_stap[i2 + 1] * (g$lon[t[st_id, 2]] - g$lon[s[st_id, 2]])
-
-      # As ERA5 data is available every hour, we build a one hour resolution time series including
-      # the start and end time of the flight. Thus, we first round the start end end time.
-      t_s <- as.POSIXct(format(fl_s$start[i2], "%Y-%m-%d %H:00:00"),
-        tz = "UTC"
-      )
-      t_e <- as.POSIXct(format(fl_s$end[i2] + 60 * 60, "%Y-%m-%d %H:00:00"),
-        tz = "UTC"
-      )
-      t_q <- seq(from = t_s, to = t_e, by = 60 * 60)
-
-      # We assume that the bird is moving with a constant groundspeed between `flight$start` and
-      # `flight$end`. Using a linear interpolation, we extract the position (lat, lon) at every hour
-      # on `t_q`. Extrapolation outside (before the bird departure or after he arrived) is with a
-      # nearest neighbour.
-
-      dt <- fl_s_dur[i2] # old code not tested replacement as.numeric(difftime(fl_s$end[i2],
-      # fl_s$start[i2], units = "hours"))
-      dlat <- (lat_e - lat_s) / dt
-      dlon <- (lon_e - lon_s) / dt
-      w <- pmax(pmin(as.numeric(
-        difftime(t_q, fl_s$start[i2], units = "hours")
-      ), dt), 0)
-      w2 <- matrix(w, nrow = length(dlat), ncol = length(w), byrow = TRUE)
-      lat_int <- lat_s + w2 * replicate(length(w), dlat)
-      lon_int <- lon_s + w2 * replicate(length(w), dlon)
-
-      # As we are interesting in the average windspeed experienced during the entire flight, we need
-      # to find the weights of each 1hr interval extracted from ERA5. We can estimate these weight
-      # assuming a linear integration of the time (trapezoidal rule) or a step integration (Riemann
-      # sum)
-
-      # Linear integration
-      w <- numeric(length(t_q))
-      assertthat::assert_that(length(w) > 1)
-      alpha <- 1 - as.numeric(difftime(fl_s$start[i2], t_q[1],
-        units = "hours"
-      ))
-      assertthat::assert_that(alpha >= 0 & alpha <= 1)
-      w[c(1, 2)] <- w[c(1, 2)] + c(alpha, 1 - alpha) * alpha
-      alpha <- 1 - as.numeric(difftime(utils::tail(t_q, 1), fl_s$end[i2],
-        units = "hours"
-      ))
-      assertthat::assert_that(alpha >= 0 & alpha <= 1)
-      w[length(w) - c(1, 0)] <- w[length(w) - c(1, 0)] +
-        c(1 - alpha, alpha) * alpha
-
-      if (length(w) >= 4) {
-        w[c(2, length(w) - 1)] <- w[c(2, length(w) - 1)] + 0.5
-      }
-      if (length(w) >= 5) {
-        w[seq(3, length(w) - 2)] <- w[seq(3, length(w) - 2)] + 1
-      }
-      # normalize the weight
-      w <- w / sum(w)
-
-      assertthat::assert_that(all(!is.na(w)))
-
-      # step integration
-      # w <- difftime(pmin(pmax(t_q+60*60/2,fl_s$start[i2]),fl_s$end[i2]),
-      #               pmin(pmax(t_q-60*60/2,fl_s$start[i2]),fl_s$end[i2]),
-      #               units = "hours")
-
-      # Prepare the interpolated u- v- vector for each flight
-      u_int <- matrix(NA, nrow = length(t_q), ncol = length(st_id))
-      v_int <- matrix(NA, nrow = length(t_q), ncol = length(st_id))
-
-      # Loop through the 1hr interval
-      for (i3 in seq_len(length(t_q))) {
-        # find the time step to query in ERA5
-        id_time <- which(time == t_q[i3])
-        # find the two pressure level to query (one above, one under) based on the geolocator
-        # pressure at this timestep
-        pres_value <- stats::approx(pressure$date, pressure$value, t_q[i3])$y
-        df <- pres_value - pres
-        df[df < 0] <- NA
-        id_pres <- which.min(df)
-        # if the pressure is higher than the highest level (i.e. bird below the ground level
-        # pressure), we extract only the last layer
-        n_pres <- ifelse(id_pres == length(df), 1, 2)
-
-        dlon <- lon[2] - lon[1]
-        id_lon <- which(lon >= (min(lon_int[, i3]) - dlon) &
-          (max(lon_int[, i3]) + dlon) >= lon)
-
-        dlat <- abs(lat[2] - lat[1])
-        id_lat <- which(lat >= (min(lat_int[, i3]) - dlat) &
-          (max(lat_int[, i3]) + dlat) >= lat)
-
-
-        # get the two maps of u- and v-
-        u <- ncdf4::ncvar_get(nc, "u",
-          start = c(id_lon[1], id_lat[1], id_pres, id_time),
-          count = c(length(id_lon), length(id_lat), n_pres, 1)
-        )
-        v <- ncdf4::ncvar_get(nc, "v",
-          start = c(id_lon[1], id_lat[1], id_pres, id_time),
-          count = c(length(id_lon), length(id_lat), n_pres, 1)
-        )
-
-        # Interpolate linearly the map of wind based on pressure.
-        if (n_pres == 2) {
-          w2 <- abs(pres[id_pres + c(0, 1)] - pres_value)
-          w2 <- w2 / sum(w2)
-          u <- w2[1] * u[, , 1] + w2[2] * u[, , 2]
-          v <- w2[1] * v[, , 1] + w2[2] * v[, , 2]
-        }
-
-        # Interpolation the u- and v- component at the interpolated position at the current time
-        # step.
-        # Because lat_int and lon_int are so big, we round their value and only interpolate on the
-        # unique value that are needed. Then, we give the interpolated value back to all the lat_int
-        # lon_int dimension
-        # Convert the coordinate to 1d to have a more efficient unique.
-        ll_int_1d <- (round(lat_int[, i3], 1) + 90) * 10 * 10000 +
-          (round(lon_int[, i3], 1) + 180) * 10 + 1
-        ll_int_1d_uniq <- unique(ll_int_1d)
-
-        lat_int_uniq <- ((ll_int_1d_uniq - 1) %/% 10000) / 10 - 90
-        lon_int_uniq <- ((ll_int_1d_uniq - 1) %% 10000) / 10 - 180
-        # CHeck that the transofmration is correct with
-        # cbind((round(lat_int[, i3], 1)+90)*10, (ll_int_1d - 1) %/% 10000)
-        # cbind((round(lon_int[, i3],1)+180)*10, (ll_int_1d - 1) %% 10000)
-        # cbind(lat_int_uniq, lon_int_uniq, lat_int[, i3], lon_int[, i3])
-
-        id_uniq <- match(ll_int_1d, ll_int_1d_uniq)
-
-        tmp <- pracma::interp2(rev(lat[id_lat]), lon[id_lon],
-          u[, rev(seq_len(ncol(u)))],
-          lat_int_uniq, lon_int_uniq,
-          method = "linear"
-        )
-
-        assertthat::assert_that(all(!is.na(tmp)))
-        u_int[i3, ] <- tmp[id_uniq]
-        tmp <- pracma::interp2(rev(lat[id_lat]), lon[id_lon],
-          v[, rev(seq_len(ncol(u)))],
-          lat_int_uniq, lon_int_uniq,
-          method = "linear"
-        )
-        assertthat::assert_that(all(!is.na(tmp)))
-        v_int[i3, ] <- tmp[id_uniq]
-      }
-      # Compute the average wind component of the flight accounting for the weighting scheme
-      u_stap[i2, ] <- colSums(u_int * w)
-      v_stap[i2, ] <- colSums(v_int * w)
-    }
-    # Compute the average  over all the flight of the transition accounting for the duration of the
-    # flight.
-    uv[st_id, 1] <- colSums(u_stap * fl_s_dur / sum(fl_s_dur))
-    uv[st_id, 2] <- colSums(v_stap * fl_s_dur / sum(fl_s_dur))
-
-    if (!quiet) {
-      cli::cli_progress_update(set = sum(table(s[, 3])[seq(1, i1)]), force = TRUE)
-    }
-  }
+  uv <- edge_add_wind(graph,
+    edge_s = graph$s,
+    edge_t = graph$t,
+    pressure = pressure,
+    variable = c("u", "v"),
+    rounding_interval = rounding_interval,
+    interp_spatial_linear = interp_spatial_linear,
+    return_average_variable = TRUE,
+    file = file,
+    quiet = quiet
+  )
 
   # save windspeed in complex notation and convert from m/s to km/h
   graph$ws <- (uv[, 1] + 1i * uv[, 2]) / 1000 * 60 * 60
 
-  # compute airspeed
-  as <- graph$gs - graph$ws
-
   # filter edges based on airspeed
-  id <- abs(as) <= thr_as
-  sta_pass <- which(!(seq_len(graph$sz[3] - 1) %in% unique(s[id, 3])))
+  id <- abs(graph$gs - graph$ws) <= thr_as
+  sta_pass <- which(!(seq_len(graph$sz[3] - 1) %in% unique(edge_s[id, 3])))
   if (length(sta_pass) > 0) {
     cli::cli_abort(c(
       x = "Using the {.val thr_as} of {thr_as} km/h provided with the exact distance of edges, \\
       there are not any nodes left for the stationary period: {sta_pass} with a minimum airspeed \\
-      of {min(abs(as[s[, 3] == sta_pass]))} km/h."
+      of {min(abs(as[edge_s[, 3] == sta_pass]))} km/h."
     ))
   }