@@ -6,11 +6,14 @@
#ifdef CONFIG_PCI_ATS
/* Address Translation Service */
+bool pci_ats_supported(struct pci_dev *dev);
int pci_enable_ats(struct pci_dev *dev, int ps);
void pci_disable_ats(struct pci_dev *dev);
int pci_ats_queue_depth(struct pci_dev *dev);
int pci_ats_page_aligned(struct pci_dev *dev);
#else /* CONFIG_PCI_ATS */
+static inline bool pci_ats_supported(struct pci_dev *d)
+{ return false; }
static inline int pci_enable_ats(struct pci_dev *d, int ps)
{ return -ENODEV; }
static inline void pci_disable_ats(struct pci_dev *d) { }
@@ -30,6 +30,22 @@ void pci_ats_init(struct pci_dev *dev)
dev->ats_cap = pos;
}
+/**
+ * pci_ats_supported - check if the device can use ATS
+ * @dev: the PCI device
+ *
+ * Returns true if the device supports ATS and is allowed to use it, false
+ * otherwise.
+ */
+bool pci_ats_supported(struct pci_dev *dev)
+{
+ if (!dev->ats_cap)
+ return false;
+
+ return (dev->untrusted == 0);
+}
+EXPORT_SYMBOL_GPL(pci_ats_supported);
+
/**
* pci_enable_ats - enable the ATS capability
* @dev: the PCI device
@@ -42,7 +58,7 @@ int pci_enable_ats(struct pci_dev *dev, int ps)
u16 ctrl;
struct pci_dev *pdev;
- if (!dev->ats_cap)
+ if (!pci_ats_supported(dev))
return -EINVAL;
if (WARN_ON(dev->ats_enabled))