linux kernel hook

本文探讨了Linux内核的热修补技术,介绍了一种利用内核API进行函数替换的方法,通过修改函数入口跳转至新函数实现动态更新,避免系统重启。文中详细展示了使用stop_machine API和inline hook技术的具体实现过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

其时是用的linux一个热修补技术。调用内核api 把旧函数前修改成跳转,跳转到新的函数里去执行。这个api就是stop_machine 但是传的不是函数而是一个包好的结构体
stop_machine
https://www.ibm.com/developerworks/cn/aix/library/au-spunix_ksplice/
源码地址:
https://github.com/haidragon/inl_hook

/*
 * inline hook usage example.
 */

#define KMSG_COMPONENT "HELLO"
#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt

#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/version.h>
#include <linux/kallsyms.h>
#include <linux/stop_machine.h>
#include <linux/stacktrace.h>
#include <asm/stacktrace.h>
#include <net/tcp.h>

#include "util.h"

/* variable */
static void (*tcp_set_state_fn)(struct sock *sk, int state);

/* hook function */
static void my_tcp_set_state(struct sock *sk, int state);

static struct symbol_ops hello_ops[] = {
    DECLARE_SYMBOL(&tcp_set_state_fn, "tcp_set_state"),
};

static struct hook_ops hello_hooks[] = {
    DECLARE_HOOK(&tcp_set_state_fn, my_tcp_set_state),
};

/* hook function */
static void
my_tcp_set_state(struct sock *sk, int state)
{
    /////////////////////////////////////////////////////////
    // add patch code.
    static const char *my_state_name[]={
        "Unused","Established","Syn Sent","Syn Recv",
        "Fin Wait 1","Fin Wait 2","Time Wait", "Close",
        "Close Wait","Last ACK","Listen","Closing"
    };
    struct inet_sock *inet = inet_sk(sk);
    /////////////////////////////////////////////////////////

    int oldstate = sk->sk_state;

    switch (state) {
    case TCP_ESTABLISHED:
        if (oldstate != TCP_ESTABLISHED)
            TCP_INC_STATS(sock_net(sk), TCP_MIB_CURRESTAB);
        break;

    case TCP_CLOSE:
        if (oldstate == TCP_CLOSE_WAIT || oldstate == TCP_ESTABLISHED)
            TCP_INC_STATS(sock_net(sk), TCP_MIB_ESTABRESETS);

        sk->sk_prot->unhash(sk);
        if (inet_csk(sk)->icsk_bind_hash &&
            !(sk->sk_userlocks & SOCK_BINDPORT_LOCK))
            inet_put_port(sk);
        /* fall through */
    default:
        if (oldstate == TCP_ESTABLISHED)
            TCP_DEC_STATS(sock_net(sk), TCP_MIB_CURRESTAB);
    }

    /* Change state AFTER socket is unhashed to avoid closed
     * socket sitting in hash tables.
     */
    sk->sk_state = state;

    /////////////////////////////////////////////////////////
    // add patch code.
    pr_info("TCP %pI4:%d -> %pI4:%d, State %s -> %s\n",
            &inet->inet_saddr, ntohs(inet->inet_sport),
            &inet->inet_daddr, ntohs(inet->inet_dport),
            my_state_name[oldstate], my_state_name[state]);
    /////////////////////////////////////////////////////////

#ifdef STATE_TRACE
    SOCK_DEBUG(sk, "TCP sk=%p, State %s -> %s\n", sk, statename[oldstate], statename[state]);
#endif
}

static int __init hello_init(void)
{
    if (!find_ksymbol(hello_ops, ARRAY_SIZE(hello_ops))) {
        pr_err("hello symbol table not find.\n");
        return -1;
    }

    if (!inl_sethook_ops(hello_hooks, ARRAY_SIZE(hello_hooks))) {
        pr_err("hijack hello functions fail.\n");
        return -1;
    }

    pr_info("hello loaded.\n");
    return 0;
}

static void __exit hello_cleanup(void)
{
    inl_unhook_ops(hello_hooks, ARRAY_SIZE(hello_hooks));
    pr_info("hello unloaded.\n");
}

module_init(hello_init);
module_exit(hello_cleanup);
MODULE_LICENSE("GPL");
/*
 * kernel function inline hook.
 */
#define KMSG_COMPONENT "KINL_HOOK"
#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt

#include <linux/kernel.h>
#include <linux/kprobes.h>
#include <linux/memory.h>
#include <linux/vmalloc.h>

#include "udis86.h"
#include "inl_hook.h"

#define JMP_CODE_BYTES          5
#define HOOK_MAX_CODE_BYTES     32 
#define MAX_DISASSEMBLE_BYTES   1024

struct hook_item {
    void *orig_func;
    void *hook_func;
    int stolen;
    u8 *orig_inst[HOOK_MAX_CODE_BYTES];

    // can execute page.
    u8 *trampoline;
    struct list_head list;
};

LIST_HEAD(hook_list);

inline unsigned long disable_wp(void)
{
    unsigned long cr0;

    preempt_disable();
    barrier();

    cr0 = read_cr0();
    write_cr0(cr0 & ~X86_CR0_WP);
    return cr0;
}

inline void restore_wp(unsigned long cr0)
{
    write_cr0(cr0);

    barrier();
    preempt_enable();
}

static u8 *skip_jumps(u8 *pcode)
{
    u8 *orig_code = pcode;

#if defined(CONFIG_X86_32) || defined(CONFIG_X86_64)
#if defined(CONFIG_X86_32)
    //mov edi,edi: hot patch point
    if (pcode[0] == 0x8b && pcode[1] == 0xff) {
        pcode += 2;
    }

    // push rbp; mov rsp, rbp;
    // 55 48 89 e5
    if (pcode[0] == 0x55 && pcode[1] == 0x48 && pcode[2] == 0x89 && pcode[3] == 0xe5) {
        pcode += 4;
    }
#endif

    if (pcode[0] == 0xff && pcode[1] == 0x25) {
#if defined(CONFIG_X86_32)
        // on x86 we have an absolute pointer...
        u8 *target = *(u8 **)&pcode[2];
        // ... that shows us an absolute pointer.
        return skip_jumps(*(u8 **)target);
#elif defined(CONFIG_X86_64)
        // on x64 we have a 32-bit offset...
        s32 offset = *(s32 *)&pcode[2];
        // ... that shows us an absolute pointer
        return skip_jumps(*(u8 **)(pcode + 6 + offset));
    } else if (pcode[0] == 0x48 && pcode[1] == 0xff && pcode[2] == 0x25) {
        // or we can have the same with a REX prefix
        s32 offset = *(s32 *)&pcode[3];
        // ... that shows us an absolute pointer
        return skip_jumps(*(u8 **)(pcode + 7 + offset));
#endif
    } else if (pcode[0] == 0xe9) {
        // here the behavior is identical, we have...
        // ...a 32-bit offset to the destination.
        return skip_jumps(pcode + 5 + *(s32 *)&pcode[1]);
    } else if (pcode[0] == 0xeb) {
        // and finally an 8-bit offset to the destination
        return skip_jumps(pcode + 2 + *(u8 *)&pcode[1]);
    }
#else
#error unsupported platform
#endif

    return orig_code;
}

static u8 *emit_jump(u8 *pcode, u8 *jumpto)
{
#if defined(CONFIG_X86_32) || defined(CONFIG_X86_64)
    u8 *jumpfrom = pcode + 5;
    size_t diff = jumpfrom > jumpto ? jumpfrom - jumpto : jumpto - jumpfrom;

    pr_debug("emit_jumps from %p to %p, diff is %ld", jumpfrom, jumpto, diff);

    if (diff <= 0x7fff0000) {
        pcode[0] = 0xe9;
        pcode += 1;
        *((u32 *)pcode) = (u32)(jumpto - jumpfrom);
        pcode += sizeof(u32);
    } else {
        pcode[0] = 0xff;
        pcode[1] = 0x25;
        pcode += 2;
#if defined(CONFIG_X86_32)
        // on x86 we write an absolute address (just behind the instruction)
        *((u32 *)pcode) = (u32)(pcode + sizeof(u32));
        pcode += sizeof(u32);
        *((u32 *)pcode) = (u32)jumpto;
        pcode += sizeof(u32);
#elif defined(CONFIG_X86_64)
        // on x64 we write the relative address of the same location
        *((u32 *)pcode) = (u32)0;
        pcode += sizeof(u32);
        *((u64 *)pcode) = (u64)jumpto;
        pcode += sizeof(u64);
#endif
    }
#else
#error unsupported platform
#endif

    return pcode;
}

static u32 disassemble_skip(u8 *target, u32 min_len)
{
    ud_t u;
    u32 ret = 0;

    ud_init(&u);
    ud_set_input_buffer(&u, target, MAX_DISASSEMBLE_BYTES);
    ud_set_mode(&u, 64);
    ud_set_syntax(&u, UD_SYN_INTEL);

    while (ret < min_len && ud_disassemble(&u)) {
        ret += ud_insn_len(&u);
    }

    return ret;
}

static struct hook_item *trampoline_alloc(void *target, u32 stolen)
{
    struct hook_item *item;
    u32 bytes = stolen + HOOK_MAX_CODE_BYTES;

    item = vzalloc(sizeof(struct hook_item));
    if (!item) {
        return NULL;
    }

    item->trampoline = __vmalloc(bytes, GFP_KERNEL, PAGE_KERNEL_EXEC);

    if (item->trampoline == NULL) {
        vfree(item);
        return NULL;
    }

    memset(item->trampoline, 0, bytes);

    return item;
}

static struct hook_item *trampoline_find(u8 *hook)
{
    struct hook_item *item;

    list_for_each_entry(item, &hook_list, list) {
        if (hook == item->hook_func) {
            return item;
        }
    }

    return NULL;
}

static u8 *post_hook(struct hook_item *item, void *target,
        void *hook, u32 stolen)
{
    unsigned long o_cr0;

    item->orig_func = target;
    item->hook_func = hook;
    item->stolen = stolen;

    memmove(item->orig_inst, target, stolen);
    memmove(item->trampoline, target, stolen);

    emit_jump(item->trampoline + stolen, target + stolen);

    o_cr0 = disable_wp();
    emit_jump(target, hook);
    restore_wp(o_cr0);

    list_add(&item->list, &hook_list);

    return item->trampoline;
}

static void hook_restore(struct hook_item *item)
{
    unsigned long o_cr0;

    o_cr0 = disable_wp();
    memmove(item->orig_func, item->orig_inst, item->stolen);
    restore_wp(o_cr0);

    list_del(&item->list);

    vfree(item->trampoline);
    vfree(item);
}

int inl_within_trampoline(unsigned long address)
{
    long bytes;
    struct hook_item *item;
    unsigned long start, end;

    list_for_each_entry(item, &hook_list, list) {
        bytes = item->stolen + HOOK_MAX_CODE_BYTES;
        start = (unsigned long)item->trampoline;
        end = (unsigned long)item->trampoline + bytes;

        if (address >= start && address < end) {
            return -EBUSY;
        }
    }

    return 0;
}

int inl_sethook(void **orig, void *hook)
{
    u32 instr_len;
    struct hook_item *item;
    void *target = *orig;

    target = skip_jumps(target);

    pr_debug("Started on the job: %p / %p\n", target, hook);

    instr_len = disassemble_skip(target, JMP_CODE_BYTES);
    if (instr_len < JMP_CODE_BYTES) {
        pr_err("disassemble_skip invalid instruction length: %u\n",
                instr_len);
        return -1;
    }

    pr_debug("disassembly signals %d bytes.\n", instr_len);

    item = trampoline_alloc(target, instr_len);
    if (item == NULL) {
        pr_err("alloc trampoline fail, no memory.\n");
        return -ENOMEM;
    }

    *orig = post_hook(item, target, hook, instr_len);

    return 0;
}

int inl_unhook(void *hook)
{
    struct hook_item *item;

    item = trampoline_find(hook);
    if (item == NULL) {
        pr_info("no find hook function: %p\n", hook);
        return -1;
    }

    hook_restore(item);

    return 0;
}
/*
 * inl_hook util.
 */
#define KMSG_COMPONENT "KINL_HOOK"
#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt

#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/types.h>
#include <linux/kallsyms.h>
#include <linux/swap.h>
#include <linux/stop_machine.h>
#include <linux/stacktrace.h>
#include <asm/stacktrace.h>

#include "util.h"
#include "inl_hook.h"

struct instr_range {
    unsigned long start;
    unsigned long end;
};

struct hook_cbdata {
    struct hook_ops *ops;
    int count;
    struct instr_range ir;
};

#define MAX_HOOK_CODE_BYTES     32
#define MAX_STACK_TRACE_DEPTH   64
static unsigned long stack_entries[MAX_STACK_TRACE_DEPTH];
struct stack_trace trace = {
    .max_entries    = ARRAY_SIZE(stack_entries),
    .entries    = &stack_entries[0],
};

static bool inline
within_address(unsigned long address, struct instr_range *ir)
{
    return address >= ir->start && address < ir->end;
}

bool find_ksymbol(struct symbol_ops *ops, int n)
{
    int i;
    void **addr;
    const char *name;

    for (i = 0; i < n; i++) {
        addr = ops[i].addr;
        name = ops[i].symbol;

        *addr = (void *) kallsyms_lookup_name(name);
        if (*addr == NULL) {
            pr_err("not find %s.\n", name);
            return false;
        }
    }

    return true;
}

static int
inl_sethook_safe_verify(struct hook_ops *ops, int count)
{
    struct instr_range ir;
    struct task_struct *g, *t;
    int i, j;
    void *orig;
    int ret = 0;

    /* Check the stacks of all tasks. */
    do_each_thread(g, t) {

        trace.nr_entries = 0;
        save_stack_trace_tsk(t, &trace);
        if (trace.nr_entries >= trace.max_entries) {
            ret = -EBUSY;
            pr_err("more than %u trace entries!\n",
                   trace.max_entries);
            goto out;
        }

        for (i = 0; i < trace.nr_entries; i++) {
            if (trace.entries[i] == ULONG_MAX)
                break;

            for (j = 0; j < count; j++) {
                orig = *(ops[j].orig);

                ir.start = (unsigned long)orig;
                ir.end = (unsigned long)orig + MAX_HOOK_CODE_BYTES;

                if (within_address(trace.entries[i], &ir)) {
                    ret = -EBUSY;
                    goto out;
                }
            }
        }

    } while_each_thread(g, t);

out:
    return ret;
}

/* Called from stop_machine */
static int
inl_sethook_callback(void *data)
{
    int i;
    int ret;
    struct hook_cbdata *cbdata = (struct hook_cbdata *)data;

    ret = inl_sethook_safe_verify(cbdata->ops, cbdata->count);
    if (ret != 0) {
        return ret;
    }

    for (i = 0; i < cbdata->count; i++) {
        if (inl_sethook((void **)cbdata->ops[i].orig, cbdata->ops[i].hook) < 0) {
            pr_err("sethook_ops hook %s fail.\n", cbdata->ops[i].name);
            return -EFAULT;
        }
    }

    return 0;
}

bool inl_sethook_ops(struct hook_ops *ops, int n)
{
    int ret;
    struct hook_cbdata cbdata = {
        .ops = ops,
        .count = n,
    };

try_again_sethook:

    ret = stop_machine(inl_sethook_callback, &cbdata, NULL);

    if (ret == -EBUSY) {
        yield();
        pr_info("kernel busy, retry again inl_sethook_ops.\n");
        goto try_again_sethook;
    }

    return ret == 0;
}

static int
inl_unhook_safe_verify(struct instr_range *ir)
{
    unsigned long address;
    struct task_struct *g, *t;
    int i;
    int ret = 0;
    struct instr_range self_ir = {
        .start = (unsigned long)inl_unhook_safe_verify,
        .end = (unsigned long)&&label_unhook_verify_end,
    };

    /* Check the stacks of all tasks. */
    do_each_thread(g, t) {

        trace.nr_entries = 0;
        save_stack_trace_tsk(t, &trace);
        if (trace.nr_entries >= trace.max_entries) {
            ret = -EBUSY;
            pr_err("more than %u trace entries!\n",
                   trace.max_entries);
            goto out;
        }

        for (i = 0; i < trace.nr_entries; i++) {
            if (trace.entries[i] == ULONG_MAX)
                break;

            address = trace.entries[i];
            // within cleanup method.
            if (within_address(address, ir)
                || within_address(address, &self_ir)) {
                break;
            }

            if (inl_within_trampoline(address)
                || within_module_core(address, THIS_MODULE)) {
                ret = -EBUSY;
                goto out;
            }
        }

    } while_each_thread(g, t);

out:
    if (ret) {
        pr_err("PID: %d Comm: %.20s\n", t->pid, t->comm);
        for (i = 0; i < trace.nr_entries; i++) {
            if (trace.entries[i] == ULONG_MAX)
                break;
            pr_err("  [<%pK>] %pB\n",
                   (void *)trace.entries[i],
                   (void *)trace.entries[i]);
        }
    }

    return ret;

label_unhook_verify_end: ;
}

/* Called from stop_machine */
static int
inl_unhook_callback(void *data)
{
    int i;
    int ret;
    struct hook_cbdata *cbdata = (struct hook_cbdata *)data;

    ret = inl_unhook_safe_verify(&cbdata->ir);
    if (ret != 0) {
        return ret;
    }

    for (i = 0; i < cbdata->count; i++) {
        inl_unhook(cbdata->ops[i].hook);
    }

    return 0;
}

void inl_unhook_ops(struct hook_ops *ops, int n)
{
    int ret;
    struct hook_cbdata cbdata = {
        .ops = ops,
        .count = n,
        .ir.start = (unsigned long)inl_unhook_ops,
        .ir.end = (unsigned long)&&label_unhook_end,
    };

try_again_unhook:

    ret = stop_machine(inl_unhook_callback, &cbdata, NULL);

    if (ret) {
        yield();
        pr_info("module busy, retry again inl_unhook_ops.\n");
        goto try_again_unhook;
    }

label_unhook_end: ;
}

转载于:https://blog.51cto.com/haidragon/2363215

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值