diff mbox series

[v6,1/3] mm: add new api to enable ksm per process

Message ID 20230412031648.2206875-2-shr@devkernel.io (mailing list archive)
State New
Headers show
Series mm: process/cgroup ksm support | expand

Commit Message

Stefan Roesch April 12, 2023, 3:16 a.m. UTC
So far KSM can only be enabled by calling madvise for memory regions.  To
be able to use KSM for more workloads, KSM needs to have the ability to be
enabled / disabled at the process / cgroup level.

1. New options for prctl system command

   This patch series adds two new options to the prctl system call.
   The first one allows to enable KSM at the process level and the second
   one to query the setting.

   The setting will be inherited by child processes.

   With the above setting, KSM can be enabled for the seed process of a
   cgroup and all processes in the cgroup will inherit the setting.

2. Changes to KSM processing

   When KSM is enabled at the process level, the KSM code will iterate
   over all the VMA's and enable KSM for the eligible VMA's.

   When forking a process that has KSM enabled, the setting will be
   inherited by the new child process.

  1) Introduce new MMF_VM_MERGE_ANY flag

     This introduces the new flag MMF_VM_MERGE_ANY flag.  When this flag
     is set, kernel samepage merging (ksm) gets enabled for all vma's of a
     process.

  2) Setting VM_MERGEABLE on VMA creation

     When a VMA is created, if the MMF_VM_MERGE_ANY flag is set, the
     VM_MERGEABLE flag will be set for this VMA.

  3) add flag to __ksm_enter

     This change adds the flag parameter to __ksm_enter.  This allows to
     distinguish if ksm was called by prctl or madvise.

  4) add flag to __ksm_exit call

     This adds the flag parameter to the __ksm_exit() call.  This allows
     to distinguish if this call is for an prctl or madvise invocation.

  5) support disabling of ksm for a process

     This adds the ability to disable ksm for a process if ksm has been
     enabled for the process.

  6) add new prctl option to get and set ksm for a process

     This adds two new options to the prctl system call
     - enable ksm for all vmas of a process (if the vmas support it).
     - query if ksm has been enabled for a process.

3. Disabling MMF_VM_MERGE_ANY for storage keys in s390

   In the s390 architecture when storage keys are used, the
   MMF_VM_MERGE_ANY will be disabled.

Signed-off-by: Stefan Roesch <shr@devkernel.io>
Cc: David Hildenbrand <david@redhat.com>
Cc: Johannes Weiner <hannes@cmpxchg.org>
Cc: Michal Hocko <mhocko@suse.com>
Cc: Rik van Riel <riel@surriel.com>
Cc: Bagas Sanjaya <bagasdotme@gmail.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
---
 arch/s390/mm/gmap.c            |   1 +
 include/linux/ksm.h            |  23 +++++--
 include/linux/sched/coredump.h |   1 +
 include/uapi/linux/prctl.h     |   2 +
 kernel/fork.c                  |   1 +
 kernel/sys.c                   |  23 +++++++
 mm/ksm.c                       | 111 ++++++++++++++++++++++++++-------
 mm/mmap.c                      |   7 +++
 8 files changed, 142 insertions(+), 27 deletions(-)

Comments

Matthew Wilcox April 12, 2023, 1:20 p.m. UTC | #1
On Tue, Apr 11, 2023 at 08:16:46PM -0700, Stefan Roesch wrote:
>  	case PR_SET_VMA:
>  		error = prctl_set_vma(arg2, arg3, arg4, arg5);
>  		break;
> +#ifdef CONFIG_KSM
> +	case PR_SET_MEMORY_MERGE:
> +		if (mmap_write_lock_killable(me->mm))
> +			return -EINTR;
> +
> +		if (arg2) {
> +			int err = ksm_add_mm(me->mm);
> +
> +			if (!err)
> +				ksm_add_vmas(me->mm);

in the last version of this patch, you reported the error.  Now you
swallow the error.  I have no idea which is correct, but you've
changed the behaviour without explaining it, so I assume it's wrong.

> +		} else {
> +			clear_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
> +		}
> +		mmap_write_unlock(me->mm);
> +		break;
> +	case PR_GET_MEMORY_MERGE:
> +		if (arg2 || arg3 || arg4 || arg5)
> +			return -EINVAL;
> +
> +		error = !!test_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
> +		break;

Why do we need a GET?  Just for symmetry, or is there an actual need for
it?
David Hildenbrand April 12, 2023, 3:40 p.m. UTC | #2
[...]

Thanks for giving mu sugegstions a churn. I think we can further 
improve/simplify some things. I added some comments, but might have more 
regarding MMF_VM_MERGE_ANY / MMF_VM_MERGEABLE.

[I'll try reowkring your patch after I send this mail to play with some 
simplifications]

>   arch/s390/mm/gmap.c            |   1 +
>   include/linux/ksm.h            |  23 +++++--
>   include/linux/sched/coredump.h |   1 +
>   include/uapi/linux/prctl.h     |   2 +
>   kernel/fork.c                  |   1 +
>   kernel/sys.c                   |  23 +++++++
>   mm/ksm.c                       | 111 ++++++++++++++++++++++++++-------
>   mm/mmap.c                      |   7 +++
>   8 files changed, 142 insertions(+), 27 deletions(-)
> 
> diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c
> index 5a716bdcba05..9d85e5589474 100644
> --- a/arch/s390/mm/gmap.c
> +++ b/arch/s390/mm/gmap.c
> @@ -2591,6 +2591,7 @@ int gmap_mark_unmergeable(void)
>   	int ret;
>   	VMA_ITERATOR(vmi, mm, 0);
>   
> +	clear_bit(MMF_VM_MERGE_ANY, &mm->flags);

Okay, that should keep the existing mechanism working. (but users can 
still mess it up)

Might be worth a comment

/*
  * Make sure to disable KSM (if enabled for the whole process or
  * individual VMAs). Note that nothing currently hinders user space
  * from re-enabling it.
  */

>   	for_each_vma(vmi, vma) {
>   		/* Copy vm_flags to avoid partial modifications in ksm_madvise */
>   		vm_flags = vma->vm_flags;
> diff --git a/include/linux/ksm.h b/include/linux/ksm.h
> index 7e232ba59b86..f24f9faf1561 100644
> --- a/include/linux/ksm.h
> +++ b/include/linux/ksm.h
> @@ -18,20 +18,29 @@
>   #ifdef CONFIG_KSM
>   int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
>   		unsigned long end, int advice, unsigned long *vm_flags);
> -int __ksm_enter(struct mm_struct *mm);
> -void __ksm_exit(struct mm_struct *mm);
> +
> +int ksm_add_mm(struct mm_struct *mm);
> +void ksm_add_vma(struct vm_area_struct *vma);
> +void ksm_add_vmas(struct mm_struct *mm);
> +
> +int __ksm_enter(struct mm_struct *mm, int flag);
> +void __ksm_exit(struct mm_struct *mm, int flag);
>   
>   static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
>   {
> +	if (test_bit(MMF_VM_MERGE_ANY, &oldmm->flags))
> +		return ksm_add_mm(mm);

ksm_fork() runs before copying any VMAs. Copying the bit should be 
sufficient.

Would it be possible to rework to something like:

if (test_bit(MMF_VM_MERGE_ANY, &oldmm->flags))
	set_bit(MMF_VM_MERGE_ANY, &mm->flags)
if (test_bit(MMF_VM_MERGEABLE, &oldmm->flags))
	return __ksm_enter(mm);

work? IOW, not exporting ksm_add_mm() and not passing a flag to 
__ksm_enter() -- it would simply set MMF_VM_MERGEABLE ?


I rememebr proposing that enabling MMF_VM_MERGE_ANY would simply enable 
MMF_VM_MERGEABLE.

>   	if (test_bit(MMF_VM_MERGEABLE, &oldmm->flags))
> -		return __ksm_enter(mm);
> +		return __ksm_enter(mm, MMF_VM_MERGEABLE);
>   	return 0;
>   }
>   
>   static inline void ksm_exit(struct mm_struct *mm)
>   {
> -	if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
> -		__ksm_exit(mm);
> +	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
> +		__ksm_exit(mm, MMF_VM_MERGE_ANY);
> +	else if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
> +		__ksm_exit(mm, MMF_VM_MERGEABLE);

Can we do

if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
	__ksm_exit(mm);

And simply let __ksm_exit() clear both bits?

>   }
>   
>   /*
> @@ -53,6 +62,10 @@ void folio_migrate_ksm(struct folio *newfolio, struct folio *folio);
>   
>   #else  /* !CONFIG_KSM */
>   

[...]

>   #endif /* _LINUX_SCHED_COREDUMP_H */
> diff --git a/include/uapi/linux/prctl.h b/include/uapi/linux/prctl.h
> index 1312a137f7fb..759b3f53e53f 100644
> --- a/include/uapi/linux/prctl.h
> +++ b/include/uapi/linux/prctl.h
> @@ -290,4 +290,6 @@ struct prctl_mm_map {
>   #define PR_SET_VMA		0x53564d41
>   # define PR_SET_VMA_ANON_NAME		0
>   
> +#define PR_SET_MEMORY_MERGE		67
> +#define PR_GET_MEMORY_MERGE		68
>   #endif /* _LINUX_PRCTL_H */
> diff --git a/kernel/fork.c b/kernel/fork.c
> index f68954d05e89..1520697cf6c7 100644
> --- a/kernel/fork.c
> +++ b/kernel/fork.c
> @@ -686,6 +686,7 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,
>   		if (vma_iter_bulk_store(&vmi, tmp))
>   			goto fail_nomem_vmi_store;
>   
> +		ksm_add_vma(tmp);

Is this really required? The relevant VMAs should have VM_MERGEABLE set.

>   		mm->map_count++;
>   		if (!(tmp->vm_flags & VM_WIPEONFORK))
>   			retval = copy_page_range(tmp, mpnt);
> diff --git a/kernel/sys.c b/kernel/sys.c
> index 495cd87d9bf4..9bba163d2d04 100644
> --- a/kernel/sys.c
> +++ b/kernel/sys.c
> @@ -15,6 +15,7 @@
>   #include <linux/highuid.h>
>   #include <linux/fs.h>
>   #include <linux/kmod.h>
> +#include <linux/ksm.h>
>   #include <linux/perf_event.h>
>   #include <linux/resource.h>
>   #include <linux/kernel.h>
> @@ -2661,6 +2662,28 @@ SYSCALL_DEFINE5(prctl, int, option, unsigned long, arg2, unsigned long, arg3,
>   	case PR_SET_VMA:
>   		error = prctl_set_vma(arg2, arg3, arg4, arg5);
>   		break;
> +#ifdef CONFIG_KSM
> +	case PR_SET_MEMORY_MERGE:
> +		if (mmap_write_lock_killable(me->mm))
> +			return -EINTR;
> +
> +		if (arg2) {
> +			int err = ksm_add_mm(me->mm);
> +
> +			if (!err)
> +				ksm_add_vmas(me->mm);
> +		} else {
> +			clear_bit(MMF_VM_MERGE_ANY, &me->mm->flags);

Okay, so disabling doesn't actually unshare anything.

> +		}
> +		mmap_write_unlock(me->mm);
> +		break;
> +	case PR_GET_MEMORY_MERGE:
> +		if (arg2 || arg3 || arg4 || arg5)
> +			return -EINVAL;
> +
> +		error = !!test_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
> +		break;
> +#endif
>   	default:
>   		error = -EINVAL;
>   		break;
> diff --git a/mm/ksm.c b/mm/ksm.c
> index d7bd28199f6c..ab95ae0f9def 100644
> --- a/mm/ksm.c
> +++ b/mm/ksm.c
> @@ -534,10 +534,33 @@ static int break_ksm(struct vm_area_struct *vma, unsigned long addr,
>   	return (ret & VM_FAULT_OOM) ? -ENOMEM : 0;
>   }
>   
> +static bool vma_ksm_compatible(struct vm_area_struct *vma)
> +{
> +	if (vma->vm_flags & (VM_SHARED  | VM_MAYSHARE   | VM_PFNMAP  |
> +			     VM_IO      | VM_DONTEXPAND | VM_HUGETLB |
> +			     VM_MIXEDMAP))
> +		return false;		/* just ignore the advice */
> +
> +	if (vma_is_dax(vma))
> +		return false;
> +
> +#ifdef VM_SAO
> +	if (vma->vm_flags & VM_SAO)
> +		return false;
> +#endif
> +#ifdef VM_SPARC_ADI
> +	if (vma->vm_flags & VM_SPARC_ADI)
> +		return false;
> +#endif
> +
> +	return true;
> +}
> +
>   static struct vm_area_struct *find_mergeable_vma(struct mm_struct *mm,
>   		unsigned long addr)
>   {
>   	struct vm_area_struct *vma;
> +

unrelated change

>   	if (ksm_test_exit(mm))
>   		return NULL;
>   	vma = vma_lookup(mm, addr);
> @@ -1065,6 +1088,7 @@ static int unmerge_and_remove_all_rmap_items(void)
>   
>   			mm_slot_free(mm_slot_cache, mm_slot);
>   			clear_bit(MMF_VM_MERGEABLE, &mm->flags);
> +			clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
>   			mmdrop(mm);
>   		} else
>   			spin_unlock(&ksm_mmlist_lock);
> @@ -2495,6 +2519,7 @@ static struct ksm_rmap_item *scan_get_next_rmap_item(struct page **page)
>   
>   		mm_slot_free(mm_slot_cache, mm_slot);
>   		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
> +		clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
>   		mmap_read_unlock(mm);
>   		mmdrop(mm);
>   	} else {
> @@ -2571,6 +2596,63 @@ static int ksm_scan_thread(void *nothing)
>   	return 0;
>   }
>   
> +static void __ksm_add_vma(struct vm_area_struct *vma)
> +{
> +	unsigned long vm_flags = vma->vm_flags;
> +
> +	if (vm_flags & VM_MERGEABLE)
> +		return;
> +
> +	if (vma_ksm_compatible(vma)) {
> +		vm_flags |= VM_MERGEABLE;
> +		vm_flags_reset(vma, vm_flags);
> +	}
> +}
> +
> +/**
> + * ksm_add_vma - Mark vma as mergeable

"if compatible"

> + *
> + * @vma:  Pointer to vma
> + */
> +void ksm_add_vma(struct vm_area_struct *vma)
> +{
> +	struct mm_struct *mm = vma->vm_mm;
> +
> +	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
> +		__ksm_add_vma(vma);
> +}
> +
> +/**
> + * ksm_add_vmas - Mark all vma's of a process as mergeable
> + *
> + * @mm:  Pointer to mm
> + */
> +void ksm_add_vmas(struct mm_struct *mm)

I'd suggest calling this

> +{
> +	struct vm_area_struct *vma;
> +
> +	VMA_ITERATOR(vmi, mm, 0);
> +	for_each_vma(vmi, vma)
> +		__ksm_add_vma(vma);
> +}
> +
> +/**
> + * ksm_add_mm - Add mm to mm ksm list
> + *
> + * @mm:  Pointer to mm
> + *
> + * Returns 0 on success, otherwise error code
> + */
> +int ksm_add_mm(struct mm_struct *mm)
> +{
> +	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
> +		return -EINVAL;
> +	if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
> +		return -EINVAL;
> +
> +	return __ksm_enter(mm, MMF_VM_MERGE_ANY);
> +}
> +
>   int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
>   		unsigned long end, int advice, unsigned long *vm_flags)
>   {
> @@ -2579,28 +2661,13 @@ int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
>   
>   	switch (advice) {
>   	case MADV_MERGEABLE:
> -		/*
> -		 * Be somewhat over-protective for now!
> -		 */
> -		if (*vm_flags & (VM_MERGEABLE | VM_SHARED  | VM_MAYSHARE   |
> -				 VM_PFNMAP    | VM_IO      | VM_DONTEXPAND |
> -				 VM_HUGETLB | VM_MIXEDMAP))
> -			return 0;		/* just ignore the advice */
> -
> -		if (vma_is_dax(vma))
> +		if (vma->vm_flags & VM_MERGEABLE)
>   			return 0;
> -
> -#ifdef VM_SAO
> -		if (*vm_flags & VM_SAO)
> +		if (!vma_ksm_compatible(vma))
>   			return 0;
> -#endif
> -#ifdef VM_SPARC_ADI
> -		if (*vm_flags & VM_SPARC_ADI)
> -			return 0;
> -#endif
>   
>   		if (!test_bit(MMF_VM_MERGEABLE, &mm->flags)) {
> -			err = __ksm_enter(mm);
> +			err = __ksm_enter(mm, MMF_VM_MERGEABLE);
>   			if (err)
>   				return err;
>   		}
> @@ -2626,7 +2693,7 @@ int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
>   }
>   EXPORT_SYMBOL_GPL(ksm_madvise);
>   
> -int __ksm_enter(struct mm_struct *mm)
> +int __ksm_enter(struct mm_struct *mm, int flag)
>   {
>   	struct ksm_mm_slot *mm_slot;
>   	struct mm_slot *slot;
> @@ -2659,7 +2726,7 @@ int __ksm_enter(struct mm_struct *mm)
>   		list_add_tail(&slot->mm_node, &ksm_scan.mm_slot->slot.mm_node);
>   	spin_unlock(&ksm_mmlist_lock);
>   
> -	set_bit(MMF_VM_MERGEABLE, &mm->flags);
> +	set_bit(flag, &mm->flags);
>   	mmgrab(mm);
>   
>   	if (needs_wakeup)
> @@ -2668,7 +2735,7 @@ int __ksm_enter(struct mm_struct *mm)
>   	return 0;
>   }
>   
> -void __ksm_exit(struct mm_struct *mm)
> +void __ksm_exit(struct mm_struct *mm, int flag)
>   {
>   	struct ksm_mm_slot *mm_slot;
>   	struct mm_slot *slot;
> @@ -2700,7 +2767,7 @@ void __ksm_exit(struct mm_struct *mm)
>   
>   	if (easy_to_free) {
>   		mm_slot_free(mm_slot_cache, mm_slot);
> -		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
> +		clear_bit(flag, &mm->flags);
>   		mmdrop(mm);
>   	} else if (mm_slot) {
>   		mmap_write_lock(mm);
> diff --git a/mm/mmap.c b/mm/mmap.c
> index 740b54be3ed4..483e182e0b9d 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -46,6 +46,7 @@
>   #include <linux/pkeys.h>
>   #include <linux/oom.h>
>   #include <linux/sched/mm.h>
> +#include <linux/ksm.h>
>   
>   #include <linux/uaccess.h>
>   #include <asm/cacheflush.h>
> @@ -2213,6 +2214,8 @@ int __split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
>   	/* vma_complete stores the new vma */
>   	vma_complete(&vp, vmi, vma->vm_mm);
>   
> +	ksm_add_vma(new);
> +

Splitting a VMA shouldn't modify VM_MERGEABLE, so I assume this is not 
required?

>   	/* Success. */
>   	if (new_below)
>   		vma_next(vmi);
> @@ -2664,6 +2667,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
>   	if (file && vm_flags & VM_SHARED)
>   		mapping_unmap_writable(file->f_mapping);
>   	file = vma->vm_file;
> +	ksm_add_vma(vma);
>   expanded:
>   	perf_event_mmap(vma);
>   
> @@ -2936,6 +2940,7 @@ static int do_brk_flags(struct vma_iterator *vmi, struct vm_area_struct *vma,
>   		goto mas_store_fail;
>   
>   	mm->map_count++;
> +	ksm_add_vma(vma);
>   out:
>   	perf_event_mmap(vma);
>   	mm->total_vm += len >> PAGE_SHIFT;
> @@ -3180,6 +3185,7 @@ struct vm_area_struct *copy_vma(struct vm_area_struct **vmap,
>   		if (vma_link(mm, new_vma))
>   			goto out_vma_link;
>   		*need_rmap_locks = false;
> +		ksm_add_vma(new_vma);

Copying shouldn't modify VM_MERGEABLE, so I think this is not required?

>   	}
>   	validate_mm_mt(mm);
>   	return new_vma;
> @@ -3356,6 +3362,7 @@ static struct vm_area_struct *__install_special_mapping(
>   	vm_stat_account(mm, vma->vm_flags, len >> PAGE_SHIFT);
>   
>   	perf_event_mmap(vma);
> +	ksm_add_vma(vma);

IIUC, special mappings will never be considered a reasonable target for 
KSM (especially, because at least VM_DONTEXPAND is always set).

I think you can just drop this call.
Stefan Roesch April 12, 2023, 4:08 p.m. UTC | #3
Matthew Wilcox <willy@infradead.org> writes:

> On Tue, Apr 11, 2023 at 08:16:46PM -0700, Stefan Roesch wrote:
>>  	case PR_SET_VMA:
>>  		error = prctl_set_vma(arg2, arg3, arg4, arg5);
>>  		break;
>> +#ifdef CONFIG_KSM
>> +	case PR_SET_MEMORY_MERGE:
>> +		if (mmap_write_lock_killable(me->mm))
>> +			return -EINTR;
>> +
>> +		if (arg2) {
>> +			int err = ksm_add_mm(me->mm);
>> +
>> +			if (!err)
>> +				ksm_add_vmas(me->mm);
>
> in the last version of this patch, you reported the error.  Now you
> swallow the error.  I have no idea which is correct, but you've
> changed the behaviour without explaining it, so I assume it's wrong.
>

I don't see how the error is swallowed in the arg2 case. If there is
an error ksm_add_vmas is not executedd and at the end of the function
the error is returned. Am I missing something?

>> +		} else {
>> +			clear_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
>> +		}
>> +		mmap_write_unlock(me->mm);
>> +		break;
>> +	case PR_GET_MEMORY_MERGE:
>> +		if (arg2 || arg3 || arg4 || arg5)
>> +			return -EINVAL;
>> +
>> +		error = !!test_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
>> +		break;
>
> Why do we need a GET?  Just for symmetry, or is there an actual need for
> it?

There are three reasons:
- For symmetry
- The ksm sharing is inherited by child processes. This allows the test
  programs to verify that this is working.
- For child processes it might be useful to have the ability to check if
  ksm sharing has been enabled
Matthew Wilcox April 12, 2023, 4:29 p.m. UTC | #4
On Wed, Apr 12, 2023 at 09:08:11AM -0700, Stefan Roesch wrote:
> 
> Matthew Wilcox <willy@infradead.org> writes:
> 
> > On Tue, Apr 11, 2023 at 08:16:46PM -0700, Stefan Roesch wrote:
> >>  	case PR_SET_VMA:
> >>  		error = prctl_set_vma(arg2, arg3, arg4, arg5);
> >>  		break;
> >> +#ifdef CONFIG_KSM
> >> +	case PR_SET_MEMORY_MERGE:
> >> +		if (mmap_write_lock_killable(me->mm))
> >> +			return -EINTR;
> >> +
> >> +		if (arg2) {
> >> +			int err = ksm_add_mm(me->mm);
> >> +
> >> +			if (!err)
> >> +				ksm_add_vmas(me->mm);
> >
> > in the last version of this patch, you reported the error.  Now you
> > swallow the error.  I have no idea which is correct, but you've
> > changed the behaviour without explaining it, so I assume it's wrong.
> >
> 
> I don't see how the error is swallowed in the arg2 case. If there is
> an error ksm_add_vmas is not executedd and at the end of the function
> the error is returned. Am I missing something?

You said 'int err' which declares a new variable.  If you want it
reported, just use 'error = ksm_add_mm(me->mm);'.
Stefan Roesch April 12, 2023, 4:44 p.m. UTC | #5
David Hildenbrand <david@redhat.com> writes:

> [...]
>
> Thanks for giving mu sugegstions a churn. I think we can further
> improve/simplify some things. I added some comments, but might have more
> regarding MMF_VM_MERGE_ANY / MMF_VM_MERGEABLE.
>
> [I'll try reowkring your patch after I send this mail to play with some
> simplifications]
>
>>   arch/s390/mm/gmap.c            |   1 +
>>   include/linux/ksm.h            |  23 +++++--
>>   include/linux/sched/coredump.h |   1 +
>>   include/uapi/linux/prctl.h     |   2 +
>>   kernel/fork.c                  |   1 +
>>   kernel/sys.c                   |  23 +++++++
>>   mm/ksm.c                       | 111 ++++++++++++++++++++++++++-------
>>   mm/mmap.c                      |   7 +++
>>   8 files changed, 142 insertions(+), 27 deletions(-)
>> diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c
>> index 5a716bdcba05..9d85e5589474 100644
>> --- a/arch/s390/mm/gmap.c
>> +++ b/arch/s390/mm/gmap.c
>> @@ -2591,6 +2591,7 @@ int gmap_mark_unmergeable(void)
>>   	int ret;
>>   	VMA_ITERATOR(vmi, mm, 0);
>>   +	clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
>
> Okay, that should keep the existing mechanism working. (but users can still mess
> it up)
>
> Might be worth a comment
>
> /*
>  * Make sure to disable KSM (if enabled for the whole process or
>  * individual VMAs). Note that nothing currently hinders user space
>  * from re-enabling it.
>  */
>

I'll add the comment.

>>   	for_each_vma(vmi, vma) {
>>   		/* Copy vm_flags to avoid partial modifications in ksm_madvise */
>>   		vm_flags = vma->vm_flags;
>> diff --git a/include/linux/ksm.h b/include/linux/ksm.h
>> index 7e232ba59b86..f24f9faf1561 100644
>> --- a/include/linux/ksm.h
>> +++ b/include/linux/ksm.h
>> @@ -18,20 +18,29 @@
>>   #ifdef CONFIG_KSM
>>   int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
>>   		unsigned long end, int advice, unsigned long *vm_flags);
>> -int __ksm_enter(struct mm_struct *mm);
>> -void __ksm_exit(struct mm_struct *mm);
>> +
>> +int ksm_add_mm(struct mm_struct *mm);
>> +void ksm_add_vma(struct vm_area_struct *vma);
>> +void ksm_add_vmas(struct mm_struct *mm);
>> +
>> +int __ksm_enter(struct mm_struct *mm, int flag);
>> +void __ksm_exit(struct mm_struct *mm, int flag);
>>     static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
>>   {
>> +	if (test_bit(MMF_VM_MERGE_ANY, &oldmm->flags))
>> +		return ksm_add_mm(mm);
>
> ksm_fork() runs before copying any VMAs. Copying the bit should be sufficient.
>
> Would it be possible to rework to something like:
>
> if (test_bit(MMF_VM_MERGE_ANY, &oldmm->flags))
> 	set_bit(MMF_VM_MERGE_ANY, &mm->flags)
> if (test_bit(MMF_VM_MERGEABLE, &oldmm->flags))
> 	return __ksm_enter(mm);
>

That will work.
> work? IOW, not exporting ksm_add_mm() and not passing a flag to __ksm_enter() --
> it would simply set MMF_VM_MERGEABLE ?
>

ksm_add_mm() is also used in prctl (kernel/sys.c). Do you want to make a
similar change there?
>
> I rememebr proposing that enabling MMF_VM_MERGE_ANY would simply enable
> MMF_VM_MERGEABLE.
>
>>   	if (test_bit(MMF_VM_MERGEABLE, &oldmm->flags))
>> -		return __ksm_enter(mm);
>> +		return __ksm_enter(mm, MMF_VM_MERGEABLE);
>>   	return 0;
>>   }
>>     static inline void ksm_exit(struct mm_struct *mm)
>>   {
>> -	if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
>> -		__ksm_exit(mm);
>> +	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
>> +		__ksm_exit(mm, MMF_VM_MERGE_ANY);
>> +	else if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
>> +		__ksm_exit(mm, MMF_VM_MERGEABLE);
>
> Can we do
>
> if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
> 	__ksm_exit(mm);
>
> And simply let __ksm_exit() clear both bits?
>
Yes, I'll make the change.
>>   }
>>     /*
>> @@ -53,6 +62,10 @@ void folio_migrate_ksm(struct folio *newfolio, struct folio *folio);
>>     #else  /* !CONFIG_KSM */
>>
>
> [...]
>
>>   #endif /* _LINUX_SCHED_COREDUMP_H */
>> diff --git a/include/uapi/linux/prctl.h b/include/uapi/linux/prctl.h
>> index 1312a137f7fb..759b3f53e53f 100644
>> --- a/include/uapi/linux/prctl.h
>> +++ b/include/uapi/linux/prctl.h
>> @@ -290,4 +290,6 @@ struct prctl_mm_map {
>>   #define PR_SET_VMA		0x53564d41
>>   # define PR_SET_VMA_ANON_NAME		0
>>   +#define PR_SET_MEMORY_MERGE		67
>> +#define PR_GET_MEMORY_MERGE		68
>>   #endif /* _LINUX_PRCTL_H */
>> diff --git a/kernel/fork.c b/kernel/fork.c
>> index f68954d05e89..1520697cf6c7 100644
>> --- a/kernel/fork.c
>> +++ b/kernel/fork.c
>> @@ -686,6 +686,7 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,
>>   		if (vma_iter_bulk_store(&vmi, tmp))
>>   			goto fail_nomem_vmi_store;
>>   +		ksm_add_vma(tmp);
>
> Is this really required? The relevant VMAs should have VM_MERGEABLE set.
>
I'll fix it.

>>   		mm->map_count++;
>>   		if (!(tmp->vm_flags & VM_WIPEONFORK))
>>   			retval = copy_page_range(tmp, mpnt);
>> diff --git a/kernel/sys.c b/kernel/sys.c
>> index 495cd87d9bf4..9bba163d2d04 100644
>> --- a/kernel/sys.c
>> +++ b/kernel/sys.c
>> @@ -15,6 +15,7 @@
>>   #include <linux/highuid.h>
>>   #include <linux/fs.h>
>>   #include <linux/kmod.h>
>> +#include <linux/ksm.h>
>>   #include <linux/perf_event.h>
>>   #include <linux/resource.h>
>>   #include <linux/kernel.h>
>> @@ -2661,6 +2662,28 @@ SYSCALL_DEFINE5(prctl, int, option, unsigned long, arg2, unsigned long, arg3,
>>   	case PR_SET_VMA:
>>   		error = prctl_set_vma(arg2, arg3, arg4, arg5);
>>   		break;
>> +#ifdef CONFIG_KSM
>> +	case PR_SET_MEMORY_MERGE:
>> +		if (mmap_write_lock_killable(me->mm))
>> +			return -EINTR;
>> +
>> +		if (arg2) {
>> +			int err = ksm_add_mm(me->mm);
>> +
>> +			if (!err)
>> +				ksm_add_vmas(me->mm);
>> +		} else {
>> +			clear_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
>
> Okay, so disabling doesn't actually unshare anything.
>
>> +		}
>> +		mmap_write_unlock(me->mm);
>> +		break;
>> +	case PR_GET_MEMORY_MERGE:
>> +		if (arg2 || arg3 || arg4 || arg5)
>> +			return -EINVAL;
>> +
>> +		error = !!test_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
>> +		break;
>> +#endif
>>   	default:
>>   		error = -EINVAL;
>>   		break;
>> diff --git a/mm/ksm.c b/mm/ksm.c
>> index d7bd28199f6c..ab95ae0f9def 100644
>> --- a/mm/ksm.c
>> +++ b/mm/ksm.c
>> @@ -534,10 +534,33 @@ static int break_ksm(struct vm_area_struct *vma, unsigned long addr,
>>   	return (ret & VM_FAULT_OOM) ? -ENOMEM : 0;
>>   }
>>   +static bool vma_ksm_compatible(struct vm_area_struct *vma)
>> +{
>> +	if (vma->vm_flags & (VM_SHARED  | VM_MAYSHARE   | VM_PFNMAP  |
>> +			     VM_IO      | VM_DONTEXPAND | VM_HUGETLB |
>> +			     VM_MIXEDMAP))
>> +		return false;		/* just ignore the advice */
>> +
>> +	if (vma_is_dax(vma))
>> +		return false;
>> +
>> +#ifdef VM_SAO
>> +	if (vma->vm_flags & VM_SAO)
>> +		return false;
>> +#endif
>> +#ifdef VM_SPARC_ADI
>> +	if (vma->vm_flags & VM_SPARC_ADI)
>> +		return false;
>> +#endif
>> +
>> +	return true;
>> +}
>> +
>>   static struct vm_area_struct *find_mergeable_vma(struct mm_struct *mm,
>>   		unsigned long addr)
>>   {
>>   	struct vm_area_struct *vma;
>> +
>
> unrelated change
>
Removed.

>>   	if (ksm_test_exit(mm))
>>   		return NULL;
>>   	vma = vma_lookup(mm, addr);
>> @@ -1065,6 +1088,7 @@ static int unmerge_and_remove_all_rmap_items(void)
>>     			mm_slot_free(mm_slot_cache, mm_slot);
>>   			clear_bit(MMF_VM_MERGEABLE, &mm->flags);
>> +			clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
>>   			mmdrop(mm);
>>   		} else
>>   			spin_unlock(&ksm_mmlist_lock);
>> @@ -2495,6 +2519,7 @@ static struct ksm_rmap_item *scan_get_next_rmap_item(struct page **page)
>>     		mm_slot_free(mm_slot_cache, mm_slot);
>>   		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
>> +		clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
>>   		mmap_read_unlock(mm);
>>   		mmdrop(mm);
>>   	} else {
>> @@ -2571,6 +2596,63 @@ static int ksm_scan_thread(void *nothing)
>>   	return 0;
>>   }
>>   +static void __ksm_add_vma(struct vm_area_struct *vma)
>> +{
>> +	unsigned long vm_flags = vma->vm_flags;
>> +
>> +	if (vm_flags & VM_MERGEABLE)
>> +		return;
>> +
>> +	if (vma_ksm_compatible(vma)) {
>> +		vm_flags |= VM_MERGEABLE;
>> +		vm_flags_reset(vma, vm_flags);
>> +	}
>> +}
>> +
>> +/**
>> + * ksm_add_vma - Mark vma as mergeable
>
> "if compatible"
>
I'll added the above.

>> + *
>> + * @vma:  Pointer to vma
>> + */
>> +void ksm_add_vma(struct vm_area_struct *vma)
>> +{
>> +	struct mm_struct *mm = vma->vm_mm;
>> +
>> +	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
>> +		__ksm_add_vma(vma);
>> +}
>> +
>> +/**
>> + * ksm_add_vmas - Mark all vma's of a process as mergeable
>> + *
>> + * @mm:  Pointer to mm
>> + */
>> +void ksm_add_vmas(struct mm_struct *mm)
>
> I'd suggest calling this
>
I guess you forgot your name suggestion?

>> +{
>> +	struct vm_area_struct *vma;
>> +
>> +	VMA_ITERATOR(vmi, mm, 0);
>> +	for_each_vma(vmi, vma)
>> +		__ksm_add_vma(vma);
>> +}
>> +
>> +/**
>> + * ksm_add_mm - Add mm to mm ksm list
>> + *
>> + * @mm:  Pointer to mm
>> + *
>> + * Returns 0 on success, otherwise error code
>> + */
>> +int ksm_add_mm(struct mm_struct *mm)
>> +{
>> +	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
>> +		return -EINVAL;
>> +	if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
>> +		return -EINVAL;
>> +
>> +	return __ksm_enter(mm, MMF_VM_MERGE_ANY);
>> +}
>> +
>>   int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
>>   		unsigned long end, int advice, unsigned long *vm_flags)
>>   {
>> @@ -2579,28 +2661,13 @@ int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
>>     	switch (advice) {
>>   	case MADV_MERGEABLE:
>> -		/*
>> -		 * Be somewhat over-protective for now!
>> -		 */
>> -		if (*vm_flags & (VM_MERGEABLE | VM_SHARED  | VM_MAYSHARE   |
>> -				 VM_PFNMAP    | VM_IO      | VM_DONTEXPAND |
>> -				 VM_HUGETLB | VM_MIXEDMAP))
>> -			return 0;		/* just ignore the advice */
>> -
>> -		if (vma_is_dax(vma))
>> +		if (vma->vm_flags & VM_MERGEABLE)
>>   			return 0;
>> -
>> -#ifdef VM_SAO
>> -		if (*vm_flags & VM_SAO)
>> +		if (!vma_ksm_compatible(vma))
>>   			return 0;
>> -#endif
>> -#ifdef VM_SPARC_ADI
>> -		if (*vm_flags & VM_SPARC_ADI)
>> -			return 0;
>> -#endif
>>     		if (!test_bit(MMF_VM_MERGEABLE, &mm->flags)) {
>> -			err = __ksm_enter(mm);
>> +			err = __ksm_enter(mm, MMF_VM_MERGEABLE);
>>   			if (err)
>>   				return err;
>>   		}
>> @@ -2626,7 +2693,7 @@ int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
>>   }
>>   EXPORT_SYMBOL_GPL(ksm_madvise);
>>   -int __ksm_enter(struct mm_struct *mm)
>> +int __ksm_enter(struct mm_struct *mm, int flag)
>>   {
>>   	struct ksm_mm_slot *mm_slot;
>>   	struct mm_slot *slot;
>> @@ -2659,7 +2726,7 @@ int __ksm_enter(struct mm_struct *mm)
>>   		list_add_tail(&slot->mm_node, &ksm_scan.mm_slot->slot.mm_node);
>>   	spin_unlock(&ksm_mmlist_lock);
>>   -	set_bit(MMF_VM_MERGEABLE, &mm->flags);
>> +	set_bit(flag, &mm->flags);
>>   	mmgrab(mm);
>>     	if (needs_wakeup)
>> @@ -2668,7 +2735,7 @@ int __ksm_enter(struct mm_struct *mm)
>>   	return 0;
>>   }
>>   -void __ksm_exit(struct mm_struct *mm)
>> +void __ksm_exit(struct mm_struct *mm, int flag)
>>   {
>>   	struct ksm_mm_slot *mm_slot;
>>   	struct mm_slot *slot;
>> @@ -2700,7 +2767,7 @@ void __ksm_exit(struct mm_struct *mm)
>>     	if (easy_to_free) {
>>   		mm_slot_free(mm_slot_cache, mm_slot);
>> -		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
>> +		clear_bit(flag, &mm->flags);
>>   		mmdrop(mm);
>>   	} else if (mm_slot) {
>>   		mmap_write_lock(mm);
>> diff --git a/mm/mmap.c b/mm/mmap.c
>> index 740b54be3ed4..483e182e0b9d 100644
>> --- a/mm/mmap.c
>> +++ b/mm/mmap.c
>> @@ -46,6 +46,7 @@
>>   #include <linux/pkeys.h>
>>   #include <linux/oom.h>
>>   #include <linux/sched/mm.h>
>> +#include <linux/ksm.h>
>>     #include <linux/uaccess.h>
>>   #include <asm/cacheflush.h>
>> @@ -2213,6 +2214,8 @@ int __split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
>>   	/* vma_complete stores the new vma */
>>   	vma_complete(&vp, vmi, vma->vm_mm);
>>   +	ksm_add_vma(new);
>> +
>
> Splitting a VMA shouldn't modify VM_MERGEABLE, so I assume this is not required?
>
I'll fix it.

>>   	/* Success. */
>>   	if (new_below)
>>   		vma_next(vmi);
>> @@ -2664,6 +2667,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
>>   	if (file && vm_flags & VM_SHARED)
>>   		mapping_unmap_writable(file->f_mapping);
>>   	file = vma->vm_file;
>> +	ksm_add_vma(vma);
>>   expanded:
>>   	perf_event_mmap(vma);
>>   @@ -2936,6 +2940,7 @@ static int do_brk_flags(struct vma_iterator *vmi,
>> struct vm_area_struct *vma,
>>   		goto mas_store_fail;
>>     	mm->map_count++;
>> +	ksm_add_vma(vma);
>>   out:
>>   	perf_event_mmap(vma);
>>   	mm->total_vm += len >> PAGE_SHIFT;
>> @@ -3180,6 +3185,7 @@ struct vm_area_struct *copy_vma(struct vm_area_struct **vmap,
>>   		if (vma_link(mm, new_vma))
>>   			goto out_vma_link;
>>   		*need_rmap_locks = false;
>> +		ksm_add_vma(new_vma);
>
> Copying shouldn't modify VM_MERGEABLE, so I think this is not required?
>
I'll fix it.

>>   	}
>>   	validate_mm_mt(mm);
>>   	return new_vma;
>> @@ -3356,6 +3362,7 @@ static struct vm_area_struct *__install_special_mapping(
>>   	vm_stat_account(mm, vma->vm_flags, len >> PAGE_SHIFT);
>>     	perf_event_mmap(vma);
>> +	ksm_add_vma(vma);
>
> IIUC, special mappings will never be considered a reasonable target for KSM
> (especially, because at least VM_DONTEXPAND is always set).
>
> I think you can just drop this call.
I dropped it.
David Hildenbrand April 12, 2023, 6:41 p.m. UTC | #6
[...]
> That will work.
>> work? IOW, not exporting ksm_add_mm() and not passing a flag to __ksm_enter() --
>> it would simply set MMF_VM_MERGEABLE ?
>>
> 
> ksm_add_mm() is also used in prctl (kernel/sys.c). Do you want to make a
> similar change there?

Yes.

>>> + *
>>> + * @vma:  Pointer to vma
>>> + */
>>> +void ksm_add_vma(struct vm_area_struct *vma)
>>> +{
>>> +	struct mm_struct *mm = vma->vm_mm;
>>> +
>>> +	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
>>> +		__ksm_add_vma(vma);
>>> +}
>>> +
>>> +/**
>>> + * ksm_add_vmas - Mark all vma's of a process as mergeable
>>> + *
>>> + * @mm:  Pointer to mm
>>> + */
>>> +void ksm_add_vmas(struct mm_struct *mm)
>>
>> I'd suggest calling this
>>
> I guess you forgot your name suggestion?

Yeah, I reconsidered because the first idea I had was not particularly 
good. Maybe

ksm_enable_for_all_vmas()

But not so sure. If you think the "add" terminology is a good fit, keep 
it like that.


Thanks for bearing with me :)
David Hildenbrand April 12, 2023, 7:08 p.m. UTC | #7
On 12.04.23 20:41, David Hildenbrand wrote:
> [...]
>> That will work.
>>> work? IOW, not exporting ksm_add_mm() and not passing a flag to __ksm_enter() --
>>> it would simply set MMF_VM_MERGEABLE ?
>>>
>>
>> ksm_add_mm() is also used in prctl (kernel/sys.c). Do you want to make a
>> similar change there?
> 
> Yes.
> 
>>>> + *
>>>> + * @vma:  Pointer to vma
>>>> + */
>>>> +void ksm_add_vma(struct vm_area_struct *vma)
>>>> +{
>>>> +	struct mm_struct *mm = vma->vm_mm;
>>>> +
>>>> +	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
>>>> +		__ksm_add_vma(vma);
>>>> +}
>>>> +
>>>> +/**
>>>> + * ksm_add_vmas - Mark all vma's of a process as mergeable
>>>> + *
>>>> + * @mm:  Pointer to mm
>>>> + */
>>>> +void ksm_add_vmas(struct mm_struct *mm)
>>>
>>> I'd suggest calling this
>>>
>> I guess you forgot your name suggestion?
> 
> Yeah, I reconsidered because the first idea I had was not particularly
> good. Maybe
> 
> ksm_enable_for_all_vmas()
> 
> But not so sure. If you think the "add" terminology is a good fit, keep
> it like that.
> 
> 
> Thanks for bearing with me :)
> 

I briefly played with your patch to see how much it can be simplified.
Always enabling ksm (setting MMF_VM_MERGEABLE) before setting
MMF_VM_MERGE_ANY might simplify things. ksm_enable_merge_any() [or however it should
be called] and ksm_fork() contain the interesting bits.


Feel free to incorporate what you consider valuable (uncompiled,
untested).


diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c
index 5a716bdcba05..5b2eef31398e 100644
--- a/arch/s390/mm/gmap.c
+++ b/arch/s390/mm/gmap.c
@@ -2591,6 +2591,12 @@ int gmap_mark_unmergeable(void)
  	int ret;
  	VMA_ITERATOR(vmi, mm, 0);
  
+	/*
+	 * Make sure to disable KSM (if enabled for the whole process or
+	 * individual VMAs). Note that nothing currently hinders user space
+	 * from re-enabling it.
+	 */
+	clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
  	for_each_vma(vmi, vma) {
  		/* Copy vm_flags to avoid partial modifications in ksm_madvise */
  		vm_flags = vma->vm_flags;
diff --git a/include/linux/ksm.h b/include/linux/ksm.h
index 7e232ba59b86..c638b034d586 100644
--- a/include/linux/ksm.h
+++ b/include/linux/ksm.h
@@ -18,13 +18,24 @@
  #ifdef CONFIG_KSM
  int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
  		unsigned long end, int advice, unsigned long *vm_flags);
+
+void ksm_add_vma(struct vm_area_struct *vma);
+int ksm_enable_merge_any(struct mm_struct *mm);
+
  int __ksm_enter(struct mm_struct *mm);
  void __ksm_exit(struct mm_struct *mm);
  
  static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
  {
-	if (test_bit(MMF_VM_MERGEABLE, &oldmm->flags))
-		return __ksm_enter(mm);
+	int ret;
+
+	if (test_bit(MMF_VM_MERGEABLE, &oldmm->flags)) {
+		ret = __ksm_enter(mm);
+		if (ret)
+			return ret;
+	}
+	if (test_bit(MMF_VM_MERGE_ANY, &oldmm->flags))
+		set_bit(MMF_VM_MERGE_ANY, &mm->flags);
  	return 0;
  }
  
@@ -53,6 +64,10 @@ void folio_migrate_ksm(struct folio *newfolio, struct folio *folio);
  
  #else  /* !CONFIG_KSM */
  
+static inline void ksm_add_vma(struct vm_area_struct *vma)
+{
+}
+
  static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
  {
  	return 0;
diff --git a/include/linux/sched/coredump.h b/include/linux/sched/coredump.h
index 0e17ae7fbfd3..0ee96ea7a0e9 100644
--- a/include/linux/sched/coredump.h
+++ b/include/linux/sched/coredump.h
@@ -90,4 +90,5 @@ static inline int get_dumpable(struct mm_struct *mm)
  #define MMF_INIT_MASK		(MMF_DUMPABLE_MASK | MMF_DUMP_FILTER_MASK |\
  				 MMF_DISABLE_THP_MASK | MMF_HAS_MDWE_MASK)
  
+#define MMF_VM_MERGE_ANY	29
  #endif /* _LINUX_SCHED_COREDUMP_H */
diff --git a/include/uapi/linux/prctl.h b/include/uapi/linux/prctl.h
index 1312a137f7fb..759b3f53e53f 100644
--- a/include/uapi/linux/prctl.h
+++ b/include/uapi/linux/prctl.h
@@ -290,4 +290,6 @@ struct prctl_mm_map {
  #define PR_SET_VMA		0x53564d41
  # define PR_SET_VMA_ANON_NAME		0
  
+#define PR_SET_MEMORY_MERGE		67
+#define PR_GET_MEMORY_MERGE		68
  #endif /* _LINUX_PRCTL_H */
diff --git a/kernel/sys.c b/kernel/sys.c
index 495cd87d9bf4..8c2e50edeb18 100644
--- a/kernel/sys.c
+++ b/kernel/sys.c
@@ -15,6 +15,7 @@
  #include <linux/highuid.h>
  #include <linux/fs.h>
  #include <linux/kmod.h>
+#include <linux/ksm.h>
  #include <linux/perf_event.h>
  #include <linux/resource.h>
  #include <linux/kernel.h>
@@ -2661,6 +2662,30 @@ SYSCALL_DEFINE5(prctl, int, option, unsigned long, arg2, unsigned long, arg3,
  	case PR_SET_VMA:
  		error = prctl_set_vma(arg2, arg3, arg4, arg5);
  		break;
+#ifdef CONFIG_KSM
+	case PR_SET_MEMORY_MERGE:
+		if (mmap_write_lock_killable(me->mm))
+			return -EINTR;
+
+		if (arg2) {
+			error = ksm_enable_merge_any(me->mm);
+		} else {
+			/*
+			 * TODO: we might want disable KSM on all VMAs and
+			 * trigger unsharing to completely disable KSM.
+			 */
+			clear_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
+			error = 0;
+		}
+		mmap_write_unlock(me->mm);
+		break;
+	case PR_GET_MEMORY_MERGE:
+		if (arg2 || arg3 || arg4 || arg5)
+			return -EINVAL;
+
+		error = !!test_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
+		break;
+#endif
  	default:
  		error = -EINVAL;
  		break;
diff --git a/mm/ksm.c b/mm/ksm.c
index 2b8d30068cbb..76ceec35395c 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -512,6 +512,28 @@ static int break_ksm(struct vm_area_struct *vma, unsigned long addr)
  	return (ret & VM_FAULT_OOM) ? -ENOMEM : 0;
  }
  
+static bool vma_ksm_compatible(struct vm_area_struct *vma)
+{
+	if (vma->vm_flags & (VM_SHARED  | VM_MAYSHARE   | VM_PFNMAP  |
+			     VM_IO      | VM_DONTEXPAND | VM_HUGETLB |
+			     VM_MIXEDMAP))
+		return false;		/* just ignore the advice */
+
+	if (vma_is_dax(vma))
+		return false;
+
+#ifdef VM_SAO
+	if (vma->vm_flags & VM_SAO)
+		return false;
+#endif
+#ifdef VM_SPARC_ADI
+	if (vma->vm_flags & VM_SPARC_ADI)
+		return false;
+#endif
+
+	return true;
+}
+
  static struct vm_area_struct *find_mergeable_vma(struct mm_struct *mm,
  		unsigned long addr)
  {
@@ -1020,6 +1042,7 @@ static int unmerge_and_remove_all_rmap_items(void)
  
  			mm_slot_free(mm_slot_cache, mm_slot);
  			clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+			clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
  			mmdrop(mm);
  		} else
  			spin_unlock(&ksm_mmlist_lock);
@@ -2395,6 +2418,7 @@ static struct ksm_rmap_item *scan_get_next_rmap_item(struct page **page)
  
  		mm_slot_free(mm_slot_cache, mm_slot);
  		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+		clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
  		mmap_read_unlock(mm);
  		mmdrop(mm);
  	} else {
@@ -2471,6 +2495,52 @@ static int ksm_scan_thread(void *nothing)
  	return 0;
  }
  
+static void __ksm_add_vma(struct vm_area_struct *vma)
+{
+	unsigned long vm_flags = vma->vm_flags;
+
+	if (vm_flags & VM_MERGEABLE)
+		return;
+
+	if (vma_ksm_compatible(vma)) {
+		vm_flags |= VM_MERGEABLE;
+		vm_flags_reset(vma, vm_flags);
+	}
+}
+
+/**
+ * ksm_add_vma - Mark vma as mergeable
+ *
+ * @vma:  Pointer to vma
+ */
+void ksm_add_vma(struct vm_area_struct *vma)
+{
+	struct mm_struct *mm = vma->vm_mm;
+
+	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
+		__ksm_add_vma(vma);
+}
+
+int ksm_enable_merge_any(struct mm_struct *mm)
+{
+	struct vm_area_struct *vma;
+	int ret;
+
+	if (test_bit(MMF_VM_MERGE_ANY, mm->flags))
+		return 0;
+
+	if (!test_bit(MMF_VM_MERGEABLE, mm->flags)) {
+		ret = __ksm_enter(mm);
+		if (ret)
+			return ret;
+	}
+	set_bit(MMF_VM_MERGE_ANY, &mm->flags);
+
+	VMA_ITERATOR(vmi, mm, 0);
+	for_each_vma(vmi, vma)
+		__ksm_add_vma(vma);
+}
+
  int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
  		unsigned long end, int advice, unsigned long *vm_flags)
  {
@@ -2479,25 +2549,10 @@ int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
  
  	switch (advice) {
  	case MADV_MERGEABLE:
-		/*
-		 * Be somewhat over-protective for now!
-		 */
-		if (*vm_flags & (VM_MERGEABLE | VM_SHARED  | VM_MAYSHARE   |
-				 VM_PFNMAP    | VM_IO      | VM_DONTEXPAND |
-				 VM_HUGETLB | VM_MIXEDMAP))
-			return 0;		/* just ignore the advice */
-
-		if (vma_is_dax(vma))
+		if (vma->vm_flags & VM_MERGEABLE)
  			return 0;
-
-#ifdef VM_SAO
-		if (*vm_flags & VM_SAO)
-			return 0;
-#endif
-#ifdef VM_SPARC_ADI
-		if (*vm_flags & VM_SPARC_ADI)
+		if (!vma_ksm_compatible(vma))
  			return 0;
-#endif
  
  		if (!test_bit(MMF_VM_MERGEABLE, &mm->flags)) {
  			err = __ksm_enter(mm);
@@ -2601,6 +2656,7 @@ void __ksm_exit(struct mm_struct *mm)
  	if (easy_to_free) {
  		mm_slot_free(mm_slot_cache, mm_slot);
  		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+		clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
  		mmdrop(mm);
  	} else if (mm_slot) {
  		mmap_write_lock(mm);
diff --git a/mm/mmap.c b/mm/mmap.c
index ff68a67a2a7c..1f8619ff58ca 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -46,6 +46,7 @@
  #include <linux/pkeys.h>
  #include <linux/oom.h>
  #include <linux/sched/mm.h>
+#include <linux/ksm.h>
  
  #include <linux/uaccess.h>
  #include <asm/cacheflush.h>
@@ -2659,6 +2660,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
  	if (file && vm_flags & VM_SHARED)
  		mapping_unmap_writable(file->f_mapping);
  	file = vma->vm_file;
+	ksm_add_vma(vma);
  expanded:
  	perf_event_mmap(vma);
  
@@ -2931,6 +2933,7 @@ static int do_brk_flags(struct vma_iterator *vmi, struct vm_area_struct *vma,
  		goto mas_store_fail;
  
  	mm->map_count++;
+	ksm_add_vma(vma);
  out:
  	perf_event_mmap(vma);
  	mm->total_vm += len >> PAGE_SHIFT;
Stefan Roesch April 12, 2023, 7:55 p.m. UTC | #8
David Hildenbrand <david@redhat.com> writes:

> On 12.04.23 20:41, David Hildenbrand wrote:
>> [...]
>>> That will work.
>>>> work? IOW, not exporting ksm_add_mm() and not passing a flag to __ksm_enter() --
>>>> it would simply set MMF_VM_MERGEABLE ?
>>>>
>>>
>>> ksm_add_mm() is also used in prctl (kernel/sys.c). Do you want to make a
>>> similar change there?
>> Yes.
>>
>>>>> + *
>>>>> + * @vma:  Pointer to vma
>>>>> + */
>>>>> +void ksm_add_vma(struct vm_area_struct *vma)
>>>>> +{
>>>>> +	struct mm_struct *mm = vma->vm_mm;
>>>>> +
>>>>> +	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
>>>>> +		__ksm_add_vma(vma);
>>>>> +}
>>>>> +
>>>>> +/**
>>>>> + * ksm_add_vmas - Mark all vma's of a process as mergeable
>>>>> + *
>>>>> + * @mm:  Pointer to mm
>>>>> + */
>>>>> +void ksm_add_vmas(struct mm_struct *mm)
>>>>
>>>> I'd suggest calling this
>>>>
>>> I guess you forgot your name suggestion?
>> Yeah, I reconsidered because the first idea I had was not particularly
>> good. Maybe
>> ksm_enable_for_all_vmas()
>> But not so sure. If you think the "add" terminology is a good fit, keep
>> it like that.
>> Thanks for bearing with me :)
>>
>
> I briefly played with your patch to see how much it can be simplified.
> Always enabling ksm (setting MMF_VM_MERGEABLE) before setting
> MMF_VM_MERGE_ANY might simplify things. ksm_enable_merge_any() [or however it should
> be called] and ksm_fork() contain the interesting bits.
>
>
> Feel free to incorporate what you consider valuable (uncompiled,
> untested).
>
I added most of it. The only change is that I kept ksm_add_vmas as a
static function, otherwise I need to define the VMA_ITERATOR at the top
of the function.

>
> diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c
> index 5a716bdcba05..5b2eef31398e 100644
> --- a/arch/s390/mm/gmap.c
> +++ b/arch/s390/mm/gmap.c
> @@ -2591,6 +2591,12 @@ int gmap_mark_unmergeable(void)
>  	int ret;
>  	VMA_ITERATOR(vmi, mm, 0);
>  +	/*
> +	 * Make sure to disable KSM (if enabled for the whole process or
> +	 * individual VMAs). Note that nothing currently hinders user space
> +	 * from re-enabling it.
> +	 */
> +	clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
>  	for_each_vma(vmi, vma) {
>  		/* Copy vm_flags to avoid partial modifications in ksm_madvise */
>  		vm_flags = vma->vm_flags;
> diff --git a/include/linux/ksm.h b/include/linux/ksm.h
> index 7e232ba59b86..c638b034d586 100644
> --- a/include/linux/ksm.h
> +++ b/include/linux/ksm.h
> @@ -18,13 +18,24 @@
>  #ifdef CONFIG_KSM
>  int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
>  		unsigned long end, int advice, unsigned long *vm_flags);
> +
> +void ksm_add_vma(struct vm_area_struct *vma);
> +int ksm_enable_merge_any(struct mm_struct *mm);
> +
>  int __ksm_enter(struct mm_struct *mm);
>  void __ksm_exit(struct mm_struct *mm);
>    static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
>  {
> -	if (test_bit(MMF_VM_MERGEABLE, &oldmm->flags))
> -		return __ksm_enter(mm);
> +	int ret;
> +
> +	if (test_bit(MMF_VM_MERGEABLE, &oldmm->flags)) {
> +		ret = __ksm_enter(mm);
> +		if (ret)
> +			return ret;
> +	}
> +	if (test_bit(MMF_VM_MERGE_ANY, &oldmm->flags))
> +		set_bit(MMF_VM_MERGE_ANY, &mm->flags);
>  	return 0;
>  }
>  @@ -53,6 +64,10 @@ void folio_migrate_ksm(struct folio *newfolio, struct folio
> *folio);
>    #else  /* !CONFIG_KSM */
>  +static inline void ksm_add_vma(struct vm_area_struct *vma)
> +{
> +}
> +
>  static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
>  {
>  	return 0;
> diff --git a/include/linux/sched/coredump.h b/include/linux/sched/coredump.h
> index 0e17ae7fbfd3..0ee96ea7a0e9 100644
> --- a/include/linux/sched/coredump.h
> +++ b/include/linux/sched/coredump.h
> @@ -90,4 +90,5 @@ static inline int get_dumpable(struct mm_struct *mm)
>  #define MMF_INIT_MASK		(MMF_DUMPABLE_MASK | MMF_DUMP_FILTER_MASK |\
>  				 MMF_DISABLE_THP_MASK | MMF_HAS_MDWE_MASK)
>  +#define MMF_VM_MERGE_ANY	29
>  #endif /* _LINUX_SCHED_COREDUMP_H */
> diff --git a/include/uapi/linux/prctl.h b/include/uapi/linux/prctl.h
> index 1312a137f7fb..759b3f53e53f 100644
> --- a/include/uapi/linux/prctl.h
> +++ b/include/uapi/linux/prctl.h
> @@ -290,4 +290,6 @@ struct prctl_mm_map {
>  #define PR_SET_VMA		0x53564d41
>  # define PR_SET_VMA_ANON_NAME		0
>  +#define PR_SET_MEMORY_MERGE		67
> +#define PR_GET_MEMORY_MERGE		68
>  #endif /* _LINUX_PRCTL_H */
> diff --git a/kernel/sys.c b/kernel/sys.c
> index 495cd87d9bf4..8c2e50edeb18 100644
> --- a/kernel/sys.c
> +++ b/kernel/sys.c
> @@ -15,6 +15,7 @@
>  #include <linux/highuid.h>
>  #include <linux/fs.h>
>  #include <linux/kmod.h>
> +#include <linux/ksm.h>
>  #include <linux/perf_event.h>
>  #include <linux/resource.h>
>  #include <linux/kernel.h>
> @@ -2661,6 +2662,30 @@ SYSCALL_DEFINE5(prctl, int, option, unsigned long, arg2, unsigned long, arg3,
>  	case PR_SET_VMA:
>  		error = prctl_set_vma(arg2, arg3, arg4, arg5);
>  		break;
> +#ifdef CONFIG_KSM
> +	case PR_SET_MEMORY_MERGE:
> +		if (mmap_write_lock_killable(me->mm))
> +			return -EINTR;
> +
> +		if (arg2) {
> +			error = ksm_enable_merge_any(me->mm);
> +		} else {
> +			/*
> +			 * TODO: we might want disable KSM on all VMAs and
> +			 * trigger unsharing to completely disable KSM.
> +			 */
> +			clear_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
> +			error = 0;
> +		}
> +		mmap_write_unlock(me->mm);
> +		break;
> +	case PR_GET_MEMORY_MERGE:
> +		if (arg2 || arg3 || arg4 || arg5)
> +			return -EINVAL;
> +
> +		error = !!test_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
> +		break;
> +#endif
>  	default:
>  		error = -EINVAL;
>  		break;
> diff --git a/mm/ksm.c b/mm/ksm.c
> index 2b8d30068cbb..76ceec35395c 100644
> --- a/mm/ksm.c
> +++ b/mm/ksm.c
> @@ -512,6 +512,28 @@ static int break_ksm(struct vm_area_struct *vma, unsigned long addr)
>  	return (ret & VM_FAULT_OOM) ? -ENOMEM : 0;
>  }
>  +static bool vma_ksm_compatible(struct vm_area_struct *vma)
> +{
> +	if (vma->vm_flags & (VM_SHARED  | VM_MAYSHARE   | VM_PFNMAP  |
> +			     VM_IO      | VM_DONTEXPAND | VM_HUGETLB |
> +			     VM_MIXEDMAP))
> +		return false;		/* just ignore the advice */
> +
> +	if (vma_is_dax(vma))
> +		return false;
> +
> +#ifdef VM_SAO
> +	if (vma->vm_flags & VM_SAO)
> +		return false;
> +#endif
> +#ifdef VM_SPARC_ADI
> +	if (vma->vm_flags & VM_SPARC_ADI)
> +		return false;
> +#endif
> +
> +	return true;
> +}
> +
>  static struct vm_area_struct *find_mergeable_vma(struct mm_struct *mm,
>  		unsigned long addr)
>  {
> @@ -1020,6 +1042,7 @@ static int unmerge_and_remove_all_rmap_items(void)
>    			mm_slot_free(mm_slot_cache, mm_slot);
>  			clear_bit(MMF_VM_MERGEABLE, &mm->flags);
> +			clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
>  			mmdrop(mm);
>  		} else
>  			spin_unlock(&ksm_mmlist_lock);
> @@ -2395,6 +2418,7 @@ static struct ksm_rmap_item *scan_get_next_rmap_item(struct page **page)
>    		mm_slot_free(mm_slot_cache, mm_slot);
>  		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
> +		clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
>  		mmap_read_unlock(mm);
>  		mmdrop(mm);
>  	} else {
> @@ -2471,6 +2495,52 @@ static int ksm_scan_thread(void *nothing)
>  	return 0;
>  }
>  +static void __ksm_add_vma(struct vm_area_struct *vma)
> +{
> +	unsigned long vm_flags = vma->vm_flags;
> +
> +	if (vm_flags & VM_MERGEABLE)
> +		return;
> +
> +	if (vma_ksm_compatible(vma)) {
> +		vm_flags |= VM_MERGEABLE;
> +		vm_flags_reset(vma, vm_flags);
> +	}
> +}
> +
> +/**
> + * ksm_add_vma - Mark vma as mergeable
> + *
> + * @vma:  Pointer to vma
> + */
> +void ksm_add_vma(struct vm_area_struct *vma)
> +{
> +	struct mm_struct *mm = vma->vm_mm;
> +
> +	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
> +		__ksm_add_vma(vma);
> +}
> +
> +int ksm_enable_merge_any(struct mm_struct *mm)
> +{
> +	struct vm_area_struct *vma;
> +	int ret;
> +
> +	if (test_bit(MMF_VM_MERGE_ANY, mm->flags))
> +		return 0;
> +
> +	if (!test_bit(MMF_VM_MERGEABLE, mm->flags)) {
> +		ret = __ksm_enter(mm);
> +		if (ret)
> +			return ret;
> +	}
> +	set_bit(MMF_VM_MERGE_ANY, &mm->flags);
> +
> +	VMA_ITERATOR(vmi, mm, 0);
> +	for_each_vma(vmi, vma)
> +		__ksm_add_vma(vma);
> +}
> +
>  int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
>  		unsigned long end, int advice, unsigned long *vm_flags)
>  {
> @@ -2479,25 +2549,10 @@ int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
>    	switch (advice) {
>  	case MADV_MERGEABLE:
> -		/*
> -		 * Be somewhat over-protective for now!
> -		 */
> -		if (*vm_flags & (VM_MERGEABLE | VM_SHARED  | VM_MAYSHARE   |
> -				 VM_PFNMAP    | VM_IO      | VM_DONTEXPAND |
> -				 VM_HUGETLB | VM_MIXEDMAP))
> -			return 0;		/* just ignore the advice */
> -
> -		if (vma_is_dax(vma))
> +		if (vma->vm_flags & VM_MERGEABLE)
>  			return 0;
> -
> -#ifdef VM_SAO
> -		if (*vm_flags & VM_SAO)
> -			return 0;
> -#endif
> -#ifdef VM_SPARC_ADI
> -		if (*vm_flags & VM_SPARC_ADI)
> +		if (!vma_ksm_compatible(vma))
>  			return 0;
> -#endif
>    		if (!test_bit(MMF_VM_MERGEABLE, &mm->flags)) {
>  			err = __ksm_enter(mm);
> @@ -2601,6 +2656,7 @@ void __ksm_exit(struct mm_struct *mm)
>  	if (easy_to_free) {
>  		mm_slot_free(mm_slot_cache, mm_slot);
>  		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
> +		clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
>  		mmdrop(mm);
>  	} else if (mm_slot) {
>  		mmap_write_lock(mm);
> diff --git a/mm/mmap.c b/mm/mmap.c
> index ff68a67a2a7c..1f8619ff58ca 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -46,6 +46,7 @@
>  #include <linux/pkeys.h>
>  #include <linux/oom.h>
>  #include <linux/sched/mm.h>
> +#include <linux/ksm.h>
>    #include <linux/uaccess.h>
>  #include <asm/cacheflush.h>
> @@ -2659,6 +2660,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
>  	if (file && vm_flags & VM_SHARED)
>  		mapping_unmap_writable(file->f_mapping);
>  	file = vma->vm_file;
> +	ksm_add_vma(vma);
>  expanded:
>  	perf_event_mmap(vma);
>  @@ -2931,6 +2933,7 @@ static int do_brk_flags(struct vma_iterator *vmi, struct
> vm_area_struct *vma,
>  		goto mas_store_fail;
>    	mm->map_count++;
> +	ksm_add_vma(vma);
>  out:
>  	perf_event_mmap(vma);
>  	mm->total_vm += len >> PAGE_SHIFT;
> --
> 2.39.2
David Hildenbrand April 13, 2023, 9:46 a.m. UTC | #9
On 12.04.23 21:55, Stefan Roesch wrote:
> 
> David Hildenbrand <david@redhat.com> writes:
> 
>> On 12.04.23 20:41, David Hildenbrand wrote:
>>> [...]
>>>> That will work.
>>>>> work? IOW, not exporting ksm_add_mm() and not passing a flag to __ksm_enter() --
>>>>> it would simply set MMF_VM_MERGEABLE ?
>>>>>
>>>>
>>>> ksm_add_mm() is also used in prctl (kernel/sys.c). Do you want to make a
>>>> similar change there?
>>> Yes.
>>>
>>>>>> + *
>>>>>> + * @vma:  Pointer to vma
>>>>>> + */
>>>>>> +void ksm_add_vma(struct vm_area_struct *vma)
>>>>>> +{
>>>>>> +	struct mm_struct *mm = vma->vm_mm;
>>>>>> +
>>>>>> +	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
>>>>>> +		__ksm_add_vma(vma);
>>>>>> +}
>>>>>> +
>>>>>> +/**
>>>>>> + * ksm_add_vmas - Mark all vma's of a process as mergeable
>>>>>> + *
>>>>>> + * @mm:  Pointer to mm
>>>>>> + */
>>>>>> +void ksm_add_vmas(struct mm_struct *mm)
>>>>>
>>>>> I'd suggest calling this
>>>>>
>>>> I guess you forgot your name suggestion?
>>> Yeah, I reconsidered because the first idea I had was not particularly
>>> good. Maybe
>>> ksm_enable_for_all_vmas()
>>> But not so sure. If you think the "add" terminology is a good fit, keep
>>> it like that.
>>> Thanks for bearing with me :)
>>>
>>
>> I briefly played with your patch to see how much it can be simplified.
>> Always enabling ksm (setting MMF_VM_MERGEABLE) before setting
>> MMF_VM_MERGE_ANY might simplify things. ksm_enable_merge_any() [or however it should
>> be called] and ksm_fork() contain the interesting bits.
>>
>>
>> Feel free to incorporate what you consider valuable (uncompiled,
>> untested).
>>
> I added most of it. The only change is that I kept ksm_add_vmas as a
> static function, otherwise I need to define the VMA_ITERATOR at the top
> of the function.


Makes sense. I'll review patch #3 later, so we can hopefully get this 
into the 6.4 merge window after letting it rest at least some days in -next.
diff mbox series

Patch

diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c
index 5a716bdcba05..9d85e5589474 100644
--- a/arch/s390/mm/gmap.c
+++ b/arch/s390/mm/gmap.c
@@ -2591,6 +2591,7 @@  int gmap_mark_unmergeable(void)
 	int ret;
 	VMA_ITERATOR(vmi, mm, 0);
 
+	clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
 	for_each_vma(vmi, vma) {
 		/* Copy vm_flags to avoid partial modifications in ksm_madvise */
 		vm_flags = vma->vm_flags;
diff --git a/include/linux/ksm.h b/include/linux/ksm.h
index 7e232ba59b86..f24f9faf1561 100644
--- a/include/linux/ksm.h
+++ b/include/linux/ksm.h
@@ -18,20 +18,29 @@ 
 #ifdef CONFIG_KSM
 int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
 		unsigned long end, int advice, unsigned long *vm_flags);
-int __ksm_enter(struct mm_struct *mm);
-void __ksm_exit(struct mm_struct *mm);
+
+int ksm_add_mm(struct mm_struct *mm);
+void ksm_add_vma(struct vm_area_struct *vma);
+void ksm_add_vmas(struct mm_struct *mm);
+
+int __ksm_enter(struct mm_struct *mm, int flag);
+void __ksm_exit(struct mm_struct *mm, int flag);
 
 static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
 {
+	if (test_bit(MMF_VM_MERGE_ANY, &oldmm->flags))
+		return ksm_add_mm(mm);
 	if (test_bit(MMF_VM_MERGEABLE, &oldmm->flags))
-		return __ksm_enter(mm);
+		return __ksm_enter(mm, MMF_VM_MERGEABLE);
 	return 0;
 }
 
 static inline void ksm_exit(struct mm_struct *mm)
 {
-	if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
-		__ksm_exit(mm);
+	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
+		__ksm_exit(mm, MMF_VM_MERGE_ANY);
+	else if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
+		__ksm_exit(mm, MMF_VM_MERGEABLE);
 }
 
 /*
@@ -53,6 +62,10 @@  void folio_migrate_ksm(struct folio *newfolio, struct folio *folio);
 
 #else  /* !CONFIG_KSM */
 
+static inline void ksm_add_vma(struct vm_area_struct *vma)
+{
+}
+
 static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
 {
 	return 0;
diff --git a/include/linux/sched/coredump.h b/include/linux/sched/coredump.h
index 0e17ae7fbfd3..0ee96ea7a0e9 100644
--- a/include/linux/sched/coredump.h
+++ b/include/linux/sched/coredump.h
@@ -90,4 +90,5 @@  static inline int get_dumpable(struct mm_struct *mm)
 #define MMF_INIT_MASK		(MMF_DUMPABLE_MASK | MMF_DUMP_FILTER_MASK |\
 				 MMF_DISABLE_THP_MASK | MMF_HAS_MDWE_MASK)
 
+#define MMF_VM_MERGE_ANY	29
 #endif /* _LINUX_SCHED_COREDUMP_H */
diff --git a/include/uapi/linux/prctl.h b/include/uapi/linux/prctl.h
index 1312a137f7fb..759b3f53e53f 100644
--- a/include/uapi/linux/prctl.h
+++ b/include/uapi/linux/prctl.h
@@ -290,4 +290,6 @@  struct prctl_mm_map {
 #define PR_SET_VMA		0x53564d41
 # define PR_SET_VMA_ANON_NAME		0
 
+#define PR_SET_MEMORY_MERGE		67
+#define PR_GET_MEMORY_MERGE		68
 #endif /* _LINUX_PRCTL_H */
diff --git a/kernel/fork.c b/kernel/fork.c
index f68954d05e89..1520697cf6c7 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -686,6 +686,7 @@  static __latent_entropy int dup_mmap(struct mm_struct *mm,
 		if (vma_iter_bulk_store(&vmi, tmp))
 			goto fail_nomem_vmi_store;
 
+		ksm_add_vma(tmp);
 		mm->map_count++;
 		if (!(tmp->vm_flags & VM_WIPEONFORK))
 			retval = copy_page_range(tmp, mpnt);
diff --git a/kernel/sys.c b/kernel/sys.c
index 495cd87d9bf4..9bba163d2d04 100644
--- a/kernel/sys.c
+++ b/kernel/sys.c
@@ -15,6 +15,7 @@ 
 #include <linux/highuid.h>
 #include <linux/fs.h>
 #include <linux/kmod.h>
+#include <linux/ksm.h>
 #include <linux/perf_event.h>
 #include <linux/resource.h>
 #include <linux/kernel.h>
@@ -2661,6 +2662,28 @@  SYSCALL_DEFINE5(prctl, int, option, unsigned long, arg2, unsigned long, arg3,
 	case PR_SET_VMA:
 		error = prctl_set_vma(arg2, arg3, arg4, arg5);
 		break;
+#ifdef CONFIG_KSM
+	case PR_SET_MEMORY_MERGE:
+		if (mmap_write_lock_killable(me->mm))
+			return -EINTR;
+
+		if (arg2) {
+			int err = ksm_add_mm(me->mm);
+
+			if (!err)
+				ksm_add_vmas(me->mm);
+		} else {
+			clear_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
+		}
+		mmap_write_unlock(me->mm);
+		break;
+	case PR_GET_MEMORY_MERGE:
+		if (arg2 || arg3 || arg4 || arg5)
+			return -EINVAL;
+
+		error = !!test_bit(MMF_VM_MERGE_ANY, &me->mm->flags);
+		break;
+#endif
 	default:
 		error = -EINVAL;
 		break;
diff --git a/mm/ksm.c b/mm/ksm.c
index d7bd28199f6c..ab95ae0f9def 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -534,10 +534,33 @@  static int break_ksm(struct vm_area_struct *vma, unsigned long addr,
 	return (ret & VM_FAULT_OOM) ? -ENOMEM : 0;
 }
 
+static bool vma_ksm_compatible(struct vm_area_struct *vma)
+{
+	if (vma->vm_flags & (VM_SHARED  | VM_MAYSHARE   | VM_PFNMAP  |
+			     VM_IO      | VM_DONTEXPAND | VM_HUGETLB |
+			     VM_MIXEDMAP))
+		return false;		/* just ignore the advice */
+
+	if (vma_is_dax(vma))
+		return false;
+
+#ifdef VM_SAO
+	if (vma->vm_flags & VM_SAO)
+		return false;
+#endif
+#ifdef VM_SPARC_ADI
+	if (vma->vm_flags & VM_SPARC_ADI)
+		return false;
+#endif
+
+	return true;
+}
+
 static struct vm_area_struct *find_mergeable_vma(struct mm_struct *mm,
 		unsigned long addr)
 {
 	struct vm_area_struct *vma;
+
 	if (ksm_test_exit(mm))
 		return NULL;
 	vma = vma_lookup(mm, addr);
@@ -1065,6 +1088,7 @@  static int unmerge_and_remove_all_rmap_items(void)
 
 			mm_slot_free(mm_slot_cache, mm_slot);
 			clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+			clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
 			mmdrop(mm);
 		} else
 			spin_unlock(&ksm_mmlist_lock);
@@ -2495,6 +2519,7 @@  static struct ksm_rmap_item *scan_get_next_rmap_item(struct page **page)
 
 		mm_slot_free(mm_slot_cache, mm_slot);
 		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+		clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
 		mmap_read_unlock(mm);
 		mmdrop(mm);
 	} else {
@@ -2571,6 +2596,63 @@  static int ksm_scan_thread(void *nothing)
 	return 0;
 }
 
+static void __ksm_add_vma(struct vm_area_struct *vma)
+{
+	unsigned long vm_flags = vma->vm_flags;
+
+	if (vm_flags & VM_MERGEABLE)
+		return;
+
+	if (vma_ksm_compatible(vma)) {
+		vm_flags |= VM_MERGEABLE;
+		vm_flags_reset(vma, vm_flags);
+	}
+}
+
+/**
+ * ksm_add_vma - Mark vma as mergeable
+ *
+ * @vma:  Pointer to vma
+ */
+void ksm_add_vma(struct vm_area_struct *vma)
+{
+	struct mm_struct *mm = vma->vm_mm;
+
+	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
+		__ksm_add_vma(vma);
+}
+
+/**
+ * ksm_add_vmas - Mark all vma's of a process as mergeable
+ *
+ * @mm:  Pointer to mm
+ */
+void ksm_add_vmas(struct mm_struct *mm)
+{
+	struct vm_area_struct *vma;
+
+	VMA_ITERATOR(vmi, mm, 0);
+	for_each_vma(vmi, vma)
+		__ksm_add_vma(vma);
+}
+
+/**
+ * ksm_add_mm - Add mm to mm ksm list
+ *
+ * @mm:  Pointer to mm
+ *
+ * Returns 0 on success, otherwise error code
+ */
+int ksm_add_mm(struct mm_struct *mm)
+{
+	if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
+		return -EINVAL;
+	if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
+		return -EINVAL;
+
+	return __ksm_enter(mm, MMF_VM_MERGE_ANY);
+}
+
 int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
 		unsigned long end, int advice, unsigned long *vm_flags)
 {
@@ -2579,28 +2661,13 @@  int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
 
 	switch (advice) {
 	case MADV_MERGEABLE:
-		/*
-		 * Be somewhat over-protective for now!
-		 */
-		if (*vm_flags & (VM_MERGEABLE | VM_SHARED  | VM_MAYSHARE   |
-				 VM_PFNMAP    | VM_IO      | VM_DONTEXPAND |
-				 VM_HUGETLB | VM_MIXEDMAP))
-			return 0;		/* just ignore the advice */
-
-		if (vma_is_dax(vma))
+		if (vma->vm_flags & VM_MERGEABLE)
 			return 0;
-
-#ifdef VM_SAO
-		if (*vm_flags & VM_SAO)
+		if (!vma_ksm_compatible(vma))
 			return 0;
-#endif
-#ifdef VM_SPARC_ADI
-		if (*vm_flags & VM_SPARC_ADI)
-			return 0;
-#endif
 
 		if (!test_bit(MMF_VM_MERGEABLE, &mm->flags)) {
-			err = __ksm_enter(mm);
+			err = __ksm_enter(mm, MMF_VM_MERGEABLE);
 			if (err)
 				return err;
 		}
@@ -2626,7 +2693,7 @@  int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
 }
 EXPORT_SYMBOL_GPL(ksm_madvise);
 
-int __ksm_enter(struct mm_struct *mm)
+int __ksm_enter(struct mm_struct *mm, int flag)
 {
 	struct ksm_mm_slot *mm_slot;
 	struct mm_slot *slot;
@@ -2659,7 +2726,7 @@  int __ksm_enter(struct mm_struct *mm)
 		list_add_tail(&slot->mm_node, &ksm_scan.mm_slot->slot.mm_node);
 	spin_unlock(&ksm_mmlist_lock);
 
-	set_bit(MMF_VM_MERGEABLE, &mm->flags);
+	set_bit(flag, &mm->flags);
 	mmgrab(mm);
 
 	if (needs_wakeup)
@@ -2668,7 +2735,7 @@  int __ksm_enter(struct mm_struct *mm)
 	return 0;
 }
 
-void __ksm_exit(struct mm_struct *mm)
+void __ksm_exit(struct mm_struct *mm, int flag)
 {
 	struct ksm_mm_slot *mm_slot;
 	struct mm_slot *slot;
@@ -2700,7 +2767,7 @@  void __ksm_exit(struct mm_struct *mm)
 
 	if (easy_to_free) {
 		mm_slot_free(mm_slot_cache, mm_slot);
-		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+		clear_bit(flag, &mm->flags);
 		mmdrop(mm);
 	} else if (mm_slot) {
 		mmap_write_lock(mm);
diff --git a/mm/mmap.c b/mm/mmap.c
index 740b54be3ed4..483e182e0b9d 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -46,6 +46,7 @@ 
 #include <linux/pkeys.h>
 #include <linux/oom.h>
 #include <linux/sched/mm.h>
+#include <linux/ksm.h>
 
 #include <linux/uaccess.h>
 #include <asm/cacheflush.h>
@@ -2213,6 +2214,8 @@  int __split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
 	/* vma_complete stores the new vma */
 	vma_complete(&vp, vmi, vma->vm_mm);
 
+	ksm_add_vma(new);
+
 	/* Success. */
 	if (new_below)
 		vma_next(vmi);
@@ -2664,6 +2667,7 @@  unsigned long mmap_region(struct file *file, unsigned long addr,
 	if (file && vm_flags & VM_SHARED)
 		mapping_unmap_writable(file->f_mapping);
 	file = vma->vm_file;
+	ksm_add_vma(vma);
 expanded:
 	perf_event_mmap(vma);
 
@@ -2936,6 +2940,7 @@  static int do_brk_flags(struct vma_iterator *vmi, struct vm_area_struct *vma,
 		goto mas_store_fail;
 
 	mm->map_count++;
+	ksm_add_vma(vma);
 out:
 	perf_event_mmap(vma);
 	mm->total_vm += len >> PAGE_SHIFT;
@@ -3180,6 +3185,7 @@  struct vm_area_struct *copy_vma(struct vm_area_struct **vmap,
 		if (vma_link(mm, new_vma))
 			goto out_vma_link;
 		*need_rmap_locks = false;
+		ksm_add_vma(new_vma);
 	}
 	validate_mm_mt(mm);
 	return new_vma;
@@ -3356,6 +3362,7 @@  static struct vm_area_struct *__install_special_mapping(
 	vm_stat_account(mm, vma->vm_flags, len >> PAGE_SHIFT);
 
 	perf_event_mmap(vma);
+	ksm_add_vma(vma);
 
 	validate_mm_mt(mm);
 	return vma;