diff mbox series

[7/9] vhost: allow userspace to create workers

Message ID 20210525180600.6349-8-michael.christie@oracle.com (mailing list archive)
State New, archived
Headers show
Series [1/9] vhost: move worker thread fields to new struct | expand

Commit Message

Mike Christie May 25, 2021, 6:05 p.m. UTC
This patch allows userspace to create workers and bind them to vqs, so you
can have N workers per dev and also share N workers with M vqs. The next
patch will allow sharing across devices.

Signed-off-by: Mike Christie <michael.christie@oracle.com>
---
 drivers/vhost/vhost.c            | 94 +++++++++++++++++++++++++++++++-
 drivers/vhost/vhost.h            |  3 +
 include/uapi/linux/vhost.h       |  6 ++
 include/uapi/linux/vhost_types.h | 12 ++++
 4 files changed, 113 insertions(+), 2 deletions(-)

Comments

Stefan Hajnoczi June 3, 2021, 2:30 p.m. UTC | #1
On Tue, May 25, 2021 at 01:05:58PM -0500, Mike Christie wrote:
> This patch allows userspace to create workers and bind them to vqs, so you
> can have N workers per dev and also share N workers with M vqs. The next
> patch will allow sharing across devices.
> 
> Signed-off-by: Mike Christie <michael.christie@oracle.com>
> ---
>  drivers/vhost/vhost.c            | 94 +++++++++++++++++++++++++++++++-
>  drivers/vhost/vhost.h            |  3 +
>  include/uapi/linux/vhost.h       |  6 ++
>  include/uapi/linux/vhost_types.h | 12 ++++
>  4 files changed, 113 insertions(+), 2 deletions(-)
> 
> diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
> index 345ade0af133..981e9bac7a31 100644
> --- a/drivers/vhost/vhost.c
> +++ b/drivers/vhost/vhost.c
> @@ -30,6 +30,7 @@
>  #include <linux/interval_tree_generic.h>
>  #include <linux/nospec.h>
>  #include <linux/kcov.h>
> +#include <linux/hashtable.h>
>  
>  #include "vhost.h"
>  
> @@ -42,6 +43,9 @@ module_param(max_iotlb_entries, int, 0444);
>  MODULE_PARM_DESC(max_iotlb_entries,
>  	"Maximum number of iotlb entries. (default: 2048)");
>  
> +static DEFINE_HASHTABLE(vhost_workers, 5);
> +static DEFINE_SPINLOCK(vhost_workers_lock);
> +
>  enum {
>  	VHOST_MEMORY_F_LOG = 0x1,
>  };
> @@ -617,8 +621,17 @@ static void vhost_detach_mm(struct vhost_dev *dev)
>  	dev->mm = NULL;
>  }
>  
> -static void vhost_worker_free(struct vhost_worker *worker)
> +static void vhost_worker_put(struct vhost_worker *worker)
>  {
> +	spin_lock(&vhost_workers_lock);
> +	if (!refcount_dec_and_test(&worker->refcount)) {
> +		spin_unlock(&vhost_workers_lock);
> +		return;
> +	}
> +
> +	hash_del(&worker->h_node);
> +	spin_unlock(&vhost_workers_lock);
> +
>  	WARN_ON(!llist_empty(&worker->work_list));
>  	kthread_stop(worker->task);
>  	kfree(worker);
> @@ -632,7 +645,7 @@ static void vhost_workers_free(struct vhost_dev *dev)
>  		return;
>  
>  	for (i = 0; i < dev->num_workers; i++)
> -		vhost_worker_free(dev->workers[i]);
> +		vhost_worker_put(dev->workers[i]);
>  
>  	kfree(dev->workers);
>  	dev->num_workers = 0;
> @@ -652,6 +665,8 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
>  	worker->id = dev->num_workers;
>  	worker->dev = dev;
>  	init_llist_head(&worker->work_list);
> +	INIT_HLIST_NODE(&worker->h_node);
> +	refcount_set(&worker->refcount, 1);
>  
>  	task = kthread_create(vhost_worker, worker, "vhost-%d", current->pid);
>  	if (IS_ERR(task))
> @@ -664,6 +679,9 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
>  	if (ret)
>  		goto stop_worker;
>  
> +	spin_lock(&vhost_workers_lock);
> +	hash_add(vhost_workers, &worker->h_node, worker->task->pid);
> +	spin_unlock(&vhost_workers_lock);
>  	return worker;
>  
>  stop_worker:
> @@ -673,6 +691,67 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
>  	return NULL;
>  }
>  
> +static struct vhost_worker *vhost_worker_find(struct vhost_dev *dev, pid_t pid)
> +{
> +	struct vhost_worker *worker, *found_worker = NULL;
> +
> +	spin_lock(&vhost_workers_lock);
> +	hash_for_each_possible(vhost_workers, worker, h_node, pid) {
> +		if (worker->task->pid == pid) {
> +			/* tmp - next patch allows sharing across devs */
> +			if (worker->dev != dev)
> +				break;
> +
> +			found_worker = worker;
> +			refcount_inc(&worker->refcount);
> +			break;
> +		}
> +	}
> +	spin_unlock(&vhost_workers_lock);
> +	return found_worker;
> +}
> +
> +/* Caller must have device mutex */
> +static int vhost_vq_set_worker(struct vhost_virtqueue *vq,
> +			       struct vhost_vring_worker *info)
> +{
> +	struct vhost_dev *dev = vq->dev;
> +	struct vhost_worker *worker;
> +
> +	if (vq->worker) {
> +		/* TODO - support changing while works are running */
> +		return -EBUSY;
> +	}
> +
> +	if (info->pid == VHOST_VRING_NEW_WORKER) {
> +		worker = vhost_worker_create(dev);

The maximum number of kthreads created is limited by
vhost_dev_init(nvqs)? For example VHOST_SCSI_MAX_VQ 128.

IIUC kthread_create is not limited by or accounted against the current
task, so I'm a little worried that a process can create a lot of
kthreads.

I haven't investigated other kthread_create() users reachable from
userspace applications to see how they bound the number of threads
effectively.

Any thoughts?

> +		if (!worker)
> +			return -ENOMEM;
> +
> +		info->pid = worker->task->pid;
> +	} else {
> +		worker = vhost_worker_find(dev, info->pid);
> +		if (!worker)
> +			return -ENODEV;
> +	}
> +
> +	if (!dev->workers) {
> +		dev->workers = kcalloc(vq->dev->nvqs,
> +				       sizeof(struct vhost_worker *),
> +				       GFP_KERNEL);

Another candidate for GFP_KERNEL_ACCOUNT.

> +		if (!dev->workers) {
> +			vhost_worker_put(worker);
> +			return -ENOMEM;
> +		}
> +	}
> +
> +	vq->worker = worker;
> +
> +	dev->workers[dev->num_workers] = worker;
> +	dev->num_workers++;

Hmm...should we really append to workers[] in the vhost_worker_find()
case?

> +	return 0;
> +}
> +
>  /* Caller must have device mutex */
>  static int vhost_worker_try_create_def(struct vhost_dev *dev)
>  {
> @@ -1680,6 +1759,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
>  	struct eventfd_ctx *ctx = NULL;
>  	u32 __user *idxp = argp;
>  	struct vhost_virtqueue *vq;
> +	struct vhost_vring_worker w;
>  	struct vhost_vring_state s;
>  	struct vhost_vring_file f;
>  	u32 idx;
> @@ -1794,6 +1874,15 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
>  		if (copy_to_user(argp, &s, sizeof(s)))
>  			r = -EFAULT;
>  		break;
> +	case VHOST_SET_VRING_WORKER:
> +		if (copy_from_user(&w, argp, sizeof(w))) {
> +			r = -EFAULT;
> +			break;
> +		}
> +		r = vhost_vq_set_worker(vq, &w);
> +		if (!r && copy_to_user(argp, &w, sizeof(w)))
> +			r = -EFAULT;
> +		break;
>  	default:
>  		r = -ENOIOCTLCMD;
>  	}
> @@ -2726,6 +2815,7 @@ EXPORT_SYMBOL_GPL(vhost_set_backend_features);
>  
>  static int __init vhost_init(void)
>  {
> +	hash_init(vhost_workers);
>  	return 0;
>  }
>  
> diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
> index 0a252dd45101..75b884ad1f17 100644
> --- a/drivers/vhost/vhost.h
> +++ b/drivers/vhost/vhost.h
> @@ -14,6 +14,7 @@
>  #include <linux/atomic.h>
>  #include <linux/vhost_iotlb.h>
>  #include <linux/irqbypass.h>
> +#include <linux/refcount.h>
>  
>  struct vhost_work;
>  typedef void (*vhost_work_fn_t)(struct vhost_work *work);
> @@ -28,6 +29,8 @@ struct vhost_work {
>  struct vhost_worker {
>  	struct task_struct	*task;
>  	struct llist_head	work_list;
> +	struct hlist_node	h_node;

h_node is a generic name. If you're willing to use a longer name then
vhost_workers_node would make it clear which hlist this is associated
with.

> +	refcount_t		refcount;
>  	struct vhost_dev	*dev;
>  	int			id;
>  };
> diff --git a/include/uapi/linux/vhost.h b/include/uapi/linux/vhost.h
> index c998860d7bbc..ce32119cb139 100644
> --- a/include/uapi/linux/vhost.h
> +++ b/include/uapi/linux/vhost.h
> @@ -70,6 +70,12 @@
>  #define VHOST_VRING_BIG_ENDIAN 1
>  #define VHOST_SET_VRING_ENDIAN _IOW(VHOST_VIRTIO, 0x13, struct vhost_vring_state)
>  #define VHOST_GET_VRING_ENDIAN _IOW(VHOST_VIRTIO, 0x14, struct vhost_vring_state)
> +/* Create/bind a vhost worker to a virtqueue. If pid > 0 and matches an existing
> + * vhost_worker thread it will be bound to the vq. If pid is
> + * VHOST_VRING_NEW_WORKER, then a new worker will be created and bound to the
> + * vq.
> + */

Please document when this ioctl must be called (before kick is set up).

> +#define VHOST_SET_VRING_WORKER _IOWR(VHOST_VIRTIO, 0x15, struct vhost_vring_worker)
>  
>  /* The following ioctls use eventfd file descriptors to signal and poll
>   * for events. */
> diff --git a/include/uapi/linux/vhost_types.h b/include/uapi/linux/vhost_types.h
> index f7f6a3a28977..5113baa8bc3e 100644
> --- a/include/uapi/linux/vhost_types.h
> +++ b/include/uapi/linux/vhost_types.h
> @@ -47,6 +47,18 @@ struct vhost_vring_addr {
>  	__u64 log_guest_addr;
>  };
>  
> +#define VHOST_VRING_NEW_WORKER -1
> +
> +struct vhost_vring_worker {
> +	unsigned int index;
> +	/*
> +	 * The pid of the vhost worker that the vq will be bound to. If
> +	 * pid is VHOST_VRING_NEW_WORKER a new worker will be created and it's

s/it's/its/

> +	 * pid will be returned in pid.
> +	 */
> +	__kernel_pid_t pid;
> +};
> +
>  /* no alignment requirement */
>  struct vhost_iotlb_msg {
>  	__u64 iova;
> -- 
> 2.25.1
>
Mike Christie June 5, 2021, 11:53 p.m. UTC | #2
On 6/3/21 9:30 AM, Stefan Hajnoczi wrote:
>> +	if (info->pid == VHOST_VRING_NEW_WORKER) {
>> +		worker = vhost_worker_create(dev);
> 
> The maximum number of kthreads created is limited by
> vhost_dev_init(nvqs)? For example VHOST_SCSI_MAX_VQ 128.
> 
> IIUC kthread_create is not limited by or accounted against the current
> task, so I'm a little worried that a process can create a lot of
> kthreads.
> 
> I haven't investigated other kthread_create() users reachable from
> userspace applications to see how they bound the number of threads
> effectively.

Do we want something like io_uring's copy_process use? It's what fork uses,
so we get checks like RLIMIT_NPROC and max_threads.

I know I didn't look at everything, but it looks like for some software
drivers we just allow the user to run wild. For example for nbd, when we
create the device to do alloc_workqueue and use the default max_active
value (256). We then don't have a limit on connections, so we could end
up with 256 workqueue threads per device. And then there is no limit on
devices a user can make.


> 
> Any thoughts?
>

Is the concern a bad VM could create N devs each with 128 vqs/threads
and it would slow down other VMs? How do we handle the case where
some VM makes M * N devs and that is equal to N * 128 so we would end
up with the same number of threads either way? Is there a limit to the
number of vhost devices a VM can make and can I just stick in a similar
check for workers?

For vhost-scsi specifically, the 128 limit does not make a lot of sense.
I think we want the max to be the number of vCPUs the VM has so we can
add checks for that. Then we would assume someone making a VM with lots of
CPUs is going to have the resources to support them.

Note: It does make sense from the point of view that we don't know the
number of vCPUs when vhost-scsi calls vhost_dev_init, so I get we had to
select an initial limit.



>> +		if (!dev->workers) {
>> +			vhost_worker_put(worker);
>> +			return -ENOMEM;
>> +		}
>> +	}
>> +
>> +	vq->worker = worker;
>> +
>> +	dev->workers[dev->num_workers] = worker;
>> +	dev->num_workers++;
> 
> Hmm...should we really append to workers[] in the vhost_worker_find()
> case?


As it's coded now, yes. Every successful vhost_worker_find call does a
get on the worker's refcount. Later when we delete the device, we loop
over the workers array and for every entry we do a put.

I can add in some code to first check if the worker is already in the
dev's worker list. If so then skip the refcount and skip adding to the
workers array. If not in the dev's worker list then do a vhost_worker_find.

I thought it might be nicer how it is now with the single path. It's less
code at least. Later if we add support to change a vq's worker then we also
don't have to worry about refcounts as much. We just always drop the count
taken from when it was added.
Stefan Hajnoczi June 7, 2021, 3:19 p.m. UTC | #3
On Sat, Jun 05, 2021 at 06:53:58PM -0500, michael.christie@oracle.com wrote:
> On 6/3/21 9:30 AM, Stefan Hajnoczi wrote:
> >> +	if (info->pid == VHOST_VRING_NEW_WORKER) {
> >> +		worker = vhost_worker_create(dev);
> > 
> > The maximum number of kthreads created is limited by
> > vhost_dev_init(nvqs)? For example VHOST_SCSI_MAX_VQ 128.
> > 
> > IIUC kthread_create is not limited by or accounted against the current
> > task, so I'm a little worried that a process can create a lot of
> > kthreads.
> > 
> > I haven't investigated other kthread_create() users reachable from
> > userspace applications to see how they bound the number of threads
> > effectively.
> 
> Do we want something like io_uring's copy_process use? It's what fork uses,
> so we get checks like RLIMIT_NPROC and max_threads.
> 
> I know I didn't look at everything, but it looks like for some software
> drivers we just allow the user to run wild. For example for nbd, when we
> create the device to do alloc_workqueue and use the default max_active
> value (256). We then don't have a limit on connections, so we could end
> up with 256 workqueue threads per device. And then there is no limit on
> devices a user can make.
> 
> 
> > 
> > Any thoughts?
> >
> 
> Is the concern a bad VM could create N devs each with 128 vqs/threads
> and it would slow down other VMs? How do we handle the case where
> some VM makes M * N devs and that is equal to N * 128 so we would end
> up with the same number of threads either way? Is there a limit to the
> number of vhost devices a VM can make and can I just stick in a similar
> check for workers?
> 
> For vhost-scsi specifically, the 128 limit does not make a lot of sense.
> I think we want the max to be the number of vCPUs the VM has so we can
> add checks for that. Then we would assume someone making a VM with lots of
> CPUs is going to have the resources to support them.
> 
> Note: It does make sense from the point of view that we don't know the
> number of vCPUs when vhost-scsi calls vhost_dev_init, so I get we had to
> select an initial limit.

My concern is that threads should probably accounted against
RLIMIT_NPROC and max_threads rather than something indirect like 128 *
RLIMIT_NOFILE (a userspace process can only have RLIMIT_NOFILE
vhost-user file descriptors open).

> >> +		if (!dev->workers) {
> >> +			vhost_worker_put(worker);
> >> +			return -ENOMEM;
> >> +		}
> >> +	}
> >> +
> >> +	vq->worker = worker;
> >> +
> >> +	dev->workers[dev->num_workers] = worker;
> >> +	dev->num_workers++;
> > 
> > Hmm...should we really append to workers[] in the vhost_worker_find()
> > case?
> 
> 
> As it's coded now, yes. Every successful vhost_worker_find call does a
> get on the worker's refcount. Later when we delete the device, we loop
> over the workers array and for every entry we do a put.
> 
> I can add in some code to first check if the worker is already in the
> dev's worker list. If so then skip the refcount and skip adding to the
> workers array. If not in the dev's worker list then do a vhost_worker_find.
> 
> I thought it might be nicer how it is now with the single path. It's less
> code at least. Later if we add support to change a vq's worker then we also
> don't have to worry about refcounts as much. We just always drop the count
> taken from when it was added.

Thanks for explaining.

Stefan
Mike Christie June 9, 2021, 9:03 p.m. UTC | #4
On 6/7/21 10:19 AM, Stefan Hajnoczi wrote:
> My concern is that threads should probably accounted against
> RLIMIT_NPROC and max_threads rather than something indirect like 128 *
> RLIMIT_NOFILE (a userspace process can only have RLIMIT_NOFILE
> vhost-user file descriptors open).
> 

Ah ok, I see what you want I think.

Ok, I think the options are:

0. Nothing. Just use existing indirect/RLIMIT_NOFILE.

1. Do something like io_uring's create_io_thread/copy_process. If we call
copy_process from the vhost ioctl context, then the userspace process that
did the ioctl will have it's processes count incremented and checked against
its rlimit.

The drawbacks:
- This gets a little more complicated than just calling copy_process though.
We end up duplicating a lot of the kthread API.
- We have to deal with new error cases like the parent exiting early.
- I think all devs sharing a worker have to have the same owner. kthread_use_mm
and kthread_unuse_mm to switch between mm's for differrent owner's devs seem to
be causing lots of errors. I'm still looking into this one though.

2.  It's not really what you want, but for unbound work io_uring has a check for
RLIMIT_NPROC in the io_uring code. It does:

wqe->acct[IO_WQ_ACCT_UNBOUND].max_workers =
					task_rlimit(current, RLIMIT_NPROC);

then does:

if (!ret && acct->nr_workers < acct->max_workers) {

Drawbacks:
In vhost.c, we could do something similar. It would make sure that vhost.c does
not create more worker threads than the rlimit value, but we wouldn't be
incrementing the userspace process's process count. The userspace process could
then create RLIMIT_NPROC threads and vhost.c could also create RLIMIT_NPROC
threads, so we end up with 2 * RLIMIT_NPROC threads.

3. Change the kthread and copy_process code so we can pass in the thread
(or it's creds or some struct that has the values that need to be check) that
needs to be checked and updated.

Drawback:
This might be considered too ugly for how special case vhost is. For example, we
need checks/code like the io_thread/PF_IO_WORKER code in copy_process for io_uring.
I can see how added that for io_uring because it affects so many users, but I can
see how vhost is not special enough.
Stefan Hajnoczi June 10, 2021, 8:06 a.m. UTC | #5
On Wed, Jun 09, 2021 at 04:03:55PM -0500, Mike Christie wrote:
> On 6/7/21 10:19 AM, Stefan Hajnoczi wrote:
> > My concern is that threads should probably accounted against
> > RLIMIT_NPROC and max_threads rather than something indirect like 128 *
> > RLIMIT_NOFILE (a userspace process can only have RLIMIT_NOFILE
> > vhost-user file descriptors open).
> > 
> 
> Ah ok, I see what you want I think.
> 
> Ok, I think the options are:
> 
> 0. Nothing. Just use existing indirect/RLIMIT_NOFILE.
> 
> 1. Do something like io_uring's create_io_thread/copy_process. If we call
> copy_process from the vhost ioctl context, then the userspace process that
> did the ioctl will have it's processes count incremented and checked against
> its rlimit.
> 
> The drawbacks:
> - This gets a little more complicated than just calling copy_process though.
> We end up duplicating a lot of the kthread API.
> - We have to deal with new error cases like the parent exiting early.
> - I think all devs sharing a worker have to have the same owner. kthread_use_mm
> and kthread_unuse_mm to switch between mm's for differrent owner's devs seem to
> be causing lots of errors. I'm still looking into this one though.
> 
> 2.  It's not really what you want, but for unbound work io_uring has a check for
> RLIMIT_NPROC in the io_uring code. It does:
> 
> wqe->acct[IO_WQ_ACCT_UNBOUND].max_workers =
> 					task_rlimit(current, RLIMIT_NPROC);
> 
> then does:
> 
> if (!ret && acct->nr_workers < acct->max_workers) {
> 
> Drawbacks:
> In vhost.c, we could do something similar. It would make sure that vhost.c does
> not create more worker threads than the rlimit value, but we wouldn't be
> incrementing the userspace process's process count. The userspace process could
> then create RLIMIT_NPROC threads and vhost.c could also create RLIMIT_NPROC
> threads, so we end up with 2 * RLIMIT_NPROC threads.

Yes, in that case we might as well go with Option 0, so I think this
option can be eliminated.

> 3. Change the kthread and copy_process code so we can pass in the thread
> (or it's creds or some struct that has the values that need to be check) that
> needs to be checked and updated.
> 
> Drawback:
> This might be considered too ugly for how special case vhost is. For example, we
> need checks/code like the io_thread/PF_IO_WORKER code in copy_process for io_uring.
> I can see how added that for io_uring because it affects so many users, but I can
> see how vhost is not special enough.

This seems like the most general solution. If you try it and get
negative feedback then maybe the maintainers can help suggest how to
solve this problem :).

Stefan
Mike Christie June 18, 2021, 2:49 a.m. UTC | #6
On 6/10/21 3:06 AM, Stefan Hajnoczi wrote:
> On Wed, Jun 09, 2021 at 04:03:55PM -0500, Mike Christie wrote:
>> On 6/7/21 10:19 AM, Stefan Hajnoczi wrote:
>>> My concern is that threads should probably accounted against
>>> RLIMIT_NPROC and max_threads rather than something indirect like 128 *
>>> RLIMIT_NOFILE (a userspace process can only have RLIMIT_NOFILE
>>> vhost-user file descriptors open).
>>>
>>
>> Ah ok, I see what you want I think.
>>
>> Ok, I think the options are:
>>
>> 0. Nothing. Just use existing indirect/RLIMIT_NOFILE.
>>
>> 1. Do something like io_uring's create_io_thread/copy_process. If we call
>> copy_process from the vhost ioctl context, then the userspace process that
>> did the ioctl will have it's processes count incremented and checked against
>> its rlimit.
>>
>> The drawbacks:
>> - This gets a little more complicated than just calling copy_process though.
>> We end up duplicating a lot of the kthread API.
>> - We have to deal with new error cases like the parent exiting early.
>> - I think all devs sharing a worker have to have the same owner. kthread_use_mm
>> and kthread_unuse_mm to switch between mm's for differrent owner's devs seem to
>> be causing lots of errors. I'm still looking into this one though.
>>
>> 2.  It's not really what you want, but for unbound work io_uring has a check for
>> RLIMIT_NPROC in the io_uring code. It does:
>>
>> wqe->acct[IO_WQ_ACCT_UNBOUND].max_workers =
>> 					task_rlimit(current, RLIMIT_NPROC);
>>
>> then does:
>>
>> if (!ret && acct->nr_workers < acct->max_workers) {
>>
>> Drawbacks:
>> In vhost.c, we could do something similar. It would make sure that vhost.c does
>> not create more worker threads than the rlimit value, but we wouldn't be
>> incrementing the userspace process's process count. The userspace process could
>> then create RLIMIT_NPROC threads and vhost.c could also create RLIMIT_NPROC
>> threads, so we end up with 2 * RLIMIT_NPROC threads.
> 
> Yes, in that case we might as well go with Option 0, so I think this
> option can be eliminated.
> 
>> 3. Change the kthread and copy_process code so we can pass in the thread
>> (or it's creds or some struct that has the values that need to be check) that
>> needs to be checked and updated.
>>
>> Drawback:
>> This might be considered too ugly for how special case vhost is. For example, we
>> need checks/code like the io_thread/PF_IO_WORKER code in copy_process for io_uring.
>> I can see how added that for io_uring because it affects so many users, but I can
>> see how vhost is not special enough.
> 
> This seems like the most general solution. If you try it and get
> negative feedback then maybe the maintainers can help suggest how to
> solve this problem :).

Hey, I implemented #3 here:

https://github.com/mikechristie/linux/commit/76f7a555a85147420a22d0163c15259e01e02193

in this patchset:

https://github.com/mikechristie/linux/commits/kthread-node-user

but before I post I wanted to bring up an option 4 someone mentioned to me
offlist.

Again it's io_uring. Check out fs/io_uring.c:__io_account_mem(). For RLIMIT_MEMLOCK
it just does the check and increments the user's counter itself. It's simple like
option 2, and it handles the issue where the process doing the ioctl wasn't having
its RLIMIT_NPROC checked/updated.
Stefan Hajnoczi June 21, 2021, 1:41 p.m. UTC | #7
On Thu, Jun 17, 2021 at 09:49:07PM -0500, Mike Christie wrote:
> Again it's io_uring. Check out fs/io_uring.c:__io_account_mem(). For RLIMIT_MEMLOCK
> it just does the check and increments the user's counter itself. It's simple like
> option 2, and it handles the issue where the process doing the ioctl wasn't having
> its RLIMIT_NPROC checked/updated.

This can work too. It doesn't cover cases where code called indirectly
acquires resources, but that's probably fine for the vhost worker thread
case.

Stefan
diff mbox series

Patch

diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 345ade0af133..981e9bac7a31 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -30,6 +30,7 @@ 
 #include <linux/interval_tree_generic.h>
 #include <linux/nospec.h>
 #include <linux/kcov.h>
+#include <linux/hashtable.h>
 
 #include "vhost.h"
 
@@ -42,6 +43,9 @@  module_param(max_iotlb_entries, int, 0444);
 MODULE_PARM_DESC(max_iotlb_entries,
 	"Maximum number of iotlb entries. (default: 2048)");
 
+static DEFINE_HASHTABLE(vhost_workers, 5);
+static DEFINE_SPINLOCK(vhost_workers_lock);
+
 enum {
 	VHOST_MEMORY_F_LOG = 0x1,
 };
@@ -617,8 +621,17 @@  static void vhost_detach_mm(struct vhost_dev *dev)
 	dev->mm = NULL;
 }
 
-static void vhost_worker_free(struct vhost_worker *worker)
+static void vhost_worker_put(struct vhost_worker *worker)
 {
+	spin_lock(&vhost_workers_lock);
+	if (!refcount_dec_and_test(&worker->refcount)) {
+		spin_unlock(&vhost_workers_lock);
+		return;
+	}
+
+	hash_del(&worker->h_node);
+	spin_unlock(&vhost_workers_lock);
+
 	WARN_ON(!llist_empty(&worker->work_list));
 	kthread_stop(worker->task);
 	kfree(worker);
@@ -632,7 +645,7 @@  static void vhost_workers_free(struct vhost_dev *dev)
 		return;
 
 	for (i = 0; i < dev->num_workers; i++)
-		vhost_worker_free(dev->workers[i]);
+		vhost_worker_put(dev->workers[i]);
 
 	kfree(dev->workers);
 	dev->num_workers = 0;
@@ -652,6 +665,8 @@  static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
 	worker->id = dev->num_workers;
 	worker->dev = dev;
 	init_llist_head(&worker->work_list);
+	INIT_HLIST_NODE(&worker->h_node);
+	refcount_set(&worker->refcount, 1);
 
 	task = kthread_create(vhost_worker, worker, "vhost-%d", current->pid);
 	if (IS_ERR(task))
@@ -664,6 +679,9 @@  static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
 	if (ret)
 		goto stop_worker;
 
+	spin_lock(&vhost_workers_lock);
+	hash_add(vhost_workers, &worker->h_node, worker->task->pid);
+	spin_unlock(&vhost_workers_lock);
 	return worker;
 
 stop_worker:
@@ -673,6 +691,67 @@  static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
 	return NULL;
 }
 
+static struct vhost_worker *vhost_worker_find(struct vhost_dev *dev, pid_t pid)
+{
+	struct vhost_worker *worker, *found_worker = NULL;
+
+	spin_lock(&vhost_workers_lock);
+	hash_for_each_possible(vhost_workers, worker, h_node, pid) {
+		if (worker->task->pid == pid) {
+			/* tmp - next patch allows sharing across devs */
+			if (worker->dev != dev)
+				break;
+
+			found_worker = worker;
+			refcount_inc(&worker->refcount);
+			break;
+		}
+	}
+	spin_unlock(&vhost_workers_lock);
+	return found_worker;
+}
+
+/* Caller must have device mutex */
+static int vhost_vq_set_worker(struct vhost_virtqueue *vq,
+			       struct vhost_vring_worker *info)
+{
+	struct vhost_dev *dev = vq->dev;
+	struct vhost_worker *worker;
+
+	if (vq->worker) {
+		/* TODO - support changing while works are running */
+		return -EBUSY;
+	}
+
+	if (info->pid == VHOST_VRING_NEW_WORKER) {
+		worker = vhost_worker_create(dev);
+		if (!worker)
+			return -ENOMEM;
+
+		info->pid = worker->task->pid;
+	} else {
+		worker = vhost_worker_find(dev, info->pid);
+		if (!worker)
+			return -ENODEV;
+	}
+
+	if (!dev->workers) {
+		dev->workers = kcalloc(vq->dev->nvqs,
+				       sizeof(struct vhost_worker *),
+				       GFP_KERNEL);
+		if (!dev->workers) {
+			vhost_worker_put(worker);
+			return -ENOMEM;
+		}
+	}
+
+	vq->worker = worker;
+
+	dev->workers[dev->num_workers] = worker;
+	dev->num_workers++;
+	return 0;
+}
+
 /* Caller must have device mutex */
 static int vhost_worker_try_create_def(struct vhost_dev *dev)
 {
@@ -1680,6 +1759,7 @@  long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
 	struct eventfd_ctx *ctx = NULL;
 	u32 __user *idxp = argp;
 	struct vhost_virtqueue *vq;
+	struct vhost_vring_worker w;
 	struct vhost_vring_state s;
 	struct vhost_vring_file f;
 	u32 idx;
@@ -1794,6 +1874,15 @@  long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
 		if (copy_to_user(argp, &s, sizeof(s)))
 			r = -EFAULT;
 		break;
+	case VHOST_SET_VRING_WORKER:
+		if (copy_from_user(&w, argp, sizeof(w))) {
+			r = -EFAULT;
+			break;
+		}
+		r = vhost_vq_set_worker(vq, &w);
+		if (!r && copy_to_user(argp, &w, sizeof(w)))
+			r = -EFAULT;
+		break;
 	default:
 		r = -ENOIOCTLCMD;
 	}
@@ -2726,6 +2815,7 @@  EXPORT_SYMBOL_GPL(vhost_set_backend_features);
 
 static int __init vhost_init(void)
 {
+	hash_init(vhost_workers);
 	return 0;
 }
 
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 0a252dd45101..75b884ad1f17 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -14,6 +14,7 @@ 
 #include <linux/atomic.h>
 #include <linux/vhost_iotlb.h>
 #include <linux/irqbypass.h>
+#include <linux/refcount.h>
 
 struct vhost_work;
 typedef void (*vhost_work_fn_t)(struct vhost_work *work);
@@ -28,6 +29,8 @@  struct vhost_work {
 struct vhost_worker {
 	struct task_struct	*task;
 	struct llist_head	work_list;
+	struct hlist_node	h_node;
+	refcount_t		refcount;
 	struct vhost_dev	*dev;
 	int			id;
 };
diff --git a/include/uapi/linux/vhost.h b/include/uapi/linux/vhost.h
index c998860d7bbc..ce32119cb139 100644
--- a/include/uapi/linux/vhost.h
+++ b/include/uapi/linux/vhost.h
@@ -70,6 +70,12 @@ 
 #define VHOST_VRING_BIG_ENDIAN 1
 #define VHOST_SET_VRING_ENDIAN _IOW(VHOST_VIRTIO, 0x13, struct vhost_vring_state)
 #define VHOST_GET_VRING_ENDIAN _IOW(VHOST_VIRTIO, 0x14, struct vhost_vring_state)
+/* Create/bind a vhost worker to a virtqueue. If pid > 0 and matches an existing
+ * vhost_worker thread it will be bound to the vq. If pid is
+ * VHOST_VRING_NEW_WORKER, then a new worker will be created and bound to the
+ * vq.
+ */
+#define VHOST_SET_VRING_WORKER _IOWR(VHOST_VIRTIO, 0x15, struct vhost_vring_worker)
 
 /* The following ioctls use eventfd file descriptors to signal and poll
  * for events. */
diff --git a/include/uapi/linux/vhost_types.h b/include/uapi/linux/vhost_types.h
index f7f6a3a28977..5113baa8bc3e 100644
--- a/include/uapi/linux/vhost_types.h
+++ b/include/uapi/linux/vhost_types.h
@@ -47,6 +47,18 @@  struct vhost_vring_addr {
 	__u64 log_guest_addr;
 };
 
+#define VHOST_VRING_NEW_WORKER -1
+
+struct vhost_vring_worker {
+	unsigned int index;
+	/*
+	 * The pid of the vhost worker that the vq will be bound to. If
+	 * pid is VHOST_VRING_NEW_WORKER a new worker will be created and it's
+	 * pid will be returned in pid.
+	 */
+	__kernel_pid_t pid;
+};
+
 /* no alignment requirement */
 struct vhost_iotlb_msg {
 	__u64 iova;