From 7d8251d66250773628ccfe43c5d596b24abbfdff Mon Sep 17 00:00:00 2001 From: hsj <2383755847@qq.com> Date: Sun, 31 May 2026 23:31:20 +0800 Subject: [PATCH 1/2] feat(cute): support fine-grained matrix loads --- src/main/scala/ABMatrixReg.scala | 55 +++++---- src/main/scala/AMemoryLoader.scala | 61 ++++++++-- src/main/scala/BMemoryLoader.scala | 140 ++++++++++++---------- src/main/scala/CDataController.scala | 3 + src/main/scala/CMatrixReg.scala | 49 ++++++-- src/main/scala/CMemoryLoader.scala | 75 +++++++++--- src/main/scala/CUTEParameters.scala | 17 ++- src/main/scala/CUTETOP.scala | 11 ++ src/main/scala/TaskController.scala | 170 +++++++++++++++++++++------ 9 files changed, 432 insertions(+), 149 deletions(-) diff --git a/src/main/scala/ABMatrixReg.scala b/src/main/scala/ABMatrixReg.scala index 340ceb9..2dba0b6 100644 --- a/src/main/scala/ABMatrixReg.scala +++ b/src/main/scala/ABMatrixReg.scala @@ -33,7 +33,8 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ // MemoryLoader端的信号 val MemoryLoaderBankAddr = io.MatrixRegIO.FromMemoryLoader.BankAddr val MemoryLoaderData = io.MatrixRegIO.FromMemoryLoader.Data - + // 【修改 :提取掩码信号】 + val MemoryLoaderByteMask = io.MatrixRegIO.FromMemoryLoader.ByteMask // 写优先的MatrixReg控制逻辑 // write_go: 只要有写入请求(正常写或零填充)就为true val write_go = MemoryLoaderBankAddr.zip(MemoryLoaderData).map{case (a, b) => a.valid && b.valid}.reduce(_||_) @@ -49,17 +50,23 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ // SRAM下一拍返回结果,所以使用上一拍的ready作为valid io.MatrixRegIO.FromDataController.Data.valid := RegNext(read_go) val debug_s1_bank_addr = RegNext(DataControllerBankAddr) + + when(io.MatrixRegIO.FromDataController.BankAddr.fire) { + if (YJPAMLDebugEnable || YJPBMLDebugEnable) { + printf("[ABMatrixReg_ReadReq(%d)] addr0=%d\n", scp_id.U, io.MatrixRegIO.FromDataController.BankAddr.bits(0)) + } + } // 实例化多个SRAM作为多个bank val sram_banks = (0 until ABMatrixRegNBanks) map { i => - // 使用SRAMTemplate替代SyncReadMem - // singlePort=true: 单端口SRAM,支持读写冲突处理 - // latency=1: 读延迟为1拍 + // 【修改 :重构 SRAMTemplate 物理映射语义】 + // 将原本宽字长、单 way 的 SRAM,转变为 1 Byte 为颗粒度、多 way 的 SRAM。 + // 综合工具(DC/Genus)会将其自动识别为:带有 Byte Write Enable (BWEB) 引脚的单个宏单元,或由标准单元组成的寄存器堆,绝不会产生碎片化拥塞。 val bank = Module(new SRAMTemplate( - gen = UInt((ABMatrixRegEntryByteSize*8).W), - set = ABMatrixRegBankNEntrys, - way = 1, + gen = UInt(8.W), // 核心修改:基础数据单元改为 1 Byte + set = ABMatrixRegBankNEntrys, // 深度保持不变 + way = ABMatrixRegEntryByteSize, // 核心修改:相联度(way)数量等于掩码(Byte)数量 singlePort = true, latency = 1, hasMbist = false, @@ -75,10 +82,8 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ when(RegNext(read_go)) { - // 输出读的信息 - if (YJPDebugEnable) - { - printf("[ABMatrixReg_Read(%d)]Bank(%d): debug_s1_bank_addr = %d, s1_bank_read_data = %x\n", + if (YJPAMLDebugEnable) { + printf("[ABMatrixReg_ReadResp(%d)]Bank(%d): debug_s1_bank_addr = %d, s1_bank_read_data = %x\n", scp_id.U, i.U, debug_s1_bank_addr(0), s1_bank_read_data) } } @@ -89,18 +94,22 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ // 写数据逻辑 val s0_bank_write_addr = MemoryLoaderBankAddr(i).bits val s0_bank_write_data = MemoryLoaderData(i).bits - val s0_bank_write_valid = MemoryLoaderBankAddr(i).valid && MemoryLoaderData(i).valid + + // 【修改 :提取单 Bank 掩码】 + val s0_bank_write_mask = MemoryLoaderByteMask(i).bits + // 写握手必须同时满足 addr, data, mask 皆有效 + val s0_bank_write_valid = MemoryLoaderBankAddr(i).valid && MemoryLoaderData(i).valid && MemoryLoaderByteMask(i).valid // 最终的写入控制 val s0_final_write_valid = write_go && s0_bank_write_valid val s0_final_write_addr = MemoryLoaderBankAddr(i).bits val s0_final_write_data = MemoryLoaderData(i).bits - when(write_go && s0_bank_write_valid){ - if (YJPDebugEnable) - { - printf("[ABMatrixReg_Write(%d)]Bank(%d): s0_bank_write_addr = %d, s0_bank_write_data = %x\n", - scp_id.U, i.U, s0_bank_write_addr, s0_bank_write_data) + + when(s0_final_write_valid) { + if (YJPAMLDebugEnable) { + printf("[ABMatrixReg_Write(%d)]Bank(%d): s0_bank_write_addr = %d, s0_bank_write_data = %x, mask = %b\n", + scp_id.U, i.U, s0_bank_write_addr, s0_final_write_data, s0_bank_write_mask) } } @@ -110,13 +119,19 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ bank.io.r.req.valid := bank_read_valid bank.io.r.req.bits.setIdx := s0_bank_read_addr // 读响应在下一拍返回(latency=1) - s1_bank_read_data := bank.io.r.resp.data(0) + // 【修改 :零开销数据拼装】 + // bank.io.r.resp.data 此时是一个 Vec(way, UInt(8.W)) + // 使用 .asUInt 将 Vec 无缝强制转换为大位宽 UInt,不仅代码整洁,而且在综合时完全是一根线(Wire),没有任何面积和时序延迟开销。 + s1_bank_read_data := bank.io.r.resp.data.asUInt // 连接SRAMTemplate的写接口 - // 使用apply方法设置写请求,way=1时waymask为None bank.io.w.req.valid := s0_final_write_valid bank.io.w.req.bits.setIdx := s0_final_write_addr - bank.io.w.req.bits.data(0) := s0_final_write_data + bank.io.w.req.bits.waymask.get := s0_bank_write_mask + + // 使用 .asTypeOf(Vec) 将大宽度的 s0_final_write_data (如 256.W) 直接解包为等宽的 Vec(32, UInt(8.W))。 + // 这取代了臃肿的 for 循环位截取,对后端极度友好,综合后就是干净的连线(Assign)。 + bank.io.w.req.bits.data := s0_final_write_data.asTypeOf(Vec(ABMatrixRegEntryByteSize, UInt(8.W))) bank } diff --git a/src/main/scala/AMemoryLoader.scala b/src/main/scala/AMemoryLoader.scala index 42bd017..e0d1f7c 100644 --- a/src/main/scala/AMemoryLoader.scala +++ b/src/main/scala/AMemoryLoader.scala @@ -12,6 +12,7 @@ import org.chipsalliance.cde.config._ class ASourceIdSearch(implicit p: Parameters) extends CuteBundle{ val MatrixRegBankId = UInt(log2Ceil(ABMatrixRegNBanks).W) val MatrixRegAddr = UInt(log2Ceil(ABMatrixRegBankNEntrys).W) + val MatrixRegisTail = Bool() } class AMemoryLoader(implicit p: Parameters) extends CuteModule{ @@ -28,6 +29,8 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ io.ToMatrixRegIO.BankAddr.map(_.bits := DontCare) io.ToMatrixRegIO.Data.map(_.valid := false.B) io.ToMatrixRegIO.Data.map(_.bits := DontCare) + io.ToMatrixRegIO.ByteMask.map(_.valid := false.B) + io.ToMatrixRegIO.ByteMask.map(_.bits := Fill(ABMatrixRegEntryByteSize, true.B)) io.LocalMMUIO.Request.valid := false.B io.LocalMMUIO.Request.bits := DontCare // It will be set if Request is valid io.LocalMMUIO.Response.ready := false.B @@ -53,6 +56,7 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ for (i <- 0 until ABMatrixRegNBanks) { difftestAmuFinish.bankValid(i) := io.ToMatrixRegIO.BankAddr(i).valid difftestAmuFinish.bankAddr(i) := io.ToMatrixRegIO.BankAddr(i).bits + difftestAmuFinish.bankMask(i) := io.ToMatrixRegIO.ByteMask(i).bits difftestAmuFinish.data(i * 4 + 0) := io.ToMatrixRegIO.Data(i).bits(63,0) difftestAmuFinish.data(i * 4 + 1) := io.ToMatrixRegIO.Data(i).bits(127,64) difftestAmuFinish.data(i * 4 + 2) := io.ToMatrixRegIO.Data(i).bits(191,128) @@ -65,6 +69,9 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ val Tensor_Block_BaseAddr = Reg(UInt(MMUAddrWidth.W)) val ApplicationTensor_A_Stride_M = RegInit(0.U(MMUAddrWidth.W)) val dataType = RegInit(0.U(ElementDataType.DataTypeBitWidth.W)) + val HasTail = RegInit(false.B) + val TailByteMask = RegInit(0.U(log2Ceil(outsideDataWidthByte + 1).W)) + val K_Beat_Count = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_M = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_K = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val Conherent = RegInit(true.B) @@ -89,6 +96,7 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ val MReg_Fill_Table = RegInit((VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(0.U(outsideDataWidth.W))))) val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(ABMatrixRegBankNEntrys).W))))) val MReg_Fill_Table_Time = RegInit((VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(0.U((log2Ceil(outsideDataWidthByte/ABMatrixRegEntryByteSize)+1).W))))) + val MReg_Fill_Table_IsTail = RegInit(VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(false.B))) val MReg_Fill_Table_Free = MReg_Fill_Table_Time.map(_ === 0.U) val MReg_Fill_Table_Insert_Index = PriorityEncoder(MReg_Fill_Table_Free) val MReg_Fill_Table_Not_Full = MReg_Fill_Table_Free.reduce(_ || _) @@ -121,6 +129,9 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ Tensor_Block_BaseAddr := ConfigInfo.ApplicationTensor_A.ApplicationTensor_A_BaseVaddr ApplicationTensor_A_Stride_M := ConfigInfo.ApplicationTensor_A.ApplicationTensor_A_Stride_M dataType := ConfigInfo.ApplicationTensor_A.dataType + HasTail := ConfigInfo.ApplicationTensor_A.HasTail + TailByteMask := ConfigInfo.ApplicationTensor_A.TailByteMask + K_Beat_Count := ConfigInfo.ApplicationTensor_A.K_Beat_Count Is_ZeroLoad := ConfigInfo.LoadTaskInfo.Is_ZeroLoad Is_FullLoad := ConfigInfo.LoadTaskInfo.Is_FullLoad Conherent := ConfigInfo.Conherent @@ -141,13 +152,14 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ CurrentLoaded_BlockTensor_M_Iter := 0.U CurrentLoaded_BlockTensor_K_Iter := 0.U Request_M_Iter_Time := 0.U - MaxRequestIter := MatrixRegTensor_M * MatrixRegTensor_K * ReduceWidthByte.U / outsideDataWidthByte.U + MaxRequestIter := MatrixRegTensor_M * K_Beat_Count Bank_Fill_Search_FIFO := 0.U.asTypeOf(Bank_Fill_Search_FIFO) Bank_Fill_Search_FIFO_Head := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Head) Bank_Fill_Search_FIFO_Tail := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Tail) MReg_Fill_Table := 0.U.asTypeOf(MReg_Fill_Table) MReg_Fill_Table_MReg_Addr := 0.U.asTypeOf(MReg_Fill_Table_MReg_Addr) MReg_Fill_Table_Time := 0.U.asTypeOf(MReg_Fill_Table_Time) + MReg_Fill_Table_IsTail := VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(false.B)) } is(s_load_working) { io.ToMatrixRegIO.active := true.B @@ -161,6 +173,8 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ io.ToMatrixRegIO.BankAddr(i).valid := true.B io.ToMatrixRegIO.Data(i).bits := 0.U io.ToMatrixRegIO.Data(i).valid := true.B + io.ToMatrixRegIO.ByteMask(i).bits := Fill(ABMatrixRegEntryByteSize, true.B) + io.ToMatrixRegIO.ByteMask(i).valid := true.B } TotalLoadSize := TotalLoadSize + 1.U if (YJPAMLDebugEnable) printf("[AML<%d>]ZeroLoad, TotalLoadSize: %d\n", io.DebugInfo.DebugTimeStampe, TotalLoadSize) @@ -171,11 +185,16 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ } when(Is_FullLoad){ - // 矩阵访存顺序:按 M 分 bank 交织,再扫 K。地址 = BaseAddr + M*Stride_M + K*ReduceWidthByte + //先转换成独热码然后进行减一即可计算出掩码 + val tailTaskMask = UIntToOH(TailByteMask, outsideDataWidthByte + 1).asUInt - 1.U(outsideDataWidthByte.W) + val fullTaskMask = Fill(outsideDataWidthByte, true.B) + val RequestBeatIsTail = HasTail && (CurrentLoaded_BlockTensor_K_Iter === (K_Beat_Count - 1.U)) + // 矩阵访存顺序:按 M 分 bank 交织,再扫 K。地址 = BaseAddr + M*Stride_M + K*64B val RequestMatrixRegBankId = (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) % ABMatrixRegNBanks.U - val RequestMatrixRegAddr = (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) / ABMatrixRegNBanks.U * ReduceGroupSize.U + CurrentLoaded_BlockTensor_K_Iter + val RequestMatrixRegBaseAddr = (((CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) / ABMatrixRegNBanks.U) * ReduceGroupSize.U) + val RequestMatrixRegAddr = RequestMatrixRegBaseAddr + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(MAX_Fill_Times)) - Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) * ApplicationTensor_A_Stride_M + CurrentLoaded_BlockTensor_K_Iter * ReduceWidthByte.U + Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) * ApplicationTensor_A_Stride_M + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(outsideDataWidthByte)) val sourceId = Mux(Conherent, io.LocalMMUIO.ConherentRequsetSourceID, io.LocalMMUIO.nonConherentRequsetSourceID) Request.bits.RequestConherent := Conherent Request.bits.RequestSourceID := sourceId.bits @@ -186,13 +205,20 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ val TableItem = Wire(new ASourceIdSearch) TableItem.MatrixRegBankId := RequestMatrixRegBankId TableItem.MatrixRegAddr := RequestMatrixRegAddr + TableItem.MatrixRegisTail := RequestBeatIsTail SoureceIdSearchTable(sourceId.bits) := TableItem.asUInt + if (YJPAMLDebugEnable) { + printf("[AML_RequestHandshake<%d>] M_Iter:%d, K_Iter:%d, ReqTime:%d, Addr:%x, BankId:%d, RegAddr:%d, SourceId:%d, Tail:%d\n", + io.DebugInfo.DebugTimeStampe, CurrentLoaded_BlockTensor_M_Iter, CurrentLoaded_BlockTensor_K_Iter, + Request_M_Iter_Time, Request.bits.RequestVirtualAddr, RequestMatrixRegBankId, RequestMatrixRegAddr, sourceId.bits, RequestBeatIsTail) + } + Request_M_Iter_Time := Request_M_Iter_Time + 1.U when(Request_M_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) === MatrixRegTensor_M - 1.U){ Request_M_Iter_Time := 0.U - CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + (outsideDataWidthByte.U / ReduceWidthByte.U) - when(CurrentLoaded_BlockTensor_K_Iter + (outsideDataWidthByte.U / ReduceWidthByte.U) === MatrixRegTensor_K){ + CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U + when(CurrentLoaded_BlockTensor_K_Iter + 1.U === K_Beat_Count){ CurrentLoaded_BlockTensor_K_Iter := 0.U CurrentLoaded_BlockTensor_M_Iter := CurrentLoaded_BlockTensor_M_Iter + Matrix_MN.U } @@ -217,14 +243,20 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ when(io.LocalMMUIO.Response.fire){ val sourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID - val MatrixRegBankId = SoureceIdSearchTable(sourceId).asTypeOf(new ASourceIdSearch).MatrixRegBankId - val MatrixRegAddr = SoureceIdSearchTable(sourceId).asTypeOf(new ASourceIdSearch).MatrixRegAddr + val MatrixRegSearch = SoureceIdSearchTable(sourceId).asTypeOf(new ASourceIdSearch) + val MatrixRegBankId = MatrixRegSearch.MatrixRegBankId + val MatrixRegAddr = MatrixRegSearch.MatrixRegAddr val ResponseData = io.LocalMMUIO.Response.bits.ReseponseData val FIFOIndex = Bank_Fill_Search_FIFO_Head(MatrixRegBankId) + if (YJPAMLDebugEnable) { + printf("[AML_ResponseHandshake<%d>] Data:%x, BankId:%d, RegAddr:%d, SourceId:%d, FIFOIndex:%d, Tail:%d\n", io.DebugInfo.DebugTimeStampe, ResponseData, MatrixRegBankId, MatrixRegAddr, sourceId, FIFOIndex, MatrixRegSearch.MatrixRegisTail) + } + MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U + MReg_Fill_Table_IsTail(MReg_Fill_Table_Insert_Index) := MatrixRegSearch.MatrixRegisTail Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), AMemoryLoaderReadFromMemoryFIFODepth) if (YJPAMLDebugEnable){ @@ -236,13 +268,22 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ for (i <- 0 until ABMatrixRegNBanks){ when(Bank_Fill_Search_FIFO_Empty(i) === false.B){ val CurrentFIFOIndex = Bank_Fill_Search_FIFO(i)(Bank_Fill_Search_FIFO_Tail(i)) + val fillSlot = MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex) + val fillSlotOH = UIntToOH(fillSlot, MAX_Fill_Times) + val fillLowHalf = fillSlot(0) === 0.U + val currentIsTail = MReg_Fill_Table_IsTail(CurrentFIFOIndex) Current_Fill_MReg_Time(i) := 1.U val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*ABMatrixRegEntryByteSize).W))))) FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) - io.ToMatrixRegIO.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + (MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) + io.ToMatrixRegIO.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + fillSlot io.ToMatrixRegIO.BankAddr(i).valid := true.B - io.ToMatrixRegIO.Data(i).bits := FIFOData(MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) + io.ToMatrixRegIO.Data(i).bits := Mux(fillLowHalf, FIFOData(0), FIFOData(1)) io.ToMatrixRegIO.Data(i).valid := true.B + io.ToMatrixRegIO.ByteMask(i).bits := Mux(currentIsTail && fillSlotOH(1), tailTaskMask(63, 32), Mux(currentIsTail && fillSlotOH(0), tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B))) + io.ToMatrixRegIO.ByteMask(i).valid := true.B + if (YJPAMLDebugEnable) { + printf("[AML_MRegWriteHandshake<%d>] bank:%d, RegAddr:%x, WriteAddr:%x, Data:%x, ByteMask:%x, Time:%d\n", io.DebugInfo.DebugTimeStampe, i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), io.ToMatrixRegIO.BankAddr(i).bits, io.ToMatrixRegIO.Data(i).bits, io.ToMatrixRegIO.ByteMask(i).bits, MReg_Fill_Table_Time(CurrentFIFOIndex)) + } MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ Bank_Fill_Search_FIFO_Tail(i) := WrapInc(Bank_Fill_Search_FIFO_Tail(i), AMemoryLoaderReadFromMemoryFIFODepth) diff --git a/src/main/scala/BMemoryLoader.scala b/src/main/scala/BMemoryLoader.scala index b7c2df0..22aaf3b 100644 --- a/src/main/scala/BMemoryLoader.scala +++ b/src/main/scala/BMemoryLoader.scala @@ -21,6 +21,7 @@ import org.chipsalliance.cde.config._ class BSourceIdSearch(implicit p: Parameters) extends CuteBundle{ val MatrixRegBankId = UInt(log2Ceil(ABMatrixRegNBanks).W) val MatrixRegAddr = UInt(log2Ceil(ABMatrixRegBankNEntrys).W) + val MatrixRegisTail = Bool() } //对于卷积,数据摆放是[khkwoc][ic],对于矩阵乘,数据摆放是[N][K] @@ -41,6 +42,8 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ io.ToMatrixRegIO.BankAddr.map(_.bits := DontCare) io.ToMatrixRegIO.Data.map(_.valid := false.B) io.ToMatrixRegIO.Data.map(_.bits := DontCare) + io.ToMatrixRegIO.ByteMask.map(_.valid := false.B) + io.ToMatrixRegIO.ByteMask.map(_.bits := Fill(ABMatrixRegEntryByteSize, true.B)) io.LocalMMUIO.Request.valid := false.B io.LocalMMUIO.Request.bits := DontCare // It will be set if Request is valid io.LocalMMUIO.Response.ready := false.B @@ -62,6 +65,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ for (i <- 0 until ABMatrixRegNBanks) { difftestAmuFinish.bankValid(i) := io.ToMatrixRegIO.BankAddr(i).valid difftestAmuFinish.bankAddr(i) := io.ToMatrixRegIO.BankAddr(i).bits + difftestAmuFinish.bankMask(i) := io.ToMatrixRegIO.ByteMask(i).bits difftestAmuFinish.data(i * 4 + 0) := io.ToMatrixRegIO.Data(i).bits(63,0) difftestAmuFinish.data(i * 4 + 1) := io.ToMatrixRegIO.Data(i).bits(127,64) difftestAmuFinish.data(i * 4 + 2) := io.ToMatrixRegIO.Data(i).bits(191,128) @@ -80,9 +84,11 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ val MatrixRegTensor_N = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_K = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) - + val HasTail = RegInit(false.B) + val TailByteMask = RegInit(0.U(log2Ceil(outsideDataWidthByte + 1).W)) + val K_Beat_Count = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) + val dataType = RegInit(0.U(ElementDataType.DataTypeBitWidth.W)) val Tensor_B_BaseVaddr = RegInit(0.U(MMUAddrWidth.W)) - val ApplicationTensor_B_Stride_N = RegInit(0.U(MMUAddrWidth.W)) @@ -115,6 +121,10 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ Tensor_B_BaseVaddr := io.ConfigInfo.ApplicationTensor_B.ApplicationTensor_B_BaseVaddr //这个不重要 Tensor_Block_BaseAddr := io.ConfigInfo.ApplicationTensor_B.BlockTensor_B_BaseVaddr //这个是关键 Conherent := io.ConfigInfo.Conherent + HasTail := io.ConfigInfo.ApplicationTensor_B.HasTail + TailByteMask := io.ConfigInfo.ApplicationTensor_B.TailByteMask + K_Beat_Count := io.ConfigInfo.ApplicationTensor_B.K_Beat_Count + dataType := io.ConfigInfo.ApplicationTensor_B.dataType ApplicationTensor_B_Stride_N := io.ConfigInfo.ApplicationTensor_B.ApplicationTensor_B_Stride_N //下一个N,需要增加多少地址偏移量 if(YJPBMLDebugEnable) { @@ -149,11 +159,11 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ //如果是memoryload_state === s_load_working,那么我们就要开始取数 //如果是memoryload_state === s_load_end,那么我们就要结束取数 val TotalLoadSize = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*ReduceWidthByte)+1).W)) //总共要加载的数据量 - val CurrentLoaded_BlockTensor_N = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) - val CurrentLoaded_BlockTensor_K = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) + val TotalRequestSize = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*ReduceWidthByte)).W)) + val CurrentLoaded_BlockTensor_N_Iter = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) + val CurrentLoaded_BlockTensor_K_Iter = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) + val Request_N_Iter_Time = RegInit(0.U(log2Ceil(Matrix_MN).W)) - val MaxBlockTensor_N_Index = MatrixRegTensor_N - val MaxBlockTensor_K_Index = MatrixRegTensor_K //一个cam来存储访存请求的source_id对应的MatrixReg的地址和bank号 //用sourceid做索引,存储MatrixReg的地址和bank号,是一组寄存器 @@ -165,6 +175,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ val MReg_Fill_Table = RegInit((VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(0.U(outsideDataWidth.W))))) val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(ABMatrixRegBankNEntrys).W)))))//记录这个LLC回的数是在scp的哪个地址 val MReg_Fill_Table_Time = RegInit((VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(0.U((log2Ceil(outsideDataWidthByte/ABMatrixRegEntryByteSize)+1).W)))))//记录这个LLC回的数需要回填的次数,完成就可以将数据释放了 + val MReg_Fill_Table_IsTail = RegInit(VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(false.B))) val MReg_Fill_Table_Free = MReg_Fill_Table_Time.map(_ === 0.U)//记录这个FIFO能否能填数据 val MReg_Fill_Table_Valid = MReg_Fill_Table_Time.map(_ =/= 0.U)//记录这个FIFO里的数据是否有效 val MReg_Fill_Table_Insert_Index = PriorityEncoder(MReg_Fill_Table_Free)//返回第一个空位的index @@ -190,69 +201,61 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ is(s_load_init) { memoryload_state := s_load_working TotalLoadSize := 0.U - CurrentLoaded_BlockTensor_N := 0.U - CurrentLoaded_BlockTensor_K := 0.U - MaxRequestIter := MatrixRegTensor_K * MatrixRegTensor_N * ReduceWidthByte.U / (outsideDataWidthByte.U) //总共要发出的访存请求的次数 + TotalRequestSize := 0.U + CurrentLoaded_BlockTensor_N_Iter := 0.U + CurrentLoaded_BlockTensor_K_Iter := 0.U + Request_N_Iter_Time := 0.U + MaxRequestIter := MatrixRegTensor_N * K_Beat_Count //总共要发出的访存请求的次数 + Bank_Fill_Search_FIFO := 0.U.asTypeOf(Bank_Fill_Search_FIFO) + Bank_Fill_Search_FIFO_Head := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Head) + Bank_Fill_Search_FIFO_Tail := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Tail) + MReg_Fill_Table := 0.U.asTypeOf(MReg_Fill_Table) + MReg_Fill_Table_MReg_Addr := 0.U.asTypeOf(MReg_Fill_Table_MReg_Addr) + MReg_Fill_Table_Time := 0.U.asTypeOf(MReg_Fill_Table_Time) + MReg_Fill_Table_IsTail := VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(false.B)) } is(s_load_working) { io.ToMatrixRegIO.active := true.B //根据不同的MemoryOrder,执行不同的访存模式 - //只要Request是ready,我们发出的访存请求就会被MMU送往总线,我们可以发出下一个访存请求 - //不用担心乘法电路延迟,再不济,可以提前几个周期将乘法结果算好,做成fifo送进来 - Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentLoaded_BlockTensor_N * ApplicationTensor_B_Stride_N) + (CurrentLoaded_BlockTensor_K * ReduceWidthByte.U) + //先转换成独热码然后进行减一即可计算出掩码 + val tailTaskMask = UIntToOH(TailByteMask, outsideDataWidthByte + 1).asUInt - 1.U(outsideDataWidthByte.W) + val RequestBeatIsTail = HasTail && (CurrentLoaded_BlockTensor_K_Iter === (K_Beat_Count - 1.U)) + // 访存顺序与AML保持一致:先沿N维分4个bank发射,再推进K维,最后推进下一组N block + val RequestMatrixRegBankId = (CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) % ABMatrixRegNBanks.U + val RequestMatrixRegBaseAddr = (((CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) / ABMatrixRegNBanks.U) * ReduceGroupSize.U) + val RequestMatrixRegAddr = RequestMatrixRegBaseAddr + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(MAX_Fill_Times)) + + Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) * ApplicationTensor_B_Stride_N + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(outsideDataWidthByte)) val sourceId = Mux(Conherent,io.LocalMMUIO.ConherentRequsetSourceID,io.LocalMMUIO.nonConherentRequsetSourceID) Request.bits.RequestConherent := Conherent Request.bits.RequestSourceID := sourceId.bits Request.bits.RequestType_isWrite := false.B - Request.valid := true.B - when(CurrentLoaded_BlockTensor_N === MaxBlockTensor_N_Index || CurrentLoaded_BlockTensor_K === MaxBlockTensor_K_Index)//Is_invalid_IH_IW时,不发出访存请求,尝试直接0填充 - { - Request.valid := false.B - } - - //数据在MatrixReg中的编排 - //数据会先排K,再排M - //AVector一定是不同M的数据,K不断送入,直到K迭代完成,再换新的M, - // K 0 1 2 3 4 5 6 7 time AVector MatrixRegData也这么排布 - // M 0 0 8 g o {bank[0] [1] [2] [3]} - // 0 0 1 2 3 4 5 6 7 1 1 9 h p |addr 0 | 0 8 g o - // 1 8 9 a b c d e f 2 2 a i q | 1 | 1 9 h p - // 2 g h i j k l m n 3 3 b j r | 2 | 2 a i q - // 3 o p g r s t u v 4 4 c k s | 3 | 3 b j r - // 4 w x y z ....... 5 5 d l t | 4 | 4 c k s - // 5 !.............. 6 6 e m u | 5 | 5 d l t - // 6 @.............. 7 7 f n v | 6 | 6 e m u - // 7 #.............. 8 w ! @ # | 7 | 7 f n v - // 8 $.............. 9 ....... | ........................... - // - // - // 在内存中的排布则是 0 1 2 3 4 5 6 7 8 9 a b c d e f g h i j k l m n o p q r s t u v w x y z ....... + Request.valid := (TotalRequestSize < MaxRequestIter) - when(Request.fire && sourceId.valid){//符合条件的话,这条访存请求一定会被发出 //Request.ready表明了LocalMMU会处理这条访存请求,sourceID valid,表明这条访存请求的sourceID是被LocalMMU认可有效才发送到这个模块的 val TableItem = Wire(new BSourceIdSearch) - TableItem.MatrixRegBankId := CurrentLoaded_BlockTensor_N % ABMatrixRegNBanks.U - TableItem.MatrixRegAddr := ((CurrentLoaded_BlockTensor_N / ABMatrixRegNBanks.U) * ReduceGroupSize.U) + CurrentLoaded_BlockTensor_K + TableItem.MatrixRegBankId := RequestMatrixRegBankId + TableItem.MatrixRegAddr := RequestMatrixRegAddr + TableItem.MatrixRegisTail := RequestBeatIsTail SoureceIdSearchTable(sourceId.bits) := TableItem.asUInt - if (YJPBMLDebugEnable) - { - //输出id和request的信息 - printf("[BML<%d>]sourceId:%d,MatrixRegBankId:%d,MatrixRegAddr:%d\n",io.DebugInfo.DebugTimeStampe,sourceId.bits,TableItem.MatrixRegBankId,TableItem.MatrixRegAddr) - //输出这次request的信息 - printf("[BML<%d>]RequestVirtualAddr:%x,RequestConherent:%d,RequestSourceID:%d,RequestType_isWrite:%d\n",io.DebugInfo.DebugTimeStampe,Request.bits.RequestVirtualAddr,Request.bits.RequestConherent,Request.bits.RequestSourceID,Request.bits.RequestType_isWrite) + if (YJPBMLDebugEnable) { + printf("[BML_RequestHandshake<%d>] sourceId:%d, MatrixRegBankId:%d, MatrixRegAddr:%d, RequestVirtualAddr:%x, RequestConherent:%d, RequestType_isWrite:%d, Tail:%d\n",io.DebugInfo.DebugTimeStampe,sourceId.bits,TableItem.MatrixRegBankId,TableItem.MatrixRegAddr,Request.bits.RequestVirtualAddr,Request.bits.RequestConherent,Request.bits.RequestType_isWrite,RequestBeatIsTail) } - when(CurrentLoaded_BlockTensor_N < MaxBlockTensor_N_Index){ - when(CurrentLoaded_BlockTensor_K + MAX_Fill_Times.U < MaxBlockTensor_K_Index){ - //根据不同的内存Order,计算出访存请求的地址 - CurrentLoaded_BlockTensor_K := CurrentLoaded_BlockTensor_K + MAX_Fill_Times.U - }.otherwise{ - CurrentLoaded_BlockTensor_K := 0.U - CurrentLoaded_BlockTensor_N := CurrentLoaded_BlockTensor_N + 1.U + Request_N_Iter_Time := Request_N_Iter_Time + 1.U + when(Request_N_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) === MatrixRegTensor_N - 1.U){ + Request_N_Iter_Time := 0.U + CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U + when(CurrentLoaded_BlockTensor_K_Iter + 1.U === K_Beat_Count){ + CurrentLoaded_BlockTensor_K_Iter := 0.U + CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + Matrix_MN.U } } + when(TotalRequestSize =/= MaxRequestIter){ + TotalRequestSize := TotalRequestSize + 1.U + } } val current_fill_fifo_full = WireInit(false.B) when(io.LocalMMUIO.Response.valid) @@ -275,11 +278,16 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ //Trick注意这个设计,是doublebuffer的,AB只能是doublebuffer,回数一定是不会堵的,而且我们有时间对数据进行压缩解压缩~ //如果要做release设计,要么数据位宽翻倍,腾出周期来使得有空泡能给写任务进行,要么就是数据位宽不变,将读写端口变成独立的读和独立的写端口 val sourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID - val MatrixRegBankId = SoureceIdSearchTable(sourceId).asTypeOf(new BSourceIdSearch).MatrixRegBankId - val MatrixRegAddr = SoureceIdSearchTable(sourceId).asTypeOf(new BSourceIdSearch).MatrixRegAddr + val searchEntry = SoureceIdSearchTable(sourceId).asTypeOf(new BSourceIdSearch) + val MatrixRegBankId = searchEntry.MatrixRegBankId + val MatrixRegAddr = searchEntry.MatrixRegAddr val ResponseData = io.LocalMMUIO.Response.bits.ReseponseData val FIFOIndex = Bank_Fill_Search_FIFO_Head(MatrixRegBankId)//该bank的fill_fifo_index,标注了它当前在fillfifo的哪个位置,我们一共有bank个fill_fifo + if (YJPBMLDebugEnable) { + printf("[BML_ResponseHandshake<%d>] ResponseData:%x, MatrixRegBankId:%d, MatrixRegAddr:%d, SourceId:%d, FIFOIndex:%d, Tail:%d\n",io.DebugInfo.DebugTimeStampe,ResponseData,MatrixRegBankId,MatrixRegAddr,sourceId,FIFOIndex,searchEntry.MatrixRegisTail) + } + if (!ABMLNeedMRegFillTable) { TotalLoadSize := TotalLoadSize + 1.U @@ -288,9 +296,11 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ when(MatrixRegBankId === i.U) { io.ToMatrixRegIO.BankAddr(i).bits := MatrixRegAddr - io.ToMatrixRegIO.Data(i).bits := ResponseData + io.ToMatrixRegIO.Data(i).bits := ResponseData(255, 0) io.ToMatrixRegIO.BankAddr(i).valid := true.B io.ToMatrixRegIO.Data(i).valid := true.B + io.ToMatrixRegIO.ByteMask(i).bits := Mux(searchEntry.MatrixRegisTail, tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B)) + io.ToMatrixRegIO.ByteMask(i).valid := true.B } } } @@ -298,6 +308,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U + MReg_Fill_Table_IsTail(MReg_Fill_Table_Insert_Index) := searchEntry.MatrixRegisTail Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), BMemoryLoaderReadFromMemoryFIFODepth) @@ -324,21 +335,30 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ for (i <- 0 until ABMatrixRegNBanks){ when(Bank_Fill_Search_FIFO_Empty(i) === false.B){ val CurrentFIFOIndex = Bank_Fill_Search_FIFO(i)(Bank_Fill_Search_FIFO_Tail(i)) + val fillSlot = MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex) + val fillLowHalf = fillSlot(0) === 0.U + val fillSlotOH = UIntToOH(fillSlot, MAX_Fill_Times) + val currentIsTail = MReg_Fill_Table_IsTail(CurrentFIFOIndex) Current_Fill_MReg_Time(i) := 1.U val MatrixRegWriteRequest = io.ToMatrixRegIO val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*ABMatrixRegEntryByteSize).W))))) FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) - MatrixRegWriteRequest.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + (MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) + MatrixRegWriteRequest.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + fillSlot MatrixRegWriteRequest.BankAddr(i).valid := true.B - MatrixRegWriteRequest.Data(i).bits := FIFOData(MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) + MatrixRegWriteRequest.Data(i).bits := Mux(fillLowHalf, FIFOData(0), FIFOData(1)) MatrixRegWriteRequest.Data(i).valid := true.B + MatrixRegWriteRequest.ByteMask(i).bits := Mux(currentIsTail && fillSlotOH(1), tailTaskMask(63, 32), Mux(currentIsTail && fillSlotOH(0), tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B))) + MatrixRegWriteRequest.ByteMask(i).valid := true.B + if (YJPBMLDebugEnable) { + printf("[BML_MRegWriteHandshake<%d>] bankid: %d, CurrentFIFOIndex: %d, ScartchPadAddr: %x, BankAddr: %x, Data: %x, ByteMask: %x\n", io.DebugInfo.DebugTimeStampe,i.U, CurrentFIFOIndex, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MatrixRegWriteRequest.BankAddr(i).bits, MatrixRegWriteRequest.Data(i).bits, MatrixRegWriteRequest.ByteMask(i).bits) + } MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ Bank_Fill_Search_FIFO_Tail(i) := WrapInc(Bank_Fill_Search_FIFO_Tail(i), BMemoryLoaderReadFromMemoryFIFODepth) } - if (YJPCMLDebugEnable) + if (YJPBMLDebugEnable) { //输出fill_time 和 fifoindex printf("[BML BMemoryLoader_Load<%d>]bankid: %d,CurrentFIFOIndex %d,ScartchPadAddr: %x, MReg_Fill_Table_Time(CurrentFIFOIndex): %d\n", io.DebugInfo.DebugTimeStampe,i.U, CurrentFIFOIndex, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MReg_Fill_Table_Time(CurrentFIFOIndex)) @@ -355,7 +375,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ { TotalLoadSize := TotalLoadSize + Current_Load_Fill_Size } - if (YJPCMLDebugEnable) + if (YJPBMLDebugEnable) { when(Current_Load_Fill_Size =/= 0.U) { @@ -365,7 +385,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ //状态机切换 when(TotalLoadSize === (MaxRequestIter * MAX_Fill_Times.U)){ memoryload_state := s_load_end - if (YJPCMLDebugEnable) + if (YJPBMLDebugEnable) { printf("[BMemoryLoader_Load<%d>]LoadEnd\n",io.DebugInfo.DebugTimeStampe) } @@ -383,4 +403,4 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ } } } -} \ No newline at end of file +} diff --git a/src/main/scala/CDataController.scala b/src/main/scala/CDataController.scala index d901107..189bcfc 100644 --- a/src/main/scala/CDataController.scala +++ b/src/main/scala/CDataController.scala @@ -36,6 +36,8 @@ class CDataController(implicit p: Parameters) extends CuteModule{ io.FromMatrixRegIO.WriteBankAddr.map(_.bits := DontCare) io.FromMatrixRegIO.WriteRequestData.map(_.valid := false.B) io.FromMatrixRegIO.WriteRequestData.map(_.bits := DontCare) + io.FromMatrixRegIO.WriteRequestByteMask.map(_.valid := false.B) + io.FromMatrixRegIO.WriteRequestByteMask.map(_.bits := Fill(CMatrixRegEntryByteSize, true.B)) io.ConfigInfo.MicroTaskEndValid := false.B io.ConfigInfo.MicroTaskReady := false.B io.ConfigInfo.MicroTask_TEComputeEndValid := false.B @@ -65,6 +67,7 @@ class CDataController(implicit p: Parameters) extends CuteModule{ for (i <- 0 until CMatrixRegNBanks) { difftestAmuFinish.bankValid(i) := io.FromMatrixRegIO.WriteBankAddr(i).valid difftestAmuFinish.bankAddr(i) := io.FromMatrixRegIO.WriteBankAddr(i).bits + difftestAmuFinish.bankMask(i) := Fill(32, true.B) difftestAmuFinish.data(i * 4 + 0) := io.FromMatrixRegIO.WriteRequestData(i).bits(63,0) difftestAmuFinish.data(i * 4 + 1) := io.FromMatrixRegIO.WriteRequestData(i).bits(127,64) difftestAmuFinish.data(i * 4 + 2) := io.FromMatrixRegIO.WriteRequestData(i).bits(191,128) diff --git a/src/main/scala/CMatrixReg.scala b/src/main/scala/CMatrixReg.scala index e0120b8..891f33d 100644 --- a/src/main/scala/CMatrixReg.scala +++ b/src/main/scala/CMatrixReg.scala @@ -43,6 +43,29 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ io.MatrixRegIO.FromDataController.ReadWriteResponse := io.MatrixRegIO.FromDataController.ReadWriteRequest io.MatrixRegIO.FromMemoryLoader.ReadWriteResponse := io.MatrixRegIO.FromMemoryLoader.ReadWriteRequest + when(io.MatrixRegIO.FromMemoryLoader.ReadWriteRequest(MatrixRegTaskType.ReadFromMemoryLoaderIndex)) { + for (i <- 0 until CMatrixRegNBanks) { + when(io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.BankAddr(i).valid) { + if (YJPCMLDebugEnable) { + printf("[CMatrixReg_CMLReadReq(%d)] bank=%d addr=%x\n", scp_id.U, i.U, io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.BankAddr(i).bits) + } + } + } + } + + when(io.MatrixRegIO.FromMemoryLoader.ReadWriteRequest(MatrixRegTaskType.WriteFromMemoryLoaderIndex)) { + for (i <- 0 until CMatrixRegNBanks) { + when(io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.BankAddr(i).valid && io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.Data(i).valid) { + if (YJPCMLDebugEnable) { + printf("[CMatrixReg_CMLWriteReq(%d)] bank=%d addr=%x data=%x mask=%x\n", scp_id.U, i.U, + io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.BankAddr(i).bits, + io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.Data(i).bits, + io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.ByteMask(i).bits) + } + } + } + } + //记录当前拍回数应该返回给哪条数据线 val request = io.MatrixRegIO.FromDataController.ReadWriteRequest | io.MatrixRegIO.FromMemoryLoader.ReadWriteRequest val PreRequest = RegNext(request) @@ -59,6 +82,7 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ val write_request_per_bank_addr = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(0.U(CMatrixRegBankNEntrys.W)))) val write_request_per_bank_data= WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(0.U((8*CMatrixRegEntryByteSize).W)))) + val write_request_per_bank_mask = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(Fill(CMatrixRegEntryByteSize, true.B)))) val write_request_per_bank_valid = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(false.B))) for( i <- 0 until CMatrixRegNBanks) { @@ -68,6 +92,7 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ write_request_per_bank_addr(i) := Mux(decode_request.IsWriteFromDataController, io.MatrixRegIO.FromDataController.WriteBankAddr(i).bits, io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.BankAddr(i).bits) write_request_per_bank_data(i) := Mux(decode_request.IsWriteFromDataController, io.MatrixRegIO.FromDataController.WriteRequestData(i).bits, io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.Data(i).bits) + write_request_per_bank_mask(i) := Mux(decode_request.IsWriteFromDataController, Fill(CMatrixRegEntryByteSize, true.B), io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.ByteMask(i).bits) write_request_per_bank_valid(i) := Mux(decode_request.IsWriteFromDataController, io.MatrixRegIO.FromDataController.WriteRequestData(i).valid, io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.Data(i).valid) } @@ -78,9 +103,9 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ // 两个单口SRAM,奇偶地址各自负责,期望奇偶地址读写错开,奇读偶写,偶读奇写 val bankDepthHalf = (CMatrixRegBankNEntrys + 1) / 2 val evenBank = Module(new SRAMTemplate( - gen = UInt((8*CMatrixRegEntryByteSize).W), + gen = UInt(8.W), set = bankDepthHalf, - way = 1, + way = CMatrixRegEntryByteSize, singlePort = true, latency = 1, hasMbist = false, @@ -89,9 +114,9 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ evenBank.suggestName("CUTE-C-MatrixReg-Even-SRAM") val oddBank = Module(new SRAMTemplate( - gen = UInt((8*CMatrixRegEntryByteSize).W), + gen = UInt(8.W), set = bankDepthHalf, - way = 1, + way = CMatrixRegEntryByteSize, singlePort = true, latency = 1, hasMbist = false, @@ -122,6 +147,12 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ io.MatrixRegIO.FromDataController.ReadResponseData(i).valid := decode_pre_request.IsReadFromDataController && read_request_response_valid(i) io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.ReadResponseData(i).valid := decode_pre_request.IsReadFromMemoryLoader && read_request_response_valid(i) + when(io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.ReadResponseData(i).valid) { + if (YJPCMLDebugEnable) { + printf("[CMatrixReg_CMLReadResp(%d)] bank=%d addr=%x data=%x\n", scp_id.U, i.U, debug_s1_bank_addr, s1_bank_read_data) + } + } + //单口读路径:奇偶分流 // 偶地址SRAM读请求 evenBank.io.r.req.valid := s0_bank_read_valid && s0_read_is_even @@ -132,20 +163,22 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ oddBank.io.r.req.bits.setIdx := s0_read_idx // 读数据在下一拍返回(latency=1) - val even_read_data = evenBank.io.r.resp.data(0) - val odd_read_data = oddBank.io.r.resp.data(0) + val even_read_data = evenBank.io.r.resp.data.asUInt + val odd_read_data = oddBank.io.r.resp.data.asUInt s1_bank_read_data := Mux(s1_read_is_even, even_read_data, odd_read_data) //单口写路径:奇偶分流 // 偶地址SRAM写请求 evenBank.io.w.req.valid := s0_bank_write_valid && s0_write_is_even evenBank.io.w.req.bits.setIdx := s0_write_idx - evenBank.io.w.req.bits.data(0) := s0_bank_write_data + evenBank.io.w.req.bits.waymask.get := write_request_per_bank_mask(i) + evenBank.io.w.req.bits.data := s0_bank_write_data.asTypeOf(Vec(CMatrixRegEntryByteSize, UInt(8.W))) // 奇地址SRAM写请求 oddBank.io.w.req.valid := s0_bank_write_valid && !s0_write_is_even oddBank.io.w.req.bits.setIdx := s0_write_idx - oddBank.io.w.req.bits.data(0) := s0_bank_write_data + oddBank.io.w.req.bits.waymask.get := write_request_per_bank_mask(i) + oddBank.io.w.req.bits.data := s0_bank_write_data.asTypeOf(Vec(CMatrixRegEntryByteSize, UInt(8.W))) //单口SRAM读写不能同拍同时有效,分别对奇偶SRAM进行断言 val even_conflict = s0_bank_read_valid && s0_read_is_even && s0_bank_write_valid && s0_write_is_even diff --git a/src/main/scala/CMemoryLoader.scala b/src/main/scala/CMemoryLoader.scala index a75f353..7c30faa 100644 --- a/src/main/scala/CMemoryLoader.scala +++ b/src/main/scala/CMemoryLoader.scala @@ -19,6 +19,7 @@ import freechips.rocketchip.util.SeqToAugmentedSeq class CSourceIdSearch(implicit p: Parameters) extends CuteBundle{ val MatrixRegBankId =UInt(log2Ceil(CMatrixRegNBanks).W) val MatrixRegAddr = UInt(log2Ceil(CMatrixRegBankNEntrys).W) + val MatrixRegisTail = Bool() } class CMemoryLoader(implicit p: Parameters) extends CuteModule{ @@ -40,6 +41,8 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr.map(_.bits := DontCare) io.ToMatrixRegIO.WriteRequestToMatrixReg.Data.map(_.valid := false.B) io.ToMatrixRegIO.WriteRequestToMatrixReg.Data.map(_.bits := DontCare) + io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask.map(_.valid := false.B) + io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask.map(_.bits := Fill(CMatrixRegEntryByteSize, true.B)) io.LocalMMUIO.Request.valid := false.B io.LocalMMUIO.Request.bits := DontCare // It will be set if Request is valid io.LocalMMUIO.Response.ready := false.B @@ -65,6 +68,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ for (i <- 0 until CMatrixRegNBanks) { difftestAmuFinish.bankValid(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(i).valid difftestAmuFinish.bankAddr(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(i).bits + difftestAmuFinish.bankMask(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(i).bits difftestAmuFinish.data(i * 4 + 0) := io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).bits(63,0) difftestAmuFinish.data(i * 4 + 1) := io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).bits(127,64) difftestAmuFinish.data(i * 4 + 2) := io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).bits(191,128) @@ -104,6 +108,9 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ val ApplicationTensor_C_Stride_M = RegInit(0.U(MMUAddrWidth.W)) val ApplicationTensor_D_Stride_M = RegInit(0.U(MMUAddrWidth.W)) + val HasTail = RegInit(false.B) + val TailByteMask = RegInit(0.U(log2Ceil(outsideDataWidthByte + 1).W)) + val N_Beat_Count = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val Is_ZeroLoad = RegInit(false.B) val Is_FullLoad = RegInit(false.B) @@ -119,6 +126,21 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ when(io.ConfigInfo.MicroTaskReady && io.ConfigInfo.MicroTaskValid){ state := s_mm_task CurrentMatrixRegId := io.ConfigInfo.MatrixRegId + if (YJPCMLDebugEnable) { + printf("[CMemoryLoader_TaskHandshake<%d>] valid=%d ready=%d matrixRegId=%d isLoad=%d isStore=%d coher=%d transpose=%d M=%d N=%d baseC=%x baseD=%x\n", + io.DebugInfo.DebugTimeStampe, + io.ConfigInfo.MicroTaskValid, + io.ConfigInfo.MicroTaskReady, + io.ConfigInfo.MatrixRegId, + io.ConfigInfo.IsLoadMicroTask, + io.ConfigInfo.IsStoreMicroTask, + io.ConfigInfo.Conherent, + io.ConfigInfo.Is_Transpose, + io.ConfigInfo.MatrixRegTensor_M, + io.ConfigInfo.MatrixRegTensor_N, + io.ConfigInfo.ApplicationTensor_C.BlockTensor_C_BaseVaddr, + io.ConfigInfo.ApplicationTensor_D.BlockTensor_D_BaseVaddr) + } assert( !(io.ConfigInfo.IsLoadMicroTask === true.B && io.ConfigInfo.IsStoreMicroTask === true.B), "CMemoryLoader: Load and Store MicroTask cannot be enabled at the same time" @@ -128,6 +150,9 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ Tensor_Block_BaseAddr := io.ConfigInfo.ApplicationTensor_C.BlockTensor_C_BaseVaddr ApplicationTensor_C_Stride_M := io.ConfigInfo.ApplicationTensor_C.ApplicationTensor_C_Stride_M IsConherent := io.ConfigInfo.Conherent + HasTail := io.ConfigInfo.ApplicationTensor_C.HasTail + TailByteMask := io.ConfigInfo.ApplicationTensor_C.TailByteMask + N_Beat_Count := io.ConfigInfo.ApplicationTensor_C.N_Beat_Count Is_ZeroLoad := io.ConfigInfo.LoadTaskInfo.Is_ZeroLoad Is_FullLoad := io.ConfigInfo.LoadTaskInfo.Is_FullLoad @@ -202,6 +227,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ val MReg_Fill_Table = RegInit((VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(0.U(outsideDataWidth.W))))) val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(CMatrixRegBankNEntrys).W)))))//记录这个LLC回的数是在scp的哪个地址 val MReg_Fill_Table_Time = RegInit((VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(0.U((log2Ceil(outsideDataWidthByte/CMatrixRegEntryByteSize)+1).W)))))//记录这个LLC回的数需要回填的次数,完成就可以将数据释放了 + val MReg_Fill_Table_IsTail = RegInit(VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(false.B))) val MReg_Fill_Table_Free = MReg_Fill_Table_Time.map(_ === 0.U)//记录这个FIFO能否能填数据 val MReg_Fill_Table_Valid = MReg_Fill_Table_Time.map(_ =/= 0.U)//记录这个FIFO里的数据是否有效 val MReg_Fill_Table_Insert_Index = PriorityEncoder(MReg_Fill_Table_Free)//返回第一个空位的index @@ -242,13 +268,14 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ TotalRequestSize := 0.U CurrentLoaded_BlockTensor_M_Iter := 0.U CurrentLoaded_BlockTensor_N_Iter := 0.U - MaxRequestIter := MatrixRegTensor_M * MatrixRegTensor_N * ResultWidthByte.U / (outsideDataWidthByte.U) //总共要发出的访存请求的次数 + MaxRequestIter := MatrixRegTensor_M * N_Beat_Count //总共要发出的访存请求的次数 Bank_Fill_Search_FIFO := 0.U.asTypeOf(Bank_Fill_Search_FIFO) Bank_Fill_Search_FIFO_Head := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Head) Bank_Fill_Search_FIFO_Tail := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Tail) MReg_Fill_Table := 0.U.asTypeOf(MReg_Fill_Table) MReg_Fill_Table_MReg_Addr := 0.U.asTypeOf(MReg_Fill_Table_MReg_Addr) MReg_Fill_Table_Time := 0.U.asTypeOf(MReg_Fill_Table_Time) + MReg_Fill_Table_IsTail := VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(false.B)) Request_M_Iter_Time := 0.U MReg_Fill_Table_Head := 0.U MReg_Fill_Table_Tail := 0.U @@ -299,10 +326,12 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ when(Is_FullLoad) { + val tailTaskMask = UIntToOH(TailByteMask, outsideDataWidthByte + 1).asUInt - 1.U(outsideDataWidthByte.W) + val RequestBeatIsTail = HasTail && (CurrentLoaded_BlockTensor_N_Iter === (N_Beat_Count - 1.U)) val RequestMatrixRegBankId = (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) % CMatrixRegNBanks.U //访存请求落在哪个MatrixRegBank上 - val RequestMatrixRegAddr = (CurrentLoaded_BlockTensor_M_Iter / CMatrixRegNBanks.U * MatrixRegTensor_N / Matrix_MN.U) + (CurrentLoaded_BlockTensor_N_Iter / Matrix_MN.U) //该访存请求的第零号数据,落在哪个MatrixRegBank的哪个地址上 + val RequestMatrixRegAddr = ((CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time ) / CMatrixRegNBanks.U ) * (Tensor_MN.U / Matrix_MN.U) + (CurrentLoaded_BlockTensor_N_Iter << log2Ceil(MAX_Fill_Times)) //该访存请求的第零号数据,落在哪个MatrixRegBank的哪个地址上 - ReadRequest.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) * ApplicationTensor_C_Stride_M + CurrentLoaded_BlockTensor_N_Iter * C_DataType + ReadRequest.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) * ApplicationTensor_C_Stride_M + (CurrentLoaded_BlockTensor_N_Iter << log2Ceil(outsideDataWidthByte)) // val CurrentBankID = RequestMatrixRegBankId // val CurrentFIFOIndex = FromMemoryLoaderReadFIFOHead @@ -320,13 +349,14 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ val TableItem = Wire(new CSourceIdSearch) TableItem.MatrixRegBankId := RequestMatrixRegBankId TableItem.MatrixRegAddr := RequestMatrixRegAddr + TableItem.MatrixRegisTail := RequestBeatIsTail SoureceIdSearchTable(sourceId.bits) := TableItem.asUInt Request_M_Iter_Time := Request_M_Iter_Time + 1.U//连续的跨bank去访存 when(Request_M_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) === MatrixRegTensor_M - 1.U){ Request_M_Iter_Time := 0.U - CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + outsideDataWidthByte.U / C_DataType - when(CurrentLoaded_BlockTensor_N_Iter + outsideDataWidthByte.U / C_DataType === MatrixRegTensor_N){ + CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + 1.U + when(CurrentLoaded_BlockTensor_N_Iter + 1.U === N_Beat_Count){ CurrentLoaded_BlockTensor_N_Iter := 0.U CurrentLoaded_BlockTensor_M_Iter := CurrentLoaded_BlockTensor_M_Iter + Matrix_MN.U } @@ -340,9 +370,8 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ //TODO:这里数据读取量定死了,需要为了支持边界情况,改一改 //不过我们保证了数据是256bit对齐的~剩下的就是Tensor_M和Tensor_K不满足的情况思考好就行了 //输出request的次数 - if (YJPCMLDebugEnable) - { - printf("[CMemoryLoader_Load<%d>]RequestMatrixRegAddr: %x,RequestMatrixRegBankId: %x,CurrentLoaded_BlockTensor_N_Iter: %x,CurrentLoaded_BlockTensor_M_Iter: %x,Request_M_Iter_Time: %x,RequestVirtualAddr: %x, RequestSourceID: %x, RequestConherent: %x, RequestType_isWrite: %x, RequestTimes: %d\n", io.DebugInfo.DebugTimeStampe, RequestMatrixRegAddr,RequestMatrixRegBankId,CurrentLoaded_BlockTensor_N_Iter,CurrentLoaded_BlockTensor_M_Iter,Request_M_Iter_Time,ReadRequest.bits.RequestVirtualAddr, ReadRequest.bits.RequestSourceID, ReadRequest.bits.RequestConherent, ReadRequest.bits.RequestType_isWrite, TotalRequestSize) + if (YJPCMLDebugEnable) { + printf("[CMemoryLoader_LoadReq<%d>] RequestMatrixRegAddr: %x,RequestMatrixRegBankId: %x,CurrentLoaded_BlockTensor_N_Iter: %x,CurrentLoaded_BlockTensor_M_Iter: %x,Request_M_Iter_Time: %x,RequestVirtualAddr: %x, RequestSourceID: %x, RequestConherent: %x, RequestType_isWrite: %x, RequestTimes: %d\n", io.DebugInfo.DebugTimeStampe, RequestMatrixRegAddr,RequestMatrixRegBankId,CurrentLoaded_BlockTensor_N_Iter,CurrentLoaded_BlockTensor_M_Iter,Request_M_Iter_Time,ReadRequest.bits.RequestVirtualAddr, ReadRequest.bits.RequestSourceID, ReadRequest.bits.RequestConherent, ReadRequest.bits.RequestType_isWrite, TotalRequestSize) } when(TotalRequestSize === MaxRequestIter){ //assert! @@ -374,14 +403,14 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U + MReg_Fill_Table_IsTail(MReg_Fill_Table_Insert_Index) := SoureceIdSearchTable(sourceId).asTypeOf(new CSourceIdSearch).MatrixRegisTail Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), CMemoryLoaderReadFromMemoryFIFODepth) //输出回填的数据 - if (YJPCMLDebugEnable) - { - printf("[CMemoryLoader_Load<%d>]ResponseData: %x, MatrixRegBankId: %x, MatrixRegAddr: %x, FIFOIndex: %x\n",io.DebugInfo.DebugTimeStampe, ResponseData, MatrixRegBankId, MatrixRegAddr, FIFOIndex) + if (YJPCMLDebugEnable) { + printf("[CMemoryLoader_LoadResp<%d>]ResponseData: %x, MatrixRegBankId: %x, MatrixRegAddr: %x, FIFOIndex: %x, sourceId: %x\n",io.DebugInfo.DebugTimeStampe, ResponseData, MatrixRegBankId, MatrixRegAddr, FIFOIndex, sourceId) } } @@ -395,12 +424,26 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ { Current_Fill_MReg_Time(i) := 1.U val MatrixRegWriteRequest = io.ToMatrixRegIO.WriteRequestToMatrixReg - val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*CMatrixRegEntryByteSize).W))))) + val FIFOData = WireInit(VecInit(Seq.fill(MAX_Fill_Times)(0.U((8 * CMatrixRegEntryByteSize).W)))) + val fillSlot = MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex) + val currentIsTail = MReg_Fill_Table_IsTail(CurrentFIFOIndex) + val fullByteMask = Fill(CMatrixRegEntryByteSize, true.B) + val tailByteMaskVec = Wire(Vec(MAX_Fill_Times, UInt(CMatrixRegEntryByteSize.W))) + for (j <- 0 until MAX_Fill_Times) { + val high = (j + 1) * CMatrixRegEntryByteSize - 1 + val low = j * CMatrixRegEntryByteSize + tailByteMaskVec(j) := tailTaskMask(high, low) + } FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) - MatrixRegWriteRequest.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + (MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) + MatrixRegWriteRequest.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + fillSlot MatrixRegWriteRequest.BankAddr(i).valid := true.B - MatrixRegWriteRequest.Data(i).bits := FIFOData(MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) + MatrixRegWriteRequest.Data(i).bits := FIFOData(fillSlot) MatrixRegWriteRequest.Data(i).valid := true.B + MatrixRegWriteRequest.ByteMask(i).bits := Mux(currentIsTail, tailByteMaskVec(fillSlot), fullByteMask) + MatrixRegWriteRequest.ByteMask(i).valid := true.B + if (YJPCMLDebugEnable) { + printf("[CMemoryLoader_MRegWriteHandshake<%d>] bankid: %d, CurrentFIFOIndex: %d, ScartchPadAddr: %x, BankAddr: %x, Data: %x, ByteMask: %x\n", io.DebugInfo.DebugTimeStampe, i.U, CurrentFIFOIndex, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MatrixRegWriteRequest.BankAddr(i).bits, MatrixRegWriteRequest.Data(i).bits, MatrixRegWriteRequest.ByteMask(i).bits) + } MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ @@ -457,6 +500,8 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(i).valid := true.B io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).bits := 0.U io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).valid := true.B + io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(i).bits := Fill(CMatrixRegEntryByteSize, true.B) + io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(i).valid := true.B } when(io.ToMatrixRegIO.ReadWriteResponse(MatrixRegTaskType.WriteFromMemoryLoaderIndex) === true.B) @@ -920,4 +965,4 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ } } } -} \ No newline at end of file +} diff --git a/src/main/scala/CUTEParameters.scala b/src/main/scala/CUTEParameters.scala index ae9bda1..56545b2 100644 --- a/src/main/scala/CUTEParameters.scala +++ b/src/main/scala/CUTEParameters.scala @@ -80,7 +80,7 @@ case class MatrixIsaParams( def enable4BitDst: Boolean = false - def enable8BitDst: Boolean = false + def enable8BitDst: Boolean = true //开启8位类型,仅仅用于load测试 def enable16BitDst: Boolean = enableFp8Fp16 || enableFp8Bf16 || enableFp16Fp16 @@ -750,6 +750,10 @@ class ApplicationTensor_A_Info()(implicit p: Parameters) extends CuteBundle{ // val BlockTensor_A_BaseVaddr = (UInt(MMUAddrWidth.W))//可能没有了 val ApplicationTensor_A_Stride_M = (UInt(MMUAddrWidth.W))//下一个M需要增加多少的地址偏移量 val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) + //细粒度控制参数增加 + val HasTail = Bool() + val TailByteMask = UInt(log2Ceil(outsideDataWidthByte + 1).W) + val K_Beat_Count = UInt(MatrixRegMaxTensorDimBitSize.W) } class AMLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ @@ -778,6 +782,10 @@ class ApplicationTensor_B_Info()(implicit p: Parameters) extends CuteBundle{ val BlockTensor_B_BaseVaddr = (UInt(MMUAddrWidth.W)) val ApplicationTensor_B_Stride_N = (UInt(MMUAddrWidth.W))//下一个N需要增加多少的地址偏移量 val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) + //细粒度控制参数增加 + val HasTail = Bool() + val TailByteMask = UInt(log2Ceil(outsideDataWidthByte + 1).W) + val K_Beat_Count = UInt(MatrixRegMaxTensorDimBitSize.W) } class BMLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ @@ -804,6 +812,9 @@ class ApplicationTensor_C_Info()(implicit p: Parameters) extends CuteBundle{ val BlockTensor_C_BaseVaddr = (UInt(MMUAddrWidth.W)) val ApplicationTensor_C_Stride_M = (UInt(MMUAddrWidth.W))//下一个M需要增加多少的地址偏移量 val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) + val HasTail = Bool() + val TailByteMask = UInt(log2Ceil(outsideDataWidthByte + 1).W) + val N_Beat_Count = UInt(MatrixRegMaxTensorDimBitSize.W) } class ApplicationTensor_D_Info()(implicit p: Parameters) extends CuteBundle{ @@ -887,6 +898,8 @@ class ABMemoryLoaderMatrixRegIO(implicit p: Parameters) extends CuteBundle{ val BankAddr = Flipped(Vec(ABMatrixRegNBanks, Valid(UInt(log2Ceil(ABMatrixRegBankNEntrys).W)))) //bankdata是对nbanks个bank,各自bank的行数据,是一个vec,有nbanks个元素,每个元素是一个UInt,UInt的宽度是ReduceWidthByte*8 val Data = Flipped(Vec(ABMatrixRegNBanks, Valid(UInt(ABMatrixRegEntryBitSize.W)))) + // 每个bit控制对应1个byte是否写入 + val ByteMask = Flipped(Vec(ABMatrixRegNBanks, Valid(UInt(ABMatrixRegEntryByteSize.W)))) } class CDataControlMatrixRegIO(implicit p: Parameters) extends CuteBundle{ @@ -896,6 +909,7 @@ class CDataControlMatrixRegIO(implicit p: Parameters) extends CuteBundle{ //bankdata是对nbanks个bank,各自bank的行数据,是一个vec,有nbanks个元素,每个元素是一个UInt val ReadResponseData = (Vec(CMatrixRegNBanks, Valid(UInt(CMatrixRegEntryBitSize.W)))) val WriteRequestData = Flipped((Vec(CMatrixRegNBanks, Valid(UInt(CMatrixRegEntryBitSize.W))))) + val WriteRequestByteMask = Flipped((Vec(CMatrixRegNBanks, Valid(UInt(CMatrixRegEntryByteSize.W))))) //chosen是选择该MatrixReg的信号,是一个bool,我们做doublebuffer,选择其一供数,选择其一加载数据 val ReadWriteRequest = Input(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) val ReadWriteResponse = Output(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) @@ -910,6 +924,7 @@ class CMemoryLoaderMatrixRegIO(implicit p: Parameters) extends CuteBundle{ val WriteRequestToMatrixReg = (new Bundle{ val BankAddr = Flipped(Vec(CMatrixRegNBanks, (Valid(UInt(log2Ceil(CMatrixRegBankNEntrys).W))))) val Data = Flipped(Vec(CMatrixRegNBanks, (Valid(UInt(CMatrixRegEntryBitSize.W))))) + val ByteMask = Flipped(Vec(CMatrixRegNBanks, Valid(UInt(CMatrixRegEntryByteSize.W)))) }) val ReadWriteRequest = Input(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) diff --git a/src/main/scala/CUTETOP.scala b/src/main/scala/CUTETOP.scala index 15dbe2d..8e58117 100644 --- a/src/main/scala/CUTETOP.scala +++ b/src/main/scala/CUTETOP.scala @@ -99,6 +99,7 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ //MemoryLoader的请求 ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.BankAddr := 0.U.asTypeOf(ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.BankAddr) ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.Data := 0.U.asTypeOf(ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.Data) + ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ByteMask := 0.U.asTypeOf(ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ByteMask) } // C MatrixReg @@ -113,6 +114,8 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.BankAddr(b).bits := DontCare dest.Data(b).valid := false.B dest.Data(b).bits := DontCare + dest.ByteMask(b).valid := false.B + dest.ByteMask(b).bits := DontCare } } @@ -122,6 +125,8 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.BankAddr(b).bits := src.BankAddr(b).bits dest.Data(b).valid := src.Data(b).valid dest.Data(b).bits := src.Data(b).bits + dest.ByteMask(b).valid := src.ByteMask(b).valid + dest.ByteMask(b).bits := src.ByteMask(b).bits } } @@ -199,6 +204,7 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.ReadRequestToMatrixReg.BankAddr(bank).valid := CML.io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr(bank).valid dest.WriteRequestToMatrixReg.BankAddr(bank).valid := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(bank).valid dest.WriteRequestToMatrixReg.Data(bank).valid := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(bank).valid + dest.WriteRequestToMatrixReg.ByteMask(bank).valid := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(bank).valid } dest.ReadWriteRequest := CML.io.ToMatrixRegIO.ReadWriteRequest }.otherwise { @@ -206,6 +212,7 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.ReadRequestToMatrixReg.BankAddr(bank).valid := false.B dest.WriteRequestToMatrixReg.BankAddr(bank).valid := false.B dest.WriteRequestToMatrixReg.Data(bank).valid := false.B + dest.WriteRequestToMatrixReg.ByteMask(bank).valid := false.B } } @@ -213,6 +220,7 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.ReadRequestToMatrixReg.BankAddr(bank).bits := CML.io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr(bank).bits dest.WriteRequestToMatrixReg.BankAddr(bank).bits := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(bank).bits dest.WriteRequestToMatrixReg.Data(bank).bits := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(bank).bits + dest.WriteRequestToMatrixReg.ByteMask(bank).bits := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(bank).bits } } @@ -238,6 +246,7 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.ReadBankAddr(bank).valid := CDC.io.FromMatrixRegIO.ReadBankAddr(bank).valid dest.WriteBankAddr(bank).valid := CDC.io.FromMatrixRegIO.WriteBankAddr(bank).valid dest.WriteRequestData(bank).valid := CDC.io.FromMatrixRegIO.WriteRequestData(bank).valid + dest.WriteRequestByteMask(bank).valid := CDC.io.FromMatrixRegIO.WriteRequestByteMask(bank).valid } dest.ReadWriteRequest := CDC.io.FromMatrixRegIO.ReadWriteRequest }.otherwise { @@ -245,6 +254,7 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.ReadBankAddr(bank).valid := false.B dest.WriteBankAddr(bank).valid := false.B dest.WriteRequestData(bank).valid := false.B + dest.WriteRequestByteMask(bank).valid := false.B } } @@ -252,6 +262,7 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.ReadBankAddr(bank).bits := CDC.io.FromMatrixRegIO.ReadBankAddr(bank).bits dest.WriteBankAddr(bank).bits := CDC.io.FromMatrixRegIO.WriteBankAddr(bank).bits dest.WriteRequestData(bank).bits := CDC.io.FromMatrixRegIO.WriteRequestData(bank).bits + dest.WriteRequestByteMask(bank).bits := CDC.io.FromMatrixRegIO.WriteRequestByteMask(bank).bits } } diff --git a/src/main/scala/TaskController.scala b/src/main/scala/TaskController.scala index 9df40d2..334f3a0 100644 --- a/src/main/scala/TaskController.scala +++ b/src/main/scala/TaskController.scala @@ -293,6 +293,9 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { // Completion handshakes and scoreboard updates io.AML_MicroTask_Config.MicroTaskEndReady := pendingLoadA when(pendingLoadA && io.AML_MicroTask_Config.MicroTaskEndValid) { + if (YJPTASKDebugEnable) { + printf("[TaskController_LoadAFinish<%d>] reg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingLoadAReg, pendingLoadAFifoIdx, io.AML_MicroTask_Config.MicroTaskEndValid, io.AML_MicroTask_Config.MicroTaskEndReady) + } scoreboard.io.update.load_finish_a := true.B scoreboard.io.update.load_finish_a_reg := pendingLoadAReg pendingLoadA := false.B @@ -309,6 +312,9 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.BML_MicroTask_Config.MicroTaskEndReady := pendingLoadB when(pendingLoadB && io.BML_MicroTask_Config.MicroTaskEndValid) { + if (YJPTASKDebugEnable) { + printf("[TaskController_LoadBFinish<%d>] reg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingLoadBReg, pendingLoadBFifoIdx, io.BML_MicroTask_Config.MicroTaskEndValid, io.BML_MicroTask_Config.MicroTaskEndReady) + } scoreboard.io.update.load_finish_b := true.B scoreboard.io.update.load_finish_b_reg := pendingLoadBReg pendingLoadB := false.B @@ -326,6 +332,9 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.CML_MicroTask_Config.MicroTaskEndReady := pendingLoadC || pendingStore when(io.CML_MicroTask_Config.MicroTaskEndValid) { when(pendingLoadC) { + if (YJPTASKDebugEnable) { + printf("[TaskController_LoadCFinish<%d>] reg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingLoadCReg, pendingLoadCFifoIdx, io.CML_MicroTask_Config.MicroTaskEndValid, io.CML_MicroTask_Config.MicroTaskEndReady) + } scoreboard.io.update.load_finish_c := true.B scoreboard.io.update.load_finish_c_reg := pendingLoadCReg pendingLoadC := false.B @@ -339,6 +348,9 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { loadCFinishEvent.isAcc := true.B loadCFinishEventEn := true.B }.elsewhen(pendingStore) { + if (YJPTASKDebugEnable) { + printf("[TaskController_StoreCFinish<%d>] reg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingStoreReg, pendingStoreFifoIdx, io.CML_MicroTask_Config.MicroTaskEndValid, io.CML_MicroTask_Config.MicroTaskEndReady) + } scoreboard.io.update.store_finish := true.B scoreboard.io.update.store_finish_c_reg := pendingStoreReg pendingStore := false.B @@ -355,6 +367,9 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.ADC_MicroTask_Config.MicroTaskEndReady := pendingComputeA when(pendingComputeA && io.ADC_MicroTask_Config.MicroTaskEndValid) { + if (YJPTASKDebugEnable) { + printf("[TaskController_ComputeAFinish<%d>] aReg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingComputeAReg, pendingComputeAFifoIdx, io.ADC_MicroTask_Config.MicroTaskEndValid, io.ADC_MicroTask_Config.MicroTaskEndReady) + } scoreboard.io.update.compute_read_finish_a := true.B scoreboard.io.update.compute_read_finish_a_reg := pendingComputeAReg pendingComputeA := false.B @@ -373,6 +388,9 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.BDC_MicroTask_Config.MicroTaskEndReady := pendingComputeB when(pendingComputeB && io.BDC_MicroTask_Config.MicroTaskEndValid) { + if (YJPTASKDebugEnable) { + printf("[TaskController_ComputeBFinish<%d>] bReg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingComputeBReg, pendingComputeBFifoIdx, io.BDC_MicroTask_Config.MicroTaskEndValid, io.BDC_MicroTask_Config.MicroTaskEndReady) + } scoreboard.io.update.compute_read_finish_b := true.B scoreboard.io.update.compute_read_finish_b_reg := pendingComputeBReg pendingComputeB := false.B @@ -391,6 +409,9 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.CDC_MicroTask_Config.MicroTaskEndReady := pendingComputeC when(pendingComputeC && io.CDC_MicroTask_Config.MicroTaskEndValid) { + if (YJPTASKDebugEnable) { + printf("[TaskController_ComputeCFinish<%d>] cReg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingComputeCReg, pendingComputeCFifoIdx, io.CDC_MicroTask_Config.MicroTaskEndValid, io.CDC_MicroTask_Config.MicroTaskEndReady) + } scoreboard.io.update.compute_write_finish_c := true.B scoreboard.io.update.compute_write_finish_c_reg := pendingComputeCReg pendingComputeC := false.B @@ -572,28 +593,80 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { val issueZeroTr = issueFire && isMzeroTr val issueRelease = issueFire && isRelease + val loadDataType = MuxLookup(lsuInfo.widths, ElementDataType.DataTypeWidth32)(Seq( + Bundles.MSew.e8 -> ElementDataType.DataTypeWidth8, + Bundles.MSew.e16 -> ElementDataType.DataTypeWidth16, + Bundles.MSew.e32 -> ElementDataType.DataTypeWidth32, + Bundles.MSew.e4 -> ElementDataType.DataTypeWidth4 + )) + val loadKBytes = MuxLookup(lsuInfo.widths, lsuInfo.column)(Seq( + Bundles.MSew.e8 -> lsuInfo.column, + Bundles.MSew.e16 -> (lsuInfo.column << 1), + Bundles.MSew.e32 -> (lsuInfo.column << 2), + Bundles.MSew.e4 -> (lsuInfo.column >> 1) + )) + val loadKBytes_for_B = MuxLookup(lsuInfo.widths, lsuInfo.row)(Seq( + Bundles.MSew.e8 -> lsuInfo.row, + Bundles.MSew.e16 -> (lsuInfo.row << 1), + Bundles.MSew.e32 -> (lsuInfo.row << 2), + Bundles.MSew.e4 -> (lsuInfo.row >> 1) + )) + val loadKBeatCount_for_B = (loadKBytes_for_B + (outsideDataWidthByte - 1).U) >> log2Ceil(outsideDataWidthByte) + val loadKBeatCount = (loadKBytes + (outsideDataWidthByte - 1).U) >> log2Ceil(outsideDataWidthByte) + val loadHasTail = MuxLookup(loadDataType, false.B)(Seq( + ElementDataType.DataTypeWidth8 -> lsuInfo.column(5, 0).orR, + ElementDataType.DataTypeWidth16 -> lsuInfo.column(4, 0).orR, + ElementDataType.DataTypeWidth32 -> lsuInfo.column(3, 0).orR + )) + val loadTailByteMask = MuxLookup(loadDataType, 0.U(log2Ceil(outsideDataWidthByte + 1).W))(Seq( + ElementDataType.DataTypeWidth8 -> lsuInfo.column(5, 0), + ElementDataType.DataTypeWidth16 -> Cat(lsuInfo.column(4, 0), 0.U(1.W)), + ElementDataType.DataTypeWidth32 -> Cat(lsuInfo.column(3, 0), 0.U(2.W)) + )) + val loadHasTail_for_B = MuxLookup(loadDataType, false.B)(Seq( + ElementDataType.DataTypeWidth8 -> lsuInfo.row(5, 0).orR, + ElementDataType.DataTypeWidth16 -> lsuInfo.row(4, 0).orR, + ElementDataType.DataTypeWidth32 -> lsuInfo.row(3, 0).orR + )) + val loadTailByteMask_for_B = MuxLookup(loadDataType, 0.U(log2Ceil(outsideDataWidthByte + 1).W))(Seq( + ElementDataType.DataTypeWidth8 -> lsuInfo.row(5, 0), + ElementDataType.DataTypeWidth16 -> Cat(lsuInfo.row(4, 0), 0.U(1.W)), + ElementDataType.DataTypeWidth32 -> Cat(lsuInfo.row(3, 0), 0.U(2.W)) + )) + val loadNBytes = MuxLookup(loadDataType, lsuInfo.column << 2)(Seq( + ElementDataType.DataTypeWidth8 -> lsuInfo.column, + ElementDataType.DataTypeWidth16 -> (lsuInfo.column << 1), + ElementDataType.DataTypeWidth32 -> (lsuInfo.column << 2) + )) + val loadNBeatCount = (loadNBytes + (outsideDataWidthByte - 1).U) >> log2Ceil(outsideDataWidthByte) + val loadNHasTail = MuxLookup(loadDataType, false.B)(Seq( + ElementDataType.DataTypeWidth8 -> lsuInfo.column(5, 0).orR, + ElementDataType.DataTypeWidth16 -> lsuInfo.column(4, 0).orR, + ElementDataType.DataTypeWidth32 -> lsuInfo.column(3, 0).orR + )) + val loadNTailByteMask = MuxLookup(loadDataType, 0.U(log2Ceil(outsideDataWidthByte + 1).W))(Seq( + ElementDataType.DataTypeWidth8 -> lsuInfo.column(5, 0), + ElementDataType.DataTypeWidth16 -> Cat(lsuInfo.column(4, 0), 0.U(1.W)), + ElementDataType.DataTypeWidth32 -> Cat(lsuInfo.column(3, 0), 0.U(2.W)) + )) + when(issueLoad) { val regIdx = lsuInfo.ms(1, 0) val loadIdx = loadAllocIdx + assert(lsuInfo.stride(5, 0) === 0.U, "TaskController load stride must be 64B aligned") + when(needA) { io.AML_MicroTask_Config.ApplicationTensor_A.ApplicationTensor_A_BaseVaddr := lsuInfo.baseAddr io.AML_MicroTask_Config.ApplicationTensor_A.ApplicationTensor_A_Stride_M := lsuInfo.stride - io.AML_MicroTask_Config.ApplicationTensor_A.dataType := MuxLookup(lsuInfo.widths, ElementDataType.DataTypeWidth32)(Seq( - Bundles.MSew.e8 -> ElementDataType.DataTypeWidth8, - Bundles.MSew.e16 -> ElementDataType.DataTypeWidth16, - Bundles.MSew.e32 -> ElementDataType.DataTypeWidth32, - Bundles.MSew.e4 -> ElementDataType.DataTypeWidth4 - )) + io.AML_MicroTask_Config.ApplicationTensor_A.dataType := loadDataType + io.AML_MicroTask_Config.ApplicationTensor_A.HasTail := loadHasTail + io.AML_MicroTask_Config.ApplicationTensor_A.TailByteMask := loadTailByteMask + io.AML_MicroTask_Config.ApplicationTensor_A.K_Beat_Count := loadKBeatCount io.AML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := true.B io.AML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := false.B io.AML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B io.AML_MicroTask_Config.MatrixRegTensor_M := lsuInfo.row - io.AML_MicroTask_Config.MatrixRegTensor_K := MuxLookup(lsuInfo.widths, lsuInfo.column)(Seq( - Bundles.MSew.e8 -> lsuInfo.column, - Bundles.MSew.e16 -> lsuInfo.column * 2.U, - Bundles.MSew.e32 -> lsuInfo.column * 4.U, - Bundles.MSew.e4 -> lsuInfo.column / 2.U - )) / ReduceWidthByte.U + io.AML_MicroTask_Config.MatrixRegTensor_K := loadKBeatCount io.AML_MicroTask_Config.MatrixRegId := regIdx if (EnableDifftest) { io.AML_MicroTask_Config.pc.get := headEntry.ctrl.pc.get @@ -603,6 +676,11 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.AML_MicroTask_Config.Conherent := true.B io.AML_MicroTask_Config.MicroTaskValid := true.B + if (YJPTASKDebugEnable) { + printf("[TaskController_IssueAML<%d>] reg=%d fifo=%d row=%d col=%d stride=%x coher=%d kBeat=%d tail=%d tailMask=%d base=%x\n", + io.DebugTimeStampe, regIdx, loadIdx, lsuInfo.row, lsuInfo.column, lsuInfo.stride, io.AML_MicroTask_Config.Conherent, + loadKBeatCount, loadHasTail, loadTailByteMask, lsuInfo.baseAddr) + } pendingLoadA := true.B pendingLoadAReg := regIdx @@ -612,19 +690,12 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.BML_MicroTask_Config.ApplicationTensor_B.ApplicationTensor_B_BaseVaddr := lsuInfo.baseAddr io.BML_MicroTask_Config.ApplicationTensor_B.ApplicationTensor_B_Stride_N := lsuInfo.stride io.BML_MicroTask_Config.ApplicationTensor_B.BlockTensor_B_BaseVaddr := lsuInfo.baseAddr - io.BML_MicroTask_Config.ApplicationTensor_B.dataType := MuxLookup(lsuInfo.widths, ElementDataType.DataTypeWidth32)(Seq( - Bundles.MSew.e8 -> ElementDataType.DataTypeWidth8, - Bundles.MSew.e16 -> ElementDataType.DataTypeWidth16, - Bundles.MSew.e32 -> ElementDataType.DataTypeWidth32, - Bundles.MSew.e4 -> ElementDataType.DataTypeWidth4 - )) + io.BML_MicroTask_Config.ApplicationTensor_B.dataType := loadDataType + io.BML_MicroTask_Config.ApplicationTensor_B.HasTail := loadHasTail_for_B + io.BML_MicroTask_Config.ApplicationTensor_B.TailByteMask := loadTailByteMask_for_B + io.BML_MicroTask_Config.ApplicationTensor_B.K_Beat_Count := loadKBeatCount_for_B io.BML_MicroTask_Config.MatrixRegTensor_N := lsuInfo.column - io.BML_MicroTask_Config.MatrixRegTensor_K := MuxLookup(lsuInfo.widths, lsuInfo.row)(Seq( - Bundles.MSew.e8 -> lsuInfo.row, - Bundles.MSew.e16 -> lsuInfo.row * 2.U, - Bundles.MSew.e32 -> lsuInfo.row * 4.U, - Bundles.MSew.e4 -> lsuInfo.row / 2.U - )) / ReduceWidthByte.U + io.BML_MicroTask_Config.MatrixRegTensor_K := loadKBeatCount_for_B io.BML_MicroTask_Config.MatrixRegId := regIdx if (EnableDifftest) { io.BML_MicroTask_Config.pc.get := headEntry.ctrl.pc.get @@ -632,6 +703,11 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { } io.BML_MicroTask_Config.Conherent := true.B io.BML_MicroTask_Config.MicroTaskValid := true.B + if (YJPTASKDebugEnable) { + printf("[TaskController_IssueBML<%d>] reg=%d fifo=%d row=%d col=%d stride=%x coher=%d kBeat=%d tail=%d tailMask=%d base=%x\n", + io.DebugTimeStampe, regIdx, loadIdx, lsuInfo.row, lsuInfo.column, lsuInfo.stride, io.BML_MicroTask_Config.Conherent, + loadKBeatCount_for_B, loadHasTail_for_B, loadTailByteMask_for_B, lsuInfo.baseAddr) + } pendingLoadB := true.B pendingLoadBReg := regIdx pendingLoadBFifoIdx := loadIdx @@ -640,12 +716,10 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.CML_MicroTask_Config.ApplicationTensor_C.ApplicationTensor_C_BaseVaddr := lsuInfo.baseAddr io.CML_MicroTask_Config.ApplicationTensor_C.ApplicationTensor_C_Stride_M := lsuInfo.stride io.CML_MicroTask_Config.ApplicationTensor_C.BlockTensor_C_BaseVaddr := lsuInfo.baseAddr - io.CML_MicroTask_Config.ApplicationTensor_C.dataType := MuxLookup(lsuInfo.widths, ElementDataType.DataTypeWidth32)(Seq( - Bundles.MSew.e8 -> ElementDataType.DataTypeWidth8, - Bundles.MSew.e16 -> ElementDataType.DataTypeWidth16, - Bundles.MSew.e32 -> ElementDataType.DataTypeWidth32, - Bundles.MSew.e4 -> ElementDataType.DataTypeWidth4 - )) + io.CML_MicroTask_Config.ApplicationTensor_C.dataType := loadDataType + io.CML_MicroTask_Config.ApplicationTensor_C.HasTail := loadNHasTail + io.CML_MicroTask_Config.ApplicationTensor_C.TailByteMask := loadNTailByteMask + io.CML_MicroTask_Config.ApplicationTensor_C.N_Beat_Count := loadNBeatCount io.CML_MicroTask_Config.Conherent := true.B io.CML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := true.B io.CML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := false.B @@ -660,6 +734,11 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.CML_MicroTask_Config.IsLoadMicroTask := true.B io.CML_MicroTask_Config.IsStoreMicroTask := false.B io.CML_MicroTask_Config.MicroTaskValid := true.B + if (YJPTASKDebugEnable) { + printf("[TaskController_IssueCMLLoad<%d>] reg=%d fifo=%d row=%d col=%d stride=%x coher=%d transpose=%d nBeat=%d tail=%d tailMask=%d base=%x\n", + io.DebugTimeStampe, regIdx, loadIdx, lsuInfo.row, lsuInfo.column, lsuInfo.stride, io.CML_MicroTask_Config.Conherent, + lsuInfo.transpose, loadNBeatCount, loadNHasTail, loadNTailByteMask, lsuInfo.baseAddr) + } io.CML_MicroTask_Config.Is_Transpose := lsuInfo.transpose pendingLoadC := true.B @@ -714,6 +793,10 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.CML_MicroTask_Config.IsLoadMicroTask := true.B io.CML_MicroTask_Config.IsStoreMicroTask := false.B io.CML_MicroTask_Config.MicroTaskValid := true.B + if (YJPTASKDebugEnable) { + printf("[TaskController_IssueZeroAcc<%d>] reg=%d fifo=%d M=%d N=%d\n", + io.DebugTimeStampe, regIdx, loadIdx, io.CML_MicroTask_Config.MatrixRegTensor_M, io.CML_MicroTask_Config.MatrixRegTensor_N) + } io.CML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := true.B io.CML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := false.B io.CML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B @@ -771,6 +854,10 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.AML_MicroTask_Config.MatrixRegTensor_K := cuteParams.Tensor_K.U / ReduceWidthByte.U io.AML_MicroTask_Config.MatrixRegId := regIdx io.AML_MicroTask_Config.MicroTaskValid := true.B + if (YJPTASKDebugEnable) { + printf("[TaskController_IssueZeroTr<%d>] reg=%d fifo=%d M=%d K=%d\n", + io.DebugTimeStampe, regIdx, loadIdx, io.AML_MicroTask_Config.MatrixRegTensor_M, io.AML_MicroTask_Config.MatrixRegTensor_K) + } io.AML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := true.B io.AML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := false.B io.AML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B @@ -827,10 +914,6 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { val cReg = Mux(isMma, mmaInfo.md(1, 0), arithInfo.md(1, 0)) val computeIdx = computeIssueIdx - io.ADC_MicroTask_Config.MicroTaskValid := true.B - io.BDC_MicroTask_Config.MicroTaskValid := true.B - io.CDC_MicroTask_Config.MicroTaskValid := true.B - val mVal = mmaInfo.mtilem val nVal = mmaInfo.mtilen // val kVal = mmaInfo.mtilek @@ -841,6 +924,14 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { Bundles.MSew.e4 -> mmaInfo.mtilek / 2.U, )) + io.ADC_MicroTask_Config.MicroTaskValid := true.B + io.BDC_MicroTask_Config.MicroTaskValid := true.B + io.CDC_MicroTask_Config.MicroTaskValid := true.B + if (YJPTASKDebugEnable) { + printf("[TaskController_IssueMMA<%d>] aReg=%d bReg=%d cReg=%d fifo=%d m=%d n=%d k=%d isMma=%d isFp=%d\n", + io.DebugTimeStampe, aReg, bReg, cReg, computeIdx, mVal, nVal, kVal, isMma, Mux(isMma, mmaInfo.isfp, false.B)) + } + io.ADC_MicroTask_Config.ApplicationTensor_A.dataType := ElementDataType.DataTypeWidth8 io.ADC_MicroTask_Config.MatrixRegTensor_M := mVal io.ADC_MicroTask_Config.MatrixRegTensor_N := nVal @@ -929,6 +1020,7 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { when(issueStore) { val regIdx = lsuInfo.ms(1, 0) val storeIdx = storeIssueIdx + assert(lsuInfo.stride(5, 0) === 0.U, "TaskController store stride must be 64B aligned") io.CML_MicroTask_Config.ApplicationTensor_D.ApplicationTensor_D_BaseVaddr := lsuInfo.baseAddr io.CML_MicroTask_Config.ApplicationTensor_D.ApplicationTensor_D_Stride_M := lsuInfo.stride @@ -951,6 +1043,11 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.CML_MicroTask_Config.IsStoreMicroTask := true.B io.CML_MicroTask_Config.MicroTaskValid := true.B + if (YJPTASKDebugEnable) { + printf("[TaskController_IssueCMLStore<%d>] reg=%d fifo=%d row=%d col=%d stride=%x coher=%d transpose=%d base=%x dataType=%d\n", + io.DebugTimeStampe, regIdx, storeIdx, lsuInfo.row, lsuInfo.column, lsuInfo.stride, io.CML_MicroTask_Config.Conherent, + lsuInfo.transpose, lsuInfo.baseAddr, io.CML_MicroTask_Config.ApplicationTensor_D.dataType) + } if (EnableDifftest) { io.CML_MicroTask_Config.pc.get := headEntry.ctrl.pc.get io.CML_MicroTask_Config.coreid.get := headEntry.ctrl.coreid.get @@ -981,6 +1078,9 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { when(issueRelease) { io.ygjkctrl.mrelease.valid := true.B + if (YJPTASKDebugEnable) { + printf("[TaskController_IssueRelease<%d>] token=%d\n", io.DebugTimeStampe, releaseInfo.tokenRd) + } io.ygjkctrl.mrelease.bits.tokenRd(releaseInfo.tokenRd) := true.B releaseIssueEvent.eventType := 0.U releaseIssueEvent.token := releaseInfo.tokenRd @@ -995,6 +1095,7 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { difftestAmuFinish.pc := headEntry.ctrl.pc.get difftestAmuFinish.bankValid.foreach(_ := false.B) difftestAmuFinish.bankAddr.foreach(_ := 0.U) + difftestAmuFinish.bankMask.foreach(_ := 0.U) difftestAmuFinish.data.foreach(_ := 0.U) difftestAmuFinish.finish := io.ygjkctrl.mrelease.valid } @@ -1016,4 +1117,3 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { releaseEventTable.log(releaseIssueEvent, releaseIssueEventEn, "ReleaseIssue", clock, reset) } - From 32709ddd2934f0c72d3757b7fcf536af9f1cde9a Mon Sep 17 00:00:00 2001 From: hsj <2383755847@qq.com> Date: Fri, 5 Jun 2026 00:02:26 +0800 Subject: [PATCH 2/2] feat(load): implement fine-grained access and int8 transpose load --- src/main/scala/ABMatrixReg.scala | 50 +- src/main/scala/ADataController.scala | 2 +- src/main/scala/AMemoryLoader.scala | 618 +++++++- src/main/scala/BDataController.scala | 2 +- src/main/scala/BMemoryLoader.scala | 429 ++++-- src/main/scala/Bundles.scala | 74 +- src/main/scala/CDataController.scala | 25 +- src/main/scala/CMatrixReg.scala | 43 +- src/main/scala/CMemoryLoader.scala | 364 ++--- src/main/scala/CUTE2YGJK.scala | 3 +- src/main/scala/CUTEParameters.scala | 1046 ++++++++----- src/main/scala/CUTETOP.scala | 215 ++- src/main/scala/LocalMMU.scala | 209 ++- src/main/scala/TaskController.scala | 2035 +++++++++++++++----------- 14 files changed, 3394 insertions(+), 1721 deletions(-) diff --git a/src/main/scala/ABMatrixReg.scala b/src/main/scala/ABMatrixReg.scala index 2dba0b6..f8ffd74 100644 --- a/src/main/scala/ABMatrixReg.scala +++ b/src/main/scala/ABMatrixReg.scala @@ -14,7 +14,7 @@ import utility.sram.SRAMTemplate class ABMatrixRegIO(implicit p: Parameters) extends CuteBundle{ val FromDataController = new ABDataControlMatrixRegIO - val FromMemoryLoader = new ABMemoryLoaderMatrixRegIO + val FromMemoryLoader = new ABMemoryLoaderMatrixRegIO } class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ @@ -22,7 +22,7 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ val MatrixRegIO = new ABMatrixRegIO }) - + // 读写优先级逻辑:写优先于读 // 目前在Loader、DataController里面都加了FIFO,能保证一些堵的情况的发生 @@ -33,12 +33,12 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ // MemoryLoader端的信号 val MemoryLoaderBankAddr = io.MatrixRegIO.FromMemoryLoader.BankAddr val MemoryLoaderData = io.MatrixRegIO.FromMemoryLoader.Data - // 【修改 :提取掩码信号】 val MemoryLoaderByteMask = io.MatrixRegIO.FromMemoryLoader.ByteMask + // 写优先的MatrixReg控制逻辑 // write_go: 只要有写入请求(正常写或零填充)就为true val write_go = MemoryLoaderBankAddr.zip(MemoryLoaderData).map{case (a, b) => a.valid && b.valid}.reduce(_||_) - + // read_go: 只有在没有写请求时才允许读,实现写优先 val read_go = io.MatrixRegIO.FromDataController.BankAddr.valid && !MemoryLoaderBankAddr.map(_.valid).reduce(_||_) && @@ -50,23 +50,15 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ // SRAM下一拍返回结果,所以使用上一拍的ready作为valid io.MatrixRegIO.FromDataController.Data.valid := RegNext(read_go) val debug_s1_bank_addr = RegNext(DataControllerBankAddr) - - when(io.MatrixRegIO.FromDataController.BankAddr.fire) { - if (YJPAMLDebugEnable || YJPBMLDebugEnable) { - printf("[ABMatrixReg_ReadReq(%d)] addr0=%d\n", scp_id.U, io.MatrixRegIO.FromDataController.BankAddr.bits(0)) - } - } // 实例化多个SRAM作为多个bank val sram_banks = (0 until ABMatrixRegNBanks) map { i => - // 【修改 :重构 SRAMTemplate 物理映射语义】 - // 将原本宽字长、单 way 的 SRAM,转变为 1 Byte 为颗粒度、多 way 的 SRAM。 - // 综合工具(DC/Genus)会将其自动识别为:带有 Byte Write Enable (BWEB) 引脚的单个宏单元,或由标准单元组成的寄存器堆,绝不会产生碎片化拥塞。 + // Use byte-wide ways so SRAMTemplate waymask becomes byte write enable. val bank = Module(new SRAMTemplate( - gen = UInt(8.W), // 核心修改:基础数据单元改为 1 Byte - set = ABMatrixRegBankNEntrys, // 深度保持不变 - way = ABMatrixRegEntryByteSize, // 核心修改:相联度(way)数量等于掩码(Byte)数量 + gen = UInt(8.W), + set = ABMatrixRegBankNEntries, + way = ABMatrixRegEntryByteSize, singlePort = true, latency = 1, hasMbist = false, @@ -82,8 +74,10 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ when(RegNext(read_go)) { - if (YJPAMLDebugEnable) { - printf("[ABMatrixReg_ReadResp(%d)]Bank(%d): debug_s1_bank_addr = %d, s1_bank_read_data = %x\n", + // 输出读的信息 + if (YJPDebugEnable) + { + printf("[ABMatrixReg_Read(%d)]Bank(%d): debug_s1_bank_addr = %d, s1_bank_read_data = %x\n", scp_id.U, i.U, debug_s1_bank_addr(0), s1_bank_read_data) } } @@ -94,10 +88,7 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ // 写数据逻辑 val s0_bank_write_addr = MemoryLoaderBankAddr(i).bits val s0_bank_write_data = MemoryLoaderData(i).bits - - // 【修改 :提取单 Bank 掩码】 val s0_bank_write_mask = MemoryLoaderByteMask(i).bits - // 写握手必须同时满足 addr, data, mask 皆有效 val s0_bank_write_valid = MemoryLoaderBankAddr(i).valid && MemoryLoaderData(i).valid && MemoryLoaderByteMask(i).valid // 最终的写入控制 @@ -105,11 +96,11 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ val s0_final_write_addr = MemoryLoaderBankAddr(i).bits val s0_final_write_data = MemoryLoaderData(i).bits - - when(s0_final_write_valid) { - if (YJPAMLDebugEnable) { - printf("[ABMatrixReg_Write(%d)]Bank(%d): s0_bank_write_addr = %d, s0_bank_write_data = %x, mask = %b\n", - scp_id.U, i.U, s0_bank_write_addr, s0_final_write_data, s0_bank_write_mask) + when(write_go && s0_bank_write_valid){ + if (YJPDebugEnable) + { + printf("[ABMatrixReg_Write(%d)]Bank(%d): s0_bank_write_addr = %d, s0_bank_write_data = %x, mask = %x\n", + scp_id.U, i.U, s0_bank_write_addr, s0_bank_write_data, s0_bank_write_mask) } } @@ -119,21 +110,14 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ bank.io.r.req.valid := bank_read_valid bank.io.r.req.bits.setIdx := s0_bank_read_addr // 读响应在下一拍返回(latency=1) - // 【修改 :零开销数据拼装】 - // bank.io.r.resp.data 此时是一个 Vec(way, UInt(8.W)) - // 使用 .asUInt 将 Vec 无缝强制转换为大位宽 UInt,不仅代码整洁,而且在综合时完全是一根线(Wire),没有任何面积和时序延迟开销。 s1_bank_read_data := bank.io.r.resp.data.asUInt // 连接SRAMTemplate的写接口 bank.io.w.req.valid := s0_final_write_valid bank.io.w.req.bits.setIdx := s0_final_write_addr bank.io.w.req.bits.waymask.get := s0_bank_write_mask - - // 使用 .asTypeOf(Vec) 将大宽度的 s0_final_write_data (如 256.W) 直接解包为等宽的 Vec(32, UInt(8.W))。 - // 这取代了臃肿的 for 循环位截取,对后端极度友好,综合后就是干净的连线(Assign)。 bank.io.w.req.bits.data := s0_final_write_data.asTypeOf(Vec(ABMatrixRegEntryByteSize, UInt(8.W))) bank } } - diff --git a/src/main/scala/ADataController.scala b/src/main/scala/ADataController.scala index b4862b8..cf85f02 100644 --- a/src/main/scala/ADataController.scala +++ b/src/main/scala/ADataController.scala @@ -151,7 +151,7 @@ class ADataController(implicit p: Parameters) extends CuteModule{ printf("[ADataController<%d>]ADataController: M_IteratorMax is %d, N_IteratorMax is %d, K_IteratorMax is %d\n",io.DebugInfo.DebugTimeStampe, M_IteratorMax, N_IteratorMax, K_IteratorMax) } //MTE循环的最外层是M,然后是N,最后是K,所以这里在同步信号的ComputeGo的协同下,执行Max_Caculate_Iter次取数 - val next_addr = Wire(UInt(ABMatrixRegBankNEntrys.W)) + val next_addr = Wire(UInt(ABMatrixRegBankNEntries.W)) next_addr := M_Iterator * K_IteratorMax + K_Iterator MatrixRegRequestBankAddr.bits.foreach(_ := next_addr) diff --git a/src/main/scala/AMemoryLoader.scala b/src/main/scala/AMemoryLoader.scala index e0d1f7c..dcd4539 100644 --- a/src/main/scala/AMemoryLoader.scala +++ b/src/main/scala/AMemoryLoader.scala @@ -11,8 +11,255 @@ import org.chipsalliance.cde.config._ class ASourceIdSearch(implicit p: Parameters) extends CuteBundle{ val MatrixRegBankId = UInt(log2Ceil(ABMatrixRegNBanks).W) - val MatrixRegAddr = UInt(log2Ceil(ABMatrixRegBankNEntrys).W) + val MatrixRegAddr = UInt(log2Ceil(ABMatrixRegBankNEntries).W) val MatrixRegisTail = Bool() + val BeatIndex = UInt(log2Ceil(ABMatrixRegEntryByteSize).W) +} + +class TransDataPacket(implicit p: Parameters) extends CuteBundle { + val data = UInt(8.W) + val mask = Bool() + val entry_offset = UInt(log2Ceil(ABMatrixRegEntryByteSize).W) +} + +class TransAlignPipe(bankId: Int)(implicit p: Parameters) extends CuteModule { + private val transLoadSize = Trans_Load_Size + private val transLoadSizeBits = log2Ceil(transLoadSize) + private val entryOffsetBits = log2Ceil(ABMatrixRegEntryByteSize) + private val routerLatency = 3 + private val drainCycles = transLoadSize + 2 + routerLatency // 两级 barrel shifter 和 OOORouter 三级流水都需要排空。 + + val io = IO(new Bundle { + val in_data = Input(UInt(outsideDataWidth.W)) + val in_mask = Input(UInt(outsideDataWidthByte.W)) + val resp_beat_cnt = Input(UInt(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + val entry_offset = Input(UInt(entryOffsetBits.W)) + val debug_time = Input(UInt(64.W)) + val is_drain_trigger = Input(Bool()) + val in_valid = Input(Bool()) + val bus_stall = Output(Bool()) + val empty = Output(Bool()) + val out = Decoupled(Vec(transLoadSize, new TransDataPacket)) + }) + + require(isPow2(transLoadSize), "Trans_Load_Size must be power of 2") + require(transLoadSize == 8 || transLoadSize == 16, "Trans_Load_Size currently supports 8 or 16") + + val s_normal :: s_drain :: Nil = Enum(2) + val state = RegInit(s_normal) + val drain_cnt = RegInit(0.U(log2Ceil(drainCycles + 1).W)) + + val out_valid_now = Wire(Bool()) + + val inBytes = io.in_data.asTypeOf(Vec(outsideDataWidthByte, UInt(8.W))) + val inMaskBits = io.in_mask.asBools + val bankData = VecInit((0 until transLoadSize).map { i => + val byteIdx = bankId + i * ABMatrixRegNBanks + Mux(inMaskBits(byteIdx), inBytes(byteIdx), 0.U(8.W)) + }) + // Tail/full mask is consumed here only to zero invalid payload bytes. The + // downstream write mask is regenerated from BeatIndex. + + def rotateLeft[T <: Data](data: Vec[T], amount: UInt): Vec[T] = { + VecInit((0 until transLoadSize).map { i => + val idx = (i.U(transLoadSizeBits.W) + amount.pad(transLoadSizeBits))(transLoadSizeBits - 1, 0) + data(idx) + }) + } + + val shift_amt = io.resp_beat_cnt(transLoadSizeBits - 1, 0) + val shift_low = shift_amt(1, 0) + val shift_high_amt = if (transLoadSizeBits > 2) Cat(shift_amt(transLoadSizeBits - 1, 2), 0.U(2.W)) else 0.U(transLoadSizeBits.W) + + val coarseData = rotateLeft(bankData, shift_high_amt) + + val stage1Data = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(8.W)))) + val stage1Valid = RegInit(false.B) + val stage1ShiftLow = RegInit(0.U(2.W)) + val stage1EntryOffset = RegInit(0.U(entryOffsetBits.W)) + + val stage2Data = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(8.W)))) + val stage2Valid = RegInit(false.B) + val stage2EntryOffset = RegInit(0.U(entryOffsetBits.W)) + + val fineData = rotateLeft(stage1Data, stage1ShiftLow) + + val laneData = Seq.tabulate(transLoadSize) { i => + RegInit(VecInit(Seq.fill(i + 1)(0.U(8.W)))) + } + // laneValid only marks that a response packet occupies this triangular + // pipeline slot. Tail/full masks have already zeroed invalid payload bytes; + // SRAM write masks are regenerated later from BeatIndex. + val laneValid = Seq.tabulate(transLoadSize) { i => + RegInit(VecInit(Seq.fill(i + 1)(false.B))) + } + // offset 与最后一级 lane 的物理位置绑定,而不是一拍内所有 lane 共享。 + // 新输入的 offset 进入 offset(0),随后沿 offset(0)->offset(1)->... 移动, + // 对齐“同一笔输入数据依次出现在最后一级第 0、1、...、N-1 个位置”的时空轨迹。 + val entryOffsetPipe = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(entryOffsetBits.W)))) + + val laneValidBits = laneValid.zipWithIndex.flatMap { case (validVec, laneIdx) => + (0 to laneIdx).map(validVec(_)) + } + val any_valid_in_lane_pipe = laneValidBits.reduce(_ || _) + val any_valid_in_pipe = stage1Valid || stage2Valid || any_valid_in_lane_pipe + val input_accept = state === s_normal && io.in_valid + val pipe_advance = input_accept || (state === s_drain) + + if (bankId == 0) { + when(input_accept) { + printf("[AML_TRANS_ALIGN_IN<%d>] beatCnt:%d shift:%d shiftHigh:%d shiftLow:%d entry:%d raw:%x mask:%x bankData:%x coarse:%x drainTrig:%d\n", + io.debug_time, io.resp_beat_cnt, shift_amt, shift_high_amt, shift_low, + io.entry_offset, io.in_data, io.in_mask, bankData.asUInt, + coarseData.asUInt, io.is_drain_trigger.asUInt) + } + } + + // MatrixReg 写侧恒 ready,因此转置对齐流水线不再等待末端 valid/ready。 + // drain 期间仍反压总线响应,防止下一组数据在旧组 router 三级流水未排空前进入。 + io.bus_stall := state === s_drain + io.empty := state === s_normal && !any_valid_in_pipe + + when(pipe_advance) { + stage1Data := coarseData + stage1Valid := input_accept + stage1ShiftLow := shift_low + stage1EntryOffset := io.entry_offset + + stage2Data := fineData + stage2Valid := stage1Valid + stage2EntryOffset := stage1EntryOffset + + for (i <- 0 until transLoadSize) { + laneData(i)(0) := stage2Data(i) + laneValid(i)(0) := stage2Valid + for (j <- 1 to i) { + laneData(i)(j) := laneData(i)(j - 1) + laneValid(i)(j) := laneValid(i)(j - 1) + } + } + + entryOffsetPipe(0) := Mux(stage2Valid, stage2EntryOffset, 0.U) + for (i <- 1 until transLoadSize) { + entryOffsetPipe(i) := entryOffsetPipe(i - 1) + } + } + + val outPackets = Wire(Vec(transLoadSize, new TransDataPacket)) + val outValids = Wire(Vec(transLoadSize, Bool())) + for (i <- 0 until transLoadSize) { + val depthIdx = i + outPackets(i).data := laneData(i)(depthIdx) + outPackets(i).mask := laneValid(i)(depthIdx) + outPackets(i).entry_offset := entryOffsetPipe(i) + outValids(i) := laneValid(i)(depthIdx) + } + out_valid_now := outValids.asUInt.orR + + io.out.bits := outPackets + io.out.valid := out_valid_now && pipe_advance + + if (bankId == 0) { + val outDataUInt = VecInit((0 until transLoadSize).map(i => outPackets(i).data)).asUInt + val outEntryUInt = VecInit((0 until transLoadSize).map(i => outPackets(i).entry_offset)).asUInt + when(io.out.valid) { + printf("[AML_TRANS_ALIGN_OUT<%d>] valids:%b data:%x entry:%x\n", + io.debug_time, outValids.asUInt, outDataUInt, outEntryUInt) + } + } + + switch(state) { + is(s_normal) { + when(input_accept && io.is_drain_trigger) { + state := s_drain + drain_cnt := (drainCycles - 1).U + } + } + is(s_drain) { + when(drain_cnt === 0.U) { + state := s_normal + }.otherwise { + drain_cnt := drain_cnt - 1.U + } + } + } +} + +class OOORouter(implicit p: Parameters) extends CuteModule { + private val transLoadSize = Trans_Load_Size + private val entryOffsetBits = log2Ceil(ABMatrixRegEntryByteSize) + private val groupSize = math.max(1, transLoadSize / 4) + + val io = IO(new Bundle { + val in = Flipped(Decoupled(Vec(transLoadSize, new TransDataPacket))) + val final_data = Output(UInt(ABMatrixRegEntryBitSize.W)) + val final_mask = Output(UInt(ABMatrixRegEntryByteSize.W)) + val valid = Output(Bool()) + val empty = Output(Bool()) + }) + + require(ABMatrixRegEntryByteSize == 32, "OOORouter currently targets 32B AB MatrixReg entries") + require(transLoadSize == 8 || transLoadSize == 16, "OOORouter currently supports Trans_Load_Size 8 or 16") + + io.in.ready := true.B + + val inputFire = io.in.valid && io.in.ready + + // Byte mask 与 data 使用同一套 offset 译码和规约树。 + // 不保留历史 life mask,避免在当前拍没有有效 data 的 byte 上误拉高写使能。 + val stage1Data = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(32.W)))) + val stage1Mask = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(4.W)))) + val stage1WordIdx = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(3.W)))) + val stage1Valid = RegInit(VecInit(Seq.fill(transLoadSize)(false.B))) + val stage1AnyValid = RegInit(false.B) + + val stage2Data = RegInit(VecInit(Seq.fill(4)(0.U(ABMatrixRegEntryBitSize.W)))) + val stage2Mask = RegInit(VecInit(Seq.fill(4)(0.U(ABMatrixRegEntryByteSize.W)))) + val stage2Valid = RegInit(false.B) + + val stage3Data = RegInit(0.U(ABMatrixRegEntryBitSize.W)) + val stage3Mask = RegInit(0.U(ABMatrixRegEntryByteSize.W)) + val stage3Valid = RegInit(false.B) + + for (i <- 0 until transLoadSize) { + val pkt = io.in.bits(i) + val byteInWord = pkt.entry_offset(1, 0) + val wordIdx = pkt.entry_offset(entryOffsetBits - 1, 2) + stage1Data(i) := Mux(inputFire && pkt.mask, (pkt.data.pad(32) << (byteInWord << 3))(31, 0), 0.U) + stage1Mask(i) := Mux(inputFire && pkt.mask, UIntToOH(byteInWord, 4).asUInt, 0.U) + stage1WordIdx(i) := wordIdx + stage1Valid(i) := inputFire && pkt.mask + } + val inputAnyValid = io.in.bits.map(_.mask).reduce(_ || _) + stage1AnyValid := inputFire && inputAnyValid + + val expandedData = (0 until transLoadSize).map { i => + Mux(stage1Valid(i), (stage1Data(i).asUInt << (stage1WordIdx(i) << 5))(ABMatrixRegEntryBitSize - 1, 0), 0.U(ABMatrixRegEntryBitSize.W)) + } + val expandedMask = (0 until transLoadSize).map { i => + Mux(stage1Valid(i), (stage1Mask(i).asUInt << (stage1WordIdx(i) << 2))(ABMatrixRegEntryByteSize - 1, 0), 0.U(ABMatrixRegEntryByteSize.W)) + } + + for (g <- 0 until 4) { + val start = g * groupSize + val end = math.min(start + groupSize, transLoadSize) + val dataTerms = (start until end).map(expandedData) + val maskTerms = (start until end).map(expandedMask) + val reducedData = if (dataTerms.nonEmpty) dataTerms.reduce(_ | _) else 0.U(ABMatrixRegEntryBitSize.W) + val reducedMask = if (maskTerms.nonEmpty) maskTerms.reduce(_ | _) else 0.U(ABMatrixRegEntryByteSize.W) + stage2Data(g) := reducedData + stage2Mask(g) := reducedMask + } + stage2Valid := stage1AnyValid + + stage3Data := stage2Data.reduce(_ | _) + stage3Mask := stage2Mask.reduce(_ | _) + stage3Valid := stage2Valid + + io.final_data := stage3Data + io.final_mask := stage3Mask + io.valid := stage3Valid + io.empty := !stage1AnyValid && !stage2Valid && !stage3Valid } class AMemoryLoader(implicit p: Parameters) extends CuteModule{ @@ -33,6 +280,7 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ io.ToMatrixRegIO.ByteMask.map(_.bits := Fill(ABMatrixRegEntryByteSize, true.B)) io.LocalMMUIO.Request.valid := false.B io.LocalMMUIO.Request.bits := DontCare // It will be set if Request is valid + io.LocalMMUIO.Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) io.LocalMMUIO.Response.ready := false.B io.ConfigInfo.MicroTaskEndValid := false.B io.ConfigInfo.MicroTaskReady := false.B @@ -43,24 +291,36 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ dontTouch(io) if (EnableDifftest) { + DifftestModule.addCppMacro("CONFIG_DIFF_AMU_AB_WORDS_PER_BANK", ABMatrixRegEntryBitSize / 64) + DifftestModule.addCppMacro("CONFIG_DIFF_AMU_AB_REG_SIZE_BYTES", ABMatrixRegSize) val pcReg = RegInit(0.U(64.W)) when (io.ConfigInfo.MicroTaskValid) { pcReg := io.ConfigInfo.pc.get } - val difftestAmuFinish = DifftestModule(new DiffAmuFinishEvent, delay = 0, dontCare = true) + val difftestAmuFinish = DifftestModule(new DiffAmuFinishEvent(ABMatrixRegNBanks, DiffAmuFinishWordsPerBank), delay = 0, dontCare = true) difftestAmuFinish.coreid := io.ConfigInfo.coreid.get difftestAmuFinish.index := 0.U difftestAmuFinish.valid := (io.ToMatrixRegIO.BankAddr.map(_.valid).reduce(_||_) || (io.ConfigInfo.MicroTaskEndValid && io.ConfigInfo.MicroTaskEndReady)) difftestAmuFinish.pc := pcReg + val eventWordsPerBank = difftestAmuFinish.data.length / ABMatrixRegNBanks + val abMRegWordsPerBank = ABMatrixRegEntryBitSize / 64 + require(difftestAmuFinish.data.length % ABMatrixRegNBanks == 0, "DiffAmuFinishEvent.data should divide by AB bank count") + require(ABMatrixRegEntryBitSize % 64 == 0, s"ABMatrixRegEntryBitSize must be 64-bit aligned, got $ABMatrixRegEntryBitSize") + require(abMRegWordsPerBank <= eventWordsPerBank, s"DiffAmuFinishEvent only supports up to $eventWordsPerBank words per AB bank, got $abMRegWordsPerBank") for (i <- 0 until ABMatrixRegNBanks) { difftestAmuFinish.bankValid(i) := io.ToMatrixRegIO.BankAddr(i).valid difftestAmuFinish.bankAddr(i) := io.ToMatrixRegIO.BankAddr(i).bits difftestAmuFinish.bankMask(i) := io.ToMatrixRegIO.ByteMask(i).bits - difftestAmuFinish.data(i * 4 + 0) := io.ToMatrixRegIO.Data(i).bits(63,0) - difftestAmuFinish.data(i * 4 + 1) := io.ToMatrixRegIO.Data(i).bits(127,64) - difftestAmuFinish.data(i * 4 + 2) := io.ToMatrixRegIO.Data(i).bits(191,128) - difftestAmuFinish.data(i * 4 + 3) := io.ToMatrixRegIO.Data(i).bits(255,192) + for (w <- 0 until eventWordsPerBank) { + if (w < abMRegWordsPerBank) { + val lo = w * 64 + val hi = lo + 63 + difftestAmuFinish.data(i * eventWordsPerBank + w) := io.ToMatrixRegIO.Data(i).bits(hi, lo) + } else { + difftestAmuFinish.data(i * eventWordsPerBank + w) := 0.U(64.W) + } + } } difftestAmuFinish.finish := io.ConfigInfo.MicroTaskEndValid && io.ConfigInfo.MicroTaskEndReady } @@ -68,33 +328,40 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ val ConfigInfo = io.ConfigInfo val Tensor_Block_BaseAddr = Reg(UInt(MMUAddrWidth.W)) val ApplicationTensor_A_Stride_M = RegInit(0.U(MMUAddrWidth.W)) - val dataType = RegInit(0.U(ElementDataType.DataTypeBitWidth.W)) val HasTail = RegInit(false.B) val TailByteMask = RegInit(0.U(log2Ceil(outsideDataWidthByte + 1).W)) val K_Beat_Count = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_M = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) - val MatrixRegTensor_K = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val Conherent = RegInit(true.B) + val Is_Transpose = RegInit(false.B) val Is_ZeroLoad = RegInit(false.B) val Is_FullLoad = RegInit(false.B) val s_idle :: s_mm_task :: Nil = Enum(2) val state = RegInit(s_idle) - val s_load_idle :: s_load_init :: s_load_working :: s_load_end :: Nil = Enum(4) + val s_load_idle :: s_load_init :: s_load_working :: s_load_quiesce :: s_load_end :: Nil = Enum(5) val memoryload_state = RegInit(s_load_idle) + private val transposeEndDrainCycles = Trans_Load_Size + 2 + 3 + val transposeEndDrainCnt = RegInit(0.U(log2Ceil(transposeEndDrainCycles + 1).W)) - val TotalLoadSize = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize)+1).W)) + val TotalLoadSize = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*outsideDataWidthByte)+1).W)) val TotalRequestSize = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*ReduceWidthByte)).W)) val CurrentLoaded_BlockTensor_M_Iter = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val CurrentLoaded_BlockTensor_K_Iter = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) - val Request_M_Iter_Time = RegInit(0.U(log2Ceil(Matrix_MN).W)) + val Request_M_Iter_Time = RegInit(0.U(log2Ceil(math.max(Matrix_MN, ABMatrixRegEntryByteSize)).W)) + + val group_req_cnt = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + val group_resp_cnt = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + val group_size_reg = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + + private val transBaseAddrBits = log2Ceil(ABMatrixRegBankNEntries) val SoureceIdSearchTable = RegInit(VecInit(Seq.fill(SoureceMaxNum)(0.U((new ASourceIdSearch).getWidth.W)))) val MaxRequestIter = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*ReduceWidthByte)).W)) val MReg_Fill_Table = RegInit((VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(0.U(outsideDataWidth.W))))) - val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(ABMatrixRegBankNEntrys).W))))) + val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(ABMatrixRegBankNEntries).W))))) val MReg_Fill_Table_Time = RegInit((VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(0.U((log2Ceil(outsideDataWidthByte/ABMatrixRegEntryByteSize)+1).W))))) val MReg_Fill_Table_IsTail = RegInit(VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(false.B))) val MReg_Fill_Table_Free = MReg_Fill_Table_Time.map(_ === 0.U) @@ -107,15 +374,43 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ val Bank_Fill_Search_FIFO_Tail = RegInit((VecInit(Seq.fill(ABMatrixRegNBanks)(0.U(log2Ceil(AMemoryLoaderReadFromMemoryFIFODepth).W))))) val Bank_Fill_Search_FIFO_Full = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(false.B))) val Bank_Fill_Search_FIFO_Empty = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(true.B))) - val Bank_Fill_Valid = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(false.B))) - val Have_Bank_Fill = Bank_Fill_Valid.reduce(_ || _) for(i <- 0 until ABMatrixRegNBanks){ Bank_Fill_Search_FIFO_Full(i) := Bank_Fill_Search_FIFO_Tail(i) === WrapInc(Bank_Fill_Search_FIFO_Head(i), AMemoryLoaderReadFromMemoryFIFODepth) Bank_Fill_Search_FIFO_Empty(i) := Bank_Fill_Search_FIFO_Head(i) === Bank_Fill_Search_FIFO_Tail(i) - Bank_Fill_Valid(i) := Bank_Fill_Search_FIFO_Head(i) =/= Bank_Fill_Search_FIFO_Tail(i) } + val transAlignPipes = Seq.tabulate(ABMatrixRegNBanks)(i => Module(new TransAlignPipe(i))) + val transRouters = Seq.tabulate(ABMatrixRegNBanks)(_ => Module(new OOORouter)) + val transPipeInValid = WireInit(false.B) + val transPipeInData = WireInit(0.U(outsideDataWidth.W)) + val transPipeInMask = WireInit(0.U(outsideDataWidthByte.W)) + val transPipeRespBeatCnt = WireInit(0.U(group_resp_cnt.getWidth.W)) + val transPipeEntryOffset = WireInit(0.U(log2Ceil(ABMatrixRegEntryByteSize).W)) + val transPipeDrainTrigger = WireInit(false.B) + val transBusStall = transAlignPipes.map(_.io.bus_stall).reduce(_ || _) + val transWriteBaseAddr = RegInit(0.U(transBaseAddrBits.W)) + val transWriteAddrCnt = RegInit(0.U(log2Ceil(Trans_Load_Size).W)) + + for (i <- 0 until ABMatrixRegNBanks) { + transAlignPipes(i).io.in_data := transPipeInData + transAlignPipes(i).io.in_mask := transPipeInMask + transAlignPipes(i).io.resp_beat_cnt := transPipeRespBeatCnt + transAlignPipes(i).io.entry_offset := transPipeEntryOffset + transAlignPipes(i).io.debug_time := io.DebugInfo.DebugTimeStampe + transAlignPipes(i).io.is_drain_trigger := transPipeDrainTrigger + transAlignPipes(i).io.in_valid := transPipeInValid + transRouters(i).io.in <> transAlignPipes(i).io.out + } + val transAlignEmpty = transAlignPipes.map(_.io.empty).reduce(_ && _) + val transRouterEmpty = transRouters.map(_.io.empty).reduce(_ && _) + val transPipelineEmpty = transAlignEmpty && transRouterEmpty + val transRouterValidVec = VecInit(transRouters.map(_.io.valid)).asUInt + val transRouterWriteValid = transRouters.map(_.io.valid).reduce(_ || _) + val transWriteAddrOffset = transWriteAddrCnt * ReduceGroupSize.U + val transWriteAddrWide = transWriteBaseAddr + transWriteAddrOffset + val transWriteAddr = transWriteAddrWide(transBaseAddrBits - 1, 0) + val Request = io.LocalMMUIO.Request when(state === s_idle){ @@ -124,22 +419,22 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ state := s_mm_task memoryload_state := s_load_init MatrixRegTensor_M := ConfigInfo.MatrixRegTensor_M - MatrixRegTensor_K := ConfigInfo.MatrixRegTensor_K CurrentMatrixRegId := ConfigInfo.MatrixRegId Tensor_Block_BaseAddr := ConfigInfo.ApplicationTensor_A.ApplicationTensor_A_BaseVaddr ApplicationTensor_A_Stride_M := ConfigInfo.ApplicationTensor_A.ApplicationTensor_A_Stride_M - dataType := ConfigInfo.ApplicationTensor_A.dataType HasTail := ConfigInfo.ApplicationTensor_A.HasTail TailByteMask := ConfigInfo.ApplicationTensor_A.TailByteMask K_Beat_Count := ConfigInfo.ApplicationTensor_A.K_Beat_Count Is_ZeroLoad := ConfigInfo.LoadTaskInfo.Is_ZeroLoad Is_FullLoad := ConfigInfo.LoadTaskInfo.Is_FullLoad Conherent := ConfigInfo.Conherent + Is_Transpose := ConfigInfo.Is_Transpose if(YJPAMLDebugEnable){ - printf("[AML<%d>]AMemoryLoader Task Start, MatrixRegTensor_M:%d, MatrixRegTensor_K:%d, BaseVaddr:%x, Stride_M:%x, dataType:%d, Is_ZeroLoad:%d, Is_FullLoad:%d\n", + printf("[AML<%d>]AMemoryLoader Task Start, MatrixRegTensor_M:%d, MatrixRegTensor_K:%d, BaseVaddr:%x, Stride_M:%x, dataType:%d, Is_Transpose:%d, Is_ZeroLoad:%d, Is_FullLoad:%d\n", io.DebugInfo.DebugTimeStampe, ConfigInfo.MatrixRegTensor_M, ConfigInfo.MatrixRegTensor_K, ConfigInfo.ApplicationTensor_A.ApplicationTensor_A_BaseVaddr, ConfigInfo.ApplicationTensor_A.ApplicationTensor_A_Stride_M, - ConfigInfo.ApplicationTensor_A.dataType, ConfigInfo.LoadTaskInfo.Is_ZeroLoad.asUInt, ConfigInfo.LoadTaskInfo.Is_FullLoad.asUInt) + ConfigInfo.ApplicationTensor_A.dataType, ConfigInfo.Is_Transpose.asUInt, + ConfigInfo.LoadTaskInfo.Is_ZeroLoad.asUInt, ConfigInfo.LoadTaskInfo.Is_FullLoad.asUInt) } } } @@ -152,6 +447,12 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ CurrentLoaded_BlockTensor_M_Iter := 0.U CurrentLoaded_BlockTensor_K_Iter := 0.U Request_M_Iter_Time := 0.U + group_req_cnt := 0.U + group_resp_cnt := 0.U + group_size_reg := 0.U + transposeEndDrainCnt := 0.U + transWriteBaseAddr := 0.U + transWriteAddrCnt := 0.U MaxRequestIter := MatrixRegTensor_M * K_Beat_Count Bank_Fill_Search_FIFO := 0.U.asTypeOf(Bank_Fill_Search_FIFO) Bank_Fill_Search_FIFO_Head := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Head) @@ -167,7 +468,7 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ "Error! AML Load Task Type: Exactly one of Is_ZeroLoad, Is_FullLoad should be true!") when(Is_ZeroLoad){ - val Max_ZeroLoad_Write_Times = ABMatrixRegBankNEntrys + val Max_ZeroLoad_Write_Times = ABMatrixRegBankNEntries for (i <- 0 until ABMatrixRegNBanks){ io.ToMatrixRegIO.BankAddr(i).bits := TotalLoadSize io.ToMatrixRegIO.BankAddr(i).valid := true.B @@ -189,23 +490,54 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ val tailTaskMask = UIntToOH(TailByteMask, outsideDataWidthByte + 1).asUInt - 1.U(outsideDataWidthByte.W) val fullTaskMask = Fill(outsideDataWidthByte, true.B) val RequestBeatIsTail = HasTail && (CurrentLoaded_BlockTensor_K_Iter === (K_Beat_Count - 1.U)) + + // Transpose request path reuses the normal iter registers: + // CurrentLoaded_BlockTensor_M_Iter -> large_M group index + // CurrentLoaded_BlockTensor_K_Iter -> K/M-beat index + // Request_M_Iter_Time -> small_M inside current group + val transpose_large_m_base = CurrentLoaded_BlockTensor_M_Iter * ABMatrixRegEntryByteSize.U + val transpose_current_m = transpose_large_m_base + Request_M_Iter_Time + val transpose_group_in_range = transpose_large_m_base < MatrixRegTensor_M + val transpose_group_remain = Mux(transpose_group_in_range, MatrixRegTensor_M - transpose_large_m_base, 0.U) + val current_group_size = Wire(UInt(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + current_group_size := Mux( + transpose_group_remain < ABMatrixRegEntryByteSize.U, + transpose_group_remain(current_group_size.getWidth - 1, 0), + ABMatrixRegEntryByteSize.U(current_group_size.getWidth.W) + ) + val group_has_no_requests = group_req_cnt === 0.U && group_resp_cnt === 0.U + val group_is_idle = group_has_no_requests && transPipelineEmpty + val active_group_size = Mux(group_has_no_requests, current_group_size, group_size_reg) + val transpose_group_can_issue = Mux( + group_has_no_requests, + group_is_idle && (current_group_size =/= 0.U), + group_req_cnt < active_group_size + ) + val transpose_req_enable = (TotalRequestSize < MaxRequestIter) && transpose_group_can_issue + // 矩阵访存顺序:按 M 分 bank 交织,再扫 K。地址 = BaseAddr + M*Stride_M + K*64B - val RequestMatrixRegBankId = (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) % ABMatrixRegNBanks.U - val RequestMatrixRegBaseAddr = (((CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) / ABMatrixRegNBanks.U) * ReduceGroupSize.U) - val RequestMatrixRegAddr = RequestMatrixRegBaseAddr + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(MAX_Fill_Times)) + val RequestMatrixRegMIndex = Mux(Is_Transpose, transpose_current_m, CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) + val NormalRequestMatrixRegBankId = RequestMatrixRegMIndex % ABMatrixRegNBanks.U + val NormalRequestMatrixRegBaseAddr = ((RequestMatrixRegMIndex / ABMatrixRegNBanks.U) * ReduceGroupSize.U) + val NormalRequestMatrixRegAddr = NormalRequestMatrixRegBaseAddr + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(MAX_Fill_Times)) + val TransposeRequestMatrixRegAddr = CurrentLoaded_BlockTensor_K_Iter * (Trans_Load_Size * ReduceGroupSize).U + CurrentLoaded_BlockTensor_M_Iter + val RequestMatrixRegBankId = Mux(Is_Transpose, 0.U, NormalRequestMatrixRegBankId) + val RequestMatrixRegAddr = Mux(Is_Transpose, TransposeRequestMatrixRegAddr, NormalRequestMatrixRegAddr) - Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) * ApplicationTensor_A_Stride_M + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(outsideDataWidthByte)) + Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + RequestMatrixRegMIndex * ApplicationTensor_A_Stride_M + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(outsideDataWidthByte)) val sourceId = Mux(Conherent, io.LocalMMUIO.ConherentRequsetSourceID, io.LocalMMUIO.nonConherentRequsetSourceID) Request.bits.RequestConherent := Conherent Request.bits.RequestSourceID := sourceId.bits Request.bits.RequestType_isWrite := false.B - Request.valid := (TotalRequestSize < MaxRequestIter) + Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) + Request.valid := Mux(Is_Transpose, transpose_req_enable, TotalRequestSize < MaxRequestIter) when(Request.fire){ val TableItem = Wire(new ASourceIdSearch) TableItem.MatrixRegBankId := RequestMatrixRegBankId TableItem.MatrixRegAddr := RequestMatrixRegAddr TableItem.MatrixRegisTail := RequestBeatIsTail + TableItem.BeatIndex := Request_M_Iter_Time SoureceIdSearchTable(sourceId.bits) := TableItem.asUInt if (YJPAMLDebugEnable) { @@ -214,13 +546,41 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ Request_M_Iter_Time, Request.bits.RequestVirtualAddr, RequestMatrixRegBankId, RequestMatrixRegAddr, sourceId.bits, RequestBeatIsTail) } - Request_M_Iter_Time := Request_M_Iter_Time + 1.U - when(Request_M_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) === MatrixRegTensor_M - 1.U){ - Request_M_Iter_Time := 0.U - CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U - when(CurrentLoaded_BlockTensor_K_Iter + 1.U === K_Beat_Count){ - CurrentLoaded_BlockTensor_K_Iter := 0.U - CurrentLoaded_BlockTensor_M_Iter := CurrentLoaded_BlockTensor_M_Iter + Matrix_MN.U + when(Is_Transpose) { + printf("[AML_TRANS_REQ<%d>] totalReq:%d groupReq:%d groupResp:%d idle:%d largeM:%d kBeat:%d beat:%d groupSize:%d activeGroupSize:%d regBase:%d vaddr:%x source:%d tail:%d\n", + io.DebugInfo.DebugTimeStampe, TotalRequestSize, group_req_cnt, group_resp_cnt, + group_is_idle.asUInt, CurrentLoaded_BlockTensor_M_Iter, CurrentLoaded_BlockTensor_K_Iter, + Request_M_Iter_Time, current_group_size, active_group_size, RequestMatrixRegAddr, + Request.bits.RequestVirtualAddr, sourceId.bits, RequestBeatIsTail.asUInt) + + val small_m_reach_group_boundary = Request_M_Iter_Time === (ABMatrixRegEntryByteSize - 1).U + val small_m_reach_tensor_boundary = transpose_current_m === (MatrixRegTensor_M - 1.U) + val small_m_wrap = small_m_reach_group_boundary || small_m_reach_tensor_boundary + val k_wrap = CurrentLoaded_BlockTensor_K_Iter === (K_Beat_Count - 1.U) + + when(group_is_idle) { + group_size_reg := current_group_size + } + group_req_cnt := group_req_cnt + 1.U + + Request_M_Iter_Time := Request_M_Iter_Time + 1.U + when(small_m_wrap) { + Request_M_Iter_Time := 0.U + CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U + when(k_wrap) { + CurrentLoaded_BlockTensor_K_Iter := 0.U + CurrentLoaded_BlockTensor_M_Iter := CurrentLoaded_BlockTensor_M_Iter + 1.U + } + } + }.otherwise { + Request_M_Iter_Time := Request_M_Iter_Time + 1.U + when(Request_M_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) === MatrixRegTensor_M - 1.U){ + Request_M_Iter_Time := 0.U + CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U + when(CurrentLoaded_BlockTensor_K_Iter + 1.U === K_Beat_Count){ + CurrentLoaded_BlockTensor_K_Iter := 0.U + CurrentLoaded_BlockTensor_M_Iter := CurrentLoaded_BlockTensor_M_Iter + Matrix_MN.U + } } } when(TotalRequestSize =/= MaxRequestIter){ @@ -234,12 +594,12 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ } val current_fill_fifo_full = WireInit(false.B) - when(io.LocalMMUIO.Response.valid){ + when(io.LocalMMUIO.Response.valid && !Is_Transpose){ val respSourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID val MatrixRegBankId = SoureceIdSearchTable(respSourceId).asTypeOf(new ASourceIdSearch).MatrixRegBankId current_fill_fifo_full := Bank_Fill_Search_FIFO_Full(MatrixRegBankId) } - io.LocalMMUIO.Response.ready := MReg_Fill_Table_Not_Full && !current_fill_fifo_full + io.LocalMMUIO.Response.ready := Mux(Is_Transpose, !transBusStall, MReg_Fill_Table_Not_Full && !current_fill_fifo_full) when(io.LocalMMUIO.Response.fire){ val sourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID @@ -253,58 +613,172 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ printf("[AML_ResponseHandshake<%d>] Data:%x, BankId:%d, RegAddr:%d, SourceId:%d, FIFOIndex:%d, Tail:%d\n", io.DebugInfo.DebugTimeStampe, ResponseData, MatrixRegBankId, MatrixRegAddr, sourceId, FIFOIndex, MatrixRegSearch.MatrixRegisTail) } - MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData - MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr - MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U - MReg_Fill_Table_IsTail(MReg_Fill_Table_Insert_Index) := MatrixRegSearch.MatrixRegisTail - Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index - Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), AMemoryLoaderReadFromMemoryFIFODepth) + when(Is_Transpose) { + val next_group_resp_cnt = group_resp_cnt + 1.U + val drain_trigger = next_group_resp_cnt === active_group_size + + printf("[AML_TRANS_RESP<%d>] source:%d data:%x tail:%d mask:%x base:%d beatIndex:%d respCnt:%d nextResp:%d activeGroupSize:%d drainTrig:%d writeBase:%d writeCnt:%d\n", + io.DebugInfo.DebugTimeStampe, sourceId, ResponseData, MatrixRegSearch.MatrixRegisTail.asUInt, + Mux(MatrixRegSearch.MatrixRegisTail, tailTaskMask, fullTaskMask), MatrixRegAddr, + MatrixRegSearch.BeatIndex, group_resp_cnt, next_group_resp_cnt, active_group_size, + drain_trigger.asUInt, transWriteBaseAddr, transWriteAddrCnt) + + transPipeInValid := true.B + transPipeInData := ResponseData + transPipeInMask := Mux(MatrixRegSearch.MatrixRegisTail, tailTaskMask, fullTaskMask) + transPipeRespBeatCnt := group_resp_cnt + transPipeEntryOffset := MatrixRegSearch.BeatIndex + transPipeDrainTrigger := drain_trigger + + when(group_resp_cnt === 0.U) { + transWriteBaseAddr := MatrixRegAddr + transWriteAddrCnt := 0.U + } + + when(next_group_resp_cnt === active_group_size) { + group_req_cnt := 0.U + group_resp_cnt := 0.U + group_size_reg := 0.U + }.otherwise { + group_resp_cnt := next_group_resp_cnt + } + }.otherwise { + MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData + MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr + MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U + MReg_Fill_Table_IsTail(MReg_Fill_Table_Insert_Index) := MatrixRegSearch.MatrixRegisTail + Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index + Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), AMemoryLoaderReadFromMemoryFIFODepth) + } if (YJPAMLDebugEnable){ printf("[AML<%d>]Response, Data:%x, BankId:%d, RegAddr:%d\n", io.DebugInfo.DebugTimeStampe, ResponseData, MatrixRegBankId, MatrixRegAddr) } } - val Current_Fill_MReg_Time = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(0.U(1.W)))) - for (i <- 0 until ABMatrixRegNBanks){ - when(Bank_Fill_Search_FIFO_Empty(i) === false.B){ - val CurrentFIFOIndex = Bank_Fill_Search_FIFO(i)(Bank_Fill_Search_FIFO_Tail(i)) - val fillSlot = MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex) - val fillSlotOH = UIntToOH(fillSlot, MAX_Fill_Times) - val fillLowHalf = fillSlot(0) === 0.U - val currentIsTail = MReg_Fill_Table_IsTail(CurrentFIFOIndex) - Current_Fill_MReg_Time(i) := 1.U - val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*ABMatrixRegEntryByteSize).W))))) - FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) - io.ToMatrixRegIO.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + fillSlot - io.ToMatrixRegIO.BankAddr(i).valid := true.B - io.ToMatrixRegIO.Data(i).bits := Mux(fillLowHalf, FIFOData(0), FIFOData(1)) - io.ToMatrixRegIO.Data(i).valid := true.B - io.ToMatrixRegIO.ByteMask(i).bits := Mux(currentIsTail && fillSlotOH(1), tailTaskMask(63, 32), Mux(currentIsTail && fillSlotOH(0), tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B))) - io.ToMatrixRegIO.ByteMask(i).valid := true.B + when(Is_Transpose) { + for (i <- 0 until ABMatrixRegNBanks) { + val routerValid = transRouters(i).io.valid + io.ToMatrixRegIO.BankAddr(i).bits := transWriteAddr + io.ToMatrixRegIO.BankAddr(i).valid := routerValid + io.ToMatrixRegIO.Data(i).bits := transRouters(i).io.final_data + io.ToMatrixRegIO.Data(i).valid := routerValid + io.ToMatrixRegIO.ByteMask(i).bits := transRouters(i).io.final_mask + io.ToMatrixRegIO.ByteMask(i).valid := routerValid if (YJPAMLDebugEnable) { - printf("[AML_MRegWriteHandshake<%d>] bank:%d, RegAddr:%x, WriteAddr:%x, Data:%x, ByteMask:%x, Time:%d\n", io.DebugInfo.DebugTimeStampe, i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), io.ToMatrixRegIO.BankAddr(i).bits, io.ToMatrixRegIO.Data(i).bits, io.ToMatrixRegIO.ByteMask(i).bits, MReg_Fill_Table_Time(CurrentFIFOIndex)) + when(routerValid) { + printf("[AML_TransposeWrite<%d>] bank:%d, Addr:%d, Data:%x, Mask:%x\n", + io.DebugInfo.DebugTimeStampe, i.U, io.ToMatrixRegIO.BankAddr(i).bits, + transRouters(i).io.final_data, transRouters(i).io.final_mask) + } } - MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U - when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ - Bank_Fill_Search_FIFO_Tail(i) := WrapInc(Bank_Fill_Search_FIFO_Tail(i), AMemoryLoaderReadFromMemoryFIFODepth) + } + when(transRouterWriteValid) { + printf("[AML_TRANS_WRITE<%d>] validVec:%b base:%d cnt:%d addr:%d bank0Data:%x bank0Mask:%x totalLoad:%d pipelineEmpty:%d\n", + io.DebugInfo.DebugTimeStampe, transRouterValidVec, transWriteBaseAddr, transWriteAddrCnt, + transWriteAddr, transRouters(0).io.final_data, transRouters(0).io.final_mask, + TotalLoadSize, transPipelineEmpty.asUInt) + transWriteAddrCnt := Mux( + transWriteAddrCnt === (Trans_Load_Size - 1).U, + 0.U, + transWriteAddrCnt + 1.U + ) + } + // 转置路径以真实写侧 valid 计数;任务结束只看请求、响应和流水线是否全部静默。 + val Current_Load_Fill_Size = transRouterWriteValid.asUInt + val nextTotalLoadSize = TotalLoadSize + Current_Load_Fill_Size + val transposeDone = TotalRequestSize === MaxRequestIter && + group_req_cnt === 0.U && group_resp_cnt === 0.U && + transPipelineEmpty + TotalLoadSize := nextTotalLoadSize + if(YJPAMLDebugEnable){ + when(Current_Load_Fill_Size =/= 0.U) { + printf("[AML_TransposeLoad<%d>]TotalLoadSize:%d, FillSize:%d\n", io.DebugInfo.DebugTimeStampe, TotalLoadSize, Current_Load_Fill_Size) } - if (YJPAMLDebugEnable){ - printf("[AML<%d>]Fill bank:%d, RegAddr:%x, Time:%d\n", io.DebugInfo.DebugTimeStampe, i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MReg_Fill_Table_Time(CurrentFIFOIndex)) + } + when(transposeDone){ + memoryload_state := s_load_quiesce + transposeEndDrainCnt := (transposeEndDrainCycles - 1).U + if (YJPAMLDebugEnable) printf("[AML<%d>]TransposeFullLoadEnd\n", io.DebugInfo.DebugTimeStampe) + } + }.otherwise { + val Current_Fill_MReg_Time = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(0.U(1.W)))) + for (i <- 0 until ABMatrixRegNBanks){ + when(Bank_Fill_Search_FIFO_Empty(i) === false.B){ + val CurrentFIFOIndex = Bank_Fill_Search_FIFO(i)(Bank_Fill_Search_FIFO_Tail(i)) + val fillSlot = MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex) + val fillSlotOH = UIntToOH(fillSlot, MAX_Fill_Times) + val fillLowHalf = fillSlot(0) === 0.U + val currentIsTail = MReg_Fill_Table_IsTail(CurrentFIFOIndex) + Current_Fill_MReg_Time(i) := 1.U + val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*ABMatrixRegEntryByteSize).W))))) + FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) + io.ToMatrixRegIO.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + fillSlot + io.ToMatrixRegIO.BankAddr(i).valid := true.B + io.ToMatrixRegIO.Data(i).bits := Mux(fillLowHalf, FIFOData(0), FIFOData(1)) + io.ToMatrixRegIO.Data(i).valid := true.B + io.ToMatrixRegIO.ByteMask(i).bits := Mux(currentIsTail && fillSlotOH(1), tailTaskMask(63, 32), Mux(currentIsTail && fillSlotOH(0), tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B))) + io.ToMatrixRegIO.ByteMask(i).valid := true.B + if (YJPAMLDebugEnable) { + printf("[AML_MRegWriteHandshake<%d>] bank:%d, RegAddr:%x, WriteAddr:%x, Data:%x, ByteMask:%x, Time:%d\n", io.DebugInfo.DebugTimeStampe, i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), io.ToMatrixRegIO.BankAddr(i).bits, io.ToMatrixRegIO.Data(i).bits, io.ToMatrixRegIO.ByteMask(i).bits, MReg_Fill_Table_Time(CurrentFIFOIndex)) + } + MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U + when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ + Bank_Fill_Search_FIFO_Tail(i) := WrapInc(Bank_Fill_Search_FIFO_Tail(i), AMemoryLoaderReadFromMemoryFIFODepth) + } + if (YJPAMLDebugEnable){ + printf("[AML<%d>]Fill bank:%d, RegAddr:%x, Time:%d\n", io.DebugInfo.DebugTimeStampe, i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MReg_Fill_Table_Time(CurrentFIFOIndex)) + } } } - } - val Current_Load_Fill_Size = PopCount(Current_Fill_MReg_Time.asUInt) - TotalLoadSize := TotalLoadSize + Current_Load_Fill_Size - if(YJPAMLDebugEnable){ - when(Current_Load_Fill_Size =/= 0.U) { - printf("[AML<%d>]TotalLoadSize:%d, FillSize:%d, Max:%d\n", io.DebugInfo.DebugTimeStampe, TotalLoadSize, Current_Load_Fill_Size, MaxRequestIter * MAX_Fill_Times.U) + val Current_Load_Fill_Size = PopCount(Current_Fill_MReg_Time.asUInt) + TotalLoadSize := TotalLoadSize + Current_Load_Fill_Size + if(YJPAMLDebugEnable){ + when(Current_Load_Fill_Size =/= 0.U) { + printf("[AML<%d>]TotalLoadSize:%d, FillSize:%d, Max:%d\n", io.DebugInfo.DebugTimeStampe, TotalLoadSize, Current_Load_Fill_Size, MaxRequestIter * MAX_Fill_Times.U) + } + } + when(TotalLoadSize === (MaxRequestIter * MAX_Fill_Times.U)){ + memoryload_state := s_load_end + if (YJPAMLDebugEnable) printf("[AML<%d>]FullLoadEnd\n", io.DebugInfo.DebugTimeStampe) } } - when(TotalLoadSize === (MaxRequestIter * MAX_Fill_Times.U)){ - memoryload_state := s_load_end - if (YJPAMLDebugEnable) printf("[AML<%d>]FullLoadEnd\n", io.DebugInfo.DebugTimeStampe) + } + } + is(s_load_quiesce) { + io.ToMatrixRegIO.active := true.B + for (i <- 0 until ABMatrixRegNBanks) { + val routerValid = transRouters(i).io.valid + io.ToMatrixRegIO.BankAddr(i).bits := transWriteAddr + io.ToMatrixRegIO.BankAddr(i).valid := routerValid + io.ToMatrixRegIO.Data(i).bits := transRouters(i).io.final_data + io.ToMatrixRegIO.Data(i).valid := routerValid + io.ToMatrixRegIO.ByteMask(i).bits := transRouters(i).io.final_mask + io.ToMatrixRegIO.ByteMask(i).valid := routerValid + if (YJPAMLDebugEnable) { + when(routerValid) { + printf("[AML_TransposeQuiesceWrite<%d>] bank:%d, Addr:%d, Data:%x, Mask:%x\n", + io.DebugInfo.DebugTimeStampe, i.U, io.ToMatrixRegIO.BankAddr(i).bits, + transRouters(i).io.final_data, transRouters(i).io.final_mask) + } } } + when(transRouterWriteValid) { + printf("[AML_TRANS_QUIESCE_WRITE<%d>] validVec:%b base:%d cnt:%d addr:%d bank0Data:%x bank0Mask:%x drainCnt:%d pipelineEmpty:%d\n", + io.DebugInfo.DebugTimeStampe, transRouterValidVec, transWriteBaseAddr, transWriteAddrCnt, + transWriteAddr, transRouters(0).io.final_data, transRouters(0).io.final_mask, + transposeEndDrainCnt, transPipelineEmpty.asUInt) + transWriteAddrCnt := Mux( + transWriteAddrCnt === (Trans_Load_Size - 1).U, + 0.U, + transWriteAddrCnt + 1.U + ) + } + when(transposeEndDrainCnt === 0.U) { + memoryload_state := s_load_end + if (YJPAMLDebugEnable) printf("[AML<%d>]TransposeQuiesceEnd\n", io.DebugInfo.DebugTimeStampe) + }.otherwise { + transposeEndDrainCnt := transposeEndDrainCnt - 1.U + } } is(s_load_end) { ConfigInfo.MicroTaskEndValid := true.B diff --git a/src/main/scala/BDataController.scala b/src/main/scala/BDataController.scala index 4c58d1d..d94fa3c 100644 --- a/src/main/scala/BDataController.scala +++ b/src/main/scala/BDataController.scala @@ -146,7 +146,7 @@ class BDataController(implicit p: Parameters) extends CuteModule{ // printf("[BDataController<%d>]BDataController: M_IteratorMax is %d, N_IteratorMax is %d, K_IteratorMax is %d\n",io.DebugInfo.DebugTimeStampe, M_IteratorMax, N_IteratorMax, K_IteratorMax) // } //MTE循环的最外层是M,然后是N,最后是K,所以这里在同步信号的ComputeGo的协同下,执行Max_Caculate_Iter次取数 - val next_addr = Wire(UInt(ABMatrixRegBankNEntrys.W)) + val next_addr = Wire(UInt(ABMatrixRegBankNEntries.W)) next_addr := N_Iterator * K_IteratorMax + K_Iterator MatrixRegRequestBankAddr.bits.foreach(_ := next_addr) diff --git a/src/main/scala/BMemoryLoader.scala b/src/main/scala/BMemoryLoader.scala index 22aaf3b..7d70bfd 100644 --- a/src/main/scala/BMemoryLoader.scala +++ b/src/main/scala/BMemoryLoader.scala @@ -20,8 +20,9 @@ import org.chipsalliance.cde.config._ class BSourceIdSearch(implicit p: Parameters) extends CuteBundle{ val MatrixRegBankId = UInt(log2Ceil(ABMatrixRegNBanks).W) - val MatrixRegAddr = UInt(log2Ceil(ABMatrixRegBankNEntrys).W) + val MatrixRegAddr = UInt(log2Ceil(ABMatrixRegBankNEntries).W) val MatrixRegisTail = Bool() + val BeatIndex = UInt(log2Ceil(ABMatrixRegEntryByteSize).W) } //对于卷积,数据摆放是[khkwoc][ic],对于矩阵乘,数据摆放是[N][K] @@ -46,6 +47,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ io.ToMatrixRegIO.ByteMask.map(_.bits := Fill(ABMatrixRegEntryByteSize, true.B)) io.LocalMMUIO.Request.valid := false.B io.LocalMMUIO.Request.bits := DontCare // It will be set if Request is valid + io.LocalMMUIO.Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) io.LocalMMUIO.Response.ready := false.B io.ConfigInfo.MicroTaskEndValid := false.B io.ConfigInfo.MicroTaskReady := false.B @@ -55,21 +57,31 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ when (io.ConfigInfo.MicroTaskValid) { pcReg := io.ConfigInfo.pc.get } - val difftestAmuFinish = DifftestModule(new DiffAmuFinishEvent, delay = 0, dontCare = true) + val difftestAmuFinish = DifftestModule(new DiffAmuFinishEvent(ABMatrixRegNBanks, DiffAmuFinishWordsPerBank), delay = 0, dontCare = true) // 默认值初始化 difftestAmuFinish.coreid := io.ConfigInfo.coreid.get difftestAmuFinish.index := 1.U difftestAmuFinish.valid := (io.ToMatrixRegIO.BankAddr.map(_.valid).reduce(_||_) || (io.ConfigInfo.MicroTaskEndValid && io.ConfigInfo.MicroTaskEndReady)) difftestAmuFinish.pc := pcReg + val eventWordsPerBank = difftestAmuFinish.data.length / ABMatrixRegNBanks + val abMRegWordsPerBank = ABMatrixRegEntryBitSize / 64 + require(difftestAmuFinish.data.length % ABMatrixRegNBanks == 0, "DiffAmuFinishEvent.data should divide by AB bank count") + require(ABMatrixRegEntryBitSize % 64 == 0, s"ABMatrixRegEntryBitSize must be 64-bit aligned, got $ABMatrixRegEntryBitSize") + require(abMRegWordsPerBank <= eventWordsPerBank, s"DiffAmuFinishEvent only supports up to $eventWordsPerBank words per AB bank, got $abMRegWordsPerBank") for (i <- 0 until ABMatrixRegNBanks) { difftestAmuFinish.bankValid(i) := io.ToMatrixRegIO.BankAddr(i).valid difftestAmuFinish.bankAddr(i) := io.ToMatrixRegIO.BankAddr(i).bits difftestAmuFinish.bankMask(i) := io.ToMatrixRegIO.ByteMask(i).bits - difftestAmuFinish.data(i * 4 + 0) := io.ToMatrixRegIO.Data(i).bits(63,0) - difftestAmuFinish.data(i * 4 + 1) := io.ToMatrixRegIO.Data(i).bits(127,64) - difftestAmuFinish.data(i * 4 + 2) := io.ToMatrixRegIO.Data(i).bits(191,128) - difftestAmuFinish.data(i * 4 + 3) := io.ToMatrixRegIO.Data(i).bits(255,192) + for (w <- 0 until eventWordsPerBank) { + if (w < abMRegWordsPerBank) { + val lo = w * 64 + val hi = lo + 63 + difftestAmuFinish.data(i * eventWordsPerBank + w) := io.ToMatrixRegIO.Data(i).bits(hi, lo) + } else { + difftestAmuFinish.data(i * eventWordsPerBank + w) := 0.U(64.W) + } + } } difftestAmuFinish.finish := io.ConfigInfo.MicroTaskEndValid && io.ConfigInfo.MicroTaskEndReady } @@ -98,11 +110,14 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ //访存状态机,用来配合流水线刷新 - val s_load_idle :: s_load_init :: s_load_working :: s_load_end :: Nil = Enum(4) + val s_load_idle :: s_load_init :: s_load_working :: s_load_quiesce :: s_load_end :: Nil = Enum(5) val memoryload_state = RegInit(s_load_idle) + private val transposeEndDrainCycles = Trans_Load_Size + 2 + 3 + val transposeEndDrainCnt = RegInit(0.U(log2Ceil(transposeEndDrainCycles + 1).W)) val Tensor_Block_BaseAddr = Reg(UInt(MMUAddrWidth.W)) //分块矩阵的基地址 val Conherent = RegInit(true.B) //是否一致性访存的标志位,由TaskController提供 + val Is_Transpose = RegInit(false.B) //是否转置load,由TaskController提供 //如果configinfo有效 @@ -121,6 +136,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ Tensor_B_BaseVaddr := io.ConfigInfo.ApplicationTensor_B.ApplicationTensor_B_BaseVaddr //这个不重要 Tensor_Block_BaseAddr := io.ConfigInfo.ApplicationTensor_B.BlockTensor_B_BaseVaddr //这个是关键 Conherent := io.ConfigInfo.Conherent + Is_Transpose := io.ConfigInfo.Is_Transpose HasTail := io.ConfigInfo.ApplicationTensor_B.HasTail TailByteMask := io.ConfigInfo.ApplicationTensor_B.TailByteMask K_Beat_Count := io.ConfigInfo.ApplicationTensor_B.K_Beat_Count @@ -129,7 +145,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ if(YJPBMLDebugEnable) { printf("[BML<%d>]BMemoryLoader Task Start\n",io.DebugInfo.DebugTimeStampe) - printf("[BML<%d>]MatrixRegTensor_N:%d,MatrixRegTensor_K:%d\n",io.DebugInfo.DebugTimeStampe,io.ConfigInfo.MatrixRegTensor_N,io.ConfigInfo.MatrixRegTensor_K) + printf("[BML<%d>]MatrixRegTensor_N:%d,MatrixRegTensor_K:%d,Is_Transpose:%d\n",io.DebugInfo.DebugTimeStampe,io.ConfigInfo.MatrixRegTensor_N,io.ConfigInfo.MatrixRegTensor_K,io.ConfigInfo.Is_Transpose) printf("[BML<%d>]Tensor_B_BaseVaddr:%x,Tensor_Block_BaseAddr:%x\n",io.DebugInfo.DebugTimeStampe,io.ConfigInfo.ApplicationTensor_B.ApplicationTensor_B_BaseVaddr,io.ConfigInfo.ApplicationTensor_B.BlockTensor_B_BaseVaddr) printf("[BML<%d>]ApplicationTensor_B_Stride_N:%x\n",io.DebugInfo.DebugTimeStampe,io.ConfigInfo.ApplicationTensor_B.ApplicationTensor_B_Stride_N) } @@ -162,7 +178,13 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ val TotalRequestSize = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*ReduceWidthByte)).W)) val CurrentLoaded_BlockTensor_N_Iter = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val CurrentLoaded_BlockTensor_K_Iter = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) - val Request_N_Iter_Time = RegInit(0.U(log2Ceil(Matrix_MN).W)) + val Request_N_Iter_Time = RegInit(0.U(log2Ceil(math.max(Matrix_MN, ABMatrixRegEntryByteSize)).W)) + + val group_req_cnt = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + val group_resp_cnt = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + val group_size_reg = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + + private val transBaseAddrBits = log2Ceil(ABMatrixRegBankNEntries) //一个cam来存储访存请求的source_id对应的MatrixReg的地址和bank号 @@ -173,7 +195,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ val MaxRequestIter = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*ReduceWidthByte)).W)) val MReg_Fill_Table = RegInit((VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(0.U(outsideDataWidth.W))))) - val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(ABMatrixRegBankNEntrys).W)))))//记录这个LLC回的数是在scp的哪个地址 + val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(ABMatrixRegBankNEntries).W)))))//记录这个LLC回的数是在scp的哪个地址 val MReg_Fill_Table_Time = RegInit((VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(0.U((log2Ceil(outsideDataWidthByte/ABMatrixRegEntryByteSize)+1).W)))))//记录这个LLC回的数需要回填的次数,完成就可以将数据释放了 val MReg_Fill_Table_IsTail = RegInit(VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(false.B))) val MReg_Fill_Table_Free = MReg_Fill_Table_Time.map(_ === 0.U)//记录这个FIFO能否能填数据 @@ -195,6 +217,37 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ Bank_Fill_Search_FIFO_Empty(i) := Bank_Fill_Search_FIFO_Head(i) === Bank_Fill_Search_FIFO_Tail(i)//这个bank不需要写scp } + val transAlignPipes = Seq.tabulate(ABMatrixRegNBanks)(i => Module(new TransAlignPipe(i))) + val transRouters = Seq.tabulate(ABMatrixRegNBanks)(_ => Module(new OOORouter)) + val transPipeInValid = WireInit(false.B) + val transPipeInData = WireInit(0.U(outsideDataWidth.W)) + val transPipeInMask = WireInit(0.U(outsideDataWidthByte.W)) + val transPipeRespBeatCnt = WireInit(0.U(group_resp_cnt.getWidth.W)) + val transPipeEntryOffset = WireInit(0.U(log2Ceil(ABMatrixRegEntryByteSize).W)) + val transPipeDrainTrigger = WireInit(false.B) + val transBusStall = transAlignPipes.map(_.io.bus_stall).reduce(_ || _) + val transWriteBaseAddr = RegInit(0.U(transBaseAddrBits.W)) + val transWriteAddrCnt = RegInit(0.U(log2Ceil(Trans_Load_Size).W)) + + for (i <- 0 until ABMatrixRegNBanks) { + transAlignPipes(i).io.in_data := transPipeInData + transAlignPipes(i).io.in_mask := transPipeInMask + transAlignPipes(i).io.resp_beat_cnt := transPipeRespBeatCnt + transAlignPipes(i).io.entry_offset := transPipeEntryOffset + transAlignPipes(i).io.debug_time := io.DebugInfo.DebugTimeStampe + transAlignPipes(i).io.is_drain_trigger := transPipeDrainTrigger + transAlignPipes(i).io.in_valid := transPipeInValid + transRouters(i).io.in <> transAlignPipes(i).io.out + } + val transAlignEmpty = transAlignPipes.map(_.io.empty).reduce(_ && _) + val transRouterEmpty = transRouters.map(_.io.empty).reduce(_ && _) + val transPipelineEmpty = transAlignEmpty && transRouterEmpty + val transRouterValidVec = VecInit(transRouters.map(_.io.valid)).asUInt + val transRouterWriteValid = transRouters.map(_.io.valid).reduce(_ || _) + val transWriteAddrOffset = transWriteAddrCnt * ReduceGroupSize.U + val transWriteAddrWide = transWriteBaseAddr + transWriteAddrOffset + val transWriteAddr = transWriteAddrWide(transBaseAddrBits - 1, 0) + val Request = io.LocalMMUIO.Request switch(memoryload_state) { @@ -205,6 +258,12 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ CurrentLoaded_BlockTensor_N_Iter := 0.U CurrentLoaded_BlockTensor_K_Iter := 0.U Request_N_Iter_Time := 0.U + group_req_cnt := 0.U + group_resp_cnt := 0.U + group_size_reg := 0.U + transposeEndDrainCnt := 0.U + transWriteBaseAddr := 0.U + transWriteAddrCnt := 0.U MaxRequestIter := MatrixRegTensor_N * K_Beat_Count //总共要发出的访存请求的次数 Bank_Fill_Search_FIFO := 0.U.asTypeOf(Bank_Fill_Search_FIFO) Bank_Fill_Search_FIFO_Head := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Head) @@ -220,19 +279,50 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ //先转换成独热码然后进行减一即可计算出掩码 val tailTaskMask = UIntToOH(TailByteMask, outsideDataWidthByte + 1).asUInt - 1.U(outsideDataWidthByte.W) + val fullTaskMask = Fill(outsideDataWidthByte, true.B) val RequestBeatIsTail = HasTail && (CurrentLoaded_BlockTensor_K_Iter === (K_Beat_Count - 1.U)) - // 访存顺序与AML保持一致:先沿N维分4个bank发射,再推进K维,最后推进下一组N block - val RequestMatrixRegBankId = (CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) % ABMatrixRegNBanks.U - val RequestMatrixRegBaseAddr = (((CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) / ABMatrixRegNBanks.U) * ReduceGroupSize.U) - val RequestMatrixRegAddr = RequestMatrixRegBaseAddr + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(MAX_Fill_Times)) - Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) * ApplicationTensor_B_Stride_N + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(outsideDataWidthByte)) + // Transpose request path reuses the normal iter registers: + // CurrentLoaded_BlockTensor_N_Iter -> large_N group index + // CurrentLoaded_BlockTensor_K_Iter -> K/N-beat index + // Request_N_Iter_Time -> small_N inside current group + val transpose_large_n_base = CurrentLoaded_BlockTensor_N_Iter * ABMatrixRegEntryByteSize.U + val transpose_current_n = transpose_large_n_base + Request_N_Iter_Time + val transpose_group_in_range = transpose_large_n_base < MatrixRegTensor_N + val transpose_group_remain = Mux(transpose_group_in_range, MatrixRegTensor_N - transpose_large_n_base, 0.U) + val current_group_size = Wire(UInt(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + current_group_size := Mux( + transpose_group_remain < ABMatrixRegEntryByteSize.U, + transpose_group_remain(current_group_size.getWidth - 1, 0), + ABMatrixRegEntryByteSize.U(current_group_size.getWidth.W) + ) + val group_has_no_requests = group_req_cnt === 0.U && group_resp_cnt === 0.U + val group_is_idle = group_has_no_requests && transPipelineEmpty + val active_group_size = Mux(group_has_no_requests, current_group_size, group_size_reg) + val transpose_group_can_issue = Mux( + group_has_no_requests, + group_is_idle && (current_group_size =/= 0.U), + group_req_cnt < active_group_size + ) + val transpose_req_enable = (TotalRequestSize < MaxRequestIter) && transpose_group_can_issue + + // 访存顺序与AML保持一致:先沿N维分4个bank发射,再推进K维,最后推进下一组N block + val RequestMatrixRegNIndex = Mux(Is_Transpose, transpose_current_n, CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) + val NormalRequestMatrixRegBankId = RequestMatrixRegNIndex % ABMatrixRegNBanks.U + val NormalRequestMatrixRegBaseAddr = ((RequestMatrixRegNIndex / ABMatrixRegNBanks.U) * ReduceGroupSize.U) + val NormalRequestMatrixRegAddr = NormalRequestMatrixRegBaseAddr + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(MAX_Fill_Times)) + val TransposeRequestMatrixRegAddr = CurrentLoaded_BlockTensor_K_Iter * (Trans_Load_Size * ReduceGroupSize).U + CurrentLoaded_BlockTensor_N_Iter + val RequestMatrixRegBankId = Mux(Is_Transpose, 0.U, NormalRequestMatrixRegBankId) + val RequestMatrixRegAddr = Mux(Is_Transpose, TransposeRequestMatrixRegAddr, NormalRequestMatrixRegAddr) + + Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + RequestMatrixRegNIndex * ApplicationTensor_B_Stride_N + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(outsideDataWidthByte)) val sourceId = Mux(Conherent,io.LocalMMUIO.ConherentRequsetSourceID,io.LocalMMUIO.nonConherentRequsetSourceID) Request.bits.RequestConherent := Conherent Request.bits.RequestSourceID := sourceId.bits Request.bits.RequestType_isWrite := false.B - Request.valid := (TotalRequestSize < MaxRequestIter) + Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) + Request.valid := Mux(Is_Transpose, transpose_req_enable, TotalRequestSize < MaxRequestIter) when(Request.fire && sourceId.valid){//符合条件的话,这条访存请求一定会被发出 //Request.ready表明了LocalMMU会处理这条访存请求,sourceID valid,表明这条访存请求的sourceID是被LocalMMU认可有效才发送到这个模块的 @@ -240,17 +330,46 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ TableItem.MatrixRegBankId := RequestMatrixRegBankId TableItem.MatrixRegAddr := RequestMatrixRegAddr TableItem.MatrixRegisTail := RequestBeatIsTail + TableItem.BeatIndex := Request_N_Iter_Time SoureceIdSearchTable(sourceId.bits) := TableItem.asUInt if (YJPBMLDebugEnable) { printf("[BML_RequestHandshake<%d>] sourceId:%d, MatrixRegBankId:%d, MatrixRegAddr:%d, RequestVirtualAddr:%x, RequestConherent:%d, RequestType_isWrite:%d, Tail:%d\n",io.DebugInfo.DebugTimeStampe,sourceId.bits,TableItem.MatrixRegBankId,TableItem.MatrixRegAddr,Request.bits.RequestVirtualAddr,Request.bits.RequestConherent,Request.bits.RequestType_isWrite,RequestBeatIsTail) } - Request_N_Iter_Time := Request_N_Iter_Time + 1.U - when(Request_N_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) === MatrixRegTensor_N - 1.U){ - Request_N_Iter_Time := 0.U - CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U - when(CurrentLoaded_BlockTensor_K_Iter + 1.U === K_Beat_Count){ - CurrentLoaded_BlockTensor_K_Iter := 0.U - CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + Matrix_MN.U + when(Is_Transpose) { + printf("[BML_TRANS_REQ<%d>] totalReq:%d groupReq:%d groupResp:%d idle:%d largeN:%d kBeat:%d beat:%d groupSize:%d activeGroupSize:%d regBase:%d vaddr:%x source:%d tail:%d\n", + io.DebugInfo.DebugTimeStampe, TotalRequestSize, group_req_cnt, group_resp_cnt, + group_is_idle.asUInt, CurrentLoaded_BlockTensor_N_Iter, CurrentLoaded_BlockTensor_K_Iter, + Request_N_Iter_Time, current_group_size, active_group_size, RequestMatrixRegAddr, + Request.bits.RequestVirtualAddr, sourceId.bits, RequestBeatIsTail.asUInt) + + val small_n_reach_group_boundary = Request_N_Iter_Time === (ABMatrixRegEntryByteSize - 1).U + val small_n_reach_tensor_boundary = transpose_current_n === (MatrixRegTensor_N - 1.U) + val small_n_wrap = small_n_reach_group_boundary || small_n_reach_tensor_boundary + val k_wrap = CurrentLoaded_BlockTensor_K_Iter === (K_Beat_Count - 1.U) + + when(group_is_idle) { + group_size_reg := current_group_size + } + group_req_cnt := group_req_cnt + 1.U + + Request_N_Iter_Time := Request_N_Iter_Time + 1.U + when(small_n_wrap) { + Request_N_Iter_Time := 0.U + CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U + when(k_wrap) { + CurrentLoaded_BlockTensor_K_Iter := 0.U + CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + 1.U + } + } + }.otherwise { + Request_N_Iter_Time := Request_N_Iter_Time + 1.U + when(Request_N_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) === MatrixRegTensor_N - 1.U){ + Request_N_Iter_Time := 0.U + CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U + when(CurrentLoaded_BlockTensor_K_Iter + 1.U === K_Beat_Count){ + CurrentLoaded_BlockTensor_K_Iter := 0.U + CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + Matrix_MN.U + } } } when(TotalRequestSize =/= MaxRequestIter){ @@ -258,7 +377,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ } } val current_fill_fifo_full = WireInit(false.B) - when(io.LocalMMUIO.Response.valid) + when(io.LocalMMUIO.Response.valid && !Is_Transpose) { val sourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID val MatrixRegBankId = SoureceIdSearchTable(sourceId).asTypeOf(new BSourceIdSearch).MatrixRegBankId @@ -267,13 +386,12 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ //接受访存的返回值 //一个cam来存储访存请求的source_id对应的MatrixReg的地址和bank号 //根据response的sourceid,找到对应的MatrixReg的Fill_Table的队伍头的索引,填充到Fill_Table中 - if (ABMLNeedMRegFillTable) - { - io.LocalMMUIO.Response.ready := MReg_Fill_Table_Not_Full && (current_fill_fifo_full === false.B) - } else - { - io.LocalMMUIO.Response.ready := true.B + val normalRespReady = if (ABMLNeedMRegFillTable) { + MReg_Fill_Table_Not_Full && (current_fill_fifo_full === false.B) + } else { + true.B } + io.LocalMMUIO.Response.ready := Mux(Is_Transpose, !transBusStall, normalRespReady) when(io.LocalMMUIO.Response.fire){ //Trick注意这个设计,是doublebuffer的,AB只能是doublebuffer,回数一定是不会堵的,而且我们有时间对数据进行压缩解压缩~ //如果要做release设计,要么数据位宽翻倍,腾出周期来使得有空泡能给写任务进行,要么就是数据位宽不变,将读写端口变成独立的读和独立的写端口 @@ -288,30 +406,61 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ printf("[BML_ResponseHandshake<%d>] ResponseData:%x, MatrixRegBankId:%d, MatrixRegAddr:%d, SourceId:%d, FIFOIndex:%d, Tail:%d\n",io.DebugInfo.DebugTimeStampe,ResponseData,MatrixRegBankId,MatrixRegAddr,sourceId,FIFOIndex,searchEntry.MatrixRegisTail) } - if (!ABMLNeedMRegFillTable) - { - TotalLoadSize := TotalLoadSize + 1.U - for (i <- 0 until ABMatrixRegNBanks) + when(Is_Transpose) { + val next_group_resp_cnt = group_resp_cnt + 1.U + val drain_trigger = next_group_resp_cnt === active_group_size + + printf("[BML_TRANS_RESP<%d>] source:%d data:%x tail:%d mask:%x base:%d beatIndex:%d respCnt:%d nextResp:%d activeGroupSize:%d drainTrig:%d writeBase:%d writeCnt:%d\n", + io.DebugInfo.DebugTimeStampe, sourceId, ResponseData, searchEntry.MatrixRegisTail.asUInt, + Mux(searchEntry.MatrixRegisTail, tailTaskMask, fullTaskMask), MatrixRegAddr, + searchEntry.BeatIndex, group_resp_cnt, next_group_resp_cnt, active_group_size, + drain_trigger.asUInt, transWriteBaseAddr, transWriteAddrCnt) + + transPipeInValid := true.B + transPipeInData := ResponseData + transPipeInMask := Mux(searchEntry.MatrixRegisTail, tailTaskMask, fullTaskMask) + transPipeRespBeatCnt := group_resp_cnt + transPipeEntryOffset := searchEntry.BeatIndex + transPipeDrainTrigger := drain_trigger + + when(group_resp_cnt === 0.U) { + transWriteBaseAddr := MatrixRegAddr + transWriteAddrCnt := 0.U + } + + when(next_group_resp_cnt === active_group_size) { + group_req_cnt := 0.U + group_resp_cnt := 0.U + group_size_reg := 0.U + }.otherwise { + group_resp_cnt := next_group_resp_cnt + } + }.otherwise { + if (!ABMLNeedMRegFillTable) { - when(MatrixRegBankId === i.U) + TotalLoadSize := TotalLoadSize + 1.U + for (i <- 0 until ABMatrixRegNBanks) { - io.ToMatrixRegIO.BankAddr(i).bits := MatrixRegAddr - io.ToMatrixRegIO.Data(i).bits := ResponseData(255, 0) - io.ToMatrixRegIO.BankAddr(i).valid := true.B - io.ToMatrixRegIO.Data(i).valid := true.B - io.ToMatrixRegIO.ByteMask(i).bits := Mux(searchEntry.MatrixRegisTail, tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B)) - io.ToMatrixRegIO.ByteMask(i).valid := true.B + when(MatrixRegBankId === i.U) + { + io.ToMatrixRegIO.BankAddr(i).bits := MatrixRegAddr + io.ToMatrixRegIO.Data(i).bits := ResponseData(255, 0) + io.ToMatrixRegIO.BankAddr(i).valid := true.B + io.ToMatrixRegIO.Data(i).valid := true.B + io.ToMatrixRegIO.ByteMask(i).bits := Mux(searchEntry.MatrixRegisTail, tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B)) + io.ToMatrixRegIO.ByteMask(i).valid := true.B + } } } - } - MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData - MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr - MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U - MReg_Fill_Table_IsTail(MReg_Fill_Table_Insert_Index) := searchEntry.MatrixRegisTail + MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData + MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr + MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U + MReg_Fill_Table_IsTail(MReg_Fill_Table_Insert_Index) := searchEntry.MatrixRegisTail - Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index - Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), BMemoryLoaderReadFromMemoryFIFODepth) + Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index + Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), BMemoryLoaderReadFromMemoryFIFODepth) + } //需要一个fifo?TODO:需要fifo的设计是可能这里会堵,实际上我们满吞吐的doublebuff的设计,咱们这里是不会堵的,直接填就完事了?还是等总线上去握手? //MatrixReg->MemoryLoader->MMU->Memory Bus->Memory上的长组合逻辑链,可以实现一下,为后续的开发做准备 //否则就靠软件来保证数据流和访存流,保证访存流的稳定性,一定不会堵,就可以省下这个长组合逻辑的延迟? @@ -327,70 +476,152 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ } } - // Fill_Table的回填优先级最高,一旦有回填任务就立即执行 - val HasScarhpadWrite = Have_Bank_Fill - val Current_Fill_MReg_Time = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(0.U(1.W)))) - if (ABMLNeedMRegFillTable) - { - for (i <- 0 until ABMatrixRegNBanks){ - when(Bank_Fill_Search_FIFO_Empty(i) === false.B){ - val CurrentFIFOIndex = Bank_Fill_Search_FIFO(i)(Bank_Fill_Search_FIFO_Tail(i)) - val fillSlot = MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex) - val fillLowHalf = fillSlot(0) === 0.U - val fillSlotOH = UIntToOH(fillSlot, MAX_Fill_Times) - val currentIsTail = MReg_Fill_Table_IsTail(CurrentFIFOIndex) - Current_Fill_MReg_Time(i) := 1.U - val MatrixRegWriteRequest = io.ToMatrixRegIO - val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*ABMatrixRegEntryByteSize).W))))) - FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) - MatrixRegWriteRequest.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + fillSlot - MatrixRegWriteRequest.BankAddr(i).valid := true.B - MatrixRegWriteRequest.Data(i).bits := Mux(fillLowHalf, FIFOData(0), FIFOData(1)) - MatrixRegWriteRequest.Data(i).valid := true.B - MatrixRegWriteRequest.ByteMask(i).bits := Mux(currentIsTail && fillSlotOH(1), tailTaskMask(63, 32), Mux(currentIsTail && fillSlotOH(0), tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B))) - MatrixRegWriteRequest.ByteMask(i).valid := true.B - if (YJPBMLDebugEnable) { - printf("[BML_MRegWriteHandshake<%d>] bankid: %d, CurrentFIFOIndex: %d, ScartchPadAddr: %x, BankAddr: %x, Data: %x, ByteMask: %x\n", io.DebugInfo.DebugTimeStampe,i.U, CurrentFIFOIndex, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MatrixRegWriteRequest.BankAddr(i).bits, MatrixRegWriteRequest.Data(i).bits, MatrixRegWriteRequest.ByteMask(i).bits) - } - - MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U - when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ - Bank_Fill_Search_FIFO_Tail(i) := WrapInc(Bank_Fill_Search_FIFO_Tail(i), BMemoryLoaderReadFromMemoryFIFODepth) + when(Is_Transpose) { + for (i <- 0 until ABMatrixRegNBanks) { + val routerValid = transRouters(i).io.valid + io.ToMatrixRegIO.BankAddr(i).bits := transWriteAddr + io.ToMatrixRegIO.BankAddr(i).valid := routerValid + io.ToMatrixRegIO.Data(i).bits := transRouters(i).io.final_data + io.ToMatrixRegIO.Data(i).valid := routerValid + io.ToMatrixRegIO.ByteMask(i).bits := transRouters(i).io.final_mask + io.ToMatrixRegIO.ByteMask(i).valid := routerValid + if (YJPBMLDebugEnable) { + when(routerValid) { + printf("[BML_TransposeWrite<%d>] bank:%d, Addr:%d, Data:%x, Mask:%x\n", + io.DebugInfo.DebugTimeStampe, i.U, io.ToMatrixRegIO.BankAddr(i).bits, + transRouters(i).io.final_data, transRouters(i).io.final_mask) } - - if (YJPBMLDebugEnable) - { - //输出fill_time 和 fifoindex - printf("[BML BMemoryLoader_Load<%d>]bankid: %d,CurrentFIFOIndex %d,ScartchPadAddr: %x, MReg_Fill_Table_Time(CurrentFIFOIndex): %d\n", io.DebugInfo.DebugTimeStampe,i.U, CurrentFIFOIndex, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MReg_Fill_Table_Time(CurrentFIFOIndex)) - printf("[BML BMemoryLoader_Load<%d>]bankid: %d,ScartchPadAddr: %x, BankAddr: %x, Data: %x\n", io.DebugInfo.DebugTimeStampe,i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MatrixRegWriteRequest.BankAddr(i).bits, MatrixRegWriteRequest.Data(i).bits) + } + } + when(transRouterWriteValid) { + printf("[BML_TRANS_WRITE<%d>] validVec:%b base:%d cnt:%d addr:%d bank0Data:%x bank0Mask:%x totalLoad:%d pipelineEmpty:%d\n", + io.DebugInfo.DebugTimeStampe, transRouterValidVec, transWriteBaseAddr, transWriteAddrCnt, + transWriteAddr, transRouters(0).io.final_data, transRouters(0).io.final_mask, + TotalLoadSize, transPipelineEmpty.asUInt) + transWriteAddrCnt := Mux( + transWriteAddrCnt === (Trans_Load_Size - 1).U, + 0.U, + transWriteAddrCnt + 1.U + ) + } + val Current_Load_Fill_Size = transRouterWriteValid.asUInt + val nextTotalLoadSize = TotalLoadSize + Current_Load_Fill_Size + val transposeDone = TotalRequestSize === MaxRequestIter && + group_req_cnt === 0.U && group_resp_cnt === 0.U && + transPipelineEmpty + TotalLoadSize := nextTotalLoadSize + if(YJPBMLDebugEnable){ + when(Current_Load_Fill_Size =/= 0.U) { + printf("[BML_TransposeLoad<%d>]TotalLoadSize:%d, FillSize:%d\n", io.DebugInfo.DebugTimeStampe, TotalLoadSize, Current_Load_Fill_Size) + } + } + when(transposeDone){ + memoryload_state := s_load_quiesce + transposeEndDrainCnt := (transposeEndDrainCycles - 1).U + if (YJPBMLDebugEnable) printf("[BML<%d>]TransposeFullLoadEnd\n", io.DebugInfo.DebugTimeStampe) + } + }.otherwise { + // Fill_Table的回填优先级最高,一旦有回填任务就立即执行 + val HasScarhpadWrite = Have_Bank_Fill + val Current_Fill_MReg_Time = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(0.U(1.W)))) + if (ABMLNeedMRegFillTable) + { + for (i <- 0 until ABMatrixRegNBanks){ + when(Bank_Fill_Search_FIFO_Empty(i) === false.B){ + val CurrentFIFOIndex = Bank_Fill_Search_FIFO(i)(Bank_Fill_Search_FIFO_Tail(i)) + val fillSlot = MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex) + val fillLowHalf = fillSlot(0) === 0.U + val fillSlotOH = UIntToOH(fillSlot, MAX_Fill_Times) + val currentIsTail = MReg_Fill_Table_IsTail(CurrentFIFOIndex) + Current_Fill_MReg_Time(i) := 1.U + val MatrixRegWriteRequest = io.ToMatrixRegIO + val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*ABMatrixRegEntryByteSize).W))))) + FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) + MatrixRegWriteRequest.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + fillSlot + MatrixRegWriteRequest.BankAddr(i).valid := true.B + MatrixRegWriteRequest.Data(i).bits := Mux(fillLowHalf, FIFOData(0), FIFOData(1)) + MatrixRegWriteRequest.Data(i).valid := true.B + MatrixRegWriteRequest.ByteMask(i).bits := Mux(currentIsTail && fillSlotOH(1), tailTaskMask(63, 32), Mux(currentIsTail && fillSlotOH(0), tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B))) + MatrixRegWriteRequest.ByteMask(i).valid := true.B + if (YJPBMLDebugEnable) { + printf("[BML_MRegWriteHandshake<%d>] bankid: %d, CurrentFIFOIndex: %d, ScartchPadAddr: %x, BankAddr: %x, Data: %x, ByteMask: %x\n", io.DebugInfo.DebugTimeStampe,i.U, CurrentFIFOIndex, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MatrixRegWriteRequest.BankAddr(i).bits, MatrixRegWriteRequest.Data(i).bits, MatrixRegWriteRequest.ByteMask(i).bits) + } + + MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U + when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ + Bank_Fill_Search_FIFO_Tail(i) := WrapInc(Bank_Fill_Search_FIFO_Tail(i), BMemoryLoaderReadFromMemoryFIFODepth) + } + + if (YJPBMLDebugEnable) + { + //输出fill_time 和 fifoindex + printf("[BML BMemoryLoader_Load<%d>]bankid: %d,CurrentFIFOIndex %d,ScartchPadAddr: %x, MReg_Fill_Table_Time(CurrentFIFOIndex): %d\n", io.DebugInfo.DebugTimeStampe,i.U, CurrentFIFOIndex, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MReg_Fill_Table_Time(CurrentFIFOIndex)) + printf("[BML BMemoryLoader_Load<%d>]bankid: %d,ScartchPadAddr: %x, BankAddr: %x, Data: %x\n", io.DebugInfo.DebugTimeStampe,i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MatrixRegWriteRequest.BankAddr(i).bits, MatrixRegWriteRequest.Data(i).bits) + } } } } - } - val Current_Load_Fill_Size = WireInit(0.U((log2Ceil(ABMatrixRegNBanks)+1).W)) - Current_Load_Fill_Size := PopCount(Current_Fill_MReg_Time.asUInt) + val Current_Load_Fill_Size = WireInit(0.U((log2Ceil(ABMatrixRegNBanks)+1).W)) + Current_Load_Fill_Size := PopCount(Current_Fill_MReg_Time.asUInt) - if (ABMLNeedMRegFillTable) - { - TotalLoadSize := TotalLoadSize + Current_Load_Fill_Size - } - if (YJPBMLDebugEnable) - { - when(Current_Load_Fill_Size =/= 0.U) + if (ABMLNeedMRegFillTable) { - printf("[BMemoryLoader_Load<%d>]Current_Load_Fill_Size: %d, TotalLoadSize: %d, MaxLoadSize: %d\n",io.DebugInfo.DebugTimeStampe, Current_Load_Fill_Size, TotalLoadSize, MaxRequestIter * MAX_Fill_Times.U) + TotalLoadSize := TotalLoadSize + Current_Load_Fill_Size } - } - //状态机切换 - when(TotalLoadSize === (MaxRequestIter * MAX_Fill_Times.U)){ - memoryload_state := s_load_end if (YJPBMLDebugEnable) { - printf("[BMemoryLoader_Load<%d>]LoadEnd\n",io.DebugInfo.DebugTimeStampe) + when(Current_Load_Fill_Size =/= 0.U) + { + printf("[BMemoryLoader_Load<%d>]Current_Load_Fill_Size: %d, TotalLoadSize: %d, MaxLoadSize: %d\n",io.DebugInfo.DebugTimeStampe, Current_Load_Fill_Size, TotalLoadSize, MaxRequestIter * MAX_Fill_Times.U) + } + } + //状态机切换 + when(TotalLoadSize === (MaxRequestIter * MAX_Fill_Times.U)){ + memoryload_state := s_load_end + if (YJPBMLDebugEnable) + { + printf("[BMemoryLoader_Load<%d>]LoadEnd\n",io.DebugInfo.DebugTimeStampe) + } } } } + is(s_load_quiesce) { + io.ToMatrixRegIO.active := true.B + for (i <- 0 until ABMatrixRegNBanks) { + val routerValid = transRouters(i).io.valid + io.ToMatrixRegIO.BankAddr(i).bits := transWriteAddr + io.ToMatrixRegIO.BankAddr(i).valid := routerValid + io.ToMatrixRegIO.Data(i).bits := transRouters(i).io.final_data + io.ToMatrixRegIO.Data(i).valid := routerValid + io.ToMatrixRegIO.ByteMask(i).bits := transRouters(i).io.final_mask + io.ToMatrixRegIO.ByteMask(i).valid := routerValid + if (YJPBMLDebugEnable) { + when(routerValid) { + printf("[BML_TransposeQuiesceWrite<%d>] bank:%d, Addr:%d, Data:%x, Mask:%x\n", + io.DebugInfo.DebugTimeStampe, i.U, io.ToMatrixRegIO.BankAddr(i).bits, + transRouters(i).io.final_data, transRouters(i).io.final_mask) + } + } + } + when(transRouterWriteValid) { + printf("[BML_TRANS_QUIESCE_WRITE<%d>] validVec:%b base:%d cnt:%d addr:%d bank0Data:%x bank0Mask:%x drainCnt:%d pipelineEmpty:%d\n", + io.DebugInfo.DebugTimeStampe, transRouterValidVec, transWriteBaseAddr, transWriteAddrCnt, + transWriteAddr, transRouters(0).io.final_data, transRouters(0).io.final_mask, + transposeEndDrainCnt, transPipelineEmpty.asUInt) + transWriteAddrCnt := Mux( + transWriteAddrCnt === (Trans_Load_Size - 1).U, + 0.U, + transWriteAddrCnt + 1.U + ) + } + when(transposeEndDrainCnt === 0.U) { + memoryload_state := s_load_end + if (YJPBMLDebugEnable) printf("[BML<%d>]TransposeQuiesceEnd\n", io.DebugInfo.DebugTimeStampe) + }.otherwise { + transposeEndDrainCnt := transposeEndDrainCnt - 1.U + } + } is(s_load_end) { io.ConfigInfo.MicroTaskEndValid := true.B when(io.ConfigInfo.MicroTaskEndValid && io.ConfigInfo.MicroTaskEndReady){ diff --git a/src/main/scala/Bundles.scala b/src/main/scala/Bundles.scala index 40f9af0..12d36e7 100644 --- a/src/main/scala/Bundles.scala +++ b/src/main/scala/Bundles.scala @@ -79,17 +79,34 @@ object Bundles { } class AmuMmaIO extends Bundle { + // rounding mode (xmfrm/xmxrm) val rm = UInt(3.W) // 52 : 50 + // dest matrix register index val md = UInt(4.W) // 49 : 46 + // whether saturate (xmsaten) val sat = Bool() // 45 + // src matrix register indices val ms1 = UInt(4.W) // 44 : 41 val ms2 = UInt(4.W) // 40 : 37 + + // the scale of mma operations, m/n/k val mtilem = Mtilex() // 36 : 28 val mtilen = Mtilex() // 27 : 19 val mtilek = Mtilex() // 18 : 10 - val types2 = UInt(3.W) // 9 : 8 + + // the type of source matrices + // - lower 2 bits stands for the element width: + // - 0: e8, 1: e16, 2: e32, 3: e4 + // - the highest bit determines the specific type: + // - 0 for unsigned and 1 for signed when isfp is false + // - 0 for e5m2 and 1 for e4m3 when the type is 8-bit fp + // - 0 for fp16 and 1 for bf16 when the type is 16-bit fp + // - 0 for fp32 and 1 for tf32 when the type is 32-bit fp + val types2 = UInt(3.W) // 9 : 7 val types1 = UInt(3.W) // 6 : 4 + // the same as types1/2, but for destination matrix val typed = UInt(3.W) // 3 : 1 + // whether floating point mma val isfp = Bool() // 0 } @@ -108,14 +125,23 @@ object Bundles { val transpose = Bool() // 120 // whether accumulation register val isacc = Bool() // 119 + // whether matrix A val isA = Bool() // 118 + // whether matrix B val isB = Bool() // 117 + // the address of the first element of the matrix val baseAddr = UInt(48.W) // 116 : 69 + // the stride of the matrix val stride = UInt(48.W) // 68 : 21 - + + // the number of rows of the matrix val row = Mtilex() // 20 : 12 + // the number of columns of the matrix val column = Mtilex() // 11 : 3 + // the width of elements in the matrix, see also MSew + // 0: e8, 1: e16, 2: e32, 3: e64, 7: e4 + // other values are reserved val widths = MtypeMSew() // 2 : 0 } @@ -136,7 +162,12 @@ object Bundles { } class AmuArithIO extends Bundle { + // Only support mzero1r currently + + // dest matrix register index val md = UInt(4.W) // 12 : 9 + // operation type + // see also package.scala val opType = UInt(9.W) // 8 : 0 } @@ -174,4 +205,41 @@ object Bundles { def releaseOp() : UInt = "b10".U def arithOp() : UInt = "b11".U } -} \ No newline at end of file +} + +object CutePerfEventCounts { + val Backend = 9 + val Mem = 12 +} + +class TaskControllerPerfProbe(implicit p: Parameters) extends CuteBundle { + val ownedWork = Bool() + val retire = Bool() + val loadADone = Bool() + val loadBDone = Bool() + val loadCDone = Bool() + val storeDone = Bool() + val compDone = Bool() + val releaseDone = Bool() + val mmaNonfpDone = Bool() + val mmaFp16Done = Bool() + val mmaBf16Done = Bool() + val mmaTf32Done = Bool() + val amlActive = Bool() + val bmlActive = Bool() + val cmlLoadActive = Bool() + val mteActive = Bool() + val cmlStoreActive = Bool() +} + +class LocalMMUPerfProbe extends Bundle { + val rdReq = Bool() + val wrReq = Bool() + val rd32BReq = UInt(6.W) + val wr32BReq = UInt(6.W) +} + +class CutePerfToCoreIO(implicit p: Parameters) extends CuteBundle { + val backendEvents = Vec(CutePerfEventCounts.Backend, UInt(6.W)) + val memEvents = Vec(CutePerfEventCounts.Mem, UInt(6.W)) +} diff --git a/src/main/scala/CDataController.scala b/src/main/scala/CDataController.scala index 189bcfc..756c251 100644 --- a/src/main/scala/CDataController.scala +++ b/src/main/scala/CDataController.scala @@ -36,8 +36,6 @@ class CDataController(implicit p: Parameters) extends CuteModule{ io.FromMatrixRegIO.WriteBankAddr.map(_.bits := DontCare) io.FromMatrixRegIO.WriteRequestData.map(_.valid := false.B) io.FromMatrixRegIO.WriteRequestData.map(_.bits := DontCare) - io.FromMatrixRegIO.WriteRequestByteMask.map(_.valid := false.B) - io.FromMatrixRegIO.WriteRequestByteMask.map(_.bits := Fill(CMatrixRegEntryByteSize, true.B)) io.ConfigInfo.MicroTaskEndValid := false.B io.ConfigInfo.MicroTaskReady := false.B io.ConfigInfo.MicroTask_TEComputeEndValid := false.B @@ -57,21 +55,32 @@ class CDataController(implicit p: Parameters) extends CuteModule{ when (io.ConfigInfo.MicroTaskValid) { pcReg := io.ConfigInfo.pc.get } - val difftestAmuFinish = DifftestModule(new DiffAmuFinishEvent, delay = 0, dontCare = true) + val difftestAmuFinish = DifftestModule(new DiffAmuFinishEvent(CMatrixRegNBanks, DiffAmuFinishWordsPerBank), delay = 0, dontCare = true) // 默认值初始化 difftestAmuFinish.coreid := io.ConfigInfo.coreid.get difftestAmuFinish.index := 3.U difftestAmuFinish.valid := (io.FromMatrixRegIO.WriteBankAddr.map(_.valid).reduce(_||_) || (io.ConfigInfo.MicroTaskEndValid && io.ConfigInfo.MicroTaskEndReady)) difftestAmuFinish.pc := pcReg + // DiffAmuFinishEvent packing is parameterized by words-per-bank; C payload may be narrower (e.g. 2Tops). + val eventWordsPerBank = difftestAmuFinish.data.length / CMatrixRegNBanks + val cMRegWordsPerBank = CMatrixRegEntryBitSize / 64 + require(difftestAmuFinish.data.length % CMatrixRegNBanks == 0, "DiffAmuFinishEvent.data should divide by C bank count") + require(CMatrixRegEntryBitSize % 64 == 0, s"CMatrixRegEntryBitSize must be 64-bit aligned, got $CMatrixRegEntryBitSize") + require(cMRegWordsPerBank <= eventWordsPerBank, s"DiffAmuFinishEvent only supports up to $eventWordsPerBank words per C bank, got $cMRegWordsPerBank") for (i <- 0 until CMatrixRegNBanks) { difftestAmuFinish.bankValid(i) := io.FromMatrixRegIO.WriteBankAddr(i).valid difftestAmuFinish.bankAddr(i) := io.FromMatrixRegIO.WriteBankAddr(i).bits - difftestAmuFinish.bankMask(i) := Fill(32, true.B) - difftestAmuFinish.data(i * 4 + 0) := io.FromMatrixRegIO.WriteRequestData(i).bits(63,0) - difftestAmuFinish.data(i * 4 + 1) := io.FromMatrixRegIO.WriteRequestData(i).bits(127,64) - difftestAmuFinish.data(i * 4 + 2) := io.FromMatrixRegIO.WriteRequestData(i).bits(191,128) - difftestAmuFinish.data(i * 4 + 3) := io.FromMatrixRegIO.WriteRequestData(i).bits(255,192) + difftestAmuFinish.bankMask(i) := Fill(CMatrixRegEntryByteSize, true.B) + for (w <- 0 until eventWordsPerBank) { + if (w < cMRegWordsPerBank) { + val lo = w * 64 + val hi = lo + 63 + difftestAmuFinish.data(i * eventWordsPerBank + w) := io.FromMatrixRegIO.WriteRequestData(i).bits(hi, lo) + } else { + difftestAmuFinish.data(i * eventWordsPerBank + w) := 0.U(64.W) + } + } } difftestAmuFinish.finish := io.ConfigInfo.MicroTaskEndValid && io.ConfigInfo.MicroTaskEndReady } diff --git a/src/main/scala/CMatrixReg.scala b/src/main/scala/CMatrixReg.scala index 891f33d..affe2f3 100644 --- a/src/main/scala/CMatrixReg.scala +++ b/src/main/scala/CMatrixReg.scala @@ -41,33 +41,13 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ //按照目前的设计,可以服务所有请求 io.MatrixRegIO.FromDataController.ReadWriteResponse := io.MatrixRegIO.FromDataController.ReadWriteRequest - io.MatrixRegIO.FromMemoryLoader.ReadWriteResponse := io.MatrixRegIO.FromMemoryLoader.ReadWriteRequest - - when(io.MatrixRegIO.FromMemoryLoader.ReadWriteRequest(MatrixRegTaskType.ReadFromMemoryLoaderIndex)) { - for (i <- 0 until CMatrixRegNBanks) { - when(io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.BankAddr(i).valid) { - if (YJPCMLDebugEnable) { - printf("[CMatrixReg_CMLReadReq(%d)] bank=%d addr=%x\n", scp_id.U, i.U, io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.BankAddr(i).bits) - } - } - } - } - - when(io.MatrixRegIO.FromMemoryLoader.ReadWriteRequest(MatrixRegTaskType.WriteFromMemoryLoaderIndex)) { - for (i <- 0 until CMatrixRegNBanks) { - when(io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.BankAddr(i).valid && io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.Data(i).valid) { - if (YJPCMLDebugEnable) { - printf("[CMatrixReg_CMLWriteReq(%d)] bank=%d addr=%x data=%x mask=%x\n", scp_id.U, i.U, - io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.BankAddr(i).bits, - io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.Data(i).bits, - io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.ByteMask(i).bits) - } - } - } - } + io.MatrixRegIO.FromMemoryLoader.LoadReadWriteResponse := io.MatrixRegIO.FromMemoryLoader.LoadReadWriteRequest + io.MatrixRegIO.FromMemoryLoader.StoreReadWriteResponse := io.MatrixRegIO.FromMemoryLoader.StoreReadWriteRequest //记录当前拍回数应该返回给哪条数据线 - val request = io.MatrixRegIO.FromDataController.ReadWriteRequest | io.MatrixRegIO.FromMemoryLoader.ReadWriteRequest + val request = io.MatrixRegIO.FromDataController.ReadWriteRequest | + io.MatrixRegIO.FromMemoryLoader.LoadReadWriteRequest | + io.MatrixRegIO.FromMemoryLoader.StoreReadWriteRequest val PreRequest = RegNext(request) val decode_request = new MatrixRegTaskDecode(request) @@ -76,11 +56,11 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ assert(!(decode_request.IsReadFromDataController && decode_request.IsReadFromMemoryLoader), "CMatrixReg: ReadFromDataController and ReadFromMemoryLoader should not be both true at the same time") assert(!(decode_request.IsWriteFromDataController && decode_request.IsWriteFromMemoryLoader), "CMatrixReg: WriteFromDataController and WriteFromMemoryLoader should not be both true at the same time") - val read_request_per_bank_addr = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(0.U(CMatrixRegBankNEntrys.W)))) + val read_request_per_bank_addr = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(0.U(CMatrixRegBankNEntries.W)))) val read_request_per_bank_valid = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(false.B))) val read_request_response_valid = RegInit(VecInit(Seq.fill(CMatrixRegNBanks)(false.B))) - val write_request_per_bank_addr = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(0.U(CMatrixRegBankNEntrys.W)))) + val write_request_per_bank_addr = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(0.U(CMatrixRegBankNEntries.W)))) val write_request_per_bank_data= WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(0.U((8*CMatrixRegEntryByteSize).W)))) val write_request_per_bank_mask = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(Fill(CMatrixRegEntryByteSize, true.B)))) val write_request_per_bank_valid = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(false.B))) @@ -101,7 +81,7 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ val sram_banks = (0 until CMatrixRegNBanks) map { i => // 两个单口SRAM,奇偶地址各自负责,期望奇偶地址读写错开,奇读偶写,偶读奇写 - val bankDepthHalf = (CMatrixRegBankNEntrys + 1) / 2 + val bankDepthHalf = (CMatrixRegBankNEntries + 1) / 2 val evenBank = Module(new SRAMTemplate( gen = UInt(8.W), set = bankDepthHalf, @@ -147,12 +127,6 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ io.MatrixRegIO.FromDataController.ReadResponseData(i).valid := decode_pre_request.IsReadFromDataController && read_request_response_valid(i) io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.ReadResponseData(i).valid := decode_pre_request.IsReadFromMemoryLoader && read_request_response_valid(i) - when(io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.ReadResponseData(i).valid) { - if (YJPCMLDebugEnable) { - printf("[CMatrixReg_CMLReadResp(%d)] bank=%d addr=%x data=%x\n", scp_id.U, i.U, debug_s1_bank_addr, s1_bank_read_data) - } - } - //单口读路径:奇偶分流 // 偶地址SRAM读请求 evenBank.io.r.req.valid := s0_bank_read_valid && s0_read_is_even @@ -189,4 +163,3 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ (evenBank, oddBank) } } - diff --git a/src/main/scala/CMemoryLoader.scala b/src/main/scala/CMemoryLoader.scala index 7c30faa..c247004 100644 --- a/src/main/scala/CMemoryLoader.scala +++ b/src/main/scala/CMemoryLoader.scala @@ -18,7 +18,7 @@ import freechips.rocketchip.util.SeqToAugmentedSeq class CSourceIdSearch(implicit p: Parameters) extends CuteBundle{ val MatrixRegBankId =UInt(log2Ceil(CMatrixRegNBanks).W) - val MatrixRegAddr = UInt(log2Ceil(CMatrixRegBankNEntrys).W) + val MatrixRegAddr = UInt(log2Ceil(CMatrixRegBankNEntries).W) val MatrixRegisTail = Bool() } @@ -26,85 +26,117 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ val io = IO(new Bundle{ val ToMatrixRegIO = Flipped(new CMemoryLoaderMatrixRegIO) val ConfigInfo = Flipped(new CMLMicroTaskConfigIO) - val LocalMMUIO = Flipped(new LocalMMUIO) + val LoadLocalMMUIO = Flipped(new LocalMMUIO) + val StoreLocalMMUIO = Flipped(new LocalMMUIO) val DebugInfo = Input(new DebugInfoIO) - val MatrixRegId = Output(UInt(CMatrixRegIdWidth.W)) + val LoadMatrixRegId = Output(UInt(CMatrixRegIdWidth.W)) + val StoreMatrixRegId = Output(UInt(CMatrixRegIdWidth.W)) }) - // 对外统一使用 ToMatrixRegIO - - io.ConfigInfo.MicroTaskEndValid := false.B - io.ConfigInfo.MicroTaskReady := false.B + io.ConfigInfo.LoadMicroTaskEndValid := false.B + io.ConfigInfo.StoreMicroTaskEndValid := false.B + io.ConfigInfo.LoadMicroTaskReady := false.B + io.ConfigInfo.StoreMicroTaskReady := false.B io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr.map(_.valid := false.B) io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr.map(_.bits := DontCare) io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr.map(_.valid := false.B) io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr.map(_.bits := DontCare) io.ToMatrixRegIO.WriteRequestToMatrixReg.Data.map(_.valid := false.B) io.ToMatrixRegIO.WriteRequestToMatrixReg.Data.map(_.bits := DontCare) + io.ToMatrixRegIO.LoadReadWriteRequest := 0.U + io.ToMatrixRegIO.StoreReadWriteRequest := 0.U + io.LoadLocalMMUIO.Request.valid := false.B + io.LoadLocalMMUIO.Request.bits := DontCare + io.LoadLocalMMUIO.Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) + io.LoadLocalMMUIO.Response.ready := false.B + io.StoreLocalMMUIO.Request.valid := false.B + io.StoreLocalMMUIO.Request.bits := DontCare + io.StoreLocalMMUIO.Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) + io.StoreLocalMMUIO.Response.ready := false.B io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask.map(_.valid := false.B) io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask.map(_.bits := Fill(CMatrixRegEntryByteSize, true.B)) - io.LocalMMUIO.Request.valid := false.B - io.LocalMMUIO.Request.bits := DontCare // It will be set if Request is valid - io.LocalMMUIO.Response.ready := false.B - - val ConfigInfo = io.ConfigInfo - val CurrentMatrixRegId = RegInit(0.U(CMatrixRegIdWidth.W)) - io.MatrixRegId := CurrentMatrixRegId + val CurrentLoadMatrixRegId = RegInit(0.U(CMatrixRegIdWidth.W)) + val CurrentStoreMatrixRegId = RegInit(0.U(CMatrixRegIdWidth.W)) + io.LoadMatrixRegId := CurrentLoadMatrixRegId + io.StoreMatrixRegId := CurrentStoreMatrixRegId + + val LoadPcReg = if (EnableDifftest) Some(RegInit(0.U(64.W))) else None + val StorePcReg = if (EnableDifftest) Some(RegInit(0.U(64.W))) else None // Difftest interface if (EnableDifftest) { - val pcReg = RegInit(0.U(64.W)) - when (io.ConfigInfo.MicroTaskValid) { - pcReg := io.ConfigInfo.pc.get - } - val difftestAmuFinish = DifftestModule(new DiffAmuFinishEvent, delay = 0, dontCare = true) - // 默认值初始化 - difftestAmuFinish.coreid := io.ConfigInfo.coreid.get - difftestAmuFinish.index := 2.U - difftestAmuFinish.valid := (io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr.map(_.valid).reduce(_||_) - || (io.ConfigInfo.MicroTaskEndValid && io.ConfigInfo.MicroTaskEndReady)) - difftestAmuFinish.pc := pcReg + DifftestModule.addCppMacro("CONFIG_DIFF_AMU_C_WORDS_PER_BANK", CMatrixRegEntryBitSize / 64) + DifftestModule.addCppMacro("CONFIG_DIFF_AMU_C_REG_SIZE_BYTES", CMatrixRegSize) + + val loadWriteAny = io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr.map(_.valid).reduce(_||_) + val loadFinishAny = io.ConfigInfo.LoadMicroTaskEndValid && io.ConfigInfo.LoadMicroTaskEndReady + val storeFinishAny = io.ConfigInfo.StoreMicroTaskEndValid && io.ConfigInfo.StoreMicroTaskEndReady + + val difftestLoadFinish = DifftestModule(new DiffAmuFinishEvent(CMatrixRegNBanks, DiffAmuFinishWordsPerBank), delay = 0, dontCare = true) + difftestLoadFinish.coreid := io.ConfigInfo.coreid.get + difftestLoadFinish.index := 2.U + difftestLoadFinish.valid := loadWriteAny || loadFinishAny + difftestLoadFinish.pc := LoadPcReg.get + + val eventWordsPerBank = difftestLoadFinish.data.length / CMatrixRegNBanks + val cMRegWordsPerBank = CMatrixRegEntryBitSize / 64 + require(difftestLoadFinish.data.length % CMatrixRegNBanks == 0, "DiffAmuFinishEvent.data should divide by C bank count") + require(CMatrixRegEntryBitSize % 64 == 0, s"CMatrixRegEntryBitSize must be 64-bit aligned, got $CMatrixRegEntryBitSize") + require(cMRegWordsPerBank <= eventWordsPerBank, s"DiffAmuFinishEvent only supports up to $eventWordsPerBank words per C bank, got $cMRegWordsPerBank") for (i <- 0 until CMatrixRegNBanks) { - difftestAmuFinish.bankValid(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(i).valid - difftestAmuFinish.bankAddr(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(i).bits - difftestAmuFinish.bankMask(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(i).bits - difftestAmuFinish.data(i * 4 + 0) := io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).bits(63,0) - difftestAmuFinish.data(i * 4 + 1) := io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).bits(127,64) - difftestAmuFinish.data(i * 4 + 2) := io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).bits(191,128) - difftestAmuFinish.data(i * 4 + 3) := io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).bits(255,192) + difftestLoadFinish.bankValid(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(i).valid + difftestLoadFinish.bankAddr(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(i).bits + difftestLoadFinish.bankMask(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(i).bits + for (w <- 0 until eventWordsPerBank) { + if (w < cMRegWordsPerBank) { + val lo = w * 64 + val hi = lo + 63 + difftestLoadFinish.data(i * eventWordsPerBank + w) := io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).bits(hi, lo) + } else { + difftestLoadFinish.data(i * eventWordsPerBank + w) := 0.U(64.W) + } + } } - difftestAmuFinish.finish := io.ConfigInfo.MicroTaskEndValid && io.ConfigInfo.MicroTaskEndReady + difftestLoadFinish.finish := loadFinishAny + + // Store path has no per-bank writeback payload, only finish handshake. + val difftestStoreFinish = DifftestModule(new DiffAmuFinishEvent(CMatrixRegNBanks, DiffAmuFinishWordsPerBank), delay = 0, dontCare = true) + difftestStoreFinish.coreid := io.ConfigInfo.coreid.get + difftestStoreFinish.index := 5.U + difftestStoreFinish.valid := storeFinishAny + difftestStoreFinish.pc := StorePcReg.get + difftestStoreFinish.bankValid.foreach(_ := false.B) + difftestStoreFinish.bankAddr.foreach(_ := 0.U) + difftestStoreFinish.bankMask.foreach(_ := 0.U) + difftestStoreFinish.data.foreach(_ := 0.U) + difftestStoreFinish.finish := storeFinishAny } - val MatrixRegTensor_M = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) - val MatrixRegTensor_N = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) - - val Tensor_C_BaseVaddr = RegInit(0.U(MMUAddrWidth.W)) - val Tensor_D_BaseVaddr = RegInit(0.U(MMUAddrWidth.W)) - - - //任务状态机 先来个简单的,顺序读取所有分块矩阵 - val s_idle :: s_mm_task :: Nil = Enum(2) - val state = RegInit(s_idle) + val LoadMatrixRegTensor_M = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) + val LoadMatrixRegTensor_N = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) + val StoreMatrixRegTensor_M = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) + val StoreMatrixRegTensor_N = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) //访存读状态机,用来配合流水线刷新 val s_load_idle :: s_load_init :: s_load_working :: s_load_end :: Nil = Enum(4) val memoryload_state = RegInit(s_load_idle) - val MemoryOrder_LoadConfig = RegInit(MemoryOrderType.OrderTypeUndef) //访存写状态机,用来配合流水线刷新 val s_store_idle :: s_store_init :: s_store_working :: s_store_end :: Nil = Enum(4) val memorystore_state = RegInit(s_store_idle) - val Tensor_Block_BaseAddr = Reg(UInt(MMUAddrWidth.W)) //分块矩阵的基地址 + val LoadTensorBlockBaseAddr = Reg(UInt(MMUAddrWidth.W)) + val StoreTensorBlockBaseAddr = Reg(UInt(MMUAddrWidth.W)) - val IsConherent = RegInit(true.B) //是否一致性访存的标志位,由TaskController提供 - val Is_Transpose = RegInit(false.B) //是否转置的标志位,由TaskController提供 + val IsLoadConherent = RegInit(true.B) + val IsStoreConherent = RegInit(true.B) + val IsStoreTranspose = RegInit(false.B) val HasScarhpadRead = WireInit(false.B) val HasScarhpadWrite = WireInit(false.B) - io.ToMatrixRegIO.ReadWriteRequest := Cat(HasScarhpadRead,Cat(HasScarhpadWrite,Cat(0.U(1.W),0.U(1.W)))) + io.ToMatrixRegIO.LoadReadWriteRequest := Cat(0.U(1.W), Cat(HasScarhpadWrite, Cat(0.U(1.W), 0.U(1.W)))) + io.ToMatrixRegIO.StoreReadWriteRequest := Cat(HasScarhpadRead, Cat(0.U(1.W), Cat(0.U(1.W), 0.U(1.W)))) val ApplicationTensor_C_Stride_M = RegInit(0.U(MMUAddrWidth.W)) val ApplicationTensor_D_Stride_M = RegInit(0.U(MMUAddrWidth.W)) @@ -116,74 +148,53 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ val Is_FullLoad = RegInit(false.B) val Is_RepeatRowLoad = RegInit(false.B) - val C_DataType = RegInit(0.U(ElementDataType.DataTypeBitWidth.W)) + val C_DataWidth = RegInit(0.U(ElementDataType.DataTypeBitWidth.W)) val D_DataType = RegInit(0.U(ElementDataType.DataTypeBitWidth.W)) + val loadTaskIdle = memoryload_state === s_load_idle + val storeTaskIdle = memorystore_state === s_store_idle + io.ConfigInfo.LoadMicroTaskReady := loadTaskIdle + io.ConfigInfo.StoreMicroTaskReady := storeTaskIdle + + when(io.ConfigInfo.LoadMicroTaskValid && io.ConfigInfo.LoadMicroTaskReady){ + CurrentLoadMatrixRegId := io.ConfigInfo.MatrixRegId + LoadTensorBlockBaseAddr := io.ConfigInfo.ApplicationTensor_C.BlockTensor_C_BaseVaddr + ApplicationTensor_C_Stride_M := io.ConfigInfo.ApplicationTensor_C.ApplicationTensor_C_Stride_M + IsLoadConherent := io.ConfigInfo.Conherent + LoadMatrixRegTensor_M := io.ConfigInfo.MatrixRegTensor_M + LoadMatrixRegTensor_N := io.ConfigInfo.MatrixRegTensor_N + HasTail := io.ConfigInfo.ApplicationTensor_C.HasTail + TailByteMask := io.ConfigInfo.ApplicationTensor_C.TailByteMask + N_Beat_Count := io.ConfigInfo.ApplicationTensor_C.N_Beat_Count + + Is_ZeroLoad := io.ConfigInfo.LoadTaskInfo.Is_ZeroLoad + Is_FullLoad := io.ConfigInfo.LoadTaskInfo.Is_FullLoad + Is_RepeatRowLoad := io.ConfigInfo.LoadTaskInfo.Is_RepeatRowLoad + val peDataType = new FReducePEDataType + C_DataWidth := peDataType.CdataByteWidth(io.ConfigInfo.ApplicationTensor_C.dataType) + memoryload_state := s_load_init + if (EnableDifftest) { + LoadPcReg.get := io.ConfigInfo.pc.get + } + } - when(state === s_idle) - { - io.ConfigInfo.MicroTaskReady := true.B - //如果configinfo有效 - when(io.ConfigInfo.MicroTaskReady && io.ConfigInfo.MicroTaskValid){ - state := s_mm_task - CurrentMatrixRegId := io.ConfigInfo.MatrixRegId - if (YJPCMLDebugEnable) { - printf("[CMemoryLoader_TaskHandshake<%d>] valid=%d ready=%d matrixRegId=%d isLoad=%d isStore=%d coher=%d transpose=%d M=%d N=%d baseC=%x baseD=%x\n", - io.DebugInfo.DebugTimeStampe, - io.ConfigInfo.MicroTaskValid, - io.ConfigInfo.MicroTaskReady, - io.ConfigInfo.MatrixRegId, - io.ConfigInfo.IsLoadMicroTask, - io.ConfigInfo.IsStoreMicroTask, - io.ConfigInfo.Conherent, - io.ConfigInfo.Is_Transpose, - io.ConfigInfo.MatrixRegTensor_M, - io.ConfigInfo.MatrixRegTensor_N, - io.ConfigInfo.ApplicationTensor_C.BlockTensor_C_BaseVaddr, - io.ConfigInfo.ApplicationTensor_D.BlockTensor_D_BaseVaddr) - } - assert( - !(io.ConfigInfo.IsLoadMicroTask === true.B && io.ConfigInfo.IsStoreMicroTask === true.B), - "CMemoryLoader: Load and Store MicroTask cannot be enabled at the same time" - ) - when(io.ConfigInfo.IsLoadMicroTask === true.B){ - memoryload_state := s_load_init - Tensor_Block_BaseAddr := io.ConfigInfo.ApplicationTensor_C.BlockTensor_C_BaseVaddr - ApplicationTensor_C_Stride_M := io.ConfigInfo.ApplicationTensor_C.ApplicationTensor_C_Stride_M - IsConherent := io.ConfigInfo.Conherent - HasTail := io.ConfigInfo.ApplicationTensor_C.HasTail - TailByteMask := io.ConfigInfo.ApplicationTensor_C.TailByteMask - N_Beat_Count := io.ConfigInfo.ApplicationTensor_C.N_Beat_Count - - Is_ZeroLoad := io.ConfigInfo.LoadTaskInfo.Is_ZeroLoad - Is_FullLoad := io.ConfigInfo.LoadTaskInfo.Is_FullLoad - Is_RepeatRowLoad := io.ConfigInfo.LoadTaskInfo.Is_RepeatRowLoad - - C_DataType := io.ConfigInfo.ApplicationTensor_C.dataType - if(YJPCMLDebugEnable) - { - printf("[CMemoryLoader_Load<%d>]Load C Tensor Start, Tensor_Block_BaseAddr: %x, ApplicationTensor_C_Stride_M: %x, IsConherent: %x,MatrixRegTensor_M: %x,MatrixRegTensor_N: %x,C_DataType(zero,full,repeatrow) :(%d,%d,%d)\n", io.DebugInfo.DebugTimeStampe, io.ConfigInfo.ApplicationTensor_C.BlockTensor_C_BaseVaddr, io.ConfigInfo.ApplicationTensor_C.ApplicationTensor_C_Stride_M, io.ConfigInfo.Conherent,io.ConfigInfo.MatrixRegTensor_M,io.ConfigInfo.MatrixRegTensor_N,io.ConfigInfo.LoadTaskInfo.Is_ZeroLoad.asUInt,io.ConfigInfo.LoadTaskInfo.Is_FullLoad.asUInt,io.ConfigInfo.LoadTaskInfo.Is_RepeatRowLoad.asUInt) - } - - } - when(io.ConfigInfo.IsStoreMicroTask === true.B){ - memorystore_state := s_store_init - Tensor_Block_BaseAddr := io.ConfigInfo.ApplicationTensor_D.BlockTensor_D_BaseVaddr - IsConherent := io.ConfigInfo.Conherent - ApplicationTensor_D_Stride_M := io.ConfigInfo.ApplicationTensor_D.ApplicationTensor_D_Stride_M - Is_Transpose := io.ConfigInfo.Is_Transpose - - D_DataType := io.ConfigInfo.ApplicationTensor_D.dataType - if(YJPCMLDebugEnable) - { - printf("[CMemoryLoader_Start<%d>]Store D Tensor Start, Tensor_Block_BaseAddr: %x, ApplicationTensor_D_Stride_M: %x, IsConherent: %x, Is_Transpose: %x,MatrixRegTensor_M: %x,MatrixRegTensor_N: %x\n", io.DebugInfo.DebugTimeStampe, io.ConfigInfo.ApplicationTensor_D.BlockTensor_D_BaseVaddr, io.ConfigInfo.ApplicationTensor_D.ApplicationTensor_D_Stride_M, io.ConfigInfo.Conherent, io.ConfigInfo.Is_Transpose,io.ConfigInfo.MatrixRegTensor_M,io.ConfigInfo.MatrixRegTensor_N) - } - - } - MatrixRegTensor_M := io.ConfigInfo.MatrixRegTensor_M - MatrixRegTensor_N := io.ConfigInfo.MatrixRegTensor_N + when(io.ConfigInfo.StoreMicroTaskValid && io.ConfigInfo.StoreMicroTaskReady){ + CurrentStoreMatrixRegId := io.ConfigInfo.MatrixRegId + StoreTensorBlockBaseAddr := io.ConfigInfo.ApplicationTensor_D.BlockTensor_D_BaseVaddr + IsStoreConherent := io.ConfigInfo.Conherent + ApplicationTensor_D_Stride_M := io.ConfigInfo.ApplicationTensor_D.ApplicationTensor_D_Stride_M + IsStoreTranspose := io.ConfigInfo.Is_Transpose + StoreMatrixRegTensor_M := io.ConfigInfo.MatrixRegTensor_M + StoreMatrixRegTensor_N := io.ConfigInfo.MatrixRegTensor_N + D_DataType := io.ConfigInfo.ApplicationTensor_D.dataType + memorystore_state := s_store_init + if (EnableDifftest) { + StorePcReg.get := io.ConfigInfo.pc.get } } + assert(!(io.ConfigInfo.LoadMicroTaskValid && io.ConfigInfo.StoreMicroTaskValid), + "CMemoryLoader: split channels share one config payload, load/store valid cannot be high together") + //三个张量的虚拟地址,肯定得是连续的,这个可以交给操作系统和编译器来保证 @@ -225,7 +236,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ //2.给每个准备回填数据的bank,找到其对应的Fill_FIFO的index,在这个fill_fifo[index]的filltime+1,如果filltime==MAX_Fill_Times,那么这个数据就用完了 //3.更新FIFO,更新Tail,更新Table val MReg_Fill_Table = RegInit((VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(0.U(outsideDataWidth.W))))) - val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(CMatrixRegBankNEntrys).W)))))//记录这个LLC回的数是在scp的哪个地址 + val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(CMatrixRegBankNEntries).W)))))//记录这个LLC回的数是在scp的哪个地址 val MReg_Fill_Table_Time = RegInit((VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(0.U((log2Ceil(outsideDataWidthByte/CMatrixRegEntryByteSize)+1).W)))))//记录这个LLC回的数需要回填的次数,完成就可以将数据释放了 val MReg_Fill_Table_IsTail = RegInit(VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(false.B))) val MReg_Fill_Table_Free = MReg_Fill_Table_Time.map(_ === 0.U)//记录这个FIFO能否能填数据 @@ -260,7 +271,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ val Request_M_Iter_Time = RegInit(0.U(log2Ceil(Matrix_MN).W)) // val Fill_N_Iter_Time = RegInit(0.U(log2Ceil(Tensor_MN).W)) //读数请求 - val ReadRequest = io.LocalMMUIO.Request + val ReadRequest = io.LoadLocalMMUIO.Request switch(memoryload_state) { is(s_load_init) { memoryload_state := s_load_working @@ -268,7 +279,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ TotalRequestSize := 0.U CurrentLoaded_BlockTensor_M_Iter := 0.U CurrentLoaded_BlockTensor_N_Iter := 0.U - MaxRequestIter := MatrixRegTensor_M * N_Beat_Count //总共要发出的访存请求的次数 + MaxRequestIter := LoadMatrixRegTensor_M * N_Beat_Count //总共要发出的访存请求的次数 Bank_Fill_Search_FIFO := 0.U.asTypeOf(Bank_Fill_Search_FIFO) Bank_Fill_Search_FIFO_Head := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Head) Bank_Fill_Search_FIFO_Tail := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Tail) @@ -329,19 +340,20 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ val tailTaskMask = UIntToOH(TailByteMask, outsideDataWidthByte + 1).asUInt - 1.U(outsideDataWidthByte.W) val RequestBeatIsTail = HasTail && (CurrentLoaded_BlockTensor_N_Iter === (N_Beat_Count - 1.U)) val RequestMatrixRegBankId = (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) % CMatrixRegNBanks.U //访存请求落在哪个MatrixRegBank上 - val RequestMatrixRegAddr = ((CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time ) / CMatrixRegNBanks.U ) * (Tensor_MN.U / Matrix_MN.U) + (CurrentLoaded_BlockTensor_N_Iter << log2Ceil(MAX_Fill_Times)) //该访存请求的第零号数据,落在哪个MatrixRegBank的哪个地址上 + val RequestMatrixRegAddr = ((CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) / CMatrixRegNBanks.U) * (Tensor_MN.U / Matrix_MN.U) + (CurrentLoaded_BlockTensor_N_Iter << log2Ceil(MAX_Fill_Times)) //该访存请求的第零号数据,落在哪个MatrixRegBank的哪个地址上 - ReadRequest.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) * ApplicationTensor_C_Stride_M + (CurrentLoaded_BlockTensor_N_Iter << log2Ceil(outsideDataWidthByte)) + ReadRequest.bits.RequestVirtualAddr := LoadTensorBlockBaseAddr + (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) * ApplicationTensor_C_Stride_M + (CurrentLoaded_BlockTensor_N_Iter << log2Ceil(outsideDataWidthByte)) // val CurrentBankID = RequestMatrixRegBankId // val CurrentFIFOIndex = FromMemoryLoaderReadFIFOHead - val sourceId = Mux(IsConherent,io.LocalMMUIO.ConherentRequsetSourceID,io.LocalMMUIO.nonConherentRequsetSourceID) + val sourceId = Mux(IsLoadConherent,io.LoadLocalMMUIO.ConherentRequsetSourceID,io.LoadLocalMMUIO.nonConherentRequsetSourceID) - ReadRequest.bits.RequestConherent := IsConherent + ReadRequest.bits.RequestConherent := IsLoadConherent ReadRequest.bits.RequestSourceID := sourceId.bits ReadRequest.bits.RequestType_isWrite := false.B + ReadRequest.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) ReadRequest.valid := (TotalRequestSize < MaxRequestIter) //确定这个访存请求一定会发出 @@ -353,7 +365,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ SoureceIdSearchTable(sourceId.bits) := TableItem.asUInt Request_M_Iter_Time := Request_M_Iter_Time + 1.U//连续的跨bank去访存 - when(Request_M_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) === MatrixRegTensor_M - 1.U){ + when(Request_M_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) === LoadMatrixRegTensor_M - 1.U){ Request_M_Iter_Time := 0.U CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + 1.U when(CurrentLoaded_BlockTensor_N_Iter + 1.U === N_Beat_Count){ @@ -370,8 +382,9 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ //TODO:这里数据读取量定死了,需要为了支持边界情况,改一改 //不过我们保证了数据是256bit对齐的~剩下的就是Tensor_M和Tensor_K不满足的情况思考好就行了 //输出request的次数 - if (YJPCMLDebugEnable) { - printf("[CMemoryLoader_LoadReq<%d>] RequestMatrixRegAddr: %x,RequestMatrixRegBankId: %x,CurrentLoaded_BlockTensor_N_Iter: %x,CurrentLoaded_BlockTensor_M_Iter: %x,Request_M_Iter_Time: %x,RequestVirtualAddr: %x, RequestSourceID: %x, RequestConherent: %x, RequestType_isWrite: %x, RequestTimes: %d\n", io.DebugInfo.DebugTimeStampe, RequestMatrixRegAddr,RequestMatrixRegBankId,CurrentLoaded_BlockTensor_N_Iter,CurrentLoaded_BlockTensor_M_Iter,Request_M_Iter_Time,ReadRequest.bits.RequestVirtualAddr, ReadRequest.bits.RequestSourceID, ReadRequest.bits.RequestConherent, ReadRequest.bits.RequestType_isWrite, TotalRequestSize) + if (YJPCMLDebugEnable) + { + printf("[CMemoryLoader_Load<%d>]RequestMatrixRegAddr: %x,RequestMatrixRegBankId: %x,CurrentLoaded_BlockTensor_N_Iter: %x,CurrentLoaded_BlockTensor_M_Iter: %x,Request_M_Iter_Time: %x,RequestVirtualAddr: %x, RequestSourceID: %x, RequestConherent: %x, RequestType_isWrite: %x, RequestTimes: %d\n", io.DebugInfo.DebugTimeStampe, RequestMatrixRegAddr,RequestMatrixRegBankId,CurrentLoaded_BlockTensor_N_Iter,CurrentLoaded_BlockTensor_M_Iter,Request_M_Iter_Time,ReadRequest.bits.RequestVirtualAddr, ReadRequest.bits.RequestSourceID, ReadRequest.bits.RequestConherent, ReadRequest.bits.RequestType_isWrite, TotalRequestSize) } when(TotalRequestSize === MaxRequestIter){ //assert! @@ -382,22 +395,22 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ } val current_fill_fifo_full = WireInit(false.B) - when(io.LocalMMUIO.Response.valid) + when(io.LoadLocalMMUIO.Response.valid) { - val sourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID + val sourceId = io.LoadLocalMMUIO.Response.bits.ReseponseSourceID val MatrixRegBankId = SoureceIdSearchTable(sourceId).asTypeOf(new CSourceIdSearch).MatrixRegBankId current_fill_fifo_full := Bank_Fill_Search_FIFO_Full(MatrixRegBankId) } - io.LocalMMUIO.Response.ready := MReg_Fill_Table_Not_Full && (current_fill_fifo_full === false.B) + io.LoadLocalMMUIO.Response.ready := MReg_Fill_Table_Not_Full && (current_fill_fifo_full === false.B) //接受访存的返回值 //一个cam来存储访存请求的source_id对应的MatrixReg的地址和bank号 //根据response的sourceid,找到对应的MatrixReg的地址和bank号,回填数据 - when(io.LocalMMUIO.Response.fire){ - val sourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID + when(io.LoadLocalMMUIO.Response.fire){ + val sourceId = io.LoadLocalMMUIO.Response.bits.ReseponseSourceID val MatrixRegBankId = SoureceIdSearchTable(sourceId).asTypeOf(new CSourceIdSearch).MatrixRegBankId val MatrixRegAddr = SoureceIdSearchTable(sourceId).asTypeOf(new CSourceIdSearch).MatrixRegAddr - val ResponseData = io.LocalMMUIO.Response.bits.ReseponseData + val ResponseData = io.LoadLocalMMUIO.Response.bits.ReseponseData val FIFOIndex = Bank_Fill_Search_FIFO_Head(MatrixRegBankId)//该bank的fill_fifo_index,标注了它当前在fillfifo的哪个位置,我们一共有bank个fill_fifo MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData @@ -409,8 +422,9 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), CMemoryLoaderReadFromMemoryFIFODepth) //输出回填的数据 - if (YJPCMLDebugEnable) { - printf("[CMemoryLoader_LoadResp<%d>]ResponseData: %x, MatrixRegBankId: %x, MatrixRegAddr: %x, FIFOIndex: %x, sourceId: %x\n",io.DebugInfo.DebugTimeStampe, ResponseData, MatrixRegBankId, MatrixRegAddr, FIFOIndex, sourceId) + if (YJPCMLDebugEnable) + { + printf("[CMemoryLoader_Load<%d>]ResponseData: %x, MatrixRegBankId: %x, MatrixRegAddr: %x, FIFOIndex: %x\n",io.DebugInfo.DebugTimeStampe, ResponseData, MatrixRegBankId, MatrixRegAddr, FIFOIndex) } } @@ -420,11 +434,11 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ for (i <- 0 until CMatrixRegNBanks){ when(Bank_Fill_Search_FIFO_Empty(i) === false.B){ val CurrentFIFOIndex = Bank_Fill_Search_FIFO(i)(Bank_Fill_Search_FIFO_Tail(i)) - when(io.ToMatrixRegIO.ReadWriteResponse(MatrixRegTaskType.WriteFromMemoryLoaderIndex) === true.B) + when(io.ToMatrixRegIO.LoadReadWriteResponse(MatrixRegTaskType.WriteFromMemoryLoaderIndex) === true.B) { Current_Fill_MReg_Time(i) := 1.U val MatrixRegWriteRequest = io.ToMatrixRegIO.WriteRequestToMatrixReg - val FIFOData = WireInit(VecInit(Seq.fill(MAX_Fill_Times)(0.U((8 * CMatrixRegEntryByteSize).W)))) + val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*CMatrixRegEntryByteSize).W))))) val fillSlot = MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex) val currentIsTail = MReg_Fill_Table_IsTail(CurrentFIFOIndex) val fullByteMask = Fill(CMatrixRegEntryByteSize, true.B) @@ -441,9 +455,6 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ MatrixRegWriteRequest.Data(i).valid := true.B MatrixRegWriteRequest.ByteMask(i).bits := Mux(currentIsTail, tailByteMaskVec(fillSlot), fullByteMask) MatrixRegWriteRequest.ByteMask(i).valid := true.B - if (YJPCMLDebugEnable) { - printf("[CMemoryLoader_MRegWriteHandshake<%d>] bankid: %d, CurrentFIFOIndex: %d, ScartchPadAddr: %x, BankAddr: %x, Data: %x, ByteMask: %x\n", io.DebugInfo.DebugTimeStampe, i.U, CurrentFIFOIndex, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MatrixRegWriteRequest.BankAddr(i).bits, MatrixRegWriteRequest.Data(i).bits, MatrixRegWriteRequest.ByteMask(i).bits) - } MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ @@ -492,19 +503,17 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ { //给所有的bank发出写0的请求 HasScarhpadWrite := true.B - //每次写所有bank的一个entry,总共要写CMatrixRegBankNEntrys次 - val Max_ZeroLoad_Write_Times = CMatrixRegBankNEntrys + //每次写所有bank的一个entry,总共要写CMatrixRegBankNEntries次 + val Max_ZeroLoad_Write_Times = CMatrixRegBankNEntries for (i <- 0 until CMatrixRegNBanks) { io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(i).bits := TotalLoadSize io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(i).valid := true.B io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).bits := 0.U io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(i).valid := true.B - io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(i).bits := Fill(CMatrixRegEntryByteSize, true.B) - io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(i).valid := true.B } - when(io.ToMatrixRegIO.ReadWriteResponse(MatrixRegTaskType.WriteFromMemoryLoaderIndex) === true.B) + when(io.ToMatrixRegIO.LoadReadWriteResponse(MatrixRegTaskType.WriteFromMemoryLoaderIndex) === true.B) { TotalLoadSize := TotalLoadSize + 1.U if (YJPCMLDebugEnable) @@ -526,11 +535,11 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ // { // //由于RepeatRowLoad的特殊性,我们一次Load需要写MReg很多次,导致我们的FIFO在被写满时,会导致长时间的TL无法握手。 // //故,我们针对这样的情况,我们需要为每一个发出的访存请求预留一个FIFO的空位,这样就可以保证TL握手成功,从而不浪费访存带宽,这样可能会导致整体延迟增加(但不会低到阻碍吞吐),但我们的访存带宽利用率一定不会低 - // //获取整个Row的数据,然后重复填充,Row的总数据量为Tensor_N*C_DataType + // //获取整个Row的数据,然后重复填充,Row的总数据量为Tensor_N*C_DataWidth // val sourceId = Mux(IsConherent,io.LocalMMUIO.ConherentRequsetSourceID,io.LocalMMUIO.nonConherentRequsetSourceID) - // val Max_RepeatRowLoad_Memory_Load_Times = Tensor_MN.U * C_DataType / outsideDataWidthByte.U //总共要发出的访存请求的次数 + // val Max_RepeatRowLoad_Memory_Load_Times = Tensor_MN.U * C_DataWidth / outsideDataWidthByte.U //总共要发出的访存请求的次数 // val Max_MReg_Write_Times = Tensor_MN*Tensor_MN*ResultWidthByte/CMatrixReg_Total_Bandwidth //总共要写入MReg的次数 - // ReadRequest.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + CurrentLoaded_BlockTensor_N_Iter * C_DataType + // ReadRequest.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + CurrentLoaded_BlockTensor_N_Iter * C_DataWidth // ReadRequest.bits.RequestConherent := IsConherent // ReadRequest.bits.RequestSourceID := sourceId.bits // ReadRequest.bits.RequestType_isWrite := false.B @@ -542,7 +551,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ // val Per_Memory_Load_Have_Data_Write_Group = (outsideDataWidthByte/CMatrixRegEntryByteSize)//每次Memory的load,有几组数据要写回 // val Per_Write_MReg_Addr_Add = (Tensor_MN / Matrix_MN).U //一组数据Per_Data_Repeat_Times迭代中,下一次写入的scp地址的增量 - // // val Load_Time = CurrentLoaded_BlockTensor_N_Iter / (outsideDataWidthByte.U/C_DataType) + // // val Load_Time = CurrentLoaded_BlockTensor_N_Iter / (outsideDataWidthByte.U/C_DataWidth) // //向量的访存顺序 // //01,23,45,67..... @@ -571,7 +580,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ // TableItem.MatrixRegAddr := TotalRequestSize * Per_Memory_Load_Have_Data_Write_Group.U//这个数据的第一个数据,落在哪个MatrixRegBank的哪个地址上 // SoureceIdSearchTable(sourceId.bits) := TableItem.asUInt - // CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + outsideDataWidthByte.U / C_DataType + // CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + outsideDataWidthByte.U / C_DataWidth // Repeat_Fill_Request_Infight := Repeat_Fill_Request_Infight + 1.U // if (YJPCMLDebugEnable) // { @@ -628,7 +637,6 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ // HasScarhpadWrite := true.B - // when(io.ToMatrixRegIO.ReadWriteResponse(MatrixRegTaskType.WriteFromMemoryLoaderIndex) === true.B) // { // Repeat_Fill_Times := Repeat_Fill_Times + 1.U // TotalLoadSize := TotalLoadSize + 1.U @@ -666,10 +674,9 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ // } } is(s_load_end) { - io.ConfigInfo.MicroTaskEndValid := true.B - when(io.ConfigInfo.MicroTaskEndReady && io.ConfigInfo.MicroTaskEndValid){ + io.ConfigInfo.LoadMicroTaskEndValid := true.B + when(io.ConfigInfo.LoadMicroTaskEndReady && io.ConfigInfo.LoadMicroTaskEndValid){ memoryload_state := s_load_idle - state := s_idle if (YJPCMLDebugEnable) { printf("[CMemoryLoader_Load<%d>]Load Finish\n",io.DebugInfo.DebugTimeStampe) @@ -716,17 +723,17 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ val FireTimes = RegInit(0.U(log2Ceil(CMatrixRegNBanks).W)) when(Write_Mem_Wait_Table.reduce(_||_)){ - io.LocalMMUIO.Response.ready := true.B + io.StoreLocalMMUIO.Response.ready := true.B } - when(io.LocalMMUIO.Response.fire){ - val sourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID + when(io.StoreLocalMMUIO.Response.fire){ + val sourceId = io.StoreLocalMMUIO.Response.bits.ReseponseSourceID Write_Mem_Wait_Table(sourceId) := false.B } - val M_Get_IteratorMax = Mux(Is_Transpose, (MatrixRegTensor_M / (Matrix_MN.U * 2.U) + (MatrixRegTensor_M % (Matrix_MN.U * 2.U) =/= 0.U)) * 2.U, (MatrixRegTensor_M / Matrix_MN.U) + ((MatrixRegTensor_M % Matrix_MN.U) =/= 0.U)) - val N_Get_IteratorMax = WireInit(0.U(log2Ceil(CMatrixRegBankNEntrys).W)) - N_Get_IteratorMax := (MatrixRegTensor_N / Matrix_MN.U) - val transpose_scp_addr = WireInit(0.U(log2Ceil(CMatrixRegBankNEntrys).W)) + val M_Get_IteratorMax = Mux(IsStoreTranspose, (StoreMatrixRegTensor_M / (Matrix_MN.U * 2.U) + (StoreMatrixRegTensor_M % (Matrix_MN.U * 2.U) =/= 0.U)) * 2.U, (StoreMatrixRegTensor_M / Matrix_MN.U) + ((StoreMatrixRegTensor_M % Matrix_MN.U) =/= 0.U)) + val N_Get_IteratorMax = WireInit(0.U(log2Ceil(CMatrixRegBankNEntries).W)) + N_Get_IteratorMax := (StoreMatrixRegTensor_N / Matrix_MN.U) + val transpose_scp_addr = WireInit(0.U(log2Ceil(CMatrixRegBankNEntries).W)) // val Max_Caculate_Iter = M_Get_IteratorMax * N_Get_IteratorMax @@ -764,12 +771,12 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ FromMatrixRegReadFIFO := 0.U.asTypeOf(FromMatrixRegReadFIFO) FromMatrixRegReadFIFOHead := 0.U FromMatrixRegReadFIFOTail := 0.U - Max_Load_MReg_Time := MatrixRegTensor_M * MatrixRegTensor_N * D_DataType / CMatrixReg_Total_Bandwidth.U//总共要发对SCAP的访存次数 - Max_Store_Memory_Time := Mux(Is_Transpose, M_Get_IteratorMax * Matrix_MN.U, MatrixRegTensor_M) * MatrixRegTensor_N * D_DataType / outsideDataWidthByte.U//总共要发对LLC的访存次数 - // MaxIncStoreScpRequestSize := Mux(Is_Transpose, MatrixRegTensor_N, M_Get_IteratorMax * Matrix_MN.U) * Mux(Is_Transpose, MatrixRegTensor_M, MatrixRegTensor_N) * D_DataType / CMatrixReg_Total_Bandwidth.U - MaxIncStoreScpRequestSize := M_Get_IteratorMax * Matrix_MN.U * MatrixRegTensor_N * D_DataType / CMatrixReg_Total_Bandwidth.U - MaxIncStoreRequestSize := (Mux(Is_Transpose, MatrixRegTensor_N, MatrixRegTensor_M) / Matrix_MN.U * Matrix_MN.U) * Mux(Is_Transpose, MatrixRegTensor_M, MatrixRegTensor_N) * D_DataType / CMatrixReg_Total_Bandwidth.U - Max_Load_Scp_Tail_SubMajor_Iter := Mux(Is_Transpose, MatrixRegTensor_N, MatrixRegTensor_M) % Matrix_MN.U + Max_Load_MReg_Time := StoreMatrixRegTensor_M * StoreMatrixRegTensor_N * D_DataType / CMatrixReg_Total_Bandwidth.U//总共要发对SCAP的访存次数 + Max_Store_Memory_Time := Mux(IsStoreTranspose, M_Get_IteratorMax * Matrix_MN.U, StoreMatrixRegTensor_M) * StoreMatrixRegTensor_N * D_DataType / outsideDataWidthByte.U//总共要发对LLC的访存次数 + // MaxIncStoreScpRequestSize := Mux(IsStoreTranspose, StoreMatrixRegTensor_N, M_Get_IteratorMax * Matrix_MN.U) * Mux(IsStoreTranspose, StoreMatrixRegTensor_M, StoreMatrixRegTensor_N) * D_DataType / CMatrixReg_Total_Bandwidth.U + MaxIncStoreScpRequestSize := M_Get_IteratorMax * Matrix_MN.U * StoreMatrixRegTensor_N * D_DataType / CMatrixReg_Total_Bandwidth.U + MaxIncStoreRequestSize := (Mux(IsStoreTranspose, StoreMatrixRegTensor_N, StoreMatrixRegTensor_M) / Matrix_MN.U * Matrix_MN.U) * Mux(IsStoreTranspose, StoreMatrixRegTensor_M, StoreMatrixRegTensor_N) * D_DataType / CMatrixReg_Total_Bandwidth.U + Max_Load_Scp_Tail_SubMajor_Iter := Mux(IsStoreTranspose, StoreMatrixRegTensor_N, StoreMatrixRegTensor_M) % Matrix_MN.U Current_Load_Scp_Tail_subMajor_Iter := 0.U Current_Load_Scp_addr := 0.U Current_Load_M_iter := 0.U @@ -784,13 +791,13 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ Mux(D_DataType === 2.U, CMatrixReg_Total_Bandwidth.U/2.U, Mux(D_DataType === 4.U, CMatrixReg_Total_Bandwidth.U/4.U, CMatrixReg_Total_Bandwidth.U))) FireTimes := 0.U - Max_BlockTensor_Reduce_DIM := Mux(Is_Transpose, MatrixRegTensor_M, MatrixRegTensor_N) - Max_BlockTensor_Request_Reduce_DIM := Mux(Is_Transpose, M_Get_IteratorMax * Matrix_MN.U, MatrixRegTensor_N) - Max_BlockTensor_Major_DIM := Mux(Is_Transpose, MatrixRegTensor_N, MatrixRegTensor_M) + Max_BlockTensor_Reduce_DIM := Mux(IsStoreTranspose, StoreMatrixRegTensor_M, StoreMatrixRegTensor_N) + Max_BlockTensor_Request_Reduce_DIM := Mux(IsStoreTranspose, M_Get_IteratorMax * Matrix_MN.U, StoreMatrixRegTensor_N) + Max_BlockTensor_Major_DIM := Mux(IsStoreTranspose, StoreMatrixRegTensor_N, StoreMatrixRegTensor_M) if(YJPCMLDebugEnable) { - printf("[CMemoryLoader_Store<%d>]Store D Tensor Start, Max_Load_MReg_Time: %x\n", io.DebugInfo.DebugTimeStampe, MatrixRegTensor_M * MatrixRegTensor_N * D_DataType / CMatrixReg_Total_Bandwidth.U) + printf("[CMemoryLoader_Store<%d>]Store D Tensor Start, Max_Load_MReg_Time: %x\n", io.DebugInfo.DebugTimeStampe, StoreMatrixRegTensor_M * StoreMatrixRegTensor_N * D_DataType / CMatrixReg_Total_Bandwidth.U) } } is(s_store_working) { @@ -805,10 +812,10 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ } //根据MatrixReg的仲裁结果,我们可以读取数据了 for (i <- 0 until CMatrixRegNBanks){ - io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr(i).bits := Mux(Is_Transpose, transpose_scp_addr, Current_Load_Scp_addr) + io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr(i).bits := Mux(IsStoreTranspose, transpose_scp_addr, Current_Load_Scp_addr) io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr(i).valid := true.B } - when(io.ToMatrixRegIO.ReadWriteResponse(MatrixRegTaskType.ReadFromMemoryLoaderIndex)){ + when(io.ToMatrixRegIO.StoreReadWriteResponse(MatrixRegTaskType.ReadFromMemoryLoaderIndex)){ TotalStoreRequestSize := TotalStoreRequestSize + 1.U // logic for transpose Current_Load_M_iter := Current_Load_M_iter + 1.U @@ -839,7 +846,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ } //只要fifo内的数据有效,就可以写入LLC - val WriteRequest = io.LocalMMUIO.Request + val WriteRequest = io.StoreLocalMMUIO.Request WriteRequest.valid := false.B when(!FromMatrixRegReadFIFOEmpty && Reorder_ToLLC_Reg_Ready_Get){ val Read_Data_list = WireInit(VecInit(Seq.fill(Matrix_MN)(0.U(Per_GetMatrix_NDim_Width.W)))) @@ -880,20 +887,21 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ } } - WriteRequest.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentStore_BlockTensor_Major_DIM_Iter + CurrentStore_BlockTensor_SubMajor_DIM_Iter) * ApplicationTensor_D_Stride_M + CurrentStore_BlockTensor_Reduce_DIM_Iter * D_DataType - WriteRequest.bits.RequestConherent := IsConherent - WriteRequest.bits.RequestSourceID := io.LocalMMUIO.ConherentRequsetSourceID.bits + WriteRequest.bits.RequestVirtualAddr := StoreTensorBlockBaseAddr + (CurrentStore_BlockTensor_Major_DIM_Iter + CurrentStore_BlockTensor_SubMajor_DIM_Iter) * ApplicationTensor_D_Stride_M + CurrentStore_BlockTensor_Reduce_DIM_Iter * D_DataType + WriteRequest.bits.RequestConherent := IsStoreConherent + WriteRequest.bits.RequestSourceID := io.StoreLocalMMUIO.ConherentRequsetSourceID.bits WriteRequest.bits.RequestType_isWrite := true.B + WriteRequest.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) WriteRequest.bits.RequestData := Request_Data.asUInt WriteRequest.valid := true.B //只有fire了才能继续 - when(WriteRequest.fire && io.LocalMMUIO.ConherentRequsetSourceID.valid){ + when(WriteRequest.fire && io.StoreLocalMMUIO.ConherentRequsetSourceID.valid){ Send_LLC_Iter := WrapInc(Send_LLC_Iter, Send_LLC_Max_Iter) // if (YJPAfterOpsDebugEnable) // { // printf("[AfterOps<%d>]AfterOps: Send data to Vector, Send_Vector_Iter is %d,Send_Vector_Data is %x\n",io.DebugInfo.DebugTimeStampe, Send_Vector_Iter,io.VectorInterface.VectorDataIn.bits) // } - when(Is_Transpose) { + when(IsStoreTranspose) { when(Send_LLC_Iter === (Send_LLC_Max_Iter - 1).U) { Send_LLC_Iter := 0.U Reorder_ToLLC_Reg_Valid(Reorder_ToLLC_Reg_Send_Index) := false.B @@ -901,7 +909,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ } }.otherwise { - when(Send_LLC_Iter === (Send_LLC_Max_Iter - 1).U || (CurrentStore_BlockTensor_Major_DIM_Iter + CurrentStore_BlockTensor_SubMajor_DIM_Iter) === (MatrixRegTensor_M - 1.U)){ + when(Send_LLC_Iter === (Send_LLC_Max_Iter - 1).U || (CurrentStore_BlockTensor_Major_DIM_Iter + CurrentStore_BlockTensor_SubMajor_DIM_Iter) === (StoreMatrixRegTensor_M - 1.U)){ Send_LLC_Iter := 0.U Reorder_ToLLC_Reg_Valid(Reorder_ToLLC_Reg_Send_Index) := false.B Reorder_ToLLC_Reg_Send_Index := WrapInc(Reorder_ToLLC_Reg_Send_Index, 2) @@ -927,7 +935,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ } } - Write_Mem_Wait_Table(io.LocalMMUIO.ConherentRequsetSourceID.bits) := true.B + Write_Mem_Wait_Table(io.StoreLocalMMUIO.ConherentRequsetSourceID.bits) := true.B TotalStoreSize := TotalStoreSize + 1.U //输出完成的写回次数 if (YJPCMLDebugEnable) @@ -953,11 +961,10 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ is(s_store_end) { // memorystore_state := s_store_end when(!Write_Mem_Wait_Table.reduce(_||_)) { - io.ConfigInfo.MicroTaskEndValid := true.B + io.ConfigInfo.StoreMicroTaskEndValid := true.B } - when(io.ConfigInfo.MicroTaskEndReady && io.ConfigInfo.MicroTaskEndValid){ + when(io.ConfigInfo.StoreMicroTaskEndReady && io.ConfigInfo.StoreMicroTaskEndValid){ memorystore_state := s_store_idle - state := s_idle if (YJPCMLDebugEnable) { printf("[CMemoryLoader_Store<%d>]Store Finish\n",io.DebugInfo.DebugTimeStampe) @@ -965,4 +972,5 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ } } } + } diff --git a/src/main/scala/CUTE2YGJK.scala b/src/main/scala/CUTE2YGJK.scala index be36652..e2e241e 100644 --- a/src/main/scala/CUTE2YGJK.scala +++ b/src/main/scala/CUTE2YGJK.scala @@ -48,6 +48,7 @@ class CUTE2TLImp(outer: Cute2TL) extends LazyModuleImp(outer) with CUTEImplParam }) val data = io.mmu.Request.bits.RequestData + val mask = io.mmu.Request.bits.RequestMask val busy = RegInit(VecInit(Seq.fill(LLCSourceMaxNum)(false.B))) val id = WireInit(0.U(LLCSourceMaxNumBitSize.W)) @@ -137,7 +138,7 @@ class CUTE2TLImp(outer: Cute2TL) extends LazyModuleImp(outer) with CUTEImplParam tl_out.a.valid := io.mmu.Request.valid && !is_full tl_out.a.bits := Mux1H(Seq( (io.mmu.Request.bits.RequestType_isWrite === 0.U) -> edge.Get(id, io.mmu.Request.bits.RequestPhysicalAddr, log2Ceil(outsideDataWidthByte).U)._2, - (io.mmu.Request.bits.RequestType_isWrite === 1.U) -> edge.Put(id, io.mmu.Request.bits.RequestPhysicalAddr, log2Ceil(outsideDataWidthByte).U, data)._2 + (io.mmu.Request.bits.RequestType_isWrite === 1.U) -> edge.Put(id, io.mmu.Request.bits.RequestPhysicalAddr, log2Ceil(outsideDataWidthByte).U, data, mask)._2 )) // Assign MatrixKey to cooperate with HBL2. diff --git a/src/main/scala/CUTEParameters.scala b/src/main/scala/CUTEParameters.scala index 56545b2..998b6b8 100644 --- a/src/main/scala/CUTEParameters.scala +++ b/src/main/scala/CUTEParameters.scala @@ -41,6 +41,25 @@ class DebugInfoIO()(implicit p: Parameters) extends CuteBundle{ case object CuteParamsKey extends Field[CuteParams] +case object MteComputeType extends Field[UInt] { + val ComputeTypeBitWidth = 4 + val ComputeTypeUndef = 15.U(ComputeTypeBitWidth.W) + + val I8I8I32 = 0.U(ComputeTypeBitWidth.W) + val F16F16F32 = 1.U(ComputeTypeBitWidth.W) + val BF16BF16F32 = 2.U(ComputeTypeBitWidth.W) + val TF32TF32F32 = 3.U(ComputeTypeBitWidth.W) + val I8U8I32 = 4.U(ComputeTypeBitWidth.W) + val U8I8I32 = 5.U(ComputeTypeBitWidth.W) + val U8U8I32 = 6.U(ComputeTypeBitWidth.W) + val Mxfp8e4m3F32 = 7.U(ComputeTypeBitWidth.W) + val Mxfp8e5m2F32 = 8.U(ComputeTypeBitWidth.W) + val Nvfp4F32 = 9.U(ComputeTypeBitWidth.W) + val Mxfp4F32 = 10.U(ComputeTypeBitWidth.W) + val Fp8e4m3F32 = 11.U(ComputeTypeBitWidth.W) + val Fp8e5m2F32 = 12.U(ComputeTypeBitWidth.W) +} + case class MatrixIsaParams( enableInt4Int32: Boolean = false, enableInt8Int32: Boolean = false, @@ -53,6 +72,9 @@ case class MatrixIsaParams( enableFp16Fp32: Boolean = false, enableTf32Fp32: Boolean = false, enableFp32Fp32: Boolean = false, + enableMxfp4Fp32: Boolean = false, + enableMxfp8Fp32: Boolean = false, + enableNvfp4Fp32: Boolean = false, ) { assert(enableInt84Int32 == false, "enableInt84Int32 is not supported now") assert(enableFp32Fp32 == false, "enableFp32Fp32 is not supported now") @@ -61,14 +83,16 @@ case class MatrixIsaParams( enableInt4Int32 || enableInt8Int32 || enableInt84Int32 || enableFp8Fp32 || enableFp8Fp16 || enableFp8Bf16 || enableFp16Fp16 || enableBf16Fp32 || enableFp16Fp32 || - enableFp32Fp32 || enableTf32Fp32 + enableFp32Fp32 || enableTf32Fp32 || + enableMxfp4Fp32 || enableMxfp8Fp32 || enableNvfp4Fp32 def enable4BitSrc: Boolean = - enableInt4Int32 + enableInt4Int32 || enableMxfp4Fp32 || enableNvfp4Fp32 def enable8BitSrc: Boolean = enableInt8Int32 || enableInt84Int32 || - enableFp8Fp32 || enableFp8Fp16 || enableFp8Bf16 + enableFp8Fp32 || enableFp8Fp16 || enableFp8Bf16 || + enableMxfp8Fp32 def enable16BitSrc: Boolean = enableFp16Fp16 || enableBf16Fp32 || enableFp16Fp32 @@ -80,7 +104,7 @@ case class MatrixIsaParams( def enable4BitDst: Boolean = false - def enable8BitDst: Boolean = true //开启8位类型,仅仅用于load测试 + def enable8BitDst: Boolean = false def enable16BitDst: Boolean = enableFp8Fp16 || enableFp8Bf16 || enableFp16Fp16 @@ -93,6 +117,11 @@ case class MatrixIsaParams( enableTf32Fp32 def enable64BitDst: Boolean = false + + def enableScalingFactor: Boolean = + enableNvfp4Fp32 || enableMxfp4Fp32 || enableMxfp8Fp32 + def enableFp4withsf: Boolean = + enableNvfp4Fp32 || enableMxfp4Fp32 } trait CuteParamsKey{ @@ -127,10 +156,29 @@ object CuteParams { Tensor_K = 64, Matrix_MN = 8, ReduceWidthByte = 32, - // Debug = CuteDebugParams.AMLDebugEnable ) - def CUTE_8Tops_128SCP = baseParams.copy( + def CUTE_32Tops = baseParams.copy( + outsideDataWidth = 512, + LLCSourceMaxNum = 64, + MemorysourceMaxNum = 64, + Tensor_MN = 256, + Tensor_K = 64, + Matrix_MN = 16, + ReduceWidthByte = 32, + ) + + def CUTE_16Tops = baseParams.copy( + outsideDataWidth = 512, + LLCSourceMaxNum = 64, + MemorysourceMaxNum = 64, + Tensor_MN = 128, + Tensor_K = 64, + Matrix_MN = 8, + ReduceWidthByte = 64, + ) + + def CUTE_8Tops = baseParams.copy( outsideDataWidth = 512, LLCSourceMaxNum = 64, MemorysourceMaxNum = 64, @@ -149,14 +197,14 @@ object CuteParams { // Debug = CuteDebugParams.AMLDebugEnable ) - def CUTE_32Tops_128SCP = baseParams.copy( + def CUTE_4Tops = baseParams.copy( outsideDataWidth = 512, LLCSourceMaxNum = 64, MemorysourceMaxNum = 64, Tensor_MN = 256, Tensor_K = 64, - Matrix_MN = 16, - ReduceWidthByte = 32, + Matrix_MN = 4, + ReduceWidthByte = 64, MatrixExtension = MatrixIsaParams( enableInt8Int32 = true, enableFp8Fp32 = true, @@ -184,6 +232,71 @@ object CuteParams { enableFp16Fp16 = true, enableBf16Fp32 = true, ), + // Debug = CuteDebugParams.AMLDebugEnable + ) + + def CUTE_1Tops = baseParams.copy( + outsideDataWidth = 512, + LLCSourceMaxNum = 64, + MemorysourceMaxNum = 64, + Tensor_MN = 64, + Tensor_K = 64, + Matrix_MN = 4, + ReduceWidthByte = 64, + // Debug = CuteDebugParams.AMLDebugEnable + ) + + def CUTE_05Tops = baseParams.copy( + outsideDataWidth = 512, + LLCSourceMaxNum = 64, + MemorysourceMaxNum = 64, + Tensor_MN = 64, + Tensor_K = 64, + Matrix_MN = 4, + ReduceWidthByte = 32, + // Debug = CuteDebugParams.AMLDebugEnable + ) + + def CUTE_512SCP(params: CuteParams) = params.copy( + Tensor_MN = 512, + Tensor_K = 64, + ) + + def CUTE_256SCP(params: CuteParams) = params.copy( + Tensor_MN = 256, + Tensor_K = 64, + ) + + def CUTE_128SCP(params: CuteParams) = params.copy( + Tensor_MN = 128, + Tensor_K = 64, + ) + + def CUTE_64SCP(params: CuteParams) = params.copy( + Tensor_MN = 64, + Tensor_K = 64, + ) + + def CUTE_32Tops_512SCP = CUTE_512SCP(CUTE_32Tops) + def CUTE_16Tops_512SCP = CUTE_512SCP(CUTE_16Tops) + def CUTE_8Tops_512SCP = CUTE_512SCP(CUTE_8Tops) + def CUTE_4Tops_512SCP = CUTE_512SCP(CUTE_4Tops) + def CUTE_16Tops_256SCP = CUTE_256SCP(CUTE_16Tops) + def CUTE_8Tops_256SCP = CUTE_256SCP(CUTE_8Tops) + def CUTE_4Tops_256SCP = CUTE_256SCP(CUTE_4Tops) + def CUTE_2Tops_256SCP = CUTE_256SCP(CUTE_2Tops) + def CUTE_8Tops_128SCP = CUTE_128SCP(CUTE_8Tops) + def CUTE_4Tops_128SCP = CUTE_128SCP(CUTE_4Tops) + def CUTE_2Tops_128SCP = CUTE_128SCP(CUTE_2Tops) + def CUTE_1Tops_128SCP = CUTE_128SCP(CUTE_1Tops) + def CUTE_4Tops_64SCP = CUTE_64SCP(CUTE_4Tops) + def CUTE_2Tops_64SCP = CUTE_64SCP(CUTE_2Tops) + def CUTE_1Tops_64SCP = CUTE_64SCP(CUTE_1Tops) + def CUTE_05Tops_64SCP = CUTE_64SCP(CUTE_05Tops) + + + def CUTE_4Tops_128SCP_debug = CUTE_4Tops_128SCP.copy( + Debug = CuteDebugParams.AllDebugOn ) def CUTE_2Tops_debug = baseParams.copy( @@ -323,7 +436,7 @@ case class CuteMMUParams( object Cutev3extParams { // NoV3ExtParams: def NoextParams = Cutev3extParams( - TaskCtrl_AutoClear = false, //任务控制器是否自动清除已完成指令 + TaskCtrl_AutoClear = false, //whether the task controller auto-clears completed instructions ) // V3 Base Ext @@ -332,7 +445,7 @@ object Cutev3extParams { } case class Cutev3extParams( - val TaskCtrl_AutoClear :Boolean = true, //任务控制器是否自动清除已完成指令 + val TaskCtrl_AutoClear :Boolean = true, //whether the task controller auto-clears completed instructions ) @@ -343,65 +456,80 @@ object CuteFPEParams { } case class CuteFPEParams( + // currently fixed + val MinGroupSize :Int = 16, + val MinDataTypeWidth : Int = 4, + val ScaleElementWidth : Int = 8, + // + val cmptreelayers :Int = 4, - val P3AddNum :Int = 4, + val fp8cmptreelayers :Int = 4, + // currently fixed and shared with FP4; leave it unchanged + val FP4P0AddNum :Int = 2, ) case class CuteParams( - val outsideDataWidth :Int = 512, //cute对外访存的带宽 - val MemoryDataWidth :Int = 64, //TODO:DRAM的访存通道的数据位宽 + val outsideDataWidth :Int = 512, //CUTE external memory bandwidth + val MemoryDataWidth :Int = 64, //TODO: data width of the DRAM memory channel - val VectorWidth :Int = 256, //向量流水线的宽度 + val VectorWidth :Int = 256, //vector pipeline width val L2NBanks :Int = 4, - val ConvolutionApplicationConfigDataWidth :Int = 32, //卷积相关的配置信息的宽度 - val ConvolutionDIM_Max :Int = 65536, //卷积相关的配置信息的宽度 + val ConvolutionApplicationConfigDataWidth :Int = 32, //width of convolution-related configuration information + val ConvolutionDIM_Max :Int = 65536, //width of convolution-related configuration information val Convolution_Input_Height_Weight_Dim_Max :Int = 16384, - val KernelSizeMax :Int = 16, //卷积核的最大尺寸 - val StrideSizeMax :Int = 4, //步长的最大尺寸 + val KernelSizeMax :Int = 16, //maximum kernel size + val StrideSizeMax :Int = 4, //maximum stride size - val ApplicationMaxTensorSize :Int = 65536, //最大可处理的程序的张量形状, + val ApplicationMaxTensorSize :Int = 65536, //maximum tensor shape a program can handle, - val MMUAddrWidth :Int = 64 , //CUTE MMU的地址宽度 + val MMUAddrWidth :Int = 64 , //CUTE MMU address width - val LLCSourceMaxNum :Int = 64, //LLC总线上的source最大数量 --> 这个参数和LLC的访存延迟强相关,若要满流水,这个sourceMAXnum的数量必须大于LLC的访存延迟 - val MemorysourceMaxNum :Int = 64, //Memory总线上的source最大数量 --> 这个参数和Memory的访存延迟强相关,若要满流水,这个sourceMAXnum的数量必顶大于Memory的访存延迟 + val LLCSourceMaxNum :Int = 64, //maximum number of sources on the LLC bus -> this parameter is tightly coupled to LLC latency; to sustain full throughput, sourceMAXnum must exceed LLC latency + val MemorysourceMaxNum :Int = 64, //maximum number of sources on the memory bus -> this parameter is tightly coupled to memory latency; to sustain full throughput, sourceMAXnum must exceed memory latency - //MatrixReg中保存的张量形状 - val Tensor_MN :Int = 128, //这里指要存的张量的M与N的大小 - val Tensor_K :Int = 64, //这里指要存的张量的K(8bit/elment)的大小 + //tensor shape stored in MatrixReg + val Tensor_MN :Int = 128, //M and N dimensions of the tensor to be stored here + val Tensor_K :Int = 64, //K size of the tensor to be stored here (8-bit/element) - //矩阵乘计算单元MTE的形状 - val Matrix_MN :Int = 4, //Matrix_MN,代表TE执行的矩阵乘法的M与N的大小 - val ReduceWidthByte :Int = 32, //ReduceWidthByte 代表ReducePE进行内积时的数据宽度,单位是字节 - val ResultWidthByte :Int = 4, //ResultWidthByte 代表ReducePE的结果宽度,单位是字节 + //MTE matrix-multiply unit shape + val Matrix_MN :Int = 4, //Matrix_MN: M and N dimensions of the matrix multiply executed by TE + val ReduceWidthByte :Int = 64, //ReduceWidthByte: data width used by ReducePE for inner products, in bytes + val ResultWidthByte :Int = 4, //ResultWidthByte: result width of ReducePE, in bytes - val ResultFIFODepth :Int = 8, //乘累加FIFO的深度 + val ResultFIFODepth :Int = 8, //MAC FIFO depth - val AMemoryLoaderReadFromMemoryFIFODepth :Int = 4, //用于暂存AML的数据到CCSP的FIFO - val BMemoryLoaderReadFromMemoryFIFODepth :Int = 4, //用于暂存BML的数据到CCSP的FIFO - val CMemoryLoaderReadFromMatrixRegFIFODepth :Int = 4, //用于暂存CCSP的数据到CML的FIFO - val CMemoryLoaderReadFromMemoryFIFODepth :Int = 4, //用于暂存CML的数据到CMReg的FIFO + val AMemoryLoaderReadFromMemoryFIFODepth :Int = 4, //FIFO for buffering AML data to CCSP + val BMemoryLoaderReadFromMemoryFIFODepth :Int = 4, //FIFO for buffering BML data to CCSP + val CMemoryLoaderReadFromMatrixRegFIFODepth :Int = 4, //FIFO for buffering CCSP data to CML + val CMemoryLoaderReadFromMemoryFIFODepth :Int = 4, //FIFO for buffering CML data to CMReg - val VecTaskInstBufferDepth :Int = 32, //VecTask的指令缓冲深度 - val VecTaskInstBufferSize :Int = 8, //VecTask的指令缓冲的数量 - val VecTaskDataBufferDepth :Int = 4, //VecTask的指令缓冲深度掩盖从VecInterface到VPU的数据传输延迟即可 + val VecTaskInstBufferDepth :Int = 32, //VecTask instruction buffer depth + val VecTaskInstBufferSize :Int = 8, //number of VecTask instruction buffers + val VecTaskDataBufferDepth :Int = 4, //VecTask buffer depth only needs to cover transfer latency from VecInterface to VPU - val EnableDifftest: Boolean = false, //是否启用DiffTest + // TaskController issue window depth (compile-time only) + val TaskCtrlIssueWindowDepth :Int = 8, - val Debug : CuteDebugParams = CuteDebugParams.NoDebug, //调试参数 - val MMUParams: CuteMMUParams = CuteMMUParams.baseParams, //MMU的参数 + val EnableDifftest: Boolean = false, //whether DiffTest is enabled + + val Debug : CuteDebugParams = CuteDebugParams.NoDebug, //debug parameters + val MMUParams: CuteMMUParams = CuteMMUParams.baseParams, //MMU parameters - val v3config: Cutev3extParams = Cutev3extParams.NoextParams, //v3的扩展参数 + val v3config: Cutev3extParams = Cutev3extParams.NoextParams, //v3 extension parameters - val FPEparams: CuteFPEParams = CuteFPEParams.baseparams, //FPE的参数 + val FPEparams: CuteFPEParams = CuteFPEParams.baseparams, //FPE parameters val MatrixExtension: MatrixIsaParams = MatrixIsaParams() ) { - //所有参数都必须是2的n次方 + //all parameters must be powers of 2 + // require(ReduceWidthByte == 64, "FP8/4 now only support 512 bit reduce width") + require(outsideDataWidth == 512 || outsideDataWidth == 256, "currently only support 512/256 bit outsideDataWidth") + require(ReduceWidthByte == 64 || ReduceWidthByte == 32, "currently only support 512/256 bit ReduceWidth") + require(outsideDataWidth >= ReduceWidthByte * 8, "outsideDataWidth must be larger than or equal to ReduceWidthByte") require((outsideDataWidth & (outsideDataWidth - 1)) == 0, "outsideDataWidth must be power of 2") require((MemoryDataWidth & (MemoryDataWidth - 1)) == 0, "MemoryDataWidth must be power of 2") require((VectorWidth & (VectorWidth - 1)) == 0, "VectorWidth must be power of 2") @@ -411,7 +539,7 @@ case class CuteParams( require((KernelSizeMax & (KernelSizeMax - 1)) == 0, "KernelSizeMax must be power of 2") require((StrideSizeMax & (StrideSizeMax - 1)) == 0, "StrideSizeMax must be power of 2") require((ApplicationMaxTensorSize & (ApplicationMaxTensorSize - 1)) == 0, "ApplicationMaxTensorSize must be power of 2") - require((MMUAddrWidth & (MMUAddrWidth - 1)) == 0, "MMUAddrWidth must be power of 2" ) + // require((MMUAddrWidth & (MMUAddrWidth - 1)) == 0, "MMUAddrWidth must be power of 2" ) require((LLCSourceMaxNum & (LLCSourceMaxNum - 1)) == 0, "LLCSourceMaxNum must be power of 2") require((MemorysourceMaxNum & (MemorysourceMaxNum - 1)) == 0, "MemorysourceMaxNum must be power of 2") require((Tensor_MN & (Tensor_MN - 1)) == 0, "Tensor_MN must be power of 2") @@ -427,152 +555,219 @@ case class CuteParams( require((VecTaskInstBufferDepth & (VecTaskInstBufferDepth - 1)) == 0, "VecTaskInstBufferDepth must be power of 2") require((VecTaskInstBufferSize & (VecTaskInstBufferSize - 1)) == 0, "VecTaskInstBufferSize must be power of 2") require((VecTaskDataBufferDepth & (VecTaskDataBufferDepth - 1)) == 0, "VecTaskDataBufferDepth must be power of 2") + require(Seq(4, 8, 16).contains(TaskCtrlIssueWindowDepth), "TaskCtrlIssueWindowDepth only supports 4/8/16") + require((FPEparams.MinGroupSize == 16), "FPEparams.MinGroupSize must be 16") + require((FPEparams.MinDataTypeWidth == 4), "FPEparams.MinDataTypeWidth must be 4") + require((FPEparams.ScaleElementWidth == 8), "FPEparams.ScaleElementWidth must be 8") def outsideDataWidthByte = outsideDataWidth / 8 def ReduceWidth = ReduceWidthByte * 8 - def ABMLNeedMRegFillTable = ReduceWidthByte < outsideDataWidthByte //内存返回的数据一周期写不完时ABML需要写回缓冲 + def ABMLNeedMRegFillTable = ReduceWidthByte < outsideDataWidthByte //when returned memory data cannot be written in one cycle, ABML needs a writeback buffer def ResultWidth = ResultWidthByte * 8 def ApplicationMaxTensorSizeBitSize = log2Ceil(ApplicationMaxTensorSize) + 1 - def MMUDataWidth = outsideDataWidth //MMU的数据线宽度 - def MMUMaskWidth = MMUDataWidth / 8 //MMU的掩码线宽度 - def MMUDataWidthBitSize = log2Ceil(MMUDataWidth) + 1 //MMU的数据线有效数据位数 + def MMUDataWidth = outsideDataWidth //MMU data bus width + def MMUMaskWidth = MMUDataWidth / 8 //MMU mask width + def MMUDataWidthBitSize = log2Ceil(MMUDataWidth) + 1 //effective data-bit count of the MMU data bus def LLCSourceMaxNumBitSize = log2Ceil(LLCSourceMaxNum) + 1 def MemorysourceMaxNumBitSize = log2Ceil(MemorysourceMaxNum) + 1 def SoureceMaxNum = math.max(LLCSourceMaxNum, MemorysourceMaxNum) def SoureceMaxNumBitSize = log2Ceil(SoureceMaxNum) + 1 - def ReduceGroupSize = Tensor_K/ReduceWidthByte //这里指要存的张量的K的ReduceVector的数量!不是张量的K的大小 + def P3AddNum = ReduceWidth / 4 / FPEparams.MinGroupSize + def P2AddNum = ReduceWidth / (P3AddNum * 16) + + def ReduceGroupSize = Tensor_K/ReduceWidthByte //number of ReduceVectors for K to be stored, not the K size def MatrixRegMaxTensorDim = Math.max(Tensor_MN, Math.max(Tensor_MN, ReduceGroupSize)) def MatrixRegMaxTensorDimBitSize = log2Ceil(MatrixRegMaxTensorDim) + 1 - //A MatrixReg中保存的张量形状为M*K - //A MatrixReg的大小为Tenser_M * ReduceGroupSize * ReduceWidthByte - //128*(4*256/8),单次读的张量为128*128的张量 - //单次计算需要的时间为(128/4)*(128/4)*4 = 4096拍,单次读需要128×4=512拍。 - //需要考虑MatrixReg的顺序读,需要考虑为MatrixReg分bank + //A tensor shape stored in MatrixReg is M*K + //A MatrixReg size is Tenser_M * ReduceGroupSize * ReduceWidthByte + //128*(4*256/8); one read covers a 128*128 tensor + //one compute takes (128/4)*(128/4)*4 = 4096 cycles; one read takes 128*4 = 512 cycles. + //sequential reads must be considered; MatrixReg should be banked def ABMatrixRegSize = Tensor_MN * ReduceGroupSize * ReduceWidthByte //reduce def CMatrixRegSize = Tensor_MN * Tensor_MN * ResultWidthByte //result - //目前的MatrixReg设计,分Tensor_T个bank,每次取Tensor_T个数据,根据取数逻辑,在不同的bank里取不同的数据,然后拼接 - def ABMatrixRegEntryByteSize = ReduceWidthByte //适合向TE供数的带宽 - def CMatrixRegEntryByteSize = Matrix_MN*ResultWidthByte //这个取数和存数的带宽 - def ABMatrixRegEntryBitSize = ReduceWidthByte * 8 //适合向TE供数的带宽 - def CMatrixRegEntryBitSize = Matrix_MN*ResultWidthByte * 8//这个取数和存数的带宽 - def ABMatrixRegNBanks = Matrix_MN //注意这里与Matrix_MN有强相关性,一般是Matrix_MN的整数倍 - def CMatrixRegNBanks = Matrix_MN //方便进行reorder - def ABMatrixReg_Total_Bandwidth = ABMatrixRegNBanks * ABMatrixRegEntryByteSize //ABMatrixReg的总带宽 - def CMatrixReg_Total_Bandwidth = CMatrixRegNBanks * CMatrixRegEntryByteSize //CMatrixReg的总带宽 - def ABMatrixReg_Total_Bandwidth_Bit = ABMatrixRegNBanks * ABMatrixRegEntryByteSize * 8 //ABMatrixReg的总带宽 - def CMatrixReg_Total_Bandwidth_Bit = CMatrixRegNBanks * CMatrixRegEntryByteSize * 8 //CMatrixReg的总带宽 + //the current MatrixReg design is split into Tensor_T banks; each access fetches Tensor_T items from different banks and concatenates them + def ABMatrixRegEntryByteSize = ReduceWidthByte //bandwidth suitable for feeding TE + def CMatrixRegEntryByteSize = Matrix_MN*ResultWidthByte //bandwidth for reads and writes + def ABMatrixRegEntryBitSize = ReduceWidthByte * 8 //bandwidth suitable for feeding TE + def CMatrixRegEntryBitSize = Matrix_MN*ResultWidthByte * 8//bandwidth for reads and writes + // DiffAmuFinishEvent flattens per-bank payload into UInt(64) lanes. + // Use the maximum bank payload width so one event layout works for AB/C paths. + def DiffAmuFinishWordsPerBank = Math.max(ABMatrixRegEntryBitSize, CMatrixRegEntryBitSize) / 64 + def ABMatrixRegNBanks = Matrix_MN //note that this is strongly tied to Matrix_MN and is generally an integer multiple of Matrix_MN + def CMatrixRegNBanks = Matrix_MN //convenient for reorder + def Trans_Load_Size = outsideDataWidthByte / ABMatrixRegNBanks + def ABMatrixReg_Total_Bandwidth = ABMatrixRegNBanks * ABMatrixRegEntryByteSize //total bandwidth of ABMatrixReg + def CMatrixReg_Total_Bandwidth = CMatrixRegNBanks * CMatrixRegEntryByteSize //total bandwidth of CMatrixReg + def ABMatrixReg_Total_Bandwidth_Bit = ABMatrixRegNBanks * ABMatrixRegEntryByteSize * 8 //total bandwidth of ABMatrixReg + def CMatrixReg_Total_Bandwidth_Bit = CMatrixRegNBanks * CMatrixRegEntryByteSize * 8 //total bandwidth of CMatrixReg def ABMatrixRegBankSize = ABMatrixRegSize / ABMatrixRegNBanks def CMatrixRegBankSize = CMatrixRegSize / CMatrixRegNBanks - def ABMatrixRegBankNEntrys = ABMatrixRegBankSize / ABMatrixRegEntryByteSize - def CMatrixRegBankNEntrys = CMatrixRegBankSize / CMatrixRegEntryByteSize - - require(ReduceGroupSize == 2, "ReduceGroupSize must be 2, Wait for update") + def ABMatrixRegBankNEntries = ABMatrixRegBankSize / ABMatrixRegEntryByteSize + def CMatrixRegBankNEntries = CMatrixRegBankSize / CMatrixRegEntryByteSize + + /** + * Scale Factor Parameters + */ + def ScaleWidth = ReduceWidthByte * 8 * FPEparams.ScaleElementWidth / FPEparams.MinDataTypeWidth / FPEparams.MinGroupSize // maximum bit width needed for one group scale + def ABScaleSize = Tensor_MN * ReduceGroupSize * ScaleWidth + def ABScaleNSlices = outsideDataWidth / ScaleWidth / ReduceGroupSize // maximum scale length for one Tensor_K per slice + def ABScaleBankNEntries = ABScaleSize / (ABScaleNSlices * ScaleWidth * ReduceGroupSize) + + // require(ReduceGroupSize == 2, "ReduceGroupSize must be 2, Wait for update") require(outsideDataWidthByte <= Tensor_K, "outsideDataWidthByte must be less than or equal to Tensor_K, or a load will exceed the subtensor in micro load") + require(outsideDataWidthByte % ABMatrixRegNBanks == 0, "outsideDataWidthByte must be divisible by ABMatrixRegNBanks for transpose load") } trait CUTEImplParameters{ implicit val p: Parameters def cuteParams: CuteParams = p(CuteParamsKey) + def cuteMatrixExtension: MatrixIsaParams = cuteParams.MatrixExtension + + def enableMteInt8: Boolean = cuteMatrixExtension.enableInt8Int32 + def enableMteFp8: Boolean = cuteMatrixExtension.enableFp8Fp32 + def enableMteFp16: Boolean = cuteMatrixExtension.enableFp16Fp32 || cuteMatrixExtension.enableFp16Fp16 + def enableMteBf16: Boolean = cuteMatrixExtension.enableBf16Fp32 + def enableMteTf32: Boolean = cuteMatrixExtension.enableTf32Fp32 + def enableMteNvfp4: Boolean = false + def enableMteMxfp8: Boolean = false + def enableMteMxfp4: Boolean = false + def MMUParams: CuteMMUParams = cuteParams.MMUParams def DebugParams: CuteDebugParams = cuteParams.Debug def v3config: Cutev3extParams = cuteParams.v3config def FPEparams: CuteFPEParams = cuteParams.FPEparams - val DecodedAmuCtrlFIFODepth = 8 //解码后的AMU指令FIFO的深度 - val DecodedAmuCtrlFIFODepthBitSize = log2Ceil(DecodedAmuCtrlFIFODepth) //解码后的AMU指令FIFO的深度 - - val ABMatrixRegCount = 4 - val CMatrixRegCount = 4 - val ABMatrixRegIdWidth = log2Ceil(ABMatrixRegCount) - val CMatrixRegIdWidth = log2Ceil(CMatrixRegCount) - - val vpnBits = MMUParams.vpnBits - val ppnBits = MMUParams.ppnBits - val pgIdxBits = MMUParams.pgIdxBits - val vaddrBits = MMUParams.vaddrBits - val paddrBits = MMUParams.paddrBits - val corePAddrBits = MMUParams.corePAddrBits - - val TaskCtrl_AutoClear = v3config.TaskCtrl_AutoClear - - val YJPDebugEnable = DebugParams.YJPDebugEnable - val YJPADCDebugEnable = DebugParams.YJPADCDebugEnable - val YJPBDCDebugEnable = DebugParams.YJPBDCDebugEnable - val YJPCDCDebugEnable = DebugParams.YJPCDCDebugEnable - val YJPAMLDebugEnable = DebugParams.YJPAMLDebugEnable - val YJPBMLDebugEnable = DebugParams.YJPBMLDebugEnable - val YJPCMLDebugEnable = DebugParams.YJPCMLDebugEnable - val YJPTASKDebugEnable = DebugParams.YJPTASKDebugEnable - val YJPVECDebugEnable = DebugParams.YJPVECDebugEnable - val YJPMACDebugEnable = DebugParams.YJPMACDebugEnable - val YJPPEDebugEnable = DebugParams.YJPPEDebugEnable - val YJPAfterOpsDebugEnable = DebugParams.YJPAfterOpsDebugEnable - - val ConvolutionApplicationConfigDataWidth = cuteParams.ConvolutionApplicationConfigDataWidth - val ConvolutionDIM_Max = cuteParams.ConvolutionDIM_Max - val Convolution_Input_Height_Weight_Dim_Max = cuteParams.Convolution_Input_Height_Weight_Dim_Max - val KernelSizeMax = cuteParams.KernelSizeMax - val StrideSizeMax = cuteParams.StrideSizeMax - val outsideDataWidth = cuteParams.outsideDataWidth - val outsideDataWidthByte = cuteParams.outsideDataWidthByte - val MemoryDataWidth = cuteParams.MemoryDataWidth - val ReduceWidthByte = cuteParams.ReduceWidthByte - val ReduceWidth = cuteParams.ReduceWidth - val ABMLNeedMRegFillTable = cuteParams.ABMLNeedMRegFillTable - val ResultWidthByte = cuteParams.ResultWidthByte - val ResultWidth = cuteParams.ResultWidth - val VectorWidth = cuteParams.VectorWidth - val ApplicationMaxTensorSize = cuteParams.ApplicationMaxTensorSize - val ApplicationMaxTensorSizeBitSize = cuteParams.ApplicationMaxTensorSizeBitSize - val MMUAddrWidth = cuteParams.MMUAddrWidth - val MMUDataWidth = cuteParams.MMUDataWidth - val MMUMaskWidth = cuteParams.MMUMaskWidth - val MMUDataWidthBitSize = cuteParams.MMUDataWidthBitSize - val LLCSourceMaxNum = cuteParams.LLCSourceMaxNum - val LLCSourceMaxNumBitSize = cuteParams.LLCSourceMaxNumBitSize - val MemorysourceMaxNum = cuteParams.MemorysourceMaxNum - val MemorysourceMaxNumBitSize = cuteParams.MemorysourceMaxNumBitSize - val SoureceMaxNum = cuteParams.SoureceMaxNum - val SoureceMaxNumBitSize = cuteParams.SoureceMaxNumBitSize - val Tensor_MN = cuteParams.Tensor_MN - val Tensor_K = cuteParams.Tensor_K - val MatrixRegMaxTensorDim = cuteParams.MatrixRegMaxTensorDim - val MatrixRegMaxTensorDimBitSize = cuteParams.MatrixRegMaxTensorDimBitSize - val ABMatrixRegSize = cuteParams.ABMatrixRegSize - val CMatrixRegSize = cuteParams.CMatrixRegSize - val Matrix_MN = cuteParams.Matrix_MN - val ABMatrixRegEntryByteSize = cuteParams.ABMatrixRegEntryByteSize - val CMatrixRegEntryByteSize = cuteParams.CMatrixRegEntryByteSize - val ABMatrixRegEntryBitSize = cuteParams.ABMatrixRegEntryBitSize - val CMatrixRegEntryBitSize = cuteParams.CMatrixRegEntryBitSize - val ABMatrixRegNBanks = cuteParams.ABMatrixRegNBanks - val CMatrixRegNBanks = cuteParams.CMatrixRegNBanks - val ABMatrixReg_Total_Bandwidth = cuteParams.ABMatrixReg_Total_Bandwidth - val CMatrixReg_Total_Bandwidth = cuteParams.CMatrixReg_Total_Bandwidth - val ABMatrixReg_Total_Bandwidth_Bit = cuteParams.ABMatrixReg_Total_Bandwidth_Bit - val CMatrixReg_Total_Bandwidth_Bit = cuteParams.CMatrixReg_Total_Bandwidth_Bit - val ABMatrixRegBankSize = cuteParams.ABMatrixRegBankSize - val CMatrixRegBankSize = cuteParams.CMatrixRegBankSize - val ABMatrixRegBankNEntrys = cuteParams.ABMatrixRegBankNEntrys - val CMatrixRegBankNEntrys = cuteParams.CMatrixRegBankNEntrys - val ResultFIFODepth = cuteParams.ResultFIFODepth - val AMemoryLoaderReadFromMemoryFIFODepth = cuteParams.AMemoryLoaderReadFromMemoryFIFODepth - val BMemoryLoaderReadFromMemoryFIFODepth = cuteParams.BMemoryLoaderReadFromMemoryFIFODepth - val CMemoryLoaderReadFromMatrixRegFIFODepth = cuteParams.CMemoryLoaderReadFromMatrixRegFIFODepth - val CMemoryLoaderReadFromMemoryFIFODepth = cuteParams.CMemoryLoaderReadFromMemoryFIFODepth - val VecTaskInstBufferDepth = cuteParams.VecTaskInstBufferDepth - val VecTaskInstBufferSize = cuteParams.VecTaskInstBufferSize - val VecTaskDataBufferDepth = cuteParams.VecTaskDataBufferDepth - val ReduceGroupSize = cuteParams.ReduceGroupSize - val EnableDifftest = cuteParams.EnableDifftest - val L2NBanks = cuteParams.L2NBanks - - val cmptreelayers = FPEparams.cmptreelayers //FPE的计算树层数 - val P3AddNum :Int = FPEparams.P3AddNum //FPE的P3加法器的数量 - val P2AddNum :Int = ReduceWidth / (P3AddNum * 16) + def TaskCtrlIssueWindowDepth = cuteParams.TaskCtrlIssueWindowDepth + def TaskCtrlIssueWindowDepthBitSize = log2Ceil(TaskCtrlIssueWindowDepth) + + def DecodedAmuCtrlFIFODepth = TaskCtrlIssueWindowDepth // Decoded AMU instruction FIFO depth, tied to the issue window + def DecodedAmuCtrlFIFODepthBitSize = log2Ceil(DecodedAmuCtrlFIFODepth) // Bit width of the decoded AMU instruction FIFO depth + + def ABMatrixRegCount = 4 + def CMatrixRegCount = 4 + def ABMatrixRegIdWidth = log2Ceil(ABMatrixRegCount) + def CMatrixRegIdWidth = log2Ceil(CMatrixRegCount) + def MatrixRegIdWidth = ABMatrixRegIdWidth max CMatrixRegIdWidth + + def vpnBits = MMUParams.vpnBits + def ppnBits = MMUParams.ppnBits + def pgIdxBits = MMUParams.pgIdxBits + def vaddrBits = MMUParams.vaddrBits + def paddrBits = MMUParams.paddrBits + def corePAddrBits = MMUParams.corePAddrBits + + def TaskCtrl_AutoClear = v3config.TaskCtrl_AutoClear + + def YJPDebugEnable = DebugParams.YJPDebugEnable + def YJPADCDebugEnable = DebugParams.YJPADCDebugEnable + def YJPBDCDebugEnable = DebugParams.YJPBDCDebugEnable + def YJPCDCDebugEnable = DebugParams.YJPCDCDebugEnable + def YJPAMLDebugEnable = DebugParams.YJPAMLDebugEnable + def YJPBMLDebugEnable = DebugParams.YJPBMLDebugEnable + def YJPCMLDebugEnable = DebugParams.YJPCMLDebugEnable + def YJPTASKDebugEnable = DebugParams.YJPTASKDebugEnable + def YJPVECDebugEnable = DebugParams.YJPVECDebugEnable + def YJPMACDebugEnable = DebugParams.YJPMACDebugEnable + def YJPPEDebugEnable = DebugParams.YJPPEDebugEnable + def YJPAfterOpsDebugEnable = DebugParams.YJPAfterOpsDebugEnable + + def ConvolutionApplicationConfigDataWidth = cuteParams.ConvolutionApplicationConfigDataWidth + def ConvolutionDIM_Max = cuteParams.ConvolutionDIM_Max + def Convolution_Input_Height_Weight_Dim_Max = cuteParams.Convolution_Input_Height_Weight_Dim_Max + def KernelSizeMax = cuteParams.KernelSizeMax + def StrideSizeMax = cuteParams.StrideSizeMax + def outsideDataWidth = cuteParams.outsideDataWidth + def outsideDataWidthByte = cuteParams.outsideDataWidthByte + def MemoryDataWidth = cuteParams.MemoryDataWidth + def ReduceWidthByte = cuteParams.ReduceWidthByte + def ReduceWidth = cuteParams.ReduceWidth + def mxfp8ScaleWidth = ReduceWidth * 8 / 8 / 32 //total scale width accepted by one PE per cycle [single scale width 8-bit, single element width 4-bit, groupsize 32] + def nvfp4ScaleWidth = ReduceWidth * 8 / 4 / 16 //total scale width accepted by one PE per cycle [single scale width 8-bit, single element width 4-bit, groupsize 16] + def mxfp4ScaleWidth = ReduceWidth * 8 / 4 / 32 //total scale width accepted by one PE per cycle [single scale width 8-bit, single element width 4-bit, groupsize 32] + def ABMLNeedMRegFillTable = cuteParams.ABMLNeedMRegFillTable + def ResultWidthByte = cuteParams.ResultWidthByte + def ResultWidth = cuteParams.ResultWidth + def VectorWidth = cuteParams.VectorWidth + def ApplicationMaxTensorSize = cuteParams.ApplicationMaxTensorSize + def ApplicationMaxTensorSizeBitSize = cuteParams.ApplicationMaxTensorSizeBitSize + def MMUAddrWidth = cuteParams.MMUAddrWidth + def MMUDataWidth = cuteParams.MMUDataWidth + def MMUMaskWidth = cuteParams.MMUMaskWidth + def MMUDataWidthBitSize = cuteParams.MMUDataWidthBitSize + def LLCSourceMaxNum = cuteParams.LLCSourceMaxNum + def LLCSourceMaxNumBitSize = cuteParams.LLCSourceMaxNumBitSize + def MemorysourceMaxNum = cuteParams.MemorysourceMaxNum + def MemorysourceMaxNumBitSize = cuteParams.MemorysourceMaxNumBitSize + def SoureceMaxNum = cuteParams.SoureceMaxNum + def SoureceMaxNumBitSize = cuteParams.SoureceMaxNumBitSize + def Tensor_MN = cuteParams.Tensor_MN + def Tensor_K = cuteParams.Tensor_K + def MatrixRegMaxTensorDim = cuteParams.MatrixRegMaxTensorDim + def MatrixRegMaxTensorDimBitSize = cuteParams.MatrixRegMaxTensorDimBitSize + def ABMatrixRegSize = cuteParams.ABMatrixRegSize + def CMatrixRegSize = cuteParams.CMatrixRegSize + def Matrix_MN = cuteParams.Matrix_MN + def ABMatrixRegEntryByteSize = cuteParams.ABMatrixRegEntryByteSize + def CMatrixRegEntryByteSize = cuteParams.CMatrixRegEntryByteSize + def ABMatrixRegEntryBitSize = cuteParams.ABMatrixRegEntryBitSize + def CMatrixRegEntryBitSize = cuteParams.CMatrixRegEntryBitSize + def DiffAmuFinishWordsPerBank = cuteParams.DiffAmuFinishWordsPerBank + def ABMatrixRegNBanks = cuteParams.ABMatrixRegNBanks + def CMatrixRegNBanks = cuteParams.CMatrixRegNBanks + def Trans_Load_Size = cuteParams.Trans_Load_Size + def ABMatrixReg_Total_Bandwidth = cuteParams.ABMatrixReg_Total_Bandwidth + def CMatrixReg_Total_Bandwidth = cuteParams.CMatrixReg_Total_Bandwidth + def ABMatrixReg_Total_Bandwidth_Bit = cuteParams.ABMatrixReg_Total_Bandwidth_Bit + def CMatrixReg_Total_Bandwidth_Bit = cuteParams.CMatrixReg_Total_Bandwidth_Bit + def ABMatrixRegBankSize = cuteParams.ABMatrixRegBankSize + def CMatrixRegBankSize = cuteParams.CMatrixRegBankSize + def ABMatrixRegBankNEntries = cuteParams.ABMatrixRegBankNEntries + def CMatrixRegBankNEntries = cuteParams.CMatrixRegBankNEntries + def ScaleWidth = cuteParams.ScaleWidth + def ABScaleBankNEntries = cuteParams.ABScaleBankNEntries + def ABScaleNSlices = cuteParams.ABScaleNSlices + def ResultFIFODepth = cuteParams.ResultFIFODepth + def AMemoryLoaderReadFromMemoryFIFODepth = cuteParams.AMemoryLoaderReadFromMemoryFIFODepth + def BMemoryLoaderReadFromMemoryFIFODepth = cuteParams.BMemoryLoaderReadFromMemoryFIFODepth + def CMemoryLoaderReadFromMatrixRegFIFODepth = cuteParams.CMemoryLoaderReadFromMatrixRegFIFODepth + def CMemoryLoaderReadFromMemoryFIFODepth = cuteParams.CMemoryLoaderReadFromMemoryFIFODepth + def VecTaskInstBufferDepth = cuteParams.VecTaskInstBufferDepth + def VecTaskInstBufferSize = cuteParams.VecTaskInstBufferSize + def VecTaskDataBufferDepth = cuteParams.VecTaskDataBufferDepth + def ReduceGroupSize = cuteParams.ReduceGroupSize + def EnableDifftest = cuteParams.EnableDifftest + def L2NBanks = cuteParams.L2NBanks + + def MinGroupSize = FPEparams.MinGroupSize //minimum compute group size of FPE + def MinDataTypeWidth = FPEparams.MinDataTypeWidth //minimum data type width of FPE + def ScaleElementWidth = FPEparams.ScaleElementWidth //FPE scale element width + + def cmptreelayers = FPEparams.cmptreelayers //FPE compute tree depth + def fp8cmptreelayers = FPEparams.fp8cmptreelayers + + def P3AddNum :Int = cuteParams.P3AddNum //number of P3 adders in FPE + def P2AddNum :Int = cuteParams.P2AddNum + + def FP4P0AddNum :Int = FPEparams.FP4P0AddNum + def FP4P1AddNum :Int = 16 / FP4P0AddNum + + def ScaleVecWidth(computeType : UInt) : UInt = { + val scaleVecWidth = Wire(UInt(4.W)) + scaleVecWidth := 0.U + switch(computeType){ + is (MteComputeType.Mxfp8e4m3F32) { scaleVecWidth := (ReduceWidthByte * 8 / 8 / 32).U } + is (MteComputeType.Mxfp8e5m2F32) { scaleVecWidth := (ReduceWidthByte * 8 / 8 / 32).U } + is (MteComputeType.Nvfp4F32) { scaleVecWidth := (ReduceWidthByte * 8 / 4 / 16).U } + is (MteComputeType.Mxfp4F32) { scaleVecWidth := (ReduceWidthByte * 8 / 4 / 32).U } + } + scaleVecWidth + } + + def DEBUG_FP8 = false + def DEBUG_FP4 = false } class CuteModule(implicit val p: Parameters) extends Module with CUTEImplParameters @@ -580,10 +775,10 @@ class CuteBundle(implicit val p: Parameters) extends Bundle with CUTEImplParamet class AfterOpsInterface()(implicit p: Parameters) extends CuteBundle{ - //每拍可接受一个来自CDC的与MReg和TE等宽的数据,并在自己模块内完成数据的拆分、重排、缩放、转置以及其他复杂向量任务 + //can accept per-cycle data from CDC matching MReg and TE width, and perform splitting, reordering, scaling, transposition, and other complex vector tasks inside this module val CDCDataToInterface = DecoupledIO(UInt((ResultWidth*Matrix_MN*Matrix_MN).W)) val InterfaceToCDCData = Flipped(DecoupledIO(UInt((ResultWidth*Matrix_MN*Matrix_MN).W))) - // val CDCStoreAddr = Input(UInt(log2Ceil(CMatrixRegBankNEntrys).W)) + // val CDCStoreAddr = Input(UInt(log2Ceil(CMatrixRegBankNEntries).W)) val VecInstQueueID = UInt(1.W) } @@ -592,15 +787,15 @@ class VPUInterface_Input()(implicit p: Parameters) extends CuteBundle{ val inst_uop = Output(UInt(32.W)) val inst_src0 = Output(UInt(VectorWidth.W)) val inst_src1 = Output(UInt(VectorWidth.W)) - val inst_src0_type = Output(UInt(2.W))//从寄存器还是来自输入 - val inst_src1_type = Output(UInt(2.W))//从寄存器还是来自输入 - val inst_dest_type = Output(UInt(2.W))//写回寄存器还是写回输出 - val stream_id = Output(UInt(log2Ceil(Matrix_MN*Matrix_MN+10).W))//stream data的id + val inst_src0_type = Output(UInt(2.W))//from register or input + val inst_src1_type = Output(UInt(2.W))//from register or input + val inst_dest_type = Output(UInt(2.W))//write back to register or output + val stream_id = Output(UInt(log2Ceil(Matrix_MN*Matrix_MN+10).W))//stream data ID } class VPUInterface_Output()(implicit p: Parameters) extends CuteBundle{ - val stream_id = Output(UInt(log2Ceil(Matrix_MN*Matrix_MN+10).W))//stream data的id - val stream_data = Output(UInt(VectorWidth.W))//stream data还能存一些额外的信息,这些信息也会返回,后续可以用于配置VPU的部分隐式寄存器,或者留存在VPU的的隐式寄存器中,这些寄存器是uop可见的,如下一次的scale,下一次的bias等。 + val stream_id = Output(UInt(log2Ceil(Matrix_MN*Matrix_MN+10).W))//stream data ID + val stream_data = Output(UInt(VectorWidth.W))//stream data can carry extra information that is also returned; later it can be used to configure some VPU implicit registers or kept in VPU implicit registers visible to uops, such as the next scale or bias } class VPUInterfaceIO()(implicit p: Parameters) extends CuteBundle{ @@ -611,7 +806,7 @@ class VPUInterfaceIO()(implicit p: Parameters) extends CuteBundle{ class VectorInterfaceIO()(implicit p: Parameters) extends CuteBundle{ - //每拍可接受一个来自AfterOpsInterface的与VectorWidth等宽的数据 + //can accept per-cycle data from AfterOpsInterface matching VectorWidth val VecTask = DecoupledIO(UInt(log2Ceil(VecTaskInstBufferSize).W)) val VectorDataIn = DecoupledIO(UInt((VectorWidth).W)) val VectorDataOut = Flipped(DecoupledIO(UInt((VectorWidth).W))) @@ -625,7 +820,7 @@ case object StreamStateType extends Field[UInt]{ val Reorder_DIM_M_First = 2.U(StreamStateTypeBitWidth.W) } -//描述计算时,数据流的访问顺序,transpose的时候就是N_M,不transpose的时候就是M_N +//describes the data-flow access order during computation: N_M for transpose, M_N otherwise case object CaculateStreamStateType extends Field[UInt]{ val CaculateStreamStateTypeBitWidth = 4 @@ -655,58 +850,62 @@ class AfterOpsMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ val MatrixRegTensor_K = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_N = (UInt(MatrixRegMaxTensorDimBitSize.W)) - //接受后操作的任务,有可能是重排序,有可能是缩放,有可能是转置,有可能是其他复杂后操作任务 - val Is_Transpose = (Bool()) //是否需要转置 - val Is_Reorder_Only_Ops = (Bool()) //是否只是重排,不需要计算 - val Is_EasyScale_Only_Ops = (Bool()) //是否只是简单的缩放,不需要额外的后操作计算 - val Is_VecFIFO_Ops = (Bool()) //是否真的需要通用VecFIFO的参与 + //tasks that accept post-ops may be reorder, scaling, transpose, or other complex post-processing tasks + val Is_Transpose = (Bool()) //whether transpose is needed + val Is_Reorder_Only_Ops = (Bool()) //whether this is reorder-only, with no computation needed + val Is_EasyScale_Only_Ops = (Bool()) //whether this is simple scaling only, without extra post-op computation + val Is_VecFIFO_Ops = (Bool()) //whether the generic VecFIFO is actually needed - val MicroTaskReady = Flipped(Bool())//可配置下一个任务 - val MicroTaskValid = (Bool()) //当前任务的配置信息有效 - val MicroTaskEndValid = Flipped(Bool())//已完成当前任务 - val MicroTaskEndReady = (Bool()) //已知晓当前任务完成 + val MicroTaskReady = Flipped(Bool())//can configure the next task + val MicroTaskValid = (Bool()) //current task configuration is valid + val MicroTaskEndValid = Flipped(Bool())//current task completed + val MicroTaskEndReady = (Bool()) //current task completion acknowledged val CUTEuop = (new CUTE_uop) } -class ADCMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ - val ApplicationTensor_A = (new Bundle{ - // val ApplicationTensor_A_BaseVaddr = (UInt(MMUAddrWidth.W)) - // val BlockTensor_A_BaseVaddr = (UInt(MMUAddrWidth.W)) - val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) - }) - +class ADCMicroTaskConfigBaseIO()(implicit p: Parameters) extends CuteBundle{ val MatrixRegTensor_M = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_K = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_N = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegId = UInt(ABMatrixRegIdWidth.W) - val Is_Transpose = (Bool()) //是否需要转置 + val Is_Transpose = (Bool()) //whether transpose is needed - val MicroTaskReady = Flipped(Bool())//可配置下一个任务 - val MicroTaskValid = (Bool()) //当前任务的配置信息有效 - val MicroTaskEndValid = Flipped(Bool())//已完成当前任务 - val MicroTaskEndReady = (Bool()) //已知晓当前任务完成 + val MicroTaskReady = Flipped(Bool())//can configure the next task + val MicroTaskValid = (Bool()) //current task configuration is valid + val MicroTaskEndValid = Flipped(Bool())//current task completed + val MicroTaskEndReady = (Bool()) //current task completion acknowledged } -class BDCMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ - val ApplicationTensor_B = (new Bundle{ - // val ApplicationTensor_B_BaseVaddr = (UInt(MMUAddrWidth.W)) - // val BlockTensor_B_BaseVaddr = (UInt(MMUAddrWidth.W)) - val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) - }) +class ADCMicroTaskConfigIO()(implicit p: Parameters) extends ADCMicroTaskConfigBaseIO { + val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) +} + +class ASCMicroTaskConfigIO()(implicit p: Parameters) extends ADCMicroTaskConfigBaseIO { + val computeType = UInt(MteComputeType.ComputeTypeBitWidth.W) +} +class BDCMicroTaskConfigBaseIO()(implicit p: Parameters) extends CuteBundle{ val MatrixRegTensor_M = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_K = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_N = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegId = UInt(ABMatrixRegIdWidth.W) - val Is_Transpose = (Bool()) //是否需要转置 + val Is_Transpose = (Bool()) //whether transpose is needed - val MicroTaskReady = Flipped(Bool())//可配置下一个任务 - val MicroTaskValid = (Bool()) //当前任务的配置信息有效 - val MicroTaskEndValid = Flipped(Bool())//已完成当前任务 - val MicroTaskEndReady = (Bool()) //已知晓当前任务完成 + val MicroTaskReady = Flipped(Bool())//can configure the next task + val MicroTaskValid = (Bool()) //current task configuration is valid + val MicroTaskEndValid = Flipped(Bool())//current task completed + val MicroTaskEndReady = (Bool()) //current task completion acknowledged +} + +class BDCMicroTaskConfigIO()(implicit p: Parameters) extends BDCMicroTaskConfigBaseIO { + val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) +} + +class BSCMicroTaskConfigIO()(implicit p: Parameters) extends BDCMicroTaskConfigBaseIO { + val computeType = UInt(MteComputeType.ComputeTypeBitWidth.W) } class CDCMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ @@ -727,19 +926,19 @@ class CDCMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ val MatrixRegTensor_N = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegId = UInt(CMatrixRegIdWidth.W) - val Is_Transpose = (Bool()) //是否需要转置 - val Is_AfterOps_Tile = (Bool()) //是否是需要执行后操作的Tile,包括转置等 + val Is_Transpose = (Bool()) //whether transpose is needed + val Is_AfterOps_Tile = (Bool()) //whether this is a tile that requires post-ops, including transpose - val Is_Reorder_Only_Ops = (Bool()) //是否只是重排,不需要计算 - val Is_EasyScale_Only_Ops = (Bool()) //是否只是简单的缩放,不需要额外的后操作计算 - val Is_VecFIFO_Ops = (Bool()) //是否真的需要通用VecFIFO的参与 + val Is_Reorder_Only_Ops = (Bool()) //whether this is reorder-only, with no computation needed + val Is_EasyScale_Only_Ops = (Bool()) //whether this is simple scaling only, without extra post-op computation + val Is_VecFIFO_Ops = (Bool()) //whether the generic VecFIFO is actually needed - val MicroTaskReady = Flipped(Bool())//可配置下一个任务 - val MicroTaskValid = (Bool()) //当前任务的配置信息有效 - val MicroTaskEndValid = Flipped(Bool())//已完成当前任务 - val MicroTaskEndReady = (Bool()) //已知晓当前任务完成 - val MicroTask_TEComputeEndValid = Flipped(Bool())//已完成当前的TE的计算任务(但是还没有完成后操作),但是可以提前释放TE的占用 - val MicroTask_TEComputeEndReady = (Bool()) //已知晓当前的TE的计算任务完成 + val MicroTaskReady = Flipped(Bool())//can configure the next task + val MicroTaskValid = (Bool()) //current task configuration is valid + val MicroTaskEndValid = Flipped(Bool())//current task completed + val MicroTaskEndReady = (Bool()) //current task completion acknowledged + val MicroTask_TEComputeEndValid = Flipped(Bool())//current TE compute task completed (post-ops still pending), but TE occupancy can be released early + val MicroTask_TEComputeEndReady = (Bool()) //current TE compute task completion acknowledged val pc = Option.when(EnableDifftest) (UInt(64.W)) val coreid = Option.when(EnableDifftest) (UInt(8.W)) @@ -747,15 +946,20 @@ class CDCMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ class ApplicationTensor_A_Info()(implicit p: Parameters) extends CuteBundle{ val ApplicationTensor_A_BaseVaddr = (UInt(MMUAddrWidth.W)) - // val BlockTensor_A_BaseVaddr = (UInt(MMUAddrWidth.W))//可能没有了 - val ApplicationTensor_A_Stride_M = (UInt(MMUAddrWidth.W))//下一个M需要增加多少的地址偏移量 + // val BlockTensor_A_BaseVaddr = (UInt(MMUAddrWidth.W))//may be gone already + val ApplicationTensor_A_Stride_M = (UInt(MMUAddrWidth.W))//address offset increment for the next M val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) - //细粒度控制参数增加 val HasTail = Bool() val TailByteMask = UInt(log2Ceil(outsideDataWidthByte + 1).W) val K_Beat_Count = UInt(MatrixRegMaxTensorDimBitSize.W) } +class ApplicationScale_A_Info()(implicit p: Parameters) extends CuteBundle{ + val ApplicationScale_A_BaseVaddr = (UInt(MMUAddrWidth.W)) + val BlockScale_A_BaseVaddr = (UInt(MMUAddrWidth.W)) // main active field + val computeType = (UInt(MteComputeType.ComputeTypeBitWidth.W)) +} + class AMLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ val ApplicationTensor_A = new ApplicationTensor_A_Info @@ -766,26 +970,31 @@ class AMLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ val MatrixRegTensor_K = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegId = UInt(ABMatrixRegIdWidth.W) - val Conherent = (Bool()) //是否需要coherent + val Conherent = (Bool()) //whether coherence is needed + val Is_Transpose = (Bool()) //whether transpose is needed - val MicroTaskReady = Flipped(Bool())//可配置下一个任务 - val MicroTaskValid = (Bool()) //当前任务的配置信息有效 - val MicroTaskEndValid = Flipped(Bool())//已完成当前任务 - val MicroTaskEndReady = (Bool()) //已知晓当前任务完成 + val MicroTaskReady = Flipped(Bool())//can configure the next task + val MicroTaskValid = (Bool()) //current task configuration is valid + val MicroTaskEndValid = Flipped(Bool())//current task completed + val MicroTaskEndReady = (Bool()) //current task completion acknowledged val pc = Option.when(EnableDifftest) (UInt(64.W)) val coreid = Option.when(EnableDifftest) (UInt(8.W)) } -class ApplicationTensor_B_Info()(implicit p: Parameters) extends CuteBundle{ - val ApplicationTensor_B_BaseVaddr = (UInt(MMUAddrWidth.W)) - val BlockTensor_B_BaseVaddr = (UInt(MMUAddrWidth.W)) - val ApplicationTensor_B_Stride_N = (UInt(MMUAddrWidth.W))//下一个N需要增加多少的地址偏移量 - val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) - //细粒度控制参数增加 - val HasTail = Bool() - val TailByteMask = UInt(log2Ceil(outsideDataWidthByte + 1).W) - val K_Beat_Count = UInt(MatrixRegMaxTensorDimBitSize.W) +class ASLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ + + val ApplicationScale_A = (new ApplicationScale_A_Info) + + val MatrixRegTensor_M = (UInt(MatrixRegMaxTensorDimBitSize.W)) + val MatrixRegTensor_K = (UInt(MatrixRegMaxTensorDimBitSize.W)) + + val Conherent = (Bool()) //whether coherence is needed + + val MicroTaskReady = Flipped(Bool())//can configure the next task + val MicroTaskValid = (Bool()) //current task configuration is valid + val MicroTaskEndValid = Flipped(Bool())//current task completed + val MicroTaskEndReady = (Bool()) //current task completion acknowledged } class BMLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ @@ -796,21 +1005,53 @@ class BMLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ val MatrixRegTensor_K = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegId = UInt(ABMatrixRegIdWidth.W) - val Conherent = (Bool()) //是否需要coherent + val Conherent = (Bool()) //whether coherence is needed + val Is_Transpose = (Bool()) //whether transpose is needed - val MicroTaskReady = Flipped(Bool())//可配置下一个任务 - val MicroTaskValid = (Bool()) //当前任务的配置信息有效 - val MicroTaskEndValid = Flipped(Bool())//已完成当前任务 - val MicroTaskEndReady = (Bool()) //已知晓当前任务完成 + val MicroTaskReady = Flipped(Bool())//can configure the next task + val MicroTaskValid = (Bool()) //current task configuration is valid + val MicroTaskEndValid = Flipped(Bool())//current task completed + val MicroTaskEndReady = (Bool()) //current task completion acknowledged val pc = Option.when(EnableDifftest) (UInt(64.W)) val coreid = Option.when(EnableDifftest) (UInt(8.W)) } +class BSLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ + + val ApplicationScale_B = (new ApplicationScale_B_Info) + + val MatrixRegTensor_N = (UInt(MatrixRegMaxTensorDimBitSize.W)) + val MatrixRegTensor_K = (UInt(MatrixRegMaxTensorDimBitSize.W)) + + val Conherent = (Bool()) //whether coherence is needed + + val MicroTaskReady = Flipped(Bool())//can configure the next task + val MicroTaskValid = (Bool()) //current task configuration is valid + val MicroTaskEndValid = Flipped(Bool())//current task completed + val MicroTaskEndReady = (Bool()) //current task completion acknowledged +} + +class ApplicationTensor_B_Info()(implicit p: Parameters) extends CuteBundle{ + val ApplicationTensor_B_BaseVaddr = (UInt(MMUAddrWidth.W)) + val BlockTensor_B_BaseVaddr = (UInt(MMUAddrWidth.W)) + val ApplicationTensor_B_Stride_N = (UInt(MMUAddrWidth.W))//address offset increment for the next N + val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) + val HasTail = Bool() + val TailByteMask = UInt(log2Ceil(outsideDataWidthByte + 1).W) + val K_Beat_Count = UInt(MatrixRegMaxTensorDimBitSize.W) +} + +class ApplicationScale_B_Info()(implicit p: Parameters) extends CuteBundle{ + val ApplicationScale_B_BaseVaddr = (UInt(MMUAddrWidth.W)) + val BlockScale_B_BaseVaddr = (UInt(MMUAddrWidth.W)) // main active field + val computeType = (UInt(MteComputeType.ComputeTypeBitWidth.W)) +} + class ApplicationTensor_C_Info()(implicit p: Parameters) extends CuteBundle{ val ApplicationTensor_C_BaseVaddr = (UInt(MMUAddrWidth.W)) val BlockTensor_C_BaseVaddr = (UInt(MMUAddrWidth.W)) - val ApplicationTensor_C_Stride_M = (UInt(MMUAddrWidth.W))//下一个M需要增加多少的地址偏移量 + val ApplicationTensor_C_Stride_M = (UInt(MMUAddrWidth.W))//address offset increment for the next M val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) val HasTail = Bool() val TailByteMask = UInt(log2Ceil(outsideDataWidthByte + 1).W) @@ -820,7 +1061,7 @@ class ApplicationTensor_C_Info()(implicit p: Parameters) extends CuteBundle{ class ApplicationTensor_D_Info()(implicit p: Parameters) extends CuteBundle{ val ApplicationTensor_D_BaseVaddr = (UInt(MMUAddrWidth.W)) val BlockTensor_D_BaseVaddr = (UInt(MMUAddrWidth.W)) - val ApplicationTensor_D_Stride_M = (UInt(MMUAddrWidth.W))//下一个M需要增加多少的地址偏移量 + val ApplicationTensor_D_Stride_M = (UInt(MMUAddrWidth.W))//address offset increment for the next M val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) } @@ -830,7 +1071,7 @@ class LoadTask_Info()(implicit p: Parameters) extends CuteBundle{ val Is_FullLoad = (Bool()) } class CMLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ - //就是一个TensorC,是累加寄存器视角的不动的部分 + //this is Tensor C, the invariant part from the accumulator-register perspective val ApplicationTensor_C = (new ApplicationTensor_C_Info) @@ -838,30 +1079,29 @@ class CMLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ val LoadTaskInfo = (new LoadTask_Info) - val StoreTaskInfo = (new Bundle{ - val Is_ZeroStore = (Bool())//暂时没有传递的参数 - }) - - val Conherent = (Bool()) //是否需要coherent - val Is_Transpose = (Bool()) //是否需要转置 + val Conherent = (Bool()) //whether coherence is needed + val Is_Transpose = (Bool()) //whether transpose is needed val MatrixRegTensor_M = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_N = (UInt(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegId = UInt(CMatrixRegIdWidth.W) - val IsLoadMicroTask = (Bool()) //是否是Load任务 - val IsStoreMicroTask = (Bool()) //是否是Store任务 + val LoadMicroTaskReady = Flipped(Bool())//can configure the next load task + val LoadMicroTaskValid = (Bool()) //current load-task configuration is valid + val LoadMicroTaskEndValid = Flipped(Bool())//current load task completed + val LoadMicroTaskEndReady = (Bool()) //current load task completion acknowledged - val MicroTaskReady = Flipped(Bool())//可配置下一个任务 - val MicroTaskValid = (Bool()) //当前任务的配置信息有效 - val MicroTaskEndValid = Flipped(Bool())//已完成当前任务 - val MicroTaskEndReady = (Bool()) //已知晓当前任务完成 + val StoreMicroTaskReady = Flipped(Bool())//can configure the next store task + val StoreMicroTaskValid = (Bool()) //current store-task configuration is valid + val StoreMicroTaskEndValid = Flipped(Bool())//current store task completed + val StoreMicroTaskEndReady = (Bool()) //current store task completion acknowledged val pc = Option.when(EnableDifftest) (UInt(64.W)) val coreid = Option.when(EnableDifftest) (UInt(8.W)) } class MTEMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ - val dataType = Output(UInt(ElementDataType.DataTypeBitWidth.W)) + val MicroTaskValid = (Bool()) //current task configuration is valid + val computeType = Output(UInt(MteComputeType.ComputeTypeBitWidth.W)) } class MRegControlInfo()(implicit p: Parameters) extends CuteBundle{ @@ -873,44 +1113,60 @@ class MRegControlInfo()(implicit p: Parameters) extends CuteBundle{ val CML_MReg_ID = UInt(CMatrixRegIdWidth.W) } -//从MatrixReg中取数,要明确是从哪个bank里,取第几行的数据,然后完成数据拼接返回 -//从哪个bank里取数据,取第几行的数据,是由datacontrol模块算出来的 -//怎么在bank里编排数据,是由MemoryLoader模块填进去的 -//MemoryLoader模块和datacontrol模块都有窗口期,可以完成数据额外的一些编排如量化、反稀疏、反量化、量化重排等等 -//将MemoryLoader模块和datacontrol模块分开,是为了使用窗口期,让单读写口的MatrixReg可以独立运行 -//有没有能同时读写的SRAM啊?我能保证不写同一块数据,还是先doublebuffer吧.... -//我们考虑到回数的延迟,所以DataControl与MatrixReg之间也是有fifo的。考虑到后续的SRAM是一个简单模块,fifo要加在DataControl里,让MatrixReg尽可能简单。 +//when reading from MatrixReg, it must be clear which bank and which row are accessed, then the data is concatenated and returned +//which bank and which row to read are computed by the data-control module +//how data is arranged within banks is filled by the MemoryLoader module +//MemoryLoader and data-control modules both have a window period, which can perform extra data arrangement such as quantization, sparsification reversal, dequantization, and quantization reorder +//MemoryLoader and data-control are separated to exploit the window period so the single-read/single-write MatrixReg can run independently +//is there an SRAM that can read and write simultaneously? I can guarantee no writes hit the same location, but let us use double buffering first.... +//we account for response latency, so there is also a FIFO between DataControl and MatrixReg. Since the future SRAM is expected to be a simple module, keep the FIFO in DataControl so MatrixReg stays as simple as possible. class ABDataControlMatrixRegIO(implicit p: Parameters) extends CuteBundle{ - //bankaddr是对nbanks个bank,各自bank的行选信号,是一个vec,有nbanks个元素,每个元素是一个UInt,UInt的宽度是log2Ceil(AMatrixRegBankNLines),是输入的需要握手的数据 - val BankAddr = Flipped(DecoupledIO(Vec(ABMatrixRegNBanks, (UInt(log2Ceil(ABMatrixRegBankNEntrys).W))))) - //bankdata是对nbanks个bank,各自bank的行数据,是一个vec,有nbanks个元素,每个元素是一个UInt,UInt的宽度是ReduceWidthByte*8 + //bankaddr is the row-select signal for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is log2Ceil(ABMatrixRegBankNLines), and it is input data that requires handshaking + val BankAddr = Flipped(DecoupledIO(Vec(ABMatrixRegNBanks, (UInt(log2Ceil(ABMatrixRegBankNEntries).W))))) + //bankdata is the row data for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is ReduceWidthByte*8 val Data = Valid(Vec(ABMatrixRegNBanks, UInt(ABMatrixRegEntryBitSize.W))) - //chosen是选择该MatrixReg的信号,是一个bool,我们做doublebuffer,选择其一供数,选择其一加载数据 + //chosen selects the MatrixReg; it is a Bool. We use double buffering, selecting one for output and one for loading // val Chosen = Input(Bool()) } -// 统一的AB MemoryLoader接口,支持ZeroFill功能 +class ABScaleControlMatrixRegIO(implicit p: Parameters) extends CuteBundle{ + //bankaddr is the row-select signal for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is log2Ceil(ABMatrixRegBankNLines), and it is input data that requires handshaking + val BankAddr = Flipped(DecoupledIO(UInt(log2Ceil(ABScaleBankNEntries).W))) + //bankdata is the row data for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is ReduceWidthByte*8 + val Data = Valid(Vec(ABScaleNSlices, UInt((ScaleWidth * ReduceGroupSize).W))) + //chosen selects the MatrixReg; it is a Bool. We use double buffering, selecting one for output and one for loading + // val Chosen = Input(Bool()) +} + +// unified AB MemoryLoader interface, with ZeroFill support class ABMemoryLoaderMatrixRegIO(implicit p: Parameters) extends CuteBundle{ // When MemLoader is working, active will be true. // Otherwise, active will be false. val active = Input(Bool()) - //bankaddr是对nbanks个bank,各自bank的行选信号,是一个vec,有nbanks个元素,每个元素是一个UInt,UInt的宽度是log2Ceil(ABMatrixRegBankNLines),是输入的需要握手的数据 - val BankAddr = Flipped(Vec(ABMatrixRegNBanks, Valid(UInt(log2Ceil(ABMatrixRegBankNEntrys).W)))) - //bankdata是对nbanks个bank,各自bank的行数据,是一个vec,有nbanks个元素,每个元素是一个UInt,UInt的宽度是ReduceWidthByte*8 + //bankaddr is the row-select signal for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is log2Ceil(ABMatrixRegBankNLines), and it is input data that requires handshaking + val BankAddr = Flipped(Vec(ABMatrixRegNBanks, Valid(UInt(log2Ceil(ABMatrixRegBankNEntries).W)))) + //bankdata is the row data for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is ReduceWidthByte*8 val Data = Flipped(Vec(ABMatrixRegNBanks, Valid(UInt(ABMatrixRegEntryBitSize.W)))) - // 每个bit控制对应1个byte是否写入 val ByteMask = Flipped(Vec(ABMatrixRegNBanks, Valid(UInt(ABMatrixRegEntryByteSize.W)))) } +class ABScaleLoaderMatrixRegIO(implicit p: Parameters) extends CuteBundle{ + //bankaddr is the row-select signal for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is log2Ceil(AScratchpadBankNLines), and it is input data that requires handshaking + val BankAddr = Flipped(Valid(UInt(log2Ceil(ABScaleBankNEntries).W))) + //bankdata is the row data for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is ReduceWidthByte*8 + val Data = Flipped(Valid(Vec(ABScaleNSlices, UInt((ScaleWidth * ReduceGroupSize).W)))) + //chosen selects the ScratchPad; it is a Bool. We use double buffering, selecting one for output and one for loading + // val Chosen = Input(Bool()) +} + class CDataControlMatrixRegIO(implicit p: Parameters) extends CuteBundle{ - //bankaddr是对nbanks个bank,各自bank的行选信号,是一个vec,有nbanks个元素,每个元素是一个UInt,UInt的宽度是log2Ceil(CMatrixRegBankNLines),是输入的需要握手的数据 - val ReadBankAddr = Flipped((Vec(CMatrixRegNBanks, Valid(UInt(log2Ceil(CMatrixRegBankNEntrys).W))))) - val WriteBankAddr = Flipped((Vec(CMatrixRegNBanks, Valid(UInt(log2Ceil(CMatrixRegBankNEntrys).W))))) - //bankdata是对nbanks个bank,各自bank的行数据,是一个vec,有nbanks个元素,每个元素是一个UInt + //bankaddr is the row-select signal for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is log2Ceil(CMatrixRegBankNLines), and it is input data that requires handshaking + val ReadBankAddr = Flipped((Vec(CMatrixRegNBanks, Valid(UInt(log2Ceil(CMatrixRegBankNEntries).W))))) + val WriteBankAddr = Flipped((Vec(CMatrixRegNBanks, Valid(UInt(log2Ceil(CMatrixRegBankNEntries).W))))) + //bankdata is the row data for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt val ReadResponseData = (Vec(CMatrixRegNBanks, Valid(UInt(CMatrixRegEntryBitSize.W)))) val WriteRequestData = Flipped((Vec(CMatrixRegNBanks, Valid(UInt(CMatrixRegEntryBitSize.W))))) - val WriteRequestByteMask = Flipped((Vec(CMatrixRegNBanks, Valid(UInt(CMatrixRegEntryByteSize.W))))) - //chosen是选择该MatrixReg的信号,是一个bool,我们做doublebuffer,选择其一供数,选择其一加载数据 + //chosen selects the MatrixReg; it is a Bool. We use double buffering, selecting one for output and one for loading val ReadWriteRequest = Input(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) val ReadWriteResponse = Output(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) // val Chosen = Input(Bool()) @@ -918,36 +1174,38 @@ class CDataControlMatrixRegIO(implicit p: Parameters) extends CuteBundle{ class CMemoryLoaderMatrixRegIO(implicit p: Parameters) extends CuteBundle{ val ReadRequestToMatrixReg = (new Bundle{ - val BankAddr = Flipped(Vec(CMatrixRegNBanks, Valid(UInt(log2Ceil(CMatrixRegBankNEntrys).W)))) + val BankAddr = Flipped(Vec(CMatrixRegNBanks, Valid(UInt(log2Ceil(CMatrixRegBankNEntries).W)))) val ReadResponseData = ((Vec(CMatrixRegNBanks, Valid(UInt(CMatrixRegEntryBitSize.W))))) }) val WriteRequestToMatrixReg = (new Bundle{ - val BankAddr = Flipped(Vec(CMatrixRegNBanks, (Valid(UInt(log2Ceil(CMatrixRegBankNEntrys).W))))) + val BankAddr = Flipped(Vec(CMatrixRegNBanks, (Valid(UInt(log2Ceil(CMatrixRegBankNEntries).W))))) val Data = Flipped(Vec(CMatrixRegNBanks, (Valid(UInt(CMatrixRegEntryBitSize.W))))) val ByteMask = Flipped(Vec(CMatrixRegNBanks, Valid(UInt(CMatrixRegEntryByteSize.W)))) }) - - val ReadWriteRequest = Input(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) - val ReadWriteResponse = Output(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) + val LoadReadWriteRequest = Input(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) + val StoreReadWriteRequest = Input(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) + val LoadReadWriteResponse = Output(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) + val StoreReadWriteResponse = Output(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) // val Chosen = Input(Bool()) } -//LocalMMU的接口 +//LocalMMU interface class LocalMMUIO(implicit p: Parameters) extends CuteBundle{ - //发出的访存请求 + //issued memory request val Request = Flipped(DecoupledIO(new Bundle{ val RequestVirtualAddr = UInt(MMUAddrWidth.W) val RequestConherent = Bool() val RequestData = UInt(MMUDataWidth.W) val RequestSourceID = UInt(SoureceMaxNumBitSize.W) val RequestType_isWrite = Bool() + val RequestMask = UInt(MMUMaskWidth.W) //MMU byte mask })) - //读请求分发到的TL Link的事务编号 + //transaction ID of the TL link to which the read request is dispatched val ConherentRequsetSourceID = Valid(UInt(LLCSourceMaxNumBitSize.W)) val nonConherentRequsetSourceID = Valid(UInt(MemorysourceMaxNumBitSize.W)) - //Memoryloader一定能保证收回! + //the MemoryLoader is guaranteed to receive the response back! val Response = DecoupledIO(new Bundle{ val ReseponseData = UInt(MMUDataWidth.W) val ReseponseConherent = Bool() @@ -957,21 +1215,21 @@ class LocalMMUIO(implicit p: Parameters) extends CuteBundle{ class MMU2TLIO(implicit p: Parameters) extends CuteBundle{ - //发出的访存请求 + //issued memory request val Request = Flipped(DecoupledIO(new Bundle{ val RequestPhysicalAddr = UInt(MMUAddrWidth.W) val RequestConherent = Bool() val RequestData = UInt(MMUDataWidth.W) val RequestSourceID = UInt(SoureceMaxNumBitSize.W) val RequestType_isWrite = Bool() - val RequestMask = UInt(MMUMaskWidth.W) //MMU的Mask + val RequestMask = UInt(MMUMaskWidth.W) //MMU mask val MatrixIsAcc = Bool() // false for A/B matrix (tile matrix register), true for C matrix (accumulation matrix register) })) - //读请求分发到的TL Link的事务编号 + //transaction ID of the TL link to which the read request is dispatched val ConherentRequsetSourceID = Valid(UInt(LLCSourceMaxNumBitSize.W)) val nonConherentRequsetSourceID = Valid(UInt(MemorysourceMaxNumBitSize.W)) - //Memoryloader一定能保证收回! + //the MemoryLoader is guaranteed to receive the response back! val Response = DecoupledIO(new Bundle{ val ReseponseData = UInt(MMUDataWidth.W) val ReseponseConherent = Bool() @@ -979,73 +1237,150 @@ class MMU2TLIO(implicit p: Parameters) extends CuteBundle{ }) } -class FReducePEDataType(dataType: UInt){ +class FReducePEDataType { //0:Int8, 1:FP16, 2:BF16, 3:TF32, 4:I8 * UI8, 5:UI8 * I8, 6:UI8 * UI8 - def AdataByteWidth: Int = dataType match { - case ElementDataType.DataTypeI8I8I32 => 1 - case ElementDataType.DataTypeF16F16F32 => 2 - case ElementDataType.DataTypeBF16BF16F32 => 2 - case ElementDataType.DataTypeTF32TF32F32 => 4 - case ElementDataType.DataTypeI8U8I32 => 1 - case ElementDataType.DataTypeU8I8I32 => 1 - case ElementDataType.DataTypeU8U8I32 => 1 - case _ => 0 //未定义的类型,返回0字节宽度 + def AdataByteWidth(computeType : UInt) : UInt = { + val dataByteWidth = Wire(UInt(3.W)) + dataByteWidth := 0.U + switch(computeType){ + is (MteComputeType.I8I8I32) { dataByteWidth := 1.U } + is (MteComputeType.F16F16F32) { dataByteWidth := 2.U } + is (MteComputeType.BF16BF16F32) { dataByteWidth := 2.U } + is (MteComputeType.TF32TF32F32) { dataByteWidth := 4.U } + is (MteComputeType.I8U8I32) { dataByteWidth := 1.U } + is (MteComputeType.U8I8I32) { dataByteWidth := 1.U } + is (MteComputeType.U8U8I32) { dataByteWidth := 1.U } + is (MteComputeType.Mxfp8e4m3F32) { dataByteWidth := 1.U } + is (MteComputeType.Mxfp8e5m2F32) { dataByteWidth := 1.U } + is (MteComputeType.Fp8e4m3F32) { dataByteWidth := 1.U } + is (MteComputeType.Fp8e5m2F32) { dataByteWidth := 1.U } + } + dataByteWidth + } + + def BdataByteWidth(computeType : UInt) : UInt = { + val dataByteWidth = Wire(UInt(3.W)) + dataByteWidth := 0.U + switch(computeType){ + is (MteComputeType.I8I8I32) { dataByteWidth := 1.U } + is (MteComputeType.F16F16F32) { dataByteWidth := 2.U } + is (MteComputeType.BF16BF16F32) { dataByteWidth := 2.U } + is (MteComputeType.TF32TF32F32) { dataByteWidth := 4.U } + is (MteComputeType.I8U8I32) { dataByteWidth := 1.U } + is (MteComputeType.U8I8I32) { dataByteWidth := 1.U } + is (MteComputeType.U8U8I32) { dataByteWidth := 1.U } + is (MteComputeType.Mxfp8e4m3F32) { dataByteWidth := 1.U } + is (MteComputeType.Mxfp8e5m2F32) { dataByteWidth := 1.U } + is (MteComputeType.Fp8e4m3F32) { dataByteWidth := 1.U } + is (MteComputeType.Fp8e5m2F32) { dataByteWidth := 1.U } + } + dataByteWidth + } + + def AdataBitWidth(computeType : UInt) : UInt = { + val dataBitWidth = Wire(UInt(6.W)) + dataBitWidth := 0.U + switch(computeType){ + is (MteComputeType.I8I8I32) { dataBitWidth := 8.U } + is (MteComputeType.F16F16F32) { dataBitWidth := 16.U } + is (MteComputeType.BF16BF16F32) { dataBitWidth := 16.U } + is (MteComputeType.TF32TF32F32) { dataBitWidth := 32.U } + is (MteComputeType.I8U8I32) { dataBitWidth := 8.U } + is (MteComputeType.U8I8I32) { dataBitWidth := 8.U } + is (MteComputeType.U8U8I32) { dataBitWidth := 8.U } + is (MteComputeType.Mxfp8e4m3F32) { dataBitWidth := 8.U } + is (MteComputeType.Mxfp8e5m2F32) { dataBitWidth := 8.U } + is (MteComputeType.Nvfp4F32) { dataBitWidth := 4.U } + is (MteComputeType.Mxfp4F32) { dataBitWidth := 4.U } + is (MteComputeType.Fp8e4m3F32) { dataBitWidth := 8.U } + is (MteComputeType.Fp8e5m2F32) { dataBitWidth := 8.U } + } + dataBitWidth } - def BdataByteWidth: Int = dataType match { - case ElementDataType.DataTypeI8I8I32 => 1 - case ElementDataType.DataTypeF16F16F32 => 2 - case ElementDataType.DataTypeBF16BF16F32 => 2 - case ElementDataType.DataTypeTF32TF32F32 => 4 - case ElementDataType.DataTypeI8U8I32 => 1 - case ElementDataType.DataTypeU8I8I32 => 1 - case ElementDataType.DataTypeU8U8I32 => 1 - case _ => 0 //未定义的类型,返回0字节宽度 + def BdataBitWidth(computeType : UInt) : UInt = { + val dataBitWidth = Wire(UInt(6.W)) + dataBitWidth := 0.U + switch(computeType){ + is (MteComputeType.I8I8I32) { dataBitWidth := 8.U } + is (MteComputeType.F16F16F32) { dataBitWidth := 16.U } + is (MteComputeType.BF16BF16F32) { dataBitWidth := 16.U } + is (MteComputeType.TF32TF32F32) { dataBitWidth := 32.U } + is (MteComputeType.I8U8I32) { dataBitWidth := 8.U } + is (MteComputeType.U8I8I32) { dataBitWidth := 8.U } + is (MteComputeType.U8U8I32) { dataBitWidth := 8.U } + is (MteComputeType.Mxfp8e4m3F32) { dataBitWidth := 8.U } + is (MteComputeType.Mxfp8e5m2F32) { dataBitWidth := 8.U } + is (MteComputeType.Nvfp4F32) { dataBitWidth := 4.U } + is (MteComputeType.Mxfp4F32) { dataBitWidth := 4.U } + is (MteComputeType.Fp8e4m3F32) { dataBitWidth := 8.U } + is (MteComputeType.Fp8e5m2F32) { dataBitWidth := 8.U } + } + dataBitWidth } - def CdataByteWidth: Int = dataType match { - case ElementDataType.DataTypeI8I8I32 => 4 - case ElementDataType.DataTypeF16F16F32 => 4 - case ElementDataType.DataTypeBF16BF16F32 => 4 - case ElementDataType.DataTypeTF32TF32F32 => 4 - case ElementDataType.DataTypeI8U8I32 => 4 - case ElementDataType.DataTypeU8I8I32 => 4 - case ElementDataType.DataTypeU8U8I32 => 4 - case _ => 0 //未定义的类型,返回0字节宽度 + def CdataByteWidth(computeType : UInt) : UInt = { + val dataByteWidth = Wire(UInt(3.W)) + dataByteWidth := 0.U + switch(computeType){ + is (MteComputeType.I8I8I32) { dataByteWidth := 4.U } + is (MteComputeType.F16F16F32) { dataByteWidth := 4.U } + is (MteComputeType.BF16BF16F32) { dataByteWidth := 4.U } + is (MteComputeType.TF32TF32F32) { dataByteWidth := 4.U } + is (MteComputeType.I8U8I32) { dataByteWidth := 4.U } + is (MteComputeType.U8I8I32) { dataByteWidth := 4.U } + is (MteComputeType.U8U8I32) { dataByteWidth := 4.U } + is (MteComputeType.Mxfp8e4m3F32) { dataByteWidth := 4.U } + is (MteComputeType.Mxfp8e5m2F32) { dataByteWidth := 4.U } + is (MteComputeType.Nvfp4F32) { dataByteWidth := 4.U } + is (MteComputeType.Mxfp4F32) { dataByteWidth := 4.U } + is (MteComputeType.Fp8e4m3F32) { dataByteWidth := 4.U } + is (MteComputeType.Fp8e5m2F32) { dataByteWidth := 4.U } + } + dataByteWidth } - def DdataByteWidth: Int = dataType match { - case ElementDataType.DataTypeI8I8I32 => 4 - case ElementDataType.DataTypeF16F16F32 => 4 - case ElementDataType.DataTypeBF16BF16F32 => 4 - case ElementDataType.DataTypeTF32TF32F32 => 4 - case ElementDataType.DataTypeI8U8I32 => 4 - case ElementDataType.DataTypeU8I8I32 => 4 - case ElementDataType.DataTypeU8U8I32 => 4 - case _ => 0 //未定义的类型,返回0字节宽度 + def DdataByteWidth(computeType : UInt) : UInt = { + val dataByteWidth = Wire(UInt(3.W)) + dataByteWidth := 0.U + switch(computeType){ + is (MteComputeType.I8I8I32) { dataByteWidth := 4.U } + is (MteComputeType.F16F16F32) { dataByteWidth := 4.U } + is (MteComputeType.BF16BF16F32) { dataByteWidth := 4.U } + is (MteComputeType.TF32TF32F32) { dataByteWidth := 4.U } + is (MteComputeType.I8U8I32) { dataByteWidth := 4.U } + is (MteComputeType.U8I8I32) { dataByteWidth := 4.U } + is (MteComputeType.U8U8I32) { dataByteWidth := 4.U } + is (MteComputeType.Mxfp8e4m3F32) { dataByteWidth := 4.U } + is (MteComputeType.Mxfp8e5m2F32) { dataByteWidth := 4.U } + is (MteComputeType.Nvfp4F32) { dataByteWidth := 4.U } + is (MteComputeType.Mxfp4F32) { dataByteWidth := 4.U } + is (MteComputeType.Fp8e4m3F32) { dataByteWidth := 4.U } + is (MteComputeType.Fp8e5m2F32) { dataByteWidth := 4.U } + } + dataByteWidth } + } -//数据类型的样板类 +//data type template class case object ElementDataType extends Field[UInt]{ - val DataTypeBitWidth = 3 + val DataTypeBitWidth = 4 val DataTypeUndef = 0.U(DataTypeBitWidth.W) + val DataTypeI8I8I32 = MteComputeType.I8I8I32 + val DataTypeI8U8I32 = MteComputeType.I8U8I32 + val DataTypeU8I8I32 = MteComputeType.U8I8I32 + val DataTypeU8U8I32 = MteComputeType.U8U8I32 + val DataTypeF16F16F32 = MteComputeType.F16F16F32 + val DataTypeBF16BF16F32 = MteComputeType.BF16BF16F32 + val DataTypeTF32TF32F32 = MteComputeType.TF32TF32F32 val DataTypeWidth32 = 4.U(DataTypeBitWidth.W) val DataTypeWidth16 = 2.U(DataTypeBitWidth.W) val DataTypeWidth8 = 1.U(DataTypeBitWidth.W) val DataTypeWidth4 = 7.U(DataTypeBitWidth.W) - - val DataTypeI8I8I32 = 0.U(DataTypeBitWidth.W) //I8 * I8 * I32 - val DataTypeF16F16F32 = 1.U(DataTypeBitWidth.W) //FP16 * FP16 * FP32 - val DataTypeBF16BF16F32 = 2.U(DataTypeBitWidth.W) //BF16 * BF16 * FP32 - val DataTypeTF32TF32F32 = 3.U(DataTypeBitWidth.W) //TF32 * TF32 * FP32 - val DataTypeI8U8I32 = 4.U(DataTypeBitWidth.W) //I8 * UI8 * I32 - val DataTypeU8I8I32 = 5.U(DataTypeBitWidth.W) //U8 * I8 * I32 - val DataTypeU8U8I32 = 6.U(DataTypeBitWidth.W) //U8 * U8 * I32 - } -//工作任务的样板类 +//work-task template class case object CUTETaskType extends Field[UInt]{ val CUTETaskBitWidth = 8 val TaskTypeUndef = 0.U(CUTETaskBitWidth.W) @@ -1056,19 +1391,19 @@ case object CUTETaskType extends Field[UInt]{ case object CMemoryLoaderTaskType extends Field[UInt]{ val TypeBitWidth = 4 val TaskTypeUndef = 0.U(TypeBitWidth.W) - val TaskTypeTensorZeroLoad = 1.U(TypeBitWidth.W) //直接将数据填充为0,实际上是什么也没做,默认可以写入SRAM,无视以前SRAM里面的数据即可 - val TaskTypeTensorRepeatRowLoad = 2.U(TypeBitWidth.W) //重复加载一行数据,实际上是什么也没做,默认可以写入SRAM,无视以前SRAM里面的数据即可 - val TaskTypeTensorLoad = 3.U(TypeBitWidth.W) //完整的加载所有数据 + val TaskTypeTensorZeroLoad = 1.U(TypeBitWidth.W) //fill the data with 0 directly; in practice this does nothing, and it can write to SRAM by default, ignoring previous SRAM contents + val TaskTypeTensorRepeatRowLoad = 2.U(TypeBitWidth.W) //load one row repeatedly; in practice this does nothing, and it can write to SRAM by default, ignoring previous SRAM contents + val TaskTypeTensorLoad = 3.U(TypeBitWidth.W) //load all data completely } case object MemoryOrderType extends Field[UInt]{ val MemoryOrderTypeBitWidth = 8 val OrderTypeUndef = 0.U(MemoryOrderTypeBitWidth.W) - val OrderType_Mb_Kb = 1.U(MemoryOrderTypeBitWidth.W) //在地址空间中顺序摆放的顺序, Mb在前,Kb在后 - val OrderType_Mb_Nb = 1.U(MemoryOrderTypeBitWidth.W) //在地址空间中顺序摆放的顺序, Mb在前,Nb在后 - val OrderType_Nb_Kb = 1.U(MemoryOrderTypeBitWidth.W) //在地址空间中顺序摆放的顺序, Nb在前,Kb在后 - val OrderType_Nb_Mb = 2.U(MemoryOrderTypeBitWidth.W) //在地址空间中顺序摆放的顺序, Nb在前,Mb在后 - val OrderType_Kb_Mb = 2.U(MemoryOrderTypeBitWidth.W) //在地址空间中顺序摆放的顺序, Kb在前,Mb在后 - val OrderType_Kb_Nb = 2.U(MemoryOrderTypeBitWidth.W) //在地址空间中顺序摆放的顺序, Kb在前,Nb在后 + val OrderType_Mb_Kb = 1.U(MemoryOrderTypeBitWidth.W) //ordering in address space, with Mb first and Kb after + val OrderType_Mb_Nb = 1.U(MemoryOrderTypeBitWidth.W) //ordering in address space, with Mb first and Nb after + val OrderType_Nb_Kb = 1.U(MemoryOrderTypeBitWidth.W) //ordering in address space, with Nb first and Kb after + val OrderType_Nb_Mb = 2.U(MemoryOrderTypeBitWidth.W) //ordering in address space, with Nb first and Mb after + val OrderType_Kb_Mb = 2.U(MemoryOrderTypeBitWidth.W) //ordering in address space, with Kb first and Mb after + val OrderType_Kb_Nb = 2.U(MemoryOrderTypeBitWidth.W) //ordering in address space, with Kb first and Nb after } @@ -1088,11 +1423,11 @@ class MatrixRegTaskDecode(MatrixRegTask: UInt) extends Field[UInt]{ case object MatrixRegTaskType extends Field[UInt]{ val TaskTypeBitWidth = 4 - // 对于单个MatrixReg,其并发的数据来源一共用3个,所以用3bit来表示。 - // 1. DataController对PE的输入数据的对MatrixReg读请求 - // 2. DataController将PE的输出结果送入MatrixReg写请求 - // 3. MemoryLoader对MatrixReg的写请求 - // 我们不知道MatrixReg的读写端口数量,所以用使能信号表示接受的数据来源 + // for a single MatrixReg, there are three concurrent data sources, so they are encoded with 3 bits + // 1. DataController read requests to MatrixReg for PE input data + // 2. DataController write requests that send PE outputs into MatrixReg + // 3. MemoryLoader write requests to MatrixReg + // we do not know the number of MatrixReg read/write ports, so we use enable signals to indicate accepted data sources val EnableReadFromDataController = 1.U(TaskTypeBitWidth.W) val EnableWriteFromDataController = 2.U(TaskTypeBitWidth.W) val EnableWriteFromMemoryLoader = 4.U(TaskTypeBitWidth.W) @@ -1111,10 +1446,13 @@ class MatrixRegTask(implicit p: Parameters) extends CuteBundle{ } case object LocalMMUTaskType extends Field[UInt]{ - val TaskTypeBitWidth = 2 - val TaskTypeMax = 3 + val TaskTypeBitWidth = 3 + val TaskTypeMax = 6 val AFirst = 0.U(TaskTypeBitWidth.W) val BFirst = 1.U(TaskTypeBitWidth.W) - val CFirst = 2.U(TaskTypeBitWidth.W) + val CLoadFirst = 2.U(TaskTypeBitWidth.W) + val CStoreFirst = 3.U(TaskTypeBitWidth.W) + val BScaleFirst = 4.U(TaskTypeBitWidth.W) + val AScaleFirst = 5.U(TaskTypeBitWidth.W) // val DFirst = 3.U(TaskTypeBitWidth.W) } diff --git a/src/main/scala/CUTETOP.scala b/src/main/scala/CUTETOP.scala index 8e58117..9650618 100644 --- a/src/main/scala/CUTETOP.scala +++ b/src/main/scala/CUTETOP.scala @@ -10,6 +10,7 @@ import org.chipsalliance.cde.config._ class CUTETopIO()(implicit p: Parameters) extends CuteBundle{ val mmu2llc = Flipped(new MMU2TLIO) val ctrl2top = Flipped(new YGJKControl) + val perf = Output(new CutePerfToCoreIO) } class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ val io = IO(new CUTETopIO) @@ -24,6 +25,27 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ val BML = Module(new BMemoryLoader) val CMatrixRegs = Seq.tabulate(CMatrixRegCount)(i => Module(new CMatrixReg(i))).toVector + + val ASMRegs = Option.when(cuteMatrixExtension.enableScalingFactor)( + Seq.tabulate(2)(i => Module(new ABScaleMatrixReg)).toVector + ) //双缓冲 + val ASC = Option.when(cuteMatrixExtension.enableScalingFactor)( + Module(new AScaleController) + ) + val ASL = Option.when(cuteMatrixExtension.enableScalingFactor)( + Module(new AScaleLoader) + ) + + val BSMRegs = Option.when(cuteMatrixExtension.enableScalingFactor)( + Seq.tabulate(2)(i => Module(new ABScaleMatrixReg)).toVector + ) //双缓冲 + val BSC = Option.when(cuteMatrixExtension.enableScalingFactor)( + Module(new BScaleController) + ) + val BSL = Option.when(cuteMatrixExtension.enableScalingFactor)( + Module(new BScaleLoader) + ) + val CDC = Module(new CDataController) val CML = Module(new CMemoryLoader) @@ -44,22 +66,54 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ ADC.io.ConfigInfo <> TaskCtrl.io.ADC_MicroTask_Config ADC.io.DebugInfo.DebugTimeStampe := DebugTimeStampe + ASC.foreach { asc => + //ASC的默认输入 + asc.io.FromMatrixRegIO.Data.valid := false.B + asc.io.FromMatrixRegIO.Data.bits := 0.U.asTypeOf(asc.io.FromMatrixRegIO.Data.bits) + asc.io.FromMatrixRegIO.BankAddr.ready := false.B + asc.io.ConfigInfo <> TaskCtrl.io.ASC_MicroTask_Config.get + asc.io.DebugInfo.DebugTimeStampe := DebugTimeStampe + } + //AML的默认输入 AML.io.ConfigInfo <> TaskCtrl.io.AML_MicroTask_Config AML.io.DebugInfo.DebugTimeStampe := DebugTimeStampe AML.io.LocalMMUIO <> MMU.io.ALocalMMUIO + ASL.foreach { asl => + //ASL的默认输入 + asl.io.ConfigInfo <> TaskCtrl.io.ASL_MicroTask_Config.get + asl.io.DebugInfo.DebugTimeStampe := DebugTimeStampe + asl.io.LocalMMUIO <> MMU.io.ASLocalMMUIO.get + } + //BDC的默认输入 BDC.io.FromMatrixRegIO.Data.valid := false.B BDC.io.FromMatrixRegIO.BankAddr.ready := false.B BDC.io.ConfigInfo <> TaskCtrl.io.BDC_MicroTask_Config BDC.io.DebugInfo.DebugTimeStampe := DebugTimeStampe + BSC.foreach { bsc => + //BSC的默认输入 + bsc.io.FromMatrixRegIO.Data.valid := false.B + bsc.io.FromMatrixRegIO.Data.bits := 0.U.asTypeOf(bsc.io.FromMatrixRegIO.Data.bits) + bsc.io.FromMatrixRegIO.BankAddr.ready := false.B + bsc.io.ConfigInfo <> TaskCtrl.io.BSC_MicroTask_Config.get + bsc.io.DebugInfo.DebugTimeStampe := DebugTimeStampe + } + //BML的默认输入 BML.io.ConfigInfo <> TaskCtrl.io.BML_MicroTask_Config BML.io.DebugInfo.DebugTimeStampe := DebugTimeStampe BML.io.LocalMMUIO <> MMU.io.BLocalMMUIO + BSL.foreach { bsl => + //BSL的默认输入 + bsl.io.ConfigInfo <> TaskCtrl.io.BSL_MicroTask_Config.get + bsl.io.DebugInfo.DebugTimeStampe := DebugTimeStampe + bsl.io.LocalMMUIO <> MMU.io.BSLocalMMUIO.get + } + //CDC的默认输入 CDC.io.FromMatrixRegIO.ReadResponseData := 0.U.asTypeOf(CDC.io.FromMatrixRegIO.ReadResponseData) CDC.io.FromMatrixRegIO.ReadWriteResponse := 0.U.asTypeOf(CDC.io.FromMatrixRegIO.ReadWriteResponse) @@ -69,25 +123,57 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ //CML的默认输入 CML.io.ConfigInfo <> TaskCtrl.io.CML_MicroTask_Config CML.io.DebugInfo.DebugTimeStampe := DebugTimeStampe - CML.io.LocalMMUIO <> MMU.io.CLocalMMUIO - CML.io.ToMatrixRegIO.ReadWriteResponse := 0.U + CML.io.LoadLocalMMUIO <> MMU.io.CLoadLocalMMUIO + CML.io.StoreLocalMMUIO <> MMU.io.CStoreLocalMMUIO + CML.io.ToMatrixRegIO.LoadReadWriteResponse := 0.U + CML.io.ToMatrixRegIO.StoreReadWriteResponse := 0.U CML.io.ToMatrixRegIO.ReadRequestToMatrixReg.ReadResponseData := 0.U.asTypeOf(CML.io.ToMatrixRegIO.ReadRequestToMatrixReg.ReadResponseData) - //MTE的默认输入 MTE.io.VectorA <> ADC.io.VectorA MTE.io.VectorB <> BDC.io.VectorB + MTE.io.ScaleA.zip(ASC).foreach { case (scaleA, asc) => + scaleA <> asc.io.ScaleA + } + MTE.io.ScaleB.zip(BSC).foreach { case (scaleB, bsc) => + scaleB <> bsc.io.ScaleB + } MTE.io.MatrixC <> CDC.io.Matrix_C MTE.io.MatrixD <> CDC.io.ResultMatrix_D MTE.io.ConfigInfo <> TaskCtrl.io.MTE_MicroTask_Config MTE.io.DebugInfo.DebugTimeStampe := DebugTimeStampe ADC.io.ComputeGo := MTE.io.ComputeGo BDC.io.ComputeGo := MTE.io.ComputeGo + ASC.foreach(_.io.ComputeGo := MTE.io.ComputeGo) + BSC.foreach(_.io.ComputeGo := MTE.io.ComputeGo) CDC.io.ComputeGo := MTE.io.ComputeGo //后续需要连入CPU的MMU或者IOMMU MMU.io.LastLevelCacheTLIO <> io.mmu2llc io.ctrl2top <> TaskCtrl.io.ygjkctrl + val perf = WireInit(0.U.asTypeOf(new CutePerfToCoreIO)) + perf.backendEvents(0) := TaskCtrl.io.perfProbe.ownedWork + perf.backendEvents(1) := TaskCtrl.io.perfProbe.retire + perf.backendEvents(2) := TaskCtrl.io.perfProbe.compDone + perf.backendEvents(3) := TaskCtrl.io.perfProbe.releaseDone + perf.backendEvents(4) := TaskCtrl.io.perfProbe.mteActive + perf.backendEvents(5) := TaskCtrl.io.perfProbe.mmaNonfpDone + perf.backendEvents(6) := TaskCtrl.io.perfProbe.mmaFp16Done + perf.backendEvents(7) := TaskCtrl.io.perfProbe.mmaBf16Done + perf.backendEvents(8) := TaskCtrl.io.perfProbe.mmaTf32Done + perf.memEvents(0) := TaskCtrl.io.perfProbe.loadADone + perf.memEvents(1) := TaskCtrl.io.perfProbe.loadBDone + perf.memEvents(2) := TaskCtrl.io.perfProbe.loadCDone + perf.memEvents(3) := TaskCtrl.io.perfProbe.storeDone + perf.memEvents(4) := TaskCtrl.io.perfProbe.amlActive + perf.memEvents(5) := TaskCtrl.io.perfProbe.bmlActive + perf.memEvents(6) := TaskCtrl.io.perfProbe.cmlLoadActive + perf.memEvents(7) := TaskCtrl.io.perfProbe.cmlStoreActive + perf.memEvents(8) := MMU.io.perfProbe.rdReq + perf.memEvents(9) := MMU.io.perfProbe.wrReq + perf.memEvents(10) := MMU.io.perfProbe.rd32BReq + perf.memEvents(11) := MMU.io.perfProbe.wr32BReq + io.perf := RegNext(perf, 0.U.asTypeOf(new CutePerfToCoreIO)) //给每个 MatrixReg 的输入进行默认赋值 @@ -102,10 +188,80 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ByteMask := 0.U.asTypeOf(ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ByteMask) } + if (cuteMatrixExtension.enableScalingFactor) { + // AB Scale Regs + (ASMRegs.get ++ BSMRegs.get).foreach { reg => + reg.io.FromScaleController.BankAddr.valid := false.B + reg.io.FromScaleController.BankAddr.bits := 0.U.asTypeOf(reg.io.FromScaleController.BankAddr.bits) + reg.io.FromScaleLoader.BankAddr := 0.U.asTypeOf(reg.io.FromScaleLoader.BankAddr) + reg.io.FromScaleLoader.Data := 0.U.asTypeOf(reg.io.FromScaleLoader.Data) + } + } + // C MatrixReg for (i <- 0 until CMatrixRegCount){ CMatrixRegs(i).io.MatrixRegIO.FromDataController.ReadWriteRequest := 0.U - CMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ReadWriteRequest := 0.U + CMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.LoadReadWriteRequest := 0.U + CMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.StoreReadWriteRequest := 0.U + } + + // ============================================ + // A Scale MatrixReg 路由逻辑 (双缓冲) + // ============================================ + def connectScaleControlToRegs( + spadId: UInt, + ScaleCtrlIO: ABScaleControlMatrixRegIO, + ScaleRegs: Seq[ABScaleMatrixReg] + ): Unit = { + // ASC 选择 ScaleRegs,根据 SpadId 选择对应的 MatrixReg + for (spadIdx <- 0 until ScaleRegs.length) { + val dest = ScaleRegs(spadIdx).io.FromScaleController + val ascSel = spadId === spadIdx.U + when(ascSel) { + dest.BankAddr.valid := ScaleCtrlIO.BankAddr.valid + dest.BankAddr.bits := ScaleCtrlIO.BankAddr.bits + }.otherwise { + dest.BankAddr.valid := false.B + dest.BankAddr.bits := DontCare + } + } + + // ASC 接收 ScaleRegs 返回的数据 + val sels = ScaleRegs.indices.map(spadId === _.U) + ScaleCtrlIO.BankAddr.ready := Mux1H(sels zip ScaleRegs.map(_.io.FromScaleController.BankAddr.ready)) + ScaleCtrlIO.Data.valid := Mux1H(sels zip ScaleRegs.map(_.io.FromScaleController.Data.valid)) + ScaleCtrlIO.Data.bits := Mux1H(sels zip ScaleRegs.map(_.io.FromScaleController.Data.bits)) + } + + def connectScaleLoaderToRegs( + spadId: UInt, + ScaleLoaderIO: ABScaleLoaderMatrixRegIO, + ScaleRegs: Seq[ABScaleMatrixReg] + ): Unit = { + // ASL 选择 ScaleRegs,根据 SpadId 选择对应的 MatrixReg + for (spadIdx <- 0 until ScaleRegs.length) { + val dest = ScaleRegs(spadIdx).io.FromScaleLoader + val aslSel = spadId === spadIdx.U + when(aslSel) { + dest.BankAddr := ScaleLoaderIO.BankAddr + dest.Data := ScaleLoaderIO.Data + }.otherwise { + dest.BankAddr.valid := false.B + dest.BankAddr.bits := DontCare + dest.Data.valid := false.B + dest.Data.bits := DontCare + } + } + } + + ASC.zip(ASL).zip(ASMRegs).foreach { case ((asc, asl), asmRegs) => + connectScaleControlToRegs(asc.io.SpadId, asc.io.FromMatrixRegIO, asmRegs) + connectScaleLoaderToRegs(asl.io.SpadId, asl.io.ToMatrixRegIO, asmRegs) + } + + BSC.zip(BSL).zip(BSMRegs).foreach { case ((bsc, bsl), bsmRegs) => + connectScaleControlToRegs(bsc.io.SpadId, bsc.io.FromMatrixRegIO, bsmRegs) + connectScaleLoaderToRegs(bsl.io.SpadId, bsl.io.ToMatrixRegIO, bsmRegs) } def disableABLoaderPort(dest: ABMemoryLoaderMatrixRegIO): Unit = { @@ -198,45 +354,49 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ for (regIdx <- 0 until CMatrixRegCount) { val dest = CMatrixRegs(regIdx).io.MatrixRegIO.FromMemoryLoader - - when(CML.io.MatrixRegId === regIdx.U) { - for (bank <- 0 until CMatrixRegNBanks) { - dest.ReadRequestToMatrixReg.BankAddr(bank).valid := CML.io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr(bank).valid - dest.WriteRequestToMatrixReg.BankAddr(bank).valid := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(bank).valid - dest.WriteRequestToMatrixReg.Data(bank).valid := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(bank).valid - dest.WriteRequestToMatrixReg.ByteMask(bank).valid := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(bank).valid - } - dest.ReadWriteRequest := CML.io.ToMatrixRegIO.ReadWriteRequest - }.otherwise { - for (bank <- 0 until CMatrixRegNBanks) { - dest.ReadRequestToMatrixReg.BankAddr(bank).valid := false.B - dest.WriteRequestToMatrixReg.BankAddr(bank).valid := false.B - dest.WriteRequestToMatrixReg.Data(bank).valid := false.B - dest.WriteRequestToMatrixReg.ByteMask(bank).valid := false.B - } + val loadSel = CML.io.LoadMatrixRegId === regIdx.U + val storeSel = CML.io.StoreMatrixRegId === regIdx.U + val loadWriteReq = CML.io.ToMatrixRegIO.LoadReadWriteRequest(MatrixRegTaskType.WriteFromMemoryLoaderIndex) + val storeReadReq = CML.io.ToMatrixRegIO.StoreReadWriteRequest(MatrixRegTaskType.ReadFromMemoryLoaderIndex) + + when(loadSel && storeSel && loadWriteReq && storeReadReq) { + assert(false.B, cf"[CUTETop] CML load/store target the same C MatrixReg($regIdx) in one cycle") } for (bank <- 0 until CMatrixRegNBanks) { + dest.ReadRequestToMatrixReg.BankAddr(bank).valid := storeSel && CML.io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr(bank).valid + dest.WriteRequestToMatrixReg.BankAddr(bank).valid := loadSel && CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(bank).valid + dest.WriteRequestToMatrixReg.Data(bank).valid := loadSel && CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(bank).valid + dest.WriteRequestToMatrixReg.ByteMask(bank).valid := loadSel && CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(bank).valid dest.ReadRequestToMatrixReg.BankAddr(bank).bits := CML.io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr(bank).bits dest.WriteRequestToMatrixReg.BankAddr(bank).bits := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(bank).bits dest.WriteRequestToMatrixReg.Data(bank).bits := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(bank).bits dest.WriteRequestToMatrixReg.ByteMask(bank).bits := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(bank).bits } + + val loadReqForThisReg = Mux(loadSel, CML.io.ToMatrixRegIO.LoadReadWriteRequest, 0.U) + val storeReqForThisReg = Mux(storeSel, CML.io.ToMatrixRegIO.StoreReadWriteRequest, 0.U) + dest.LoadReadWriteRequest := loadReqForThisReg + dest.StoreReadWriteRequest := storeReqForThisReg } - val cmlSelVec = VecInit((0 until CMatrixRegCount).map(i => CML.io.MatrixRegId === i.U)) + val cmlLoadSelVec = VecInit((0 until CMatrixRegCount).map(i => CML.io.LoadMatrixRegId === i.U)) + val cmlStoreSelVec = VecInit((0 until CMatrixRegCount).map(i => CML.io.StoreMatrixRegId === i.U)) for (bank <- 0 until CMatrixRegNBanks) { val readRespValidChoices = (0 until CMatrixRegCount).map(i => - cmlSelVec(i) -> CMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.ReadResponseData(bank).valid + cmlStoreSelVec(i) -> CMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.ReadResponseData(bank).valid ) val readRespBitsChoices = (0 until CMatrixRegCount).map(i => - cmlSelVec(i) -> CMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.ReadResponseData(bank).bits + cmlStoreSelVec(i) -> CMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ReadRequestToMatrixReg.ReadResponseData(bank).bits ) CML.io.ToMatrixRegIO.ReadRequestToMatrixReg.ReadResponseData(bank).valid := Mux1H(readRespValidChoices) CML.io.ToMatrixRegIO.ReadRequestToMatrixReg.ReadResponseData(bank).bits := Mux1H(readRespBitsChoices) } - val cmlRwRespChoices = (0 until CMatrixRegCount).map(i => cmlSelVec(i) -> CMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ReadWriteResponse) - CML.io.ToMatrixRegIO.ReadWriteResponse := Mux1H(cmlRwRespChoices) + + val cmlLoadRwRespChoices = (0 until CMatrixRegCount).map(i => cmlLoadSelVec(i) -> CMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.LoadReadWriteResponse) + val cmlStoreRwRespChoices = (0 until CMatrixRegCount).map(i => cmlStoreSelVec(i) -> CMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.StoreReadWriteResponse) + CML.io.ToMatrixRegIO.LoadReadWriteResponse := Mux1H(cmlLoadRwRespChoices) + CML.io.ToMatrixRegIO.StoreReadWriteResponse := Mux1H(cmlStoreRwRespChoices) for (regIdx <- 0 until CMatrixRegCount) { val dest = CMatrixRegs(regIdx).io.MatrixRegIO.FromDataController @@ -246,7 +406,6 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.ReadBankAddr(bank).valid := CDC.io.FromMatrixRegIO.ReadBankAddr(bank).valid dest.WriteBankAddr(bank).valid := CDC.io.FromMatrixRegIO.WriteBankAddr(bank).valid dest.WriteRequestData(bank).valid := CDC.io.FromMatrixRegIO.WriteRequestData(bank).valid - dest.WriteRequestByteMask(bank).valid := CDC.io.FromMatrixRegIO.WriteRequestByteMask(bank).valid } dest.ReadWriteRequest := CDC.io.FromMatrixRegIO.ReadWriteRequest }.otherwise { @@ -254,7 +413,6 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.ReadBankAddr(bank).valid := false.B dest.WriteBankAddr(bank).valid := false.B dest.WriteRequestData(bank).valid := false.B - dest.WriteRequestByteMask(bank).valid := false.B } } @@ -262,7 +420,6 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.ReadBankAddr(bank).bits := CDC.io.FromMatrixRegIO.ReadBankAddr(bank).bits dest.WriteBankAddr(bank).bits := CDC.io.FromMatrixRegIO.WriteBankAddr(bank).bits dest.WriteRequestData(bank).bits := CDC.io.FromMatrixRegIO.WriteRequestData(bank).bits - dest.WriteRequestByteMask(bank).bits := CDC.io.FromMatrixRegIO.WriteRequestByteMask(bank).bits } } @@ -282,5 +439,3 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ ) CDC.io.FromMatrixRegIO.ReadWriteResponse := Mux1H(cdcRwRespChoices) } - - diff --git a/src/main/scala/LocalMMU.scala b/src/main/scala/LocalMMU.scala index 147ea27..222c135 100644 --- a/src/main/scala/LocalMMU.scala +++ b/src/main/scala/LocalMMU.scala @@ -9,118 +9,147 @@ class LocalMMU()(implicit p: Parameters) extends CuteModule{ val io = IO(new Bundle{ val ALocalMMUIO = (new LocalMMUIO) val BLocalMMUIO = (new LocalMMUIO) - val CLocalMMUIO = (new LocalMMUIO) + val BSLocalMMUIO = Option.when(cuteMatrixExtension.enableScalingFactor)(new LocalMMUIO) + val ASLocalMMUIO = Option.when(cuteMatrixExtension.enableScalingFactor)(new LocalMMUIO) + val CLoadLocalMMUIO = (new LocalMMUIO) + val CStoreLocalMMUIO = (new LocalMMUIO) val LastLevelCacheTLIO = Flipped(new MMU2TLIO) + val perfProbe = Output(new LocalMMUPerfProbe) }) - //比较低的性能方式,轮询的方式,但是doublebuffer是可以的 - //访存流的设计,可以通过设计一些标志位来实现~! - //设计当前访存流处于哪种优先级。通过一些参数可以控制这个访存流。 - - val FirstRequestIndex = RegInit(0.U(log2Ceil(LocalMMUTaskType.TaskTypeMax).W))//只有3个请求来源,2位就够用了 - FirstRequestIndex := WrapInc(FirstRequestIndex, LocalMMUTaskType.TaskTypeMax) - // //printf(p"FirstRequestIndex ${FirstRequestIndex}\n") - //选择一个离FirstRequestIndex最近的请求 + val FirstRequestIndex = RegInit(0.U(LocalMMUTaskType.TaskTypeBitWidth.W)) val FirstIndex = FirstRequestIndex - val SecIndex = WrapInc(FirstRequestIndex, LocalMMUTaskType.TaskTypeMax) + val SecIndex = WrapInc(FirstIndex, LocalMMUTaskType.TaskTypeMax) val ThirdIndex = WrapInc(SecIndex, LocalMMUTaskType.TaskTypeMax) - // val ForthIndex = WrapInc(ThirdIndex, LocalMMUTaskType.TaskTypeMax) + val FourthIndex = WrapInc(ThirdIndex, LocalMMUTaskType.TaskTypeMax) + val FifthIndex = WrapInc(FourthIndex, LocalMMUTaskType.TaskTypeMax) + val SixthIndex = WrapInc(FifthIndex, LocalMMUTaskType.TaskTypeMax) - //假设目前只有一个LLC的访存端口。所以只能选择一个LLC的访存请求,进行服务。 - //循环服务和连续顺序服务,要考虑Cache连续读和Memory连续读的性能啊!! - //如果这样循环发出请求,可能会导致访存性能下降了,尤其是Memory,他是有bank切换和line切换的代价的!! - //这里先写一个循环的,后面再修改成局部连续的 - val AllRequestValid = Cat(io.CLocalMMUIO.Request.valid, io.BLocalMMUIO.Request.valid, io.ALocalMMUIO.Request.valid) + // Bit index follows LocalMMUTaskType encoding directly. + val AllRequestValid = Cat( + io.ASLocalMMUIO.map(_.Request.valid).getOrElse(false.B), + io.BSLocalMMUIO.map(_.Request.valid).getOrElse(false.B), + io.CStoreLocalMMUIO.Request.valid, + io.CLoadLocalMMUIO.Request.valid, + io.BLocalMMUIO.Request.valid, + io.ALocalMMUIO.Request.valid + ) val HasRequest = AllRequestValid.orR val ChoseIndex_0 = Mux(AllRequestValid(FirstIndex), FirstIndex, Mux(AllRequestValid(SecIndex), SecIndex, - Mux(AllRequestValid(ThirdIndex), ThirdIndex,LocalMMUTaskType.TaskTypeMax.U))) - - //如果是AFirst,就服务A,如果是B,就服务B,如果是C,就服务C + Mux(AllRequestValid(ThirdIndex), ThirdIndex, + Mux(AllRequestValid(FourthIndex), FourthIndex, + Mux(AllRequestValid(FifthIndex), FifthIndex, + Mux(AllRequestValid(SixthIndex), SixthIndex, LocalMMUTaskType.TaskTypeMax.U)))))) - //这里的设计是,只有一个LLC的访存端口,所以只能选择一个访存请求,进行服务。 - //如果有多个访存端口,就可以同时服务多个访存请求。 + FirstRequestIndex := WrapInc(ChoseIndex_0, LocalMMUTaskType.TaskTypeMax) io.ALocalMMUIO.Request.ready := false.B io.BLocalMMUIO.Request.ready := false.B - io.CLocalMMUIO.Request.ready := false.B + io.ASLocalMMUIO.foreach(_.Request.ready := false.B) + io.BSLocalMMUIO.foreach(_.Request.ready := false.B) + io.CLoadLocalMMUIO.Request.ready := false.B + io.CStoreLocalMMUIO.Request.ready := false.B + io.ALocalMMUIO.ConherentRequsetSourceID.valid := false.B io.BLocalMMUIO.ConherentRequsetSourceID.valid := false.B - io.CLocalMMUIO.ConherentRequsetSourceID.valid := false.B + io.ASLocalMMUIO.foreach(_.ConherentRequsetSourceID.valid := false.B) + io.BSLocalMMUIO.foreach(_.ConherentRequsetSourceID.valid := false.B) + io.CLoadLocalMMUIO.ConherentRequsetSourceID.valid := false.B + io.CStoreLocalMMUIO.ConherentRequsetSourceID.valid := false.B + io.ALocalMMUIO.ConherentRequsetSourceID.bits := DontCare io.BLocalMMUIO.ConherentRequsetSourceID.bits := DontCare - io.CLocalMMUIO.ConherentRequsetSourceID.bits := DontCare + io.ASLocalMMUIO.foreach(_.ConherentRequsetSourceID.bits := DontCare) + io.BSLocalMMUIO.foreach(_.ConherentRequsetSourceID.bits := DontCare) + io.CLoadLocalMMUIO.ConherentRequsetSourceID.bits := DontCare + io.CStoreLocalMMUIO.ConherentRequsetSourceID.bits := DontCare + io.ALocalMMUIO.nonConherentRequsetSourceID.valid := false.B io.BLocalMMUIO.nonConherentRequsetSourceID.valid := false.B - io.CLocalMMUIO.nonConherentRequsetSourceID.valid := false.B + io.ASLocalMMUIO.foreach(_.nonConherentRequsetSourceID.valid := false.B) + io.BSLocalMMUIO.foreach(_.nonConherentRequsetSourceID.valid := false.B) + io.CLoadLocalMMUIO.nonConherentRequsetSourceID.valid := false.B + io.CStoreLocalMMUIO.nonConherentRequsetSourceID.valid := false.B + io.ALocalMMUIO.nonConherentRequsetSourceID.bits := DontCare io.BLocalMMUIO.nonConherentRequsetSourceID.bits := DontCare - io.CLocalMMUIO.nonConherentRequsetSourceID.bits := DontCare - // io.DLocalMMUIO.Request.ready := false.B - //如果sourceid是valid,则LLC可以接受这个请求,开始送入到LLC的访存端口 - //这里得到谁先服务,送入LLC的访存端口,如果这里需要切流水也简单,提前锁定sourceid即可,将TLnode内的sourceid锁定的逻辑放到这里来写 - // val sourceid2port = VecInit(Seq.fill(LLCSourceMaxNum)(RegInit(0.U(log2Ceil(LocalMMUTaskType.TaskTypeMax).W)))) - val sourceid2port = RegInit(VecInit(Seq.fill(LLCSourceMaxNum)(0.U(log2Ceil(LocalMMUTaskType.TaskTypeMax).W)))) - //输出一下sourceid2port的数据类型 - println("[LocalMMU] sourceid2port: " + sourceid2port) + io.ASLocalMMUIO.foreach(_.nonConherentRequsetSourceID.bits := DontCare) + io.BSLocalMMUIO.foreach(_.nonConherentRequsetSourceID.bits := DontCare) + io.CLoadLocalMMUIO.nonConherentRequsetSourceID.bits := DontCare + io.CStoreLocalMMUIO.nonConherentRequsetSourceID.bits := DontCare + val sourceid2port = RegInit(VecInit(Seq.fill(LLCSourceMaxNum)(0.U(log2Ceil(LocalMMUTaskType.TaskTypeMax).W)))) io.LastLevelCacheTLIO.Request.valid := false.B io.LastLevelCacheTLIO.Request.bits := DontCare io.LastLevelCacheTLIO.Response.ready := false.B - // //输出ABC的信息和valid和hasrequest - // printf(p"ALocalMMUIO ${io.ALocalMMUIO.Request.bits} request_valid ${io.ALocalMMUIO.Request.valid} ${io.ALocalMMUIO.Request.ready} ${io.ALocalMMUIO.Response}\n") - // printf(p"BLocalMMUIO ${io.BLocalMMUIO.Request.bits} request_valid ${io.BLocalMMUIO.Request.valid} ${io.BLocalMMUIO.Request.ready} ${io.BLocalMMUIO.Response}\n") - // printf(p"CLocalMMUIO ${io.CLocalMMUIO.Request.bits} request_valid ${io.CLocalMMUIO.Request.valid} ${io.CLocalMMUIO.Request.ready} ${io.CLocalMMUIO.Response}\n") - // //输出io.LastLevelCacheTLIO.ConherentRequsetSourceID - // printf(p"ConherentRequsetSourceID ${io.LastLevelCacheTLIO.ConherentRequsetSourceID}\n") - // printf(p"HasRequest ${HasRequest}\n") - // printf(p"ChoseIndex_0 ${ChoseIndex_0}\n") - // val last_sourceid = RegInit(0.U(LLCSourceMaxNumBitSize.W)) - - - //如果HasRequest,输出其他两个信息 - when(HasRequest) - { - //输出io.LastLevelCacheTLIO.ConherentRequsetSourceID.valid - //输出io.LastLevelCacheTLIO.Request.ready - // printf(p"[localmmu]io.LastLevelCacheTLIO.ConherentRequsetSourceID.valid ${io.LastLevelCacheTLIO.ConherentRequsetSourceID.valid} io.LastLevelCacheTLIO.Request.ready ${io.LastLevelCacheTLIO.Request.ready}\n") - } + val selectedRequestMask = WireDefault(Fill(MMUMaskWidth, 1.U(1.W))) when(io.LastLevelCacheTLIO.ConherentRequsetSourceID.valid && HasRequest) { - // printf(p"last_sourceid ${last_sourceid} last_sourceid2port ${sourceid2port(last_sourceid)}\n") - // last_sourceid := io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits switch(ChoseIndex_0) { is(LocalMMUTaskType.AFirst) { io.ALocalMMUIO.Request.ready := io.LastLevelCacheTLIO.Request.ready io.ALocalMMUIO.ConherentRequsetSourceID := io.LastLevelCacheTLIO.ConherentRequsetSourceID io.LastLevelCacheTLIO.Request.bits.RequestPhysicalAddr := io.ALocalMMUIO.Request.bits.RequestVirtualAddr io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite := false.B + selectedRequestMask := io.ALocalMMUIO.Request.bits.RequestMask sourceid2port(io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits) := LocalMMUTaskType.AFirst - io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := false.B // A matrix is tile matrix register + io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := false.B + } + is(LocalMMUTaskType.AScaleFirst){ + io.ASLocalMMUIO.foreach { asLocalMMUIO => + asLocalMMUIO.Request.ready := io.LastLevelCacheTLIO.Request.ready + asLocalMMUIO.ConherentRequsetSourceID := io.LastLevelCacheTLIO.ConherentRequsetSourceID + io.LastLevelCacheTLIO.Request.bits.RequestData := asLocalMMUIO.Request.bits.RequestData + io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite := asLocalMMUIO.Request.bits.RequestType_isWrite + sourceid2port(io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits) := LocalMMUTaskType.AScaleFirst + io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := false.B + } } is(LocalMMUTaskType.BFirst) { io.BLocalMMUIO.Request.ready := io.LastLevelCacheTLIO.Request.ready io.BLocalMMUIO.ConherentRequsetSourceID := io.LastLevelCacheTLIO.ConherentRequsetSourceID io.LastLevelCacheTLIO.Request.bits.RequestPhysicalAddr := io.BLocalMMUIO.Request.bits.RequestVirtualAddr io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite := false.B + selectedRequestMask := io.BLocalMMUIO.Request.bits.RequestMask sourceid2port(io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits) := LocalMMUTaskType.BFirst - io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := false.B // B matrix is tile matrix register + io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := false.B } - is(LocalMMUTaskType.CFirst) { - io.CLocalMMUIO.Request.ready := io.LastLevelCacheTLIO.Request.ready - io.CLocalMMUIO.ConherentRequsetSourceID := io.LastLevelCacheTLIO.ConherentRequsetSourceID - io.LastLevelCacheTLIO.Request.bits.RequestPhysicalAddr := io.CLocalMMUIO.Request.bits.RequestVirtualAddr - io.LastLevelCacheTLIO.Request.bits.RequestData := io.CLocalMMUIO.Request.bits.RequestData - io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite := io.CLocalMMUIO.Request.bits.RequestType_isWrite - sourceid2port(io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits) := LocalMMUTaskType.CFirst - io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := true.B // C matrix is accumulation matrix register + is(LocalMMUTaskType.BScaleFirst) { + io.BSLocalMMUIO.foreach { bsLocalMMUIO => + bsLocalMMUIO.Request.ready := io.LastLevelCacheTLIO.Request.ready + bsLocalMMUIO.ConherentRequsetSourceID := io.LastLevelCacheTLIO.ConherentRequsetSourceID + io.LastLevelCacheTLIO.Request.bits.RequestData := bsLocalMMUIO.Request.bits.RequestData + io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite := bsLocalMMUIO.Request.bits.RequestType_isWrite + sourceid2port(io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits) := LocalMMUTaskType.BScaleFirst + io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := false.B + } + } + is(LocalMMUTaskType.CLoadFirst) { + io.CLoadLocalMMUIO.Request.ready := io.LastLevelCacheTLIO.Request.ready + io.CLoadLocalMMUIO.ConherentRequsetSourceID := io.LastLevelCacheTLIO.ConherentRequsetSourceID + io.LastLevelCacheTLIO.Request.bits.RequestPhysicalAddr := io.CLoadLocalMMUIO.Request.bits.RequestVirtualAddr + io.LastLevelCacheTLIO.Request.bits.RequestData := io.CLoadLocalMMUIO.Request.bits.RequestData + io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite := io.CLoadLocalMMUIO.Request.bits.RequestType_isWrite + selectedRequestMask := io.CLoadLocalMMUIO.Request.bits.RequestMask + sourceid2port(io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits) := LocalMMUTaskType.CLoadFirst + io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := true.B + } + is(LocalMMUTaskType.CStoreFirst) { + io.CStoreLocalMMUIO.Request.ready := io.LastLevelCacheTLIO.Request.ready + io.CStoreLocalMMUIO.ConherentRequsetSourceID := io.LastLevelCacheTLIO.ConherentRequsetSourceID + io.LastLevelCacheTLIO.Request.bits.RequestPhysicalAddr := io.CStoreLocalMMUIO.Request.bits.RequestVirtualAddr + io.LastLevelCacheTLIO.Request.bits.RequestData := io.CStoreLocalMMUIO.Request.bits.RequestData + io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite := io.CStoreLocalMMUIO.Request.bits.RequestType_isWrite + selectedRequestMask := io.CStoreLocalMMUIO.Request.bits.RequestMask + sourceid2port(io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits) := LocalMMUTaskType.CStoreFirst + io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := true.B } } - // TODO: Support Request Mask - io.LastLevelCacheTLIO.Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) + io.LastLevelCacheTLIO.Request.bits.RequestMask := selectedRequestMask io.LastLevelCacheTLIO.Request.bits.RequestConherent := true.B io.LastLevelCacheTLIO.Request.bits.RequestSourceID := io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits io.LastLevelCacheTLIO.Request.valid := true.B @@ -128,23 +157,46 @@ class LocalMMU()(implicit p: Parameters) extends CuteModule{ io.ALocalMMUIO.Response.bits := io.LastLevelCacheTLIO.Response.bits io.BLocalMMUIO.Response.bits := io.LastLevelCacheTLIO.Response.bits - io.CLocalMMUIO.Response.bits := io.LastLevelCacheTLIO.Response.bits + io.ASLocalMMUIO.foreach(_.Response.bits := io.LastLevelCacheTLIO.Response.bits) + io.BSLocalMMUIO.foreach(_.Response.bits := io.LastLevelCacheTLIO.Response.bits) + io.CLoadLocalMMUIO.Response.bits := io.LastLevelCacheTLIO.Response.bits + io.CStoreLocalMMUIO.Response.bits := io.LastLevelCacheTLIO.Response.bits + io.ALocalMMUIO.Response.valid := false.B io.BLocalMMUIO.Response.valid := false.B - io.CLocalMMUIO.Response.valid := false.B + io.ASLocalMMUIO.foreach(_.Response.valid := false.B) + io.BSLocalMMUIO.foreach(_.Response.valid := false.B) + io.CLoadLocalMMUIO.Response.valid := false.B + io.CStoreLocalMMUIO.Response.valid := false.B switch(sourceid2port(io.LastLevelCacheTLIO.Response.bits.ReseponseSourceID)) { is(LocalMMUTaskType.AFirst) { io.ALocalMMUIO.Response.valid := io.LastLevelCacheTLIO.Response.valid io.LastLevelCacheTLIO.Response.ready := io.ALocalMMUIO.Response.ready } + is(LocalMMUTaskType.AScaleFirst) { + io.ASLocalMMUIO.foreach { asLocalMMUIO => + asLocalMMUIO.Response.valid := io.LastLevelCacheTLIO.Response.valid + io.LastLevelCacheTLIO.Response.ready := asLocalMMUIO.Response.ready + } + } is(LocalMMUTaskType.BFirst) { io.BLocalMMUIO.Response.valid := io.LastLevelCacheTLIO.Response.valid io.LastLevelCacheTLIO.Response.ready := io.BLocalMMUIO.Response.ready } - is(LocalMMUTaskType.CFirst) { - io.CLocalMMUIO.Response.valid := io.LastLevelCacheTLIO.Response.valid - io.LastLevelCacheTLIO.Response.ready := io.CLocalMMUIO.Response.ready + is(LocalMMUTaskType.BScaleFirst){ + io.BSLocalMMUIO.foreach { bsLocalMMUIO => + bsLocalMMUIO.Response.valid := io.LastLevelCacheTLIO.Response.valid + io.LastLevelCacheTLIO.Response.ready := bsLocalMMUIO.Response.ready + } + } + is(LocalMMUTaskType.CLoadFirst) { + io.CLoadLocalMMUIO.Response.valid := io.LastLevelCacheTLIO.Response.valid + io.LastLevelCacheTLIO.Response.ready := io.CLoadLocalMMUIO.Response.ready + } + is(LocalMMUTaskType.CStoreFirst) { + io.CStoreLocalMMUIO.Response.valid := io.LastLevelCacheTLIO.Response.valid + io.LastLevelCacheTLIO.Response.ready := io.CStoreLocalMMUIO.Response.ready } } @@ -152,6 +204,19 @@ class LocalMMU()(implicit p: Parameters) extends CuteModule{ XSPerfAccumulate("CUTE_MMU_A_wr_request", io.ALocalMMUIO.Request.fire & io.ALocalMMUIO.Request.bits.RequestType_isWrite) XSPerfAccumulate("CUTE_MMU_B_rd_request", io.BLocalMMUIO.Request.fire & !io.BLocalMMUIO.Request.bits.RequestType_isWrite) XSPerfAccumulate("CUTE_MMU_B_wr_request", io.BLocalMMUIO.Request.fire & io.BLocalMMUIO.Request.bits.RequestType_isWrite) - XSPerfAccumulate("CUTE_MMU_C_rd_request", io.CLocalMMUIO.Request.fire & !io.CLocalMMUIO.Request.bits.RequestType_isWrite) - XSPerfAccumulate("CUTE_MMU_C_wr_request", io.CLocalMMUIO.Request.fire & io.CLocalMMUIO.Request.bits.RequestType_isWrite) + + val cLoadRd = io.CLoadLocalMMUIO.Request.fire & !io.CLoadLocalMMUIO.Request.bits.RequestType_isWrite + val cStoreRd = io.CStoreLocalMMUIO.Request.fire & !io.CStoreLocalMMUIO.Request.bits.RequestType_isWrite + val cLoadWr = io.CLoadLocalMMUIO.Request.fire & io.CLoadLocalMMUIO.Request.bits.RequestType_isWrite + val cStoreWr = io.CStoreLocalMMUIO.Request.fire & io.CStoreLocalMMUIO.Request.bits.RequestType_isWrite + XSPerfAccumulate("CUTE_MMU_C_rd_request", cLoadRd || cStoreRd) + XSPerfAccumulate("CUTE_MMU_C_wr_request", cLoadWr || cStoreWr) + + val outReqFire = io.LastLevelCacheTLIO.Request.fire + val outReqIsWr = io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite + val outReqMask32B = PopCount(io.LastLevelCacheTLIO.Request.bits.RequestMask) >> 5 + io.perfProbe.rdReq := outReqFire && !outReqIsWr + io.perfProbe.wrReq := outReqFire && outReqIsWr + io.perfProbe.rd32BReq := Mux(outReqFire && !outReqIsWr, outReqMask32B, 0.U).asUInt + io.perfProbe.wr32BReq := Mux(outReqFire && outReqIsWr, outReqMask32B, 0.U).asUInt } diff --git a/src/main/scala/TaskController.scala b/src/main/scala/TaskController.scala index 334f3a0..aebfcb8 100644 --- a/src/main/scala/TaskController.scala +++ b/src/main/scala/TaskController.scala @@ -3,21 +3,34 @@ package cute import chisel3._ import chisel3.util._ import org.chipsalliance.cde.config._ +import freechips.rocketchip.util._ import cute.Bundles._ -import cute.ElementDataType._ import difftest._ import utility.ChiselDB +/* + * TaskController scheduling overview: + * - Decode AMU instructions into a fixed-size issue window. + * - Build dependencies at enqueue time from static read/write footprints. + * - Scan the window oldest-first and issue at most one ready slot per cycle. + * - Route completion back through FU ownerSlot; retire only from the window head. + * - Release and NopLike stay local; ZeroAcc and ZeroTr map to load-like aliases. + */ class TaskControllerIO(implicit p: Parameters) extends CuteBundle { val ygjkctrl = Flipped(new YGJKControl) val ADC_MicroTask_Config = new ADCMicroTaskConfigIO val BDC_MicroTask_Config = new BDCMicroTaskConfigIO + val ASC_MicroTask_Config = Option.when(cuteMatrixExtension.enableScalingFactor)(new ASCMicroTaskConfigIO) + val BSC_MicroTask_Config = Option.when(cuteMatrixExtension.enableScalingFactor)(new BSCMicroTaskConfigIO) val CDC_MicroTask_Config = new CDCMicroTaskConfigIO val AML_MicroTask_Config = new AMLMicroTaskConfigIO val BML_MicroTask_Config = new BMLMicroTaskConfigIO + val ASL_MicroTask_Config = Option.when(cuteMatrixExtension.enableScalingFactor)(new ASLMicroTaskConfigIO) + val BSL_MicroTask_Config = Option.when(cuteMatrixExtension.enableScalingFactor)(new BSLMicroTaskConfigIO) val CML_MicroTask_Config = new CMLMicroTaskConfigIO val MTE_MicroTask_Config = new MTEMicroTaskConfigIO val DebugTimeStampe = Input(UInt(32.W)) + val perfProbe = Output(new TaskControllerPerfProbe) } abstract class BaseTaskController(implicit p: Parameters) extends CuteModule { @@ -40,16 +53,28 @@ class DecodedAmuCtrlEntry(implicit p: Parameters) extends CuteBundle { val writeValid = Vec(MaxWriteRegs, Bool()) } +object TaskCtrlOpKind extends ChiselEnum { + val NopLike, LoadA, LoadB, LoadC, Compute, Store, Release, ZeroAcc, ZeroTr = Value +} + class TaskController(implicit p: Parameters) extends BaseTaskController { import NewTaskController._ dontTouch(io) + // Scheduler flow: decode -> window/dependency build -> oldest-ready issue -> done/retire/enqueue. + // Release/NopLike stay scheduler-local; ZeroAcc/ZeroTr remain mzero-like load aliases. + + private val WinDepth = TaskCtrlIssueWindowDepth + private val SlotIdxWidth = log2Ceil(WinDepth) + private val WinCountWidth = log2Ceil(WinDepth + 1) + private val SeqIdWidth = 16 + + // ===================== Default output assignments ===================== io.ygjkctrl.mrelease.valid := false.B io.ygjkctrl.mrelease.bits := 0.U.asTypeOf(new MreleaseIO) - // 默认输出赋值 - io.ADC_MicroTask_Config.ApplicationTensor_A.dataType := 0.U + io.ADC_MicroTask_Config.dataType := 0.U io.ADC_MicroTask_Config.MatrixRegTensor_M := 0.U io.ADC_MicroTask_Config.MatrixRegTensor_K := 0.U io.ADC_MicroTask_Config.MatrixRegTensor_N := 0.U @@ -58,7 +83,18 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.ADC_MicroTask_Config.MicroTaskValid := false.B io.ADC_MicroTask_Config.MicroTaskEndReady := false.B - io.BDC_MicroTask_Config.ApplicationTensor_B.dataType := 0.U + io.ASC_MicroTask_Config.foreach { cfg => + cfg.MatrixRegTensor_M := 0.U + cfg.MatrixRegTensor_K := 0.U + cfg.MatrixRegTensor_N := 0.U + cfg.MatrixRegId := 0.U + cfg.computeType := MteComputeType.ComputeTypeUndef + cfg.Is_Transpose := false.B + cfg.MicroTaskValid := false.B + cfg.MicroTaskEndReady := false.B + } + + io.BDC_MicroTask_Config.dataType := 0.U io.BDC_MicroTask_Config.MatrixRegTensor_M := 0.U io.BDC_MicroTask_Config.MatrixRegTensor_K := 0.U io.BDC_MicroTask_Config.MatrixRegTensor_N := 0.U @@ -67,6 +103,17 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.BDC_MicroTask_Config.MicroTaskValid := false.B io.BDC_MicroTask_Config.MicroTaskEndReady := false.B + io.BSC_MicroTask_Config.foreach { cfg => + cfg.MatrixRegTensor_M := 0.U + cfg.MatrixRegTensor_K := 0.U + cfg.MatrixRegTensor_N := 0.U + cfg.MatrixRegId := 0.U + cfg.computeType := MteComputeType.ComputeTypeUndef + cfg.Is_Transpose := false.B + cfg.MicroTaskValid := false.B + cfg.MicroTaskEndReady := false.B + } + io.CDC_MicroTask_Config.ApplicationTensor_C.dataType := 0.U io.CDC_MicroTask_Config.ApplicationTensor_D.dataType := 0.U io.CDC_MicroTask_Config.MatrixRegTensor_M := 0.U @@ -91,6 +138,7 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.AML_MicroTask_Config.MatrixRegTensor_M := 0.U io.AML_MicroTask_Config.MatrixRegTensor_K := 0.U io.AML_MicroTask_Config.Conherent := false.B + io.AML_MicroTask_Config.Is_Transpose := false.B io.AML_MicroTask_Config.MatrixRegId := 0.U io.AML_MicroTask_Config.MicroTaskValid := false.B io.AML_MicroTask_Config.MicroTaskEndReady := false.B @@ -99,10 +147,20 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.AML_MicroTask_Config.coreid.get := 0.U } + io.ASL_MicroTask_Config.foreach { cfg => + cfg.ApplicationScale_A := 0.U.asTypeOf(cfg.ApplicationScale_A) + cfg.MatrixRegTensor_M := 0.U + cfg.MatrixRegTensor_K := 0.U + cfg.Conherent := false.B + cfg.MicroTaskValid := false.B + cfg.MicroTaskEndReady := false.B + } + io.BML_MicroTask_Config.ApplicationTensor_B := 0.U.asTypeOf(io.BML_MicroTask_Config.ApplicationTensor_B) io.BML_MicroTask_Config.MatrixRegTensor_N := 0.U io.BML_MicroTask_Config.MatrixRegTensor_K := 0.U io.BML_MicroTask_Config.Conherent := false.B + io.BML_MicroTask_Config.Is_Transpose := false.B io.BML_MicroTask_Config.MatrixRegId := 0.U io.BML_MicroTask_Config.MicroTaskValid := false.B io.BML_MicroTask_Config.MicroTaskEndReady := false.B @@ -111,30 +169,37 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.BML_MicroTask_Config.coreid.get := 0.U } + io.BSL_MicroTask_Config.foreach { cfg => + cfg.ApplicationScale_B := 0.U.asTypeOf(cfg.ApplicationScale_B) + cfg.MatrixRegTensor_N := 0.U + cfg.MatrixRegTensor_K := 0.U + cfg.Conherent := false.B + cfg.MicroTaskValid := false.B + cfg.MicroTaskEndReady := false.B + } + io.CML_MicroTask_Config.ApplicationTensor_C := 0.U.asTypeOf(io.CML_MicroTask_Config.ApplicationTensor_C) io.CML_MicroTask_Config.ApplicationTensor_D := 0.U.asTypeOf(io.CML_MicroTask_Config.ApplicationTensor_D) io.CML_MicroTask_Config.LoadTaskInfo := 0.U.asTypeOf(io.CML_MicroTask_Config.LoadTaskInfo) - io.CML_MicroTask_Config.StoreTaskInfo := 0.U.asTypeOf(io.CML_MicroTask_Config.StoreTaskInfo) io.CML_MicroTask_Config.Conherent := false.B io.CML_MicroTask_Config.Is_Transpose := false.B io.CML_MicroTask_Config.MatrixRegTensor_M := 0.U io.CML_MicroTask_Config.MatrixRegTensor_N := 0.U io.CML_MicroTask_Config.MatrixRegId := 0.U - io.CML_MicroTask_Config.IsLoadMicroTask := false.B - io.CML_MicroTask_Config.IsStoreMicroTask := false.B - io.CML_MicroTask_Config.MicroTaskValid := false.B - io.CML_MicroTask_Config.MicroTaskEndReady := false.B + io.CML_MicroTask_Config.LoadMicroTaskValid := false.B + io.CML_MicroTask_Config.LoadMicroTaskEndReady := false.B + io.CML_MicroTask_Config.StoreMicroTaskValid := false.B + io.CML_MicroTask_Config.StoreMicroTaskEndReady := false.B if (EnableDifftest) { io.CML_MicroTask_Config.pc.get := 0.U io.CML_MicroTask_Config.coreid.get := 0.U } - io.MTE_MicroTask_Config.dataType := 0.U + io.MTE_MicroTask_Config.MicroTaskValid := false.B + io.MTE_MicroTask_Config.computeType := MteComputeType.ComputeTypeUndef + io.perfProbe := 0.U.asTypeOf(new TaskControllerPerfProbe) - // Scoreboard实例 - private val scoreboard = Module(new Scoreboard) - - // ===================== ChiselDB 事件定义 ===================== + // ===================== ChiselDB event definitions ===================== private val TileDimWidth = Bundles.Mtilex.width private val LoadFifoIdxWidth = 4 private val ComputeFifoIdxWidth = 4 @@ -149,6 +214,8 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { val column = UInt(TileDimWidth.W) val transpose = Bool() val isAcc = Bool() + val slotId = UInt(SlotIdxWidth.W) + val seqId = UInt(SeqIdWidth.W) } class ComputeEventEntry extends Bundle { @@ -162,6 +229,8 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { val mtilek = UInt(TileDimWidth.W) val isMma = Bool() val isFp = Bool() + val slotId = UInt(SlotIdxWidth.W) + val seqId = UInt(SeqIdWidth.W) } class StoreEventEntry extends Bundle { @@ -172,11 +241,15 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { val column = UInt(TileDimWidth.W) val transpose = Bool() val isAcc = Bool() + val slotId = UInt(SlotIdxWidth.W) + val seqId = UInt(SeqIdWidth.W) } class ReleaseEventEntry extends Bundle { val eventType = UInt(2.W) val token = UInt(5.W) + val slotId = UInt(SlotIdxWidth.W) + val seqId = UInt(SeqIdWidth.W) } private val loadEventTable = ChiselDB.createTable("CUTELoadEvent", new LoadEventEntry, basicDB = true) @@ -212,895 +285,1169 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { private val releaseIssueEvent = WireInit(0.U.asTypeOf(new ReleaseEventEntry)) private val releaseIssueEventEn = WireInit(false.B) - scoreboard.io.update.load_allocate := false.B - scoreboard.io.update.load_alloc_a_reg := 0.U - scoreboard.io.update.load_alloc_b_reg := 0.U - scoreboard.io.update.load_alloc_c_reg := 0.U - scoreboard.io.update.load_alloc_has_a := false.B - scoreboard.io.update.load_alloc_has_b := false.B - scoreboard.io.update.load_alloc_has_c := false.B - scoreboard.io.update.load_alloc_fifo_idx := 0.U - - scoreboard.io.update.load_finish_a := false.B - scoreboard.io.update.load_finish_a_reg := 0.U - scoreboard.io.update.load_finish_b := false.B - scoreboard.io.update.load_finish_b_reg := 0.U - scoreboard.io.update.load_finish_c := false.B - scoreboard.io.update.load_finish_c_reg := 0.U - - scoreboard.io.update.compute_issue := false.B - scoreboard.io.update.compute_issue_a_reg := 0.U - scoreboard.io.update.compute_issue_b_reg := 0.U - scoreboard.io.update.compute_issue_c_reg := 0.U - scoreboard.io.update.compute_issue_fifo_idx := 0.U - - scoreboard.io.update.compute_read_finish_a := false.B - scoreboard.io.update.compute_read_finish_a_reg := 0.U - scoreboard.io.update.compute_read_finish_b := false.B - scoreboard.io.update.compute_read_finish_b_reg := 0.U - - scoreboard.io.update.compute_write_finish_c := false.B - scoreboard.io.update.compute_write_finish_c_reg := 0.U - - scoreboard.io.update.store_issue := false.B - scoreboard.io.update.store_issue_c_reg := 0.U - scoreboard.io.update.store_issue_fifo_idx := 0.U - - scoreboard.io.update.store_finish := false.B - scoreboard.io.update.store_finish_c_reg := 0.U - - // Pending bookkeeping for outstanding micro tasks - val loadAllocIdx = RegInit(0.U(2.W)) - val computeIssueIdx = RegInit(0.U(2.W)) - val storeIssueIdx = RegInit(0.U(2.W)) - - val pendingLoadA = RegInit(false.B) - val pendingLoadAReg = RegInit(0.U(2.W)) - val pendingLoadAFifoIdx = RegInit(0.U(LoadFifoIdxWidth.W)) - val pendingLoadB = RegInit(false.B) - val pendingLoadBReg = RegInit(0.U(2.W)) - val pendingLoadBFifoIdx = RegInit(0.U(LoadFifoIdxWidth.W)) - val pendingLoadC = RegInit(false.B) - val pendingLoadCReg = RegInit(0.U(2.W)) - val pendingLoadCFifoIdx = RegInit(0.U(LoadFifoIdxWidth.W)) - val pendingLoadRow = RegInit(0.U(TileDimWidth.W)) - val pendingLoadColumn = RegInit(0.U(TileDimWidth.W)) - val pendingLoadTranspose = RegInit(false.B) - - val pendingComputeA = RegInit(false.B) - val pendingComputeAReg = RegInit(0.U(2.W)) - val pendingComputeAFifoIdx = RegInit(0.U(ComputeFifoIdxWidth.W)) - val pendingComputeB = RegInit(false.B) - val pendingComputeBReg = RegInit(0.U(2.W)) - val pendingComputeBFifoIdx = RegInit(0.U(ComputeFifoIdxWidth.W)) - val pendingComputeC = RegInit(false.B) - val pendingComputeCReg = RegInit(0.U(2.W)) - val pendingComputeCFifoIdx = RegInit(0.U(ComputeFifoIdxWidth.W)) - val pendingComputeM = RegInit(0.U(TileDimWidth.W)) - val pendingComputeN = RegInit(0.U(TileDimWidth.W)) - val pendingComputeK = RegInit(0.U(TileDimWidth.W)) - val pendingComputeIsMma = RegInit(false.B) - val pendingComputeIsFp = RegInit(false.B) - - val pendingStore = RegInit(false.B) - val pendingStoreReg = RegInit(0.U(2.W)) - val pendingStoreFifoIdx = RegInit(0.U(StoreFifoIdxWidth.W)) - val pendingStoreRow = RegInit(0.U(TileDimWidth.W)) - val pendingStoreColumn = RegInit(0.U(TileDimWidth.W)) - val pendingStoreTranspose = RegInit(false.B) - val pendingStoreIsAcc = RegInit(false.B) - - // Completion handshakes and scoreboard updates - io.AML_MicroTask_Config.MicroTaskEndReady := pendingLoadA - when(pendingLoadA && io.AML_MicroTask_Config.MicroTaskEndValid) { - if (YJPTASKDebugEnable) { - printf("[TaskController_LoadAFinish<%d>] reg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingLoadAReg, pendingLoadAFifoIdx, io.AML_MicroTask_Config.MicroTaskEndValid, io.AML_MicroTask_Config.MicroTaskEndReady) - } - scoreboard.io.update.load_finish_a := true.B - scoreboard.io.update.load_finish_a_reg := pendingLoadAReg - pendingLoadA := false.B - loadAFinishEvent.eventType := 2.U - loadAFinishEvent.regId := pendingLoadAReg - loadAFinishEvent.fifoIdx := pendingLoadAFifoIdx - loadAFinishEvent.needMask := "b001".U - loadAFinishEvent.row := pendingLoadRow - loadAFinishEvent.column := pendingLoadColumn - loadAFinishEvent.transpose := pendingLoadTranspose - loadAFinishEvent.isAcc := false.B - loadAFinishEventEn := true.B + // ===================== Decoded instruction FIFO ===================== + private val decodedFifo = Module(new Queue(new DecodedAmuCtrlEntry, DecodedAmuCtrlFIFODepth)) + + decodedFifo.io.enq.valid := io.ygjkctrl.amuCtrl.valid + io.ygjkctrl.amuCtrl.ready := decodedFifo.io.enq.ready + + val amuCtrlBits = io.ygjkctrl.amuCtrl.bits + val decEntryEnq = Wire(new DecodedAmuCtrlEntry) + decEntryEnq.ctrl := amuCtrlBits + + for (i <- 0 until MaxReadRegs) { + decEntryEnq.readRegs(i) := 0.U + decEntryEnq.readValid(i) := false.B } + for (i <- 0 until MaxWriteRegs) { + decEntryEnq.writeRegs(i) := 0.U + decEntryEnq.writeValid(i) := false.B + } + + val enqMma = decodeMma(amuCtrlBits) + val enqLsu = decodeLsu(amuCtrlBits) + val enqArith = decodeArith(amuCtrlBits) - io.BML_MicroTask_Config.MicroTaskEndReady := pendingLoadB - when(pendingLoadB && io.BML_MicroTask_Config.MicroTaskEndValid) { - if (YJPTASKDebugEnable) { - printf("[TaskController_LoadBFinish<%d>] reg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingLoadBReg, pendingLoadBFifoIdx, io.BML_MicroTask_Config.MicroTaskEndValid, io.BML_MicroTask_Config.MicroTaskEndReady) + when(amuCtrlBits.isMma()) { + val mma = enqMma + decEntryEnq.readRegs(0) := mma.ms1 + decEntryEnq.readRegs(1) := mma.ms2 + decEntryEnq.readRegs(2) := mma.md + decEntryEnq.readValid(0) := true.B + decEntryEnq.readValid(1) := true.B + decEntryEnq.readValid(2) := true.B + decEntryEnq.writeRegs(0) := mma.md + decEntryEnq.writeValid(0) := true.B + } + when(amuCtrlBits.isMls()) { + val lsu = enqLsu + when(lsu.ls === 0.U) { + decEntryEnq.writeRegs(0) := lsu.ms + decEntryEnq.writeValid(0) := true.B + }.otherwise { + decEntryEnq.readRegs(0) := lsu.ms + decEntryEnq.readValid(0) := true.B } - scoreboard.io.update.load_finish_b := true.B - scoreboard.io.update.load_finish_b_reg := pendingLoadBReg - pendingLoadB := false.B - loadBFinishEvent.eventType := 2.U - loadBFinishEvent.regId := pendingLoadBReg - loadBFinishEvent.fifoIdx := pendingLoadBFifoIdx - loadBFinishEvent.needMask := "b010".U - loadBFinishEvent.row := pendingLoadRow - loadBFinishEvent.column := pendingLoadColumn - loadBFinishEvent.transpose := pendingLoadTranspose - loadBFinishEvent.isAcc := false.B - loadBFinishEventEn := true.B } - - io.CML_MicroTask_Config.MicroTaskEndReady := pendingLoadC || pendingStore - when(io.CML_MicroTask_Config.MicroTaskEndValid) { - when(pendingLoadC) { - if (YJPTASKDebugEnable) { - printf("[TaskController_LoadCFinish<%d>] reg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingLoadCReg, pendingLoadCFifoIdx, io.CML_MicroTask_Config.MicroTaskEndValid, io.CML_MicroTask_Config.MicroTaskEndReady) - } - scoreboard.io.update.load_finish_c := true.B - scoreboard.io.update.load_finish_c_reg := pendingLoadCReg - pendingLoadC := false.B - loadCFinishEvent.eventType := 2.U - loadCFinishEvent.regId := pendingLoadCReg - loadCFinishEvent.fifoIdx := pendingLoadCFifoIdx - loadCFinishEvent.needMask := "b100".U - loadCFinishEvent.row := pendingLoadRow - loadCFinishEvent.column := pendingLoadColumn - loadCFinishEvent.transpose := pendingLoadTranspose - loadCFinishEvent.isAcc := true.B - loadCFinishEventEn := true.B - }.elsewhen(pendingStore) { - if (YJPTASKDebugEnable) { - printf("[TaskController_StoreCFinish<%d>] reg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingStoreReg, pendingStoreFifoIdx, io.CML_MicroTask_Config.MicroTaskEndValid, io.CML_MicroTask_Config.MicroTaskEndReady) - } - scoreboard.io.update.store_finish := true.B - scoreboard.io.update.store_finish_c_reg := pendingStoreReg - pendingStore := false.B - storeFinishEvent.eventType := 1.U - storeFinishEvent.regId := pendingStoreReg - storeFinishEvent.fifoIdx := pendingStoreFifoIdx - storeFinishEvent.row := pendingStoreRow - storeFinishEvent.column := pendingStoreColumn - storeFinishEvent.transpose := pendingStoreTranspose - storeFinishEvent.isAcc := pendingStoreIsAcc - storeFinishEventEn := true.B + when(amuCtrlBits.isArith()) { + val arith = enqArith + // Current CUTE flow only supports mzero-style marith. + // Unknown marith opType is treated as NopLike (no read/write footprint). + when(isMzeroLike(arith)) { + decEntryEnq.writeRegs(0) := arith.md + decEntryEnq.writeValid(0) := true.B } } - io.ADC_MicroTask_Config.MicroTaskEndReady := pendingComputeA - when(pendingComputeA && io.ADC_MicroTask_Config.MicroTaskEndValid) { - if (YJPTASKDebugEnable) { - printf("[TaskController_ComputeAFinish<%d>] aReg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingComputeAReg, pendingComputeAFifoIdx, io.ADC_MicroTask_Config.MicroTaskEndValid, io.ADC_MicroTask_Config.MicroTaskEndReady) - } - scoreboard.io.update.compute_read_finish_a := true.B - scoreboard.io.update.compute_read_finish_a_reg := pendingComputeAReg - pendingComputeA := false.B - computeReadAFinishEvent.eventType := 1.U - computeReadAFinishEvent.aReg := pendingComputeAReg - computeReadAFinishEvent.bReg := pendingComputeBReg - computeReadAFinishEvent.cReg := pendingComputeCReg - computeReadAFinishEvent.fifoIdx := pendingComputeAFifoIdx - computeReadAFinishEvent.mtilem := pendingComputeM - computeReadAFinishEvent.mtilen := pendingComputeN - computeReadAFinishEvent.mtilek := pendingComputeK - computeReadAFinishEvent.isMma := pendingComputeIsMma - computeReadAFinishEvent.isFp := pendingComputeIsFp - computeReadAFinishEventEn := true.B + decodedFifo.io.enq.bits := decEntryEnq + + // ===================== Issue window state ===================== + class IssueWindowSlot(implicit p: Parameters) extends CuteBundle { + val valid = Bool() + val issued = Bool() + val completed = Bool() + val readADone = Bool() + val readBDone = Bool() + val waitCompleteMask = UInt(WinDepth.W) + val waitReadAMask = UInt(WinDepth.W) + val waitReadBMask = UInt(WinDepth.W) + val opKind = TaskCtrlOpKind() + val entry = new DecodedAmuCtrlEntry + val seqId = UInt(SeqIdWidth.W) + val fifoIdx = UInt(4.W) } - io.BDC_MicroTask_Config.MicroTaskEndReady := pendingComputeB - when(pendingComputeB && io.BDC_MicroTask_Config.MicroTaskEndValid) { - if (YJPTASKDebugEnable) { - printf("[TaskController_ComputeBFinish<%d>] bReg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingComputeBReg, pendingComputeBFifoIdx, io.BDC_MicroTask_Config.MicroTaskEndValid, io.BDC_MicroTask_Config.MicroTaskEndReady) - } - scoreboard.io.update.compute_read_finish_b := true.B - scoreboard.io.update.compute_read_finish_b_reg := pendingComputeBReg - pendingComputeB := false.B - computeReadBFinishEvent.eventType := 2.U - computeReadBFinishEvent.aReg := pendingComputeAReg - computeReadBFinishEvent.bReg := pendingComputeBReg - computeReadBFinishEvent.cReg := pendingComputeCReg - computeReadBFinishEvent.fifoIdx := pendingComputeBFifoIdx - computeReadBFinishEvent.mtilem := pendingComputeM - computeReadBFinishEvent.mtilen := pendingComputeN - computeReadBFinishEvent.mtilek := pendingComputeK - computeReadBFinishEvent.isMma := pendingComputeIsMma - computeReadBFinishEvent.isFp := pendingComputeIsFp - computeReadBFinishEventEn := true.B + class FuTracker extends Bundle { + val busy = Bool() + val ownerSlot = UInt(SlotIdxWidth.W) } - io.CDC_MicroTask_Config.MicroTaskEndReady := pendingComputeC - when(pendingComputeC && io.CDC_MicroTask_Config.MicroTaskEndValid) { - if (YJPTASKDebugEnable) { - printf("[TaskController_ComputeCFinish<%d>] cReg=%d fifo=%d endValid=%d endReady=%d\n", io.DebugTimeStampe, pendingComputeCReg, pendingComputeCFifoIdx, io.CDC_MicroTask_Config.MicroTaskEndValid, io.CDC_MicroTask_Config.MicroTaskEndReady) - } - scoreboard.io.update.compute_write_finish_c := true.B - scoreboard.io.update.compute_write_finish_c_reg := pendingComputeCReg - pendingComputeC := false.B - computeWriteCFinishEvent.eventType := 3.U - computeWriteCFinishEvent.aReg := pendingComputeAReg - computeWriteCFinishEvent.bReg := pendingComputeBReg - computeWriteCFinishEvent.cReg := pendingComputeCReg - computeWriteCFinishEvent.fifoIdx := pendingComputeCFifoIdx - computeWriteCFinishEvent.mtilem := pendingComputeM - computeWriteCFinishEvent.mtilen := pendingComputeN - computeWriteCFinishEvent.mtilek := pendingComputeK - computeWriteCFinishEvent.isMma := pendingComputeIsMma - computeWriteCFinishEvent.isFp := pendingComputeIsFp - computeWriteCFinishEventEn := true.B + val slots = RegInit(VecInit(Seq.fill(WinDepth)(0.U.asTypeOf(new IssueWindowSlot)))) + val winHead = RegInit(0.U(SlotIdxWidth.W)) + val winTail = RegInit(0.U(SlotIdxWidth.W)) + val winCount = RegInit(0.U(WinCountWidth.W)) + val seqIdAlloc = RegInit(0.U(SeqIdWidth.W)) + + val loadFifoIdxAlloc = RegInit(0.U(LoadFifoIdxWidth.W)) + val computeIssueIdx = RegInit(0.U(ComputeFifoIdxWidth.W)) + val storeIssueIdx = RegInit(0.U(StoreFifoIdxWidth.W)) + + val fuAML = RegInit(0.U.asTypeOf(new FuTracker)) + val fuBML = RegInit(0.U.asTypeOf(new FuTracker)) + val fuCMLLoad = RegInit(0.U.asTypeOf(new FuTracker)) + val fuCompute = RegInit(0.U.asTypeOf(new FuTracker)) + val fuCMLStore = RegInit(0.U.asTypeOf(new FuTracker)) + + private def slotOH(idx: UInt): UInt = UIntToOH(idx, WinDepth) + private def abIdx(x: UInt): UInt = x(ABMatrixRegIdWidth - 1, 0) + private def cIdx(x: UInt): UInt = x(CMatrixRegIdWidth - 1, 0) + + private def decodeMma(ctrl: AmuCtrlIO): AmuMmaIO = ctrl.data.asTypeOf(new AmuMmaIO) + private def decodeLsu(ctrl: AmuCtrlIO): AmuLsuIO = ctrl.data.asTypeOf(new AmuLsuIO) + private def decodeArith(ctrl: AmuCtrlIO): AmuArithIO = ctrl.data.asTypeOf(new AmuArithIO) + private def isMzeroLike(arith: AmuArithIO): Bool = arith.opType(8, 2) === "b1101110".U + + private def isReadAB(op: TaskCtrlOpKind.Type, mma: AmuMmaIO, lsu: AmuLsuIO, reg: UInt, treatStoreAsABRead: Bool): Bool = { + (op === TaskCtrlOpKind.Compute) && (abIdx(mma.ms1) === abIdx(reg) || abIdx(mma.ms2) === abIdx(reg)) || + (op === TaskCtrlOpKind.Store) && treatStoreAsABRead && (abIdx(lsu.ms) === abIdx(reg)) } - // 解码后的指令FIFO - private val decodedFifo = Module(new Queue(new DecodedAmuCtrlEntry, DecodedAmuCtrlFIFODepth)) + private def isReadC(op: TaskCtrlOpKind.Type, mma: AmuMmaIO, lsu: AmuLsuIO, reg: UInt, treatStoreAsCRead: Bool): Bool = { + ((op === TaskCtrlOpKind.Compute) && (cIdx(mma.md) === cIdx(reg))) || + ((op === TaskCtrlOpKind.Store) && treatStoreAsCRead && (cIdx(lsu.ms) === cIdx(reg))) + } - // FIFO出队暂未使用 - decodedFifo.io.deq.ready := false.B + private def isWriteAB(op: TaskCtrlOpKind.Type, lsu: AmuLsuIO, arith: AmuArithIO, reg: UInt): Bool = { + ((op === TaskCtrlOpKind.LoadA || op === TaskCtrlOpKind.LoadB) && (abIdx(lsu.ms) === abIdx(reg))) || + ((op === TaskCtrlOpKind.ZeroTr) && (abIdx(arith.md) === abIdx(reg))) + } - // AMU指令译码 - decodedFifo.io.enq.valid := io.ygjkctrl.amuCtrl.valid - io.ygjkctrl.amuCtrl.ready := decodedFifo.io.enq.ready + private def isWriteC(op: TaskCtrlOpKind.Type, mma: AmuMmaIO, lsu: AmuLsuIO, arith: AmuArithIO, reg: UInt): Bool = { + ((op === TaskCtrlOpKind.Compute) && (cIdx(mma.md) === cIdx(reg))) || + ((op === TaskCtrlOpKind.LoadC) && (cIdx(lsu.ms) === cIdx(reg))) || + ((op === TaskCtrlOpKind.ZeroAcc) && (cIdx(arith.md) === cIdx(reg))) + } - val amuCtrlBits = io.ygjkctrl.amuCtrl.bits + private def isReadAOfCompute(op: TaskCtrlOpKind.Type, mma: AmuMmaIO, reg: UInt): Bool = { + (op === TaskCtrlOpKind.Compute) && (abIdx(mma.ms1) === abIdx(reg)) + } - val entry = Wire(new DecodedAmuCtrlEntry) - entry.ctrl := amuCtrlBits + private def isReadBOfCompute(op: TaskCtrlOpKind.Type, mma: AmuMmaIO, reg: UInt): Bool = { + (op === TaskCtrlOpKind.Compute) && (abIdx(mma.ms2) === abIdx(reg)) + } - // 默认清零 - for (i <- 0 until MaxReadRegs) { - entry.readRegs(i) := 0.U - entry.readValid(i) := false.B + // ===================== Done handshakes (level-ready) ===================== + val amlOwnerValid = fuAML.busy && slots(fuAML.ownerSlot).valid + val bmlOwnerValid = fuBML.busy && slots(fuBML.ownerSlot).valid + val cmlLoadOwnerValid = fuCMLLoad.busy && slots(fuCMLLoad.ownerSlot).valid + val cmlStoreOwnerValid = fuCMLStore.busy && slots(fuCMLStore.ownerSlot).valid + val computeOwnerValid = fuCompute.busy && slots(fuCompute.ownerSlot).valid + + io.AML_MicroTask_Config.MicroTaskEndReady := amlOwnerValid + io.BML_MicroTask_Config.MicroTaskEndReady := bmlOwnerValid + io.CML_MicroTask_Config.LoadMicroTaskEndReady := cmlLoadOwnerValid + io.CML_MicroTask_Config.StoreMicroTaskEndReady := cmlStoreOwnerValid + io.ADC_MicroTask_Config.MicroTaskEndReady := computeOwnerValid + io.BDC_MicroTask_Config.MicroTaskEndReady := computeOwnerValid + io.CDC_MicroTask_Config.MicroTaskEndReady := computeOwnerValid + + val amlDone = io.AML_MicroTask_Config.MicroTaskEndValid && io.AML_MicroTask_Config.MicroTaskEndReady + val bmlDone = io.BML_MicroTask_Config.MicroTaskEndValid && io.BML_MicroTask_Config.MicroTaskEndReady + val cmlLoadDone = io.CML_MicroTask_Config.LoadMicroTaskEndValid && io.CML_MicroTask_Config.LoadMicroTaskEndReady + val cmlStoreDone = io.CML_MicroTask_Config.StoreMicroTaskEndValid && io.CML_MicroTask_Config.StoreMicroTaskEndReady + val adcDone = io.ADC_MicroTask_Config.MicroTaskEndValid && io.ADC_MicroTask_Config.MicroTaskEndReady + val bdcDone = io.BDC_MicroTask_Config.MicroTaskEndValid && io.BDC_MicroTask_Config.MicroTaskEndReady + val cdcDone = io.CDC_MicroTask_Config.MicroTaskEndValid && io.CDC_MicroTask_Config.MicroTaskEndReady + + when(io.AML_MicroTask_Config.MicroTaskEndValid) { + assert(amlOwnerValid, "TaskController: AML done without valid owner") } - for (i <- 0 until MaxWriteRegs) { - entry.writeRegs(i) := 0.U - entry.writeValid(i) := false.B + when(io.BML_MicroTask_Config.MicroTaskEndValid) { + assert(bmlOwnerValid, "TaskController: BML done without valid owner") + } + when(io.CML_MicroTask_Config.LoadMicroTaskEndValid) { + assert(cmlLoadOwnerValid, "TaskController: CML-load done without valid owner") + } + when(io.CML_MicroTask_Config.StoreMicroTaskEndValid) { + assert(cmlStoreOwnerValid, "TaskController: CML-store done without valid owner") + } + when(io.ADC_MicroTask_Config.MicroTaskEndValid) { + assert(computeOwnerValid, "TaskController: ADC done without valid owner") + } + when(io.BDC_MicroTask_Config.MicroTaskEndValid) { + assert(computeOwnerValid, "TaskController: BDC done without valid owner") + } + when(io.CDC_MicroTask_Config.MicroTaskEndValid) { + assert(computeOwnerValid, "TaskController: CDC done without valid owner") } - when(amuCtrlBits.isMma()) { - val mma = amuCtrlBits.data.asTypeOf(new AmuMmaIO) - entry.readRegs(0) := mma.ms1 - entry.readRegs(1) := mma.ms2 - entry.readRegs(2) := mma.md - entry.readValid(0) := true.B - entry.readValid(1) := true.B - entry.readValid(2) := true.B - entry.writeRegs(0) := mma.md - entry.writeValid(0) := true.B + // ===================== Enqueue-side classification ===================== + val deqValid = decodedFifo.io.deq.valid + val deqEntry = decodedFifo.io.deq.bits + + val deqCtrl = deqEntry.ctrl + val deqIsMma = deqCtrl.isMma() + val deqIsArith = deqCtrl.isArith() + val deqIsLsu = deqCtrl.isMls() + val deqIsRelease = deqCtrl.isRelease() + + val deqLsu = decodeLsu(deqCtrl) + val deqArith = decodeArith(deqCtrl) + + val deqIsLoad = deqIsLsu && deqLsu.ls === 0.U + val deqLoadSelOH = Cat(deqLsu.isacc, deqLsu.isB, deqLsu.isA) + val deqLoadSelOneHot = PopCount(deqLoadSelOH) === 1.U + val deqIsMzeroLike = deqIsArith && isMzeroLike(deqArith) + + when(deqValid && deqIsLoad) { + assert(deqLoadSelOneHot, "TaskController: MLS load selector must be onehot among isA/isB/isacc") } - when(amuCtrlBits.isMls()) { - val lsu = amuCtrlBits.data.asTypeOf(new AmuLsuIO) - when(lsu.ls === 0.U) { // Load: 写寄存器 - entry.writeRegs(0) := lsu.ms - entry.writeValid(0) := true.B - }.otherwise { // Store: 读寄存器 - entry.readRegs(0) := lsu.ms - entry.readValid(0) := true.B - } + + val deqOpKind = Wire(TaskCtrlOpKind()) + deqOpKind := TaskCtrlOpKind.NopLike + when(deqIsMma) { + deqOpKind := TaskCtrlOpKind.Compute + }.elsewhen(deqIsLsu && deqLsu.ls === 1.U) { + deqOpKind := TaskCtrlOpKind.Store + }.elsewhen(deqIsLoad && deqLoadSelOneHot && deqLsu.isA) { + deqOpKind := TaskCtrlOpKind.LoadA + }.elsewhen(deqIsLoad && deqLoadSelOneHot && deqLsu.isB) { + deqOpKind := TaskCtrlOpKind.LoadB + }.elsewhen(deqIsLoad && deqLoadSelOneHot && deqLsu.isacc) { + deqOpKind := TaskCtrlOpKind.LoadC + }.elsewhen(deqIsRelease) { + deqOpKind := TaskCtrlOpKind.Release + }.elsewhen(deqIsMzeroLike && deqArith.md(2) === 1.U) { + deqOpKind := TaskCtrlOpKind.ZeroAcc + }.elsewhen(deqIsMzeroLike && deqArith.md(2) === 0.U) { + deqOpKind := TaskCtrlOpKind.ZeroTr } - when(amuCtrlBits.isArith()) { - val arith = amuCtrlBits.data.asTypeOf(new AmuArithIO) - entry.readRegs(0) := arith.md - entry.readValid(0) := true.B - entry.writeRegs(0) := arith.md - entry.writeValid(0) := true.B + + // ===================== Oldest-ready issue selection (based on cycle-start state) ===================== + val completedVec = VecInit(slots.map(s => s.valid && s.completed)).asUInt + val readAVec = VecInit(slots.map(s => s.valid && s.readADone)).asUInt + val readBVec = VecInit(slots.map(s => s.valid && s.readBDone)).asUInt + + val readyBySlot = Wire(Vec(WinDepth, Bool())) + readyBySlot.foreach(_ := false.B) + for (slotIdx <- 0 until WinDepth) { + val slot = slots(slotIdx) + val slotAge = (slotIdx.U + WinDepth.U - winHead)(SlotIdxWidth - 1, 0) + val inWindow = slotAge < winCount + val slotCanConsider = inWindow && slot.valid && !slot.issued + + val slotDepReady = Mux( + slotCanConsider, + ((slot.waitCompleteMask & (~completedVec)(WinDepth - 1, 0)) === 0.U) && + ((slot.waitReadAMask & (~readAVec)(WinDepth - 1, 0)) === 0.U) && + ((slot.waitReadBMask & (~readBVec)(WinDepth - 1, 0)) === 0.U), + false.B + ) + + val slotFuReady = Mux( + slotCanConsider, + MuxLookup(slot.opKind.asUInt, true.B)(Seq( + TaskCtrlOpKind.LoadA.asUInt -> (!fuAML.busy && io.AML_MicroTask_Config.MicroTaskReady), + TaskCtrlOpKind.LoadB.asUInt -> (!fuBML.busy && io.BML_MicroTask_Config.MicroTaskReady), + TaskCtrlOpKind.LoadC.asUInt -> (!fuCMLLoad.busy && io.CML_MicroTask_Config.LoadMicroTaskReady), + TaskCtrlOpKind.ZeroAcc.asUInt -> (!fuCMLLoad.busy && io.CML_MicroTask_Config.LoadMicroTaskReady), + TaskCtrlOpKind.ZeroTr.asUInt -> (!fuAML.busy && io.AML_MicroTask_Config.MicroTaskReady), + TaskCtrlOpKind.Store.asUInt -> (!fuCMLStore.busy && io.CML_MicroTask_Config.StoreMicroTaskReady), + TaskCtrlOpKind.Compute.asUInt -> (!fuCompute.busy && io.ADC_MicroTask_Config.MicroTaskReady && io.BDC_MicroTask_Config.MicroTaskReady && io.CDC_MicroTask_Config.MicroTaskReady), + TaskCtrlOpKind.Release.asUInt -> true.B, + TaskCtrlOpKind.NopLike.asUInt -> true.B + )), + false.B + ) + + readyBySlot(slotIdx) := slotCanConsider && slotDepReady && slotFuReady } - decodedFifo.io.enq.bits := entry - - // 仅查询队首指令能否发射 - val headValid = decodedFifo.io.deq.valid - val headEntry = decodedFifo.io.deq.bits - - val isMma = headEntry.ctrl.isMma() - val isArith = headEntry.ctrl.isArith() - val isLsu = headEntry.ctrl.isMls() - val isRelease = headEntry.ctrl.isRelease() - - val lsuInfo = headEntry.ctrl.data.asTypeOf(new AmuLsuIO) - val mmaInfo = headEntry.ctrl.data.asTypeOf(new AmuMmaIO) - val arithInfo = headEntry.ctrl.data.asTypeOf(new AmuArithIO) - val releaseInfo = headEntry.ctrl.data.asTypeOf(new AmuReleaseIO) - - val isLoad = isLsu && headEntry.writeValid(0) - val isStore = isLsu && !headEntry.writeValid(0) && headEntry.readValid(0) - val arithDestIsAcc = arithInfo.md(2) === 1.U - val isMzeroAcc = isArith && arithDestIsAcc - val isMzeroTr = isArith && !arithDestIsAcc - - val scoreboardReq = WireInit(0.U.asTypeOf(new QueryReq)) - val scoreboardReqValid = WireInit(false.B) - when(headValid) { - when(isLoad) { - scoreboardReqValid := true.B - val fuType = MuxCase(ScoreboardFuType.AML, Seq( - lsuInfo.isacc -> ScoreboardFuType.CML, - lsuInfo.isA -> ScoreboardFuType.AML, - lsuInfo.isB -> ScoreboardFuType.BML - )) - scoreboardReq.fuType := fuType - scoreboardReq.dest.valid := true.B - scoreboardReq.dest.bits.is_acc := lsuInfo.isacc - scoreboardReq.dest.bits.regIdx := lsuInfo.ms(ScoreboardConsts.RegIdxWidth - 1, 0) - } - when(isMma) { - scoreboardReqValid := true.B - scoreboardReq.fuType := ScoreboardFuType.Compute - scoreboardReq.dest.valid := true.B - scoreboardReq.dest.bits.is_acc := true.B - scoreboardReq.dest.bits.regIdx := mmaInfo.md(ScoreboardConsts.RegIdxWidth - 1, 0) - scoreboardReq.src1.valid := true.B - scoreboardReq.src1.bits.is_acc := false.B - scoreboardReq.src1.bits.regIdx := mmaInfo.ms1(ScoreboardConsts.RegIdxWidth - 1, 0) - scoreboardReq.src2.valid := true.B - scoreboardReq.src2.bits.is_acc := false.B - scoreboardReq.src2.bits.regIdx := mmaInfo.ms2(ScoreboardConsts.RegIdxWidth - 1, 0) - scoreboardReq.src3.valid := true.B - scoreboardReq.src3.bits.is_acc := true.B - scoreboardReq.src3.bits.regIdx := mmaInfo.md(ScoreboardConsts.RegIdxWidth - 1, 0) - } - when(isArith) { - scoreboardReqValid := true.B - scoreboardReq.fuType := Mux(arithDestIsAcc, ScoreboardFuType.CML, ScoreboardFuType.AML) - scoreboardReq.dest.valid := true.B - scoreboardReq.dest.bits.is_acc := arithDestIsAcc - scoreboardReq.dest.bits.regIdx := arithInfo.md(ScoreboardConsts.RegIdxWidth - 1, 0) + val readyByAge = VecInit(readyBySlot.rotate(winHead)) + val issueFound = readyByAge.asUInt.orR + val issueAgeOH = PriorityEncoderOH(readyByAge) + val issueSlotOH = VecInit(issueAgeOH.rotateRight(winHead)) + val issueSlotIdx = WireInit(0.U(SlotIdxWidth.W)) + when(issueFound) { + issueSlotIdx := OHToUInt(issueSlotOH) + } + + val issueFire = issueFound + val issueSlot = slots(issueSlotIdx) + + val deqStoreReadsAB = deqIsLsu && (deqLsu.ls === 1.U) && !deqLsu.isacc + val deqStoreReadsC = deqIsLsu && (deqLsu.ls === 1.U) && deqLsu.isacc + + // ===================== Retirement lookahead (after done convergence) ===================== + val headSlot = slots(winHead) + + val headDoneByFu = + (amlDone && fuAML.busy && fuAML.ownerSlot === winHead) || + (bmlDone && fuBML.busy && fuBML.ownerSlot === winHead) || + (cmlLoadDone && fuCMLLoad.busy && fuCMLLoad.ownerSlot === winHead) || + (cmlStoreDone && fuCMLStore.busy && fuCMLStore.ownerSlot === winHead) || + (cdcDone && fuCompute.busy && fuCompute.ownerSlot === winHead) + + val headCompletedAfterDone = headSlot.completed || headDoneByFu + val retireFire = headSlot.valid && headCompletedAfterDone + + // Allow retire + enqueue in the same cycle, including full-window reuse of the retired slot + val windowFull = winCount === WinDepth.U + val enqueueCanFire = deqValid && (!windowFull || retireFire) + decodedFifo.io.deq.ready := enqueueCanFire + val enqueueFire = decodedFifo.io.deq.fire + + val enqueueSlotIdx = Mux(windowFull, winHead, winTail) + val ownedWork = deqValid || (winCount =/= 0.U) + + // ===================== Issue dispatch bridge ===================== + val issueCtrl = issueSlot.entry.ctrl + val issueLsu = decodeLsu(issueCtrl) + val issueMma = decodeMma(issueCtrl) + val issueArith = decodeArith(issueCtrl) + + private def computeKFromMsew(k: UInt, msew: UInt): UInt = { + MuxLookup(msew(1, 0), k)(Seq( + Bundles.MSew.e8 -> k, + Bundles.MSew.e16 -> (k << 1), + Bundles.MSew.e32 -> (k << 2), + Bundles.MSew.e4 -> (k >> 1) + )) + } + + private def loadDataType(widths: UInt): UInt = { + MuxLookup(widths, ElementDataType.DataTypeWidth32)(Seq( + Bundles.MSew.e8 -> ElementDataType.DataTypeWidth8, + Bundles.MSew.e16 -> ElementDataType.DataTypeWidth16, + Bundles.MSew.e32 -> ElementDataType.DataTypeWidth32, + Bundles.MSew.e4 -> ElementDataType.DataTypeWidth4 + )) + } + + private val loadByteCountBits = MatrixRegMaxTensorDimBitSize + 2 + + private def loadByteCount(dim: UInt, widths: UInt): UInt = { + val dimWide = dim.pad(loadByteCountBits) + val byteCount = WireDefault((dimWide << 2)(loadByteCountBits - 1, 0)) + switch(widths) { + is(Bundles.MSew.e8) { byteCount := dimWide } + is(Bundles.MSew.e16) { byteCount := (dimWide << 1)(loadByteCountBits - 1, 0) } + is(Bundles.MSew.e32) { byteCount := (dimWide << 2)(loadByteCountBits - 1, 0) } + is(Bundles.MSew.e4) { byteCount := dimWide >> 1 } } - when(isStore) { - scoreboardReqValid := true.B - scoreboardReq.fuType := ScoreboardFuType.CML - scoreboardReq.src1.valid := true.B - scoreboardReq.src1.bits.is_acc := lsuInfo.isacc - scoreboardReq.src1.bits.regIdx := lsuInfo.ms(ScoreboardConsts.RegIdxWidth - 1, 0) + byteCount + } + + private def loadBeatCount(dim: UInt, widths: UInt): UInt = { + val beatCount = (loadByteCount(dim, widths) + (outsideDataWidthByte - 1).U) >> log2Ceil(outsideDataWidthByte) + beatCount.pad(MatrixRegMaxTensorDimBitSize)(MatrixRegMaxTensorDimBitSize - 1, 0) + } + + private def loadHasTail(dim: UInt, widths: UInt): Bool = { + loadByteCount(dim, widths)(log2Ceil(outsideDataWidthByte) - 1, 0).orR + } + + private def loadTailByteMask(dim: UInt, widths: UInt): UInt = { + Cat(0.U((log2Ceil(outsideDataWidthByte + 1) - log2Ceil(outsideDataWidthByte)).W), loadByteCount(dim, widths)(log2Ceil(outsideDataWidthByte) - 1, 0)) + } + + private def decodeMmaComputeType(mma: AmuMmaIO): UInt = { + val computeType = WireInit(MteComputeType.ComputeTypeUndef) + when(mma.isfp) { + when(mma.types1 === "b001".U && mma.types2 === "b001".U) { + computeType := MteComputeType.F16F16F32 + }.elsewhen(mma.types1 === "b000".U && mma.types2 === "b000".U) { + computeType := MteComputeType.Fp8e5m2F32 + }.elsewhen(mma.types1 === "b100".U && mma.types2 === "b100".U) { + computeType := MteComputeType.Fp8e4m3F32 + }.elsewhen(mma.types1 === "b101".U && mma.types2 === "b101".U) { + computeType := MteComputeType.BF16BF16F32 + }.elsewhen(mma.types1 === "b010".U && mma.types2 === "b010".U) { + computeType := MteComputeType.ComputeTypeUndef // FP32FP32FP32 is unsupported. + }.elsewhen(mma.types1 === "b110".U && mma.types2 === "b110".U) { + computeType := MteComputeType.TF32TF32F32 + }.elsewhen(mma.types1 === "b011".U && mma.types2 === "b011".U) { + computeType := MteComputeType.Nvfp4F32 + }.otherwise { + computeType := MteComputeType.ComputeTypeUndef + } + }.otherwise { + when(mma.types1 === "b000".U && mma.types2 === "b000".U) { + computeType := MteComputeType.U8U8I32 + }.elsewhen(mma.types1 === "b100".U && mma.types2 === "b000".U) { + computeType := MteComputeType.I8U8I32 + }.elsewhen(mma.types1 === "b000".U && mma.types2 === "b100".U) { + computeType := MteComputeType.U8I8I32 + }.elsewhen(mma.types1 === "b100".U && mma.types2 === "b100".U) { + computeType := MteComputeType.I8I8I32 + }.otherwise { + computeType := MteComputeType.ComputeTypeUndef + } } + computeType } - scoreboard.io.query.req.valid := scoreboardReqValid - scoreboard.io.query.req.bits := scoreboardReq - - val mmaUnitsReady = io.ADC_MicroTask_Config.MicroTaskReady && - io.BDC_MicroTask_Config.MicroTaskReady && - io.CDC_MicroTask_Config.MicroTaskReady - val storeUnitsReady = io.CML_MicroTask_Config.MicroTaskReady - val zeroAccUnitsReady = io.CML_MicroTask_Config.MicroTaskReady - val zeroTrUnitsReady = io.AML_MicroTask_Config.MicroTaskReady - - val needA = isLoad && lsuInfo.isA - val needB = isLoad && lsuInfo.isB - val needC = isLoad && lsuInfo.isacc - - val loadUnitsReady = (!needA || io.AML_MicroTask_Config.MicroTaskReady) && - (!needB || io.BML_MicroTask_Config.MicroTaskReady) && - (!needC || io.CML_MicroTask_Config.MicroTaskReady) - - // val storeReadersEmpty = scoreboard.io.debug.c_reg_reader_counts.map(_ === 0.U).reduce(_ && _) - val releaseReady = !pendingStore // && storeReadersEmpty - - val scoreboardReqReady = !scoreboardReqValid || scoreboard.io.query.req.ready - - val headReady = MuxCase(true.B, Seq( - (isLoad) -> (scoreboardReqReady && loadUnitsReady), - (isStore) -> (scoreboardReqReady && storeUnitsReady), - (isMma) -> (scoreboardReqReady && mmaUnitsReady), - (isMzeroAcc) -> (scoreboardReqReady && zeroAccUnitsReady), - (isMzeroTr) -> (scoreboardReqReady && zeroTrUnitsReady), - (isRelease) -> releaseReady - )) - - decodedFifo.io.deq.ready := headValid && headReady - - val issueFire = decodedFifo.io.deq.fire - val issueLoad = issueFire && isLoad - val issueStore = issueFire && isStore - val issueMma = issueFire && isMma - val issueZeroAcc = issueFire && isMzeroAcc - val issueZeroTr = issueFire && isMzeroTr - val issueRelease = issueFire && isRelease - - val loadDataType = MuxLookup(lsuInfo.widths, ElementDataType.DataTypeWidth32)(Seq( - Bundles.MSew.e8 -> ElementDataType.DataTypeWidth8, - Bundles.MSew.e16 -> ElementDataType.DataTypeWidth16, - Bundles.MSew.e32 -> ElementDataType.DataTypeWidth32, - Bundles.MSew.e4 -> ElementDataType.DataTypeWidth4 - )) - val loadKBytes = MuxLookup(lsuInfo.widths, lsuInfo.column)(Seq( - Bundles.MSew.e8 -> lsuInfo.column, - Bundles.MSew.e16 -> (lsuInfo.column << 1), - Bundles.MSew.e32 -> (lsuInfo.column << 2), - Bundles.MSew.e4 -> (lsuInfo.column >> 1) - )) - val loadKBytes_for_B = MuxLookup(lsuInfo.widths, lsuInfo.row)(Seq( - Bundles.MSew.e8 -> lsuInfo.row, - Bundles.MSew.e16 -> (lsuInfo.row << 1), - Bundles.MSew.e32 -> (lsuInfo.row << 2), - Bundles.MSew.e4 -> (lsuInfo.row >> 1) - )) - val loadKBeatCount_for_B = (loadKBytes_for_B + (outsideDataWidthByte - 1).U) >> log2Ceil(outsideDataWidthByte) - val loadKBeatCount = (loadKBytes + (outsideDataWidthByte - 1).U) >> log2Ceil(outsideDataWidthByte) - val loadHasTail = MuxLookup(loadDataType, false.B)(Seq( - ElementDataType.DataTypeWidth8 -> lsuInfo.column(5, 0).orR, - ElementDataType.DataTypeWidth16 -> lsuInfo.column(4, 0).orR, - ElementDataType.DataTypeWidth32 -> lsuInfo.column(3, 0).orR - )) - val loadTailByteMask = MuxLookup(loadDataType, 0.U(log2Ceil(outsideDataWidthByte + 1).W))(Seq( - ElementDataType.DataTypeWidth8 -> lsuInfo.column(5, 0), - ElementDataType.DataTypeWidth16 -> Cat(lsuInfo.column(4, 0), 0.U(1.W)), - ElementDataType.DataTypeWidth32 -> Cat(lsuInfo.column(3, 0), 0.U(2.W)) - )) - val loadHasTail_for_B = MuxLookup(loadDataType, false.B)(Seq( - ElementDataType.DataTypeWidth8 -> lsuInfo.row(5, 0).orR, - ElementDataType.DataTypeWidth16 -> lsuInfo.row(4, 0).orR, - ElementDataType.DataTypeWidth32 -> lsuInfo.row(3, 0).orR - )) - val loadTailByteMask_for_B = MuxLookup(loadDataType, 0.U(log2Ceil(outsideDataWidthByte + 1).W))(Seq( - ElementDataType.DataTypeWidth8 -> lsuInfo.row(5, 0), - ElementDataType.DataTypeWidth16 -> Cat(lsuInfo.row(4, 0), 0.U(1.W)), - ElementDataType.DataTypeWidth32 -> Cat(lsuInfo.row(3, 0), 0.U(2.W)) - )) - val loadNBytes = MuxLookup(loadDataType, lsuInfo.column << 2)(Seq( - ElementDataType.DataTypeWidth8 -> lsuInfo.column, - ElementDataType.DataTypeWidth16 -> (lsuInfo.column << 1), - ElementDataType.DataTypeWidth32 -> (lsuInfo.column << 2) - )) - val loadNBeatCount = (loadNBytes + (outsideDataWidthByte - 1).U) >> log2Ceil(outsideDataWidthByte) - val loadNHasTail = MuxLookup(loadDataType, false.B)(Seq( - ElementDataType.DataTypeWidth8 -> lsuInfo.column(5, 0).orR, - ElementDataType.DataTypeWidth16 -> lsuInfo.column(4, 0).orR, - ElementDataType.DataTypeWidth32 -> lsuInfo.column(3, 0).orR - )) - val loadNTailByteMask = MuxLookup(loadDataType, 0.U(log2Ceil(outsideDataWidthByte + 1).W))(Seq( - ElementDataType.DataTypeWidth8 -> lsuInfo.column(5, 0), - ElementDataType.DataTypeWidth16 -> Cat(lsuInfo.column(4, 0), 0.U(1.W)), - ElementDataType.DataTypeWidth32 -> Cat(lsuInfo.column(3, 0), 0.U(2.W)) - )) - - when(issueLoad) { - val regIdx = lsuInfo.ms(1, 0) - val loadIdx = loadAllocIdx - assert(lsuInfo.stride(5, 0) === 0.U, "TaskController load stride must be 64B aligned") - - when(needA) { - io.AML_MicroTask_Config.ApplicationTensor_A.ApplicationTensor_A_BaseVaddr := lsuInfo.baseAddr - io.AML_MicroTask_Config.ApplicationTensor_A.ApplicationTensor_A_Stride_M := lsuInfo.stride - io.AML_MicroTask_Config.ApplicationTensor_A.dataType := loadDataType - io.AML_MicroTask_Config.ApplicationTensor_A.HasTail := loadHasTail - io.AML_MicroTask_Config.ApplicationTensor_A.TailByteMask := loadTailByteMask - io.AML_MicroTask_Config.ApplicationTensor_A.K_Beat_Count := loadKBeatCount - io.AML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := true.B - io.AML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := false.B - io.AML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B - io.AML_MicroTask_Config.MatrixRegTensor_M := lsuInfo.row - io.AML_MicroTask_Config.MatrixRegTensor_K := loadKBeatCount - io.AML_MicroTask_Config.MatrixRegId := regIdx - if (EnableDifftest) { - io.AML_MicroTask_Config.pc.get := headEntry.ctrl.pc.get - io.AML_MicroTask_Config.coreid.get := headEntry.ctrl.coreid.get + when(issueFire) { + switch(issueSlot.opKind) { + is(TaskCtrlOpKind.LoadA) { + val regIdx = issueLsu.ms(1, 0) + val matrixDim = Mux(issueLsu.transpose, issueLsu.column, issueLsu.row) + val reduceDim = Mux(issueLsu.transpose, issueLsu.row, issueLsu.column) + val kVal = loadBeatCount(reduceDim, issueLsu.widths) + + io.AML_MicroTask_Config.ApplicationTensor_A.ApplicationTensor_A_BaseVaddr := issueLsu.baseAddr + io.AML_MicroTask_Config.ApplicationTensor_A.ApplicationTensor_A_Stride_M := issueLsu.stride + io.AML_MicroTask_Config.ApplicationTensor_A.dataType := loadDataType(issueLsu.widths) + io.AML_MicroTask_Config.ApplicationTensor_A.HasTail := loadHasTail(reduceDim, issueLsu.widths) + io.AML_MicroTask_Config.ApplicationTensor_A.TailByteMask := loadTailByteMask(reduceDim, issueLsu.widths) + io.AML_MicroTask_Config.ApplicationTensor_A.K_Beat_Count := kVal + io.AML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := true.B + io.AML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := false.B + io.AML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B + io.AML_MicroTask_Config.MatrixRegTensor_M := matrixDim + io.AML_MicroTask_Config.MatrixRegTensor_K := kVal + io.AML_MicroTask_Config.MatrixRegId := regIdx + io.AML_MicroTask_Config.Conherent := true.B + io.AML_MicroTask_Config.Is_Transpose := issueLsu.transpose + io.AML_MicroTask_Config.MicroTaskValid := true.B + if (EnableDifftest) { + io.AML_MicroTask_Config.pc.get := issueCtrl.pc.get + io.AML_MicroTask_Config.coreid.get := issueCtrl.coreid.get + } + + loadAllocateEvent.eventType := 0.U + loadAllocateEvent.regId := regIdx + loadAllocateEvent.fifoIdx := issueSlot.fifoIdx + loadAllocateEvent.needMask := "b001".U + loadAllocateEvent.row := issueLsu.row + loadAllocateEvent.column := issueLsu.column + loadAllocateEvent.transpose := issueLsu.transpose + loadAllocateEvent.isAcc := false.B + loadAllocateEvent.slotId := issueSlotIdx + loadAllocateEvent.seqId := issueSlot.seqId + loadAllocateEventEn := true.B + + loadIssueEvent.eventType := 1.U + loadIssueEvent.regId := regIdx + loadIssueEvent.fifoIdx := issueSlot.fifoIdx + loadIssueEvent.needMask := "b001".U + loadIssueEvent.row := issueLsu.row + loadIssueEvent.column := issueLsu.column + loadIssueEvent.transpose := issueLsu.transpose + loadIssueEvent.isAcc := false.B + loadIssueEvent.slotId := issueSlotIdx + loadIssueEvent.seqId := issueSlot.seqId + loadIssueEventEn := true.B } - io.AML_MicroTask_Config.Conherent := true.B + is(TaskCtrlOpKind.LoadB) { + val regIdx = issueLsu.ms(1, 0) + val matrixDim = Mux(issueLsu.transpose, issueLsu.row, issueLsu.column) + val reduceDim = Mux(issueLsu.transpose, issueLsu.column, issueLsu.row) + val kVal = loadBeatCount(reduceDim, issueLsu.widths) + + io.BML_MicroTask_Config.ApplicationTensor_B.ApplicationTensor_B_BaseVaddr := issueLsu.baseAddr + io.BML_MicroTask_Config.ApplicationTensor_B.ApplicationTensor_B_Stride_N := issueLsu.stride + io.BML_MicroTask_Config.ApplicationTensor_B.BlockTensor_B_BaseVaddr := issueLsu.baseAddr + io.BML_MicroTask_Config.ApplicationTensor_B.dataType := loadDataType(issueLsu.widths) + io.BML_MicroTask_Config.ApplicationTensor_B.HasTail := loadHasTail(reduceDim, issueLsu.widths) + io.BML_MicroTask_Config.ApplicationTensor_B.TailByteMask := loadTailByteMask(reduceDim, issueLsu.widths) + io.BML_MicroTask_Config.ApplicationTensor_B.K_Beat_Count := kVal + io.BML_MicroTask_Config.MatrixRegTensor_N := matrixDim + io.BML_MicroTask_Config.MatrixRegTensor_K := kVal + io.BML_MicroTask_Config.MatrixRegId := regIdx + io.BML_MicroTask_Config.Conherent := true.B + io.BML_MicroTask_Config.Is_Transpose := issueLsu.transpose + io.BML_MicroTask_Config.MicroTaskValid := true.B + if (EnableDifftest) { + io.BML_MicroTask_Config.pc.get := issueCtrl.pc.get + io.BML_MicroTask_Config.coreid.get := issueCtrl.coreid.get + } + + loadAllocateEvent.eventType := 0.U + loadAllocateEvent.regId := regIdx + loadAllocateEvent.fifoIdx := issueSlot.fifoIdx + loadAllocateEvent.needMask := "b010".U + loadAllocateEvent.row := issueLsu.row + loadAllocateEvent.column := issueLsu.column + loadAllocateEvent.transpose := issueLsu.transpose + loadAllocateEvent.isAcc := false.B + loadAllocateEvent.slotId := issueSlotIdx + loadAllocateEvent.seqId := issueSlot.seqId + loadAllocateEventEn := true.B + + loadIssueEvent.eventType := 1.U + loadIssueEvent.regId := regIdx + loadIssueEvent.fifoIdx := issueSlot.fifoIdx + loadIssueEvent.needMask := "b010".U + loadIssueEvent.row := issueLsu.row + loadIssueEvent.column := issueLsu.column + loadIssueEvent.transpose := issueLsu.transpose + loadIssueEvent.isAcc := false.B + loadIssueEvent.slotId := issueSlotIdx + loadIssueEvent.seqId := issueSlot.seqId + loadIssueEventEn := true.B + } - io.AML_MicroTask_Config.MicroTaskValid := true.B - if (YJPTASKDebugEnable) { - printf("[TaskController_IssueAML<%d>] reg=%d fifo=%d row=%d col=%d stride=%x coher=%d kBeat=%d tail=%d tailMask=%d base=%x\n", - io.DebugTimeStampe, regIdx, loadIdx, lsuInfo.row, lsuInfo.column, lsuInfo.stride, io.AML_MicroTask_Config.Conherent, - loadKBeatCount, loadHasTail, loadTailByteMask, lsuInfo.baseAddr) + is(TaskCtrlOpKind.LoadC) { + val regIdx = issueLsu.ms(1, 0) + val matrixDimM = Mux(issueLsu.transpose, issueLsu.column, issueLsu.row) + val matrixDimN = Mux(issueLsu.transpose, issueLsu.row, issueLsu.column) + val nVal = loadBeatCount(matrixDimN, issueLsu.widths) + + io.CML_MicroTask_Config.ApplicationTensor_C.ApplicationTensor_C_BaseVaddr := issueLsu.baseAddr + io.CML_MicroTask_Config.ApplicationTensor_C.ApplicationTensor_C_Stride_M := issueLsu.stride + io.CML_MicroTask_Config.ApplicationTensor_C.BlockTensor_C_BaseVaddr := issueLsu.baseAddr + io.CML_MicroTask_Config.ApplicationTensor_C.dataType := loadDataType(issueLsu.widths) + io.CML_MicroTask_Config.ApplicationTensor_C.HasTail := loadHasTail(matrixDimN, issueLsu.widths) + io.CML_MicroTask_Config.ApplicationTensor_C.TailByteMask := loadTailByteMask(matrixDimN, issueLsu.widths) + io.CML_MicroTask_Config.ApplicationTensor_C.N_Beat_Count := nVal + io.CML_MicroTask_Config.Conherent := true.B + io.CML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := true.B + io.CML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := false.B + io.CML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B + io.CML_MicroTask_Config.MatrixRegTensor_M := matrixDimM + io.CML_MicroTask_Config.MatrixRegTensor_N := matrixDimN + io.CML_MicroTask_Config.MatrixRegId := regIdx + io.CML_MicroTask_Config.Is_Transpose := issueLsu.transpose + io.CML_MicroTask_Config.LoadMicroTaskValid := true.B + io.CML_MicroTask_Config.StoreMicroTaskValid := false.B + if (EnableDifftest) { + io.CML_MicroTask_Config.pc.get := issueCtrl.pc.get + io.CML_MicroTask_Config.coreid.get := issueCtrl.coreid.get + } + + loadAllocateEvent.eventType := 0.U + loadAllocateEvent.regId := regIdx + loadAllocateEvent.fifoIdx := issueSlot.fifoIdx + loadAllocateEvent.needMask := "b100".U + loadAllocateEvent.row := issueLsu.row + loadAllocateEvent.column := issueLsu.column + loadAllocateEvent.transpose := issueLsu.transpose + loadAllocateEvent.isAcc := true.B + loadAllocateEvent.slotId := issueSlotIdx + loadAllocateEvent.seqId := issueSlot.seqId + loadAllocateEventEn := true.B + + loadIssueEvent.eventType := 1.U + loadIssueEvent.regId := regIdx + loadIssueEvent.fifoIdx := issueSlot.fifoIdx + loadIssueEvent.needMask := "b100".U + loadIssueEvent.row := issueLsu.row + loadIssueEvent.column := issueLsu.column + loadIssueEvent.transpose := issueLsu.transpose + loadIssueEvent.isAcc := true.B + loadIssueEvent.slotId := issueSlotIdx + loadIssueEvent.seqId := issueSlot.seqId + loadIssueEventEn := true.B } - - pendingLoadA := true.B - pendingLoadAReg := regIdx - pendingLoadAFifoIdx := loadIdx - } - when(needB) { - io.BML_MicroTask_Config.ApplicationTensor_B.ApplicationTensor_B_BaseVaddr := lsuInfo.baseAddr - io.BML_MicroTask_Config.ApplicationTensor_B.ApplicationTensor_B_Stride_N := lsuInfo.stride - io.BML_MicroTask_Config.ApplicationTensor_B.BlockTensor_B_BaseVaddr := lsuInfo.baseAddr - io.BML_MicroTask_Config.ApplicationTensor_B.dataType := loadDataType - io.BML_MicroTask_Config.ApplicationTensor_B.HasTail := loadHasTail_for_B - io.BML_MicroTask_Config.ApplicationTensor_B.TailByteMask := loadTailByteMask_for_B - io.BML_MicroTask_Config.ApplicationTensor_B.K_Beat_Count := loadKBeatCount_for_B - io.BML_MicroTask_Config.MatrixRegTensor_N := lsuInfo.column - io.BML_MicroTask_Config.MatrixRegTensor_K := loadKBeatCount_for_B - io.BML_MicroTask_Config.MatrixRegId := regIdx - if (EnableDifftest) { - io.BML_MicroTask_Config.pc.get := headEntry.ctrl.pc.get - io.BML_MicroTask_Config.coreid.get := headEntry.ctrl.coreid.get + + is(TaskCtrlOpKind.ZeroAcc) { + val regIdx = issueArith.md(1, 0) + + io.CML_MicroTask_Config.ApplicationTensor_C.dataType := ElementDataType.DataTypeWidth32 + io.CML_MicroTask_Config.MatrixRegTensor_M := cuteParams.Tensor_MN.U + io.CML_MicroTask_Config.MatrixRegTensor_N := cuteParams.Tensor_MN.U + io.CML_MicroTask_Config.MatrixRegId := regIdx + io.CML_MicroTask_Config.LoadMicroTaskValid := true.B + io.CML_MicroTask_Config.StoreMicroTaskValid := false.B + io.CML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := true.B + io.CML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := false.B + io.CML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B + io.CML_MicroTask_Config.Conherent := true.B + io.CML_MicroTask_Config.Is_Transpose := false.B + if (EnableDifftest) { + io.CML_MicroTask_Config.pc.get := issueCtrl.pc.get + io.CML_MicroTask_Config.coreid.get := issueCtrl.coreid.get + } + + loadAllocateEvent.eventType := 0.U + loadAllocateEvent.regId := regIdx + loadAllocateEvent.fifoIdx := issueSlot.fifoIdx + loadAllocateEvent.needMask := "b100".U + loadAllocateEvent.row := 0.U + loadAllocateEvent.column := 0.U + loadAllocateEvent.transpose := false.B + loadAllocateEvent.isAcc := true.B + loadAllocateEvent.slotId := issueSlotIdx + loadAllocateEvent.seqId := issueSlot.seqId + loadAllocateEventEn := true.B + + loadIssueEvent.eventType := 1.U + loadIssueEvent.regId := regIdx + loadIssueEvent.fifoIdx := issueSlot.fifoIdx + loadIssueEvent.needMask := "b100".U + loadIssueEvent.row := 0.U + loadIssueEvent.column := 0.U + loadIssueEvent.transpose := false.B + loadIssueEvent.isAcc := true.B + loadIssueEvent.slotId := issueSlotIdx + loadIssueEvent.seqId := issueSlot.seqId + loadIssueEventEn := true.B } - io.BML_MicroTask_Config.Conherent := true.B - io.BML_MicroTask_Config.MicroTaskValid := true.B - if (YJPTASKDebugEnable) { - printf("[TaskController_IssueBML<%d>] reg=%d fifo=%d row=%d col=%d stride=%x coher=%d kBeat=%d tail=%d tailMask=%d base=%x\n", - io.DebugTimeStampe, regIdx, loadIdx, lsuInfo.row, lsuInfo.column, lsuInfo.stride, io.BML_MicroTask_Config.Conherent, - loadKBeatCount_for_B, loadHasTail_for_B, loadTailByteMask_for_B, lsuInfo.baseAddr) + + is(TaskCtrlOpKind.ZeroTr) { + val regIdx = issueArith.md(1, 0) + + io.AML_MicroTask_Config.ApplicationTensor_A.dataType := ElementDataType.DataTypeWidth8 + io.AML_MicroTask_Config.MatrixRegTensor_M := cuteParams.Tensor_MN.U + io.AML_MicroTask_Config.MatrixRegTensor_K := cuteParams.Tensor_K.U / ReduceWidthByte.U + io.AML_MicroTask_Config.MatrixRegId := regIdx + io.AML_MicroTask_Config.MicroTaskValid := true.B + io.AML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := true.B + io.AML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := false.B + io.AML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B + io.AML_MicroTask_Config.Conherent := true.B + io.AML_MicroTask_Config.Is_Transpose := false.B + if (EnableDifftest) { + io.AML_MicroTask_Config.pc.get := issueCtrl.pc.get + io.AML_MicroTask_Config.coreid.get := issueCtrl.coreid.get + } + + loadAllocateEvent.eventType := 0.U + loadAllocateEvent.regId := regIdx + loadAllocateEvent.fifoIdx := issueSlot.fifoIdx + loadAllocateEvent.needMask := "b001".U + loadAllocateEvent.row := 0.U + loadAllocateEvent.column := 0.U + loadAllocateEvent.transpose := false.B + loadAllocateEvent.isAcc := false.B + loadAllocateEvent.slotId := issueSlotIdx + loadAllocateEvent.seqId := issueSlot.seqId + loadAllocateEventEn := true.B + + loadIssueEvent.eventType := 1.U + loadIssueEvent.regId := regIdx + loadIssueEvent.fifoIdx := issueSlot.fifoIdx + loadIssueEvent.needMask := "b001".U + loadIssueEvent.row := 0.U + loadIssueEvent.column := 0.U + loadIssueEvent.transpose := false.B + loadIssueEvent.isAcc := false.B + loadIssueEvent.slotId := issueSlotIdx + loadIssueEvent.seqId := issueSlot.seqId + loadIssueEventEn := true.B } - pendingLoadB := true.B - pendingLoadBReg := regIdx - pendingLoadBFifoIdx := loadIdx - } - when(needC) { - io.CML_MicroTask_Config.ApplicationTensor_C.ApplicationTensor_C_BaseVaddr := lsuInfo.baseAddr - io.CML_MicroTask_Config.ApplicationTensor_C.ApplicationTensor_C_Stride_M := lsuInfo.stride - io.CML_MicroTask_Config.ApplicationTensor_C.BlockTensor_C_BaseVaddr := lsuInfo.baseAddr - io.CML_MicroTask_Config.ApplicationTensor_C.dataType := loadDataType - io.CML_MicroTask_Config.ApplicationTensor_C.HasTail := loadNHasTail - io.CML_MicroTask_Config.ApplicationTensor_C.TailByteMask := loadNTailByteMask - io.CML_MicroTask_Config.ApplicationTensor_C.N_Beat_Count := loadNBeatCount - io.CML_MicroTask_Config.Conherent := true.B - io.CML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := true.B - io.CML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := false.B - io.CML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B - io.CML_MicroTask_Config.MatrixRegTensor_M := lsuInfo.row - io.CML_MicroTask_Config.MatrixRegTensor_N := lsuInfo.column - io.CML_MicroTask_Config.MatrixRegId := regIdx - if (EnableDifftest) { - io.CML_MicroTask_Config.pc.get := headEntry.ctrl.pc.get - io.CML_MicroTask_Config.coreid.get := headEntry.ctrl.coreid.get + + is(TaskCtrlOpKind.Compute) { + val aReg = issueMma.ms1(1, 0) + val bReg = issueMma.ms2(1, 0) + val cReg = issueMma.md(1, 0) + val mVal = issueMma.mtilem + val nVal = issueMma.mtilen + val kVal = computeKFromMsew(issueMma.mtilek, issueMma.types1) + val mmaComputeType = decodeMmaComputeType(issueMma) + + io.ADC_MicroTask_Config.MicroTaskValid := true.B + io.BDC_MicroTask_Config.MicroTaskValid := true.B + io.CDC_MicroTask_Config.MicroTaskValid := true.B + + io.ADC_MicroTask_Config.dataType := ElementDataType.DataTypeWidth8 + io.ADC_MicroTask_Config.MatrixRegTensor_M := mVal + io.ADC_MicroTask_Config.MatrixRegTensor_N := nVal + io.ADC_MicroTask_Config.MatrixRegTensor_K := kVal / ReduceWidthByte.U + io.ADC_MicroTask_Config.MatrixRegId := aReg + io.ADC_MicroTask_Config.Is_Transpose := false.B + + io.BDC_MicroTask_Config.dataType := ElementDataType.DataTypeWidth8 + io.BDC_MicroTask_Config.MatrixRegTensor_M := mVal + io.BDC_MicroTask_Config.MatrixRegTensor_N := nVal + io.BDC_MicroTask_Config.MatrixRegTensor_K := kVal / ReduceWidthByte.U + io.BDC_MicroTask_Config.MatrixRegId := bReg + io.BDC_MicroTask_Config.Is_Transpose := false.B + + io.CDC_MicroTask_Config.ApplicationTensor_C.dataType := ElementDataType.DataTypeWidth32 + io.CDC_MicroTask_Config.MatrixRegTensor_M := mVal + io.CDC_MicroTask_Config.MatrixRegTensor_N := nVal + io.CDC_MicroTask_Config.MatrixRegTensor_K := kVal / ReduceWidthByte.U + io.CDC_MicroTask_Config.MatrixRegId := cReg + io.CDC_MicroTask_Config.Is_Transpose := false.B + io.CDC_MicroTask_Config.Is_AfterOps_Tile := false.B + if (EnableDifftest) { + io.CDC_MicroTask_Config.pc.get := issueCtrl.pc.get + io.CDC_MicroTask_Config.coreid.get := issueCtrl.coreid.get + } + + io.MTE_MicroTask_Config.MicroTaskValid := true.B + io.MTE_MicroTask_Config.computeType := mmaComputeType + + computeIssueEvent.eventType := 0.U + computeIssueEvent.aReg := aReg + computeIssueEvent.bReg := bReg + computeIssueEvent.cReg := cReg + computeIssueEvent.fifoIdx := issueSlot.fifoIdx + computeIssueEvent.mtilem := mVal + computeIssueEvent.mtilen := nVal + computeIssueEvent.mtilek := kVal + computeIssueEvent.isMma := true.B + computeIssueEvent.isFp := issueMma.isfp + computeIssueEvent.slotId := issueSlotIdx + computeIssueEvent.seqId := issueSlot.seqId + computeIssueEventEn := true.B } - io.CML_MicroTask_Config.IsLoadMicroTask := true.B - io.CML_MicroTask_Config.IsStoreMicroTask := false.B - io.CML_MicroTask_Config.MicroTaskValid := true.B - if (YJPTASKDebugEnable) { - printf("[TaskController_IssueCMLLoad<%d>] reg=%d fifo=%d row=%d col=%d stride=%x coher=%d transpose=%d nBeat=%d tail=%d tailMask=%d base=%x\n", - io.DebugTimeStampe, regIdx, loadIdx, lsuInfo.row, lsuInfo.column, lsuInfo.stride, io.CML_MicroTask_Config.Conherent, - lsuInfo.transpose, loadNBeatCount, loadNHasTail, loadNTailByteMask, lsuInfo.baseAddr) + + is(TaskCtrlOpKind.Store) { + val regIdx = issueLsu.ms(1, 0) + + io.CML_MicroTask_Config.ApplicationTensor_D.ApplicationTensor_D_BaseVaddr := issueLsu.baseAddr + io.CML_MicroTask_Config.ApplicationTensor_D.ApplicationTensor_D_Stride_M := issueLsu.stride + io.CML_MicroTask_Config.ApplicationTensor_D.BlockTensor_D_BaseVaddr := issueLsu.baseAddr + io.CML_MicroTask_Config.ApplicationTensor_D.dataType := loadDataType(issueLsu.widths) + io.CML_MicroTask_Config.Conherent := true.B + io.CML_MicroTask_Config.Is_Transpose := issueLsu.transpose + io.CML_MicroTask_Config.MatrixRegTensor_M := issueLsu.row + io.CML_MicroTask_Config.MatrixRegTensor_N := issueLsu.column + io.CML_MicroTask_Config.MatrixRegId := regIdx + io.CML_MicroTask_Config.LoadMicroTaskValid := false.B + io.CML_MicroTask_Config.StoreMicroTaskValid := true.B + if (EnableDifftest) { + io.CML_MicroTask_Config.pc.get := issueCtrl.pc.get + io.CML_MicroTask_Config.coreid.get := issueCtrl.coreid.get + } + + storeIssueEvent.eventType := 0.U + storeIssueEvent.regId := regIdx + storeIssueEvent.fifoIdx := issueSlot.fifoIdx + storeIssueEvent.row := issueLsu.row + storeIssueEvent.column := issueLsu.column + storeIssueEvent.transpose := issueLsu.transpose + storeIssueEvent.isAcc := issueLsu.isacc + storeIssueEvent.slotId := issueSlotIdx + storeIssueEvent.seqId := issueSlot.seqId + storeIssueEventEn := true.B } - io.CML_MicroTask_Config.Is_Transpose := lsuInfo.transpose - - pendingLoadC := true.B - pendingLoadCReg := regIdx - pendingLoadCFifoIdx := loadIdx - } - when(needA || needB || needC) { - scoreboard.io.update.load_allocate := true.B - scoreboard.io.update.load_alloc_fifo_idx := loadIdx - scoreboard.io.update.load_alloc_a_reg := Mux(needA, regIdx, 0.U) - scoreboard.io.update.load_alloc_b_reg := Mux(needB, regIdx, 0.U) - scoreboard.io.update.load_alloc_c_reg := Mux(needC, regIdx, 0.U) - scoreboard.io.update.load_alloc_has_a := needA - scoreboard.io.update.load_alloc_has_b := needB - scoreboard.io.update.load_alloc_has_c := needC - loadAllocIdx := loadAllocIdx + 1.U - pendingLoadRow := lsuInfo.row - pendingLoadColumn := lsuInfo.column - pendingLoadTranspose := lsuInfo.transpose - - loadAllocateEvent.eventType := 0.U - loadAllocateEvent.regId := regIdx - loadAllocateEvent.fifoIdx := loadIdx - loadAllocateEvent.needMask := Cat(needC.asUInt, needB.asUInt, needA.asUInt) - loadAllocateEvent.row := lsuInfo.row - loadAllocateEvent.column := lsuInfo.column - loadAllocateEvent.transpose := lsuInfo.transpose - loadAllocateEvent.isAcc := lsuInfo.isacc - loadAllocateEventEn := true.B + is(TaskCtrlOpKind.Release) { + val issueRelease = issueCtrl.data.asTypeOf(new AmuReleaseIO) + io.ygjkctrl.mrelease.valid := true.B + io.ygjkctrl.mrelease.bits.tokenRd(issueRelease.tokenRd) := true.B + + releaseIssueEvent.eventType := 0.U + releaseIssueEvent.token := issueRelease.tokenRd + releaseIssueEvent.slotId := issueSlotIdx + releaseIssueEvent.seqId := issueSlot.seqId + releaseIssueEventEn := true.B + } + + is(TaskCtrlOpKind.NopLike) { + // NopLike no-op: issue/complete in scheduler only. + } } + } - loadIssueEvent.eventType := 1.U - loadIssueEvent.regId := regIdx - loadIssueEvent.fifoIdx := loadIdx - loadIssueEvent.needMask := Cat(needC.asUInt, needB.asUInt, needA.asUInt) - loadIssueEvent.row := lsuInfo.row - loadIssueEvent.column := lsuInfo.column - loadIssueEvent.transpose := lsuInfo.transpose - loadIssueEvent.isAcc := lsuInfo.isacc - loadIssueEventEn := true.B + // ===================== State update: done -> retire -> enqueue -> issue ===================== + when(amlDone) { + val owner = fuAML.ownerSlot + assert(slots(owner).valid, "TaskController: AML owner slot invalid on done") + slots(owner).completed := true.B + fuAML.busy := false.B } - when(issueZeroAcc) { - val regIdx = arithInfo.md(1, 0) - val loadIdx = loadAllocIdx - - io.CML_MicroTask_Config.ApplicationTensor_C.dataType := ElementDataType.DataTypeWidth32 - io.CML_MicroTask_Config.MatrixRegTensor_M := cuteParams.Tensor_MN.U - io.CML_MicroTask_Config.MatrixRegTensor_N := cuteParams.Tensor_MN.U - io.CML_MicroTask_Config.MatrixRegId := regIdx - io.CML_MicroTask_Config.IsLoadMicroTask := true.B - io.CML_MicroTask_Config.IsStoreMicroTask := false.B - io.CML_MicroTask_Config.MicroTaskValid := true.B - if (YJPTASKDebugEnable) { - printf("[TaskController_IssueZeroAcc<%d>] reg=%d fifo=%d M=%d N=%d\n", - io.DebugTimeStampe, regIdx, loadIdx, io.CML_MicroTask_Config.MatrixRegTensor_M, io.CML_MicroTask_Config.MatrixRegTensor_N) - } - io.CML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := true.B - io.CML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := false.B - io.CML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B - io.CML_MicroTask_Config.Conherent := true.B - io.CML_MicroTask_Config.Is_Transpose := false.B - if (EnableDifftest) { - io.CML_MicroTask_Config.pc.get := headEntry.ctrl.pc.get - io.CML_MicroTask_Config.coreid.get := headEntry.ctrl.coreid.get - } + when(bmlDone) { + val owner = fuBML.ownerSlot + assert(slots(owner).valid, "TaskController: BML owner slot invalid on done") + slots(owner).completed := true.B + fuBML.busy := false.B + } - scoreboard.io.update.load_allocate := true.B - scoreboard.io.update.load_alloc_fifo_idx := loadIdx - scoreboard.io.update.load_alloc_a_reg := 0.U - scoreboard.io.update.load_alloc_b_reg := 0.U - scoreboard.io.update.load_alloc_c_reg := regIdx - scoreboard.io.update.load_alloc_has_a := false.B - scoreboard.io.update.load_alloc_has_b := false.B - scoreboard.io.update.load_alloc_has_c := true.B - loadAllocIdx := loadAllocIdx + 1.U - - pendingLoadC := true.B - pendingLoadCReg := regIdx - pendingLoadCFifoIdx := loadIdx - pendingLoadRow := 0.U - pendingLoadColumn := 0.U - pendingLoadTranspose := false.B - - loadAllocateEvent.eventType := 0.U - loadAllocateEvent.regId := regIdx - loadAllocateEvent.fifoIdx := loadIdx - loadAllocateEvent.needMask := "b100".U - loadAllocateEvent.row := 0.U - loadAllocateEvent.column := 0.U - loadAllocateEvent.transpose := false.B - loadAllocateEvent.isAcc := true.B - loadAllocateEventEn := true.B - - loadIssueEvent.eventType := 1.U - loadIssueEvent.regId := regIdx - loadIssueEvent.fifoIdx := loadIdx - loadIssueEvent.needMask := "b100".U - loadIssueEvent.row := 0.U - loadIssueEvent.column := 0.U - loadIssueEvent.transpose := false.B - loadIssueEvent.isAcc := true.B - loadIssueEventEn := true.B + when(cmlLoadDone) { + val owner = fuCMLLoad.ownerSlot + assert(slots(owner).valid, "TaskController: CML-load owner slot invalid on done") + slots(owner).completed := true.B + fuCMLLoad.busy := false.B } - when(issueZeroTr) { - val regIdx = arithInfo.md(1, 0) - val loadIdx = loadAllocIdx - - io.AML_MicroTask_Config.ApplicationTensor_A.dataType := ElementDataType.DataTypeWidth8 - io.AML_MicroTask_Config.MatrixRegTensor_M := cuteParams.Tensor_MN.U - io.AML_MicroTask_Config.MatrixRegTensor_K := cuteParams.Tensor_K.U / ReduceWidthByte.U - io.AML_MicroTask_Config.MatrixRegId := regIdx - io.AML_MicroTask_Config.MicroTaskValid := true.B - if (YJPTASKDebugEnable) { - printf("[TaskController_IssueZeroTr<%d>] reg=%d fifo=%d M=%d K=%d\n", - io.DebugTimeStampe, regIdx, loadIdx, io.AML_MicroTask_Config.MatrixRegTensor_M, io.AML_MicroTask_Config.MatrixRegTensor_K) - } - io.AML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := true.B - io.AML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := false.B - io.AML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B - io.AML_MicroTask_Config.Conherent := true.B - if (EnableDifftest) { - io.AML_MicroTask_Config.pc.get := headEntry.ctrl.pc.get - io.AML_MicroTask_Config.coreid.get := headEntry.ctrl.coreid.get + when(cmlStoreDone) { + val owner = fuCMLStore.ownerSlot + assert(slots(owner).valid, "TaskController: CML-store owner slot invalid on done") + slots(owner).completed := true.B + fuCMLStore.busy := false.B + } + + when(adcDone) { + val owner = fuCompute.ownerSlot + assert(slots(owner).valid, "TaskController: ADC owner slot invalid on done") + assert(!slots(owner).readADone, "TaskController: duplicated ADC done") + slots(owner).readADone := true.B + + val slot = slots(owner) + val mma = decodeMma(slot.entry.ctrl) + computeReadAFinishEvent.eventType := 1.U + computeReadAFinishEvent.aReg := mma.ms1(1, 0) + computeReadAFinishEvent.bReg := mma.ms2(1, 0) + computeReadAFinishEvent.cReg := mma.md(1, 0) + computeReadAFinishEvent.fifoIdx := slot.fifoIdx + computeReadAFinishEvent.mtilem := mma.mtilem + computeReadAFinishEvent.mtilen := mma.mtilen + computeReadAFinishEvent.mtilek := computeKFromMsew(mma.mtilek, mma.types1) + computeReadAFinishEvent.isMma := true.B + computeReadAFinishEvent.isFp := mma.isfp + computeReadAFinishEvent.slotId := owner + computeReadAFinishEvent.seqId := slot.seqId + computeReadAFinishEventEn := true.B + } + + when(bdcDone) { + val owner = fuCompute.ownerSlot + assert(slots(owner).valid, "TaskController: BDC owner slot invalid on done") + assert(!slots(owner).readBDone, "TaskController: duplicated BDC done") + slots(owner).readBDone := true.B + + val slot = slots(owner) + val mma = decodeMma(slot.entry.ctrl) + computeReadBFinishEvent.eventType := 2.U + computeReadBFinishEvent.aReg := mma.ms1(1, 0) + computeReadBFinishEvent.bReg := mma.ms2(1, 0) + computeReadBFinishEvent.cReg := mma.md(1, 0) + computeReadBFinishEvent.fifoIdx := slot.fifoIdx + computeReadBFinishEvent.mtilem := mma.mtilem + computeReadBFinishEvent.mtilen := mma.mtilen + computeReadBFinishEvent.mtilek := computeKFromMsew(mma.mtilek, mma.types1) + computeReadBFinishEvent.isMma := true.B + computeReadBFinishEvent.isFp := mma.isfp + computeReadBFinishEvent.slotId := owner + computeReadBFinishEvent.seqId := slot.seqId + computeReadBFinishEventEn := true.B + } + + when(cdcDone) { + val owner = fuCompute.ownerSlot + assert(slots(owner).valid, "TaskController: CDC owner slot invalid on done") + slots(owner).completed := true.B + fuCompute.busy := false.B + + val slot = slots(owner) + val mma = decodeMma(slot.entry.ctrl) + computeWriteCFinishEvent.eventType := 3.U + computeWriteCFinishEvent.aReg := mma.ms1(1, 0) + computeWriteCFinishEvent.bReg := mma.ms2(1, 0) + computeWriteCFinishEvent.cReg := mma.md(1, 0) + computeWriteCFinishEvent.fifoIdx := slot.fifoIdx + computeWriteCFinishEvent.mtilem := mma.mtilem + computeWriteCFinishEvent.mtilen := mma.mtilen + computeWriteCFinishEvent.mtilek := computeKFromMsew(mma.mtilek, mma.types1) + computeWriteCFinishEvent.isMma := true.B + computeWriteCFinishEvent.isFp := mma.isfp + computeWriteCFinishEvent.slotId := owner + computeWriteCFinishEvent.seqId := slot.seqId + computeWriteCFinishEventEn := true.B + } + + when(cdcDone) { + val owner = fuCompute.ownerSlot + // CDC can legally arrive in the same cycle as ADC/BDC. + assert(slots(owner).readADone || adcDone, "TaskController: CDC done before ADC done") + assert(slots(owner).readBDone || bdcDone, "TaskController: CDC done before BDC done") + } + + when(retireFire) { + val retireOH = slotOH(winHead) + + for (i <- 0 until WinDepth) { + when(slots(i).valid) { + slots(i).waitCompleteMask := slots(i).waitCompleteMask & (~retireOH)(WinDepth - 1, 0) + slots(i).waitReadAMask := slots(i).waitReadAMask & (~retireOH)(WinDepth - 1, 0) + slots(i).waitReadBMask := slots(i).waitReadBMask & (~retireOH)(WinDepth - 1, 0) + } } - scoreboard.io.update.load_allocate := true.B - scoreboard.io.update.load_alloc_fifo_idx := loadIdx - scoreboard.io.update.load_alloc_a_reg := regIdx - scoreboard.io.update.load_alloc_b_reg := 0.U - scoreboard.io.update.load_alloc_c_reg := 0.U - scoreboard.io.update.load_alloc_has_a := true.B - scoreboard.io.update.load_alloc_has_b := false.B - scoreboard.io.update.load_alloc_has_c := false.B - loadAllocIdx := loadAllocIdx + 1.U - - pendingLoadA := true.B - pendingLoadAReg := regIdx - pendingLoadAFifoIdx := loadIdx - pendingLoadRow := 0.U - pendingLoadColumn := 0.U - pendingLoadTranspose := false.B - - loadAllocateEvent.eventType := 0.U - loadAllocateEvent.regId := regIdx - loadAllocateEvent.fifoIdx := loadIdx - loadAllocateEvent.needMask := "b100".U - loadAllocateEvent.row := 0.U - loadAllocateEvent.column := 0.U - loadAllocateEvent.transpose := false.B - loadAllocateEvent.isAcc := false.B - loadAllocateEventEn := true.B - - loadIssueEvent.eventType := 1.U - loadIssueEvent.regId := regIdx - loadIssueEvent.fifoIdx := loadIdx - loadIssueEvent.needMask := "b100".U - loadIssueEvent.row := 0.U - loadIssueEvent.column := 0.U - loadIssueEvent.transpose := false.B - loadIssueEvent.isAcc := false.B - loadIssueEventEn := true.B + slots(winHead).valid := false.B + slots(winHead).issued := false.B + slots(winHead).completed := false.B + slots(winHead).readADone := false.B + slots(winHead).readBDone := false.B + slots(winHead).waitCompleteMask := 0.U + slots(winHead).waitReadAMask := 0.U + slots(winHead).waitReadBMask := 0.U + slots(winHead).opKind := TaskCtrlOpKind.NopLike + slots(winHead).entry := 0.U.asTypeOf(new DecodedAmuCtrlEntry) + slots(winHead).seqId := 0.U + slots(winHead).fifoIdx := 0.U } - val mmaDataType = RegInit(0.U(3.W)) - io.MTE_MicroTask_Config.dataType := mmaDataType - - when(issueMma) { - val aReg = Mux(isMma, mmaInfo.ms1(1, 0), arithInfo.md(1, 0)) - val bReg = Mux(isMma, mmaInfo.ms2(1, 0), arithInfo.md(1, 0)) - val cReg = Mux(isMma, mmaInfo.md(1, 0), arithInfo.md(1, 0)) - val computeIdx = computeIssueIdx - - val mVal = mmaInfo.mtilem - val nVal = mmaInfo.mtilen - // val kVal = mmaInfo.mtilek - val kVal = MuxLookup(mmaInfo.types1(1, 0), mmaInfo.mtilek)(Seq( - Bundles.MSew.e8 -> mmaInfo.mtilek, - Bundles.MSew.e16 -> mmaInfo.mtilek * 2.U, - Bundles.MSew.e32 -> mmaInfo.mtilek * 4.U, - Bundles.MSew.e4 -> mmaInfo.mtilek / 2.U, - )) + when(enqueueFire) { + val newSlot = WireInit(0.U.asTypeOf(new IssueWindowSlot)) + newSlot.valid := true.B + newSlot.issued := false.B + newSlot.completed := false.B + newSlot.readADone := false.B + newSlot.readBDone := false.B + newSlot.waitCompleteMask := 0.U + newSlot.waitReadAMask := 0.U + newSlot.waitReadBMask := 0.U + newSlot.opKind := deqOpKind + newSlot.entry := deqEntry + newSlot.seqId := seqIdAlloc + newSlot.fifoIdx := 0.U + + when( + deqOpKind === TaskCtrlOpKind.LoadA || + deqOpKind === TaskCtrlOpKind.LoadB || + deqOpKind === TaskCtrlOpKind.LoadC || + deqOpKind === TaskCtrlOpKind.ZeroAcc || + deqOpKind === TaskCtrlOpKind.ZeroTr + ) { + newSlot.fifoIdx := loadFifoIdxAlloc + }.elsewhen(deqOpKind === TaskCtrlOpKind.Compute) { + newSlot.fifoIdx := computeIssueIdx + }.elsewhen(deqOpKind === TaskCtrlOpKind.Store) { + newSlot.fifoIdx := storeIssueIdx + } - io.ADC_MicroTask_Config.MicroTaskValid := true.B - io.BDC_MicroTask_Config.MicroTaskValid := true.B - io.CDC_MicroTask_Config.MicroTaskValid := true.B - if (YJPTASKDebugEnable) { - printf("[TaskController_IssueMMA<%d>] aReg=%d bReg=%d cReg=%d fifo=%d m=%d n=%d k=%d isMma=%d isFp=%d\n", - io.DebugTimeStampe, aReg, bReg, cReg, computeIdx, mVal, nVal, kVal, isMma, Mux(isMma, mmaInfo.isfp, false.B)) + val newReadsAB = deqOpKind === TaskCtrlOpKind.Compute || (deqOpKind === TaskCtrlOpKind.Store && deqStoreReadsAB) + val newReadsC = deqOpKind === TaskCtrlOpKind.Compute || (deqOpKind === TaskCtrlOpKind.Store && deqStoreReadsC) + val newWritesAB = + deqOpKind === TaskCtrlOpKind.LoadA || + deqOpKind === TaskCtrlOpKind.LoadB || + deqOpKind === TaskCtrlOpKind.ZeroTr + val newWritesC = + deqOpKind === TaskCtrlOpKind.LoadC || + deqOpKind === TaskCtrlOpKind.ZeroAcc || + deqOpKind === TaskCtrlOpKind.Compute + + val depCompleteBits = Wire(Vec(WinDepth, Bool())) + val depReadABits = Wire(Vec(WinDepth, Bool())) + val depReadBBits = Wire(Vec(WinDepth, Bool())) + + // Build dependencies only against strictly older in-window slots (ring age from head). + // Age base is post-retire view because state update priority is done -> retire -> enqueue -> issue. + val depOlderMask = Wire(Vec(WinDepth, Bool())) + depOlderMask.foreach(_ := false.B) + val depHead = Mux(retireFire, (winHead + 1.U)(SlotIdxWidth - 1, 0), winHead) + val depCount = winCount - retireFire.asUInt + for (age <- 0 until WinDepth) { + val idx = (depHead + age.U)(SlotIdxWidth - 1, 0) + when(age.U < depCount) { + depOlderMask(idx) := true.B + } } - io.ADC_MicroTask_Config.ApplicationTensor_A.dataType := ElementDataType.DataTypeWidth8 - io.ADC_MicroTask_Config.MatrixRegTensor_M := mVal - io.ADC_MicroTask_Config.MatrixRegTensor_N := nVal - io.ADC_MicroTask_Config.MatrixRegTensor_K := kVal / ReduceWidthByte.U // TODO: It's not hardware-friendly, but it's ok for now - io.ADC_MicroTask_Config.MatrixRegId := aReg - io.ADC_MicroTask_Config.Is_Transpose := false.B - - io.BDC_MicroTask_Config.ApplicationTensor_B.dataType := ElementDataType.DataTypeWidth8 - io.BDC_MicroTask_Config.MatrixRegTensor_M := mVal - io.BDC_MicroTask_Config.MatrixRegTensor_N := nVal - io.BDC_MicroTask_Config.MatrixRegTensor_K := kVal / ReduceWidthByte.U // TODO: It's not hardware-friendly, but it's ok for now - io.BDC_MicroTask_Config.MatrixRegId := bReg - io.BDC_MicroTask_Config.Is_Transpose := false.B - - io.CDC_MicroTask_Config.ApplicationTensor_C.dataType := ElementDataType.DataTypeWidth32 - io.CDC_MicroTask_Config.MatrixRegTensor_M := mVal - io.CDC_MicroTask_Config.MatrixRegTensor_N := nVal - io.CDC_MicroTask_Config.MatrixRegTensor_K := kVal / ReduceWidthByte.U // TODO: It's not hardware-friendly, but it's ok for now - io.CDC_MicroTask_Config.MatrixRegId := cReg - io.CDC_MicroTask_Config.Is_Transpose := false.B - io.CDC_MicroTask_Config.Is_AfterOps_Tile := false.B - if (EnableDifftest) { - io.CDC_MicroTask_Config.pc.get := headEntry.ctrl.pc.get - io.CDC_MicroTask_Config.coreid.get := headEntry.ctrl.coreid.get + for (j <- 0 until WinDepth) { + val older = slots(j) + val olderValid = older.valid && depOlderMask(j) + val olderCtrl = older.entry.ctrl + val olderMma = decodeMma(olderCtrl) + val olderLsu = decodeLsu(olderCtrl) + val olderArith = decodeArith(olderCtrl) + + val olderIsStore = older.opKind === TaskCtrlOpKind.Store + val olderStoreReadsAB = olderIsStore && !olderLsu.isacc + val olderStoreReadsC = olderIsStore && olderLsu.isacc + + val olderReadABHit = + isReadAB(older.opKind, olderMma, olderLsu, deqEntry.writeRegs(0), olderStoreReadsAB) + val olderReadCHit = + isReadC(older.opKind, olderMma, olderLsu, deqEntry.writeRegs(0), olderStoreReadsC) + + val olderWriteABHitForNewRead = + (deqEntry.readValid(0) && isWriteAB(older.opKind, olderLsu, olderArith, deqEntry.readRegs(0))) || + (deqEntry.readValid(1) && isWriteAB(older.opKind, olderLsu, olderArith, deqEntry.readRegs(1))) || + (deqEntry.readValid(2) && isWriteAB(older.opKind, olderLsu, olderArith, deqEntry.readRegs(2))) + val olderWriteCHitForNewRead = + (deqEntry.readValid(0) && isWriteC(older.opKind, olderMma, olderLsu, olderArith, deqEntry.readRegs(0))) || + (deqEntry.readValid(1) && isWriteC(older.opKind, olderMma, olderLsu, olderArith, deqEntry.readRegs(1))) || + (deqEntry.readValid(2) && isWriteC(older.opKind, olderMma, olderLsu, olderArith, deqEntry.readRegs(2))) + val olderWriteABHitForNewWrite = deqEntry.writeValid(0) && isWriteAB(older.opKind, olderLsu, olderArith, deqEntry.writeRegs(0)) + val olderWriteCHitForNewWrite = deqEntry.writeValid(0) && isWriteC(older.opKind, olderMma, olderLsu, olderArith, deqEntry.writeRegs(0)) + + val olderWritesAB = + older.opKind === TaskCtrlOpKind.LoadA || + older.opKind === TaskCtrlOpKind.LoadB || + older.opKind === TaskCtrlOpKind.ZeroTr + val olderWritesC = + older.opKind === TaskCtrlOpKind.LoadC || + older.opKind === TaskCtrlOpKind.ZeroAcc || + older.opKind === TaskCtrlOpKind.Compute + + val olderReadsABByComputeA = isReadAOfCompute(older.opKind, olderMma, deqEntry.writeRegs(0)) + val olderReadsABByComputeB = isReadBOfCompute(older.opKind, olderMma, deqEntry.writeRegs(0)) + + val depCompleteJ = Mux( + deqOpKind === TaskCtrlOpKind.Release, + olderValid && (older.opKind === TaskCtrlOpKind.Store), + olderValid && ( + (olderWritesAB && newReadsAB && olderWriteABHitForNewRead) || + (olderWritesC && newReadsC && olderWriteCHitForNewRead) || + (olderWritesAB && newWritesAB && olderWriteABHitForNewWrite) || + (olderWritesC && newWritesC && olderWriteCHitForNewWrite) || + (newWritesAB && olderReadABHit && !olderReadsABByComputeA && !olderReadsABByComputeB) || + (newWritesC && olderReadCHit) + ) + ) + + depCompleteBits(j) := depCompleteJ + depReadABits(j) := (deqOpKind =/= TaskCtrlOpKind.Release) && olderValid && newWritesAB && olderReadsABByComputeA + depReadBBits(j) := (deqOpKind =/= TaskCtrlOpKind.Release) && olderValid && newWritesAB && olderReadsABByComputeB } - when (mmaInfo.isfp) { - when (mmaInfo.types1 === "b001".U && mmaInfo.types2 === "b001".U) { - mmaDataType := DataTypeF16F16F32 - }.elsewhen (mmaInfo.types1 === "b101".U && mmaInfo.types2 === "b101".U) { - mmaDataType := DataTypeBF16BF16F32 - }.elsewhen (mmaInfo.types1 === "b110".U && mmaInfo.types2 === "b110".U) { - mmaDataType := DataTypeTF32TF32F32 - }.otherwise { - mmaDataType := 7.U + newSlot.waitCompleteMask := depCompleteBits.asUInt + newSlot.waitReadAMask := depReadABits.asUInt + newSlot.waitReadBMask := depReadBBits.asUInt + + slots(enqueueSlotIdx) := newSlot + } + + when(issueFire) { + slots(issueSlotIdx).issued := true.B + + switch(issueSlot.opKind) { + is(TaskCtrlOpKind.LoadA) { + fuAML.busy := true.B + fuAML.ownerSlot := issueSlotIdx } - }.otherwise { // !mmaInfo.isfp - when (mmaInfo.types1 === "b000".U && mmaInfo.types2 === "b000".U) { - mmaDataType := DataTypeU8U8I32 - }.elsewhen (mmaInfo.types1 === "b100".U && mmaInfo.types2 === "b000".U) { - mmaDataType := DataTypeI8U8I32 - }.elsewhen (mmaInfo.types1 === "b000".U && mmaInfo.types2 === "b100".U) { - mmaDataType := DataTypeU8I8I32 - }.elsewhen (mmaInfo.types1 === "b100".U && mmaInfo.types2 === "b100".U) { - mmaDataType := DataTypeI8I8I32 - }.otherwise { - mmaDataType := 7.U + is(TaskCtrlOpKind.LoadB) { + fuBML.busy := true.B + fuBML.ownerSlot := issueSlotIdx + } + is(TaskCtrlOpKind.LoadC) { + fuCMLLoad.busy := true.B + fuCMLLoad.ownerSlot := issueSlotIdx + } + is(TaskCtrlOpKind.ZeroAcc) { + fuCMLLoad.busy := true.B + fuCMLLoad.ownerSlot := issueSlotIdx + } + is(TaskCtrlOpKind.ZeroTr) { + fuAML.busy := true.B + fuAML.ownerSlot := issueSlotIdx + } + is(TaskCtrlOpKind.Compute) { + fuCompute.busy := true.B + fuCompute.ownerSlot := issueSlotIdx + } + is(TaskCtrlOpKind.Store) { + fuCMLStore.busy := true.B + fuCMLStore.ownerSlot := issueSlotIdx + } + is(TaskCtrlOpKind.Release) { + slots(issueSlotIdx).completed := true.B + } + is(TaskCtrlOpKind.NopLike) { + slots(issueSlotIdx).completed := true.B } } - - scoreboard.io.update.compute_issue := true.B - scoreboard.io.update.compute_issue_a_reg := aReg - scoreboard.io.update.compute_issue_b_reg := bReg - scoreboard.io.update.compute_issue_c_reg := cReg - scoreboard.io.update.compute_issue_fifo_idx := computeIdx - computeIssueIdx := computeIssueIdx + 1.U - - pendingComputeA := true.B - pendingComputeAReg := aReg - pendingComputeAFifoIdx := computeIdx - pendingComputeB := true.B - pendingComputeBReg := bReg - pendingComputeBFifoIdx := computeIdx - pendingComputeC := true.B - pendingComputeCReg := cReg - pendingComputeCFifoIdx := computeIdx - pendingComputeM := mVal - pendingComputeN := nVal - pendingComputeK := kVal - pendingComputeIsMma := isMma - pendingComputeIsFp := Mux(isMma, mmaInfo.isfp, false.B) - - computeIssueEvent.eventType := 0.U - computeIssueEvent.aReg := aReg - computeIssueEvent.bReg := bReg - computeIssueEvent.cReg := cReg - computeIssueEvent.fifoIdx := computeIdx - computeIssueEvent.mtilem := mVal - computeIssueEvent.mtilen := nVal - computeIssueEvent.mtilek := kVal - computeIssueEvent.isMma := isMma - computeIssueEvent.isFp := Mux(isMma, mmaInfo.isfp, false.B) - computeIssueEventEn := true.B } - when(issueStore) { - val regIdx = lsuInfo.ms(1, 0) - val storeIdx = storeIssueIdx - assert(lsuInfo.stride(5, 0) === 0.U, "TaskController store stride must be 64B aligned") - - io.CML_MicroTask_Config.ApplicationTensor_D.ApplicationTensor_D_BaseVaddr := lsuInfo.baseAddr - io.CML_MicroTask_Config.ApplicationTensor_D.ApplicationTensor_D_Stride_M := lsuInfo.stride - io.CML_MicroTask_Config.ApplicationTensor_D.BlockTensor_D_BaseVaddr := lsuInfo.baseAddr - io.CML_MicroTask_Config.ApplicationTensor_D.dataType := MuxLookup(lsuInfo.widths, ElementDataType.DataTypeWidth32)(Seq( - Bundles.MSew.e8 -> ElementDataType.DataTypeWidth8, - Bundles.MSew.e16 -> ElementDataType.DataTypeWidth16, - Bundles.MSew.e32 -> ElementDataType.DataTypeWidth32, - Bundles.MSew.e4 -> ElementDataType.DataTypeWidth4 - )) - - io.CML_MicroTask_Config.StoreTaskInfo.Is_ZeroStore := false.B - io.CML_MicroTask_Config.Conherent := true.B - io.CML_MicroTask_Config.Is_Transpose := lsuInfo.transpose - io.CML_MicroTask_Config.MatrixRegTensor_M := lsuInfo.row - io.CML_MicroTask_Config.MatrixRegTensor_N := lsuInfo.column - io.CML_MicroTask_Config.MatrixRegId := regIdx - - io.CML_MicroTask_Config.IsLoadMicroTask := false.B - io.CML_MicroTask_Config.IsStoreMicroTask := true.B - - io.CML_MicroTask_Config.MicroTaskValid := true.B - if (YJPTASKDebugEnable) { - printf("[TaskController_IssueCMLStore<%d>] reg=%d fifo=%d row=%d col=%d stride=%x coher=%d transpose=%d base=%x dataType=%d\n", - io.DebugTimeStampe, regIdx, storeIdx, lsuInfo.row, lsuInfo.column, lsuInfo.stride, io.CML_MicroTask_Config.Conherent, - lsuInfo.transpose, lsuInfo.baseAddr, io.CML_MicroTask_Config.ApplicationTensor_D.dataType) + // Pointer and count updates (retire + enqueue) + when(retireFire && !enqueueFire) { + winHead := (winHead + 1.U)(SlotIdxWidth - 1, 0) + winCount := winCount - 1.U + }.elsewhen(!retireFire && enqueueFire) { + winTail := (winTail + 1.U)(SlotIdxWidth - 1, 0) + winCount := winCount + 1.U + }.elsewhen(retireFire && enqueueFire) { + val nextHead = (winHead + 1.U)(SlotIdxWidth - 1, 0) + when(windowFull) { + // retire slot is immediately reused by enqueue; head/tail advance together by one. + winHead := nextHead + winTail := nextHead + winCount := winCount + }.otherwise { + winHead := nextHead + winTail := (winTail + 1.U)(SlotIdxWidth - 1, 0) + winCount := winCount } - if (EnableDifftest) { - io.CML_MicroTask_Config.pc.get := headEntry.ctrl.pc.get - io.CML_MicroTask_Config.coreid.get := headEntry.ctrl.coreid.get + } + + when(enqueueFire) { + seqIdAlloc := seqIdAlloc + 1.U + when( + deqOpKind === TaskCtrlOpKind.LoadA || + deqOpKind === TaskCtrlOpKind.LoadB || + deqOpKind === TaskCtrlOpKind.LoadC || + deqOpKind === TaskCtrlOpKind.ZeroAcc || + deqOpKind === TaskCtrlOpKind.ZeroTr + ) { + loadFifoIdxAlloc := loadFifoIdxAlloc + 1.U + }.elsewhen(deqOpKind === TaskCtrlOpKind.Compute) { + computeIssueIdx := computeIssueIdx + 1.U + }.elsewhen(deqOpKind === TaskCtrlOpKind.Store) { + storeIssueIdx := storeIssueIdx + 1.U } + } + + // ===================== Finish events (load/store) ===================== + when(amlDone) { + val owner = fuAML.ownerSlot + val slot = slots(owner) + val op = slot.opKind + val lsu = decodeLsu(slot.entry.ctrl) + val arith = decodeArith(slot.entry.ctrl) + val reg = Mux(op === TaskCtrlOpKind.ZeroTr, arith.md(1, 0), lsu.ms(1, 0)) - scoreboard.io.update.store_issue := true.B - scoreboard.io.update.store_issue_c_reg := regIdx - scoreboard.io.update.store_issue_fifo_idx := storeIdx - storeIssueIdx := storeIssueIdx + 1.U - - pendingStore := true.B - pendingStoreReg := regIdx - pendingStoreFifoIdx := storeIdx - pendingStoreRow := lsuInfo.row - pendingStoreColumn := lsuInfo.column - pendingStoreTranspose := lsuInfo.transpose - pendingStoreIsAcc := lsuInfo.isacc - - storeIssueEvent.eventType := 0.U - storeIssueEvent.regId := regIdx - storeIssueEvent.fifoIdx := storeIdx - storeIssueEvent.row := lsuInfo.row - storeIssueEvent.column := lsuInfo.column - storeIssueEvent.transpose := lsuInfo.transpose - storeIssueEvent.isAcc := lsuInfo.isacc - storeIssueEventEn := true.B + loadAFinishEvent.eventType := 2.U + loadAFinishEvent.regId := reg + loadAFinishEvent.fifoIdx := slot.fifoIdx + loadAFinishEvent.needMask := "b001".U + loadAFinishEvent.row := Mux(op === TaskCtrlOpKind.ZeroTr, 0.U, lsu.row) + loadAFinishEvent.column := Mux(op === TaskCtrlOpKind.ZeroTr, 0.U, lsu.column) + loadAFinishEvent.transpose := Mux(op === TaskCtrlOpKind.ZeroTr, false.B, lsu.transpose) + loadAFinishEvent.isAcc := false.B + loadAFinishEvent.slotId := owner + loadAFinishEvent.seqId := slot.seqId + loadAFinishEventEn := true.B } - when(issueRelease) { - io.ygjkctrl.mrelease.valid := true.B - if (YJPTASKDebugEnable) { - printf("[TaskController_IssueRelease<%d>] token=%d\n", io.DebugTimeStampe, releaseInfo.tokenRd) - } - io.ygjkctrl.mrelease.bits.tokenRd(releaseInfo.tokenRd) := true.B - releaseIssueEvent.eventType := 0.U - releaseIssueEvent.token := releaseInfo.tokenRd - releaseIssueEventEn := true.B + when(bmlDone) { + val owner = fuBML.ownerSlot + val slot = slots(owner) + val lsu = decodeLsu(slot.entry.ctrl) + + loadBFinishEvent.eventType := 2.U + loadBFinishEvent.regId := lsu.ms(1, 0) + loadBFinishEvent.fifoIdx := slot.fifoIdx + loadBFinishEvent.needMask := "b010".U + loadBFinishEvent.row := lsu.row + loadBFinishEvent.column := lsu.column + loadBFinishEvent.transpose := lsu.transpose + loadBFinishEvent.isAcc := false.B + loadBFinishEvent.slotId := owner + loadBFinishEvent.seqId := slot.seqId + loadBFinishEventEn := true.B } + when(cmlLoadDone) { + val owner = fuCMLLoad.ownerSlot + val slot = slots(owner) + val op = slot.opKind + val lsu = decodeLsu(slot.entry.ctrl) + val arith = decodeArith(slot.entry.ctrl) + val reg = Mux(op === TaskCtrlOpKind.ZeroAcc, arith.md(1, 0), lsu.ms(1, 0)) + val row = Mux(op === TaskCtrlOpKind.ZeroAcc, 0.U, lsu.row) + val col = Mux(op === TaskCtrlOpKind.ZeroAcc, 0.U, lsu.column) + + loadCFinishEvent.eventType := 2.U + loadCFinishEvent.regId := reg + loadCFinishEvent.fifoIdx := slot.fifoIdx + loadCFinishEvent.needMask := "b100".U + loadCFinishEvent.row := row + loadCFinishEvent.column := col + loadCFinishEvent.transpose := Mux(op === TaskCtrlOpKind.ZeroAcc, false.B, lsu.transpose) + loadCFinishEvent.isAcc := true.B + loadCFinishEvent.slotId := owner + loadCFinishEvent.seqId := slot.seqId + loadCFinishEventEn := true.B + } + + when(cmlStoreDone) { + val owner = fuCMLStore.ownerSlot + val slot = slots(owner) + val lsu = decodeLsu(slot.entry.ctrl) + + storeFinishEvent.eventType := 1.U + storeFinishEvent.regId := lsu.ms(1, 0) + storeFinishEvent.fifoIdx := slot.fifoIdx + storeFinishEvent.row := lsu.row + storeFinishEvent.column := lsu.column + storeFinishEvent.transpose := lsu.transpose + storeFinishEvent.isAcc := lsu.isacc + storeFinishEvent.slotId := owner + storeFinishEvent.seqId := slot.seqId + storeFinishEventEn := true.B + } + + // ===================== Release DiffTest alignment ===================== if (EnableDifftest) { - val difftestAmuFinish = DifftestModule(new DiffAmuFinishEvent, delay = 0, dontCare = true) - difftestAmuFinish.coreid := io.ygjkctrl.amuCtrl.bits.coreid.get - difftestAmuFinish.index := 4.U - difftestAmuFinish.valid := io.ygjkctrl.mrelease.valid - difftestAmuFinish.pc := headEntry.ctrl.pc.get - difftestAmuFinish.bankValid.foreach(_ := false.B) - difftestAmuFinish.bankAddr.foreach(_ := 0.U) - difftestAmuFinish.bankMask.foreach(_ := 0.U) - difftestAmuFinish.data.foreach(_ := 0.U) - difftestAmuFinish.finish := io.ygjkctrl.mrelease.valid + val releaseFinish = DifftestModule(new DiffAmuFinishEvent(CMatrixRegNBanks, DiffAmuFinishWordsPerBank), delay = 0, dontCare = true) + val releaseIssueOwnerSlot = issueSlotIdx + val releaseIssueOwnerEntry = issueSlot.entry + + releaseFinish.coreid := Mux(issueFire && issueSlot.opKind === TaskCtrlOpKind.Release, releaseIssueOwnerEntry.ctrl.coreid.get, 0.U) + releaseFinish.index := 4.U + releaseFinish.valid := issueFire && issueSlot.opKind === TaskCtrlOpKind.Release + releaseFinish.pc := Mux(issueFire && issueSlot.opKind === TaskCtrlOpKind.Release, releaseIssueOwnerEntry.ctrl.pc.get, 0.U) + releaseFinish.bankValid.foreach(_ := false.B) + releaseFinish.bankAddr.foreach(_ := 0.U) + releaseFinish.bankMask.foreach(_ := 0.U) + releaseFinish.data.foreach(_ := 0.U) + releaseFinish.finish := issueFire && issueSlot.opKind === TaskCtrlOpKind.Release + + when(issueFire && issueSlot.opKind === TaskCtrlOpKind.Release) { + dontTouch(releaseIssueOwnerSlot) + } } - // ===================== ChiselDB 日志提交 ===================== + // ===================== ChiselDB log commit ===================== loadEventTable.log(loadAllocateEvent, loadAllocateEventEn, "LoadAllocate", clock, reset) loadEventTable.log(loadIssueEvent, loadIssueEventEn, "LoadIssue", clock, reset) loadEventTable.log(loadAFinishEvent, loadAFinishEventEn, "LoadFinish", clock, reset) @@ -1116,4 +1463,24 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { storeEventTable.log(storeFinishEvent, storeFinishEventEn, "StoreFinish", clock, reset) releaseEventTable.log(releaseIssueEvent, releaseIssueEventEn, "ReleaseIssue", clock, reset) + + val mmaDoneType = decodeMmaComputeType(decodeMma(slots(fuCompute.ownerSlot).entry.ctrl)) + val releaseDone = issueFire && issueSlot.opKind === TaskCtrlOpKind.Release + io.perfProbe.ownedWork := ownedWork + io.perfProbe.retire := retireFire + io.perfProbe.loadADone := loadAFinishEventEn + io.perfProbe.loadBDone := loadBFinishEventEn + io.perfProbe.loadCDone := loadCFinishEventEn + io.perfProbe.storeDone := storeFinishEventEn + io.perfProbe.compDone := computeWriteCFinishEventEn + io.perfProbe.releaseDone := releaseDone + io.perfProbe.mmaNonfpDone := computeWriteCFinishEventEn && !decodeMma(slots(fuCompute.ownerSlot).entry.ctrl).isfp + io.perfProbe.mmaFp16Done := computeWriteCFinishEventEn && (mmaDoneType === MteComputeType.F16F16F32) + io.perfProbe.mmaBf16Done := computeWriteCFinishEventEn && (mmaDoneType === MteComputeType.BF16BF16F32) + io.perfProbe.mmaTf32Done := computeWriteCFinishEventEn && (mmaDoneType === MteComputeType.TF32TF32F32) + io.perfProbe.amlActive := fuAML.busy + io.perfProbe.bmlActive := fuBML.busy + io.perfProbe.cmlLoadActive := fuCMLLoad.busy + io.perfProbe.mteActive := fuCompute.busy + io.perfProbe.cmlStoreActive := fuCMLStore.busy }