diff --git a/include/WinStd/COM.h b/include/WinStd/COM.h index aa4fd601..271248b1 100644 --- a/include/WinStd/COM.h +++ b/include/WinStd/COM.h @@ -225,9 +225,10 @@ namespace winstd template bstr(_In_ const std::basic_string &src) { - if (src.length() >= UINT_MAX) - throw std::invalid_argument("String too long"); - m_h = SysAllocStringLen(src.c_str(), (UINT)src.length()); + size_t len = src.length(); + if (len > UINT_MAX) + throw std::invalid_argument("string too long"); + m_h = SysAllocStringLen(src.c_str(), static_cast(len)); if (!m_h) throw std::bad_alloc(); } diff --git a/include/WinStd/Common.h b/include/WinStd/Common.h index 8532c085..a5c855fc 100644 --- a/include/WinStd/Common.h +++ b/include/WinStd/Common.h @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -203,6 +204,42 @@ typedef const BYTE *LPCBYTE; #pragma warning(disable: 4996) #pragma warning(disable: 4505) // Don't warn on unused code +#ifdef _WIN64 +inline ULONGLONG ULongLongMult(ULONGLONG a, ULONGLONG b) +{ + ULONGLONG result; + if (SUCCEEDED(ULongLongMult(a, b, &result))) + return result; + throw std::invalid_argument("multiply overflow"); +} +#else +inline SIZE_T SIZETMult(SIZE_T a, SIZE_T b) +{ + SIZE_T result; + if (SUCCEEDED(SIZETMult(a, b, &result))) + return result; + throw std::invalid_argument("multiply overflow"); +} +#endif + +#ifdef _WIN64 +inline ULONGLONG ULongLongAdd(ULONGLONG a, ULONGLONG b) +{ + ULONGLONG result; + if (SUCCEEDED(ULongLongAdd(a, b, &result))) + return result; + throw std::invalid_argument("add overflow"); +} +#else +inline SIZE_T SIZETAdd(SIZE_T a, SIZE_T b) +{ + SIZE_T result; + if (SUCCEEDED(SIZETAdd(a, b, &result))) + return result; + throw std::invalid_argument("add overflow"); +} +#endif + /// \addtogroup WinStdStrFormat /// @{ diff --git a/include/WinStd/Crypt.h b/include/WinStd/Crypt.h index d92d41f3..effaa9bb 100644 --- a/include/WinStd/Crypt.h +++ b/include/WinStd/Crypt.h @@ -176,9 +176,14 @@ static _Success_(return != 0) BOOL CryptExportKey(_In_ HCRYPTKEY hKey, _In_ HCRY template static _Success_(return != 0) BOOL CryptEncrypt(_In_ HCRYPTKEY hKey, _In_opt_ HCRYPTHASH hHash, _In_ BOOL Final, _In_ DWORD dwFlags, _Inout_ std::vector<_Ty, _Ax> &aData) { + SIZE_T + sDataLen = SIZETMult(aData.size(), sizeof(_Ty)), + sBufLen = SIZETMult(aData.capacity(), sizeof(_Ty)); + if (sDataLen > DWORD_MAX || sBufLen > DWORD_MAX) + throw std::invalid_argument("Data too big"); DWORD - dwDataLen = (DWORD)(aData.size() * sizeof(_Ty)), - dwBufLen = (DWORD)(aData.capacity() * sizeof(_Ty)), + dwDataLen = static_cast(sDataLen), + dwBufLen = static_cast(sBufLen), dwEncLen = dwDataLen, dwResult; @@ -226,7 +231,10 @@ static _Success_(return != 0) BOOL CryptEncrypt(_In_ HCRYPTKEY hKey, _In_opt_ HC template static _Success_(return != 0) BOOL CryptDecrypt(_In_ HCRYPTKEY hKey, _In_opt_ HCRYPTHASH hHash, _In_ BOOL Final, _In_ DWORD dwFlags, _Inout_ std::vector<_Ty, _Ax> &aData) { - DWORD dwDataLen = (DWORD)(aData.size() * sizeof(_Ty)); + SIZE_T sDataLen = SIZETMult(aData.size(), sizeof(_Ty)); + if (sDataLen > DWORD_MAX) + throw std::invalid_argument("Data too big"); + DWORD dwDataLen = static_cast(sDataLen); if (CryptDecrypt(hKey, hHash, Final, dwFlags, reinterpret_cast(aData.data()), &dwDataLen)) { // Decryption succeeded. diff --git a/include/WinStd/WLAN.h b/include/WinStd/WLAN.h index e9ae90da..10a5fac0 100644 --- a/include/WinStd/WLAN.h +++ b/include/WinStd/WLAN.h @@ -33,23 +33,25 @@ extern DWORD (WINAPI *pfnWlanReasonCodeToString)(__in DWORD dwReasonCode, __in D template static DWORD WlanReasonCodeToString(_In_ DWORD dwReasonCode, _Inout_ std::basic_string &sValue, __reserved PVOID pReserved) { - DWORD dwSize = 0; + SIZE_T sSize = 0; if (!::pfnWlanReasonCodeToString) return ERROR_CALL_NOT_IMPLEMENTED; for (;;) { // Increment size and allocate buffer. - dwSize += 1024; - std::unique_ptr szBuffer(new wchar_t[dwSize]); + sSize = SIZETAdd(sSize, 1024); + if (sSize > DWORD_MAX) + throw std::runtime_exception("Data too big"); + std::unique_ptr szBuffer(new wchar_t[sSize]); // Try! - DWORD dwResult = ::pfnWlanReasonCodeToString(dwReasonCode, dwSize, szBuffer.get(), pReserved); + DWORD dwResult = ::pfnWlanReasonCodeToString(dwReasonCode, static_cast(sSize), szBuffer.get(), pReserved); if (dwResult == ERROR_SUCCESS) { - DWORD dwLength = (DWORD)wcsnlen(szBuffer.get(), dwSize); - if (dwLength < dwSize - 1) { + SIZE_T sLength = wcsnlen(szBuffer.get(), sSize); + if (sLength < sSize - 1) { // Buffer was long enough. - sValue.assign(szBuffer.get(), dwLength); + sValue.assign(szBuffer.get(), sLength); return ERROR_SUCCESS; } } else { diff --git a/include/WinStd/Win.h b/include/WinStd/Win.h index 677ae684..a106c17d 100644 --- a/include/WinStd/Win.h +++ b/include/WinStd/Win.h @@ -177,21 +177,23 @@ static _Success_(return != 0) BOOL GetFileVersionInfoW(_In_z_ LPCWSTR lptstrFile /// @copydoc ExpandEnvironmentStringsW() template -static _Success_(return != 0) DWORD ExpandEnvironmentStringsA(_In_z_ LPCSTR lpSrc, _Out_ std::basic_string &sValue) noexcept +static _Success_(return != 0) DWORD ExpandEnvironmentStringsA(_In_z_ LPCSTR lpSrc, _Out_ std::basic_string &sValue) { assert(0); // TODO: Test this code. - for (DWORD dwSizeOut = (DWORD)strlen(lpSrc) + 0x100;;) { - DWORD dwSizeIn = dwSizeOut; + for (SIZE_T sSizeOut = SIZETAdd(strlen(lpSrc), 0x100);;) { + if (sSizeOut > DWORD_MAX) + throw std::invalid_argument("String too big"); + DWORD dwSizeIn = static_cast(sSizeOut); std::unique_ptr szBuffer(new char[(size_t)dwSizeIn + 2]); // Note: ANSI version requires one extra char. - dwSizeOut = ::ExpandEnvironmentStringsA(lpSrc, szBuffer.get(), dwSizeIn); - if (dwSizeOut == 0) { + sSizeOut = ::ExpandEnvironmentStringsA(lpSrc, szBuffer.get(), dwSizeIn); + if (sSizeOut == 0) { // Error or zero-length input. break; - } else if (dwSizeOut <= dwSizeIn) { + } else if (sSizeOut <= dwSizeIn) { // The buffer was sufficient. - sValue.assign(szBuffer.get(), dwSizeOut - 1); - return dwSizeOut; + sValue.assign(szBuffer.get(), sSizeOut - 1); + return static_cast(sSizeOut); } } @@ -205,19 +207,21 @@ static _Success_(return != 0) DWORD ExpandEnvironmentStringsA(_In_z_ LPCSTR lpSr /// \sa [ExpandEnvironmentStrings function](https://msdn.microsoft.com/en-us/library/windows/desktop/ms724265.aspx) /// template -static _Success_(return != 0) DWORD ExpandEnvironmentStringsW(_In_z_ LPCWSTR lpSrc, _Out_ std::basic_string &sValue) noexcept +static _Success_(return != 0) DWORD ExpandEnvironmentStringsW(_In_z_ LPCWSTR lpSrc, _Out_ std::basic_string &sValue) { - for (DWORD dwSizeOut = (DWORD)wcslen(lpSrc) + 0x100;;) { - DWORD dwSizeIn = dwSizeOut; + for (SIZE_T sSizeOut = SIZETAdd(wcslen(lpSrc), 0x100);;) { + if (sSizeOut > DWORD_MAX) + throw std::invalid_argument("String too big"); + DWORD dwSizeIn = static_cast(sSizeOut); std::unique_ptr szBuffer(new wchar_t[(size_t)dwSizeIn + 1]); - dwSizeOut = ::ExpandEnvironmentStringsW(lpSrc, szBuffer.get(), dwSizeIn); - if (dwSizeOut == 0) { + sSizeOut = ::ExpandEnvironmentStringsW(lpSrc, szBuffer.get(), dwSizeIn); + if (sSizeOut == 0) { // Error or zero-length input. break; - } else if (dwSizeOut <= dwSizeIn) { + } else if (sSizeOut <= dwSizeIn) { // The buffer was sufficient. - sValue.assign(szBuffer.get(), dwSizeOut - 1); - return dwSizeOut; + sValue.assign(szBuffer.get(), sSizeOut - 1); + return static_cast(sSizeOut); } }