Skip to content

Commit

Permalink
hacky fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 16, 2024
1 parent c1d13df commit 06e8adc
Showing 1 changed file with 49 additions and 3 deletions.
52 changes: 49 additions & 3 deletions examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,15 @@ void mnist_model_train(mnist_model & model, const float * images, const float *

struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
ggml_build_forward_expand(gf, model.loss);
struct ggml_tensor * fc1wg = model.fc1_weight->grad;
struct ggml_tensor * fc1bg = model.fc1_bias->grad;
struct ggml_tensor * fc2wg = model.fc2_weight->grad;
struct ggml_tensor * fc2bg = model.fc2_bias->grad;

struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf); // Backward pass, gradients.
ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ true, false);

struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gf); // Backward pass, gradients + optimizer.
struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad); // Backward pass, gradients + optimizer.
ggml_build_opt_adamw(model.ctx_compute, gf, gb_opt, 1e-3f, 0.9f, 0.999f, 1e-8f, 0.0f);

model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
Expand All @@ -557,16 +561,58 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images));
ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));

ggml_backend_graph_compute(model.backend, gf); // Always compute forward pass.

// With a period of nbatch_logical/nbatch_physical iterations:
if ((iex0 + model.nbatch_physical) % model.nbatch_logical != 0) {
// For the first nbatch_logical/nbatch_physical - 1 iterations, only calculate gradients and accumulate them:
ggml_backend_graph_compute(model.backend, gb_grad);
{
std::vector<float> tmp0(ggml_nelements(fc1wg));
std::vector<float> tmp1(ggml_nelements(fc1wg));
ggml_backend_tensor_get(fc1wg, tmp0.data(), 0, ggml_nbytes(fc1wg));
ggml_backend_tensor_get(model.fc1_weight->grad, tmp0.data(), 0, ggml_nbytes(fc1wg));
for (size_t i = 0; i < tmp0.size(); i++) {
tmp0[i] += tmp1[i];
}
ggml_backend_tensor_set(fc1wg, tmp0.data(), 0, ggml_nbytes(fc1wg));
}
{
std::vector<float> tmp0(ggml_nelements(fc1bg));
std::vector<float> tmp1(ggml_nelements(fc1bg));
ggml_backend_tensor_get(fc1bg, tmp0.data(), 0, ggml_nbytes(fc1bg));
ggml_backend_tensor_get(model.fc1_bias->grad, tmp0.data(), 0, ggml_nbytes(fc1bg));
for (size_t i = 0; i < tmp0.size(); i++) {
tmp0[i] += tmp1[i];
}
ggml_backend_tensor_set(fc1bg, tmp0.data(), 0, ggml_nbytes(fc1bg));
}
{
std::vector<float> tmp0(ggml_nelements(fc2wg));
std::vector<float> tmp1(ggml_nelements(fc2wg));
ggml_backend_tensor_get(fc2wg, tmp0.data(), 0, ggml_nbytes(fc2wg));
ggml_backend_tensor_get(model.fc2_weight->grad, tmp0.data(), 0, ggml_nbytes(fc2wg));
for (size_t i = 0; i < tmp0.size(); i++) {
tmp0[i] += tmp1[i];
}
ggml_backend_tensor_set(fc2wg, tmp0.data(), 0, ggml_nbytes(fc2wg));
}
{
std::vector<float> tmp0(ggml_nelements(fc2bg));
std::vector<float> tmp1(ggml_nelements(fc2bg));
ggml_backend_tensor_get(fc2bg, tmp0.data(), 0, ggml_nbytes(fc2bg));
ggml_backend_tensor_get(model.fc2_bias->grad, tmp0.data(), 0, ggml_nbytes(fc2bg));
for (size_t i = 0; i < tmp0.size(); i++) {
tmp0[i] += tmp1[i];
}
ggml_backend_tensor_set(fc2bg, tmp0.data(), 0, ggml_nbytes(fc2bg));
}
} else {
// For the last iteration, calculate gradients and also apply the optimizer:
ggml_backend_graph_compute(model.backend, gb_opt); // gb_opt contains all nodes of gb_grad so no extra call for gb_grad is needed.
ggml_graph_reset(gb_grad); // Set gradients to zero, do not reset optimizer.
ggml_set_zero(fc1wg);
ggml_set_zero(fc1bg);
ggml_set_zero(fc2wg);
ggml_set_zero(fc2bg);
}

ggml_backend_tensor_get(model.loss, &loss, 0, ggml_nbytes(model.loss));
Expand Down

0 comments on commit 06e8adc

Please # to comment.