diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 1499804c8592..43e97babf2a8 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -1,5 +1,5 @@ /** - * Copyright 2015-2023 by XGBoost Contributors + * Copyright 2015-2025, XGBoost Contributors * \file learner.h * \brief Learner interface that integrates objective, gbm and evaluation together. * This is the user facing XGBoost training module. @@ -151,9 +151,6 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { void LoadModel(Json const& in) override = 0; void SaveModel(Json* out) const override = 0; - virtual void LoadModel(dmlc::Stream* fi) = 0; - virtual void SaveModel(dmlc::Stream* fo) const = 0; - /*! * \brief Set multiple parameters at once. * diff --git a/include/xgboost/model.h b/include/xgboost/model.h index 610c7a0f5c48..c9c045234501 100644 --- a/include/xgboost/model.h +++ b/include/xgboost/model.h @@ -1,15 +1,12 @@ -/*! - * Copyright (c) 2019 by Contributors - * \file model.h - * \brief Defines the abstract interface for different components in XGBoost. +/** + * Copyright 2019-2025, XGBoost Contributors + * + * @file model.h + * @brief Defines the abstract interface for different components in XGBoost. */ #ifndef XGBOOST_MODEL_H_ #define XGBOOST_MODEL_H_ -namespace dmlc { -class Stream; -} // namespace dmlc - namespace xgboost { class Json; diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 921fc5a1ebc8..7deb5f7d2f78 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -365,17 +365,6 @@ class RegTree : public Model { return stats_[nid]; } - /*! - * \brief load model from stream - * \param fi input stream - */ - void Load(dmlc::Stream* fi); - /*! - * \brief save model to stream - * \param fo output stream - */ - void Save(dmlc::Stream* fo) const; - void LoadModel(Json const& in) override; void SaveModel(Json* out) const override; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 3fbfbc8b9792..c0af1258e514 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1368,7 +1368,38 @@ XGB_DLL int XGBoosterPredictFromCUDAColumnar(BoosterHandle handle, char const *, } #endif // !defined(XGBOOST_USE_CUDA) -XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) { +namespace { +template +Json DispatchModelType(Buffer const &buffer, StringView ext, bool warn) { + auto first_non_space = [&](Iter beg, Iter end) { + for (auto i = beg; i != end; ++i) { + if (!std::isspace(*i)) { + return i; + } + } + return end; + }; + + Json model; + auto it = first_non_space(buffer.cbegin() + 1, buffer.cend()); + if (it != buffer.cend() && *it == '"') { + if (warn) { + LOG(WARNING) << "Unknown file format: `" << ext << "`. Using JSON as a guess."; + } + model = Json::Load(StringView{buffer.data(), buffer.size()}); + } else if (it != buffer.cend() && std::isalpha(*it)) { + if (warn) { + LOG(WARNING) << "Unknown file format: `" << ext << "`. Using UBJ as a guess."; + } + model = Json::Load(StringView{buffer.data(), buffer.size()}, std::ios::binary); + } else { + LOG(FATAL) << "Invalid model format"; + } + return model; +} +} // namespace + +XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char *fname) { API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(fname); @@ -1378,28 +1409,23 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) { CHECK_EQ(str[0], '{'); return str; }; - if (common::FileExtension(fname) == "json") { + auto ext = common::FileExtension(fname); + if (ext == "json") { auto buffer = read_file(); Json in{Json::Load(StringView{buffer.data(), buffer.size()})}; - static_cast(handle)->LoadModel(in); - } else if (common::FileExtension(fname) == "ubj") { + static_cast(handle)->LoadModel(in); + } else if (ext == "ubj") { auto buffer = read_file(); Json in = Json::Load(StringView{buffer.data(), buffer.size()}, std::ios::binary); static_cast(handle)->LoadModel(in); } else { - std::unique_ptr fi(dmlc::Stream::Create(fname, "r")); - static_cast(handle)->LoadModel(fi.get()); + auto buffer = read_file(); + auto in = DispatchModelType(buffer, ext, true); + static_cast(handle)->LoadModel(in); } API_END(); } -namespace { -void WarnOldModel() { - LOG(WARNING) << "Saving into deprecated binary model format, please consider using `json` or " - "`ubj`. Model format is default to UBJSON in XGBoost 2.1 if not specified."; -} -} // anonymous namespace - XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *fname) { API_BEGIN(); CHECK_HANDLE(); @@ -1419,13 +1445,9 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *fname) { save_json(std::ios::out); } else if (common::FileExtension(fname) == "ubj") { save_json(std::ios::binary); - } else if (common::FileExtension(fname) == "deprecated") { - WarnOldModel(); - auto *bst = static_cast(handle); - bst->SaveModel(fo.get()); } else { LOG(WARNING) << "Saving model in the UBJSON format as default. You can use file extension:" - " `json`, `ubj` or `deprecated` to choose between formats."; + " `json` or `ubj` to choose between formats."; save_json(std::ios::binary); } API_END(); @@ -1436,9 +1458,11 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, const void *buf, API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(buf); - + auto buffer = common::Span{static_cast(buf), len}; + // Don't warn, we have to guess the format with buffer input. + auto in = DispatchModelType(buffer, "", false); common::MemoryFixSizeBuffer fs((void *)buf, len); // NOLINT(*) - static_cast(handle)->LoadModel(&fs); + static_cast(handle)->LoadModel(in); API_END(); } @@ -1471,15 +1495,6 @@ XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_co save_json(std::ios::out); } else if (format == "ubj") { save_json(std::ios::binary); - } else if (format == "deprecated") { - WarnOldModel(); - auto &raw_str = learner->GetThreadLocal().ret_str; - raw_str.clear(); - common::MemoryBufferStream fo(&raw_str); - learner->SaveModel(&fo); - - *out_dptr = dmlc::BeginPtr(raw_str); - *out_len = static_cast(raw_str.size()); } else { LOG(FATAL) << "Unknown format: `" << format << "`"; } diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 2cacfe078b4b..5c9208f39a75 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -101,13 +101,6 @@ class GBLinear : public GradientBooster { bool ModelFitted() const override { return BoostedRounds() != 0; } - void Load(dmlc::Stream* fi) override { - model_.Load(fi); - } - void Save(dmlc::Stream* fo) const override { - model_.Save(fo); - } - void SaveModel(Json* p_out) const override { auto& out = *p_out; out["name"] = String{"gblinear"}; diff --git a/src/gbm/gblinear_model.cc b/src/gbm/gblinear_model.cc index 5e6f5dda9a1f..0be4b5a2914f 100644 --- a/src/gbm/gblinear_model.cc +++ b/src/gbm/gblinear_model.cc @@ -1,9 +1,8 @@ -/*! - * Copyright 2019-2022 by Contributors +/** + * Copyright 2019-2025, XGBoost Contributors */ #include #include -#include #include "xgboost/json.h" #include "gblinear_model.h" diff --git a/src/gbm/gblinear_model.h b/src/gbm/gblinear_model.h index 91760346ca47..c82627067b2c 100644 --- a/src/gbm/gblinear_model.h +++ b/src/gbm/gblinear_model.h @@ -71,17 +71,6 @@ class GBLinearModel : public Model { void SaveModel(Json *p_out) const override; void LoadModel(Json const &in) override; - // save the model to file - void Save(dmlc::Stream *fo) const { - fo->Write(¶m_, sizeof(param_)); - fo->Write(weight); - } - // load model from file - void Load(dmlc::Stream *fi) { - CHECK_EQ(fi->Read(¶m_, sizeof(param_)), sizeof(param_)); - fi->Read(&weight); - } - // model bias inline bst_float *Bias() { return &weight[learner_model_param->num_feature * diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 1fbf0ebdaf7f..2df185e6db92 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -188,11 +188,6 @@ class GBTree : public GradientBooster { [[nodiscard]] GBTreeTrainParam const& GetTrainParam() const { return tparam_; } - void Load(dmlc::Stream* fi) override { model_.Load(fi); } - void Save(dmlc::Stream* fo) const override { - model_.Save(fo); - } - void LoadConfig(Json const& in) override; void SaveConfig(Json* p_out) const override; diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index 2edb456c95de..cba4e0f9b8ee 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -50,62 +50,6 @@ void Validate(GBTreeModel const& model) { } } // namespace -void GBTreeModel::Save(dmlc::Stream* fo) const { - CHECK_EQ(param.num_trees, static_cast(trees.size())); - - if (DMLC_IO_NO_ENDIAN_SWAP) { - fo->Write(¶m, sizeof(param)); - } else { - auto x = param.ByteSwap(); - fo->Write(&x, sizeof(x)); - } - for (const auto & tree : trees) { - tree->Save(fo); - } - if (tree_info.size() != 0) { - if (DMLC_IO_NO_ENDIAN_SWAP) { - fo->Write(dmlc::BeginPtr(tree_info), sizeof(int32_t) * tree_info.size()); - } else { - for (const auto& e : tree_info) { - auto x = e; - dmlc::ByteSwap(&x, sizeof(x), 1); - fo->Write(&x, sizeof(x)); - } - } - } -} - -void GBTreeModel::Load(dmlc::Stream* fi) { - CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param)) - << "GBTree: invalid model file"; - if (!DMLC_IO_NO_ENDIAN_SWAP) { - param = param.ByteSwap(); - } - trees.clear(); - trees_to_update.clear(); - for (int32_t i = 0; i < param.num_trees; ++i) { - std::unique_ptr ptr(new RegTree()); - ptr->Load(fi); - trees.push_back(std::move(ptr)); - } - tree_info.resize(param.num_trees); - if (param.num_trees != 0) { - if (DMLC_IO_NO_ENDIAN_SWAP) { - CHECK_EQ( - fi->Read(dmlc::BeginPtr(tree_info), sizeof(int32_t) * param.num_trees), - sizeof(int32_t) * param.num_trees); - } else { - for (auto& info : tree_info) { - CHECK_EQ(fi->Read(&info, sizeof(int32_t)), sizeof(int32_t)); - dmlc::ByteSwap(&info, sizeof(info), 1); - } - } - } - - MakeIndptr(this); - Validate(*this); -} - void GBTreeModel::SaveModel(Json* p_out) const { auto& out = *p_out; CHECK_EQ(param.num_trees, static_cast(trees.size())); diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 32fa868638bb..c561fcd066ee 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -106,9 +106,6 @@ struct GBTreeModel : public Model { } } - void Load(dmlc::Stream* fi); - void Save(dmlc::Stream* fo) const; - void SaveModel(Json* p_out) const override; void LoadModel(Json const& p_out) override; diff --git a/src/learner.cc b/src/learner.cc index 34f395beb34b..7e33ab5fab18 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -853,11 +853,6 @@ class LearnerConfiguration : public Learner { std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT class LearnerIO : public LearnerConfiguration { - private: - // Used to identify the offset of JSON string when - // Will be removed once JSON takes over. Right now we still loads some RDS files from R. - std::string const serialisation_header_ { u8"CONFIG-offset:" }; - protected: void ClearCaches() { this->prediction_container_ = PredictionContainer{}; } @@ -956,186 +951,34 @@ class LearnerIO : public LearnerConfiguration { void LoadModel(dmlc::Stream* fi) override { ctx_.UpdateAllowUnknown(Args{}); tparam_.Init(std::vector>{}); + // TODO(tqchen) mark deprecation of old format. common::PeekableInStream fp(fi); // backward compatible header check. std::string header; header.resize(4); + StringView msg = "Only `json` and `ubj` is supported starting from 3.1."; if (fp.PeekRead(&header[0], 4) == 4) { - CHECK_NE(header, "bs64") - << "Base64 format is no longer supported in brick."; - if (header == "binf") { - CHECK_EQ(fp.Read(&header[0], 4), 4U); - } + CHECK_NE(header, "bs64") << msg; + CHECK_NE(header, "binf") << msg; } + CHECK_EQ(header[0], '{') << msg; // FIXME(jiamingy): Move this out of learner after the old binary model is remove. - auto first_non_space = [&](std::string::const_iterator beg, std::string::const_iterator end) { - for (auto i = beg; i != end; ++i) { - if (!std::isspace(*i)) { - return i; - } - } - return end; - }; - if (header[0] == '{') { // Dispatch to JSON - auto buffer = common::ReadAll(fi, &fp); - Json model; - auto it = first_non_space(buffer.cbegin() + 1, buffer.cend()); - if (it != buffer.cend() && *it == '"') { - model = Json::Load(StringView{buffer}); - } else if (it != buffer.cend() && std::isalpha(*it)) { - model = Json::Load(StringView{buffer}, std::ios::binary); - } else { - LOG(FATAL) << "Invalid model format"; - } - this->LoadModel(model); - return; - } - - // use the peekable reader. - fi = &fp; - // read parameter - CHECK_EQ(fi->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_)) - << "BoostLearner: wrong model format"; - if (!DMLC_IO_NO_ENDIAN_SWAP) { - mparam_ = mparam_.ByteSwap(); - } - CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format"; - CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format"; - - obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_)); - gbm_.reset(GradientBooster::Create(tparam_.booster, &ctx_, &learner_model_param_)); - gbm_->Load(fi); - if (mparam_.contain_extra_attrs != 0) { - std::vector > attr; - fi->Read(&attr); - attributes_ = std::map(attr.begin(), attr.end()); - } - bool warn_old_model { false }; - if (attributes_.find("count_poisson_max_delta_step") != attributes_.cend()) { - // Loading model from < 1.0.0, objective is not saved. - cfg_["max_delta_step"] = attributes_.at("count_poisson_max_delta_step"); - attributes_.erase("count_poisson_max_delta_step"); - warn_old_model = true; + // Dispatch to JSON + auto buffer = common::ReadAll(fi, &fp); + Json model; + auto it = first_non_space(buffer.cbegin() + 1, buffer.cend()); + if (it != buffer.cend() && *it == '"') { + model = Json::Load(StringView{buffer}); + } else if (it != buffer.cend() && std::isalpha(*it)) { + model = Json::Load(StringView{buffer}, std::ios::binary); } else { - warn_old_model = false; - } - - if (mparam_.major_version < 1) { - // Before 1.0.0, base_score is saved as a transformed value, and there's no version - // attribute (saved a 0) in the saved model. - std::string multi{"multi:"}; - if (!std::equal(multi.cbegin(), multi.cend(), tparam_.objective.cbegin())) { - HostDeviceVector t; - t.HostVector().resize(1); - t.HostVector().at(0) = mparam_.base_score; - this->obj_->PredTransform(&t); - auto base_score = t.HostVector().at(0); - mparam_.base_score = base_score; - } - warn_old_model = true; - } - - learner_model_param_ = - LearnerModelParam(&ctx_, mparam_, - linalg::Tensor{{std::isnan(mparam_.base_score) - ? std::numeric_limits::quiet_NaN() - : obj_->ProbToMargin(mparam_.base_score)}, - {1}, - DeviceOrd::CPU()}, - obj_->Task(), tparam_.multi_strategy); - - if (attributes_.find("objective") != attributes_.cend()) { - auto obj_str = attributes_.at("objective"); - auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()}); - obj_->LoadConfig(j_obj); - attributes_.erase("objective"); - } else { - warn_old_model = true; - } - if (attributes_.find("metrics") != attributes_.cend()) { - auto metrics_str = attributes_.at("metrics"); - std::vector names { common::Split(metrics_str, ';') }; - attributes_.erase("metrics"); - for (auto const& n : names) { - this->SetParam(kEvalMetric, n); - } - } - - if (warn_old_model) { - LOG(WARNING) << "Loading model from XGBoost < 1.0.0, consider saving it " - "again for improved compatibility"; - } - - // Renew the version. - mparam_.major_version = std::get<0>(Version::Self()); - mparam_.minor_version = std::get<1>(Version::Self()); - - cfg_["num_feature"] = std::to_string(mparam_.num_feature); - - auto n = tparam_.__DICT__(); - cfg_.insert(n.cbegin(), n.cend()); - - this->need_configuration_ = true; - this->ClearCaches(); - } - - // Save model into binary format. The code is about to be deprecated by more robust - // JSON serialization format. - void SaveModel(dmlc::Stream* fo) const override { - this->CheckModelInitialized(); - CHECK(!this->learner_model_param_.IsVectorLeaf()) - << "Please use JSON/UBJ format for model serialization with multi-output models."; - - LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify - std::vector > extra_attr; - mparam.contain_extra_attrs = 1; - - if (!this->feature_names_.empty() || !this->feature_types_.empty()) { - LOG(WARNING) << "feature names and feature types are being disregarded, use JSON/UBJSON " - "format instead."; - } - - { - // Similar to JSON model IO, we save the objective. - Json j_obj { Object() }; - obj_->SaveConfig(&j_obj); - std::string obj_doc; - Json::Dump(j_obj, &obj_doc); - extra_attr.emplace_back("objective", obj_doc); - } - // As of 1.0.0, JVM Package and R Package uses Save/Load model for serialization. - // Remove this part once they are ported to use actual serialization methods. - if (mparam.contain_eval_metrics != 0) { - std::stringstream os; - for (auto& ev : metrics_) { - os << ev->Name() << ";"; - } - extra_attr.emplace_back("metrics", os.str()); - } - - std::string header {"binf"}; - fo->Write(header.data(), 4); - if (DMLC_IO_NO_ENDIAN_SWAP) { - fo->Write(&mparam, sizeof(LearnerModelParamLegacy)); - } else { - LearnerModelParamLegacy x = mparam.ByteSwap(); - fo->Write(&x, sizeof(LearnerModelParamLegacy)); - } - fo->Write(tparam_.objective); - fo->Write(tparam_.booster); - gbm_->Save(fo); - if (mparam.contain_extra_attrs != 0) { - std::map attr(attributes_); - for (const auto& kv : extra_attr) { - attr[kv.first] = kv.second; - } - fo->Write(std::vector>( - attr.begin(), attr.end())); + LOG(FATAL) << "Invalid model format"; } + this->LoadModel(model); } void Save(dmlc::Stream* fo) const override { @@ -1158,46 +1001,28 @@ class LearnerIO : public LearnerConfiguration { common::PeekableInStream fp(fi); char header[2]; fp.PeekRead(header, 2); - if (header[0] == '{') { - auto buffer = common::ReadAll(fi, &fp); - Json memory_snapshot; - if (header[1] == '"') { - memory_snapshot = Json::Load(StringView{buffer}); - error::WarnOldSerialization(); - } else if (std::isalpha(header[1])) { - memory_snapshot = Json::Load(StringView{buffer}, std::ios::binary); - } else { - LOG(FATAL) << "Invalid serialization file."; - } - if (IsA(memory_snapshot["Model"])) { - // R has xgb.load that doesn't distinguish whether configuration is saved. - // We should migrate to use `xgb.load.raw` instead. - this->LoadModel(memory_snapshot); - } else { - this->LoadModel(memory_snapshot["Model"]); - this->LoadConfig(memory_snapshot["Config"]); - } + StringView msg = "Invalid serialization file."; + CHECK_EQ(header[0], '{') << msg; + + auto buffer = common::ReadAll(fi, &fp); + Json memory_snapshot; + CHECK(std::isalpha(header[1])) << msg; + if (header[1] == '"') { + memory_snapshot = Json::Load(StringView{buffer}); + error::WarnOldSerialization(); + } else if (std::isalpha(header[1])) { + memory_snapshot = Json::Load(StringView{buffer}, std::ios::binary); } else { - std::string header; - header.resize(serialisation_header_.size()); - CHECK_EQ(fp.Read(&header[0], header.size()), serialisation_header_.size()); - // Avoid printing the content in loaded header, which might be random binary code. - CHECK(header == serialisation_header_) << error::OldSerialization(); - int64_t sz {-1}; - CHECK_EQ(fp.Read(&sz, sizeof(sz)), sizeof(sz)); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(&sz, sizeof(sz), 1); - } - CHECK_GT(sz, 0); - size_t json_offset = static_cast(sz); - std::string buffer; - common::FixedSizeStream{&fp}.Take(&buffer); - - common::MemoryFixSizeBuffer binary_buf(&buffer[0], json_offset); - this->LoadModel(&binary_buf); + LOG(FATAL) << "Invalid serialization file."; + } - auto config = Json::Load({buffer.c_str() + json_offset, buffer.size() - json_offset}); - this->LoadConfig(config); + if (IsA(memory_snapshot["Model"])) { + // R has xgb.load that doesn't distinguish whether configuration is saved. + // We should migrate to use `xgb.load.raw` instead. + this->LoadModel(memory_snapshot); + } else { + this->LoadModel(memory_snapshot["Model"]); + this->LoadConfig(memory_snapshot["Config"]); } } }; diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 0639233510f7..a35df0424b49 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -911,76 +911,6 @@ void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index, this->split_categories_segments_.at(nid).size = split_cat.size(); } -void RegTree::Load(dmlc::Stream* fi) { - CHECK_EQ(fi->Read(¶m_, sizeof(TreeParam)), sizeof(TreeParam)); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - param_ = param_.ByteSwap(); - } - nodes_.resize(param_.num_nodes); - stats_.resize(param_.num_nodes); - CHECK_NE(param_.num_nodes, 0); - CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()), - sizeof(Node) * nodes_.size()); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - for (Node& node : nodes_) { - node = node.ByteSwap(); - } - } - CHECK_EQ(fi->Read(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * stats_.size()), - sizeof(RTreeNodeStat) * stats_.size()); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - for (RTreeNodeStat& stat : stats_) { - stat = stat.ByteSwap(); - } - } - // chg deleted nodes - deleted_nodes_.resize(0); - for (int i = 1; i < param_.num_nodes; ++i) { - if (nodes_[i].IsDeleted()) { - deleted_nodes_.push_back(i); - } - } - CHECK_EQ(static_cast(deleted_nodes_.size()), param_.num_deleted); - - split_types_.resize(param_.num_nodes, FeatureType::kNumerical); - split_categories_segments_.resize(param_.num_nodes); -} - -void RegTree::Save(dmlc::Stream* fo) const { - CHECK_EQ(param_.num_nodes, static_cast(nodes_.size())); - CHECK_EQ(param_.num_nodes, static_cast(stats_.size())); - CHECK_EQ(param_.deprecated_num_roots, 1); - CHECK_NE(param_.num_nodes, 0); - CHECK(!IsMultiTarget()) - << "Please use JSON/UBJSON for saving models with multi-target trees."; - CHECK(!HasCategoricalSplit()) - << "Please use JSON/UBJSON for saving models with categorical splits."; - - if (DMLC_IO_NO_ENDIAN_SWAP) { - fo->Write(¶m_, sizeof(TreeParam)); - } else { - TreeParam x = param_.ByteSwap(); - fo->Write(&x, sizeof(x)); - } - - if (DMLC_IO_NO_ENDIAN_SWAP) { - fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()); - } else { - for (const Node& node : nodes_) { - Node x = node.ByteSwap(); - fo->Write(&x, sizeof(x)); - } - } - if (DMLC_IO_NO_ENDIAN_SWAP) { - fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size()); - } else { - for (const RTreeNodeStat& stat : stats_) { - RTreeNodeStat x = stat.ByteSwap(); - fo->Write(&x, sizeof(x)); - } - } -} - template void RegTree::LoadCategoricalSplit(Json const& in) { auto const& categories_segments = get>(in["categories_segments"]);