//===- KillTheDoctor - Prevent Dr. Watson from stopping tests ---*- C++ -*-===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// // // This program provides an extremely hacky way to stop Dr. Watson from starting // due to unhandled exceptions in child processes. // // This simply starts the program named in the first positional argument with // the arguments following it under a debugger. All this debugger does is catch // any unhandled exceptions thrown in the child process and close the program // (and hopefully tells someone about it). // // This also provides another really hacky method to prevent assert dialog boxes // from poping up. When --no-user32 is passed, if any process loads user32.dll, // we assume it is trying to call MessageBoxEx and terminate it. The proper way // to do this would be to actually set a break point, but there's quite a bit // of code involved to get the address of MessageBoxEx in the remote process's // address space due to Address space layout randomization (ASLR). This can be // added if it's ever actually needed. // // If the subprocess exits for any reason other than sucessful termination, -1 // is returned. If the process exits normally the value it returned is returned. // // I hate Windows. // //===----------------------------------------------------------------------===// #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Support/type_traits.h" #include "llvm/Support/Signals.h" #include "llvm/Support/system_error.h" #include #include #include #include #include #include #include #include #include using namespace llvm; #undef max namespace { cl::opt ProgramToRun(cl::Positional, cl::desc("")); cl::list Argv(cl::ConsumeAfter, cl::desc("...")); cl::opt TraceExecution("x", cl::desc("Print detailed output about what is being run to stderr.")); cl::opt Timeout("t", cl::init(0), cl::desc("Set maximum runtime in seconds. Defaults to infinite.")); cl::opt NoUser32("no-user32", cl::desc("Terminate process if it loads user32.dll.")); StringRef ToolName; template class ScopedHandle { typedef typename HandleType::handle_type handle_type; handle_type Handle; public: ScopedHandle() : Handle(HandleType::GetInvalidHandle()) {} explicit ScopedHandle(handle_type handle) : Handle(handle) {} ~ScopedHandle() { HandleType::Destruct(Handle); } ScopedHandle& operator=(handle_type handle) { // Cleanup current handle. if (!HandleType::isValid(Handle)) HandleType::Destruct(Handle); Handle = handle; return *this; } operator bool() const { return HandleType::isValid(Handle); } operator handle_type() { return Handle; } }; // This implements the most common handle in the Windows API. struct CommonHandle { typedef HANDLE handle_type; static handle_type GetInvalidHandle() { return INVALID_HANDLE_VALUE; } static void Destruct(handle_type Handle) { ::CloseHandle(Handle); } static bool isValid(handle_type Handle) { return Handle != GetInvalidHandle(); } }; struct FileMappingHandle { typedef HANDLE handle_type; static handle_type GetInvalidHandle() { return NULL; } static void Destruct(handle_type Handle) { ::CloseHandle(Handle); } static bool isValid(handle_type Handle) { return Handle != GetInvalidHandle(); } }; struct MappedViewOfFileHandle { typedef LPVOID handle_type; static handle_type GetInvalidHandle() { return NULL; } static void Destruct(handle_type Handle) { ::UnmapViewOfFile(Handle); } static bool isValid(handle_type Handle) { return Handle != GetInvalidHandle(); } }; struct ProcessHandle : CommonHandle {}; struct ThreadHandle : CommonHandle {}; struct TokenHandle : CommonHandle {}; struct FileHandle : CommonHandle {}; typedef ScopedHandle FileMappingScopedHandle; typedef ScopedHandle MappedViewOfFileScopedHandle; typedef ScopedHandle ProcessScopedHandle; typedef ScopedHandle ThreadScopedHandle; typedef ScopedHandle TokenScopedHandle; typedef ScopedHandle FileScopedHandle; error_code get_windows_last_error() { return make_error_code(windows_error(::GetLastError())); } } static error_code GetFileNameFromHandle(HANDLE FileHandle, std::string& Name) { char Filename[MAX_PATH+1]; bool Sucess = false; Name.clear(); // Get the file size. LARGE_INTEGER FileSize; Sucess = ::GetFileSizeEx(FileHandle, &FileSize); if (!Sucess) return get_windows_last_error(); // Create a file mapping object. FileMappingScopedHandle FileMapping( ::CreateFileMappingA(FileHandle, NULL, PAGE_READONLY, 0, 1, NULL)); if (!FileMapping) return get_windows_last_error(); // Create a file mapping to get the file name. MappedViewOfFileScopedHandle MappedFile( ::MapViewOfFile(FileMapping, FILE_MAP_READ, 0, 0, 1)); if (!MappedFile) return get_windows_last_error(); Sucess = ::GetMappedFileNameA(::GetCurrentProcess(), MappedFile, Filename, array_lengthof(Filename) - 1); if (!Sucess) return get_windows_last_error(); else { Name = Filename; return windows_error::success; } } static std::string QuoteProgramPathIfNeeded(StringRef Command) { if (Command.find_first_of(' ') == StringRef::npos) return Command; else { std::string ret; ret.reserve(Command.size() + 3); ret.push_back('"'); ret.append(Command.begin(), Command.end()); ret.push_back('"'); return ret; } } /// @brief Find program using shell lookup rules. /// @param Program This is either an absolute path, relative path, or simple a /// program name. Look in PATH for any programs that match. If no /// extension is present, try all extensions in PATHEXT. /// @return If ec == errc::success, The absolute path to the program. Otherwise /// the return value is undefined. static std::string FindProgram(const std::string &Program, error_code &ec) { char PathName[MAX_PATH + 1]; typedef SmallVector pathext_t; pathext_t pathext; // Check for the program without an extension (in case it already has one). pathext.push_back(""); SplitString(std::getenv("PATHEXT"), pathext, ";"); for (pathext_t::iterator i = pathext.begin(), e = pathext.end(); i != e; ++i){ SmallString<5> ext; for (std::size_t ii = 0, e = i->size(); ii != e; ++ii) ext.push_back(::tolower((*i)[ii])); LPCSTR Extension = NULL; if (ext.size() && ext[0] == '.') Extension = ext.c_str(); DWORD length = ::SearchPathA(NULL, Program.c_str(), Extension, array_lengthof(PathName), PathName, NULL); if (length == 0) ec = get_windows_last_error(); else if (length > array_lengthof(PathName)) { // This may have been the file, return with error. ec = error_code(windows_error::buffer_overflow); break; } else { // We found the path! Return it. ec = windows_error::success; break; } } // Make sure PathName is valid. PathName[MAX_PATH] = 0; return PathName; } static error_code EnableDebugPrivileges() { HANDLE TokenHandle; BOOL success = ::OpenProcessToken(::GetCurrentProcess(), TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &TokenHandle); if (!success) return error_code(::GetLastError(), system_category()); TokenScopedHandle Token(TokenHandle); TOKEN_PRIVILEGES TokenPrivileges; LUID LocallyUniqueID; success = ::LookupPrivilegeValueA(NULL, SE_DEBUG_NAME, &LocallyUniqueID); if (!success) return error_code(::GetLastError(), system_category()); TokenPrivileges.PrivilegeCount = 1; TokenPrivileges.Privileges[0].Luid = LocallyUniqueID; TokenPrivileges.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED; success = ::AdjustTokenPrivileges(Token, FALSE, &TokenPrivileges, sizeof(TOKEN_PRIVILEGES), NULL, NULL); // The value of success is basically useless. Either way we are just returning // the value of ::GetLastError(). return error_code(::GetLastError(), system_category()); } static StringRef ExceptionCodeToString(DWORD ExceptionCode) { switch(ExceptionCode) { case EXCEPTION_ACCESS_VIOLATION: return "EXCEPTION_ACCESS_VIOLATION"; case EXCEPTION_ARRAY_BOUNDS_EXCEEDED: return "EXCEPTION_ARRAY_BOUNDS_EXCEEDED"; case EXCEPTION_BREAKPOINT: return "EXCEPTION_BREAKPOINT"; case EXCEPTION_DATATYPE_MISALIGNMENT: return "EXCEPTION_DATATYPE_MISALIGNMENT"; case EXCEPTION_FLT_DENORMAL_OPERAND: return "EXCEPTION_FLT_DENORMAL_OPERAND"; case EXCEPTION_FLT_DIVIDE_BY_ZERO: return "EXCEPTION_FLT_DIVIDE_BY_ZERO"; case EXCEPTION_FLT_INEXACT_RESULT: return "EXCEPTION_FLT_INEXACT_RESULT"; case EXCEPTION_FLT_INVALID_OPERATION: return "EXCEPTION_FLT_INVALID_OPERATION"; case EXCEPTION_FLT_OVERFLOW: return "EXCEPTION_FLT_OVERFLOW"; case EXCEPTION_FLT_STACK_CHECK: return "EXCEPTION_FLT_STACK_CHECK"; case EXCEPTION_FLT_UNDERFLOW: return "EXCEPTION_FLT_UNDERFLOW"; case EXCEPTION_ILLEGAL_INSTRUCTION: return "EXCEPTION_ILLEGAL_INSTRUCTION"; case EXCEPTION_IN_PAGE_ERROR: return "EXCEPTION_IN_PAGE_ERROR"; case EXCEPTION_INT_DIVIDE_BY_ZERO: return "EXCEPTION_INT_DIVIDE_BY_ZERO"; case EXCEPTION_INT_OVERFLOW: return "EXCEPTION_INT_OVERFLOW"; case EXCEPTION_INVALID_DISPOSITION: return "EXCEPTION_INVALID_DISPOSITION"; case EXCEPTION_NONCONTINUABLE_EXCEPTION: return "EXCEPTION_NONCONTINUABLE_EXCEPTION"; case EXCEPTION_PRIV_INSTRUCTION: return "EXCEPTION_PRIV_INSTRUCTION"; case EXCEPTION_SINGLE_STEP: return "EXCEPTION_SINGLE_STEP"; case EXCEPTION_STACK_OVERFLOW: return "EXCEPTION_STACK_OVERFLOW"; default: return ""; } } int main(int argc, char **argv) { // Print a stack trace if we signal out. sys::PrintStackTraceOnErrorSignal(); PrettyStackTraceProgram X(argc, argv); llvm_shutdown_obj Y; // Call llvm_shutdown() on exit. ToolName = argv[0]; cl::ParseCommandLineOptions(argc, argv, "Dr. Watson Assassin.\n"); if (ProgramToRun.size() == 0) { cl::PrintHelpMessage(); return -1; } if (Timeout > std::numeric_limits::max() / 1000) { errs() << ToolName << ": Timeout value too large, must be less than: " << std::numeric_limits::max() / 1000 << '\n'; return -1; } std::string CommandLine(ProgramToRun); error_code ec; ProgramToRun = FindProgram(ProgramToRun, ec); if (ec) { errs() << ToolName << ": Failed to find program: '" << CommandLine << "': " << ec.message() << '\n'; return -1; } if (TraceExecution) errs() << ToolName << ": Found Program: " << ProgramToRun << '\n'; for (std::vector::iterator i = Argv.begin(), e = Argv.end(); i != e; ++i) { CommandLine.push_back(' '); CommandLine.append(*i); } if (TraceExecution) errs() << ToolName << ": Program Image Path: " << ProgramToRun << '\n' << ToolName << ": Command Line: " << CommandLine << '\n'; STARTUPINFO StartupInfo; PROCESS_INFORMATION ProcessInfo; std::memset(&StartupInfo, 0, sizeof(StartupInfo)); StartupInfo.cb = sizeof(StartupInfo); std::memset(&ProcessInfo, 0, sizeof(ProcessInfo)); // Set error mode to not display any message boxes. The child process inherets // this. ::SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX); ::_set_error_mode(_OUT_TO_STDERR); BOOL success = ::CreateProcessA(ProgramToRun.c_str(), LPSTR(CommandLine.c_str()), NULL, NULL, FALSE, DEBUG_PROCESS, NULL, NULL, &StartupInfo, &ProcessInfo); if (!success) { errs() << ToolName << ": Failed to run program: '" << ProgramToRun << "': " << error_code(::GetLastError(), system_category()).message() << '\n'; return -1; } // Make sure ::CloseHandle is called on exit. std::map ProcessIDToHandle; DEBUG_EVENT DebugEvent; std::memset(&DebugEvent, 0, sizeof(DebugEvent)); DWORD dwContinueStatus = DBG_CONTINUE; // Run the program under the debugger until either it exits, or throws an // exception. if (TraceExecution) errs() << ToolName << ": Debugging...\n"; while(true) { DWORD TimeLeft = INFINITE; if (Timeout > 0) { FILETIME CreationTime, ExitTime, KernelTime, UserTime; ULARGE_INTEGER a, b; success = ::GetProcessTimes(ProcessInfo.hProcess, &CreationTime, &ExitTime, &KernelTime, &UserTime); if (!success) { ec = error_code(::GetLastError(), system_category()); errs() << ToolName << ": Failed to get process times: " << ec.message() << '\n'; return -1; } a.LowPart = KernelTime.dwLowDateTime; a.HighPart = KernelTime.dwHighDateTime; b.LowPart = UserTime.dwLowDateTime; b.HighPart = UserTime.dwHighDateTime; // Convert 100-nanosecond units to miliseconds. uint64_t TotalTimeMiliseconds = (a.QuadPart + b.QuadPart) / 10000; // Handle the case where the process has been running for more than 49 // days. if (TotalTimeMiliseconds > std::numeric_limits::max()) { errs() << ToolName << ": Timeout Failed: Process has been running for" "more than 49 days.\n"; return -1; } // We check with > instead of using Timeleft because if // TotalTimeMiliseconds is greater than Timeout * 1000, TimeLeft would // underflow. if (TotalTimeMiliseconds > (Timeout * 1000)) { errs() << ToolName << ": Process timed out.\n"; ::TerminateProcess(ProcessInfo.hProcess, -1); // Otherwise other stuff starts failing... return -1; } TimeLeft = (Timeout * 1000) - static_cast(TotalTimeMiliseconds); } success = WaitForDebugEvent(&DebugEvent, TimeLeft); if (!success) { ec = error_code(::GetLastError(), system_category()); if (ec == error_condition(errc::timed_out)) { errs() << ToolName << ": Process timed out.\n"; ::TerminateProcess(ProcessInfo.hProcess, -1); // Otherwise other stuff starts failing... return -1; } errs() << ToolName << ": Failed to wait for debug event in program: '" << ProgramToRun << "': " << ec.message() << '\n'; return -1; } switch(DebugEvent.dwDebugEventCode) { case CREATE_PROCESS_DEBUG_EVENT: // Make sure we remove the handle on exit. if (TraceExecution) errs() << ToolName << ": Debug Event: CREATE_PROCESS_DEBUG_EVENT\n"; ProcessIDToHandle[DebugEvent.dwProcessId] = DebugEvent.u.CreateProcessInfo.hProcess; ::CloseHandle(DebugEvent.u.CreateProcessInfo.hFile); break; case EXIT_PROCESS_DEBUG_EVENT: { if (TraceExecution) errs() << ToolName << ": Debug Event: EXIT_PROCESS_DEBUG_EVENT\n"; // If this is the process we origionally created, exit with its exit // code. if (DebugEvent.dwProcessId == ProcessInfo.dwProcessId) return DebugEvent.u.ExitProcess.dwExitCode; // Otherwise cleanup any resources we have for it. std::map::iterator ExitingProcess = ProcessIDToHandle.find(DebugEvent.dwProcessId); if (ExitingProcess == ProcessIDToHandle.end()) { errs() << ToolName << ": Got unknown process id!\n"; return -1; } ::CloseHandle(ExitingProcess->second); ProcessIDToHandle.erase(ExitingProcess); } break; case CREATE_THREAD_DEBUG_EVENT: ::CloseHandle(DebugEvent.u.CreateThread.hThread); break; case LOAD_DLL_DEBUG_EVENT: { // Cleanup the file handle. FileScopedHandle DLLFile(DebugEvent.u.LoadDll.hFile); std::string DLLName; ec = GetFileNameFromHandle(DLLFile, DLLName); if (ec) { DLLName = " : "; DLLName += ec.message(); } if (TraceExecution) { errs() << ToolName << ": Debug Event: LOAD_DLL_DEBUG_EVENT\n"; errs().indent(ToolName.size()) << ": DLL Name : " << DLLName << '\n'; } if (NoUser32 && sys::Path(DLLName).getBasename() == "user32") { // Program is loading user32.dll, in the applications we are testing, // this only happens if an assert has fired. By now the message has // already been printed, so simply close the program. errs() << ToolName << ": user32.dll loaded!\n"; errs().indent(ToolName.size()) << ": This probably means that assert was called. Closing " "program to prevent message box from poping up.\n"; dwContinueStatus = DBG_CONTINUE; ::TerminateProcess(ProcessIDToHandle[DebugEvent.dwProcessId], -1); return -1; } } break; case EXCEPTION_DEBUG_EVENT: { // Close the application if this exception will not be handled by the // child application. if (TraceExecution) errs() << ToolName << ": Debug Event: EXCEPTION_DEBUG_EVENT\n"; EXCEPTION_DEBUG_INFO &Exception = DebugEvent.u.Exception; if (Exception.dwFirstChance > 0) { if (TraceExecution) { errs().indent(ToolName.size()) << ": Debug Info : "; errs() << "First chance exception at " << Exception.ExceptionRecord.ExceptionAddress << ", exception code: " << ExceptionCodeToString( Exception.ExceptionRecord.ExceptionCode) << " (" << Exception.ExceptionRecord.ExceptionCode << ")\n"; } dwContinueStatus = DBG_EXCEPTION_NOT_HANDLED; } else { errs() << ToolName << ": Unhandled exception in: " << ProgramToRun << "!\n"; errs().indent(ToolName.size()) << ": location: "; errs() << Exception.ExceptionRecord.ExceptionAddress << ", exception code: " << ExceptionCodeToString( Exception.ExceptionRecord.ExceptionCode) << " (" << Exception.ExceptionRecord.ExceptionCode << ")\n"; dwContinueStatus = DBG_CONTINUE; ::TerminateProcess(ProcessIDToHandle[DebugEvent.dwProcessId], -1); return -1; } } break; default: // Do nothing. if (TraceExecution) errs() << ToolName << ": Debug Event: \n"; break; } success = ContinueDebugEvent(DebugEvent.dwProcessId, DebugEvent.dwThreadId, dwContinueStatus); if (!success) { ec = error_code(::GetLastError(), system_category()); errs() << ToolName << ": Failed to continue debugging program: '" << ProgramToRun << "': " << ec.message() << '\n'; return -1; } dwContinueStatus = DBG_CONTINUE; } assert(0 && "Fell out of debug loop. This shouldn't be possible!"); return -1; }