COM组件的类厂(COM技术内幕笔记之四)
在上一篇中,介绍了怎么样用动态链接库去实现COM,但组件对我们来说仍是不透明的,我们需要知道实现组件DLL的位置,必须自己来加载组件的CreateInstance函数来获得组件的指针.在书中第一篇就曾经提到过:COM组件可以透明地在网络上(或本地)被重新分配位置,而不会影响本地客户程序.所以,由客户端来调用DLL并不是什么好主意.必须有一种更好的办法让组件的实现更透明,更灵活!
于是,就引入了类厂的概念.什么是类厂,类厂也是一个接口,它的职责是帮我们创造组件的对象.并返回给客户程序一个接口的指针.每个组件都必须有一个与之相关的类厂,这个类厂知道怎么样创建组件.当客户请求一个组件对象的实例时,实际上这个请求交给了类厂,由类厂创建组件实例,然后把实例指针交给客户程序。这么说有点难明白.先看一个伪实例.
1.实现二个接口IX,IY (上二节中有详细介绍)
2.实现一个组件CA,实现了IX,IY接口. (上二节中有详细介绍)
3.对于这个组件进行注册,把组件的信息加入到注册表中.
实现DllRegisterServer和DllUnregisterServer函数.函数具体功能就是把本组件的CLSID,ProgID,DLL的位置放入注册表中.这样程序就可以通过查询注册表来获得组件的位置.
4.创建本组件类厂的实例
class CFactory:public IClassFactory
{
virtual HRESULT __stdcall QueryInterface(const IID& iid,void** ppv);
virtual ULONG __stdcall AddRef();
virtual ULONG __stdcall Release();
virtual HRESULT __stdcall CreateInstance(IUnknown* pUnknownOuter,
const IID& iid,
void** ppv);
}
在类厂实例中,主要的功能就是CreateInstance了,这个函数就是创建组件的相应实例.看它的实现:
HRESULT __stdcall CFactory::CreateInstance(IUnknown* pUnknownOuter,const IID& iid,void** ppv)
{
//...
CA* pA = new CA;
if(pA == NULL)
return E_OUTOFMEMORY;
HRESULT hr = pA->QueryInterface(iid,ppv);
pA->Release();
return hr;
}
5.在这个组件的DLL中导出DllGetClassObject函数.这个函数的功能就是创建类厂的实例对象并查询接口.看其实现:
STDAPI DllGetClassObject(const CLSID& clsid,
const IID& iid,
void** ppv)
{
//....
CFactory* pFactory = new CFactory();
if(pFactory == NULL)
return E_OUTOFMEMORY;
HRESULT hr = pFactory->QueryInterface(iid,ppv);
pFactory->Release();
return hr;
}
组件的实现差不多就这么多,下面在客户端怎么调用组件呢?这就需要用到COM函数库了,由COM函数库去查找注册表,调用组件的类厂,创建组件实例,返回接口.如下所示:
IUnknown* pUnk = NULL;
IX* iX = NULL;
CoInitialize(NULL);
CoCreateInstance(CLSID_Component1,CLSCTX_INPROC_SERVER,IID_IUnknown,(void**)&pUnk);
pUnk->QueryInterface(IID_IX,(void**)&iX);
pUnk->Release();
iX->Fx();
iX->Release();
CoUninitialize();
至于客户是通过CoCreateInstance怎么获得组件的类厂,创建组件实例的.下面摘录的一篇文章很清晰的说明了这一切:
-------------------------------------------------------------------------------------
这部分我们将构造一个创建COM组件的最小框架结构,然后看一看其内部处理流程是怎样的
COM组件的运行机制,即COM是怎么跑起来的。
IUnknown *pUnk=NULL;
IObject *pObject=NULL;
CoInitialize(NULL);
CoCreateInstance(CLSID_Object, CLSCTX_INPROC_SERVER, NULL, IID_IUnknown, (void**)&pUnk);
pUnk->QueryInterface(IID_IOjbect, (void**)&pObject);
pUnk->Release();
pObject->Func();
pObject->Release();
CoUninitialize();
CoCreateInstance身上,让我们来看看它内部做了一些什么事情。以下是它内部实现的一个伪代码:
CoCreateInstance(....)
{
.......
IClassFactory *pClassFactory=NULL;
CoGetClassObject(CLSID_Object, CLSCTX_INPROC_SERVER, NULL, IID_IClassFactory, (void **)&pClassFactory);
pClassFactory->CreateInstance(NULL, IID_IUnknown, (void**)&pUnk);
pClassFactory->Release();
........
}
这段话的意思就是先得到类厂对象,再通过类厂创建组件从而得到IUnknown指针。
继续深入一步,看看CoGetClassObject的内部伪码:
CoGetClassObject(.....)
{
//通过查注册表CLSID_Object,得知组件DLL的位置、文件名
//装入DLL库
//使用函数GetProcAddress(...)得到DLL库中函数DllGetClassObject的函数指针。
//调用DllGetClassObject
}
DllGetClassObject是干什么的,它是用来获得类厂对象的。只有先得到类厂才能去创建组件.
下面是DllGetClassObject的伪码:
DllGetClassObject(...)
{
......
CFactory* pFactory= new CFactory; //类厂对象
pFactory->QueryInterface(IID_IClassFactory, (void**)&pClassFactory);
//查询IClassFactory指针
pFactory->Release();
......
}
CoGetClassObject的流程已经到此为止,现在返回CoCreateInstance,看看CreateInstance的伪码:
CFactory::CreateInstance(.....)
{
...........
CObject *pObject = new CObject; //组件对象
pObject->QueryInterface(IID_IUnknown, (void**)&pUnk);
pObject->Release();
...........
}
下图是从COM+技术内幕中COPY来的一个例图,从图中可以清楚的看到CoCreateInstance的整个流程。
接下来就写下完全的源代码,说明类厂的概念:
Component实现:(FacInterFace.dll)
#ifndef _IFACE_H
#define _IFACE_H
//interfaces
interface IX:IUnknown
{
virtual void __stdcall Fx() = 0;
};
interface IY: IUnknown
{
virtual void __stdcall Fy() = 0;
};
interface IZ: IUnknown
{
virtual void __stdcall Fz() = 0;
};
//Forward references for GUIDs
extern "C"
{
extern const IID IID_IX;
extern const IID IID_IY;
extern const IID IID_IZ;
extern const CLSID CLSID_Component1;
}
extern "C"
{
// {A33D4226-0F56-4e34-91F3-BF4F85761101}
static const IID IID_IX =
{ 0xa33d4226, 0xf56, 0x4e34, { 0x91, 0xf3, 0xbf, 0x4f, 0x85, 0x76, 0x11, 0x1 } };
// {41A5F090-B33A-4ae8-A1BB-EF2D0B4F8B0E}
static const IID IID_IY =
{ 0x41a5f090, 0xb33a, 0x4ae8, { 0xa1, 0xbb, 0xef, 0x2d, 0xb, 0x4f, 0x8b, 0xe } };
// {65411881-4E05-4b71-9CB5-943D5E0787C4}
static const IID IID_IZ =
{ 0x65411881, 0x4e05, 0x4b71, { 0x9c, 0xb5, 0x94, 0x3d, 0x5e, 0x7, 0x87, 0xc4 } };
}
//组件的CLSID,每个组件都有唯一的CLSID,需要把此CLSID添加到注册表中去.如何添加,见Register.cpp文件.
// {282D8F98-BC89-43d5-9225-0B1BB479CBDE}
static const CLSID CLSID_Component1 =
{ 0x282d8f98, 0xbc89, 0x43d5, { 0x92, 0x25, 0xb, 0x1b, 0xb4, 0x79, 0xcb, 0xde } };
#endif
组件的注册:
HRESULT RegisterServer(HMODULE hModule, const CLSID& clsid, const char* szFriendlyName, const char* szVerIndProgID, const char* szProgID);
HRESULT UnRegisterServer(const CLSID& clsid, const char* szVerIndProgID, const char* szProgID);
//In Register.cpp
//此文件是如何注册组件的代码实现,是把CLSID,ProgID,Version,Dll位置添加到
//HKEY_CLASSES_ROOT/CLSID,HKEY_CLASSES_ROOT的子键中去.
#include <objbase.h>
#include <assert.h>
#include "Register.h"
//set the given key and its value;
BOOL setKeyAndValue(const char* pszPath,
const char* szSubkey,
const char* szValue);
//Convert a CLSID into a char string
void CLSIDtochar(const CLSID& clsid,
char* szCLSID,
int length);
//Delete szKeyChild and all of its descendents
LONG recursiveDeleteKey(HKEY hKeyParent,const char* szKeyChild);
//size of a CLSID as a string
const int CLSID_STRING_SIZE = 39;
//Register the component in the registry
HRESULT RegisterServer(HMODULE hModule,
const CLSID& clsid,
const char* szFriendlyName,
const char* szVerIndProgID,
const char* szProgID)
{
//Get the Server location
char szModule[512];
DWORD dwResult = ::GetModuleFileName(hModule,szModule,sizeof(szModule)/sizeof(char));
assert(dwResult!=0);
//Convert the CLSID into a char
char szCLSID[CLSID_STRING_SIZE];
CLSIDtochar(clsid,szCLSID,sizeof(szCLSID));
//Build the key CLSID\\{}
char szKey[64];
strcpy(szKey,"CLSID\\");
strcat(szKey,szCLSID);
//Add the CLSID to the registry
setKeyAndValue(szKey,NULL,szFriendlyName);
//Add the Server filename subkey under the CLSID key
setKeyAndValue(szKey,"InprocServer32",szModule);
setKeyAndValue(szKey,"ProgID",szProgID);
setKeyAndValue(szKey,"VersionIndependentProgID",szVerIndProgID);
//Add the version-independent ProgID subkey under HKEY_CLASSES_ROOT
setKeyAndValue(szVerIndProgID,NULL,szFriendlyName);
setKeyAndValue(szVerIndProgID,"CLSID",szCLSID);
setKeyAndValue(szVerIndProgID,"CurVer",szProgID);
//Add the versioned ProgID subkey under HKEY_CLASSES_ROOT
setKeyAndValue(szProgID,NULL,szFriendlyName);
setKeyAndValue(szProgID,"CLSID",szCLSID);
return S_OK;
}
//
//Remove the component from the register
//
HRESULT UnRegisterServer(const CLSID& clsid, // Class ID
const char* szVerIndProgID, // Programmatic
const char* szProgID) // IDs
{
//Convert the CLSID into a char.
char szCLSID[CLSID_STRING_SIZE];
CLSIDtochar(clsid,szCLSID,sizeof(szCLSID));
//Build the key CLSID\\{}
char szKey[64];
strcpy(szKey,"CLSID\\");
strcat(szKey,szCLSID);
//Delete the CLSID key - CLSID\{}
LONG lResult = recursiveDeleteKey(HKEY_CLASSES_ROOT,szKey);
assert((lResult == ERROR_SUCCESS) || (lResult == ERROR_FILE_NOT_FOUND));
//Delete the version-independent ProgID Key
lResult = recursiveDeleteKey(HKEY_CLASSES_ROOT,szVerIndProgID);
assert((lResult == ERROR_SUCCESS) || (lResult == ERROR_FILE_NOT_FOUND));
//Delete the ProgID key.
lResult = recursiveDeleteKey(HKEY_CLASSES_ROOT,szProgID);
assert((lResult == ERROR_SUCCESS) || (lResult == ERROR_FILE_NOT_FOUND));
return S_OK;
}
//Convert a CLSID to a char string
void CLSIDtochar(const CLSID& clsid,
char* szCLSID,
int length)
{
assert(length>=CLSID_STRING_SIZE);
//Get CLSID
LPOLESTR wszCLSID = NULL;
HRESULT hr = StringFromCLSID(clsid,&wszCLSID);
assert(SUCCEEDED(hr));
//Convert from wide characters to non_wide
wcstombs(szCLSID,wszCLSID,length);
//Free memory
CoTaskMemFree(wszCLSID);
}
//
// Delete a Key and all of its descendents
//
LONG recursiveDeleteKey(HKEY hKeyParent,const char* lpszKeyChild)
{
//Open the child.
HKEY hKeyChild;
LONG lRes = RegOpenKeyEx(hKeyParent,lpszKeyChild,0,KEY_ALL_ACCESS,&hKeyChild);
if(lRes != ERROR_SUCCESS)
return lRes;
//Enumerate all of the decendents of this child
FILETIME time;
char szBuffer[256];
DWORD dwSize = 256 ;
while(RegEnumKeyEx(hKeyChild,0,szBuffer,&dwSize,NULL,
NULL,NULL,&time) == S_OK)
{
//Delete the decendents of this child.
lRes = recursiveDeleteKey(hKeyChild,szBuffer);
if(lRes != ERROR_SUCCESS)
{
RegCloseKey(hKeyChild);
return lRes;
}
dwSize = 256;
}
RegCloseKey(hKeyChild);
return RegDeleteKey(hKeyParent,lpszKeyChild);
}
BOOL setKeyAndValue(const char* szKey,
const char* szSubkey,
const char* szValue)
{
HKEY hKey;
char szKeyBuf[1024];
//Copy keyname into buffer.
strcpy(szKeyBuf,szKey);
//Add subkey name to buffer.
if(szSubkey!=NULL)
{
strcat(szKeyBuf,"\\");
strcat(szKeyBuf,szSubkey);
}
// Create and open key and subkey.
long lResult = RegCreateKeyEx(HKEY_CLASSES_ROOT ,
szKeyBuf,
0, NULL, REG_OPTION_NON_VOLATILE,
KEY_ALL_ACCESS, NULL,
&hKey, NULL) ;
if (lResult != ERROR_SUCCESS)
{
return FALSE ;
}
// Set the Value.
if (szValue != NULL)
{
RegSetValueEx(hKey, NULL, 0, REG_SZ,
(BYTE *)szValue,
strlen(szValue)+1) ;
}
RegCloseKey(hKey) ;
return TRUE ;
}
组件的实现:
//此文件是组件CA,组件类厂CFactory的实现,CA的实现与前面讲述的是一样的,关键在于多引入了
//一个CFactory,还有一个是全局函数DllGetClassObject,另外,除了要导出DllGetClassObject之
//外,还要导出三个函数,分别是DllCanUnloadNow / DllRegisterServer / DllUnregisterServer.
//还有一项工作就是在DllMain中保存模块的信息.
#include <iostream.h>
#include <objbase.h>
#include "..\MYIF2\IFACE.h"
#include "Register.h"
//#ifndef EXPORTAPI
//#define EXPORTAPI extern "C" __declspec(dllexport)
//#endif
void trace(const char* msg){cout<<msg<<endl;}
//Gobal variables
static HMODULE g_hModule = NULL ;
static long g_cComponents = 0; //Count of active components
static long g_cServerLocks = 0; //Count of locks
//Friendly name of component
const char g_szFriendlyName[] = "Inside COM.Chapter 7 Example";
//Version-independent ProgID
const char g_szVerIndProgID[] = "InsideCOM.Chap07";
//ProgID
const char g_szProgID[] = "InsideCOM.Chap07.1";
//Component
class CA:public IX,public IY
{
public:
//IUnknown
virtual HRESULT __stdcall QueryInterface(const IID& iid,void** ppv);
virtual ULONG __stdcall AddRef();
virtual ULONG __stdcall Release();
//Interface IX
virtual void __stdcall Fx(){cout<<"Fx function"<<endl;}
//Interface IY
virtual void __stdcall Fy(){cout<<"Fy function"<<endl;}
//Constructor
CA();
//Destructor
~CA();
private:
long m_cRef;
};
CA::CA():m_cRef(1)
{
InterlockedIncrement(&g_cComponents);
}
CA::~CA()
{
InterlockedDecrement(&g_cComponents);
trace("Component:\t\tDestory self");
}
//IUnknown implementation
HRESULT __stdcall CA::QueryInterface(const IID& iid,void** ppv)
{
if(iid == IID_IUnknown)
{
*ppv = static_cast<IX*>(this);
}
else if(iid == IID_IX)
{
*ppv = static_cast<IX*>(this);
trace("Component:\tReturn pointer to IX.");
}
else if(iid == IID_IY)
{
*ppv = static_cast<IY*>(this);
trace("Component:\tReturn pointer to IY.");
}
else
{
*ppv = NULL;
trace("Component:\tCannot Get pointer to IX/IY");
return E_NOINTERFACE;
}
reinterpret_cast<IUnknown*>(*ppv)->AddRef();
return S_OK;
}
ULONG __stdcall CA::AddRef()
{
return InterlockedIncrement(&m_cRef);
}
ULONG __stdcall CA::Release()
{
if(InterlockedDecrement(&m_cRef)==0)
{
delete this;
return 0;
}
return m_cRef;
}
///////////////////////////////////////////
//class factory
///////////////////////////////////////////
class CFactory:public IClassFactory
{
public:
//IUnknown
virtual HRESULT __stdcall QueryInterface(const IID& iid,void** ppv);
virtual ULONG __stdcall AddRef();
virtual ULONG __stdcall Release();
//Interface IClassFactory
virtual HRESULT __stdcall CreateInstance(IUnknown* pUnknownOuter,
const IID& iid,
void** ppv);
virtual HRESULT __stdcall LockServer(BOOL bLock);
//Constructor
CFactory():m_cRef(1){}
//Destructor
~CFactory() {trace("Class factory:\t\tDestory self.");}
private:
long m_cRef;
};
HRESULT __stdcall CFactory::QueryInterface(const IID& iid,void** ppv)
{
if((iid == IID_IUnknown) || (iid == IID_IClassFactory))
{
*ppv= static_cast<IClassFactory*>(this);
}
else
{
*ppv = NULL;
return E_NOINTERFACE;
}
reinterpret_cast<IUnknown*>(*ppv)->AddRef();
return S_OK;
}
ULONG __stdcall CFactory::AddRef()
{
return InterlockedIncrement(&m_cRef);
}
ULONG __stdcall CFactory::Release()
{
if(InterlockedDecrement(&m_cRef)==0)
{
delete this;
return 0;
}
else
return m_cRef;
}
HRESULT __stdcall CFactory::CreateInstance(IUnknown* pUnknownOuter,const IID& iid,void** ppv)
{
trace("Class factory:\t\tCreate component.");
// Cannot aggregate.
if (pUnknownOuter != NULL)
{
return CLASS_E_NOAGGREGATION ;
}
//if(pUnknownOuter!=NULL)
// return CLASS_E_NOA
CA* pA = new CA;
if(pA == NULL)
return E_OUTOFMEMORY;
//Get the request interface
HRESULT hr = pA->QueryInterface(iid,ppv);
pA->Release();
return hr;
}
//LockServer
HRESULT __stdcall CFactory::LockServer(BOOL bLock)
{
if(bLock)
{
InterlockedIncrement(&g_cServerLocks);
}
else
InterlockedDecrement(&g_cServerLocks);
return S_OK;
}
//Can Dll unload now?
int AddNum(int a,int b)
{
return a+b;
}
STDAPI DllCanUnloadNow()
{
if((g_cComponents ==0 ) && (g_cServerLocks==0))
return S_OK;
else
return S_FALSE;
}
STDAPI DllGetClassObject(const CLSID& clsid,
const IID& iid,
void** ppv)
{
trace("DllGetClassObject:\tCreate Class factory");
if(clsid != CLSID_Component1)
{
return CLASS_E_CLASSNOTAVAILABLE;
}
CFactory* pFactory = new CFactory();
if(pFactory == NULL)
return E_OUTOFMEMORY;
//Get request interfaces
HRESULT hr = pFactory->QueryInterface(iid,ppv);
pFactory->Release();
return hr;
}
//Server registration
STDAPI DllRegisterServer()
{
return RegisterServer(g_hModule,CLSID_Component1,g_szFriendlyName,
g_szVerIndProgID,g_szProgID);
}
//Server unregistration
STDAPI DllUnregisterServer()
{
return UnRegisterServer(CLSID_Component1,g_szVerIndProgID,g_szProgID);
}
BOOL APIENTRY DllMain(HANDLE hModule,
DWORD dwReason,
void* lpReserved)
{
if (dwReason == DLL_PROCESS_ATTACH)
{
g_hModule = (HMODULE)hModule ;
}
return TRUE ;
}
以上是组件的实现。下面是客户端的代码实现:
int main()
{
HRESULT hr;
::CoInitialize(NULL);
trace("Call CoCreateInstance to Create");
trace(" componet and get interface IX");
IX* pIX = NULL;
hr = ::CoCreateInstance(CLSID_Component1,
NULL,
CLSCTX_INPROC_SERVER,
IID_IX,
(void**)&pIX);
if(SUCCEEDED(hr))
{
trace("Succeeded getting IX");
pIX->Fx();
trace("Ask for Interface IY");
IY* pIY = NULL;
hr = pIX->QueryInterface(IID_IY,(void**)&pIY);
if(SUCCEEDED(hr))
{
trace("Succeeded getting IY");
pIY->Fy();
pIY->Release();
trace("Release IY interface");
}
else
{
trace("Could not get interface IY.");
}
pIX->Release();
}
else
cout<<"Client: \t\tCould not create component hr="<<hex<<hr<<endl;
CoUninitialize();
}
再在最后详述一篇,客户端调用CoCreateInstance,导致调用CoGetClassObject,CoGetClassObject通过查找注册表,得知DLL位置,文件名,然后调用DLL中DllGetClassObject,
DllGetClassObject的功能是返回CFactory的实例.
返回后,回到CoCreateInstance,通过CFactory的指针,调用
pClassFactory->CreaetInstance()创建组件实例.
这样就返回了组件实例的指针.
CoCreateInstace --> CoGetClassObject --> DllGetClassObject --> Get CFactory*
<-------------------------------------------------------
--> CFactory->CreateInstance(); --> Get IX*
IX->Fx();