diff --git a/atlcrypt.h b/atlcrypt.h index d6fba7e..c3b28b6 100644 --- a/atlcrypt.h +++ b/atlcrypt.h @@ -49,3 +49,115 @@ inline DWORD CertGetNameStringW(PCCERT_CONTEXT pCertContext, DWORD dwType, DWORD sNameString.ReleaseBuffer(dwSize); return dwSize; } + + +namespace ATL +{ + namespace Crypt + { + + // + // CCertContext + // + class CCertContext + { + public: + inline CCertContext() throw() : m_pCertContext(NULL) + { + } + + inline CCertContext(PCCERT_CONTEXT p) throw() : m_pCertContext(p) + { + } + + inline ~CCertContext() throw() + { + if (m_pCertContext) + CertFreeCertificateContext(m_pCertContext); + } + + inline operator PCCERT_CONTEXT() const throw() + { + return m_pCertContext; + } + + inline const CERT_CONTEXT& operator*() const + { + ATLENSURE(m_pCertContext != NULL); + return *m_pCertContext; + } + + inline PCCERT_CONTEXT* operator&() throw() + { + ATLASSERT(m_pCertContext == NULL); + return &m_pCertContext; + } + + inline PCCERT_CONTEXT operator->() const throw() + { + ATLASSERT(m_pCertContext != NULL); + return m_pCertContext; + } + + inline bool operator!() const throw() + { + return m_pCertContext == NULL; + } + + inline bool operator<(_In_opt_ PCCERT_CONTEXT p) const throw() + { + return m_pCertContext < p; + } + + inline bool operator!=(_In_opt_ PCCERT_CONTEXT p) const + { + return !operator==(p); + } + + inline bool operator==(_In_opt_ PCCERT_CONTEXT p) const throw() + { + return m_pCertContext == p; + } + + inline void Attach(_In_opt_ PCCERT_CONTEXT p) throw() + { + if (m_pCertContext) + CertFreeCertificateContext(m_pCertContext); + m_pCertContext = p; + } + + inline PCCERT_CONTEXT Detach() throw() + { + PCCERT_CONTEXT p = m_pCertContext; + m_pCertContext = NULL; + return p; + } + + inline BOOL Create(_In_ DWORD dwCertEncodingType, _In_ const BYTE *pbCertEncoded, _In_ DWORD cbCertEncoded) throw() + { + PCCERT_CONTEXT p; + + p = CertCreateCertificateContext(dwCertEncodingType, pbCertEncoded, cbCertEncoded); + if (!p) return FALSE; + + if (m_pCertContext) + CertFreeCertificateContext(m_pCertContext); + m_pCertContext = p; + return TRUE; + } + + inline BOOL Free() throw() + { + if (m_pCertContext) { + BOOL bResult = CertFreeCertificateContext(m_pCertContext); + m_pCertContext = NULL; + return bResult; + } else + return TRUE; + } + + protected: + PCCERT_CONTEXT m_pCertContext; + }; + } +} diff --git a/atlwin.h b/atlwin.h index 645b527..1a5261f 100644 --- a/atlwin.h +++ b/atlwin.h @@ -227,79 +227,97 @@ inline DWORD ExpandEnvironmentStringsW(__in LPCWSTR lpSrc, ATL::CAtlStringW &sVa -inline BOOL RegQueryStringValue(_In_ HKEY hReg, _In_z_ LPCSTR pszName, _Out_ ATL::CAtlStringA &sValue) +inline LSTATUS RegQueryStringValue(_In_ HKEY hReg, _In_z_ LPCSTR pszName, _Out_ ATL::CAtlStringA &sValue) { - DWORD dwSize = 0; - DWORD dwType; + LSTATUS lResult; + DWORD dwSize = 0, dwType; // Determine the type and size first. - if (::RegQueryValueExA(hReg, pszName, NULL, &dwType, NULL, &dwSize) == ERROR_SUCCESS) { + if ((lResult = ::RegQueryValueExA(hReg, pszName, NULL, &dwType, NULL, &dwSize)) == ERROR_SUCCESS) { if (dwType == REG_SZ || dwType == REG_MULTI_SZ) { // The value is REG_SZ or REG_MULTI_SZ. Read it now. LPSTR szTemp = sValue.GetBuffer(dwSize / sizeof(TCHAR)); - if (!szTemp) { - ::SetLastError(ERROR_OUTOFMEMORY); - return FALSE; - } - if (::RegQueryValueExA(hReg, pszName, NULL, NULL, (LPBYTE)szTemp, &dwSize) == ERROR_SUCCESS) { + if (!szTemp) return ERROR_OUTOFMEMORY; + if ((lResult = ::RegQueryValueExA(hReg, pszName, NULL, NULL, (LPBYTE)szTemp, &dwSize)) == ERROR_SUCCESS) { sValue.ReleaseBuffer(); - return TRUE; } else { // Reading of the value failed. sValue.ReleaseBuffer(0); - return FALSE; } } else if (dwType == REG_EXPAND_SZ) { // The value is REG_EXPAND_SZ. Read it and expand environment variables. ATL::CTempBuffer sTemp(dwSize / sizeof(CHAR)); - return - ::RegQueryValueExA(hReg, pszName, NULL, NULL, (LPBYTE)(CHAR*)sTemp, &dwSize) == ERROR_SUCCESS && - ::ExpandEnvironmentStringsA((const CHAR*)sTemp, sValue) != 0; + if ((lResult = ::RegQueryValueExA(hReg, pszName, NULL, NULL, (LPBYTE)(CHAR*)sTemp, &dwSize)) == ERROR_SUCCESS) + if (::ExpandEnvironmentStringsA((const CHAR*)sTemp, sValue) == 0) + lResult = ::GetLastError(); } else { // The value is not a string type. - return FALSE; + lResult = ERROR_INVALID_DATA; } - } else { - // The value with given name doesn't exist in this key. - return FALSE; } + + return lResult; } -inline BOOL RegQueryStringValue(_In_ HKEY hReg, _In_z_ LPCWSTR pszName, _Out_ ATL::CAtlStringW &sValue) +inline LSTATUS RegQueryStringValue(_In_ HKEY hReg, _In_z_ LPCWSTR pszName, _Out_ ATL::CAtlStringW &sValue) { - DWORD dwSize = 0; - DWORD dwType; + LSTATUS lResult; + DWORD dwSize = 0, dwType; // Determine the type and size first. - if (::RegQueryValueExW(hReg, pszName, NULL, &dwType, NULL, &dwSize) == ERROR_SUCCESS) { + if ((lResult = ::RegQueryValueExW(hReg, pszName, NULL, &dwType, NULL, &dwSize)) == ERROR_SUCCESS) { if (dwType == REG_SZ || dwType == REG_MULTI_SZ) { // The value is REG_SZ or REG_MULTI_SZ. Read it now. LPWSTR szTemp = sValue.GetBuffer(dwSize / sizeof(TCHAR)); - if (!szTemp) { - ::SetLastError(ERROR_OUTOFMEMORY); - return FALSE; - } - if (::RegQueryValueExW(hReg, pszName, NULL, NULL, (LPBYTE)szTemp, &dwSize) == ERROR_SUCCESS) { + if (!szTemp) return ERROR_OUTOFMEMORY; + if ((lResult = ::RegQueryValueExW(hReg, pszName, NULL, NULL, (LPBYTE)szTemp, &dwSize)) == ERROR_SUCCESS) { sValue.ReleaseBuffer(); - return TRUE; } else { // Reading of the value failed. sValue.ReleaseBuffer(0); - return FALSE; } } else if (dwType == REG_EXPAND_SZ) { // The value is REG_EXPAND_SZ. Read it and expand environment variables. ATL::CTempBuffer sTemp(dwSize / sizeof(WCHAR)); - return - ::RegQueryValueExW(hReg, pszName, NULL, NULL, (LPBYTE)(WCHAR*)sTemp, &dwSize) == ERROR_SUCCESS && - ::ExpandEnvironmentStringsW((const WCHAR*)sTemp, sValue) != 0; + if ((lResult = ::RegQueryValueExW(hReg, pszName, NULL, NULL, (LPBYTE)(WCHAR*)sTemp, &dwSize)) == ERROR_SUCCESS) + if (::ExpandEnvironmentStringsW((const WCHAR*)sTemp, sValue) == 0) + lResult = ::GetLastError(); } else { // The value is not a string type. - return FALSE; + lResult = ERROR_INVALID_DATA; } - } else { - // The value with given name doesn't exist in this key. - return FALSE; } + + return lResult; +} + + +inline LSTATUS RegQueryValueExA(__in HKEY hKey, __in_opt LPCSTR lpValueName, __reserved LPDWORD lpReserved, __out_opt LPDWORD lpType, __out ATL::CAtlArray &aData) +{ + LSTATUS lResult; + DWORD dwDataSize; + + if ((lResult = RegQueryValueExA(hKey, lpValueName, lpReserved, NULL, NULL, &dwDataSize)) == ERROR_SUCCESS) { + if (!aData.SetCount(dwDataSize)) return ERROR_OUTOFMEMORY; + if ((lResult = RegQueryValueExA(hKey, lpValueName, lpReserved, lpType, aData.GetData(), &dwDataSize)) != ERROR_SUCCESS) + aData.SetCount(0); + } + + return lResult; +} + + +inline LSTATUS RegQueryValueExW(__in HKEY hKey, __in_opt LPCWSTR lpValueName, __reserved LPDWORD lpReserved, __out_opt LPDWORD lpType, __out ATL::CAtlArray &aData) +{ + LSTATUS lResult; + DWORD dwDataSize; + + if ((lResult = RegQueryValueExW(hKey, lpValueName, lpReserved, NULL, NULL, &dwDataSize)) == ERROR_SUCCESS) { + if (!aData.SetCount(dwDataSize)) return ERROR_OUTOFMEMORY; + if ((lResult = RegQueryValueExW(hKey, lpValueName, lpReserved, lpType, aData.GetData(), &dwDataSize)) != ERROR_SUCCESS) + aData.SetCount(0); + } + + return lResult; }