@@ -8,6 +8,7 @@
#include <linux/pagemap.h>
#include <linux/genalloc.h>
#include <linux/syscalls.h>
+#include <linux/memblock.h>
#include <linux/pseudo_fs.h>
#include <linux/set_memory.h>
#include <linux/sched/signal.h>
@@ -35,6 +36,39 @@ struct secretmem_ctx {
unsigned int mode;
};
+struct secretmem_pool {
+ struct gen_pool *pool;
+ unsigned long reserved_size;
+ void *reserved;
+};
+
+static struct secretmem_pool secretmem_pool;
+
+static struct page *secretmem_alloc_huge_page(gfp_t gfp)
+{
+ struct gen_pool *pool = secretmem_pool.pool;
+ unsigned long addr = 0;
+ struct page *page = NULL;
+
+ if (pool) {
+ if (gen_pool_avail(pool) < PMD_SIZE)
+ return NULL;
+
+ addr = gen_pool_alloc(pool, PMD_SIZE);
+ if (!addr)
+ return NULL;
+
+ page = virt_to_page(addr);
+ } else {
+ page = alloc_pages(gfp, PMD_PAGE_ORDER);
+
+ if (page)
+ split_page(page, PMD_PAGE_ORDER);
+ }
+
+ return page;
+}
+
static int secretmem_pool_increase(struct secretmem_ctx *ctx, gfp_t gfp)
{
unsigned long nr_pages = (1 << PMD_PAGE_ORDER);
@@ -43,12 +77,11 @@ static int secretmem_pool_increase(struct secretmem_ctx *ctx, gfp_t gfp)
struct page *page;
int err;
- page = alloc_pages(gfp, PMD_PAGE_ORDER);
+ page = secretmem_alloc_huge_page(gfp);
if (!page)
return -ENOMEM;
addr = (unsigned long)page_address(page);
- split_page(page, PMD_PAGE_ORDER);
err = gen_pool_add(pool, addr, PMD_SIZE, NUMA_NO_NODE);
if (err) {
@@ -274,11 +307,13 @@ SYSCALL_DEFINE1(memfd_secret, unsigned long, flags)
return err;
}
-static void secretmem_cleanup_chunk(struct gen_pool *pool,
- struct gen_pool_chunk *chunk, void *data)
+static void secretmem_recycle_range(unsigned long start, unsigned long end)
+{
+ gen_pool_free(secretmem_pool.pool, start, PMD_SIZE);
+}
+
+static void secretmem_release_range(unsigned long start, unsigned long end)
{
- unsigned long start = chunk->start_addr;
- unsigned long end = chunk->end_addr;
unsigned long nr_pages, addr;
nr_pages = (end - start + 1) / PAGE_SIZE;
@@ -288,6 +323,18 @@ static void secretmem_cleanup_chunk(struct gen_pool *pool,
put_page(virt_to_page(addr));
}
+static void secretmem_cleanup_chunk(struct gen_pool *pool,
+ struct gen_pool_chunk *chunk, void *data)
+{
+ unsigned long start = chunk->start_addr;
+ unsigned long end = chunk->end_addr;
+
+ if (secretmem_pool.pool)
+ secretmem_recycle_range(start, end);
+ else
+ secretmem_release_range(start, end);
+}
+
static void secretmem_cleanup_pool(struct secretmem_ctx *ctx)
{
struct gen_pool *pool = ctx->pool;
@@ -327,14 +374,85 @@ static struct file_system_type secretmem_fs = {
.kill_sb = kill_anon_super,
};
+static int secretmem_reserved_mem_init(void)
+{
+ struct gen_pool *pool;
+ struct page *page;
+ void *addr;
+ int err;
+
+ if (!secretmem_pool.reserved)
+ return 0;
+
+ pool = gen_pool_create(PMD_SHIFT, NUMA_NO_NODE);
+ if (!pool)
+ return -ENOMEM;
+
+ err = gen_pool_add(pool, (unsigned long)secretmem_pool.reserved,
+ secretmem_pool.reserved_size, NUMA_NO_NODE);
+ if (err)
+ goto err_destroy_pool;
+
+ for (addr = secretmem_pool.reserved;
+ addr < secretmem_pool.reserved + secretmem_pool.reserved_size;
+ addr += PAGE_SIZE) {
+ page = virt_to_page(addr);
+ __ClearPageReserved(page);
+ set_page_count(page, 1);
+ }
+
+ secretmem_pool.pool = pool;
+ page = virt_to_page(secretmem_pool.reserved);
+ __kernel_map_pages(page, secretmem_pool.reserved_size / PAGE_SIZE, 0);
+ return 0;
+
+err_destroy_pool:
+ gen_pool_destroy(pool);
+ return err;
+}
+
static int secretmem_init(void)
{
- int ret = 0;
+ int ret;
+
+ ret = secretmem_reserved_mem_init();
+ if (ret)
+ return ret;
secretmem_mnt = kern_mount(&secretmem_fs);
- if (IS_ERR(secretmem_mnt))
+ if (IS_ERR(secretmem_mnt)) {
+ gen_pool_destroy(secretmem_pool.pool);
ret = PTR_ERR(secretmem_mnt);
+ }
return ret;
}
fs_initcall(secretmem_init);
+
+static int __init secretmem_setup(char *str)
+{
+ phys_addr_t align = PMD_SIZE;
+ unsigned long reserved_size;
+ void *reserved;
+
+ reserved_size = memparse(str, NULL);
+ if (!reserved_size)
+ return 0;
+
+ if (reserved_size * 2 > PUD_SIZE)
+ align = PUD_SIZE;
+
+ reserved = memblock_alloc(reserved_size, align);
+ if (!reserved) {
+ pr_err("failed to reserve %lu bytes\n", secretmem_pool.reserved_size);
+ return 0;
+ }
+
+ secretmem_pool.reserved_size = reserved_size;
+ secretmem_pool.reserved = reserved;
+
+ pr_info("reserved %luM\n", reserved_size >> 20);
+
+ return 1;
+}
+__setup("secretmem=", secretmem_setup);