21 lines
548 B
C++
21 lines
548 B
C++
#include "cuda_utils.h"
|
|
|
|
inline unsigned int getElementSize(nvinfer1::DataType t)
|
|
{
|
|
switch (t){
|
|
case nvinfer1::DataType::kINT32: return 4;
|
|
case nvinfer1::DataType::kFLOAT: return 4;
|
|
case nvinfer1::DataType::kHALF: return 2;
|
|
case nvinfer1::DataType::kBOOL:
|
|
case nvinfer1::DataType::kINT8: return 1;
|
|
}
|
|
throw std::runtime_error("Invalid DataType.");
|
|
return 0;
|
|
}
|
|
|
|
inline int64_t volume(const nvinfer1::Dims& d)
|
|
{
|
|
return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies<int64_t>());
|
|
}
|
|
|