diff mbox series

[RFC,v2,virtio,3/7] pds_vdpa: virtio bar setup for vdpa

Message ID 20230309013046.23523-4-shannon.nelson@amd.com (mailing list archive)
State Superseded
Headers show
Series pds_vdpa driver | expand

Checks

Context Check Description
netdev/tree_selection success Not a local patch

Commit Message

Nelson, Shannon March 9, 2023, 1:30 a.m. UTC
The PDS vDPA device has a virtio BAR for describing itself, and
the pds_vdpa driver needs to access it.  Here we copy liberally
from the existing drivers/virtio/virtio_pci_modern_dev.c as it
has what we need, but we need to modify it so that it can work
with our device id and so we can use our own DMA mask.

We suspect there is room for discussion here about making the
existing code a little more flexible, but we thought we'd at
least start the discussion here.

Signed-off-by: Shannon Nelson <shannon.nelson@amd.com>
---
 drivers/vdpa/pds/Makefile     |   1 +
 drivers/vdpa/pds/aux_drv.c    |  14 ++
 drivers/vdpa/pds/aux_drv.h    |   1 +
 drivers/vdpa/pds/debugfs.c    |   1 +
 drivers/vdpa/pds/vdpa_dev.c   |   1 +
 drivers/vdpa/pds/virtio_pci.c | 281 ++++++++++++++++++++++++++++++++++
 drivers/vdpa/pds/virtio_pci.h |   8 +
 7 files changed, 307 insertions(+)
 create mode 100644 drivers/vdpa/pds/virtio_pci.c
 create mode 100644 drivers/vdpa/pds/virtio_pci.h

Comments

Jason Wang March 15, 2023, 7:05 a.m. UTC | #1
On Thu, Mar 9, 2023 at 9:31 AM Shannon Nelson <shannon.nelson@amd.com> wrote:
>
> The PDS vDPA device has a virtio BAR for describing itself, and
> the pds_vdpa driver needs to access it.  Here we copy liberally
> from the existing drivers/virtio/virtio_pci_modern_dev.c as it
> has what we need, but we need to modify it so that it can work
> with our device id and so we can use our own DMA mask.

By passing a pointer to a customized id probing routine to vp_modern_probe()?

Thanks


>
> We suspect there is room for discussion here about making the
> existing code a little more flexible, but we thought we'd at
> least start the discussion here.
>
> Signed-off-by: Shannon Nelson <shannon.nelson@amd.com>
> ---
>  drivers/vdpa/pds/Makefile     |   1 +
>  drivers/vdpa/pds/aux_drv.c    |  14 ++
>  drivers/vdpa/pds/aux_drv.h    |   1 +
>  drivers/vdpa/pds/debugfs.c    |   1 +
>  drivers/vdpa/pds/vdpa_dev.c   |   1 +
>  drivers/vdpa/pds/virtio_pci.c | 281 ++++++++++++++++++++++++++++++++++
>  drivers/vdpa/pds/virtio_pci.h |   8 +
>  7 files changed, 307 insertions(+)
>  create mode 100644 drivers/vdpa/pds/virtio_pci.c
>  create mode 100644 drivers/vdpa/pds/virtio_pci.h
>
> diff --git a/drivers/vdpa/pds/Makefile b/drivers/vdpa/pds/Makefile
> index 13b50394ec64..ca2efa8c6eb5 100644
> --- a/drivers/vdpa/pds/Makefile
> +++ b/drivers/vdpa/pds/Makefile
> @@ -4,6 +4,7 @@
>  obj-$(CONFIG_PDS_VDPA) := pds_vdpa.o
>
>  pds_vdpa-y := aux_drv.o \
> +             virtio_pci.o \
>               vdpa_dev.o
>
>  pds_vdpa-$(CONFIG_DEBUG_FS) += debugfs.o
> diff --git a/drivers/vdpa/pds/aux_drv.c b/drivers/vdpa/pds/aux_drv.c
> index 63e40ae68211..28158d0d98a5 100644
> --- a/drivers/vdpa/pds/aux_drv.c
> +++ b/drivers/vdpa/pds/aux_drv.c
> @@ -4,6 +4,7 @@
>  #include <linux/auxiliary_bus.h>
>  #include <linux/pci.h>
>  #include <linux/vdpa.h>
> +#include <linux/virtio_pci_modern.h>
>
>  #include <linux/pds/pds_core.h>
>  #include <linux/pds/pds_auxbus.h>
> @@ -12,6 +13,7 @@
>  #include "aux_drv.h"
>  #include "debugfs.h"
>  #include "vdpa_dev.h"
> +#include "virtio_pci.h"
>
>  static const struct auxiliary_device_id pds_vdpa_id_table[] = {
>         { .name = PDS_VDPA_DEV_NAME, },
> @@ -49,8 +51,19 @@ static int pds_vdpa_probe(struct auxiliary_device *aux_dev,
>         if (err)
>                 goto err_aux_unreg;
>
> +       /* Find the virtio configuration */
> +       vdpa_aux->vd_mdev.pci_dev = padev->vf->pdev;
> +       err = pds_vdpa_probe_virtio(&vdpa_aux->vd_mdev);
> +       if (err) {
> +               dev_err(dev, "Unable to probe for virtio configuration: %pe\n",
> +                       ERR_PTR(err));
> +               goto err_free_mgmt_info;
> +       }
> +
>         return 0;
>
> +err_free_mgmt_info:
> +       pci_free_irq_vectors(padev->vf->pdev);
>  err_aux_unreg:
>         padev->ops->unregister_client(padev);
>  err_free_mem:
> @@ -65,6 +78,7 @@ static void pds_vdpa_remove(struct auxiliary_device *aux_dev)
>         struct pds_vdpa_aux *vdpa_aux = auxiliary_get_drvdata(aux_dev);
>         struct device *dev = &aux_dev->dev;
>
> +       pds_vdpa_remove_virtio(&vdpa_aux->vd_mdev);
>         pci_free_irq_vectors(vdpa_aux->padev->vf->pdev);
>
>         vdpa_aux->padev->ops->unregister_client(vdpa_aux->padev);
> diff --git a/drivers/vdpa/pds/aux_drv.h b/drivers/vdpa/pds/aux_drv.h
> index 94ba7abcaa43..87ac3c01c476 100644
> --- a/drivers/vdpa/pds/aux_drv.h
> +++ b/drivers/vdpa/pds/aux_drv.h
> @@ -16,6 +16,7 @@ struct pds_vdpa_aux {
>
>         int vf_id;
>         struct dentry *dentry;
> +       struct virtio_pci_modern_device vd_mdev;
>
>         int nintrs;
>  };
> diff --git a/drivers/vdpa/pds/debugfs.c b/drivers/vdpa/pds/debugfs.c
> index 7b7e90fd6578..aa5e9677fe74 100644
> --- a/drivers/vdpa/pds/debugfs.c
> +++ b/drivers/vdpa/pds/debugfs.c
> @@ -1,6 +1,7 @@
>  // SPDX-License-Identifier: GPL-2.0-only
>  /* Copyright(c) 2023 Advanced Micro Devices, Inc */
>
> +#include <linux/virtio_pci_modern.h>
>  #include <linux/vdpa.h>
>
>  #include <linux/pds/pds_core.h>
> diff --git a/drivers/vdpa/pds/vdpa_dev.c b/drivers/vdpa/pds/vdpa_dev.c
> index bd840688503c..15d623297203 100644
> --- a/drivers/vdpa/pds/vdpa_dev.c
> +++ b/drivers/vdpa/pds/vdpa_dev.c
> @@ -4,6 +4,7 @@
>  #include <linux/pci.h>
>  #include <linux/vdpa.h>
>  #include <uapi/linux/vdpa.h>
> +#include <linux/virtio_pci_modern.h>
>
>  #include <linux/pds/pds_core.h>
>  #include <linux/pds/pds_adminq.h>
> diff --git a/drivers/vdpa/pds/virtio_pci.c b/drivers/vdpa/pds/virtio_pci.c
> new file mode 100644
> index 000000000000..cb879619dac3
> --- /dev/null
> +++ b/drivers/vdpa/pds/virtio_pci.c
> @@ -0,0 +1,281 @@
> +// SPDX-License-Identifier: GPL-2.0-or-later
> +
> +/*
> + * adapted from drivers/virtio/virtio_pci_modern_dev.c, v6.0-rc1
> + */
> +
> +#include <linux/virtio_pci_modern.h>
> +#include <linux/pci.h>
> +
> +#include "virtio_pci.h"
> +
> +/*
> + * pds_vdpa_map_capability - map a part of virtio pci capability
> + * @mdev: the modern virtio-pci device
> + * @off: offset of the capability
> + * @minlen: minimal length of the capability
> + * @align: align requirement
> + * @start: start from the capability
> + * @size: map size
> + * @len: the length that is actually mapped
> + * @pa: physical address of the capability
> + *
> + * Returns the io address of for the part of the capability
> + */
> +static void __iomem *
> +pds_vdpa_map_capability(struct virtio_pci_modern_device *mdev, int off,
> +                       size_t minlen, u32 align, u32 start, u32 size,
> +                       size_t *len, resource_size_t *pa)
> +{
> +       struct pci_dev *dev = mdev->pci_dev;
> +       u8 bar;
> +       u32 offset, length;
> +       void __iomem *p;
> +
> +       pci_read_config_byte(dev, off + offsetof(struct virtio_pci_cap,
> +                                                bar),
> +                            &bar);
> +       pci_read_config_dword(dev, off + offsetof(struct virtio_pci_cap, offset),
> +                             &offset);
> +       pci_read_config_dword(dev, off + offsetof(struct virtio_pci_cap, length),
> +                             &length);
> +
> +       /* Check if the BAR may have changed since we requested the region. */
> +       if (bar >= PCI_STD_NUM_BARS || !(mdev->modern_bars & (1 << bar))) {
> +               dev_err(&dev->dev,
> +                       "virtio_pci: bar unexpectedly changed to %u\n", bar);
> +               return NULL;
> +       }
> +
> +       if (length <= start) {
> +               dev_err(&dev->dev,
> +                       "virtio_pci: bad capability len %u (>%u expected)\n",
> +                       length, start);
> +               return NULL;
> +       }
> +
> +       if (length - start < minlen) {
> +               dev_err(&dev->dev,
> +                       "virtio_pci: bad capability len %u (>=%zu expected)\n",
> +                       length, minlen);
> +               return NULL;
> +       }
> +
> +       length -= start;
> +
> +       if (start + offset < offset) {
> +               dev_err(&dev->dev,
> +                       "virtio_pci: map wrap-around %u+%u\n",
> +                       start, offset);
> +               return NULL;
> +       }
> +
> +       offset += start;
> +
> +       if (offset & (align - 1)) {
> +               dev_err(&dev->dev,
> +                       "virtio_pci: offset %u not aligned to %u\n",
> +                       offset, align);
> +               return NULL;
> +       }
> +
> +       if (length > size)
> +               length = size;
> +
> +       if (len)
> +               *len = length;
> +
> +       if (minlen + offset < minlen ||
> +           minlen + offset > pci_resource_len(dev, bar)) {
> +               dev_err(&dev->dev,
> +                       "virtio_pci: map virtio %zu@%u out of range on bar %i length %lu\n",
> +                       minlen, offset,
> +                       bar, (unsigned long)pci_resource_len(dev, bar));
> +               return NULL;
> +       }
> +
> +       p = pci_iomap_range(dev, bar, offset, length);
> +       if (!p)
> +               dev_err(&dev->dev,
> +                       "virtio_pci: unable to map virtio %u@%u on bar %i\n",
> +                       length, offset, bar);
> +       else if (pa)
> +               *pa = pci_resource_start(dev, bar) + offset;
> +
> +       return p;
> +}
> +
> +/**
> + * virtio_pci_find_capability - walk capabilities to find device info.
> + * @dev: the pci device
> + * @cfg_type: the VIRTIO_PCI_CAP_* value we seek
> + * @ioresource_types: IORESOURCE_MEM and/or IORESOURCE_IO.
> + * @bars: the bitmask of BARs
> + *
> + * Returns offset of the capability, or 0.
> + */
> +static inline int virtio_pci_find_capability(struct pci_dev *dev, u8 cfg_type,
> +                                            u32 ioresource_types, int *bars)
> +{
> +       int pos;
> +
> +       for (pos = pci_find_capability(dev, PCI_CAP_ID_VNDR);
> +            pos > 0;
> +            pos = pci_find_next_capability(dev, pos, PCI_CAP_ID_VNDR)) {
> +               u8 type, bar;
> +
> +               pci_read_config_byte(dev, pos + offsetof(struct virtio_pci_cap,
> +                                                        cfg_type),
> +                                    &type);
> +               pci_read_config_byte(dev, pos + offsetof(struct virtio_pci_cap,
> +                                                        bar),
> +                                    &bar);
> +
> +               /* Ignore structures with reserved BAR values */
> +               if (bar >= PCI_STD_NUM_BARS)
> +                       continue;
> +
> +               if (type == cfg_type) {
> +                       if (pci_resource_len(dev, bar) &&
> +                           pci_resource_flags(dev, bar) & ioresource_types) {
> +                               *bars |= (1 << bar);
> +                               return pos;
> +                       }
> +               }
> +       }
> +       return 0;
> +}
> +
> +/*
> + * pds_vdpa_probe_virtio: probe the modern virtio pci device, note that the
> + * caller is required to enable PCI device before calling this function.
> + * @mdev: the modern virtio-pci device
> + *
> + * Return 0 on succeed otherwise fail
> + */
> +int pds_vdpa_probe_virtio(struct virtio_pci_modern_device *mdev)
> +{
> +       struct pci_dev *pci_dev = mdev->pci_dev;
> +       int err, common, isr, notify, device;
> +       u32 notify_length;
> +       u32 notify_offset;
> +
> +       /* check for a common config: if not, use legacy mode (bar 0). */
> +       common = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_COMMON_CFG,
> +                                           IORESOURCE_IO | IORESOURCE_MEM,
> +                                           &mdev->modern_bars);
> +       if (!common) {
> +               dev_info(&pci_dev->dev,
> +                        "virtio_pci: missing common config\n");
> +               return -ENODEV;
> +       }
> +
> +       /* If common is there, these should be too... */
> +       isr = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_ISR_CFG,
> +                                        IORESOURCE_IO | IORESOURCE_MEM,
> +                                        &mdev->modern_bars);
> +       notify = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_NOTIFY_CFG,
> +                                           IORESOURCE_IO | IORESOURCE_MEM,
> +                                           &mdev->modern_bars);
> +       if (!isr || !notify) {
> +               dev_err(&pci_dev->dev,
> +                       "virtio_pci: missing capabilities %i/%i/%i\n",
> +                       common, isr, notify);
> +               return -EINVAL;
> +       }
> +
> +       /* Device capability is only mandatory for devices that have
> +        * device-specific configuration.
> +        */
> +       device = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_DEVICE_CFG,
> +                                           IORESOURCE_IO | IORESOURCE_MEM,
> +                                           &mdev->modern_bars);
> +
> +       err = pci_request_selected_regions(pci_dev, mdev->modern_bars,
> +                                          "virtio-pci-modern");
> +       if (err)
> +               return err;
> +
> +       err = -EINVAL;
> +       mdev->common = pds_vdpa_map_capability(mdev, common,
> +                                              sizeof(struct virtio_pci_common_cfg),
> +                                              4, 0,
> +                                              sizeof(struct virtio_pci_common_cfg),
> +                                              NULL, NULL);
> +       if (!mdev->common)
> +               goto err_map_common;
> +       mdev->isr = pds_vdpa_map_capability(mdev, isr, sizeof(u8), 1,
> +                                           0, 1, NULL, NULL);
> +       if (!mdev->isr)
> +               goto err_map_isr;
> +
> +       /* Read notify_off_multiplier from config space. */
> +       pci_read_config_dword(pci_dev,
> +                             notify + offsetof(struct virtio_pci_notify_cap,
> +                                               notify_off_multiplier),
> +                             &mdev->notify_offset_multiplier);
> +       /* Read notify length and offset from config space. */
> +       pci_read_config_dword(pci_dev,
> +                             notify + offsetof(struct virtio_pci_notify_cap,
> +                                               cap.length),
> +                             &notify_length);
> +
> +       pci_read_config_dword(pci_dev,
> +                             notify + offsetof(struct virtio_pci_notify_cap,
> +                                               cap.offset),
> +                             &notify_offset);
> +
> +       /* We don't know how many VQs we'll map, ahead of the time.
> +        * If notify length is small, map it all now.
> +        * Otherwise, map each VQ individually later.
> +        */
> +       if ((u64)notify_length + (notify_offset % PAGE_SIZE) <= PAGE_SIZE) {
> +               mdev->notify_base = pds_vdpa_map_capability(mdev, notify,
> +                                                           2, 2,
> +                                                           0, notify_length,
> +                                                           &mdev->notify_len,
> +                                                           &mdev->notify_pa);
> +               if (!mdev->notify_base)
> +                       goto err_map_notify;
> +       } else {
> +               mdev->notify_map_cap = notify;
> +       }
> +
> +       /* Again, we don't know how much we should map, but PAGE_SIZE
> +        * is more than enough for all existing devices.
> +        */
> +       if (device) {
> +               mdev->device = pds_vdpa_map_capability(mdev, device, 0, 4,
> +                                                      0, PAGE_SIZE,
> +                                                      &mdev->device_len,
> +                                                      NULL);
> +               if (!mdev->device)
> +                       goto err_map_device;
> +       }
> +
> +       return 0;
> +
> +err_map_device:
> +       if (mdev->notify_base)
> +               pci_iounmap(pci_dev, mdev->notify_base);
> +err_map_notify:
> +       pci_iounmap(pci_dev, mdev->isr);
> +err_map_isr:
> +       pci_iounmap(pci_dev, mdev->common);
> +err_map_common:
> +       pci_release_selected_regions(pci_dev, mdev->modern_bars);
> +       return err;
> +}
> +
> +void pds_vdpa_remove_virtio(struct virtio_pci_modern_device *mdev)
> +{
> +       struct pci_dev *pci_dev = mdev->pci_dev;
> +
> +       if (mdev->device)
> +               pci_iounmap(pci_dev, mdev->device);
> +       if (mdev->notify_base)
> +               pci_iounmap(pci_dev, mdev->notify_base);
> +       pci_iounmap(pci_dev, mdev->isr);
> +       pci_iounmap(pci_dev, mdev->common);
> +       pci_release_selected_regions(pci_dev, mdev->modern_bars);
> +}
> diff --git a/drivers/vdpa/pds/virtio_pci.h b/drivers/vdpa/pds/virtio_pci.h
> new file mode 100644
> index 000000000000..f017cfa1173c
> --- /dev/null
> +++ b/drivers/vdpa/pds/virtio_pci.h
> @@ -0,0 +1,8 @@
> +/* SPDX-License-Identifier: GPL-2.0-only */
> +/* Copyright(c) 2023 Advanced Micro Devices, Inc */
> +
> +#ifndef _PDS_VIRTIO_PCI_H_
> +#define _PDS_VIRTIO_PCI_H_
> +int pds_vdpa_probe_virtio(struct virtio_pci_modern_device *mdev);
> +void pds_vdpa_remove_virtio(struct virtio_pci_modern_device *mdev);
> +#endif /* _PDS_VIRTIO_PCI_H_ */
> --
> 2.17.1
>
Nelson, Shannon March 16, 2023, 3:25 a.m. UTC | #2
On 3/15/23 12:05 AM, Jason Wang wrote:
> On Thu, Mar 9, 2023 at 9:31 AM Shannon Nelson <shannon.nelson@amd.com> wrote:
>>
>> The PDS vDPA device has a virtio BAR for describing itself, and
>> the pds_vdpa driver needs to access it.  Here we copy liberally
>> from the existing drivers/virtio/virtio_pci_modern_dev.c as it
>> has what we need, but we need to modify it so that it can work
>> with our device id and so we can use our own DMA mask.
> 
> By passing a pointer to a customized id probing routine to vp_modern_probe()?

The only real differences are that we needed to cut out the device id 
checks to use our vDPA VF device id, and remove 
dma_set_mask_and_coherent() because we need a different DMA_BIT_MASK().

Maybe a function pointer to something that can validate the device id, 
and a bitmask for setting DMA mapping; if they are 0/NULL, use the 
default device id check and DMA mask.

Adding them as extra arguments to the function call seems a bit messy, 
maybe add them to the struct virtio_pci_modern_device and the caller can 
set them as overrides if needed?

struct virtio_pci_modern_device {

	...

	int (*device_id_check_override(struct pci_dev *pdev));
	u64 dma_mask_override;
}

sln


> 
> Thanks
> 
> 
>>
>> We suspect there is room for discussion here about making the
>> existing code a little more flexible, but we thought we'd at
>> least start the discussion here.
>>
>> Signed-off-by: Shannon Nelson <shannon.nelson@amd.com>
>> ---
>>   drivers/vdpa/pds/Makefile     |   1 +
>>   drivers/vdpa/pds/aux_drv.c    |  14 ++
>>   drivers/vdpa/pds/aux_drv.h    |   1 +
>>   drivers/vdpa/pds/debugfs.c    |   1 +
>>   drivers/vdpa/pds/vdpa_dev.c   |   1 +
>>   drivers/vdpa/pds/virtio_pci.c | 281 ++++++++++++++++++++++++++++++++++
>>   drivers/vdpa/pds/virtio_pci.h |   8 +
>>   7 files changed, 307 insertions(+)
>>   create mode 100644 drivers/vdpa/pds/virtio_pci.c
>>   create mode 100644 drivers/vdpa/pds/virtio_pci.h
>>
>> diff --git a/drivers/vdpa/pds/Makefile b/drivers/vdpa/pds/Makefile
>> index 13b50394ec64..ca2efa8c6eb5 100644
>> --- a/drivers/vdpa/pds/Makefile
>> +++ b/drivers/vdpa/pds/Makefile
>> @@ -4,6 +4,7 @@
>>   obj-$(CONFIG_PDS_VDPA) := pds_vdpa.o
>>
>>   pds_vdpa-y := aux_drv.o \
>> +             virtio_pci.o \
>>                vdpa_dev.o
>>
>>   pds_vdpa-$(CONFIG_DEBUG_FS) += debugfs.o
>> diff --git a/drivers/vdpa/pds/aux_drv.c b/drivers/vdpa/pds/aux_drv.c
>> index 63e40ae68211..28158d0d98a5 100644
>> --- a/drivers/vdpa/pds/aux_drv.c
>> +++ b/drivers/vdpa/pds/aux_drv.c
>> @@ -4,6 +4,7 @@
>>   #include <linux/auxiliary_bus.h>
>>   #include <linux/pci.h>
>>   #include <linux/vdpa.h>
>> +#include <linux/virtio_pci_modern.h>
>>
>>   #include <linux/pds/pds_core.h>
>>   #include <linux/pds/pds_auxbus.h>
>> @@ -12,6 +13,7 @@
>>   #include "aux_drv.h"
>>   #include "debugfs.h"
>>   #include "vdpa_dev.h"
>> +#include "virtio_pci.h"
>>
>>   static const struct auxiliary_device_id pds_vdpa_id_table[] = {
>>          { .name = PDS_VDPA_DEV_NAME, },
>> @@ -49,8 +51,19 @@ static int pds_vdpa_probe(struct auxiliary_device *aux_dev,
>>          if (err)
>>                  goto err_aux_unreg;
>>
>> +       /* Find the virtio configuration */
>> +       vdpa_aux->vd_mdev.pci_dev = padev->vf->pdev;
>> +       err = pds_vdpa_probe_virtio(&vdpa_aux->vd_mdev);
>> +       if (err) {
>> +               dev_err(dev, "Unable to probe for virtio configuration: %pe\n",
>> +                       ERR_PTR(err));
>> +               goto err_free_mgmt_info;
>> +       }
>> +
>>          return 0;
>>
>> +err_free_mgmt_info:
>> +       pci_free_irq_vectors(padev->vf->pdev);
>>   err_aux_unreg:
>>          padev->ops->unregister_client(padev);
>>   err_free_mem:
>> @@ -65,6 +78,7 @@ static void pds_vdpa_remove(struct auxiliary_device *aux_dev)
>>          struct pds_vdpa_aux *vdpa_aux = auxiliary_get_drvdata(aux_dev);
>>          struct device *dev = &aux_dev->dev;
>>
>> +       pds_vdpa_remove_virtio(&vdpa_aux->vd_mdev);
>>          pci_free_irq_vectors(vdpa_aux->padev->vf->pdev);
>>
>>          vdpa_aux->padev->ops->unregister_client(vdpa_aux->padev);
>> diff --git a/drivers/vdpa/pds/aux_drv.h b/drivers/vdpa/pds/aux_drv.h
>> index 94ba7abcaa43..87ac3c01c476 100644
>> --- a/drivers/vdpa/pds/aux_drv.h
>> +++ b/drivers/vdpa/pds/aux_drv.h
>> @@ -16,6 +16,7 @@ struct pds_vdpa_aux {
>>
>>          int vf_id;
>>          struct dentry *dentry;
>> +       struct virtio_pci_modern_device vd_mdev;
>>
>>          int nintrs;
>>   };
>> diff --git a/drivers/vdpa/pds/debugfs.c b/drivers/vdpa/pds/debugfs.c
>> index 7b7e90fd6578..aa5e9677fe74 100644
>> --- a/drivers/vdpa/pds/debugfs.c
>> +++ b/drivers/vdpa/pds/debugfs.c
>> @@ -1,6 +1,7 @@
>>   // SPDX-License-Identifier: GPL-2.0-only
>>   /* Copyright(c) 2023 Advanced Micro Devices, Inc */
>>
>> +#include <linux/virtio_pci_modern.h>
>>   #include <linux/vdpa.h>
>>
>>   #include <linux/pds/pds_core.h>
>> diff --git a/drivers/vdpa/pds/vdpa_dev.c b/drivers/vdpa/pds/vdpa_dev.c
>> index bd840688503c..15d623297203 100644
>> --- a/drivers/vdpa/pds/vdpa_dev.c
>> +++ b/drivers/vdpa/pds/vdpa_dev.c
>> @@ -4,6 +4,7 @@
>>   #include <linux/pci.h>
>>   #include <linux/vdpa.h>
>>   #include <uapi/linux/vdpa.h>
>> +#include <linux/virtio_pci_modern.h>
>>
>>   #include <linux/pds/pds_core.h>
>>   #include <linux/pds/pds_adminq.h>
>> diff --git a/drivers/vdpa/pds/virtio_pci.c b/drivers/vdpa/pds/virtio_pci.c
>> new file mode 100644
>> index 000000000000..cb879619dac3
>> --- /dev/null
>> +++ b/drivers/vdpa/pds/virtio_pci.c
>> @@ -0,0 +1,281 @@
>> +// SPDX-License-Identifier: GPL-2.0-or-later
>> +
>> +/*
>> + * adapted from drivers/virtio/virtio_pci_modern_dev.c, v6.0-rc1
>> + */
>> +
>> +#include <linux/virtio_pci_modern.h>
>> +#include <linux/pci.h>
>> +
>> +#include "virtio_pci.h"
>> +
>> +/*
>> + * pds_vdpa_map_capability - map a part of virtio pci capability
>> + * @mdev: the modern virtio-pci device
>> + * @off: offset of the capability
>> + * @minlen: minimal length of the capability
>> + * @align: align requirement
>> + * @start: start from the capability
>> + * @size: map size
>> + * @len: the length that is actually mapped
>> + * @pa: physical address of the capability
>> + *
>> + * Returns the io address of for the part of the capability
>> + */
>> +static void __iomem *
>> +pds_vdpa_map_capability(struct virtio_pci_modern_device *mdev, int off,
>> +                       size_t minlen, u32 align, u32 start, u32 size,
>> +                       size_t *len, resource_size_t *pa)
>> +{
>> +       struct pci_dev *dev = mdev->pci_dev;
>> +       u8 bar;
>> +       u32 offset, length;
>> +       void __iomem *p;
>> +
>> +       pci_read_config_byte(dev, off + offsetof(struct virtio_pci_cap,
>> +                                                bar),
>> +                            &bar);
>> +       pci_read_config_dword(dev, off + offsetof(struct virtio_pci_cap, offset),
>> +                             &offset);
>> +       pci_read_config_dword(dev, off + offsetof(struct virtio_pci_cap, length),
>> +                             &length);
>> +
>> +       /* Check if the BAR may have changed since we requested the region. */
>> +       if (bar >= PCI_STD_NUM_BARS || !(mdev->modern_bars & (1 << bar))) {
>> +               dev_err(&dev->dev,
>> +                       "virtio_pci: bar unexpectedly changed to %u\n", bar);
>> +               return NULL;
>> +       }
>> +
>> +       if (length <= start) {
>> +               dev_err(&dev->dev,
>> +                       "virtio_pci: bad capability len %u (>%u expected)\n",
>> +                       length, start);
>> +               return NULL;
>> +       }
>> +
>> +       if (length - start < minlen) {
>> +               dev_err(&dev->dev,
>> +                       "virtio_pci: bad capability len %u (>=%zu expected)\n",
>> +                       length, minlen);
>> +               return NULL;
>> +       }
>> +
>> +       length -= start;
>> +
>> +       if (start + offset < offset) {
>> +               dev_err(&dev->dev,
>> +                       "virtio_pci: map wrap-around %u+%u\n",
>> +                       start, offset);
>> +               return NULL;
>> +       }
>> +
>> +       offset += start;
>> +
>> +       if (offset & (align - 1)) {
>> +               dev_err(&dev->dev,
>> +                       "virtio_pci: offset %u not aligned to %u\n",
>> +                       offset, align);
>> +               return NULL;
>> +       }
>> +
>> +       if (length > size)
>> +               length = size;
>> +
>> +       if (len)
>> +               *len = length;
>> +
>> +       if (minlen + offset < minlen ||
>> +           minlen + offset > pci_resource_len(dev, bar)) {
>> +               dev_err(&dev->dev,
>> +                       "virtio_pci: map virtio %zu@%u out of range on bar %i length %lu\n",
>> +                       minlen, offset,
>> +                       bar, (unsigned long)pci_resource_len(dev, bar));
>> +               return NULL;
>> +       }
>> +
>> +       p = pci_iomap_range(dev, bar, offset, length);
>> +       if (!p)
>> +               dev_err(&dev->dev,
>> +                       "virtio_pci: unable to map virtio %u@%u on bar %i\n",
>> +                       length, offset, bar);
>> +       else if (pa)
>> +               *pa = pci_resource_start(dev, bar) + offset;
>> +
>> +       return p;
>> +}
>> +
>> +/**
>> + * virtio_pci_find_capability - walk capabilities to find device info.
>> + * @dev: the pci device
>> + * @cfg_type: the VIRTIO_PCI_CAP_* value we seek
>> + * @ioresource_types: IORESOURCE_MEM and/or IORESOURCE_IO.
>> + * @bars: the bitmask of BARs
>> + *
>> + * Returns offset of the capability, or 0.
>> + */
>> +static inline int virtio_pci_find_capability(struct pci_dev *dev, u8 cfg_type,
>> +                                            u32 ioresource_types, int *bars)
>> +{
>> +       int pos;
>> +
>> +       for (pos = pci_find_capability(dev, PCI_CAP_ID_VNDR);
>> +            pos > 0;
>> +            pos = pci_find_next_capability(dev, pos, PCI_CAP_ID_VNDR)) {
>> +               u8 type, bar;
>> +
>> +               pci_read_config_byte(dev, pos + offsetof(struct virtio_pci_cap,
>> +                                                        cfg_type),
>> +                                    &type);
>> +               pci_read_config_byte(dev, pos + offsetof(struct virtio_pci_cap,
>> +                                                        bar),
>> +                                    &bar);
>> +
>> +               /* Ignore structures with reserved BAR values */
>> +               if (bar >= PCI_STD_NUM_BARS)
>> +                       continue;
>> +
>> +               if (type == cfg_type) {
>> +                       if (pci_resource_len(dev, bar) &&
>> +                           pci_resource_flags(dev, bar) & ioresource_types) {
>> +                               *bars |= (1 << bar);
>> +                               return pos;
>> +                       }
>> +               }
>> +       }
>> +       return 0;
>> +}
>> +
>> +/*
>> + * pds_vdpa_probe_virtio: probe the modern virtio pci device, note that the
>> + * caller is required to enable PCI device before calling this function.
>> + * @mdev: the modern virtio-pci device
>> + *
>> + * Return 0 on succeed otherwise fail
>> + */
>> +int pds_vdpa_probe_virtio(struct virtio_pci_modern_device *mdev)
>> +{
>> +       struct pci_dev *pci_dev = mdev->pci_dev;
>> +       int err, common, isr, notify, device;
>> +       u32 notify_length;
>> +       u32 notify_offset;
>> +
>> +       /* check for a common config: if not, use legacy mode (bar 0). */
>> +       common = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_COMMON_CFG,
>> +                                           IORESOURCE_IO | IORESOURCE_MEM,
>> +                                           &mdev->modern_bars);
>> +       if (!common) {
>> +               dev_info(&pci_dev->dev,
>> +                        "virtio_pci: missing common config\n");
>> +               return -ENODEV;
>> +       }
>> +
>> +       /* If common is there, these should be too... */
>> +       isr = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_ISR_CFG,
>> +                                        IORESOURCE_IO | IORESOURCE_MEM,
>> +                                        &mdev->modern_bars);
>> +       notify = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_NOTIFY_CFG,
>> +                                           IORESOURCE_IO | IORESOURCE_MEM,
>> +                                           &mdev->modern_bars);
>> +       if (!isr || !notify) {
>> +               dev_err(&pci_dev->dev,
>> +                       "virtio_pci: missing capabilities %i/%i/%i\n",
>> +                       common, isr, notify);
>> +               return -EINVAL;
>> +       }
>> +
>> +       /* Device capability is only mandatory for devices that have
>> +        * device-specific configuration.
>> +        */
>> +       device = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_DEVICE_CFG,
>> +                                           IORESOURCE_IO | IORESOURCE_MEM,
>> +                                           &mdev->modern_bars);
>> +
>> +       err = pci_request_selected_regions(pci_dev, mdev->modern_bars,
>> +                                          "virtio-pci-modern");
>> +       if (err)
>> +               return err;
>> +
>> +       err = -EINVAL;
>> +       mdev->common = pds_vdpa_map_capability(mdev, common,
>> +                                              sizeof(struct virtio_pci_common_cfg),
>> +                                              4, 0,
>> +                                              sizeof(struct virtio_pci_common_cfg),
>> +                                              NULL, NULL);
>> +       if (!mdev->common)
>> +               goto err_map_common;
>> +       mdev->isr = pds_vdpa_map_capability(mdev, isr, sizeof(u8), 1,
>> +                                           0, 1, NULL, NULL);
>> +       if (!mdev->isr)
>> +               goto err_map_isr;
>> +
>> +       /* Read notify_off_multiplier from config space. */
>> +       pci_read_config_dword(pci_dev,
>> +                             notify + offsetof(struct virtio_pci_notify_cap,
>> +                                               notify_off_multiplier),
>> +                             &mdev->notify_offset_multiplier);
>> +       /* Read notify length and offset from config space. */
>> +       pci_read_config_dword(pci_dev,
>> +                             notify + offsetof(struct virtio_pci_notify_cap,
>> +                                               cap.length),
>> +                             &notify_length);
>> +
>> +       pci_read_config_dword(pci_dev,
>> +                             notify + offsetof(struct virtio_pci_notify_cap,
>> +                                               cap.offset),
>> +                             &notify_offset);
>> +
>> +       /* We don't know how many VQs we'll map, ahead of the time.
>> +        * If notify length is small, map it all now.
>> +        * Otherwise, map each VQ individually later.
>> +        */
>> +       if ((u64)notify_length + (notify_offset % PAGE_SIZE) <= PAGE_SIZE) {
>> +               mdev->notify_base = pds_vdpa_map_capability(mdev, notify,
>> +                                                           2, 2,
>> +                                                           0, notify_length,
>> +                                                           &mdev->notify_len,
>> +                                                           &mdev->notify_pa);
>> +               if (!mdev->notify_base)
>> +                       goto err_map_notify;
>> +       } else {
>> +               mdev->notify_map_cap = notify;
>> +       }
>> +
>> +       /* Again, we don't know how much we should map, but PAGE_SIZE
>> +        * is more than enough for all existing devices.
>> +        */
>> +       if (device) {
>> +               mdev->device = pds_vdpa_map_capability(mdev, device, 0, 4,
>> +                                                      0, PAGE_SIZE,
>> +                                                      &mdev->device_len,
>> +                                                      NULL);
>> +               if (!mdev->device)
>> +                       goto err_map_device;
>> +       }
>> +
>> +       return 0;
>> +
>> +err_map_device:
>> +       if (mdev->notify_base)
>> +               pci_iounmap(pci_dev, mdev->notify_base);
>> +err_map_notify:
>> +       pci_iounmap(pci_dev, mdev->isr);
>> +err_map_isr:
>> +       pci_iounmap(pci_dev, mdev->common);
>> +err_map_common:
>> +       pci_release_selected_regions(pci_dev, mdev->modern_bars);
>> +       return err;
>> +}
>> +
>> +void pds_vdpa_remove_virtio(struct virtio_pci_modern_device *mdev)
>> +{
>> +       struct pci_dev *pci_dev = mdev->pci_dev;
>> +
>> +       if (mdev->device)
>> +               pci_iounmap(pci_dev, mdev->device);
>> +       if (mdev->notify_base)
>> +               pci_iounmap(pci_dev, mdev->notify_base);
>> +       pci_iounmap(pci_dev, mdev->isr);
>> +       pci_iounmap(pci_dev, mdev->common);
>> +       pci_release_selected_regions(pci_dev, mdev->modern_bars);
>> +}
>> diff --git a/drivers/vdpa/pds/virtio_pci.h b/drivers/vdpa/pds/virtio_pci.h
>> new file mode 100644
>> index 000000000000..f017cfa1173c
>> --- /dev/null
>> +++ b/drivers/vdpa/pds/virtio_pci.h
>> @@ -0,0 +1,8 @@
>> +/* SPDX-License-Identifier: GPL-2.0-only */
>> +/* Copyright(c) 2023 Advanced Micro Devices, Inc */
>> +
>> +#ifndef _PDS_VIRTIO_PCI_H_
>> +#define _PDS_VIRTIO_PCI_H_
>> +int pds_vdpa_probe_virtio(struct virtio_pci_modern_device *mdev);
>> +void pds_vdpa_remove_virtio(struct virtio_pci_modern_device *mdev);
>> +#endif /* _PDS_VIRTIO_PCI_H_ */
>> --
>> 2.17.1
>>
>
Jason Wang March 17, 2023, 3:37 a.m. UTC | #3
On Thu, Mar 16, 2023 at 11:25 AM Shannon Nelson <shannon.nelson@amd.com> wrote:
>
> On 3/15/23 12:05 AM, Jason Wang wrote:
> > On Thu, Mar 9, 2023 at 9:31 AM Shannon Nelson <shannon.nelson@amd.com> wrote:
> >>
> >> The PDS vDPA device has a virtio BAR for describing itself, and
> >> the pds_vdpa driver needs to access it.  Here we copy liberally
> >> from the existing drivers/virtio/virtio_pci_modern_dev.c as it
> >> has what we need, but we need to modify it so that it can work
> >> with our device id and so we can use our own DMA mask.
> >
> > By passing a pointer to a customized id probing routine to vp_modern_probe()?
>
> The only real differences are that we needed to cut out the device id
> checks to use our vDPA VF device id, and remove
> dma_set_mask_and_coherent() because we need a different DMA_BIT_MASK().
>
> Maybe a function pointer to something that can validate the device id,
> and a bitmask for setting DMA mapping; if they are 0/NULL, use the
> default device id check and DMA mask.
>
> Adding them as extra arguments to the function call seems a bit messy,
> maybe add them to the struct virtio_pci_modern_device and the caller can
> set them as overrides if needed?
>
> struct virtio_pci_modern_device {
>
>         ...
>
>         int (*device_id_check_override(struct pci_dev *pdev));
>         u64 dma_mask_override;
> }

Looks fine.

Thanks
diff mbox series

Patch

diff --git a/drivers/vdpa/pds/Makefile b/drivers/vdpa/pds/Makefile
index 13b50394ec64..ca2efa8c6eb5 100644
--- a/drivers/vdpa/pds/Makefile
+++ b/drivers/vdpa/pds/Makefile
@@ -4,6 +4,7 @@ 
 obj-$(CONFIG_PDS_VDPA) := pds_vdpa.o
 
 pds_vdpa-y := aux_drv.o \
+	      virtio_pci.o \
 	      vdpa_dev.o
 
 pds_vdpa-$(CONFIG_DEBUG_FS) += debugfs.o
diff --git a/drivers/vdpa/pds/aux_drv.c b/drivers/vdpa/pds/aux_drv.c
index 63e40ae68211..28158d0d98a5 100644
--- a/drivers/vdpa/pds/aux_drv.c
+++ b/drivers/vdpa/pds/aux_drv.c
@@ -4,6 +4,7 @@ 
 #include <linux/auxiliary_bus.h>
 #include <linux/pci.h>
 #include <linux/vdpa.h>
+#include <linux/virtio_pci_modern.h>
 
 #include <linux/pds/pds_core.h>
 #include <linux/pds/pds_auxbus.h>
@@ -12,6 +13,7 @@ 
 #include "aux_drv.h"
 #include "debugfs.h"
 #include "vdpa_dev.h"
+#include "virtio_pci.h"
 
 static const struct auxiliary_device_id pds_vdpa_id_table[] = {
 	{ .name = PDS_VDPA_DEV_NAME, },
@@ -49,8 +51,19 @@  static int pds_vdpa_probe(struct auxiliary_device *aux_dev,
 	if (err)
 		goto err_aux_unreg;
 
+	/* Find the virtio configuration */
+	vdpa_aux->vd_mdev.pci_dev = padev->vf->pdev;
+	err = pds_vdpa_probe_virtio(&vdpa_aux->vd_mdev);
+	if (err) {
+		dev_err(dev, "Unable to probe for virtio configuration: %pe\n",
+			ERR_PTR(err));
+		goto err_free_mgmt_info;
+	}
+
 	return 0;
 
+err_free_mgmt_info:
+	pci_free_irq_vectors(padev->vf->pdev);
 err_aux_unreg:
 	padev->ops->unregister_client(padev);
 err_free_mem:
@@ -65,6 +78,7 @@  static void pds_vdpa_remove(struct auxiliary_device *aux_dev)
 	struct pds_vdpa_aux *vdpa_aux = auxiliary_get_drvdata(aux_dev);
 	struct device *dev = &aux_dev->dev;
 
+	pds_vdpa_remove_virtio(&vdpa_aux->vd_mdev);
 	pci_free_irq_vectors(vdpa_aux->padev->vf->pdev);
 
 	vdpa_aux->padev->ops->unregister_client(vdpa_aux->padev);
diff --git a/drivers/vdpa/pds/aux_drv.h b/drivers/vdpa/pds/aux_drv.h
index 94ba7abcaa43..87ac3c01c476 100644
--- a/drivers/vdpa/pds/aux_drv.h
+++ b/drivers/vdpa/pds/aux_drv.h
@@ -16,6 +16,7 @@  struct pds_vdpa_aux {
 
 	int vf_id;
 	struct dentry *dentry;
+	struct virtio_pci_modern_device vd_mdev;
 
 	int nintrs;
 };
diff --git a/drivers/vdpa/pds/debugfs.c b/drivers/vdpa/pds/debugfs.c
index 7b7e90fd6578..aa5e9677fe74 100644
--- a/drivers/vdpa/pds/debugfs.c
+++ b/drivers/vdpa/pds/debugfs.c
@@ -1,6 +1,7 @@ 
 // SPDX-License-Identifier: GPL-2.0-only
 /* Copyright(c) 2023 Advanced Micro Devices, Inc */
 
+#include <linux/virtio_pci_modern.h>
 #include <linux/vdpa.h>
 
 #include <linux/pds/pds_core.h>
diff --git a/drivers/vdpa/pds/vdpa_dev.c b/drivers/vdpa/pds/vdpa_dev.c
index bd840688503c..15d623297203 100644
--- a/drivers/vdpa/pds/vdpa_dev.c
+++ b/drivers/vdpa/pds/vdpa_dev.c
@@ -4,6 +4,7 @@ 
 #include <linux/pci.h>
 #include <linux/vdpa.h>
 #include <uapi/linux/vdpa.h>
+#include <linux/virtio_pci_modern.h>
 
 #include <linux/pds/pds_core.h>
 #include <linux/pds/pds_adminq.h>
diff --git a/drivers/vdpa/pds/virtio_pci.c b/drivers/vdpa/pds/virtio_pci.c
new file mode 100644
index 000000000000..cb879619dac3
--- /dev/null
+++ b/drivers/vdpa/pds/virtio_pci.c
@@ -0,0 +1,281 @@ 
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+/*
+ * adapted from drivers/virtio/virtio_pci_modern_dev.c, v6.0-rc1
+ */
+
+#include <linux/virtio_pci_modern.h>
+#include <linux/pci.h>
+
+#include "virtio_pci.h"
+
+/*
+ * pds_vdpa_map_capability - map a part of virtio pci capability
+ * @mdev: the modern virtio-pci device
+ * @off: offset of the capability
+ * @minlen: minimal length of the capability
+ * @align: align requirement
+ * @start: start from the capability
+ * @size: map size
+ * @len: the length that is actually mapped
+ * @pa: physical address of the capability
+ *
+ * Returns the io address of for the part of the capability
+ */
+static void __iomem *
+pds_vdpa_map_capability(struct virtio_pci_modern_device *mdev, int off,
+			size_t minlen, u32 align, u32 start, u32 size,
+			size_t *len, resource_size_t *pa)
+{
+	struct pci_dev *dev = mdev->pci_dev;
+	u8 bar;
+	u32 offset, length;
+	void __iomem *p;
+
+	pci_read_config_byte(dev, off + offsetof(struct virtio_pci_cap,
+						 bar),
+			     &bar);
+	pci_read_config_dword(dev, off + offsetof(struct virtio_pci_cap, offset),
+			      &offset);
+	pci_read_config_dword(dev, off + offsetof(struct virtio_pci_cap, length),
+			      &length);
+
+	/* Check if the BAR may have changed since we requested the region. */
+	if (bar >= PCI_STD_NUM_BARS || !(mdev->modern_bars & (1 << bar))) {
+		dev_err(&dev->dev,
+			"virtio_pci: bar unexpectedly changed to %u\n", bar);
+		return NULL;
+	}
+
+	if (length <= start) {
+		dev_err(&dev->dev,
+			"virtio_pci: bad capability len %u (>%u expected)\n",
+			length, start);
+		return NULL;
+	}
+
+	if (length - start < minlen) {
+		dev_err(&dev->dev,
+			"virtio_pci: bad capability len %u (>=%zu expected)\n",
+			length, minlen);
+		return NULL;
+	}
+
+	length -= start;
+
+	if (start + offset < offset) {
+		dev_err(&dev->dev,
+			"virtio_pci: map wrap-around %u+%u\n",
+			start, offset);
+		return NULL;
+	}
+
+	offset += start;
+
+	if (offset & (align - 1)) {
+		dev_err(&dev->dev,
+			"virtio_pci: offset %u not aligned to %u\n",
+			offset, align);
+		return NULL;
+	}
+
+	if (length > size)
+		length = size;
+
+	if (len)
+		*len = length;
+
+	if (minlen + offset < minlen ||
+	    minlen + offset > pci_resource_len(dev, bar)) {
+		dev_err(&dev->dev,
+			"virtio_pci: map virtio %zu@%u out of range on bar %i length %lu\n",
+			minlen, offset,
+			bar, (unsigned long)pci_resource_len(dev, bar));
+		return NULL;
+	}
+
+	p = pci_iomap_range(dev, bar, offset, length);
+	if (!p)
+		dev_err(&dev->dev,
+			"virtio_pci: unable to map virtio %u@%u on bar %i\n",
+			length, offset, bar);
+	else if (pa)
+		*pa = pci_resource_start(dev, bar) + offset;
+
+	return p;
+}
+
+/**
+ * virtio_pci_find_capability - walk capabilities to find device info.
+ * @dev: the pci device
+ * @cfg_type: the VIRTIO_PCI_CAP_* value we seek
+ * @ioresource_types: IORESOURCE_MEM and/or IORESOURCE_IO.
+ * @bars: the bitmask of BARs
+ *
+ * Returns offset of the capability, or 0.
+ */
+static inline int virtio_pci_find_capability(struct pci_dev *dev, u8 cfg_type,
+					     u32 ioresource_types, int *bars)
+{
+	int pos;
+
+	for (pos = pci_find_capability(dev, PCI_CAP_ID_VNDR);
+	     pos > 0;
+	     pos = pci_find_next_capability(dev, pos, PCI_CAP_ID_VNDR)) {
+		u8 type, bar;
+
+		pci_read_config_byte(dev, pos + offsetof(struct virtio_pci_cap,
+							 cfg_type),
+				     &type);
+		pci_read_config_byte(dev, pos + offsetof(struct virtio_pci_cap,
+							 bar),
+				     &bar);
+
+		/* Ignore structures with reserved BAR values */
+		if (bar >= PCI_STD_NUM_BARS)
+			continue;
+
+		if (type == cfg_type) {
+			if (pci_resource_len(dev, bar) &&
+			    pci_resource_flags(dev, bar) & ioresource_types) {
+				*bars |= (1 << bar);
+				return pos;
+			}
+		}
+	}
+	return 0;
+}
+
+/*
+ * pds_vdpa_probe_virtio: probe the modern virtio pci device, note that the
+ * caller is required to enable PCI device before calling this function.
+ * @mdev: the modern virtio-pci device
+ *
+ * Return 0 on succeed otherwise fail
+ */
+int pds_vdpa_probe_virtio(struct virtio_pci_modern_device *mdev)
+{
+	struct pci_dev *pci_dev = mdev->pci_dev;
+	int err, common, isr, notify, device;
+	u32 notify_length;
+	u32 notify_offset;
+
+	/* check for a common config: if not, use legacy mode (bar 0). */
+	common = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_COMMON_CFG,
+					    IORESOURCE_IO | IORESOURCE_MEM,
+					    &mdev->modern_bars);
+	if (!common) {
+		dev_info(&pci_dev->dev,
+			 "virtio_pci: missing common config\n");
+		return -ENODEV;
+	}
+
+	/* If common is there, these should be too... */
+	isr = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_ISR_CFG,
+					 IORESOURCE_IO | IORESOURCE_MEM,
+					 &mdev->modern_bars);
+	notify = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_NOTIFY_CFG,
+					    IORESOURCE_IO | IORESOURCE_MEM,
+					    &mdev->modern_bars);
+	if (!isr || !notify) {
+		dev_err(&pci_dev->dev,
+			"virtio_pci: missing capabilities %i/%i/%i\n",
+			common, isr, notify);
+		return -EINVAL;
+	}
+
+	/* Device capability is only mandatory for devices that have
+	 * device-specific configuration.
+	 */
+	device = virtio_pci_find_capability(pci_dev, VIRTIO_PCI_CAP_DEVICE_CFG,
+					    IORESOURCE_IO | IORESOURCE_MEM,
+					    &mdev->modern_bars);
+
+	err = pci_request_selected_regions(pci_dev, mdev->modern_bars,
+					   "virtio-pci-modern");
+	if (err)
+		return err;
+
+	err = -EINVAL;
+	mdev->common = pds_vdpa_map_capability(mdev, common,
+					       sizeof(struct virtio_pci_common_cfg),
+					       4, 0,
+					       sizeof(struct virtio_pci_common_cfg),
+					       NULL, NULL);
+	if (!mdev->common)
+		goto err_map_common;
+	mdev->isr = pds_vdpa_map_capability(mdev, isr, sizeof(u8), 1,
+					    0, 1, NULL, NULL);
+	if (!mdev->isr)
+		goto err_map_isr;
+
+	/* Read notify_off_multiplier from config space. */
+	pci_read_config_dword(pci_dev,
+			      notify + offsetof(struct virtio_pci_notify_cap,
+						notify_off_multiplier),
+			      &mdev->notify_offset_multiplier);
+	/* Read notify length and offset from config space. */
+	pci_read_config_dword(pci_dev,
+			      notify + offsetof(struct virtio_pci_notify_cap,
+						cap.length),
+			      &notify_length);
+
+	pci_read_config_dword(pci_dev,
+			      notify + offsetof(struct virtio_pci_notify_cap,
+						cap.offset),
+			      &notify_offset);
+
+	/* We don't know how many VQs we'll map, ahead of the time.
+	 * If notify length is small, map it all now.
+	 * Otherwise, map each VQ individually later.
+	 */
+	if ((u64)notify_length + (notify_offset % PAGE_SIZE) <= PAGE_SIZE) {
+		mdev->notify_base = pds_vdpa_map_capability(mdev, notify,
+							    2, 2,
+							    0, notify_length,
+							    &mdev->notify_len,
+							    &mdev->notify_pa);
+		if (!mdev->notify_base)
+			goto err_map_notify;
+	} else {
+		mdev->notify_map_cap = notify;
+	}
+
+	/* Again, we don't know how much we should map, but PAGE_SIZE
+	 * is more than enough for all existing devices.
+	 */
+	if (device) {
+		mdev->device = pds_vdpa_map_capability(mdev, device, 0, 4,
+						       0, PAGE_SIZE,
+						       &mdev->device_len,
+						       NULL);
+		if (!mdev->device)
+			goto err_map_device;
+	}
+
+	return 0;
+
+err_map_device:
+	if (mdev->notify_base)
+		pci_iounmap(pci_dev, mdev->notify_base);
+err_map_notify:
+	pci_iounmap(pci_dev, mdev->isr);
+err_map_isr:
+	pci_iounmap(pci_dev, mdev->common);
+err_map_common:
+	pci_release_selected_regions(pci_dev, mdev->modern_bars);
+	return err;
+}
+
+void pds_vdpa_remove_virtio(struct virtio_pci_modern_device *mdev)
+{
+	struct pci_dev *pci_dev = mdev->pci_dev;
+
+	if (mdev->device)
+		pci_iounmap(pci_dev, mdev->device);
+	if (mdev->notify_base)
+		pci_iounmap(pci_dev, mdev->notify_base);
+	pci_iounmap(pci_dev, mdev->isr);
+	pci_iounmap(pci_dev, mdev->common);
+	pci_release_selected_regions(pci_dev, mdev->modern_bars);
+}
diff --git a/drivers/vdpa/pds/virtio_pci.h b/drivers/vdpa/pds/virtio_pci.h
new file mode 100644
index 000000000000..f017cfa1173c
--- /dev/null
+++ b/drivers/vdpa/pds/virtio_pci.h
@@ -0,0 +1,8 @@ 
+/* SPDX-License-Identifier: GPL-2.0-only */
+/* Copyright(c) 2023 Advanced Micro Devices, Inc */
+
+#ifndef _PDS_VIRTIO_PCI_H_
+#define _PDS_VIRTIO_PCI_H_
+int pds_vdpa_probe_virtio(struct virtio_pci_modern_device *mdev);
+void pds_vdpa_remove_virtio(struct virtio_pci_modern_device *mdev);
+#endif /* _PDS_VIRTIO_PCI_H_ */