Skip to content

Commit fa3ce13

Browse files
authored
Merge pull request #18 from kozakura913/main
レートリミットの実装
2 parents 56ff62a + 2cf0eec commit fa3ce13

File tree

1 file changed

+73
-3
lines changed

1 file changed

+73
-3
lines changed

src/main.rs

+73-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,50 @@
1-
use std::{borrow::Cow, io::Write, net::SocketAddr, sync::Arc};
1+
use std::{borrow::Cow, collections::HashMap, io::Write, net::SocketAddr, sync::Arc};
22

33
use axum::{response::IntoResponse, Router};
44
use serde::{Deserialize, Serialize};
55
use tokio_stream::StreamExt;
66

7+
/** レートリミット対象の処理が終わった時に破棄する*/
8+
struct RateLimitTracker(Option<RateLimit>,Option<String>,tokio::runtime::Handle);
9+
impl Drop for RateLimitTracker{
10+
fn drop(&mut self) {
11+
let s=self.1.take().unwrap();
12+
let hosts=self.0.take().unwrap().hosts;
13+
self.2.spawn(async move{
14+
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
15+
let mut wlock=hosts.write().await;
16+
let active_tasks=wlock.remove(&s).unwrap_or(0);
17+
if active_tasks>1{
18+
wlock.insert(s,active_tasks-1);
19+
}
20+
});
21+
}
22+
}
23+
#[derive(Clone,Debug)]
24+
struct RateLimit{
25+
hosts:Arc<tokio::sync::RwLock<HashMap<String,u32>>>,
26+
}
27+
impl RateLimit{
28+
/** 処理を実行しても良いか確認し、ロックを取得する。Err(true)は待てば成功する可能性がある*/
29+
async fn request(&self,url:&str)->Result<RateLimitTracker,bool>{
30+
let host=reqwest::Url::parse(url).map_err(|_|false)?;
31+
let host=host.host().ok_or(false)?.to_string();
32+
let rlock=self.hosts.read().await;
33+
let active_tasks=rlock.get(&host).copied().unwrap_or(0);
34+
drop(rlock);
35+
let active_tasks=active_tasks+1;
36+
if active_tasks<3{
37+
//処理中がこれ含め3件未満であれば即時実行
38+
let mut wlock=self.hosts.write().await;
39+
wlock.insert(host.clone(),active_tasks);
40+
let handle=tokio::runtime::Handle::current();
41+
Ok(RateLimitTracker(Some(self.clone()),Some(host),handle))
42+
}else{
43+
//再試行するべき失敗
44+
Err(true)
45+
}
46+
}
47+
}
748
#[derive(Clone,Debug,Serialize,Deserialize)]
849
pub struct ConfigFile{
950
bind_addr: String,
@@ -137,7 +178,10 @@ fn main() {
137178
let client=reqwest::ClientBuilder::new();
138179
let client=client.build().unwrap();
139180
let rt=tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
140-
let arg_tup=(client,config);
181+
let limit=RateLimit{
182+
hosts:Arc::new(tokio::sync::RwLock::new(HashMap::new()))
183+
};
184+
let arg_tup=(client,config,limit);
141185
rt.block_on(async{
142186
let http_addr:SocketAddr = arg_tup.1.bind_addr.parse().unwrap();
143187
let app = Router::new();
@@ -154,7 +198,7 @@ fn main() {
154198
async fn get_file(
155199
_path:Option<axum::extract::Path<String>>,
156200
request_headers:axum::http::HeaderMap,
157-
(client,config):(reqwest::Client,Arc<ConfigFile>),
201+
(client,config,limit):(reqwest::Client,Arc<ConfigFile>,RateLimit),
158202
axum::extract::Query(q):axum::extract::Query<RequestParams>,
159203
)->axum::response::Response{
160204
println!("{}\t{}\tlang:{:?}\tresponse_timeout:{:?}\tcontent_length_limit:{:?}\tuser_agent:{:?}",
@@ -171,6 +215,32 @@ async fn get_file(
171215
config.append_headers(&mut headers);
172216
return (axum::http::StatusCode::IM_A_TEAPOT,headers).into_response()
173217
}
218+
for _ in 0..3{
219+
match limit.request(&q.url).await{
220+
Ok(_)=>{
221+
return remote_request(request_headers,(client,config),q).await;
222+
},
223+
Err(true)=>{
224+
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
225+
},
226+
_=>{
227+
let mut headers=axum::http::HeaderMap::new();
228+
headers.append(axum::http::header::CACHE_CONTROL,"public, max-age=30".parse().unwrap());
229+
config.append_headers(&mut headers);
230+
return (axum::http::StatusCode::BAD_REQUEST,headers).into_response();
231+
}
232+
}
233+
}
234+
let mut headers=axum::http::HeaderMap::new();
235+
headers.append(axum::http::header::CACHE_CONTROL,"public, max-age=30".parse().unwrap());
236+
config.append_headers(&mut headers);
237+
(axum::http::StatusCode::TOO_MANY_REQUESTS,headers).into_response()
238+
}
239+
async fn remote_request(
240+
request_headers:axum::http::HeaderMap,
241+
(client,config):(reqwest::Client,Arc<ConfigFile>),
242+
q:RequestParams,
243+
)->axum::response::Response{
174244
let builder=client.get(&q.url);
175245
let user_agent=q.user_agent.as_ref().unwrap_or_else(||&config.user_agent);
176246
let builder=builder.header(reqwest::header::USER_AGENT,user_agent);

0 commit comments

Comments
 (0)