Skip to content

Commit

Permalink
fix: add constant for HF_HUB_USER_AGENT_ORIGIN and fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugoch committed Feb 17, 2025
1 parent 1ab0787 commit ff706fe
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub mod sync;

const HF_ENDPOINT: &str = "HF_ENDPOINT";

const ENV_UA_ORIGIN: &str = "HF_HUB_USER_AGENT_ORIGIN";

/// This trait is used by users of the lib
/// to implement custom behavior during file downloads
pub trait Progress {
Expand Down
18 changes: 14 additions & 4 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{RepoInfo, HF_ENDPOINT};
use super::{RepoInfo, ENV_UA_ORIGIN, HF_ENDPOINT};
use crate::api::sync::ApiError::InvalidHeader;
use crate::api::Progress;
use crate::{Cache, Repo, RepoType};
Expand Down Expand Up @@ -305,8 +305,8 @@ impl ApiBuilder {
let mut headers = HeaderMap::new();
let user_agent = format!("unknown/None; {NAME}/{VERSION}; rust/unknown");
// add origin user-agent if HF_HUB_USER_AGENT_ORIGIN is set in environment
let user_agent = match std::env::var("HF_HUB_USER_AGENT_ORIGIN").ok() {
Some(origin) => format!("{user_agent}; {origin}"),
let user_agent = match std::env::var(ENV_UA_ORIGIN).ok() {
Some(origin) => format!("{user_agent}; origin/{origin}"),
None => user_agent,
};
headers.insert(USER_AGENT, user_agent);
Expand Down Expand Up @@ -1168,7 +1168,7 @@ mod tests {
"gated": false,
"id": "mcpotato/42-eicar-street",
"lastModified": "2022-11-30T19:54:16.000Z",
"likes": 1,
"likes": 2,
"modelId": "mcpotato/42-eicar-street",
"private": false,
"sha": "8b3861f6931c4026b0cd22b38dbc09e7668983ac",
Expand Down Expand Up @@ -1216,6 +1216,7 @@ mod tests {
],
"spaces": [],
"tags": ["pytorch", "region:us"],
"usedStorage": 22,
})
);
}
Expand All @@ -1232,6 +1233,15 @@ mod tests {
assert_eq!(api.endpoint, fake_endpoint);
}

#[test]
fn build_headers() {
std::env::set_var(ENV_UA_ORIGIN, "foo");
let api = ApiBuilder::new().build().unwrap();
let headers = api.client.headers;
println!("{:?}", headers);
assert!(headers.get(USER_AGENT).unwrap().contains("origin/foo"));
}

// #[test]
// fn real() {
// let api = Api::new().unwrap();
Expand Down
9 changes: 5 additions & 4 deletions src/api/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::Progress as SyncProgress;
use super::{Progress as SyncProgress, ENV_UA_ORIGIN};
use super::{RepoInfo, HF_ENDPOINT};
use crate::{Cache, Repo, RepoType};
use futures::stream::FuturesUnordered;
Expand Down Expand Up @@ -330,8 +330,8 @@ impl ApiBuilder {
let mut headers = HeaderMap::new();
let user_agent = format!("unknown/None; {NAME}/{VERSION}; rust/unknown");
// add origin user-agent if HF_HUB_USER_AGENT_ORIGIN is set in environment
let user_agent = match std::env::var("HF_HUB_USER_AGENT_ORIGIN").ok() {
Some(origin) => format!("{user_agent}; {origin}"),
let user_agent = match std::env::var(ENV_UA_ORIGIN).ok() {
Some(origin) => format!("{user_agent}; origin/{origin}"),
None => user_agent,
};
headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?);
Expand Down Expand Up @@ -1322,7 +1322,7 @@ mod tests {
"gated": false,
"id": "mcpotato/42-eicar-street",
"lastModified": "2022-11-30T19:54:16.000Z",
"likes": 1,
"likes": 2,
"modelId": "mcpotato/42-eicar-street",
"private": false,
"sha": "8b3861f6931c4026b0cd22b38dbc09e7668983ac",
Expand Down Expand Up @@ -1370,6 +1370,7 @@ mod tests {
],
"spaces": [],
"tags": ["pytorch", "region:us"],
"usedStorage": 22,
})
);
}
Expand Down

0 comments on commit ff706fe

Please # to comment.