Skip to content

Commit

Permalink
Rollback to minhook for hooking
Browse files Browse the repository at this point in the history
  • Loading branch information
AkaiiKitsune committed Nov 25, 2024
1 parent d0374dc commit 4bf913c
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 84 deletions.
3 changes: 3 additions & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_project_link_arguments(
language: 'cpp',
)

minhook = subproject('minhook')
tomlc99 = subproject('tomlc99')
sdl2 = subproject('sdl2', default_options: ['default_library=static', 'test=false', 'use_render=disabled'])
xxhash = subproject('xxhash', default_options: ['default_library=static', 'cli=false'])
Expand All @@ -52,6 +53,7 @@ pugixml_dep = pugixml.get_variable('pugixml_static_dep')
library(
'bnusio',
link_with: [
minhook.get_variable('minhook_lib'),
tomlc99.get_variable('tomlc99_lib'),
sdl2.get_variable('sdl2'),
xxhash.get_variable('xxhash'),
Expand All @@ -61,6 +63,7 @@ library(
link_args : '-Wl,--allow-multiple-definition',
include_directories: [
'src',
minhook.get_variable('minhook_inc'),
tomlc99.get_variable('tomlc99_inc'),
sdl2.get_variable('core_inc'),
xxhash.get_variable('inc'),
Expand Down
2 changes: 1 addition & 1 deletion src/bnusio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ HOOK (u64, bngrw_ReqWaitTouch, PROC_ADDRESS ("bngrw.dll", "BngRwReqWaitTouch"),
return 1;
} else {
// This is called when we use an original card reader and acceptInvalidCards is set to true
return originalbngrw_ReqWaitTouch.call<u64> (a1, a2, a3, InspectWaitTouch, _touchData);
return originalbngrw_ReqWaitTouch (a1, a2, a3, InspectWaitTouch, _touchData);
}
}

Expand Down
12 changes: 6 additions & 6 deletions src/dllmain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ std::string logLevelStr = "INFO";
bool logToFile = true;

HWND hGameWnd;
HOOK (i32, ShowMouse, PROC_ADDRESS ("user32.dll", "ShowCursor"), bool) { return originalShowMouse.call<i32> (true); }
HOOK (i32, ShowMouse, PROC_ADDRESS ("user32.dll", "ShowCursor"), bool) { return originalShowMouse (true); }
HOOK (i32, ExitWindows, PROC_ADDRESS ("user32.dll", "ExitWindowsEx")) { ExitProcess (0); }
HOOK (HWND, CreateWindow, PROC_ADDRESS ("user32.dll", "CreateWindowExW"), DWORD dwExStyle, LPCWSTR lpClassName, LPCWSTR lpWindowName, DWORD dwStyle,
i32 X, i32 Y, i32 nWidth, i32 nHeight, HWND hWndParent, HMENU hMenu, HINSTANCE hInstance, LPVOID lpParam) {
if (lpWindowName != NULL) {
if (wcscmp (lpWindowName, L"Taiko") == 0) {
if (windowed) dwStyle = WS_TILEDWINDOW ^ WS_MAXIMIZEBOX ^ WS_THICKFRAME;

hGameWnd = originalCreateWindow.call<HWND> (dwExStyle, lpClassName, lpWindowName, dwStyle, X, Y, nWidth, nHeight, hWndParent, hMenu,
hGameWnd = originalCreateWindow (dwExStyle, lpClassName, lpWindowName, dwStyle, X, Y, nWidth, nHeight, hWndParent, hMenu,
hInstance, lpParam);
return hGameWnd;
}
}
return originalCreateWindow.call<HWND> (dwExStyle, lpClassName, lpWindowName, dwStyle, X, Y, nWidth, nHeight, hWndParent, hMenu, hInstance,
return originalCreateWindow (dwExStyle, lpClassName, lpWindowName, dwStyle, X, Y, nWidth, nHeight, hWndParent, hMenu, hInstance,
lpParam);
}
HOOK (bool, SetWindowPosition, PROC_ADDRESS ("user32.dll", "SetWindowPos"), HWND hWnd, HWND hWndInsertAfter, i32 X, i32 Y, i32 cx, i32 cy,
Expand All @@ -60,12 +60,12 @@ HOOK (bool, SetWindowPosition, PROC_ADDRESS ("user32.dll", "SetWindowPos"), HWND
cx = (rw.right - rw.left) - (rc.right - rc.left) + cx;
cy = (rw.bottom - rw.top) - (rc.bottom - rc.top) + cy;
}
return originalSetWindowPosition.call<bool> (hWnd, hWndInsertAfter, X, Y, cx, cy, uFlags);
return originalSetWindowPosition (hWnd, hWndInsertAfter, X, Y, cx, cy, uFlags);
}

HOOK (void, ExitProcessHook, PROC_ADDRESS ("kernel32.dll", "ExitProcess"), u32 uExitCode) {
bnusio::Close ();
originalExitProcessHook.call<void> (uExitCode);
originalExitProcessHook (uExitCode);
}

HOOK (i32, XinputGetState, PROC_ADDRESS ("xinput9_1_0.dll", "XInputGetState")) { return ERROR_DEVICE_NOT_CONNECTED; }
Expand All @@ -82,7 +82,7 @@ HOOK (i64, UsbFinderGetSerialNumber, PROC_ADDRESS ("nbamUsbFinder.dll", "nbamUsb
}

HOOK (i32, ws2_getaddrinfo, PROC_ADDRESS ("ws2_32.dll", "getaddrinfo"), const char *node, char *service, void *hints, void *out) {
return originalws2_getaddrinfo.call<i32> (server.c_str (), service, hints, out);
return originalws2_getaddrinfo (server.c_str (), service, hints, out);
}

void
Expand Down
63 changes: 34 additions & 29 deletions src/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <map>
#include <mutex>
#include <safetyhook.hpp>
#include <MinHook.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
Expand Down Expand Up @@ -36,35 +37,41 @@ const HMODULE MODULE_HANDLE = GetModuleHandle (nullptr);
#define ASLR(address) ((u64)MODULE_HANDLE + (u64)address - (u64)BASE_ADDRESS)
#endif

#define HOOK(returnType, functionName, location, ...) \
SafetyHookInline original##functionName{}; \
void *where##functionName = (void *)location; \
#define HOOK(returnType, functionName, location, ...) \
typedef returnType (*functionName) (__VA_ARGS__); \
functionName original##functionName = NULL; \
void *where##functionName = (void *)(location); \
returnType implOf##functionName (__VA_ARGS__)

#define HOOK_DYNAMIC(returnType, functionName, ...) \
SafetyHookInline original##functionName{}; \
void *where##functionName = NULL; \
#define HOOK_DYNAMIC(returnType, functionName, ...) \
typedef returnType (*functionName) (__VA_ARGS__); \
functionName original##functionName = NULL; \
void *where##functionName = NULL; \
returnType implOf##functionName (__VA_ARGS__)

#define VTABLE_HOOK(returnType, className, functionName, ...) \
SafetyHookInline original##className##functionName{}; \
void *where##className##functionName = NULL; \
#define VTABLE_HOOK(returnType, className, functionName, ...) \
typedef returnType (*className##functionName) (className * This, __VA_ARGS__); \
className##functionName original##className##functionName = NULL; \
void *where##className##functionName = NULL; \
returnType implOf##className##functionName (className *This, __VA_ARGS__)

#define MID_HOOK(functionName, location, ...) \
SafetyHookMid midHook##functionName{}; \
void *where##functionName = (void *)location; \
#define MID_HOOK(functionName, location, ...) \
typedef void (*functionName) (__VA_ARGS__); \
SafetyHookMid midHook##functionName{}; \
u64 where##functionName = (location); \
void implOf##functionName (SafetyHookContext &ctx)

#define MID_HOOK_DYNAMIC(functionName, ...) \
SafetyHookMid midHook##functionName{}; \
void *where##functionName = NULL; \
#define MID_HOOK_DYNAMIC(functionName, ...) \
typedef void (*functionName) (__VA_ARGS__); \
std::map<u64, SafetyHookMid> mapOf##functionName; \
void implOf##functionName (SafetyHookContext &ctx)

#define INSTALL_HOOK(functionName) \
{ \
LogMessage (LOG_LEVEL_DEBUG, (std::string ("Installing hook for ") + #functionName).c_str ()); \
original##functionName = safetyhook::create_inline (where##functionName, implOf##functionName); \
#define INSTALL_HOOK(functionName) \
{ \
LogMessage (LOG_LEVEL_DEBUG, (std::string ("Installing hook for ") + #functionName).c_str ()); \
MH_Initialize (); \
MH_CreateHook ((void *)where##functionName, (void *)implOf##functionName, (void **)(&original##functionName)); \
MH_EnableHook ((void *)where##functionName); \
}

#define INSTALL_HOOK_DYNAMIC(functionName, location) \
Expand All @@ -76,7 +83,9 @@ const HMODULE MODULE_HANDLE = GetModuleHandle (nullptr);
#define INSTALL_HOOK_DIRECT(location, locationOfHook) \
{ \
LogMessage (LOG_LEVEL_DEBUG, (std::string ("Installing direct hook for ") + #location).c_str ()); \
directHooks.push_back (safetyhook::create_inline ((void *)location, (void *)locationOfHook)); \
MH_Initialize (); \
MH_CreateHook ((void *)(location), (void *)(locationOfHook), NULL); \
MH_EnableHook ((void *)(location)); \
}

#define INSTALL_VTABLE_HOOK(className, object, functionName, functionIndex) \
Expand All @@ -92,32 +101,29 @@ const HMODULE MODULE_HANDLE = GetModuleHandle (nullptr);
}

#define INSTALL_MID_HOOK_DYNAMIC(functionName, location) \
{ \
where##functionName = (void *)location; \
INSTALL_MID_HOOK (functionName); \
}
{ mapOf##functionName[location] = safetyhook::create_mid (location, implOf##functionName); }

bool sendFlag = false;
#define SCENE_RESULT_HOOK(functionName, location) \
HOOK (void, functionName, location, i64 a1, i64 a2, i64 a3) { \
if (TestMode::ReadTestModeValue (L"ModInstantResult") != 1 && TestMode::ReadTestModeValue (L"NumberOfStageItem") <= 4) { \
original##functionName.call (a1, a2, a3); \
original##functionName (a1, a2, a3); \
return; \
} \
sendFlag = true; \
original##functionName.call (a1, a2, a3); \
original##functionName (a1, a2, a3); \
ExecuteSendResultData (); \
}

#define SEND_RESULT_HOOK(functionName, location) \
HOOK (void, functionName, location, i64 a1) { \
if (TestMode::ReadTestModeValue (L"ModInstantResult") != 1 && TestMode::ReadTestModeValue (L"NumberOfStageItem") <= 4) { \
original##functionName.call (a1); \
original##functionName (a1); \
return; \
} \
if (sendFlag) { \
sendFlag = false; \
original##functionName.call (a1); \
original##functionName (a1); \
} \
}

Expand Down Expand Up @@ -189,7 +195,6 @@ const std::string readConfigString (toml_table_t *table, const std::string &key,
std::vector<int64_t> readConfigIntArray (toml_table_t *table, const std::string &key, std::vector<int64_t> notFoundValue);
std::wstring replace (const std::wstring orignStr, const std::wstring oldStr, const std::wstring newStr);
std::string replace (const std::string orignStr, const std::string oldStr, const std::string newStr);
std::vector<SafetyHookInline> directHooks = {};
const char *GameVersionToString (GameVersion version);
const char *languageStr (int language);
std::string ConvertWideToUtf8 (const std::wstring &wstr);
Expand Down
12 changes: 8 additions & 4 deletions src/patches/amauth.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "helpers.h"
#include <bits/stdc++.h>
#include <format>
#include <safetyhook.hpp>
#include <MinHook.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
Expand Down Expand Up @@ -533,7 +533,8 @@ class CAuthFactory final : public IClassFactory {
virtual HRESULT LockServer (int32_t lock) { return 0; }
};

SafetyHookInline g_origCoCreateInstance{};
static HRESULT (STDAPICALLTYPE *g_origCoCreateInstance) (const IID *const rclsid, LPUNKNOWN pUnkOuter, DWORD dwClsContext, const IID *const riid,
LPVOID *ppv);

static HRESULT STDAPICALLTYPE
CoCreateInstanceHook (const IID *const rclsid, LPUNKNOWN pUnkOuter, DWORD dwClsContext, const IID *const riid, LPVOID *ppv) {
Expand All @@ -548,7 +549,7 @@ CoCreateInstanceHook (const IID *const rclsid, LPUNKNOWN pUnkOuter, DWORD dwClsC
auto cauth = new CAuth ();
result = cauth->QueryInterface (*riid, ppv);
} else {
result = g_origCoCreateInstance.call<HRESULT> (rclsid, pUnkOuter, dwClsContext, riid, ppv);
result = g_origCoCreateInstance (rclsid, pUnkOuter, dwClsContext, riid, ppv);
}

CoTaskMemFree (clsidStr);
Expand All @@ -560,7 +561,10 @@ void
Init () {
LogMessage (LOG_LEVEL_DEBUG, "Init AmAuth patches");

g_origCoCreateInstance = safetyhook::create_inline (PROC_ADDRESS ("ole32.dll", "CoCreateInstance"), CoCreateInstanceHook);
MH_Initialize ();
MH_CreateHookApi (L"ole32.dll", "CoCreateInstance", (LPVOID)CoCreateInstanceHook,
(void **)&g_origCoCreateInstance); // NOLINT(clang-diagnostic-microsoft-cast)
MH_EnableHook (nullptr);

struct addrinfo *res = 0;
getaddrinfo (server.c_str (), "", 0, &res);
Expand Down
4 changes: 2 additions & 2 deletions src/patches/audio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ HOOK_DYNAMIC (i64, NUSCDeviceInit, void *a1, nusc_init_config_t *a2, nusc_init_c
a2->device_mode = asio;
a2->asio_driver_name = asio ? asioDriver.c_str () : "";
a2->wasapi_exclusive = asio ? 1 : wasapiShared ? 0 : 1;
return originalNUSCDeviceInit.call<i64> (a1, a2, a3, a4);
return originalNUSCDeviceInit (a1, a2, a3, a4);
}
HOOK_DYNAMIC (bool, LoadASIODriver, void *a1, const char *a2) {
auto result = originalLoadASIODriver.call<bool> (a1, a2);
auto result = originalLoadASIODriver (a1, a2);
if (!result) {
LogMessage (LOG_LEVEL_ERROR, (std::string ("Failed to load ASIO driver ") + asioDriver).c_str ());
MessageBoxA (nullptr, "Failed to load ASIO driver", nullptr, MB_OK);
Expand Down
Loading

0 comments on commit 4bf913c

Please # to comment.