diff mbox series

[3/4] rust: pci: fix unrestricted &mut pci::Device

Message ID 20250313021550.133041-4-dakr@kernel.org (mailing list archive)
State Handled Elsewhere
Delegated to: Bjorn Helgaas
Headers show
Series Improve soundness of bus device abstractions | expand

Commit Message

Danilo Krummrich March 13, 2025, 2:13 a.m. UTC
As by now, pci::Device is implemented as:

	#[derive(Clone)]
	pub struct Device(ARef<device::Device>);

This may be convenient, but has the implication that drivers can call
device methods that require a mutable reference concurrently at any
point of time.

Instead define pci::Device as

	pub struct Device<Ctx: DeviceContext = Normal>(
		Opaque<bindings::pci_dev>,
		PhantomData<Ctx>,
	);

and manually implement the AlwaysRefCounted trait.

With this we can implement methods that should only be called from
bus callbacks (such as probe()) for pci::Device<Core>. Consequently, we
make this type accessible in bus callbacks only.

Arbitrary references taken by the driver are still of type
ARef<pci::Device> and hence don't provide access to methods that are
reserved for bus callbacks.

Fixes: 1bd8b6b2c5d3 ("rust: pci: add basic PCI device / driver abstractions")
Signed-off-by: Danilo Krummrich <dakr@kernel.org>
---
 drivers/gpu/nova-core/driver.rs |   4 +-
 rust/kernel/pci.rs              | 126 ++++++++++++++++++++------------
 samples/rust/rust_driver_pci.rs |   8 +-
 3 files changed, 85 insertions(+), 53 deletions(-)

Comments

Benno Lossin March 13, 2025, 10:44 a.m. UTC | #1
On Thu Mar 13, 2025 at 3:13 AM CET, Danilo Krummrich wrote:
> As by now, pci::Device is implemented as:
>
> 	#[derive(Clone)]
> 	pub struct Device(ARef<device::Device>);
>
> This may be convenient, but has the implication that drivers can call
> device methods that require a mutable reference concurrently at any
> point of time.

Which methods take mutable references? The `set_master` method you
mentioned also took a shared reference before this patch.

> Instead define pci::Device as
>
> 	pub struct Device<Ctx: DeviceContext = Normal>(
> 		Opaque<bindings::pci_dev>,
> 		PhantomData<Ctx>,
> 	);
>
> and manually implement the AlwaysRefCounted trait.
>
> With this we can implement methods that should only be called from
> bus callbacks (such as probe()) for pci::Device<Core>. Consequently, we
> make this type accessible in bus callbacks only.
>
> Arbitrary references taken by the driver are still of type
> ARef<pci::Device> and hence don't provide access to methods that are
> reserved for bus callbacks.
>
> Fixes: 1bd8b6b2c5d3 ("rust: pci: add basic PCI device / driver abstractions")
> Signed-off-by: Danilo Krummrich <dakr@kernel.org>

Two small nits below, but it already looks good:

Reviewed-by: Benno Lossin <benno.lossin@proton.me>

> ---
>  drivers/gpu/nova-core/driver.rs |   4 +-
>  rust/kernel/pci.rs              | 126 ++++++++++++++++++++------------
>  samples/rust/rust_driver_pci.rs |   8 +-
>  3 files changed, 85 insertions(+), 53 deletions(-)
>

> @@ -351,20 +361,8 @@ fn deref(&self) -> &Self::Target {
>  }
>  
>  impl Device {

One alternative to implementing `Deref` below would be to change this to
`impl<Ctx: DeviceContext> Device<Ctx>`. But then one would lose the
ability to just do `&pdev` to get a `Device` from a `Device<Core>`... So
I think the deref way is better. Just wanted to mention this in case
someone re-uses this pattern.

> -    /// Create a PCI Device instance from an existing `device::Device`.
> -    ///
> -    /// # Safety
> -    ///
> -    /// `dev` must be an `ARef<device::Device>` whose underlying `bindings::device` is a member of
> -    /// a `bindings::pci_dev`.
> -    pub unsafe fn from_dev(dev: ARef<device::Device>) -> Self {
> -        Self(dev)
> -    }
> -
>      fn as_raw(&self) -> *mut bindings::pci_dev {
> -        // SAFETY: By the type invariant `self.0.as_raw` is a pointer to the `struct device`
> -        // embedded in `struct pci_dev`.
> -        unsafe { container_of!(self.0.as_raw(), bindings::pci_dev, dev) as _ }
> +        self.0.get()
>      }
>  
>      /// Returns the PCI vendor ID.

>  impl AsRef<device::Device> for Device {
>      fn as_ref(&self) -> &device::Device {
> -        &self.0
> +        // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid
> +        // `struct pci_dev`.
> +        let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) };
> +
> +        // SAFETY: `dev` points to a valid `struct device`.
> +        unsafe { device::Device::as_ref(dev) }

Why not use `&**self` instead (ie go through the `Deref` impl)?

> @@ -77,7 +77,7 @@ fn probe(pdev: &mut pci::Device, info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>
>  
>          let drvdata = KBox::new(
>              Self {
> -                pdev: pdev.clone(),
> +                pdev: (&**pdev).into(),

It might be possible to do:

    impl From<&pci::Device<Core>> for ARef<pci::Device> { ... }

Then this line could become `pdev: pdev.into()`.

---
Cheers,
Benno

>                  bar,
>              },
>              GFP_KERNEL,
Danilo Krummrich March 13, 2025, 2:25 p.m. UTC | #2
On Thu, Mar 13, 2025 at 10:44:38AM +0000, Benno Lossin wrote:
> On Thu Mar 13, 2025 at 3:13 AM CET, Danilo Krummrich wrote:
> > As by now, pci::Device is implemented as:
> >
> > 	#[derive(Clone)]
> > 	pub struct Device(ARef<device::Device>);
> >
> > This may be convenient, but has the implication that drivers can call
> > device methods that require a mutable reference concurrently at any
> > point of time.
> 
> Which methods take mutable references? The `set_master` method you
> mentioned also took a shared reference before this patch.

Yeah, that's basically a bug that I never fixed (until now), since making it
take a mutable reference would not have changed anything in terms of
accessibility.

> 
> > Instead define pci::Device as
> >
> > 	pub struct Device<Ctx: DeviceContext = Normal>(
> > 		Opaque<bindings::pci_dev>,
> > 		PhantomData<Ctx>,
> > 	);
> >
> > and manually implement the AlwaysRefCounted trait.
> >
> > With this we can implement methods that should only be called from
> > bus callbacks (such as probe()) for pci::Device<Core>. Consequently, we
> > make this type accessible in bus callbacks only.
> >
> > Arbitrary references taken by the driver are still of type
> > ARef<pci::Device> and hence don't provide access to methods that are
> > reserved for bus callbacks.
> >
> > Fixes: 1bd8b6b2c5d3 ("rust: pci: add basic PCI device / driver abstractions")
> > Signed-off-by: Danilo Krummrich <dakr@kernel.org>
> 
> Two small nits below, but it already looks good:
> 
> Reviewed-by: Benno Lossin <benno.lossin@proton.me>
> 
> > ---
> >  drivers/gpu/nova-core/driver.rs |   4 +-
> >  rust/kernel/pci.rs              | 126 ++++++++++++++++++++------------
> >  samples/rust/rust_driver_pci.rs |   8 +-
> >  3 files changed, 85 insertions(+), 53 deletions(-)
> >
> 
> > @@ -351,20 +361,8 @@ fn deref(&self) -> &Self::Target {
> >  }
> >  
> >  impl Device {
> 
> One alternative to implementing `Deref` below would be to change this to
> `impl<Ctx: DeviceContext> Device<Ctx>`. But then one would lose the
> ability to just do `&pdev` to get a `Device` from a `Device<Core>`... So
> I think the deref way is better. Just wanted to mention this in case
> someone re-uses this pattern.
> 
> > -    /// Create a PCI Device instance from an existing `device::Device`.
> > -    ///
> > -    /// # Safety
> > -    ///
> > -    /// `dev` must be an `ARef<device::Device>` whose underlying `bindings::device` is a member of
> > -    /// a `bindings::pci_dev`.
> > -    pub unsafe fn from_dev(dev: ARef<device::Device>) -> Self {
> > -        Self(dev)
> > -    }
> > -
> >      fn as_raw(&self) -> *mut bindings::pci_dev {
> > -        // SAFETY: By the type invariant `self.0.as_raw` is a pointer to the `struct device`
> > -        // embedded in `struct pci_dev`.
> > -        unsafe { container_of!(self.0.as_raw(), bindings::pci_dev, dev) as _ }
> > +        self.0.get()
> >      }
> >  
> >      /// Returns the PCI vendor ID.
> 
> >  impl AsRef<device::Device> for Device {
> >      fn as_ref(&self) -> &device::Device {
> > -        &self.0
> > +        // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid
> > +        // `struct pci_dev`.
> > +        let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) };
> > +
> > +        // SAFETY: `dev` points to a valid `struct device`.
> > +        unsafe { device::Device::as_ref(dev) }
> 
> Why not use `&**self` instead (ie go through the `Deref` impl)?

`&**self` gives us a `Device` (i.e. `pci::Device`), not a `device::Device`.

> 
> > @@ -77,7 +77,7 @@ fn probe(pdev: &mut pci::Device, info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>
> >  
> >          let drvdata = KBox::new(
> >              Self {
> > -                pdev: pdev.clone(),
> > +                pdev: (&**pdev).into(),
> 
> It might be possible to do:
> 
>     impl From<&pci::Device<Core>> for ARef<pci::Device> { ... }
> 
> Then this line could become `pdev: pdev.into()`.

Yeah, having to write `&**pdev` was bothering me too, and I actually tried what
you suggest, but it didn't compile -- I'll double check.
Benno Lossin March 13, 2025, 2:30 p.m. UTC | #3
On Thu Mar 13, 2025 at 3:25 PM CET, Danilo Krummrich wrote:
> On Thu, Mar 13, 2025 at 10:44:38AM +0000, Benno Lossin wrote:
>> On Thu Mar 13, 2025 at 3:13 AM CET, Danilo Krummrich wrote:
>> > As by now, pci::Device is implemented as:
>> >
>> > 	#[derive(Clone)]
>> > 	pub struct Device(ARef<device::Device>);
>> >
>> > This may be convenient, but has the implication that drivers can call
>> > device methods that require a mutable reference concurrently at any
>> > point of time.
>> 
>> Which methods take mutable references? The `set_master` method you
>> mentioned also took a shared reference before this patch.
>
> Yeah, that's basically a bug that I never fixed (until now), since making it
> take a mutable reference would not have changed anything in terms of
> accessibility.

Gotcha.

>> >  impl AsRef<device::Device> for Device {
>> >      fn as_ref(&self) -> &device::Device {
>> > -        &self.0
>> > +        // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid
>> > +        // `struct pci_dev`.
>> > +        let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) };
>> > +
>> > +        // SAFETY: `dev` points to a valid `struct device`.
>> > +        unsafe { device::Device::as_ref(dev) }
>> 
>> Why not use `&**self` instead (ie go through the `Deref` impl)?
>
> `&**self` gives us a `Device` (i.e. `pci::Device`), not a `device::Device`.

Ah, yeah then you'll have to use `unsafe`.

---
Cheers,
Benno
diff mbox series

Patch

diff --git a/drivers/gpu/nova-core/driver.rs b/drivers/gpu/nova-core/driver.rs
index 63c19f140fbd..a08fb6599267 100644
--- a/drivers/gpu/nova-core/driver.rs
+++ b/drivers/gpu/nova-core/driver.rs
@@ -1,6 +1,6 @@ 
 // SPDX-License-Identifier: GPL-2.0
 
-use kernel::{bindings, c_str, pci, prelude::*};
+use kernel::{bindings, c_str, device::Core, pci, prelude::*};
 
 use crate::gpu::Gpu;
 
@@ -27,7 +27,7 @@  impl pci::Driver for NovaCore {
     type IdInfo = ();
     const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;
 
-    fn probe(pdev: &mut pci::Device, _info: &Self::IdInfo) -> Result<Pin<KBox<Self>>> {
+    fn probe(pdev: &pci::Device<Core>, _info: &Self::IdInfo) -> Result<Pin<KBox<Self>>> {
         dev_dbg!(pdev.as_ref(), "Probe Nova Core GPU driver.\n");
 
         pdev.enable_device_mem()?;
diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs
index 386484dcf36e..6357b4ff8d65 100644
--- a/rust/kernel/pci.rs
+++ b/rust/kernel/pci.rs
@@ -6,7 +6,7 @@ 
 
 use crate::{
     alloc::flags::*,
-    bindings, container_of, device,
+    bindings, device,
     device_id::RawDeviceId,
     devres::Devres,
     driver,
@@ -17,7 +17,11 @@ 
     types::{ARef, ForeignOwnable, Opaque},
     ThisModule,
 };
-use core::{ops::Deref, ptr::addr_of_mut};
+use core::{
+    marker::PhantomData,
+    ops::Deref,
+    ptr::{addr_of_mut, NonNull},
+};
 use kernel::prelude::*;
 
 /// An adapter for the registration of PCI drivers.
@@ -60,17 +64,16 @@  extern "C" fn probe_callback(
     ) -> kernel::ffi::c_int {
         // SAFETY: The PCI bus only ever calls the probe callback with a valid pointer to a
         // `struct pci_dev`.
-        let dev = unsafe { device::Device::get_device(addr_of_mut!((*pdev).dev)) };
-        // SAFETY: `dev` is guaranteed to be embedded in a valid `struct pci_dev` by the call
-        // above.
-        let mut pdev = unsafe { Device::from_dev(dev) };
+        //
+        // INVARIANT: `pdev` is valid for the duration of `probe_callback()`.
+        let pdev = unsafe { &*pdev.cast::<Device<device::Core>>() };
 
         // SAFETY: `DeviceId` is a `#[repr(transparent)` wrapper of `struct pci_device_id` and
         // does not add additional invariants, so it's safe to transmute.
         let id = unsafe { &*id.cast::<DeviceId>() };
         let info = T::ID_TABLE.info(id.index());
 
-        match T::probe(&mut pdev, info) {
+        match T::probe(pdev, info) {
             Ok(data) => {
                 // Let the `struct pci_dev` own a reference of the driver's private data.
                 // SAFETY: By the type invariant `pdev.as_raw` returns a valid pointer to a
@@ -192,7 +195,7 @@  macro_rules! pci_device_table {
 /// # Example
 ///
 ///```
-/// # use kernel::{bindings, pci};
+/// # use kernel::{bindings, device::Core, pci};
 ///
 /// struct MyDriver;
 ///
@@ -210,7 +213,7 @@  macro_rules! pci_device_table {
 ///     const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;
 ///
 ///     fn probe(
-///         _pdev: &mut pci::Device,
+///         _pdev: &pci::Device<Core>,
 ///         _id_info: &Self::IdInfo,
 ///     ) -> Result<Pin<KBox<Self>>> {
 ///         Err(ENODEV)
@@ -234,20 +237,23 @@  pub trait Driver {
     ///
     /// Called when a new platform device is added or discovered.
     /// Implementers should attempt to initialize the device here.
-    fn probe(dev: &mut Device, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>;
+    fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>;
 }
 
 /// The PCI device representation.
 ///
-/// A PCI device is based on an always reference counted `device:Device` instance. Cloning a PCI
-/// device, hence, also increments the base device' reference count.
+/// This structure represents the Rust abstraction for a C `struct pci_dev`. The implementation
+/// abstracts the usage of an already existing C `struct pci_dev` within Rust code that we get
+/// passed from the C side.
 ///
 /// # Invariants
 ///
-/// `Device` hold a valid reference of `ARef<device::Device>` whose underlying `struct device` is a
-/// member of a `struct pci_dev`.
-#[derive(Clone)]
-pub struct Device(ARef<device::Device>);
+/// A [`Device`] instance represents a valid `struct device` created by the C portion of the kernel.
+#[repr(transparent)]
+pub struct Device<Ctx: device::DeviceContext = device::Normal>(
+    Opaque<bindings::pci_dev>,
+    PhantomData<Ctx>,
+);
 
 /// A PCI BAR to perform I/O-Operations on.
 ///
@@ -256,13 +262,13 @@  pub trait Driver {
 /// `Bar` always holds an `IoRaw` inststance that holds a valid pointer to the start of the I/O
 /// memory mapped PCI bar and its size.
 pub struct Bar<const SIZE: usize = 0> {
-    pdev: Device,
+    pdev: ARef<Device>,
     io: IoRaw<SIZE>,
     num: i32,
 }
 
 impl<const SIZE: usize> Bar<SIZE> {
-    fn new(pdev: Device, num: u32, name: &CStr) -> Result<Self> {
+    fn new(pdev: &Device, num: u32, name: &CStr) -> Result<Self> {
         let len = pdev.resource_len(num)?;
         if len == 0 {
             return Err(ENOMEM);
@@ -300,12 +306,16 @@  fn new(pdev: Device, num: u32, name: &CStr) -> Result<Self> {
                 // `pdev` is valid by the invariants of `Device`.
                 // `ioptr` is guaranteed to be the start of a valid I/O mapped memory region.
                 // `num` is checked for validity by a previous call to `Device::resource_len`.
-                unsafe { Self::do_release(&pdev, ioptr, num) };
+                unsafe { Self::do_release(pdev, ioptr, num) };
                 return Err(err);
             }
         };
 
-        Ok(Bar { pdev, io, num })
+        Ok(Bar {
+            pdev: pdev.into(),
+            io,
+            num,
+        })
     }
 
     /// # Safety
@@ -351,20 +361,8 @@  fn deref(&self) -> &Self::Target {
 }
 
 impl Device {
-    /// Create a PCI Device instance from an existing `device::Device`.
-    ///
-    /// # Safety
-    ///
-    /// `dev` must be an `ARef<device::Device>` whose underlying `bindings::device` is a member of
-    /// a `bindings::pci_dev`.
-    pub unsafe fn from_dev(dev: ARef<device::Device>) -> Self {
-        Self(dev)
-    }
-
     fn as_raw(&self) -> *mut bindings::pci_dev {
-        // SAFETY: By the type invariant `self.0.as_raw` is a pointer to the `struct device`
-        // embedded in `struct pci_dev`.
-        unsafe { container_of!(self.0.as_raw(), bindings::pci_dev, dev) as _ }
+        self.0.get()
     }
 
     /// Returns the PCI vendor ID.
@@ -379,18 +377,6 @@  pub fn device_id(&self) -> u16 {
         unsafe { (*self.as_raw()).device }
     }
 
-    /// Enable memory resources for this device.
-    pub fn enable_device_mem(&self) -> Result {
-        // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
-        to_result(unsafe { bindings::pci_enable_device_mem(self.as_raw()) })
-    }
-
-    /// Enable bus-mastering for this device.
-    pub fn set_master(&self) {
-        // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
-        unsafe { bindings::pci_set_master(self.as_raw()) };
-    }
-
     /// Returns the size of the given PCI bar resource.
     pub fn resource_len(&self, bar: u32) -> Result<bindings::resource_size_t> {
         if !Bar::index_is_valid(bar) {
@@ -410,7 +396,7 @@  pub fn iomap_region_sized<const SIZE: usize>(
         bar: u32,
         name: &CStr,
     ) -> Result<Devres<Bar<SIZE>>> {
-        let bar = Bar::<SIZE>::new(self.clone(), bar, name)?;
+        let bar = Bar::<SIZE>::new(self, bar, name)?;
         let devres = Devres::new(self.as_ref(), bar, GFP_KERNEL)?;
 
         Ok(devres)
@@ -422,8 +408,54 @@  pub fn iomap_region(&self, bar: u32, name: &CStr) -> Result<Devres<Bar>> {
     }
 }
 
+impl Device<device::Core> {
+    /// Enable memory resources for this device.
+    pub fn enable_device_mem(&self) -> Result {
+        // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
+        to_result(unsafe { bindings::pci_enable_device_mem(self.as_raw()) })
+    }
+
+    /// Enable bus-mastering for this device.
+    pub fn set_master(&self) {
+        // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
+        unsafe { bindings::pci_set_master(self.as_raw()) };
+    }
+}
+
+impl Deref for Device<device::Core> {
+    type Target = Device;
+
+    fn deref(&self) -> &Self::Target {
+        let ptr: *const Self = self;
+
+        // CAST: `Device<Ctx>` is a transparent wrapper of `Opaque<bindings::pci_dev>`.
+        let ptr = ptr.cast::<Device>();
+
+        // SAFETY: `ptr` was derived from `&self`.
+        unsafe { &*ptr }
+    }
+}
+
+// SAFETY: Instances of `Device` are always reference-counted.
+unsafe impl crate::types::AlwaysRefCounted for Device {
+    fn inc_ref(&self) {
+        // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero.
+        unsafe { bindings::pci_dev_get(self.as_raw()) };
+    }
+
+    unsafe fn dec_ref(obj: NonNull<Self>) {
+        // SAFETY: The safety requirements guarantee that the refcount is non-zero.
+        unsafe { bindings::pci_dev_put(obj.cast().as_ptr()) }
+    }
+}
+
 impl AsRef<device::Device> for Device {
     fn as_ref(&self) -> &device::Device {
-        &self.0
+        // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid
+        // `struct pci_dev`.
+        let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) };
+
+        // SAFETY: `dev` points to a valid `struct device`.
+        unsafe { device::Device::as_ref(dev) }
     }
 }
diff --git a/samples/rust/rust_driver_pci.rs b/samples/rust/rust_driver_pci.rs
index 1fb6e44f3395..b90df5f9d1d0 100644
--- a/samples/rust/rust_driver_pci.rs
+++ b/samples/rust/rust_driver_pci.rs
@@ -4,7 +4,7 @@ 
 //!
 //! To make this driver probe, QEMU must be run with `-device pci-testdev`.
 
-use kernel::{bindings, c_str, devres::Devres, pci, prelude::*};
+use kernel::{bindings, c_str, device::Core, devres::Devres, pci, prelude::*, types::ARef};
 
 struct Regs;
 
@@ -26,7 +26,7 @@  impl TestIndex {
 }
 
 struct SampleDriver {
-    pdev: pci::Device,
+    pdev: ARef<pci::Device>,
     bar: Devres<Bar0>,
 }
 
@@ -62,7 +62,7 @@  impl pci::Driver for SampleDriver {
 
     const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;
 
-    fn probe(pdev: &mut pci::Device, info: &Self::IdInfo) -> Result<Pin<KBox<Self>>> {
+    fn probe(pdev: &pci::Device<Core>, info: &Self::IdInfo) -> Result<Pin<KBox<Self>>> {
         dev_dbg!(
             pdev.as_ref(),
             "Probe Rust PCI driver sample (PCI ID: 0x{:x}, 0x{:x}).\n",
@@ -77,7 +77,7 @@  fn probe(pdev: &mut pci::Device, info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>
 
         let drvdata = KBox::new(
             Self {
-                pdev: pdev.clone(),
+                pdev: (&**pdev).into(),
                 bar,
             },
             GFP_KERNEL,