summaryrefslogtreecommitdiff
path: root/hook/src/mem.zig
diff options
context:
space:
mode:
Diffstat (limited to 'hook/src/mem.zig')
-rw-r--r--hook/src/mem.zig199
1 files changed, 199 insertions, 0 deletions
diff --git a/hook/src/mem.zig b/hook/src/mem.zig
new file mode 100644
index 0000000..62aec00
--- /dev/null
+++ b/hook/src/mem.zig
@@ -0,0 +1,199 @@
+const std = @import("std");
+const builtin = @import("builtin");
+const testing = std.testing;
+
+const utils = @import("utils.zig");
+
+const isHex = utils.isHex;
+const makeHex = utils.makeHex;
+
+pub fn makePattern(comptime str: []const u8) []const ?u8 {
+ return comptime blk: {
+ @setEvalBranchQuota(10000);
+ var it = std.mem.splitSequence(u8, str, " ");
+ var pat: []const ?u8 = &.{};
+
+ while (it.next()) |byte| {
+ if (byte.len != 2) {
+ @compileError("Each byte should be 2 characters");
+ }
+ if (byte[0] == '?') {
+ if (byte[1] != '?') {
+ @compileError("The second question mark is missing");
+ }
+ pat = pat ++ .{null};
+ } else if (isHex(byte[0])) {
+ if (!isHex(byte[1])) {
+ @compileError("The second hex digit is missing");
+ }
+ const n = try std.fmt.parseInt(u8, byte, 16);
+ pat = pat ++ .{n};
+ } else {
+ @compileError("Only hex digits, spaces and question marks are allowed");
+ }
+ }
+ break :blk pat;
+ };
+}
+
+pub fn makePatterns(comptime arr: anytype) []const []const ?u8 {
+ return comptime blk: {
+ var patterns: []const []const ?u8 = &.{};
+ for (arr) |str| {
+ const pat: []const []const ?u8 = &.{makePattern(str)};
+ patterns = patterns ++ pat;
+ }
+ break :blk patterns;
+ };
+}
+
+pub fn scanFirst(mem: []const u8, pattern: []const ?u8) ?usize {
+ if (mem.len < pattern.len) {
+ return null;
+ }
+
+ var offset: usize = 0;
+ outer: while (offset < mem.len - pattern.len + 1) : (offset += 1) {
+ for (pattern, 0..) |byte, j| {
+ if (byte) |b| {
+ if (b != mem[offset + j]) {
+ continue :outer;
+ }
+ }
+ }
+ return offset;
+ }
+
+ return null;
+}
+
+pub fn scanUnique(mem: []const u8, pattern: []const ?u8) ?usize {
+ if (scanFirst(mem, pattern)) |offset| {
+ if (scanFirst(mem[offset + pattern.len ..], pattern) != null) {
+ return null;
+ }
+ return offset;
+ }
+
+ return null;
+}
+
+pub const MatchedPattern = struct {
+ index: usize,
+ ptr: [*]const u8,
+};
+
+pub fn scanAllPatterns(mem: []const u8, patterns: []const []const ?u8, data: *std.ArrayList(MatchedPattern)) !void {
+ for (patterns, 0..) |pattern, i| {
+ var base: usize = 0;
+ while (scanFirst(mem[base..], pattern)) |offset| {
+ try data.append(MatchedPattern{
+ .index = i,
+ .ptr = mem.ptr + base + offset,
+ });
+ base += offset + pattern.len;
+ }
+ }
+}
+
+pub fn scanUniquePatterns(mem: []const u8, patterns: []const []const ?u8) ?MatchedPattern {
+ var match: ?MatchedPattern = null;
+ for (patterns, 0..) |pattern, i| {
+ if (scanFirst(mem, pattern)) |offset| {
+ if (scanFirst(mem[offset + pattern.len ..], pattern) != null) {
+ return null;
+ }
+
+ if (match != null) {
+ return null;
+ }
+
+ match = .{
+ .index = i,
+ .ptr = mem.ptr + offset,
+ };
+ }
+ }
+
+ return match;
+}
+
+test "Scan first pattern" {
+ const mem = makeHex("F6 05 12 34 56 78 12");
+
+ // Match at the start
+ const test_pattern1 = makePattern("F6 05 12");
+ const result1 = scanFirst(mem, test_pattern1);
+ try testing.expect(result1 != null);
+ if (result1) |offset| {
+ try testing.expectEqual(0, offset);
+ }
+
+ // Match at the middle
+ const test_pattern2 = makePattern("12 34 56");
+ const result2 = scanFirst(mem, test_pattern2);
+ try testing.expect(result2 != null);
+ if (result2) |offset| {
+ try testing.expectEqual(2, offset);
+ }
+
+ // Match at the end
+ const test_pattern3 = makePattern("56 78 12");
+ const result3 = scanFirst(mem, test_pattern3);
+ try testing.expect(result3 != null);
+ if (result3) |offset| {
+ try testing.expectEqual(4, offset);
+ }
+}
+
+test "Scan unique patterns" {
+ const mem = makeHex("F6 05 12 34 56 78 12");
+ const test_patterns = makePatterns(.{
+ "00 00 ?? ?? 12",
+ "12 ?? 56",
+ "F6 05 00 34",
+ });
+
+ const result = scanUniquePatterns(mem, test_patterns);
+ try testing.expect(result != null);
+ if (result) |r| {
+ try testing.expectEqual(1, r.index);
+ try testing.expectEqual(mem.ptr + 2, r.ptr);
+ }
+}
+
+test "Scan unique patterns with multiple matches" {
+ const mem = makeHex("12 34 56 12 34 56 78 9A BC DE");
+
+ const test_patterns1 = makePatterns(.{
+ "12 34 56", // Non-unique match
+ });
+ try testing.expect(scanUniquePatterns(mem, test_patterns1) == null);
+
+ const test_patterns2 = makePatterns(.{
+ "12 ?? ?? 12", // Unique match
+ "9A BC DE", // Unique match
+ });
+ try testing.expect(scanUniquePatterns(mem, test_patterns2) == null);
+
+ const test_patterns3 = makePatterns(.{
+ "12 34 56", // Non-unique match
+ "9A BC DE", // Unique match
+ });
+ try testing.expect(scanUniquePatterns(mem, test_patterns3) == null);
+}
+
+pub fn loadValue(T: type, ptr: [*]const u8) T {
+ const val: *align(1) const T = @ptrCast(ptr);
+ return val.*;
+}
+
+pub fn setValue(T: type, ptr: [*]u8, value: T) void {
+ const val: *align(1) T = @ptrCast(ptr);
+ val.* = value;
+}
+
+test "Load value from memory" {
+ const mem = makeHex("E9 B1 9A 78 56"); // jmp
+ try testing.expectEqual(0x56789AB1, loadValue(u32, mem.ptr + 1));
+}