1818#include " sampleUtils.h"
1919#include " bfloat16.h"
2020#include " half.h"
21+ #include < cuda_fp8.h>
2122
2223using namespace nvinfer1 ;
2324
@@ -433,6 +434,11 @@ void print(std::ostream& os, __half v)
433434 os << static_cast <float >(v);
434435}
435436
437+ void print (std::ostream& os, __nv_fp8_e4m3 v)
438+ {
439+ os << static_cast <float >(v);
440+ }
441+
436442template <typename T>
437443void dumpBuffer (void const * buffer, std::string const & separator, std::ostream& os, Dims const & dims,
438444 Dims const & strides, int32_t vectorDim, int32_t spv)
@@ -482,6 +488,8 @@ template void dumpBuffer<uint8_t>(void const* buffer, std::string const& separat
482488 Dims const & strides, int32_t vectorDim, int32_t spv);
483489template void dumpBuffer<int64_t >(void const * buffer, std::string const & separator, std::ostream& os, Dims const & dims,
484490 Dims const & strides, int32_t vectorDim, int32_t spv);
491+ template void dumpBuffer<__nv_fp8_e4m3>(void const * buffer, std::string const & separator, std::ostream& os, Dims const & dims,
492+ Dims const & strides, int32_t vectorDim, int32_t spv);
485493
486494template <typename T>
487495void sparsify (T const * values, int64_t count, int32_t k, int32_t trs, std::vector<int8_t >& sparseWeights)
@@ -566,7 +574,7 @@ void fillBuffer(void* buffer, int64_t volume, T min, T max)
566574{
567575 T* typedBuffer = static_cast <T*>(buffer);
568576 std::default_random_engine engine;
569- std::uniform_real_distribution<float > distribution (min, max);
577+ std::uniform_real_distribution<float > distribution (( float ) min, ( float ) max);
570578 auto generator = [&engine, &distribution]() { return static_cast <T>(distribution (engine)); };
571579 std::generate (typedBuffer, typedBuffer + volume, generator);
572580}
@@ -580,6 +588,7 @@ template void fillBuffer<int8_t>(void* buffer, int64_t volume, int8_t min, int8_
580588template void fillBuffer<__half>(void * buffer, int64_t volume, __half min, __half max);
581589template void fillBuffer<BFloat16>(void * buffer, int64_t volume, BFloat16 min, BFloat16 max);
582590template void fillBuffer<uint8_t >(void * buffer, int64_t volume, uint8_t min, uint8_t max);
591+ template void fillBuffer<__nv_fp8_e4m3>(void * buffer, int64_t volume, __nv_fp8_e4m3 min, __nv_fp8_e4m3 max);
583592
584593bool matchStringWithOneWildcard (std::string const & pattern, std::string const & target)
585594{
0 commit comments