Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[WIP] Drop the deprecated binary format. #11307

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions include/xgboost/learner.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2015-2023 by XGBoost Contributors
* Copyright 2015-2025, XGBoost Contributors
* \file learner.h
* \brief Learner interface that integrates objective, gbm and evaluation together.
* This is the user facing XGBoost training module.
Expand Down Expand Up @@ -151,9 +151,6 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
void LoadModel(Json const& in) override = 0;
void SaveModel(Json* out) const override = 0;

virtual void LoadModel(dmlc::Stream* fi) = 0;
virtual void SaveModel(dmlc::Stream* fo) const = 0;

/*!
* \brief Set multiple parameters at once.
*
Expand Down
13 changes: 5 additions & 8 deletions include/xgboost/model.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
/*!
* Copyright (c) 2019 by Contributors
* \file model.h
* \brief Defines the abstract interface for different components in XGBoost.
/**
* Copyright 2019-2025, XGBoost Contributors
*
* @file model.h
* @brief Defines the abstract interface for different components in XGBoost.
*/
#ifndef XGBOOST_MODEL_H_
#define XGBOOST_MODEL_H_

namespace dmlc {
class Stream;
} // namespace dmlc

namespace xgboost {

class Json;
Expand Down
11 changes: 0 additions & 11 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,17 +365,6 @@ class RegTree : public Model {
return stats_[nid];
}

/*!
* \brief load model from stream
* \param fi input stream
*/
void Load(dmlc::Stream* fi);
/*!
* \brief save model to stream
* \param fo output stream
*/
void Save(dmlc::Stream* fo) const;

void LoadModel(Json const& in) override;
void SaveModel(Json* out) const override;

Expand Down
73 changes: 44 additions & 29 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,38 @@ XGB_DLL int XGBoosterPredictFromCUDAColumnar(BoosterHandle handle, char const *,
}
#endif // !defined(XGBOOST_USE_CUDA)

XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
namespace {
template <typename Buffer, typename Iter = typename Buffer::const_iterator>
Json DispatchModelType(Buffer const &buffer, StringView ext, bool warn) {
auto first_non_space = [&](Iter beg, Iter end) {
for (auto i = beg; i != end; ++i) {
if (!std::isspace(*i)) {
return i;
}
}
return end;
};

Json model;
auto it = first_non_space(buffer.cbegin() + 1, buffer.cend());
if (it != buffer.cend() && *it == '"') {
if (warn) {
LOG(WARNING) << "Unknown file format: `" << ext << "`. Using JSON as a guess.";
}
model = Json::Load(StringView{buffer.data(), buffer.size()});
} else if (it != buffer.cend() && std::isalpha(*it)) {
if (warn) {
LOG(WARNING) << "Unknown file format: `" << ext << "`. Using UBJ as a guess.";
}
model = Json::Load(StringView{buffer.data(), buffer.size()}, std::ios::binary);
} else {
LOG(FATAL) << "Invalid model format";
}
return model;
}
} // namespace

XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char *fname) {
API_BEGIN();
CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(fname);
Expand All @@ -1378,28 +1409,23 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
CHECK_EQ(str[0], '{');
return str;
};
if (common::FileExtension(fname) == "json") {
auto ext = common::FileExtension(fname);
if (ext == "json") {
auto buffer = read_file();
Json in{Json::Load(StringView{buffer.data(), buffer.size()})};
static_cast<Learner*>(handle)->LoadModel(in);
} else if (common::FileExtension(fname) == "ubj") {
static_cast<Learner *>(handle)->LoadModel(in);
} else if (ext == "ubj") {
auto buffer = read_file();
Json in = Json::Load(StringView{buffer.data(), buffer.size()}, std::ios::binary);
static_cast<Learner *>(handle)->LoadModel(in);
} else {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
static_cast<Learner*>(handle)->LoadModel(fi.get());
auto buffer = read_file();
auto in = DispatchModelType(buffer, ext, true);
static_cast<Learner *>(handle)->LoadModel(in);
}
API_END();
}

namespace {
void WarnOldModel() {
LOG(WARNING) << "Saving into deprecated binary model format, please consider using `json` or "
"`ubj`. Model format is default to UBJSON in XGBoost 2.1 if not specified.";
}
} // anonymous namespace

XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *fname) {
API_BEGIN();
CHECK_HANDLE();
Expand All @@ -1419,13 +1445,9 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *fname) {
save_json(std::ios::out);
} else if (common::FileExtension(fname) == "ubj") {
save_json(std::ios::binary);
} else if (common::FileExtension(fname) == "deprecated") {
WarnOldModel();
auto *bst = static_cast<Learner *>(handle);
bst->SaveModel(fo.get());
} else {
LOG(WARNING) << "Saving model in the UBJSON format as default. You can use file extension:"
" `json`, `ubj` or `deprecated` to choose between formats.";
" `json` or `ubj` to choose between formats.";
save_json(std::ios::binary);
}
API_END();
Expand All @@ -1436,9 +1458,11 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, const void *buf,
API_BEGIN();
CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(buf);

auto buffer = common::Span<char const>{static_cast<char const *>(buf), len};
// Don't warn, we have to guess the format with buffer input.
auto in = DispatchModelType(buffer, "", false);
common::MemoryFixSizeBuffer fs((void *)buf, len); // NOLINT(*)
static_cast<Learner *>(handle)->LoadModel(&fs);
static_cast<Learner *>(handle)->LoadModel(in);
API_END();
}

Expand Down Expand Up @@ -1471,15 +1495,6 @@ XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_co
save_json(std::ios::out);
} else if (format == "ubj") {
save_json(std::ios::binary);
} else if (format == "deprecated") {
WarnOldModel();
auto &raw_str = learner->GetThreadLocal().ret_str;
raw_str.clear();
common::MemoryBufferStream fo(&raw_str);
learner->SaveModel(&fo);

*out_dptr = dmlc::BeginPtr(raw_str);
*out_len = static_cast<xgboost::bst_ulong>(raw_str.size());
} else {
LOG(FATAL) << "Unknown format: `" << format << "`";
}
Expand Down
7 changes: 0 additions & 7 deletions src/gbm/gblinear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,6 @@ class GBLinear : public GradientBooster {

bool ModelFitted() const override { return BoostedRounds() != 0; }

void Load(dmlc::Stream* fi) override {
model_.Load(fi);
}
void Save(dmlc::Stream* fo) const override {
model_.Save(fo);
}

void SaveModel(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String{"gblinear"};
Expand Down
5 changes: 2 additions & 3 deletions src/gbm/gblinear_model.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
/*!
* Copyright 2019-2022 by Contributors
/**
* Copyright 2019-2025, XGBoost Contributors
*/
#include <algorithm>
#include <utility>
#include <limits>
#include "xgboost/json.h"
#include "gblinear_model.h"

Expand Down
11 changes: 0 additions & 11 deletions src/gbm/gblinear_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,6 @@ class GBLinearModel : public Model {
void SaveModel(Json *p_out) const override;
void LoadModel(Json const &in) override;

// save the model to file
void Save(dmlc::Stream *fo) const {
fo->Write(&param_, sizeof(param_));
fo->Write(weight);
}
// load model from file
void Load(dmlc::Stream *fi) {
CHECK_EQ(fi->Read(&param_, sizeof(param_)), sizeof(param_));
fi->Read(&weight);
}

// model bias
inline bst_float *Bias() {
return &weight[learner_model_param->num_feature *
Expand Down
5 changes: 0 additions & 5 deletions src/gbm/gbtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,6 @@ class GBTree : public GradientBooster {

[[nodiscard]] GBTreeTrainParam const& GetTrainParam() const { return tparam_; }

void Load(dmlc::Stream* fi) override { model_.Load(fi); }
void Save(dmlc::Stream* fo) const override {
model_.Save(fo);
}

void LoadConfig(Json const& in) override;
void SaveConfig(Json* p_out) const override;

Expand Down
56 changes: 0 additions & 56 deletions src/gbm/gbtree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,62 +50,6 @@ void Validate(GBTreeModel const& model) {
}
} // namespace

void GBTreeModel::Save(dmlc::Stream* fo) const {
CHECK_EQ(param.num_trees, static_cast<int32_t>(trees.size()));

if (DMLC_IO_NO_ENDIAN_SWAP) {
fo->Write(&param, sizeof(param));
} else {
auto x = param.ByteSwap();
fo->Write(&x, sizeof(x));
}
for (const auto & tree : trees) {
tree->Save(fo);
}
if (tree_info.size() != 0) {
if (DMLC_IO_NO_ENDIAN_SWAP) {
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int32_t) * tree_info.size());
} else {
for (const auto& e : tree_info) {
auto x = e;
dmlc::ByteSwap(&x, sizeof(x), 1);
fo->Write(&x, sizeof(x));
}
}
}
}

void GBTreeModel::Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(param)), sizeof(param))
<< "GBTree: invalid model file";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
param = param.ByteSwap();
}
trees.clear();
trees_to_update.clear();
for (int32_t i = 0; i < param.num_trees; ++i) {
std::unique_ptr<RegTree> ptr(new RegTree());
ptr->Load(fi);
trees.push_back(std::move(ptr));
}
tree_info.resize(param.num_trees);
if (param.num_trees != 0) {
if (DMLC_IO_NO_ENDIAN_SWAP) {
CHECK_EQ(
fi->Read(dmlc::BeginPtr(tree_info), sizeof(int32_t) * param.num_trees),
sizeof(int32_t) * param.num_trees);
} else {
for (auto& info : tree_info) {
CHECK_EQ(fi->Read(&info, sizeof(int32_t)), sizeof(int32_t));
dmlc::ByteSwap(&info, sizeof(info), 1);
}
}
}

MakeIndptr(this);
Validate(*this);
}

void GBTreeModel::SaveModel(Json* p_out) const {
auto& out = *p_out;
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
Expand Down
3 changes: 0 additions & 3 deletions src/gbm/gbtree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,6 @@ struct GBTreeModel : public Model {
}
}

void Load(dmlc::Stream* fi);
void Save(dmlc::Stream* fo) const;

void SaveModel(Json* p_out) const override;
void LoadModel(Json const& p_out) override;

Expand Down
Loading
Loading