@@ -1504,6 +1504,7 @@ static int check_version(unsigned int cmd, struct dm_ioctl __user *user)
static int copy_params(struct dm_ioctl __user *user, struct dm_ioctl **param)
{
struct dm_ioctl tmp, *dmi;
+ int secure_data;
if (copy_from_user(&tmp, user, sizeof(tmp) - sizeof(tmp.data)))
return -EFAULT;
@@ -1511,23 +1512,27 @@ static int copy_params(struct dm_ioctl __user *user, struct dm_ioctl **param)
if (tmp.data_size < (sizeof(tmp) - sizeof(tmp.data)))
return -EINVAL;
+ secure_data = tmp.flags & DM_SECURE_DATA_FLAG;
+
dmi = vmalloc(tmp.data_size);
- if (!dmi)
+ if (!dmi) {
+ if (secure_data && clear_user(user, tmp.data_size))
+ return -EFAULT;
return -ENOMEM;
+ }
if (copy_from_user(dmi, user, tmp.data_size))
goto bad;
/* Wipe the user buffer so we do not return it to userspace */
- if ((tmp.flags & DM_SECURE_DATA_FLAG) &&
- clear_user(user, tmp.data_size))
+ if (secure_data && clear_user(user, tmp.data_size))
goto bad;
*param = dmi;
return 0;
bad:
- if (tmp.flags & DM_SECURE_DATA_FLAG)
+ if (secure_data)
memset(dmi, 0, tmp.data_size);
vfree(dmi);
return -EFAULT;