struct OrtMemoryInfo {
OrtMemoryInfo() = default; // to allow default construction of Tensor
// TS Elijah Do
size_t HeapOffset = 0;
bool UsePool = 0;
// use string for name, so we could have customized allocator in execution provider.
const char* name = nullptr;
int id = -1;
OrtMemType mem_type = OrtMemTypeDefault;
OrtAllocatorType alloc_type = OrtInvalidAllocator;
OrtDevice device;
}
// Struct to represent a physical device.
struct OrtDevice {
using DeviceType = int8_t;
using MemoryType = int8_t;
using DeviceId = int16_t;
private:
// Device type.
DeviceType device_type;
// Memory type.
MemoryType memory_type;
// Device index.
DeviceId device_id;
};
class kCpuExecutionProvider_Conv_kOnnxDomain_ver1_10;
template <>
KernelCreateInfo BuildKernelCreateInfo<kCpuExecutionProvider_Conv_kOnnxDomain_ver1_10>() {
return KernelCreateInfo(
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.SetName("Conv")
.SetDomain(kOnnxDomain)
.SinceVersion(1, 10)
.Provider(kCpuExecutionProvider)
.Build(),
static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
out = std::make_unique<Conv<float>>(info);
return Status::OK();
}
)
);
}
class kCpuExecutionProvider_Conv_kOnnxDomain_ver11;
template <>
KernelCreateInfo BuildKernelCreateInfo<kCpuExecutionProvider_Conv_kOnnxDomain_ver11>() {
return KernelCreateInfo(
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.SetName("Conv")
.SetDomain(kOnnxDomain)
.SinceVersion(11)
.Provider(kCpuExecutionProvider)
.Build(),
static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
out = std::make_unique<Conv<float>>(info);
return Status::OK();
})
);
}
Tensor 只是一个带有shape信息的内存占位符,不具有内存的拥有权和管理权。
class Tensor final {
void* p_data_;
AllocatorPtr buffer_deleter_; // 如果为空,说明Tensor不拥有内存,否则Tensor需要释放内存
TensorShape shape_;
const PrimitiveDataTypeBase* dtype_;
OrtMemoryInfo alloc_info_;
ptrdiff_t byte_offset_;
}
class PrimitiveDataTypeBase : public DataTypeImpl {
const int32_t data_type_;
}
class DataTypeImpl {
public:
enum class GeneralType {
kInvalid = 0,
kNonTensor = 1,
kTensor = 2,
kTensorSequence = 3,
kSparseTensor = 4,
kOptional = 5,
kPrimitive = 6,
};
const GeneralType type_;
const size_t size_;
}
TensorShape
class TensorShape {
gsl::span<int64_t> values_; //见下面定义
int64_t small_buffer_[kTensorShapeSmallBufferElementsSize]{0}; //存放每个维度的stride大小 kTensorShapeSmallBufferElementsSize = 5
std::unique_ptr<int64_t[]> allocated_buffer_;
friend struct ProviderHostImpl; // So that the shared provider interface can access Allocate
}
template <class ElementType, std::size_t Extent>
class span
{
public:
// constants and types
using element_type = ElementType;
using value_type = std::remove_cv_t<ElementType>;
using size_type = std::size_t;
using pointer = element_type*;
using const_pointer = const element_type*;
using reference = element_type&;
using const_reference = const element_type&;
using difference_type = std::ptrdiff_t;
using iterator = details::span_iterator<ElementType>;
using reverse_iterator = std::reverse_iterator<iterator>;
}
template <class Type>
class span_iterator
{
public:
using iterator_category = std::random_access_iterator_tag;
using value_type = std::remove_cv_t<Type>;
using difference_type = std::ptrdiff_t;
using pointer = Type*;
using reference = Type&;
pointer begin_ = nullptr;
pointer end_ = nullptr;
pointer current_ = nullptr;
}
/** Copied from TensorProto::DataType
* Currently, Ort doesn't support complex64, complex128
*/
typedef enum ONNXTensorElementDataType {
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, // maps to c type float
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, // maps to c type uint8_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, // maps to c type int8_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, // maps to c type uint16_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, // maps to c type int16_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, // maps to c type int32_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, // maps to c type int64_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string
ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components
ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision
} ONNXTensorElementDataType;
// The data types that are supported in this build (enabled) for inputs/outputs.
// Key is input/output/type constraint name defined in op schema, Value is supported types.
std::unordered_map<std::string, std::vector<MLDataType>> type_constraints_;
conv的算法类型:
case MlasConvAlgorithmGemmDirect:
{
//
// Invoke the threaded GEMM directly with the input tensor.
//
MlasGemm(CblasNoTrans, Parameters->u.GemmDirect.TransB, FilterCount, OutputSize,
K, 1.0f, filter, K, Input, Parameters->u.GemmDirect.ldb,
Parameters->Beta, Output, OutputSize, ThreadPool);
//
// Apply the activation with optional bias.
//
MlasActivation(Parameters->Activation, Output, bias, FilterCount,
OutputSize, OutputSize);
break;
}
enum MLAS_CONV_ALGORITHM {
MlasConvAlgorithmGemmDirect,
MlasConvAlgorithmExpandThenGemm,
MlasConvAlgorithmExpandThenGemmSegmented,
#if defined(MLAS_TARGET_WASM_SCALAR)
MlasConvAlgorithmDepthwise,
#endif
};
struct MLAS_CONV_PARAMETERS {
const MLAS_ACTIVATION* Activation;
size_t Dimensions;
size_t BatchCount;
size_t GroupCount;
size_t InputChannels;
size_t InputShape[3];
size_t KernelShape[3];
size_t DilationShape[3];
size_t Padding[6];
size_t StrideShape[3];
size_t FilterCount;
size_t OutputShape[3];
size_t InputSize;
size_t OutputSize;
size_t K;
float Beta;
MLAS_CONV_ALGORITHM Algorithm;
ptrdiff_t ThreadCount;
union {
struct {
CBLAS_TRANSPOSE TransB;
size_t ldb; // OutputSize
} GemmDirect;
struct {
size_t ThreadStrideN;
} ExpandThenGemmSegmented;
} u;
};