ONNXRuntime Conv算子实现解析

struct OrtMemoryInfo {
  OrtMemoryInfo() = default;  // to allow default construction of Tensor

  // TS Elijah Do
  size_t HeapOffset = 0;
  bool UsePool = 0;

  // use string for name, so we could have customized allocator in execution provider.
  const char* name = nullptr;
  int id = -1;
  OrtMemType mem_type = OrtMemTypeDefault;
  OrtAllocatorType alloc_type = OrtInvalidAllocator;
  OrtDevice device;
}
// Struct to represent a physical device.
struct OrtDevice {
  using DeviceType = int8_t;
  using MemoryType = int8_t;
  using DeviceId = int16_t;
private:
  // Device type.
  DeviceType device_type;

  // Memory type.
  MemoryType memory_type;

  // Device index.
  DeviceId device_id;
};
class kCpuExecutionProvider_Conv_kOnnxDomain_ver1_10; 
template <> 
KernelCreateInfo BuildKernelCreateInfo<kCpuExecutionProvider_Conv_kOnnxDomain_ver1_10>() { 
    return KernelCreateInfo( 
        KernelDefBuilder()
            .TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
            .SetName("Conv") 
            .SetDomain(kOnnxDomain) 
            .SinceVersion(1, 10) 
            .Provider(kCpuExecutionProvider) 
            .Build(), 
        static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { 
            out = std::make_unique<Conv<float>>(info); 
            return Status::OK(); 
            }
        )
    ); 
}


class kCpuExecutionProvider_Conv_kOnnxDomain_ver11; 
template <> 
KernelCreateInfo BuildKernelCreateInfo<kCpuExecutionProvider_Conv_kOnnxDomain_ver11>() { 
    return KernelCreateInfo( 
        KernelDefBuilder()
            .TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
            .SetName("Conv") 
            .SetDomain(kOnnxDomain) 
            .SinceVersion(11) 
            .Provider(kCpuExecutionProvider) 
            .Build(), 
        static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { 
            out = std::make_unique<Conv<float>>(info); 
            return Status::OK(); 
        })
    );
}

Tensor 只是一个带有shape信息的内存占位符,不具有内存的拥有权和管理权。

class Tensor final {

    void* p_data_;
    
    AllocatorPtr buffer_deleter_; // 如果为空,说明Tensor不拥有内存,否则Tensor需要释放内存
    TensorShape shape_;
    const PrimitiveDataTypeBase* dtype_;
    OrtMemoryInfo alloc_info_;
    ptrdiff_t byte_offset_;
}
class PrimitiveDataTypeBase : public DataTypeImpl {
    const int32_t data_type_;
}
class DataTypeImpl {
     public:
      enum class GeneralType {
        kInvalid = 0,
        kNonTensor = 1,
        kTensor = 2,
        kTensorSequence = 3,
        kSparseTensor = 4,
        kOptional = 5,
        kPrimitive = 6,
      };

      const GeneralType type_;
      const size_t size_;
}

TensorShape

class TensorShape {
      gsl::span<int64_t> values_; //见下面定义
      int64_t small_buffer_[kTensorShapeSmallBufferElementsSize]{0}; //存放每个维度的stride大小 kTensorShapeSmallBufferElementsSize = 5
      std::unique_ptr<int64_t[]> allocated_buffer_;

      friend struct ProviderHostImpl;  // So that the shared provider interface can         access Allocate
}
template <class ElementType, std::size_t Extent>
class span
{
public:
    // constants and types
    using element_type = ElementType;
    using value_type = std::remove_cv_t<ElementType>;
    using size_type = std::size_t;
    using pointer = element_type*;
    using const_pointer = const element_type*;
    using reference = element_type&;
    using const_reference = const element_type&;
    using difference_type = std::ptrdiff_t;

    using iterator = details::span_iterator<ElementType>;
    using reverse_iterator = std::reverse_iterator<iterator>;
}

 template <class Type>
 class span_iterator
{
    public:
        using iterator_category = std::random_access_iterator_tag;
        using value_type = std::remove_cv_t<Type>;
        using difference_type = std::ptrdiff_t;
        using pointer = Type*;
        using reference = Type&;  
        
        pointer begin_ = nullptr;
        pointer end_ = nullptr;
        pointer current_ = nullptr;  
}        
/** Copied from TensorProto::DataType
 * Currently, Ort doesn't support complex64, complex128
 */
typedef enum ONNXTensorElementDataType {
  ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED,
  ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,   // maps to c type float
  ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,   // maps to c type uint8_t
  ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,    // maps to c type int8_t
  ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,  // maps to c type uint16_t
  ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,   // maps to c type int16_t
  ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,   // maps to c type int32_t
  ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,   // maps to c type int64_t
  ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,  // maps to c++ type std::string
  ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
  ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
  ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,      // maps to c type double
  ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,      // maps to c type uint32_t
  ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,      // maps to c type uint64_t
  ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64,   // complex with float32 real and imaginary components
  ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128,  // complex with float64 real and imaginary components
  ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16     // Non-IEEE floating-point format based on IEEE754 single-precision
} ONNXTensorElementDataType;

// The data types that are supported in this build (enabled) for inputs/outputs.
// Key is input/output/type constraint name defined in op schema, Value is supported types.
std::unordered_map<std::string, std::vector<MLDataType>> type_constraints_;

conv的算法类型:

case MlasConvAlgorithmGemmDirect:
                    {
                    //
                    // Invoke the threaded GEMM directly with the input tensor.
                    //

                    MlasGemm(CblasNoTrans, Parameters->u.GemmDirect.TransB, FilterCount, OutputSize,
                             K, 1.0f, filter, K, Input, Parameters->u.GemmDirect.ldb,
                             Parameters->Beta, Output, OutputSize, ThreadPool);

                    //
                    // Apply the activation with optional bias.
                    //

                    MlasActivation(Parameters->Activation, Output, bias, FilterCount,
                        OutputSize, OutputSize);

                    break;
                }
enum MLAS_CONV_ALGORITHM {
    MlasConvAlgorithmGemmDirect,
    MlasConvAlgorithmExpandThenGemm,
    MlasConvAlgorithmExpandThenGemmSegmented,
#if defined(MLAS_TARGET_WASM_SCALAR)
    MlasConvAlgorithmDepthwise,
#endif
};

struct MLAS_CONV_PARAMETERS {
    const MLAS_ACTIVATION* Activation;
    size_t Dimensions;
    size_t BatchCount;
    size_t GroupCount;
    size_t InputChannels;
    size_t InputShape[3];
    size_t KernelShape[3];
    size_t DilationShape[3];
    size_t Padding[6];
    size_t StrideShape[3];
    size_t FilterCount;
    size_t OutputShape[3];
    size_t InputSize;
    size_t OutputSize;
    size_t K;
    float Beta;
    MLAS_CONV_ALGORITHM Algorithm;
    ptrdiff_t ThreadCount;
    union {
        struct {
            CBLAS_TRANSPOSE TransB;
            size_t ldb; // OutputSize
        } GemmDirect;
        struct {
            size_t ThreadStrideN;
        } ExpandThenGemmSegmented;
    } u;
};
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值