diff mbox series

[5/6] accel/habanalabs/gaudi2: move HMMU page tables to device memory

Message ID 20240102150654.522555-5-ogabbay@kernel.org (mailing list archive)
State New, archived
Headers show
Series [1/6] accel/habanalabs: check failure of eventfd_signal | expand

Commit Message

Oded Gabbay Jan. 2, 2024, 3:06 p.m. UTC
From: Farah Kassabri <fkassabri@habana.ai>

Currently the HMMU page tables reside in the host memory,
which will cause host access from the device for every page walk.
This can affect PCIe bandwidth in certain scenarios.

To prevent that problem, HMMU page tables will be moved to the device
memory so the miss transaction will read the hops from there instead of
going to the host.

Signed-off-by: Farah Kassabri <fkassabri@habana.ai>
Reviewed-by: Oded Gabbay <ogabbay@kernel.org>
Signed-off-by: Oded Gabbay <ogabbay@kernel.org>
---
 drivers/accel/habanalabs/common/habanalabs.h  |  26 ++
 drivers/accel/habanalabs/common/hw_queue.c    |  17 +
 drivers/accel/habanalabs/common/mmu/Makefile  |   2 +-
 drivers/accel/habanalabs/common/mmu/mmu.c     | 223 ++++++++++-
 drivers/accel/habanalabs/common/mmu/mmu_v1.c  | 352 +++---------------
 drivers/accel/habanalabs/common/mmu/mmu_v2.c  | 338 +++++++++++++++++
 drivers/accel/habanalabs/gaudi/gaudi.c        |   1 +
 drivers/accel/habanalabs/gaudi2/gaudi2.c      | 245 ++++++++----
 drivers/accel/habanalabs/gaudi2/gaudi2P.h     |  12 +-
 .../include/hw_ip/mmu/mmu_general.h           |   2 +
 10 files changed, 836 insertions(+), 382 deletions(-)
 create mode 100644 drivers/accel/habanalabs/common/mmu/mmu_v2.c
diff mbox series

Patch

diff --git a/drivers/accel/habanalabs/common/habanalabs.h b/drivers/accel/habanalabs/common/habanalabs.h
index 253873315888..7397ce86b7f0 100644
--- a/drivers/accel/habanalabs/common/habanalabs.h
+++ b/drivers/accel/habanalabs/common/habanalabs.h
@@ -443,18 +443,22 @@  enum hl_collective_mode {
  *                  a CB handle can be provided for jobs on this queue.
  *                  Otherwise, a CB address must be provided.
  * @collective_mode: collective mode of current queue
+ * @q_dram_bd_address: PQ dram address, used when PQ need to reside in DRAM.
  * @driver_only: true if only the driver is allowed to send a job to this queue,
  *               false otherwise.
  * @binned: True if the queue is binned out and should not be used
  * @supports_sync_stream: True if queue supports sync stream
+ * @dram_bd: True if the bd should be copied to dram, needed for PQ which has been allocated on dram
  */
 struct hw_queue_properties {
 	enum hl_queue_type		type;
 	enum queue_cb_alloc_flags	cb_alloc_flags;
 	enum hl_collective_mode		collective_mode;
+	u64				q_dram_bd_address;
 	u8				driver_only;
 	u8				binned;
 	u8				supports_sync_stream;
+	u8				dram_bd;
 };
 
 /**
@@ -1052,6 +1056,8 @@  struct hl_encaps_signals_mgr {
  * @collective_mode: collective mode of current queue
  * @kernel_address: holds the queue's kernel virtual address.
  * @bus_address: holds the queue's DMA address.
+ * @pq_dram_address: hold the dram address when the PQ is allocated, used when dram_bd is true in
+ *                   queue properites.
  * @pi: holds the queue's pi value.
  * @ci: holds the queue's ci value, AS CALCULATED BY THE DRIVER (not real ci).
  * @hw_queue_id: the id of the H/W queue.
@@ -1061,6 +1067,7 @@  struct hl_encaps_signals_mgr {
  * @valid: is the queue valid (we have array of 32 queues, not all of them
  *         exist).
  * @supports_sync_stream: True if queue supports sync stream
+ * @dram_bd: True if the bd should be copied to dram, needed for PQ which has been allocated on dram
  */
 struct hl_hw_queue {
 	struct hl_cs_job			**shadow_queue;
@@ -1069,6 +1076,7 @@  struct hl_hw_queue {
 	enum hl_collective_mode			collective_mode;
 	void					*kernel_address;
 	dma_addr_t				bus_address;
+	u64					pq_dram_address;
 	u32					pi;
 	atomic_t				ci;
 	u32					hw_queue_id;
@@ -1077,6 +1085,7 @@  struct hl_hw_queue {
 	u16					int_queue_len;
 	u8					valid;
 	u8					supports_sync_stream;
+	u8					dram_bd;
 };
 
 /**
@@ -3889,6 +3898,7 @@  int hl_mmu_hr_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr, struct hl_mmu_hop_
 							struct hl_hr_mmu_funcs *hr_func);
 int hl_mmu_if_set_funcs(struct hl_device *hdev);
 void hl_mmu_v1_set_funcs(struct hl_device *hdev, struct hl_mmu_funcs *mmu);
+void hl_mmu_v2_set_funcs(struct hl_device *hdev, struct hl_mmu_funcs *mmu);
 void hl_mmu_v2_hr_set_funcs(struct hl_device *hdev, struct hl_mmu_funcs *mmu);
 int hl_mmu_va_to_pa(struct hl_ctx *ctx, u64 virt_addr, u64 *phys_addr);
 int hl_mmu_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr,
@@ -3896,6 +3906,22 @@  int hl_mmu_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr,
 u64 hl_mmu_scramble_addr(struct hl_device *hdev, u64 addr);
 u64 hl_mmu_descramble_addr(struct hl_device *hdev, u64 addr);
 bool hl_is_dram_va(struct hl_device *hdev, u64 virt_addr);
+struct pgt_info *hl_mmu_dr_get_pgt_info(struct hl_ctx *ctx, u64 hop_addr);
+void hl_mmu_dr_free_hop(struct hl_ctx *ctx, u64 hop_addr);
+void hl_mmu_dr_free_pgt_node(struct hl_ctx *ctx, struct pgt_info *pgt_info);
+u64 hl_mmu_dr_get_phys_hop0_addr(struct hl_ctx *ctx);
+u64 hl_mmu_dr_get_hop0_addr(struct hl_ctx *ctx);
+void hl_mmu_dr_write_pte(struct hl_ctx *ctx, u64 shadow_pte_addr, u64 val);
+void hl_mmu_dr_write_final_pte(struct hl_ctx *ctx, u64 shadow_pte_addr, u64 val);
+void hl_mmu_dr_clear_pte(struct hl_ctx *ctx, u64 pte_addr);
+u64 hl_mmu_dr_get_phys_addr(struct hl_ctx *ctx, u64 shadow_addr);
+void hl_mmu_dr_get_pte(struct hl_ctx *ctx, u64 hop_addr);
+int hl_mmu_dr_put_pte(struct hl_ctx *ctx, u64 hop_addr);
+u64 hl_mmu_dr_get_alloc_next_hop_addr(struct hl_ctx *ctx, u64 curr_pte, bool *is_new_hop);
+u64 hl_mmu_dr_alloc_hop(struct hl_ctx *ctx);
+void hl_mmu_dr_flush(struct hl_ctx *ctx);
+int hl_mmu_dr_init(struct hl_device *hdev);
+void hl_mmu_dr_fini(struct hl_device *hdev);
 
 int hl_fw_load_fw_to_device(struct hl_device *hdev, const char *fw_name,
 				void __iomem *dst, u32 src_offset, u32 size);
diff --git a/drivers/accel/habanalabs/common/hw_queue.c b/drivers/accel/habanalabs/common/hw_queue.c
index d0087c0ec48c..3d04a7507cce 100644
--- a/drivers/accel/habanalabs/common/hw_queue.c
+++ b/drivers/accel/habanalabs/common/hw_queue.c
@@ -84,6 +84,8 @@  void hl_hw_queue_submit_bd(struct hl_device *hdev, struct hl_hw_queue *q,
 		u32 ctl, u32 len, u64 ptr)
 {
 	struct hl_bd *bd;
+	u64 addr;
+	int i;
 
 	bd = q->kernel_address;
 	bd += hl_pi_2_offset(q->pi);
@@ -91,7 +93,16 @@  void hl_hw_queue_submit_bd(struct hl_device *hdev, struct hl_hw_queue *q,
 	bd->len = cpu_to_le32(len);
 	bd->ptr = cpu_to_le64(ptr);
 
+	if (q->dram_bd)
+		for (i = 0 ; i < 2 ; i++) {
+			addr = q->pq_dram_address +
+			((hl_pi_2_offset(q->pi) * sizeof(struct hl_bd))	+ (i * sizeof(u64)));
+			hdev->asic_funcs->access_dev_mem(hdev, PCI_REGION_DRAM,	addr,
+						(u64 *)(bd) + i, DEBUGFS_WRITE64);
+		}
+
 	q->pi = hl_queue_inc_ptr(q->pi);
+
 	hdev->asic_funcs->ring_doorbell(hdev, q->hw_queue_id, q->pi);
 }
 
@@ -1087,12 +1098,18 @@  int hl_hw_queues_create(struct hl_device *hdev)
 		q->supports_sync_stream =
 				asic->hw_queues_props[i].supports_sync_stream;
 		q->collective_mode = asic->hw_queues_props[i].collective_mode;
+		q->dram_bd = asic->hw_queues_props[i].dram_bd;
+
 		rc = queue_init(hdev, q, i);
 		if (rc) {
 			dev_err(hdev->dev,
 				"failed to initialize queue %d\n", i);
 			goto release_queues;
 		}
+
+		/* Set DRAM PQ address for the queue if it should be at DRAM */
+		if (q->dram_bd)
+			q->pq_dram_address = asic->hw_queues_props[i].q_dram_bd_address;
 	}
 
 	return 0;
diff --git a/drivers/accel/habanalabs/common/mmu/Makefile b/drivers/accel/habanalabs/common/mmu/Makefile
index 1806c524e04a..f4b815bf4f7d 100644
--- a/drivers/accel/habanalabs/common/mmu/Makefile
+++ b/drivers/accel/habanalabs/common/mmu/Makefile
@@ -1,3 +1,3 @@ 
 # SPDX-License-Identifier: GPL-2.0-only
 HL_COMMON_MMU_FILES := common/mmu/mmu.o common/mmu/mmu_v1.o \
-			common/mmu/mmu_v2_hr.o
+			common/mmu/mmu_v2.o common/mmu/mmu_v2_hr.o
diff --git a/drivers/accel/habanalabs/common/mmu/mmu.c b/drivers/accel/habanalabs/common/mmu/mmu.c
index b654302a68fc..fa7919dba783 100644
--- a/drivers/accel/habanalabs/common/mmu/mmu.c
+++ b/drivers/accel/habanalabs/common/mmu/mmu.c
@@ -585,6 +585,8 @@  int hl_mmu_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr,
 
 int hl_mmu_if_set_funcs(struct hl_device *hdev)
 {
+	struct asic_fixed_properties *prop = &hdev->asic_prop;
+
 	if (hdev->mmu_disable)
 		return 0;
 
@@ -597,8 +599,9 @@  int hl_mmu_if_set_funcs(struct hl_device *hdev)
 	case ASIC_GAUDI2:
 	case ASIC_GAUDI2B:
 	case ASIC_GAUDI2C:
-		/* MMUs in Gaudi2 are always host resident */
-		hl_mmu_v2_hr_set_funcs(hdev, &hdev->mmu_func[MMU_HR_PGT]);
+		hl_mmu_v2_set_funcs(hdev, &hdev->mmu_func[MMU_DR_PGT]);
+		if (prop->pmmu.host_resident)
+			hl_mmu_v2_hr_set_funcs(hdev, &hdev->mmu_func[MMU_HR_PGT]);
 		break;
 	default:
 		dev_err(hdev->dev, "Unrecognized ASIC type %d\n",
@@ -1209,3 +1212,219 @@  int hl_mmu_hr_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr, struct hl_mmu_hop_
 	return 0;
 }
 
+struct pgt_info *hl_mmu_dr_get_pgt_info(struct hl_ctx *ctx, u64 hop_addr)
+{
+	struct pgt_info *pgt_info = NULL;
+
+	hash_for_each_possible(ctx->mmu_shadow_hash, pgt_info, node,
+			(unsigned long) hop_addr)
+		if (hop_addr == pgt_info->shadow_addr)
+			break;
+
+	return pgt_info;
+}
+
+void hl_mmu_dr_free_hop(struct hl_ctx *ctx, u64 hop_addr)
+{
+	struct pgt_info *pgt_info = hl_mmu_dr_get_pgt_info(ctx, hop_addr);
+
+	hl_mmu_dr_free_pgt_node(ctx, pgt_info);
+}
+
+void hl_mmu_dr_free_pgt_node(struct hl_ctx *ctx, struct pgt_info *pgt_info)
+{
+	struct hl_device *hdev = ctx->hdev;
+
+	gen_pool_free(hdev->mmu_priv.dr.mmu_pgt_pool, pgt_info->phys_addr,
+			hdev->asic_prop.mmu_hop_table_size);
+	hash_del(&pgt_info->node);
+	kfree((u64 *) (uintptr_t) pgt_info->shadow_addr);
+	kfree(pgt_info);
+}
+
+u64 hl_mmu_dr_get_phys_hop0_addr(struct hl_ctx *ctx)
+{
+	return ctx->hdev->asic_prop.mmu_pgt_addr +
+			(ctx->asid * ctx->hdev->asic_prop.mmu_hop_table_size);
+}
+
+u64 hl_mmu_dr_get_hop0_addr(struct hl_ctx *ctx)
+{
+	return (u64) (uintptr_t) ctx->hdev->mmu_priv.dr.mmu_shadow_hop0 +
+			(ctx->asid * ctx->hdev->asic_prop.mmu_hop_table_size);
+}
+
+u64 hl_mmu_dr_get_phys_addr(struct hl_ctx *ctx, u64 shadow_addr)
+{
+	u64 page_mask = ctx->hdev->asic_prop.mmu_hop_table_size - 1;
+	u64 shadow_hop_addr = shadow_addr & (~page_mask);
+	u64 pte_offset = shadow_addr & page_mask;
+	u64 phys_hop_addr;
+
+	if (shadow_hop_addr != hl_mmu_dr_get_hop0_addr(ctx))
+		phys_hop_addr = hl_mmu_dr_get_pgt_info(ctx, shadow_hop_addr)->phys_addr;
+	else
+		phys_hop_addr = hl_mmu_dr_get_phys_hop0_addr(ctx);
+
+	return phys_hop_addr + pte_offset;
+}
+
+void hl_mmu_dr_write_pte(struct hl_ctx *ctx, u64 shadow_pte_addr, u64 val)
+{
+	u64 phys_val = hl_mmu_dr_get_phys_addr(ctx, val);
+
+	ctx->hdev->asic_funcs->write_pte(ctx->hdev, hl_mmu_dr_get_phys_addr(ctx, shadow_pte_addr),
+					phys_val);
+
+	*(u64 *) (uintptr_t) shadow_pte_addr = val;
+}
+
+void hl_mmu_dr_write_final_pte(struct hl_ctx *ctx, u64 shadow_pte_addr, u64 val)
+{
+	ctx->hdev->asic_funcs->write_pte(ctx->hdev,
+				hl_mmu_dr_get_phys_addr(ctx, shadow_pte_addr), val);
+	*(u64 *) (uintptr_t) shadow_pte_addr = val;
+}
+
+void hl_mmu_dr_clear_pte(struct hl_ctx *ctx, u64 pte_addr)
+{
+	hl_mmu_dr_write_final_pte(ctx, pte_addr, 0);
+}
+
+void hl_mmu_dr_get_pte(struct hl_ctx *ctx, u64 hop_addr)
+{
+	hl_mmu_dr_get_pgt_info(ctx, hop_addr)->num_of_ptes++;
+}
+
+int hl_mmu_dr_put_pte(struct hl_ctx *ctx, u64 hop_addr)
+{
+	struct pgt_info *pgt_info = hl_mmu_dr_get_pgt_info(ctx, hop_addr);
+	int num_of_ptes_left;
+
+	pgt_info->num_of_ptes--;
+
+	/*
+	 * Need to save the number of ptes left because hl_mmu_free_hop might free
+	 * the pgt_info
+	 */
+	num_of_ptes_left = pgt_info->num_of_ptes;
+	if (!num_of_ptes_left)
+		hl_mmu_dr_free_pgt_node(ctx, pgt_info);
+
+	return num_of_ptes_left;
+}
+
+u64 hl_mmu_dr_alloc_hop(struct hl_ctx *ctx)
+{
+	struct hl_device *hdev = ctx->hdev;
+	struct asic_fixed_properties *prop = &hdev->asic_prop;
+	struct pgt_info *pgt_info;
+	u64 phys_addr, shadow_addr;
+
+	pgt_info = kmalloc(sizeof(*pgt_info), GFP_KERNEL);
+	if (!pgt_info)
+		return ULLONG_MAX;
+
+	phys_addr = (u64) gen_pool_alloc(hdev->mmu_priv.dr.mmu_pgt_pool,
+					prop->mmu_hop_table_size);
+	if (!phys_addr) {
+		dev_err(hdev->dev, "failed to allocate page\n");
+		goto pool_add_err;
+	}
+
+	shadow_addr = (u64) (uintptr_t) kzalloc(prop->mmu_hop_table_size,
+						GFP_KERNEL);
+	if (!shadow_addr)
+		goto shadow_err;
+
+	pgt_info->phys_addr = phys_addr;
+	pgt_info->shadow_addr = shadow_addr;
+	pgt_info->ctx = ctx;
+	pgt_info->num_of_ptes = 0;
+	hash_add(ctx->mmu_shadow_hash, &pgt_info->node, shadow_addr);
+
+	return shadow_addr;
+
+shadow_err:
+	gen_pool_free(hdev->mmu_priv.dr.mmu_pgt_pool,
+			phys_addr, prop->mmu_hop_table_size);
+pool_add_err:
+	kfree(pgt_info);
+
+	return ULLONG_MAX;
+}
+
+u64 hl_mmu_dr_get_alloc_next_hop_addr(struct hl_ctx *ctx, u64 curr_pte, bool *is_new_hop)
+{
+	u64 hop_addr = hl_mmu_get_next_hop_addr(ctx, curr_pte);
+
+	if (hop_addr == ULLONG_MAX) {
+		hop_addr = hl_mmu_dr_alloc_hop(ctx);
+		*is_new_hop = (hop_addr != ULLONG_MAX);
+	}
+
+	return hop_addr;
+}
+
+void hl_mmu_dr_flush(struct hl_ctx *ctx)
+{
+	/* flush all writes from all cores to reach PCI */
+	mb();
+	ctx->hdev->asic_funcs->read_pte(ctx->hdev, hl_mmu_dr_get_phys_hop0_addr(ctx));
+}
+
+int hl_mmu_dr_init(struct hl_device *hdev)
+{
+	struct asic_fixed_properties *prop = &hdev->asic_prop;
+	int rc;
+
+	hdev->mmu_priv.dr.mmu_pgt_pool =
+			gen_pool_create(__ffs(prop->mmu_hop_table_size), -1);
+
+	if (!hdev->mmu_priv.dr.mmu_pgt_pool) {
+		dev_err(hdev->dev, "Failed to create page gen pool\n");
+		return -ENOMEM;
+	}
+
+	rc = gen_pool_add(hdev->mmu_priv.dr.mmu_pgt_pool, prop->mmu_pgt_addr +
+			prop->mmu_hop0_tables_total_size,
+			prop->dmmu.pgt_size - prop->mmu_hop0_tables_total_size,
+			-1);
+	if (rc) {
+		dev_err(hdev->dev, "Failed to add memory to page gen pool\n");
+		goto err_pool_add;
+	}
+
+	hdev->mmu_priv.dr.mmu_shadow_hop0 = kvcalloc(prop->max_asid,
+						prop->mmu_hop_table_size, GFP_KERNEL);
+	if (ZERO_OR_NULL_PTR(hdev->mmu_priv.dr.mmu_shadow_hop0)) {
+		rc = -ENOMEM;
+		goto err_pool_add;
+	}
+
+	/* MMU H/W init will be done in device hw_init() */
+
+	return 0;
+
+err_pool_add:
+	gen_pool_destroy(hdev->mmu_priv.dr.mmu_pgt_pool);
+
+	return rc;
+}
+
+void hl_mmu_dr_fini(struct hl_device *hdev)
+{
+	/* MMU H/W fini was already done in device hw_fini() */
+
+	if (ZERO_OR_NULL_PTR(hdev->mmu_priv.dr.mmu_shadow_hop0))
+		return;
+
+	kvfree(hdev->mmu_priv.dr.mmu_shadow_hop0);
+	gen_pool_destroy(hdev->mmu_priv.dr.mmu_pgt_pool);
+
+	/* Make sure that if we arrive here again without init was
+	 * called we won't cause kernel panic. This can happen for
+	 * example if we fail during hard reset code at certain points
+	 */
+	hdev->mmu_priv.dr.mmu_shadow_hop0 = NULL;
+}
diff --git a/drivers/accel/habanalabs/common/mmu/mmu_v1.c b/drivers/accel/habanalabs/common/mmu/mmu_v1.c
index d925dc4dd097..64b5c8fbb166 100644
--- a/drivers/accel/habanalabs/common/mmu/mmu_v1.c
+++ b/drivers/accel/habanalabs/common/mmu/mmu_v1.c
@@ -12,166 +12,6 @@ 
 
 #define MMU_V1_MAX_HOPS	(MMU_HOP4 + 1)
 
-static inline u64 get_phys_addr(struct hl_ctx *ctx, u64 shadow_addr);
-
-static struct pgt_info *get_pgt_info(struct hl_ctx *ctx, u64 hop_addr)
-{
-	struct pgt_info *pgt_info = NULL;
-
-	hash_for_each_possible(ctx->mmu_shadow_hash, pgt_info, node,
-				(unsigned long) hop_addr)
-		if (hop_addr == pgt_info->shadow_addr)
-			break;
-
-	return pgt_info;
-}
-
-static void _free_hop(struct hl_ctx *ctx, struct pgt_info *pgt_info)
-{
-	struct hl_device *hdev = ctx->hdev;
-
-	gen_pool_free(hdev->mmu_priv.dr.mmu_pgt_pool, pgt_info->phys_addr,
-			hdev->asic_prop.mmu_hop_table_size);
-	hash_del(&pgt_info->node);
-	kfree((u64 *) (uintptr_t) pgt_info->shadow_addr);
-	kfree(pgt_info);
-}
-
-static void free_hop(struct hl_ctx *ctx, u64 hop_addr)
-{
-	struct pgt_info *pgt_info = get_pgt_info(ctx, hop_addr);
-
-	_free_hop(ctx, pgt_info);
-}
-
-static u64 alloc_hop(struct hl_ctx *ctx)
-{
-	struct hl_device *hdev = ctx->hdev;
-	struct asic_fixed_properties *prop = &hdev->asic_prop;
-	struct pgt_info *pgt_info;
-	u64 phys_addr, shadow_addr;
-
-	pgt_info = kmalloc(sizeof(*pgt_info), GFP_KERNEL);
-	if (!pgt_info)
-		return ULLONG_MAX;
-
-	phys_addr = (u64) gen_pool_alloc(hdev->mmu_priv.dr.mmu_pgt_pool,
-					prop->mmu_hop_table_size);
-	if (!phys_addr) {
-		dev_err(hdev->dev, "failed to allocate page\n");
-		goto pool_add_err;
-	}
-
-	shadow_addr = (u64) (uintptr_t) kzalloc(prop->mmu_hop_table_size,
-						GFP_KERNEL);
-	if (!shadow_addr)
-		goto shadow_err;
-
-	pgt_info->phys_addr = phys_addr;
-	pgt_info->shadow_addr = shadow_addr;
-	pgt_info->ctx = ctx;
-	pgt_info->num_of_ptes = 0;
-	hash_add(ctx->mmu_shadow_hash, &pgt_info->node, shadow_addr);
-
-	return shadow_addr;
-
-shadow_err:
-	gen_pool_free(hdev->mmu_priv.dr.mmu_pgt_pool, phys_addr,
-			prop->mmu_hop_table_size);
-pool_add_err:
-	kfree(pgt_info);
-
-	return ULLONG_MAX;
-}
-
-static inline u64 get_phys_hop0_addr(struct hl_ctx *ctx)
-{
-	return ctx->hdev->asic_prop.mmu_pgt_addr +
-			(ctx->asid * ctx->hdev->asic_prop.mmu_hop_table_size);
-}
-
-static inline u64 get_hop0_addr(struct hl_ctx *ctx)
-{
-	return (u64) (uintptr_t) ctx->hdev->mmu_priv.dr.mmu_shadow_hop0 +
-			(ctx->asid * ctx->hdev->asic_prop.mmu_hop_table_size);
-}
-
-static void flush(struct hl_ctx *ctx)
-{
-	/* flush all writes from all cores to reach PCI */
-	mb();
-	ctx->hdev->asic_funcs->read_pte(ctx->hdev, get_phys_hop0_addr(ctx));
-}
-
-/* transform the value to physical address when writing to H/W */
-static inline void write_pte(struct hl_ctx *ctx, u64 shadow_pte_addr, u64 val)
-{
-	/*
-	 * The value to write is actually the address of the next shadow hop +
-	 * flags at the 12 LSBs.
-	 * Hence in order to get the value to write to the physical PTE, we
-	 * clear the 12 LSBs and translate the shadow hop to its associated
-	 * physical hop, and add back the original 12 LSBs.
-	 */
-	u64 phys_val = get_phys_addr(ctx, val & HOP_PHYS_ADDR_MASK) |
-				(val & FLAGS_MASK);
-
-	ctx->hdev->asic_funcs->write_pte(ctx->hdev,
-					get_phys_addr(ctx, shadow_pte_addr),
-					phys_val);
-
-	*(u64 *) (uintptr_t) shadow_pte_addr = val;
-}
-
-/* do not transform the value to physical address when writing to H/W */
-static inline void write_final_pte(struct hl_ctx *ctx, u64 shadow_pte_addr,
-					u64 val)
-{
-	ctx->hdev->asic_funcs->write_pte(ctx->hdev,
-					get_phys_addr(ctx, shadow_pte_addr),
-					val);
-	*(u64 *) (uintptr_t) shadow_pte_addr = val;
-}
-
-/* clear the last and present bits */
-static inline void clear_pte(struct hl_ctx *ctx, u64 pte_addr)
-{
-	/* no need to transform the value to physical address */
-	write_final_pte(ctx, pte_addr, 0);
-}
-
-static inline void get_pte(struct hl_ctx *ctx, u64 hop_addr)
-{
-	get_pgt_info(ctx, hop_addr)->num_of_ptes++;
-}
-
-/*
- * put_pte - decrement the num of ptes and free the hop if possible
- *
- * @ctx: pointer to the context structure
- * @hop_addr: addr of the hop
- *
- * This function returns the number of ptes left on this hop. If the number is
- * 0, it means the pte was freed.
- */
-static inline int put_pte(struct hl_ctx *ctx, u64 hop_addr)
-{
-	struct pgt_info *pgt_info = get_pgt_info(ctx, hop_addr);
-	int num_of_ptes_left;
-
-	pgt_info->num_of_ptes--;
-
-	/*
-	 * Need to save the number of ptes left because free_hop might free
-	 * the pgt_info
-	 */
-	num_of_ptes_left = pgt_info->num_of_ptes;
-	if (!num_of_ptes_left)
-		_free_hop(ctx, pgt_info);
-
-	return num_of_ptes_left;
-}
-
 static inline u64 get_hop_pte_addr(struct hl_ctx *ctx, struct hl_mmu_properties *mmu_prop,
 					u64 *hop_addr_arr, u64 virt_addr, enum mmu_hop_num hop_idx)
 {
@@ -183,35 +23,6 @@  static inline u64 get_hop_pte_addr(struct hl_ctx *ctx, struct hl_mmu_properties
 			ctx->hdev->asic_prop.mmu_pte_size * ((virt_addr & mask) >> shift);
 }
 
-static inline u64 get_alloc_next_hop_addr(struct hl_ctx *ctx, u64 curr_pte,
-						bool *is_new_hop)
-{
-	u64 hop_addr = hl_mmu_get_next_hop_addr(ctx, curr_pte);
-
-	if (hop_addr == ULLONG_MAX) {
-		hop_addr = alloc_hop(ctx);
-		*is_new_hop = (hop_addr != ULLONG_MAX);
-	}
-
-	return hop_addr;
-}
-
-/* translates shadow address inside hop to a physical address */
-static inline u64 get_phys_addr(struct hl_ctx *ctx, u64 shadow_addr)
-{
-	u64 page_mask = (ctx->hdev->asic_prop.mmu_hop_table_size - 1);
-	u64 shadow_hop_addr = shadow_addr & ~page_mask;
-	u64 pte_offset = shadow_addr & page_mask;
-	u64 phys_hop_addr;
-
-	if (shadow_hop_addr != get_hop0_addr(ctx))
-		phys_hop_addr = get_pgt_info(ctx, shadow_hop_addr)->phys_addr;
-	else
-		phys_hop_addr = get_phys_hop0_addr(ctx);
-
-	return phys_hop_addr + pte_offset;
-}
-
 static int dram_default_mapping_init(struct hl_ctx *ctx)
 {
 	struct hl_device *hdev = ctx->hdev;
@@ -236,9 +47,9 @@  static int dram_default_mapping_init(struct hl_ctx *ctx)
 	if (!ctx->dram_default_hops)
 		return -ENOMEM;
 
-	hop0_addr = get_hop0_addr(ctx);
+	hop0_addr = hl_mmu_dr_get_hop0_addr(ctx);
 
-	hop1_addr = alloc_hop(ctx);
+	hop1_addr = hl_mmu_dr_alloc_hop(ctx);
 	if (hop1_addr == ULLONG_MAX) {
 		dev_err(hdev->dev, "failed to alloc hop 1\n");
 		rc = -ENOMEM;
@@ -247,7 +58,7 @@  static int dram_default_mapping_init(struct hl_ctx *ctx)
 
 	ctx->dram_default_hops[total_hops - 1] = hop1_addr;
 
-	hop2_addr = alloc_hop(ctx);
+	hop2_addr = hl_mmu_dr_alloc_hop(ctx);
 	if (hop2_addr == ULLONG_MAX) {
 		dev_err(hdev->dev, "failed to alloc hop 2\n");
 		rc = -ENOMEM;
@@ -257,7 +68,7 @@  static int dram_default_mapping_init(struct hl_ctx *ctx)
 	ctx->dram_default_hops[total_hops - 2] = hop2_addr;
 
 	for (i = 0 ; i < num_of_hop3 ; i++) {
-		ctx->dram_default_hops[i] = alloc_hop(ctx);
+		ctx->dram_default_hops[i] = hl_mmu_dr_alloc_hop(ctx);
 		if (ctx->dram_default_hops[i] == ULLONG_MAX) {
 			dev_err(hdev->dev, "failed to alloc hop 3, i: %d\n", i);
 			rc = -ENOMEM;
@@ -268,18 +79,18 @@  static int dram_default_mapping_init(struct hl_ctx *ctx)
 
 	/* need only pte 0 in hops 0 and 1 */
 	pte_val = (hop1_addr & HOP_PHYS_ADDR_MASK) | PAGE_PRESENT_MASK;
-	write_pte(ctx, hop0_addr, pte_val);
+	hl_mmu_dr_write_pte(ctx, hop0_addr, pte_val);
 
 	pte_val = (hop2_addr & HOP_PHYS_ADDR_MASK) | PAGE_PRESENT_MASK;
-	write_pte(ctx, hop1_addr, pte_val);
-	get_pte(ctx, hop1_addr);
+	hl_mmu_dr_write_pte(ctx, hop1_addr, pte_val);
+	hl_mmu_dr_get_pte(ctx, hop1_addr);
 
 	hop2_pte_addr = hop2_addr;
 	for (i = 0 ; i < num_of_hop3 ; i++) {
 		pte_val = (ctx->dram_default_hops[i] & HOP_PHYS_ADDR_MASK) |
 				PAGE_PRESENT_MASK;
-		write_pte(ctx, hop2_pte_addr, pte_val);
-		get_pte(ctx, hop2_addr);
+		hl_mmu_dr_write_pte(ctx, hop2_pte_addr, pte_val);
+		hl_mmu_dr_get_pte(ctx, hop2_addr);
 		hop2_pte_addr += HL_PTE_SIZE;
 	}
 
@@ -289,23 +100,23 @@  static int dram_default_mapping_init(struct hl_ctx *ctx)
 	for (i = 0 ; i < num_of_hop3 ; i++) {
 		hop3_pte_addr = ctx->dram_default_hops[i];
 		for (j = 0 ; j < HOP_PTE_ENTRIES_512 ; j++) {
-			write_final_pte(ctx, hop3_pte_addr, pte_val);
-			get_pte(ctx, ctx->dram_default_hops[i]);
+			hl_mmu_dr_write_final_pte(ctx, hop3_pte_addr, pte_val);
+			hl_mmu_dr_get_pte(ctx, ctx->dram_default_hops[i]);
 			hop3_pte_addr += HL_PTE_SIZE;
 		}
 	}
 
-	flush(ctx);
+	hl_mmu_dr_flush(ctx);
 
 	return 0;
 
 hop3_err:
 	for (i = 0 ; i < hop3_allocated ; i++)
-		free_hop(ctx, ctx->dram_default_hops[i]);
+		hl_mmu_dr_free_hop(ctx, ctx->dram_default_hops[i]);
 
-	free_hop(ctx, hop2_addr);
+	hl_mmu_dr_free_hop(ctx, hop2_addr);
 hop2_err:
-	free_hop(ctx, hop1_addr);
+	hl_mmu_dr_free_hop(ctx, hop1_addr);
 hop1_err:
 	kfree(ctx->dram_default_hops);
 
@@ -329,7 +140,7 @@  static void dram_default_mapping_fini(struct hl_ctx *ctx)
 	do_div(num_of_hop3, prop->dram_page_size);
 	do_div(num_of_hop3, HOP_PTE_ENTRIES_512);
 
-	hop0_addr = get_hop0_addr(ctx);
+	hop0_addr = hl_mmu_dr_get_hop0_addr(ctx);
 	/* add hop1 and hop2 */
 	total_hops = num_of_hop3 + 2;
 	hop1_addr = ctx->dram_default_hops[total_hops - 1];
@@ -338,101 +149,26 @@  static void dram_default_mapping_fini(struct hl_ctx *ctx)
 	for (i = 0 ; i < num_of_hop3 ; i++) {
 		hop3_pte_addr = ctx->dram_default_hops[i];
 		for (j = 0 ; j < HOP_PTE_ENTRIES_512 ; j++) {
-			clear_pte(ctx, hop3_pte_addr);
-			put_pte(ctx, ctx->dram_default_hops[i]);
+			hl_mmu_dr_clear_pte(ctx, hop3_pte_addr);
+			hl_mmu_dr_put_pte(ctx, ctx->dram_default_hops[i]);
 			hop3_pte_addr += HL_PTE_SIZE;
 		}
 	}
 
 	hop2_pte_addr = hop2_addr;
 	for (i = 0 ; i < num_of_hop3 ; i++) {
-		clear_pte(ctx, hop2_pte_addr);
-		put_pte(ctx, hop2_addr);
+		hl_mmu_dr_clear_pte(ctx, hop2_pte_addr);
+		hl_mmu_dr_put_pte(ctx, hop2_addr);
 		hop2_pte_addr += HL_PTE_SIZE;
 	}
 
-	clear_pte(ctx, hop1_addr);
-	put_pte(ctx, hop1_addr);
-	clear_pte(ctx, hop0_addr);
+	hl_mmu_dr_clear_pte(ctx, hop1_addr);
+	hl_mmu_dr_put_pte(ctx, hop1_addr);
+	hl_mmu_dr_clear_pte(ctx, hop0_addr);
 
 	kfree(ctx->dram_default_hops);
 
-	flush(ctx);
-}
-
-/**
- * hl_mmu_v1_init() - initialize the MMU module.
- * @hdev: habanalabs device structure.
- *
- * This function does the following:
- * - Create a pool of pages for pgt_infos.
- * - Create a shadow table for pgt
- *
- * Return: 0 for success, non-zero for failure.
- */
-static int hl_mmu_v1_init(struct hl_device *hdev)
-{
-	struct asic_fixed_properties *prop = &hdev->asic_prop;
-	int rc;
-
-	hdev->mmu_priv.dr.mmu_pgt_pool =
-			gen_pool_create(__ffs(prop->mmu_hop_table_size), -1);
-
-	if (!hdev->mmu_priv.dr.mmu_pgt_pool) {
-		dev_err(hdev->dev, "Failed to create page gen pool\n");
-		return -ENOMEM;
-	}
-
-	rc = gen_pool_add(hdev->mmu_priv.dr.mmu_pgt_pool, prop->mmu_pgt_addr +
-			prop->mmu_hop0_tables_total_size,
-			prop->mmu_pgt_size - prop->mmu_hop0_tables_total_size,
-			-1);
-	if (rc) {
-		dev_err(hdev->dev, "Failed to add memory to page gen pool\n");
-		goto err_pool_add;
-	}
-
-	hdev->mmu_priv.dr.mmu_shadow_hop0 = kvcalloc(prop->max_asid, prop->mmu_hop_table_size,
-										GFP_KERNEL);
-	if (ZERO_OR_NULL_PTR(hdev->mmu_priv.dr.mmu_shadow_hop0)) {
-		rc = -ENOMEM;
-		goto err_pool_add;
-	}
-
-	/* MMU H/W init will be done in device hw_init() */
-
-	return 0;
-
-err_pool_add:
-	gen_pool_destroy(hdev->mmu_priv.dr.mmu_pgt_pool);
-
-	return rc;
-}
-
-/**
- * hl_mmu_v1_fini() - release the MMU module.
- * @hdev: habanalabs device structure.
- *
- * This function does the following:
- * - Disable MMU in H/W.
- * - Free the pgt_infos pool.
- *
- * All contexts should be freed before calling this function.
- */
-static void hl_mmu_v1_fini(struct hl_device *hdev)
-{
-	/* MMU H/W fini was already done in device hw_fini() */
-
-	if (!ZERO_OR_NULL_PTR(hdev->mmu_priv.dr.mmu_shadow_hop0)) {
-		kvfree(hdev->mmu_priv.dr.mmu_shadow_hop0);
-		gen_pool_destroy(hdev->mmu_priv.dr.mmu_pgt_pool);
-
-		/* Make sure that if we arrive here again without init was
-		 * called we won't cause kernel panic. This can happen for
-		 * example if we fail during hard reset code at certain points
-		 */
-		hdev->mmu_priv.dr.mmu_shadow_hop0 = NULL;
-	}
+	hl_mmu_dr_flush(ctx);
 }
 
 /**
@@ -476,7 +212,7 @@  static void hl_mmu_v1_ctx_fini(struct hl_ctx *ctx)
 		dev_err_ratelimited(hdev->dev,
 			"pgt_info of addr 0x%llx of asid %d was not destroyed, num_ptes: %d\n",
 			pgt_info->phys_addr, ctx->asid, pgt_info->num_of_ptes);
-		_free_hop(ctx, pgt_info);
+		hl_mmu_dr_free_pgt_node(ctx, pgt_info);
 	}
 }
 
@@ -495,7 +231,7 @@  static int hl_mmu_v1_unmap(struct hl_ctx *ctx,
 
 	for (hop_idx = MMU_HOP0; hop_idx < MMU_HOP4; hop_idx++) {
 		if (hop_idx == MMU_HOP0) {
-			hop_addr[hop_idx] = get_hop0_addr(ctx);
+			hop_addr[hop_idx] = hl_mmu_dr_get_hop0_addr(ctx);
 		} else {
 			hop_addr[hop_idx] = hl_mmu_get_next_hop_addr(ctx, curr_pte);
 			if (hop_addr[hop_idx] == ULLONG_MAX)
@@ -546,30 +282,30 @@  static int hl_mmu_v1_unmap(struct hl_ctx *ctx,
 		}
 
 		hop_idx = MMU_HOP3;
-		write_final_pte(ctx, hop_pte_addr[hop_idx], default_pte);
-		put_pte(ctx, hop_addr[hop_idx]);
+		hl_mmu_dr_write_final_pte(ctx, hop_pte_addr[hop_idx], default_pte);
+		hl_mmu_dr_put_pte(ctx, hop_addr[hop_idx]);
 	} else {
 		if (!(curr_pte & PAGE_PRESENT_MASK))
 			goto not_mapped;
 
 		if (hop_addr[MMU_HOP4])
-			clear_pte(ctx, hop_pte_addr[MMU_HOP4]);
+			hl_mmu_dr_clear_pte(ctx, hop_pte_addr[MMU_HOP4]);
 		else
-			clear_pte(ctx, hop_pte_addr[MMU_HOP3]);
+			hl_mmu_dr_clear_pte(ctx, hop_pte_addr[MMU_HOP3]);
 
-		if (hop_addr[MMU_HOP4] && !put_pte(ctx, hop_addr[MMU_HOP4]))
+		if (hop_addr[MMU_HOP4] && !hl_mmu_dr_put_pte(ctx, hop_addr[MMU_HOP4]))
 			clear_hop3 = true;
 
 		if (!clear_hop3)
 			goto mapped;
 
 		for (hop_idx = MMU_HOP3; hop_idx >= 0; hop_idx--) {
-			clear_pte(ctx, hop_pte_addr[hop_idx]);
+			hl_mmu_dr_clear_pte(ctx, hop_pte_addr[hop_idx]);
 
 			if (hop_idx == MMU_HOP0)
 				break;
 
-			if (put_pte(ctx, hop_addr[hop_idx]))
+			if (hl_mmu_dr_put_pte(ctx, hop_addr[hop_idx]))
 				goto mapped;
 		}
 	}
@@ -616,10 +352,10 @@  static int hl_mmu_v1_map(struct hl_ctx *ctx, u64 virt_addr, u64 phys_addr,
 
 	for (hop_idx = MMU_HOP0; hop_idx < num_hops; hop_idx++) {
 		if (hop_idx == MMU_HOP0) {
-			hop_addr[hop_idx] = get_hop0_addr(ctx);
+			hop_addr[hop_idx] = hl_mmu_dr_get_hop0_addr(ctx);
 		} else {
 			hop_addr[hop_idx] =
-					get_alloc_next_hop_addr(ctx, curr_pte, &hop_new[hop_idx]);
+				hl_mmu_dr_get_alloc_next_hop_addr(ctx, curr_pte, &hop_new[hop_idx]);
 			if (hop_addr[hop_idx] == ULLONG_MAX)
 				goto err;
 		}
@@ -666,27 +402,27 @@  static int hl_mmu_v1_map(struct hl_ctx *ctx, u64 virt_addr, u64 phys_addr,
 	curr_pte = (phys_addr & HOP_PHYS_ADDR_MASK) | mmu_prop->last_mask
 			| PAGE_PRESENT_MASK;
 
-	write_final_pte(ctx, hop_pte_addr[num_hops - 1], curr_pte);
+	hl_mmu_dr_write_final_pte(ctx, hop_pte_addr[num_hops - 1], curr_pte);
 
 	for (hop_idx = MMU_HOP1; hop_idx < num_hops; hop_idx++) {
 		prev_hop = hop_idx - 1;
 
 		if (hop_new[hop_idx]) {
 			curr_pte = (hop_addr[hop_idx] & HOP_PHYS_ADDR_MASK) | PAGE_PRESENT_MASK;
-			write_pte(ctx, hop_pte_addr[prev_hop], curr_pte);
+			hl_mmu_dr_write_pte(ctx, hop_pte_addr[prev_hop], curr_pte);
 			if (hop_idx != MMU_HOP1)
-				get_pte(ctx, hop_addr[prev_hop]);
+				hl_mmu_dr_get_pte(ctx, hop_addr[prev_hop]);
 		}
 	}
 
-	get_pte(ctx, hop_addr[num_hops - 1]);
+	hl_mmu_dr_get_pte(ctx, hop_addr[num_hops - 1]);
 
 	return 0;
 
 err:
 	for (hop_idx = num_hops; hop_idx > MMU_HOP0; hop_idx--) {
 		if (hop_new[hop_idx])
-			free_hop(ctx, hop_addr[hop_idx]);
+			hl_mmu_dr_free_hop(ctx, hop_addr[hop_idx]);
 	}
 
 	return rc;
@@ -752,7 +488,7 @@  static int hl_mmu_v1_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr,
 	if (is_huge)
 		used_hops--;
 
-	hops->hop_info[0].hop_addr = get_phys_hop0_addr(ctx);
+	hops->hop_info[0].hop_addr = hl_mmu_dr_get_phys_hop0_addr(ctx);
 	hops->hop_info[0].hop_pte_addr =
 			hl_mmu_get_hop_pte_phys_addr(ctx, mmu_prop, 0,
 					hops->hop_info[0].hop_addr, virt_addr);
@@ -801,13 +537,13 @@  static int hl_mmu_v1_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr,
  */
 void hl_mmu_v1_set_funcs(struct hl_device *hdev, struct hl_mmu_funcs *mmu)
 {
-	mmu->init = hl_mmu_v1_init;
-	mmu->fini = hl_mmu_v1_fini;
+	mmu->init = hl_mmu_dr_init;
+	mmu->fini = hl_mmu_dr_fini;
 	mmu->ctx_init = hl_mmu_v1_ctx_init;
 	mmu->ctx_fini = hl_mmu_v1_ctx_fini;
 	mmu->map = hl_mmu_v1_map;
 	mmu->unmap = hl_mmu_v1_unmap;
-	mmu->flush = flush;
+	mmu->flush = hl_mmu_dr_flush;
 	mmu->swap_out = hl_mmu_v1_swap_out;
 	mmu->swap_in = hl_mmu_v1_swap_in;
 	mmu->get_tlb_info = hl_mmu_v1_get_tlb_info;
diff --git a/drivers/accel/habanalabs/common/mmu/mmu_v2.c b/drivers/accel/habanalabs/common/mmu/mmu_v2.c
new file mode 100644
index 000000000000..4bc0268fff1c
--- /dev/null
+++ b/drivers/accel/habanalabs/common/mmu/mmu_v2.c
@@ -0,0 +1,338 @@ 
+// SPDX-License-Identifier: GPL-2.0
+
+/*
+ * Copyright 2016-2020 HabanaLabs, Ltd.
+ * All Rights Reserved.
+ */
+
+#include "../habanalabs.h"
+#include "../../include/hw_ip/mmu/mmu_general.h"
+#include "../../include/hw_ip/mmu/mmu_v2_0.h"
+
+#include <linux/slab.h>
+
+/**
+ * hl_mmu_v2_ctx_init() - initialize a context for using the MMU module.
+ * @ctx: pointer to the context structure to initialize.
+ *
+ * Initialize a mutex to protect the concurrent mapping flow, a hash to hold all
+ * page tables hops related to this context.
+ * Return: 0 on success, non-zero otherwise.
+ */
+static int hl_mmu_v2_ctx_init(struct hl_ctx *ctx)
+{
+	hash_init(ctx->mmu_shadow_hash);
+
+	return 0;
+}
+
+/*
+ * hl_mmu_v2_ctx_fini - disable a ctx from using the mmu module
+ *
+ * @ctx: pointer to the context structure
+ *
+ * This function does the following:
+ * - Free any pgts which were not freed yet
+ * - Free the mutex
+ * - Free DRAM default page mapping hops
+ */
+static void hl_mmu_v2_ctx_fini(struct hl_ctx *ctx)
+{
+	struct hl_device *hdev = ctx->hdev;
+	struct pgt_info *pgt_info;
+	struct hlist_node *tmp;
+	int i;
+
+	if (!hash_empty(ctx->mmu_shadow_hash))
+		dev_err(hdev->dev, "ctx %d is freed while it has pgts in use\n",
+			ctx->asid);
+
+	hash_for_each_safe(ctx->mmu_shadow_hash, i, tmp, pgt_info, node) {
+		dev_err_ratelimited(hdev->dev,
+			"pgt_info of addr 0x%llx of asid %d was not destroyed, num_ptes: %d\n",
+			pgt_info->phys_addr, ctx->asid, pgt_info->num_of_ptes);
+		hl_mmu_dr_free_pgt_node(ctx, pgt_info);
+	}
+}
+
+static int hl_mmu_v2_unmap(struct hl_ctx *ctx,	u64 virt_addr, bool is_dram_addr)
+{
+	u64 hop_addr[MMU_ARCH_6_HOPS] = { 0 }, hop_pte_addr[MMU_ARCH_6_HOPS] = { 0 }, curr_pte,
+							scrambled_virt_addr;
+	struct asic_fixed_properties *prop = &ctx->hdev->asic_prop;
+	struct hl_device *hdev = ctx->hdev;
+	struct hl_mmu_properties *mmu_prop;
+	bool is_huge = false;
+	int i, hop_last;
+
+	/* device resident in V2 are allowed only for HMMU */
+	if (!is_dram_addr)
+		return -EINVAL;
+
+	mmu_prop = &prop->dmmu;
+
+	hop_last = mmu_prop->num_hops - 1;
+
+	scrambled_virt_addr = hdev->asic_funcs->scramble_addr(hdev, virt_addr);
+
+	hop_addr[0] = hl_mmu_dr_get_hop0_addr(ctx);
+	hop_pte_addr[0] = hl_mmu_get_hop_pte_phys_addr(ctx, mmu_prop, 0,
+					hop_addr[0], scrambled_virt_addr);
+	if (hop_pte_addr[0] == U64_MAX)
+		return -EFAULT;
+
+	curr_pte = *(u64 *) (uintptr_t) hop_pte_addr[0];
+
+	for (i = 1 ; i < mmu_prop->num_hops ; i++) {
+		hop_addr[i] = hl_mmu_get_next_hop_addr(ctx, curr_pte);
+		if (hop_addr[i] == ULLONG_MAX)
+			goto not_mapped;
+
+		hop_pte_addr[i] = hl_mmu_get_hop_pte_phys_addr(ctx, mmu_prop, i,
+					hop_addr[i], scrambled_virt_addr);
+		if (hop_pte_addr[i] == U64_MAX)
+			return -EFAULT;
+
+		curr_pte = *(u64 *) (uintptr_t) hop_pte_addr[i];
+
+		if ((i <= hop_last) && (curr_pte & mmu_prop->last_mask)) {
+			hop_last = i;
+			is_huge = true;
+			break;
+		}
+	}
+
+	if (is_dram_addr && !is_huge) {
+		dev_err(hdev->dev, "DRAM unmapping should use huge pages only\n");
+		return -EFAULT;
+	}
+
+	if (!(curr_pte & PAGE_PRESENT_MASK))
+		goto not_mapped;
+
+	for (i = hop_last ; i > 0 ; i--) {
+		hl_mmu_dr_clear_pte(ctx, hop_pte_addr[i]);
+		if (hl_mmu_dr_put_pte(ctx, hop_addr[i]))
+			goto mapped;
+	}
+	hl_mmu_dr_clear_pte(ctx, hop_pte_addr[0]);
+
+mapped:
+	return 0;
+
+not_mapped:
+	dev_err(hdev->dev, "virt addr 0x%llx is not mapped to phys addr\n",
+		virt_addr);
+
+	return -EINVAL;
+}
+
+static int hl_mmu_v2_map(struct hl_ctx *ctx, u64 virt_addr, u64 phys_addr,
+							u32 page_size, bool is_dram_addr)
+{
+	u64 hop_addr[MMU_ARCH_6_HOPS] = { 0 }, hop_pte_addr[MMU_ARCH_6_HOPS] = { 0 },
+			curr_pte = 0, scrambled_virt_addr, scrambled_phys_addr;
+	struct asic_fixed_properties *prop = &ctx->hdev->asic_prop;
+	bool hop_new[MMU_ARCH_6_HOPS] = { false };
+	struct hl_device *hdev = ctx->hdev;
+	struct hl_mmu_properties *mmu_prop;
+	int rc, i, hop_last;
+
+	/* device resident in V2 are allowed only for HMMU */
+	if (!is_dram_addr)
+		return -EINVAL;
+
+	mmu_prop = &prop->dmmu;
+
+	hop_last = mmu_prop->num_hops - 1;
+
+	scrambled_virt_addr = hdev->asic_funcs->scramble_addr(hdev, virt_addr);
+	scrambled_phys_addr = hdev->asic_funcs->scramble_addr(hdev, phys_addr);
+
+	/* First hop is preallocated therefore it is treated differently  */
+	hop_addr[0] = hl_mmu_dr_get_hop0_addr(ctx);
+	hop_pte_addr[0] = hl_mmu_get_hop_pte_phys_addr(ctx, mmu_prop, 0,
+						hop_addr[0], scrambled_virt_addr);
+	curr_pte = *(u64 *) (uintptr_t) hop_pte_addr[0];
+
+	/* Handle hop1 to hop_last */
+	for (i = 1 ; i <= hop_last ; i++) {
+		hop_addr[i] = hl_mmu_dr_get_alloc_next_hop_addr(ctx, curr_pte, &hop_new[i]);
+		if (hop_addr[i] == ULLONG_MAX) {
+			rc = -ENOMEM;
+			goto err;
+		}
+
+		hop_pte_addr[i] = hl_mmu_get_hop_pte_phys_addr(ctx, mmu_prop, i,
+					hop_addr[i], scrambled_virt_addr);
+		if (hop_pte_addr[i] == U64_MAX) {
+			rc = -EINVAL;
+			goto err;
+		}
+
+		if (!hop_pte_addr[i]) {
+			rc = -EINVAL;
+			goto err;
+		}
+
+		curr_pte = *(u64 *) (uintptr_t) hop_pte_addr[i];
+	}
+
+	if (curr_pte & PAGE_PRESENT_MASK) {
+		dev_err(hdev->dev,
+			"mapping already exists for virt_addr 0x%llx\n",
+				virt_addr);
+
+		for (i = 0 ; i <= hop_last ; i++)
+			dev_dbg(hdev->dev, "hop%d pte: 0x%llx (0x%llx)\n",
+				i, *(u64 *) (uintptr_t) hop_pte_addr[i],
+				hop_pte_addr[i]);
+
+		rc = -EINVAL;
+		goto err;
+	}
+
+	curr_pte = (scrambled_phys_addr & HOP_PHYS_ADDR_MASK)
+					| mmu_prop->last_mask | PAGE_PRESENT_MASK;
+
+	/* Write the PTEs */
+	hl_mmu_dr_write_final_pte(ctx, hop_pte_addr[hop_last], curr_pte);
+
+	/* for each new hop, add its address to the table of previous-hop */
+	for (i = 1 ; i <= hop_last ; i++) {
+		if (hop_new[i]) {
+			curr_pte = (hop_addr[i] & HOP_PHYS_ADDR_MASK) | PAGE_PRESENT_MASK;
+			hl_mmu_dr_write_pte(ctx, hop_pte_addr[i - 1], curr_pte);
+
+			if (i - 1)
+				hl_mmu_dr_get_pte(ctx, hop_addr[i - 1]);
+		}
+	}
+	hl_mmu_dr_get_pte(ctx, hop_addr[hop_last]);
+
+	return 0;
+
+err:
+	for (i = 1 ; i <= hop_last ; i++)
+		if (hop_new[i] && (hop_addr[i] != U64_MAX))
+			hl_mmu_dr_free_hop(ctx, hop_addr[i]);
+
+	return rc;
+}
+
+/*
+ * hl_mmu_v2_swap_out - marks all mapping of the given ctx as swapped out
+ *
+ * @ctx: pointer to the context structure
+ *
+ */
+static void hl_mmu_v2_swap_out(struct hl_ctx *ctx)
+{
+
+}
+
+/*
+ * hl_mmu_v2_swap_in - marks all mapping of the given ctx as swapped in
+ *
+ * @ctx: pointer to the context structure
+ *
+ */
+static void hl_mmu_v2_swap_in(struct hl_ctx *ctx)
+{
+
+}
+
+static int hl_mmu_v2_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr, struct hl_mmu_hop_info *hops)
+{
+	struct asic_fixed_properties *prop = &ctx->hdev->asic_prop;
+	struct hl_device *hdev = ctx->hdev;
+	struct hl_mmu_properties *mmu_prop;
+	bool is_dram_addr;
+	int i;
+
+	is_dram_addr = hl_mem_area_inside_range(virt_addr, prop->dmmu.page_size,
+						prop->dmmu.start_addr,
+						prop->dmmu.end_addr);
+
+	/* device resident in V2 are allowed only for HMMU */
+	if (!is_dram_addr)
+		return -EINVAL;
+
+	mmu_prop = &prop->dmmu;
+	hops->range_type = HL_VA_RANGE_TYPE_DRAM;
+
+	hops->scrambled_vaddr = hdev->asic_funcs->scramble_addr(hdev, virt_addr);
+
+	hops->hop_info[0].hop_addr = hl_mmu_dr_get_phys_hop0_addr(ctx);
+	hops->hop_info[0].hop_pte_addr = hl_mmu_get_hop_pte_phys_addr(ctx, mmu_prop, 0,
+						hops->hop_info[0].hop_addr,
+							hops->scrambled_vaddr);
+	if (hops->hop_info[0].hop_pte_addr == U64_MAX)
+		return -EFAULT;
+
+	hops->hop_info[0].hop_pte_val = hdev->asic_funcs->read_pte(hdev,
+						hops->hop_info[0].hop_pte_addr);
+	if (hops->hop_info[0].hop_pte_val == U64_MAX)
+		return -EFAULT;
+
+	for (i = 1 ; i < mmu_prop->num_hops ; i++) {
+		hops->hop_info[i].hop_addr =
+			hl_mmu_get_next_hop_addr(ctx, hops->hop_info[i - 1].hop_pte_val);
+		if (hops->hop_info[i].hop_addr == ULLONG_MAX)
+			return -EFAULT;
+
+		hops->hop_info[i].hop_pte_addr =
+				hl_mmu_get_hop_pte_phys_addr(ctx, mmu_prop, i,
+						hops->hop_info[i].hop_addr,
+						hops->scrambled_vaddr);
+		if (hops->hop_info[i].hop_pte_addr == U64_MAX)
+			return -EFAULT;
+
+		hops->hop_info[i].hop_pte_val =
+				hdev->asic_funcs->read_pte(hdev,
+					hops->hop_info[i].hop_pte_addr);
+
+		if (!(hops->hop_info[i].hop_pte_val & PAGE_PRESENT_MASK))
+			return -EFAULT;
+
+		if (hops->hop_info[i].hop_pte_val & mmu_prop->last_mask)
+			break;
+	}
+
+	/* if passed over all hops then no last hop was found */
+	if (i == mmu_prop->num_hops)
+		return -EFAULT;
+
+	if (!(hops->hop_info[i].hop_pte_val & PAGE_PRESENT_MASK))
+		return -EFAULT;
+
+	if (hops->scrambled_vaddr != virt_addr)
+		hops->unscrambled_paddr = hdev->asic_funcs->descramble_addr
+				(hdev, hops->hop_info[i].hop_pte_val);
+	else
+		hops->unscrambled_paddr = hops->hop_info[i].hop_pte_val;
+
+	hops->used_hops = i + 1;
+
+	return 0;
+}
+
+/*
+ * hl_mmu_v2_prepare - prepare mmu_if for working with mmu v2
+ *
+ * @hdev: pointer to the device structure
+ * @mmu_if: pointer to the mmu interface structure
+ */
+void hl_mmu_v2_set_funcs(struct hl_device *hdev, struct hl_mmu_funcs *mmu)
+{
+	mmu->init = hl_mmu_dr_init;
+	mmu->fini = hl_mmu_dr_fini;
+	mmu->ctx_init = hl_mmu_v2_ctx_init;
+	mmu->ctx_fini = hl_mmu_v2_ctx_fini;
+	mmu->map = hl_mmu_v2_map;
+	mmu->unmap = hl_mmu_v2_unmap;
+	mmu->flush = hl_mmu_dr_flush;
+	mmu->swap_out = hl_mmu_v2_swap_out;
+	mmu->swap_in = hl_mmu_v2_swap_in;
+	mmu->get_tlb_info = hl_mmu_v2_get_tlb_info;
+}
diff --git a/drivers/accel/habanalabs/gaudi/gaudi.c b/drivers/accel/habanalabs/gaudi/gaudi.c
index 53292d4c15c8..dde3839fe0e0 100644
--- a/drivers/accel/habanalabs/gaudi/gaudi.c
+++ b/drivers/accel/habanalabs/gaudi/gaudi.c
@@ -649,6 +649,7 @@  static int gaudi_set_fixed_properties(struct hl_device *hdev)
 	prop->dmmu.start_addr = (VA_HOST_SPACE_START + VA_HOST_SPACE_SIZE / 2);
 	prop->dmmu.end_addr = VA_HOST_SPACE_END;
 	prop->dmmu.page_size = PAGE_SIZE_2MB;
+	prop->dmmu.pgt_size = prop->mmu_pgt_size;
 
 	prop->cfg_size = CFG_SIZE;
 	prop->max_asid = MAX_ASID;
diff --git a/drivers/accel/habanalabs/gaudi2/gaudi2.c b/drivers/accel/habanalabs/gaudi2/gaudi2.c
index fd01525b1ea2..c95809120a07 100644
--- a/drivers/accel/habanalabs/gaudi2/gaudi2.c
+++ b/drivers/accel/habanalabs/gaudi2/gaudi2.c
@@ -2308,11 +2308,26 @@  static int set_number_of_functional_hbms(struct hl_device *hdev)
 	return 0;
 }
 
+bool gaudi2_is_edma_queue_id(u32 queue_id)
+{
+
+	switch (queue_id) {
+	case GAUDI2_QUEUE_ID_DCORE0_EDMA_0_0...GAUDI2_QUEUE_ID_DCORE0_EDMA_1_3:
+	case GAUDI2_QUEUE_ID_DCORE1_EDMA_0_0...GAUDI2_QUEUE_ID_DCORE1_EDMA_1_3:
+	case GAUDI2_QUEUE_ID_DCORE2_EDMA_0_0...GAUDI2_QUEUE_ID_DCORE2_EDMA_1_3:
+	case GAUDI2_QUEUE_ID_DCORE3_EDMA_0_0...GAUDI2_QUEUE_ID_DCORE3_EDMA_1_3:
+		return true;
+	default:
+		return false;
+	}
+}
+
 static int gaudi2_set_dram_properties(struct hl_device *hdev)
 {
 	struct asic_fixed_properties *prop = &hdev->asic_prop;
-	u32 basic_hbm_page_size;
-	int rc;
+	u64 hbm_drv_base_offset = 0, edma_pq_base_addr;
+	u32 basic_hbm_page_size, edma_idx = 0;
+	int rc, i;
 
 	rc = set_number_of_functional_hbms(hdev);
 	if (rc)
@@ -2356,9 +2371,35 @@  static int gaudi2_set_dram_properties(struct hl_device *hdev)
 	prop->dmmu.start_addr = prop->dram_base_address +
 			(prop->dram_page_size *
 				DIV_ROUND_UP_SECTOR_T(prop->dram_size, prop->dram_page_size));
-
 	prop->dmmu.end_addr = prop->dmmu.start_addr + prop->dram_page_size *
 			div_u64((VA_HBM_SPACE_END - prop->dmmu.start_addr), prop->dmmu.page_size);
+	/*
+	 * Driver can't share an (48MB) HBM page with the F/W in order to prevent FW to block
+	 * the driver part by range register, so it must start at the next (48MB) page
+	 */
+	hbm_drv_base_offset = roundup(CPU_FW_IMAGE_SIZE, prop->num_functional_hbms * SZ_8M);
+
+	/*
+	 * The NIC driver section size and the HMMU page tables section in the HBM needs
+	 * to be the remaining size in the first dram page after taking into
+	 * account the F/W image size
+	 */
+
+	/* Reserve region in HBM for HMMU page tables */
+	prop->mmu_pgt_addr = DRAM_PHYS_BASE + hbm_drv_base_offset +
+				((prop->dram_page_size - hbm_drv_base_offset) -
+				(HMMU_PAGE_TABLES_SIZE + EDMA_PQS_SIZE + EDMA_SCRATCHPAD_SIZE));
+
+	/* Set EDMA PQs HBM addresses */
+	edma_pq_base_addr = prop->mmu_pgt_addr + HMMU_PAGE_TABLES_SIZE;
+
+	for (i = 0 ; i < GAUDI2_QUEUE_ID_CPU_PQ ; i++) {
+		if (gaudi2_is_edma_queue_id(i)) {
+			prop->hw_queues_props[i].q_dram_bd_address = edma_pq_base_addr +
+							(edma_idx * HL_QUEUE_SIZE_IN_BYTES);
+			edma_idx++;
+		}
+	}
 
 	return 0;
 }
@@ -2368,7 +2409,7 @@  static int gaudi2_set_fixed_properties(struct hl_device *hdev)
 	struct asic_fixed_properties *prop = &hdev->asic_prop;
 	struct hw_queue_properties *q_props;
 	u32 num_sync_stream_queues = 0;
-	int i;
+	int i, rc;
 
 	prop->max_queues = GAUDI2_QUEUE_ID_SIZE;
 	prop->hw_queues_props = kcalloc(prop->max_queues, sizeof(struct hw_queue_properties),
@@ -2391,6 +2432,9 @@  static int gaudi2_set_fixed_properties(struct hl_device *hdev)
 		}
 
 		q_props[i].cb_alloc_flags = CB_ALLOC_USER;
+
+		if (gaudi2_is_edma_queue_id(i))
+			q_props[i].dram_bd = 1;
 	}
 
 	q_props[GAUDI2_QUEUE_ID_CPU_PQ].type = QUEUE_TYPE_CPU;
@@ -2419,40 +2463,39 @@  static int gaudi2_set_fixed_properties(struct hl_device *hdev)
 
 	prop->rotator_enabled_mask = BIT(NUM_OF_ROT) - 1;
 
-	if (hdev->pldm)
-		prop->mmu_pgt_size = 0x800000; /* 8MB */
-	else
-		prop->mmu_pgt_size = MMU_PAGE_TABLES_INITIAL_SIZE;
+	prop->max_asid = 2;
 
+	prop->dmmu.pgt_size = HMMU_PAGE_TABLES_SIZE;
 	prop->mmu_pte_size = HL_PTE_SIZE;
 	prop->mmu_hop_table_size = HOP_TABLE_SIZE_512_PTE;
-	prop->mmu_hop0_tables_total_size = HOP0_512_PTE_TABLES_TOTAL_SIZE;
+	prop->mmu_hop0_tables_total_size = HOP_TABLE_SIZE_512_PTE * prop->max_asid;
 
 	prop->dmmu.hop_shifts[MMU_HOP0] = DHOP0_SHIFT;
 	prop->dmmu.hop_shifts[MMU_HOP1] = DHOP1_SHIFT;
 	prop->dmmu.hop_shifts[MMU_HOP2] = DHOP2_SHIFT;
 	prop->dmmu.hop_shifts[MMU_HOP3] = DHOP3_SHIFT;
-	prop->dmmu.hop_shifts[MMU_HOP4] = DHOP4_SHIFT;
 	prop->dmmu.hop_masks[MMU_HOP0] = DHOP0_MASK;
 	prop->dmmu.hop_masks[MMU_HOP1] = DHOP1_MASK;
 	prop->dmmu.hop_masks[MMU_HOP2] = DHOP2_MASK;
 	prop->dmmu.hop_masks[MMU_HOP3] = DHOP3_MASK;
-	prop->dmmu.hop_masks[MMU_HOP4] = DHOP4_MASK;
 	prop->dmmu.page_size = PAGE_SIZE_1GB;
-	prop->dmmu.num_hops = MMU_ARCH_6_HOPS;
+	prop->dmmu.num_hops = MMU_ARCH_4_HOPS;
 	prop->dmmu.last_mask = LAST_MASK;
-	prop->dmmu.host_resident = 1;
+	prop->dmmu.host_resident = 0;
 	prop->dmmu.hop_table_size = prop->mmu_hop_table_size;
 	prop->dmmu.hop0_tables_total_size = prop->mmu_hop0_tables_total_size;
 
-	/*
-	 * this is done in order to be able to validate FW descriptor (i.e. validating that
-	 * the addresses and allocated space for FW image does not cross memory bounds).
-	 * for this reason we set the DRAM size to the minimum possible and later it will
-	 * be modified according to what reported in the cpucp info packet
+	/* As we need to set the pgt address in dram for HMMU init so we cannot
+	 * wait to the fw cpucp info to set the dram props as mmu init comes before
+	 * hw init
 	 */
-	prop->dram_size = (GAUDI2_HBM_NUM - 1) * SZ_16G;
+	rc = hdev->asic_funcs->set_dram_properties(hdev);
+	if (rc)
+		goto free_qprops;
+
+	prop->mmu_pgt_size = PMMU_PAGE_TABLES_SIZE;
 
+	prop->pmmu.pgt_size = prop->mmu_pgt_size;
 	hdev->pmmu_huge_range = true;
 	prop->pmmu.host_resident = 1;
 	prop->pmmu.num_hops = MMU_ARCH_6_HOPS;
@@ -2516,7 +2559,6 @@  static int gaudi2_set_fixed_properties(struct hl_device *hdev)
 	prop->max_num_of_engines = GAUDI2_ENGINE_ID_SIZE;
 	prop->num_engine_cores = CPU_ID_MAX;
 	prop->cfg_size = CFG_SIZE;
-	prop->max_asid = MAX_ASID;
 	prop->num_of_events = GAUDI2_EVENT_SIZE;
 
 	prop->supports_engine_modes = true;
@@ -2560,6 +2602,10 @@  static int gaudi2_set_fixed_properties(struct hl_device *hdev)
 	prop->hbw_flush_reg = mmPCIE_WRAP_SPECIAL_GLBL_SPARE_0;
 
 	return 0;
+
+free_qprops:
+	kfree(prop->hw_queues_props);
+	return rc;
 }
 
 static int gaudi2_pci_bars_map(struct hl_device *hdev)
@@ -3033,6 +3079,25 @@  static int gaudi2_fetch_psoc_frequency(struct hl_device *hdev)
 	return 0;
 }
 
+static int gaudi2_mmu_clear_pgt_range(struct hl_device *hdev)
+{
+	struct gaudi2_device *gaudi2 = hdev->asic_specific;
+	struct asic_fixed_properties *prop = &hdev->asic_prop;
+	int rc;
+
+	if (!(gaudi2->hw_cap_initialized & HW_CAP_MMU_MASK))
+		return 0;
+
+	if (prop->dmmu.host_resident)
+		return 0;
+
+	rc = gaudi2_memset_device_memory(hdev, prop->mmu_pgt_addr, prop->dmmu.pgt_size, 0);
+	if (rc)
+		dev_err(hdev->dev, "Failed to clear mmu pgt");
+
+	return rc;
+}
+
 static int gaudi2_early_init(struct hl_device *hdev)
 {
 	struct asic_fixed_properties *prop = &hdev->asic_prop;
@@ -3258,6 +3323,12 @@  static int gaudi2_late_init(struct hl_device *hdev)
 		goto disable_pci_access;
 	}
 
+	rc = gaudi2_mmu_clear_pgt_range(hdev);
+	if (rc) {
+		dev_err(hdev->dev, "Failed to clear MMU page tables range\n");
+		goto disable_pci_access;
+	}
+
 	gaudi2_init_arcs(hdev);
 
 	rc = gaudi2_scrub_arcs_dccm(hdev);
@@ -3697,13 +3768,7 @@  static int gaudi2_sw_init(struct hl_device *hdev)
 
 	spin_lock_init(&gaudi2->hw_queues_lock);
 
-	gaudi2->scratchpad_kernel_address = hl_asic_dma_alloc_coherent(hdev, PAGE_SIZE,
-							&gaudi2->scratchpad_bus_address,
-							GFP_KERNEL | __GFP_ZERO);
-	if (!gaudi2->scratchpad_kernel_address) {
-		rc = -ENOMEM;
-		goto free_virt_msix_db_mem;
-	}
+	gaudi2->scratchpad_bus_address = prop->mmu_pgt_addr + HMMU_PAGE_TABLES_SIZE + EDMA_PQS_SIZE;
 
 	gaudi2_user_mapped_blocks_init(hdev);
 
@@ -3727,7 +3792,7 @@  static int gaudi2_sw_init(struct hl_device *hdev)
 
 	rc = gaudi2_special_blocks_iterator_config(hdev);
 	if (rc)
-		goto free_scratchpad_mem;
+		goto free_virt_msix_db_mem;
 
 	rc = gaudi2_test_queues_msgs_alloc(hdev);
 	if (rc)
@@ -3737,9 +3802,6 @@  static int gaudi2_sw_init(struct hl_device *hdev)
 
 special_blocks_free:
 	gaudi2_special_blocks_iterator_free(hdev);
-free_scratchpad_mem:
-	hl_asic_dma_free_coherent(hdev, PAGE_SIZE, gaudi2->scratchpad_kernel_address,
-				  gaudi2->scratchpad_bus_address);
 free_virt_msix_db_mem:
 	hl_cpu_accessible_dma_pool_free(hdev, prop->pmmu.page_size, gaudi2->virt_msix_db_cpu_addr);
 free_cpu_accessible_dma_pool:
@@ -3770,9 +3832,6 @@  static int gaudi2_sw_fini(struct hl_device *hdev)
 	hl_asic_dma_free_coherent(hdev, HL_CPU_ACCESSIBLE_MEM_SIZE, hdev->cpu_accessible_dma_mem,
 						hdev->cpu_accessible_dma_address);
 
-	hl_asic_dma_free_coherent(hdev, PAGE_SIZE, gaudi2->scratchpad_kernel_address,
-					gaudi2->scratchpad_bus_address);
-
 	dma_pool_destroy(hdev->dma_pool);
 
 	kfree(gaudi2);
@@ -4962,10 +5021,17 @@  static void gaudi2_init_qman_pq(struct hl_device *hdev, u32 reg_base,
 		q = &hdev->kernel_queues[queue_id_base + pq_id];
 		pq_offset = pq_id * 4;
 
-		WREG32(reg_base + QM_PQ_BASE_LO_0_OFFSET + pq_offset,
-				lower_32_bits(q->bus_address));
-		WREG32(reg_base + QM_PQ_BASE_HI_0_OFFSET + pq_offset,
-				upper_32_bits(q->bus_address));
+		if (q->dram_bd) {
+			WREG32(reg_base + QM_PQ_BASE_LO_0_OFFSET + pq_offset,
+					lower_32_bits(q->pq_dram_address));
+			WREG32(reg_base + QM_PQ_BASE_HI_0_OFFSET + pq_offset,
+					upper_32_bits(q->pq_dram_address));
+		} else {
+			WREG32(reg_base + QM_PQ_BASE_LO_0_OFFSET + pq_offset,
+					lower_32_bits(q->bus_address));
+			WREG32(reg_base + QM_PQ_BASE_HI_0_OFFSET + pq_offset,
+					upper_32_bits(q->bus_address));
+		}
 		WREG32(reg_base + QM_PQ_SIZE_0_OFFSET + pq_offset, ilog2(HL_QUEUE_LENGTH));
 		WREG32(reg_base + QM_PQ_PI_0_OFFSET + pq_offset, 0);
 		WREG32(reg_base + QM_PQ_CI_0_OFFSET + pq_offset, 0);
@@ -5852,7 +5918,8 @@  static int gaudi2_mmu_invalidate_cache_range(struct hl_device *hdev, bool is_har
 	return rc;
 }
 
-static int gaudi2_mmu_update_hop0_addr(struct hl_device *hdev, u32 stlb_base)
+static int gaudi2_mmu_update_hop0_addr(struct hl_device *hdev, u32 stlb_base,
+									bool host_resident_pgt)
 {
 	struct asic_fixed_properties *prop = &hdev->asic_prop;
 	u64 hop0_addr;
@@ -5864,7 +5931,11 @@  static int gaudi2_mmu_update_hop0_addr(struct hl_device *hdev, u32 stlb_base)
 		max_asid = min((u32) 8, max_asid);
 
 	for (asid = 0 ; asid < max_asid ; asid++) {
-		hop0_addr = hdev->mmu_priv.hr.mmu_asid_hop0[asid].phys_addr;
+		if (host_resident_pgt)
+			hop0_addr = hdev->mmu_priv.hr.mmu_asid_hop0[asid].phys_addr;
+		else
+			hop0_addr = prop->mmu_pgt_addr + (asid * prop->mmu_hop_table_size);
+
 		rc = gaudi2_mmu_update_asid_hop0_addr(hdev, stlb_base, asid, hop0_addr);
 		if (rc) {
 			dev_err(hdev->dev, "failed to set hop0 addr for asid %d\n", asid);
@@ -5875,7 +5946,8 @@  static int gaudi2_mmu_update_hop0_addr(struct hl_device *hdev, u32 stlb_base)
 	return 0;
 }
 
-static int gaudi2_mmu_init_common(struct hl_device *hdev, u32 mmu_base, u32 stlb_base)
+static int gaudi2_mmu_init_common(struct hl_device *hdev, u32 mmu_base, u32 stlb_base,
+								bool host_resident_pgt)
 {
 	u32 status, timeout_usec;
 	int rc;
@@ -5898,7 +5970,7 @@  static int gaudi2_mmu_init_common(struct hl_device *hdev, u32 mmu_base, u32 stlb
 	if (rc)
 		dev_notice_ratelimited(hdev->dev, "Timeout when waiting for MMU SRAM init\n");
 
-	rc = gaudi2_mmu_update_hop0_addr(hdev, stlb_base);
+	rc = gaudi2_mmu_update_hop0_addr(hdev, stlb_base, host_resident_pgt);
 	if (rc)
 		return rc;
 
@@ -5922,6 +5994,7 @@  static int gaudi2_mmu_init_common(struct hl_device *hdev, u32 mmu_base, u32 stlb
 
 static int gaudi2_pci_mmu_init(struct hl_device *hdev)
 {
+	struct asic_fixed_properties *prop = &hdev->asic_prop;
 	struct gaudi2_device *gaudi2 = hdev->asic_specific;
 	u32 mmu_base, stlb_base;
 	int rc;
@@ -5961,7 +6034,7 @@  static int gaudi2_pci_mmu_init(struct hl_device *hdev)
 
 	WREG32(mmu_base + MMU_SPI_SEI_MASK_OFFSET, GAUDI2_PMMU_SPI_SEI_ENABLE_MASK);
 
-	rc = gaudi2_mmu_init_common(hdev, mmu_base, stlb_base);
+	rc = gaudi2_mmu_init_common(hdev, mmu_base, stlb_base, prop->pmmu.host_resident);
 	if (rc)
 		return rc;
 
@@ -6013,7 +6086,7 @@  static int gaudi2_dcore_hmmu_init(struct hl_device *hdev, int dcore_id,
 
 	WREG32(mmu_base + MMU_SPI_SEI_MASK_OFFSET, GAUDI2_HMMU_SPI_SEI_ENABLE_MASK);
 
-	rc = gaudi2_mmu_init_common(hdev, mmu_base, stlb_base);
+	rc = gaudi2_mmu_init_common(hdev, mmu_base, stlb_base, prop->dmmu.host_resident);
 	if (rc)
 		return rc;
 
@@ -7051,7 +7124,7 @@  static int gaudi2_test_queues(struct hl_device *hdev)
 
 	/* send test message on all enabled Qs */
 	for (i = GAUDI2_QUEUE_ID_PDMA_0_0 ; i < GAUDI2_QUEUE_ID_CPU_PQ; i++) {
-		if (!gaudi2_is_queue_enabled(hdev, i))
+		if (!gaudi2_is_queue_enabled(hdev, i) || gaudi2_is_edma_queue_id(i))
 			continue;
 
 		msg_info = &gaudi2->queues_test_info[i - GAUDI2_QUEUE_ID_PDMA_0_0];
@@ -7068,7 +7141,7 @@  static int gaudi2_test_queues(struct hl_device *hdev)
 
 	/* verify that all messages were processed */
 	for (i = GAUDI2_QUEUE_ID_PDMA_0_0 ; i < GAUDI2_QUEUE_ID_CPU_PQ; i++) {
-		if (!gaudi2_is_queue_enabled(hdev, i))
+		if (!gaudi2_is_queue_enabled(hdev, i) || gaudi2_is_edma_queue_id(i))
 			continue;
 
 		rc = gaudi2_test_queue_wait_completion(hdev, i, sob_val);
@@ -8988,7 +9061,6 @@  static void gaudi2_handle_page_error(struct hl_device *hdev, u64 mmu_base, bool
 	if (is_pmmu) {
 		dev_err_ratelimited(hdev->dev, "PMMU page fault on va 0x%llx\n", addr);
 	} else {
-
 		addr = gaudi2_mmu_descramble_addr(hdev, addr);
 		addr &= HW_UNSCRAMBLED_BITS_MASK;
 		dev_err_ratelimited(hdev->dev, "HMMU page fault on va range 0x%llx - 0x%llx\n",
@@ -10255,11 +10327,11 @@  static void gaudi2_handle_eqe(struct hl_device *hdev, struct hl_eq_entry *eq_ent
 }
 
 static int gaudi2_memset_memory_chunk_using_edma_qm(struct hl_device *hdev,
-			struct packet_lin_dma *lin_dma_pkt, dma_addr_t pkt_dma_addr,
-			u32 hw_queue_id, u32 size, u64 addr, u32 val)
+			struct packet_lin_dma *lin_dma_pkt,
+			u64 phys_addr, u32 hw_queue_id, u32 size, u64 addr, u32 val)
 {
 	u32 ctl, pkt_size;
-	int rc = 0;
+	int rc = 0, i;
 
 	ctl = FIELD_PREP(GAUDI2_PKT_CTL_OPCODE_MASK, PACKET_LIN_DMA);
 	ctl |= FIELD_PREP(GAUDI2_PKT_LIN_DMA_CTL_MEMSET_MASK, 1);
@@ -10273,7 +10345,12 @@  static int gaudi2_memset_memory_chunk_using_edma_qm(struct hl_device *hdev,
 
 	pkt_size = sizeof(struct packet_lin_dma);
 
-	rc = hl_hw_queue_send_cb_no_cmpl(hdev, hw_queue_id, pkt_size, pkt_dma_addr);
+	for (i = 0; i < 3; i++)
+		rc = hdev->asic_funcs->access_dev_mem(hdev, PCI_REGION_DRAM,
+				phys_addr + (i * sizeof(u64)),
+				((u64 *)(lin_dma_pkt)) + i, DEBUGFS_WRITE64);
+
+	rc = hl_hw_queue_send_cb_no_cmpl(hdev, hw_queue_id, pkt_size, phys_addr);
 	if (rc)
 		dev_err(hdev->dev, "Failed to send lin dma packet to H/W queue %d\n",
 				hw_queue_id);
@@ -10288,12 +10365,11 @@  static int gaudi2_memset_device_memory(struct hl_device *hdev, u64 addr, u64 siz
 					GAUDI2_QUEUE_ID_DCORE2_EDMA_0_0,
 					GAUDI2_QUEUE_ID_DCORE3_EDMA_0_0};
 	u32 chunk_size, dcore, edma_idx, sob_offset, sob_addr, comp_val,
-		old_mmubp, mmubp, num_of_pkts, busy, pkt_size;
+		old_mmubp, mmubp, num_of_pkts, busy, pkt_size, cb_len;
 	u64 comp_addr, cur_addr = addr, end_addr = addr + size;
 	struct asic_fixed_properties *prop = &hdev->asic_prop;
+	int rc = 0, dma_num = 0, i;
 	void *lin_dma_pkts_arr;
-	dma_addr_t pkt_dma_addr;
-	int rc = 0, dma_num = 0;
 
 	if (prop->edma_enabled_mask == 0) {
 		dev_info(hdev->dev, "non of the EDMA engines is enabled - skip dram scrubbing\n");
@@ -10311,9 +10387,19 @@  static int gaudi2_memset_device_memory(struct hl_device *hdev, u64 addr, u64 siz
 	/* Calculate how many lin dma pkts we'll need */
 	num_of_pkts = div64_u64(round_up(size, SZ_2G), SZ_2G);
 	pkt_size = sizeof(struct packet_lin_dma);
+	cb_len = pkt_size * num_of_pkts;
+
+	/*
+	 * if we're not scrubing HMMU or NIC reserved sections in hbm,
+	 * then it the scrubing of the user section, as we use the start of the user section
+	 * to store the CB of the EDMA QM, so shift the start address of the scrubbing accordingly
+	 * and scrub the CB section before leaving this function.
+	 */
+	if ((addr >= prop->dram_user_base_address) &&
+				(addr < prop->dram_user_base_address + cb_len))
+		cur_addr += (prop->dram_user_base_address + cb_len) - addr;
 
-	lin_dma_pkts_arr = hl_asic_dma_alloc_coherent(hdev, pkt_size * num_of_pkts,
-					&pkt_dma_addr, GFP_KERNEL);
+	lin_dma_pkts_arr = kvcalloc(num_of_pkts, pkt_size, GFP_KERNEL);
 	if (!lin_dma_pkts_arr)
 		return -ENOMEM;
 
@@ -10359,7 +10445,7 @@  static int gaudi2_memset_device_memory(struct hl_device *hdev, u64 addr, u64 siz
 
 				rc = gaudi2_memset_memory_chunk_using_edma_qm(hdev,
 					(struct packet_lin_dma *)lin_dma_pkts_arr + dma_num,
-					pkt_dma_addr + dma_num * pkt_size,
+					prop->dram_user_base_address + (dma_num * pkt_size),
 					edma_queues_id[dcore] + edma_idx * 4,
 					chunk_size, cur_addr, val);
 				if (rc)
@@ -10368,14 +10454,16 @@  static int gaudi2_memset_device_memory(struct hl_device *hdev, u64 addr, u64 siz
 				dma_num++;
 				cur_addr += chunk_size;
 				if (cur_addr == end_addr)
-					break;
+					goto edma_wait;
 			}
 		}
 	}
 
+edma_wait:
 	rc = hl_poll_timeout(hdev, sob_addr, busy, (busy == dma_num), 1000, 1000000);
 	if (rc) {
-		dev_err(hdev->dev, "DMA Timeout during HBM scrubbing\n");
+		dev_err(hdev->dev, "DMA Timeout during HBM scrubbing(sob: 0x%x, dma_num: 0x%x)\n",
+						busy, dma_num);
 		goto end;
 	}
 end:
@@ -10396,8 +10484,16 @@  static int gaudi2_memset_device_memory(struct hl_device *hdev, u64 addr, u64 siz
 		}
 	}
 
+	memset(lin_dma_pkts_arr, 0, sizeof(u64));
+
+	/* Zero the HBM area where we copied the CB */
+	for (i = 0; i < cb_len / sizeof(u64); i += sizeof(u64))
+		rc = hdev->asic_funcs->access_dev_mem(hdev, PCI_REGION_DRAM,
+			prop->dram_user_base_address + i,
+				(u64 *)(lin_dma_pkts_arr), DEBUGFS_WRITE64);
 	WREG32(sob_addr, 0);
-	hl_asic_dma_free_coherent(hdev, pkt_size * num_of_pkts, lin_dma_pkts_arr, pkt_dma_addr);
+
+	kfree(lin_dma_pkts_arr);
 
 	return rc;
 }
@@ -11455,7 +11551,7 @@  static int gaudi2_mmu_get_real_page_size(struct hl_device *hdev, struct hl_mmu_p
 	return 0;
 
 page_size_err:
-	dev_err(hdev->dev, "page size of %u is not %uKB aligned, can't map\n",
+	dev_err(hdev->dev, "page size of 0x%X is not 0x%X aligned, can't map\n",
 							page_size, mmu_prop->page_size >> 10);
 	return -EFAULT;
 }
@@ -11475,6 +11571,29 @@  int gaudi2_send_device_activity(struct hl_device *hdev, bool open)
 	return hl_fw_send_device_activity(hdev, open);
 }
 
+static u64 gaudi2_read_pte(struct hl_device *hdev, u64 addr)
+{
+	struct gaudi2_device *gaudi2 = hdev->asic_specific;
+	u64 val;
+
+	if (hdev->reset_info.hard_reset_pending)
+		return U64_MAX;
+
+	val = readq(hdev->pcie_bar[DRAM_BAR_ID] + (addr - gaudi2->dram_bar_cur_addr));
+
+	return val;
+}
+
+static void gaudi2_write_pte(struct hl_device *hdev, u64 addr, u64 val)
+{
+	struct gaudi2_device *gaudi2 = hdev->asic_specific;
+
+	if (hdev->reset_info.hard_reset_pending)
+		return;
+
+	writeq(val, hdev->pcie_bar[DRAM_BAR_ID] + (addr - gaudi2->dram_bar_cur_addr));
+}
+
 static const struct hl_asic_funcs gaudi2_funcs = {
 	.early_init = gaudi2_early_init,
 	.early_fini = gaudi2_early_fini,
@@ -11511,8 +11630,8 @@  static const struct hl_asic_funcs gaudi2_funcs = {
 	.add_device_attr = gaudi2_add_device_attr,
 	.handle_eqe = gaudi2_handle_eqe,
 	.get_events_stat = gaudi2_get_events_stat,
-	.read_pte = NULL,
-	.write_pte = NULL,
+	.read_pte = gaudi2_read_pte,
+	.write_pte = gaudi2_write_pte,
 	.mmu_invalidate_cache = gaudi2_mmu_invalidate_cache,
 	.mmu_invalidate_cache_range = gaudi2_mmu_invalidate_cache_range,
 	.mmu_prefetch_cache_range = NULL,
diff --git a/drivers/accel/habanalabs/gaudi2/gaudi2P.h b/drivers/accel/habanalabs/gaudi2/gaudi2P.h
index 9b9eef0d97d6..bc508c9cee5c 100644
--- a/drivers/accel/habanalabs/gaudi2/gaudi2P.h
+++ b/drivers/accel/habanalabs/gaudi2/gaudi2P.h
@@ -19,8 +19,6 @@ 
 #define GAUDI2_LINUX_FW_FILE	"habanalabs/gaudi2/gaudi2-fit.itb"
 #define GAUDI2_BOOT_FIT_FILE	"habanalabs/gaudi2/gaudi2-boot-fit.itb"
 
-#define MMU_PAGE_TABLES_INITIAL_SIZE	0x10000000	/* 256MB */
-
 #define GAUDI2_CPU_TIMEOUT_USEC		30000000	/* 30s */
 
 #define NUMBER_OF_PDMA_QUEUES		2
@@ -109,13 +107,11 @@ 
 /* DRAM Memory Map */
 
 #define CPU_FW_IMAGE_SIZE			0x10000000	/* 256MB */
-
-/* This define should be used only when working in a debug mode without dram.
- * When working with dram, the driver size will be calculated dynamically.
- */
-#define NIC_DEFAULT_DRV_SIZE			0x20000000	/* 512MB */
-
 #define CPU_FW_IMAGE_ADDR			DRAM_PHYS_BASE
+#define PMMU_PAGE_TABLES_SIZE			0x10000000      /* 256MB */
+#define EDMA_PQS_SIZE				SZ_2M
+#define EDMA_SCRATCHPAD_SIZE			SZ_1M
+#define HMMU_PAGE_TABLES_SIZE			SZ_1M
 
 #define NIC_NUMBER_OF_PORTS			NIC_NUMBER_OF_ENGINES
 
diff --git a/drivers/accel/habanalabs/include/hw_ip/mmu/mmu_general.h b/drivers/accel/habanalabs/include/hw_ip/mmu/mmu_general.h
index d408feecd483..b4a5e95be354 100644
--- a/drivers/accel/habanalabs/include/hw_ip/mmu/mmu_general.h
+++ b/drivers/accel/habanalabs/include/hw_ip/mmu/mmu_general.h
@@ -26,6 +26,8 @@ 
 #define LAST_MASK			0x0000000000800ull
 #define FLAGS_MASK			0x0000000000FFFull
 
+#define MMU_ARCH_3_HOPS			3
+#define MMU_ARCH_4_HOPS			4
 #define MMU_ARCH_5_HOPS			5
 #define MMU_ARCH_6_HOPS			6