Skip to content

Commit

Permalink
Tweak the code generation to properly mark self as mutable, closes #33.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed May 5, 2019
1 parent 2d4743b commit b564bd0
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 267 deletions.
9 changes: 7 additions & 2 deletions gen/gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,18 @@ module Func = struct
Printf.sprintf "%s: %s" (rust_name arg.arg_name) rust_arg_type )
|> String.concat ~sep:", "
in
let self_arg =
if String.is_suffix t.name ~suffix:"_"
then "&mut self"
else "&self"
in
match List.partition_tf t.args ~f:self_tensor with
| [self], args_list ->
(Some self.arg_name, Printf.sprintf "&self, %s" (to_string args_list))
(Some self.arg_name, Printf.sprintf "%s, %s" self_arg (to_string args_list))
| _, _ -> (
match List.partition_tf t.args ~f:input_tensor with
| [self], args_list ->
(Some self.arg_name, Printf.sprintf "&self, %s" (to_string args_list))
(Some self.arg_name, Printf.sprintf "%s, %s" self_arg (to_string args_list))
| _, _ -> (None, to_string t.args) )

let rust_return_type t ~fallible =
Expand Down
2 changes: 1 addition & 1 deletion src/wrappers/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl Tensor {

/// Zeroes the gradient tensor attached to this tensor if defined.
pub fn zero_grad(&mut self) {
let grad = self.grad();
let mut grad = self.grad();
if grad.defined() {
let _ = grad.detach_().zero_();
}
Expand Down
Loading

0 comments on commit b564bd0

Please # to comment.