Skip to content

Commit f3645da

Browse files
committed
Merge remote-tracking branch 'upstream/develop' into numpy-generic-kernels
2 parents 261ee70 + 45249c2 commit f3645da

10 files changed

Lines changed: 93 additions & 82 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Also see the [`/examples`](examples) directory and [usage documentation](https:/
7575

7676
### 📖 Documentation & usage guides
7777

78-
| | |
78+
| Documentation | Description |
7979
| --------------------------------------------------------------------------------- | ----------------------------------------------------- |
8080
| [Introduction](https://thinc.ai/docs) | Everything you need to know. |
8181
| [Concept & Design](https://thinc.ai/docs/concept) | Thinc's conceptual model and how it works. |

azure-pipelines.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
matrix:
2222
Python36Windows:
23-
imageName: 'windows-latest'
23+
imageName: 'windows-2019'
2424
python.version: '3.6'
2525
Python37Mac:
2626
imageName: 'macos-10.15'

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ contextvars>=2.4,<3; python_version < "3.7"
1717
# Development dependencies
1818
cython>=0.25.0,<3.0
1919
hypothesis>=3.27.0,<7.0.0
20-
pytest>=5.2.0,<7.1.0
20+
pytest>=5.2.0,!=7.1.0
2121
pytest-cov>=2.7.0,<2.8.0
2222
coverage>=5.0.0,<6.0.0
2323
mock>=2.0.0,<3.0.0

thinc/backends/_custom_kernels.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def compile_mmh(src):
126126

127127
def clipped_linear(
128128
X,
129+
*,
129130
inplace=False,
130131
slope=1.0,
131132
offset=0.0,
@@ -154,7 +155,7 @@ def clipped_linear(
154155
return out
155156

156157

157-
def gelu(X, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128):
158+
def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128):
158159
_is_float_array(X)
159160

160161
out = X
@@ -179,32 +180,32 @@ def check_seq2col_lengths(lengths, B):
179180
return lengths
180181

181182

182-
def seq2col(X, nW, *, lengths=None, threads_per_block=128, num_blocks=128):
183-
_is_float_array(X)
183+
def seq2col(seq, nW, *, lengths=None, threads_per_block=128, num_blocks=128):
184+
_is_float_array(seq)
184185

185-
B = X.shape[0]
186+
B = seq.shape[0]
186187
nF = nW * 2 + 1
187-
I = X.shape[1]
188+
I = seq.shape[1]
188189

189190
lengths = check_seq2col_lengths(lengths, B)
190191
nL = lengths.shape[0]
191192

192-
out = cupy.zeros((B, I * nF), dtype=X.dtype)
193+
out = cupy.zeros((B, I * nF), dtype=seq.dtype)
193194

194-
if X.size != 0 and lengths.size != 0:
195-
if X.dtype == "float32":
195+
if seq.size != 0 and lengths.size != 0:
196+
if seq.dtype == "float32":
196197
seq2col_kernel_float(
197-
(num_blocks,), (threads_per_block,), (out, X, lengths, nW, B, I, nL)
198+
(num_blocks,), (threads_per_block,), (out, seq, lengths, nW, B, I, nL)
198199
)
199200
else:
200201
seq2col_kernel_double(
201-
(num_blocks,), (threads_per_block,), (out, X, lengths, nW, B, I, nL)
202+
(num_blocks,), (threads_per_block,), (out, seq, lengths, nW, B, I, nL)
202203
)
203204

204205
return out
205206

206207

207-
def maxout(X, threads_per_block=128, num_blocks=128):
208+
def maxout(X, *, threads_per_block=128, num_blocks=128):
208209
_is_float_array(X)
209210

210211
B, I, P = X.shape
@@ -225,7 +226,7 @@ def maxout(X, threads_per_block=128, num_blocks=128):
225226
return best, which
226227

227228

228-
def mish(X, inplace=False, threshold=5, threads_per_block=128, num_blocks=128):
229+
def mish(X, *, inplace=False, threshold=5, threads_per_block=128, num_blocks=128):
229230
_is_float_array(X)
230231

231232
out = X
@@ -244,7 +245,7 @@ def mish(X, inplace=False, threshold=5, threads_per_block=128, num_blocks=128):
244245
return out
245246

246247

247-
def reduce_sum(X, lengths, threads_per_block=128, num_blocks=128):
248+
def reduce_sum(X, lengths, *, threads_per_block=128, num_blocks=128):
248249
_is_float_array(X)
249250

250251
B = len(lengths)
@@ -267,7 +268,7 @@ def reduce_sum(X, lengths, threads_per_block=128, num_blocks=128):
267268
return out
268269

269270

270-
def reduce_mean(X, lengths, threads_per_block=128, num_blocks=128):
271+
def reduce_mean(X, lengths, *, threads_per_block=128, num_blocks=128):
271272
_is_float_array(X)
272273

273274
B = len(lengths)
@@ -292,7 +293,7 @@ def reduce_mean(X, lengths, threads_per_block=128, num_blocks=128):
292293
return out
293294

294295

295-
def reduce_max(X, lengths, threads_per_block=128, num_blocks=128):
296+
def reduce_max(X, lengths, *, threads_per_block=128, num_blocks=128):
296297
_is_float_array(X)
297298

298299
B = len(lengths)
@@ -317,7 +318,7 @@ def reduce_max(X, lengths, threads_per_block=128, num_blocks=128):
317318
return maxes, which
318319

319320

320-
def swish(X, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128):
321+
def swish(X, *, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128):
321322
_is_float_array(X)
322323

323324
out = X
@@ -362,6 +363,7 @@ def backprop_seq2col(dY, nW, *, lengths=None, threads_per_block=128, num_blocks=
362363
def backprop_clipped_linear(
363364
dY,
364365
X,
366+
*,
365367
slope: float = 1.0,
366368
offset: float = 0.0,
367369
min_val: float = 0.0,
@@ -394,7 +396,7 @@ def backprop_clipped_linear(
394396

395397

396398
def backprop_hard_swish(
397-
dY, X, inplace: bool = False, threads_per_block=128, num_blocks=128
399+
dY, X, *, inplace: bool = False, threads_per_block=128, num_blocks=128
398400
):
399401
_is_float_array(dY)
400402
_is_float_array(X, shape=dY.shape)
@@ -416,7 +418,7 @@ def backprop_hard_swish(
416418

417419

418420
def backprop_hard_swish_mobilenet(
419-
dY, X, inplace: bool = False, threads_per_block=128, num_blocks=128
421+
dY, X, *, inplace: bool = False, threads_per_block=128, num_blocks=128
420422
):
421423
_is_float_array(dY)
422424
_is_float_array(X, shape=dY.shape)
@@ -438,7 +440,13 @@ def backprop_hard_swish_mobilenet(
438440

439441

440442
def backprop_gelu(
441-
dY, X, inplace: bool = False, threshold=6.0, threads_per_block=128, num_blocks=128
443+
dY,
444+
X,
445+
*,
446+
inplace: bool = False,
447+
threshold=6.0,
448+
threads_per_block=128,
449+
num_blocks=128,
442450
):
443451
_is_float_array(dY)
444452
_is_float_array(X, shape=dY.shape)
@@ -459,7 +467,7 @@ def backprop_gelu(
459467
return out
460468

461469

462-
def backprop_maxout(dY, which, P, threads_per_block=128, num_blocks=128):
470+
def backprop_maxout(dY, which, P, *, threads_per_block=128, num_blocks=128):
463471
_is_float_array(dY)
464472

465473
B = dY.shape[0]
@@ -482,7 +490,7 @@ def backprop_maxout(dY, which, P, threads_per_block=128, num_blocks=128):
482490

483491

484492
def backprop_mish(
485-
dY, X, inplace: bool = False, threshold=5, threads_per_block=128, num_blocks=128
493+
dY, X, *, inplace: bool = False, threshold=5, threads_per_block=128, num_blocks=128
486494
):
487495
_is_float_array(dY)
488496
_is_float_array(X, shape=dY.shape)
@@ -503,51 +511,53 @@ def backprop_mish(
503511
return out
504512

505513

506-
def backprop_reduce_sum(d_sum, lengths, threads_per_block=128, num_blocks=128):
507-
_is_float_array(d_sum)
514+
def backprop_reduce_sum(d_sums, lengths, *, threads_per_block=128, num_blocks=128):
515+
_is_float_array(d_sums)
508516

509517
B = len(lengths)
510518
T = int(lengths.sum())
511-
O = d_sum.shape[1]
519+
O = d_sums.shape[1]
512520
_check_lengths(lengths, T)
513521

514-
out = cupy.zeros((T, O), dtype=d_sum.dtype)
522+
out = cupy.zeros((T, O), dtype=d_sums.dtype)
515523

516-
if d_sum.dtype == "float32":
524+
if d_sums.dtype == "float32":
517525
backprop_reduce_sum_kernel_float(
518-
(num_blocks,), (threads_per_block,), (out, d_sum, lengths, B, T, O)
526+
(num_blocks,), (threads_per_block,), (out, d_sums, lengths, B, T, O)
519527
)
520528
else:
521529
backprop_reduce_sum_kernel_double(
522-
(num_blocks,), (threads_per_block,), (out, d_sum, lengths, B, T, O)
530+
(num_blocks,), (threads_per_block,), (out, d_sums, lengths, B, T, O)
523531
)
524532

525533
return out
526534

527535

528-
def backprop_reduce_mean(d_mean, lengths, threads_per_block=128, num_blocks=128):
529-
_is_float_array(d_mean)
536+
def backprop_reduce_mean(d_means, lengths, *, threads_per_block=128, num_blocks=128):
537+
_is_float_array(d_means)
530538

531539
B = len(lengths)
532540
T = int(lengths.sum())
533-
O = d_mean.shape[1]
541+
O = d_means.shape[1]
534542
_check_lengths(lengths, T)
535543

536-
out = cupy.zeros((T, O), dtype=d_mean.dtype)
544+
out = cupy.zeros((T, O), dtype=d_means.dtype)
537545

538-
if d_mean.dtype == "float32":
546+
if d_means.dtype == "float32":
539547
backprop_reduce_mean_kernel_float(
540-
(num_blocks,), (threads_per_block,), (out, d_mean, lengths, B, T, O)
548+
(num_blocks,), (threads_per_block,), (out, d_means, lengths, B, T, O)
541549
)
542550
else:
543551
backprop_reduce_mean_kernel_double(
544-
(num_blocks,), (threads_per_block,), (out, d_mean, lengths, B, T, O)
552+
(num_blocks,), (threads_per_block,), (out, d_means, lengths, B, T, O)
545553
)
546554

547555
return out
548556

549557

550-
def backprop_reduce_max(d_maxes, which, lengths, threads_per_block=128, num_blocks=128):
558+
def backprop_reduce_max(
559+
d_maxes, which, lengths, *, threads_per_block=128, num_blocks=128
560+
):
551561
_is_float_array(d_maxes)
552562

553563
B = len(lengths)
@@ -572,7 +582,7 @@ def backprop_reduce_max(d_maxes, which, lengths, threads_per_block=128, num_bloc
572582

573583

574584
def backprop_swish(
575-
dY, X, Y, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128
585+
dY, X, Y, *, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128
576586
):
577587
_is_float_array(dY)
578588
_is_float_array(X, shape=dY.shape)
@@ -594,7 +604,7 @@ def backprop_swish(
594604
return out
595605

596606

597-
def hash(ids, seed, threads_per_block=128, num_blocks=128):
607+
def hash(ids, seed, *, threads_per_block=128, num_blocks=128):
598608
out = cupy.zeros((ids.shape[0], 4), dtype="uint32")
599609

600610
# sizeof(uint32_t) * 4

thinc/backends/cupy_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def backprop_clipped_linear(
150150
):
151151
if X.dtype == dY.dtype and X.dtype in ("float32", "float64"):
152152
return _custom_kernels.backprop_clipped_linear(
153-
dY=dY,
154-
X=X,
153+
dY,
154+
X,
155155
slope=slope,
156156
offset=offset,
157157
min_val=min_val,
@@ -243,7 +243,7 @@ def backprop_seq2col(self, dY, nW, *, lengths=None):
243243

244244
def reduce_mean(self, X, lengths):
245245
if X.dtype in ("float32", "float64") and lengths.dtype == "int32":
246-
return _custom_kernels.reduce_mean(X, lengths)
246+
return _custom_kernels.reduce_mean(X, lengths=lengths)
247247
else:
248248
super().reduce_mean(X, lengths)
249249

thinc/backends/ops.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -256,18 +256,17 @@ def unflatten(self, X: Floats2d, lengths: Ints1d, pad: int = 0) -> List[Floats2d
256256
"""The reverse/backward operation of the `flatten` function: unflatten
257257
a large array into a list of arrays according to the given lengths.
258258
"""
259-
unflat = []
260-
pad = int(pad)
261-
for length in lengths:
262-
length = int(length)
263-
if pad >= 1 and length != 0:
264-
X = X[pad:]
265-
unflat.append(X[:length])
266-
X = X[length:]
267-
if pad >= 1:
268-
X = X[pad:]
269-
assert len(X) == 0
259+
# cupy.split requires lengths to be in CPU memory.
260+
lengths = to_numpy(lengths)
261+
262+
if pad > 0:
263+
lengths = numpy.where(lengths > 0, lengths + pad, 0) # type: ignore
264+
unflat = self.xp.split(X, numpy.cumsum(lengths))[:-1] # type: ignore
265+
if pad > 0:
266+
unflat = [a[pad:] for a in unflat]
267+
270268
assert len(unflat) == len(lengths)
269+
271270
return unflat
272271

273272
@overload

thinc/shims/pytorch.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ def __init__(
4343
mixed_precision: bool = False,
4444
grad_scaler: Optional[PyTorchGradScaler] = None,
4545
):
46-
if mixed_precision and not has_torch_amp:
47-
raise ValueError(
48-
"Mixed-precision training is not supported, requires capable GPU and torch>=1.9.0"
49-
)
50-
5146
super().__init__(model, config, optimizer)
5247

5348
if grad_scaler is None:

thinc/shims/pytorch_grad_scaler.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,6 @@ def __init__(
5050
When no overflows were found for this number of steps, the scale will
5151
be multiplied by "growth_factor".
5252
"""
53-
if enabled and not has_torch_amp:
54-
raise ValueError(
55-
"Gradient scaling is not supported, requires capable GPU and torch>=1.9.0"
56-
)
57-
5853
self._enabled = enabled
5954
self._growth_factor = growth_factor
6055
self._backoff_factor = backoff_factor
@@ -107,7 +102,18 @@ def _scale_tensor(
107102
scale_per_device: Dict["torch.device", "torch.Tensor"],
108103
inplace: bool,
109104
):
110-
assert tensor.is_cuda, "Gradient scaling is only supported for CUDA tensors"
105+
if not has_torch_amp:
106+
raise ValueError(
107+
"Gradient scaling is not supported, requires capable GPU and torch>=1.9.0"
108+
)
109+
110+
if not tensor.is_cuda:
111+
msg = (
112+
"Gradient scaling is only supported for CUDA tensors. "
113+
"If you are using PyTorch models, you can avoid this "
114+
"error by disabling mixed-precision support."
115+
)
116+
raise ValueError(msg)
111117

112118
device = tensor.device
113119

thinc/tests/layers/test_pytorch_wrapper.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -148,18 +148,3 @@ def test_pytorch_convert_inputs(data, n_args, kwargs_keys):
148148
convert_inputs = model.attrs["convert_inputs"]
149149
Y, backprop = convert_inputs(model, data, is_train=True)
150150
check_input_converters(Y, backprop, data, n_args, kwargs_keys, torch.Tensor)
151-
152-
153-
@pytest.mark.skipif(not has_torch_gpu, reason="needs PyTorch with CUDA-capable GPU")
154-
@pytest.mark.skipif(
155-
has_torch_amp, reason="needs PyTorch without mixed-precision support"
156-
)
157-
def test_raises_on_old_pytorch():
158-
import torch.nn
159-
160-
pytorch_layer = torch.nn.Linear(5, 5)
161-
with pytest.raises(ValueError, match=r"not supported.*1.9.0"):
162-
PyTorchWrapper_v2(
163-
pytorch_layer.cuda(),
164-
mixed_precision=True,
165-
)

0 commit comments

Comments
 (0)