|
16 | 16 | use Rindow\Math\Matrix\Complex; |
17 | 17 | use Rindow\Math\Matrix\ComplexUtils; |
18 | 18 | use Rindow\Math\Matrix\Drivers\Service; |
19 | | - |
20 | | -//use Rindow\Math\Matrix\MatrixOperator; |
21 | 19 | use Rindow\Math\Matrix\Range; |
22 | 20 | use RuntimeException; |
23 | 21 | use Serializable; |
@@ -390,6 +388,16 @@ public static function fromString(string $string, int $dtype, array $shape): sta |
390 | 388 | return new static($buffer, $dtype, $shape, 0); |
391 | 389 | } |
392 | 390 |
|
| 391 | + public static function random(array $shape, ?int $dtype = null): static |
| 392 | + { |
| 393 | + $dtype ??= NDArray::float32; |
| 394 | + $size = array_product($shape); |
| 395 | + |
| 396 | + $buffer = Tensor::newBuffer($size, $dtype); |
| 397 | + $buffer->load(random_bytes($size * TensorBuffer::$valueSize[$dtype])); |
| 398 | + return new static($buffer, shape: $shape, offset: 0); |
| 399 | + } |
| 400 | + |
393 | 401 | /** |
394 | 402 | * Convert the tensor into an array. |
395 | 403 | */ |
@@ -680,6 +688,15 @@ public function multiply(Tensor|float|int $value): self |
680 | 688 | return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset()); |
681 | 689 | } |
682 | 690 |
|
| 691 | + public function matmul(Tensor $other, ?bool $transposeA = null, ?bool $transposeB = null): Tensor |
| 692 | + { |
| 693 | + $mo = self::mo(); |
| 694 | + |
| 695 | + $result = $mo->la()->matmul($this, $other, $transposeA, $transposeB); |
| 696 | + |
| 697 | + return new static($result->buffer(), $result->dtype(), $result->shape(), $result->offset()); |
| 698 | + } |
| 699 | + |
683 | 700 | public function log(): self |
684 | 701 | { |
685 | 702 | $mo = self::mo(); |
@@ -942,7 +959,7 @@ public function to(int $dtype): static |
942 | 959 | /** |
943 | 960 | * Returns the mean value of each row of the tensor in the given axis. |
944 | 961 | */ |
945 | | - public function mean(?int $axis = null, bool $keepShape = false): static|float|int |
| 962 | + public function mean(?int $axis = null, bool $keepShape = false): static|float|int|Tensor |
946 | 963 | { |
947 | 964 | $mo = self::mo(); |
948 | 965 |
|
@@ -1017,7 +1034,7 @@ public function stdMean(?int $axis = null, int $correction = 1, bool $keepShape |
1017 | 1034 | $num = floor($num / $size); |
1018 | 1035 | } |
1019 | 1036 |
|
1020 | | - $result->buffer[$resultIndex] += pow($this->buffer[$i] - $mean->buffer()[$resultIndex], 2); |
| 1037 | + $result->buffer[$resultIndex] += pow($this->buffer[$i] - $mean->buffer[$resultIndex], 2); |
1021 | 1038 | } |
1022 | 1039 |
|
1023 | 1040 | for ($i = 0; $i < count($result->buffer); ++$i) { |
@@ -1226,9 +1243,9 @@ protected function softmax2D(): static |
1226 | 1243 | * |
1227 | 1244 | * @return array The top k values and indices of the tensor. |
1228 | 1245 | */ |
1229 | | - public function topk(int $k = null, bool $sorted = true): array |
| 1246 | + public function topk(int $k = -1, bool $sorted = true): array |
1230 | 1247 | { |
1231 | | - if ($k === null) { |
| 1248 | + if ($k === -1) { |
1232 | 1249 | $k = $this->shape[0]; |
1233 | 1250 | } |
1234 | 1251 |
|
|
0 commit comments