diff mbox series

[net-next,v13,12/14] net: replace page_frag with page_frag_cache

Message ID 20240808123714.462740-13-linyunsheng@huawei.com (mailing list archive)
State Handled Elsewhere, archived
Delegated to: Mat Martineau
Headers show
Series None | expand

Commit Message

Yunsheng Lin Aug. 8, 2024, 12:37 p.m. UTC
Use the newly introduced prepare/probe/commit API to
replace page_frag with page_frag_cache for sk_page_frag().

CC: Alexander Duyck <alexander.duyck@gmail.com>
Signed-off-by: Yunsheng Lin <linyunsheng@huawei.com>
Acked-by: Mat Martineau <martineau@kernel.org>
---
 .../chelsio/inline_crypto/chtls/chtls.h       |   3 -
 .../chelsio/inline_crypto/chtls/chtls_io.c    | 100 ++++---------
 .../chelsio/inline_crypto/chtls/chtls_main.c  |   3 -
 drivers/net/tun.c                             |  48 +++---
 include/linux/sched.h                         |   2 +-
 include/net/sock.h                            |  14 +-
 kernel/exit.c                                 |   3 +-
 kernel/fork.c                                 |   3 +-
 net/core/skbuff.c                             |  59 +++++---
 net/core/skmsg.c                              |  22 +--
 net/core/sock.c                               |  46 ++++--
 net/ipv4/ip_output.c                          |  33 +++--
 net/ipv4/tcp.c                                |  32 ++--
 net/ipv4/tcp_output.c                         |  28 ++--
 net/ipv6/ip6_output.c                         |  33 +++--
 net/kcm/kcmsock.c                             |  27 ++--
 net/mptcp/protocol.c                          |  67 +++++----
 net/sched/em_meta.c                           |   2 +-
 net/tls/tls_device.c                          | 137 ++++++++++--------
 19 files changed, 347 insertions(+), 315 deletions(-)

Comments

Alexander Duyck Aug. 14, 2024, 10:01 p.m. UTC | #1
On Thu, 2024-08-08 at 20:37 +0800, Yunsheng Lin wrote:
> Use the newly introduced prepare/probe/commit API to
> replace page_frag with page_frag_cache for sk_page_frag().
> 
> CC: Alexander Duyck <alexander.duyck@gmail.com>
> Signed-off-by: Yunsheng Lin <linyunsheng@huawei.com>
> Acked-by: Mat Martineau <martineau@kernel.org>
> ---
>  .../chelsio/inline_crypto/chtls/chtls.h       |   3 -
>  .../chelsio/inline_crypto/chtls/chtls_io.c    | 100 ++++---------
>  .../chelsio/inline_crypto/chtls/chtls_main.c  |   3 -
>  drivers/net/tun.c                             |  48 +++---
>  include/linux/sched.h                         |   2 +-
>  include/net/sock.h                            |  14 +-
>  kernel/exit.c                                 |   3 +-
>  kernel/fork.c                                 |   3 +-
>  net/core/skbuff.c                             |  59 +++++---
>  net/core/skmsg.c                              |  22 +--
>  net/core/sock.c                               |  46 ++++--
>  net/ipv4/ip_output.c                          |  33 +++--
>  net/ipv4/tcp.c                                |  32 ++--
>  net/ipv4/tcp_output.c                         |  28 ++--
>  net/ipv6/ip6_output.c                         |  33 +++--
>  net/kcm/kcmsock.c                             |  27 ++--
>  net/mptcp/protocol.c                          |  67 +++++----
>  net/sched/em_meta.c                           |   2 +-
>  net/tls/tls_device.c                          | 137 ++++++++++--------
>  19 files changed, 347 insertions(+), 315 deletions(-)
> 
> diff --git a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h
> index 7ff82b6778ba..fe2b6a8ef718 100644
> --- a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h
> +++ b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h
> @@ -234,7 +234,6 @@ struct chtls_dev {
>  	struct list_head list_node;
>  	struct list_head rcu_node;
>  	struct list_head na_node;
> -	unsigned int send_page_order;
>  	int max_host_sndbuf;
>  	u32 round_robin_cnt;
>  	struct key_map kmap;
> @@ -453,8 +452,6 @@ enum {
>  
>  /* The ULP mode/submode of an skbuff */
>  #define skb_ulp_mode(skb)  (ULP_SKB_CB(skb)->ulp_mode)
> -#define TCP_PAGE(sk)   (sk->sk_frag.page)
> -#define TCP_OFF(sk)    (sk->sk_frag.offset)
>  
>  static inline struct chtls_dev *to_chtls_dev(struct tls_toe_device *tlsdev)
>  {
> diff --git a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c
> index d567e42e1760..334381c1587f 100644
> --- a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c
> +++ b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c
> @@ -825,12 +825,6 @@ void skb_entail(struct sock *sk, struct sk_buff *skb, int flags)
>  	ULP_SKB_CB(skb)->flags = flags;
>  	__skb_queue_tail(&csk->txq, skb);
>  	sk->sk_wmem_queued += skb->truesize;
> -
> -	if (TCP_PAGE(sk) && TCP_OFF(sk)) {
> -		put_page(TCP_PAGE(sk));
> -		TCP_PAGE(sk) = NULL;
> -		TCP_OFF(sk) = 0;
> -	}
>  }
>  
>  static struct sk_buff *get_tx_skb(struct sock *sk, int size)
> @@ -882,16 +876,12 @@ static void push_frames_if_head(struct sock *sk)
>  		chtls_push_frames(csk, 1);
>  }
>  
> -static int chtls_skb_copy_to_page_nocache(struct sock *sk,
> -					  struct iov_iter *from,
> -					  struct sk_buff *skb,
> -					  struct page *page,
> -					  int off, int copy)
> +static int chtls_skb_copy_to_va_nocache(struct sock *sk, struct iov_iter *from,
> +					struct sk_buff *skb, char *va, int copy)
>  {
>  	int err;
>  
> -	err = skb_do_copy_data_nocache(sk, skb, from, page_address(page) +
> -				       off, copy, skb->len);
> +	err = skb_do_copy_data_nocache(sk, skb, from, va, copy, skb->len);
>  	if (err)
>  		return err;
>  
> @@ -1114,82 +1104,44 @@ int chtls_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
>  			if (err)
>  				goto do_fault;
>  		} else {
> +			struct page_frag_cache *pfrag = &sk->sk_frag;

Is this even valid? Shouldn't it be using sk_page_frag to get the
reference here? Seems like it might be trying to instantiate an unused
cache.

As per my earlier suggestion this could be made very simple if we are
just pulling a bio_vec out from the page cache at the start. With that
we could essentially plug it into the TCP_PAGE/TCP_OFF block here and
most of it would just function the same.

>  			int i = skb_shinfo(skb)->nr_frags;
> -			struct page *page = TCP_PAGE(sk);
> -			int pg_size = PAGE_SIZE;
> -			int off = TCP_OFF(sk);
> -			bool merge;
> -
> -			if (page)
> -				pg_size = page_size(page);
> -			if (off < pg_size &&
> -			    skb_can_coalesce(skb, i, page, off)) {
> +			unsigned int offset, fragsz;
> +			bool merge = false;
> +			struct page *page;
> +			void *va;
> +
> +			fragsz = 32U;
> +			page = page_frag_alloc_prepare(pfrag, &offset, &fragsz,
> +						       &va, sk->sk_allocation);
> +			if (unlikely(!page))
> +				goto wait_for_memory;
> +
> +			if (skb_can_coalesce(skb, i, page, offset))
>  				merge = true;
> -				goto copy;
> -			}
> -			merge = false;
> -			if (i == (is_tls_tx(csk) ? (MAX_SKB_FRAGS - 1) :
> -			    MAX_SKB_FRAGS))
> +			else if (i == (is_tls_tx(csk) ? (MAX_SKB_FRAGS - 1) :
> +				       MAX_SKB_FRAGS))
>  				goto new_buf;
>  
> -			if (page && off == pg_size) {
> -				put_page(page);
> -				TCP_PAGE(sk) = page = NULL;
> -				pg_size = PAGE_SIZE;
> -			}
> -
> -			if (!page) {
> -				gfp_t gfp = sk->sk_allocation;
> -				int order = cdev->send_page_order;
> -
> -				if (order) {
> -					page = alloc_pages(gfp | __GFP_COMP |
> -							   __GFP_NOWARN |
> -							   __GFP_NORETRY,
> -							   order);
> -					if (page)
> -						pg_size <<= order;
> -				}
> -				if (!page) {
> -					page = alloc_page(gfp);
> -					pg_size = PAGE_SIZE;
> -				}
> -				if (!page)
> -					goto wait_for_memory;
> -				off = 0;
> -			}
> -copy:
> -			if (copy > pg_size - off)
> -				copy = pg_size - off;
> +			copy = min_t(int, copy, fragsz);
>  			if (is_tls_tx(csk))
>  				copy = min_t(int, copy, csk->tlshws.txleft);
>  
> -			err = chtls_skb_copy_to_page_nocache(sk, &msg->msg_iter,
> -							     skb, page,
> -							     off, copy);
> -			if (unlikely(err)) {
> -				if (!TCP_PAGE(sk)) {
> -					TCP_PAGE(sk) = page;
> -					TCP_OFF(sk) = 0;
> -				}
> +			err = chtls_skb_copy_to_va_nocache(sk, &msg->msg_iter,
> +							   skb, va, copy);
> +			if (unlikely(err))
>  				goto do_fault;
> -			}
> +
>  			/* Update the skb. */
>  			if (merge) {
>  				skb_frag_size_add(
>  						&skb_shinfo(skb)->frags[i - 1],
>  						copy);
> +				page_frag_alloc_commit_noref(pfrag, copy);
>  			} else {
> -				skb_fill_page_desc(skb, i, page, off, copy);
> -				if (off + copy < pg_size) {
> -					/* space left keep page */
> -					get_page(page);
> -					TCP_PAGE(sk) = page;
> -				} else {
> -					TCP_PAGE(sk) = NULL;
> -				}
> +				skb_fill_page_desc(skb, i, page, offset, copy);
> +				page_frag_alloc_commit(pfrag, copy);
>  			}
> -			TCP_OFF(sk) = off + copy;
>  		}
>  		if (unlikely(skb->len == mss))
>  			tx_skb_finalize(skb);

Really there is so much refactor here it is hard to tell what is what.
I would suggest just trying to plug in an intermediary value and you
can save the refactor for later.

> diff --git a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c
> index 455a54708be4..ba88b2fc7cd8 100644
> --- a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c
> +++ b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c
> @@ -34,7 +34,6 @@ static DEFINE_MUTEX(notify_mutex);
>  static RAW_NOTIFIER_HEAD(listen_notify_list);
>  static struct proto chtls_cpl_prot, chtls_cpl_protv6;
>  struct request_sock_ops chtls_rsk_ops, chtls_rsk_opsv6;
> -static uint send_page_order = (14 - PAGE_SHIFT < 0) ? 0 : 14 - PAGE_SHIFT;
>  
>  static void register_listen_notifier(struct notifier_block *nb)
>  {
> @@ -273,8 +272,6 @@ static void *chtls_uld_add(const struct cxgb4_lld_info *info)
>  	INIT_WORK(&cdev->deferq_task, process_deferq);
>  	spin_lock_init(&cdev->listen_lock);
>  	spin_lock_init(&cdev->idr_lock);
> -	cdev->send_page_order = min_t(uint, get_order(32768),
> -				      send_page_order);
>  	cdev->max_host_sndbuf = 48 * 1024;
>  
>  	if (lldi->vr->key.size)
> diff --git a/drivers/net/tun.c b/drivers/net/tun.c
> index 1d06c560c5e6..51df92fd60db 100644
> --- a/drivers/net/tun.c
> +++ b/drivers/net/tun.c
> @@ -1598,21 +1598,19 @@ static bool tun_can_build_skb(struct tun_struct *tun, struct tun_file *tfile,
>  }
>  
>  static struct sk_buff *__tun_build_skb(struct tun_file *tfile,
> -				       struct page_frag *alloc_frag, char *buf,
> -				       int buflen, int len, int pad)
> +				       char *buf, int buflen, int len, int pad)
>  {
>  	struct sk_buff *skb = build_skb(buf, buflen);
>  
> -	if (!skb)
> +	if (!skb) {
> +		page_frag_free_va(buf);
>  		return ERR_PTR(-ENOMEM);
> +	}
>  
>  	skb_reserve(skb, pad);
>  	skb_put(skb, len);
>  	skb_set_owner_w(skb, tfile->socket.sk);
>  
> -	get_page(alloc_frag->page);
> -	alloc_frag->offset += buflen;
> -

Rather than freeing the buf it would be better if you were to just
stick to the existing pattern and commit the alloc_frag at the end here
instead of calling get_page.

>  	return skb;
>  }
>  
> @@ -1660,7 +1658,7 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
>  				     struct virtio_net_hdr *hdr,
>  				     int len, int *skb_xdp)
>  {
> -	struct page_frag *alloc_frag = &current->task_frag;
> +	struct page_frag_cache *alloc_frag = &current->task_frag;
>  	struct bpf_net_context __bpf_net_ctx, *bpf_net_ctx;
>  	struct bpf_prog *xdp_prog;
>  	int buflen = SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
> @@ -1676,16 +1674,16 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
>  	buflen += SKB_DATA_ALIGN(len + pad);
>  	rcu_read_unlock();
>  
> -	alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES);
> -	if (unlikely(!skb_page_frag_refill(buflen, alloc_frag, GFP_KERNEL)))
> +	buf = page_frag_alloc_va_align(alloc_frag, buflen, GFP_KERNEL,
> +				       SMP_CACHE_BYTES);
> +	if (unlikely(!buf))
>  		return ERR_PTR(-ENOMEM);
>  
> -	buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
> -	copied = copy_page_from_iter(alloc_frag->page,
> -				     alloc_frag->offset + pad,
> -				     len, from);
> -	if (copied != len)
> +	copied = copy_from_iter(buf + pad, len, from);
> +	if (copied != len) {
> +		page_frag_alloc_abort(alloc_frag, buflen);
>  		return ERR_PTR(-EFAULT);
> +	}
>  
>  	/* There's a small window that XDP may be set after the check
>  	 * of xdp_prog above, this should be rare and for simplicity
> @@ -1693,8 +1691,7 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
>  	 */
>  	if (hdr->gso_type || !xdp_prog) {
>  		*skb_xdp = 1;
> -		return __tun_build_skb(tfile, alloc_frag, buf, buflen, len,
> -				       pad);
> +		return __tun_build_skb(tfile, buf, buflen, len, pad);
>  	}
>  
>  	*skb_xdp = 0;
> @@ -1711,21 +1708,16 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
>  		xdp_prepare_buff(&xdp, buf, pad, len, false);
>  
>  		act = bpf_prog_run_xdp(xdp_prog, &xdp);
> -		if (act == XDP_REDIRECT || act == XDP_TX) {
> -			get_page(alloc_frag->page);
> -			alloc_frag->offset += buflen;
> -		}
>  		err = tun_xdp_act(tun, xdp_prog, &xdp, act);
> -		if (err < 0) {
> -			if (act == XDP_REDIRECT || act == XDP_TX)
> -				put_page(alloc_frag->page);
> -			goto out;
> -		}
> -
>  		if (err == XDP_REDIRECT)
>  			xdp_do_flush();
> -		if (err != XDP_PASS)
> +
> +		if (err == XDP_REDIRECT || err == XDP_TX) {
> +			goto out;
> +		} else if (err < 0 || err != XDP_PASS) {
> +			page_frag_alloc_abort(alloc_frag, buflen);
>  			goto out;
> +		}
>  

Your abort function here is not necessarily safe. It is assuming that
nothing else might have taken a reference to the page or modified it in
some way. Generally we shouldn't allow rewinding the pointer until we
check the page count to guarantee nobody else is now working with a
copy of the page.

>  		pad = xdp.data - xdp.data_hard_start;
>  		len = xdp.data_end - xdp.data;
> @@ -1734,7 +1726,7 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
>  	rcu_read_unlock();
>  	local_bh_enable();
>  
> -	return __tun_build_skb(tfile, alloc_frag, buf, buflen, len, pad);
> +	return __tun_build_skb(tfile, buf, buflen, len, pad);
>  
>  out:
>  	bpf_net_ctx_clear(bpf_net_ctx);
> diff --git a/include/linux/sched.h b/include/linux/sched.h
> index f8d150343d42..bb9a8e9d6d2d 100644
> --- a/include/linux/sched.h
> +++ b/include/linux/sched.h
> @@ -1355,7 +1355,7 @@ struct task_struct {
>  	/* Cache last used pipe for splice(): */
>  	struct pipe_inode_info		*splice_pipe;
>  
> -	struct page_frag		task_frag;
> +	struct page_frag_cache		task_frag;
>  
>  #ifdef CONFIG_TASK_DELAY_ACCT
>  	struct task_delay_info		*delays;
> diff --git a/include/net/sock.h b/include/net/sock.h
> index b5e702298ab7..8f6cc0dd2f4f 100644
> 

It occurs to me that bio_vec and page_frag are essentially the same
thing. Instead of having your functions pass a bio_vec it might make
more sense to work with just a page_frag as the unit to be probed and
committed with the page_frag_cache being what is borrowed from.

With that I think you could clean up a bunch of the change this code is
generating as there is too much refactor to make this easy to review.
If you were to change things though so that you maintain working with a
page_frag and are just probing it out of the page_frag_cache and
committing your change back in it would make the diff much more
readable in my opinion.

The general idea would be that the page and offset should not be
changed from probe to commit, and size would only be able to be reduced
vs remaining. It would help to make this more readable instead of
returning page while passing pointers to offset and length/size.

Also as I mentioned earlier we cannot be rolling the offset backwards.
It needs to always be making forward progress unless we own all
instances of the page as it is possible that a section may have been
shared out from underneath us when we showed some other entity the
memory.
Yunsheng Lin Aug. 18, 2024, 2:17 p.m. UTC | #2
On 8/15/2024 6:01 AM, Alexander H Duyck wrote:
> On Thu, 2024-08-08 at 20:37 +0800, Yunsheng Lin wrote:
>> Use the newly introduced prepare/probe/commit API to
>> replace page_frag with page_frag_cache for sk_page_frag().
>>
>> CC: Alexander Duyck <alexander.duyck@gmail.com>
>> Signed-off-by: Yunsheng Lin <linyunsheng@huawei.com>
>> Acked-by: Mat Martineau <martineau@kernel.org>
>> ---
>>   .../chelsio/inline_crypto/chtls/chtls.h       |   3 -
>>   .../chelsio/inline_crypto/chtls/chtls_io.c    | 100 ++++---------
>>   .../chelsio/inline_crypto/chtls/chtls_main.c  |   3 -
>>   drivers/net/tun.c                             |  48 +++---
>>   include/linux/sched.h                         |   2 +-
>>   include/net/sock.h                            |  14 +-
>>   kernel/exit.c                                 |   3 +-
>>   kernel/fork.c                                 |   3 +-
>>   net/core/skbuff.c                             |  59 +++++---
>>   net/core/skmsg.c                              |  22 +--
>>   net/core/sock.c                               |  46 ++++--
>>   net/ipv4/ip_output.c                          |  33 +++--
>>   net/ipv4/tcp.c                                |  32 ++--
>>   net/ipv4/tcp_output.c                         |  28 ++--
>>   net/ipv6/ip6_output.c                         |  33 +++--
>>   net/kcm/kcmsock.c                             |  27 ++--
>>   net/mptcp/protocol.c                          |  67 +++++----
>>   net/sched/em_meta.c                           |   2 +-
>>   net/tls/tls_device.c                          | 137 ++++++++++--------
>>   19 files changed, 347 insertions(+), 315 deletions(-)
>>
>> diff --git a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h
>> index 7ff82b6778ba..fe2b6a8ef718 100644
>> --- a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h
>> +++ b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h
>> @@ -234,7 +234,6 @@ struct chtls_dev {
>>   	struct list_head list_node;
>>   	struct list_head rcu_node;
>>   	struct list_head na_node;
>> -	unsigned int send_page_order;
>>   	int max_host_sndbuf;
>>   	u32 round_robin_cnt;
>>   	struct key_map kmap;
>> @@ -453,8 +452,6 @@ enum {
>>   
>>   /* The ULP mode/submode of an skbuff */
>>   #define skb_ulp_mode(skb)  (ULP_SKB_CB(skb)->ulp_mode)
>> -#define TCP_PAGE(sk)   (sk->sk_frag.page)
>> -#define TCP_OFF(sk)    (sk->sk_frag.offset)
>>   
>>   static inline struct chtls_dev *to_chtls_dev(struct tls_toe_device *tlsdev)
>>   {
>> diff --git a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c
>> index d567e42e1760..334381c1587f 100644
>> --- a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c
>> +++ b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c
>> @@ -825,12 +825,6 @@ void skb_entail(struct sock *sk, struct sk_buff *skb, int flags)
>>   	ULP_SKB_CB(skb)->flags = flags;
>>   	__skb_queue_tail(&csk->txq, skb);
>>   	sk->sk_wmem_queued += skb->truesize;
>> -
>> -	if (TCP_PAGE(sk) && TCP_OFF(sk)) {
>> -		put_page(TCP_PAGE(sk));
>> -		TCP_PAGE(sk) = NULL;
>> -		TCP_OFF(sk) = 0;
>> -	}
>>   }
>>   
>>   static struct sk_buff *get_tx_skb(struct sock *sk, int size)
>> @@ -882,16 +876,12 @@ static void push_frames_if_head(struct sock *sk)
>>   		chtls_push_frames(csk, 1);
>>   }
>>   
>> -static int chtls_skb_copy_to_page_nocache(struct sock *sk,
>> -					  struct iov_iter *from,
>> -					  struct sk_buff *skb,
>> -					  struct page *page,
>> -					  int off, int copy)
>> +static int chtls_skb_copy_to_va_nocache(struct sock *sk, struct iov_iter *from,
>> +					struct sk_buff *skb, char *va, int copy)
>>   {
>>   	int err;
>>   
>> -	err = skb_do_copy_data_nocache(sk, skb, from, page_address(page) +
>> -				       off, copy, skb->len);
>> +	err = skb_do_copy_data_nocache(sk, skb, from, va, copy, skb->len);
>>   	if (err)
>>   		return err;
>>   
>> @@ -1114,82 +1104,44 @@ int chtls_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
>>   			if (err)
>>   				goto do_fault;
>>   		} else {
>> +			struct page_frag_cache *pfrag = &sk->sk_frag;
> 
> Is this even valid? Shouldn't it be using sk_page_frag to get the

chtls_sendmsg() only use sk->sk_frag, see below.

> reference here? Seems like it might be trying to instantiate an unused
> cache.

I am not sure if I understand what you meant by "trying to instantiate
an unused cache". sk->sk_frag is supposed to be instantiated in
sock_init_data_uid() by calling page_frag_cache_init() in this patch.

> 
> As per my earlier suggestion this could be made very simple if we are
> just pulling a bio_vec out from the page cache at the start. With that
> we could essentially plug it into the TCP_PAGE/TCP_OFF block here and
> most of it would just function the same.

I am not sure if we had the same common understanding as why chtls had
more changing than other places when replacing page_frag with
page_frag_cache.

chtls_sendmsg() was duplicating its own implementation of page_frag
refilling instead of using skb_page_frag_refill(), we can remove that
implementation by using the new API, that is why there is more changing
for chtls than other places. Are you suggesting to keep chtls's own
implementation of page_frag refilling by 'plug it into the TCP_PAGE/
TCP_OFF block' here?

> 
>>   			int i = skb_shinfo(skb)->nr_frags;
>> -			struct page *page = TCP_PAGE(sk);

TCP_PAGE macro is defined as below, that is why sk->sk_frag is used
instead of sk_page_frag() for chtls case if I understand your question
correctly:

#define TCP_PAGE(sk)   (sk->sk_frag.page)
#define TCP_OFF(sk)    (sk->sk_frag.offset)

>> -			int pg_size = PAGE_SIZE;
>> -			int off = TCP_OFF(sk);
>> -			bool merge;
>> -
>> -			if (page)
>> -				pg_size = page_size(page);
>> -			if (off < pg_size &&
>> -			    skb_can_coalesce(skb, i, page, off)) {
>> +			unsigned int offset, fragsz;
>> +			bool merge = false;
>> +			struct page *page;
>> +			void *va;
>> +
>> +			fragsz = 32U;
>> +			page = page_frag_alloc_prepare(pfrag, &offset, &fragsz,
>> +						       &va, sk->sk_allocation);
>> +			if (unlikely(!page))
>> +				goto wait_for_memory;
>> +
>> +			if (skb_can_coalesce(skb, i, page, offset))
>>   				merge = true;
>> -				goto copy;
>> -			}
>> -			merge = false;
>> -			if (i == (is_tls_tx(csk) ? (MAX_SKB_FRAGS - 1) :
>> -			    MAX_SKB_FRAGS))
>> +			else if (i == (is_tls_tx(csk) ? (MAX_SKB_FRAGS - 1) :
>> +				       MAX_SKB_FRAGS))
>>   				goto new_buf;
>>   
>> -			if (page && off == pg_size) {
>> -				put_page(page);
>> -				TCP_PAGE(sk) = page = NULL;
>> -				pg_size = PAGE_SIZE;
>> -			}
>> -
>> -			if (!page) {
>> -				gfp_t gfp = sk->sk_allocation;
>> -				int order = cdev->send_page_order;
>> -
>> -				if (order) {
>> -					page = alloc_pages(gfp | __GFP_COMP |
>> -							   __GFP_NOWARN |
>> -							   __GFP_NORETRY,
>> -							   order);
>> -					if (page)
>> -						pg_size <<= order;
>> -				}
>> -				if (!page) {
>> -					page = alloc_page(gfp);
>> -					pg_size = PAGE_SIZE;
>> -				}
>> -				if (!page)
>> -					goto wait_for_memory;
>> -				off = 0;
>> -			}
>> -copy:
>> -			if (copy > pg_size - off)
>> -				copy = pg_size - off;
>> +			copy = min_t(int, copy, fragsz);
>>   			if (is_tls_tx(csk))
>>   				copy = min_t(int, copy, csk->tlshws.txleft);
>>   
>> -			err = chtls_skb_copy_to_page_nocache(sk, &msg->msg_iter,
>> -							     skb, page,
>> -							     off, copy);
>> -			if (unlikely(err)) {
>> -				if (!TCP_PAGE(sk)) {
>> -					TCP_PAGE(sk) = page;
>> -					TCP_OFF(sk) = 0;
>> -				}
>> +			err = chtls_skb_copy_to_va_nocache(sk, &msg->msg_iter,
>> +							   skb, va, copy);
>> +			if (unlikely(err))
>>   				goto do_fault;
>> -			}
>> +
>>   			/* Update the skb. */
>>   			if (merge) {
>>   				skb_frag_size_add(
>>   						&skb_shinfo(skb)->frags[i - 1],
>>   						copy);
>> +				page_frag_alloc_commit_noref(pfrag, copy);
>>   			} else {
>> -				skb_fill_page_desc(skb, i, page, off, copy);
>> -				if (off + copy < pg_size) {
>> -					/* space left keep page */
>> -					get_page(page);
>> -					TCP_PAGE(sk) = page;
>> -				} else {
>> -					TCP_PAGE(sk) = NULL;
>> -				}
>> +				skb_fill_page_desc(skb, i, page, offset, copy);
>> +				page_frag_alloc_commit(pfrag, copy);
>>   			}
>> -			TCP_OFF(sk) = off + copy;
>>   		}
>>   		if (unlikely(skb->len == mss))
>>   			tx_skb_finalize(skb);
> 
> Really there is so much refactor here it is hard to tell what is what.
> I would suggest just trying to plug in an intermediary value and you
> can save the refactor for later.

I am not sure if your above suggestion works, if it does work, I am not
sure if it is that simple enough to just plug in an intermediary value
as the the fields in 'struct page_frag_cache' is much different from the
fields in 'struct page_frag' as below when replacing page_frag with 
page_frag_cache for sk->sk_frag:

struct page_frag_cache {
	unsigned long encoded_va;

+#if (PAGE_SIZE < PAGE_FRAG_CACHE_MAX_SIZE) && (BITS_PER_LONG <= 32)
  	__u16 remaining;
	__u16 pagecnt_bias;
  #else
  	__u32 remaining;
	__u32 pagecnt_bias;
  #endif
};

struct page_frag {
	struct page *page;
#if (BITS_PER_LONG > 32) || (PAGE_SIZE >= 65536)
	__u32 offset;
	__u32 size;
#else
	__u16 offset;
	__u16 size;
#endif
};


> 
>> diff --git a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c
>> index 455a54708be4..ba88b2fc7cd8 100644
>> --- a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c
>> +++ b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c
>> @@ -34,7 +34,6 @@ static DEFINE_MUTEX(notify_mutex);
>>   static RAW_NOTIFIER_HEAD(listen_notify_list);
>>   static struct proto chtls_cpl_prot, chtls_cpl_protv6;
>>   struct request_sock_ops chtls_rsk_ops, chtls_rsk_opsv6;
>> -static uint send_page_order = (14 - PAGE_SHIFT < 0) ? 0 : 14 - PAGE_SHIFT;
>>   
>>   static void register_listen_notifier(struct notifier_block *nb)
>>   {
>> @@ -273,8 +272,6 @@ static void *chtls_uld_add(const struct cxgb4_lld_info *info)
>>   	INIT_WORK(&cdev->deferq_task, process_deferq);
>>   	spin_lock_init(&cdev->listen_lock);
>>   	spin_lock_init(&cdev->idr_lock);
>> -	cdev->send_page_order = min_t(uint, get_order(32768),
>> -				      send_page_order);
>>   	cdev->max_host_sndbuf = 48 * 1024;
>>   
>>   	if (lldi->vr->key.size)
>> diff --git a/drivers/net/tun.c b/drivers/net/tun.c
>> index 1d06c560c5e6..51df92fd60db 100644
>> --- a/drivers/net/tun.c
>> +++ b/drivers/net/tun.c
>> @@ -1598,21 +1598,19 @@ static bool tun_can_build_skb(struct tun_struct *tun, struct tun_file *tfile,
>>   }
>>   
>>   static struct sk_buff *__tun_build_skb(struct tun_file *tfile,
>> -				       struct page_frag *alloc_frag, char *buf,
>> -				       int buflen, int len, int pad)
>> +				       char *buf, int buflen, int len, int pad)
>>   {
>>   	struct sk_buff *skb = build_skb(buf, buflen);
>>   
>> -	if (!skb)
>> +	if (!skb) {
>> +		page_frag_free_va(buf);
>>   		return ERR_PTR(-ENOMEM);
>> +	}
>>   
>>   	skb_reserve(skb, pad);
>>   	skb_put(skb, len);
>>   	skb_set_owner_w(skb, tfile->socket.sk);
>>   
>> -	get_page(alloc_frag->page);
>> -	alloc_frag->offset += buflen;
>> -
> 
> Rather than freeing the buf it would be better if you were to just
> stick to the existing pattern and commit the alloc_frag at the end here
> instead of calling get_page.
> 
>>   	return skb;
>>   }
>>   
>> @@ -1660,7 +1658,7 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
>>   				     struct virtio_net_hdr *hdr,
>>   				     int len, int *skb_xdp)
>>   {
>> -	struct page_frag *alloc_frag = &current->task_frag;
>> +	struct page_frag_cache *alloc_frag = &current->task_frag;
>>   	struct bpf_net_context __bpf_net_ctx, *bpf_net_ctx;
>>   	struct bpf_prog *xdp_prog;
>>   	int buflen = SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
>> @@ -1676,16 +1674,16 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
>>   	buflen += SKB_DATA_ALIGN(len + pad);
>>   	rcu_read_unlock();
>>   
>> -	alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES);
>> -	if (unlikely(!skb_page_frag_refill(buflen, alloc_frag, GFP_KERNEL)))
>> +	buf = page_frag_alloc_va_align(alloc_frag, buflen, GFP_KERNEL,
>> +				       SMP_CACHE_BYTES);
>> +	if (unlikely(!buf))
>>   		return ERR_PTR(-ENOMEM);
>>   
>> -	buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
>> -	copied = copy_page_from_iter(alloc_frag->page,
>> -				     alloc_frag->offset + pad,
>> -				     len, from);
>> -	if (copied != len)
>> +	copied = copy_from_iter(buf + pad, len, from);
>> +	if (copied != len) {
>> +		page_frag_alloc_abort(alloc_frag, buflen);
>>   		return ERR_PTR(-EFAULT);
>> +	}
>>   
>>   	/* There's a small window that XDP may be set after the check
>>   	 * of xdp_prog above, this should be rare and for simplicity
>> @@ -1693,8 +1691,7 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
>>   	 */
>>   	if (hdr->gso_type || !xdp_prog) {
>>   		*skb_xdp = 1;
>> -		return __tun_build_skb(tfile, alloc_frag, buf, buflen, len,
>> -				       pad);
>> +		return __tun_build_skb(tfile, buf, buflen, len, pad);
>>   	}
>>   
>>   	*skb_xdp = 0;
>> @@ -1711,21 +1708,16 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
>>   		xdp_prepare_buff(&xdp, buf, pad, len, false);
>>   
>>   		act = bpf_prog_run_xdp(xdp_prog, &xdp);
>> -		if (act == XDP_REDIRECT || act == XDP_TX) {
>> -			get_page(alloc_frag->page);
>> -			alloc_frag->offset += buflen;

the above is only executed for XDP_REDIRECT and XDP_TX cases.

>> -		}
>>   		err = tun_xdp_act(tun, xdp_prog, &xdp, act);
>> -		if (err < 0) {
>> -			if (act == XDP_REDIRECT || act == XDP_TX)
>> -				put_page(alloc_frag->page);

And there is a put_page() immediately when xdp_do_redirect() or
tun_xdp_tx() fails in tun_xdp_act(), so there is something else
might have taken a reference to the page and modified it in some way
even when tun_xdp_act() return error? Would you be more specific
about why above happens?

>> -			goto out;
>> -		}
>> -
>>   		if (err == XDP_REDIRECT)
>>   			xdp_do_flush();
>> -		if (err != XDP_PASS)
>> +
>> +		if (err == XDP_REDIRECT || err == XDP_TX) {
>> +			goto out;
>> +		} else if (err < 0 || err != XDP_PASS) {
>> +			page_frag_alloc_abort(alloc_frag, buflen);
>>   			goto out;
>> +		}
>>   
> 
> Your abort function here is not necessarily safe. It is assuming that
> nothing else might have taken a reference to the page or modified it in
> some way. Generally we shouldn't allow rewinding the pointer until we
> check the page count to guarantee nobody else is now working with a
> copy of the page.
> 
>>   		pad = xdp.data - xdp.data_hard_start;
>>   		len = xdp.data_end - xdp.data;
>> @@ -1734,7 +1726,7 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
>>   	rcu_read_unlock();
>>   	local_bh_enable();
>>   
>> -	return __tun_build_skb(tfile, alloc_frag, buf, buflen, len, pad);
>> +	return __tun_build_skb(tfile, buf, buflen, len, pad);
>>   
>>   out:
>>   	bpf_net_ctx_clear(bpf_net_ctx);
>

...
diff mbox series

Patch

diff --git a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h
index 7ff82b6778ba..fe2b6a8ef718 100644
--- a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h
+++ b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h
@@ -234,7 +234,6 @@  struct chtls_dev {
 	struct list_head list_node;
 	struct list_head rcu_node;
 	struct list_head na_node;
-	unsigned int send_page_order;
 	int max_host_sndbuf;
 	u32 round_robin_cnt;
 	struct key_map kmap;
@@ -453,8 +452,6 @@  enum {
 
 /* The ULP mode/submode of an skbuff */
 #define skb_ulp_mode(skb)  (ULP_SKB_CB(skb)->ulp_mode)
-#define TCP_PAGE(sk)   (sk->sk_frag.page)
-#define TCP_OFF(sk)    (sk->sk_frag.offset)
 
 static inline struct chtls_dev *to_chtls_dev(struct tls_toe_device *tlsdev)
 {
diff --git a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c
index d567e42e1760..334381c1587f 100644
--- a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c
+++ b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c
@@ -825,12 +825,6 @@  void skb_entail(struct sock *sk, struct sk_buff *skb, int flags)
 	ULP_SKB_CB(skb)->flags = flags;
 	__skb_queue_tail(&csk->txq, skb);
 	sk->sk_wmem_queued += skb->truesize;
-
-	if (TCP_PAGE(sk) && TCP_OFF(sk)) {
-		put_page(TCP_PAGE(sk));
-		TCP_PAGE(sk) = NULL;
-		TCP_OFF(sk) = 0;
-	}
 }
 
 static struct sk_buff *get_tx_skb(struct sock *sk, int size)
@@ -882,16 +876,12 @@  static void push_frames_if_head(struct sock *sk)
 		chtls_push_frames(csk, 1);
 }
 
-static int chtls_skb_copy_to_page_nocache(struct sock *sk,
-					  struct iov_iter *from,
-					  struct sk_buff *skb,
-					  struct page *page,
-					  int off, int copy)
+static int chtls_skb_copy_to_va_nocache(struct sock *sk, struct iov_iter *from,
+					struct sk_buff *skb, char *va, int copy)
 {
 	int err;
 
-	err = skb_do_copy_data_nocache(sk, skb, from, page_address(page) +
-				       off, copy, skb->len);
+	err = skb_do_copy_data_nocache(sk, skb, from, va, copy, skb->len);
 	if (err)
 		return err;
 
@@ -1114,82 +1104,44 @@  int chtls_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 			if (err)
 				goto do_fault;
 		} else {
+			struct page_frag_cache *pfrag = &sk->sk_frag;
 			int i = skb_shinfo(skb)->nr_frags;
-			struct page *page = TCP_PAGE(sk);
-			int pg_size = PAGE_SIZE;
-			int off = TCP_OFF(sk);
-			bool merge;
-
-			if (page)
-				pg_size = page_size(page);
-			if (off < pg_size &&
-			    skb_can_coalesce(skb, i, page, off)) {
+			unsigned int offset, fragsz;
+			bool merge = false;
+			struct page *page;
+			void *va;
+
+			fragsz = 32U;
+			page = page_frag_alloc_prepare(pfrag, &offset, &fragsz,
+						       &va, sk->sk_allocation);
+			if (unlikely(!page))
+				goto wait_for_memory;
+
+			if (skb_can_coalesce(skb, i, page, offset))
 				merge = true;
-				goto copy;
-			}
-			merge = false;
-			if (i == (is_tls_tx(csk) ? (MAX_SKB_FRAGS - 1) :
-			    MAX_SKB_FRAGS))
+			else if (i == (is_tls_tx(csk) ? (MAX_SKB_FRAGS - 1) :
+				       MAX_SKB_FRAGS))
 				goto new_buf;
 
-			if (page && off == pg_size) {
-				put_page(page);
-				TCP_PAGE(sk) = page = NULL;
-				pg_size = PAGE_SIZE;
-			}
-
-			if (!page) {
-				gfp_t gfp = sk->sk_allocation;
-				int order = cdev->send_page_order;
-
-				if (order) {
-					page = alloc_pages(gfp | __GFP_COMP |
-							   __GFP_NOWARN |
-							   __GFP_NORETRY,
-							   order);
-					if (page)
-						pg_size <<= order;
-				}
-				if (!page) {
-					page = alloc_page(gfp);
-					pg_size = PAGE_SIZE;
-				}
-				if (!page)
-					goto wait_for_memory;
-				off = 0;
-			}
-copy:
-			if (copy > pg_size - off)
-				copy = pg_size - off;
+			copy = min_t(int, copy, fragsz);
 			if (is_tls_tx(csk))
 				copy = min_t(int, copy, csk->tlshws.txleft);
 
-			err = chtls_skb_copy_to_page_nocache(sk, &msg->msg_iter,
-							     skb, page,
-							     off, copy);
-			if (unlikely(err)) {
-				if (!TCP_PAGE(sk)) {
-					TCP_PAGE(sk) = page;
-					TCP_OFF(sk) = 0;
-				}
+			err = chtls_skb_copy_to_va_nocache(sk, &msg->msg_iter,
+							   skb, va, copy);
+			if (unlikely(err))
 				goto do_fault;
-			}
+
 			/* Update the skb. */
 			if (merge) {
 				skb_frag_size_add(
 						&skb_shinfo(skb)->frags[i - 1],
 						copy);
+				page_frag_alloc_commit_noref(pfrag, copy);
 			} else {
-				skb_fill_page_desc(skb, i, page, off, copy);
-				if (off + copy < pg_size) {
-					/* space left keep page */
-					get_page(page);
-					TCP_PAGE(sk) = page;
-				} else {
-					TCP_PAGE(sk) = NULL;
-				}
+				skb_fill_page_desc(skb, i, page, offset, copy);
+				page_frag_alloc_commit(pfrag, copy);
 			}
-			TCP_OFF(sk) = off + copy;
 		}
 		if (unlikely(skb->len == mss))
 			tx_skb_finalize(skb);
diff --git a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c
index 455a54708be4..ba88b2fc7cd8 100644
--- a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c
+++ b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_main.c
@@ -34,7 +34,6 @@  static DEFINE_MUTEX(notify_mutex);
 static RAW_NOTIFIER_HEAD(listen_notify_list);
 static struct proto chtls_cpl_prot, chtls_cpl_protv6;
 struct request_sock_ops chtls_rsk_ops, chtls_rsk_opsv6;
-static uint send_page_order = (14 - PAGE_SHIFT < 0) ? 0 : 14 - PAGE_SHIFT;
 
 static void register_listen_notifier(struct notifier_block *nb)
 {
@@ -273,8 +272,6 @@  static void *chtls_uld_add(const struct cxgb4_lld_info *info)
 	INIT_WORK(&cdev->deferq_task, process_deferq);
 	spin_lock_init(&cdev->listen_lock);
 	spin_lock_init(&cdev->idr_lock);
-	cdev->send_page_order = min_t(uint, get_order(32768),
-				      send_page_order);
 	cdev->max_host_sndbuf = 48 * 1024;
 
 	if (lldi->vr->key.size)
diff --git a/drivers/net/tun.c b/drivers/net/tun.c
index 1d06c560c5e6..51df92fd60db 100644
--- a/drivers/net/tun.c
+++ b/drivers/net/tun.c
@@ -1598,21 +1598,19 @@  static bool tun_can_build_skb(struct tun_struct *tun, struct tun_file *tfile,
 }
 
 static struct sk_buff *__tun_build_skb(struct tun_file *tfile,
-				       struct page_frag *alloc_frag, char *buf,
-				       int buflen, int len, int pad)
+				       char *buf, int buflen, int len, int pad)
 {
 	struct sk_buff *skb = build_skb(buf, buflen);
 
-	if (!skb)
+	if (!skb) {
+		page_frag_free_va(buf);
 		return ERR_PTR(-ENOMEM);
+	}
 
 	skb_reserve(skb, pad);
 	skb_put(skb, len);
 	skb_set_owner_w(skb, tfile->socket.sk);
 
-	get_page(alloc_frag->page);
-	alloc_frag->offset += buflen;
-
 	return skb;
 }
 
@@ -1660,7 +1658,7 @@  static struct sk_buff *tun_build_skb(struct tun_struct *tun,
 				     struct virtio_net_hdr *hdr,
 				     int len, int *skb_xdp)
 {
-	struct page_frag *alloc_frag = &current->task_frag;
+	struct page_frag_cache *alloc_frag = &current->task_frag;
 	struct bpf_net_context __bpf_net_ctx, *bpf_net_ctx;
 	struct bpf_prog *xdp_prog;
 	int buflen = SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
@@ -1676,16 +1674,16 @@  static struct sk_buff *tun_build_skb(struct tun_struct *tun,
 	buflen += SKB_DATA_ALIGN(len + pad);
 	rcu_read_unlock();
 
-	alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES);
-	if (unlikely(!skb_page_frag_refill(buflen, alloc_frag, GFP_KERNEL)))
+	buf = page_frag_alloc_va_align(alloc_frag, buflen, GFP_KERNEL,
+				       SMP_CACHE_BYTES);
+	if (unlikely(!buf))
 		return ERR_PTR(-ENOMEM);
 
-	buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
-	copied = copy_page_from_iter(alloc_frag->page,
-				     alloc_frag->offset + pad,
-				     len, from);
-	if (copied != len)
+	copied = copy_from_iter(buf + pad, len, from);
+	if (copied != len) {
+		page_frag_alloc_abort(alloc_frag, buflen);
 		return ERR_PTR(-EFAULT);
+	}
 
 	/* There's a small window that XDP may be set after the check
 	 * of xdp_prog above, this should be rare and for simplicity
@@ -1693,8 +1691,7 @@  static struct sk_buff *tun_build_skb(struct tun_struct *tun,
 	 */
 	if (hdr->gso_type || !xdp_prog) {
 		*skb_xdp = 1;
-		return __tun_build_skb(tfile, alloc_frag, buf, buflen, len,
-				       pad);
+		return __tun_build_skb(tfile, buf, buflen, len, pad);
 	}
 
 	*skb_xdp = 0;
@@ -1711,21 +1708,16 @@  static struct sk_buff *tun_build_skb(struct tun_struct *tun,
 		xdp_prepare_buff(&xdp, buf, pad, len, false);
 
 		act = bpf_prog_run_xdp(xdp_prog, &xdp);
-		if (act == XDP_REDIRECT || act == XDP_TX) {
-			get_page(alloc_frag->page);
-			alloc_frag->offset += buflen;
-		}
 		err = tun_xdp_act(tun, xdp_prog, &xdp, act);
-		if (err < 0) {
-			if (act == XDP_REDIRECT || act == XDP_TX)
-				put_page(alloc_frag->page);
-			goto out;
-		}
-
 		if (err == XDP_REDIRECT)
 			xdp_do_flush();
-		if (err != XDP_PASS)
+
+		if (err == XDP_REDIRECT || err == XDP_TX) {
+			goto out;
+		} else if (err < 0 || err != XDP_PASS) {
+			page_frag_alloc_abort(alloc_frag, buflen);
 			goto out;
+		}
 
 		pad = xdp.data - xdp.data_hard_start;
 		len = xdp.data_end - xdp.data;
@@ -1734,7 +1726,7 @@  static struct sk_buff *tun_build_skb(struct tun_struct *tun,
 	rcu_read_unlock();
 	local_bh_enable();
 
-	return __tun_build_skb(tfile, alloc_frag, buf, buflen, len, pad);
+	return __tun_build_skb(tfile, buf, buflen, len, pad);
 
 out:
 	bpf_net_ctx_clear(bpf_net_ctx);
diff --git a/include/linux/sched.h b/include/linux/sched.h
index f8d150343d42..bb9a8e9d6d2d 100644
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -1355,7 +1355,7 @@  struct task_struct {
 	/* Cache last used pipe for splice(): */
 	struct pipe_inode_info		*splice_pipe;
 
-	struct page_frag		task_frag;
+	struct page_frag_cache		task_frag;
 
 #ifdef CONFIG_TASK_DELAY_ACCT
 	struct task_delay_info		*delays;
diff --git a/include/net/sock.h b/include/net/sock.h
index b5e702298ab7..8f6cc0dd2f4f 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -461,7 +461,7 @@  struct sock {
 	struct sk_buff_head	sk_write_queue;
 	u32			sk_dst_pending_confirm;
 	u32			sk_pacing_status; /* see enum sk_pacing */
-	struct page_frag	sk_frag;
+	struct page_frag_cache	sk_frag;
 	struct timer_list	sk_timer;
 
 	unsigned long		sk_pacing_rate; /* bytes per second */
@@ -2484,7 +2484,7 @@  static inline void sk_stream_moderate_sndbuf(struct sock *sk)
  * Return: a per task page_frag if context allows that,
  * otherwise a per socket one.
  */
-static inline struct page_frag *sk_page_frag(struct sock *sk)
+static inline struct page_frag_cache *sk_page_frag(struct sock *sk)
 {
 	if (sk->sk_use_task_frag)
 		return &current->task_frag;
@@ -2492,7 +2492,15 @@  static inline struct page_frag *sk_page_frag(struct sock *sk)
 	return &sk->sk_frag;
 }
 
-bool sk_page_frag_refill(struct sock *sk, struct page_frag *pfrag);
+struct page *sk_page_frag_alloc_prepare(struct sock *sk,
+					struct page_frag_cache *pfrag,
+					unsigned int *size,
+					unsigned int *offset, void **va);
+
+struct page *sk_page_frag_alloc_pg_prepare(struct sock *sk,
+					   struct page_frag_cache *pfrag,
+					   unsigned int *size,
+					   unsigned int *offset);
 
 /*
  *	Default write policy as shown to user space via poll/select/SIGIO
diff --git a/kernel/exit.c b/kernel/exit.c
index 7430852a8571..b5257e74ec1c 100644
--- a/kernel/exit.c
+++ b/kernel/exit.c
@@ -917,8 +917,7 @@  void __noreturn do_exit(long code)
 	if (tsk->splice_pipe)
 		free_pipe_info(tsk->splice_pipe);
 
-	if (tsk->task_frag.page)
-		put_page(tsk->task_frag.page);
+	page_frag_cache_drain(&tsk->task_frag);
 
 	exit_task_stack_account(tsk);
 
diff --git a/kernel/fork.c b/kernel/fork.c
index cc760491f201..7d380a6fd64a 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -80,6 +80,7 @@ 
 #include <linux/tty.h>
 #include <linux/fs_struct.h>
 #include <linux/magic.h>
+#include <linux/page_frag_cache.h>
 #include <linux/perf_event.h>
 #include <linux/posix-timers.h>
 #include <linux/user-return-notifier.h>
@@ -1157,10 +1158,10 @@  static struct task_struct *dup_task_struct(struct task_struct *orig, int node)
 	tsk->btrace_seq = 0;
 #endif
 	tsk->splice_pipe = NULL;
-	tsk->task_frag.page = NULL;
 	tsk->wake_q.next = NULL;
 	tsk->worker_private = NULL;
 
+	page_frag_cache_init(&tsk->task_frag);
 	kcov_task_init(tsk);
 	kmsan_task_create(tsk);
 	kmap_local_fork(tsk);
diff --git a/net/core/skbuff.c b/net/core/skbuff.c
index bb77c3fd192f..a7e86984bbcd 100644
--- a/net/core/skbuff.c
+++ b/net/core/skbuff.c
@@ -3040,25 +3040,6 @@  static void sock_spd_release(struct splice_pipe_desc *spd, unsigned int i)
 	put_page(spd->pages[i]);
 }
 
-static struct page *linear_to_page(struct page *page, unsigned int *len,
-				   unsigned int *offset,
-				   struct sock *sk)
-{
-	struct page_frag *pfrag = sk_page_frag(sk);
-
-	if (!sk_page_frag_refill(sk, pfrag))
-		return NULL;
-
-	*len = min_t(unsigned int, *len, pfrag->size - pfrag->offset);
-
-	memcpy(page_address(pfrag->page) + pfrag->offset,
-	       page_address(page) + *offset, *len);
-	*offset = pfrag->offset;
-	pfrag->offset += *len;
-
-	return pfrag->page;
-}
-
 static bool spd_can_coalesce(const struct splice_pipe_desc *spd,
 			     struct page *page,
 			     unsigned int offset)
@@ -3069,6 +3050,38 @@  static bool spd_can_coalesce(const struct splice_pipe_desc *spd,
 		 spd->partial[spd->nr_pages - 1].len == offset);
 }
 
+static bool spd_fill_linear_page(struct splice_pipe_desc *spd,
+				 struct page *page, unsigned int offset,
+				 unsigned int *len, struct sock *sk)
+{
+	struct page_frag_cache *pfrag = sk_page_frag(sk);
+	unsigned int frag_len, frag_offset;
+	struct page *frag_page;
+	void *va;
+
+	frag_page = sk_page_frag_alloc_prepare(sk, pfrag, &frag_offset,
+					       &frag_len, &va);
+	if (!frag_page)
+		return true;
+
+	*len = min_t(unsigned int, *len, frag_len);
+	memcpy(va, page_address(page) + offset, *len);
+
+	if (spd_can_coalesce(spd, frag_page, frag_offset)) {
+		spd->partial[spd->nr_pages - 1].len += *len;
+		page_frag_alloc_commit_noref(pfrag, *len);
+		return false;
+	}
+
+	page_frag_alloc_commit(pfrag, *len);
+	spd->pages[spd->nr_pages] = frag_page;
+	spd->partial[spd->nr_pages].len = *len;
+	spd->partial[spd->nr_pages].offset = frag_offset;
+	spd->nr_pages++;
+
+	return false;
+}
+
 /*
  * Fill page/offset/length into spd, if it can hold more pages.
  */
@@ -3081,11 +3094,9 @@  static bool spd_fill_page(struct splice_pipe_desc *spd,
 	if (unlikely(spd->nr_pages == MAX_SKB_FRAGS))
 		return true;
 
-	if (linear) {
-		page = linear_to_page(page, len, &offset, sk);
-		if (!page)
-			return true;
-	}
+	if (linear)
+		return spd_fill_linear_page(spd, page, offset, len,  sk);
+
 	if (spd_can_coalesce(spd, page, offset)) {
 		spd->partial[spd->nr_pages - 1].len += *len;
 		return false;
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index bbf40b999713..956fd6103909 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -27,23 +27,25 @@  static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
 		 int elem_first_coalesce)
 {
-	struct page_frag *pfrag = sk_page_frag(sk);
+	struct page_frag_cache *pfrag = sk_page_frag(sk);
 	u32 osize = msg->sg.size;
 	int ret = 0;
 
 	len -= msg->sg.size;
 	while (len > 0) {
+		unsigned int frag_offset, frag_len;
 		struct scatterlist *sge;
-		u32 orig_offset;
+		struct page *page;
 		int use, i;
 
-		if (!sk_page_frag_refill(sk, pfrag)) {
+		page = sk_page_frag_alloc_pg_prepare(sk, pfrag, &frag_offset,
+						     &frag_len);
+		if (!page) {
 			ret = -ENOMEM;
 			goto msg_trim;
 		}
 
-		orig_offset = pfrag->offset;
-		use = min_t(int, len, pfrag->size - orig_offset);
+		use = min_t(int, len, frag_len);
 		if (!sk_wmem_schedule(sk, use)) {
 			ret = -ENOMEM;
 			goto msg_trim;
@@ -54,9 +56,10 @@  int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
 		sge = &msg->sg.data[i];
 
 		if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
-		    sg_page(sge) == pfrag->page &&
-		    sge->offset + sge->length == orig_offset) {
+		    sg_page(sge) == page &&
+		    sge->offset + sge->length == frag_offset) {
 			sge->length += use;
+			page_frag_alloc_commit_noref(pfrag, use);
 		} else {
 			if (sk_msg_full(msg)) {
 				ret = -ENOSPC;
@@ -65,14 +68,13 @@  int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
 
 			sge = &msg->sg.data[msg->sg.end];
 			sg_unmark_end(sge);
-			sg_set_page(sge, pfrag->page, use, orig_offset);
-			get_page(pfrag->page);
+			sg_set_page(sge, page, use, frag_offset);
+			page_frag_alloc_commit(pfrag, use);
 			sk_msg_iter_next(msg, end);
 		}
 
 		sk_mem_charge(sk, use);
 		msg->sg.size += use;
-		pfrag->offset += use;
 		len -= use;
 	}
 
diff --git a/net/core/sock.c b/net/core/sock.c
index 9abc4fe25953..26c100ee9001 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -2207,10 +2207,7 @@  static void __sk_destruct(struct rcu_head *head)
 		pr_debug("%s: optmem leakage (%d bytes) detected\n",
 			 __func__, atomic_read(&sk->sk_omem_alloc));
 
-	if (sk->sk_frag.page) {
-		put_page(sk->sk_frag.page);
-		sk->sk_frag.page = NULL;
-	}
+	page_frag_cache_drain(&sk->sk_frag);
 
 	/* We do not need to acquire sk->sk_peer_lock, we are the last user. */
 	put_cred(sk->sk_peer_cred);
@@ -2956,16 +2953,43 @@  bool skb_page_frag_refill(unsigned int sz, struct page_frag *pfrag, gfp_t gfp)
 }
 EXPORT_SYMBOL(skb_page_frag_refill);
 
-bool sk_page_frag_refill(struct sock *sk, struct page_frag *pfrag)
+struct page *sk_page_frag_alloc_prepare(struct sock *sk,
+					struct page_frag_cache *pfrag,
+					unsigned int *offset,
+					unsigned int *size, void **va)
 {
-	if (likely(skb_page_frag_refill(32U, pfrag, sk->sk_allocation)))
-		return true;
+	struct page *page;
+
+	*size = 32U;
+	page = page_frag_alloc_prepare(pfrag, offset, size, va,
+				       sk->sk_allocation);
+	if (likely(page))
+		return page;
 
 	sk_enter_memory_pressure(sk);
 	sk_stream_moderate_sndbuf(sk);
-	return false;
+	return NULL;
+}
+EXPORT_SYMBOL(sk_page_frag_alloc_prepare);
+
+struct page *sk_page_frag_alloc_pg_prepare(struct sock *sk,
+					   struct page_frag_cache *pfrag,
+					   unsigned int *offset,
+					   unsigned int *size)
+{
+	struct page *page;
+
+	*size = 32U;
+	page = page_frag_alloc_pg_prepare(pfrag, offset, size,
+					  sk->sk_allocation);
+	if (likely(page))
+		return page;
+
+	sk_enter_memory_pressure(sk);
+	sk_stream_moderate_sndbuf(sk);
+	return NULL;
 }
-EXPORT_SYMBOL(sk_page_frag_refill);
+EXPORT_SYMBOL(sk_page_frag_alloc_pg_prepare);
 
 void __lock_sock(struct sock *sk)
 	__releases(&sk->sk_lock.slock)
@@ -3487,8 +3511,8 @@  void sock_init_data_uid(struct socket *sock, struct sock *sk, kuid_t uid)
 	sk->sk_error_report	=	sock_def_error_report;
 	sk->sk_destruct		=	sock_def_destruct;
 
-	sk->sk_frag.page	=	NULL;
-	sk->sk_frag.offset	=	0;
+	page_frag_cache_init(&sk->sk_frag);
+
 	sk->sk_peek_off		=	-1;
 
 	sk->sk_peer_pid 	=	NULL;
diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
index 8a10a7c67834..57499e3ed9e5 100644
--- a/net/ipv4/ip_output.c
+++ b/net/ipv4/ip_output.c
@@ -952,7 +952,7 @@  static int __ip_append_data(struct sock *sk,
 			    struct flowi4 *fl4,
 			    struct sk_buff_head *queue,
 			    struct inet_cork *cork,
-			    struct page_frag *pfrag,
+			    struct page_frag_cache *pfrag,
 			    int getfrag(void *from, char *to, int offset,
 					int len, int odd, struct sk_buff *skb),
 			    void *from, int length, int transhdrlen,
@@ -1228,31 +1228,38 @@  static int __ip_append_data(struct sock *sk,
 			wmem_alloc_delta += copy;
 		} else if (!zc) {
 			int i = skb_shinfo(skb)->nr_frags;
+			unsigned int frag_offset, frag_size;
+			struct page *page;
+			void *va;
 
 			err = -ENOMEM;
-			if (!sk_page_frag_refill(sk, pfrag))
+			page = sk_page_frag_alloc_prepare(sk, pfrag,
+							  &frag_offset,
+							  &frag_size, &va);
+			if (!page)
 				goto error;
 
 			skb_zcopy_downgrade_managed(skb);
-			if (!skb_can_coalesce(skb, i, pfrag->page,
-					      pfrag->offset)) {
+			copy = min_t(int, copy, frag_size);
+
+			if (!skb_can_coalesce(skb, i, page, frag_offset)) {
 				err = -EMSGSIZE;
 				if (i == MAX_SKB_FRAGS)
 					goto error;
 
-				__skb_fill_page_desc(skb, i, pfrag->page,
-						     pfrag->offset, 0);
+				__skb_fill_page_desc(skb, i, page, frag_offset,
+						     copy);
 				skb_shinfo(skb)->nr_frags = ++i;
-				get_page(pfrag->page);
+				page_frag_alloc_commit(pfrag, copy);
+			} else {
+				skb_frag_size_add(
+					&skb_shinfo(skb)->frags[i - 1], copy);
+				page_frag_alloc_commit_noref(pfrag, copy);
 			}
-			copy = min_t(int, copy, pfrag->size - pfrag->offset);
-			if (getfrag(from,
-				    page_address(pfrag->page) + pfrag->offset,
-				    offset, copy, skb->len, skb) < 0)
+
+			if (getfrag(from, va, offset, copy, skb->len, skb) < 0)
 				goto error_efault;
 
-			pfrag->offset += copy;
-			skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy);
 			skb_len_add(skb, copy);
 			wmem_alloc_delta += copy;
 		} else {
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 7c392710ae15..815ec53b16d5 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -1189,13 +1189,17 @@  int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
 		if (zc == 0) {
 			bool merge = true;
 			int i = skb_shinfo(skb)->nr_frags;
-			struct page_frag *pfrag = sk_page_frag(sk);
-
-			if (!sk_page_frag_refill(sk, pfrag))
+			struct page_frag_cache *pfrag = sk_page_frag(sk);
+			unsigned int frag_offset, frag_size;
+			struct page *page;
+			void *va;
+
+			page = sk_page_frag_alloc_prepare(sk, pfrag, &frag_offset,
+							  &frag_size, &va);
+			if (!page)
 				goto wait_for_space;
 
-			if (!skb_can_coalesce(skb, i, pfrag->page,
-					      pfrag->offset)) {
+			if (!skb_can_coalesce(skb, i, page, frag_offset)) {
 				if (i >= READ_ONCE(net_hotdata.sysctl_max_skb_frags)) {
 					tcp_mark_push(tp, skb);
 					goto new_segment;
@@ -1203,7 +1207,7 @@  int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
 				merge = false;
 			}
 
-			copy = min_t(int, copy, pfrag->size - pfrag->offset);
+			copy = min_t(int, copy, frag_size);
 
 			if (unlikely(skb_zcopy_pure(skb) || skb_zcopy_managed(skb))) {
 				if (tcp_downgrade_zcopy_pure(sk, skb))
@@ -1216,20 +1220,18 @@  int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
 				goto wait_for_space;
 
 			err = skb_copy_to_va_nocache(sk, &msg->msg_iter, skb,
-						     page_address(pfrag->page) +
-						     pfrag->offset, copy);
+						     va, copy);
 			if (err)
 				goto do_error;
 
 			/* Update the skb. */
 			if (merge) {
 				skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy);
+				page_frag_alloc_commit_noref(pfrag, copy);
 			} else {
-				skb_fill_page_desc(skb, i, pfrag->page,
-						   pfrag->offset, copy);
-				page_ref_inc(pfrag->page);
+				skb_fill_page_desc(skb, i, page, frag_offset, copy);
+				page_frag_alloc_commit(pfrag, copy);
 			}
-			pfrag->offset += copy;
 		} else if (zc == MSG_ZEROCOPY)  {
 			/* First append to a fragless skb builds initial
 			 * pure zerocopy skb
@@ -3131,11 +3133,7 @@  int tcp_disconnect(struct sock *sk, int flags)
 
 	WARN_ON(inet->inet_num && !icsk->icsk_bind_hash);
 
-	if (sk->sk_frag.page) {
-		put_page(sk->sk_frag.page);
-		sk->sk_frag.page = NULL;
-		sk->sk_frag.offset = 0;
-	}
+	page_frag_cache_drain(&sk->sk_frag);
 	sk_error_report(sk);
 	return 0;
 }
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 16c48df8df4c..43208092b89c 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -3970,9 +3970,12 @@  static int tcp_send_syn_data(struct sock *sk, struct sk_buff *syn)
 	struct inet_connection_sock *icsk = inet_csk(sk);
 	struct tcp_sock *tp = tcp_sk(sk);
 	struct tcp_fastopen_request *fo = tp->fastopen_req;
-	struct page_frag *pfrag = sk_page_frag(sk);
+	struct page_frag_cache *pfrag = sk_page_frag(sk);
+	unsigned int offset, size;
 	struct sk_buff *syn_data;
 	int space, err = 0;
+	struct page *page;
+	void *va;
 
 	tp->rx_opt.mss_clamp = tp->advmss;  /* If MSS is not cached */
 	if (!tcp_fastopen_cookie_check(sk, &tp->rx_opt.mss_clamp, &fo->cookie))
@@ -3991,30 +3994,31 @@  static int tcp_send_syn_data(struct sock *sk, struct sk_buff *syn)
 
 	space = min_t(size_t, space, fo->size);
 
-	if (space &&
-	    !skb_page_frag_refill(min_t(size_t, space, PAGE_SIZE),
-				  pfrag, sk->sk_allocation))
-		goto fallback;
+	if (space) {
+		size = min_t(size_t, space, PAGE_SIZE);
+		page = page_frag_alloc_prepare(pfrag, &offset, &size, &va,
+					       sk->sk_allocation);
+		if (!page)
+			goto fallback;
+	}
+
 	syn_data = tcp_stream_alloc_skb(sk, sk->sk_allocation, false);
 	if (!syn_data)
 		goto fallback;
 	memcpy(syn_data->cb, syn->cb, sizeof(syn->cb));
 	if (space) {
-		space = min_t(size_t, space, pfrag->size - pfrag->offset);
+		space = min_t(size_t, space, size);
 		space = tcp_wmem_schedule(sk, space);
 	}
 	if (space) {
-		space = copy_page_from_iter(pfrag->page, pfrag->offset,
-					    space, &fo->data->msg_iter);
+		space = _copy_from_iter(va, space, &fo->data->msg_iter);
 		if (unlikely(!space)) {
 			tcp_skb_tsorted_anchor_cleanup(syn_data);
 			kfree_skb(syn_data);
 			goto fallback;
 		}
-		skb_fill_page_desc(syn_data, 0, pfrag->page,
-				   pfrag->offset, space);
-		page_ref_inc(pfrag->page);
-		pfrag->offset += space;
+		skb_fill_page_desc(syn_data, 0, page, offset, space);
+		page_frag_alloc_commit(pfrag, space);
 		skb_len_add(syn_data, space);
 		skb_zcopy_set(syn_data, fo->uarg, NULL);
 	}
diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c
index ab504d31f0cd..86086b4bb55c 100644
--- a/net/ipv6/ip6_output.c
+++ b/net/ipv6/ip6_output.c
@@ -1405,7 +1405,7 @@  static int __ip6_append_data(struct sock *sk,
 			     struct sk_buff_head *queue,
 			     struct inet_cork_full *cork_full,
 			     struct inet6_cork *v6_cork,
-			     struct page_frag *pfrag,
+			     struct page_frag_cache *pfrag,
 			     int getfrag(void *from, char *to, int offset,
 					 int len, int odd, struct sk_buff *skb),
 			     void *from, size_t length, int transhdrlen,
@@ -1746,32 +1746,39 @@  static int __ip6_append_data(struct sock *sk,
 			copy = err;
 			wmem_alloc_delta += copy;
 		} else if (!zc) {
+			unsigned int frag_offset, frag_size;
 			int i = skb_shinfo(skb)->nr_frags;
+			struct page *page;
+			void *va;
 
 			err = -ENOMEM;
-			if (!sk_page_frag_refill(sk, pfrag))
+			page = sk_page_frag_alloc_prepare(sk, pfrag,
+							  &frag_offset,
+							  &frag_size, &va);
+			if (!page)
 				goto error;
 
 			skb_zcopy_downgrade_managed(skb);
-			if (!skb_can_coalesce(skb, i, pfrag->page,
-					      pfrag->offset)) {
+			copy = min_t(int, copy, frag_size);
+
+			if (!skb_can_coalesce(skb, i, page, frag_offset)) {
 				err = -EMSGSIZE;
 				if (i == MAX_SKB_FRAGS)
 					goto error;
 
-				__skb_fill_page_desc(skb, i, pfrag->page,
-						     pfrag->offset, 0);
+				__skb_fill_page_desc(skb, i, page, frag_offset,
+						     copy);
 				skb_shinfo(skb)->nr_frags = ++i;
-				get_page(pfrag->page);
+				page_frag_alloc_commit(pfrag, copy);
+			} else {
+				skb_frag_size_add(
+					&skb_shinfo(skb)->frags[i - 1], copy);
+				page_frag_alloc_commit_noref(pfrag, copy);
 			}
-			copy = min_t(int, copy, pfrag->size - pfrag->offset);
-			if (getfrag(from,
-				    page_address(pfrag->page) + pfrag->offset,
-				    offset, copy, skb->len, skb) < 0)
+
+			if (getfrag(from, va, offset, copy, skb->len, skb) < 0)
 				goto error_efault;
 
-			pfrag->offset += copy;
-			skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy);
 			skb->len += copy;
 			skb->data_len += copy;
 			skb->truesize += copy;
diff --git a/net/kcm/kcmsock.c b/net/kcm/kcmsock.c
index eec6c56b7f3e..e52ddf716fa5 100644
--- a/net/kcm/kcmsock.c
+++ b/net/kcm/kcmsock.c
@@ -803,13 +803,17 @@  static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 	while (msg_data_left(msg)) {
 		bool merge = true;
 		int i = skb_shinfo(skb)->nr_frags;
-		struct page_frag *pfrag = sk_page_frag(sk);
-
-		if (!sk_page_frag_refill(sk, pfrag))
+		struct page_frag_cache *pfrag = sk_page_frag(sk);
+		unsigned int offset, size;
+		struct page *page;
+		void *va;
+
+		page = sk_page_frag_alloc_prepare(sk, pfrag, &offset, &size,
+						  &va);
+		if (!page)
 			goto wait_for_memory;
 
-		if (!skb_can_coalesce(skb, i, pfrag->page,
-				      pfrag->offset)) {
+		if (!skb_can_coalesce(skb, i, page, offset)) {
 			if (i == MAX_SKB_FRAGS) {
 				struct sk_buff *tskb;
 
@@ -850,14 +854,12 @@  static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 			if (head != skb)
 				head->truesize += copy;
 		} else {
-			copy = min_t(int, msg_data_left(msg),
-				     pfrag->size - pfrag->offset);
+			copy = min_t(int, msg_data_left(msg), size);
 			if (!sk_wmem_schedule(sk, copy))
 				goto wait_for_memory;
 
 			err = skb_copy_to_va_nocache(sk, &msg->msg_iter, skb,
-						     page_address(pfrag->page) +
-						     pfrag->offset, copy);
+						     va, copy);
 			if (err)
 				goto out_error;
 
@@ -865,13 +867,12 @@  static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 			if (merge) {
 				skb_frag_size_add(
 					&skb_shinfo(skb)->frags[i - 1], copy);
+				page_frag_alloc_commit_noref(pfrag, copy);
 			} else {
-				skb_fill_page_desc(skb, i, pfrag->page,
-						   pfrag->offset, copy);
-				get_page(pfrag->page);
+				skb_fill_page_desc(skb, i, page, offset, copy);
+				page_frag_alloc_commit(pfrag, copy);
 			}
 
-			pfrag->offset += copy;
 		}
 
 		copied += copy;
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index 0d536b183a6c..3d27ede2c781 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -960,17 +960,16 @@  static bool mptcp_skb_can_collapse_to(u64 write_seq,
 }
 
 /* we can append data to the given data frag if:
- * - there is space available in the backing page_frag
- * - the data frag tail matches the current page_frag free offset
+ * - the data frag tail matches the current page and offset
  * - the data frag end sequence number matches the current write seq
  */
 static bool mptcp_frag_can_collapse_to(const struct mptcp_sock *msk,
-				       const struct page_frag *pfrag,
+				       const struct page *page,
+				       const unsigned int offset,
 				       const struct mptcp_data_frag *df)
 {
-	return df && pfrag->page == df->page &&
-		pfrag->size - pfrag->offset > 0 &&
-		pfrag->offset == (df->offset + df->data_len) &&
+	return df && page == df->page &&
+		offset == (df->offset + df->data_len) &&
 		df->data_seq + df->data_len == msk->write_seq;
 }
 
@@ -1085,30 +1084,36 @@  static void mptcp_enter_memory_pressure(struct sock *sk)
 /* ensure we get enough memory for the frag hdr, beyond some minimal amount of
  * data
  */
-static bool mptcp_page_frag_refill(struct sock *sk, struct page_frag *pfrag)
+static struct page *mptcp_page_frag_alloc_prepare(struct sock *sk,
+						  struct page_frag_cache *pfrag,
+						  unsigned int *offset,
+						  unsigned int *size, void **va)
 {
-	if (likely(skb_page_frag_refill(32U + sizeof(struct mptcp_data_frag),
-					pfrag, sk->sk_allocation)))
-		return true;
+	struct page *page;
+
+	page = page_frag_alloc_prepare(pfrag, offset, size, va,
+				       sk->sk_allocation);
+	if (likely(page))
+		return page;
 
 	mptcp_enter_memory_pressure(sk);
-	return false;
+	return NULL;
 }
 
 static struct mptcp_data_frag *
-mptcp_carve_data_frag(const struct mptcp_sock *msk, struct page_frag *pfrag,
-		      int orig_offset)
+mptcp_carve_data_frag(const struct mptcp_sock *msk, struct page *page,
+		      unsigned int orig_offset)
 {
 	int offset = ALIGN(orig_offset, sizeof(long));
 	struct mptcp_data_frag *dfrag;
 
-	dfrag = (struct mptcp_data_frag *)(page_to_virt(pfrag->page) + offset);
+	dfrag = (struct mptcp_data_frag *)(page_to_virt(page) + offset);
 	dfrag->data_len = 0;
 	dfrag->data_seq = msk->write_seq;
 	dfrag->overhead = offset - orig_offset + sizeof(struct mptcp_data_frag);
 	dfrag->offset = offset + sizeof(struct mptcp_data_frag);
 	dfrag->already_sent = 0;
-	dfrag->page = pfrag->page;
+	dfrag->page = page;
 
 	return dfrag;
 }
@@ -1795,7 +1800,7 @@  static u32 mptcp_send_limit(const struct sock *sk)
 static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
 {
 	struct mptcp_sock *msk = mptcp_sk(sk);
-	struct page_frag *pfrag;
+	struct page_frag_cache *pfrag;
 	size_t copied = 0;
 	int ret = 0;
 	long timeo;
@@ -1834,9 +1839,12 @@  static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
 	while (msg_data_left(msg)) {
 		int total_ts, frag_truesize = 0;
 		struct mptcp_data_frag *dfrag;
+		unsigned int offset = 0, size;
 		bool dfrag_collapsed;
-		size_t psize, offset;
+		struct page *page;
 		u32 copy_limit;
+		size_t psize;
+		void *va;
 
 		/* ensure fitting the notsent_lowat() constraint */
 		copy_limit = mptcp_send_limit(sk);
@@ -1847,21 +1855,27 @@  static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
 		 * page allocator
 		 */
 		dfrag = mptcp_pending_tail(sk);
-		dfrag_collapsed = mptcp_frag_can_collapse_to(msk, pfrag, dfrag);
+		size = 1;
+		page = page_frag_alloc_probe(pfrag, &offset, &size, &va);
+		dfrag_collapsed = mptcp_frag_can_collapse_to(msk, page, offset,
+							     dfrag);
 		if (!dfrag_collapsed) {
-			if (!mptcp_page_frag_refill(sk, pfrag))
+			size = 32U + sizeof(struct mptcp_data_frag);
+			page = mptcp_page_frag_alloc_prepare(sk, pfrag, &offset,
+							     &size, &va);
+			if (!page)
 				goto wait_for_memory;
 
-			dfrag = mptcp_carve_data_frag(msk, pfrag, pfrag->offset);
+			dfrag = mptcp_carve_data_frag(msk, page, offset);
 			frag_truesize = dfrag->overhead;
+			va += dfrag->overhead;
 		}
 
 		/* we do not bound vs wspace, to allow a single packet.
 		 * memory accounting will prevent execessive memory usage
 		 * anyway
 		 */
-		offset = dfrag->offset + dfrag->data_len;
-		psize = pfrag->size - offset;
+		psize = size - frag_truesize;
 		psize = min_t(size_t, psize, msg_data_left(msg));
 		psize = min_t(size_t, psize, copy_limit);
 		total_ts = psize + frag_truesize;
@@ -1869,8 +1883,7 @@  static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
 		if (!sk_wmem_schedule(sk, total_ts))
 			goto wait_for_memory;
 
-		ret = do_copy_data_nocache(sk, psize, &msg->msg_iter,
-					   page_address(dfrag->page) + offset);
+		ret = do_copy_data_nocache(sk, psize, &msg->msg_iter, va);
 		if (ret)
 			goto do_error;
 
@@ -1879,7 +1892,6 @@  static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
 		copied += psize;
 		dfrag->data_len += psize;
 		frag_truesize += psize;
-		pfrag->offset += frag_truesize;
 		WRITE_ONCE(msk->write_seq, msk->write_seq + psize);
 
 		/* charge data on mptcp pending queue to the msk socket
@@ -1887,11 +1899,14 @@  static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
 		 */
 		sk_wmem_queued_add(sk, frag_truesize);
 		if (!dfrag_collapsed) {
-			get_page(dfrag->page);
+			page_frag_alloc_commit(pfrag, frag_truesize);
 			list_add_tail(&dfrag->list, &msk->rtx_queue);
 			if (!msk->first_pending)
 				WRITE_ONCE(msk->first_pending, dfrag);
+		} else {
+			page_frag_alloc_commit_noref(pfrag, frag_truesize);
 		}
+
 		pr_debug("msk=%p dfrag at seq=%llu len=%u sent=%u new=%d", msk,
 			 dfrag->data_seq, dfrag->data_len, dfrag->already_sent,
 			 !dfrag_collapsed);
diff --git a/net/sched/em_meta.c b/net/sched/em_meta.c
index 8996c73c9779..4da465af972f 100644
--- a/net/sched/em_meta.c
+++ b/net/sched/em_meta.c
@@ -590,7 +590,7 @@  META_COLLECTOR(int_sk_sendmsg_off)
 		*err = -1;
 		return;
 	}
-	dst->value = sk->sk_frag.offset;
+	dst->value = page_frag_cache_page_offset(&sk->sk_frag);
 }
 
 META_COLLECTOR(int_sk_write_pend)
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index dc063c2c7950..02925c25ae12 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -253,24 +253,42 @@  static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
 }
 
 static void tls_append_frag(struct tls_record_info *record,
-			    struct page_frag *pfrag,
-			    int size)
+			    struct page_frag_cache *pfrag, struct page *page,
+			    unsigned int offset, unsigned int size)
 {
 	skb_frag_t *frag;
 
 	frag = &record->frags[record->num_frags - 1];
-	if (skb_frag_page(frag) == pfrag->page &&
-	    skb_frag_off(frag) + skb_frag_size(frag) == pfrag->offset) {
+	if (skb_frag_page(frag) == page &&
+	    skb_frag_off(frag) + skb_frag_size(frag) == offset) {
 		skb_frag_size_add(frag, size);
+		page_frag_alloc_commit_noref(pfrag, size);
 	} else {
 		++frag;
-		skb_frag_fill_page_desc(frag, pfrag->page, pfrag->offset,
-					size);
+		skb_frag_fill_page_desc(frag, page, offset, size);
 		++record->num_frags;
-		get_page(pfrag->page);
+		page_frag_alloc_commit(pfrag, size);
+	}
+
+	record->len += size;
+}
+
+static void tls_append_page(struct tls_record_info *record, struct page *page,
+			    unsigned int offset, unsigned int size)
+{
+	skb_frag_t *frag;
+
+	frag = &record->frags[record->num_frags - 1];
+	if (skb_frag_page(frag) == page &&
+	    skb_frag_off(frag) + skb_frag_size(frag) == offset) {
+		skb_frag_size_add(frag, size);
+	} else {
+		++frag;
+		skb_frag_fill_page_desc(frag, page, offset, size);
+		++record->num_frags;
+		get_page(page);
 	}
 
-	pfrag->offset += size;
 	record->len += size;
 }
 
@@ -311,11 +329,12 @@  static int tls_push_record(struct sock *sk,
 static void tls_device_record_close(struct sock *sk,
 				    struct tls_context *ctx,
 				    struct tls_record_info *record,
-				    struct page_frag *pfrag,
+				    struct page_frag_cache *pfrag,
 				    unsigned char record_type)
 {
 	struct tls_prot_info *prot = &ctx->prot_info;
-	struct page_frag dummy_tag_frag;
+	unsigned int offset, size;
+	struct page *page;
 
 	/* append tag
 	 * device will fill in the tag, we just need to append a placeholder
@@ -323,13 +342,14 @@  static void tls_device_record_close(struct sock *sk,
 	 * increases frag count)
 	 * if we can't allocate memory now use the dummy page
 	 */
-	if (unlikely(pfrag->size - pfrag->offset < prot->tag_size) &&
-	    !skb_page_frag_refill(prot->tag_size, pfrag, sk->sk_allocation)) {
-		dummy_tag_frag.page = dummy_page;
-		dummy_tag_frag.offset = 0;
-		pfrag = &dummy_tag_frag;
+	size = prot->tag_size;
+	page = page_frag_alloc_pg_prepare(pfrag, &offset, &size,
+					  sk->sk_allocation);
+	if (unlikely(!page)) {
+		tls_append_page(record, dummy_page, 0, prot->tag_size);
+	} else {
+		tls_append_frag(record, pfrag, page, offset, prot->tag_size);
 	}
-	tls_append_frag(record, pfrag, prot->tag_size);
 
 	/* fill prepend */
 	tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]),
@@ -337,57 +357,52 @@  static void tls_device_record_close(struct sock *sk,
 			 record_type);
 }
 
-static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
-				 struct page_frag *pfrag,
+static int tls_create_new_record(struct sock *sk,
+				 struct tls_offload_context_tx *offload_ctx,
+				 struct page_frag_cache *pfrag,
 				 size_t prepend_size)
 {
 	struct tls_record_info *record;
+	unsigned int offset;
+	struct page *page;
 	skb_frag_t *frag;
 
 	record = kmalloc(sizeof(*record), GFP_KERNEL);
 	if (!record)
 		return -ENOMEM;
 
-	frag = &record->frags[0];
-	skb_frag_fill_page_desc(frag, pfrag->page, pfrag->offset,
-				prepend_size);
-
-	get_page(pfrag->page);
-	pfrag->offset += prepend_size;
+	page = page_frag_alloc_pg(pfrag, &offset, prepend_size,
+				  sk->sk_allocation);
+	if (!page) {
+		kfree(record);
+		READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
+		sk_stream_moderate_sndbuf(sk);
+		return -ENOMEM;
+	}
 
+	frag = &record->frags[0];
+	skb_frag_fill_page_desc(frag, page, offset, prepend_size);
 	record->num_frags = 1;
 	record->len = prepend_size;
 	offload_ctx->open_record = record;
 	return 0;
 }
 
-static int tls_do_allocation(struct sock *sk,
-			     struct tls_offload_context_tx *offload_ctx,
-			     struct page_frag *pfrag,
-			     size_t prepend_size)
+static struct page *tls_do_allocation(struct sock *sk,
+				      struct tls_offload_context_tx *ctx,
+				      struct page_frag_cache *pfrag,
+				      size_t prepend_size, unsigned int *offset,
+				      unsigned int *size, void **va)
 {
-	int ret;
-
-	if (!offload_ctx->open_record) {
-		if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
-						   sk->sk_allocation))) {
-			READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
-			sk_stream_moderate_sndbuf(sk);
-			return -ENOMEM;
-		}
+	if (!ctx->open_record) {
+		int ret;
 
-		ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
+		ret = tls_create_new_record(sk, ctx, pfrag, prepend_size);
 		if (ret)
-			return ret;
-
-		if (pfrag->size > pfrag->offset)
-			return 0;
+			return NULL;
 	}
 
-	if (!sk_page_frag_refill(sk, pfrag))
-		return -ENOMEM;
-
-	return 0;
+	return sk_page_frag_alloc_prepare(sk, pfrag, offset, size, va);
 }
 
 static int tls_device_copy_data(void *addr, size_t bytes, struct iov_iter *i)
@@ -424,8 +439,8 @@  static int tls_push_data(struct sock *sk,
 	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
 	struct tls_record_info *record;
+	struct page_frag_cache *pfrag;
 	int tls_push_record_flags;
-	struct page_frag *pfrag;
 	size_t orig_size = size;
 	u32 max_open_record_len;
 	bool more = false;
@@ -462,8 +477,13 @@  static int tls_push_data(struct sock *sk,
 	max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
 			      prot->prepend_size;
 	do {
-		rc = tls_do_allocation(sk, ctx, pfrag, prot->prepend_size);
-		if (unlikely(rc)) {
+		unsigned int frag_offset, frag_size;
+		struct page *page;
+		void *va;
+
+		page = tls_do_allocation(sk, ctx, pfrag, prot->prepend_size,
+					 &frag_offset, &frag_size, &va);
+		if (unlikely(!page)) {
 			rc = sk_stream_wait_memory(sk, &timeo);
 			if (!rc)
 				continue;
@@ -491,8 +511,8 @@  static int tls_push_data(struct sock *sk,
 
 		copy = min_t(size_t, size, max_open_record_len - record->len);
 		if (copy && (flags & MSG_SPLICE_PAGES)) {
-			struct page_frag zc_pfrag;
-			struct page **pages = &zc_pfrag.page;
+			struct page *splice_page;
+			struct page **pages = &splice_page;
 			size_t off;
 
 			rc = iov_iter_extract_pages(iter, &pages,
@@ -504,24 +524,21 @@  static int tls_push_data(struct sock *sk,
 			}
 			copy = rc;
 
-			if (WARN_ON_ONCE(!sendpage_ok(zc_pfrag.page))) {
+			if (WARN_ON_ONCE(!sendpage_ok(splice_page))) {
 				iov_iter_revert(iter, copy);
 				rc = -EIO;
 				goto handle_error;
 			}
 
-			zc_pfrag.offset = off;
-			zc_pfrag.size = copy;
-			tls_append_frag(record, &zc_pfrag, copy);
+			tls_append_page(record, splice_page, off, copy);
 		} else if (copy) {
-			copy = min_t(size_t, copy, pfrag->size - pfrag->offset);
+			copy = min_t(size_t, copy, frag_size);
 
-			rc = tls_device_copy_data(page_address(pfrag->page) +
-						  pfrag->offset, copy,
-						  iter);
+			rc = tls_device_copy_data(va, copy, iter);
 			if (rc)
 				goto handle_error;
-			tls_append_frag(record, pfrag, copy);
+
+			tls_append_frag(record, pfrag, page, frag_offset, copy);
 		}
 
 		size -= copy;