1
- use std:: { borrow:: Cow , io:: Write , net:: SocketAddr , sync:: Arc } ;
1
+ use std:: { borrow:: Cow , collections :: HashMap , io:: Write , net:: SocketAddr , sync:: Arc } ;
2
2
3
3
use axum:: { response:: IntoResponse , Router } ;
4
4
use serde:: { Deserialize , Serialize } ;
5
5
use tokio_stream:: StreamExt ;
6
6
7
+ /** レートリミット対象の処理が終わった時に破棄する*/
8
+ struct RateLimitTracker ( Option < RateLimit > , Option < String > , tokio:: runtime:: Handle ) ;
9
+ impl Drop for RateLimitTracker {
10
+ fn drop ( & mut self ) {
11
+ let s=self . 1 . take ( ) . unwrap ( ) ;
12
+ let hosts=self . 0 . take ( ) . unwrap ( ) . hosts ;
13
+ self . 2 . spawn ( async move {
14
+ tokio:: time:: sleep ( tokio:: time:: Duration :: from_millis ( 500 ) ) . await ;
15
+ let mut wlock=hosts. write ( ) . await ;
16
+ let active_tasks=wlock. remove ( & s) . unwrap_or ( 0 ) ;
17
+ if active_tasks>1 {
18
+ wlock. insert ( s, active_tasks-1 ) ;
19
+ }
20
+ } ) ;
21
+ }
22
+ }
23
+ #[ derive( Clone , Debug ) ]
24
+ struct RateLimit {
25
+ hosts : Arc < tokio:: sync:: RwLock < HashMap < String , u32 > > > ,
26
+ }
27
+ impl RateLimit {
28
+ /** 処理を実行しても良いか確認し、ロックを取得する。Err(true)は待てば成功する可能性がある*/
29
+ async fn request ( & self , url : & str ) ->Result < RateLimitTracker , bool > {
30
+ let host=reqwest:: Url :: parse ( url) . map_err ( |_|false ) ?;
31
+ let host=host. host ( ) . ok_or ( false ) ?. to_string ( ) ;
32
+ let rlock=self . hosts . read ( ) . await ;
33
+ let active_tasks=rlock. get ( & host) . copied ( ) . unwrap_or ( 0 ) ;
34
+ drop ( rlock) ;
35
+ let active_tasks=active_tasks+1 ;
36
+ if active_tasks<3 {
37
+ //処理中がこれ含め3件未満であれば即時実行
38
+ let mut wlock=self . hosts . write ( ) . await ;
39
+ wlock. insert ( host. clone ( ) , active_tasks) ;
40
+ let handle=tokio:: runtime:: Handle :: current ( ) ;
41
+ Ok ( RateLimitTracker ( Some ( self . clone ( ) ) , Some ( host) , handle) )
42
+ } else {
43
+ //再試行するべき失敗
44
+ Err ( true )
45
+ }
46
+ }
47
+ }
7
48
#[ derive( Clone , Debug , Serialize , Deserialize ) ]
8
49
pub struct ConfigFile {
9
50
bind_addr : String ,
@@ -137,7 +178,10 @@ fn main() {
137
178
let client=reqwest:: ClientBuilder :: new ( ) ;
138
179
let client=client. build ( ) . unwrap ( ) ;
139
180
let rt=tokio:: runtime:: Builder :: new_multi_thread ( ) . enable_all ( ) . build ( ) . unwrap ( ) ;
140
- let arg_tup=( client, config) ;
181
+ let limit=RateLimit {
182
+ hosts : Arc :: new ( tokio:: sync:: RwLock :: new ( HashMap :: new ( ) ) )
183
+ } ;
184
+ let arg_tup=( client, config, limit) ;
141
185
rt. block_on ( async {
142
186
let http_addr: SocketAddr = arg_tup. 1 . bind_addr . parse ( ) . unwrap ( ) ;
143
187
let app = Router :: new ( ) ;
@@ -154,7 +198,7 @@ fn main() {
154
198
async fn get_file (
155
199
_path : Option < axum:: extract:: Path < String > > ,
156
200
request_headers : axum:: http:: HeaderMap ,
157
- ( client, config) : ( reqwest:: Client , Arc < ConfigFile > ) ,
201
+ ( client, config, limit ) : ( reqwest:: Client , Arc < ConfigFile > , RateLimit ) ,
158
202
axum:: extract:: Query ( q) : axum:: extract:: Query < RequestParams > ,
159
203
) ->axum:: response:: Response {
160
204
println ! ( "{}\t {}\t lang:{:?}\t response_timeout:{:?}\t content_length_limit:{:?}\t user_agent:{:?}" ,
@@ -171,6 +215,32 @@ async fn get_file(
171
215
config. append_headers ( & mut headers) ;
172
216
return ( axum:: http:: StatusCode :: IM_A_TEAPOT , headers) . into_response ( )
173
217
}
218
+ for _ in 0 ..3 {
219
+ match limit. request ( & q. url ) . await {
220
+ Ok ( _) =>{
221
+ return remote_request ( request_headers, ( client, config) , q) . await ;
222
+ } ,
223
+ Err ( true ) =>{
224
+ tokio:: time:: sleep ( tokio:: time:: Duration :: from_millis ( 1000 ) ) . await ;
225
+ } ,
226
+ _=>{
227
+ let mut headers=axum:: http:: HeaderMap :: new ( ) ;
228
+ headers. append ( axum:: http:: header:: CACHE_CONTROL , "public, max-age=30" . parse ( ) . unwrap ( ) ) ;
229
+ config. append_headers ( & mut headers) ;
230
+ return ( axum:: http:: StatusCode :: BAD_REQUEST , headers) . into_response ( ) ;
231
+ }
232
+ }
233
+ }
234
+ let mut headers=axum:: http:: HeaderMap :: new ( ) ;
235
+ headers. append ( axum:: http:: header:: CACHE_CONTROL , "public, max-age=30" . parse ( ) . unwrap ( ) ) ;
236
+ config. append_headers ( & mut headers) ;
237
+ ( axum:: http:: StatusCode :: TOO_MANY_REQUESTS , headers) . into_response ( )
238
+ }
239
+ async fn remote_request (
240
+ request_headers : axum:: http:: HeaderMap ,
241
+ ( client, config) : ( reqwest:: Client , Arc < ConfigFile > ) ,
242
+ q : RequestParams ,
243
+ ) ->axum:: response:: Response {
174
244
let builder=client. get ( & q. url ) ;
175
245
let user_agent=q. user_agent . as_ref ( ) . unwrap_or_else ( ||& config. user_agent ) ;
176
246
let builder=builder. header ( reqwest:: header:: USER_AGENT , user_agent) ;
0 commit comments