|
11 | 11 | #include <CL/__spirv/spirv_ops.hpp> |
12 | 12 | #include <CL/sycl/detail/defines_elementary.hpp> |
13 | 13 | #include <CL/sycl/feature_test.hpp> |
14 | | -#include <sycl/ext/oneapi/experimental/bfloat16.hpp> |
15 | 14 |
|
16 | 15 | __SYCL_INLINE_NAMESPACE(cl) { |
17 | 16 | namespace sycl { |
@@ -454,156 +453,6 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> { |
454 | 453 | #undef OP |
455 | 454 | }; |
456 | 455 |
|
457 | | -template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group> |
458 | | -class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols, |
459 | | - Layout, Group> { |
460 | | - joint_matrix<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols, |
461 | | - Layout, Group> &M; |
462 | | - std::size_t idx; |
463 | | - |
464 | | -public: |
465 | | - wi_element(joint_matrix<sycl::ext::oneapi::experimental::bfloat16, NumRows, |
466 | | - NumCols, Layout, Group> &Mat, |
467 | | - std::size_t i) |
468 | | - : M(Mat), idx(i) {} |
469 | | - operator sycl::ext::oneapi::experimental::bfloat16() { |
470 | | -#ifdef __SYCL_DEVICE_ONLY__ |
471 | | - return __spirv_VectorExtractDynamic(M.spvm, idx); |
472 | | -#else |
473 | | - throw runtime_error("joint matrix is not supported on host device.", |
474 | | - PI_INVALID_DEVICE); |
475 | | -#endif // __SYCL_DEVICE_ONLY__ |
476 | | - } |
477 | | - |
478 | | - explicit operator bool() { |
479 | | -#ifdef __SYCL_DEVICE_ONLY__ |
480 | | - return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic( |
481 | | - M.spvm, idx))) >= std::numeric_limits<float>::epsilon(); |
482 | | -#else |
483 | | - throw runtime_error("joint matrix is not supported on host device.", |
484 | | - PI_INVALID_DEVICE); |
485 | | -#endif // __SYCL_DEVICE_ONLY__ |
486 | | - } |
487 | | - |
488 | | - wi_element &operator=(const sycl::ext::oneapi::experimental::bfloat16 &rhs) { |
489 | | -#ifdef __SYCL_DEVICE_ONLY__ |
490 | | - M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); |
491 | | - return *this; |
492 | | -#else |
493 | | - (void)rhs; |
494 | | - throw runtime_error("joint matrix is not supported on host device.", |
495 | | - PI_INVALID_DEVICE); |
496 | | -#endif // __SYCL_DEVICE_ONLY__ |
497 | | - } |
498 | | - |
499 | | - wi_element & |
500 | | - operator=(const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, |
501 | | - NumCols, Layout, Group> &rhs) { |
502 | | -#ifdef __SYCL_DEVICE_ONLY__ |
503 | | - M.spvm = __spirv_VectorInsertDynamic( |
504 | | - M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); |
505 | | - return *this; |
506 | | -#else |
507 | | - (void)rhs; |
508 | | - throw runtime_error("joint matrix is not supported on host device.", |
509 | | - PI_INVALID_DEVICE); |
510 | | -#endif // __SYCL_DEVICE_ONLY__ |
511 | | - } |
512 | | - |
513 | | -#if __SYCL_DEVICE_ONLY__ |
514 | | -#define OP(opassign, op) \ |
515 | | - wi_element &operator opassign( \ |
516 | | - const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ |
517 | | - M.spvm = __spirv_VectorInsertDynamic( \ |
518 | | - M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \ |
519 | | - return *this; \ |
520 | | - } |
521 | | -#else // __SYCL_DEVICE_ONLY__ |
522 | | -#define OP(opassign, op) \ |
523 | | - wi_element &operator opassign( \ |
524 | | - const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ |
525 | | - (void)rhs; \ |
526 | | - throw runtime_error("joint matrix is not supported on host device.", \ |
527 | | - PI_INVALID_DEVICE); \ |
528 | | - } |
529 | | -#endif // __SYCL_DEVICE_ONLY__ |
530 | | - OP(+=, +) |
531 | | - OP(-=, -) |
532 | | - OP(*=, *) |
533 | | - OP(/=, /) |
534 | | -#undef OP |
535 | | - |
536 | | -#if __SYCL_DEVICE_ONLY__ |
537 | | -#define OP(type, op) \ |
538 | | - friend type operator op( \ |
539 | | - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
540 | | - NumCols, Layout, Group> &lhs, \ |
541 | | - const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ |
542 | | - return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \ |
543 | | - } \ |
544 | | - friend type operator op( \ |
545 | | - const sycl::ext::oneapi::experimental::bfloat16 &lhs, \ |
546 | | - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
547 | | - NumCols, Layout, Group> &rhs) { \ |
548 | | - return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \ |
549 | | - } |
550 | | - OP(sycl::ext::oneapi::experimental::bfloat16, +) |
551 | | - OP(sycl::ext::oneapi::experimental::bfloat16, -) |
552 | | - OP(sycl::ext::oneapi::experimental::bfloat16, *) |
553 | | - OP(sycl::ext::oneapi::experimental::bfloat16, /) |
554 | | -#undef OP |
555 | | -#define OP(type, op) \ |
556 | | - friend type operator op( \ |
557 | | - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
558 | | - NumCols, Layout, Group> &lhs, \ |
559 | | - const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ |
560 | | - return type{static_cast<float>(__spirv_VectorExtractDynamic( \ |
561 | | - lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \ |
562 | | - } \ |
563 | | - friend type operator op( \ |
564 | | - const sycl::ext::oneapi::experimental::bfloat16 &lhs, \ |
565 | | - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
566 | | - NumCols, Layout, Group> &rhs) { \ |
567 | | - return type{static_cast<float>(__spirv_VectorExtractDynamic( \ |
568 | | - rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \ |
569 | | - } |
570 | | - OP(bool, ==) |
571 | | - OP(bool, !=) |
572 | | - OP(bool, <) |
573 | | - OP(bool, >) |
574 | | - OP(bool, <=) |
575 | | - OP(bool, >=) |
576 | | -#undef OP |
577 | | -#else // __SYCL_DEVICE_ONLY__ |
578 | | -#define OP(type, op) \ |
579 | | - friend type operator op( \ |
580 | | - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
581 | | - NumCols, Layout, Group> &, \ |
582 | | - const sycl::ext::oneapi::experimental::bfloat16 &) { \ |
583 | | - throw runtime_error("joint matrix is not supported on host device.", \ |
584 | | - PI_INVALID_DEVICE); \ |
585 | | - } \ |
586 | | - friend type operator op( \ |
587 | | - const sycl::ext::oneapi::experimental::bfloat16 &, \ |
588 | | - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
589 | | - NumCols, Layout, Group> &) { \ |
590 | | - throw runtime_error("joint matrix is not supported on host device.", \ |
591 | | - PI_INVALID_DEVICE); \ |
592 | | - } |
593 | | - OP(sycl::ext::oneapi::experimental::bfloat16, +) |
594 | | - OP(sycl::ext::oneapi::experimental::bfloat16, -) |
595 | | - OP(sycl::ext::oneapi::experimental::bfloat16, *) |
596 | | - OP(sycl::ext::oneapi::experimental::bfloat16, /) |
597 | | - OP(bool, ==) |
598 | | - OP(bool, !=) |
599 | | - OP(bool, <) |
600 | | - OP(bool, >) |
601 | | - OP(bool, <=) |
602 | | - OP(bool, >=) |
603 | | -#undef OP |
604 | | -#endif // __SYCL_DEVICE_ONLY__ |
605 | | -}; |
606 | | - |
607 | 456 | template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout, |
608 | 457 | typename Group> |
609 | 458 | class wi_data { |
|
0 commit comments