Skip to content

Commit bf7b663

Browse files
committed
Improved C++ code
1 parent eae9037 commit bf7b663

File tree

7 files changed

+26
-11
lines changed

7 files changed

+26
-11
lines changed

ext/torch/device.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
#include <string>
2+
13
#include <torch/torch.h>
24

35
#include <rice/rice.hpp>
6+
#include <rice/stl.hpp>
47

58
#include "utils.h"
69

ext/torch/ext.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
1717
void init_random(Rice::Module& m);
1818

1919
extern "C"
20-
void Init_ext()
21-
{
20+
void Init_ext() {
2221
auto m = Rice::define_module("Torch");
2322

2423
// need to define certain classes up front to keep Rice happy

ext/torch/ivalue.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <utility>
2+
13
#include <torch/torch.h>
24

35
#include <rice/rice.hpp>

ext/torch/nn.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <utility>
2+
13
#include <torch/torch.h>
24

35
#include <rice/rice.hpp>

ext/torch/ruby_arg_parser.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
// adapted from PyTorch - python_arg_parser.cpp
22

3+
#include <string>
4+
#include <unordered_map>
5+
#include <unordered_set>
6+
#include <vector>
7+
38
#include "ruby_arg_parser.h"
49

510
VALUE THPGeneratorClass = Qnil;
@@ -99,7 +104,7 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
99104
ruby_name = THPUtils_internSymbol(name);
100105
auto np_compat_it = numpy_compatibility_arg_names.find(name);
101106
if (np_compat_it != numpy_compatibility_arg_names.end()) {
102-
for (const auto& str: np_compat_it->second) {
107+
for (const auto& str : np_compat_it->second) {
103108
numpy_python_names.push_back(THPUtils_internSymbol(str));
104109
}
105110
}
@@ -190,8 +195,7 @@ static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
190195
}
191196

192197
// argnum is needed for raising the TypeError, it's used in the error message.
193-
auto FunctionParameter::check(VALUE obj, int argnum) -> bool
194-
{
198+
auto FunctionParameter::check(VALUE obj, int argnum) -> bool {
195199
switch (type_) {
196200
case ParameterType::TENSOR: {
197201
if (THPVariable_Check(obj)) {

ext/torch/tensor.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#include <string>
2+
#include <vector>
3+
14
#include <torch/torch.h>
25

36
#include <rice/rice.hpp>
@@ -7,7 +10,8 @@
710
#include "templates.h"
811
#include "utils.h"
912

10-
using namespace Rice;
13+
using Rice::Array;
14+
using Rice::Object;
1115
using torch::indexing::TensorIndex;
1216

1317
template<typename T>
@@ -75,8 +79,7 @@ std::vector<TensorIndex> index_vector(Array a) {
7579
// https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
7680
// TODO add support for inputs argument
7781
// _backward
78-
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
79-
{
82+
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_) {
8083
HANDLE_TH_ERRORS
8184
Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
8285
static RubyArgParser parser({
@@ -197,7 +200,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
197200
.define_method(
198201
"_dtype",
199202
[](Tensor& self) {
200-
return (int) at::typeMetaToScalarType(self.dtype());
203+
return static_cast<int>(at::typeMetaToScalarType(self.dtype()));
201204
})
202205
.define_method(
203206
"_type",

ext/torch/torch.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
#include <fstream>
2+
#include <string>
3+
#include <vector>
4+
15
#include <torch/torch.h>
26

37
#include <rice/rice.hpp>
48

5-
#include <fstream>
6-
79
#include "torch_functions.h"
810
#include "templates.h"
911
#include "utils.h"

0 commit comments

Comments
 (0)