21 lines
		
	
	
		
			548 B
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			21 lines
		
	
	
		
			548 B
		
	
	
	
		
			C++
		
	
	
	
| #include "cuda_utils.h"
 | |
| 
 | |
| inline unsigned int getElementSize(nvinfer1::DataType t)
 | |
| {
 | |
|     switch (t){
 | |
|         case nvinfer1::DataType::kINT32: return 4;
 | |
|         case nvinfer1::DataType::kFLOAT: return 4;
 | |
|         case nvinfer1::DataType::kHALF: return 2;
 | |
|         case nvinfer1::DataType::kBOOL:
 | |
|         case nvinfer1::DataType::kINT8: return 1;
 | |
|     }
 | |
|     throw std::runtime_error("Invalid DataType.");
 | |
|     return 0;
 | |
| }
 | |
| 
 | |
| inline int64_t volume(const nvinfer1::Dims& d)
 | |
| {
 | |
|     return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies<int64_t>());
 | |
| }
 | |
| 
 |