6
6
7
7
import os
8
8
import re
9
+ from datetime import datetime
9
10
from typing import Tuple , Dict , List , Union
10
11
11
12
from mypy_boto3_lambda import LambdaClient
12
13
from pyriandx .client import Client
13
14
import json
14
15
import pandas as pd
15
16
import time
17
+ import jwt
18
+ from jwt import DecodeError
19
+
16
20
17
21
from pyriandx .utils import retry_session
18
22
19
23
from .globals import \
20
24
PIERIANDX_CDK_SSM_LIST , \
21
25
PIERIANDX_CDK_SSM_PATH , \
22
26
MAX_ATTEMPTS_GET_CASES , LIST_CASES_RETRY_TIME , \
23
- PanelType , SampleType , PIERIANDX_USER_AUTH_TOKEN_LAMBDA_PATH
27
+ PanelType , SampleType , PIERIANDX_USER_AUTH_TOKEN_LAMBDA_PATH , JWT_EXPIRY_BUFFER
24
28
25
29
from .miscell import \
26
30
change_case
@@ -76,7 +80,7 @@ def get_pieriandx_env_vars() -> Tuple:
76
80
output_dict [env_var ] = parameter_value
77
81
78
82
# Set PIERIANDX_USER_AUTH_TOKEN based on secret
79
- if "PIERIANDX_USER_AUTH_TOKEN" in os .environ :
83
+ if "PIERIANDX_USER_AUTH_TOKEN" in os .environ and jwt_is_valid ( os . environ [ "PIERIANDX_USER_AUTH_TOKEN" ]) :
80
84
# Already here!
81
85
output_dict ["PIERIANDX_USER_AUTH_TOKEN" ] = os .environ ["PIERIANDX_USER_AUTH_TOKEN" ]
82
86
else :
@@ -91,8 +95,12 @@ def get_pieriandx_env_vars() -> Tuple:
91
95
InvocationType = "RequestResponse"
92
96
)
93
97
auth_token_resp = response ['Payload' ].read ().decode ('utf-8' )
98
+ if auth_token_resp is None or auth_token_resp == 'null' or json .loads (auth_token_resp ).get ("auth_token" ) is None :
99
+ logger .info ("Could not get valid auth token from lambda, trying again in five seconds" )
100
+ time .sleep (5 )
94
101
95
102
output_dict ["PIERIANDX_USER_AUTH_TOKEN" ] = json .loads (auth_token_resp ).get ("auth_token" )
103
+ os .environ ["PIERIANDX_USER_AUTH_TOKEN" ] = output_dict ["PIERIANDX_USER_AUTH_TOKEN" ]
96
104
97
105
return (
98
106
output_dict .get ("PIERIANDX_USER_EMAIL" ),
@@ -479,3 +487,25 @@ def get_pieriandx_status_for_missing_sample(case_id: str) -> pd.Series:
479
487
case_dict ["pieriandx_report_status" ] = report ["status" ]
480
488
481
489
return pd .Series (case_dict )
490
+
491
+
492
+ def decode_jwt (jwt_string : str ) -> Dict :
493
+ return jwt .decode (
494
+ jwt_string ,
495
+ algorithms = ["HS256" ],
496
+ options = {"verify_signature" : False }
497
+ )
498
+
499
+
500
+ def jwt_is_valid (jwt_string : str ) -> bool :
501
+ try :
502
+ decode_jwt (jwt_string )
503
+ timestamp_exp = decode_jwt (jwt_string ).get ("exp" )
504
+
505
+ # If timestamp will expire in less than one minute's time, return False
506
+ if int (timestamp_exp ) < (int (datetime .now ().timestamp ()) + JWT_EXPIRY_BUFFER ):
507
+ return False
508
+ else :
509
+ return True
510
+ except DecodeError as e :
511
+ return False
0 commit comments