summaryrefslogtreecommitdiff
path: root/hook/src/utils.zig
diff options
context:
space:
mode:
Diffstat (limited to 'hook/src/utils.zig')
-rw-r--r--hook/src/utils.zig198
1 files changed, 198 insertions, 0 deletions
diff --git a/hook/src/utils.zig b/hook/src/utils.zig
new file mode 100644
index 0000000..57fdfec
--- /dev/null
+++ b/hook/src/utils.zig
@@ -0,0 +1,198 @@
+const std = @import("std");
+const w = std.os.windows;
+const builtin = @import("builtin");
+
+const x86 = @import("x86.zig");
+const mem = @import("mem.zig");
+
+const MODULEINFO = extern struct {
+ lpBaseOfDll: w.LPVOID,
+ SizeOfImage: w.DWORD,
+ EntryPoint: w.LPVOID,
+};
+
+extern "kernel32" fn FlushInstructionCache(hProcess: w.HANDLE, lpBaseAddress: w.LPCVOID, dwSize: w.SIZE_T) callconv(.winapi) w.BOOL;
+extern "kernel32" fn GetModuleHandleW(lpModuleName: ?w.LPCWSTR) callconv(.winapi) w.HMODULE;
+extern "kernel32" fn GetModuleInformation(hProcess: w.HANDLE, hModule: w.HMODULE, lpmodinfo: *MODULEINFO, cb: w.DWORD) callconv(.winapi) w.BOOL;
+extern "kernel32" fn GetCurrentProcess() callconv(.winapi) w.HANDLE;
+
+pub inline fn isHex(c: u8) bool {
+ return (c >= '0' and c <= '9') or (c >= 'a' and c <= 'f') or (c >= 'A' and c <= 'F');
+}
+
+pub fn makeHex(comptime str: []const u8) []const u8 {
+ return comptime blk: {
+ @setEvalBranchQuota(10000);
+ var it = std.mem.splitSequence(u8, str, " ");
+ var pat: []const u8 = &.{};
+
+ while (it.next()) |byte| {
+ if (byte.len != 2) {
+ @compileError("Each byte should be 2 characters");
+ }
+ if (isHex(byte[0])) {
+ if (!isHex(byte[1])) {
+ @compileError("The second hex digit is missing");
+ }
+ const n = try std.fmt.parseInt(u8, byte, 16);
+ pat = pat ++ .{n};
+ } else {
+ @compileError("Only hex digits are allowed");
+ }
+ }
+ break :blk pat;
+ };
+}
+
+pub fn patchCode(addr: [*]u8, data: []const u8, restore_protect: u32) !void {
+ if (builtin.os.tag == .windows) {
+ var old_protect: w.DWORD = undefined;
+
+ try w.VirtualProtect(addr, data.len, w.PAGE_EXECUTE_READWRITE, &old_protect);
+ @memcpy(addr, data);
+
+ _ = FlushInstructionCache(GetCurrentProcess(), addr, data.len);
+ try w.VirtualProtect(addr, data.len, old_protect, &old_protect);
+ } else {
+ const page_size = std.heap.page_size_min;
+ const addr_int = @intFromPtr(addr);
+ const page_start = addr_int & ~(page_size - 1);
+ const page_end = addr_int + data.len;
+ const page_len = (page_end - page_start + page_size - 1) & ~(page_size - 1);
+
+ const prot_all = 0b111; // rwx
+
+ if (std.c.mprotect(@ptrFromInt(page_start), page_len, prot_all) != 0)
+ return error.MProtectWritable;
+
+ @memcpy(addr, data);
+
+ if (std.c.mprotect(@ptrFromInt(page_start), page_len, restore_protect) != 0)
+ return error.MProtectRestore;
+ }
+}
+
+// Windows: Return entire module memory
+// Linux: Return the code segment memory
+pub fn getModule(comptime module_name: []const u8) ?[]const u8 {
+ return switch (builtin.os.tag) {
+ .windows => getModuleWindows(module_name),
+ .linux => getModuleLinux(module_name, 0b101) catch return null,
+ else => @compileError("getModule is not available for this target"),
+ };
+}
+
+// Return the entire module memory
+pub fn getEntireModule(comptime module_name: []const u8) ?[]const u8 {
+ return switch (builtin.os.tag) {
+ .windows => getModuleWindows(module_name),
+ .linux => getModuleLinux(module_name, 0) catch return null,
+ else => @compileError("getModule is not available for this target"),
+ };
+}
+
+fn getModuleWindows(comptime module_name: []const u8) ?[]const u8 {
+ const dll_name = module_name ++ ".dll";
+ const path_w = std.unicode.utf8ToUtf16LeStringLiteral(dll_name);
+ const dll = w.GetModuleHandleW(path_w) orelse return null;
+ var info: w.MODULEINFO = undefined;
+ if (GetModuleInformation(GetCurrentProcess(), dll, &info, @sizeOf(MODULEINFO)) == 0) {
+ return null;
+ }
+ const module: [*]const u8 = @ptrCast(dll);
+ return module[0..info.SizeOfImage];
+}
+
+fn getModuleLinux(comptime module_name: []const u8, permission: u32) !?[]const u8 {
+ const file_name = module_name ++ ".so";
+
+ const allocator = std.heap.page_allocator;
+ var file = try std.fs.openFileAbsolute("/proc/self/maps", .{ .mode = .read_only });
+ defer file.close();
+ var reader = file.reader();
+
+ var base: usize = 0;
+ var end: usize = 0;
+ var found = false;
+
+ while (try reader.readUntilDelimiterOrEofAlloc(allocator, '\n', 4096)) |line| {
+ defer allocator.free(line);
+
+ // Example format:
+ // de228000-de229000 r--p 00000000 00:29 1026008 /usr/lib/libstdc++.so.6.0.33
+
+ if (!std.mem.endsWith(u8, line, file_name)) {
+ if (found) break;
+ continue;
+ }
+
+ const pos = line.len - file_name.len;
+ if (line[pos - 1] != '/' and line[pos - 1] != ' ') continue;
+
+ const dash = std.mem.indexOfScalar(u8, line, '-') orelse continue;
+ const space = std.mem.indexOfScalarPos(u8, line, dash + 1, ' ') orelse continue;
+
+ const perms_start = space + 1;
+ if (line.len < perms_start + 4) continue;
+ const read = line[perms_start];
+ const write = line[perms_start + 1];
+ const exec = line[perms_start + 2];
+ if (permission & 0b001 != 0 and read == '-') continue;
+ if (permission & 0b010 != 0 and write == '-') continue;
+ if (permission & 0b100 != 0 and exec == '-') continue;
+
+ const start_hex = line[0..dash];
+ const end_hex = line[dash + 1 .. space];
+
+ const start_addr = try std.fmt.parseInt(usize, start_hex, 16);
+ const end_addr = try std.fmt.parseInt(usize, end_hex, 16);
+
+ if (!found) {
+ base = start_addr;
+ end = end_addr;
+ found = true;
+ } else if (start_addr == end) {
+ end = end_addr;
+ } else {
+ break;
+ }
+ }
+
+ if (!found) return null;
+
+ const size = end - base;
+ const ptr: [*]const u8 = @ptrFromInt(base);
+ return ptr[0..size];
+}
+
+// Match call + add pattern
+// If matched, inst + len will be the start of the imm32
+pub fn matchPIC(inst: [*]const u8) ?u32 {
+ if (inst[0] != x86.Opcode.Op1.call) return null;
+ if (inst[5] == x86.Opcode.Op1.alumiw) {
+ const modrm = inst[6];
+ // mod must be 0b11 (register operand)
+ if ((modrm & 0b1100_0000) != 0b1100_0000) return null;
+ // reg/opcode must be 0b000 (ADD)
+ if ((modrm & 0b0011_1000) != 0b0000_0000) return null;
+
+ // rm should not be 0b100 (ESP)
+ // Although it's rare, compiler occasionally uses EBP for PIC
+ const rm = modrm & 0b0000_0111;
+ if (rm == 0b100) return null;
+ return 7;
+ } else if (inst[5] == x86.Opcode.Op1.addeaxi) {
+ return 6;
+ }
+ return null;
+}
+
+const GOT_pattern = mem.makePattern("E8 ?? ?? ?? ?? 05 ?? ?? ?? ?? 8D 80");
+
+pub fn findGOTAddr(module: []const u8) ?u32 {
+ if (mem.scanFirst(module, GOT_pattern)) |offset| {
+ const imm32 = mem.loadValue(u32, module.ptr + offset + 6);
+ return @intFromPtr(module.ptr + offset + 5) +% imm32;
+ }
+ return null;
+}