Skip to content

Commit f85be44

Browse files
author
Montana Low
committed
removed upstream dmlc/xgboost/pull/6505
1 parent b4337b2 commit f85be44

File tree

1 file changed

+3
-19
lines changed

1 file changed

+3
-19
lines changed

src/booster.rs

+3-19
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ impl Booster {
152152
//let num_parallel_tree = 1;
153153

154154
// load distributed code checkpoint from rabit
155-
let version = bst.load_rabit_checkpoint()?;
155+
let version = unsafe { xgboost_sys::RabitVersionNumber() };
156156
debug!("Loaded Rabit checkpoint: version={}", version);
157157
assert!(unsafe { xgboost_sys::RabitGetWorldSize() != 1 || version == 0 });
158158

@@ -171,7 +171,6 @@ impl Booster {
171171
debug!("Updating in round: {}", i);
172172
bst.update(params.dtrain, i)?;
173173
}
174-
bst.save_rabit_checkpoint()?;
175174
}
176175

177176
assert!(unsafe { xgboost_sys::RabitGetWorldSize() == 1 || version == xgboost_sys::RabitVersionNumber() });
@@ -328,7 +327,7 @@ impl Booster {
328327
let name = "default";
329328
let mut eval = self.eval_set(&[(dmat, name)], 0)?;
330329
let mut result = HashMap::new();
331-
eval.remove(name).unwrap().into_iter().for_each(|(k, v)| {
330+
eval.swap_remove(name).unwrap().into_iter().for_each(|(k, v)| {
332331
result.insert(k.to_owned(), v);
333332
});
334333

@@ -564,16 +563,6 @@ impl Booster {
564563
Ok(out_vec.join("\n"))
565564
}
566565

567-
pub(crate) fn load_rabit_checkpoint(&self) -> XGBResult<i32> {
568-
let mut version = 0;
569-
xgb_call!(xgboost_sys::XGBoosterLoadRabitCheckpoint(self.handle, &mut version))?;
570-
Ok(version)
571-
}
572-
573-
pub(crate) fn save_rabit_checkpoint(&self) -> XGBResult<()> {
574-
xgb_call!(xgboost_sys::XGBoosterSaveRabitCheckpoint(self.handle))
575-
}
576-
577566
pub fn set_param(&mut self, name: &str, value: &str) -> XGBResult<()> {
578567
let name = ffi::CString::new(name).unwrap();
579568
let value = ffi::CString::new(value).unwrap();
@@ -742,11 +731,6 @@ mod tests {
742731
assert!(res.is_ok());
743732
}
744733

745-
#[test]
746-
fn load_rabit_version() {
747-
let version = load_test_booster().load_rabit_checkpoint().unwrap();
748-
assert_eq!(version, 0);
749-
}
750734

751735
#[test]
752736
fn get_set_attr() {
@@ -841,7 +825,7 @@ mod tests {
841825
assert_eq!(*train_metrics.get("map@4-").unwrap(), 1.0);
842826

843827
let test_metrics = booster.evaluate(&dmat_test).unwrap();
844-
assert_eq!(*test_metrics.get("logloss").unwrap(), 0.006919953);
828+
assert_eq!(*test_metrics.get("logloss").unwrap(), 0.0069199526);
845829
assert_eq!(*test_metrics.get("map@4-").unwrap(), 1.0);
846830

847831
let v = booster.predict(&dmat_test).unwrap();

0 commit comments

Comments
 (0)