参考:
http://www.codeguru.com/cpp/com-tech/atl/misc/article.php/c37/Asynchronous-Pluggable-Protocol-Implementation-with-ATL.htm
http://blog.youkuaiyun.com/cumtzly/article/details/40072613
// http_protocol.cc
#include "trident/glue/protocol_impl/http_protocol.h"
#include "base/logging.h"
#include <WinInet.h>
#include <ExDisp.h>
namespace trident {
HttpProtocol::HttpProtocol(IUnknown* pOuterUnknown)
: reference_count_(0),
outer_unknown_(pOuterUnknown),
grf_BindF_(0),
inner_unknown_(NULL) {
inner_unknown_ = reinterpret_cast<IUnknown*>((INonDelegatingUnknown*)(this));
ZeroMemory(&bind_info_, sizeof(BINDINFO));
bind_info_.cbSize = sizeof(BINDINFO);
}
HttpProtocol::~HttpProtocol() {
}
// INonDelegatingUnknown
STDMETHODIMP HttpProtocol::NonDelegatingQueryInterface(REFIID riid, void** ppvObject) {
if(ppvObject == NULL){
return E_INVALIDARG;
}
HRESULT result = E_NOINTERFACE;
*ppvObject = NULL;
NonDelegatingAddRef();
if (riid == IID_IUnknown) {
*ppvObject = static_cast<INonDelegatingUnknown*>(this);
}else if(riid == IID_IInternetProtocolRoot) {
*ppvObject = static_cast<IInternetProtocolRoot*>(this);
} else if (riid == IID_IInternetProtocol) {
*ppvObject = static_cast<IInternetProtocol*>(this);
} else if (riid == IID_IInternetProtocolEx) {
*ppvObject = static_cast<IInternetProtocolEx*>(this);
} else if (riid == IID_IInternetProtocolInfo) {
*ppvObject = static_cast<IInternetProtocolInfo*>(this);
}
if(*ppvObject)
result = S_OK;
else
NonDelegatingRelease();
return result;
}
STDMETHODIMP_(ULONG) HttpProtocol::NonDelegatingAddRef() {
return (ULONG)::InterlockedIncrement(&reference_count_);
}
STDMETHODIMP_(ULONG) HttpProtocol::NonDelegatingRelease() {
::InterlockedDecrement(&reference_count_);
if (reference_count_ == 0) {
delete this;
}
return reference_count_;
}
// IUnknown
STDMETHODIMP HttpProtocol::QueryInterface(REFIID riid, void** ppvObject) {
if (outer_unknown_) {
return outer_unknown_->QueryInterface(riid, ppvObject);
} else {
return inner_unknown_->QueryInterface(riid, ppvObject);
}
}
STDMETHODIMP_(ULONG) HttpProtocol::AddRef() {
if (outer_unknown_) {
return outer_unknown_->AddRef();
} else {
return inner_unknown_->AddRef();
}
}
STDMETHODIMP_(ULONG) HttpProtocol::Release() {
if (outer_unknown_) {
return outer_unknown_->Release();
} else {
return inner_unknown_->Release();
}
}
// IInternetProtocolRoot , XP SP2及以下版本走这个接口
STDMETHODIMP HttpProtocol::Start(LPCWSTR url, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) {
if(bind_info == NULL || protocol_sink == NULL || url == NULL)
return E_INVALIDARG;
bind_url_ = GURL(url);
spSink_ = protocol_sink;
spBindinfo_ = bind_info;
spSink_->QueryInterface(IID_IServiceProvider, (void**)&spServiceProvider_);
if(!spServiceProvider_)
spBindinfo_->QueryInterface(IID_IServiceProvider, (void**)&spServiceProvider_);
DCHECK(spServiceProvider_);
// BINDINFO
//http://msdn.microsoft.com/en-us/library/ie/aa767897(v=vs.85).aspx
//http://msdn.microsoft.com/en-us/library/ie/aa741006(v=vs.85).aspx#Handling_BINDINFO_St
HRESULT result = spBindinfo_->GetBindInfo(&grf_BindF_, &bind_info_);
DCHECK(result == S_OK);
if( !bind_info_.dwCodePage )
bind_info_.dwCodePage = ::GetACP();
/*bind_info_->ReportProgress(BINDSTATUS_FINDINGRESOURCE, strData);
bind_info_->ReportProgress(BINDSTATUS_CONNECTING, strData);
bind_info_->ReportProgress(BINDSTATUS_SENDINGREQUEST, strData);
bind_info_->ReportProgress(BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, CAtlString(m_url.GetMimeType()));
bind_info_->ReportData(BSCF_FIRSTDATANOTIFICATION, 0, bind_url_.GetDataLength());
bind_info_->ReportData(BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE, m_url.GetDataLength(), m_url.GetDataLength());*/
return S_OK;
return S_OK;
}
STDMETHODIMP HttpProtocol::Continue(PROTOCOLDATA* pProtocolData) {
return S_OK;
}
// IE6/IE8下有断言,发现调用Terminate后还会调用Abort
STDMETHODIMP HttpProtocol::Abort(HRESULT reason, DWORD options) {
return S_OK;
}
STDMETHODIMP HttpProtocol::Terminate(DWORD options) {
return S_OK;
}
STDMETHODIMP HttpProtocol::Suspend() {
return E_NOTIMPL;
}
STDMETHODIMP HttpProtocol::Resume() {
return E_NOTIMPL;
}
STDMETHODIMP HttpProtocol::Read(void* pv, ULONG size, ULONG* pcbRead) {
return S_OK;
}
STDMETHODIMP HttpProtocol::Seek(LARGE_INTEGER move, DWORD origin, ULARGE_INTEGER* new_position) {
return S_OK;
}
STDMETHODIMP HttpProtocol::LockRequest(DWORD options) {
has_lock_request_ = true;
return S_OK;
}
STDMETHODIMP HttpProtocol::UnlockRequest() {
has_lock_request_ = false;
return S_OK;
}
// XP SP3及以上版本走这个接口
STDMETHODIMP HttpProtocol::StartEx(IUri* uri, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) {
if(uri == NULL) {
return E_INVALIDARG;
}
BSTR uri_URL = NULL;
std::wstring url;
uri->GetAbsoluteUri(&uri_URL);
if (uri_URL != NULL) {
url = uri_URL;
::SysFreeString(uri_URL);
}
uri->Release();
return Start(url.c_str(), protocol_sink, bind_info, flags, reserved);
}
STDMETHODIMP HttpProtocol::ParseUrl(LPCWSTR pwzUrl, PARSEACTION ParseAction, DWORD dwParseFlags, LPWSTR pwzResult,
DWORD cchResult, DWORD *pcchResult, DWORD dwReserved) {
return S_OK;
}
STDMETHODIMP HttpProtocol::CombineUrl( LPCWSTR pwzBaseUrl, LPCWSTR pwzRelativeUrl, DWORD dwCombineFlags, LPWSTR pwzResult,
DWORD cchResult,DWORD *pcchResult,DWORD dwReserved) {
return S_OK;
}
STDMETHODIMP HttpProtocol::CompareUrl( LPCWSTR pwzUrl1,LPCWSTR pwzUrl2,DWORD dwCompareFlags) {
return S_OK;
}
STDMETHODIMP HttpProtocol::QueryInfo(LPCWSTR pwzUrl, QUERYOPTION OueryOption, DWORD dwQueryFlags, LPVOID pBuffer, DWORD cbBuffer,
DWORD *pcbBuf, DWORD dwReserved) {
return S_OK;
}
std::wstring HttpProtocol::GetVerbStr() const {
wchar_t* pszRes = NULL;
switch (bind_info_.dwBindVerb)
{
case BINDVERB_GET :
pszRes = L"GET";
break;
case BINDVERB_POST :
pszRes = L"POST";
break;
case BINDVERB_PUT :
pszRes = L"PUT";
break;
case BINDVERB_CUSTOM :
pszRes = bind_info_.szCustomVerb;
break;
}
DCHECK(pszRes);
return pszRes;
}
bool HttpProtocol::GetDataToSend(char** lplpData, DWORD* pdwSize) const {
if(bind_info_.dwBindVerb == BINDVERB_GET)
return false;
if (bind_info_.stgmedData.tymed == TYMED_HGLOBAL) {
if(lplpData)
*lplpData = (char*)bind_info_.stgmedData.hGlobal;
if(pdwSize)
*pdwSize = bind_info_.cbstgmedData;
return true;
} else {
return false;
}
}
}
<pre name="code" class="cpp">// http_protocol.h
#ifndef TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_
#define TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_
// 实现参考 Win2K 源码
// private\inet\urlmon\iapp\cnet.cxx
// private\inet\urlmon\iapp\cnethttp.cxx
#include <atlbase.h>
#include <urlmon.h>
#include <vector>
#include "base/basictypes.h"
#include "url/gurl.h"
namespace trident {
// COM组件聚合帮助接口
// 参考:http://msdn.microsoft.com/en-us/library/windows/desktop/dd390339(v=vs.85).aspx
struct INonDelegatingUnknown {
STDMETHOD(NonDelegatingQueryInterface)(REFIID riid, void** ppvObject) = 0;
STDMETHOD_(ULONG, NonDelegatingAddRef)() = 0;
STDMETHOD_(ULONG, NonDelegatingRelease)() = 0;
};
class HttpProtocol : public INonDelegatingUnknown,
public IInternetProtocolEx,
public IInternetProtocolInfo{
public:
HttpProtocol(IUnknown* pOuterUnknown);
virtual ~HttpProtocol();
public:
// INonDelegatingUnknown
// 只提供Protocol接口查询,不提供Sink接口查询
STDMETHOD(NonDelegatingQueryInterface)(REFIID riid, void** ppvObject);
STDMETHOD_(ULONG, NonDelegatingAddRef)();
STDMETHOD_(ULONG, NonDelegatingRelease)();
// IUnknown
STDMETHOD(QueryInterface)(REFIID riid, void** ppvObject);
STDMETHOD_(ULONG, AddRef)();
STDMETHOD_(ULONG, Release)();
// IInternetProtocolRoot
STDMETHOD(Start)(LPCWSTR url, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved);
STDMETHOD(Continue)(PROTOCOLDATA* pProtocolData);
STDMETHOD(Abort)(HRESULT reason, DWORD options);
STDMETHOD(Terminate)(DWORD options);
STDMETHOD(Suspend)();
STDMETHOD(Resume)();
// IInternetProtocol : public IInternetProtocolRoot
STDMETHOD(Read)(void* pv, ULONG size, ULONG* pcbRead);
STDMETHOD(Seek)(LARGE_INTEGER move, DWORD origin, ULARGE_INTEGER* new_position);
STDMETHOD(LockRequest)(DWORD options);
STDMETHOD(UnlockRequest)();
// IInternetProtocolEx : public IInternetProtocol
STDMETHOD(StartEx)(IUri* uri, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved);
// IInternetProtocolInfo
STDMETHOD(ParseUrl)(LPCWSTR pwzUrl, PARSEACTION ParseAction, DWORD dwParseFlags, LPWSTR pwzResult, DWORD cchResult, DWORD *pcchResult, DWORD dwReserved) ;
STDMETHOD(CombineUrl)( LPCWSTR pwzBaseUrl, LPCWSTR pwzRelativeUrl, DWORD dwCombineFlags, LPWSTR pwzResult,DWORD cchResult,DWORD *pcchResult,DWORD dwReserved) ;
STDMETHOD(CompareUrl)( LPCWSTR pwzUrl1,LPCWSTR pwzUrl2,DWORD dwCompareFlags) ;
STDMETHOD(QueryInfo)(LPCWSTR pwzUrl, QUERYOPTION OueryOption, DWORD dwQueryFlags, LPVOID pBuffer, DWORD cbBuffer, DWORD *pcbBuf, DWORD dwReserved);
private:
std::wstring GetVerbStr() const ;
bool GetDataToSend(char** lplpData, DWORD* pdwSize) const ;
private:
volatile LONG reference_count_;
IUnknown* outer_unknown_;
IUnknown* inner_unknown_;
CComPtr<IInternetProtocolSink> spSink_;
CComPtr<IInternetBindInfo> spBindinfo_;
CComPtr<IServiceProvider> spServiceProvider_;
BINDINFO bind_info_;
DWORD grf_BindF_;
bool has_lock_request_;
GURL bind_url_;
DISALLOW_COPY_AND_ASSIGN(HttpProtocol);
};
} //namespace trident
#endif // TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_
<pre name="code" class="cpp">// http protocol factory.cc
#include "trident/glue/protocol_impl/http_protocol_factory.h"
#include "base/logging.h"
#include "trident/glue/protocol_impl/http_protocol.h"
namespace trident {
HttpProtocolFactory::HttpProtocolFactory(bool is_https_protocol) : reference_count_(1) {
HRESULT result = S_OK;
if (is_https_protocol) {
result = ::CoGetClassObject(CLSID_HttpSProtocol, CLSCTX_INPROC_SERVER, NULL, IID_IClassFactory, (void**)&origin_factory_);
} else {
result = ::CoGetClassObject(CLSID_HttpProtocol, CLSCTX_INPROC_SERVER, NULL, IID_IClassFactory, (void**)&origin_factory_);
}
DCHECK(result == S_OK);
DCHECK(origin_factory_ != NULL);
}
HttpProtocolFactory::~HttpProtocolFactory() {
}
// IUnknown
STDMETHODIMP HttpProtocolFactory::QueryInterface(REFIID riid, void** ppvObject) {
if (!ppvObject) {
return E_INVALIDARG;
}
*ppvObject = NULL;
HRESULT result = E_NOINTERFACE;
if (riid == IID_IUnknown) {
*ppvObject = static_cast<IUnknown*>(this);
} else if (riid == IID_IClassFactory) {
*ppvObject = static_cast<IClassFactory*>(this);
}
if (*ppvObject) {
static_cast<IUnknown*>(*ppvObject)->AddRef();
result = S_OK;
}
return result;
}
STDMETHODIMP_(ULONG) HttpProtocolFactory::AddRef() {
return ::InterlockedIncrement(&reference_count_);
}
STDMETHODIMP_(ULONG) HttpProtocolFactory::Release() {
ULONG count = ::InterlockedDecrement(&reference_count_);
if (count == 0) {
delete this;
return 0;
}
return count;
}
// IClassFactory
STDMETHODIMP HttpProtocolFactory::CreateInstance(IUnknown* pUnkOuter, REFIID riid, void** ppvObject) {
if (pUnkOuter && riid != IID_IUnknown) {
return CLASS_E_NOAGGREGATION;
}
HttpProtocol* http_protocol = new HttpProtocol(pUnkOuter);
if(http_protocol->NonDelegatingQueryInterface(riid, ppvObject) != S_OK) {
delete http_protocol;
*ppvObject = NULL;
return E_NOINTERFACE;
}else {
return S_OK;
}
}
STDMETHODIMP HttpProtocolFactory::LockServer(BOOL fLock) {
if(fLock)
AddRef();
else
Release();
return S_OK;
}
}
// http protocol factroy .h
#ifndef TRIDENT_HTTP_PROTOCOL_FACTORY_H_
#define TRIDENT_HTTP_PROTOCOL_FACTORY_H_
#include <atlbase.h>
#include <Unknwn.h>
#include "base/basictypes.h"
namespace trident {
class HttpProtocolFactory : public IClassFactory {
public:
explicit HttpProtocolFactory(bool is_https_protocol);
// IUnknown
STDMETHOD(QueryInterface)(REFIID riid, void** ppvObject);
STDMETHOD_(ULONG, AddRef)();
STDMETHOD_(ULONG, Release)();
// IClassFactory
STDMETHOD(CreateInstance)(IUnknown* pUnkOuter, REFIID riid, void** ppvObject);
STDMETHOD(LockServer)(BOOL fLock);
private:
virtual ~HttpProtocolFactory();
volatile ULONG reference_count_;
CComPtr<IClassFactory> origin_factory_;
DISALLOW_IMPLICIT_CONSTRUCTORS(HttpProtocolFactory);
};
}
#endif // TRIDENT_HTTP_PROTOCOL_FACTORY_H_