diff mbox series

[v3,3/4] rust: block: convert `block::mq` to use `Refcount`

Message ID 20250219201602.1898383-4-gary@garyguo.net (mailing list archive)
State New
Headers show
Series None | expand

Commit Message

Gary Guo Feb. 19, 2025, 8:15 p.m. UTC
Currently there's a custom reference counting in `block::mq`, which uses
`AtomicU64` Rust atomics, and this type doesn't exist on some 32-bit
architectures. We cannot just change it to use 32-bit atomics, because
doing so will make it vulnerable to refcount overflow. So switch it to
use the kernel refcount `kernel::sync::Refcount` instead.

There is an operation needed by `block::mq`, atomically decreasing
refcount from 2 to 0, which is not available through refcount.h, so
I exposed `Refcount::as_atomic` which allows accessing the refcount
directly.

Acked-by: Andreas Hindborg <a.hindborg@kernel.org>
Signed-off-by: Gary Guo <gary@garyguo.net>
---
 rust/kernel/block/mq/operations.rs |  7 +--
 rust/kernel/block/mq/request.rs    | 70 ++++++++++--------------------
 rust/kernel/sync/refcount.rs       | 14 ++++++
 3 files changed, 40 insertions(+), 51 deletions(-)

Comments

Tamir Duberstein Feb. 19, 2025, 10:26 p.m. UTC | #1
On Wed, Feb 19, 2025 at 3:17 PM Gary Guo <gary@garyguo.net> wrote:
>
> Currently there's a custom reference counting in `block::mq`, which uses
> `AtomicU64` Rust atomics, and this type doesn't exist on some 32-bit
> architectures. We cannot just change it to use 32-bit atomics, because
> doing so will make it vulnerable to refcount overflow. So switch it to
> use the kernel refcount `kernel::sync::Refcount` instead.
>
> There is an operation needed by `block::mq`, atomically decreasing
> refcount from 2 to 0, which is not available through refcount.h, so
> I exposed `Refcount::as_atomic` which allows accessing the refcount
> directly.
>
> Acked-by: Andreas Hindborg <a.hindborg@kernel.org>
> Signed-off-by: Gary Guo <gary@garyguo.net>
> ---
>  rust/kernel/block/mq/operations.rs |  7 +--
>  rust/kernel/block/mq/request.rs    | 70 ++++++++++--------------------
>  rust/kernel/sync/refcount.rs       | 14 ++++++
>  3 files changed, 40 insertions(+), 51 deletions(-)
>
> diff --git a/rust/kernel/block/mq/operations.rs b/rust/kernel/block/mq/operations.rs
> index 864ff379dc91..c399dcaa6740 100644
> --- a/rust/kernel/block/mq/operations.rs
> +++ b/rust/kernel/block/mq/operations.rs
> @@ -10,9 +10,10 @@
>      block::mq::Request,
>      error::{from_result, Result},
>      prelude::*,
> +    sync::Refcount,
>      types::ARef,
>  };
> -use core::{marker::PhantomData, sync::atomic::AtomicU64, sync::atomic::Ordering};
> +use core::marker::PhantomData;
>
>  /// Implement this trait to interface blk-mq as block devices.
>  ///
> @@ -78,7 +79,7 @@ impl<T: Operations> OperationsVTable<T> {
>          let request = unsafe { &*(*bd).rq.cast::<Request<T>>() };
>
>          // One refcount for the ARef, one for being in flight
> -        request.wrapper_ref().refcount().store(2, Ordering::Relaxed);
> +        request.wrapper_ref().refcount().set(2);
>
>          // SAFETY:
>          //  - We own a refcount that we took above. We pass that to `ARef`.
> @@ -187,7 +188,7 @@ impl<T: Operations> OperationsVTable<T> {
>
>              // SAFETY: The refcount field is allocated but not initialized, so
>              // it is valid for writes.
> -            unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(AtomicU64::new(0)) };
> +            unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(Refcount::new(0)) };

Could we just make the field pub and remove refcount_ptr? I believe a
few callers of `wrapper_ptr` could be replaced with `wrapper_ref`.

>
>              Ok(0)
>          })
> diff --git a/rust/kernel/block/mq/request.rs b/rust/kernel/block/mq/request.rs
> index 7943f43b9575..7c782d70935e 100644
> --- a/rust/kernel/block/mq/request.rs
> +++ b/rust/kernel/block/mq/request.rs
> @@ -8,12 +8,13 @@
>      bindings,
>      block::mq::Operations,
>      error::Result,
> +    sync::Refcount,
>      types::{ARef, AlwaysRefCounted, Opaque},
>  };
>  use core::{
>      marker::PhantomData,
>      ptr::{addr_of_mut, NonNull},
> -    sync::atomic::{AtomicU64, Ordering},
> +    sync::atomic::Ordering,
>  };
>
>  /// A wrapper around a blk-mq [`struct request`]. This represents an IO request.
> @@ -37,6 +38,9 @@
>  /// We need to track 3 and 4 to ensure that it is safe to end the request and hand
>  /// back ownership to the block layer.
>  ///
> +/// Note that driver can still obtain new `ARef` even if there is no `ARef`s in existence by using

Is this missing an article? "The driver".

> +/// `tag_to_rq`, hence the need to distinct B and C.

s/distinct/distinguish/, I think.

> +///
>  /// The states are tracked through the private `refcount` field of
>  /// `RequestDataWrapper`. This structure lives in the private data area of the C
>  /// [`struct request`].
> @@ -98,13 +102,17 @@ pub(crate) unsafe fn start_unchecked(this: &ARef<Self>) {
>      ///
>      /// [`struct request`]: srctree/include/linux/blk-mq.h
>      fn try_set_end(this: ARef<Self>) -> Result<*mut bindings::request, ARef<Self>> {
> -        // We can race with `TagSet::tag_to_rq`
> -        if let Err(_old) = this.wrapper_ref().refcount().compare_exchange(
> -            2,
> -            0,
> -            Ordering::Relaxed,
> -            Ordering::Relaxed,
> -        ) {
> +        // To hand back the ownership, we need the current refcount to be 2.
> +        // Since we can race with `TagSet::tag_to_rq`, this needs to atomically reduce
> +        // refcount to 0. `Refcount` does not provide a way to do this, so use the underlying
> +        // atomics directly.
> +        if this
> +            .wrapper_ref()
> +            .refcount()
> +            .as_atomic()
> +            .compare_exchange(2, 0, Ordering::Relaxed, Ordering::Relaxed)
> +            .is_err()

The previous `if let` was a bit more clear about what's being
discarded here (the previous value). This information is lost with
`is_err()`.

> +        {
>              return Err(this);
>          }
>
> @@ -168,13 +176,13 @@ pub(crate) struct RequestDataWrapper {
>      /// - 0: The request is owned by C block layer.
>      /// - 1: The request is owned by Rust abstractions but there are no [`ARef`] references to it.
>      /// - 2+: There are [`ARef`] references to the request.
> -    refcount: AtomicU64,
> +    refcount: Refcount,
>  }
>
>  impl RequestDataWrapper {
>      /// Return a reference to the refcount of the request that is embedding
>      /// `self`.
> -    pub(crate) fn refcount(&self) -> &AtomicU64 {
> +    pub(crate) fn refcount(&self) -> &Refcount {
>          &self.refcount
>      }
>
> @@ -184,7 +192,7 @@ pub(crate) fn refcount(&self) -> &AtomicU64 {
>      /// # Safety
>      ///
>      /// - `this` must point to a live allocation of at least the size of `Self`.
> -    pub(crate) unsafe fn refcount_ptr(this: *mut Self) -> *mut AtomicU64 {
> +    pub(crate) unsafe fn refcount_ptr(this: *mut Self) -> *mut Refcount {
>          // SAFETY: Because of the safety requirements of this function, the
>          // field projection is safe.
>          unsafe { addr_of_mut!((*this).refcount) }
> @@ -200,47 +208,13 @@ unsafe impl<T: Operations> Send for Request<T> {}
>  // mutate `self` are internally synchronized`
>  unsafe impl<T: Operations> Sync for Request<T> {}
>
> -/// Store the result of `op(target.load())` in target, returning new value of
> -/// target.
> -fn atomic_relaxed_op_return(target: &AtomicU64, op: impl Fn(u64) -> u64) -> u64 {
> -    let old = target.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some(op(x)));
> -
> -    // SAFETY: Because the operation passed to `fetch_update` above always
> -    // return `Some`, `old` will always be `Ok`.
> -    let old = unsafe { old.unwrap_unchecked() };
> -
> -    op(old)
> -}
> -
> -/// Store the result of `op(target.load)` in `target` if `target.load() !=
> -/// pred`, returning [`true`] if the target was updated.
> -fn atomic_relaxed_op_unless(target: &AtomicU64, op: impl Fn(u64) -> u64, pred: u64) -> bool {
> -    target
> -        .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| {
> -            if x == pred {
> -                None
> -            } else {
> -                Some(op(x))
> -            }
> -        })
> -        .is_ok()
> -}
> -
>  // SAFETY: All instances of `Request<T>` are reference counted. This
>  // implementation of `AlwaysRefCounted` ensure that increments to the ref count
>  // keeps the object alive in memory at least until a matching reference count
>  // decrement is executed.
>  unsafe impl<T: Operations> AlwaysRefCounted for Request<T> {
>      fn inc_ref(&self) {
> -        let refcount = &self.wrapper_ref().refcount();
> -
> -        #[cfg_attr(not(CONFIG_DEBUG_MISC), allow(unused_variables))]
> -        let updated = atomic_relaxed_op_unless(refcount, |x| x + 1, 0);
> -
> -        #[cfg(CONFIG_DEBUG_MISC)]
> -        if !updated {
> -            panic!("Request refcount zero on clone")
> -        }
> +        self.wrapper_ref().refcount().inc();
>      }
>
>      unsafe fn dec_ref(obj: core::ptr::NonNull<Self>) {
> @@ -252,10 +226,10 @@ unsafe fn dec_ref(obj: core::ptr::NonNull<Self>) {
>          let refcount = unsafe { &*RequestDataWrapper::refcount_ptr(wrapper_ptr) };
>
>          #[cfg_attr(not(CONFIG_DEBUG_MISC), allow(unused_variables))]
> -        let new_refcount = atomic_relaxed_op_return(refcount, |x| x - 1);
> +        let is_zero = refcount.dec_and_test();

Should this call .dec() if not(CONFIG_DEBUG_MISC)?

>
>          #[cfg(CONFIG_DEBUG_MISC)]
> -        if new_refcount == 0 {
> +        if is_zero {
>              panic!("Request reached refcount zero in Rust abstractions");
>          }
>      }
> diff --git a/rust/kernel/sync/refcount.rs b/rust/kernel/sync/refcount.rs
> index a6a683f5d7b8..3d7a1ffb3a46 100644
> --- a/rust/kernel/sync/refcount.rs
> +++ b/rust/kernel/sync/refcount.rs
> @@ -4,6 +4,8 @@
>  //!
>  //! C header: [`include/linux/refcount.h`](srctree/include/linux/refcount.h)
>
> +use core::sync::atomic::AtomicI32;
> +
>  use crate::types::Opaque;
>
>  /// Atomic reference counter.
> @@ -30,6 +32,18 @@ fn as_ptr(&self) -> *mut bindings::refcount_t {
>          self.0.get()
>      }
>
> +    /// Get the underlying atomic counter that backs the refcount.
> +    ///
> +    /// NOTE: This will be changed to LKMM atomic in the future.
> +    #[inline]
> +    pub fn as_atomic(&self) -> &AtomicI32 {
> +        let ptr = self.0.get() as *const AtomicI32;

Prefer `.cast()` to raw pointer casting please.

> +        // SAFETY: `refcount_t` is a transparent wrapper of `atomic_t`, which is an atomic 32-bit
> +        // integer that is layout-wise compatible with `AtomicI32`. All values are valid for
> +        // `refcount_t`, despite some of the values are considered saturated and "bad".

Grammer: s/are/being/.

Is there a citation you can link to here?

> +        unsafe { &*ptr }
> +    }
> +
>      /// Set a refcount's value.
>      #[inline]
>      pub fn set(&self, value: i32) {
> --
> 2.47.2
>
>
Tamir Duberstein Feb. 19, 2025, 10:53 p.m. UTC | #2
On Wed, Feb 19, 2025 at 5:26 PM Tamir Duberstein <tamird@gmail.com> wrote:
>
> On Wed, Feb 19, 2025 at 3:17 PM Gary Guo <gary@garyguo.net> wrote:
> >
> > Currently there's a custom reference counting in `block::mq`, which uses
> > `AtomicU64` Rust atomics, and this type doesn't exist on some 32-bit
> > architectures. We cannot just change it to use 32-bit atomics, because
> > doing so will make it vulnerable to refcount overflow. So switch it to
> > use the kernel refcount `kernel::sync::Refcount` instead.
> >
> > There is an operation needed by `block::mq`, atomically decreasing
> > refcount from 2 to 0, which is not available through refcount.h, so
> > I exposed `Refcount::as_atomic` which allows accessing the refcount
> > directly.
> >
> > Acked-by: Andreas Hindborg <a.hindborg@kernel.org>
> > Signed-off-by: Gary Guo <gary@garyguo.net>
> > ---
> >  rust/kernel/block/mq/operations.rs |  7 +--
> >  rust/kernel/block/mq/request.rs    | 70 ++++++++++--------------------
> >  rust/kernel/sync/refcount.rs       | 14 ++++++
> >  3 files changed, 40 insertions(+), 51 deletions(-)
> >
> > diff --git a/rust/kernel/block/mq/operations.rs b/rust/kernel/block/mq/operations.rs
> > index 864ff379dc91..c399dcaa6740 100644
> > --- a/rust/kernel/block/mq/operations.rs
> > +++ b/rust/kernel/block/mq/operations.rs
> > @@ -10,9 +10,10 @@
> >      block::mq::Request,
> >      error::{from_result, Result},
> >      prelude::*,
> > +    sync::Refcount,
> >      types::ARef,
> >  };
> > -use core::{marker::PhantomData, sync::atomic::AtomicU64, sync::atomic::Ordering};
> > +use core::marker::PhantomData;
> >
> >  /// Implement this trait to interface blk-mq as block devices.
> >  ///
> > @@ -78,7 +79,7 @@ impl<T: Operations> OperationsVTable<T> {
> >          let request = unsafe { &*(*bd).rq.cast::<Request<T>>() };
> >
> >          // One refcount for the ARef, one for being in flight
> > -        request.wrapper_ref().refcount().store(2, Ordering::Relaxed);
> > +        request.wrapper_ref().refcount().set(2);
> >
> >          // SAFETY:
> >          //  - We own a refcount that we took above. We pass that to `ARef`.
> > @@ -187,7 +188,7 @@ impl<T: Operations> OperationsVTable<T> {
> >
> >              // SAFETY: The refcount field is allocated but not initialized, so
> >              // it is valid for writes.
> > -            unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(AtomicU64::new(0)) };
> > +            unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(Refcount::new(0)) };
>
> Could we just make the field pub and remove refcount_ptr? I believe a
> few callers of `wrapper_ptr` could be replaced with `wrapper_ref`.

I took a stab at this to check it was possible:
https://gist.github.com/tamird/c9de7fa6e54529996f433950268f3f87
David Gow Feb. 20, 2025, 12:27 p.m. UTC | #3
On Thu, 20 Feb 2025 at 04:17, Gary Guo <gary@garyguo.net> wrote:
>
> Currently there's a custom reference counting in `block::mq`, which uses
> `AtomicU64` Rust atomics, and this type doesn't exist on some 32-bit
> architectures. We cannot just change it to use 32-bit atomics, because
> doing so will make it vulnerable to refcount overflow. So switch it to
> use the kernel refcount `kernel::sync::Refcount` instead.
>
> There is an operation needed by `block::mq`, atomically decreasing
> refcount from 2 to 0, which is not available through refcount.h, so
> I exposed `Refcount::as_atomic` which allows accessing the refcount
> directly.
>
> Acked-by: Andreas Hindborg <a.hindborg@kernel.org>
> Signed-off-by: Gary Guo <gary@garyguo.net>
> ---

Thanks very much. I can confirm this fixes 32-bit UML.

Tested-by: David Gow <davidgow@google.com>

Cheers,
-- David

>  rust/kernel/block/mq/operations.rs |  7 +--
>  rust/kernel/block/mq/request.rs    | 70 ++++++++++--------------------
>  rust/kernel/sync/refcount.rs       | 14 ++++++
>  3 files changed, 40 insertions(+), 51 deletions(-)
>
> diff --git a/rust/kernel/block/mq/operations.rs b/rust/kernel/block/mq/operations.rs
> index 864ff379dc91..c399dcaa6740 100644
> --- a/rust/kernel/block/mq/operations.rs
> +++ b/rust/kernel/block/mq/operations.rs
> @@ -10,9 +10,10 @@
>      block::mq::Request,
>      error::{from_result, Result},
>      prelude::*,
> +    sync::Refcount,
>      types::ARef,
>  };
> -use core::{marker::PhantomData, sync::atomic::AtomicU64, sync::atomic::Ordering};
> +use core::marker::PhantomData;
>
>  /// Implement this trait to interface blk-mq as block devices.
>  ///
> @@ -78,7 +79,7 @@ impl<T: Operations> OperationsVTable<T> {
>          let request = unsafe { &*(*bd).rq.cast::<Request<T>>() };
>
>          // One refcount for the ARef, one for being in flight
> -        request.wrapper_ref().refcount().store(2, Ordering::Relaxed);
> +        request.wrapper_ref().refcount().set(2);
>
>          // SAFETY:
>          //  - We own a refcount that we took above. We pass that to `ARef`.
> @@ -187,7 +188,7 @@ impl<T: Operations> OperationsVTable<T> {
>
>              // SAFETY: The refcount field is allocated but not initialized, so
>              // it is valid for writes.
> -            unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(AtomicU64::new(0)) };
> +            unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(Refcount::new(0)) };
>
>              Ok(0)
>          })
> diff --git a/rust/kernel/block/mq/request.rs b/rust/kernel/block/mq/request.rs
> index 7943f43b9575..7c782d70935e 100644
> --- a/rust/kernel/block/mq/request.rs
> +++ b/rust/kernel/block/mq/request.rs
> @@ -8,12 +8,13 @@
>      bindings,
>      block::mq::Operations,
>      error::Result,
> +    sync::Refcount,
>      types::{ARef, AlwaysRefCounted, Opaque},
>  };
>  use core::{
>      marker::PhantomData,
>      ptr::{addr_of_mut, NonNull},
> -    sync::atomic::{AtomicU64, Ordering},
> +    sync::atomic::Ordering,
>  };
>
>  /// A wrapper around a blk-mq [`struct request`]. This represents an IO request.
> @@ -37,6 +38,9 @@
>  /// We need to track 3 and 4 to ensure that it is safe to end the request and hand
>  /// back ownership to the block layer.
>  ///
> +/// Note that driver can still obtain new `ARef` even if there is no `ARef`s in existence by using
> +/// `tag_to_rq`, hence the need to distinct B and C.
> +///
>  /// The states are tracked through the private `refcount` field of
>  /// `RequestDataWrapper`. This structure lives in the private data area of the C
>  /// [`struct request`].
> @@ -98,13 +102,17 @@ pub(crate) unsafe fn start_unchecked(this: &ARef<Self>) {
>      ///
>      /// [`struct request`]: srctree/include/linux/blk-mq.h
>      fn try_set_end(this: ARef<Self>) -> Result<*mut bindings::request, ARef<Self>> {
> -        // We can race with `TagSet::tag_to_rq`
> -        if let Err(_old) = this.wrapper_ref().refcount().compare_exchange(
> -            2,
> -            0,
> -            Ordering::Relaxed,
> -            Ordering::Relaxed,
> -        ) {
> +        // To hand back the ownership, we need the current refcount to be 2.
> +        // Since we can race with `TagSet::tag_to_rq`, this needs to atomically reduce
> +        // refcount to 0. `Refcount` does not provide a way to do this, so use the underlying
> +        // atomics directly.
> +        if this
> +            .wrapper_ref()
> +            .refcount()
> +            .as_atomic()
> +            .compare_exchange(2, 0, Ordering::Relaxed, Ordering::Relaxed)
> +            .is_err()
> +        {
>              return Err(this);
>          }
>
> @@ -168,13 +176,13 @@ pub(crate) struct RequestDataWrapper {
>      /// - 0: The request is owned by C block layer.
>      /// - 1: The request is owned by Rust abstractions but there are no [`ARef`] references to it.
>      /// - 2+: There are [`ARef`] references to the request.
> -    refcount: AtomicU64,
> +    refcount: Refcount,
>  }
>
>  impl RequestDataWrapper {
>      /// Return a reference to the refcount of the request that is embedding
>      /// `self`.
> -    pub(crate) fn refcount(&self) -> &AtomicU64 {
> +    pub(crate) fn refcount(&self) -> &Refcount {
>          &self.refcount
>      }
>
> @@ -184,7 +192,7 @@ pub(crate) fn refcount(&self) -> &AtomicU64 {
>      /// # Safety
>      ///
>      /// - `this` must point to a live allocation of at least the size of `Self`.
> -    pub(crate) unsafe fn refcount_ptr(this: *mut Self) -> *mut AtomicU64 {
> +    pub(crate) unsafe fn refcount_ptr(this: *mut Self) -> *mut Refcount {
>          // SAFETY: Because of the safety requirements of this function, the
>          // field projection is safe.
>          unsafe { addr_of_mut!((*this).refcount) }
> @@ -200,47 +208,13 @@ unsafe impl<T: Operations> Send for Request<T> {}
>  // mutate `self` are internally synchronized`
>  unsafe impl<T: Operations> Sync for Request<T> {}
>
> -/// Store the result of `op(target.load())` in target, returning new value of
> -/// target.
> -fn atomic_relaxed_op_return(target: &AtomicU64, op: impl Fn(u64) -> u64) -> u64 {
> -    let old = target.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some(op(x)));
> -
> -    // SAFETY: Because the operation passed to `fetch_update` above always
> -    // return `Some`, `old` will always be `Ok`.
> -    let old = unsafe { old.unwrap_unchecked() };
> -
> -    op(old)
> -}
> -
> -/// Store the result of `op(target.load)` in `target` if `target.load() !=
> -/// pred`, returning [`true`] if the target was updated.
> -fn atomic_relaxed_op_unless(target: &AtomicU64, op: impl Fn(u64) -> u64, pred: u64) -> bool {
> -    target
> -        .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| {
> -            if x == pred {
> -                None
> -            } else {
> -                Some(op(x))
> -            }
> -        })
> -        .is_ok()
> -}
> -
>  // SAFETY: All instances of `Request<T>` are reference counted. This
>  // implementation of `AlwaysRefCounted` ensure that increments to the ref count
>  // keeps the object alive in memory at least until a matching reference count
>  // decrement is executed.
>  unsafe impl<T: Operations> AlwaysRefCounted for Request<T> {
>      fn inc_ref(&self) {
> -        let refcount = &self.wrapper_ref().refcount();
> -
> -        #[cfg_attr(not(CONFIG_DEBUG_MISC), allow(unused_variables))]
> -        let updated = atomic_relaxed_op_unless(refcount, |x| x + 1, 0);
> -
> -        #[cfg(CONFIG_DEBUG_MISC)]
> -        if !updated {
> -            panic!("Request refcount zero on clone")
> -        }
> +        self.wrapper_ref().refcount().inc();
>      }
>
>      unsafe fn dec_ref(obj: core::ptr::NonNull<Self>) {
> @@ -252,10 +226,10 @@ unsafe fn dec_ref(obj: core::ptr::NonNull<Self>) {
>          let refcount = unsafe { &*RequestDataWrapper::refcount_ptr(wrapper_ptr) };
>
>          #[cfg_attr(not(CONFIG_DEBUG_MISC), allow(unused_variables))]
> -        let new_refcount = atomic_relaxed_op_return(refcount, |x| x - 1);
> +        let is_zero = refcount.dec_and_test();
>
>          #[cfg(CONFIG_DEBUG_MISC)]
> -        if new_refcount == 0 {
> +        if is_zero {
>              panic!("Request reached refcount zero in Rust abstractions");
>          }
>      }
> diff --git a/rust/kernel/sync/refcount.rs b/rust/kernel/sync/refcount.rs
> index a6a683f5d7b8..3d7a1ffb3a46 100644
> --- a/rust/kernel/sync/refcount.rs
> +++ b/rust/kernel/sync/refcount.rs
> @@ -4,6 +4,8 @@
>  //!
>  //! C header: [`include/linux/refcount.h`](srctree/include/linux/refcount.h)
>
> +use core::sync::atomic::AtomicI32;
> +
>  use crate::types::Opaque;
>
>  /// Atomic reference counter.
> @@ -30,6 +32,18 @@ fn as_ptr(&self) -> *mut bindings::refcount_t {
>          self.0.get()
>      }
>
> +    /// Get the underlying atomic counter that backs the refcount.
> +    ///
> +    /// NOTE: This will be changed to LKMM atomic in the future.
> +    #[inline]
> +    pub fn as_atomic(&self) -> &AtomicI32 {
> +        let ptr = self.0.get() as *const AtomicI32;
> +        // SAFETY: `refcount_t` is a transparent wrapper of `atomic_t`, which is an atomic 32-bit
> +        // integer that is layout-wise compatible with `AtomicI32`. All values are valid for
> +        // `refcount_t`, despite some of the values are considered saturated and "bad".
> +        unsafe { &*ptr }
> +    }
> +
>      /// Set a refcount's value.
>      #[inline]
>      pub fn set(&self, value: i32) {
> --
> 2.47.2
>
>
Andreas Hindborg Feb. 20, 2025, 7:18 p.m. UTC | #4
Tamir Duberstein <tamird@gmail.com> writes:

> On Wed, Feb 19, 2025 at 5:26 PM Tamir Duberstein <tamird@gmail.com> wrote:
>>
>> On Wed, Feb 19, 2025 at 3:17 PM Gary Guo <gary@garyguo.net> wrote:
>> >
>> > Currently there's a custom reference counting in `block::mq`, which uses
>> > `AtomicU64` Rust atomics, and this type doesn't exist on some 32-bit
>> > architectures. We cannot just change it to use 32-bit atomics, because
>> > doing so will make it vulnerable to refcount overflow. So switch it to
>> > use the kernel refcount `kernel::sync::Refcount` instead.
>> >
>> > There is an operation needed by `block::mq`, atomically decreasing
>> > refcount from 2 to 0, which is not available through refcount.h, so
>> > I exposed `Refcount::as_atomic` which allows accessing the refcount
>> > directly.
>> >
>> > Acked-by: Andreas Hindborg <a.hindborg@kernel.org>
>> > Signed-off-by: Gary Guo <gary@garyguo.net>
>> > ---
>> >  rust/kernel/block/mq/operations.rs |  7 +--
>> >  rust/kernel/block/mq/request.rs    | 70 ++++++++++--------------------
>> >  rust/kernel/sync/refcount.rs       | 14 ++++++
>> >  3 files changed, 40 insertions(+), 51 deletions(-)
>> >
>> > diff --git a/rust/kernel/block/mq/operations.rs b/rust/kernel/block/mq/operations.rs
>> > index 864ff379dc91..c399dcaa6740 100644
>> > --- a/rust/kernel/block/mq/operations.rs
>> > +++ b/rust/kernel/block/mq/operations.rs
>> > @@ -10,9 +10,10 @@
>> >      block::mq::Request,
>> >      error::{from_result, Result},
>> >      prelude::*,
>> > +    sync::Refcount,
>> >      types::ARef,
>> >  };
>> > -use core::{marker::PhantomData, sync::atomic::AtomicU64, sync::atomic::Ordering};
>> > +use core::marker::PhantomData;
>> >
>> >  /// Implement this trait to interface blk-mq as block devices.
>> >  ///
>> > @@ -78,7 +79,7 @@ impl<T: Operations> OperationsVTable<T> {
>> >          let request = unsafe { &*(*bd).rq.cast::<Request<T>>() };
>> >
>> >          // One refcount for the ARef, one for being in flight
>> > -        request.wrapper_ref().refcount().store(2, Ordering::Relaxed);
>> > +        request.wrapper_ref().refcount().set(2);
>> >
>> >          // SAFETY:
>> >          //  - We own a refcount that we took above. We pass that to `ARef`.
>> > @@ -187,7 +188,7 @@ impl<T: Operations> OperationsVTable<T> {
>> >
>> >              // SAFETY: The refcount field is allocated but not initialized, so
>> >              // it is valid for writes.
>> > -            unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(AtomicU64::new(0)) };
>> > +            unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(Refcount::new(0)) };
>>
>> Could we just make the field pub and remove refcount_ptr? I believe a
>> few callers of `wrapper_ptr` could be replaced with `wrapper_ref`.
>
> I took a stab at this to check it was possible:
> https://gist.github.com/tamird/c9de7fa6e54529996f433950268f3f87

The access method uses a raw pointer because it is not always safe to
reference the field.

I think line 25 in your patch is UB as the field is not initialized.

At any rate, such a change is orthogonal. You could submit a separate
patch with that refactor.


Best regards,
Andreas Hindborg
diff mbox series

Patch

diff --git a/rust/kernel/block/mq/operations.rs b/rust/kernel/block/mq/operations.rs
index 864ff379dc91..c399dcaa6740 100644
--- a/rust/kernel/block/mq/operations.rs
+++ b/rust/kernel/block/mq/operations.rs
@@ -10,9 +10,10 @@ 
     block::mq::Request,
     error::{from_result, Result},
     prelude::*,
+    sync::Refcount,
     types::ARef,
 };
-use core::{marker::PhantomData, sync::atomic::AtomicU64, sync::atomic::Ordering};
+use core::marker::PhantomData;
 
 /// Implement this trait to interface blk-mq as block devices.
 ///
@@ -78,7 +79,7 @@  impl<T: Operations> OperationsVTable<T> {
         let request = unsafe { &*(*bd).rq.cast::<Request<T>>() };
 
         // One refcount for the ARef, one for being in flight
-        request.wrapper_ref().refcount().store(2, Ordering::Relaxed);
+        request.wrapper_ref().refcount().set(2);
 
         // SAFETY:
         //  - We own a refcount that we took above. We pass that to `ARef`.
@@ -187,7 +188,7 @@  impl<T: Operations> OperationsVTable<T> {
 
             // SAFETY: The refcount field is allocated but not initialized, so
             // it is valid for writes.
-            unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(AtomicU64::new(0)) };
+            unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(Refcount::new(0)) };
 
             Ok(0)
         })
diff --git a/rust/kernel/block/mq/request.rs b/rust/kernel/block/mq/request.rs
index 7943f43b9575..7c782d70935e 100644
--- a/rust/kernel/block/mq/request.rs
+++ b/rust/kernel/block/mq/request.rs
@@ -8,12 +8,13 @@ 
     bindings,
     block::mq::Operations,
     error::Result,
+    sync::Refcount,
     types::{ARef, AlwaysRefCounted, Opaque},
 };
 use core::{
     marker::PhantomData,
     ptr::{addr_of_mut, NonNull},
-    sync::atomic::{AtomicU64, Ordering},
+    sync::atomic::Ordering,
 };
 
 /// A wrapper around a blk-mq [`struct request`]. This represents an IO request.
@@ -37,6 +38,9 @@ 
 /// We need to track 3 and 4 to ensure that it is safe to end the request and hand
 /// back ownership to the block layer.
 ///
+/// Note that driver can still obtain new `ARef` even if there is no `ARef`s in existence by using
+/// `tag_to_rq`, hence the need to distinct B and C.
+///
 /// The states are tracked through the private `refcount` field of
 /// `RequestDataWrapper`. This structure lives in the private data area of the C
 /// [`struct request`].
@@ -98,13 +102,17 @@  pub(crate) unsafe fn start_unchecked(this: &ARef<Self>) {
     ///
     /// [`struct request`]: srctree/include/linux/blk-mq.h
     fn try_set_end(this: ARef<Self>) -> Result<*mut bindings::request, ARef<Self>> {
-        // We can race with `TagSet::tag_to_rq`
-        if let Err(_old) = this.wrapper_ref().refcount().compare_exchange(
-            2,
-            0,
-            Ordering::Relaxed,
-            Ordering::Relaxed,
-        ) {
+        // To hand back the ownership, we need the current refcount to be 2.
+        // Since we can race with `TagSet::tag_to_rq`, this needs to atomically reduce
+        // refcount to 0. `Refcount` does not provide a way to do this, so use the underlying
+        // atomics directly.
+        if this
+            .wrapper_ref()
+            .refcount()
+            .as_atomic()
+            .compare_exchange(2, 0, Ordering::Relaxed, Ordering::Relaxed)
+            .is_err()
+        {
             return Err(this);
         }
 
@@ -168,13 +176,13 @@  pub(crate) struct RequestDataWrapper {
     /// - 0: The request is owned by C block layer.
     /// - 1: The request is owned by Rust abstractions but there are no [`ARef`] references to it.
     /// - 2+: There are [`ARef`] references to the request.
-    refcount: AtomicU64,
+    refcount: Refcount,
 }
 
 impl RequestDataWrapper {
     /// Return a reference to the refcount of the request that is embedding
     /// `self`.
-    pub(crate) fn refcount(&self) -> &AtomicU64 {
+    pub(crate) fn refcount(&self) -> &Refcount {
         &self.refcount
     }
 
@@ -184,7 +192,7 @@  pub(crate) fn refcount(&self) -> &AtomicU64 {
     /// # Safety
     ///
     /// - `this` must point to a live allocation of at least the size of `Self`.
-    pub(crate) unsafe fn refcount_ptr(this: *mut Self) -> *mut AtomicU64 {
+    pub(crate) unsafe fn refcount_ptr(this: *mut Self) -> *mut Refcount {
         // SAFETY: Because of the safety requirements of this function, the
         // field projection is safe.
         unsafe { addr_of_mut!((*this).refcount) }
@@ -200,47 +208,13 @@  unsafe impl<T: Operations> Send for Request<T> {}
 // mutate `self` are internally synchronized`
 unsafe impl<T: Operations> Sync for Request<T> {}
 
-/// Store the result of `op(target.load())` in target, returning new value of
-/// target.
-fn atomic_relaxed_op_return(target: &AtomicU64, op: impl Fn(u64) -> u64) -> u64 {
-    let old = target.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some(op(x)));
-
-    // SAFETY: Because the operation passed to `fetch_update` above always
-    // return `Some`, `old` will always be `Ok`.
-    let old = unsafe { old.unwrap_unchecked() };
-
-    op(old)
-}
-
-/// Store the result of `op(target.load)` in `target` if `target.load() !=
-/// pred`, returning [`true`] if the target was updated.
-fn atomic_relaxed_op_unless(target: &AtomicU64, op: impl Fn(u64) -> u64, pred: u64) -> bool {
-    target
-        .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| {
-            if x == pred {
-                None
-            } else {
-                Some(op(x))
-            }
-        })
-        .is_ok()
-}
-
 // SAFETY: All instances of `Request<T>` are reference counted. This
 // implementation of `AlwaysRefCounted` ensure that increments to the ref count
 // keeps the object alive in memory at least until a matching reference count
 // decrement is executed.
 unsafe impl<T: Operations> AlwaysRefCounted for Request<T> {
     fn inc_ref(&self) {
-        let refcount = &self.wrapper_ref().refcount();
-
-        #[cfg_attr(not(CONFIG_DEBUG_MISC), allow(unused_variables))]
-        let updated = atomic_relaxed_op_unless(refcount, |x| x + 1, 0);
-
-        #[cfg(CONFIG_DEBUG_MISC)]
-        if !updated {
-            panic!("Request refcount zero on clone")
-        }
+        self.wrapper_ref().refcount().inc();
     }
 
     unsafe fn dec_ref(obj: core::ptr::NonNull<Self>) {
@@ -252,10 +226,10 @@  unsafe fn dec_ref(obj: core::ptr::NonNull<Self>) {
         let refcount = unsafe { &*RequestDataWrapper::refcount_ptr(wrapper_ptr) };
 
         #[cfg_attr(not(CONFIG_DEBUG_MISC), allow(unused_variables))]
-        let new_refcount = atomic_relaxed_op_return(refcount, |x| x - 1);
+        let is_zero = refcount.dec_and_test();
 
         #[cfg(CONFIG_DEBUG_MISC)]
-        if new_refcount == 0 {
+        if is_zero {
             panic!("Request reached refcount zero in Rust abstractions");
         }
     }
diff --git a/rust/kernel/sync/refcount.rs b/rust/kernel/sync/refcount.rs
index a6a683f5d7b8..3d7a1ffb3a46 100644
--- a/rust/kernel/sync/refcount.rs
+++ b/rust/kernel/sync/refcount.rs
@@ -4,6 +4,8 @@ 
 //!
 //! C header: [`include/linux/refcount.h`](srctree/include/linux/refcount.h)
 
+use core::sync::atomic::AtomicI32;
+
 use crate::types::Opaque;
 
 /// Atomic reference counter.
@@ -30,6 +32,18 @@  fn as_ptr(&self) -> *mut bindings::refcount_t {
         self.0.get()
     }
 
+    /// Get the underlying atomic counter that backs the refcount.
+    ///
+    /// NOTE: This will be changed to LKMM atomic in the future.
+    #[inline]
+    pub fn as_atomic(&self) -> &AtomicI32 {
+        let ptr = self.0.get() as *const AtomicI32;
+        // SAFETY: `refcount_t` is a transparent wrapper of `atomic_t`, which is an atomic 32-bit
+        // integer that is layout-wise compatible with `AtomicI32`. All values are valid for
+        // `refcount_t`, despite some of the values are considered saturated and "bad".
+        unsafe { &*ptr }
+    }
+
     /// Set a refcount's value.
     #[inline]
     pub fn set(&self, value: i32) {