Throw on 4G overflows

Signed-off-by: Simon Rozman <simon@rozman.si>
This commit is contained in:
Simon Rozman 2023-08-29 15:47:09 +02:00
parent ff8136f708
commit f9090e02f4
5 changed files with 81 additions and 29 deletions

View File

@ -225,9 +225,10 @@ namespace winstd
template<class _Traits, class _Ax> template<class _Traits, class _Ax>
bstr(_In_ const std::basic_string<OLECHAR, _Traits, _Ax> &src) bstr(_In_ const std::basic_string<OLECHAR, _Traits, _Ax> &src)
{ {
if (src.length() >= UINT_MAX) size_t len = src.length();
throw std::invalid_argument("String too long"); if (len > UINT_MAX)
m_h = SysAllocStringLen(src.c_str(), (UINT)src.length()); throw std::invalid_argument("string too long");
m_h = SysAllocStringLen(src.c_str(), static_cast<UINT>(len));
if (!m_h) if (!m_h)
throw std::bad_alloc(); throw std::bad_alloc();
} }

View File

@ -8,6 +8,7 @@
#include <Windows.h> #include <Windows.h>
#include <assert.h> #include <assert.h>
#include <intsafe.h>
#include <stdarg.h> #include <stdarg.h>
#include <tchar.h> #include <tchar.h>
#include <iostream> #include <iostream>
@ -203,6 +204,42 @@ typedef const BYTE *LPCBYTE;
#pragma warning(disable: 4996) #pragma warning(disable: 4996)
#pragma warning(disable: 4505) // Don't warn on unused code #pragma warning(disable: 4505) // Don't warn on unused code
#ifdef _WIN64
inline ULONGLONG ULongLongMult(ULONGLONG a, ULONGLONG b)
{
ULONGLONG result;
if (SUCCEEDED(ULongLongMult(a, b, &result)))
return result;
throw std::invalid_argument("multiply overflow");
}
#else
inline SIZE_T SIZETMult(SIZE_T a, SIZE_T b)
{
SIZE_T result;
if (SUCCEEDED(SIZETMult(a, b, &result)))
return result;
throw std::invalid_argument("multiply overflow");
}
#endif
#ifdef _WIN64
inline ULONGLONG ULongLongAdd(ULONGLONG a, ULONGLONG b)
{
ULONGLONG result;
if (SUCCEEDED(ULongLongAdd(a, b, &result)))
return result;
throw std::invalid_argument("add overflow");
}
#else
inline SIZE_T SIZETAdd(SIZE_T a, SIZE_T b)
{
SIZE_T result;
if (SUCCEEDED(SIZETAdd(a, b, &result)))
return result;
throw std::invalid_argument("add overflow");
}
#endif
/// \addtogroup WinStdStrFormat /// \addtogroup WinStdStrFormat
/// @{ /// @{

View File

@ -176,9 +176,14 @@ static _Success_(return != 0) BOOL CryptExportKey(_In_ HCRYPTKEY hKey, _In_ HCRY
template<class _Ty, class _Ax> template<class _Ty, class _Ax>
static _Success_(return != 0) BOOL CryptEncrypt(_In_ HCRYPTKEY hKey, _In_opt_ HCRYPTHASH hHash, _In_ BOOL Final, _In_ DWORD dwFlags, _Inout_ std::vector<_Ty, _Ax> &aData) static _Success_(return != 0) BOOL CryptEncrypt(_In_ HCRYPTKEY hKey, _In_opt_ HCRYPTHASH hHash, _In_ BOOL Final, _In_ DWORD dwFlags, _Inout_ std::vector<_Ty, _Ax> &aData)
{ {
SIZE_T
sDataLen = SIZETMult(aData.size(), sizeof(_Ty)),
sBufLen = SIZETMult(aData.capacity(), sizeof(_Ty));
if (sDataLen > DWORD_MAX || sBufLen > DWORD_MAX)
throw std::invalid_argument("Data too big");
DWORD DWORD
dwDataLen = (DWORD)(aData.size() * sizeof(_Ty)), dwDataLen = static_cast<DWORD>(sDataLen),
dwBufLen = (DWORD)(aData.capacity() * sizeof(_Ty)), dwBufLen = static_cast<DWORD>(sBufLen),
dwEncLen = dwDataLen, dwEncLen = dwDataLen,
dwResult; dwResult;
@ -226,7 +231,10 @@ static _Success_(return != 0) BOOL CryptEncrypt(_In_ HCRYPTKEY hKey, _In_opt_ HC
template<class _Ty, class _Ax> template<class _Ty, class _Ax>
static _Success_(return != 0) BOOL CryptDecrypt(_In_ HCRYPTKEY hKey, _In_opt_ HCRYPTHASH hHash, _In_ BOOL Final, _In_ DWORD dwFlags, _Inout_ std::vector<_Ty, _Ax> &aData) static _Success_(return != 0) BOOL CryptDecrypt(_In_ HCRYPTKEY hKey, _In_opt_ HCRYPTHASH hHash, _In_ BOOL Final, _In_ DWORD dwFlags, _Inout_ std::vector<_Ty, _Ax> &aData)
{ {
DWORD dwDataLen = (DWORD)(aData.size() * sizeof(_Ty)); SIZE_T sDataLen = SIZETMult(aData.size(), sizeof(_Ty));
if (sDataLen > DWORD_MAX)
throw std::invalid_argument("Data too big");
DWORD dwDataLen = static_cast<DWORD>(sDataLen);
if (CryptDecrypt(hKey, hHash, Final, dwFlags, reinterpret_cast<BYTE*>(aData.data()), &dwDataLen)) { if (CryptDecrypt(hKey, hHash, Final, dwFlags, reinterpret_cast<BYTE*>(aData.data()), &dwDataLen)) {
// Decryption succeeded. // Decryption succeeded.

View File

@ -33,23 +33,25 @@ extern DWORD (WINAPI *pfnWlanReasonCodeToString)(__in DWORD dwReasonCode, __in D
template<class _Traits, class _Ax> template<class _Traits, class _Ax>
static DWORD WlanReasonCodeToString(_In_ DWORD dwReasonCode, _Inout_ std::basic_string<wchar_t, _Traits, _Ax> &sValue, __reserved PVOID pReserved) static DWORD WlanReasonCodeToString(_In_ DWORD dwReasonCode, _Inout_ std::basic_string<wchar_t, _Traits, _Ax> &sValue, __reserved PVOID pReserved)
{ {
DWORD dwSize = 0; SIZE_T sSize = 0;
if (!::pfnWlanReasonCodeToString) if (!::pfnWlanReasonCodeToString)
return ERROR_CALL_NOT_IMPLEMENTED; return ERROR_CALL_NOT_IMPLEMENTED;
for (;;) { for (;;) {
// Increment size and allocate buffer. // Increment size and allocate buffer.
dwSize += 1024; sSize = SIZETAdd(sSize, 1024);
std::unique_ptr<wchar_t[]> szBuffer(new wchar_t[dwSize]); if (sSize > DWORD_MAX)
throw std::runtime_exception("Data too big");
std::unique_ptr<wchar_t[]> szBuffer(new wchar_t[sSize]);
// Try! // Try!
DWORD dwResult = ::pfnWlanReasonCodeToString(dwReasonCode, dwSize, szBuffer.get(), pReserved); DWORD dwResult = ::pfnWlanReasonCodeToString(dwReasonCode, static_cast<DWORD>(sSize), szBuffer.get(), pReserved);
if (dwResult == ERROR_SUCCESS) { if (dwResult == ERROR_SUCCESS) {
DWORD dwLength = (DWORD)wcsnlen(szBuffer.get(), dwSize); SIZE_T sLength = wcsnlen(szBuffer.get(), sSize);
if (dwLength < dwSize - 1) { if (sLength < sSize - 1) {
// Buffer was long enough. // Buffer was long enough.
sValue.assign(szBuffer.get(), dwLength); sValue.assign(szBuffer.get(), sLength);
return ERROR_SUCCESS; return ERROR_SUCCESS;
} }
} else { } else {

View File

@ -177,21 +177,23 @@ static _Success_(return != 0) BOOL GetFileVersionInfoW(_In_z_ LPCWSTR lptstrFile
/// @copydoc ExpandEnvironmentStringsW() /// @copydoc ExpandEnvironmentStringsW()
template<class _Traits, class _Ax> template<class _Traits, class _Ax>
static _Success_(return != 0) DWORD ExpandEnvironmentStringsA(_In_z_ LPCSTR lpSrc, _Out_ std::basic_string<char, _Traits, _Ax> &sValue) noexcept static _Success_(return != 0) DWORD ExpandEnvironmentStringsA(_In_z_ LPCSTR lpSrc, _Out_ std::basic_string<char, _Traits, _Ax> &sValue)
{ {
assert(0); // TODO: Test this code. assert(0); // TODO: Test this code.
for (DWORD dwSizeOut = (DWORD)strlen(lpSrc) + 0x100;;) { for (SIZE_T sSizeOut = SIZETAdd(strlen(lpSrc), 0x100);;) {
DWORD dwSizeIn = dwSizeOut; if (sSizeOut > DWORD_MAX)
throw std::invalid_argument("String too big");
DWORD dwSizeIn = static_cast<DWORD>(sSizeOut);
std::unique_ptr<char[]> szBuffer(new char[(size_t)dwSizeIn + 2]); // Note: ANSI version requires one extra char. std::unique_ptr<char[]> szBuffer(new char[(size_t)dwSizeIn + 2]); // Note: ANSI version requires one extra char.
dwSizeOut = ::ExpandEnvironmentStringsA(lpSrc, szBuffer.get(), dwSizeIn); sSizeOut = ::ExpandEnvironmentStringsA(lpSrc, szBuffer.get(), dwSizeIn);
if (dwSizeOut == 0) { if (sSizeOut == 0) {
// Error or zero-length input. // Error or zero-length input.
break; break;
} else if (dwSizeOut <= dwSizeIn) { } else if (sSizeOut <= dwSizeIn) {
// The buffer was sufficient. // The buffer was sufficient.
sValue.assign(szBuffer.get(), dwSizeOut - 1); sValue.assign(szBuffer.get(), sSizeOut - 1);
return dwSizeOut; return static_cast<DWORD>(sSizeOut);
} }
} }
@ -205,19 +207,21 @@ static _Success_(return != 0) DWORD ExpandEnvironmentStringsA(_In_z_ LPCSTR lpSr
/// \sa [ExpandEnvironmentStrings function](https://msdn.microsoft.com/en-us/library/windows/desktop/ms724265.aspx) /// \sa [ExpandEnvironmentStrings function](https://msdn.microsoft.com/en-us/library/windows/desktop/ms724265.aspx)
/// ///
template<class _Traits, class _Ax> template<class _Traits, class _Ax>
static _Success_(return != 0) DWORD ExpandEnvironmentStringsW(_In_z_ LPCWSTR lpSrc, _Out_ std::basic_string<wchar_t, _Traits, _Ax> &sValue) noexcept static _Success_(return != 0) DWORD ExpandEnvironmentStringsW(_In_z_ LPCWSTR lpSrc, _Out_ std::basic_string<wchar_t, _Traits, _Ax> &sValue)
{ {
for (DWORD dwSizeOut = (DWORD)wcslen(lpSrc) + 0x100;;) { for (SIZE_T sSizeOut = SIZETAdd(wcslen(lpSrc), 0x100);;) {
DWORD dwSizeIn = dwSizeOut; if (sSizeOut > DWORD_MAX)
throw std::invalid_argument("String too big");
DWORD dwSizeIn = static_cast<DWORD>(sSizeOut);
std::unique_ptr<wchar_t[]> szBuffer(new wchar_t[(size_t)dwSizeIn + 1]); std::unique_ptr<wchar_t[]> szBuffer(new wchar_t[(size_t)dwSizeIn + 1]);
dwSizeOut = ::ExpandEnvironmentStringsW(lpSrc, szBuffer.get(), dwSizeIn); sSizeOut = ::ExpandEnvironmentStringsW(lpSrc, szBuffer.get(), dwSizeIn);
if (dwSizeOut == 0) { if (sSizeOut == 0) {
// Error or zero-length input. // Error or zero-length input.
break; break;
} else if (dwSizeOut <= dwSizeIn) { } else if (sSizeOut <= dwSizeIn) {
// The buffer was sufficient. // The buffer was sufficient.
sValue.assign(szBuffer.get(), dwSizeOut - 1); sValue.assign(szBuffer.get(), sSizeOut - 1);
return dwSizeOut; return static_cast<DWORD>(sSizeOut);
} }
} }