diff mbox series

[v2,1/1] lib/string: Add strscpy_pad() function

Message ID 20190225041534.27186-2-tobin@kernel.org (mailing list archive)
State New, archived
Headers show
Series lib/string: Add strscpy_pad() function | expand

Commit Message

Tobin C. Harding Feb. 25, 2019, 4:15 a.m. UTC
We have a function to copy strings safely and we have a function to copy
strings and zero the tail of the destination (if source string is
shorter than destination buffer) but we do not have a function to do
both at once.  This means developers must write this themselves if they
desire this functionality.  This is a chore, and also leaves us open to
off by one errors unnecessarily.

Add a function that calls strscpy() then memset()s the tail to zero if
the source string is shorter than the destination buffer.

Add test module for the new code.

Signed-off-by: Tobin C. Harding <tobin@kernel.org>
---
 include/linux/string.h |   4 +
 lib/Kconfig.debug      |   3 +
 lib/Makefile           |   1 +
 lib/string.c           |  47 +++++++++--
 lib/test_strscpy.c     | 175 +++++++++++++++++++++++++++++++++++++++++
 5 files changed, 223 insertions(+), 7 deletions(-)
 create mode 100644 lib/test_strscpy.c

Comments

Andy Shevchenko Feb. 25, 2019, 8:19 a.m. UTC | #1
On Mon, Feb 25, 2019 at 6:17 AM Tobin C. Harding <tobin@kernel.org> wrote:
>
> We have a function to copy strings safely and we have a function to copy
> strings and zero the tail of the destination (if source string is
> shorter than destination buffer) but we do not have a function to do
> both at once.  This means developers must write this themselves if they
> desire this functionality.  This is a chore, and also leaves us open to
> off by one errors unnecessarily.
>
> Add a function that calls strscpy() then memset()s the tail to zero if
> the source string is shorter than the destination buffer.
>
> Add test module for the new code.

> --- /dev/null
> +++ b/lib/test_strscpy.c
> @@ -0,0 +1,175 @@

> +// SPDX-License-Identifier: GPL-2.0

> +MODULE_LICENSE("GPL");

License mismatch.

Do we need a separate module for this test?
Tobin Harding Feb. 25, 2019, 9:31 p.m. UTC | #2
On Mon, Feb 25, 2019 at 10:19:47AM +0200, Andy Shevchenko wrote:
> On Mon, Feb 25, 2019 at 6:17 AM Tobin C. Harding <tobin@kernel.org> wrote:
> >
> > We have a function to copy strings safely and we have a function to copy
> > strings and zero the tail of the destination (if source string is
> > shorter than destination buffer) but we do not have a function to do
> > both at once.  This means developers must write this themselves if they
> > desire this functionality.  This is a chore, and also leaves us open to
> > off by one errors unnecessarily.
> >
> > Add a function that calls strscpy() then memset()s the tail to zero if
> > the source string is shorter than the destination buffer.
> >
> > Add test module for the new code.
> 
> > --- /dev/null
> > +++ b/lib/test_strscpy.c
> > @@ -0,0 +1,175 @@
> 
> > +// SPDX-License-Identifier: GPL-2.0
> 
> > +MODULE_LICENSE("GPL");
> 
> License mismatch.

Thanks, will re-spin with 

  // SPDX-License-Identifier: GPL-2.0+

> Do we need a separate module for this test?

Separate as in not in lib/test_string.h?  I intend on moving the test
into that file once I've done some cleanup in tools/testing/selftest/lib/

I also tried to do this without using a module using
tools/testing/selftest/kselftest_harness.h but I could not get the
compiler to see read the patched version of
linux/include/linux/string.h?

Related question if you feel like answering it; why are test modules for
lib/ in lib/ and not in tools/testing/?

Very much open to suggestions on current best practices for kernel testing.

thanks,
Tobin.
Tobin Harding Feb. 25, 2019, 9:37 p.m. UTC | #3
On Tue, Feb 26, 2019 at 08:31:09AM +1100, Tobin C. Harding wrote:
> On Mon, Feb 25, 2019 at 10:19:47AM +0200, Andy Shevchenko wrote:
> > On Mon, Feb 25, 2019 at 6:17 AM Tobin C. Harding <tobin@kernel.org> wrote:
> > >
> > > We have a function to copy strings safely and we have a function to copy
> > > strings and zero the tail of the destination (if source string is
> > > shorter than destination buffer) but we do not have a function to do
> > > both at once.  This means developers must write this themselves if they
> > > desire this functionality.  This is a chore, and also leaves us open to
> > > off by one errors unnecessarily.
> > >
> > > Add a function that calls strscpy() then memset()s the tail to zero if
> > > the source string is shorter than the destination buffer.
> > >
> > > Add test module for the new code.
> > 
> > > --- /dev/null
> > > +++ b/lib/test_strscpy.c
> > > @@ -0,0 +1,175 @@
> > 
> > > +// SPDX-License-Identifier: GPL-2.0
> > 
> > > +MODULE_LICENSE("GPL");
> > 
> > License mismatch.
> 
> Thanks, will re-spin with 
> 
>   // SPDX-License-Identifier: GPL-2.0+
> 
> > Do we need a separate module for this test?
> 
> Separate as in not in lib/test_string.h?  I intend on moving the test
> into that file once I've done some cleanup in tools/testing/selftest/lib/
> 
> I also tried to do this without using a module using
> tools/testing/selftest/kselftest_harness.h but I could not get the
> compiler to see read the patched version of
> linux/include/linux/string.h?
> 
> Related question if you feel like answering it; why are test modules for
> lib/ in lib/ and not in tools/testing/?
> 
> Very much open to suggestions on current best practices for kernel testing.

I just read Andy mirskFrom: "Tobin C. Harding" <me@tobin.cc>
To: Andy Shevchenko <andy.shevchenko@gmail.com>
Cc: "Tobin C. Harding" <tobin@kernel.org>,
	Kees Cook <keescook@chromium.org>, Jann Horn <jannh@google.com>,
	Andy Shevchenko <andriy.shevchenko@linux.intel.com>,
	Randy Dunlap <rdunlap@infradead.org>,
	Rasmus Villemoes <linux@rasmusvillemoes.dk>,
	Stephen Rothwell <sfr@canb.auug.org.au>,
	Andy Lutomirski <luto@amacapital.net>,
	Daniel Micay <danielmicay@gmail.com>, Arnd Bergmann <arnd@arndb.de>,
	Miguel Ojeda <miguel.ojeda.sandonis@gmail.com>,
	"Gustavo A. R. Silva" <gustavo@embeddedor.com>,
	Shuah Khan <shuah@kernel.org>,
	Greg Kroah-Hartman <gregkh@linuxfoundation.org>,
	Alexander Shishkin <alexander.shishkin@linux.intel.com>,
	Kernel Hardening <kernel-hardening@lists.openwall.com>,
	Linux Kernel Mailing List <linux-kernel@vger.kernel.org>
Bcc: 
Subject: Re: [PATCH v2 1/1] lib/string: Add strscpy_pad() function
Reply-To: 
In-Reply-To: <20190225213109.GB5177@eros.localdomain>
X-Mailer: Mutt 1.11.3 (2019-02-01)

On Tue, Feb 26, 2019 at 08:31:09AM +1100, Tobin C. Harding wrote:
> On Mon, Feb 25, 2019 at 10:19:47AM +0200, Andy Shevchenko wrote:
> > On Mon, Feb 25, 2019 at 6:17 AM Tobin C. Harding <tobin@kernel.org> wrote:
> > >
> > > We have a function to copy strings safely and we have a function to copy
> > > strings and zero the tail of the destination (if source string is
> > > shorter than destination buffer) but we do not have a function to do
> > > both at once.  This means developers must write this themselves if they
> > > desire this functionality.  This is a chore, and also leaves us open to
> > > off by one errors unnecessarily.
> > >
> > > Add a function that calls strscpy() then memset()s the tail to zero if
> > > the source string is shorter than the destination buffer.
> > >
> > > Add test module for the new code.
> > 
> > > --- /dev/null
> > > +++ b/lib/test_strscpy.c
> > > @@ -0,0 +1,175 @@
> > 
> > > +// SPDX-License-Identifier: GPL-2.0
> > 
> > > +MODULE_LICENSE("GPL");
> > 
> > License mismatch.
> 
> Thanks, will re-spin with 
> 
>   // SPDX-License-Identifier: GPL-2.0+
> 
> > Do we need a separate module for this test?
> 
> Separate as in not in lib/test_string.h?  I intend on moving the test
> into that file once I've done some cleanup in tools/testing/selftest/lib/
> 
> I also tried to do this without using a module using
> tools/testing/selftest/kselftest_harness.h but I could not get the
> compiler to see read the patched version of
> linux/include/linux/string.h?
> 
> Related question if you feel like answering it; why are test modules for
> lib/ in lib/ and not in tools/testing/?
> 
> Very much open to suggestions on current best practices for kernel testing.

I just read Andy mirskFrom: "Tobin C. Harding" <me@tobin.cc>
To: Andy Shevchenko <andy.shevchenko@gmail.com>
Cc: "Tobin C. Harding" <tobin@kernel.org>,
	Kees Cook <keescook@chromium.org>, Jann Horn <jannh@google.com>,
	Andy Shevchenko <andriy.shevchenko@linux.intel.com>,
	Randy Dunlap <rdunlap@infradead.org>,
	Rasmus Villemoes <linux@rasmusvillemoes.dk>,
	Stephen Rothwell <sfr@canb.auug.org.au>,
	Andy Lutomirski <luto@amacapital.net>,
	Daniel Micay <danielmicay@gmail.com>, Arnd Bergmann <arnd@arndb.de>,
	Miguel Ojeda <miguel.ojeda.sandonis@gmail.com>,
	"Gustavo A. R. Silva" <gustavo@embeddedor.com>,
	Shuah Khan <shuah@kernel.org>,
	Greg Kroah-Hartman <gregkh@linuxfoundation.org>,
	Alexander Shishkin <alexander.shishkin@linux.intel.com>,
	Kernel Hardening <kernel-hardening@lists.openwall.com>,
	Linux Kernel Mailing List <linux-kernel@vger.kernel.org>
Bcc: 
Subject: Re: [PATCH v2 1/1] lib/string: Add strscpy_pad() function
Reply-To: 
In-Reply-To: <20190225213109.GB5177@eros.localdomain>
X-Mailer: Mutt 1.11.3 (2019-02-01)

On Tue, Feb 26, 2019 at 08:31:09AM +1100, Tobin C. Harding wrote:
> On Mon, Feb 25, 2019 at 10:19:47AM +0200, Andy Shevchenko wrote:
> > On Mon, Feb 25, 2019 at 6:17 AM Tobin C. Harding <tobin@kernel.org> wrote:
> > >
> > > We have a function to copy strings safely and we have a function to copy
> > > strings and zero the tail of the destination (if source string is
> > > shorter than destination buffer) but we do not have a function to do
> > > both at once.  This means developers must write this themselves if they
> > > desire this functionality.  This is a chore, and also leaves us open to
> > > off by one errors unnecessarily.
> > >
> > > Add a function that calls strscpy() then memset()s the tail to zero if
> > > the source string is shorter than the destination buffer.
> > >
> > > Add test module for the new code.
> > 
> > > --- /dev/null
> > > +++ b/lib/test_strscpy.c
> > > @@ -0,0 +1,175 @@
> > 
> > > +// SPDX-License-Identifier: GPL-2.0
> > 
> > > +MODULE_LICENSE("GPL");
> > 
> > License mismatch.
> 
> Thanks, will re-spin with 
> 
>   // SPDX-License-Identifier: GPL-2.0+
> 
> > Do we need a separate module for this test?
> 
> Separate as in not in lib/test_string.h?  I intend on moving the test
> into that file once I've done some cleanup in tools/testing/selftest/lib/
> 
> I also tried to do this without using a module using
> tools/testing/selftest/kselftest_harness.h but I could not get the
> compiler to see read the patched version of
> linux/include/linux/string.h?
> 
> Related question if you feel like answering it; why are test modules for
> lib/ in lib/ and not in tools/testing/?
> 
> Very much open to suggestions on current best practices for kernel testing.

I just read Andy Lutomirski's patch

       [PATCH 2/2] uaccess: Add a selftest for strncpy_from_user()

That's a better approach for this one also, right?  Put the test in
string.c and ifdef guard it with a config option.

thanks,
Tobin.
Kees Cook Feb. 25, 2019, 9:38 p.m. UTC | #4
On Sun, Feb 24, 2019 at 8:16 PM Tobin C. Harding <tobin@kernel.org> wrote:
>
> We have a function to copy strings safely and we have a function to copy
> strings and zero the tail of the destination (if source string is
> shorter than destination buffer) but we do not have a function to do
> both at once.  This means developers must write this themselves if they
> desire this functionality.  This is a chore, and also leaves us open to
> off by one errors unnecessarily.
>
> Add a function that calls strscpy() then memset()s the tail to zero if
> the source string is shorter than the destination buffer.
>
> Add test module for the new code.
>
> Signed-off-by: Tobin C. Harding <tobin@kernel.org>
> [...]
> +ssize_t strscpy_pad(char *dest, const char *src, size_t count)
> +{
> +       ssize_t written;
> +
> +       written = strscpy(dest, src, count);
> +       if (written < 0 || written == count - 1)
> +               return written;

*thread merge* Yeah, good point. written will be -E2BIG for both count
= 0 and count = 1.

> +
> +       memset(dest + written + 1, 0, count - written - 1);
> +
> +       return written;
> +}
> +EXPORT_SYMBOL(strscpy_pad);
> +
>  #ifndef __HAVE_ARCH_STRCAT
>  /**
>   * strcat - Append one %NUL-terminated string to another
> diff --git a/lib/test_strscpy.c b/lib/test_strscpy.c
> new file mode 100644
> index 000000000000..5ec6a196f4e2
> --- /dev/null
> +++ b/lib/test_strscpy.c
> @@ -0,0 +1,175 @@
> +// SPDX-License-Identifier: GPL-2.0
> +
> +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
> +
> +#include <linux/init.h>
> +#include <linux/kernel.h>
> +#include <linux/module.h>
> +#include <linux/printk.h>
> +#include <linux/string.h>
> +
> +/*
> + * Kernel module for testing 'strscpy' family of functions.
> + */
> +
> +static unsigned total_tests __initdata;
> +static unsigned failed_tests __initdata;
> +
> +static void __init do_test(int count, char *src, int expected,
> +                          int chars, int terminator, int pad)
> +{
> +       char buf[6];

I would make this "6" a define, since you use it in more than one
place. Actually... no, never mind. The other places can use
"sizeof(buf)" instead of "6". Noted below...

But I'd add an explicit check for "expected + 1 < sizeof(buf)" just
for future test-addition sanity.

> +       int written;
> +       int poison;
> +       int index;
> +       int i;
> +       const char POISON = 'z';
> +
> +       total_tests++;
> +       memset(buf, POISON, sizeof(buf));
> +
> +       /* Verify the return value */
> +

Needless blank line.

> +       written = strscpy_pad(buf, src, count);
> +       if ((written) != (expected)) {
> +               pr_err("%d != %d (written, expected)\n", written, expected);
> +               goto fail;
> +       }
> +
> +       /* Verify the state of the buffer */
> +

Same.

> +       if (count && written == -E2BIG) {
> +               if (strncmp(buf, src, count - 1) != 0) {
> +                       pr_err("buffer state invalid for -E2BIG\n");
> +                       goto fail;
> +               }
> +               if (buf[count - 1] != '\0') {
> +                       pr_err("too big string is not null terminated correctly\n");
> +                       goto fail;
> +               }
> +       }
> +
> +       /* Verify the copied content */
> +       for (i = 0; i < chars; i++) {
> +               if (buf[i] != src[i]) {
> +                       pr_err("buf[i]==%c != src[i]==%c\n", buf[i], src[i]);
> +                       goto fail;
> +               }
> +       }
> +
> +       /* Verify the null terminator */
> +       if (terminator) {
> +               if (buf[count - 1] != '\0') {
> +                       pr_err("string is not null terminated correctly\n");
> +                       goto fail;
> +               }
> +       }
> +
> +       /* Verify the padding */
> +       for (i = 0; i < pad; i++) {
> +               index = chars + terminator + i;
> +               if (buf[index] != '\0') {
> +                       pr_err("padding missing at index: %d\n", i);
> +                       goto fail;
> +               }
> +       }
> +
> +       /* Verify the rest is left untouched */
> +       poison = 6 - chars - terminator - pad;

instead of "6", use "sizeof(buf)".

> +       for (i = 0; i < poison; i++) {
> +               index = 6 - 1 - i; /* Check from the end back */

Same.

> +               if (buf[index] != POISON) {
> +                       pr_err("poison value missing at index: %d\n", i);
> +                       goto fail;
> +               }
> +       }
> +
> +       return;
> +fail:
> +       pr_info("%s(%d, '%s', %d, %d, %d, %d)\n", __func__,
> +               count, src, expected, chars, terminator, pad);
> +       failed_tests++;
> +}
> +
> +static void __init test_fully(void)
> +{
> +       /* do_test(count, src, expected, chars, terminator, pad) */
> +
> +       do_test(0, "a", -E2BIG, 0, 0, 0);
> +       do_test(0, "", -E2BIG, 0, 0, 0);
> +
> +       do_test(1, "a", -E2BIG, 0, 1, 0);
> +       do_test(1, "", 0, 0, 1, 0);
> +
> +       do_test(2, "ab", -E2BIG, 1, 1, 0);
> +       do_test(2, "a", 1, 1, 1, 0);
> +       do_test(2, "", 0, 0, 1, 1);
> +
> +       do_test(3, "abc", -E2BIG, 2, 1, 0);
> +       do_test(3, "ab", 2, 2, 1, 0);
> +       do_test(3, "a", 1, 1, 1, 1);
> +       do_test(3, "", 0, 0, 1, 2);
> +
> +       do_test(4, "abcd", -E2BIG, 3, 1, 0);
> +       do_test(4, "abc", 3, 3, 1, 0);
> +       do_test(4, "ab", 2, 2, 1, 1);
> +       do_test(4, "a", 1, 1, 1, 2);
> +       do_test(4, "", 0, 0, 1, 3);
> +}
> +
> +static void __init test_basic(void)
> +{
> +       char buf[6];
> +       int written;
> +
> +       memset(buf, 'a', sizeof(buf));
> +
> +       total_tests++;
> +       written = strscpy_pad(buf, "bb", 4);
> +       if (written != 2)
> +               failed_tests++;
> +
> +       /* Correctly copied */
> +       total_tests++;
> +       if (buf[0] != 'b' || buf[1] != 'b')
> +               failed_tests++;
> +
> +       /* Correctly padded */
> +       total_tests++;
> +       if (buf[2] != '\0' || buf[3] != '\0')
> +               failed_tests++;
> +
> +       /* Only touched what it was supposed to */
> +       total_tests++;
> +       if (buf[4] != 'a' || buf[5] != 'a')
> +               failed_tests++;
> +}

I don't think you need to keep "test_basic". Anything it tests can be
rewritten in terms of do_test().

> +
> +static int __init test_strscpy_init(void)
> +{
> +       pr_info("loaded.\n");
> +
> +       test_basic();
> +       if (failed_tests)
> +               goto out;
> +
> +       test_fully();
> +
> +out:
> +       if (failed_tests == 0)
> +               pr_info("all %u tests passed\n", total_tests);
> +       else
> +               pr_warn("failed %u out of %u tests\n", failed_tests, total_tests);
> +
> +       return failed_tests ? -EINVAL : 0;
> +}
> +module_init(test_strscpy_init);
> +
> +static void __exit test_strscpy_exit(void)
> +{
> +       pr_info("unloaded.\n");
> +}
> +module_exit(test_strscpy_exit);
> +
> +MODULE_AUTHOR("Tobin C. Harding <tobin@kernel.org>");
> +MODULE_LICENSE("GPL");
> --
> 2.20.1
>

Otherwise, yes, looks good!
Tobin Harding Feb. 27, 2019, 4:40 a.m. UTC | #5
On Mon, Feb 25, 2019 at 01:38:44PM -0800, Kees Cook wrote:
> On Sun, Feb 24, 2019 at 8:16 PM Tobin C. Harding <tobin@kernel.org> wrote:
> >
> > We have a function to copy strings safely and we have a function to copy
> > strings and zero the tail of the destination (if source string is
> > shorter than destination buffer) but we do not have a function to do
> > both at once.  This means developers must write this themselves if they
> > desire this functionality.  This is a chore, and also leaves us open to
> > off by one errors unnecessarily.
> >
> > Add a function that calls strscpy() then memset()s the tail to zero if
> > the source string is shorter than the destination buffer.
> >
> > Add test module for the new code.
> >
> > Signed-off-by: Tobin C. Harding <tobin@kernel.org>
> > [...]
> > +ssize_t strscpy_pad(char *dest, const char *src, size_t count)
> > +{
> > +       ssize_t written;
> > +
> > +       written = strscpy(dest, src, count);
> > +       if (written < 0 || written == count - 1)
> > +               return written;
> 
> *thread merge* Yeah, good point. written will be -E2BIG for both count
> = 0 and count = 1.
> 
> > +
> > +       memset(dest + written + 1, 0, count - written - 1);
> > +
> > +       return written;
> > +}
> > +EXPORT_SYMBOL(strscpy_pad);
> > +
> >  #ifndef __HAVE_ARCH_STRCAT
> >  /**
> >   * strcat - Append one %NUL-terminated string to another
> > diff --git a/lib/test_strscpy.c b/lib/test_strscpy.c
> > new file mode 100644
> > index 000000000000..5ec6a196f4e2
> > --- /dev/null
> > +++ b/lib/test_strscpy.c
> > @@ -0,0 +1,175 @@
> > +// SPDX-License-Identifier: GPL-2.0
> > +
> > +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
> > +
> > +#include <linux/init.h>
> > +#include <linux/kernel.h>
> > +#include <linux/module.h>
> > +#include <linux/printk.h>
> > +#include <linux/string.h>
> > +
> > +/*
> > + * Kernel module for testing 'strscpy' family of functions.
> > + */
> > +
> > +static unsigned total_tests __initdata;
> > +static unsigned failed_tests __initdata;
> > +
> > +static void __init do_test(int count, char *src, int expected,
> > +                          int chars, int terminator, int pad)
> > +{
> > +       char buf[6];
> 
> I would make this "6" a define, since you use it in more than one
> place. Actually... no, never mind. The other places can use
> "sizeof(buf)" instead of "6". Noted below...
> 
> But I'd add an explicit check for "expected + 1 < sizeof(buf)" just
> for future test-addition sanity.
> 
> > +       int written;
> > +       int poison;
> > +       int index;
> > +       int i;
> > +       const char POISON = 'z';
> > +
> > +       total_tests++;
> > +       memset(buf, POISON, sizeof(buf));
> > +
> > +       /* Verify the return value */
> > +
> 
> Needless blank line.
> 
> > +       written = strscpy_pad(buf, src, count);
> > +       if ((written) != (expected)) {
> > +               pr_err("%d != %d (written, expected)\n", written, expected);
> > +               goto fail;
> > +       }
> > +
> > +       /* Verify the state of the buffer */
> > +
> 
> Same.
> 
> > +       if (count && written == -E2BIG) {
> > +               if (strncmp(buf, src, count - 1) != 0) {
> > +                       pr_err("buffer state invalid for -E2BIG\n");
> > +                       goto fail;
> > +               }
> > +               if (buf[count - 1] != '\0') {
> > +                       pr_err("too big string is not null terminated correctly\n");
> > +                       goto fail;
> > +               }
> > +       }
> > +
> > +       /* Verify the copied content */
> > +       for (i = 0; i < chars; i++) {
> > +               if (buf[i] != src[i]) {
> > +                       pr_err("buf[i]==%c != src[i]==%c\n", buf[i], src[i]);
> > +                       goto fail;
> > +               }
> > +       }
> > +
> > +       /* Verify the null terminator */
> > +       if (terminator) {
> > +               if (buf[count - 1] != '\0') {
> > +                       pr_err("string is not null terminated correctly\n");
> > +                       goto fail;
> > +               }
> > +       }
> > +
> > +       /* Verify the padding */
> > +       for (i = 0; i < pad; i++) {
> > +               index = chars + terminator + i;
> > +               if (buf[index] != '\0') {
> > +                       pr_err("padding missing at index: %d\n", i);
> > +                       goto fail;
> > +               }
> > +       }
> > +
> > +       /* Verify the rest is left untouched */
> > +       poison = 6 - chars - terminator - pad;
> 
> instead of "6", use "sizeof(buf)".
> 
> > +       for (i = 0; i < poison; i++) {
> > +               index = 6 - 1 - i; /* Check from the end back */
> 
> Same.
> 
> > +               if (buf[index] != POISON) {
> > +                       pr_err("poison value missing at index: %d\n", i);
> > +                       goto fail;
> > +               }
> > +       }
> > +
> > +       return;
> > +fail:
> > +       pr_info("%s(%d, '%s', %d, %d, %d, %d)\n", __func__,
> > +               count, src, expected, chars, terminator, pad);
> > +       failed_tests++;
> > +}
> > +
> > +static void __init test_fully(void)
> > +{
> > +       /* do_test(count, src, expected, chars, terminator, pad) */
> > +
> > +       do_test(0, "a", -E2BIG, 0, 0, 0);
> > +       do_test(0, "", -E2BIG, 0, 0, 0);
> > +
> > +       do_test(1, "a", -E2BIG, 0, 1, 0);
> > +       do_test(1, "", 0, 0, 1, 0);
> > +
> > +       do_test(2, "ab", -E2BIG, 1, 1, 0);
> > +       do_test(2, "a", 1, 1, 1, 0);
> > +       do_test(2, "", 0, 0, 1, 1);
> > +
> > +       do_test(3, "abc", -E2BIG, 2, 1, 0);
> > +       do_test(3, "ab", 2, 2, 1, 0);
> > +       do_test(3, "a", 1, 1, 1, 1);
> > +       do_test(3, "", 0, 0, 1, 2);
> > +
> > +       do_test(4, "abcd", -E2BIG, 3, 1, 0);
> > +       do_test(4, "abc", 3, 3, 1, 0);
> > +       do_test(4, "ab", 2, 2, 1, 1);
> > +       do_test(4, "a", 1, 1, 1, 2);
> > +       do_test(4, "", 0, 0, 1, 3);
> > +}
> > +
> > +static void __init test_basic(void)
> > +{
> > +       char buf[6];
> > +       int written;
> > +
> > +       memset(buf, 'a', sizeof(buf));
> > +
> > +       total_tests++;
> > +       written = strscpy_pad(buf, "bb", 4);
> > +       if (written != 2)
> > +               failed_tests++;
> > +
> > +       /* Correctly copied */
> > +       total_tests++;
> > +       if (buf[0] != 'b' || buf[1] != 'b')
> > +               failed_tests++;
> > +
> > +       /* Correctly padded */
> > +       total_tests++;
> > +       if (buf[2] != '\0' || buf[3] != '\0')
> > +               failed_tests++;
> > +
> > +       /* Only touched what it was supposed to */
> > +       total_tests++;
> > +       if (buf[4] != 'a' || buf[5] != 'a')
> > +               failed_tests++;
> > +}
> 
> I don't think you need to keep "test_basic". Anything it tests can be
> rewritten in terms of do_test().
> 
> > +
> > +static int __init test_strscpy_init(void)
> > +{
> > +       pr_info("loaded.\n");
> > +
> > +       test_basic();
> > +       if (failed_tests)
> > +               goto out;
> > +
> > +       test_fully();
> > +
> > +out:
> > +       if (failed_tests == 0)
> > +               pr_info("all %u tests passed\n", total_tests);
> > +       else
> > +               pr_warn("failed %u out of %u tests\n", failed_tests, total_tests);
> > +
> > +       return failed_tests ? -EINVAL : 0;
> > +}
> > +module_init(test_strscpy_init);
> > +
> > +static void __exit test_strscpy_exit(void)
> > +{
> > +       pr_info("unloaded.\n");
> > +}
> > +module_exit(test_strscpy_exit);
> > +
> > +MODULE_AUTHOR("Tobin C. Harding <tobin@kernel.org>");
> > +MODULE_LICENSE("GPL");
> > --
> > 2.20.1
> >
> 
> Otherwise, yes, looks good!

Cool, thanks.  Will fix up as suggested and re-spin.

	Tobin
diff mbox series

Patch

diff --git a/include/linux/string.h b/include/linux/string.h
index 7927b875f80c..bfe95bf5d07e 100644
--- a/include/linux/string.h
+++ b/include/linux/string.h
@@ -31,6 +31,10 @@  size_t strlcpy(char *, const char *, size_t);
 #ifndef __HAVE_ARCH_STRSCPY
 ssize_t strscpy(char *, const char *, size_t);
 #endif
+
+/* Wraps calls to strscpy()/memset(), no arch specific code required */
+ssize_t strscpy_pad(char *dest, const char *src, size_t count);
+
 #ifndef __HAVE_ARCH_STRCAT
 extern char * strcat(char *, const char *);
 #endif
diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug
index d4df5b24d75e..fb629a0c6272 100644
--- a/lib/Kconfig.debug
+++ b/lib/Kconfig.debug
@@ -1805,6 +1805,9 @@  config TEST_HEXDUMP
 config TEST_STRING_HELPERS
 	tristate "Test functions located in the string_helpers module at runtime"
 
+config TEST_STRSCPY
+	tristate "Test strscpy*() family of functions  at runtime"
+
 config TEST_KSTRTOX
 	tristate "Test kstrto*() family of functions at runtime"
 
diff --git a/lib/Makefile b/lib/Makefile
index e1b59da71418..59519926cbc6 100644
--- a/lib/Makefile
+++ b/lib/Makefile
@@ -42,6 +42,7 @@  obj-y += bcd.o div64.o sort.o parser.o debug_locks.o random32.o \
 obj-$(CONFIG_STRING_SELFTEST) += test_string.o
 obj-y += string_helpers.o
 obj-$(CONFIG_TEST_STRING_HELPERS) += test-string_helpers.o
+obj-$(CONFIG_TEST_STRSCPY) += test_strscpy.o
 obj-y += hexdump.o
 obj-$(CONFIG_TEST_HEXDUMP) += test_hexdump.o
 obj-y += kstrtox.o
diff --git a/lib/string.c b/lib/string.c
index 38e4ca08e757..209444cb36d6 100644
--- a/lib/string.c
+++ b/lib/string.c
@@ -159,11 +159,9 @@  EXPORT_SYMBOL(strlcpy);
  * @src: Where to copy the string from
  * @count: Size of destination buffer
  *
- * Copy the string, or as much of it as fits, into the dest buffer.
- * The routine returns the number of characters copied (not including
- * the trailing NUL) or -E2BIG if the destination buffer wasn't big enough.
- * The behavior is undefined if the string buffers overlap.
- * The destination buffer is always NUL terminated, unless it's zero-sized.
+ * Copy the string, or as much of it as fits, into the dest buffer.  The
+ * behavior is undefined if the string buffers overlap.  The destination
+ * buffer is always NUL terminated, unless it's zero-sized.
  *
  * Preferred to strlcpy() since the API doesn't require reading memory
  * from the src string beyond the specified "count" bytes, and since
@@ -173,8 +171,10 @@  EXPORT_SYMBOL(strlcpy);
  *
  * Preferred to strncpy() since it always returns a valid string, and
  * doesn't unnecessarily force the tail of the destination buffer to be
- * zeroed.  If the zeroing is desired, it's likely cleaner to use strscpy()
- * with an overflow test, then just memset() the tail of the dest buffer.
+ * zeroed.  If zeroing is desired please use strscpy_pad().
+ *
+ * Return: The number of characters copied (not including the trailing
+ *         %NUL) or -E2BIG if the destination buffer wasn't big enough.
  */
 ssize_t strscpy(char *dest, const char *src, size_t count)
 {
@@ -237,6 +237,39 @@  ssize_t strscpy(char *dest, const char *src, size_t count)
 EXPORT_SYMBOL(strscpy);
 #endif
 
+/**
+ * strscpy_pad() - Copy a C-string into a sized buffer
+ * @dest: Where to copy the string to
+ * @src: Where to copy the string from
+ * @count: Size of destination buffer
+ *
+ * Copy the string, or as much of it as fits, into the dest buffer.  The
+ * behavior is undefined if the string buffers overlap.  The destination
+ * buffer is always NUL terminated, unless it's zero-sized.
+ *
+ * If the source string is shorter than the destination buffer, zeros
+ * the tail of the destination buffer.
+ *
+ * For full explanation of why you may want to consider using the
+ * 'strscpy' functions please see the function docstring for strscpy().
+ *
+ * Return: The number of characters copied (not including the trailing
+ *         %NUL) or -E2BIG if the destination buffer wasn't big enough.
+ */
+ssize_t strscpy_pad(char *dest, const char *src, size_t count)
+{
+	ssize_t written;
+
+	written = strscpy(dest, src, count);
+	if (written < 0 || written == count - 1)
+		return written;
+
+	memset(dest + written + 1, 0, count - written - 1);
+
+	return written;
+}
+EXPORT_SYMBOL(strscpy_pad);
+
 #ifndef __HAVE_ARCH_STRCAT
 /**
  * strcat - Append one %NUL-terminated string to another
diff --git a/lib/test_strscpy.c b/lib/test_strscpy.c
new file mode 100644
index 000000000000..5ec6a196f4e2
--- /dev/null
+++ b/lib/test_strscpy.c
@@ -0,0 +1,175 @@ 
+// SPDX-License-Identifier: GPL-2.0
+
+#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
+
+#include <linux/init.h>
+#include <linux/kernel.h>
+#include <linux/module.h>
+#include <linux/printk.h>
+#include <linux/string.h>
+
+/*
+ * Kernel module for testing 'strscpy' family of functions.
+ */
+
+static unsigned total_tests __initdata;
+static unsigned failed_tests __initdata;
+
+static void __init do_test(int count, char *src, int expected,
+			   int chars, int terminator, int pad)
+{
+	char buf[6];
+	int written;
+	int poison;
+	int index;
+	int i;
+	const char POISON = 'z';
+
+	total_tests++;
+	memset(buf, POISON, sizeof(buf));
+
+	/* Verify the return value */
+
+	written = strscpy_pad(buf, src, count);
+	if ((written) != (expected)) {
+		pr_err("%d != %d (written, expected)\n", written, expected);
+		goto fail;
+	}
+
+	/* Verify the state of the buffer */
+
+	if (count && written == -E2BIG) {
+		if (strncmp(buf, src, count - 1) != 0) {
+			pr_err("buffer state invalid for -E2BIG\n");
+			goto fail;
+		}
+		if (buf[count - 1] != '\0') {
+			pr_err("too big string is not null terminated correctly\n");
+			goto fail;
+		}
+	}
+
+	/* Verify the copied content */
+	for (i = 0; i < chars; i++) {
+		if (buf[i] != src[i]) {
+			pr_err("buf[i]==%c != src[i]==%c\n", buf[i], src[i]);
+			goto fail;
+		}
+	}
+
+	/* Verify the null terminator */
+	if (terminator) {
+		if (buf[count - 1] != '\0') {
+			pr_err("string is not null terminated correctly\n");
+			goto fail;
+		}
+	}
+
+	/* Verify the padding */
+	for (i = 0; i < pad; i++) {
+		index = chars + terminator + i;
+		if (buf[index] != '\0') {
+			pr_err("padding missing at index: %d\n", i);
+			goto fail;
+		}
+	}
+
+	/* Verify the rest is left untouched */
+	poison = 6 - chars - terminator - pad;
+	for (i = 0; i < poison; i++) {
+		index = 6 - 1 - i; /* Check from the end back */
+		if (buf[index] != POISON) {
+			pr_err("poison value missing at index: %d\n", i);
+			goto fail;
+		}
+	}
+
+	return;
+fail:
+	pr_info("%s(%d, '%s', %d, %d, %d, %d)\n", __func__,
+		count, src, expected, chars, terminator, pad);
+	failed_tests++;
+}
+
+static void __init test_fully(void)
+{
+	/* do_test(count, src, expected, chars, terminator, pad) */
+
+	do_test(0, "a", -E2BIG, 0, 0, 0);
+	do_test(0, "", -E2BIG, 0, 0, 0);
+
+	do_test(1, "a", -E2BIG, 0, 1, 0);
+	do_test(1, "", 0, 0, 1, 0);
+
+	do_test(2, "ab", -E2BIG, 1, 1, 0);
+	do_test(2, "a", 1, 1, 1, 0);
+	do_test(2, "", 0, 0, 1, 1);
+
+	do_test(3, "abc", -E2BIG, 2, 1, 0);
+	do_test(3, "ab", 2, 2, 1, 0);
+	do_test(3, "a", 1, 1, 1, 1);
+	do_test(3, "", 0, 0, 1, 2);
+
+	do_test(4, "abcd", -E2BIG, 3, 1, 0);
+	do_test(4, "abc", 3, 3, 1, 0);
+	do_test(4, "ab", 2, 2, 1, 1);
+	do_test(4, "a", 1, 1, 1, 2);
+	do_test(4, "", 0, 0, 1, 3);
+}
+
+static void __init test_basic(void)
+{
+	char buf[6];
+	int written;
+
+	memset(buf, 'a', sizeof(buf));
+
+	total_tests++;
+	written = strscpy_pad(buf, "bb", 4);
+	if (written != 2)
+		failed_tests++;
+
+	/* Correctly copied */
+	total_tests++;
+	if (buf[0] != 'b' || buf[1] != 'b')
+		failed_tests++;
+
+	/* Correctly padded */
+	total_tests++;
+	if (buf[2] != '\0' || buf[3] != '\0')
+		failed_tests++;
+
+	/* Only touched what it was supposed to */
+	total_tests++;
+	if (buf[4] != 'a' || buf[5] != 'a')
+		failed_tests++;
+}
+
+static int __init test_strscpy_init(void)
+{
+	pr_info("loaded.\n");
+
+	test_basic();
+	if (failed_tests)
+		goto out;
+
+	test_fully();
+
+out:
+	if (failed_tests == 0)
+		pr_info("all %u tests passed\n", total_tests);
+	else
+		pr_warn("failed %u out of %u tests\n", failed_tests, total_tests);
+
+	return failed_tests ? -EINVAL : 0;
+}
+module_init(test_strscpy_init);
+
+static void __exit test_strscpy_exit(void)
+{
+	pr_info("unloaded.\n");
+}
+module_exit(test_strscpy_exit);
+
+MODULE_AUTHOR("Tobin C. Harding <tobin@kernel.org>");
+MODULE_LICENSE("GPL");