diff options
Diffstat (limited to 'hook/src/utils.zig')
-rw-r--r-- | hook/src/utils.zig | 198 |
1 files changed, 198 insertions, 0 deletions
diff --git a/hook/src/utils.zig b/hook/src/utils.zig new file mode 100644 index 0000000..57fdfec --- /dev/null +++ b/hook/src/utils.zig @@ -0,0 +1,198 @@ +const std = @import("std"); +const w = std.os.windows; +const builtin = @import("builtin"); + +const x86 = @import("x86.zig"); +const mem = @import("mem.zig"); + +const MODULEINFO = extern struct { + lpBaseOfDll: w.LPVOID, + SizeOfImage: w.DWORD, + EntryPoint: w.LPVOID, +}; + +extern "kernel32" fn FlushInstructionCache(hProcess: w.HANDLE, lpBaseAddress: w.LPCVOID, dwSize: w.SIZE_T) callconv(.winapi) w.BOOL; +extern "kernel32" fn GetModuleHandleW(lpModuleName: ?w.LPCWSTR) callconv(.winapi) w.HMODULE; +extern "kernel32" fn GetModuleInformation(hProcess: w.HANDLE, hModule: w.HMODULE, lpmodinfo: *MODULEINFO, cb: w.DWORD) callconv(.winapi) w.BOOL; +extern "kernel32" fn GetCurrentProcess() callconv(.winapi) w.HANDLE; + +pub inline fn isHex(c: u8) bool { + return (c >= '0' and c <= '9') or (c >= 'a' and c <= 'f') or (c >= 'A' and c <= 'F'); +} + +pub fn makeHex(comptime str: []const u8) []const u8 { + return comptime blk: { + @setEvalBranchQuota(10000); + var it = std.mem.splitSequence(u8, str, " "); + var pat: []const u8 = &.{}; + + while (it.next()) |byte| { + if (byte.len != 2) { + @compileError("Each byte should be 2 characters"); + } + if (isHex(byte[0])) { + if (!isHex(byte[1])) { + @compileError("The second hex digit is missing"); + } + const n = try std.fmt.parseInt(u8, byte, 16); + pat = pat ++ .{n}; + } else { + @compileError("Only hex digits are allowed"); + } + } + break :blk pat; + }; +} + +pub fn patchCode(addr: [*]u8, data: []const u8, restore_protect: u32) !void { + if (builtin.os.tag == .windows) { + var old_protect: w.DWORD = undefined; + + try w.VirtualProtect(addr, data.len, w.PAGE_EXECUTE_READWRITE, &old_protect); + @memcpy(addr, data); + + _ = FlushInstructionCache(GetCurrentProcess(), addr, data.len); + try w.VirtualProtect(addr, data.len, old_protect, &old_protect); + } else { + const page_size = std.heap.page_size_min; + const addr_int = @intFromPtr(addr); + const page_start = addr_int & ~(page_size - 1); + const page_end = addr_int + data.len; + const page_len = (page_end - page_start + page_size - 1) & ~(page_size - 1); + + const prot_all = 0b111; // rwx + + if (std.c.mprotect(@ptrFromInt(page_start), page_len, prot_all) != 0) + return error.MProtectWritable; + + @memcpy(addr, data); + + if (std.c.mprotect(@ptrFromInt(page_start), page_len, restore_protect) != 0) + return error.MProtectRestore; + } +} + +// Windows: Return entire module memory +// Linux: Return the code segment memory +pub fn getModule(comptime module_name: []const u8) ?[]const u8 { + return switch (builtin.os.tag) { + .windows => getModuleWindows(module_name), + .linux => getModuleLinux(module_name, 0b101) catch return null, + else => @compileError("getModule is not available for this target"), + }; +} + +// Return the entire module memory +pub fn getEntireModule(comptime module_name: []const u8) ?[]const u8 { + return switch (builtin.os.tag) { + .windows => getModuleWindows(module_name), + .linux => getModuleLinux(module_name, 0) catch return null, + else => @compileError("getModule is not available for this target"), + }; +} + +fn getModuleWindows(comptime module_name: []const u8) ?[]const u8 { + const dll_name = module_name ++ ".dll"; + const path_w = std.unicode.utf8ToUtf16LeStringLiteral(dll_name); + const dll = w.GetModuleHandleW(path_w) orelse return null; + var info: w.MODULEINFO = undefined; + if (GetModuleInformation(GetCurrentProcess(), dll, &info, @sizeOf(MODULEINFO)) == 0) { + return null; + } + const module: [*]const u8 = @ptrCast(dll); + return module[0..info.SizeOfImage]; +} + +fn getModuleLinux(comptime module_name: []const u8, permission: u32) !?[]const u8 { + const file_name = module_name ++ ".so"; + + const allocator = std.heap.page_allocator; + var file = try std.fs.openFileAbsolute("/proc/self/maps", .{ .mode = .read_only }); + defer file.close(); + var reader = file.reader(); + + var base: usize = 0; + var end: usize = 0; + var found = false; + + while (try reader.readUntilDelimiterOrEofAlloc(allocator, '\n', 4096)) |line| { + defer allocator.free(line); + + // Example format: + // de228000-de229000 r--p 00000000 00:29 1026008 /usr/lib/libstdc++.so.6.0.33 + + if (!std.mem.endsWith(u8, line, file_name)) { + if (found) break; + continue; + } + + const pos = line.len - file_name.len; + if (line[pos - 1] != '/' and line[pos - 1] != ' ') continue; + + const dash = std.mem.indexOfScalar(u8, line, '-') orelse continue; + const space = std.mem.indexOfScalarPos(u8, line, dash + 1, ' ') orelse continue; + + const perms_start = space + 1; + if (line.len < perms_start + 4) continue; + const read = line[perms_start]; + const write = line[perms_start + 1]; + const exec = line[perms_start + 2]; + if (permission & 0b001 != 0 and read == '-') continue; + if (permission & 0b010 != 0 and write == '-') continue; + if (permission & 0b100 != 0 and exec == '-') continue; + + const start_hex = line[0..dash]; + const end_hex = line[dash + 1 .. space]; + + const start_addr = try std.fmt.parseInt(usize, start_hex, 16); + const end_addr = try std.fmt.parseInt(usize, end_hex, 16); + + if (!found) { + base = start_addr; + end = end_addr; + found = true; + } else if (start_addr == end) { + end = end_addr; + } else { + break; + } + } + + if (!found) return null; + + const size = end - base; + const ptr: [*]const u8 = @ptrFromInt(base); + return ptr[0..size]; +} + +// Match call + add pattern +// If matched, inst + len will be the start of the imm32 +pub fn matchPIC(inst: [*]const u8) ?u32 { + if (inst[0] != x86.Opcode.Op1.call) return null; + if (inst[5] == x86.Opcode.Op1.alumiw) { + const modrm = inst[6]; + // mod must be 0b11 (register operand) + if ((modrm & 0b1100_0000) != 0b1100_0000) return null; + // reg/opcode must be 0b000 (ADD) + if ((modrm & 0b0011_1000) != 0b0000_0000) return null; + + // rm should not be 0b100 (ESP) + // Although it's rare, compiler occasionally uses EBP for PIC + const rm = modrm & 0b0000_0111; + if (rm == 0b100) return null; + return 7; + } else if (inst[5] == x86.Opcode.Op1.addeaxi) { + return 6; + } + return null; +} + +const GOT_pattern = mem.makePattern("E8 ?? ?? ?? ?? 05 ?? ?? ?? ?? 8D 80"); + +pub fn findGOTAddr(module: []const u8) ?u32 { + if (mem.scanFirst(module, GOT_pattern)) |offset| { + const imm32 = mem.loadValue(u32, module.ptr + offset + 6); + return @intFromPtr(module.ptr + offset + 5) +% imm32; + } + return null; +} |