17
17
* Boston, MA 02111-1307, USA.
18
18
*/
19
19
20
+ use actix:: clock:: interval;
20
21
use actix:: prelude:: * ;
22
+ use actix_web:: rt:: pin;
21
23
use actix_web:: web:: Bytes ;
22
24
use actix_web:: { get, rt, web, HttpRequest , HttpResponse } ;
23
25
use actix_web_validator:: Query ;
24
- use actix_ws:: { AggregatedMessage , ProtocolError , Session } ;
26
+ use actix_ws:: { AggregatedMessage , Session } ;
25
27
use db_connector:: models:: wg_keys:: WgKey ;
26
28
use diesel:: prelude:: * ;
29
+ use futures_util:: future:: Either ;
30
+ use futures_util:: StreamExt ;
27
31
use serde:: { Deserialize , Serialize } ;
28
32
use std:: collections:: HashMap ;
29
33
use std:: str:: FromStr ;
30
34
use std:: sync:: Arc ;
31
- use std:: time:: Instant ;
35
+ use std:: time:: { Duration , Instant } ;
32
36
use validator:: { Validate , ValidationError } ;
33
37
use futures_util:: lock:: Mutex ;
34
38
@@ -44,6 +48,9 @@ use crate::{
44
48
AppState , BridgeState ,
45
49
} ;
46
50
51
+ const HEARTBEAT_INTERVAL : Duration = Duration :: from_secs ( 5 ) ;
52
+ const CLIENT_TIMEOUT : Duration = Duration :: from_secs ( 5 ) ;
53
+
47
54
#[ derive( Deserialize , Serialize , Validate ) ]
48
55
struct WsQuery {
49
56
#[ validate( custom( function = validate_key_id) ) ]
@@ -106,12 +113,16 @@ impl WebClient {
106
113
}
107
114
}
108
115
109
- pub async fn handle_message ( & mut self , msg : Result < AggregatedMessage , ProtocolError > ) {
116
+ pub async fn handle_message ( & mut self , msg : AggregatedMessage , last_heartbeat : & mut Instant ) {
110
117
match msg {
111
- Ok ( AggregatedMessage :: Ping ( msg) ) => {
118
+ AggregatedMessage :: Ping ( msg) => {
112
119
self . session . pong ( & msg) . await . unwrap ( ) ;
120
+ * last_heartbeat = Instant :: now ( ) ;
121
+ } ,
122
+ AggregatedMessage :: Pong ( _) => {
123
+ * last_heartbeat = Instant :: now ( ) ;
113
124
} ,
114
- Ok ( AggregatedMessage :: Binary ( msg) ) => {
125
+ AggregatedMessage :: Binary ( msg) => {
115
126
let peer_sock_addr = {
116
127
let meta = RemoteConnMeta {
117
128
charger_id : self . charger_id . clone ( ) ,
@@ -324,7 +335,7 @@ async fn start_ws(
324
335
bridge_state. port_discovery . clone ( ) ,
325
336
) . await ?;
326
337
327
- let ( resp, session, stream) = actix_ws:: handle ( & req, stream) ?;
338
+ let ( resp, mut session, stream) = actix_ws:: handle ( & req, stream) ?;
328
339
let mut stream = stream. aggregate_continuations ( )
329
340
. max_continuation_size ( 2_usize . pow ( 20 ) ) ;
330
341
@@ -351,14 +362,31 @@ async fn start_ws(
351
362
state,
352
363
bridge_state,
353
364
keys. connection_no ,
354
- session,
365
+ session. clone ( ) ,
355
366
) . await ;
356
- while let Some ( msg) = stream. recv ( ) . await {
357
- if let Ok ( AggregatedMessage :: Close ( _) ) = & msg {
358
- log:: info!( "/ws close connection" ) ;
359
- break ;
367
+
368
+ let mut last_heartbeat = Instant :: now ( ) ;
369
+ let mut interval = interval ( HEARTBEAT_INTERVAL ) ;
370
+ loop {
371
+ let tick = interval. tick ( ) ;
372
+ pin ! ( tick) ;
373
+
374
+ match futures_util:: future:: select ( stream. next ( ) , tick) . await {
375
+ Either :: Left ( ( Some ( Ok ( AggregatedMessage :: Close ( _) ) ) , _) ) => break ,
376
+ Either :: Left ( ( Some ( Ok ( msg) ) , _) ) => client. handle_message ( msg, & mut last_heartbeat) . await ,
377
+ Either :: Left ( ( Some ( err) , _) ) => {
378
+ log:: error!( "Websocket Error during connection: {:?}" , err) ;
379
+ break ;
380
+ } ,
381
+ Either :: Left ( ( None , _) ) => break ,
382
+ Either :: Right ( _) => {
383
+ if Instant :: now ( ) . duration_since ( last_heartbeat) > CLIENT_TIMEOUT {
384
+ log:: debug!( "Client quietly quit." ) ;
385
+ break ;
386
+ }
387
+ let _ = session. ping ( b"" ) . await ;
388
+ }
360
389
}
361
- client. handle_message ( msg) . await ;
362
390
}
363
391
client. stop ( ) . await ;
364
392
} ) ;
0 commit comments