redirect example - MarekBykowski/readme GitHub Wiki

// doe_redirect.c
#include <linux/module.h>
#include <linux/pci.h>
#include <linux/rcupdate.h>
#include <linux/jump_label.h>
#include <linux/debugfs.h>
#include <linux/uaccess.h>
#include <linux/slab.h>

/* ================================
 * Configuration (edit these)
 * ================================ */

#define DEV_A "0000:01:00.0"
#define DEV_B "0000:02:00.0"

/* ================================
 * Static key (zero overhead when off)
 * ================================ */

DEFINE_STATIC_KEY_FALSE(doe_redirect_key);

/* ================================
 * RCU redirect structure
 * ================================ */

struct doe_redirect {
	struct pci_doe_mb *from;
	struct pci_doe_mb *to;
};

static struct doe_redirect __rcu *redir;

/* ================================
 * DebugFS
 * ================================ */

static struct dentry *dbg_dir;
static bool redirect_enabled;

/* ================================
 * Lookup helper
 * ================================ */

static struct pci_doe_mb *find_doe_mb(const char *bdf)
{
	struct pci_dev *pdev;

	pdev = pci_get_domain_bus_and_slot(0,
					   simple_strtoul(bdf + 5, NULL, 16),
					   PCI_DEVFN(
					       simple_strtoul(bdf + 8, NULL, 16),
					       simple_strtoul(bdf + 11, NULL, 16)));

	if (!pdev)
		return NULL;

	/* This assumes DOE already initialized */
	return pdev->doe_mb;
}

/* ================================
 * Redirected submit wrapper
 * ================================ */

/* Original function pointer (saved once) */
static int (*orig_submit)(struct pci_doe_mb *,
			  struct pci_doe_task *);

int pci_doe_submit_task(struct pci_doe_mb *mb,
			struct pci_doe_task *task)
{
	struct doe_redirect *r;
	struct pci_doe_mb *target;

	if (!static_branch_unlikely(&doe_redirect_key))
		return orig_submit(mb, task);

	rcu_read_lock();
	r = rcu_dereference(redir);

	if (r && mb == r->from)
		target = r->to;
	else
		target = mb;

	rcu_read_unlock();

	return orig_submit(target, task);
}
EXPORT_SYMBOL_GPL(pci_doe_submit_task);

/* ================================
 * Enable/Disable Logic
 * ================================ */

static void enable_redirect(void)
{
	static_branch_enable(&doe_redirect_key);
	redirect_enabled = true;
}

static void disable_redirect(void)
{
	static_branch_disable(&doe_redirect_key);
	synchronize_rcu();
	redirect_enabled = false;
}

static ssize_t enable_write(struct file *f,
			    const char __user *buf,
			    size_t len, loff_t *ppos)
{
	char kbuf[8];
	int val;

	if (len > sizeof(kbuf) - 1)
		return -EINVAL;

	if (copy_from_user(kbuf, buf, len))
		return -EFAULT;

	kbuf[len] = 0;

	if (kstrtoint(kbuf, 0, &val))
		return -EINVAL;

	if (val)
		enable_redirect();
	else
		disable_redirect();

	return len;
}

static ssize_t enable_read(struct file *f,
			   char __user *buf,
			   size_t len, loff_t *ppos)
{
	char tmp[4];
	int r;

	r = snprintf(tmp, sizeof(tmp), "%d\n",
		     redirect_enabled ? 1 : 0);

	return simple_read_from_buffer(buf, len, ppos,
				       tmp, r);
}

static const struct file_operations enable_fops = {
	.write = enable_write,
	.read  = enable_read,
};

/* ================================
 * Module Init
 * ================================ */

static int __init doe_redirect_init(void)
{
	struct pci_doe_mb *a, *b;
	struct doe_redirect *r;

	pr_info("DOE redirect init\n");

	a = find_doe_mb(DEV_A);
	b = find_doe_mb(DEV_B);

	if (!a || !b) {
		pr_err("DOE devices not found\n");
		return -ENODEV;
	}

	r = kzalloc(sizeof(*r), GFP_KERNEL);
	if (!r)
		return -ENOMEM;

	r->from = a;
	r->to   = b;

	rcu_assign_pointer(redir, r);

	/* Save original pointer once */
	orig_submit = pci_doe_submit_task;

	dbg_dir = debugfs_create_dir("doe_redirect", NULL);
	debugfs_create_file("enable", 0644,
			    dbg_dir, NULL, &enable_fops);

	pr_info("DOE redirect ready (echo 1 to enable)\n");
	return 0;
}

/* ================================
 * Module Exit
 * ================================ */

static void __exit doe_redirect_exit(void)
{
	struct doe_redirect *r;

	disable_redirect();

	r = rcu_dereference_protected(redir, 1);
	kfree(r);

	debugfs_remove_recursive(dbg_dir);

	pr_info("DOE redirect unloaded\n");
}

module_init(doe_redirect_init);
module_exit(doe_redirect_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("DOE MITM Example");
MODULE_DESCRIPTION("DOE A->B redirect using debugfs + RCU + static key");

Refined

// doe_redirect.c

#include <linux/module.h>
#include <linux/pci.h>
#include <linux/rcupdate.h>
#include <linux/slab.h>

#define DEV_A "0000:01:00.0"
#define DEV_B "0000:02:00.0"

/*
 * Hook exported from modified PCI DOE core:
 *
 * extern int (*redirect_hook)(struct pci_doe_mb *,
 *                              struct pci_doe_task *);
 */
extern int (*redirect_hook)(struct pci_doe_mb *,
                            struct pci_doe_task *);

extern int __pci_doe_submit_task(struct pci_doe_mb *,
                                 struct pci_doe_task *);

/* ============================= */

struct redirect_ctx {
	struct pci_dev *pdev_a;
	struct pci_dev *pdev_b;
};

static struct redirect_ctx *ctx;

/* ============================= */
/* BDF helper                    */
/* ============================= */

static struct pci_dev *get_pdev_from_bdf(const char *bdf)
{
	unsigned int dom, bus, dev, fn;

	if (sscanf(bdf, "%04x:%02x:%02x.%1x",
		   &dom, &bus, &dev, &fn) != 4)
		return NULL;

	return pci_get_domain_bus_and_slot(dom, bus,
					   PCI_DEVFN(dev, fn));
}

/* ============================= */
/* Redirect hook                 */
/* ============================= */

static int doe_redirect_submit(struct pci_doe_mb *mb,
			       struct pci_doe_task *task)
{
	struct pci_dev *pdev = mb->pdev;

	/* Redirect A → B */
	if (pdev == ctx->pdev_a) {
		pr_info("DOE redirect: A -> B\n");
		/* reuse original submit with different mailbox */
		return __pci_doe_submit_task(
			ctx->pdev_b->driver ?
			task->doe_mb : mb,  /* safety fallback */
			task);
	}

	/* Redirect B → A (optional) */
	if (pdev == ctx->pdev_b) {
		pr_info("DOE redirect: B -> A\n");
		return __pci_doe_submit_task(
			ctx->pdev_a->driver ?
			task->doe_mb : mb,
			task);
	}

	/* Otherwise normal path */
	return __pci_doe_submit_task(mb, task);
}

/* ============================= */
/* Module init                   */
/* ============================= */

static int __init doe_redirect_init(void)
{
	struct pci_dev *pdev_a;
	struct pci_dev *pdev_b;

	pr_info("DOE redirect init\n");

	pdev_a = get_pdev_from_bdf(DEV_A);
	pdev_b = get_pdev_from_bdf(DEV_B);

	if (!pdev_a || !pdev_b) {
		pr_err("Failed to locate PCI devices\n");
		return -ENODEV;
	}

	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
	if (!ctx)
		return -ENOMEM;

	ctx->pdev_a = pdev_a;
	ctx->pdev_b = pdev_b;

	/* Register redirect hook */
	rcu_assign_pointer(redirect_hook, doe_redirect_submit);

	pr_info("DOE redirect active\n");
	return 0;
}

/* ============================= */
/* Module exit                   */
/* ============================= */

static void __exit doe_redirect_exit(void)
{
	pr_info("DOE redirect exit\n");

	/* Disable hook */
	rcu_assign_pointer(redirect_hook, NULL);
	synchronize_rcu();

	if (ctx) {
		if (ctx->pdev_a)
			pci_dev_put(ctx->pdev_a);
		if (ctx->pdev_b)
			pci_dev_put(ctx->pdev_b);
		kfree(ctx);
	}
}

module_init(doe_redirect_init);
module_exit(doe_redirect_exit);

MODULE_LICENSE("GPL");
MODULE_DESCRIPTION("DOE mailbox redirect (MITM)");
MODULE_AUTHOR("Marek");

revised third

// doe_redirect.c

#include <linux/module.h>
#include <linux/pci.h>
#include <linux/rcupdate.h>
#include <linux/slab.h>
#include <linux/debugfs.h>
#include <linux/uaccess.h>
#include <linux/xarray.h>

#define DEV_A "0000:01:00.0"
#define DEV_B "0000:02:00.0"

extern int (*redirect_hook)(struct pci_doe_mb *,
                            struct pci_doe_task *);

extern int __pci_doe_submit_task(struct pci_doe_mb *,
                                 struct pci_doe_task *);

struct cxl_dev_state; /* forward */

/* ============================================= */

struct redirect_ctx {
        struct pci_dev *pdev_a;
        struct pci_dev *pdev_b;
        struct pci_doe_mb *mb_a;
        struct pci_doe_mb *mb_b;
};

static struct redirect_ctx *ctx;
static struct dentry *dbg_dir;
static bool redirect_enabled;

/* ============================================= */

static struct pci_dev *get_pdev_from_bdf(const char *bdf)
{
        unsigned int dom, bus, dev, fn;

        if (sscanf(bdf, "%04x:%02x:%02x.%1x",
                   &dom, &bus, &dev, &fn) != 4)
                return NULL;

        return pci_get_domain_bus_and_slot(dom, bus,
                                           PCI_DEVFN(dev, fn));
}

/* ============================================= */
/* Resolve existing mailbox via CXL xarray       */
/* ============================================= */

static struct pci_doe_mb *
resolve_existing_mb(struct pci_dev *pdev)
{
        struct cxl_dev_state *cxlds;
        struct pci_doe_mb *mb;
        unsigned long index;

        cxlds = pci_get_drvdata(pdev);
        if (!cxlds)
                return NULL;

        rcu_read_lock();
        xa_for_each(&cxlds->doe_mbs, index, mb) {
                rcu_read_unlock();
                return mb; /* usually only one */
        }
        rcu_read_unlock();

        return NULL;
}

/* ============================================= */
/* Redirect hook                                 */
/* ============================================= */

static int doe_redirect_submit(struct pci_doe_mb *mb,
                               struct pci_doe_task *task)
{
        if (!redirect_enabled)
                return __pci_doe_submit_task(mb, task);

        if (mb == ctx->mb_a)
                return __pci_doe_submit_task(ctx->mb_b, task);

        if (mb == ctx->mb_b)
                return __pci_doe_submit_task(ctx->mb_a, task);

        return __pci_doe_submit_task(mb, task);
}

/* ============================================= */
/* DebugFS control                               */
/* ============================================= */

static ssize_t enable_write(struct file *f,
                            const char __user *buf,
                            size_t len, loff_t *ppos)
{
        char kbuf[8];
        int val;

        if (len > sizeof(kbuf) - 1)
                return -EINVAL;

        if (copy_from_user(kbuf, buf, len))
                return -EFAULT;

        kbuf[len] = 0;

        if (kstrtoint(kbuf, 0, &val))
                return -EINVAL;

        redirect_enabled = !!val;

        pr_info("DOE redirect %s\n",
                redirect_enabled ? "enabled" : "disabled");

        return len;
}

static ssize_t enable_read(struct file *f,
                           char __user *buf,
                           size_t len, loff_t *ppos)
{
        char tmp[4];
        int r;

        r = snprintf(tmp, sizeof(tmp), "%d\n",
                     redirect_enabled ? 1 : 0);

        return simple_read_from_buffer(buf, len,
                                       ppos, tmp, r);
}

static const struct file_operations enable_fops = {
        .owner = THIS_MODULE,
        .write = enable_write,
        .read  = enable_read,
};

/* ============================================= */
/* Module init                                   */
/* ============================================= */

static int __init doe_redirect_init(void)
{
        pr_info("DOE redirect init\n");

        ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
        if (!ctx)
                return -ENOMEM;

        ctx->pdev_a = get_pdev_from_bdf(DEV_A);
        ctx->pdev_b = get_pdev_from_bdf(DEV_B);

        if (!ctx->pdev_a || !ctx->pdev_b)
                return -ENODEV;

        ctx->mb_a = resolve_existing_mb(ctx->pdev_a);
        ctx->mb_b = resolve_existing_mb(ctx->pdev_b);

        if (!ctx->mb_a || !ctx->mb_b) {
                pr_err("Failed to resolve DOE mailboxes\n");
                return -ENODEV;
        }

        rcu_assign_pointer(redirect_hook,
                           doe_redirect_submit);

        dbg_dir = debugfs_create_dir("doe_redirect", NULL);
        debugfs_create_file("enable", 0644,
                            dbg_dir, NULL,
                            &enable_fops);

        pr_info("DOE redirect ready (echo 1 to enable)\n");
        return 0;
}

/* ============================================= */
/* Module exit                                   */
/* ============================================= */

static void __exit doe_redirect_exit(void)
{
        rcu_assign_pointer(redirect_hook, NULL);
        synchronize_rcu();

        debugfs_remove_recursive(dbg_dir);

        if (ctx) {
                if (ctx->pdev_a)
                        pci_dev_put(ctx->pdev_a);
                if (ctx->pdev_b)
                        pci_dev_put(ctx->pdev_b);
                kfree(ctx);
        }

        pr_info("DOE redirect unloaded\n");
}

module_init(doe_redirect_init);
module_exit(doe_redirect_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Marek");
MODULE_DESCRIPTION("Strict DOE redirect without first-transfer leak");