diff --git a/atlex.h b/atlex.h index e9a5a60..adf0e93 100644 --- a/atlex.h +++ b/atlex.h @@ -19,6 +19,7 @@ #pragma once +#include #include #include @@ -267,4 +268,57 @@ namespace ATL typedef CStrFormatMsgT< wchar_t, StrTraitATL< wchar_t, ChTraitsCRT< wchar_t > > > CStrFormatMsgW; typedef CStrFormatMsgT< char, StrTraitATL< char, ChTraitsCRT< char > > > CStrFormatMsgA; typedef CStrFormatMsgT< TCHAR, StrTraitATL< TCHAR, ChTraitsCRT< TCHAR > > > CStrFormatMsg; + + + // + // CParanoidHeap + // + template + class CParanoidHeap : public BaseHeap { + public: + virtual void Free(_In_opt_ void* p) throw() + { + // Sanitize then free. + SecureZeroMemory(p, GetSize(p)); + BaseHeap::Free(p); + } + + _Ret_opt_bytecap_(nBytes) virtual void* Reallocate(_In_opt_ void* p, _In_ size_t nBytes) throw() + { + // Create a new sized copy. + void *pNew = Allocate(nBytes); + size_t nSizePrev = GetSize(p); + memcpy(pNew, p, nSizePrev); + + // Sanitize the old data then free. + SecureZeroMemory(p, nSizePrev); + Free(p); + + return pNew; + } + }; + + + // + // CW2AParanoidEX + // + template + class CW2AParanoidEX : public CW2AEX { + public: + CW2AParanoidEX(_In_z_ LPCWSTR psz) throw(...) : CW2AEX(psz) {} + CW2AParanoidEX(_In_z_ LPCWSTR psz, _In_ UINT nCodePage) throw(...) : CW2AEX(psz, nCodePage) {} + ~CW2AParanoidEX() throw() + { + // Sanitize before free. + if (m_psz != m_szBuffer) + SecureZeroMemory(m_psz, _msize(m_psz)); + else + SecureZeroMemory(m_szBuffer, sizeof(m_szBuffer)); + } + }; + + // + // CW2AParanoid + // + typedef CW2AParanoidEX<> CW2AParanoid; }