Skip to content

Commit b62ac02

Browse files
committed
Improved error classes
1 parent 9c0e550 commit b62ac02

File tree

6 files changed

+11
-7
lines changed

6 files changed

+11
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.22.0 (unreleased)
22

33
- Updated LibTorch to 2.9.0
4+
- Improved error classes
45

56
## 0.21.0 (2025-08-07)
67

ext/torch/ext.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
1616
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
1717
void init_random(Rice::Module& m);
1818

19+
VALUE rb_eTorchError = Qnil;
20+
1921
extern "C"
2022
void Init_ext() {
2123
auto m = Rice::define_module("Torch");
2224

25+
rb_eTorchError = Rice::define_class_under(m, "Error", rb_eStandardError);
26+
2327
// need to define certain classes up front to keep Rice happy
2428
auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
2529
.define_constructor(Rice::Constructor<torch::IValue>());

ext/torch/templates.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using torch::nn::init::NonlinearityType;
3131

3232
#define END_HANDLE_TH_ERRORS \
3333
} catch (const torch::Error& ex) { \
34-
rb_raise(rb_eRuntimeError, "%s", ex.what_without_backtrace()); \
34+
rb_raise(rb_eTorchError, "%s", ex.what_without_backtrace()); \
3535
} catch (const Rice::Exception& ex) { \
3636
rb_raise(ex.class_of(), "%s", ex.what()); \
3737
} catch (const std::exception& ex) { \

ext/torch/utils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ static_assert(
1212
"Incompatible LibTorch version"
1313
);
1414

15+
extern VALUE rb_eTorchError;
16+
1517
inline void handle_global_error(const torch::Error& ex) {
16-
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
18+
throw Rice::Exception(rb_eTorchError, ex.what_without_backtrace());
1719
}
1820

1921
// keep THP prefix for now to make it easier to compare code

lib/torch.rb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@
210210
require_relative "torch/hub"
211211

212212
module Torch
213-
class Error < StandardError; end
214213
class NotImplementedYet < StandardError
215214
def message
216215
"This feature has not been implemented yet. Consider submitting a PR."

test/torch_test.rb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,15 @@ def test_tutorial
6666
assert_kind_of Numo::SFloat, b
6767
end
6868

69-
# TODO use Torch::Error
7069
def test_friendly_error_tensor_no_cpp_trace
71-
error = assert_raises(RuntimeError) do
70+
error = assert_raises(Torch::Error) do
7271
Torch.arange(0, 100).view([10, 10]).select(2, 0)
7372
end
7473
assert_equal "Dimension out of range (expected to be in range of [-2, 1], but got 2)", error.message
7574
end
7675

77-
# TODO use Torch::Error
7876
def test_friendly_error_torch_no_cpp_trace
79-
error = assert_raises(RuntimeError) do
77+
error = assert_raises(Torch::Error) do
8078
Torch.select(Torch.arange(0, 100).view([10, 10]), 2, 0)
8179
end
8280
assert_equal "Dimension out of range (expected to be in range of [-2, 1], but got 2)", error.message

0 commit comments

Comments
 (0)